From db70613702472169dfab6d3c2061244346eb3078 Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Wed, 31 Dec 2025 17:19:30 +0100 Subject: [PATCH 1/7] Adding tests for nats --- Makefile | 6 +- QUICK_START.md | 13 +- README.md | 60 +-- docs/source/index.rst | 54 +-- docs/source/intro.md | 60 +-- docs/source/quick_start.md | 11 +- protobunny/__init__.py | 50 +-- protobunny/__init__.py.j2 | 53 ++- protobunny/asyncio/__init__.py | 41 +- protobunny/asyncio/__init__.py.j2 | 42 +- protobunny/asyncio/backends/__init__.py | 20 +- .../asyncio/backends/mosquitto/connection.py | 18 - protobunny/asyncio/backends/nats/__init__.py | 1 + .../asyncio/backends/nats/connection.py | 408 ++++++++++++++++++ protobunny/asyncio/backends/nats/queues.py | 11 + .../asyncio/backends/python/connection.py | 16 - .../asyncio/backends/rabbitmq/connection.py | 22 +- .../asyncio/backends/rabbitmq/queues.py | 6 +- .../asyncio/backends/redis/connection.py | 131 +++--- protobunny/backends/__init__.py | 19 +- protobunny/backends/mosquitto/connection.py | 19 - protobunny/backends/nats/__init__.py | 1 + protobunny/backends/nats/connection.py | 32 ++ protobunny/backends/nats/queues.py | 35 ++ protobunny/backends/python/connection.py | 19 +- protobunny/backends/rabbitmq/connection.py | 18 - protobunny/backends/redis/connection.py | 18 - protobunny/config.py | 10 +- protobunny/helpers.py | 6 +- protobunny/models.py | 12 +- pyproject.toml | 18 +- setup.py | 2 +- tests/conftest.py | 62 ++- tests/test_connection.py | 39 +- tests/test_integration.py | 39 +- tests/test_queues.py | 46 +- tests/test_results.py | 1 - tests/test_tasks.py | 6 - tests/utils.py | 306 +++++++++---- uv.lock | 18 +- 40 files changed, 1171 insertions(+), 578 deletions(-) create mode 100644 protobunny/asyncio/backends/nats/__init__.py create mode 100644 protobunny/asyncio/backends/nats/connection.py create mode 100644 protobunny/asyncio/backends/nats/queues.py create mode 100644 protobunny/backends/nats/__init__.py create mode 100644 protobunny/backends/nats/connection.py create mode 100644 protobunny/backends/nats/queues.py diff --git a/Makefile b/Makefile index 2ef6dde..5726b1a 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,11 @@ t: PYTHONASYNCIODEBUG=1 PYTHONBREAKPOINT=ipdb.set_trace uv run pytest ${t} -s -vvvv --durations=0 integration-test: - uv run pytest tests/ -m "integration" ${t} + uv run pytest tests/ -m "integration" -k rabbitmq ${t} + uv run pytest tests/ -m "integration" -k redis ${t} + uv run pytest tests/ -m "integration" -k python ${t} + uv run pytest tests/ -m "integration" -k mosquitto ${t} + uv run pytest tests/ -m "integration" -k nats ${t} # run ./nats-server -js -sd nats_storage integration-test-py310: source .venv310/bin/activate diff --git a/QUICK_START.md b/QUICK_START.md index b0397a9..00a4690 100644 --- a/QUICK_START.md +++ b/QUICK_START.md @@ -21,7 +21,7 @@ You can also add it manually to pyproject.toml dependencies: ```toml dependencies = [ - "protobunny[rabbitmq, numpy]>=0.1.0", + "protobunny[rabbitmq, numpy]>=0.1.2a2", # your other dependencies ... ] ``` @@ -182,6 +182,7 @@ def worker2(task: mml.main.tasks.TaskMessage) -> None: pb.subscribe(mml.main.tasks.TaskMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) + pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) @@ -293,10 +294,14 @@ if conn.is_connected(): conn.close() ``` -If you set the `generated-package-root` folder option, you might need to add the path to your `sys.path`. +If you set the `generated-package-root` folder option, you might need to add that path to your `sys.path`. You can do it conveniently by calling `config_lib` on top of your module, before importing the library: + ```python -pb.config_lib() +import protobunny as pb +pb.config_lib() +# now you can import the library from the generated package root +import mymessagelib as mml ``` ## Complete example @@ -310,7 +315,7 @@ version = "0.1.0" description = "Project to test protobunny" requires-python = ">=3.10" dependencies = [ - "protobunny[rabbitmq,redis,numpy,mosquitto] >=0.1.2a1", + "protobunny[rabbitmq,redis,numpy,mosquitto]>=0.1.2a1", ] [tool.protobunny] diff --git a/README.md b/README.md index b901b62..59f8236 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Note: The project is in early development. Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for any message broker, including Redis and MQTT. +type-safe interface for several message brokers, including Redis and MQTT. It simplifies messaging for asynchronous tasks by providing: @@ -21,10 +21,10 @@ It simplifies messaging for asynchronous tasks by providing: * Transparently serialize "JSON-like" payload fields (numpy-friendly) -## Requirements +## Minimal requirements - Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) + ## Project scope @@ -39,40 +39,44 @@ Protobunny is designed for teams who use messaging to coordinate work between mi - Optional validation of required fields - Builtin logging service ---- - -## Usage - -See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) for installation and quick start guide. +## Why Protobunny? -Full docs are available at [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). +While there are many messaging libraries for Python, Protobunny is built specifically for teams that treat **Protobuf as the single source of truth**. +* **Type-Safe by Design**: Built natively for `protobuf/betterproto`. +* **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). +* **Backend Agnostic**: Write your logic once. Switch between Redis, RabbitMQ, Mosquitto, or Local Queues by changing a single variable in configuration. +* **Sync & Async**: Support for both modern `asyncio` and traditional synchronous workloads. +* **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. --- -## Development +### Feature Comparison with some existing libraries -### Run tests -```bash -make test -``` +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:------------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | -### Integration tests (RabbitMQ required) -Integration tests expect RabbitMQ to be running (for example via Docker Compose in this repo): -```bash -docker compose up -d -make integration-test -``` ---- +## Usage -### Future work +See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) or on the [docs site](https://am-flow.github.io/protobunny/quickstart.html). -- Support grcp -- Support for RabbitMQ certificates (through `pika`) -- More backends: - - NATS - - Kafka - - Cloud providers (AWS SQS/SNS) +Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). + +--- +### Roadmap + +- [x] **Core Support**: Redis, RabbitMQ, Mosquitto. +- [x] **Semantic Patterns**: Automatic `tasks` package routing. +- [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. +- [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. +- [ ] **Cloud-Native**: NATS (Core & JetStream) integration. +- [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. +- [ ] **More backends**: Kafka support. --- diff --git a/docs/source/index.rst b/docs/source/index.rst index 6655250..bc12de3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,57 +1,7 @@ Protobunny ========== - -.. warning:: The project is in early development. - -Protobunny is the open-source evolution of -`AM-Flow `__\ ’s internal messaging library. While -the original was purpose-built for RabbitMQ, this version has been -completely re-engineered to provide a unified, type-safe interface for -any message broker, including Redis and MQTT. - -It simplifies messaging for asynchronous tasks by providing: - -- A clean “message-first” API -- Python class generation from Protobuf messages using betterproto -- Connections facilities for backends -- Message publishing/subscribing with typed topics -- Support also “task-like” queues (shared/competing consumers) - vs. broadcast subscriptions -- Generate and consume ``Result`` messages (success/failure + optional - return payload) -- Transparent messages serialization/deserialization -- Support async and sync contexts -- Transparently serialize “JSON-like” payload fields (numpy-friendly) - -Requirements ------------- - -- Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) - -Project scope -------------- - -Protobunny is designed for teams who use messaging to coordinate work -between microservices or different python processes and want: - -- A small API surface, easy to learn and use, both async and sync -- Typed messaging with protobuf messages as payloads -- Supports various backends by simple configuration: RabbitMQ, Redis, - Mosquitto, local in-process queues -- Consistent topic naming and routing -- Builtin task queue semantics and result messages -- Transparent handling of JSON-like payload fields as plain - dictionaries/lists -- Optional validation of required fields -- Builtin logging service - --------------- - -Usage ------ - -See the `Quick start guide `__. +.. include:: intro.md + :parser: myst_parser.sphinx_ -------------- diff --git a/docs/source/intro.md b/docs/source/intro.md index b901b62..59f8236 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -6,7 +6,7 @@ Note: The project is in early development. Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for any message broker, including Redis and MQTT. +type-safe interface for several message brokers, including Redis and MQTT. It simplifies messaging for asynchronous tasks by providing: @@ -21,10 +21,10 @@ It simplifies messaging for asynchronous tasks by providing: * Transparently serialize "JSON-like" payload fields (numpy-friendly) -## Requirements +## Minimal requirements - Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) + ## Project scope @@ -39,40 +39,44 @@ Protobunny is designed for teams who use messaging to coordinate work between mi - Optional validation of required fields - Builtin logging service ---- - -## Usage - -See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) for installation and quick start guide. +## Why Protobunny? -Full docs are available at [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). +While there are many messaging libraries for Python, Protobunny is built specifically for teams that treat **Protobuf as the single source of truth**. +* **Type-Safe by Design**: Built natively for `protobuf/betterproto`. +* **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). +* **Backend Agnostic**: Write your logic once. Switch between Redis, RabbitMQ, Mosquitto, or Local Queues by changing a single variable in configuration. +* **Sync & Async**: Support for both modern `asyncio` and traditional synchronous workloads. +* **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. --- -## Development +### Feature Comparison with some existing libraries -### Run tests -```bash -make test -``` +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:------------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | -### Integration tests (RabbitMQ required) -Integration tests expect RabbitMQ to be running (for example via Docker Compose in this repo): -```bash -docker compose up -d -make integration-test -``` ---- +## Usage -### Future work +See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) or on the [docs site](https://am-flow.github.io/protobunny/quickstart.html). -- Support grcp -- Support for RabbitMQ certificates (through `pika`) -- More backends: - - NATS - - Kafka - - Cloud providers (AWS SQS/SNS) +Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). + +--- +### Roadmap + +- [x] **Core Support**: Redis, RabbitMQ, Mosquitto. +- [x] **Semantic Patterns**: Automatic `tasks` package routing. +- [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. +- [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. +- [ ] **Cloud-Native**: NATS (Core & JetStream) integration. +- [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. +- [ ] **More backends**: Kafka support. --- diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index b0397a9..87e348a 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -21,7 +21,7 @@ You can also add it manually to pyproject.toml dependencies: ```toml dependencies = [ - "protobunny[rabbitmq, numpy]>=0.1.0", + "protobunny[rabbitmq, numpy]>=0.1.2a2", # your other dependencies ... ] ``` @@ -182,6 +182,7 @@ def worker2(task: mml.main.tasks.TaskMessage) -> None: pb.subscribe(mml.main.tasks.TaskMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) + pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) @@ -295,8 +296,12 @@ if conn.is_connected(): If you set the `generated-package-root` folder option, you might need to add the path to your `sys.path`. You can do it conveniently by calling `config_lib` on top of your module, before importing the library: + ```python -pb.config_lib() +import protobunny as pb +pb.config_lib() +# now you can import the library from the generated package root +import mymessagelib as mml ``` ## Complete example @@ -310,7 +315,7 @@ version = "0.1.0" description = "Project to test protobunny" requires-python = ">=3.10" dependencies = [ - "protobunny[rabbitmq,redis,numpy,mosquitto] >=0.1.2a1", + "protobunny[rabbitmq,redis,numpy,mosquitto]>=0.1.2a1", ] [tool.protobunny] diff --git a/protobunny/__init__.py b/protobunny/__init__.py index db209ff..453f700 100644 --- a/protobunny/__init__.py +++ b/protobunny/__init__.py @@ -1,15 +1,8 @@ """ -A module providing support for messaging and communication using RabbitMQ as the backend. +A module providing support for messaging and communication using the configured broker as the backend. This module includes functionality for publishing, subscribing, and managing message queues, -as well as dynamically managing imports and configurations for RabbitMQ-based communication -logics. It enables both synchronous and asynchronous operations, while also supporting -connection resetting and management. - -Modules and functionality are primarily imported from the core RabbitMQ backend, dynamically -generated package-specific configurations, and other base utilities. Exports are adjusted -as per the backend configuration. - +as well as dynamically managing imports and configurations for the backend. """ __all__ = [ @@ -47,7 +40,7 @@ from importlib.metadata import version from types import FrameType, ModuleType -from .backends import BaseSyncQueue, LoggingSyncQueue +from .backends import BaseSyncConnection, BaseSyncQueue, LoggingSyncQueue from .config import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, @@ -67,31 +60,34 @@ SyncCallback, ) +__version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) + ############################ # -- Sync top-level methods ############################ -def reset_connection(): - backend = get_backend() - return backend.connection.reset_connection() +def connect() -> "BaseSyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + return conn -def connect(): - backend = get_backend() - return backend.connection.connect() +def disconnect() -> None: + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() -def disconnect(): - backend = get_backend() - return backend.connection.disconnect() - - -__version__ = version(PACKAGE_NAME) - - -log = logging.getLogger(PACKAGE_NAME) +def reset_connection() -> "BaseSyncConnection": + """Reset the singleton connection.""" + connection = connect() + connection.disconnect() + return connect() def publish(message: "ProtoBunnyMessage") -> None: @@ -274,9 +270,9 @@ def shutdown(signum: int, _: FrameType | None) -> None: signal.signal(signal.SIGINT, shutdown) signal.signal(signal.SIGTERM, shutdown) - log.info("Protobunny Started. Press Ctrl+C to exit.") + log.info("Protobunny Started.") signal.pause() - log.info("Protobunny Stopped. Press Ctrl+C to exit.") + log.info("Protobunny Stopped.") def config_lib() -> None: diff --git a/protobunny/__init__.py.j2 b/protobunny/__init__.py.j2 index 71a0e5c..3703ba3 100644 --- a/protobunny/__init__.py.j2 +++ b/protobunny/__init__.py.j2 @@ -1,16 +1,10 @@ """ -A module providing support for messaging and communication using RabbitMQ as the backend. +A module providing support for messaging and communication using the configured broker as the backend. This module includes functionality for publishing, subscribing, and managing message queues, -as well as dynamically managing imports and configurations for RabbitMQ-based communication -logics. It enables both synchronous and asynchronous operations, while also supporting -connection resetting and management. - -Modules and functionality are primarily imported from the core RabbitMQ backend, dynamically -generated package-specific configurations, and other base utilities. Exports are adjusted -as per the backend configuration. - +as well as dynamically managing imports and configurations for the backend. """ + __all__ = [ "get_backend", "get_message_count", @@ -59,37 +53,40 @@ from .exceptions import RequeueMessage, ConnectionError from .registry import default_registry from .helpers import get_backend, get_queue -from .backends import LoggingSyncQueue, BaseSyncQueue +from .backends import LoggingSyncQueue, BaseSyncQueue, BaseSyncConnection if tp.TYPE_CHECKING: from .models import LoggerCallback, ProtoBunnyMessage, SyncCallback, IncomingMessageProtocol from .core.results import Result +__version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) + ############################ # -- Sync top-level methods ############################ -def reset_connection(): - backend = get_backend() - return backend.connection.reset_connection() +def connect() -> "BaseSyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + return conn -def connect(): - backend = get_backend() - return backend.connection.connect() +def disconnect() -> None: + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() -def disconnect(): - backend = get_backend() - return backend.connection.disconnect() - - -__version__ = version(PACKAGE_NAME) - - -log = logging.getLogger(PACKAGE_NAME) +def reset_connection() -> "BaseSyncConnection": + """Reset the singleton connection.""" + connection = connect() + connection.disconnect() + return connect() def publish(message: "ProtoBunnyMessage") -> None: @@ -271,9 +268,9 @@ def run_forever() -> None: signal.signal(signal.SIGINT, shutdown) signal.signal(signal.SIGTERM, shutdown) - log.info("Protobunny Started. Press Ctrl+C to exit.") + log.info("Protobunny Started.") signal.pause() - log.info("Protobunny Stopped. Press Ctrl+C to exit.") + log.info("Protobunny Stopped.") def config_lib() -> None: @@ -290,4 +287,4 @@ from .{{ generated_package_name }} import ( # noqa {{ i }}, {% endfor %} ) -####################################################### \ No newline at end of file +####################################################### diff --git a/protobunny/asyncio/__init__.py b/protobunny/asyncio/__init__.py index 841b715..fcd15e7 100644 --- a/protobunny/asyncio/__init__.py +++ b/protobunny/asyncio/__init__.py @@ -1,6 +1,3 @@ -import asyncio -import signal - """ A module providing support for messaging and communication using RabbitMQ as the backend. @@ -14,6 +11,7 @@ as per the backend configuration. """ + __all__ = [ "get_message_count", "get_queue", @@ -36,14 +34,17 @@ "connect", "disconnect", "run_forever", + "config_lib", # from .core "commons", "results", ] +import asyncio import inspect import itertools import logging +import signal import textwrap import typing as tp from importlib.metadata import version @@ -70,31 +71,39 @@ ) +from .. import config_lib as config_lib from ..helpers import get_backend, get_queue -from .backends import BaseAsyncQueue, LoggingAsyncQueue +from .backends import BaseAsyncConnection, BaseAsyncQueue, LoggingAsyncQueue __version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) + + ############################ # -- Async top-level methods ############################ -log = logging.getLogger(PACKAGE_NAME) +async def connect() -> "BaseAsyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + return conn -async def reset_connection(): - backend = get_backend() - return await backend.connection.reset_connection() +async def disconnect() -> None: + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + await conn.disconnect() -async def connect(): - backend = get_backend() - return await backend.connection.connect() - -async def disconnect(): - backend = get_backend() - return await backend.connection.disconnect() +async def reset_connection() -> "BaseAsyncConnection": + """Reset the singleton connection.""" + connection = await connect() + await connection.disconnect() + return await connect() async def publish(message: "ProtoBunnyMessage") -> None: @@ -297,3 +306,5 @@ def _handler(s: int) -> asyncio.Task[None]: commons, results, ) + +####################################################### diff --git a/protobunny/asyncio/__init__.py.j2 b/protobunny/asyncio/__init__.py.j2 index 43065da..9e5e8ee 100644 --- a/protobunny/asyncio/__init__.py.j2 +++ b/protobunny/asyncio/__init__.py.j2 @@ -1,7 +1,3 @@ -{# @formatter:off #} -{# language: python #}import signal -import asyncio - """ A module providing support for messaging and communication using RabbitMQ as the backend. @@ -37,12 +33,15 @@ __all__ = [ "connect", "disconnect", "run_forever", + "config_lib", # from .{{ generated_package_name }} {% for i in main_imports|sort %} "{{ i }}", {% endfor %} ] +import signal +import asyncio import inspect import itertools import logging @@ -67,32 +66,38 @@ if tp.TYPE_CHECKING: from types import ModuleType -from .backends import LoggingAsyncQueue, BaseAsyncQueue +from .backends import LoggingAsyncQueue, BaseAsyncQueue, BaseAsyncConnection from ..helpers import get_queue, get_backend - +from .. import config_lib as config_lib __version__ = version(PACKAGE_NAME) -############################ -# -- Async top-level methods -############################ log = logging.getLogger(PACKAGE_NAME) -async def reset_connection(): - backend = get_backend() - return await backend.connection.reset_connection() +############################ +# -- Async top-level methods +############################ + +async def connect() -> "BaseAsyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + return conn -async def connect(): - backend = get_backend() - return await backend.connection.connect() +async def disconnect() -> None: + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + await conn.disconnect() -async def disconnect(): - backend = get_backend() - return await backend.connection.disconnect() +async def reset_connection() -> "BaseAsyncConnection": + """Reset the singleton connection.""" + connection = await connect() + await connection.disconnect() + return await connect() async def publish(message: "ProtoBunnyMessage") -> None: @@ -295,3 +300,4 @@ from ..{{ generated_package_name }} import ( # noqa {{ i }}, {% endfor %} ) +####################################################### diff --git a/protobunny/asyncio/backends/__init__.py b/protobunny/asyncio/backends/__init__.py index 04c16a4..30786c5 100644 --- a/protobunny/asyncio/backends/__init__.py +++ b/protobunny/asyncio/backends/__init__.py @@ -4,8 +4,9 @@ import typing as tp from abc import ABC, abstractmethod +from protobunny import asyncio as pb + from ...exceptions import RequeueMessage -from ...helpers import get_backend from ...models import ( AsyncCallback, BaseQueue, @@ -99,7 +100,7 @@ def get_consumer_count(self, topic: str) -> int | tp.Awaitable[int]: ... @abstractmethod - def setup_queue(self, topic: str, shared: bool) -> tp.Any | tp.Awaitable[tp.Any]: + def setup_queue(self, topic: str, shared: bool, **kwargs) -> tp.Any | tp.Awaitable[tp.Any]: ... @@ -170,8 +171,7 @@ def is_connected(self) -> bool: class BaseAsyncQueue(BaseQueue, ABC): async def get_connection(self) -> "BaseAsyncConnection": - backend = get_backend() - return await backend.connection.connect() + return await pb.connect() async def publish(self, message: ProtoBunnyMessage) -> None: """Publish a message to the queue. @@ -344,12 +344,11 @@ async def send_message( Returns: """ - backend = get_backend() message = Envelope( body=body, correlation_id=correlation_id or b"", ) - conn = await backend.connection.connect() + conn = await pb.connect() await conn.publish(topic, message) @@ -434,6 +433,15 @@ async def _receive( def is_task(topic: str) -> bool: + """ + Use the backend configured delimiter to check if `tasks` is in it + + Args: + topic: the topic to check + + + Returns: True if tasks is in the topic, else False + """ delimiter = default_configuration.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) diff --git a/protobunny/asyncio/backends/mosquitto/connection.py b/protobunny/asyncio/backends/mosquitto/connection.py index 4036dad..4256a8e 100644 --- a/protobunny/asyncio/backends/mosquitto/connection.py +++ b/protobunny/asyncio/backends/mosquitto/connection.py @@ -20,24 +20,6 @@ VHOST = os.environ.get("MOSQUITTO_VHOST", "/") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async Mosquitto Connection wrapper using aiomqtt.""" diff --git a/protobunny/asyncio/backends/nats/__init__.py b/protobunny/asyncio/backends/nats/__init__.py new file mode 100644 index 0000000..d45bfe6 --- /dev/null +++ b/protobunny/asyncio/backends/nats/__init__.py @@ -0,0 +1 @@ +from . import connection, queues # noqa diff --git a/protobunny/asyncio/backends/nats/connection.py b/protobunny/asyncio/backends/nats/connection.py new file mode 100644 index 0000000..834b23b --- /dev/null +++ b/protobunny/asyncio/backends/nats/connection.py @@ -0,0 +1,408 @@ +"""Implements a NATS Connection""" +import asyncio +import functools +import logging +import os +import typing as tp +import urllib.parse +import uuid +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor + +import can_ada +import nats +from nats.aio.subscription import Subscription +from nats.errors import ConnectionClosedError, TimeoutError +from nats.js.errors import BadRequestError, NoStreamResponseError + +from ....config import default_configuration +from ....exceptions import ConnectionError, PublishError, RequeueMessage +from ....models import Envelope, IncomingMessageProtocol +from .. import BaseAsyncConnection, is_task + +log = logging.getLogger(__name__) + +VHOST = os.environ.get("NATS_VHOST", "/") + + +class Connection(BaseAsyncConnection): + """Async NATS Connection wrapper.""" + + _lock: asyncio.Lock | None = None + instance_by_vhost: dict[str, "Connection | None"] = {} + + def __init__( + self, + username: str | None = None, + password: str | None = None, + host: str | None = None, + port: int | None = None, + vhost: str = "", + url: str | None = None, + worker_threads: int = 2, + prefetch_count: int = 1, + requeue_delay: int = 3, + heartbeat: int = 1200, + ): + """Initialize NATS connection. + + Args: + username: NATS username + password: NATS password + host: NATS host + port: NATS port + url: NATS URL. It will override username, password, host and port + vhost: NATS virtual host (it's used as db number string) + worker_threads: number of concurrent callback workers to use + prefetch_count: how many messages to prefetch from the queue + requeue_delay: how long to wait before re-queueing a message (seconds) + """ + super().__init__() + uname = username or os.environ.get("NATS_USERNAME", "") + passwd = password or os.environ.get("NATS_PASSWORD", "") + host = host or os.environ.get("NATS_HOST", "localhost") + port = port or int(os.environ.get("NATS_PORT", "4222")) + # URL encode credentials and vhost to prevent injection + vhost = vhost or VHOST + self.vhost = vhost + username = urllib.parse.quote(uname, safe="") + password = urllib.parse.quote(passwd, safe="") + host = urllib.parse.quote(host, safe="") + # URL for connection + url = url or os.environ.get("NATS_URL", "") + if url: + # reconstruct url for safety + parsed = can_ada.parse(url) + url = f"{parsed.protocol}//{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" + else: + # Build the URL based on what is available + if username and password: + url = f"nats://{username}:{password}@{host}:{port}{vhost}" + elif password: + url = f"nats://:{password}@{host}:{port}{vhost}" + elif username: + url = f"nats://{username}@{host}:{port}{vhost}" + else: + url = f"nats://{host}:{port}{vhost}" + + self._url = url + self._connection: nats.NATS | None = None + self.prefetch_count = prefetch_count + self.requeue_delay = requeue_delay + self.heartbeat = heartbeat + self.queues: dict[str, list[dict]] = defaultdict(list) + self.consumers: dict[str, dict] = {} + self.executor = ThreadPoolExecutor(max_workers=worker_threads) + self._instance_lock: asyncio.Lock | None = None + + self._delimiter = default_configuration.backend_config.topic_delimiter + self._exchange = default_configuration.backend_config.namespace + self._stream_name = f"{self._exchange.upper()}_TASKS" + + async def __aenter__(self) -> "Connection": + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + await self.disconnect() + return False + + @property + def lock(self) -> asyncio.Lock: + """Lazy instance lock.""" + if self._instance_lock is None: + self._instance_lock = asyncio.Lock() + return self._instance_lock + + @classmethod + def _get_class_lock(cls) -> asyncio.Lock: + """Ensure the class lock is bound to the current running loop.""" + if cls._lock is None: + cls._lock = asyncio.Lock() + return cls._lock + + def build_topic_key(self, topic: str) -> str: + return f"{self._exchange}.{topic}" + + @property + def is_connected_event(self) -> asyncio.Event: + """Lazily create the event in the current running loop.""" + if self._is_connected_event is None: + self._is_connected_event = asyncio.Event() + return self._is_connected_event + + @property + def connection(self) -> "nats.NATS": + """Get the connection object. + + Raises: + ConnectionError: If not connected + """ + if not self._connection: + raise ConnectionError("Connection not initialized. Call connect() first.") + return self._connection + + async def connect(self, timeout: float = 30.0) -> "Connection": + """Establish NATS connection. + + Args: + timeout: Maximum time to wait for connection establishment (seconds) + + Raises: + ConnectionError: If connection fails + asyncio.TimeoutError: If connection times out + """ + async with self.lock: + if self.instance_by_vhost.get(self.vhost) and self.is_connected(): + return self.instance_by_vhost[self.vhost] + try: + log.info("Establishing NATS connection to %s", self._url.split("@")[-1]) + self._connection = await nats.connect( + self._url, connect_timeout=timeout, max_reconnect_attempts=3 + ) + self.is_connected_event.set() + log.info("Successfully connected to NATS") + self.instance_by_vhost[self.vhost] = self + if default_configuration.use_tasks_in_nats: + # Create the jetstream if not existing + js = self._connection.jetstream() + # For NATS, tasks package can only be at first level after main package library + # Warning: don't bury tasks messages after three levels of hierarchy + task_patterns = [ + f"{self._exchange}.*.tasks.>", + ] + try: + await js.add_stream( + name=self._stream_name, + subjects=task_patterns, + ) + except BadRequestError: + # This usually means the stream already exists with a different config + log.warning("Stream %s exists with different settings.", self._stream_name) + return self + + except asyncio.TimeoutError as e: + log.error("NATS connection timeout after %.1f seconds", timeout) + self.is_connected_event.clear() + self._connection = None + raise ConnectionError(f"Failed to connect to NATS: {e}") from e + except Exception as e: + self.is_connected_event.clear() + self._connection = None + log.exception("Failed to establish NATS connection") + raise ConnectionError(f"Failed to connect to NATS: {e}") from e + + async def disconnect(self, timeout: float = 10.0) -> None: + """Close NATS connection and cleanup resources. + + Args: + timeout: Maximum time to wait for cleanup (seconds) + """ + async with self.lock: + if not self.is_connected(): + log.debug("Already disconnected from NATS") + return + + try: + log.info("Closing NATS connection") + # Cancel all subscriptions + for tag, consumer in self.consumers.items(): + subscription = consumer["subscription"] + try: + await subscription.unsubscribe() + # We give the task a moment to wrap up if needed + # await asyncio.sleep(0) # force context switching + # await asyncio.wait([task], timeout=2.0) + except Exception as e: + log.warning("Error stopping NATS subscription %s: %s", tag, e) + + # Shutdown Thread Executor (if used for sync callbacks) + self.executor.shutdown(wait=False, cancel_futures=True) + + # Close the NATS Connection Pool + if self._connection: + await asyncio.wait_for(self._connection.close(), timeout=timeout) + + except asyncio.TimeoutError: + log.warning("NATS connection close timeout after %.1f seconds", timeout) + except Exception: + log.exception("Error during NATS disconnect") + finally: + # Reset state + self._connection = None + self.queues.clear() # (Local queue metadata) + self.consumers.clear() + self.is_connected_event.clear() + # Remove from registry + Connection.instance_by_vhost.pop(self.vhost, None) + log.info("NATS connection closed") + + # Subscriptions methods + async def setup_queue( + self, topic: str, shared: bool, callback: tp.Callable | None = None + ) -> Subscription: + topic_key = self.build_topic_key(topic) + cb = functools.partial(self._nats_handler, callback) + if shared: + log.debug("Subscribing shared worker to JetStream: %s", topic_key) + js = self._connection.jetstream() + # We use a durable name so multiple instances share the same task state + group_name = f"{self._exchange}_{topic_key.replace('.', '_')}" + subscription = await js.subscribe( + subject=topic_key, + durable=group_name, + cb=cb, + manual_ack=True, + stream=self._stream_name, + ) + else: + log.debug("Subscribing broadcast listener to NATS Core: %s", topic_key) + subscription = await self._connection.subscribe(subject=topic_key, cb=cb) + return subscription + + async def subscribe(self, topic: str, callback: tp.Callable, shared: bool = False) -> str: + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + + topic_key = self.build_topic_key(topic) + sub_tag = f"{topic_key}_{uuid.uuid4().hex[:8]}" + subscription = await self.setup_queue(topic, shared, callback) + self.consumers[sub_tag] = { + "subscription": subscription, + "topic": topic_key, + "is_shared": shared, + } + return sub_tag + + async def _nats_handler(self, callback, msg): + topic = msg.subject + reply = msg.reply + body = msg.data + is_shared_queue = is_task(topic) + routing_key = msg.subject.removeprefix(f"{self._exchange}{self._delimiter}") + envelope = Envelope(body=body, correlation_id=reply, routing_key=routing_key) + try: + if asyncio.iscoroutinefunction(callback): + await callback(envelope) + else: + asyncio.run_coroutine_threadsafe(callback(envelope), self._loop) + if is_shared_queue: + await msg.ack() + except RequeueMessage: + log.warning("Requeuing message on topic '%s' after RequeueMessage exception", topic) + await asyncio.sleep(self.requeue_delay) + if not is_shared_queue: + await self._connection.publish(topic, body, reply=reply) + else: + # TODO check if NATS has a requeue logic + js = self._connection.jetstream() + await js.publish(topic, body) + await msg.ack() + except Exception: + log.exception("Callback failed for topic %s", topic) + # TODO check if NATS has a reject logic + await msg.ack() # avoid retry logic for potentially poisoning messages + + async def unsubscribe(self, tag: str, **kwargs) -> None: + if tag not in self.consumers: + return + sub_info = self.consumers[tag] + await sub_info["subscription"].unsubscribe() + del sub_info["subscription"] + log.info("Unsubscribed from %s", sub_info["topic"]) + self.consumers.pop(tag) + # TODO check if we need to handle self.queues[topic] cleanup here + + async def publish( + self, + topic: str, + message: "IncomingMessageProtocol", + **kwargs, + ) -> None: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + + topic_key = self.build_topic_key(topic) + is_shared = is_task(topic) + + # Standardize headers + headers = {"correlation_id": message.correlation_id} if message.correlation_id else None + + try: + if is_shared: + # Persistent "Task" publishing via JetStream + log.debug("Publishing persistent task to NATS JetStream: %s", topic_key) + js = self._connection.jetstream() + await js.publish(subject=topic_key, payload=message.body, headers=headers) + else: + # Volatile "PubSub" publishing via NATS Core + log.debug("Publishing broadcast to NATS Core: %s", topic_key) + await self._connection.publish( + subject=topic_key, payload=message.body, headers=headers + ) + except (ConnectionClosedError, TimeoutError, NoStreamResponseError, Exception) as e: + log.error("NATS publish failed: %s", e) + raise PublishError(str(e)) from e + + async def purge(self, topic: str, reset_groups: bool = False) -> None: + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + topic_key = self.build_topic_key(topic) + # NATS purges messages matching a subject within the stream + try: + jsm = self._connection.jsm() # Get JetStream Management context + + log.info("Purging NATS subject '%s' from stream %s", topic, self._stream_name) + await jsm.purge_stream(self._stream_name, subject=topic_key) + + if reset_groups: + # In NATS, we must find consumers specifically tied to this topic + # Protobunny convention: durable name includes the topic + group_name = f"{self._exchange}_{topic_key.replace('.', '_')}" + try: + await jsm.delete_consumer(self._stream_name, group_name) + log.debug("Deleted NATS durable consumer: %s", group_name) + except nats.js.errors.NotFoundError: + pass # Consumer already gone + + except Exception as e: + log.error("Failed to purge NATS subject %s: %s", topic, e) + raise ConnectionError(f"Purge failed: {e}") + + async def get_message_count(self, topic: str) -> int: + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + topic_key = self.build_topic_key(topic) + try: + jsm = self._connection.jsm() + stream_info = await jsm.stream_info(self._stream_name, subjects_filter=topic) + return stream_info.state.messages + except nats.js.errors.NotFoundError: + return 0 + except Exception as e: + log.error("Failed to get NATS message count for %s: %s", topic_key, e) + return 0 + + async def get_consumer_count(self, topic: str) -> int: + topic_key = self.build_topic_key(topic) + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + try: + jsm = self._connection.jsm() + stream_info = await jsm.stream_info(self._stream_name, subjects_filter=topic) + return stream_info.state.consumer_count + except nats.js.errors.NotFoundError: + return 0 + except Exception as e: + log.error("Failed to get NATS consumer count for %s: %s", topic_key, e) + return 0 diff --git a/protobunny/asyncio/backends/nats/queues.py b/protobunny/asyncio/backends/nats/queues.py new file mode 100644 index 0000000..4777c82 --- /dev/null +++ b/protobunny/asyncio/backends/nats/queues.py @@ -0,0 +1,11 @@ +import logging + +from .. import ( + BaseAsyncQueue, +) + +log = logging.getLogger(__name__) + + +class AsyncQueue(BaseAsyncQueue): + pass diff --git a/protobunny/asyncio/backends/python/connection.py b/protobunny/asyncio/backends/python/connection.py index 5073282..01fbe31 100644 --- a/protobunny/asyncio/backends/python/connection.py +++ b/protobunny/asyncio/backends/python/connection.py @@ -327,19 +327,3 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.disconnect() - - -# Convenience functions -async def connect() -> Connection: - return await Connection.get_connection(vhost=VHOST) - - -async def reset_connection() -> Connection: - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() diff --git a/protobunny/asyncio/backends/rabbitmq/connection.py b/protobunny/asyncio/backends/rabbitmq/connection.py index 772051c..62bdbb3 100644 --- a/protobunny/asyncio/backends/rabbitmq/connection.py +++ b/protobunny/asyncio/backends/rabbitmq/connection.py @@ -9,6 +9,7 @@ from concurrent.futures import ThreadPoolExecutor import aio_pika +import can_ada from aio_pika.abc import ( AbstractChannel, AbstractExchange, @@ -24,24 +25,6 @@ VHOST = os.environ.get("RABBITMQ_VHOST", "/") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async RabbitMQ Connection wrapper.""" @@ -97,7 +80,8 @@ def __init__( if not url: self._url = f"amqp://{username}:{password}@{host}:{port}/{clean_vhost}?heartbeat={heartbeat}&timeout={timeout}&fail_fast=no" else: - self._url = url + parsed = can_ada.parse(url) + self._url = f"amqp://{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" self._exchange_name = exchange_name self._dl_exchange = dl_exchange self._dl_queue = dl_queue diff --git a/protobunny/asyncio/backends/rabbitmq/queues.py b/protobunny/asyncio/backends/rabbitmq/queues.py index e516b2f..9788d68 100644 --- a/protobunny/asyncio/backends/rabbitmq/queues.py +++ b/protobunny/asyncio/backends/rabbitmq/queues.py @@ -3,10 +3,12 @@ import aio_pika from aio_pika import DeliveryMode +from protobunny import asyncio as pb from protobunny.asyncio.backends import ( BaseAsyncQueue, ) -from protobunny.asyncio.backends.rabbitmq.connection import connect + +# from protobunny.asyncio.backends.rabbitmq.connection import connect log = logging.getLogger(__name__) @@ -35,5 +37,5 @@ async def send_message( correlation_id=correlation_id, delivery_mode=DeliveryMode.PERSISTENT if persistent else DeliveryMode.NOT_PERSISTENT, ) - conn = await connect() + conn = await pb.connect() await conn.publish(topic, message) diff --git a/protobunny/asyncio/backends/redis/connection.py b/protobunny/asyncio/backends/redis/connection.py index b1bc29a..a955e6f 100644 --- a/protobunny/asyncio/backends/redis/connection.py +++ b/protobunny/asyncio/backends/redis/connection.py @@ -9,6 +9,7 @@ import uuid from concurrent.futures import ThreadPoolExecutor +import can_ada import redis.asyncio as redis from redis import RedisError, ResponseError @@ -22,24 +23,6 @@ VHOST = os.environ.get("REDIS_VHOST") or os.environ.get("REDIS_DB", "0") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async Redis Connection wrapper.""" @@ -98,8 +81,13 @@ def __init__( url = f"redis://{username}:{password}@{host}:{port}/{vhost}?protocol=3" elif password: url = f"redis://:{password}@{host}:{port}/{vhost}?protocol=3" + elif username: + url = f"redis://{username}@{host}:{port}/{vhost}?protocol=3" else: url = f"redis://{host}:{port}/{vhost}?protocol=3" + else: + parsed = can_ada.parse(url) + url = f"redis://{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" self._url = url self._connection: redis.Redis | None = None @@ -139,57 +127,6 @@ def _get_class_lock(cls) -> asyncio.Lock: def build_topic_key(self, topic: str) -> str: return f"{self._exchange}:{topic}" - async def disconnect(self, timeout: float = 10.0) -> None: - """Close Redis connection and cleanup resources. - - Args: - timeout: Maximum time to wait for cleanup (seconds) - """ - async with self.lock: - if not self.is_connected(): - log.debug("Already disconnected from Redis") - return - - try: - log.info("Closing Redis connection") - - # In Redis, consumers are local asyncio Tasks. - # We cancel them here. Note: Redis doesn't have "exclusive queues" - # that auto-delete, so we just clear our local registry. - for tag, consumer in self.consumers.items(): - task = consumer["task"] - try: - task.cancel() - # We give the task a moment to wrap up if needed - await asyncio.sleep(0) # force context switching - await asyncio.wait([task], timeout=2.0) - except Exception as e: - log.warning("Error stopping Redis consumer %s: %s", tag, e) - - # Shutdown Thread Executor (if used for sync callbacks) - self.executor.shutdown(wait=False, cancel_futures=True) - - # Close the Redis Connection Pool - if self._connection: - # aclose() closes the connection pool and all underlying connections - await asyncio.wait_for(self._connection.aclose(), timeout=timeout) - - except asyncio.TimeoutError: - log.warning("Redis connection close timeout after %.1f seconds", timeout) - except Exception: - log.exception("Error during Redis disconnect") - finally: - # Reset state - self._connection = None - self._exchange = None # (Stream name/prefix) - self.queues.clear() # (Local queue metadata) - self.consumers.clear() - self.is_connected_event.clear() - - # 5. Remove from registry - Connection.instance_by_vhost.pop(self.vhost, None) - log.info("Redis connection closed") - @property def is_connected_event(self) -> asyncio.Event: """Lazily create the event in the current running loop.""" @@ -253,6 +190,54 @@ async def connect(self, timeout: float = 30.0) -> "Connection": log.exception("Failed to establish Redis connection") raise ConnectionError(f"Failed to connect to Redis: {e}") from e + async def disconnect(self, timeout: float = 10.0) -> None: + """Close Redis connection and cleanup resources. + + Args: + timeout: Maximum time to wait for cleanup (seconds) + """ + async with self.lock: + if not self.is_connected(): + log.debug("Already disconnected from Redis") + return + + try: + log.info("Closing Redis connection") + # In Redis, consumers are local asyncio Tasks. + # We cancel them here. Note: Redis doesn't have "exclusive queues" + # that auto-delete, so we just clear our local registry. + for tag, consumer in self.consumers.items(): + task = consumer["task"] + try: + task.cancel() + # We give the task a moment to wrap up if needed + await asyncio.sleep(0) # force context switching + await asyncio.wait([task], timeout=2.0) + except Exception as e: + log.warning("Error stopping Redis consumer %s: %s", tag, e) + + # Shutdown Thread Executor (if used for sync callbacks) + self.executor.shutdown(wait=False, cancel_futures=True) + + # Close the Redis Connection Pool + if self._connection: + # aclose() closes the connection pool and all underlying connections + await asyncio.wait_for(self._connection.aclose(), timeout=timeout) + + except asyncio.TimeoutError: + log.warning("Redis connection close timeout after %.1f seconds", timeout) + except Exception: + log.exception("Error during Redis disconnect") + finally: + # Reset state + self._connection = None + self.queues.clear() # (Local queue metadata) + self.consumers.clear() + self.is_connected_event.clear() + # Remove from registry + Connection.instance_by_vhost.pop(self.vhost, None) + log.info("Redis connection closed") + async def subscribe(self, topic: str, callback: tp.Callable, shared: bool = False) -> str: """Subscribe to Redis. @@ -511,7 +496,7 @@ async def publish( # Tasks messages go to streams but the logger do a simple pubsub psubscription to .* # Send the message to the same topic with redis.publish so it appears there # Note: this should be used carefully (e.g. only for debugging) - # as it doubles the network calls for publishing tasks + # as it doubles the network calls when publishing tasks log.debug("Publishing message to topic: %s for logger", topic_key) await self._connection.publish(topic_key, message.body) else: @@ -523,7 +508,7 @@ async def publish( async def _on_message_task( self, stream_key: str, group_name: str, msg_id: str, payload: dict, callback: tp.Callable ): - """Wraps the user callback to simulate RabbitMQ behavior.""" + """Wraps the user callback.""" # Response is not decoded because we use bytes. But the keys will be bytes as well normalized_payload = { k.decode() if isinstance(k, bytes) else k: v for k, v in payload.items() @@ -556,12 +541,12 @@ async def _on_message_task( await asyncio.sleep(self.requeue_delay) await self._connection.xadd(name=stream_key, fields=payload) await self._connection.xack(stream_key, group_name, msg_id) - except Exception as e: + except Exception: log.exception("Callback failed for message %s", msg_id) - raise PublishError(f"Failed to publish message to topic {topic}: {e}") from e # Avoid poisoning messages # Note: In Redis, if you don't XACK, the message stays in the # Pending Entry List (PEL) for retry logic. + await self._connection.xack(stream_key, group_name, msg_id) async def unsubscribe(self, tag: str, if_unused: bool = True, if_empty: bool = True) -> None: task_to_cancel = None diff --git a/protobunny/backends/__init__.py b/protobunny/backends/__init__.py index 5e40154..a040ee9 100644 --- a/protobunny/backends/__init__.py +++ b/protobunny/backends/__init__.py @@ -5,8 +5,9 @@ import typing as tp from abc import ABC, abstractmethod +import protobunny as pb + from ..exceptions import RequeueMessage -from ..helpers import get_backend from ..models import ( BaseQueue, IncomingMessageProtocol, @@ -95,6 +96,10 @@ def get_consumer_count(self, topic: str) -> int | tp.Awaitable[int]: def setup_queue(self, topic: str, shared: bool) -> tp.Any | tp.Awaitable[tp.Any]: ... + @abstractmethod + def build_topic_key(self, topic: str) -> str: + ... + class BaseAsyncConnection(BaseConnection, ABC): instance_by_vhost: dict[str, "BaseAsyncConnection"] @@ -134,6 +139,9 @@ def get_async_connection(self, **kwargs) -> "BaseAsyncConnection": return self._async_conn return self.async_class(**kwargs) + def build_topic_key(self, topic: str) -> str: + pass + def _run_loop(self) -> None: """Run the event loop in a dedicated thread.""" loop = None @@ -404,8 +412,7 @@ def disconnect(self, timeout: float = 10.0) -> None: class BaseSyncQueue(BaseQueue, ABC): def get_connection(self) -> BaseConnection: - backend = get_backend() - return backend.connection.connect() + return pb.connect() def publish(self, message: "ProtoBunnyMessage") -> None: """Publish a message to the queue. @@ -429,7 +436,7 @@ def publish_result( correlation_id: """ result_topic = topic or self.result_topic - log.info("Publishing result to: %s", result_topic) + log.debug("Publishing result to: %s", result_topic) self.send_message( result_topic, bytes(result), correlation_id=correlation_id, persistent=False ) @@ -453,10 +460,12 @@ def _receive( # In case the subscription has .# as binding key, # this method catches also results message for all the topics in that namespace. return - # msg: "ProtoBunnyMessage" = deserialize_message(message.routing_key, message.body) + try: callback(deserialize_message(message.routing_key, message.body)) except RequeueMessage: + # The callback raised a RequeueMessage exception + # a result message will be published by the specific Connection implementation raise except Exception as exc: # pylint: disable=W0703 log.exception("Could not process message: %s", str(message.body)) diff --git a/protobunny/backends/mosquitto/connection.py b/protobunny/backends/mosquitto/connection.py index c5556cd..58c08a2 100644 --- a/protobunny/backends/mosquitto/connection.py +++ b/protobunny/backends/mosquitto/connection.py @@ -16,25 +16,6 @@ from ...models import Envelope, IncomingMessageProtocol from .. import BaseConnection - -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - log = logging.getLogger(__name__) VHOST = os.environ.get("MOSQUITTO_VHOST", "/") diff --git a/protobunny/backends/nats/__init__.py b/protobunny/backends/nats/__init__.py new file mode 100644 index 0000000..d45bfe6 --- /dev/null +++ b/protobunny/backends/nats/__init__.py @@ -0,0 +1 @@ +from . import connection, queues # noqa diff --git a/protobunny/backends/nats/connection.py b/protobunny/backends/nats/connection.py new file mode 100644 index 0000000..0f89a8a --- /dev/null +++ b/protobunny/backends/nats/connection.py @@ -0,0 +1,32 @@ +"""Implements a NATS Connection with sync methods""" +import asyncio +import logging +import os +import threading + +from ...asyncio.backends.nats.connection import Connection as NATSConnection +from .. import BaseSyncConnection + +log = logging.getLogger(__name__) + +VHOST = os.environ.get("NATS_VHOST", "/") + + +class Connection(BaseSyncConnection): + """Synchronous wrapper around Async Rmq Connection. + + Manages a dedicated event loop in a background thread to run async operations. + + Example: + .. code-block:: python + + with Connection() as conn: + conn.publish("my.topic", message) + tag = conn.subscribe("my.topic", callback) + + """ + + _lock = threading.RLock() + _stopped: asyncio.Event | None = None + instance_by_vhost: dict[str, "Connection"] = {} + async_class = NATSConnection diff --git a/protobunny/backends/nats/queues.py b/protobunny/backends/nats/queues.py new file mode 100644 index 0000000..bd0df90 --- /dev/null +++ b/protobunny/backends/nats/queues.py @@ -0,0 +1,35 @@ +import logging + +from protobunny.backends import ( + BaseSyncQueue, +) +from protobunny.models import Envelope + +log = logging.getLogger(__name__) + + +class SyncQueue(BaseSyncQueue): + """Message queue backed by pika and RabbitMQ.""" + + def get_tag(self) -> str: + return self.subscription + + def send_message( + self, topic: str, body: bytes, correlation_id: str | None = None, persistent: bool = True + ): + """Low-level message sending implementation. + + Args: + topic: a topic name for direct routing or a routing key with special binding keys + body: serialized message (e.g. a serialized protobuf message or a json string) + correlation_id: is present for result messages + persistent: if true will use aio_pika.DeliveryMode.PERSISTENT + + Returns: + + """ + message = Envelope( + body=body, + correlation_id=correlation_id, + ) + self.get_connection().publish(topic, message) diff --git a/protobunny/backends/python/connection.py b/protobunny/backends/python/connection.py index 7317f0b..59878f4 100644 --- a/protobunny/backends/python/connection.py +++ b/protobunny/backends/python/connection.py @@ -118,6 +118,9 @@ def __init__(self, vhost: str = "/", requeue_delay: int = 3): self._subscriptions: dict[str, dict] = {} self.logger_prefix = default_configuration.logger_prefix + def build_topic_key(self, topic: str) -> str: + pass + class Connection(BaseLocalConnection): """Synchronous local connection using threads.""" @@ -280,19 +283,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect() return False - - -# Convenience functions -def connect() -> Connection: - return Connection.get_connection(vhost=VHOST) - - -def reset_connection() -> Connection: - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() diff --git a/protobunny/backends/rabbitmq/connection.py b/protobunny/backends/rabbitmq/connection.py index 1896ba9..d4a6684 100644 --- a/protobunny/backends/rabbitmq/connection.py +++ b/protobunny/backends/rabbitmq/connection.py @@ -12,24 +12,6 @@ VHOST = os.environ.get("RABBITMQ_VHOST", "/") -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - class Connection(BaseSyncConnection): """Synchronous wrapper around Async Rmq Connection. diff --git a/protobunny/backends/redis/connection.py b/protobunny/backends/redis/connection.py index e6e319b..a52e222 100644 --- a/protobunny/backends/redis/connection.py +++ b/protobunny/backends/redis/connection.py @@ -13,24 +13,6 @@ VHOST = os.environ.get("REDIS_VHOST") or os.environ.get("REDIS_DB", "0") -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - class Connection(BaseSyncConnection): """Synchronous wrapper around the async connection diff --git a/protobunny/config.py b/protobunny/config.py index 90457b9..7a27b67 100644 --- a/protobunny/config.py +++ b/protobunny/config.py @@ -24,7 +24,7 @@ ENV_PREFIX = "PROTOBUNNY_" INI_FILE = "protobunny.ini" -AvailableBackends = tp.Literal["rabbitmq", "python", "redis", "mosquitto"] +AvailableBackends = tp.Literal["rabbitmq", "python", "redis", "mosquitto", "nats"] log = logging.getLogger(__name__) @@ -45,12 +45,16 @@ mosquitto_backend_config = BackEndConfig( topic_delimiter="/", multi_wildcard_delimiter="#", namespace="protobunny" ) +nats_backend_config = BackEndConfig( + topic_delimiter=".", multi_wildcard_delimiter=">", namespace="protobunny" +) backend_configs = { "rabbitmq": rabbitmq_backend_config, "python": python_backend_config, "redis": redis_backend_config, "mosquitto": mosquitto_backend_config, + "nats": nats_backend_config, } @@ -67,7 +71,9 @@ class Config: backend: "AvailableBackends" = "rabbitmq" backend_config: BackEndConfig = rabbitmq_backend_config log_task_in_redis: bool = False - available_backends = ("rabbitmq", "python", "redis", "mosquitto") + use_tasks_in_nats: bool = True # needed to create the namespaced group stream on connect + + available_backends = ("rabbitmq", "python", "redis", "mosquitto", "nats") def __post_init__(self) -> None: if self.mode not in ("sync", "async"): diff --git a/protobunny/helpers.py b/protobunny/helpers.py index ea96d03..695c682 100644 --- a/protobunny/helpers.py +++ b/protobunny/helpers.py @@ -90,10 +90,10 @@ def get_queue( @functools.lru_cache(maxsize=100) def _build_routing_key(module: str, cls_name: str) -> str: # Build the routing key from the module and class name - backend = default_configuration.backend_config + config = default_configuration + backend = config.backend_config delimiter = backend.topic_delimiter routing_key = f"{module}.{cls_name}" - config = default_configuration if not routing_key.startswith(config.generated_package_name): raise ValueError( f"Invalid topic {routing_key}, must start with {config.generated_package_name}." @@ -113,7 +113,7 @@ def build_routing_key( """Returns a routing key based on a message instance, a message class, or a module. The string will be later composed with the configured message-prefix to build the exact topic name. - This is the main logic that builds keys strings for topics/streaming, adding wildcards when needed + This is the main logic that builds keys strings for topics/streaming, adding wildcards when needed. Examples: build_routing_key(mymessaginglib.vision.control) -> "vision.control.#" routing with binding key diff --git a/protobunny/models.py b/protobunny/models.py index 4b55ffe..27a3f0b 100644 --- a/protobunny/models.py +++ b/protobunny/models.py @@ -357,20 +357,22 @@ def get_message_class_from_topic(topic: str) -> "type[ProtoBunnyMessage] | None """Return the message class from a topic with lazy import of the user library Args: - topic: the RabbitMQ topic that represents the message queue + topic: the topic that represents the message queue, mapped to the message class + example for redis mylib:tasks:TaskMessage -> mylib.tasks.TaskMessage class - Returns: the message class + Returns: the message class for the topic or None if the topic is not recognized """ delimiter = default_configuration.backend_config.topic_delimiter if topic.endswith(f"{delimiter}result"): message_type = Result else: route = topic.removeprefix(f"{default_configuration.messages_prefix}{delimiter}") - if route == topic: - # Allow pb.* internal messages + if route == topic: # the prefix is not present in the topic + # Try if it's a protobunny class + # to allow pb.* internal messages like pb.results.Result route = topic.removeprefix(f"pb{delimiter}") codegen_module = importlib.import_module(default_configuration.generated_package_name) - # if route is not recognized, the message_type will be None + # if route is not recognized at this point, the message_type will be None message_type = _get_submodule(codegen_module, route.split(delimiter)) return message_type diff --git a/pyproject.toml b/pyproject.toml index fe58855..589c3ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,24 +3,23 @@ name = "protobunny" classifiers = [ "Development Status :: 3 - Alpha" ] -version = "0.1.2a1" -description = "A type-safe, sync/async Python messaging framework." +version = "0.1.2a2" +description = "A type-safe, sync/async Python messaging library." authors = [ {name = "Domenico Nappo", email = "domenico.nappo@am-flow.com"}, {name = "Sander Koelstra", email = "sander.koelstra@am-flow.com.com"}, {name = "Sem Mulder", email = "sem.mulder@am-flow.com"} ] keywords = [ + "messaging", "protobuf", "queues", "mqtt", - "messaging", + "amqp", "rabbitmq", "redis", - "amqp", - "aio-pika", - "betterproto", - "mosquitto" + "mosquitto", + "nats" ] requires-python = ">=3.10,<3.14" readme = "README.md" @@ -29,7 +28,6 @@ dependencies = [ "can-ada>=2.0.0", "click>=8.2.1", "grpcio-tools>=1.62.0,<2", - "pytest-env>=1.1.3", "tomli ; python_version < '3.11'", "typing-extensions; python_version < '3.11'" ] @@ -41,6 +39,7 @@ numpy = ["numpy>=1.26"] mosquitto = [ "aiomqtt>=2.4.0" ] +nats = ["nats-py>=2.12.0"] [project.scripts] protobunny = "protobunny.wrapper:main" @@ -60,6 +59,7 @@ dev = [ "waiting>=1.5.0,<2", "ipython>=8.11.0,<9", "pytest>=7.2.2,<8", + "pytest-env>=1.1.3", "pytest-cov>=4.0.0,<5", "pytest-mock>=3.14.0,<4", "pytest-asyncio>=0.23.7,<0.24", @@ -136,4 +136,4 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.setuptools] -package-data = {"protobunny" = ["protobunny/protobuf/*.proto", "protobunny/__init__.py.j2", "scripts/*.py"]} +package-data = {"protobunny" = ["protobunny/protobuf/*.proto", "protobunny/__init__.py.j2", "protobunny/asyncio/__init__.py.j2", "scripts/*.py"]} diff --git a/setup.py b/setup.py index 5362ffe..cec0c9b 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def run(self) -> None: cmdclass={ "install": GenerateProtoCommand, }, - python_requires=">=3.10,<3.13", + python_requires=">=3.10,<3.14", description="Protobuf messages and python mqtt messaging toolkit", entry_points={ "console_scripts": [ diff --git a/tests/conftest.py b/tests/conftest.py index 2ea3c68..91cf566 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import aio_pika import aiormq import fakeredis +import nats import pamqp import pytest import pytest_asyncio @@ -31,7 +32,8 @@ def test_config() -> protobunny.config.Config: force_required_fields=True, mode="async", backend="rabbitmq", - log_task_in_redis=False, + log_task_in_redis=True, + use_tasks_in_nats=True, ) return conf @@ -45,6 +47,22 @@ def __init__(self): self.subscribe = AsyncMock() self.unsubscribe = AsyncMock() self.publish = AsyncMock() + self._is_connected = False + + def is_connected(self) -> bool: + return self._is_connected + + def connect(self, host, port, keepalive): + self._is_connected = True + + def disconnect(self): + self._is_connected = False + + def loop_start(self): + return True + + def loop_stop(self): + return True async def __aenter__(self): return self @@ -71,6 +89,18 @@ async def mock_mosquitto(mocker) -> tp.AsyncGenerator[MockMQTTConnection, None]: yield mock_client +@pytest.fixture +async def mock_sync_mosquitto(mocker) -> tp.AsyncGenerator[MockMQTTConnection, None]: + mock_client = MockMQTTConnection() + mocker.spy(mock_client, "publish") + mocker.spy(mock_client, "subscribe") + mocker.spy(mock_client, "unsubscribe") + + with patch("paho.mqtt.client.Client") as mock_client_class: + mock_client_class.return_value = mock_client + yield mock_client + + @pytest.fixture async def mock_redis_client(mocker) -> tp.AsyncGenerator[fakeredis.FakeAsyncRedis, None]: server = fakeredis.FakeServer() @@ -85,6 +115,30 @@ async def mock_redis_client(mocker) -> tp.AsyncGenerator[fakeredis.FakeAsyncRedi await client.aclose() +@pytest.fixture +async def mock_nats(mocker): + # with patch("nats.NATS") as mock_nats_client: #, patch("nats.connect"): + # 1. Create the top-level Mock Client + mock_nc = AsyncMock(spec=nats.aio.client.Client) + + # 2. Mock the JetStream Context (.jetstream()) + mock_js = AsyncMock() + mock_nc.jetstream.return_value = mock_js + + # 3. Mock the Management API (.jsm()) + mock_jsm = AsyncMock() + mock_nc.jsm.return_value = mock_jsm + + # 4. PATCH: Intercept the nats.connect call + # Replace 'your_module.nats.connect' with the path where you import nats + mocker.patch("nats.connect", return_value=mock_nc) + yield { + "client": mock_nc, + "js": mock_js, + "jsm": mock_jsm, + } + + @pytest.fixture async def mock_aio_pika(): """Mocks the entire aio_pika connection chain.""" @@ -120,7 +174,7 @@ async def mock_aio_pika(): def mock_sync_rmq_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=rabbitmq_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.rabbitmq.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock @@ -135,7 +189,7 @@ async def mock_rmq_connection(mocker: MockerFixture) -> tp.AsyncGenerator[AsyncM def mock_sync_redis_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=redis_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.rabbitmq.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock @@ -150,7 +204,7 @@ async def mock_redis_connection(mocker: MockerFixture) -> tp.AsyncGenerator[Asyn def mock_sync_mqtt_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=mosquitto_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.mosquitto.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock diff --git a/tests/test_connection.py b/tests/test_connection.py index 0e832e4..d4b97ed 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -11,6 +11,7 @@ from protobunny import RequeueMessage from protobunny import asyncio as pb from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio +from protobunny.asyncio.backends import nats as nats_backend_aio from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio @@ -33,13 +34,27 @@ @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + rabbitmq_backend_aio, + redis_backend_aio, + python_backend_aio, + mosquitto_backend_aio, + nats_backend_aio, + ], ) @pytest.mark.asyncio class TestConnection: @pytest.fixture(autouse=True) async def mock_connections( - self, backend, mocker, mock_redis_client, mock_aio_pika, mock_mosquitto, test_config + self, + backend, + mocker, + mock_redis_client, + mock_aio_pika, + mock_mosquitto, + mock_nats, + test_config, ) -> tp.AsyncGenerator[dict[str, AsyncMock | None], None]: backend_name = backend.__name__.split(".")[-1] @@ -60,24 +75,19 @@ async def mock_connections( assert pb.get_backend() == backend assert isinstance(get_queue(tests.tasks.TaskMessage), backend.queues.AsyncQueue) conn_with_fake_internal_conn = get_mocked_connection( - backend, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto + backend, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats ) - mocker.patch.object(pb, "connect", return_value=conn_with_fake_internal_conn) - mocker.patch.object(pb, "disconnect", side_effect=backend.connection.disconnect) mocker.patch( "protobunny.asyncio.backends.BaseAsyncQueue.get_connection", return_value=conn_with_fake_internal_conn, ) - mocker.patch( - f"protobunny.asyncio.backends.{backend_name}.connection.connect", - return_value=conn_with_fake_internal_conn, - ) yield { "rabbitmq": mock_aio_pika, "redis": mock_redis_client, "connection": conn_with_fake_internal_conn, "python": conn_with_fake_internal_conn._connection, "mosquitto": mock_mosquitto, + "nats": mock_nats, } connection_module.Connection.instance_by_vhost.clear() @@ -137,6 +147,9 @@ async def test_publish_tasks( async def test_publish(self, mock_connection: MagicMock, mock_internal_connection, backend): topic = "test.routing.key" + backend_name = backend.__name__.split(".")[-1] + delimiter = backend_configs[backend_name].topic_delimiter + topic = topic.replace(".", delimiter) conn = await mock_connection.connect() msg = None incoming = incoming_message_factory(backend) @@ -164,8 +177,8 @@ async def predicate(): @pytest.mark.asyncio async def test_singleton_logic(self, backend): - conn1 = await backend.connection.connect() - conn2 = await backend.connection.connect() + conn1 = await pb.connect() + conn2 = await pb.connect() assert conn1 is conn2 await conn1.disconnect() @@ -283,3 +296,7 @@ async def test_get_message_count(mock_redis_client): await conn.publish("test:tasks:topic", Envelope(body=b"test message 3")) count = await conn.get_message_count("test:tasks:topic") assert count == 3 + + +# --- Specific Tests for NATS --- +# TODO diff --git a/tests/test_integration.py b/tests/test_integration.py index d072092..784a0a4 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,4 +1,3 @@ -import gc import logging import typing as tp @@ -10,10 +9,12 @@ import protobunny as pb_sync from protobunny import asyncio as pb from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio +from protobunny.asyncio.backends import nats as nats_backend_aio from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio from protobunny.backends import mosquitto as mosquitto_backend +from protobunny.backends import nats as nats_backend from protobunny.backends import python as python_backend from protobunny.backends import rabbitmq as rabbitmq_backend from protobunny.backends import redis as redis_backend @@ -86,14 +87,24 @@ def log_callback(message: aio_pika.IncomingMessage, body: str) -> str: @pytest.mark.integration @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + rabbitmq_backend_aio, + redis_backend_aio, + python_backend_aio, + mosquitto_backend_aio, + nats_backend_aio, + ], ) class TestIntegration: - """Integration tests (to run with RabbitMQ up)""" + """Integration tests (to run with all the backends up) + + For a specific backend, run python -m pytest tests/test_integration.py -k redis + """ msg = tests.TestMessage(content="test", number=123, color=tests.Color.GREEN) - @pytest.fixture(autouse=True) + # @pytest.fixture(autouse=True) async def setup_test_env( self, mocker: MockerFixture, test_config: Config, backend ) -> tp.AsyncGenerator[None, None]: @@ -103,12 +114,11 @@ async def setup_test_env( test_config.log_task_in_redis = True test_config.backend_config = backend_configs[backend_name] self.topic_delimiter = test_config.backend_config.topic_delimiter + # Patch global configuration for all modules that use it mocker.patch.object(pb_sync.config, "default_configuration", test_config) mocker.patch.object(pb_sync.models, "default_configuration", test_config) mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - - # patch the asyncio modules mocker.patch.object(pb.backends, "default_configuration", test_config) mocker.patch.object(pb, "default_configuration", test_config) if hasattr(backend.connection, "default_configuration"): @@ -117,11 +127,7 @@ async def setup_test_env( mocker.patch.object(backend.queues, "default_configuration", test_config) pb.backend = backend - mocker.patch("protobunny.asyncio.backends.get_backend", return_value=backend) - mocker.patch.object(pb, "connect", backend.connection.connect) - mocker.patch.object(pb, "disconnect", backend.connection.disconnect) mocker.patch.object(pb, "get_backend", return_value=backend) - # Assert the patching is working for setting the backend connection = await pb.connect() assert isinstance(connection, backend.connection.Connection) @@ -142,11 +148,10 @@ async def setup_test_env( "result": None, "task": None, } - await connection.disconnect() + await pb.disconnect() backend.connection.Connection.instance_by_vhost.clear() - gc.collect() - @pytest.mark.flaky(max_runs=3) + # @pytest.mark.flaky(max_runs=3) async def test_publish(self, backend) -> None: global received await pb.subscribe(self.msg.__class__, callback) @@ -159,7 +164,7 @@ async def predicate() -> bool: assert received["message"].number == self.msg.number assert received["message"].content == "test" - @pytest.mark.flaky(max_runs=3) + # @pytest.mark.flaky(max_runs=3) async def test_to_dict(self, backend) -> None: global received await pb.subscribe(self.msg.__class__, callback) @@ -452,7 +457,7 @@ async def predicate() -> bool: @pytest.mark.integration @pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] + "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend, nats_backend] ) class TestIntegrationSync: """Integration tests (to run with the backend server up)""" @@ -483,8 +488,8 @@ def setup_test_env( pb_sync.backend = backend # mocker.patch("protobunny.helpers.get_backend", return_value=backend) mocker.patch.object(pb_sync.helpers, "get_backend", return_value=backend) - mocker.patch.object(pb_sync, "connect", backend.connection.connect) - mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) + # mocker.patch.object(pb_sync, "connect", backend.connection.connect) + # mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) mocker.patch.object(pb_sync, "get_backend", return_value=backend) # Assert the patching is working for setting the backend diff --git a/tests/test_queues.py b/tests/test_queues.py index 644a1c6..df62412 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -10,11 +10,13 @@ from protobunny import asyncio as pb from protobunny.asyncio.backends import LoggingAsyncQueue from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio +from protobunny.asyncio.backends import nats as nats_backend_aio from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio from protobunny.backends import LoggingSyncQueue from protobunny.backends import mosquitto as mosquitto_backend +from protobunny.backends import nats as nats_backend from protobunny.backends import python as python_backend from protobunny.backends import rabbitmq as rabbitmq_backend from protobunny.backends import redis as redis_backend @@ -30,12 +32,26 @@ @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + rabbitmq_backend_aio, + redis_backend_aio, + python_backend_aio, + mosquitto_backend_aio, + nats_backend_aio, + ], ) class TestQueue: @pytest.fixture(autouse=True) async def mock_connections( - self, backend, mocker, mock_redis_client, mock_aio_pika, test_config, mock_mosquitto + self, + backend, + mocker, + mock_redis_client, + mock_aio_pika, + test_config, + mock_mosquitto, + mock_nats, ) -> tp.AsyncGenerator[dict, None]: backend_name = backend.__name__.split(".")[-1] test_config.mode = "async" @@ -56,7 +72,6 @@ async def mock_connections( pb.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_base, "disconnect", backend.connection.disconnect) mocker.patch.object(pb_base, "get_backend", return_value=backend) assert pb_base.helpers.get_backend() == backend @@ -73,18 +88,17 @@ async def mock_connections( mock._connection = mock_aio_pika case "mosquitto": mock._connection = mock_mosquitto + case "nats": + mock._connection = mock_nats case "python": mock = backend.connection.Connection() mock.is_connected_event.set() mocker.patch.object(pb, "connect", return_value=mock) mocker.patch("protobunny.asyncio.backends.BaseAsyncQueue.get_connection", return_value=mock) - mocker.patch( - f"protobunny.asyncio.backends.{backend_name}.connection.connect", - return_value=mock, - ) - await pb.disconnect() + assert asyncio.iscoroutinefunction(pb.disconnect) + await pb.disconnect() yield { "rabbitmq": mock_aio_pika, @@ -92,12 +106,11 @@ async def mock_connections( "connection": mock, "python": mock._connection, "mosquitto": mock_mosquitto, + "nats": mock_nats, } connection_module = getattr(pb_base.backends, backend_name).connection connection_module.Connection.instance_by_vhost.clear() - connection_module.Connection.instance_by_vhost.clear() - @pytest.fixture async def mock_connection(self, mock_connections, backend, mocker): conn = mock_connections["connection"] @@ -230,12 +243,19 @@ async def test_logger(self, mock_connection: MagicMock, backend) -> None: @pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] + "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend, nats_backend] ) class TestSyncQueue: @pytest.fixture(autouse=True, scope="function") def mock_connection( - self, backend, mocker, mock_redis_client, mock_aio_pika, test_config + self, + backend, + mocker, + mock_redis_client, + mock_aio_pika, + test_config, + mock_nats, + mock_sync_mosquitto, ) -> tp.Generator[None, None, None]: backend_name = backend.__name__.split(".")[-1] test_config.mode = "sync" @@ -256,7 +276,6 @@ def mock_connection( pb_base.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_base, "disconnect", backend.connection.disconnect) mocker.patch.object(pb_base, "get_backend", return_value=backend) assert pb_base.helpers.get_backend() == backend @@ -268,7 +287,6 @@ def mock_connection( mock = mocker.MagicMock(spec=backend.connection.Connection) # should be a Sync connection mocker.patch.object(pb_base, "connect", return_value=mock) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch(f"protobunny.backends.{backend_name}.connection.connect", return_value=mock) assert not asyncio.iscoroutinefunction(pb_base.disconnect) pb_base.disconnect() yield mock diff --git a/tests/test_results.py b/tests/test_results.py index e65483b..551ba14 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -29,7 +29,6 @@ def setup_connections(mocker: MockerFixture, mock_sync_rmq_connection, test_conf mocker.patch.object(pb.backends, "default_configuration", test_config) mocker.patch.object(pb.helpers, "default_configuration", test_config) pb.backend = rabbitmq_backend - mocker.patch.object(pb, "connect", rabbitmq_backend.connection.connect) mocker.patch.object(pb.helpers.default_configuration, "backend", "rabbitmq") queue = pb.get_queue(tests.TestMessage) assert isinstance(queue, rabbitmq_backend.queues.SyncQueue) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c1138ee..5701c7e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -56,8 +56,6 @@ async def setup_test_env(self, mocker, test_config, backend) -> tp.AsyncGenerato pb.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb, "connect", backend.connection.connect) - mocker.patch.object(pb, "disconnect", backend.connection.disconnect) mocker.patch.object(pb, "get_backend", return_value=backend) # Assert the patching is working for setting the backend @@ -151,8 +149,6 @@ def setup_test_env(self, mocker, test_config, backend) -> tp.Generator[None, Non pb_sync.backend = backend mocker.patch("protobunny.backends.get_backend", return_value=backend) mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_sync, "connect", backend.connection.connect) - mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) # Assert the patching is working for setting the backend connection = pb_sync.connect() @@ -184,12 +180,10 @@ def predicate_2() -> bool: return self.received.get("task_2") is not None def callback_task_1(msg: "ProtoBunnyMessage") -> None: - log.debug("SYNC CALLBACK TASK 1 %s", msg) time.sleep(0.1) self.received["task_1"] = msg def callback_task_2(msg: "ProtoBunnyMessage") -> None: - log.debug("SYNC CALLBACK TASK 2 %s", msg) time.sleep(0.1) self.received["task_2"] = msg diff --git a/tests/utils.py b/tests/utils.py index 0cf9fae..3d130bd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -111,103 +111,235 @@ async def predicate(): assert ( await internal_mock.get_message_count(topic) == count_in_queue ), f"count was {await internal_mock.get_message_count(topic)}" - case "mosquitto": - internal_mock.publish.assert_awaited_once() + if not shared_queue: + internal_mock.publish.assert_awaited_once_with( + "protobunny/test/routing/key", payload=b"Hello", qos=1, retain=False + ) + else: + internal_mock.publish.assert_awaited_once_with( + "protobunny/test/tasks/key", payload=b"Hello", qos=1, retain=False + ) + case "nats": + if not shared_queue: + internal_mock["client"].publish.assert_awaited_once_with( + subject="protobunny.test.routing.key", payload=b"Hello", headers=None + ) + else: + internal_mock["js"].publish.assert_awaited_once_with( + subject="protobunny.test.tasks.key", payload=b"Hello", headers=None + ) async def assert_backend_setup_queue( backend, internal_mock, topic: str, shared: bool, mock_connection ) -> None: backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": - internal_mock["channel"].declare_queue.assert_called_with( - topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY - ) - elif backend_name == "redis": - if shared: - streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") - assert len(streams) == 1 - assert ( - streams[0]["name"].decode() == "shared_group" - ), f"Expected 'shared_group', got '{streams[0]['name']}'" - else: - assert mock_connection.queues[topic] == {"is_shared": False} - - elif backend_name == "python": - if not shared: - assert len(internal_mock._exclusive_queues.get(topic)) == 1 - else: - assert internal_mock._shared_queues.get(topic) - elif backend_name == "mosquitto": - if not shared: - assert mock_connection.queues[topic] == { - "is_shared": False, - "group_name": "", - "sub_key": f"$share/shared_group/test/{topic}", - "tag": ANY, - "topic": topic, - "topic_key": f"test/{topic}", - } - else: - assert list(mock_connection.queues.values())[0] == { - "is_shared": True, - "group_name": "shared_group", - "sub_key": f"$share/shared_group/protobunny/{topic}", - "tag": ANY, - "topic": topic, - "topic_key": f"protobunny/{topic}", - }, mock_connection.queues + match backend_name: + case "rabbitmq": + internal_mock["channel"].declare_queue.assert_called_with( + topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY + ) + case "redis": + if shared: + streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") + assert len(streams) == 1 + assert ( + streams[0]["name"].decode() == "shared_group" + ), f"Expected 'shared_group', got '{streams[0]['name']}'" + else: + assert mock_connection.queues[topic] == {"is_shared": False} + case "python": + if not shared: + assert len(internal_mock._exclusive_queues.get(topic)) == 1 + else: + assert internal_mock._shared_queues.get(topic) + case "mosquitto": + if not shared: + assert mock_connection.queues[topic] == { + "is_shared": False, + "group_name": "", + "sub_key": f"$share/shared_group/test/{topic}", + "tag": ANY, + "topic": topic, + "topic_key": f"test/{topic}", + } + else: + assert list(mock_connection.queues.values())[0] == { + "is_shared": True, + "group_name": "shared_group", + "sub_key": f"$share/shared_group/protobunny/{topic}", + "tag": ANY, + "topic": topic, + "topic_key": f"protobunny/{topic}", + }, mock_connection.queues + case "nats": + internal_mock["js"].subscribe.assert_awaited_once_with( + subject="protobunny.mylib.tasks.TaskMessage", + durable="protobunny_protobunny_mylib_tasks_TaskMessage", + cb=ANY, + manual_ack=True, + stream="PROTOBUNNY_TASKS", + ) + + # if backend_name == "rabbitmq": + # internal_mock["channel"].declare_queue.assert_called_with( + # topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY + # ) + # elif backend_name == "redis": + # if shared: + # streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") + # assert len(streams) == 1 + # assert ( + # streams[0]["name"].decode() == "shared_group" + # ), f"Expected 'shared_group', got '{streams[0]['name']}'" + # else: + # assert mock_connection.queues[topic] == {"is_shared": False} + # elif backend_name == "nats": + # assert False # TODO + # + # elif backend_name == "python": + # if not shared: + # assert len(internal_mock._exclusive_queues.get(topic)) == 1 + # else: + # assert internal_mock._shared_queues.get(topic) + # elif backend_name == "mosquitto": + # if not shared: + # assert mock_connection.queues[topic] == { + # "is_shared": False, + # "group_name": "", + # "sub_key": f"$share/shared_group/test/{topic}", + # "tag": ANY, + # "topic": topic, + # "topic_key": f"test/{topic}", + # } + # else: + # assert list(mock_connection.queues.values())[0] == { + # "is_shared": True, + # "group_name": "shared_group", + # "sub_key": f"$share/shared_group/protobunny/{topic}", + # "tag": ANY, + # "topic": topic, + # "topic_key": f"protobunny/{topic}", + # }, mock_connection.queues async def assert_backend_connection(backend, internal_mock): backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": - # Verify aio_pika calls - internal_mock["connect"].assert_awaited_once() - assert internal_mock["channel"].set_qos.called - # Check if main and DLX exchanges were declared - assert internal_mock["channel"].declare_exchange.call_count == 2 - elif backend_name == "redis": - assert await internal_mock.ping() - elif backend_name == "mosquitto": - internal_mock.__aenter__.assert_awaited_once() - assert True - + match backend_name: + case "rabbitmq": + # Verify aio_pika calls + internal_mock["connect"].assert_awaited_once() + assert internal_mock["channel"].set_qos.called + # Check if main and DLX exchanges were declared + assert internal_mock["channel"].declare_exchange.call_count == 2 + case "redis": + assert await internal_mock.ping() + case "python": + assert True # python is always "connected" + case "mosquitto": + internal_mock.__aenter__.assert_awaited_once() + case "nats": + import nats -def get_mocked_connection(backend, redis_client, mock_aio_pika, mocker, mock_mosquitto): + nats.connect.assert_awaited_once_with( + "nats://localhost:4222/", connect_timeout=30.0, max_reconnect_attempts=3 + ) + # internal_mock["client"].connect.assert_awaited_once() + + # if backend_name == "rabbitmq": + # # Verify aio_pika calls + # internal_mock["connect"].assert_awaited_once() + # assert internal_mock["channel"].set_qos.called + # # Check if main and DLX exchanges were declared + # assert internal_mock["channel"].declare_exchange.call_count == 2 + # elif backend_name == "redis": + # assert await internal_mock.ping() + # elif backend_name == "mosquitto": + # internal_mock.__aenter__.assert_awaited_once() + # elif backend_name == "nats": + # internal_mock.connect.assert_awaited_once() + # assert True + + +def get_mocked_connection(backend, redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats): backend_name = backend.__name__.split(".")[-1] - if backend_name == "redis": - real_conn_with_fake_redis = backend.connection.Connection(url="redis://localhost:6379/0") - assert ( - real_conn_with_fake_redis._exchange == "protobunny" - ), real_conn_with_fake_redis._exchange - real_conn_with_fake_redis._connection = redis_client - - def check_connected() -> bool: - return real_conn_with_fake_redis._connection is not None - - # Patch is_connected logic - mocker.patch.object(real_conn_with_fake_redis, "is_connected", side_effect=check_connected) - - return real_conn_with_fake_redis - elif backend_name == "rabbitmq": - real_conn_with_fake_aio_pika = backend.connection.Connection( - url="amqp://guest:guest@localhost:5672/" - ) - real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] - real_conn_with_fake_aio_pika.is_connected_event.set() - real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] - real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] - real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] - - return real_conn_with_fake_aio_pika - elif backend_name == "python": - python_conn = backend.connection.Connection() - python_conn.is_connected_event.set() - return python_conn - elif backend_name == "mosquitto": - real_conn_with_fake_aiomqtt = backend.connection.Connection() - real_conn_with_fake_aiomqtt._connection = mock_mosquitto - # real_conn_with_fake_aiomqtt.is_connected_event.set() - return real_conn_with_fake_aiomqtt + + match backend_name: + case "redis": + real_conn_with_fake_redis = backend.connection.Connection( + url="redis://localhost:6379/0" + ) + assert ( + real_conn_with_fake_redis._exchange == "protobunny" + ), real_conn_with_fake_redis._exchange + real_conn_with_fake_redis._connection = redis_client + + def check_connected() -> bool: + return real_conn_with_fake_redis._connection is not None + + # Patch is_connected logic + mocker.patch.object( + real_conn_with_fake_redis, "is_connected", side_effect=check_connected + ) + return real_conn_with_fake_redis + case "nats": + real_conn_with_fake_nats = backend.connection.Connection() + real_conn_with_fake_nats._connection = mock_nats["client"] + return real_conn_with_fake_nats + case "rabbitmq": + real_conn_with_fake_aio_pika = backend.connection.Connection( + url="amqp://guest:guest@localhost:5672/" + ) + real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] + real_conn_with_fake_aio_pika.is_connected_event.set() + real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] + real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] + real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] + return real_conn_with_fake_aio_pika + case "python": + python_conn = backend.connection.Connection() + python_conn.is_connected_event.set() + return python_conn + case "mosquitto": + real_conn_with_fake_aiomqtt = backend.connection.Connection() + real_conn_with_fake_aiomqtt._connection = mock_mosquitto + return real_conn_with_fake_aiomqtt + + # if backend_name == "redis": + # real_conn_with_fake_redis = backend.connection.Connection(url="redis://localhost:6379/0") + # assert ( + # real_conn_with_fake_redis._exchange == "protobunny" + # ), real_conn_with_fake_redis._exchange + # real_conn_with_fake_redis._connection = redis_client + # + # def check_connected() -> bool: + # return real_conn_with_fake_redis._connection is not None + # + # # Patch is_connected logic + # mocker.patch.object(real_conn_with_fake_redis, "is_connected", side_effect=check_connected) + # + # return real_conn_with_fake_redis + # elif backend_name == "rabbitmq": + # real_conn_with_fake_aio_pika = backend.connection.Connection( + # url="amqp://guest:guest@localhost:5672/" + # ) + # real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] + # real_conn_with_fake_aio_pika.is_connected_event.set() + # real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] + # real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] + # real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] + # + # return real_conn_with_fake_aio_pika + # elif backend_name == "python": + # python_conn = backend.connection.Connection() + # python_conn.is_connected_event.set() + # return python_conn + # elif backend_name == "mosquitto": + # real_conn_with_fake_aiomqtt = backend.connection.Connection() + # real_conn_with_fake_aiomqtt._connection = mock_mosquitto + # return real_conn_with_fake_aiomqtt + # elif backend_name == "nats": + # real_conn_with_fake_nats = backend.connection.Connection() + # real_conn_with_fake_nats._connection = mock_nats["client"] + # return real_conn_with_fake_nats diff --git a/uv.lock b/uv.lock index dd4283d..cc5b48d 100644 --- a/uv.lock +++ b/uv.lock @@ -1029,6 +1029,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/df/76d0321c3797b54b60fef9ec3bd6f4cfd124b9e422182156a1dd418722cf/myst_parser-4.0.1-py3-none-any.whl", hash = "sha256:9134e88959ec3b5780aedf8a99680ea242869d012e8821db3126d427edc9c95d", size = 84579, upload-time = "2025-02-12T10:53:02.078Z" }, ] +[[package]] +name = "nats-py" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/c5/2564d917503fe8d68fe630c74bf6b678fbc15c01b58f2565894761010f57/nats_py-2.12.0.tar.gz", hash = "sha256:2981ca4b63b8266c855573fa7871b1be741f1889fd429ee657e5ffc0971a38a1", size = 119821, upload-time = "2025-10-31T05:27:31.247Z" } + [[package]] name = "numpy" version = "1.26.4" @@ -1268,14 +1274,13 @@ wheels = [ [[package]] name = "protobunny" -version = "0.1.2a1" +version = "0.1.2a2" source = { editable = "." } dependencies = [ { name = "betterproto", extra = ["compiler"] }, { name = "can-ada" }, { name = "click" }, { name = "grpcio-tools" }, - { name = "pytest-env" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] @@ -1284,6 +1289,9 @@ dependencies = [ mosquitto = [ { name = "aiomqtt" }, ] +nats = [ + { name = "nats-py" }, +] numpy = [ { name = "numpy" }, ] @@ -1304,6 +1312,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-env" }, { name = "pytest-mock" }, { name = "ruff" }, { name = "toml-sort" }, @@ -1329,13 +1338,13 @@ requires-dist = [ { name = "can-ada", specifier = ">=2.0.0" }, { name = "click", specifier = ">=8.2.1" }, { name = "grpcio-tools", specifier = ">=1.62.0,<2" }, + { name = "nats-py", marker = "extra == 'nats'", specifier = ">=2.12.0" }, { name = "numpy", marker = "extra == 'numpy'", specifier = ">=1.26" }, - { name = "pytest-env", specifier = ">=1.1.3" }, { name = "redis", marker = "extra == 'redis'", specifier = "<8" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -provides-extras = ["rabbitmq", "redis", "numpy", "mosquitto"] +provides-extras = ["rabbitmq", "redis", "numpy", "mosquitto", "nats"] [package.metadata.requires-dev] dev = [ @@ -1347,6 +1356,7 @@ dev = [ { name = "pytest", specifier = ">=7.2.2,<8" }, { name = "pytest-asyncio", specifier = ">=0.23.7,<0.24" }, { name = "pytest-cov", specifier = ">=4.0.0,<5" }, + { name = "pytest-env", specifier = ">=1.1.3" }, { name = "pytest-mock", specifier = ">=3.14.0,<4" }, { name = "ruff", specifier = ">=0.1.14,<0.2" }, { name = "toml-sort", specifier = ">=0.24.3" }, From 82368aa16b1c2ff90d5666605e7be5f8d5c489de Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Wed, 31 Dec 2025 22:18:31 +0100 Subject: [PATCH 2/7] Cleaning code and tests, accepting kwargs for connect(), fixing tasks logic in nats --- .github/workflows/ci.yml | 28 ++ QUICK_START.md | 6 +- README.md | 40 ++- docs/source/quick_start.md | 6 +- protobunny/__init__.py | 52 ++-- protobunny/__init__.py.j2 | 58 ++-- protobunny/asyncio/__init__.py | 52 ++-- protobunny/asyncio/__init__.py.j2 | 56 ++-- protobunny/asyncio/backends/__init__.py | 19 +- .../asyncio/backends/mosquitto/connection.py | 10 +- .../asyncio/backends/nats/connection.py | 76 +++--- .../asyncio/backends/python/connection.py | 4 +- .../asyncio/backends/redis/connection.py | 30 +-- protobunny/backends/__init__.py | 17 +- protobunny/backends/mosquitto/connection.py | 11 +- protobunny/backends/python/connection.py | 6 +- protobunny/{config.py => conf.py} | 3 +- protobunny/helpers.py | 21 +- protobunny/logger.py | 2 +- protobunny/models.py | 24 +- protobunny/registry.py | 2 +- protobunny/wrapper.py | 2 +- pyproject.toml | 5 +- setup.py | 2 +- tests/conftest.py | 14 +- tests/test_base.py | 10 +- tests/test_config.py | 6 +- tests/test_connection.py | 59 ++-- tests/test_integration.py | 251 +++++++++--------- tests/test_publish.py | 8 +- tests/test_queues.py | 38 +-- tests/test_results.py | 12 +- tests/test_tasks.py | 153 ++++++----- tests/utils.py | 140 ++-------- 34 files changed, 574 insertions(+), 649 deletions(-) rename protobunny/{config.py => conf.py} (99%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d21b5d0..ee05c42 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -236,3 +236,31 @@ jobs: run: uv run python -m pytest tests/test_integration.py -k mosquitto -vvv -s env: MQTT_HOST: 127.0.0.1 + + integration_test_nats: + runs-on: ubuntu-latest + services: + redis: + image: nats:latest + ports: + - 4222:4222 + options: >- + -js + steps: + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + - name: Set up Python 3.12 + run: uv python install 3.12 + - name: Create virtual environment + run: uv venv --python 3.12 + - name: Install dependencies + run: uv sync --all-extras + - name: Run tests + run: uv run python -m pytest tests/test_integration.py -k nats -vvv -s + env: + NATS_HOST: localhost + NATS_PORT: 6379 diff --git a/QUICK_START.md b/QUICK_START.md index 00a4690..07bb4ca 100644 --- a/QUICK_START.md +++ b/QUICK_START.md @@ -379,7 +379,6 @@ import asyncio import logging import sys - import protobunny as pb from protobunny import asyncio as pb_asyncio @@ -391,12 +390,11 @@ pb.config_lib() import mymessagelib as ml - logging.basicConfig( level=logging.INFO, format="[%(asctime)s %(levelname)s] %(name)s - %(message)s" ) log = logging.getLogger(__name__) -conf = pb.default_configuration +conf = pb.config class TestLibAsync: @@ -418,7 +416,6 @@ class TestLibAsync: async def on_message_mymessage(self, message: ml.main.MyMessage) -> None: log.info("Got main message: %s", message) - def run_forever(self): asyncio.run(self.main()) @@ -426,7 +423,6 @@ class TestLibAsync: log.info(f"LOG {incoming.routing_key}: {body}") async def main(self): - await pb_asyncio.subscribe_logger(self.log_callback) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker1) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker2) diff --git a/README.md b/README.md index 59f8236..c692323 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,47 @@ # Protobunny ```{warning} -Note: The project is in early development. +The project is in early development. ``` Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for several message brokers, including Redis and MQTT. +type-safe interface for several message brokers, including Redis, NATS, and MQTT. -It simplifies messaging for asynchronous tasks by providing: +It simplifies messaging for asynchronous message handling by providing: -* A clean “message-first” API -* Python class generation from Protobuf messages using betterproto -* Connections facilities for backends +* A clean “message-first” API by using your protobuf definitions * Message publishing/subscribing with typed topics -* Support also “task-like” queues (shared/competing consumers) vs. broadcast subscriptions +* Supports "task-like” queues (shared/competing consumers) vs. broadcast subscriptions * Generate and consume `Result` messages (success/failure + optional return payload) * Transparent messages serialization/deserialization - * Support async and sync contexts -* Transparently serialize "JSON-like" payload fields (numpy-friendly) +* Transparently serialize/deserialize custom "JSON-like" payload fields (numpy-friendly) +* Support async and sync contexts + +Supported backends in the current version are: + +- RabbitMQ +- Redis +- NATS +- Mosquitto +- Python "backend" with Queue/asyncio.Queue for local in-processing testing + +```{note} +Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +Direct access to the internal NATS or Redis clients is intentionally restricted. +If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. +``` ## Minimal requirements -- Python >= 3.10, < 3.14 +- Python >= 3.10 <=3.13 +- Core Dependencies: betterproto 2.0.0b7, grpcio-tools>=1.62.0 +- Backend Drivers (Optional based on your usage): + - NATS: nats-py (Requires NATS Server v2.10+ for full JetStream support). + - Redis: redis (Requires Redis Server v6.2+ for Stream support). + - RabbitMQ: aio-pika + - Mosquitto: aiomqtt ## Project scope @@ -74,7 +92,7 @@ Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow - [x] **Semantic Patterns**: Automatic `tasks` package routing. - [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. - [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. -- [ ] **Cloud-Native**: NATS (Core & JetStream) integration. +- [x] **Cloud-Native**: NATS (Core & JetStream) integration. - [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. - [ ] **More backends**: Kafka support. diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 87e348a..5baaff6 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -379,7 +379,6 @@ import asyncio import logging import sys - import protobunny as pb from protobunny import asyncio as pb_asyncio @@ -391,12 +390,11 @@ pb.config_lib() import mymessagelib as ml - logging.basicConfig( level=logging.INFO, format="[%(asctime)s %(levelname)s] %(name)s - %(message)s" ) log = logging.getLogger(__name__) -conf = pb.default_configuration +conf = pb.config class TestLibAsync: @@ -418,7 +416,6 @@ class TestLibAsync: async def on_message_mymessage(self, message: ml.main.MyMessage) -> None: log.info("Got main message: %s", message) - def run_forever(self): asyncio.run(self.main()) @@ -426,7 +423,6 @@ class TestLibAsync: log.info(f"LOG {incoming.routing_key}: {body}") async def main(self): - await pb_asyncio.subscribe_logger(self.log_callback) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker1) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker2) diff --git a/protobunny/__init__.py b/protobunny/__init__.py index 453f700..f040de1 100644 --- a/protobunny/__init__.py +++ b/protobunny/__init__.py @@ -20,7 +20,7 @@ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -41,15 +41,15 @@ from types import FrameType, ModuleType from .backends import BaseSyncConnection, BaseSyncQueue, LoggingSyncQueue -from .config import ( # noqa +from .conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from .exceptions import ConnectionError, RequeueMessage from .helpers import get_backend, get_queue -from .registry import default_registry +from .registry import registry if tp.TYPE_CHECKING: from .core.results import Result @@ -133,19 +133,19 @@ def subscribe( """ register_key = str(pkg_or_msg) - with default_registry.sync_lock: + with registry.sync_lock: queue = get_queue(pkg_or_msg) if queue.shared_queue: # It's a task. Handle multiple subscriptions # queue = get_queue(pkg_or_msg) queue.subscribe(callback) - default_registry.register_task(register_key, queue) + registry.register_task(register_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(register_key) or queue + queue = registry.get_subscription(register_key) or queue queue.subscribe(callback) # register subscription to unsubscribe later - default_registry.register_subscription(register_key, queue) + registry.register_subscription(register_key, queue) return queue @@ -162,8 +162,8 @@ def subscribe_results( queue = get_queue(pkg) queue.subscribe_results(callback) # register subscription to unsubscribe later - with default_registry.sync_lock: - default_registry.register_results(pkg, queue) + with registry.sync_lock: + registry.register_results(pkg, queue) return queue @@ -174,27 +174,27 @@ def unsubscribe( ) -> None: """Remove a subscription for a message/package""" module_name = pkg.__module__ if hasattr(pkg, "__module__") else pkg.__name__ - registry_key = default_registry.get_key(pkg) + registry_key = registry.get_key(pkg) - with default_registry.sync_lock: + with registry.sync_lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for queue in queues: queue.unsubscribe() - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) def unsubscribe_results( pkg: "type[ProtoBunnyMessage] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - with default_registry.sync_lock: - queue = default_registry.unregister_results(pkg) + with registry.sync_lock: + queue = registry.unregister_results(pkg) if queue: queue.unsubscribe_results() @@ -206,19 +206,19 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - with default_registry.sync_lock: - queues = default_registry.get_all_subscriptions() + with registry.sync_lock: + queues = registry.get_all_subscriptions() for q in queues: q.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + queues = registry.get_all_results() for q in queues: q.unsubscribe_results() - default_registry.unregister_all_results() - queues = default_registry.get_all_tasks(flat=True) + registry.unregister_all_results() + queues = registry.get_all_tasks(flat=True) for q in queues: q.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_all_tasks() + registry.unregister_all_tasks() def get_message_count( @@ -277,7 +277,7 @@ def shutdown(signum: int, _: FrameType | None) -> None: def config_lib() -> None: """Add the generated package root to the sys.path.""" - lib_root = default_configuration.generated_package_root + lib_root = config.generated_package_root if lib_root and lib_root not in sys.path: sys.path.append(lib_root) diff --git a/protobunny/__init__.py.j2 b/protobunny/__init__.py.j2 index 3703ba3..d4bbd7e 100644 --- a/protobunny/__init__.py.j2 +++ b/protobunny/__init__.py.j2 @@ -20,7 +20,7 @@ __all__ = [ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -43,14 +43,14 @@ import typing as tp from importlib.metadata import version from types import ModuleType, FrameType -from .config import ( # noqa +from .conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from .exceptions import RequeueMessage, ConnectionError -from .registry import default_registry +from .registry import registry from .helpers import get_backend, get_queue from .backends import LoggingSyncQueue, BaseSyncQueue, BaseSyncConnection @@ -69,10 +69,10 @@ log = logging.getLogger(PACKAGE_NAME) ############################ -def connect() -> "BaseSyncConnection": +def connect(**kwargs) -> "BaseSyncConnection": """Get the singleton async connection.""" connection_module = get_backend().connection - conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) return conn @@ -132,19 +132,17 @@ def subscribe( """ register_key = str(pkg_or_msg) - with default_registry.sync_lock: + with registry.sync_lock: queue = get_queue(pkg_or_msg) if queue.shared_queue: # It's a task. Handle multiple subscriptions - # queue = get_queue(pkg_or_msg) queue.subscribe(callback) - default_registry.register_task(register_key, queue) + registry.register_task(register_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(register_key) or queue + queue = registry.get_subscription(register_key) or queue queue.subscribe(callback) - # register subscription to unsubscribe later - default_registry.register_subscription(register_key, queue) + registry.register_subscription(register_key, queue) return queue @@ -161,8 +159,8 @@ def subscribe_results( queue = get_queue(pkg) queue.subscribe_results(callback) # register subscription to unsubscribe later - with default_registry.sync_lock: - default_registry.register_results(pkg, queue) + with registry.sync_lock: + registry.register_results(pkg, queue) return queue @@ -173,27 +171,27 @@ def unsubscribe( ) -> None: """Remove a subscription for a message/package""" module_name = pkg.__module__ if hasattr(pkg, "__module__") else pkg.__name__ - registry_key = default_registry.get_key(pkg) + registry_key = registry.get_key(pkg) - with default_registry.sync_lock: + with registry.sync_lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for queue in queues: queue.unsubscribe() - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) def unsubscribe_results( pkg: "type[ProtoBunnyMessage] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - with default_registry.sync_lock: - queue = default_registry.unregister_results(pkg) + with registry.sync_lock: + queue = registry.unregister_results(pkg) if queue: queue.unsubscribe_results() @@ -205,19 +203,19 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - with default_registry.sync_lock: - queues = default_registry.get_all_subscriptions() + with registry.sync_lock: + queues = registry.get_all_subscriptions() for q in queues: q.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + queues = registry.get_all_results() for q in queues: q.unsubscribe_results() - default_registry.unregister_all_results() - queues = default_registry.get_all_tasks(flat=True) + registry.unregister_all_results() + queues = registry.get_all_tasks(flat=True) for q in queues: q.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_all_tasks() + registry.unregister_all_tasks() def get_message_count( @@ -275,7 +273,7 @@ def run_forever() -> None: def config_lib() -> None: """Add the generated package root to the sys.path.""" - lib_root = default_configuration.generated_package_root + lib_root = config.generated_package_root if lib_root and lib_root not in sys.path: sys.path.append(lib_root) diff --git a/protobunny/asyncio/__init__.py b/protobunny/asyncio/__init__.py index fcd15e7..e420cef 100644 --- a/protobunny/asyncio/__init__.py +++ b/protobunny/asyncio/__init__.py @@ -27,7 +27,7 @@ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -50,14 +50,14 @@ from importlib.metadata import version ####################################################### -from ..config import ( # noqa +from ..conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from ..exceptions import ConnectionError, RequeueMessage -from ..registry import default_registry +from ..registry import registry if tp.TYPE_CHECKING: from types import ModuleType @@ -153,18 +153,18 @@ async def subscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = str(pkg) - async with default_registry.lock: + async with registry.lock: if is_module_tasks(module_name): # It's a task. Handle multiple in-process subscriptions queue = get_queue(pkg) await queue.subscribe(callback) - default_registry.register_task(registry_key, queue) + registry.register_task(registry_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(registry_key) or get_queue(pkg) + queue = registry.get_subscription(registry_key) or get_queue(pkg) # queue already exists, but not subscribed yet (otherwise raise ValueError) await queue.subscribe(callback) - default_registry.register_subscription(registry_key, queue) + registry.register_subscription(registry_key, queue) return queue @@ -177,29 +177,29 @@ async def unsubscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ - registry_key = default_registry.get_key(pkg) - async with default_registry.lock: + registry_key = registry.get_key(pkg) + async with registry.lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for q in queues: await q.unsubscribe(if_unused=if_unused) - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) async def unsubscribe_results( pkg: "type[ProtoBunnyMessage] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - async with default_registry.lock: - queue = default_registry.get_results(pkg) + async with registry.lock: + queue = registry.get_results(pkg) if queue: await queue.unsubscribe_results() - default_registry.unregister_results(pkg) + registry.unregister_results(pkg) async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: @@ -209,19 +209,19 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - async with default_registry.lock: + async with registry.lock: queues = itertools.chain( - default_registry.get_all_subscriptions(), - default_registry.get_all_tasks(flat=True), + registry.get_all_subscriptions(), + registry.get_all_tasks(flat=True), ) for queue in queues: await queue.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - default_registry.unregister_all_tasks() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + registry.unregister_all_tasks() + queues = registry.get_all_results() for queue in queues: await queue.unsubscribe_results() - default_registry.unregister_all_results() + registry.unregister_all_results() async def subscribe_results( @@ -237,8 +237,8 @@ async def subscribe_results( queue = get_queue(pkg) await queue.subscribe_results(callback) # register subscription to unsubscribe later - async with default_registry.lock: - default_registry.register_results(pkg, queue) + async with registry.lock: + registry.register_results(pkg, queue) return queue diff --git a/protobunny/asyncio/__init__.py.j2 b/protobunny/asyncio/__init__.py.j2 index 9e5e8ee..53b87c3 100644 --- a/protobunny/asyncio/__init__.py.j2 +++ b/protobunny/asyncio/__init__.py.j2 @@ -26,7 +26,7 @@ __all__ = [ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -52,14 +52,14 @@ from importlib.metadata import version ####################################################### -from ..config import ( # noqa - GENERATED_PACKAGE_NAME, - PACKAGE_NAME, - ROOT_GENERATED_PACKAGE_NAME, - default_configuration, +from ..conf import ( # noqa + GENERATED_PACKAGE_NAME, + PACKAGE_NAME, + ROOT_GENERATED_PACKAGE_NAME, + config, ) from ..exceptions import RequeueMessage, ConnectionError -from ..registry import default_registry +from ..registry import registry if tp.TYPE_CHECKING: from ..models import LoggerCallback, ProtoBunnyMessage, AsyncCallback, IncomingMessageProtocol from ..core.results import Result @@ -147,18 +147,18 @@ async def subscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = str(pkg) - async with default_registry.lock: + async with registry.lock: if is_module_tasks(module_name): # It's a task. Handle multiple in-process subscriptions queue = get_queue(pkg) await queue.subscribe(callback) - default_registry.register_task(registry_key, queue) + registry.register_task(registry_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(registry_key) or get_queue(pkg) + queue = registry.get_subscription(registry_key) or get_queue(pkg) # queue already exists, but not subscribed yet (otherwise raise ValueError) await queue.subscribe(callback) - default_registry.register_subscription(registry_key, queue) + registry.register_subscription(registry_key, queue) return queue @@ -171,29 +171,29 @@ async def unsubscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ - registry_key = default_registry.get_key(pkg) - async with default_registry.lock: + registry_key = registry.get_key(pkg) + async with registry.lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for q in queues: await q.unsubscribe(if_unused=if_unused) - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) async def unsubscribe_results( pkg: "type[ProtoBunnyMessage] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - async with default_registry.lock: - queue = default_registry.get_results(pkg) + async with registry.lock: + queue = registry.get_results(pkg) if queue: await queue.unsubscribe_results() - default_registry.unregister_results(pkg) + registry.unregister_results(pkg) async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: @@ -203,18 +203,18 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - async with default_registry.lock: + async with registry.lock: queues = itertools.chain( - default_registry.get_all_subscriptions(), default_registry.get_all_tasks(flat=True) + registry.get_all_subscriptions(), registry.get_all_tasks(flat=True) ) for queue in queues: await queue.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - default_registry.unregister_all_tasks() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + registry.unregister_all_tasks() + queues = registry.get_all_results() for queue in queues: await queue.unsubscribe_results() - default_registry.unregister_all_results() + registry.unregister_all_results() async def subscribe_results( @@ -230,8 +230,8 @@ async def subscribe_results( queue = get_queue(pkg) await queue.subscribe_results(callback) # register subscription to unsubscribe later - async with default_registry.lock: - default_registry.register_results(pkg, queue) + async with registry.lock: + registry.register_results(pkg, queue) return queue diff --git a/protobunny/asyncio/backends/__init__.py b/protobunny/asyncio/backends/__init__.py index 30786c5..ee4817f 100644 --- a/protobunny/asyncio/backends/__init__.py +++ b/protobunny/asyncio/backends/__init__.py @@ -15,7 +15,7 @@ LoggerCallback, ProtoBunnyMessage, SyncCallback, - default_configuration, + config, deserialize_message, deserialize_result_message, get_body, @@ -42,7 +42,6 @@ class BaseConnection(ABC): exchange_name: str | None dl_exchange: str | None dl_queue: str | None - heartbeat: int | None timeout: int | None url: str | None = None queues: dict[str, tp.Any] = {} @@ -109,11 +108,11 @@ class BaseAsyncConnection(BaseConnection, ABC): instance_by_vhost: dict[str, Self] = {} @abstractmethod - async def connect(self, timeout: float = 30) -> None: + async def connect(self, **kwargs) -> None: ... @abstractmethod - async def disconnect(self, timeout: float = 30) -> None: + async def disconnect(self, **kwargs) -> None: ... def __init__(self, **kwargs): @@ -141,7 +140,7 @@ def is_connected_event(self) -> asyncio.Event: ... @classmethod - async def get_connection(cls, vhost: str = "/") -> Self: + async def get_connection(cls, vhost: str = "/", **kwargs) -> Self: """Get singleton instance (async).""" current_loop = asyncio.get_running_loop() async with cls._get_class_lock(): @@ -159,7 +158,7 @@ async def get_connection(cls, vhost: str = "/") -> Self: log.debug("Creating fresh connection for %s", vhost) new_instance = cls(vhost=vhost) new_instance._loop = current_loop # Store the loop it was born in - await new_instance.connect() + await new_instance.connect(**kwargs) cls.instance_by_vhost[vhost] = new_instance instance = new_instance return instance @@ -190,7 +189,7 @@ async def _receive( callback: a callable accepting a message as only argument. message: the IncomingMessageProtocol object received from the queue. """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if not message.routing_key: raise ValueError("Routing key was not set. Invalid topic") if message.routing_key == self.result_topic or message.routing_key.endswith( @@ -379,10 +378,10 @@ class LoggingAsyncQueue(BaseAsyncQueue): """ def __init__(self, prefix: str) -> None: - backend = default_configuration.backend_config + backend = config.backend_config delimiter = backend.topic_delimiter wildcard = backend.multi_wildcard_delimiter - prefix = prefix or default_configuration.messages_prefix + prefix = prefix or config.messages_prefix super().__init__(f"{prefix}{delimiter}{wildcard}") def get_tag(self) -> str: @@ -442,7 +441,7 @@ def is_task(topic: str) -> bool: Returns: True if tasks is in the topic, else False """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) diff --git a/protobunny/asyncio/backends/mosquitto/connection.py b/protobunny/asyncio/backends/mosquitto/connection.py index 4256a8e..068ada7 100644 --- a/protobunny/asyncio/backends/mosquitto/connection.py +++ b/protobunny/asyncio/backends/mosquitto/connection.py @@ -9,7 +9,7 @@ import aiomqtt import can_ada -from ....config import default_configuration +from ....conf import config from ....exceptions import ConnectionError, PublishError, RequeueMessage from ....models import Envelope, IncomingMessageProtocol from .. import BaseAsyncConnection @@ -66,8 +66,8 @@ def __init__( self._connection: aiomqtt.Client | None = None self._instance_lock: asyncio.Lock | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter # e.g. "/" - self._exchange = default_configuration.backend_config.namespace # used as root prefix + self._delimiter = config.backend_config.topic_delimiter # e.g. "/" + self._exchange = config.backend_config.namespace # used as root prefix @property def is_connected_event(self) -> asyncio.Event: @@ -86,7 +86,7 @@ def build_topic_key(self, topic: str) -> str: """Joins project prefix with topic using the configured delimiter.""" return f"{self._exchange}{self._delimiter}{topic}" - async def connect(self, timeout: float = 30.0) -> "Connection": + async def connect(self, **kwargs) -> "Connection": async with self.lock: if self.is_connected(): return self @@ -98,7 +98,7 @@ async def connect(self, timeout: float = 30.0) -> "Connection": port=self.port, username=self.username, password=self.password, - timeout=timeout, + **kwargs, ) # We enter the context manager to establish the connection await self._connection.__aenter__() diff --git a/protobunny/asyncio/backends/nats/connection.py b/protobunny/asyncio/backends/nats/connection.py index 834b23b..9d81a10 100644 --- a/protobunny/asyncio/backends/nats/connection.py +++ b/protobunny/asyncio/backends/nats/connection.py @@ -11,11 +11,12 @@ import can_ada import nats +from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.errors import ConnectionClosedError, TimeoutError from nats.js.errors import BadRequestError, NoStreamResponseError -from ....config import default_configuration +from ....conf import config from ....exceptions import ConnectionError, PublishError, RequeueMessage from ....models import Envelope, IncomingMessageProtocol from .. import BaseAsyncConnection, is_task @@ -92,12 +93,13 @@ def __init__( self.heartbeat = heartbeat self.queues: dict[str, list[dict]] = defaultdict(list) self.consumers: dict[str, dict] = {} + # to run sync callbacks self.executor = ThreadPoolExecutor(max_workers=worker_threads) self._instance_lock: asyncio.Lock | None = None - - self._delimiter = default_configuration.backend_config.topic_delimiter - self._exchange = default_configuration.backend_config.namespace - self._stream_name = f"{self._exchange.upper()}_TASKS" + self._delimiter = config.backend_config.topic_delimiter + self._namespace = config.backend_config.namespace + self._tasks_subject_prefix = "TASKS" + self._stream_name = f"{self._namespace.upper()}_{self._tasks_subject_prefix}" async def __aenter__(self) -> "Connection": await self.connect() @@ -122,7 +124,7 @@ def _get_class_lock(cls) -> asyncio.Lock: return cls._lock def build_topic_key(self, topic: str) -> str: - return f"{self._exchange}.{topic}" + return f"{self._namespace}.{topic}" @property def is_connected_event(self) -> asyncio.Event: @@ -142,11 +144,10 @@ def connection(self) -> "nats.NATS": raise ConnectionError("Connection not initialized. Call connect() first.") return self._connection - async def connect(self, timeout: float = 30.0) -> "Connection": + async def connect(self, **kwargs) -> "Connection": """Establish NATS connection. Args: - timeout: Maximum time to wait for connection establishment (seconds) Raises: ConnectionError: If connection fails @@ -157,19 +158,17 @@ async def connect(self, timeout: float = 30.0) -> "Connection": return self.instance_by_vhost[self.vhost] try: log.info("Establishing NATS connection to %s", self._url.split("@")[-1]) - self._connection = await nats.connect( - self._url, connect_timeout=timeout, max_reconnect_attempts=3 - ) + self._connection = await nats.connect(self._url, **kwargs) self.is_connected_event.set() log.info("Successfully connected to NATS") self.instance_by_vhost[self.vhost] = self - if default_configuration.use_tasks_in_nats: + if config.use_tasks_in_nats: # Create the jetstream if not existing js = self._connection.jetstream() # For NATS, tasks package can only be at first level after main package library # Warning: don't bury tasks messages after three levels of hierarchy task_patterns = [ - f"{self._exchange}.*.tasks.>", + f"{self._tasks_subject_prefix}{self._delimiter}>", ] try: await js.add_stream( @@ -244,12 +243,17 @@ async def setup_queue( topic_key = self.build_topic_key(topic) cb = functools.partial(self._nats_handler, callback) if shared: - log.debug("Subscribing shared worker to JetStream: %s", topic_key) js = self._connection.jetstream() # We use a durable name so multiple instances share the same task state - group_name = f"{self._exchange}_{topic_key.replace('.', '_')}" + group_name = topic_key.replace(".", "_") + log.debug( + "Subscribing shared worker to JetStream group %s subject %s", group_name, topic_key + ) subscription = await js.subscribe( - subject=topic_key, + # the topic with prefixes + subject=f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}", + # add queue parameter to flag it as a distributed queue + queue=group_name, durable=group_name, cb=cb, manual_ack=True, @@ -275,18 +279,25 @@ async def subscribe(self, topic: str, callback: tp.Callable, shared: bool = Fals } return sub_tag - async def _nats_handler(self, callback, msg): + async def _nats_handler(self, callback, msg: Msg): + """Callback that handles the Msg object pushed from NATS""" topic = msg.subject + is_shared_queue = is_task(topic) reply = msg.reply body = msg.data - is_shared_queue = is_task(topic) - routing_key = msg.subject.removeprefix(f"{self._exchange}{self._delimiter}") + # Remove the 'TASKS.' prefix that was added to match the filtering stream group + if is_shared_queue: + topic = topic.removeprefix(f"{self._tasks_subject_prefix}{self._delimiter}") + + # The routing key is the string used to match the protobuf python class fqn + routing_key = topic.removeprefix(f"{self._namespace}{self._delimiter}") envelope = Envelope(body=body, correlation_id=reply, routing_key=routing_key) try: if asyncio.iscoroutinefunction(callback): await callback(envelope) else: - asyncio.run_coroutine_threadsafe(callback(envelope), self._loop) + # Run the callback in a thread pool to avoid blocking the event loop + await asyncio.get_event_loop().run_in_executor(self.executor, callback, envelope) if is_shared_queue: await msg.ack() except RequeueMessage: @@ -295,14 +306,10 @@ async def _nats_handler(self, callback, msg): if not is_shared_queue: await self._connection.publish(topic, body, reply=reply) else: - # TODO check if NATS has a requeue logic - js = self._connection.jetstream() - await js.publish(topic, body) - await msg.ack() + await msg.nak(self.requeue_delay) except Exception: log.exception("Callback failed for topic %s", topic) - # TODO check if NATS has a reject logic - await msg.ack() # avoid retry logic for potentially poisoning messages + await msg.term() # avoid retry logic for potentially poisoning messages async def unsubscribe(self, tag: str, **kwargs) -> None: if tag not in self.consumers: @@ -312,7 +319,6 @@ async def unsubscribe(self, tag: str, **kwargs) -> None: del sub_info["subscription"] log.info("Unsubscribed from %s", sub_info["topic"]) self.consumers.pop(tag) - # TODO check if we need to handle self.queues[topic] cleanup here async def publish( self, @@ -322,7 +328,6 @@ async def publish( ) -> None: if not self.is_connected(): raise ConnectionError("Not connected to NATS") - topic_key = self.build_topic_key(topic) is_shared = is_task(topic) @@ -332,9 +337,17 @@ async def publish( try: if is_shared: # Persistent "Task" publishing via JetStream + stream_key = f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}" log.debug("Publishing persistent task to NATS JetStream: %s", topic_key) js = self._connection.jetstream() - await js.publish(subject=topic_key, payload=message.body, headers=headers) + await js.publish(subject=stream_key, payload=message.body, headers=headers) + if config.log_task_in_nats: + # The logger service doesn't use jetstream so we re-publish on a normal pubsub + # (it won't be re-catched by the tasks consumer) + log.debug("Publishing logging message for task to NATS Core: %s", topic_key) + await self._connection.publish( + subject=topic_key, payload=message.body, headers=headers + ) else: # Volatile "PubSub" publishing via NATS Core log.debug("Publishing broadcast to NATS Core: %s", topic_key) @@ -352,17 +365,18 @@ async def purge(self, topic: str, reset_groups: bool = False) -> None: if not self.is_connected(): raise ConnectionError("Not connected to NATS") topic_key = self.build_topic_key(topic) + subject = f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}" # NATS purges messages matching a subject within the stream try: jsm = self._connection.jsm() # Get JetStream Management context log.info("Purging NATS subject '%s' from stream %s", topic, self._stream_name) - await jsm.purge_stream(self._stream_name, subject=topic_key) + await jsm.purge_stream(self._stream_name, subject=subject) if reset_groups: # In NATS, we must find consumers specifically tied to this topic # Protobunny convention: durable name includes the topic - group_name = f"{self._exchange}_{topic_key.replace('.', '_')}" + group_name = f"{topic_key.replace('.', '_')}" try: await jsm.delete_consumer(self._stream_name, group_name) log.debug("Deleted NATS durable consumer: %s", group_name) diff --git a/protobunny/asyncio/backends/python/connection.py b/protobunny/asyncio/backends/python/connection.py index 01fbe31..c4027ee 100644 --- a/protobunny/asyncio/backends/python/connection.py +++ b/protobunny/asyncio/backends/python/connection.py @@ -7,7 +7,7 @@ from collections import defaultdict from queue import Empty, Queue -from ....config import default_configuration +from ....conf import config from ....models import AsyncCallback, Envelope from ... import RequeueMessage from .. import BaseConnection, is_task @@ -125,7 +125,7 @@ def __init__(self, vhost: str = "/", requeue_delay: int = 3): self.requeue_delay = requeue_delay self._is_connected = False self._subscriptions: dict[str, dict] = {} - self.logger_prefix = default_configuration.logger_prefix + self.logger_prefix = config.logger_prefix class Connection(BaseLocalConnection): diff --git a/protobunny/asyncio/backends/redis/connection.py b/protobunny/asyncio/backends/redis/connection.py index a955e6f..1aa166e 100644 --- a/protobunny/asyncio/backends/redis/connection.py +++ b/protobunny/asyncio/backends/redis/connection.py @@ -13,7 +13,7 @@ import redis.asyncio as redis from redis import RedisError, ResponseError -from ....config import default_configuration +from ....conf import config from ....exceptions import ConnectionError, PublishError, RequeueMessage from ....models import Envelope, IncomingMessageProtocol from .. import BaseAsyncConnection, is_task @@ -41,7 +41,7 @@ def __init__( worker_threads: int = 2, prefetch_count: int = 1, requeue_delay: int = 3, - heartbeat: int = 1200, + **kwargs, ): """Initialize Redis connection. @@ -51,7 +51,7 @@ def __init__( host: Redis host port: Redis port url: Redis URL. It will override username, password, host and port - vhost: Redis virtual host (it's used as db number string) + vhost: Redis virtual host (it's used as db number string if db not present) db: Redis database number worker_threads: number of concurrent callback workers to use prefetch_count: how many messages to prefetch from the queue @@ -93,17 +93,16 @@ def __init__( self._connection: redis.Redis | None = None self.prefetch_count = prefetch_count self.requeue_delay = requeue_delay - self.heartbeat = heartbeat self.queues: dict[str, dict] = {} self.consumers: dict[str, dict[str, tp.Any]] = {} self.executor = ThreadPoolExecutor(max_workers=worker_threads) self._instance_lock: asyncio.Lock | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter - self._exchange = default_configuration.backend_config.namespace + self._delimiter = config.backend_config.topic_delimiter + self._exchange = config.backend_config.namespace - async def __aenter__(self) -> "Connection": - await self.connect() + async def __aenter__(self, **kwargs) -> "Connection": + await self.connect(**kwargs) return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: @@ -145,11 +144,10 @@ def connection(self) -> "redis.Redis": raise ConnectionError("Connection not initialized. Call connect() first.") return self._connection - async def connect(self, timeout: float = 30.0) -> "Connection": + async def connect(self, **kwargs) -> "Connection": """Establish Redis connection. Args: - timeout: Maximum time to wait for connection establishment (seconds) Raises: ConnectionError: If connection fails @@ -161,23 +159,23 @@ async def connect(self, timeout: float = 30.0) -> "Connection": try: # Parsing URL for logging (removing credentials) log.info("Establishing Redis connection to %s", self._url.split("@")[-1]) - + # protobunny sends raw bytes with protobuf serialized payloads + kwargs.pop("decode_responses", None) # Using from_url handles connection pooling automatically self._connection = redis.from_url( self._url, decode_responses=False, - socket_connect_timeout=timeout, - health_check_interval=self.heartbeat, + **kwargs, ) - await asyncio.wait_for(self._connection.ping(), timeout=timeout) + await asyncio.wait_for(self._connection.ping(), timeout=30) self.is_connected_event.set() log.info("Successfully connected to Redis") self.instance_by_vhost[self.vhost] = self return self except asyncio.TimeoutError: - log.error("Redis connection timeout after %.1f seconds", timeout) + log.error("Redis connection timeout after %.1f seconds", 30) self.is_connected_event.clear() self._connection = None raise @@ -492,7 +490,7 @@ async def publish( "topic": topic, # add the topic here to implement topic exchange patterns } await self._connection.xadd(name=topic_key, fields=payload, maxlen=1000) - if default_configuration.log_task_in_redis: + if config.log_task_in_redis: # Tasks messages go to streams but the logger do a simple pubsub psubscription to .* # Send the message to the same topic with redis.publish so it appears there # Note: this should be used carefully (e.g. only for debugging) diff --git a/protobunny/backends/__init__.py b/protobunny/backends/__init__.py index a040ee9..5604490 100644 --- a/protobunny/backends/__init__.py +++ b/protobunny/backends/__init__.py @@ -14,7 +14,7 @@ LoggerCallback, ProtoBunnyMessage, SyncCallback, - default_configuration, + config, deserialize_message, deserialize_result_message, get_body, @@ -67,7 +67,7 @@ def is_connected(self) -> bool | tp.Awaitable[bool]: ... @abstractmethod - def connect(self, timeout: float = 30) -> None | tp.Awaitable[None]: + def connect(self, **kwargs) -> None | tp.Awaitable[None]: ... @abstractmethod @@ -133,6 +133,7 @@ def __init__(self, **kwargs): self.vhost = self._async_conn.vhost self._started = False self.instance_by_vhost = {} + self._timeout_coro = 10 def get_async_connection(self, **kwargs) -> "BaseAsyncConnection": if hasattr(self, "_async_conn"): @@ -370,7 +371,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect() return False - def connect(self, timeout: float = 10.0) -> None: + def connect(self, **kwargs) -> None: """Establish Sync connection. Args: @@ -380,7 +381,7 @@ def connect(self, timeout: float = 10.0) -> None: ConnectionError: If connection fails TimeoutError: If connection times out """ - self._run_coro(self._async_conn.connect(timeout), timeout=timeout) + self._run_coro(self._async_conn.connect(), timeout=self._timeout_coro) self.__class__.instance_by_vhost[self.vhost] = self def disconnect(self, timeout: float = 10.0) -> None: @@ -452,7 +453,7 @@ def _receive( """ if not message.routing_key: raise ValueError("Routing key was not set. Invalid topic") - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if message.routing_key == self.result_topic or message.routing_key.endswith( f"{delimiter}result" ): @@ -610,10 +611,10 @@ async def send_message( raise NotImplementedError() def __init__(self, prefix: str) -> None: - backend = default_configuration.backend_config + backend = config.backend_config delimiter = backend.topic_delimiter wildcard = backend.multi_wildcard_delimiter - prefix = prefix or default_configuration.messages_prefix + prefix = prefix or config.messages_prefix super().__init__(f"{prefix}{delimiter}{wildcard}") @property @@ -661,7 +662,7 @@ def get_tag(self) -> str: def is_task(topic: str) -> bool: - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) diff --git a/protobunny/backends/mosquitto/connection.py b/protobunny/backends/mosquitto/connection.py index 58c08a2..0ecc6b5 100644 --- a/protobunny/backends/mosquitto/connection.py +++ b/protobunny/backends/mosquitto/connection.py @@ -11,7 +11,7 @@ import can_ada import paho.mqtt.client as mqtt -from ...config import default_configuration +from ...conf import config from ...exceptions import ConnectionError, PublishError, RequeueMessage from ...models import Envelope, IncomingMessageProtocol from .. import BaseConnection @@ -104,20 +104,21 @@ def __init__( self._main_client: mqtt.Client | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter - self._exchange = default_configuration.backend_config.namespace + self._delimiter = config.backend_config.topic_delimiter + self._exchange = config.backend_config.namespace def build_topic_key(self, topic: str) -> str: return f"{self._exchange}{self._delimiter}{topic}" - def connect(self, timeout: float = 10.0) -> "Connection": + def connect(self, **kwargs) -> "Connection": with self._lock: if self.is_connected(): return self try: # Use MQTT v5 for shared subscriptions support + kwargs.pop("callback_api_version", None) self._main_client = mqtt.Client( - callback_api_version=mqtt.CallbackAPIVersion.VERSION2 + callback_api_version=mqtt.CallbackAPIVersion.VERSION2, **kwargs ) if self.username: self._main_client.username_pw_set(self.username, self.password) diff --git a/protobunny/backends/python/connection.py b/protobunny/backends/python/connection.py index 59878f4..b42edef 100644 --- a/protobunny/backends/python/connection.py +++ b/protobunny/backends/python/connection.py @@ -9,7 +9,7 @@ from queue import Empty, Queue from ... import RequeueMessage -from ...config import default_configuration +from ...conf import config from ...models import Envelope, SyncCallback, is_task from .. import BaseConnection @@ -116,7 +116,7 @@ def __init__(self, vhost: str = "/", requeue_delay: int = 3): self.requeue_delay = requeue_delay self._is_connected = False self._subscriptions: dict[str, dict] = {} - self.logger_prefix = default_configuration.logger_prefix + self.logger_prefix = config.logger_prefix def build_topic_key(self, topic: str) -> str: pass @@ -179,7 +179,7 @@ def get_connection(cls, vhost: str = "/") -> "Connection": cls.instance_by_vhost[vhost] = instance return cls.instance_by_vhost[vhost] - def connect(self, timeout: float = 10.0) -> "Connection": + def connect(self, **kwargs) -> "Connection": with self.lock: self.is_connected_event.set() return self diff --git a/protobunny/config.py b/protobunny/conf.py similarity index 99% rename from protobunny/config.py rename to protobunny/conf.py index 7a27b67..5933f6f 100644 --- a/protobunny/config.py +++ b/protobunny/conf.py @@ -72,6 +72,7 @@ class Config: backend_config: BackEndConfig = rabbitmq_backend_config log_task_in_redis: bool = False use_tasks_in_nats: bool = True # needed to create the namespaced group stream on connect + log_task_in_nats: bool = False available_backends = ("rabbitmq", "python", "redis", "mosquitto", "nats") @@ -198,4 +199,4 @@ def get_project_version() -> str: VERSION = get_project_version() -default_configuration = load_config() +config = load_config() diff --git a/protobunny/helpers.py b/protobunny/helpers.py index 695c682..a733420 100644 --- a/protobunny/helpers.py +++ b/protobunny/helpers.py @@ -6,7 +6,7 @@ import betterproto -from .config import default_configuration +from .conf import config if typing.TYPE_CHECKING: from .asyncio.backends import BaseAsyncQueue @@ -25,8 +25,8 @@ def get_topic(pkg_or_msg: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleT Returns: topic string """ - delimiter = default_configuration.backend_config.topic_delimiter - return f"{default_configuration.messages_prefix}{delimiter}{build_routing_key(pkg_or_msg)}" + delimiter = config.backend_config.topic_delimiter + return f"{config.messages_prefix}{delimiter}{build_routing_key(pkg_or_msg)}" def get_backend(backend: str | None = None) -> ModuleType: @@ -44,8 +44,8 @@ def get_backend(backend: str | None = None) -> ModuleType: Returns: The imported backend module. """ - backend = backend or default_configuration.backend - module = ".asyncio" if default_configuration.use_async else "" + backend = backend or config.backend + module = ".asyncio" if config.use_async else "" module_name = f"protobunny{module}.backends.{backend}" if module_name in sys.modules: return sys.modules[module_name] @@ -53,8 +53,8 @@ def get_backend(backend: str | None = None) -> ModuleType: module = importlib.import_module(module_name) except ModuleNotFoundError as exc: suggestion = "" - if backend not in default_configuration.available_backends: - suggestion = f" Invalid backend or backend not supported.\nAvailable backends: {default_configuration.available_backends}" + if backend not in config.available_backends: + suggestion = f" Invalid backend or backend not supported.\nAvailable backends: {config.available_backends}" else: suggestion = ( f" Install the backend with pip install protobunny[{backend}]." @@ -82,15 +82,14 @@ def get_queue( Returns: Async/SyncQueue: A queue instance configured for the relevant topic. """ - backend_name = backend_name or default_configuration.backend - queue_type = "AsyncQueue" if default_configuration.use_async else "SyncQueue" + backend_name = backend_name or config.backend + queue_type = "AsyncQueue" if config.use_async else "SyncQueue" return getattr(get_backend(backend=backend_name).queues, queue_type)(get_topic(pkg_or_msg)) @functools.lru_cache(maxsize=100) def _build_routing_key(module: str, cls_name: str) -> str: # Build the routing key from the module and class name - config = default_configuration backend = config.backend_config delimiter = backend.topic_delimiter routing_key = f"{module}.{cls_name}" @@ -126,7 +125,7 @@ def build_routing_key( Returns: a routing key based on the type of message or package """ - backend = default_configuration.backend_config + backend = config.backend_config wildcard = backend.multi_wildcard_delimiter module_name = "" class_name = "" diff --git a/protobunny/logger.py b/protobunny/logger.py index 912522a..2ccc2f3 100644 --- a/protobunny/logger.py +++ b/protobunny/logger.py @@ -40,7 +40,7 @@ import protobunny as pb_sync from protobunny import asyncio as pb -from protobunny.config import load_config +from protobunny.conf import load_config from protobunny.models import IncomingMessageProtocol, LoggerCallback logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") diff --git a/protobunny/models.py b/protobunny/models.py index 27a3f0b..0ba34cf 100644 --- a/protobunny/models.py +++ b/protobunny/models.py @@ -12,7 +12,7 @@ import betterproto from betterproto.lib.std.google.protobuf import Any -from .config import default_configuration +from .conf import config from .helpers import get_topic from .utils import ProtobunnyJsonEncoder @@ -26,7 +26,7 @@ def is_task(topic: str) -> bool: - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) @@ -67,7 +67,7 @@ def __bytes__(self: "ProtoBunnyMessage") -> bytes: # Override Message.__bytes__ method # to support transparent serialization of dictionaries to JsonContent fields. # This method validates for required fields as well - if default_configuration.force_required_fields: + if config.force_required_fields: self.validate_required_fields() msg = self.serialize_json_content() with BytesIO() as stream: @@ -236,7 +236,7 @@ def result_topic(self: "ProtoBunnyMessage") -> str: """ Build the result topic name for the message. """ - return f"{get_topic(self)}{default_configuration.backend_config.topic_delimiter}result" + return f"{get_topic(self)}{config.backend_config.topic_delimiter}result" def make_result( self: "ProtoBunnyMessage", @@ -362,16 +362,16 @@ def get_message_class_from_topic(topic: str) -> "type[ProtoBunnyMessage] | None Returns: the message class for the topic or None if the topic is not recognized """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if topic.endswith(f"{delimiter}result"): message_type = Result else: - route = topic.removeprefix(f"{default_configuration.messages_prefix}{delimiter}") + route = topic.removeprefix(f"{config.messages_prefix}{delimiter}") if route == topic: # the prefix is not present in the topic # Try if it's a protobunny class # to allow pb.* internal messages like pb.results.Result route = topic.removeprefix(f"pb{delimiter}") - codegen_module = importlib.import_module(default_configuration.generated_package_name) + codegen_module = importlib.import_module(config.generated_package_name) # if route is not recognized at this point, the message_type will be None message_type = _get_submodule(codegen_module, route.split(delimiter)) return message_type @@ -387,9 +387,9 @@ def get_message_class_from_type_url(url: str) -> type["ProtoBunnyMessage"]: Returns: the message class """ module_path, clz = url.rsplit(".", 1) - if not module_path.startswith(default_configuration.generated_package_name): + if not module_path.startswith(config.generated_package_name): raise ValueError( - f"Invalid type url {url}, must start with {default_configuration.generated_package_name}." + f"Invalid type url {url}, must start with {config.generated_package_name}." ) module = importlib.import_module(module_path) message_type = getattr(module, clz) @@ -460,7 +460,7 @@ def __init__(self, topic: str): Args: topic: a Topic value object """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter self.topic: str = topic.replace(".", delimiter) self.subscription: str | None = None self.result_subscription: str | None = None @@ -468,7 +468,7 @@ def __init__(self, topic: str): @property def result_topic(self) -> str: - return f"{self.topic}{default_configuration.backend_config.topic_delimiter}result" + return f"{self.topic}{config.backend_config.topic_delimiter}result" @abstractmethod def get_tag(self) -> str: @@ -561,7 +561,7 @@ def get_body(message: "IncomingMessageProtocol") -> str: """ msg: ProtoBunnyMessage | None body: str | bytes - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if message.routing_key and message.routing_key.endswith(f"{delimiter}result"): # log result message. Need to extract the source here result = deserialize_result_message(message.body) diff --git a/protobunny/registry.py b/protobunny/registry.py index 02cd9f0..912a928 100644 --- a/protobunny/registry.py +++ b/protobunny/registry.py @@ -77,4 +77,4 @@ def unregister_all_results(self): # Then create a default instance -default_registry = SubscriptionRegistry() +registry = SubscriptionRegistry() diff --git a/protobunny/wrapper.py b/protobunny/wrapper.py index 699aaf7..f8ea798 100644 --- a/protobunny/wrapper.py +++ b/protobunny/wrapper.py @@ -64,7 +64,7 @@ sys.exit(config_error) -from .config import load_config +from .conf import load_config from .logger import log_callback, start_logger, start_logger_sync diff --git a/pyproject.toml b/pyproject.toml index 589c3ef..a26ea89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,7 @@ dependencies = [ rabbitmq = ["aio-pika>=7.1.0"] redis = ["redis<8"] numpy = ["numpy>=1.26"] -mosquitto = [ - "aiomqtt>=2.4.0" -] +mosquitto = ["aiomqtt>=2.4.0"] nats = ["nats-py>=2.12.0"] [project.scripts] @@ -97,6 +95,7 @@ force-required-fields = true backend = "redis" mode = "async" log-task-in-redis = true +log-task-in-nats = true [tool.pytest.ini_options] addopts = "--junitxml=pytest_report.xml --cov=. --cov-report=term --cov-report=xml --cov-report html --no-success-flaky-report" diff --git a/setup.py b/setup.py index cec0c9b..6331cf7 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "protobunny"))) import typing as tp -from config import ( +from conf import ( GENERATED_PACKAGE_NAME, PACKAGE_NAME, PROJECT_NAME, diff --git a/tests/conftest.py b/tests/conftest.py index 91cf566..7a9fbc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,8 +22,8 @@ @pytest_asyncio.fixture -def test_config() -> protobunny.config.Config: - conf = protobunny.config.Config( +def test_config() -> protobunny.conf.Config: + conf = protobunny.conf.Config( messages_directory="tests/proto", messages_prefix="acme", generated_package_name="tests", @@ -34,6 +34,7 @@ def test_config() -> protobunny.config.Config: backend="rabbitmq", log_task_in_redis=True, use_tasks_in_nats=True, + log_task_in_nats=True, ) return conf @@ -117,20 +118,11 @@ async def mock_redis_client(mocker) -> tp.AsyncGenerator[fakeredis.FakeAsyncRedi @pytest.fixture async def mock_nats(mocker): - # with patch("nats.NATS") as mock_nats_client: #, patch("nats.connect"): - # 1. Create the top-level Mock Client mock_nc = AsyncMock(spec=nats.aio.client.Client) - - # 2. Mock the JetStream Context (.jetstream()) mock_js = AsyncMock() mock_nc.jetstream.return_value = mock_js - - # 3. Mock the Management API (.jsm()) mock_jsm = AsyncMock() mock_nc.jsm.return_value = mock_jsm - - # 4. PATCH: Intercept the nats.connect call - # Replace 'your_module.nats.connect' with the path where you import nats mocker.patch("nats.connect", return_value=mock_nc) yield { "client": mock_nc, diff --git a/tests/test_base.py b/tests/test_base.py index ae437b8..8c52c24 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -23,11 +23,11 @@ def setup_config(mocker, test_config) -> None: test_config.mode = "async" test_config.backend = "python" - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(python_backend.connection, "default_configuration", test_config) + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(python_backend.connection, "config", test_config) def test_json_serializer() -> None: diff --git a/tests/test_config.py b/tests/test_config.py index 63983be..1504f61 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ import pytest -from protobunny.config import load_config +from protobunny.conf import load_config @pytest.fixture(autouse=True) @@ -48,7 +48,7 @@ def test_config_from_ini(tmp_path, monkeypatch): (tmp_path / "protobunny.ini").write_text(ini_content) # Ensure pyproject doesn't interfere - with mock.patch("protobunny.config.get_config_from_pyproject", return_value={}): + with mock.patch("protobunny.conf.get_config_from_pyproject", return_value={}): config = load_config() assert config.backend == "python" assert config.mode == "async" @@ -69,7 +69,7 @@ def test_config_precedence(tmp_path, monkeypatch): # Mock pyproject to provide a base mock_pyproject = {"backend": "rabbitmq", "messages_prefix": "py_pref", "project-name": "test"} - with mock.patch("protobunny.config.get_config_from_pyproject", return_value=mock_pyproject): + with mock.patch("protobunny.conf.get_config_from_pyproject", return_value=mock_pyproject): config = load_config() # Overridden by ENV diff --git a/tests/test_connection.py b/tests/test_connection.py index d4b97ed..59c9d7f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,14 +10,12 @@ import protobunny as pb_base from protobunny import RequeueMessage from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import nats as nats_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio from protobunny.backends import python as python_backend -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.helpers import ( + get_backend, get_queue, ) from protobunny.models import Envelope, IncomingMessageProtocol @@ -36,11 +34,11 @@ @pytest.mark.parametrize( "backend", [ - rabbitmq_backend_aio, - redis_backend_aio, - python_backend_aio, - mosquitto_backend_aio, - nats_backend_aio, + "rabbitmq", + "redis", + "python", + "mosquitto", + "nats", ], ) @pytest.mark.asyncio @@ -48,7 +46,7 @@ class TestConnection: @pytest.fixture(autouse=True) async def mock_connections( self, - backend, + backend: str, mocker, mock_redis_client, mock_aio_pika, @@ -56,26 +54,24 @@ async def mock_connections( mock_nats, test_config, ) -> tp.AsyncGenerator[dict[str, AsyncMock | None], None]: - backend_name = backend.__name__.split(".")[-1] - test_config.mode = "async" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - connection_module = getattr(pb.backends, backend_name).connection - if hasattr(connection_module, "default_configuration"): - mocker.patch.object(connection_module, "default_configuration", test_config) - - assert pb_base.helpers.get_backend() == backend - assert pb.get_backend() == backend - assert isinstance(get_queue(tests.tasks.TaskMessage), backend.queues.AsyncQueue) + test_config.backend_config = backend_configs[backend] + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + connection_module = getattr(pb.backends, backend).connection + if hasattr(connection_module, "config"): + mocker.patch.object(connection_module, "config", test_config) + + backend_module = get_backend() + assert pb.get_backend() == backend_module + assert isinstance(get_queue(tests.tasks.TaskMessage), backend_module.queues.AsyncQueue) conn_with_fake_internal_conn = get_mocked_connection( - backend, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats + backend_module, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats ) mocker.patch( "protobunny.asyncio.backends.BaseAsyncQueue.get_connection", @@ -97,8 +93,7 @@ async def mock_connection(self, mock_connections, backend): @pytest.fixture async def mock_internal_connection(self, mock_connections, backend): - backend_name = backend.__name__.split(".")[-1] - yield mock_connections[backend_name] + yield mock_connections[backend] async def test_connection_success( self, mock_connection: MagicMock, mock_internal_connection, backend @@ -128,9 +123,8 @@ async def test_publish_tasks( conn = await mock_connection.connect() incoming = incoming_message_factory(backend) - backend_name = backend.__name__.split(".")[-1] topic = "test.tasks.key" - delimiter = backend_configs[backend_name].topic_delimiter + delimiter = backend_configs[backend].topic_delimiter topic = topic.replace(".", delimiter) await conn.publish(topic, incoming) @@ -147,8 +141,7 @@ async def test_publish_tasks( async def test_publish(self, mock_connection: MagicMock, mock_internal_connection, backend): topic = "test.routing.key" - backend_name = backend.__name__.split(".")[-1] - delimiter = backend_configs[backend_name].topic_delimiter + delimiter = backend_configs[backend].topic_delimiter topic = topic.replace(".", delimiter) conn = await mock_connection.connect() msg = None diff --git a/tests/test_integration.py b/tests/test_integration.py index 784a0a4..570ba6f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,3 +1,5 @@ +import asyncio +import importlib import logging import typing as tp @@ -6,19 +8,10 @@ import pytest from pytest_mock import MockerFixture -import protobunny as pb_sync +import protobunny as pb_base from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import nats as nats_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio -from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio -from protobunny.asyncio.backends import redis as redis_backend_aio -from protobunny.backends import mosquitto as mosquitto_backend -from protobunny.backends import nats as nats_backend -from protobunny.backends import python as python_backend -from protobunny.backends import rabbitmq as rabbitmq_backend -from protobunny.backends import redis as redis_backend -from protobunny.config import Config, backend_configs +from protobunny import get_backend +from protobunny.conf import Config, backend_configs from protobunny.models import ProtoBunnyMessage from . import tests @@ -89,11 +82,11 @@ def log_callback(message: aio_pika.IncomingMessage, body: str) -> str: @pytest.mark.parametrize( "backend", [ - rabbitmq_backend_aio, - redis_backend_aio, - python_backend_aio, - mosquitto_backend_aio, - nats_backend_aio, + "rabbitmq", + "redis", + "python", + "mosquitto", + "nats", ], ) class TestIntegration: @@ -104,39 +97,40 @@ class TestIntegration: msg = tests.TestMessage(content="test", number=123, color=tests.Color.GREEN) - # @pytest.fixture(autouse=True) + @pytest.fixture(autouse=True) async def setup_test_env( - self, mocker: MockerFixture, test_config: Config, backend + self, mocker: MockerFixture, test_config: Config, backend: str ) -> tp.AsyncGenerator[None, None]: - backend_name = backend.__name__.split(".")[-1] test_config.mode = "async" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter - # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb.backend = backend - mocker.patch.object(pb, "get_backend", return_value=backend) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb, "config", test_config) + backend_module = get_backend() + assert backend_module.__name__.split(".")[-1] == backend + + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + pb.backend = backend_module + mocker.patch.object(pb, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend connection = await pb.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) queue = pb.get_queue(self.msg) assert queue.topic == "acme.tests.TestMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.AsyncQueue) - assert isinstance(await queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.AsyncQueue) + assert isinstance(await queue.get_connection(), backend_module.connection.Connection) # start without pending subscriptions await pb.unsubscribe_all(if_unused=False, if_empty=False) yield @@ -149,9 +143,9 @@ async def setup_test_env( "task": None, } await pb.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() - # @pytest.mark.flaky(max_runs=3) + @pytest.mark.flaky(max_runs=3) async def test_publish(self, backend) -> None: global received await pb.subscribe(self.msg.__class__, callback) @@ -164,7 +158,7 @@ async def predicate() -> bool: assert received["message"].number == self.msg.number assert received["message"].content == "test" - # @pytest.mark.flaky(max_runs=3) + @pytest.mark.flaky(max_runs=3) async def test_to_dict(self, backend) -> None: global received await pb.subscribe(self.msg.__class__, callback) @@ -183,7 +177,7 @@ async def predicate() -> bool: ) == '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' ) - await pb.subscribe(tests.tasks.TaskMessage, callback) + await pb.subscribe(tests.tasks.TaskMessage, callback_task) msg = tests.tasks.TaskMessage( content="test", bbox=[1, 2, 3, 4], @@ -191,10 +185,11 @@ async def predicate() -> bool: await pb.publish(msg) async def predicate() -> bool: - return received["message"] == msg + await asyncio.sleep(0) + return received["task"] == msg assert await async_wait(predicate, timeout=1, sleep=0.1) - assert received["message"].to_dict( + assert received["task"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { "content": "test", @@ -202,8 +197,8 @@ async def predicate() -> bool: "weights": [], "options": None, } - # to_pydict uses enum names and don't stringyfies int64 - assert received["message"].to_pydict( + # to_pydict uses enum names and don't stringifies int64 + assert received["task"].to_pydict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { "content": "test", @@ -214,8 +209,7 @@ async def predicate() -> bool: @pytest.mark.flaky(max_runs=3) async def test_count_messages(self, backend) -> None: - backend_name = backend.__name__.split(".")[-1] - if backend_name == "mosquitto": + if backend == "mosquitto": pytest.skip("mosquitto backend doesn't support message counts") task_queue = await pb.subscribe(tests.tasks.TaskMessage, callback) msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) @@ -229,7 +223,7 @@ async def predicate() -> bool: assert await async_wait( predicate - ), f"Messages were not in the queue: {await task_queue.get_message_count()}" + ), f"Message count should be 0: {await task_queue.get_message_count()}" # we unsubscribe so the published messages # won't be consumed and stay in the queue await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) @@ -456,9 +450,7 @@ async def predicate() -> bool: @pytest.mark.integration -@pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend, nats_backend] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestIntegrationSync: """Integration tests (to run with the backend server up)""" @@ -466,44 +458,48 @@ class TestIntegrationSync: @pytest.fixture(autouse=True) def setup_test_env( - self, mocker: MockerFixture, test_config: Config, backend + self, mocker: MockerFixture, test_config: Config, backend: str ) -> tp.Generator[None, None, None]: - backend_name = backend.__name__.split(".")[-1] test_config.mode = "sync" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb_sync.backend = backend - # mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_sync.helpers, "get_backend", return_value=backend) - # mocker.patch.object(pb_sync, "connect", backend.connection.connect) - # mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) - mocker.patch.object(pb_sync, "get_backend", return_value=backend) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + + backend_module = get_backend() + assert backend_module.__name__.split(".")[-1] == backend + async_backend_module = importlib.import_module(f"protobunny.asyncio.backends.{backend}") + if hasattr(async_backend_module.connection, "config"): + # The sync connection is often implemented as a wrapper of the relative async module. Patch the config of the async module as well + mocker.patch.object(async_backend_module.connection, "config", test_config) + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + mocker.patch.object(pb_base.helpers, "get_backend", return_value=backend_module) + mocker.patch.object(pb_base, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend - connection = pb_sync.connect() - assert isinstance(connection, backend.connection.Connection) - queue = pb_sync.get_queue(self.msg) + connection = pb_base.connect() + assert isinstance(connection, backend_module.connection.Connection) + queue = pb_base.get_queue(self.msg) assert queue.topic == "acme.tests.TestMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.SyncQueue) - assert isinstance(queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.SyncQueue) + assert isinstance(queue.get_connection(), backend_module.connection.Connection) + # Setup # start without pending subscriptions - pb_sync.unsubscribe_all(if_unused=False, if_empty=False) + pb_base.unsubscribe_all(if_unused=False, if_empty=False) yield + # Teardown # reset the variables holding the messages received global received received = { @@ -513,22 +509,22 @@ def setup_test_env( "task": None, } connection.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() @pytest.mark.flaky(max_runs=3) def test_publish(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] is not None) assert received["message"].number == self.msg.number @pytest.mark.flaky(max_runs=3) def test_to_dict(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.subscribe(tests.tasks.TaskMessage, callback_task_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] == self.msg) assert received["message"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True @@ -544,8 +540,8 @@ def test_to_dict(self, backend) -> None: content="test", bbox=[1, 2, 3, 4], ) - pb_sync.publish(msg) - assert sync_wait(lambda: received["task"] == msg) + pb_base.publish(msg) + assert sync_wait(lambda: received["task"] is not None) assert received["task"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { @@ -566,32 +562,31 @@ def test_to_dict(self, backend) -> None: @pytest.mark.flaky(max_runs=3) def test_count_messages(self, backend) -> None: - backend_name = backend.__name__.split(".")[-1] - if backend_name == "mosquitto": + if backend == "mosquitto": pytest.skip("mosquitto backend doesn't support message counts") # Subscribe to a tasks topic (shared queue) - task_queue = pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_sync) + task_queue = pb_base.subscribe(tests.tasks.TaskMessage, callback_task_sync) msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) - connection = pb_sync.connect() + connection = pb_base.connect() # remove past messages connection.purge(task_queue.topic, reset_groups=True) # we unsubscribe so the published messages # won't be consumed and stay in the queue task_queue.unsubscribe() assert sync_wait(lambda: 0 == task_queue.get_consumer_count()) - pb_sync.publish(msg) - pb_sync.publish(msg) - pb_sync.publish(msg) + pb_base.publish(msg) + pb_base.publish(msg) + pb_base.publish(msg) # and we can count them assert sync_wait(lambda: 3 == task_queue.get_message_count()) @pytest.mark.flaky(max_runs=3) def test_logger_body(self, backend) -> None: - pb_sync.subscribe_logger(log_callback) + pb_base.subscribe_logger(log_callback) topic = "acme.tests.TestMessage".replace(".", self.topic_delimiter) topic_result = "acme.tests.TestMessage.result".replace(".", self.topic_delimiter) - pb_sync.publish(self.msg) + pb_base.publish(self.msg) assert sync_wait(lambda: isinstance(received["log"], str)) assert ( received["log"] @@ -599,7 +594,7 @@ def test_logger_body(self, backend) -> None: ) received["log"] = None result = self.msg.make_result() - pb_sync.publish_result(result) + pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) assert ( received["log"] @@ -609,7 +604,7 @@ def test_logger_body(self, backend) -> None: return_code=pb.results.ReturnCode.FAILURE, return_value={"test": "value"} ) received["log"] = None - pb_sync.publish_result(result) + pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) assert ( received["log"] @@ -627,8 +622,8 @@ def callback_log(message, body: str) -> None: nonlocal log_msg log_msg = log_callback(message, body) - pb_sync.subscribe_logger(callback_log) - pb_sync.publish(tests.TestMessage(number=63, content="test")) + pb_base.subscribe_logger(callback_log) + pb_base.publish(tests.TestMessage(number=63, content="test")) def predicate() -> bool: return log_msg is not None @@ -643,7 +638,7 @@ def predicate() -> bool: ), log_msg # Ensure that uint64/int64 values are not converted to strings in the LoggerQueue callbacks log_msg = None - pb_sync.publish( + pb_base.publish( tests.tasks.TaskMessage( content="test", bbox=[1, 2, 3, 4], weights=[1.0, 2.0, -100, -20] ) @@ -658,22 +653,22 @@ def predicate() -> bool: @pytest.mark.flaky(max_runs=3) def test_unsubscribe(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] is not None) assert received["message"] == self.msg received["message"] = None - pb_sync.unsubscribe(tests.TestMessage, if_unused=False, if_empty=False) - pb_sync.publish(self.msg) + pb_base.unsubscribe(tests.TestMessage, if_unused=False, if_empty=False) + pb_base.publish(self.msg) assert received["message"] is None # unsubscribe from a package-level topic - pb_sync.subscribe(tests, callback_sync) - pb_sync.publish(tests.TestMessage(number=63, content="test")) + pb_base.subscribe(tests, callback_sync) + pb_base.publish(tests.TestMessage(number=63, content="test")) assert sync_wait(lambda: received["message"] is not None) received["message"] = None - pb_sync.unsubscribe(tests, if_unused=False, if_empty=False) - pb_sync.publish(self.msg) + pb_base.unsubscribe(tests, if_unused=False, if_empty=False) + pb_base.publish(self.msg) assert received["message"] is None # subscribe/unsubscribe two callbacks for two topics @@ -683,15 +678,15 @@ def callback_2(m: "ProtoBunnyMessage") -> None: nonlocal received2 received2 = m - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.subscribe(tests, callback_2) - pb_sync.publish(self.msg) # this will reach callback_2 as well + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.subscribe(tests, callback_2) + pb_base.publish(self.msg) # this will reach callback_2 as well assert sync_wait(lambda: received["message"] and received2) assert received["message"] == received2 == self.msg - pb_sync.unsubscribe_all() + pb_base.unsubscribe_all() received["message"] = None received2 = None - pb_sync.publish(self.msg) + pb_base.publish(self.msg) assert received["message"] is None assert received2 is None @@ -708,18 +703,18 @@ def callback_results_2(m: pb.results.Result) -> None: nonlocal received_result received_result = m - pb_sync.unsubscribe_all() - pb_sync.subscribe(tests.TestMessage, callback_2) + pb_base.unsubscribe_all() + pb_base.subscribe(tests.TestMessage, callback_2) # subscribe to the result topic - pb_sync.subscribe_results(tests.TestMessage, callback_results_2) + pb_base.subscribe_results(tests.TestMessage, callback_results_2) msg = tests.TestMessage(number=63, content="test") - pb_sync.publish(msg) + pb_base.publish(msg) assert sync_wait(lambda: received_result is not None) assert received_result.source == msg assert received_result.return_code == pb.results.ReturnCode.FAILURE - pb_sync.unsubscribe_results(tests.TestMessage) + pb_base.unsubscribe_results(tests.TestMessage) received_result = None - pb_sync.publish(msg) + pb_base.publish(msg) assert received_result is None @pytest.mark.flaky(max_runs=3) @@ -740,25 +735,25 @@ def callback_results_2(m: pb.results.Result) -> None: nonlocal received_result received_result = m - pb_sync.unsubscribe_all() - q1 = pb_sync.subscribe(tests.TestMessage, callback_1) - q2 = pb_sync.subscribe(tests.tasks.TaskMessage, callback_2) + pb_base.unsubscribe_all() + q1 = pb_base.subscribe(tests.TestMessage, callback_1) + q2 = pb_base.subscribe(tests.tasks.TaskMessage, callback_2) assert q1.topic == "acme.tests.TestMessage".replace(".", self.topic_delimiter) assert q2.topic == "acme.tests.tasks.TaskMessage".replace(".", self.topic_delimiter) assert q1.subscription is not None assert q2.subscription is not None # subscribe to a result topic - pb_sync.subscribe_results(tests.TestMessage, callback_results_2) - pb_sync.publish(tests.TestMessage(number=2, content="test")) - pb_sync.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) + pb_base.subscribe_results(tests.TestMessage, callback_results_2) + pb_base.publish(tests.TestMessage(number=2, content="test")) + pb_base.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) assert sync_wait(lambda: received_message is not None) assert sync_wait(lambda: received_result is not None) assert received_result.source == tests.TestMessage(number=2, content="test") - pb_sync.unsubscribe_all() + pb_base.unsubscribe_all() received_result = None received_message = None - pb_sync.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) - pb_sync.publish(tests.TestMessage(number=2, content="test")) + pb_base.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) + pb_base.publish(tests.TestMessage(number=2, content="test")) assert received_message is None assert received_result is None diff --git a/tests/test_publish.py b/tests/test_publish.py index 8e5933d..d6d2e11 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -12,10 +12,10 @@ @pytest.fixture(autouse=True) def setup_config(mocker, test_config) -> None: test_config.mode = "sync" - mocker.patch.object(pb.config, "default_configuration", test_config) - mocker.patch.object(pb.models, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb.helpers, "default_configuration", test_config) + mocker.patch.object(pb.conf, "config", test_config) + mocker.patch.object(pb.models, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb.helpers, "config", test_config) def test_sync_send_message(mock_sync_rmq_connection: MagicMock) -> None: diff --git a/tests/test_queues.py b/tests/test_queues.py index df62412..f7ae5f5 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -20,7 +20,7 @@ from protobunny.backends import python as python_backend from protobunny.backends import rabbitmq as rabbitmq_backend from protobunny.backends import redis as redis_backend -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.helpers import ( get_queue, ) @@ -59,16 +59,16 @@ async def mock_connections( test_config.log_task_in_redis = True test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) + if hasattr(backend.connection, "config"): + mocker.patch.object(backend.connection, "config", test_config) + if hasattr(backend.queues, "config"): + mocker.patch.object(backend.queues, "config", test_config) pb.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) @@ -263,16 +263,16 @@ def mock_connection( test_config.log_task_in_redis = True test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends.redis.connection, "config", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) + if hasattr(backend.connection, "config"): + mocker.patch.object(backend.connection, "config", test_config) + if hasattr(backend.queues, "config"): + mocker.patch.object(backend.queues, "config", test_config) pb_base.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) diff --git a/tests/test_results.py b/tests/test_results.py index 551ba14..004ab2b 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -7,7 +7,7 @@ import protobunny as pb from protobunny.backends.rabbitmq.connection import Connection -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.models import ( deserialize_result_message, get_message_class_from_topic, @@ -24,12 +24,12 @@ def setup_connections(mocker: MockerFixture, mock_sync_rmq_connection, test_conf test_config.mode = "sync" test_config.backend = "rabbitmq" test_config.backend_config = backend_configs["rabbitmq"] - mocker.patch.object(pb.config, "default_configuration", test_config) - mocker.patch.object(pb.models, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb.helpers, "default_configuration", test_config) + mocker.patch.object(pb.conf, "config", test_config) + mocker.patch.object(pb.models, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb.helpers, "config", test_config) pb.backend = rabbitmq_backend - mocker.patch.object(pb.helpers.default_configuration, "backend", "rabbitmq") + mocker.patch.object(pb.helpers.config, "backend", "rabbitmq") queue = pb.get_queue(tests.TestMessage) assert isinstance(queue, rabbitmq_backend.queues.SyncQueue) assert isinstance(queue.get_connection(), Connection) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 5701c7e..583a9ab 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,4 +1,5 @@ import asyncio +import importlib import logging import time import typing as tp @@ -7,15 +8,8 @@ import protobunny as pb_sync from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio -from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio -from protobunny.asyncio.backends import redis as redis_backend_aio -from protobunny.backends import mosquitto as mosquitto_backend -from protobunny.backends import python as python_backend -from protobunny.backends import rabbitmq as rabbitmq_backend -from protobunny.backends import redis as redis_backend -from protobunny.config import backend_configs +from protobunny import get_backend +from protobunny.conf import Config, backend_configs from protobunny.models import ProtoBunnyMessage from . import tests @@ -26,9 +20,7 @@ @pytest.mark.integration @pytest.mark.asyncio -@pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestTasks: msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) received = { @@ -37,36 +29,38 @@ class TestTasks: } @pytest.fixture(autouse=True) - async def setup_test_env(self, mocker, test_config, backend) -> tp.AsyncGenerator[None, None]: - backend_name = backend.__name__.split(".")[-1] + async def setup_test_env( + self, mocker, test_config: Config, backend: str + ) -> tp.AsyncGenerator[None, None]: test_config.mode = "async" - test_config.backend = backend_name - test_config.backend_config = backend_configs[backend_name] + test_config.backend = backend + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb.backend = backend - mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb, "get_backend", return_value=backend) + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb_sync.backends, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + backend_module = get_backend() + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + pb.backend = backend_module + mocker.patch("protobunny.helpers.get_backend", return_value=backend_module) + mocker.patch.object(pb, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend connection = await pb.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) queue = pb.get_queue(self.msg) assert queue.topic == "acme.tests.tasks.TaskMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.AsyncQueue) - assert isinstance(await queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.AsyncQueue) + assert isinstance(await queue.get_connection(), backend_module.connection.Connection) await queue.purge(reset_groups=True) # start without pending subscriptions await pb.unsubscribe_all(if_unused=False, if_empty=False) @@ -75,7 +69,7 @@ async def setup_test_env(self, mocker, test_config, backend) -> tp.AsyncGenerato yield await connection.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() async def test_tasks(self, backend) -> None: async def predicate_1() -> bool: @@ -97,30 +91,28 @@ async def callback_task_2(msg: "ProtoBunnyMessage") -> None: await pb.subscribe(tests.tasks.TaskMessage, callback_task_1) await pb.subscribe(tests.tasks.TaskMessage, callback_task_2) await pb.publish(self.msg) - assert await async_wait(predicate_1) - - assert self.received.get("task_2") is None + assert await async_wait(predicate_1) or await async_wait(predicate_2) + assert self.received.get("task_2") is None or self.received.get("task_1") is None self.received["task_1"] = None + self.received["task_2"] = None + await pb.publish(self.msg) await pb.publish(self.msg) await pb.publish(self.msg) - assert await async_wait(predicate_1, timeout=2, sleep=0.1) - assert await async_wait(predicate_2) - assert self.received["task_1"] == self.msg - assert self.received["task_2"] == self.msg + assert await async_wait(predicate_1) or await async_wait(predicate_2) + assert self.received["task_1"] == self.msg or self.received["task_2"] == self.msg + + await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) self.received["task_1"] = None self.received["task_2"] = None await pb.publish(self.msg) await pb.publish(self.msg) - assert await async_wait(predicate_1) - assert await async_wait(predicate_2) - await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) + assert self.received["task_1"] is None + assert self.received["task_2"] is None @pytest.mark.integration -@pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestTasksSync: msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) received = { @@ -129,49 +121,55 @@ class TestTasksSync: } @pytest.fixture(autouse=True) - def setup_test_env(self, mocker, test_config, backend) -> tp.Generator[None, None, None]: - backend_name = backend.__name__.split(".")[-1] + def setup_test_env( + self, mocker, test_config: Config, backend: str + ) -> tp.Generator[None, None, None]: test_config.mode = "sync" - test_config.backend = backend_name - test_config.backend_config = backend_configs[backend_name] + test_config.backend = backend + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb_sync.backend = backend - mocker.patch("protobunny.backends.get_backend", return_value=backend) - mocker.patch("protobunny.helpers.get_backend", return_value=backend) - - # Assert the patching is working for setting the backend + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.backends, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + + backend_module = get_backend() + async_backend_module = importlib.import_module(f"protobunny.asyncio.backends.{backend}") + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + if hasattr(async_backend_module.connection, "config"): + # The sync connection is often implemented as a wrapper of the relative async module. + # Patch the config of the async module as well + mocker.patch.object(async_backend_module.connection, "config", test_config) + + mocker.patch("protobunny.helpers.get_backend", return_value=backend_module) + + # Test the patching is working for setting the backend connection = pb_sync.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) assert sync_wait(connection.is_connected) task_queue = pb_sync.get_queue(self.msg) assert task_queue.topic == "acme.tests.tasks.TaskMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(task_queue, backend.queues.SyncQueue) - assert isinstance(task_queue.get_connection(), backend.connection.Connection) + assert isinstance(task_queue, backend_module.queues.SyncQueue) + assert isinstance(task_queue.get_connection(), backend_module.connection.Connection) task_queue.purge(reset_groups=True) # start without pending subscriptions pb_sync.unsubscribe_all(if_unused=False, if_empty=False) pb_sync.disconnect() # reset the variables holding the messages received self.received = {} - yield pb_sync.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() + @pytest.mark.flaky(max_runs=3) def test_tasks(self, backend) -> None: def predicate_1() -> bool: return self.received.get("task_1") is not None @@ -191,22 +189,19 @@ def callback_task_2(msg: "ProtoBunnyMessage") -> None: pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_2) pb_sync.publish(self.msg) - assert sync_wait(predicate_1) + assert sync_wait(predicate_1) or sync_wait(predicate_2) - assert self.received.get("task_2") is None + assert self.received.get("task_2") is None or self.received.get("task_1") is None self.received["task_1"] = None + self.received["task_2"] = None pb_sync.publish(self.msg) pb_sync.publish(self.msg) pb_sync.publish(self.msg) - assert sync_wait(predicate_1) - assert sync_wait(predicate_2) + assert sync_wait(predicate_1) or sync_wait(predicate_2) + assert sync_wait(predicate_2) or sync_wait(predicate_1) assert self.received["task_1"] == self.msg assert self.received["task_2"] == self.msg self.received["task_1"] = None self.received["task_2"] = None pb_sync.publish(self.msg) - pb_sync.publish(self.msg) - assert sync_wait(predicate_1) - assert sync_wait(predicate_2) - self.received["task_2"] = None - self.received["task_1"] = None + assert sync_wait(predicate_1) or sync_wait(predicate_2) diff --git a/tests/utils.py b/tests/utils.py index 3d130bd..88714d9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import asyncio import time -from types import ModuleType from unittest.mock import ANY, AsyncMock from aio_pika import Message @@ -43,16 +42,15 @@ async def tear_down(event_loop): def incoming_message_factory(backend, body: bytes = b"Hello"): - backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": + if backend == "rabbitmq": return Message(body=body) - elif backend_name == "redis": + elif backend == "redis": return Envelope(body=body, correlation_id="123") return Envelope(body=body) async def assert_backend_publish( - backend: ModuleType, + backend: str, internal_mock: dict | redis.Redis | AsyncMock, mock_connection: "BaseConnection", backend_msg, @@ -60,8 +58,7 @@ async def assert_backend_publish( count_in_queue: int = 1, shared_queue: bool = False, ): - backend_name = backend.__name__.split(".")[-1] - match backend_name: + match backend: case "rabbitmq": internal_mock["exchange"].publish.assert_awaited_with( backend_msg, routing_key=topic, mandatory=True, immediate=False @@ -127,15 +124,14 @@ async def predicate(): ) else: internal_mock["js"].publish.assert_awaited_once_with( - subject="protobunny.test.tasks.key", payload=b"Hello", headers=None + subject="TASKS.protobunny.test.tasks.key", payload=b"Hello", headers=None ) async def assert_backend_setup_queue( backend, internal_mock, topic: str, shared: bool, mock_connection ) -> None: - backend_name = backend.__name__.split(".")[-1] - match backend_name: + match backend: case "rabbitmq": internal_mock["channel"].declare_queue.assert_called_with( topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY @@ -175,58 +171,17 @@ async def assert_backend_setup_queue( }, mock_connection.queues case "nats": internal_mock["js"].subscribe.assert_awaited_once_with( - subject="protobunny.mylib.tasks.TaskMessage", - durable="protobunny_protobunny_mylib_tasks_TaskMessage", + subject="TASKS.protobunny.mylib.tasks.TaskMessage", + queue="protobunny_mylib_tasks_TaskMessage", + durable="protobunny_mylib_tasks_TaskMessage", cb=ANY, manual_ack=True, stream="PROTOBUNNY_TASKS", ) - # if backend_name == "rabbitmq": - # internal_mock["channel"].declare_queue.assert_called_with( - # topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY - # ) - # elif backend_name == "redis": - # if shared: - # streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") - # assert len(streams) == 1 - # assert ( - # streams[0]["name"].decode() == "shared_group" - # ), f"Expected 'shared_group', got '{streams[0]['name']}'" - # else: - # assert mock_connection.queues[topic] == {"is_shared": False} - # elif backend_name == "nats": - # assert False # TODO - # - # elif backend_name == "python": - # if not shared: - # assert len(internal_mock._exclusive_queues.get(topic)) == 1 - # else: - # assert internal_mock._shared_queues.get(topic) - # elif backend_name == "mosquitto": - # if not shared: - # assert mock_connection.queues[topic] == { - # "is_shared": False, - # "group_name": "", - # "sub_key": f"$share/shared_group/test/{topic}", - # "tag": ANY, - # "topic": topic, - # "topic_key": f"test/{topic}", - # } - # else: - # assert list(mock_connection.queues.values())[0] == { - # "is_shared": True, - # "group_name": "shared_group", - # "sub_key": f"$share/shared_group/protobunny/{topic}", - # "tag": ANY, - # "topic": topic, - # "topic_key": f"protobunny/{topic}", - # }, mock_connection.queues - async def assert_backend_connection(backend, internal_mock): - backend_name = backend.__name__.split(".")[-1] - match backend_name: + match backend: case "rabbitmq": # Verify aio_pika calls internal_mock["connect"].assert_awaited_once() @@ -242,32 +197,16 @@ async def assert_backend_connection(backend, internal_mock): case "nats": import nats - nats.connect.assert_awaited_once_with( - "nats://localhost:4222/", connect_timeout=30.0, max_reconnect_attempts=3 - ) - # internal_mock["client"].connect.assert_awaited_once() + nats.connect.assert_awaited_once_with("nats://localhost:4222/") - # if backend_name == "rabbitmq": - # # Verify aio_pika calls - # internal_mock["connect"].assert_awaited_once() - # assert internal_mock["channel"].set_qos.called - # # Check if main and DLX exchanges were declared - # assert internal_mock["channel"].declare_exchange.call_count == 2 - # elif backend_name == "redis": - # assert await internal_mock.ping() - # elif backend_name == "mosquitto": - # internal_mock.__aenter__.assert_awaited_once() - # elif backend_name == "nats": - # internal_mock.connect.assert_awaited_once() - # assert True - - -def get_mocked_connection(backend, redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats): - backend_name = backend.__name__.split(".")[-1] +def get_mocked_connection( + backend_module, redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats +): + backend_name = backend_module.__name__.split(".")[-1] match backend_name: case "redis": - real_conn_with_fake_redis = backend.connection.Connection( + real_conn_with_fake_redis = backend_module.connection.Connection( url="redis://localhost:6379/0" ) assert ( @@ -284,11 +223,11 @@ def check_connected() -> bool: ) return real_conn_with_fake_redis case "nats": - real_conn_with_fake_nats = backend.connection.Connection() + real_conn_with_fake_nats = backend_module.connection.Connection() real_conn_with_fake_nats._connection = mock_nats["client"] return real_conn_with_fake_nats case "rabbitmq": - real_conn_with_fake_aio_pika = backend.connection.Connection( + real_conn_with_fake_aio_pika = backend_module.connection.Connection( url="amqp://guest:guest@localhost:5672/" ) real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] @@ -298,48 +237,11 @@ def check_connected() -> bool: real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] return real_conn_with_fake_aio_pika case "python": - python_conn = backend.connection.Connection() + python_conn = backend_module.connection.Connection() python_conn.is_connected_event.set() return python_conn case "mosquitto": - real_conn_with_fake_aiomqtt = backend.connection.Connection() + real_conn_with_fake_aiomqtt = backend_module.connection.Connection() real_conn_with_fake_aiomqtt._connection = mock_mosquitto return real_conn_with_fake_aiomqtt - - # if backend_name == "redis": - # real_conn_with_fake_redis = backend.connection.Connection(url="redis://localhost:6379/0") - # assert ( - # real_conn_with_fake_redis._exchange == "protobunny" - # ), real_conn_with_fake_redis._exchange - # real_conn_with_fake_redis._connection = redis_client - # - # def check_connected() -> bool: - # return real_conn_with_fake_redis._connection is not None - # - # # Patch is_connected logic - # mocker.patch.object(real_conn_with_fake_redis, "is_connected", side_effect=check_connected) - # - # return real_conn_with_fake_redis - # elif backend_name == "rabbitmq": - # real_conn_with_fake_aio_pika = backend.connection.Connection( - # url="amqp://guest:guest@localhost:5672/" - # ) - # real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] - # real_conn_with_fake_aio_pika.is_connected_event.set() - # real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] - # real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] - # real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] - # - # return real_conn_with_fake_aio_pika - # elif backend_name == "python": - # python_conn = backend.connection.Connection() - # python_conn.is_connected_event.set() - # return python_conn - # elif backend_name == "mosquitto": - # real_conn_with_fake_aiomqtt = backend.connection.Connection() - # real_conn_with_fake_aiomqtt._connection = mock_mosquitto - # return real_conn_with_fake_aiomqtt - # elif backend_name == "nats": - # real_conn_with_fake_nats = backend.connection.Connection() - # real_conn_with_fake_nats._connection = mock_nats["client"] - # return real_conn_with_fake_nats + return None From b59b5c5dc4923b1593c51a065ae4225e07c2ca3e Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Wed, 31 Dec 2025 23:04:06 +0100 Subject: [PATCH 3/7] Debug CI --- .github/workflows/ci.yml | 50 ++++++++++++++++--- .../asyncio/backends/nats/connection.py | 2 +- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee05c42..6738185 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -239,28 +239,62 @@ jobs: integration_test_nats: runs-on: ubuntu-latest - services: - redis: - image: nats:latest - ports: - - 4222:4222 - options: >- - -js steps: - uses: actions/checkout@v6 + - name: Run NATS server + run: | + docker run -d -p 4222:4222 --name nats nats:2-alpine \ + nats --config /etc/nats/nats-server.conf -js + + - name: Log NATS container + run: | + sleep 3 + echo "=== Container logs ===" + docker logs nats + echo "=== Config file ===" + docker exec nats cat /etc/nats/nats-server.conf + echo "=== Processes ===" + docker exec nats ps aux + echo "=== Network ===" + docker exec nats netstat -tlnp || docker exec nats ss -tlnp + - name: Install uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "uv.lock" + - name: Set up Python 3.12 run: uv python install 3.12 + - name: Create virtual environment run: uv venv --python 3.12 + - name: Install dependencies run: uv sync --all-extras + + - name: Smoke test NATS + run: | + # Wait for the port to be open + timeout 30s sh -c 'until nc -z localhost 4222; do sleep 1; done' + + uv run python -c " + import nats + import sys + import asyncio + + async def smoke_test(): + try: + nats.connect() + print('NATS connected Successfully') + sys.exit(0) + except Exception as e: + print(f'Failed to connect: {e}') + sys.exit(1) + + asyncio.run(smoke_test()) + " - name: Run tests run: uv run python -m pytest tests/test_integration.py -k nats -vvv -s env: NATS_HOST: localhost - NATS_PORT: 6379 diff --git a/protobunny/asyncio/backends/nats/connection.py b/protobunny/asyncio/backends/nats/connection.py index 9d81a10..55baa4d 100644 --- a/protobunny/asyncio/backends/nats/connection.py +++ b/protobunny/asyncio/backends/nats/connection.py @@ -181,7 +181,7 @@ async def connect(self, **kwargs) -> "Connection": return self except asyncio.TimeoutError as e: - log.error("NATS connection timeout after %.1f seconds", timeout) + log.error("NATS connection timeout") self.is_connected_event.clear() self._connection = None raise ConnectionError(f"Failed to connect to NATS: {e}") from e From 62c65acfdc689b613dc8f14413e9425dd939019f Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Wed, 31 Dec 2025 23:10:53 +0100 Subject: [PATCH 4/7] Debug CI --- .github/workflows/ci.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6738185..51971b4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,8 +34,8 @@ jobs: - name: Install dependencies run: | if [ "${{ matrix.python-version }}" == "3.13" ]; then - # Sync everything EXCEPT the 'numpy' extra - uv sync --extra rabbitmq --extra redis --dev --extra mosquitto + # Sync everything EXCEPT the 'numpy' extra that currently has problems on github + uv sync --extra rabbitmq --extra redis --dev --extra mosquitto --extra nats else # Install everything for older versions uv sync --all-extras @@ -243,8 +243,7 @@ jobs: - uses: actions/checkout@v6 - name: Run NATS server run: | - docker run -d -p 4222:4222 --name nats nats:2-alpine \ - nats --config /etc/nats/nats-server.conf -js + docker run -d -p 4222:4222 --name nats nats:2-alpine -js - name: Log NATS container run: | From 030c47286af5b37648154a433f77bec32fec59af Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Wed, 31 Dec 2025 23:47:49 +0100 Subject: [PATCH 5/7] Minor edits --- protobunny/__init__.py | 27 ++++++++++++++++-------- protobunny/__init__.py.j2 | 23 +++++++++++++++------ protobunny/asyncio/__init__.py | 34 +++++++++++++++++++------------ protobunny/asyncio/__init__.py.j2 | 29 +++++++++++++++----------- tests/test_integration.py | 26 +++++++++-------------- 5 files changed, 83 insertions(+), 56 deletions(-) diff --git a/protobunny/__init__.py b/protobunny/__init__.py index f040de1..503d943 100644 --- a/protobunny/__init__.py +++ b/protobunny/__init__.py @@ -70,10 +70,12 @@ ############################ -def connect() -> "BaseSyncConnection": - """Get the singleton async connection.""" +def connect(**kwargs: tp.Any) -> "BaseSyncConnection": + """Connect teh backend and get the singleton async connection. + You can pass the specific keyword arguments that you would pass to the `connect` function of the configured backend. + """ connection_module = get_backend().connection - conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) return conn @@ -83,11 +85,12 @@ def disconnect() -> None: conn.disconnect() -def reset_connection() -> "BaseSyncConnection": +def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() + return conn.connect(**kwargs) def publish(message: "ProtoBunnyMessage") -> None: @@ -137,14 +140,12 @@ def subscribe( queue = get_queue(pkg_or_msg) if queue.shared_queue: # It's a task. Handle multiple subscriptions - # queue = get_queue(pkg_or_msg) queue.subscribe(callback) registry.register_task(register_key, queue) else: # exclusive queue queue = registry.get_subscription(register_key) or queue queue.subscribe(callback) - # register subscription to unsubscribe later registry.register_subscription(register_key, queue) return queue @@ -229,6 +230,14 @@ def get_message_count( return count +def get_consumer_count( + msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = q.get_consumer_count() + return count + + def is_module_tasks(module_name: str) -> bool: return "tasks" in module_name.split(".") diff --git a/protobunny/__init__.py.j2 b/protobunny/__init__.py.j2 index d4bbd7e..1573132 100644 --- a/protobunny/__init__.py.j2 +++ b/protobunny/__init__.py.j2 @@ -69,8 +69,10 @@ log = logging.getLogger(PACKAGE_NAME) ############################ -def connect(**kwargs) -> "BaseSyncConnection": - """Get the singleton async connection.""" +def connect(**kwargs: tp.Any) -> "BaseSyncConnection": + """Connect teh backend and get the singleton async connection. + You can pass the specific keyword arguments that you would pass to the `connect` function of the configured backend. + """ connection_module = get_backend().connection conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) return conn @@ -82,11 +84,12 @@ def disconnect() -> None: conn.disconnect() -def reset_connection() -> "BaseSyncConnection": +def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() + return conn.connect(**kwargs) def publish(message: "ProtoBunnyMessage") -> None: @@ -226,6 +229,14 @@ def get_message_count( return count +def get_consumer_count( + msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = q.get_consumer_count() + return count + + def is_module_tasks(module_name: str) -> bool: return "tasks" in module_name.split(".") diff --git a/protobunny/asyncio/__init__.py b/protobunny/asyncio/__init__.py index e420cef..9d4912c 100644 --- a/protobunny/asyncio/__init__.py +++ b/protobunny/asyncio/__init__.py @@ -86,10 +86,12 @@ ############################ -async def connect() -> "BaseAsyncConnection": +async def connect(**kwargs) -> "BaseAsyncConnection": """Get the singleton async connection.""" connection_module = get_backend().connection - conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn = await connection_module.Connection.get_connection( + vhost=connection_module.VHOST, **kwargs + ) return conn @@ -160,11 +162,11 @@ async def subscribe( await queue.subscribe(callback) registry.register_task(registry_key, queue) else: - # exclusive queue + # exclusive queue, cannot register more than one callback queue = registry.get_subscription(registry_key) or get_queue(pkg) - # queue already exists, but not subscribed yet (otherwise raise ValueError) - await queue.subscribe(callback) - registry.register_subscription(registry_key, queue) + if not queue.subscription: + await queue.subscribe(callback) + registry.register_subscription(registry_key, queue) return queue @@ -175,7 +177,6 @@ async def unsubscribe( ) -> None: """Remove a subscription for a message/package""" - # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = registry.get_key(pkg) async with registry.lock: @@ -211,11 +212,10 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None """ async with registry.lock: queues = itertools.chain( - registry.get_all_subscriptions(), - registry.get_all_tasks(flat=True), + registry.get_all_subscriptions(), registry.get_all_tasks(flat=True) ) for queue in queues: - await queue.unsubscribe(if_unused=False, if_empty=False) + await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) registry.unregister_all_subscriptions() registry.unregister_all_tasks() queues = registry.get_all_results() @@ -250,6 +250,14 @@ async def get_message_count( return count +async def get_consumer_count( + msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = await q.get_consumer_count() + return count + + def default_log_callback(message: "IncomingMessageProtocol", msg_content: str) -> None: """Default callback for the logging service""" log.info( @@ -288,15 +296,15 @@ async def shutdown(signum: int) -> None: stop_event.set() for sig in (signal.SIGINT, signal.SIGTERM): - # Note: add_signal_handler requires a callback, so we use a lambda + def _handler(s: int) -> asyncio.Task[None]: return asyncio.create_task(shutdown(s)) loop.add_signal_handler(sig, _handler, sig) - log.info("Started. Press Ctrl+C to exit.") - # Wait here forever (non-blocking) until shutdown() is called + log.info("Protobunny started") await main() + # Wait here forever (non-blocking) until shutdown() is called await stop_event.wait() diff --git a/protobunny/asyncio/__init__.py.j2 b/protobunny/asyncio/__init__.py.j2 index 53b87c3..cd25e02 100644 --- a/protobunny/asyncio/__init__.py.j2 +++ b/protobunny/asyncio/__init__.py.j2 @@ -80,10 +80,10 @@ log = logging.getLogger(PACKAGE_NAME) # -- Async top-level methods ############################ -async def connect() -> "BaseAsyncConnection": +async def connect(**kwargs) -> "BaseAsyncConnection": """Get the singleton async connection.""" connection_module = get_backend().connection - conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) return conn @@ -154,11 +154,11 @@ async def subscribe( await queue.subscribe(callback) registry.register_task(registry_key, queue) else: - # exclusive queue + # exclusive queue, cannot register more than one callback queue = registry.get_subscription(registry_key) or get_queue(pkg) - # queue already exists, but not subscribed yet (otherwise raise ValueError) - await queue.subscribe(callback) - registry.register_subscription(registry_key, queue) + if not queue.subscription: + await queue.subscribe(callback) + registry.register_subscription(registry_key, queue) return queue @@ -169,7 +169,6 @@ async def unsubscribe( ) -> None: """Remove a subscription for a message/package""" - # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = registry.get_key(pkg) async with registry.lock: @@ -208,7 +207,7 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None registry.get_all_subscriptions(), registry.get_all_tasks(flat=True) ) for queue in queues: - await queue.unsubscribe(if_unused=False, if_empty=False) + await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) registry.unregister_all_subscriptions() registry.unregister_all_tasks() queues = registry.get_all_results() @@ -243,6 +242,14 @@ async def get_message_count( return count +async def get_consumer_count( + msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = await q.get_consumer_count() + return count + + def default_log_callback(message: "IncomingMessageProtocol", msg_content: str) -> None: """Default callback for the logging service""" log.info( @@ -281,15 +288,13 @@ async def _run_forever(main: tp.Callable[..., tp.Awaitable[None]]) -> None: stop_event.set() for sig in (signal.SIGINT, signal.SIGTERM): - # Note: add_signal_handler requires a callback, so we use a lambda def _handler(s: int) -> asyncio.Task[None]: return asyncio.create_task(shutdown(s)) - loop.add_signal_handler(sig, _handler, sig) - log.info("Started. Press Ctrl+C to exit.") - # Wait here forever (non-blocking) until shutdown() is called + log.info("Protobunny started") await main() + # Wait here forever (non-blocking) until shutdown() is called await stop_event.wait() diff --git a/tests/test_integration.py b/tests/test_integration.py index 570ba6f..8724392 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -367,7 +367,7 @@ async def predicate() -> bool: assert await async_wait(predicate, timeout=2, sleep=0.1) assert received["message"] == received2 == self.msg - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_empty=False, if_unused=False) received["message"] = None received2 = None await pb.publish(self.msg) @@ -422,7 +422,7 @@ async def callback_results(m: pb.results.Result) -> None: nonlocal received_result received_result = m - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_unused=False, if_empty=False) q1 = await pb.subscribe(tests.TestMessage, callback_1) q2 = await pb.subscribe(tests.tasks.TaskMessage, callback_2) assert q1.topic == "acme.tests.TestMessage".replace(".", self.topic_delimiter) @@ -440,7 +440,7 @@ async def predicate() -> bool: assert await async_wait(predicate, timeout=2, sleep=0.2) assert received_result.source == tests.TestMessage(number=2, content="test") - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_empty=False, if_unused=False) received_result = None received_message = None await pb.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) @@ -455,6 +455,9 @@ class TestIntegrationSync: """Integration tests (to run with the backend server up)""" msg = tests.TestMessage(content="test", number=123, color=tests.Color.GREEN) + expected_json = ( + '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' + ) @pytest.fixture(autouse=True) def setup_test_env( @@ -533,7 +536,7 @@ def test_to_dict(self, backend) -> None: received["message"].to_json( casing=betterproto.Casing.SNAKE, include_default_values=True ) - == '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' + == self.expected_json ) msg = tests.tasks.TaskMessage( @@ -588,28 +591,19 @@ def test_logger_body(self, backend) -> None: pb_base.publish(self.msg) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic}: {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic}: {self.expected_json}" received["log"] = None result = self.msg.make_result() pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic_result}: SUCCESS - {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic_result}: SUCCESS - {self.expected_json}" result = self.msg.make_result( return_code=pb.results.ReturnCode.FAILURE, return_value={"test": "value"} ) received["log"] = None pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic_result}: FAILURE - error: [] - {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic_result}: FAILURE - error: [] - {self.expected_json}" @pytest.mark.flaky(max_runs=3) def test_logger_int64(self, backend) -> None: From 36dccca917372cfbf8d145881935aa2767dca481 Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Thu, 1 Jan 2026 01:35:15 +0100 Subject: [PATCH 6/7] Some progress on typing --- Makefile | 1 + QUICK_START.md | 6 ++-- README.md | 19 ++++++------ RECIPES.md | 19 ++++++++++++ docs/source/conf.py | 1 + docs/source/intro.md | 49 +++++++++++++++++++++---------- docs/source/quick_start.md | 11 ++++--- protobunny/__init__.py | 21 +++++-------- protobunny/__init__.py.j2 | 16 +++++----- protobunny/asyncio/__init__.py | 21 +++++-------- protobunny/asyncio/__init__.py.j2 | 16 +++++----- protobunny/models.py | 48 +++++++++++++++--------------- scripts/convert_md.py | 6 ---- tests/test_integration.py | 4 +-- tests/test_tasks.py | 9 +++--- 15 files changed, 137 insertions(+), 110 deletions(-) delete mode 100644 scripts/convert_md.py diff --git a/Makefile b/Makefile index 5726b1a..649fc61 100644 --- a/Makefile +++ b/Makefile @@ -103,6 +103,7 @@ test-py313: copy-md: cp ./README.md docs/source/intro.md cp ./QUICK_START.md docs/source/quick_start.md + cp ./RECIPES.md docs/source/recipes.md docs: copy-md uv run sphinx-build -b html docs/source docs/build/html diff --git a/QUICK_START.md b/QUICK_START.md index 07bb4ca..8b083b8 100644 --- a/QUICK_START.md +++ b/QUICK_START.md @@ -179,13 +179,15 @@ def worker1(task: mml.main.tasks.TaskMessage) -> None: def worker2(task: mml.main.tasks.TaskMessage) -> None: print("2- Working on:", task) - -pb.subscribe(mml.main.tasks.TaskMessage, worker1) +import mymessagelib as mml +pb.subscribe(mml.main.tasks.TasqkMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) +from protobunny.models import ProtoBunnyMessage +print(isinstance(mml.main.tasks.TaskMessage(), ProtoBunnyMessage)) ``` You can also introspect/manage an underlying shared queue: diff --git a/README.md b/README.md index c692323..7e3cadc 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Protobunny -```{warning} -The project is in early development. -``` +::: {warning} +**Note**: The project is in early development. +::: Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, @@ -26,11 +26,11 @@ Supported backends in the current version are: - Mosquitto - Python "backend" with Queue/asyncio.Queue for local in-processing testing -```{note} -Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. -Direct access to the internal NATS or Redis clients is intentionally restricted. +::: {note} +**Note**: Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +Direct access to the internal NATS or Redis clients is intentionally restricted. If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. -``` +::: ## Minimal requirements @@ -63,8 +63,8 @@ While there are many messaging libraries for Python, Protobunny is built specifi * **Type-Safe by Design**: Built natively for `protobuf/betterproto`. * **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). -* **Backend Agnostic**: Write your logic once. Switch between Redis, RabbitMQ, Mosquitto, or Local Queues by changing a single variable in configuration. -* **Sync & Async**: Support for both modern `asyncio` and traditional synchronous workloads. +* **Backend Agnostic**: You can choose between RabbitMQ, Redis, NATS, and Mosquitto. Python for local testing. +* **Sync & Async**: Support for both `asyncio` and traditional synchronous workloads. * **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. --- @@ -78,6 +78,7 @@ While there are many messaging libraries for Python, Protobunny is built specifi | **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | | **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | +--- ## Usage diff --git a/RECIPES.md b/RECIPES.md index e69de29..2c04fe4 100644 --- a/RECIPES.md +++ b/RECIPES.md @@ -0,0 +1,19 @@ +# Recipes + + +## Subscribe to a queue + + +## Subscribe a task worker to a shared topic + + +## Publish + + +## Results workflow + + +## Requeuing + + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 78b9db9..e287132 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,6 +28,7 @@ "sphinx.ext.viewcode", "myst_parser", ] +myst_enable_extensions = ["colon_fence"] templates_path = ["_templates"] diff --git a/docs/source/intro.md b/docs/source/intro.md index 59f8236..5dcc3b4 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,29 +1,47 @@ # Protobunny -```{warning} -Note: The project is in early development. -``` +::: {warning} +**Warning**: The project is in early development. +::: Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for several message brokers, including Redis and MQTT. +type-safe interface for several message brokers, including Redis, NATS, and MQTT. -It simplifies messaging for asynchronous tasks by providing: +It simplifies messaging for asynchronous message handling by providing: -* A clean “message-first” API -* Python class generation from Protobuf messages using betterproto -* Connections facilities for backends +* A clean “message-first” API by using your protobuf definitions * Message publishing/subscribing with typed topics -* Support also “task-like” queues (shared/competing consumers) vs. broadcast subscriptions +* Supports "task-like” queues (shared/competing consumers) vs. broadcast subscriptions * Generate and consume `Result` messages (success/failure + optional return payload) * Transparent messages serialization/deserialization - * Support async and sync contexts -* Transparently serialize "JSON-like" payload fields (numpy-friendly) +* Transparently serialize/deserialize custom "JSON-like" payload fields (numpy-friendly) +* Support async and sync contexts + +Supported backends in the current version are: + +- RabbitMQ +- Redis +- NATS +- Mosquitto +- Python "backend" with Queue/asyncio.Queue for local in-processing testing + +::: {note} +**Note**: Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +Direct access to the internal NATS or Redis clients is intentionally restricted. +If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. +::: ## Minimal requirements -- Python >= 3.10, < 3.14 +- Python >= 3.10 <=3.13 +- Core Dependencies: betterproto 2.0.0b7, grpcio-tools>=1.62.0 +- Backend Drivers (Optional based on your usage): + - NATS: nats-py (Requires NATS Server v2.10+ for full JetStream support). + - Redis: redis (Requires Redis Server v6.2+ for Stream support). + - RabbitMQ: aio-pika + - Mosquitto: aiomqtt ## Project scope @@ -45,8 +63,8 @@ While there are many messaging libraries for Python, Protobunny is built specifi * **Type-Safe by Design**: Built natively for `protobuf/betterproto`. * **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). -* **Backend Agnostic**: Write your logic once. Switch between Redis, RabbitMQ, Mosquitto, or Local Queues by changing a single variable in configuration. -* **Sync & Async**: Support for both modern `asyncio` and traditional synchronous workloads. +* **Backend Agnostic**: You can choose between RabbitMQ, Redis, NATS, and Mosquitto. Python for local testing. +* **Sync & Async**: Support for both `asyncio` and traditional synchronous workloads. * **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. --- @@ -60,6 +78,7 @@ While there are many messaging libraries for Python, Protobunny is built specifi | **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | | **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | +--- ## Usage @@ -74,7 +93,7 @@ Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow - [x] **Semantic Patterns**: Automatic `tasks` package routing. - [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. - [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. -- [ ] **Cloud-Native**: NATS (Core & JetStream) integration. +- [x] **Cloud-Native**: NATS (Core & JetStream) integration. - [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. - [ ] **More backends**: Kafka support. diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 5baaff6..aa2240f 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -33,7 +33,7 @@ messages-directory = "messages" messages-prefix = "acme" generated-package-name = "mymessagelib.codegen" mode = "async" # or "sync" -backend = "rabbitmq" # available backends are ['rabbitmq', 'redis', 'mosquitto', 'python'] +backend = "rabbitmq" # available backends are ['rabbitmq', 'redis', 'nats', 'mosquitto', 'python'] ``` ### Install the library with `uv`, `poetry` or `pip` @@ -294,7 +294,7 @@ if conn.is_connected(): conn.close() ``` -If you set the `generated-package-root` folder option, you might need to add the path to your `sys.path`. +If you set the `generated-package-root` folder option, you might need to add that path to your `sys.path`. You can do it conveniently by calling `config_lib` on top of your module, before importing the library: ```python @@ -429,10 +429,12 @@ class TestLibAsync: await pb_asyncio.subscribe(ml.tests.TestMessage, self.on_message) await pb_asyncio.subscribe_results(ml.tests.TestMessage, self.on_message_results) await pb_asyncio.subscribe(ml.main.MyMessage, self.on_message_mymessage) - + + # Send a simple message await pb_asyncio.publish(ml.main.MyMessage(content="test")) + # Send a message with arbitrary json content await pb_asyncio.publish(ml.tests.TestMessage(number=1, content="test", data={"test": 123})) - + # Send three messages to verify the load balancing between the workers await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test1")) await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test2")) await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test3")) @@ -441,6 +443,7 @@ class TestLibAsync: class TestLib: + # Sync version def on_message(self, message: ml.tests.TestMessage) -> None: log.info("Got: %s", message) result = message.make_result() diff --git a/protobunny/__init__.py b/protobunny/__init__.py index 503d943..44d5827 100644 --- a/protobunny/__init__.py +++ b/protobunny/__init__.py @@ -53,12 +53,7 @@ if tp.TYPE_CHECKING: from .core.results import Result - from .models import ( - IncomingMessageProtocol, - LoggerCallback, - ProtoBunnyMessage, - SyncCallback, - ) + from .models import PBM, IncomingMessageProtocol, LoggerCallback, SyncCallback __version__ = version(PACKAGE_NAME) @@ -93,7 +88,7 @@ def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": return conn.connect(**kwargs) -def publish(message: "ProtoBunnyMessage") -> None: +def publish(message: "PBM") -> None: """Synchronously publish a message to its corresponding queue. This method automatically determines the correct topic based on the @@ -122,7 +117,7 @@ def publish_result( def subscribe( - pkg_or_msg: "type[ProtoBunnyMessage] | ModuleType", + pkg_or_msg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the topic. @@ -151,7 +146,7 @@ def subscribe( def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the result topic. @@ -169,7 +164,7 @@ def subscribe_results( def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: @@ -191,7 +186,7 @@ def unsubscribe( def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" with registry.sync_lock: @@ -223,7 +218,7 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_message_count() @@ -231,7 +226,7 @@ def get_message_count( def get_consumer_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_consumer_count() diff --git a/protobunny/__init__.py.j2 b/protobunny/__init__.py.j2 index 1573132..2654cda 100644 --- a/protobunny/__init__.py.j2 +++ b/protobunny/__init__.py.j2 @@ -56,7 +56,7 @@ from .helpers import get_backend, get_queue from .backends import LoggingSyncQueue, BaseSyncQueue, BaseSyncConnection if tp.TYPE_CHECKING: - from .models import LoggerCallback, ProtoBunnyMessage, SyncCallback, IncomingMessageProtocol + from .models import LoggerCallback, PBM, SyncCallback, IncomingMessageProtocol from .core.results import Result __version__ = version(PACKAGE_NAME) @@ -92,7 +92,7 @@ def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": return conn.connect(**kwargs) -def publish(message: "ProtoBunnyMessage") -> None: +def publish(message: "PBM") -> None: """Synchronously publish a message to its corresponding queue. This method automatically determines the correct topic based on the @@ -121,7 +121,7 @@ def publish_result( def subscribe( - pkg_or_msg: "type[ProtoBunnyMessage] | ModuleType", + pkg_or_msg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the topic. @@ -150,7 +150,7 @@ def subscribe( def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the result topic. @@ -168,7 +168,7 @@ def subscribe_results( def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: @@ -190,7 +190,7 @@ def unsubscribe( def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" with registry.sync_lock: @@ -222,7 +222,7 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_message_count() @@ -230,7 +230,7 @@ def get_message_count( def get_consumer_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_consumer_count() diff --git a/protobunny/asyncio/__init__.py b/protobunny/asyncio/__init__.py index 9d4912c..0d5d879 100644 --- a/protobunny/asyncio/__init__.py +++ b/protobunny/asyncio/__init__.py @@ -63,12 +63,7 @@ from types import ModuleType from ..core.results import Result - from ..models import ( - AsyncCallback, - IncomingMessageProtocol, - LoggerCallback, - ProtoBunnyMessage, - ) + from ..models import PBM, AsyncCallback, IncomingMessageProtocol, LoggerCallback from .. import config_lib as config_lib @@ -108,7 +103,7 @@ async def reset_connection() -> "BaseAsyncConnection": return await connect() -async def publish(message: "ProtoBunnyMessage") -> None: +async def publish(message: "PBM") -> None: """Asynchronously publish a message to its corresponding queue. Args: @@ -135,7 +130,7 @@ async def publish_result( async def subscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """ @@ -171,7 +166,7 @@ async def subscribe( async def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: @@ -193,7 +188,7 @@ async def unsubscribe( async def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" async with registry.lock: @@ -225,7 +220,7 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None async def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """Subscribe a callback function to the result topic. @@ -243,7 +238,7 @@ async def subscribe_results( async def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_message_count() @@ -251,7 +246,7 @@ async def get_message_count( async def get_consumer_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_consumer_count() diff --git a/protobunny/asyncio/__init__.py.j2 b/protobunny/asyncio/__init__.py.j2 index cd25e02..de61bf3 100644 --- a/protobunny/asyncio/__init__.py.j2 +++ b/protobunny/asyncio/__init__.py.j2 @@ -61,7 +61,7 @@ from ..conf import ( # noqa from ..exceptions import RequeueMessage, ConnectionError from ..registry import registry if tp.TYPE_CHECKING: - from ..models import LoggerCallback, ProtoBunnyMessage, AsyncCallback, IncomingMessageProtocol + from ..models import LoggerCallback, PBM, AsyncCallback, IncomingMessageProtocol from ..core.results import Result from types import ModuleType @@ -100,7 +100,7 @@ async def reset_connection() -> "BaseAsyncConnection": return await connect() -async def publish(message: "ProtoBunnyMessage") -> None: +async def publish(message: "PBM") -> None: """Asynchronously publish a message to its corresponding queue. Args: @@ -127,7 +127,7 @@ async def publish_result( async def subscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """ @@ -163,7 +163,7 @@ async def subscribe( async def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: @@ -185,7 +185,7 @@ async def unsubscribe( async def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" async with registry.lock: @@ -217,7 +217,7 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None async def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """Subscribe a callback function to the result topic. @@ -235,7 +235,7 @@ async def subscribe_results( async def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_message_count() @@ -243,7 +243,7 @@ async def get_message_count( async def get_consumer_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_consumer_count() diff --git a/protobunny/models.py b/protobunny/models.py index 0ba34cf..f245b5d 100644 --- a/protobunny/models.py +++ b/protobunny/models.py @@ -16,12 +16,6 @@ from .helpers import get_topic from .utils import ProtobunnyJsonEncoder -# - types -SyncCallback: tp.TypeAlias = tp.Callable[["ProtoBunnyMessage"], tp.Any] -AsyncCallback: tp.TypeAlias = tp.Callable[["ProtoBunnyMessage"], tp.Awaitable[tp.Any]] -ResultCallback: tp.TypeAlias = tp.Callable[["Result"], tp.Any] -LoggerCallback: tp.TypeAlias = tp.Callable[[tp.Any, str], tp.Any] - log = logging.getLogger(__name__) @@ -54,8 +48,9 @@ def validate_required_fields(self: "ProtoBunnyMessage") -> None: if missing: raise MissingRequiredFields(self, missing) - @functools.cached_property - def json_content_fields(self: "ProtoBunnyMessage") -> list[str]: + # @functools.cached_property + @property + def json_content_fields(self: "ProtoBunnyMessage") -> tp.Iterable[str]: """Returns: the list of fieldnames that are of type commons.JsonContent.""" return [ field_name @@ -74,12 +69,12 @@ def __bytes__(self: "ProtoBunnyMessage") -> bytes: betterproto.Message.dump(msg, stream) return stream.getvalue() - def from_dict(self: "ProtoBunnyMessage", value: dict) -> "ProtoBunnyMessage": + def from_dict(self: "ProtoBunnyMessage", value: dict) -> "PBM": json_fields = {field: value.pop(field, None) for field in self.json_content_fields} msg = betterproto.Message.from_dict(tp.cast(betterproto.Message, self), value) for field in json_fields: setattr(msg, field, json_fields[field]) - return msg + return tp.cast(PBM, msg) def to_dict( self: "ProtoBunnyMessage", @@ -109,7 +104,7 @@ def _to_dict_with_json_content(self, betterproto_func: tp.Callable[..., tp.Any]) return out_dict def to_pydict( - self: "ProtoBunnyMessage", + self: "PBM", casing: tp.Callable[[str, bool], str] = betterproto.Casing.CAMEL, include_default_values: bool = False, ) -> dict[str, tp.Any]: @@ -127,9 +122,7 @@ def to_pydict( out_dict = self._use_enum_names(casing, out_dict) return out_dict - def _use_enum_names( - self: "ProtoBunnyMessage", casing, out_dict: dict[str, tp.Any] - ) -> dict[str, tp.Any]: + def _use_enum_names(self: "PBM", casing, out_dict: dict[str, tp.Any]) -> dict[str, tp.Any]: """Used to reprocess betterproto.Message.to_pydict output to use names for Enum fields. Process only first level fields. @@ -181,7 +174,7 @@ def _process_enum_field(): return updated_out_enums def to_json( - self: "ProtoBunnyMessage", + self: "PBM", indent: None | int | str = None, include_default_values: bool = False, casing: tp.Callable[[str, bool], str] = betterproto.Casing.CAMEL, @@ -193,7 +186,7 @@ def to_json( cls=ProtobunnyJsonEncoder, ) - def parse(self: "ProtoBunnyMessage", data: bytes) -> "ProtoBunnyMessage": + def parse(self: "PBM", data: bytes) -> "PBM": # Override Message.parse() method # to support transparent deserialization of JsonContent fields json_content_fields = list(self.json_content_fields) @@ -209,12 +202,12 @@ def parse(self: "ProtoBunnyMessage", data: bytes) -> "ProtoBunnyMessage": return msg @property - def type_url(self: "ProtoBunnyMessage") -> str: + def type_url(self: "PBM") -> str: """Return the class fqn for this message.""" return f"{self.__class__.__module__}.{self.__class__.__name__}" @property - def source(self: "ProtoBunnyMessage") -> "ProtoBunnyMessage": + def source(self: "PBM") -> "PBM": """Return the source message from a Result The source message is stored as a protobuf.Any message, with its type info and serialized value. @@ -227,19 +220,19 @@ def source(self: "ProtoBunnyMessage") -> "ProtoBunnyMessage": return source_message @functools.cached_property - def topic(self: "ProtoBunnyMessage") -> str: + def topic(self: "PBM") -> str: """Build the topic name for the message.""" return get_topic(self) @functools.cached_property - def result_topic(self: "ProtoBunnyMessage") -> str: + def result_topic(self: "PBM") -> str: """ Build the result topic name for the message. """ return f"{get_topic(self)}{config.backend_config.topic_delimiter}result" def make_result( - self: "ProtoBunnyMessage", + self: "PBM", return_code: "ReturnCode | int | None" = None, error: str = "", return_value: dict[str, tp.Any] | None = None, @@ -289,7 +282,7 @@ class ProtoBunnyMessage(MessageMixin, betterproto.Message): class MissingRequiredFields(Exception): """Exception raised by MessageMixin.validate_required_fields when required fields are missing.""" - def __init__(self, msg: "ProtoBunnyMessage", missing_fields: list[str]) -> None: + def __init__(self, msg: "PBM", missing_fields: list[str]) -> None: self.missing_fields = missing_fields missing = ", ".join(missing_fields) super().__init__(f"Non optional fields for message {msg.topic} were not set: {missing}") @@ -328,9 +321,7 @@ def _deserialize_content(msg: "JsonContent") -> dict | None: return json.loads(msg.content.decode()) if msg.content else None -def _get_submodule( - package: ModuleType, paths: list[str] -) -> "type[ProtoBunnyMessage] | ModuleType | None": +def _get_submodule(package: ModuleType, paths: list[str]) -> "type[PBM] | ModuleType | None": """Import module/class from package Args: @@ -586,5 +577,12 @@ def get_body(message: "IncomingMessageProtocol") -> str: return str(body) +# - types +PBM = tp.TypeVar("PBM", bound=ProtoBunnyMessage) +SyncCallback: tp.TypeAlias = tp.Callable[[PBM], tp.Any] +AsyncCallback: tp.TypeAlias = tp.Callable[[PBM], tp.Coroutine[tp.Any, tp.Any, None] | tp.Any] +ResultCallback: tp.TypeAlias = tp.Callable[["Result"], tp.Any] +LoggerCallback: tp.TypeAlias = tp.Callable[[tp.Any, str], tp.Any] + from .core.commons import JsonContent from .core.results import Result, ReturnCode diff --git a/scripts/convert_md.py b/scripts/convert_md.py deleted file mode 100644 index f7bb5e6..0000000 --- a/scripts/convert_md.py +++ /dev/null @@ -1,6 +0,0 @@ -import pypandoc - - -if __name__ == "__main__": - pypandoc.convert_file("README.md", "rst", outputfile="docs/source/readme.rst") - pypandoc.convert_file("QUICK_START.md", "rst", outputfile="docs/source/quick_start_rst.rst") diff --git a/tests/test_integration.py b/tests/test_integration.py index 8724392..6614617 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -378,12 +378,12 @@ async def predicate() -> bool: async def test_unsubscribe_results(self, backend) -> None: received_result: pb.results.Result | None = None - def callback_test(_: tests.TestMessage) -> None: + async def callback_test(_: tests.TestMessage) -> None: # The receiver catches the error in callback and will send a Result.FAILURE message # to the result topic raise RuntimeError("error in callback") - def callback_results(m: pb.results.Result) -> None: + async def callback_results(m: pb.results.Result) -> None: nonlocal received_result received_result = m diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 583a9ab..61cc6eb 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -10,7 +10,6 @@ from protobunny import asyncio as pb from protobunny import get_backend from protobunny.conf import Config, backend_configs -from protobunny.models import ProtoBunnyMessage from . import tests from .utils import async_wait, sync_wait @@ -78,12 +77,12 @@ async def predicate_1() -> bool: async def predicate_2() -> bool: return self.received.get("task_2") is not None - async def callback_task_1(msg: "ProtoBunnyMessage") -> None: + async def callback_task_1(msg: tests.tasks.TaskMessage) -> None: log.debug("CALLBACK TASK 1 %s", msg) await asyncio.sleep(0.1) # simulate some work self.received["task_1"] = msg - async def callback_task_2(msg: "ProtoBunnyMessage") -> None: + async def callback_task_2(msg: tests.tasks.TaskMessage) -> None: log.debug("CALLBACK TASK 2 %s", msg) await asyncio.sleep(0.1) # simulate some work self.received["task_2"] = msg @@ -177,11 +176,11 @@ def predicate_1() -> bool: def predicate_2() -> bool: return self.received.get("task_2") is not None - def callback_task_1(msg: "ProtoBunnyMessage") -> None: + def callback_task_1(msg: tests.tasks.TaskMessage) -> None: time.sleep(0.1) self.received["task_1"] = msg - def callback_task_2(msg: "ProtoBunnyMessage") -> None: + def callback_task_2(msg: tests.tasks.TaskMessage) -> None: time.sleep(0.1) self.received["task_2"] = msg From e474ee8983030731b0270289a26093fc68c3ceaa Mon Sep 17 00:00:00 2001 From: Domenico Nappo Date: Thu, 1 Jan 2026 13:46:44 +0100 Subject: [PATCH 7/7] Minor edits --- .github/workflows/ci.yml | 1 - Makefile | 1 - README.md | 33 +++++----- RECIPES.md | 108 +++++++++++++++++++++++++++++++ docs/source/api.rst | 3 + docs/source/concepts.rst | 2 +- docs/source/index.rst | 1 + docs/source/intro.md | 21 +++--- docs/source/quick_start.md | 15 ++--- docs/source/recipes.md | 127 +++++++++++++++++++++++++++++++++++++ tests/test_integration.py | 5 +- tests/test_tasks.py | 11 +++- 12 files changed, 286 insertions(+), 42 deletions(-) create mode 100644 docs/source/recipes.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 51971b4..5aedf35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,6 @@ jobs: uses: codecov/codecov-action@v5 if: matrix.python-version == '3.12' with: - file: ./coverage.xml flags: unittests name: codecov-umbrella diff --git a/Makefile b/Makefile index 649fc61..7f8ed9f 100644 --- a/Makefile +++ b/Makefile @@ -101,7 +101,6 @@ test-py313: # Releasing .PHONY: docs clean build-package publish-test publish-pypi copy-md copy-md: - cp ./README.md docs/source/intro.md cp ./QUICK_START.md docs/source/quick_start.md cp ./RECIPES.md docs/source/recipes.md diff --git a/README.md b/README.md index 7e3cadc..2a28490 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Protobunny -::: {warning} -**Note**: The project is in early development. -::: +> [!WARNING] +> The project is in early development. + Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, @@ -12,7 +12,7 @@ It simplifies messaging for asynchronous message handling by providing: * A clean “message-first” API by using your protobuf definitions * Message publishing/subscribing with typed topics -* Supports "task-like” queues (shared/competing consumers) vs. broadcast subscriptions +* Supports "task-like" queues (shared/competing consumers) vs. broadcast subscriptions * Generate and consume `Result` messages (success/failure + optional return payload) * Transparent messages serialization/deserialization * Transparently serialize/deserialize custom "JSON-like" payload fields (numpy-friendly) @@ -26,11 +26,11 @@ Supported backends in the current version are: - Mosquitto - Python "backend" with Queue/asyncio.Queue for local in-processing testing -::: {note} -**Note**: Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. -Direct access to the internal NATS or Redis clients is intentionally restricted. -If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. -::: + +> [!NOTE] +> Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +> Direct access to the internal NATS or Redis clients is intentionally restricted. +> If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. ## Minimal requirements @@ -70,13 +70,13 @@ While there are many messaging libraries for Python, Protobunny is built specifi ### Feature Comparison with some existing libraries -| Feature | **Protobunny** | **FastStream** | **Celery** | -|:-----------------------|:-------------------------|:------------------------|:------------------------| -| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | -| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | -| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | -| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | -| **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:-------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ✅ Yes | ❌ Heavyweight | --- @@ -96,6 +96,7 @@ Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow - [x] **Cloud-Native**: NATS (Core & JetStream) integration. - [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. - [ ] **More backends**: Kafka support. +- [ ] **gRPC** Direct Call support --- diff --git a/RECIPES.md b/RECIPES.md index 2c04fe4..c13109b 100644 --- a/RECIPES.md +++ b/RECIPES.md @@ -1,19 +1,127 @@ # Recipes +These examples are for sync context. +For async, imports the asyncio module and the logic remains the same, just with `async/await`. + +```python +from protobunny import asyncio as pb +``` ## Subscribe to a queue +To subscribe to a specific message type, use the `subscribe` method. This creates an exclusive queue by default (one consumer per queue instance). + +```python +import protobunny as pb +import mymessagelib as mml + + +def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +# Subscribe to the message class +pb.subscribe(mml.tests.TestMessage, on_message) + +# Block and wait for messages +pb.run_forever() +``` + +For the async version, `run_forever` accepts your main async method as coroutine, that will contain the `await pb.subscribe` calls. + +```python +from protobunny import asyncio as pb +import mymessagelib as mml + + +async def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +async def main(): + await pb.subscribe(mml.tests.TestMessage, on_message) + + +pb.run_forever(main) +``` + ## Subscribe a task worker to a shared topic +Protobunny treats any message defined within a `.tasks` package as a task. +Subscribing to these messages uses a shared queue, allowing multiple workers to balance the load (competing consumers). + +```python +import protobunny as pb +import mymessagelib.main.tasks as tasks + +def worker(task: tasks.TaskMessage) -> None: + print("Processing task:", task.content) + # Perform logic here... + +# Multiple instances of this script will share the load from the same queue +pb.subscribe(tasks.TaskMessage, worker) +pb.run_forever() +``` ## Publish +Publishing is straightforward. Protobunny automatically determines the correct topic and queue routing based on the message class. + +```python +import protobunny as pb +import mymessagelib as mml + +# Create the message instance +msg = mml.tests.TestMessage(content="Hello World", number=42) + +# Publish it +pb.publish(msg) +``` ## Results workflow +The results workflow allows you to send and receive feedback for a specific message, using the built-in `Result` message type. + +### Publishing a Result +Inside a message handler, you can create and publish a result tied to the source message. + +```python +def on_message(message: mml.tests.TestMessage) -> None: + # ... process message ... + + # Create a result from the source message + result = message.make_result( + return_value={"status": "success", "processed_at": "12:00"} + ) + pb.publish_result(result) +``` + +### Subscribing to Results +To listen for results of a specific message type: + +```python +def on_result(res: pb.results.Result) -> None: + # Access the original message that triggered this result + print("Source message:", res.source) + print("Data:", res.return_value) + +pb.subscribe_results(mml.tests.TestMessage, on_result) +``` ## Requeuing +If message processing fails and you want the broker to requeue it for another attempt, raise the `RequeueMessage` exception. +```python +from protobunny import RequeueMessage +import mymessagelib as mml +def on_message(message: mml.tests.TestMessage) -> None: + try: + # Attempt processing + do_work(message) + except Exception: + # This tells the backend to put the message back in the queue + raise RequeueMessage("Service busy, retrying...") +``` diff --git a/docs/source/api.rst b/docs/source/api.rst index 835c75d..c8fe636 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -6,6 +6,9 @@ Core Package .. automodule:: protobunny :members: + .. automodule:: protobunny.asyncio + :members: + :no-index: Models ----------- diff --git a/docs/source/concepts.rst b/docs/source/concepts.rst index eb36186..17dcda8 100644 --- a/docs/source/concepts.rst +++ b/docs/source/concepts.rst @@ -57,7 +57,7 @@ A result typically contains: -------------- Task-style queues ------------------ +~~~~~~~~~~~~~~~~~ All messages that are under a ``tasks`` package are treated as shared queues. diff --git a/docs/source/index.rst b/docs/source/index.rst index bc12de3..75761f3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,4 +14,5 @@ Protobunny intro quick_start concepts + recipes api diff --git a/docs/source/intro.md b/docs/source/intro.md index 5dcc3b4..3987c99 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,7 +1,7 @@ # Protobunny ::: {warning} -**Warning**: The project is in early development. +The project is in early development. ::: Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. @@ -27,8 +27,8 @@ Supported backends in the current version are: - Python "backend" with Queue/asyncio.Queue for local in-processing testing ::: {note} -**Note**: Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. -Direct access to the internal NATS or Redis clients is intentionally restricted. +Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +Direct access to the internal NATS or Redis clients is intentionally restricted. If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. ::: @@ -70,13 +70,13 @@ While there are many messaging libraries for Python, Protobunny is built specifi ### Feature Comparison with some existing libraries -| Feature | **Protobunny** | **FastStream** | **Celery** | -|:-----------------------|:-------------------------|:------------------------|:------------------------| -| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | -| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | -| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | -| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | -| **Framework Agnostic** | ✅ Yes | ⚠️ FastAPI-like focus | ❌ Heavyweight | +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:-------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ✅ Yes | ❌ Heavyweight | --- @@ -96,6 +96,7 @@ Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow - [x] **Cloud-Native**: NATS (Core & JetStream) integration. - [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. - [ ] **More backends**: Kafka support. +- [ ] **gRPC** Direct Call support --- diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index aa2240f..8b083b8 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -33,7 +33,7 @@ messages-directory = "messages" messages-prefix = "acme" generated-package-name = "mymessagelib.codegen" mode = "async" # or "sync" -backend = "rabbitmq" # available backends are ['rabbitmq', 'redis', 'nats', 'mosquitto', 'python'] +backend = "rabbitmq" # available backends are ['rabbitmq', 'redis', 'mosquitto', 'python'] ``` ### Install the library with `uv`, `poetry` or `pip` @@ -179,13 +179,15 @@ def worker1(task: mml.main.tasks.TaskMessage) -> None: def worker2(task: mml.main.tasks.TaskMessage) -> None: print("2- Working on:", task) - -pb.subscribe(mml.main.tasks.TaskMessage, worker1) +import mymessagelib as mml +pb.subscribe(mml.main.tasks.TasqkMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) +from protobunny.models import ProtoBunnyMessage +print(isinstance(mml.main.tasks.TaskMessage(), ProtoBunnyMessage)) ``` You can also introspect/manage an underlying shared queue: @@ -429,12 +431,10 @@ class TestLibAsync: await pb_asyncio.subscribe(ml.tests.TestMessage, self.on_message) await pb_asyncio.subscribe_results(ml.tests.TestMessage, self.on_message_results) await pb_asyncio.subscribe(ml.main.MyMessage, self.on_message_mymessage) - - # Send a simple message + await pb_asyncio.publish(ml.main.MyMessage(content="test")) - # Send a message with arbitrary json content await pb_asyncio.publish(ml.tests.TestMessage(number=1, content="test", data={"test": 123})) - # Send three messages to verify the load balancing between the workers + await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test1")) await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test2")) await pb_asyncio.publish(ml.main.tasks.TaskMessage(content="test3")) @@ -443,7 +443,6 @@ class TestLibAsync: class TestLib: - # Sync version def on_message(self, message: ml.tests.TestMessage) -> None: log.info("Got: %s", message) result = message.make_result() diff --git a/docs/source/recipes.md b/docs/source/recipes.md new file mode 100644 index 0000000..c13109b --- /dev/null +++ b/docs/source/recipes.md @@ -0,0 +1,127 @@ +# Recipes + +These examples are for sync context. +For async, imports the asyncio module and the logic remains the same, just with `async/await`. + +```python +from protobunny import asyncio as pb +``` + +## Subscribe to a queue + +To subscribe to a specific message type, use the `subscribe` method. This creates an exclusive queue by default (one consumer per queue instance). + +```python +import protobunny as pb +import mymessagelib as mml + + +def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +# Subscribe to the message class +pb.subscribe(mml.tests.TestMessage, on_message) + +# Block and wait for messages +pb.run_forever() +``` + +For the async version, `run_forever` accepts your main async method as coroutine, that will contain the `await pb.subscribe` calls. + +```python +from protobunny import asyncio as pb +import mymessagelib as mml + + +async def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +async def main(): + await pb.subscribe(mml.tests.TestMessage, on_message) + + +pb.run_forever(main) +``` + + +## Subscribe a task worker to a shared topic + +Protobunny treats any message defined within a `.tasks` package as a task. +Subscribing to these messages uses a shared queue, allowing multiple workers to balance the load (competing consumers). + +```python +import protobunny as pb +import mymessagelib.main.tasks as tasks + +def worker(task: tasks.TaskMessage) -> None: + print("Processing task:", task.content) + # Perform logic here... + +# Multiple instances of this script will share the load from the same queue +pb.subscribe(tasks.TaskMessage, worker) +pb.run_forever() +``` + +## Publish + +Publishing is straightforward. Protobunny automatically determines the correct topic and queue routing based on the message class. + +```python +import protobunny as pb +import mymessagelib as mml + +# Create the message instance +msg = mml.tests.TestMessage(content="Hello World", number=42) + +# Publish it +pb.publish(msg) +``` + +## Results workflow + +The results workflow allows you to send and receive feedback for a specific message, using the built-in `Result` message type. + +### Publishing a Result +Inside a message handler, you can create and publish a result tied to the source message. + +```python +def on_message(message: mml.tests.TestMessage) -> None: + # ... process message ... + + # Create a result from the source message + result = message.make_result( + return_value={"status": "success", "processed_at": "12:00"} + ) + pb.publish_result(result) +``` + +### Subscribing to Results +To listen for results of a specific message type: + +```python +def on_result(res: pb.results.Result) -> None: + # Access the original message that triggered this result + print("Source message:", res.source) + print("Data:", res.return_value) + +pb.subscribe_results(mml.tests.TestMessage, on_result) +``` + +## Requeuing + +If message processing fails and you want the broker to requeue it for another attempt, raise the `RequeueMessage` exception. + +```python +from protobunny import RequeueMessage +import mymessagelib as mml + +def on_message(message: mml.tests.TestMessage) -> None: + try: + # Attempt processing + do_work(message) + except Exception: + # This tells the backend to put the message back in the queue + raise RequeueMessage("Service busy, retrying...") +``` diff --git a/tests/test_integration.py b/tests/test_integration.py index 6614617..ebfca10 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -729,7 +729,6 @@ def callback_results_2(m: pb.results.Result) -> None: nonlocal received_result received_result = m - pb_base.unsubscribe_all() q1 = pb_base.subscribe(tests.TestMessage, callback_1) q2 = pb_base.subscribe(tests.tasks.TaskMessage, callback_2) assert q1.topic == "acme.tests.TestMessage".replace(".", self.topic_delimiter) @@ -740,8 +739,8 @@ def callback_results_2(m: pb.results.Result) -> None: pb_base.subscribe_results(tests.TestMessage, callback_results_2) pb_base.publish(tests.TestMessage(number=2, content="test")) pb_base.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) - assert sync_wait(lambda: received_message is not None) - assert sync_wait(lambda: received_result is not None) + assert sync_wait(lambda: received_message is not None, timeout=2) + assert sync_wait(lambda: received_result is not None, timeout=2) assert received_result.source == tests.TestMessage(number=2, content="test") pb_base.unsubscribe_all() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 61cc6eb..5e4adcc 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -170,6 +170,11 @@ def setup_test_env( @pytest.mark.flaky(max_runs=3) def test_tasks(self, backend) -> None: + """ + Assert load balancing between worker callbacks + and that tasks callbacks don't receive duplicated messages + """ + def predicate_1() -> bool: return self.received.get("task_1") is not None @@ -196,8 +201,10 @@ def callback_task_2(msg: tests.tasks.TaskMessage) -> None: pb_sync.publish(self.msg) pb_sync.publish(self.msg) pb_sync.publish(self.msg) - assert sync_wait(predicate_1) or sync_wait(predicate_2) - assert sync_wait(predicate_2) or sync_wait(predicate_1) + pb_sync.publish(self.msg) + # this is a bit flaky because of the backend load balancing + assert sync_wait(predicate_1, timeout=2) or sync_wait(predicate_2, timeout=2) + assert sync_wait(predicate_2, timeout=2) or sync_wait(predicate_1, timeout=2) assert self.received["task_1"] == self.msg assert self.received["task_2"] == self.msg self.received["task_1"] = None