diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d21b5d0..5aedf35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,8 +34,8 @@ jobs: - name: Install dependencies run: | if [ "${{ matrix.python-version }}" == "3.13" ]; then - # Sync everything EXCEPT the 'numpy' extra - uv sync --extra rabbitmq --extra redis --dev --extra mosquitto + # Sync everything EXCEPT the 'numpy' extra that currently has problems on github + uv sync --extra rabbitmq --extra redis --dev --extra mosquitto --extra nats else # Install everything for older versions uv sync --all-extras @@ -51,7 +51,6 @@ jobs: uses: codecov/codecov-action@v5 if: matrix.python-version == '3.12' with: - file: ./coverage.xml flags: unittests name: codecov-umbrella @@ -236,3 +235,64 @@ jobs: run: uv run python -m pytest tests/test_integration.py -k mosquitto -vvv -s env: MQTT_HOST: 127.0.0.1 + + integration_test_nats: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Run NATS server + run: | + docker run -d -p 4222:4222 --name nats nats:2-alpine -js + + - name: Log NATS container + run: | + sleep 3 + echo "=== Container logs ===" + docker logs nats + echo "=== Config file ===" + docker exec nats cat /etc/nats/nats-server.conf + echo "=== Processes ===" + docker exec nats ps aux + echo "=== Network ===" + docker exec nats netstat -tlnp || docker exec nats ss -tlnp + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python 3.12 + run: uv python install 3.12 + + - name: Create virtual environment + run: uv venv --python 3.12 + + - name: Install dependencies + run: uv sync --all-extras + + - name: Smoke test NATS + run: | + # Wait for the port to be open + timeout 30s sh -c 'until nc -z localhost 4222; do sleep 1; done' + + uv run python -c " + import nats + import sys + import asyncio + + async def smoke_test(): + try: + nats.connect() + print('NATS connected Successfully') + sys.exit(0) + except Exception as e: + print(f'Failed to connect: {e}') + sys.exit(1) + + asyncio.run(smoke_test()) + " + - name: Run tests + run: uv run python -m pytest tests/test_integration.py -k nats -vvv -s + env: + NATS_HOST: localhost diff --git a/Makefile b/Makefile index 2ef6dde..7f8ed9f 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,11 @@ t: PYTHONASYNCIODEBUG=1 PYTHONBREAKPOINT=ipdb.set_trace uv run pytest ${t} -s -vvvv --durations=0 integration-test: - uv run pytest tests/ -m "integration" ${t} + uv run pytest tests/ -m "integration" -k rabbitmq ${t} + uv run pytest tests/ -m "integration" -k redis ${t} + uv run pytest tests/ -m "integration" -k python ${t} + uv run pytest tests/ -m "integration" -k mosquitto ${t} + uv run pytest tests/ -m "integration" -k nats ${t} # run ./nats-server -js -sd nats_storage integration-test-py310: source .venv310/bin/activate @@ -97,8 +101,8 @@ test-py313: # Releasing .PHONY: docs clean build-package publish-test publish-pypi copy-md copy-md: - cp ./README.md docs/source/intro.md cp ./QUICK_START.md docs/source/quick_start.md + cp ./RECIPES.md docs/source/recipes.md docs: copy-md uv run sphinx-build -b html docs/source docs/build/html diff --git a/QUICK_START.md b/QUICK_START.md index b0397a9..8b083b8 100644 --- a/QUICK_START.md +++ b/QUICK_START.md @@ -21,7 +21,7 @@ You can also add it manually to pyproject.toml dependencies: ```toml dependencies = [ - "protobunny[rabbitmq, numpy]>=0.1.0", + "protobunny[rabbitmq, numpy]>=0.1.2a2", # your other dependencies ... ] ``` @@ -179,12 +179,15 @@ def worker1(task: mml.main.tasks.TaskMessage) -> None: def worker2(task: mml.main.tasks.TaskMessage) -> None: print("2- Working on:", task) - -pb.subscribe(mml.main.tasks.TaskMessage, worker1) +import mymessagelib as mml +pb.subscribe(mml.main.tasks.TasqkMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) + pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) +from protobunny.models import ProtoBunnyMessage +print(isinstance(mml.main.tasks.TaskMessage(), ProtoBunnyMessage)) ``` You can also introspect/manage an underlying shared queue: @@ -293,10 +296,14 @@ if conn.is_connected(): conn.close() ``` -If you set the `generated-package-root` folder option, you might need to add the path to your `sys.path`. +If you set the `generated-package-root` folder option, you might need to add that path to your `sys.path`. You can do it conveniently by calling `config_lib` on top of your module, before importing the library: + ```python -pb.config_lib() +import protobunny as pb +pb.config_lib() +# now you can import the library from the generated package root +import mymessagelib as mml ``` ## Complete example @@ -310,7 +317,7 @@ version = "0.1.0" description = "Project to test protobunny" requires-python = ">=3.10" dependencies = [ - "protobunny[rabbitmq,redis,numpy,mosquitto] >=0.1.2a1", + "protobunny[rabbitmq,redis,numpy,mosquitto]>=0.1.2a1", ] [tool.protobunny] @@ -374,7 +381,6 @@ import asyncio import logging import sys - import protobunny as pb from protobunny import asyncio as pb_asyncio @@ -386,12 +392,11 @@ pb.config_lib() import mymessagelib as ml - logging.basicConfig( level=logging.INFO, format="[%(asctime)s %(levelname)s] %(name)s - %(message)s" ) log = logging.getLogger(__name__) -conf = pb.default_configuration +conf = pb.config class TestLibAsync: @@ -413,7 +418,6 @@ class TestLibAsync: async def on_message_mymessage(self, message: ml.main.MyMessage) -> None: log.info("Got main message: %s", message) - def run_forever(self): asyncio.run(self.main()) @@ -421,7 +425,6 @@ class TestLibAsync: log.info(f"LOG {incoming.routing_key}: {body}") async def main(self): - await pb_asyncio.subscribe_logger(self.log_callback) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker1) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker2) diff --git a/README.md b/README.md index b901b62..2a28490 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,48 @@ # Protobunny -```{warning} -Note: The project is in early development. -``` +> [!WARNING] +> The project is in early development. + Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for any message broker, including Redis and MQTT. +type-safe interface for several message brokers, including Redis, NATS, and MQTT. -It simplifies messaging for asynchronous tasks by providing: +It simplifies messaging for asynchronous message handling by providing: -* A clean “message-first” API -* Python class generation from Protobuf messages using betterproto -* Connections facilities for backends +* A clean “message-first” API by using your protobuf definitions * Message publishing/subscribing with typed topics -* Support also “task-like” queues (shared/competing consumers) vs. broadcast subscriptions +* Supports "task-like" queues (shared/competing consumers) vs. broadcast subscriptions * Generate and consume `Result` messages (success/failure + optional return payload) * Transparent messages serialization/deserialization - * Support async and sync contexts -* Transparently serialize "JSON-like" payload fields (numpy-friendly) +* Transparently serialize/deserialize custom "JSON-like" payload fields (numpy-friendly) +* Support async and sync contexts + +Supported backends in the current version are: + +- RabbitMQ +- Redis +- NATS +- Mosquitto +- Python "backend" with Queue/asyncio.Queue for local in-processing testing + +> [!NOTE] +> Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +> Direct access to the internal NATS or Redis clients is intentionally restricted. +> If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. -## Requirements -- Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) +## Minimal requirements + +- Python >= 3.10 <=3.13 +- Core Dependencies: betterproto 2.0.0b7, grpcio-tools>=1.62.0 +- Backend Drivers (Optional based on your usage): + - NATS: nats-py (Requires NATS Server v2.10+ for full JetStream support). + - Redis: redis (Requires Redis Server v6.2+ for Stream support). + - RabbitMQ: aio-pika + - Mosquitto: aiomqtt + ## Project scope @@ -39,40 +57,46 @@ Protobunny is designed for teams who use messaging to coordinate work between mi - Optional validation of required fields - Builtin logging service ---- +## Why Protobunny? -## Usage +While there are many messaging libraries for Python, Protobunny is built specifically for teams that treat **Protobuf as the single source of truth**. + +* **Type-Safe by Design**: Built natively for `protobuf/betterproto`. +* **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). +* **Backend Agnostic**: You can choose between RabbitMQ, Redis, NATS, and Mosquitto. Python for local testing. +* **Sync & Async**: Support for both `asyncio` and traditional synchronous workloads. +* **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. +--- -See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) for installation and quick start guide. +### Feature Comparison with some existing libraries -Full docs are available at [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:-------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ✅ Yes | ❌ Heavyweight | --- -## Development +## Usage -### Run tests -```bash -make test -``` +See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) or on the [docs site](https://am-flow.github.io/protobunny/quickstart.html). -### Integration tests (RabbitMQ required) +Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). -Integration tests expect RabbitMQ to be running (for example via Docker Compose in this repo): -```bash -docker compose up -d -make integration-test -``` --- - -### Future work - -- Support grcp -- Support for RabbitMQ certificates (through `pika`) -- More backends: - - NATS - - Kafka - - Cloud providers (AWS SQS/SNS) +### Roadmap + +- [x] **Core Support**: Redis, RabbitMQ, Mosquitto. +- [x] **Semantic Patterns**: Automatic `tasks` package routing. +- [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. +- [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. +- [x] **Cloud-Native**: NATS (Core & JetStream) integration. +- [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. +- [ ] **More backends**: Kafka support. +- [ ] **gRPC** Direct Call support --- diff --git a/RECIPES.md b/RECIPES.md index e69de29..c13109b 100644 --- a/RECIPES.md +++ b/RECIPES.md @@ -0,0 +1,127 @@ +# Recipes + +These examples are for sync context. +For async, imports the asyncio module and the logic remains the same, just with `async/await`. + +```python +from protobunny import asyncio as pb +``` + +## Subscribe to a queue + +To subscribe to a specific message type, use the `subscribe` method. This creates an exclusive queue by default (one consumer per queue instance). + +```python +import protobunny as pb +import mymessagelib as mml + + +def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +# Subscribe to the message class +pb.subscribe(mml.tests.TestMessage, on_message) + +# Block and wait for messages +pb.run_forever() +``` + +For the async version, `run_forever` accepts your main async method as coroutine, that will contain the `await pb.subscribe` calls. + +```python +from protobunny import asyncio as pb +import mymessagelib as mml + + +async def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +async def main(): + await pb.subscribe(mml.tests.TestMessage, on_message) + + +pb.run_forever(main) +``` + + +## Subscribe a task worker to a shared topic + +Protobunny treats any message defined within a `.tasks` package as a task. +Subscribing to these messages uses a shared queue, allowing multiple workers to balance the load (competing consumers). + +```python +import protobunny as pb +import mymessagelib.main.tasks as tasks + +def worker(task: tasks.TaskMessage) -> None: + print("Processing task:", task.content) + # Perform logic here... + +# Multiple instances of this script will share the load from the same queue +pb.subscribe(tasks.TaskMessage, worker) +pb.run_forever() +``` + +## Publish + +Publishing is straightforward. Protobunny automatically determines the correct topic and queue routing based on the message class. + +```python +import protobunny as pb +import mymessagelib as mml + +# Create the message instance +msg = mml.tests.TestMessage(content="Hello World", number=42) + +# Publish it +pb.publish(msg) +``` + +## Results workflow + +The results workflow allows you to send and receive feedback for a specific message, using the built-in `Result` message type. + +### Publishing a Result +Inside a message handler, you can create and publish a result tied to the source message. + +```python +def on_message(message: mml.tests.TestMessage) -> None: + # ... process message ... + + # Create a result from the source message + result = message.make_result( + return_value={"status": "success", "processed_at": "12:00"} + ) + pb.publish_result(result) +``` + +### Subscribing to Results +To listen for results of a specific message type: + +```python +def on_result(res: pb.results.Result) -> None: + # Access the original message that triggered this result + print("Source message:", res.source) + print("Data:", res.return_value) + +pb.subscribe_results(mml.tests.TestMessage, on_result) +``` + +## Requeuing + +If message processing fails and you want the broker to requeue it for another attempt, raise the `RequeueMessage` exception. + +```python +from protobunny import RequeueMessage +import mymessagelib as mml + +def on_message(message: mml.tests.TestMessage) -> None: + try: + # Attempt processing + do_work(message) + except Exception: + # This tells the backend to put the message back in the queue + raise RequeueMessage("Service busy, retrying...") +``` diff --git a/docs/source/api.rst b/docs/source/api.rst index 835c75d..c8fe636 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -6,6 +6,9 @@ Core Package .. automodule:: protobunny :members: + .. automodule:: protobunny.asyncio + :members: + :no-index: Models ----------- diff --git a/docs/source/concepts.rst b/docs/source/concepts.rst index eb36186..17dcda8 100644 --- a/docs/source/concepts.rst +++ b/docs/source/concepts.rst @@ -57,7 +57,7 @@ A result typically contains: -------------- Task-style queues ------------------ +~~~~~~~~~~~~~~~~~ All messages that are under a ``tasks`` package are treated as shared queues. diff --git a/docs/source/conf.py b/docs/source/conf.py index 78b9db9..e287132 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,6 +28,7 @@ "sphinx.ext.viewcode", "myst_parser", ] +myst_enable_extensions = ["colon_fence"] templates_path = ["_templates"] diff --git a/docs/source/index.rst b/docs/source/index.rst index 6655250..75761f3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,57 +1,7 @@ Protobunny ========== - -.. warning:: The project is in early development. - -Protobunny is the open-source evolution of -`AM-Flow `__\ ’s internal messaging library. While -the original was purpose-built for RabbitMQ, this version has been -completely re-engineered to provide a unified, type-safe interface for -any message broker, including Redis and MQTT. - -It simplifies messaging for asynchronous tasks by providing: - -- A clean “message-first” API -- Python class generation from Protobuf messages using betterproto -- Connections facilities for backends -- Message publishing/subscribing with typed topics -- Support also “task-like” queues (shared/competing consumers) - vs. broadcast subscriptions -- Generate and consume ``Result`` messages (success/failure + optional - return payload) -- Transparent messages serialization/deserialization -- Support async and sync contexts -- Transparently serialize “JSON-like” payload fields (numpy-friendly) - -Requirements ------------- - -- Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) - -Project scope -------------- - -Protobunny is designed for teams who use messaging to coordinate work -between microservices or different python processes and want: - -- A small API surface, easy to learn and use, both async and sync -- Typed messaging with protobuf messages as payloads -- Supports various backends by simple configuration: RabbitMQ, Redis, - Mosquitto, local in-process queues -- Consistent topic naming and routing -- Builtin task queue semantics and result messages -- Transparent handling of JSON-like payload fields as plain - dictionaries/lists -- Optional validation of required fields -- Builtin logging service - --------------- - -Usage ------ - -See the `Quick start guide `__. +.. include:: intro.md + :parser: myst_parser.sphinx_ -------------- @@ -64,4 +14,5 @@ See the `Quick start guide `__. intro quick_start concepts + recipes api diff --git a/docs/source/intro.md b/docs/source/intro.md index b901b62..3987c99 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,30 +1,48 @@ # Protobunny -```{warning} -Note: The project is in early development. -``` +::: {warning} +The project is in early development. +::: Protobunny is the open-source evolution of [AM-Flow](https://am-flow.com)'s internal messaging library. While the original was purpose-built for RabbitMQ, this version has been completely re-engineered to provide a unified, -type-safe interface for any message broker, including Redis and MQTT. +type-safe interface for several message brokers, including Redis, NATS, and MQTT. -It simplifies messaging for asynchronous tasks by providing: +It simplifies messaging for asynchronous message handling by providing: -* A clean “message-first” API -* Python class generation from Protobuf messages using betterproto -* Connections facilities for backends +* A clean “message-first” API by using your protobuf definitions * Message publishing/subscribing with typed topics -* Support also “task-like” queues (shared/competing consumers) vs. broadcast subscriptions +* Supports "task-like” queues (shared/competing consumers) vs. broadcast subscriptions * Generate and consume `Result` messages (success/failure + optional return payload) * Transparent messages serialization/deserialization - * Support async and sync contexts -* Transparently serialize "JSON-like" payload fields (numpy-friendly) +* Transparently serialize/deserialize custom "JSON-like" payload fields (numpy-friendly) +* Support async and sync contexts +Supported backends in the current version are: -## Requirements +- RabbitMQ +- Redis +- NATS +- Mosquitto +- Python "backend" with Queue/asyncio.Queue for local in-processing testing + +::: {note} +Protobunny handles backend-specific logic internally to provide a consistent experience and a lean interface. +Direct access to the internal NATS or Redis clients is intentionally restricted. +If your project depends on specialized backend parameters not covered by our API, you may find the abstraction too restrictive. +::: + + +## Minimal requirements + +- Python >= 3.10 <=3.13 +- Core Dependencies: betterproto 2.0.0b7, grpcio-tools>=1.62.0 +- Backend Drivers (Optional based on your usage): + - NATS: nats-py (Requires NATS Server v2.10+ for full JetStream support). + - Redis: redis (Requires Redis Server v6.2+ for Stream support). + - RabbitMQ: aio-pika + - Mosquitto: aiomqtt -- Python >= 3.10, < 3.14 -- Backend message broker (e.g. RabbitMQ) ## Project scope @@ -39,40 +57,46 @@ Protobunny is designed for teams who use messaging to coordinate work between mi - Optional validation of required fields - Builtin logging service ---- +## Why Protobunny? -## Usage +While there are many messaging libraries for Python, Protobunny is built specifically for teams that treat **Protobuf as the single source of truth**. -See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) for installation and quick start guide. +* **Type-Safe by Design**: Built natively for `protobuf/betterproto`. +* **Semantic Routing**: Zero-config infrastructure. Protobunny uses your Protobuf package structure to decide if a message should be broadcast (Pub/Sub) or queued (Producer/Consumer). +* **Backend Agnostic**: You can choose between RabbitMQ, Redis, NATS, and Mosquitto. Python for local testing. +* **Sync & Async**: Support for both `asyncio` and traditional synchronous workloads. +* **Battle-Tested**: Derived from internal libraries used in production systems at AM-Flow. +--- + +### Feature Comparison with some existing libraries -Full docs are available at [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). +| Feature | **Protobunny** | **FastStream** | **Celery** | +|:-----------------------|:-------------------------|:-------------------|:------------------------| +| **Multi-Backend** | ✅ Yes | ✅ Yes | ⚠️ (Tasks only) | +| **Typed Protobufs** | ✅ Native (Betterproto) | ⚠️ Manual/Pydantic | ❌ No | +| **Sync + Async** | ✅ Yes | ✅ Yes | ❌ Sync focus | +| **Pattern Routing** | ✅ Auto (`tasks` pkg) | ❌ Manual Config | ✅ Fixed | +| **Framework Agnostic** | ✅ Yes | ✅ Yes | ❌ Heavyweight | --- -## Development +## Usage -### Run tests -```bash -make test -``` +See the [Quick example on GitHub](https://github.com/am-flow/protobunny/blob/main/QUICK_START.md) or on the [docs site](https://am-flow.github.io/protobunny/quickstart.html). -### Integration tests (RabbitMQ required) +Documentation home page: [https://am-flow.github.io/protobunny/](https://am-flow.github.io/protobunny/). -Integration tests expect RabbitMQ to be running (for example via Docker Compose in this repo): -```bash -docker compose up -d -make integration-test -``` --- - -### Future work - -- Support grcp -- Support for RabbitMQ certificates (through `pika`) -- More backends: - - NATS - - Kafka - - Cloud providers (AWS SQS/SNS) +### Roadmap + +- [x] **Core Support**: Redis, RabbitMQ, Mosquitto. +- [x] **Semantic Patterns**: Automatic `tasks` package routing. +- [x] **Arbistrary dictionary parsing**: Transparently parse JSON-like fields as dictionaries/lists by using protobunny JsonContent type. +- [x] **Result workflow**: Subscribe to results topics and receive protobunny `Result` messages produced by your callbacks. +- [x] **Cloud-Native**: NATS (Core & JetStream) integration. +- [ ] **Cloud Providers**: AWS (SQS/SNS) and GCP Pub/Sub. +- [ ] **More backends**: Kafka support. +- [ ] **gRPC** Direct Call support --- diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index b0397a9..8b083b8 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -21,7 +21,7 @@ You can also add it manually to pyproject.toml dependencies: ```toml dependencies = [ - "protobunny[rabbitmq, numpy]>=0.1.0", + "protobunny[rabbitmq, numpy]>=0.1.2a2", # your other dependencies ... ] ``` @@ -179,12 +179,15 @@ def worker1(task: mml.main.tasks.TaskMessage) -> None: def worker2(task: mml.main.tasks.TaskMessage) -> None: print("2- Working on:", task) - -pb.subscribe(mml.main.tasks.TaskMessage, worker1) +import mymessagelib as mml +pb.subscribe(mml.main.tasks.TasqkMessage, worker1) pb.subscribe(mml.main.tasks.TaskMessage, worker2) + pb.publish(mml.main.tasks.TaskMessage(content="test1")) pb.publish(mml.main.tasks.TaskMessage(content="test2")) pb.publish(mml.main.tasks.TaskMessage(content="test3")) +from protobunny.models import ProtoBunnyMessage +print(isinstance(mml.main.tasks.TaskMessage(), ProtoBunnyMessage)) ``` You can also introspect/manage an underlying shared queue: @@ -293,10 +296,14 @@ if conn.is_connected(): conn.close() ``` -If you set the `generated-package-root` folder option, you might need to add the path to your `sys.path`. +If you set the `generated-package-root` folder option, you might need to add that path to your `sys.path`. You can do it conveniently by calling `config_lib` on top of your module, before importing the library: + ```python -pb.config_lib() +import protobunny as pb +pb.config_lib() +# now you can import the library from the generated package root +import mymessagelib as mml ``` ## Complete example @@ -310,7 +317,7 @@ version = "0.1.0" description = "Project to test protobunny" requires-python = ">=3.10" dependencies = [ - "protobunny[rabbitmq,redis,numpy,mosquitto] >=0.1.2a1", + "protobunny[rabbitmq,redis,numpy,mosquitto]>=0.1.2a1", ] [tool.protobunny] @@ -374,7 +381,6 @@ import asyncio import logging import sys - import protobunny as pb from protobunny import asyncio as pb_asyncio @@ -386,12 +392,11 @@ pb.config_lib() import mymessagelib as ml - logging.basicConfig( level=logging.INFO, format="[%(asctime)s %(levelname)s] %(name)s - %(message)s" ) log = logging.getLogger(__name__) -conf = pb.default_configuration +conf = pb.config class TestLibAsync: @@ -413,7 +418,6 @@ class TestLibAsync: async def on_message_mymessage(self, message: ml.main.MyMessage) -> None: log.info("Got main message: %s", message) - def run_forever(self): asyncio.run(self.main()) @@ -421,7 +425,6 @@ class TestLibAsync: log.info(f"LOG {incoming.routing_key}: {body}") async def main(self): - await pb_asyncio.subscribe_logger(self.log_callback) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker1) await pb_asyncio.subscribe(ml.main.tasks.TaskMessage, self.worker2) diff --git a/docs/source/recipes.md b/docs/source/recipes.md new file mode 100644 index 0000000..c13109b --- /dev/null +++ b/docs/source/recipes.md @@ -0,0 +1,127 @@ +# Recipes + +These examples are for sync context. +For async, imports the asyncio module and the logic remains the same, just with `async/await`. + +```python +from protobunny import asyncio as pb +``` + +## Subscribe to a queue + +To subscribe to a specific message type, use the `subscribe` method. This creates an exclusive queue by default (one consumer per queue instance). + +```python +import protobunny as pb +import mymessagelib as mml + + +def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +# Subscribe to the message class +pb.subscribe(mml.tests.TestMessage, on_message) + +# Block and wait for messages +pb.run_forever() +``` + +For the async version, `run_forever` accepts your main async method as coroutine, that will contain the `await pb.subscribe` calls. + +```python +from protobunny import asyncio as pb +import mymessagelib as mml + + +async def on_message(message: mml.tests.TestMessage) -> None: + print("Received:", message.content) + + +async def main(): + await pb.subscribe(mml.tests.TestMessage, on_message) + + +pb.run_forever(main) +``` + + +## Subscribe a task worker to a shared topic + +Protobunny treats any message defined within a `.tasks` package as a task. +Subscribing to these messages uses a shared queue, allowing multiple workers to balance the load (competing consumers). + +```python +import protobunny as pb +import mymessagelib.main.tasks as tasks + +def worker(task: tasks.TaskMessage) -> None: + print("Processing task:", task.content) + # Perform logic here... + +# Multiple instances of this script will share the load from the same queue +pb.subscribe(tasks.TaskMessage, worker) +pb.run_forever() +``` + +## Publish + +Publishing is straightforward. Protobunny automatically determines the correct topic and queue routing based on the message class. + +```python +import protobunny as pb +import mymessagelib as mml + +# Create the message instance +msg = mml.tests.TestMessage(content="Hello World", number=42) + +# Publish it +pb.publish(msg) +``` + +## Results workflow + +The results workflow allows you to send and receive feedback for a specific message, using the built-in `Result` message type. + +### Publishing a Result +Inside a message handler, you can create and publish a result tied to the source message. + +```python +def on_message(message: mml.tests.TestMessage) -> None: + # ... process message ... + + # Create a result from the source message + result = message.make_result( + return_value={"status": "success", "processed_at": "12:00"} + ) + pb.publish_result(result) +``` + +### Subscribing to Results +To listen for results of a specific message type: + +```python +def on_result(res: pb.results.Result) -> None: + # Access the original message that triggered this result + print("Source message:", res.source) + print("Data:", res.return_value) + +pb.subscribe_results(mml.tests.TestMessage, on_result) +``` + +## Requeuing + +If message processing fails and you want the broker to requeue it for another attempt, raise the `RequeueMessage` exception. + +```python +from protobunny import RequeueMessage +import mymessagelib as mml + +def on_message(message: mml.tests.TestMessage) -> None: + try: + # Attempt processing + do_work(message) + except Exception: + # This tells the backend to put the message back in the queue + raise RequeueMessage("Service busy, retrying...") +``` diff --git a/protobunny/__init__.py b/protobunny/__init__.py index db209ff..44d5827 100644 --- a/protobunny/__init__.py +++ b/protobunny/__init__.py @@ -1,15 +1,8 @@ """ -A module providing support for messaging and communication using RabbitMQ as the backend. +A module providing support for messaging and communication using the configured broker as the backend. This module includes functionality for publishing, subscribing, and managing message queues, -as well as dynamically managing imports and configurations for RabbitMQ-based communication -logics. It enables both synchronous and asynchronous operations, while also supporting -connection resetting and management. - -Modules and functionality are primarily imported from the core RabbitMQ backend, dynamically -generated package-specific configurations, and other base utilities. Exports are adjusted -as per the backend configuration. - +as well as dynamically managing imports and configurations for the backend. """ __all__ = [ @@ -27,7 +20,7 @@ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -47,25 +40,24 @@ from importlib.metadata import version from types import FrameType, ModuleType -from .backends import BaseSyncQueue, LoggingSyncQueue -from .config import ( # noqa +from .backends import BaseSyncConnection, BaseSyncQueue, LoggingSyncQueue +from .conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from .exceptions import ConnectionError, RequeueMessage from .helpers import get_backend, get_queue -from .registry import default_registry +from .registry import registry if tp.TYPE_CHECKING: from .core.results import Result - from .models import ( - IncomingMessageProtocol, - LoggerCallback, - ProtoBunnyMessage, - SyncCallback, - ) + from .models import PBM, IncomingMessageProtocol, LoggerCallback, SyncCallback + +__version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) ############################ @@ -73,28 +65,30 @@ ############################ -def reset_connection(): - backend = get_backend() - return backend.connection.reset_connection() - - -def connect(): - backend = get_backend() - return backend.connection.connect() - +def connect(**kwargs: tp.Any) -> "BaseSyncConnection": + """Connect teh backend and get the singleton async connection. + You can pass the specific keyword arguments that you would pass to the `connect` function of the configured backend. + """ + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) + return conn -def disconnect(): - backend = get_backend() - return backend.connection.disconnect() +def disconnect() -> None: + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() -__version__ = version(PACKAGE_NAME) +def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": + """Reset the singleton connection.""" + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() + return conn.connect(**kwargs) -log = logging.getLogger(PACKAGE_NAME) - -def publish(message: "ProtoBunnyMessage") -> None: +def publish(message: "PBM") -> None: """Synchronously publish a message to its corresponding queue. This method automatically determines the correct topic based on the @@ -123,7 +117,7 @@ def publish_result( def subscribe( - pkg_or_msg: "type[ProtoBunnyMessage] | ModuleType", + pkg_or_msg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the topic. @@ -137,24 +131,22 @@ def subscribe( """ register_key = str(pkg_or_msg) - with default_registry.sync_lock: + with registry.sync_lock: queue = get_queue(pkg_or_msg) if queue.shared_queue: # It's a task. Handle multiple subscriptions - # queue = get_queue(pkg_or_msg) queue.subscribe(callback) - default_registry.register_task(register_key, queue) + registry.register_task(register_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(register_key) or queue + queue = registry.get_subscription(register_key) or queue queue.subscribe(callback) - # register subscription to unsubscribe later - default_registry.register_subscription(register_key, queue) + registry.register_subscription(register_key, queue) return queue def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the result topic. @@ -166,39 +158,39 @@ def subscribe_results( queue = get_queue(pkg) queue.subscribe_results(callback) # register subscription to unsubscribe later - with default_registry.sync_lock: - default_registry.register_results(pkg, queue) + with registry.sync_lock: + registry.register_results(pkg, queue) return queue def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: """Remove a subscription for a message/package""" module_name = pkg.__module__ if hasattr(pkg, "__module__") else pkg.__name__ - registry_key = default_registry.get_key(pkg) + registry_key = registry.get_key(pkg) - with default_registry.sync_lock: + with registry.sync_lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for queue in queues: queue.unsubscribe() - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - with default_registry.sync_lock: - queue = default_registry.unregister_results(pkg) + with registry.sync_lock: + queue = registry.unregister_results(pkg) if queue: queue.unsubscribe_results() @@ -210,29 +202,37 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - with default_registry.sync_lock: - queues = default_registry.get_all_subscriptions() + with registry.sync_lock: + queues = registry.get_all_subscriptions() for q in queues: q.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + queues = registry.get_all_results() for q in queues: q.unsubscribe_results() - default_registry.unregister_all_results() - queues = default_registry.get_all_tasks(flat=True) + registry.unregister_all_results() + queues = registry.get_all_tasks(flat=True) for q in queues: q.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_all_tasks() + registry.unregister_all_tasks() def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_message_count() return count +def get_consumer_count( + msg_type: "PBM | type[PBM] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = q.get_consumer_count() + return count + + def is_module_tasks(module_name: str) -> bool: return "tasks" in module_name.split(".") @@ -274,14 +274,14 @@ def shutdown(signum: int, _: FrameType | None) -> None: signal.signal(signal.SIGINT, shutdown) signal.signal(signal.SIGTERM, shutdown) - log.info("Protobunny Started. Press Ctrl+C to exit.") + log.info("Protobunny Started.") signal.pause() - log.info("Protobunny Stopped. Press Ctrl+C to exit.") + log.info("Protobunny Stopped.") def config_lib() -> None: """Add the generated package root to the sys.path.""" - lib_root = default_configuration.generated_package_root + lib_root = config.generated_package_root if lib_root and lib_root not in sys.path: sys.path.append(lib_root) diff --git a/protobunny/__init__.py.j2 b/protobunny/__init__.py.j2 index 71a0e5c..2654cda 100644 --- a/protobunny/__init__.py.j2 +++ b/protobunny/__init__.py.j2 @@ -1,16 +1,10 @@ """ -A module providing support for messaging and communication using RabbitMQ as the backend. +A module providing support for messaging and communication using the configured broker as the backend. This module includes functionality for publishing, subscribing, and managing message queues, -as well as dynamically managing imports and configurations for RabbitMQ-based communication -logics. It enables both synchronous and asynchronous operations, while also supporting -connection resetting and management. - -Modules and functionality are primarily imported from the core RabbitMQ backend, dynamically -generated package-specific configurations, and other base utilities. Exports are adjusted -as per the backend configuration. - +as well as dynamically managing imports and configurations for the backend. """ + __all__ = [ "get_backend", "get_message_count", @@ -26,7 +20,7 @@ __all__ = [ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", @@ -49,50 +43,56 @@ import typing as tp from importlib.metadata import version from types import ModuleType, FrameType -from .config import ( # noqa +from .conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from .exceptions import RequeueMessage, ConnectionError -from .registry import default_registry +from .registry import registry from .helpers import get_backend, get_queue -from .backends import LoggingSyncQueue, BaseSyncQueue +from .backends import LoggingSyncQueue, BaseSyncQueue, BaseSyncConnection if tp.TYPE_CHECKING: - from .models import LoggerCallback, ProtoBunnyMessage, SyncCallback, IncomingMessageProtocol + from .models import LoggerCallback, PBM, SyncCallback, IncomingMessageProtocol from .core.results import Result +__version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) + ############################ # -- Sync top-level methods ############################ -def reset_connection(): - backend = get_backend() - return backend.connection.reset_connection() - - -def connect(): - backend = get_backend() - return backend.connection.connect() - +def connect(**kwargs: tp.Any) -> "BaseSyncConnection": + """Connect teh backend and get the singleton async connection. + You can pass the specific keyword arguments that you would pass to the `connect` function of the configured backend. + """ + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) + return conn -def disconnect(): - backend = get_backend() - return backend.connection.disconnect() +def disconnect() -> None: + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() -__version__ = version(PACKAGE_NAME) +def reset_connection(**kwargs: tp.Any) -> "BaseSyncConnection": + """Reset the singleton connection.""" + connection_module = get_backend().connection + conn = connection_module.Connection.get_connection(vhost=connection_module.VHOST) + conn.disconnect() + return conn.connect(**kwargs) -log = logging.getLogger(PACKAGE_NAME) - -def publish(message: "ProtoBunnyMessage") -> None: +def publish(message: "PBM") -> None: """Synchronously publish a message to its corresponding queue. This method automatically determines the correct topic based on the @@ -121,7 +121,7 @@ def publish_result( def subscribe( - pkg_or_msg: "type[ProtoBunnyMessage] | ModuleType", + pkg_or_msg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the topic. @@ -135,24 +135,22 @@ def subscribe( """ register_key = str(pkg_or_msg) - with default_registry.sync_lock: + with registry.sync_lock: queue = get_queue(pkg_or_msg) if queue.shared_queue: # It's a task. Handle multiple subscriptions - # queue = get_queue(pkg_or_msg) queue.subscribe(callback) - default_registry.register_task(register_key, queue) + registry.register_task(register_key, queue) else: # exclusive queue - queue = default_registry.get_subscription(register_key) or queue + queue = registry.get_subscription(register_key) or queue queue.subscribe(callback) - # register subscription to unsubscribe later - default_registry.register_subscription(register_key, queue) + registry.register_subscription(register_key, queue) return queue def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "SyncCallback", ) -> "BaseSyncQueue": """Subscribe a callback function to the result topic. @@ -164,39 +162,39 @@ def subscribe_results( queue = get_queue(pkg) queue.subscribe_results(callback) # register subscription to unsubscribe later - with default_registry.sync_lock: - default_registry.register_results(pkg, queue) + with registry.sync_lock: + registry.register_results(pkg, queue) return queue def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: """Remove a subscription for a message/package""" module_name = pkg.__module__ if hasattr(pkg, "__module__") else pkg.__name__ - registry_key = default_registry.get_key(pkg) + registry_key = registry.get_key(pkg) - with default_registry.sync_lock: + with registry.sync_lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for queue in queues: queue.unsubscribe() - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - with default_registry.sync_lock: - queue = default_registry.unregister_results(pkg) + with registry.sync_lock: + queue = registry.unregister_results(pkg) if queue: queue.unsubscribe_results() @@ -208,29 +206,37 @@ def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - with default_registry.sync_lock: - queues = default_registry.get_all_subscriptions() + with registry.sync_lock: + queues = registry.get_all_subscriptions() for q in queues: q.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - queues = default_registry.get_all_results() + registry.unregister_all_subscriptions() + queues = registry.get_all_results() for q in queues: q.unsubscribe_results() - default_registry.unregister_all_results() - queues = default_registry.get_all_tasks(flat=True) + registry.unregister_all_results() + queues = registry.get_all_tasks(flat=True) for q in queues: q.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_all_tasks() + registry.unregister_all_tasks() def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = q.get_message_count() return count +def get_consumer_count( + msg_type: "PBM | type[PBM] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = q.get_consumer_count() + return count + + def is_module_tasks(module_name: str) -> bool: return "tasks" in module_name.split(".") @@ -271,14 +277,14 @@ def run_forever() -> None: signal.signal(signal.SIGINT, shutdown) signal.signal(signal.SIGTERM, shutdown) - log.info("Protobunny Started. Press Ctrl+C to exit.") + log.info("Protobunny Started.") signal.pause() - log.info("Protobunny Stopped. Press Ctrl+C to exit.") + log.info("Protobunny Stopped.") def config_lib() -> None: """Add the generated package root to the sys.path.""" - lib_root = default_configuration.generated_package_root + lib_root = config.generated_package_root if lib_root and lib_root not in sys.path: sys.path.append(lib_root) @@ -290,4 +296,4 @@ from .{{ generated_package_name }} import ( # noqa {{ i }}, {% endfor %} ) -####################################################### \ No newline at end of file +####################################################### diff --git a/protobunny/asyncio/__init__.py b/protobunny/asyncio/__init__.py index 841b715..0d5d879 100644 --- a/protobunny/asyncio/__init__.py +++ b/protobunny/asyncio/__init__.py @@ -1,6 +1,3 @@ -import asyncio -import signal - """ A module providing support for messaging and communication using RabbitMQ as the backend. @@ -14,6 +11,7 @@ as per the backend configuration. """ + __all__ = [ "get_message_count", "get_queue", @@ -29,75 +27,83 @@ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", "connect", "disconnect", "run_forever", + "config_lib", # from .core "commons", "results", ] +import asyncio import inspect import itertools import logging +import signal import textwrap import typing as tp from importlib.metadata import version ####################################################### -from ..config import ( # noqa +from ..conf import ( # noqa GENERATED_PACKAGE_NAME, PACKAGE_NAME, ROOT_GENERATED_PACKAGE_NAME, - default_configuration, + config, ) from ..exceptions import ConnectionError, RequeueMessage -from ..registry import default_registry +from ..registry import registry if tp.TYPE_CHECKING: from types import ModuleType from ..core.results import Result - from ..models import ( - AsyncCallback, - IncomingMessageProtocol, - LoggerCallback, - ProtoBunnyMessage, - ) + from ..models import PBM, AsyncCallback, IncomingMessageProtocol, LoggerCallback +from .. import config_lib as config_lib from ..helpers import get_backend, get_queue -from .backends import BaseAsyncQueue, LoggingAsyncQueue +from .backends import BaseAsyncConnection, BaseAsyncQueue, LoggingAsyncQueue __version__ = version(PACKAGE_NAME) + +log = logging.getLogger(PACKAGE_NAME) + + ############################ # -- Async top-level methods ############################ -log = logging.getLogger(PACKAGE_NAME) - -async def reset_connection(): - backend = get_backend() - return await backend.connection.reset_connection() +async def connect(**kwargs) -> "BaseAsyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection( + vhost=connection_module.VHOST, **kwargs + ) + return conn -async def connect(): - backend = get_backend() - return await backend.connection.connect() +async def disconnect() -> None: + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + await conn.disconnect() -async def disconnect(): - backend = get_backend() - return await backend.connection.disconnect() +async def reset_connection() -> "BaseAsyncConnection": + """Reset the singleton connection.""" + connection = await connect() + await connection.disconnect() + return await connect() -async def publish(message: "ProtoBunnyMessage") -> None: +async def publish(message: "PBM") -> None: """Asynchronously publish a message to its corresponding queue. Args: @@ -124,7 +130,7 @@ async def publish_result( async def subscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """ @@ -144,53 +150,52 @@ async def subscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = str(pkg) - async with default_registry.lock: + async with registry.lock: if is_module_tasks(module_name): # It's a task. Handle multiple in-process subscriptions queue = get_queue(pkg) await queue.subscribe(callback) - default_registry.register_task(registry_key, queue) + registry.register_task(registry_key, queue) else: - # exclusive queue - queue = default_registry.get_subscription(registry_key) or get_queue(pkg) - # queue already exists, but not subscribed yet (otherwise raise ValueError) - await queue.subscribe(callback) - default_registry.register_subscription(registry_key, queue) + # exclusive queue, cannot register more than one callback + queue = registry.get_subscription(registry_key) or get_queue(pkg) + if not queue.subscription: + await queue.subscribe(callback) + registry.register_subscription(registry_key, queue) return queue async def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: """Remove a subscription for a message/package""" - # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ - registry_key = default_registry.get_key(pkg) - async with default_registry.lock: + registry_key = registry.get_key(pkg) + async with registry.lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for q in queues: await q.unsubscribe(if_unused=if_unused) - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) async def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - async with default_registry.lock: - queue = default_registry.get_results(pkg) + async with registry.lock: + queue = registry.get_results(pkg) if queue: await queue.unsubscribe_results() - default_registry.unregister_results(pkg) + registry.unregister_results(pkg) async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: @@ -200,23 +205,22 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - async with default_registry.lock: + async with registry.lock: queues = itertools.chain( - default_registry.get_all_subscriptions(), - default_registry.get_all_tasks(flat=True), + registry.get_all_subscriptions(), registry.get_all_tasks(flat=True) ) for queue in queues: - await queue.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - default_registry.unregister_all_tasks() - queues = default_registry.get_all_results() + await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) + registry.unregister_all_subscriptions() + registry.unregister_all_tasks() + queues = registry.get_all_results() for queue in queues: await queue.unsubscribe_results() - default_registry.unregister_all_results() + registry.unregister_all_results() async def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """Subscribe a callback function to the result topic. @@ -228,19 +232,27 @@ async def subscribe_results( queue = get_queue(pkg) await queue.subscribe_results(callback) # register subscription to unsubscribe later - async with default_registry.lock: - default_registry.register_results(pkg, queue) + async with registry.lock: + registry.register_results(pkg, queue) return queue async def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_message_count() return count +async def get_consumer_count( + msg_type: "PBM | type[PBM] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = await q.get_consumer_count() + return count + + def default_log_callback(message: "IncomingMessageProtocol", msg_content: str) -> None: """Default callback for the logging service""" log.info( @@ -279,15 +291,15 @@ async def shutdown(signum: int) -> None: stop_event.set() for sig in (signal.SIGINT, signal.SIGTERM): - # Note: add_signal_handler requires a callback, so we use a lambda + def _handler(s: int) -> asyncio.Task[None]: return asyncio.create_task(shutdown(s)) loop.add_signal_handler(sig, _handler, sig) - log.info("Started. Press Ctrl+C to exit.") - # Wait here forever (non-blocking) until shutdown() is called + log.info("Protobunny started") await main() + # Wait here forever (non-blocking) until shutdown() is called await stop_event.wait() @@ -297,3 +309,5 @@ def _handler(s: int) -> asyncio.Task[None]: commons, results, ) + +####################################################### diff --git a/protobunny/asyncio/__init__.py.j2 b/protobunny/asyncio/__init__.py.j2 index 43065da..de61bf3 100644 --- a/protobunny/asyncio/__init__.py.j2 +++ b/protobunny/asyncio/__init__.py.j2 @@ -1,7 +1,3 @@ -{# @formatter:off #} -{# language: python #}import signal -import asyncio - """ A module providing support for messaging and communication using RabbitMQ as the backend. @@ -30,19 +26,22 @@ __all__ = [ "GENERATED_PACKAGE_NAME", "PACKAGE_NAME", "ROOT_GENERATED_PACKAGE_NAME", - "default_configuration", + "config", "RequeueMessage", "ConnectionError", "reset_connection", "connect", "disconnect", "run_forever", + "config_lib", # from .{{ generated_package_name }} {% for i in main_imports|sort %} "{{ i }}", {% endfor %} ] +import signal +import asyncio import inspect import itertools import logging @@ -53,49 +52,55 @@ from importlib.metadata import version ####################################################### -from ..config import ( # noqa - GENERATED_PACKAGE_NAME, - PACKAGE_NAME, - ROOT_GENERATED_PACKAGE_NAME, - default_configuration, +from ..conf import ( # noqa + GENERATED_PACKAGE_NAME, + PACKAGE_NAME, + ROOT_GENERATED_PACKAGE_NAME, + config, ) from ..exceptions import RequeueMessage, ConnectionError -from ..registry import default_registry +from ..registry import registry if tp.TYPE_CHECKING: - from ..models import LoggerCallback, ProtoBunnyMessage, AsyncCallback, IncomingMessageProtocol + from ..models import LoggerCallback, PBM, AsyncCallback, IncomingMessageProtocol from ..core.results import Result from types import ModuleType -from .backends import LoggingAsyncQueue, BaseAsyncQueue +from .backends import LoggingAsyncQueue, BaseAsyncQueue, BaseAsyncConnection from ..helpers import get_queue, get_backend - +from .. import config_lib as config_lib __version__ = version(PACKAGE_NAME) -############################ -# -- Async top-level methods -############################ log = logging.getLogger(PACKAGE_NAME) -async def reset_connection(): - backend = get_backend() - return await backend.connection.reset_connection() +############################ +# -- Async top-level methods +############################ + +async def connect(**kwargs) -> "BaseAsyncConnection": + """Get the singleton async connection.""" + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST, **kwargs) + return conn -async def connect(): - backend = get_backend() - return await backend.connection.connect() +async def disconnect() -> None: + connection_module = get_backend().connection + conn = await connection_module.Connection.get_connection(vhost=connection_module.VHOST) + await conn.disconnect() -async def disconnect(): - backend = get_backend() - return await backend.connection.disconnect() +async def reset_connection() -> "BaseAsyncConnection": + """Reset the singleton connection.""" + connection = await connect() + await connection.disconnect() + return await connect() -async def publish(message: "ProtoBunnyMessage") -> None: +async def publish(message: "PBM") -> None: """Asynchronously publish a message to its corresponding queue. Args: @@ -122,7 +127,7 @@ async def publish_result( async def subscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """ @@ -142,53 +147,52 @@ async def subscribe( # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ registry_key = str(pkg) - async with default_registry.lock: + async with registry.lock: if is_module_tasks(module_name): # It's a task. Handle multiple in-process subscriptions queue = get_queue(pkg) await queue.subscribe(callback) - default_registry.register_task(registry_key, queue) + registry.register_task(registry_key, queue) else: - # exclusive queue - queue = default_registry.get_subscription(registry_key) or get_queue(pkg) - # queue already exists, but not subscribed yet (otherwise raise ValueError) - await queue.subscribe(callback) - default_registry.register_subscription(registry_key, queue) + # exclusive queue, cannot register more than one callback + queue = registry.get_subscription(registry_key) or get_queue(pkg) + if not queue.subscription: + await queue.subscribe(callback) + registry.register_subscription(registry_key, queue) return queue async def unsubscribe( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", if_unused: bool = True, if_empty: bool = True, ) -> None: """Remove a subscription for a message/package""" - # obj = type(pkg) if isinstance(pkg, betterproto.Message) else pkg module_name = pkg.__name__ if inspect.ismodule(pkg) else pkg.__module__ - registry_key = default_registry.get_key(pkg) - async with default_registry.lock: + registry_key = registry.get_key(pkg) + async with registry.lock: if is_module_tasks(module_name): - queues = default_registry.get_tasks(registry_key) + queues = registry.get_tasks(registry_key) for q in queues: await q.unsubscribe(if_unused=if_unused) - default_registry.unregister_tasks(registry_key) + registry.unregister_tasks(registry_key) else: - queue = default_registry.get_subscription(registry_key) + queue = registry.get_subscription(registry_key) if queue: await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) - default_registry.unregister_subscription(registry_key) + registry.unregister_subscription(registry_key) async def unsubscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", ) -> None: """Remove all in-process subscriptions for a message/package result topic""" - async with default_registry.lock: - queue = default_registry.get_results(pkg) + async with registry.lock: + queue = registry.get_results(pkg) if queue: await queue.unsubscribe_results() - default_registry.unregister_results(pkg) + registry.unregister_results(pkg) async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None: @@ -198,22 +202,22 @@ async def unsubscribe_all(if_unused: bool = True, if_empty: bool = True) -> None This clears standard subscriptions, result subscriptions, and task subscriptions, effectively stopping all message consumption for this process. """ - async with default_registry.lock: + async with registry.lock: queues = itertools.chain( - default_registry.get_all_subscriptions(), default_registry.get_all_tasks(flat=True) + registry.get_all_subscriptions(), registry.get_all_tasks(flat=True) ) for queue in queues: - await queue.unsubscribe(if_unused=False, if_empty=False) - default_registry.unregister_all_subscriptions() - default_registry.unregister_all_tasks() - queues = default_registry.get_all_results() + await queue.unsubscribe(if_unused=if_unused, if_empty=if_empty) + registry.unregister_all_subscriptions() + registry.unregister_all_tasks() + queues = registry.get_all_results() for queue in queues: await queue.unsubscribe_results() - default_registry.unregister_all_results() + registry.unregister_all_results() async def subscribe_results( - pkg: "type[ProtoBunnyMessage] | ModuleType", + pkg: "type[PBM] | ModuleType", callback: "AsyncCallback", ) -> "BaseAsyncQueue": """Subscribe a callback function to the result topic. @@ -225,19 +229,27 @@ async def subscribe_results( queue = get_queue(pkg) await queue.subscribe_results(callback) # register subscription to unsubscribe later - async with default_registry.lock: - default_registry.register_results(pkg, queue) + async with registry.lock: + registry.register_results(pkg, queue) return queue async def get_message_count( - msg_type: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleType", + msg_type: "PBM | type[PBM] | ModuleType", ) -> int | None: q = get_queue(msg_type) count = await q.get_message_count() return count +async def get_consumer_count( + msg_type: "PBM | type[PBM] | ModuleType", +) -> int | None: + q = get_queue(msg_type) + count = await q.get_consumer_count() + return count + + def default_log_callback(message: "IncomingMessageProtocol", msg_content: str) -> None: """Default callback for the logging service""" log.info( @@ -276,15 +288,13 @@ async def _run_forever(main: tp.Callable[..., tp.Awaitable[None]]) -> None: stop_event.set() for sig in (signal.SIGINT, signal.SIGTERM): - # Note: add_signal_handler requires a callback, so we use a lambda def _handler(s: int) -> asyncio.Task[None]: return asyncio.create_task(shutdown(s)) - loop.add_signal_handler(sig, _handler, sig) - log.info("Started. Press Ctrl+C to exit.") - # Wait here forever (non-blocking) until shutdown() is called + log.info("Protobunny started") await main() + # Wait here forever (non-blocking) until shutdown() is called await stop_event.wait() @@ -295,3 +305,4 @@ from ..{{ generated_package_name }} import ( # noqa {{ i }}, {% endfor %} ) +####################################################### diff --git a/protobunny/asyncio/backends/__init__.py b/protobunny/asyncio/backends/__init__.py index 04c16a4..ee4817f 100644 --- a/protobunny/asyncio/backends/__init__.py +++ b/protobunny/asyncio/backends/__init__.py @@ -4,8 +4,9 @@ import typing as tp from abc import ABC, abstractmethod +from protobunny import asyncio as pb + from ...exceptions import RequeueMessage -from ...helpers import get_backend from ...models import ( AsyncCallback, BaseQueue, @@ -14,7 +15,7 @@ LoggerCallback, ProtoBunnyMessage, SyncCallback, - default_configuration, + config, deserialize_message, deserialize_result_message, get_body, @@ -41,7 +42,6 @@ class BaseConnection(ABC): exchange_name: str | None dl_exchange: str | None dl_queue: str | None - heartbeat: int | None timeout: int | None url: str | None = None queues: dict[str, tp.Any] = {} @@ -99,7 +99,7 @@ def get_consumer_count(self, topic: str) -> int | tp.Awaitable[int]: ... @abstractmethod - def setup_queue(self, topic: str, shared: bool) -> tp.Any | tp.Awaitable[tp.Any]: + def setup_queue(self, topic: str, shared: bool, **kwargs) -> tp.Any | tp.Awaitable[tp.Any]: ... @@ -108,11 +108,11 @@ class BaseAsyncConnection(BaseConnection, ABC): instance_by_vhost: dict[str, Self] = {} @abstractmethod - async def connect(self, timeout: float = 30) -> None: + async def connect(self, **kwargs) -> None: ... @abstractmethod - async def disconnect(self, timeout: float = 30) -> None: + async def disconnect(self, **kwargs) -> None: ... def __init__(self, **kwargs): @@ -140,7 +140,7 @@ def is_connected_event(self) -> asyncio.Event: ... @classmethod - async def get_connection(cls, vhost: str = "/") -> Self: + async def get_connection(cls, vhost: str = "/", **kwargs) -> Self: """Get singleton instance (async).""" current_loop = asyncio.get_running_loop() async with cls._get_class_lock(): @@ -158,7 +158,7 @@ async def get_connection(cls, vhost: str = "/") -> Self: log.debug("Creating fresh connection for %s", vhost) new_instance = cls(vhost=vhost) new_instance._loop = current_loop # Store the loop it was born in - await new_instance.connect() + await new_instance.connect(**kwargs) cls.instance_by_vhost[vhost] = new_instance instance = new_instance return instance @@ -170,8 +170,7 @@ def is_connected(self) -> bool: class BaseAsyncQueue(BaseQueue, ABC): async def get_connection(self) -> "BaseAsyncConnection": - backend = get_backend() - return await backend.connection.connect() + return await pb.connect() async def publish(self, message: ProtoBunnyMessage) -> None: """Publish a message to the queue. @@ -190,7 +189,7 @@ async def _receive( callback: a callable accepting a message as only argument. message: the IncomingMessageProtocol object received from the queue. """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if not message.routing_key: raise ValueError("Routing key was not set. Invalid topic") if message.routing_key == self.result_topic or message.routing_key.endswith( @@ -344,12 +343,11 @@ async def send_message( Returns: """ - backend = get_backend() message = Envelope( body=body, correlation_id=correlation_id or b"", ) - conn = await backend.connection.connect() + conn = await pb.connect() await conn.publish(topic, message) @@ -380,10 +378,10 @@ class LoggingAsyncQueue(BaseAsyncQueue): """ def __init__(self, prefix: str) -> None: - backend = default_configuration.backend_config + backend = config.backend_config delimiter = backend.topic_delimiter wildcard = backend.multi_wildcard_delimiter - prefix = prefix or default_configuration.messages_prefix + prefix = prefix or config.messages_prefix super().__init__(f"{prefix}{delimiter}{wildcard}") def get_tag(self) -> str: @@ -434,7 +432,16 @@ async def _receive( def is_task(topic: str) -> bool: - delimiter = default_configuration.backend_config.topic_delimiter + """ + Use the backend configured delimiter to check if `tasks` is in it + + Args: + topic: the topic to check + + + Returns: True if tasks is in the topic, else False + """ + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) diff --git a/protobunny/asyncio/backends/mosquitto/connection.py b/protobunny/asyncio/backends/mosquitto/connection.py index 4036dad..068ada7 100644 --- a/protobunny/asyncio/backends/mosquitto/connection.py +++ b/protobunny/asyncio/backends/mosquitto/connection.py @@ -9,7 +9,7 @@ import aiomqtt import can_ada -from ....config import default_configuration +from ....conf import config from ....exceptions import ConnectionError, PublishError, RequeueMessage from ....models import Envelope, IncomingMessageProtocol from .. import BaseAsyncConnection @@ -20,24 +20,6 @@ VHOST = os.environ.get("MOSQUITTO_VHOST", "/") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async Mosquitto Connection wrapper using aiomqtt.""" @@ -84,8 +66,8 @@ def __init__( self._connection: aiomqtt.Client | None = None self._instance_lock: asyncio.Lock | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter # e.g. "/" - self._exchange = default_configuration.backend_config.namespace # used as root prefix + self._delimiter = config.backend_config.topic_delimiter # e.g. "/" + self._exchange = config.backend_config.namespace # used as root prefix @property def is_connected_event(self) -> asyncio.Event: @@ -104,7 +86,7 @@ def build_topic_key(self, topic: str) -> str: """Joins project prefix with topic using the configured delimiter.""" return f"{self._exchange}{self._delimiter}{topic}" - async def connect(self, timeout: float = 30.0) -> "Connection": + async def connect(self, **kwargs) -> "Connection": async with self.lock: if self.is_connected(): return self @@ -116,7 +98,7 @@ async def connect(self, timeout: float = 30.0) -> "Connection": port=self.port, username=self.username, password=self.password, - timeout=timeout, + **kwargs, ) # We enter the context manager to establish the connection await self._connection.__aenter__() diff --git a/protobunny/asyncio/backends/nats/__init__.py b/protobunny/asyncio/backends/nats/__init__.py new file mode 100644 index 0000000..d45bfe6 --- /dev/null +++ b/protobunny/asyncio/backends/nats/__init__.py @@ -0,0 +1 @@ +from . import connection, queues # noqa diff --git a/protobunny/asyncio/backends/nats/connection.py b/protobunny/asyncio/backends/nats/connection.py new file mode 100644 index 0000000..55baa4d --- /dev/null +++ b/protobunny/asyncio/backends/nats/connection.py @@ -0,0 +1,422 @@ +"""Implements a NATS Connection""" +import asyncio +import functools +import logging +import os +import typing as tp +import urllib.parse +import uuid +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor + +import can_ada +import nats +from nats.aio.msg import Msg +from nats.aio.subscription import Subscription +from nats.errors import ConnectionClosedError, TimeoutError +from nats.js.errors import BadRequestError, NoStreamResponseError + +from ....conf import config +from ....exceptions import ConnectionError, PublishError, RequeueMessage +from ....models import Envelope, IncomingMessageProtocol +from .. import BaseAsyncConnection, is_task + +log = logging.getLogger(__name__) + +VHOST = os.environ.get("NATS_VHOST", "/") + + +class Connection(BaseAsyncConnection): + """Async NATS Connection wrapper.""" + + _lock: asyncio.Lock | None = None + instance_by_vhost: dict[str, "Connection | None"] = {} + + def __init__( + self, + username: str | None = None, + password: str | None = None, + host: str | None = None, + port: int | None = None, + vhost: str = "", + url: str | None = None, + worker_threads: int = 2, + prefetch_count: int = 1, + requeue_delay: int = 3, + heartbeat: int = 1200, + ): + """Initialize NATS connection. + + Args: + username: NATS username + password: NATS password + host: NATS host + port: NATS port + url: NATS URL. It will override username, password, host and port + vhost: NATS virtual host (it's used as db number string) + worker_threads: number of concurrent callback workers to use + prefetch_count: how many messages to prefetch from the queue + requeue_delay: how long to wait before re-queueing a message (seconds) + """ + super().__init__() + uname = username or os.environ.get("NATS_USERNAME", "") + passwd = password or os.environ.get("NATS_PASSWORD", "") + host = host or os.environ.get("NATS_HOST", "localhost") + port = port or int(os.environ.get("NATS_PORT", "4222")) + # URL encode credentials and vhost to prevent injection + vhost = vhost or VHOST + self.vhost = vhost + username = urllib.parse.quote(uname, safe="") + password = urllib.parse.quote(passwd, safe="") + host = urllib.parse.quote(host, safe="") + # URL for connection + url = url or os.environ.get("NATS_URL", "") + if url: + # reconstruct url for safety + parsed = can_ada.parse(url) + url = f"{parsed.protocol}//{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" + else: + # Build the URL based on what is available + if username and password: + url = f"nats://{username}:{password}@{host}:{port}{vhost}" + elif password: + url = f"nats://:{password}@{host}:{port}{vhost}" + elif username: + url = f"nats://{username}@{host}:{port}{vhost}" + else: + url = f"nats://{host}:{port}{vhost}" + + self._url = url + self._connection: nats.NATS | None = None + self.prefetch_count = prefetch_count + self.requeue_delay = requeue_delay + self.heartbeat = heartbeat + self.queues: dict[str, list[dict]] = defaultdict(list) + self.consumers: dict[str, dict] = {} + # to run sync callbacks + self.executor = ThreadPoolExecutor(max_workers=worker_threads) + self._instance_lock: asyncio.Lock | None = None + self._delimiter = config.backend_config.topic_delimiter + self._namespace = config.backend_config.namespace + self._tasks_subject_prefix = "TASKS" + self._stream_name = f"{self._namespace.upper()}_{self._tasks_subject_prefix}" + + async def __aenter__(self) -> "Connection": + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + await self.disconnect() + return False + + @property + def lock(self) -> asyncio.Lock: + """Lazy instance lock.""" + if self._instance_lock is None: + self._instance_lock = asyncio.Lock() + return self._instance_lock + + @classmethod + def _get_class_lock(cls) -> asyncio.Lock: + """Ensure the class lock is bound to the current running loop.""" + if cls._lock is None: + cls._lock = asyncio.Lock() + return cls._lock + + def build_topic_key(self, topic: str) -> str: + return f"{self._namespace}.{topic}" + + @property + def is_connected_event(self) -> asyncio.Event: + """Lazily create the event in the current running loop.""" + if self._is_connected_event is None: + self._is_connected_event = asyncio.Event() + return self._is_connected_event + + @property + def connection(self) -> "nats.NATS": + """Get the connection object. + + Raises: + ConnectionError: If not connected + """ + if not self._connection: + raise ConnectionError("Connection not initialized. Call connect() first.") + return self._connection + + async def connect(self, **kwargs) -> "Connection": + """Establish NATS connection. + + Args: + + Raises: + ConnectionError: If connection fails + asyncio.TimeoutError: If connection times out + """ + async with self.lock: + if self.instance_by_vhost.get(self.vhost) and self.is_connected(): + return self.instance_by_vhost[self.vhost] + try: + log.info("Establishing NATS connection to %s", self._url.split("@")[-1]) + self._connection = await nats.connect(self._url, **kwargs) + self.is_connected_event.set() + log.info("Successfully connected to NATS") + self.instance_by_vhost[self.vhost] = self + if config.use_tasks_in_nats: + # Create the jetstream if not existing + js = self._connection.jetstream() + # For NATS, tasks package can only be at first level after main package library + # Warning: don't bury tasks messages after three levels of hierarchy + task_patterns = [ + f"{self._tasks_subject_prefix}{self._delimiter}>", + ] + try: + await js.add_stream( + name=self._stream_name, + subjects=task_patterns, + ) + except BadRequestError: + # This usually means the stream already exists with a different config + log.warning("Stream %s exists with different settings.", self._stream_name) + return self + + except asyncio.TimeoutError as e: + log.error("NATS connection timeout") + self.is_connected_event.clear() + self._connection = None + raise ConnectionError(f"Failed to connect to NATS: {e}") from e + except Exception as e: + self.is_connected_event.clear() + self._connection = None + log.exception("Failed to establish NATS connection") + raise ConnectionError(f"Failed to connect to NATS: {e}") from e + + async def disconnect(self, timeout: float = 10.0) -> None: + """Close NATS connection and cleanup resources. + + Args: + timeout: Maximum time to wait for cleanup (seconds) + """ + async with self.lock: + if not self.is_connected(): + log.debug("Already disconnected from NATS") + return + + try: + log.info("Closing NATS connection") + # Cancel all subscriptions + for tag, consumer in self.consumers.items(): + subscription = consumer["subscription"] + try: + await subscription.unsubscribe() + # We give the task a moment to wrap up if needed + # await asyncio.sleep(0) # force context switching + # await asyncio.wait([task], timeout=2.0) + except Exception as e: + log.warning("Error stopping NATS subscription %s: %s", tag, e) + + # Shutdown Thread Executor (if used for sync callbacks) + self.executor.shutdown(wait=False, cancel_futures=True) + + # Close the NATS Connection Pool + if self._connection: + await asyncio.wait_for(self._connection.close(), timeout=timeout) + + except asyncio.TimeoutError: + log.warning("NATS connection close timeout after %.1f seconds", timeout) + except Exception: + log.exception("Error during NATS disconnect") + finally: + # Reset state + self._connection = None + self.queues.clear() # (Local queue metadata) + self.consumers.clear() + self.is_connected_event.clear() + # Remove from registry + Connection.instance_by_vhost.pop(self.vhost, None) + log.info("NATS connection closed") + + # Subscriptions methods + async def setup_queue( + self, topic: str, shared: bool, callback: tp.Callable | None = None + ) -> Subscription: + topic_key = self.build_topic_key(topic) + cb = functools.partial(self._nats_handler, callback) + if shared: + js = self._connection.jetstream() + # We use a durable name so multiple instances share the same task state + group_name = topic_key.replace(".", "_") + log.debug( + "Subscribing shared worker to JetStream group %s subject %s", group_name, topic_key + ) + subscription = await js.subscribe( + # the topic with prefixes + subject=f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}", + # add queue parameter to flag it as a distributed queue + queue=group_name, + durable=group_name, + cb=cb, + manual_ack=True, + stream=self._stream_name, + ) + else: + log.debug("Subscribing broadcast listener to NATS Core: %s", topic_key) + subscription = await self._connection.subscribe(subject=topic_key, cb=cb) + return subscription + + async def subscribe(self, topic: str, callback: tp.Callable, shared: bool = False) -> str: + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + + topic_key = self.build_topic_key(topic) + sub_tag = f"{topic_key}_{uuid.uuid4().hex[:8]}" + subscription = await self.setup_queue(topic, shared, callback) + self.consumers[sub_tag] = { + "subscription": subscription, + "topic": topic_key, + "is_shared": shared, + } + return sub_tag + + async def _nats_handler(self, callback, msg: Msg): + """Callback that handles the Msg object pushed from NATS""" + topic = msg.subject + is_shared_queue = is_task(topic) + reply = msg.reply + body = msg.data + # Remove the 'TASKS.' prefix that was added to match the filtering stream group + if is_shared_queue: + topic = topic.removeprefix(f"{self._tasks_subject_prefix}{self._delimiter}") + + # The routing key is the string used to match the protobuf python class fqn + routing_key = topic.removeprefix(f"{self._namespace}{self._delimiter}") + envelope = Envelope(body=body, correlation_id=reply, routing_key=routing_key) + try: + if asyncio.iscoroutinefunction(callback): + await callback(envelope) + else: + # Run the callback in a thread pool to avoid blocking the event loop + await asyncio.get_event_loop().run_in_executor(self.executor, callback, envelope) + if is_shared_queue: + await msg.ack() + except RequeueMessage: + log.warning("Requeuing message on topic '%s' after RequeueMessage exception", topic) + await asyncio.sleep(self.requeue_delay) + if not is_shared_queue: + await self._connection.publish(topic, body, reply=reply) + else: + await msg.nak(self.requeue_delay) + except Exception: + log.exception("Callback failed for topic %s", topic) + await msg.term() # avoid retry logic for potentially poisoning messages + + async def unsubscribe(self, tag: str, **kwargs) -> None: + if tag not in self.consumers: + return + sub_info = self.consumers[tag] + await sub_info["subscription"].unsubscribe() + del sub_info["subscription"] + log.info("Unsubscribed from %s", sub_info["topic"]) + self.consumers.pop(tag) + + async def publish( + self, + topic: str, + message: "IncomingMessageProtocol", + **kwargs, + ) -> None: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + topic_key = self.build_topic_key(topic) + is_shared = is_task(topic) + + # Standardize headers + headers = {"correlation_id": message.correlation_id} if message.correlation_id else None + + try: + if is_shared: + # Persistent "Task" publishing via JetStream + stream_key = f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}" + log.debug("Publishing persistent task to NATS JetStream: %s", topic_key) + js = self._connection.jetstream() + await js.publish(subject=stream_key, payload=message.body, headers=headers) + if config.log_task_in_nats: + # The logger service doesn't use jetstream so we re-publish on a normal pubsub + # (it won't be re-catched by the tasks consumer) + log.debug("Publishing logging message for task to NATS Core: %s", topic_key) + await self._connection.publish( + subject=topic_key, payload=message.body, headers=headers + ) + else: + # Volatile "PubSub" publishing via NATS Core + log.debug("Publishing broadcast to NATS Core: %s", topic_key) + await self._connection.publish( + subject=topic_key, payload=message.body, headers=headers + ) + except (ConnectionClosedError, TimeoutError, NoStreamResponseError, Exception) as e: + log.error("NATS publish failed: %s", e) + raise PublishError(str(e)) from e + + async def purge(self, topic: str, reset_groups: bool = False) -> None: + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + topic_key = self.build_topic_key(topic) + subject = f"{self._tasks_subject_prefix}{self._delimiter}{topic_key}" + # NATS purges messages matching a subject within the stream + try: + jsm = self._connection.jsm() # Get JetStream Management context + + log.info("Purging NATS subject '%s' from stream %s", topic, self._stream_name) + await jsm.purge_stream(self._stream_name, subject=subject) + + if reset_groups: + # In NATS, we must find consumers specifically tied to this topic + # Protobunny convention: durable name includes the topic + group_name = f"{topic_key.replace('.', '_')}" + try: + await jsm.delete_consumer(self._stream_name, group_name) + log.debug("Deleted NATS durable consumer: %s", group_name) + except nats.js.errors.NotFoundError: + pass # Consumer already gone + + except Exception as e: + log.error("Failed to purge NATS subject %s: %s", topic, e) + raise ConnectionError(f"Purge failed: {e}") + + async def get_message_count(self, topic: str) -> int: + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + topic_key = self.build_topic_key(topic) + try: + jsm = self._connection.jsm() + stream_info = await jsm.stream_info(self._stream_name, subjects_filter=topic) + return stream_info.state.messages + except nats.js.errors.NotFoundError: + return 0 + except Exception as e: + log.error("Failed to get NATS message count for %s: %s", topic_key, e) + return 0 + + async def get_consumer_count(self, topic: str) -> int: + topic_key = self.build_topic_key(topic) + if not is_task(topic): + raise ValueError("Purge only supported for tasks") + async with self.lock: + if not self.is_connected(): + raise ConnectionError("Not connected to NATS") + try: + jsm = self._connection.jsm() + stream_info = await jsm.stream_info(self._stream_name, subjects_filter=topic) + return stream_info.state.consumer_count + except nats.js.errors.NotFoundError: + return 0 + except Exception as e: + log.error("Failed to get NATS consumer count for %s: %s", topic_key, e) + return 0 diff --git a/protobunny/asyncio/backends/nats/queues.py b/protobunny/asyncio/backends/nats/queues.py new file mode 100644 index 0000000..4777c82 --- /dev/null +++ b/protobunny/asyncio/backends/nats/queues.py @@ -0,0 +1,11 @@ +import logging + +from .. import ( + BaseAsyncQueue, +) + +log = logging.getLogger(__name__) + + +class AsyncQueue(BaseAsyncQueue): + pass diff --git a/protobunny/asyncio/backends/python/connection.py b/protobunny/asyncio/backends/python/connection.py index 5073282..c4027ee 100644 --- a/protobunny/asyncio/backends/python/connection.py +++ b/protobunny/asyncio/backends/python/connection.py @@ -7,7 +7,7 @@ from collections import defaultdict from queue import Empty, Queue -from ....config import default_configuration +from ....conf import config from ....models import AsyncCallback, Envelope from ... import RequeueMessage from .. import BaseConnection, is_task @@ -125,7 +125,7 @@ def __init__(self, vhost: str = "/", requeue_delay: int = 3): self.requeue_delay = requeue_delay self._is_connected = False self._subscriptions: dict[str, dict] = {} - self.logger_prefix = default_configuration.logger_prefix + self.logger_prefix = config.logger_prefix class Connection(BaseLocalConnection): @@ -327,19 +327,3 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.disconnect() - - -# Convenience functions -async def connect() -> Connection: - return await Connection.get_connection(vhost=VHOST) - - -async def reset_connection() -> Connection: - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() diff --git a/protobunny/asyncio/backends/rabbitmq/connection.py b/protobunny/asyncio/backends/rabbitmq/connection.py index 772051c..62bdbb3 100644 --- a/protobunny/asyncio/backends/rabbitmq/connection.py +++ b/protobunny/asyncio/backends/rabbitmq/connection.py @@ -9,6 +9,7 @@ from concurrent.futures import ThreadPoolExecutor import aio_pika +import can_ada from aio_pika.abc import ( AbstractChannel, AbstractExchange, @@ -24,24 +25,6 @@ VHOST = os.environ.get("RABBITMQ_VHOST", "/") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async RabbitMQ Connection wrapper.""" @@ -97,7 +80,8 @@ def __init__( if not url: self._url = f"amqp://{username}:{password}@{host}:{port}/{clean_vhost}?heartbeat={heartbeat}&timeout={timeout}&fail_fast=no" else: - self._url = url + parsed = can_ada.parse(url) + self._url = f"amqp://{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" self._exchange_name = exchange_name self._dl_exchange = dl_exchange self._dl_queue = dl_queue diff --git a/protobunny/asyncio/backends/rabbitmq/queues.py b/protobunny/asyncio/backends/rabbitmq/queues.py index e516b2f..9788d68 100644 --- a/protobunny/asyncio/backends/rabbitmq/queues.py +++ b/protobunny/asyncio/backends/rabbitmq/queues.py @@ -3,10 +3,12 @@ import aio_pika from aio_pika import DeliveryMode +from protobunny import asyncio as pb from protobunny.asyncio.backends import ( BaseAsyncQueue, ) -from protobunny.asyncio.backends.rabbitmq.connection import connect + +# from protobunny.asyncio.backends.rabbitmq.connection import connect log = logging.getLogger(__name__) @@ -35,5 +37,5 @@ async def send_message( correlation_id=correlation_id, delivery_mode=DeliveryMode.PERSISTENT if persistent else DeliveryMode.NOT_PERSISTENT, ) - conn = await connect() + conn = await pb.connect() await conn.publish(topic, message) diff --git a/protobunny/asyncio/backends/redis/connection.py b/protobunny/asyncio/backends/redis/connection.py index b1bc29a..1aa166e 100644 --- a/protobunny/asyncio/backends/redis/connection.py +++ b/protobunny/asyncio/backends/redis/connection.py @@ -9,10 +9,11 @@ import uuid from concurrent.futures import ThreadPoolExecutor +import can_ada import redis.asyncio as redis from redis import RedisError, ResponseError -from ....config import default_configuration +from ....conf import config from ....exceptions import ConnectionError, PublishError, RequeueMessage from ....models import Envelope, IncomingMessageProtocol from .. import BaseAsyncConnection, is_task @@ -22,24 +23,6 @@ VHOST = os.environ.get("REDIS_VHOST") or os.environ.get("REDIS_DB", "0") -async def connect() -> "Connection": - """Get the singleton async connection.""" - conn = await Connection.get_connection(vhost=VHOST) - return conn - - -async def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = await connect() - await connection.disconnect() - return await connect() - - -async def disconnect() -> None: - connection = await connect() - await connection.disconnect() - - class Connection(BaseAsyncConnection): """Async Redis Connection wrapper.""" @@ -58,7 +41,7 @@ def __init__( worker_threads: int = 2, prefetch_count: int = 1, requeue_delay: int = 3, - heartbeat: int = 1200, + **kwargs, ): """Initialize Redis connection. @@ -68,7 +51,7 @@ def __init__( host: Redis host port: Redis port url: Redis URL. It will override username, password, host and port - vhost: Redis virtual host (it's used as db number string) + vhost: Redis virtual host (it's used as db number string if db not present) db: Redis database number worker_threads: number of concurrent callback workers to use prefetch_count: how many messages to prefetch from the queue @@ -98,24 +81,28 @@ def __init__( url = f"redis://{username}:{password}@{host}:{port}/{vhost}?protocol=3" elif password: url = f"redis://:{password}@{host}:{port}/{vhost}?protocol=3" + elif username: + url = f"redis://{username}@{host}:{port}/{vhost}?protocol=3" else: url = f"redis://{host}:{port}/{vhost}?protocol=3" + else: + parsed = can_ada.parse(url) + url = f"redis://{parsed.username}:{parsed.password}@{parsed.host}{parsed.pathname}{parsed.search}" self._url = url self._connection: redis.Redis | None = None self.prefetch_count = prefetch_count self.requeue_delay = requeue_delay - self.heartbeat = heartbeat self.queues: dict[str, dict] = {} self.consumers: dict[str, dict[str, tp.Any]] = {} self.executor = ThreadPoolExecutor(max_workers=worker_threads) self._instance_lock: asyncio.Lock | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter - self._exchange = default_configuration.backend_config.namespace + self._delimiter = config.backend_config.topic_delimiter + self._exchange = config.backend_config.namespace - async def __aenter__(self) -> "Connection": - await self.connect() + async def __aenter__(self, **kwargs) -> "Connection": + await self.connect(**kwargs) return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: @@ -139,57 +126,6 @@ def _get_class_lock(cls) -> asyncio.Lock: def build_topic_key(self, topic: str) -> str: return f"{self._exchange}:{topic}" - async def disconnect(self, timeout: float = 10.0) -> None: - """Close Redis connection and cleanup resources. - - Args: - timeout: Maximum time to wait for cleanup (seconds) - """ - async with self.lock: - if not self.is_connected(): - log.debug("Already disconnected from Redis") - return - - try: - log.info("Closing Redis connection") - - # In Redis, consumers are local asyncio Tasks. - # We cancel them here. Note: Redis doesn't have "exclusive queues" - # that auto-delete, so we just clear our local registry. - for tag, consumer in self.consumers.items(): - task = consumer["task"] - try: - task.cancel() - # We give the task a moment to wrap up if needed - await asyncio.sleep(0) # force context switching - await asyncio.wait([task], timeout=2.0) - except Exception as e: - log.warning("Error stopping Redis consumer %s: %s", tag, e) - - # Shutdown Thread Executor (if used for sync callbacks) - self.executor.shutdown(wait=False, cancel_futures=True) - - # Close the Redis Connection Pool - if self._connection: - # aclose() closes the connection pool and all underlying connections - await asyncio.wait_for(self._connection.aclose(), timeout=timeout) - - except asyncio.TimeoutError: - log.warning("Redis connection close timeout after %.1f seconds", timeout) - except Exception: - log.exception("Error during Redis disconnect") - finally: - # Reset state - self._connection = None - self._exchange = None # (Stream name/prefix) - self.queues.clear() # (Local queue metadata) - self.consumers.clear() - self.is_connected_event.clear() - - # 5. Remove from registry - Connection.instance_by_vhost.pop(self.vhost, None) - log.info("Redis connection closed") - @property def is_connected_event(self) -> asyncio.Event: """Lazily create the event in the current running loop.""" @@ -208,11 +144,10 @@ def connection(self) -> "redis.Redis": raise ConnectionError("Connection not initialized. Call connect() first.") return self._connection - async def connect(self, timeout: float = 30.0) -> "Connection": + async def connect(self, **kwargs) -> "Connection": """Establish Redis connection. Args: - timeout: Maximum time to wait for connection establishment (seconds) Raises: ConnectionError: If connection fails @@ -224,23 +159,23 @@ async def connect(self, timeout: float = 30.0) -> "Connection": try: # Parsing URL for logging (removing credentials) log.info("Establishing Redis connection to %s", self._url.split("@")[-1]) - + # protobunny sends raw bytes with protobuf serialized payloads + kwargs.pop("decode_responses", None) # Using from_url handles connection pooling automatically self._connection = redis.from_url( self._url, decode_responses=False, - socket_connect_timeout=timeout, - health_check_interval=self.heartbeat, + **kwargs, ) - await asyncio.wait_for(self._connection.ping(), timeout=timeout) + await asyncio.wait_for(self._connection.ping(), timeout=30) self.is_connected_event.set() log.info("Successfully connected to Redis") self.instance_by_vhost[self.vhost] = self return self except asyncio.TimeoutError: - log.error("Redis connection timeout after %.1f seconds", timeout) + log.error("Redis connection timeout after %.1f seconds", 30) self.is_connected_event.clear() self._connection = None raise @@ -253,6 +188,54 @@ async def connect(self, timeout: float = 30.0) -> "Connection": log.exception("Failed to establish Redis connection") raise ConnectionError(f"Failed to connect to Redis: {e}") from e + async def disconnect(self, timeout: float = 10.0) -> None: + """Close Redis connection and cleanup resources. + + Args: + timeout: Maximum time to wait for cleanup (seconds) + """ + async with self.lock: + if not self.is_connected(): + log.debug("Already disconnected from Redis") + return + + try: + log.info("Closing Redis connection") + # In Redis, consumers are local asyncio Tasks. + # We cancel them here. Note: Redis doesn't have "exclusive queues" + # that auto-delete, so we just clear our local registry. + for tag, consumer in self.consumers.items(): + task = consumer["task"] + try: + task.cancel() + # We give the task a moment to wrap up if needed + await asyncio.sleep(0) # force context switching + await asyncio.wait([task], timeout=2.0) + except Exception as e: + log.warning("Error stopping Redis consumer %s: %s", tag, e) + + # Shutdown Thread Executor (if used for sync callbacks) + self.executor.shutdown(wait=False, cancel_futures=True) + + # Close the Redis Connection Pool + if self._connection: + # aclose() closes the connection pool and all underlying connections + await asyncio.wait_for(self._connection.aclose(), timeout=timeout) + + except asyncio.TimeoutError: + log.warning("Redis connection close timeout after %.1f seconds", timeout) + except Exception: + log.exception("Error during Redis disconnect") + finally: + # Reset state + self._connection = None + self.queues.clear() # (Local queue metadata) + self.consumers.clear() + self.is_connected_event.clear() + # Remove from registry + Connection.instance_by_vhost.pop(self.vhost, None) + log.info("Redis connection closed") + async def subscribe(self, topic: str, callback: tp.Callable, shared: bool = False) -> str: """Subscribe to Redis. @@ -507,11 +490,11 @@ async def publish( "topic": topic, # add the topic here to implement topic exchange patterns } await self._connection.xadd(name=topic_key, fields=payload, maxlen=1000) - if default_configuration.log_task_in_redis: + if config.log_task_in_redis: # Tasks messages go to streams but the logger do a simple pubsub psubscription to .* # Send the message to the same topic with redis.publish so it appears there # Note: this should be used carefully (e.g. only for debugging) - # as it doubles the network calls for publishing tasks + # as it doubles the network calls when publishing tasks log.debug("Publishing message to topic: %s for logger", topic_key) await self._connection.publish(topic_key, message.body) else: @@ -523,7 +506,7 @@ async def publish( async def _on_message_task( self, stream_key: str, group_name: str, msg_id: str, payload: dict, callback: tp.Callable ): - """Wraps the user callback to simulate RabbitMQ behavior.""" + """Wraps the user callback.""" # Response is not decoded because we use bytes. But the keys will be bytes as well normalized_payload = { k.decode() if isinstance(k, bytes) else k: v for k, v in payload.items() @@ -556,12 +539,12 @@ async def _on_message_task( await asyncio.sleep(self.requeue_delay) await self._connection.xadd(name=stream_key, fields=payload) await self._connection.xack(stream_key, group_name, msg_id) - except Exception as e: + except Exception: log.exception("Callback failed for message %s", msg_id) - raise PublishError(f"Failed to publish message to topic {topic}: {e}") from e # Avoid poisoning messages # Note: In Redis, if you don't XACK, the message stays in the # Pending Entry List (PEL) for retry logic. + await self._connection.xack(stream_key, group_name, msg_id) async def unsubscribe(self, tag: str, if_unused: bool = True, if_empty: bool = True) -> None: task_to_cancel = None diff --git a/protobunny/backends/__init__.py b/protobunny/backends/__init__.py index 5e40154..5604490 100644 --- a/protobunny/backends/__init__.py +++ b/protobunny/backends/__init__.py @@ -5,15 +5,16 @@ import typing as tp from abc import ABC, abstractmethod +import protobunny as pb + from ..exceptions import RequeueMessage -from ..helpers import get_backend from ..models import ( BaseQueue, IncomingMessageProtocol, LoggerCallback, ProtoBunnyMessage, SyncCallback, - default_configuration, + config, deserialize_message, deserialize_result_message, get_body, @@ -66,7 +67,7 @@ def is_connected(self) -> bool | tp.Awaitable[bool]: ... @abstractmethod - def connect(self, timeout: float = 30) -> None | tp.Awaitable[None]: + def connect(self, **kwargs) -> None | tp.Awaitable[None]: ... @abstractmethod @@ -95,6 +96,10 @@ def get_consumer_count(self, topic: str) -> int | tp.Awaitable[int]: def setup_queue(self, topic: str, shared: bool) -> tp.Any | tp.Awaitable[tp.Any]: ... + @abstractmethod + def build_topic_key(self, topic: str) -> str: + ... + class BaseAsyncConnection(BaseConnection, ABC): instance_by_vhost: dict[str, "BaseAsyncConnection"] @@ -128,12 +133,16 @@ def __init__(self, **kwargs): self.vhost = self._async_conn.vhost self._started = False self.instance_by_vhost = {} + self._timeout_coro = 10 def get_async_connection(self, **kwargs) -> "BaseAsyncConnection": if hasattr(self, "_async_conn"): return self._async_conn return self.async_class(**kwargs) + def build_topic_key(self, topic: str) -> str: + pass + def _run_loop(self) -> None: """Run the event loop in a dedicated thread.""" loop = None @@ -362,7 +371,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect() return False - def connect(self, timeout: float = 10.0) -> None: + def connect(self, **kwargs) -> None: """Establish Sync connection. Args: @@ -372,7 +381,7 @@ def connect(self, timeout: float = 10.0) -> None: ConnectionError: If connection fails TimeoutError: If connection times out """ - self._run_coro(self._async_conn.connect(timeout), timeout=timeout) + self._run_coro(self._async_conn.connect(), timeout=self._timeout_coro) self.__class__.instance_by_vhost[self.vhost] = self def disconnect(self, timeout: float = 10.0) -> None: @@ -404,8 +413,7 @@ def disconnect(self, timeout: float = 10.0) -> None: class BaseSyncQueue(BaseQueue, ABC): def get_connection(self) -> BaseConnection: - backend = get_backend() - return backend.connection.connect() + return pb.connect() def publish(self, message: "ProtoBunnyMessage") -> None: """Publish a message to the queue. @@ -429,7 +437,7 @@ def publish_result( correlation_id: """ result_topic = topic or self.result_topic - log.info("Publishing result to: %s", result_topic) + log.debug("Publishing result to: %s", result_topic) self.send_message( result_topic, bytes(result), correlation_id=correlation_id, persistent=False ) @@ -445,7 +453,7 @@ def _receive( """ if not message.routing_key: raise ValueError("Routing key was not set. Invalid topic") - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if message.routing_key == self.result_topic or message.routing_key.endswith( f"{delimiter}result" ): @@ -453,10 +461,12 @@ def _receive( # In case the subscription has .# as binding key, # this method catches also results message for all the topics in that namespace. return - # msg: "ProtoBunnyMessage" = deserialize_message(message.routing_key, message.body) + try: callback(deserialize_message(message.routing_key, message.body)) except RequeueMessage: + # The callback raised a RequeueMessage exception + # a result message will be published by the specific Connection implementation raise except Exception as exc: # pylint: disable=W0703 log.exception("Could not process message: %s", str(message.body)) @@ -601,10 +611,10 @@ async def send_message( raise NotImplementedError() def __init__(self, prefix: str) -> None: - backend = default_configuration.backend_config + backend = config.backend_config delimiter = backend.topic_delimiter wildcard = backend.multi_wildcard_delimiter - prefix = prefix or default_configuration.messages_prefix + prefix = prefix or config.messages_prefix super().__init__(f"{prefix}{delimiter}{wildcard}") @property @@ -652,7 +662,7 @@ def get_tag(self) -> str: def is_task(topic: str) -> bool: - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) diff --git a/protobunny/backends/mosquitto/connection.py b/protobunny/backends/mosquitto/connection.py index c5556cd..0ecc6b5 100644 --- a/protobunny/backends/mosquitto/connection.py +++ b/protobunny/backends/mosquitto/connection.py @@ -11,30 +11,11 @@ import can_ada import paho.mqtt.client as mqtt -from ...config import default_configuration +from ...conf import config from ...exceptions import ConnectionError, PublishError, RequeueMessage from ...models import Envelope, IncomingMessageProtocol from .. import BaseConnection - -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - log = logging.getLogger(__name__) VHOST = os.environ.get("MOSQUITTO_VHOST", "/") @@ -123,20 +104,21 @@ def __init__( self._main_client: mqtt.Client | None = None - self._delimiter = default_configuration.backend_config.topic_delimiter - self._exchange = default_configuration.backend_config.namespace + self._delimiter = config.backend_config.topic_delimiter + self._exchange = config.backend_config.namespace def build_topic_key(self, topic: str) -> str: return f"{self._exchange}{self._delimiter}{topic}" - def connect(self, timeout: float = 10.0) -> "Connection": + def connect(self, **kwargs) -> "Connection": with self._lock: if self.is_connected(): return self try: # Use MQTT v5 for shared subscriptions support + kwargs.pop("callback_api_version", None) self._main_client = mqtt.Client( - callback_api_version=mqtt.CallbackAPIVersion.VERSION2 + callback_api_version=mqtt.CallbackAPIVersion.VERSION2, **kwargs ) if self.username: self._main_client.username_pw_set(self.username, self.password) diff --git a/protobunny/backends/nats/__init__.py b/protobunny/backends/nats/__init__.py new file mode 100644 index 0000000..d45bfe6 --- /dev/null +++ b/protobunny/backends/nats/__init__.py @@ -0,0 +1 @@ +from . import connection, queues # noqa diff --git a/protobunny/backends/nats/connection.py b/protobunny/backends/nats/connection.py new file mode 100644 index 0000000..0f89a8a --- /dev/null +++ b/protobunny/backends/nats/connection.py @@ -0,0 +1,32 @@ +"""Implements a NATS Connection with sync methods""" +import asyncio +import logging +import os +import threading + +from ...asyncio.backends.nats.connection import Connection as NATSConnection +from .. import BaseSyncConnection + +log = logging.getLogger(__name__) + +VHOST = os.environ.get("NATS_VHOST", "/") + + +class Connection(BaseSyncConnection): + """Synchronous wrapper around Async Rmq Connection. + + Manages a dedicated event loop in a background thread to run async operations. + + Example: + .. code-block:: python + + with Connection() as conn: + conn.publish("my.topic", message) + tag = conn.subscribe("my.topic", callback) + + """ + + _lock = threading.RLock() + _stopped: asyncio.Event | None = None + instance_by_vhost: dict[str, "Connection"] = {} + async_class = NATSConnection diff --git a/protobunny/backends/nats/queues.py b/protobunny/backends/nats/queues.py new file mode 100644 index 0000000..bd0df90 --- /dev/null +++ b/protobunny/backends/nats/queues.py @@ -0,0 +1,35 @@ +import logging + +from protobunny.backends import ( + BaseSyncQueue, +) +from protobunny.models import Envelope + +log = logging.getLogger(__name__) + + +class SyncQueue(BaseSyncQueue): + """Message queue backed by pika and RabbitMQ.""" + + def get_tag(self) -> str: + return self.subscription + + def send_message( + self, topic: str, body: bytes, correlation_id: str | None = None, persistent: bool = True + ): + """Low-level message sending implementation. + + Args: + topic: a topic name for direct routing or a routing key with special binding keys + body: serialized message (e.g. a serialized protobuf message or a json string) + correlation_id: is present for result messages + persistent: if true will use aio_pika.DeliveryMode.PERSISTENT + + Returns: + + """ + message = Envelope( + body=body, + correlation_id=correlation_id, + ) + self.get_connection().publish(topic, message) diff --git a/protobunny/backends/python/connection.py b/protobunny/backends/python/connection.py index 7317f0b..b42edef 100644 --- a/protobunny/backends/python/connection.py +++ b/protobunny/backends/python/connection.py @@ -9,7 +9,7 @@ from queue import Empty, Queue from ... import RequeueMessage -from ...config import default_configuration +from ...conf import config from ...models import Envelope, SyncCallback, is_task from .. import BaseConnection @@ -116,7 +116,10 @@ def __init__(self, vhost: str = "/", requeue_delay: int = 3): self.requeue_delay = requeue_delay self._is_connected = False self._subscriptions: dict[str, dict] = {} - self.logger_prefix = default_configuration.logger_prefix + self.logger_prefix = config.logger_prefix + + def build_topic_key(self, topic: str) -> str: + pass class Connection(BaseLocalConnection): @@ -176,7 +179,7 @@ def get_connection(cls, vhost: str = "/") -> "Connection": cls.instance_by_vhost[vhost] = instance return cls.instance_by_vhost[vhost] - def connect(self, timeout: float = 10.0) -> "Connection": + def connect(self, **kwargs) -> "Connection": with self.lock: self.is_connected_event.set() return self @@ -280,19 +283,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect() return False - - -# Convenience functions -def connect() -> Connection: - return Connection.get_connection(vhost=VHOST) - - -def reset_connection() -> Connection: - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() diff --git a/protobunny/backends/rabbitmq/connection.py b/protobunny/backends/rabbitmq/connection.py index 1896ba9..d4a6684 100644 --- a/protobunny/backends/rabbitmq/connection.py +++ b/protobunny/backends/rabbitmq/connection.py @@ -12,24 +12,6 @@ VHOST = os.environ.get("RABBITMQ_VHOST", "/") -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - class Connection(BaseSyncConnection): """Synchronous wrapper around Async Rmq Connection. diff --git a/protobunny/backends/redis/connection.py b/protobunny/backends/redis/connection.py index e6e319b..a52e222 100644 --- a/protobunny/backends/redis/connection.py +++ b/protobunny/backends/redis/connection.py @@ -13,24 +13,6 @@ VHOST = os.environ.get("REDIS_VHOST") or os.environ.get("REDIS_DB", "0") -def connect() -> "Connection": - """Get the singleton async connection.""" - conn = Connection.get_connection(vhost=VHOST) - return conn - - -def reset_connection() -> "Connection": - """Reset the singleton connection.""" - connection = connect() - connection.disconnect() - return connect() - - -def disconnect() -> None: - connection = connect() - connection.disconnect() - - class Connection(BaseSyncConnection): """Synchronous wrapper around the async connection diff --git a/protobunny/config.py b/protobunny/conf.py similarity index 94% rename from protobunny/config.py rename to protobunny/conf.py index 90457b9..5933f6f 100644 --- a/protobunny/config.py +++ b/protobunny/conf.py @@ -24,7 +24,7 @@ ENV_PREFIX = "PROTOBUNNY_" INI_FILE = "protobunny.ini" -AvailableBackends = tp.Literal["rabbitmq", "python", "redis", "mosquitto"] +AvailableBackends = tp.Literal["rabbitmq", "python", "redis", "mosquitto", "nats"] log = logging.getLogger(__name__) @@ -45,12 +45,16 @@ mosquitto_backend_config = BackEndConfig( topic_delimiter="/", multi_wildcard_delimiter="#", namespace="protobunny" ) +nats_backend_config = BackEndConfig( + topic_delimiter=".", multi_wildcard_delimiter=">", namespace="protobunny" +) backend_configs = { "rabbitmq": rabbitmq_backend_config, "python": python_backend_config, "redis": redis_backend_config, "mosquitto": mosquitto_backend_config, + "nats": nats_backend_config, } @@ -67,7 +71,10 @@ class Config: backend: "AvailableBackends" = "rabbitmq" backend_config: BackEndConfig = rabbitmq_backend_config log_task_in_redis: bool = False - available_backends = ("rabbitmq", "python", "redis", "mosquitto") + use_tasks_in_nats: bool = True # needed to create the namespaced group stream on connect + log_task_in_nats: bool = False + + available_backends = ("rabbitmq", "python", "redis", "mosquitto", "nats") def __post_init__(self) -> None: if self.mode not in ("sync", "async"): @@ -192,4 +199,4 @@ def get_project_version() -> str: VERSION = get_project_version() -default_configuration = load_config() +config = load_config() diff --git a/protobunny/helpers.py b/protobunny/helpers.py index ea96d03..a733420 100644 --- a/protobunny/helpers.py +++ b/protobunny/helpers.py @@ -6,7 +6,7 @@ import betterproto -from .config import default_configuration +from .conf import config if typing.TYPE_CHECKING: from .asyncio.backends import BaseAsyncQueue @@ -25,8 +25,8 @@ def get_topic(pkg_or_msg: "ProtoBunnyMessage | type[ProtoBunnyMessage] | ModuleT Returns: topic string """ - delimiter = default_configuration.backend_config.topic_delimiter - return f"{default_configuration.messages_prefix}{delimiter}{build_routing_key(pkg_or_msg)}" + delimiter = config.backend_config.topic_delimiter + return f"{config.messages_prefix}{delimiter}{build_routing_key(pkg_or_msg)}" def get_backend(backend: str | None = None) -> ModuleType: @@ -44,8 +44,8 @@ def get_backend(backend: str | None = None) -> ModuleType: Returns: The imported backend module. """ - backend = backend or default_configuration.backend - module = ".asyncio" if default_configuration.use_async else "" + backend = backend or config.backend + module = ".asyncio" if config.use_async else "" module_name = f"protobunny{module}.backends.{backend}" if module_name in sys.modules: return sys.modules[module_name] @@ -53,8 +53,8 @@ def get_backend(backend: str | None = None) -> ModuleType: module = importlib.import_module(module_name) except ModuleNotFoundError as exc: suggestion = "" - if backend not in default_configuration.available_backends: - suggestion = f" Invalid backend or backend not supported.\nAvailable backends: {default_configuration.available_backends}" + if backend not in config.available_backends: + suggestion = f" Invalid backend or backend not supported.\nAvailable backends: {config.available_backends}" else: suggestion = ( f" Install the backend with pip install protobunny[{backend}]." @@ -82,18 +82,17 @@ def get_queue( Returns: Async/SyncQueue: A queue instance configured for the relevant topic. """ - backend_name = backend_name or default_configuration.backend - queue_type = "AsyncQueue" if default_configuration.use_async else "SyncQueue" + backend_name = backend_name or config.backend + queue_type = "AsyncQueue" if config.use_async else "SyncQueue" return getattr(get_backend(backend=backend_name).queues, queue_type)(get_topic(pkg_or_msg)) @functools.lru_cache(maxsize=100) def _build_routing_key(module: str, cls_name: str) -> str: # Build the routing key from the module and class name - backend = default_configuration.backend_config + backend = config.backend_config delimiter = backend.topic_delimiter routing_key = f"{module}.{cls_name}" - config = default_configuration if not routing_key.startswith(config.generated_package_name): raise ValueError( f"Invalid topic {routing_key}, must start with {config.generated_package_name}." @@ -113,7 +112,7 @@ def build_routing_key( """Returns a routing key based on a message instance, a message class, or a module. The string will be later composed with the configured message-prefix to build the exact topic name. - This is the main logic that builds keys strings for topics/streaming, adding wildcards when needed + This is the main logic that builds keys strings for topics/streaming, adding wildcards when needed. Examples: build_routing_key(mymessaginglib.vision.control) -> "vision.control.#" routing with binding key @@ -126,7 +125,7 @@ def build_routing_key( Returns: a routing key based on the type of message or package """ - backend = default_configuration.backend_config + backend = config.backend_config wildcard = backend.multi_wildcard_delimiter module_name = "" class_name = "" diff --git a/protobunny/logger.py b/protobunny/logger.py index 912522a..2ccc2f3 100644 --- a/protobunny/logger.py +++ b/protobunny/logger.py @@ -40,7 +40,7 @@ import protobunny as pb_sync from protobunny import asyncio as pb -from protobunny.config import load_config +from protobunny.conf import load_config from protobunny.models import IncomingMessageProtocol, LoggerCallback logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") diff --git a/protobunny/models.py b/protobunny/models.py index 4b55ffe..f245b5d 100644 --- a/protobunny/models.py +++ b/protobunny/models.py @@ -12,21 +12,15 @@ import betterproto from betterproto.lib.std.google.protobuf import Any -from .config import default_configuration +from .conf import config from .helpers import get_topic from .utils import ProtobunnyJsonEncoder -# - types -SyncCallback: tp.TypeAlias = tp.Callable[["ProtoBunnyMessage"], tp.Any] -AsyncCallback: tp.TypeAlias = tp.Callable[["ProtoBunnyMessage"], tp.Awaitable[tp.Any]] -ResultCallback: tp.TypeAlias = tp.Callable[["Result"], tp.Any] -LoggerCallback: tp.TypeAlias = tp.Callable[[tp.Any, str], tp.Any] - log = logging.getLogger(__name__) def is_task(topic: str) -> bool: - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter return "tasks" in topic.split(delimiter) @@ -54,8 +48,9 @@ def validate_required_fields(self: "ProtoBunnyMessage") -> None: if missing: raise MissingRequiredFields(self, missing) - @functools.cached_property - def json_content_fields(self: "ProtoBunnyMessage") -> list[str]: + # @functools.cached_property + @property + def json_content_fields(self: "ProtoBunnyMessage") -> tp.Iterable[str]: """Returns: the list of fieldnames that are of type commons.JsonContent.""" return [ field_name @@ -67,19 +62,19 @@ def __bytes__(self: "ProtoBunnyMessage") -> bytes: # Override Message.__bytes__ method # to support transparent serialization of dictionaries to JsonContent fields. # This method validates for required fields as well - if default_configuration.force_required_fields: + if config.force_required_fields: self.validate_required_fields() msg = self.serialize_json_content() with BytesIO() as stream: betterproto.Message.dump(msg, stream) return stream.getvalue() - def from_dict(self: "ProtoBunnyMessage", value: dict) -> "ProtoBunnyMessage": + def from_dict(self: "ProtoBunnyMessage", value: dict) -> "PBM": json_fields = {field: value.pop(field, None) for field in self.json_content_fields} msg = betterproto.Message.from_dict(tp.cast(betterproto.Message, self), value) for field in json_fields: setattr(msg, field, json_fields[field]) - return msg + return tp.cast(PBM, msg) def to_dict( self: "ProtoBunnyMessage", @@ -109,7 +104,7 @@ def _to_dict_with_json_content(self, betterproto_func: tp.Callable[..., tp.Any]) return out_dict def to_pydict( - self: "ProtoBunnyMessage", + self: "PBM", casing: tp.Callable[[str, bool], str] = betterproto.Casing.CAMEL, include_default_values: bool = False, ) -> dict[str, tp.Any]: @@ -127,9 +122,7 @@ def to_pydict( out_dict = self._use_enum_names(casing, out_dict) return out_dict - def _use_enum_names( - self: "ProtoBunnyMessage", casing, out_dict: dict[str, tp.Any] - ) -> dict[str, tp.Any]: + def _use_enum_names(self: "PBM", casing, out_dict: dict[str, tp.Any]) -> dict[str, tp.Any]: """Used to reprocess betterproto.Message.to_pydict output to use names for Enum fields. Process only first level fields. @@ -181,7 +174,7 @@ def _process_enum_field(): return updated_out_enums def to_json( - self: "ProtoBunnyMessage", + self: "PBM", indent: None | int | str = None, include_default_values: bool = False, casing: tp.Callable[[str, bool], str] = betterproto.Casing.CAMEL, @@ -193,7 +186,7 @@ def to_json( cls=ProtobunnyJsonEncoder, ) - def parse(self: "ProtoBunnyMessage", data: bytes) -> "ProtoBunnyMessage": + def parse(self: "PBM", data: bytes) -> "PBM": # Override Message.parse() method # to support transparent deserialization of JsonContent fields json_content_fields = list(self.json_content_fields) @@ -209,12 +202,12 @@ def parse(self: "ProtoBunnyMessage", data: bytes) -> "ProtoBunnyMessage": return msg @property - def type_url(self: "ProtoBunnyMessage") -> str: + def type_url(self: "PBM") -> str: """Return the class fqn for this message.""" return f"{self.__class__.__module__}.{self.__class__.__name__}" @property - def source(self: "ProtoBunnyMessage") -> "ProtoBunnyMessage": + def source(self: "PBM") -> "PBM": """Return the source message from a Result The source message is stored as a protobuf.Any message, with its type info and serialized value. @@ -227,19 +220,19 @@ def source(self: "ProtoBunnyMessage") -> "ProtoBunnyMessage": return source_message @functools.cached_property - def topic(self: "ProtoBunnyMessage") -> str: + def topic(self: "PBM") -> str: """Build the topic name for the message.""" return get_topic(self) @functools.cached_property - def result_topic(self: "ProtoBunnyMessage") -> str: + def result_topic(self: "PBM") -> str: """ Build the result topic name for the message. """ - return f"{get_topic(self)}{default_configuration.backend_config.topic_delimiter}result" + return f"{get_topic(self)}{config.backend_config.topic_delimiter}result" def make_result( - self: "ProtoBunnyMessage", + self: "PBM", return_code: "ReturnCode | int | None" = None, error: str = "", return_value: dict[str, tp.Any] | None = None, @@ -289,7 +282,7 @@ class ProtoBunnyMessage(MessageMixin, betterproto.Message): class MissingRequiredFields(Exception): """Exception raised by MessageMixin.validate_required_fields when required fields are missing.""" - def __init__(self, msg: "ProtoBunnyMessage", missing_fields: list[str]) -> None: + def __init__(self, msg: "PBM", missing_fields: list[str]) -> None: self.missing_fields = missing_fields missing = ", ".join(missing_fields) super().__init__(f"Non optional fields for message {msg.topic} were not set: {missing}") @@ -328,9 +321,7 @@ def _deserialize_content(msg: "JsonContent") -> dict | None: return json.loads(msg.content.decode()) if msg.content else None -def _get_submodule( - package: ModuleType, paths: list[str] -) -> "type[ProtoBunnyMessage] | ModuleType | None": +def _get_submodule(package: ModuleType, paths: list[str]) -> "type[PBM] | ModuleType | None": """Import module/class from package Args: @@ -357,20 +348,22 @@ def get_message_class_from_topic(topic: str) -> "type[ProtoBunnyMessage] | None """Return the message class from a topic with lazy import of the user library Args: - topic: the RabbitMQ topic that represents the message queue + topic: the topic that represents the message queue, mapped to the message class + example for redis mylib:tasks:TaskMessage -> mylib.tasks.TaskMessage class - Returns: the message class + Returns: the message class for the topic or None if the topic is not recognized """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if topic.endswith(f"{delimiter}result"): message_type = Result else: - route = topic.removeprefix(f"{default_configuration.messages_prefix}{delimiter}") - if route == topic: - # Allow pb.* internal messages + route = topic.removeprefix(f"{config.messages_prefix}{delimiter}") + if route == topic: # the prefix is not present in the topic + # Try if it's a protobunny class + # to allow pb.* internal messages like pb.results.Result route = topic.removeprefix(f"pb{delimiter}") - codegen_module = importlib.import_module(default_configuration.generated_package_name) - # if route is not recognized, the message_type will be None + codegen_module = importlib.import_module(config.generated_package_name) + # if route is not recognized at this point, the message_type will be None message_type = _get_submodule(codegen_module, route.split(delimiter)) return message_type @@ -385,9 +378,9 @@ def get_message_class_from_type_url(url: str) -> type["ProtoBunnyMessage"]: Returns: the message class """ module_path, clz = url.rsplit(".", 1) - if not module_path.startswith(default_configuration.generated_package_name): + if not module_path.startswith(config.generated_package_name): raise ValueError( - f"Invalid type url {url}, must start with {default_configuration.generated_package_name}." + f"Invalid type url {url}, must start with {config.generated_package_name}." ) module = importlib.import_module(module_path) message_type = getattr(module, clz) @@ -458,7 +451,7 @@ def __init__(self, topic: str): Args: topic: a Topic value object """ - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter self.topic: str = topic.replace(".", delimiter) self.subscription: str | None = None self.result_subscription: str | None = None @@ -466,7 +459,7 @@ def __init__(self, topic: str): @property def result_topic(self) -> str: - return f"{self.topic}{default_configuration.backend_config.topic_delimiter}result" + return f"{self.topic}{config.backend_config.topic_delimiter}result" @abstractmethod def get_tag(self) -> str: @@ -559,7 +552,7 @@ def get_body(message: "IncomingMessageProtocol") -> str: """ msg: ProtoBunnyMessage | None body: str | bytes - delimiter = default_configuration.backend_config.topic_delimiter + delimiter = config.backend_config.topic_delimiter if message.routing_key and message.routing_key.endswith(f"{delimiter}result"): # log result message. Need to extract the source here result = deserialize_result_message(message.body) @@ -584,5 +577,12 @@ def get_body(message: "IncomingMessageProtocol") -> str: return str(body) +# - types +PBM = tp.TypeVar("PBM", bound=ProtoBunnyMessage) +SyncCallback: tp.TypeAlias = tp.Callable[[PBM], tp.Any] +AsyncCallback: tp.TypeAlias = tp.Callable[[PBM], tp.Coroutine[tp.Any, tp.Any, None] | tp.Any] +ResultCallback: tp.TypeAlias = tp.Callable[["Result"], tp.Any] +LoggerCallback: tp.TypeAlias = tp.Callable[[tp.Any, str], tp.Any] + from .core.commons import JsonContent from .core.results import Result, ReturnCode diff --git a/protobunny/registry.py b/protobunny/registry.py index 02cd9f0..912a928 100644 --- a/protobunny/registry.py +++ b/protobunny/registry.py @@ -77,4 +77,4 @@ def unregister_all_results(self): # Then create a default instance -default_registry = SubscriptionRegistry() +registry = SubscriptionRegistry() diff --git a/protobunny/wrapper.py b/protobunny/wrapper.py index 699aaf7..f8ea798 100644 --- a/protobunny/wrapper.py +++ b/protobunny/wrapper.py @@ -64,7 +64,7 @@ sys.exit(config_error) -from .config import load_config +from .conf import load_config from .logger import log_callback, start_logger, start_logger_sync diff --git a/pyproject.toml b/pyproject.toml index fe58855..a26ea89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,24 +3,23 @@ name = "protobunny" classifiers = [ "Development Status :: 3 - Alpha" ] -version = "0.1.2a1" -description = "A type-safe, sync/async Python messaging framework." +version = "0.1.2a2" +description = "A type-safe, sync/async Python messaging library." authors = [ {name = "Domenico Nappo", email = "domenico.nappo@am-flow.com"}, {name = "Sander Koelstra", email = "sander.koelstra@am-flow.com.com"}, {name = "Sem Mulder", email = "sem.mulder@am-flow.com"} ] keywords = [ + "messaging", "protobuf", "queues", "mqtt", - "messaging", + "amqp", "rabbitmq", "redis", - "amqp", - "aio-pika", - "betterproto", - "mosquitto" + "mosquitto", + "nats" ] requires-python = ">=3.10,<3.14" readme = "README.md" @@ -29,7 +28,6 @@ dependencies = [ "can-ada>=2.0.0", "click>=8.2.1", "grpcio-tools>=1.62.0,<2", - "pytest-env>=1.1.3", "tomli ; python_version < '3.11'", "typing-extensions; python_version < '3.11'" ] @@ -38,9 +36,8 @@ dependencies = [ rabbitmq = ["aio-pika>=7.1.0"] redis = ["redis<8"] numpy = ["numpy>=1.26"] -mosquitto = [ - "aiomqtt>=2.4.0" -] +mosquitto = ["aiomqtt>=2.4.0"] +nats = ["nats-py>=2.12.0"] [project.scripts] protobunny = "protobunny.wrapper:main" @@ -60,6 +57,7 @@ dev = [ "waiting>=1.5.0,<2", "ipython>=8.11.0,<9", "pytest>=7.2.2,<8", + "pytest-env>=1.1.3", "pytest-cov>=4.0.0,<5", "pytest-mock>=3.14.0,<4", "pytest-asyncio>=0.23.7,<0.24", @@ -97,6 +95,7 @@ force-required-fields = true backend = "redis" mode = "async" log-task-in-redis = true +log-task-in-nats = true [tool.pytest.ini_options] addopts = "--junitxml=pytest_report.xml --cov=. --cov-report=term --cov-report=xml --cov-report html --no-success-flaky-report" @@ -136,4 +135,4 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.setuptools] -package-data = {"protobunny" = ["protobunny/protobuf/*.proto", "protobunny/__init__.py.j2", "scripts/*.py"]} +package-data = {"protobunny" = ["protobunny/protobuf/*.proto", "protobunny/__init__.py.j2", "protobunny/asyncio/__init__.py.j2", "scripts/*.py"]} diff --git a/scripts/convert_md.py b/scripts/convert_md.py deleted file mode 100644 index f7bb5e6..0000000 --- a/scripts/convert_md.py +++ /dev/null @@ -1,6 +0,0 @@ -import pypandoc - - -if __name__ == "__main__": - pypandoc.convert_file("README.md", "rst", outputfile="docs/source/readme.rst") - pypandoc.convert_file("QUICK_START.md", "rst", outputfile="docs/source/quick_start_rst.rst") diff --git a/setup.py b/setup.py index 5362ffe..6331cf7 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "protobunny"))) import typing as tp -from config import ( +from conf import ( GENERATED_PACKAGE_NAME, PACKAGE_NAME, PROJECT_NAME, @@ -97,7 +97,7 @@ def run(self) -> None: cmdclass={ "install": GenerateProtoCommand, }, - python_requires=">=3.10,<3.13", + python_requires=">=3.10,<3.14", description="Protobuf messages and python mqtt messaging toolkit", entry_points={ "console_scripts": [ diff --git a/tests/conftest.py b/tests/conftest.py index 2ea3c68..7a9fbc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import aio_pika import aiormq import fakeredis +import nats import pamqp import pytest import pytest_asyncio @@ -21,8 +22,8 @@ @pytest_asyncio.fixture -def test_config() -> protobunny.config.Config: - conf = protobunny.config.Config( +def test_config() -> protobunny.conf.Config: + conf = protobunny.conf.Config( messages_directory="tests/proto", messages_prefix="acme", generated_package_name="tests", @@ -31,7 +32,9 @@ def test_config() -> protobunny.config.Config: force_required_fields=True, mode="async", backend="rabbitmq", - log_task_in_redis=False, + log_task_in_redis=True, + use_tasks_in_nats=True, + log_task_in_nats=True, ) return conf @@ -45,6 +48,22 @@ def __init__(self): self.subscribe = AsyncMock() self.unsubscribe = AsyncMock() self.publish = AsyncMock() + self._is_connected = False + + def is_connected(self) -> bool: + return self._is_connected + + def connect(self, host, port, keepalive): + self._is_connected = True + + def disconnect(self): + self._is_connected = False + + def loop_start(self): + return True + + def loop_stop(self): + return True async def __aenter__(self): return self @@ -71,6 +90,18 @@ async def mock_mosquitto(mocker) -> tp.AsyncGenerator[MockMQTTConnection, None]: yield mock_client +@pytest.fixture +async def mock_sync_mosquitto(mocker) -> tp.AsyncGenerator[MockMQTTConnection, None]: + mock_client = MockMQTTConnection() + mocker.spy(mock_client, "publish") + mocker.spy(mock_client, "subscribe") + mocker.spy(mock_client, "unsubscribe") + + with patch("paho.mqtt.client.Client") as mock_client_class: + mock_client_class.return_value = mock_client + yield mock_client + + @pytest.fixture async def mock_redis_client(mocker) -> tp.AsyncGenerator[fakeredis.FakeAsyncRedis, None]: server = fakeredis.FakeServer() @@ -85,6 +116,21 @@ async def mock_redis_client(mocker) -> tp.AsyncGenerator[fakeredis.FakeAsyncRedi await client.aclose() +@pytest.fixture +async def mock_nats(mocker): + mock_nc = AsyncMock(spec=nats.aio.client.Client) + mock_js = AsyncMock() + mock_nc.jetstream.return_value = mock_js + mock_jsm = AsyncMock() + mock_nc.jsm.return_value = mock_jsm + mocker.patch("nats.connect", return_value=mock_nc) + yield { + "client": mock_nc, + "js": mock_js, + "jsm": mock_jsm, + } + + @pytest.fixture async def mock_aio_pika(): """Mocks the entire aio_pika connection chain.""" @@ -120,7 +166,7 @@ async def mock_aio_pika(): def mock_sync_rmq_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=rabbitmq_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.rabbitmq.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock @@ -135,7 +181,7 @@ async def mock_rmq_connection(mocker: MockerFixture) -> tp.AsyncGenerator[AsyncM def mock_sync_redis_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=redis_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.rabbitmq.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock @@ -150,7 +196,7 @@ async def mock_redis_connection(mocker: MockerFixture) -> tp.AsyncGenerator[Asyn def mock_sync_mqtt_connection(mocker: MockerFixture) -> tp.Generator[MagicMock, None, None]: mock = mocker.MagicMock(spec=mosquitto_connection.Connection) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch("protobunny.backends.mosquitto.connection.connect", return_value=mock) + mocker.patch("protobunny.connect", return_value=mock) yield mock diff --git a/tests/test_base.py b/tests/test_base.py index ae437b8..8c52c24 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -23,11 +23,11 @@ def setup_config(mocker, test_config) -> None: test_config.mode = "async" test_config.backend = "python" - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(python_backend.connection, "default_configuration", test_config) + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(python_backend.connection, "config", test_config) def test_json_serializer() -> None: diff --git a/tests/test_config.py b/tests/test_config.py index 63983be..1504f61 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ import pytest -from protobunny.config import load_config +from protobunny.conf import load_config @pytest.fixture(autouse=True) @@ -48,7 +48,7 @@ def test_config_from_ini(tmp_path, monkeypatch): (tmp_path / "protobunny.ini").write_text(ini_content) # Ensure pyproject doesn't interfere - with mock.patch("protobunny.config.get_config_from_pyproject", return_value={}): + with mock.patch("protobunny.conf.get_config_from_pyproject", return_value={}): config = load_config() assert config.backend == "python" assert config.mode == "async" @@ -69,7 +69,7 @@ def test_config_precedence(tmp_path, monkeypatch): # Mock pyproject to provide a base mock_pyproject = {"backend": "rabbitmq", "messages_prefix": "py_pref", "project-name": "test"} - with mock.patch("protobunny.config.get_config_from_pyproject", return_value=mock_pyproject): + with mock.patch("protobunny.conf.get_config_from_pyproject", return_value=mock_pyproject): config = load_config() # Overridden by ENV diff --git a/tests/test_connection.py b/tests/test_connection.py index 0e832e4..59c9d7f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,13 +10,12 @@ import protobunny as pb_base from protobunny import RequeueMessage from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio from protobunny.backends import python as python_backend -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.helpers import ( + get_backend, get_queue, ) from protobunny.models import Envelope, IncomingMessageProtocol @@ -33,51 +32,58 @@ @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + "rabbitmq", + "redis", + "python", + "mosquitto", + "nats", + ], ) @pytest.mark.asyncio class TestConnection: @pytest.fixture(autouse=True) async def mock_connections( - self, backend, mocker, mock_redis_client, mock_aio_pika, mock_mosquitto, test_config + self, + backend: str, + mocker, + mock_redis_client, + mock_aio_pika, + mock_mosquitto, + mock_nats, + test_config, ) -> tp.AsyncGenerator[dict[str, AsyncMock | None], None]: - backend_name = backend.__name__.split(".")[-1] - test_config.mode = "async" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - connection_module = getattr(pb.backends, backend_name).connection - if hasattr(connection_module, "default_configuration"): - mocker.patch.object(connection_module, "default_configuration", test_config) - - assert pb_base.helpers.get_backend() == backend - assert pb.get_backend() == backend - assert isinstance(get_queue(tests.tasks.TaskMessage), backend.queues.AsyncQueue) + test_config.backend_config = backend_configs[backend] + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + connection_module = getattr(pb.backends, backend).connection + if hasattr(connection_module, "config"): + mocker.patch.object(connection_module, "config", test_config) + + backend_module = get_backend() + assert pb.get_backend() == backend_module + assert isinstance(get_queue(tests.tasks.TaskMessage), backend_module.queues.AsyncQueue) conn_with_fake_internal_conn = get_mocked_connection( - backend, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto + backend_module, mock_redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats ) - mocker.patch.object(pb, "connect", return_value=conn_with_fake_internal_conn) - mocker.patch.object(pb, "disconnect", side_effect=backend.connection.disconnect) mocker.patch( "protobunny.asyncio.backends.BaseAsyncQueue.get_connection", return_value=conn_with_fake_internal_conn, ) - mocker.patch( - f"protobunny.asyncio.backends.{backend_name}.connection.connect", - return_value=conn_with_fake_internal_conn, - ) yield { "rabbitmq": mock_aio_pika, "redis": mock_redis_client, "connection": conn_with_fake_internal_conn, "python": conn_with_fake_internal_conn._connection, "mosquitto": mock_mosquitto, + "nats": mock_nats, } connection_module.Connection.instance_by_vhost.clear() @@ -87,8 +93,7 @@ async def mock_connection(self, mock_connections, backend): @pytest.fixture async def mock_internal_connection(self, mock_connections, backend): - backend_name = backend.__name__.split(".")[-1] - yield mock_connections[backend_name] + yield mock_connections[backend] async def test_connection_success( self, mock_connection: MagicMock, mock_internal_connection, backend @@ -118,9 +123,8 @@ async def test_publish_tasks( conn = await mock_connection.connect() incoming = incoming_message_factory(backend) - backend_name = backend.__name__.split(".")[-1] topic = "test.tasks.key" - delimiter = backend_configs[backend_name].topic_delimiter + delimiter = backend_configs[backend].topic_delimiter topic = topic.replace(".", delimiter) await conn.publish(topic, incoming) @@ -137,6 +141,8 @@ async def test_publish_tasks( async def test_publish(self, mock_connection: MagicMock, mock_internal_connection, backend): topic = "test.routing.key" + delimiter = backend_configs[backend].topic_delimiter + topic = topic.replace(".", delimiter) conn = await mock_connection.connect() msg = None incoming = incoming_message_factory(backend) @@ -164,8 +170,8 @@ async def predicate(): @pytest.mark.asyncio async def test_singleton_logic(self, backend): - conn1 = await backend.connection.connect() - conn2 = await backend.connection.connect() + conn1 = await pb.connect() + conn2 = await pb.connect() assert conn1 is conn2 await conn1.disconnect() @@ -283,3 +289,7 @@ async def test_get_message_count(mock_redis_client): await conn.publish("test:tasks:topic", Envelope(body=b"test message 3")) count = await conn.get_message_count("test:tasks:topic") assert count == 3 + + +# --- Specific Tests for NATS --- +# TODO diff --git a/tests/test_integration.py b/tests/test_integration.py index d072092..ebfca10 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,4 +1,5 @@ -import gc +import asyncio +import importlib import logging import typing as tp @@ -7,17 +8,10 @@ import pytest from pytest_mock import MockerFixture -import protobunny as pb_sync +import protobunny as pb_base from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio -from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio -from protobunny.asyncio.backends import redis as redis_backend_aio -from protobunny.backends import mosquitto as mosquitto_backend -from protobunny.backends import python as python_backend -from protobunny.backends import rabbitmq as rabbitmq_backend -from protobunny.backends import redis as redis_backend -from protobunny.config import Config, backend_configs +from protobunny import get_backend +from protobunny.conf import Config, backend_configs from protobunny.models import ProtoBunnyMessage from . import tests @@ -86,51 +80,57 @@ def log_callback(message: aio_pika.IncomingMessage, body: str) -> str: @pytest.mark.integration @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + "rabbitmq", + "redis", + "python", + "mosquitto", + "nats", + ], ) class TestIntegration: - """Integration tests (to run with RabbitMQ up)""" + """Integration tests (to run with all the backends up) + + For a specific backend, run python -m pytest tests/test_integration.py -k redis + """ msg = tests.TestMessage(content="test", number=123, color=tests.Color.GREEN) @pytest.fixture(autouse=True) async def setup_test_env( - self, mocker: MockerFixture, test_config: Config, backend + self, mocker: MockerFixture, test_config: Config, backend: str ) -> tp.AsyncGenerator[None, None]: - backend_name = backend.__name__.split(".")[-1] test_config.mode = "async" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - - # patch the asyncio modules - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb.backend = backend - mocker.patch("protobunny.asyncio.backends.get_backend", return_value=backend) - mocker.patch.object(pb, "connect", backend.connection.connect) - mocker.patch.object(pb, "disconnect", backend.connection.disconnect) - mocker.patch.object(pb, "get_backend", return_value=backend) - + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb, "config", test_config) + backend_module = get_backend() + assert backend_module.__name__.split(".")[-1] == backend + + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + pb.backend = backend_module + mocker.patch.object(pb, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend connection = await pb.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) queue = pb.get_queue(self.msg) assert queue.topic == "acme.tests.TestMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.AsyncQueue) - assert isinstance(await queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.AsyncQueue) + assert isinstance(await queue.get_connection(), backend_module.connection.Connection) # start without pending subscriptions await pb.unsubscribe_all(if_unused=False, if_empty=False) yield @@ -142,9 +142,8 @@ async def setup_test_env( "result": None, "task": None, } - await connection.disconnect() - backend.connection.Connection.instance_by_vhost.clear() - gc.collect() + await pb.disconnect() + backend_module.connection.Connection.instance_by_vhost.clear() @pytest.mark.flaky(max_runs=3) async def test_publish(self, backend) -> None: @@ -178,7 +177,7 @@ async def predicate() -> bool: ) == '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' ) - await pb.subscribe(tests.tasks.TaskMessage, callback) + await pb.subscribe(tests.tasks.TaskMessage, callback_task) msg = tests.tasks.TaskMessage( content="test", bbox=[1, 2, 3, 4], @@ -186,10 +185,11 @@ async def predicate() -> bool: await pb.publish(msg) async def predicate() -> bool: - return received["message"] == msg + await asyncio.sleep(0) + return received["task"] == msg assert await async_wait(predicate, timeout=1, sleep=0.1) - assert received["message"].to_dict( + assert received["task"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { "content": "test", @@ -197,8 +197,8 @@ async def predicate() -> bool: "weights": [], "options": None, } - # to_pydict uses enum names and don't stringyfies int64 - assert received["message"].to_pydict( + # to_pydict uses enum names and don't stringifies int64 + assert received["task"].to_pydict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { "content": "test", @@ -209,8 +209,7 @@ async def predicate() -> bool: @pytest.mark.flaky(max_runs=3) async def test_count_messages(self, backend) -> None: - backend_name = backend.__name__.split(".")[-1] - if backend_name == "mosquitto": + if backend == "mosquitto": pytest.skip("mosquitto backend doesn't support message counts") task_queue = await pb.subscribe(tests.tasks.TaskMessage, callback) msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) @@ -224,7 +223,7 @@ async def predicate() -> bool: assert await async_wait( predicate - ), f"Messages were not in the queue: {await task_queue.get_message_count()}" + ), f"Message count should be 0: {await task_queue.get_message_count()}" # we unsubscribe so the published messages # won't be consumed and stay in the queue await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) @@ -368,7 +367,7 @@ async def predicate() -> bool: assert await async_wait(predicate, timeout=2, sleep=0.1) assert received["message"] == received2 == self.msg - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_empty=False, if_unused=False) received["message"] = None received2 = None await pb.publish(self.msg) @@ -379,12 +378,12 @@ async def predicate() -> bool: async def test_unsubscribe_results(self, backend) -> None: received_result: pb.results.Result | None = None - def callback_test(_: tests.TestMessage) -> None: + async def callback_test(_: tests.TestMessage) -> None: # The receiver catches the error in callback and will send a Result.FAILURE message # to the result topic raise RuntimeError("error in callback") - def callback_results(m: pb.results.Result) -> None: + async def callback_results(m: pb.results.Result) -> None: nonlocal received_result received_result = m @@ -423,7 +422,7 @@ async def callback_results(m: pb.results.Result) -> None: nonlocal received_result received_result = m - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_unused=False, if_empty=False) q1 = await pb.subscribe(tests.TestMessage, callback_1) q2 = await pb.subscribe(tests.tasks.TaskMessage, callback_2) assert q1.topic == "acme.tests.TestMessage".replace(".", self.topic_delimiter) @@ -441,7 +440,7 @@ async def predicate() -> bool: assert await async_wait(predicate, timeout=2, sleep=0.2) assert received_result.source == tests.TestMessage(number=2, content="test") - await pb.unsubscribe_all() + await pb.unsubscribe_all(if_empty=False, if_unused=False) received_result = None received_message = None await pb.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) @@ -451,54 +450,59 @@ async def predicate() -> bool: @pytest.mark.integration -@pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestIntegrationSync: """Integration tests (to run with the backend server up)""" msg = tests.TestMessage(content="test", number=123, color=tests.Color.GREEN) + expected_json = ( + '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' + ) @pytest.fixture(autouse=True) def setup_test_env( - self, mocker: MockerFixture, test_config: Config, backend + self, mocker: MockerFixture, test_config: Config, backend: str ) -> tp.Generator[None, None, None]: - backend_name = backend.__name__.split(".")[-1] test_config.mode = "sync" - test_config.backend = backend_name + test_config.backend = backend test_config.log_task_in_redis = True - test_config.backend_config = backend_configs[backend_name] + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb_sync.backend = backend - # mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_sync.helpers, "get_backend", return_value=backend) - mocker.patch.object(pb_sync, "connect", backend.connection.connect) - mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) - mocker.patch.object(pb_sync, "get_backend", return_value=backend) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + + backend_module = get_backend() + assert backend_module.__name__.split(".")[-1] == backend + async_backend_module = importlib.import_module(f"protobunny.asyncio.backends.{backend}") + if hasattr(async_backend_module.connection, "config"): + # The sync connection is often implemented as a wrapper of the relative async module. Patch the config of the async module as well + mocker.patch.object(async_backend_module.connection, "config", test_config) + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + mocker.patch.object(pb_base.helpers, "get_backend", return_value=backend_module) + mocker.patch.object(pb_base, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend - connection = pb_sync.connect() - assert isinstance(connection, backend.connection.Connection) - queue = pb_sync.get_queue(self.msg) + connection = pb_base.connect() + assert isinstance(connection, backend_module.connection.Connection) + queue = pb_base.get_queue(self.msg) assert queue.topic == "acme.tests.TestMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.SyncQueue) - assert isinstance(queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.SyncQueue) + assert isinstance(queue.get_connection(), backend_module.connection.Connection) + # Setup # start without pending subscriptions - pb_sync.unsubscribe_all(if_unused=False, if_empty=False) + pb_base.unsubscribe_all(if_unused=False, if_empty=False) yield + # Teardown # reset the variables holding the messages received global received received = { @@ -508,22 +512,22 @@ def setup_test_env( "task": None, } connection.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() @pytest.mark.flaky(max_runs=3) def test_publish(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] is not None) assert received["message"].number == self.msg.number @pytest.mark.flaky(max_runs=3) def test_to_dict(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.subscribe(tests.tasks.TaskMessage, callback_task_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] == self.msg) assert received["message"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True @@ -532,15 +536,15 @@ def test_to_dict(self, backend) -> None: received["message"].to_json( casing=betterproto.Casing.SNAKE, include_default_values=True ) - == '{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}' + == self.expected_json ) msg = tests.tasks.TaskMessage( content="test", bbox=[1, 2, 3, 4], ) - pb_sync.publish(msg) - assert sync_wait(lambda: received["task"] == msg) + pb_base.publish(msg) + assert sync_wait(lambda: received["task"] is not None) assert received["task"].to_dict( casing=betterproto.Casing.SNAKE, include_default_values=True ) == { @@ -561,55 +565,45 @@ def test_to_dict(self, backend) -> None: @pytest.mark.flaky(max_runs=3) def test_count_messages(self, backend) -> None: - backend_name = backend.__name__.split(".")[-1] - if backend_name == "mosquitto": + if backend == "mosquitto": pytest.skip("mosquitto backend doesn't support message counts") # Subscribe to a tasks topic (shared queue) - task_queue = pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_sync) + task_queue = pb_base.subscribe(tests.tasks.TaskMessage, callback_task_sync) msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) - connection = pb_sync.connect() + connection = pb_base.connect() # remove past messages connection.purge(task_queue.topic, reset_groups=True) # we unsubscribe so the published messages # won't be consumed and stay in the queue task_queue.unsubscribe() assert sync_wait(lambda: 0 == task_queue.get_consumer_count()) - pb_sync.publish(msg) - pb_sync.publish(msg) - pb_sync.publish(msg) + pb_base.publish(msg) + pb_base.publish(msg) + pb_base.publish(msg) # and we can count them assert sync_wait(lambda: 3 == task_queue.get_message_count()) @pytest.mark.flaky(max_runs=3) def test_logger_body(self, backend) -> None: - pb_sync.subscribe_logger(log_callback) + pb_base.subscribe_logger(log_callback) topic = "acme.tests.TestMessage".replace(".", self.topic_delimiter) topic_result = "acme.tests.TestMessage.result".replace(".", self.topic_delimiter) - pb_sync.publish(self.msg) + pb_base.publish(self.msg) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic}: {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic}: {self.expected_json}" received["log"] = None result = self.msg.make_result() - pb_sync.publish_result(result) + pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic_result}: SUCCESS - {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic_result}: SUCCESS - {self.expected_json}" result = self.msg.make_result( return_code=pb.results.ReturnCode.FAILURE, return_value={"test": "value"} ) received["log"] = None - pb_sync.publish_result(result) + pb_base.publish_result(result) assert sync_wait(lambda: isinstance(received["log"], str)) - assert ( - received["log"] - == f'{topic_result}: FAILURE - error: [] - {{"content": "test", "number": 123, "detail": null, "options": null, "color": "GREEN"}}' - ) + assert received["log"] == f"{topic_result}: FAILURE - error: [] - {self.expected_json}" @pytest.mark.flaky(max_runs=3) def test_logger_int64(self, backend) -> None: @@ -622,8 +616,8 @@ def callback_log(message, body: str) -> None: nonlocal log_msg log_msg = log_callback(message, body) - pb_sync.subscribe_logger(callback_log) - pb_sync.publish(tests.TestMessage(number=63, content="test")) + pb_base.subscribe_logger(callback_log) + pb_base.publish(tests.TestMessage(number=63, content="test")) def predicate() -> bool: return log_msg is not None @@ -638,7 +632,7 @@ def predicate() -> bool: ), log_msg # Ensure that uint64/int64 values are not converted to strings in the LoggerQueue callbacks log_msg = None - pb_sync.publish( + pb_base.publish( tests.tasks.TaskMessage( content="test", bbox=[1, 2, 3, 4], weights=[1.0, 2.0, -100, -20] ) @@ -653,22 +647,22 @@ def predicate() -> bool: @pytest.mark.flaky(max_runs=3) def test_unsubscribe(self, backend) -> None: global received - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.publish(self.msg) + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.publish(self.msg) assert sync_wait(lambda: received["message"] is not None) assert received["message"] == self.msg received["message"] = None - pb_sync.unsubscribe(tests.TestMessage, if_unused=False, if_empty=False) - pb_sync.publish(self.msg) + pb_base.unsubscribe(tests.TestMessage, if_unused=False, if_empty=False) + pb_base.publish(self.msg) assert received["message"] is None # unsubscribe from a package-level topic - pb_sync.subscribe(tests, callback_sync) - pb_sync.publish(tests.TestMessage(number=63, content="test")) + pb_base.subscribe(tests, callback_sync) + pb_base.publish(tests.TestMessage(number=63, content="test")) assert sync_wait(lambda: received["message"] is not None) received["message"] = None - pb_sync.unsubscribe(tests, if_unused=False, if_empty=False) - pb_sync.publish(self.msg) + pb_base.unsubscribe(tests, if_unused=False, if_empty=False) + pb_base.publish(self.msg) assert received["message"] is None # subscribe/unsubscribe two callbacks for two topics @@ -678,15 +672,15 @@ def callback_2(m: "ProtoBunnyMessage") -> None: nonlocal received2 received2 = m - pb_sync.subscribe(tests.TestMessage, callback_sync) - pb_sync.subscribe(tests, callback_2) - pb_sync.publish(self.msg) # this will reach callback_2 as well + pb_base.subscribe(tests.TestMessage, callback_sync) + pb_base.subscribe(tests, callback_2) + pb_base.publish(self.msg) # this will reach callback_2 as well assert sync_wait(lambda: received["message"] and received2) assert received["message"] == received2 == self.msg - pb_sync.unsubscribe_all() + pb_base.unsubscribe_all() received["message"] = None received2 = None - pb_sync.publish(self.msg) + pb_base.publish(self.msg) assert received["message"] is None assert received2 is None @@ -703,18 +697,18 @@ def callback_results_2(m: pb.results.Result) -> None: nonlocal received_result received_result = m - pb_sync.unsubscribe_all() - pb_sync.subscribe(tests.TestMessage, callback_2) + pb_base.unsubscribe_all() + pb_base.subscribe(tests.TestMessage, callback_2) # subscribe to the result topic - pb_sync.subscribe_results(tests.TestMessage, callback_results_2) + pb_base.subscribe_results(tests.TestMessage, callback_results_2) msg = tests.TestMessage(number=63, content="test") - pb_sync.publish(msg) + pb_base.publish(msg) assert sync_wait(lambda: received_result is not None) assert received_result.source == msg assert received_result.return_code == pb.results.ReturnCode.FAILURE - pb_sync.unsubscribe_results(tests.TestMessage) + pb_base.unsubscribe_results(tests.TestMessage) received_result = None - pb_sync.publish(msg) + pb_base.publish(msg) assert received_result is None @pytest.mark.flaky(max_runs=3) @@ -735,25 +729,24 @@ def callback_results_2(m: pb.results.Result) -> None: nonlocal received_result received_result = m - pb_sync.unsubscribe_all() - q1 = pb_sync.subscribe(tests.TestMessage, callback_1) - q2 = pb_sync.subscribe(tests.tasks.TaskMessage, callback_2) + q1 = pb_base.subscribe(tests.TestMessage, callback_1) + q2 = pb_base.subscribe(tests.tasks.TaskMessage, callback_2) assert q1.topic == "acme.tests.TestMessage".replace(".", self.topic_delimiter) assert q2.topic == "acme.tests.tasks.TaskMessage".replace(".", self.topic_delimiter) assert q1.subscription is not None assert q2.subscription is not None # subscribe to a result topic - pb_sync.subscribe_results(tests.TestMessage, callback_results_2) - pb_sync.publish(tests.TestMessage(number=2, content="test")) - pb_sync.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) - assert sync_wait(lambda: received_message is not None) - assert sync_wait(lambda: received_result is not None) + pb_base.subscribe_results(tests.TestMessage, callback_results_2) + pb_base.publish(tests.TestMessage(number=2, content="test")) + pb_base.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) + assert sync_wait(lambda: received_message is not None, timeout=2) + assert sync_wait(lambda: received_result is not None, timeout=2) assert received_result.source == tests.TestMessage(number=2, content="test") - pb_sync.unsubscribe_all() + pb_base.unsubscribe_all() received_result = None received_message = None - pb_sync.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) - pb_sync.publish(tests.TestMessage(number=2, content="test")) + pb_base.publish(tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4])) + pb_base.publish(tests.TestMessage(number=2, content="test")) assert received_message is None assert received_result is None diff --git a/tests/test_publish.py b/tests/test_publish.py index 8e5933d..d6d2e11 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -12,10 +12,10 @@ @pytest.fixture(autouse=True) def setup_config(mocker, test_config) -> None: test_config.mode = "sync" - mocker.patch.object(pb.config, "default_configuration", test_config) - mocker.patch.object(pb.models, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb.helpers, "default_configuration", test_config) + mocker.patch.object(pb.conf, "config", test_config) + mocker.patch.object(pb.models, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb.helpers, "config", test_config) def test_sync_send_message(mock_sync_rmq_connection: MagicMock) -> None: diff --git a/tests/test_queues.py b/tests/test_queues.py index 644a1c6..f7ae5f5 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -10,15 +10,17 @@ from protobunny import asyncio as pb from protobunny.asyncio.backends import LoggingAsyncQueue from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio +from protobunny.asyncio.backends import nats as nats_backend_aio from protobunny.asyncio.backends import python as python_backend_aio from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio from protobunny.asyncio.backends import redis as redis_backend_aio from protobunny.backends import LoggingSyncQueue from protobunny.backends import mosquitto as mosquitto_backend +from protobunny.backends import nats as nats_backend from protobunny.backends import python as python_backend from protobunny.backends import rabbitmq as rabbitmq_backend from protobunny.backends import redis as redis_backend -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.helpers import ( get_queue, ) @@ -30,12 +32,26 @@ @pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] + "backend", + [ + rabbitmq_backend_aio, + redis_backend_aio, + python_backend_aio, + mosquitto_backend_aio, + nats_backend_aio, + ], ) class TestQueue: @pytest.fixture(autouse=True) async def mock_connections( - self, backend, mocker, mock_redis_client, mock_aio_pika, test_config, mock_mosquitto + self, + backend, + mocker, + mock_redis_client, + mock_aio_pika, + test_config, + mock_mosquitto, + mock_nats, ) -> tp.AsyncGenerator[dict, None]: backend_name = backend.__name__.split(".")[-1] test_config.mode = "async" @@ -43,20 +59,19 @@ async def mock_connections( test_config.log_task_in_redis = True test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) + if hasattr(backend.connection, "config"): + mocker.patch.object(backend.connection, "config", test_config) + if hasattr(backend.queues, "config"): + mocker.patch.object(backend.queues, "config", test_config) pb.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_base, "disconnect", backend.connection.disconnect) mocker.patch.object(pb_base, "get_backend", return_value=backend) assert pb_base.helpers.get_backend() == backend @@ -73,18 +88,17 @@ async def mock_connections( mock._connection = mock_aio_pika case "mosquitto": mock._connection = mock_mosquitto + case "nats": + mock._connection = mock_nats case "python": mock = backend.connection.Connection() mock.is_connected_event.set() mocker.patch.object(pb, "connect", return_value=mock) mocker.patch("protobunny.asyncio.backends.BaseAsyncQueue.get_connection", return_value=mock) - mocker.patch( - f"protobunny.asyncio.backends.{backend_name}.connection.connect", - return_value=mock, - ) - await pb.disconnect() + assert asyncio.iscoroutinefunction(pb.disconnect) + await pb.disconnect() yield { "rabbitmq": mock_aio_pika, @@ -92,12 +106,11 @@ async def mock_connections( "connection": mock, "python": mock._connection, "mosquitto": mock_mosquitto, + "nats": mock_nats, } connection_module = getattr(pb_base.backends, backend_name).connection connection_module.Connection.instance_by_vhost.clear() - connection_module.Connection.instance_by_vhost.clear() - @pytest.fixture async def mock_connection(self, mock_connections, backend, mocker): conn = mock_connections["connection"] @@ -230,12 +243,19 @@ async def test_logger(self, mock_connection: MagicMock, backend) -> None: @pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] + "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend, nats_backend] ) class TestSyncQueue: @pytest.fixture(autouse=True, scope="function") def mock_connection( - self, backend, mocker, mock_redis_client, mock_aio_pika, test_config + self, + backend, + mocker, + mock_redis_client, + mock_aio_pika, + test_config, + mock_nats, + mock_sync_mosquitto, ) -> tp.Generator[None, None, None]: backend_name = backend.__name__.split(".")[-1] test_config.mode = "sync" @@ -243,20 +263,19 @@ def mock_connection( test_config.log_task_in_redis = True test_config.backend_config = backend_configs[backend_name] - mocker.patch.object(pb_base.config, "default_configuration", test_config) - mocker.patch.object(pb_base.models, "default_configuration", test_config) - mocker.patch.object(pb_base.backends, "default_configuration", test_config) - mocker.patch.object(pb_base.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) + mocker.patch.object(pb_base.conf, "config", test_config) + mocker.patch.object(pb_base.models, "config", test_config) + mocker.patch.object(pb_base.backends, "config", test_config) + mocker.patch.object(pb_base.helpers, "config", test_config) + mocker.patch.object(pb.backends.redis.connection, "config", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) + if hasattr(backend.connection, "config"): + mocker.patch.object(backend.connection, "config", test_config) + if hasattr(backend.queues, "config"): + mocker.patch.object(backend.queues, "config", test_config) pb_base.backend = backend mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_base, "disconnect", backend.connection.disconnect) mocker.patch.object(pb_base, "get_backend", return_value=backend) assert pb_base.helpers.get_backend() == backend @@ -268,7 +287,6 @@ def mock_connection( mock = mocker.MagicMock(spec=backend.connection.Connection) # should be a Sync connection mocker.patch.object(pb_base, "connect", return_value=mock) mocker.patch("protobunny.backends.BaseSyncQueue.get_connection", return_value=mock) - mocker.patch(f"protobunny.backends.{backend_name}.connection.connect", return_value=mock) assert not asyncio.iscoroutinefunction(pb_base.disconnect) pb_base.disconnect() yield mock diff --git a/tests/test_results.py b/tests/test_results.py index e65483b..004ab2b 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -7,7 +7,7 @@ import protobunny as pb from protobunny.backends.rabbitmq.connection import Connection -from protobunny.config import backend_configs +from protobunny.conf import backend_configs from protobunny.models import ( deserialize_result_message, get_message_class_from_topic, @@ -24,13 +24,12 @@ def setup_connections(mocker: MockerFixture, mock_sync_rmq_connection, test_conf test_config.mode = "sync" test_config.backend = "rabbitmq" test_config.backend_config = backend_configs["rabbitmq"] - mocker.patch.object(pb.config, "default_configuration", test_config) - mocker.patch.object(pb.models, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - mocker.patch.object(pb.helpers, "default_configuration", test_config) + mocker.patch.object(pb.conf, "config", test_config) + mocker.patch.object(pb.models, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + mocker.patch.object(pb.helpers, "config", test_config) pb.backend = rabbitmq_backend - mocker.patch.object(pb, "connect", rabbitmq_backend.connection.connect) - mocker.patch.object(pb.helpers.default_configuration, "backend", "rabbitmq") + mocker.patch.object(pb.helpers.config, "backend", "rabbitmq") queue = pb.get_queue(tests.TestMessage) assert isinstance(queue, rabbitmq_backend.queues.SyncQueue) assert isinstance(queue.get_connection(), Connection) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c1138ee..5e4adcc 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,4 +1,5 @@ import asyncio +import importlib import logging import time import typing as tp @@ -7,16 +8,8 @@ import protobunny as pb_sync from protobunny import asyncio as pb -from protobunny.asyncio.backends import mosquitto as mosquitto_backend_aio -from protobunny.asyncio.backends import python as python_backend_aio -from protobunny.asyncio.backends import rabbitmq as rabbitmq_backend_aio -from protobunny.asyncio.backends import redis as redis_backend_aio -from protobunny.backends import mosquitto as mosquitto_backend -from protobunny.backends import python as python_backend -from protobunny.backends import rabbitmq as rabbitmq_backend -from protobunny.backends import redis as redis_backend -from protobunny.config import backend_configs -from protobunny.models import ProtoBunnyMessage +from protobunny import get_backend +from protobunny.conf import Config, backend_configs from . import tests from .utils import async_wait, sync_wait @@ -26,9 +19,7 @@ @pytest.mark.integration @pytest.mark.asyncio -@pytest.mark.parametrize( - "backend", [rabbitmq_backend_aio, redis_backend_aio, python_backend_aio, mosquitto_backend_aio] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestTasks: msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) received = { @@ -37,38 +28,38 @@ class TestTasks: } @pytest.fixture(autouse=True) - async def setup_test_env(self, mocker, test_config, backend) -> tp.AsyncGenerator[None, None]: - backend_name = backend.__name__.split(".")[-1] + async def setup_test_env( + self, mocker, test_config: Config, backend: str + ) -> tp.AsyncGenerator[None, None]: test_config.mode = "async" - test_config.backend = backend_name - test_config.backend_config = backend_configs[backend_name] + test_config.backend = backend + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb.backends, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb.backend = backend - mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb, "connect", backend.connection.connect) - mocker.patch.object(pb, "disconnect", backend.connection.disconnect) - mocker.patch.object(pb, "get_backend", return_value=backend) + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb_sync.backends, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + backend_module = get_backend() + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + + pb.backend = backend_module + mocker.patch("protobunny.helpers.get_backend", return_value=backend_module) + mocker.patch.object(pb, "get_backend", return_value=backend_module) # Assert the patching is working for setting the backend connection = await pb.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) queue = pb.get_queue(self.msg) assert queue.topic == "acme.tests.tasks.TaskMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(queue, backend.queues.AsyncQueue) - assert isinstance(await queue.get_connection(), backend.connection.Connection) + assert isinstance(queue, backend_module.queues.AsyncQueue) + assert isinstance(await queue.get_connection(), backend_module.connection.Connection) await queue.purge(reset_groups=True) # start without pending subscriptions await pb.unsubscribe_all(if_unused=False, if_empty=False) @@ -77,7 +68,7 @@ async def setup_test_env(self, mocker, test_config, backend) -> tp.AsyncGenerato yield await connection.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() async def test_tasks(self, backend) -> None: async def predicate_1() -> bool: @@ -86,12 +77,12 @@ async def predicate_1() -> bool: async def predicate_2() -> bool: return self.received.get("task_2") is not None - async def callback_task_1(msg: "ProtoBunnyMessage") -> None: + async def callback_task_1(msg: tests.tasks.TaskMessage) -> None: log.debug("CALLBACK TASK 1 %s", msg) await asyncio.sleep(0.1) # simulate some work self.received["task_1"] = msg - async def callback_task_2(msg: "ProtoBunnyMessage") -> None: + async def callback_task_2(msg: tests.tasks.TaskMessage) -> None: log.debug("CALLBACK TASK 2 %s", msg) await asyncio.sleep(0.1) # simulate some work self.received["task_2"] = msg @@ -99,30 +90,28 @@ async def callback_task_2(msg: "ProtoBunnyMessage") -> None: await pb.subscribe(tests.tasks.TaskMessage, callback_task_1) await pb.subscribe(tests.tasks.TaskMessage, callback_task_2) await pb.publish(self.msg) - assert await async_wait(predicate_1) - - assert self.received.get("task_2") is None + assert await async_wait(predicate_1) or await async_wait(predicate_2) + assert self.received.get("task_2") is None or self.received.get("task_1") is None self.received["task_1"] = None + self.received["task_2"] = None + await pb.publish(self.msg) await pb.publish(self.msg) await pb.publish(self.msg) - assert await async_wait(predicate_1, timeout=2, sleep=0.1) - assert await async_wait(predicate_2) - assert self.received["task_1"] == self.msg - assert self.received["task_2"] == self.msg + assert await async_wait(predicate_1) or await async_wait(predicate_2) + assert self.received["task_1"] == self.msg or self.received["task_2"] == self.msg + + await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) self.received["task_1"] = None self.received["task_2"] = None await pb.publish(self.msg) await pb.publish(self.msg) - assert await async_wait(predicate_1) - assert await async_wait(predicate_2) - await pb.unsubscribe(tests.tasks.TaskMessage, if_unused=False, if_empty=False) + assert self.received["task_1"] is None + assert self.received["task_2"] is None @pytest.mark.integration -@pytest.mark.parametrize( - "backend", [rabbitmq_backend, redis_backend, python_backend, mosquitto_backend] -) +@pytest.mark.parametrize("backend", ["rabbitmq", "redis", "python", "mosquitto", "nats"]) class TestTasksSync: msg = tests.tasks.TaskMessage(content="test", bbox=[1, 2, 3, 4]) received = { @@ -131,65 +120,72 @@ class TestTasksSync: } @pytest.fixture(autouse=True) - def setup_test_env(self, mocker, test_config, backend) -> tp.Generator[None, None, None]: - backend_name = backend.__name__.split(".")[-1] + def setup_test_env( + self, mocker, test_config: Config, backend: str + ) -> tp.Generator[None, None, None]: test_config.mode = "sync" - test_config.backend = backend_name - test_config.backend_config = backend_configs[backend_name] + test_config.backend = backend + test_config.backend_config = backend_configs[backend] self.topic_delimiter = test_config.backend_config.topic_delimiter # Patch global configuration for all modules that use it - mocker.patch.object(pb_sync.config, "default_configuration", test_config) - mocker.patch.object(pb_sync.models, "default_configuration", test_config) - mocker.patch.object(pb_sync.backends, "default_configuration", test_config) - mocker.patch.object(pb_sync.helpers, "default_configuration", test_config) - mocker.patch.object(pb.backends.redis.connection, "default_configuration", test_config) - if hasattr(backend.connection, "default_configuration"): - mocker.patch.object(backend.connection, "default_configuration", test_config) - if hasattr(backend.queues, "default_configuration"): - mocker.patch.object(backend.queues, "default_configuration", test_config) - - pb_sync.backend = backend - mocker.patch("protobunny.backends.get_backend", return_value=backend) - mocker.patch("protobunny.helpers.get_backend", return_value=backend) - mocker.patch.object(pb_sync, "connect", backend.connection.connect) - mocker.patch.object(pb_sync, "disconnect", backend.connection.disconnect) - - # Assert the patching is working for setting the backend + mocker.patch.object(pb_sync.conf, "config", test_config) + mocker.patch.object(pb_sync.models, "config", test_config) + mocker.patch.object(pb_sync.backends, "config", test_config) + mocker.patch.object(pb_sync.helpers, "config", test_config) + mocker.patch.object(pb.backends, "config", test_config) + + backend_module = get_backend() + async_backend_module = importlib.import_module(f"protobunny.asyncio.backends.{backend}") + if hasattr(backend_module.connection, "config"): + mocker.patch.object(backend_module.connection, "config", test_config) + if hasattr(backend_module.queues, "config"): + mocker.patch.object(backend_module.queues, "config", test_config) + if hasattr(async_backend_module.connection, "config"): + # The sync connection is often implemented as a wrapper of the relative async module. + # Patch the config of the async module as well + mocker.patch.object(async_backend_module.connection, "config", test_config) + + mocker.patch("protobunny.helpers.get_backend", return_value=backend_module) + + # Test the patching is working for setting the backend connection = pb_sync.connect() - assert isinstance(connection, backend.connection.Connection) + assert isinstance(connection, backend_module.connection.Connection) assert sync_wait(connection.is_connected) task_queue = pb_sync.get_queue(self.msg) assert task_queue.topic == "acme.tests.tasks.TaskMessage".replace( ".", test_config.backend_config.topic_delimiter ) - assert isinstance(task_queue, backend.queues.SyncQueue) - assert isinstance(task_queue.get_connection(), backend.connection.Connection) + assert isinstance(task_queue, backend_module.queues.SyncQueue) + assert isinstance(task_queue.get_connection(), backend_module.connection.Connection) task_queue.purge(reset_groups=True) # start without pending subscriptions pb_sync.unsubscribe_all(if_unused=False, if_empty=False) pb_sync.disconnect() # reset the variables holding the messages received self.received = {} - yield pb_sync.disconnect() - backend.connection.Connection.instance_by_vhost.clear() + backend_module.connection.Connection.instance_by_vhost.clear() + @pytest.mark.flaky(max_runs=3) def test_tasks(self, backend) -> None: + """ + Assert load balancing between worker callbacks + and that tasks callbacks don't receive duplicated messages + """ + def predicate_1() -> bool: return self.received.get("task_1") is not None def predicate_2() -> bool: return self.received.get("task_2") is not None - def callback_task_1(msg: "ProtoBunnyMessage") -> None: - log.debug("SYNC CALLBACK TASK 1 %s", msg) + def callback_task_1(msg: tests.tasks.TaskMessage) -> None: time.sleep(0.1) self.received["task_1"] = msg - def callback_task_2(msg: "ProtoBunnyMessage") -> None: - log.debug("SYNC CALLBACK TASK 2 %s", msg) + def callback_task_2(msg: tests.tasks.TaskMessage) -> None: time.sleep(0.1) self.received["task_2"] = msg @@ -197,22 +193,21 @@ def callback_task_2(msg: "ProtoBunnyMessage") -> None: pb_sync.subscribe(tests.tasks.TaskMessage, callback_task_2) pb_sync.publish(self.msg) - assert sync_wait(predicate_1) + assert sync_wait(predicate_1) or sync_wait(predicate_2) - assert self.received.get("task_2") is None + assert self.received.get("task_2") is None or self.received.get("task_1") is None self.received["task_1"] = None + self.received["task_2"] = None pb_sync.publish(self.msg) pb_sync.publish(self.msg) pb_sync.publish(self.msg) - assert sync_wait(predicate_1) - assert sync_wait(predicate_2) + pb_sync.publish(self.msg) + # this is a bit flaky because of the backend load balancing + assert sync_wait(predicate_1, timeout=2) or sync_wait(predicate_2, timeout=2) + assert sync_wait(predicate_2, timeout=2) or sync_wait(predicate_1, timeout=2) assert self.received["task_1"] == self.msg assert self.received["task_2"] == self.msg self.received["task_1"] = None self.received["task_2"] = None pb_sync.publish(self.msg) - pb_sync.publish(self.msg) - assert sync_wait(predicate_1) - assert sync_wait(predicate_2) - self.received["task_2"] = None - self.received["task_1"] = None + assert sync_wait(predicate_1) or sync_wait(predicate_2) diff --git a/tests/utils.py b/tests/utils.py index 0cf9fae..88714d9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import asyncio import time -from types import ModuleType from unittest.mock import ANY, AsyncMock from aio_pika import Message @@ -43,16 +42,15 @@ async def tear_down(event_loop): def incoming_message_factory(backend, body: bytes = b"Hello"): - backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": + if backend == "rabbitmq": return Message(body=body) - elif backend_name == "redis": + elif backend == "redis": return Envelope(body=body, correlation_id="123") return Envelope(body=body) async def assert_backend_publish( - backend: ModuleType, + backend: str, internal_mock: dict | redis.Redis | AsyncMock, mock_connection: "BaseConnection", backend_msg, @@ -60,8 +58,7 @@ async def assert_backend_publish( count_in_queue: int = 1, shared_queue: bool = False, ): - backend_name = backend.__name__.split(".")[-1] - match backend_name: + match backend: case "rabbitmq": internal_mock["exchange"].publish.assert_awaited_with( backend_msg, routing_key=topic, mandatory=True, immediate=False @@ -111,103 +108,140 @@ async def predicate(): assert ( await internal_mock.get_message_count(topic) == count_in_queue ), f"count was {await internal_mock.get_message_count(topic)}" - case "mosquitto": - internal_mock.publish.assert_awaited_once() + if not shared_queue: + internal_mock.publish.assert_awaited_once_with( + "protobunny/test/routing/key", payload=b"Hello", qos=1, retain=False + ) + else: + internal_mock.publish.assert_awaited_once_with( + "protobunny/test/tasks/key", payload=b"Hello", qos=1, retain=False + ) + case "nats": + if not shared_queue: + internal_mock["client"].publish.assert_awaited_once_with( + subject="protobunny.test.routing.key", payload=b"Hello", headers=None + ) + else: + internal_mock["js"].publish.assert_awaited_once_with( + subject="TASKS.protobunny.test.tasks.key", payload=b"Hello", headers=None + ) async def assert_backend_setup_queue( backend, internal_mock, topic: str, shared: bool, mock_connection ) -> None: - backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": - internal_mock["channel"].declare_queue.assert_called_with( - topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY - ) - elif backend_name == "redis": - if shared: - streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") - assert len(streams) == 1 - assert ( - streams[0]["name"].decode() == "shared_group" - ), f"Expected 'shared_group', got '{streams[0]['name']}'" - else: - assert mock_connection.queues[topic] == {"is_shared": False} - - elif backend_name == "python": - if not shared: - assert len(internal_mock._exclusive_queues.get(topic)) == 1 - else: - assert internal_mock._shared_queues.get(topic) - elif backend_name == "mosquitto": - if not shared: - assert mock_connection.queues[topic] == { - "is_shared": False, - "group_name": "", - "sub_key": f"$share/shared_group/test/{topic}", - "tag": ANY, - "topic": topic, - "topic_key": f"test/{topic}", - } - else: - assert list(mock_connection.queues.values())[0] == { - "is_shared": True, - "group_name": "shared_group", - "sub_key": f"$share/shared_group/protobunny/{topic}", - "tag": ANY, - "topic": topic, - "topic_key": f"protobunny/{topic}", - }, mock_connection.queues + match backend: + case "rabbitmq": + internal_mock["channel"].declare_queue.assert_called_with( + topic, exclusive=not shared, durable=True, auto_delete=False, arguments=ANY + ) + case "redis": + if shared: + streams = await internal_mock.xinfo_groups(f"protobunny:{topic}") + assert len(streams) == 1 + assert ( + streams[0]["name"].decode() == "shared_group" + ), f"Expected 'shared_group', got '{streams[0]['name']}'" + else: + assert mock_connection.queues[topic] == {"is_shared": False} + case "python": + if not shared: + assert len(internal_mock._exclusive_queues.get(topic)) == 1 + else: + assert internal_mock._shared_queues.get(topic) + case "mosquitto": + if not shared: + assert mock_connection.queues[topic] == { + "is_shared": False, + "group_name": "", + "sub_key": f"$share/shared_group/test/{topic}", + "tag": ANY, + "topic": topic, + "topic_key": f"test/{topic}", + } + else: + assert list(mock_connection.queues.values())[0] == { + "is_shared": True, + "group_name": "shared_group", + "sub_key": f"$share/shared_group/protobunny/{topic}", + "tag": ANY, + "topic": topic, + "topic_key": f"protobunny/{topic}", + }, mock_connection.queues + case "nats": + internal_mock["js"].subscribe.assert_awaited_once_with( + subject="TASKS.protobunny.mylib.tasks.TaskMessage", + queue="protobunny_mylib_tasks_TaskMessage", + durable="protobunny_mylib_tasks_TaskMessage", + cb=ANY, + manual_ack=True, + stream="PROTOBUNNY_TASKS", + ) async def assert_backend_connection(backend, internal_mock): - backend_name = backend.__name__.split(".")[-1] - if backend_name == "rabbitmq": - # Verify aio_pika calls - internal_mock["connect"].assert_awaited_once() - assert internal_mock["channel"].set_qos.called - # Check if main and DLX exchanges were declared - assert internal_mock["channel"].declare_exchange.call_count == 2 - elif backend_name == "redis": - assert await internal_mock.ping() - elif backend_name == "mosquitto": - internal_mock.__aenter__.assert_awaited_once() - assert True - - -def get_mocked_connection(backend, redis_client, mock_aio_pika, mocker, mock_mosquitto): - backend_name = backend.__name__.split(".")[-1] - if backend_name == "redis": - real_conn_with_fake_redis = backend.connection.Connection(url="redis://localhost:6379/0") - assert ( - real_conn_with_fake_redis._exchange == "protobunny" - ), real_conn_with_fake_redis._exchange - real_conn_with_fake_redis._connection = redis_client - - def check_connected() -> bool: - return real_conn_with_fake_redis._connection is not None - - # Patch is_connected logic - mocker.patch.object(real_conn_with_fake_redis, "is_connected", side_effect=check_connected) - - return real_conn_with_fake_redis - elif backend_name == "rabbitmq": - real_conn_with_fake_aio_pika = backend.connection.Connection( - url="amqp://guest:guest@localhost:5672/" - ) - real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] - real_conn_with_fake_aio_pika.is_connected_event.set() - real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] - real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] - real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] - - return real_conn_with_fake_aio_pika - elif backend_name == "python": - python_conn = backend.connection.Connection() - python_conn.is_connected_event.set() - return python_conn - elif backend_name == "mosquitto": - real_conn_with_fake_aiomqtt = backend.connection.Connection() - real_conn_with_fake_aiomqtt._connection = mock_mosquitto - # real_conn_with_fake_aiomqtt.is_connected_event.set() - return real_conn_with_fake_aiomqtt + match backend: + case "rabbitmq": + # Verify aio_pika calls + internal_mock["connect"].assert_awaited_once() + assert internal_mock["channel"].set_qos.called + # Check if main and DLX exchanges were declared + assert internal_mock["channel"].declare_exchange.call_count == 2 + case "redis": + assert await internal_mock.ping() + case "python": + assert True # python is always "connected" + case "mosquitto": + internal_mock.__aenter__.assert_awaited_once() + case "nats": + import nats + + nats.connect.assert_awaited_once_with("nats://localhost:4222/") + + +def get_mocked_connection( + backend_module, redis_client, mock_aio_pika, mocker, mock_mosquitto, mock_nats +): + backend_name = backend_module.__name__.split(".")[-1] + match backend_name: + case "redis": + real_conn_with_fake_redis = backend_module.connection.Connection( + url="redis://localhost:6379/0" + ) + assert ( + real_conn_with_fake_redis._exchange == "protobunny" + ), real_conn_with_fake_redis._exchange + real_conn_with_fake_redis._connection = redis_client + + def check_connected() -> bool: + return real_conn_with_fake_redis._connection is not None + + # Patch is_connected logic + mocker.patch.object( + real_conn_with_fake_redis, "is_connected", side_effect=check_connected + ) + return real_conn_with_fake_redis + case "nats": + real_conn_with_fake_nats = backend_module.connection.Connection() + real_conn_with_fake_nats._connection = mock_nats["client"] + return real_conn_with_fake_nats + case "rabbitmq": + real_conn_with_fake_aio_pika = backend_module.connection.Connection( + url="amqp://guest:guest@localhost:5672/" + ) + real_conn_with_fake_aio_pika._connection = mock_aio_pika["connection"] + real_conn_with_fake_aio_pika.is_connected_event.set() + real_conn_with_fake_aio_pika._channel = mock_aio_pika["channel"] + real_conn_with_fake_aio_pika._exchange = mock_aio_pika["exchange"] + real_conn_with_fake_aio_pika._queue = mock_aio_pika["queue"] + return real_conn_with_fake_aio_pika + case "python": + python_conn = backend_module.connection.Connection() + python_conn.is_connected_event.set() + return python_conn + case "mosquitto": + real_conn_with_fake_aiomqtt = backend_module.connection.Connection() + real_conn_with_fake_aiomqtt._connection = mock_mosquitto + return real_conn_with_fake_aiomqtt + return None diff --git a/uv.lock b/uv.lock index dd4283d..cc5b48d 100644 --- a/uv.lock +++ b/uv.lock @@ -1029,6 +1029,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/df/76d0321c3797b54b60fef9ec3bd6f4cfd124b9e422182156a1dd418722cf/myst_parser-4.0.1-py3-none-any.whl", hash = "sha256:9134e88959ec3b5780aedf8a99680ea242869d012e8821db3126d427edc9c95d", size = 84579, upload-time = "2025-02-12T10:53:02.078Z" }, ] +[[package]] +name = "nats-py" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/c5/2564d917503fe8d68fe630c74bf6b678fbc15c01b58f2565894761010f57/nats_py-2.12.0.tar.gz", hash = "sha256:2981ca4b63b8266c855573fa7871b1be741f1889fd429ee657e5ffc0971a38a1", size = 119821, upload-time = "2025-10-31T05:27:31.247Z" } + [[package]] name = "numpy" version = "1.26.4" @@ -1268,14 +1274,13 @@ wheels = [ [[package]] name = "protobunny" -version = "0.1.2a1" +version = "0.1.2a2" source = { editable = "." } dependencies = [ { name = "betterproto", extra = ["compiler"] }, { name = "can-ada" }, { name = "click" }, { name = "grpcio-tools" }, - { name = "pytest-env" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] @@ -1284,6 +1289,9 @@ dependencies = [ mosquitto = [ { name = "aiomqtt" }, ] +nats = [ + { name = "nats-py" }, +] numpy = [ { name = "numpy" }, ] @@ -1304,6 +1312,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-env" }, { name = "pytest-mock" }, { name = "ruff" }, { name = "toml-sort" }, @@ -1329,13 +1338,13 @@ requires-dist = [ { name = "can-ada", specifier = ">=2.0.0" }, { name = "click", specifier = ">=8.2.1" }, { name = "grpcio-tools", specifier = ">=1.62.0,<2" }, + { name = "nats-py", marker = "extra == 'nats'", specifier = ">=2.12.0" }, { name = "numpy", marker = "extra == 'numpy'", specifier = ">=1.26" }, - { name = "pytest-env", specifier = ">=1.1.3" }, { name = "redis", marker = "extra == 'redis'", specifier = "<8" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -provides-extras = ["rabbitmq", "redis", "numpy", "mosquitto"] +provides-extras = ["rabbitmq", "redis", "numpy", "mosquitto", "nats"] [package.metadata.requires-dev] dev = [ @@ -1347,6 +1356,7 @@ dev = [ { name = "pytest", specifier = ">=7.2.2,<8" }, { name = "pytest-asyncio", specifier = ">=0.23.7,<0.24" }, { name = "pytest-cov", specifier = ">=4.0.0,<5" }, + { name = "pytest-env", specifier = ">=1.1.3" }, { name = "pytest-mock", specifier = ">=3.14.0,<4" }, { name = "ruff", specifier = ">=0.1.14,<0.2" }, { name = "toml-sort", specifier = ">=0.24.3" },