diff --git a/mongoengine_plus/aio/async_document.py b/mongoengine_plus/aio/async_document.py index abe8a9d..30c241f 100644 --- a/mongoengine_plus/aio/async_document.py +++ b/mongoengine_plus/aio/async_document.py @@ -1,6 +1,7 @@ from mongoengine import Document from .async_query_set import AsyncQuerySet +from .async_signals import post_save, pre_save from .utils import create_awaitable @@ -11,7 +12,15 @@ class AsyncDocument(Document): ) async def async_save(self, *args, **kwargs): - return await create_awaitable(self.save, *args, **kwargs) + signal_kwargs = kwargs.pop("signal_kwargs", {}) + await pre_save.send_async( + self.__class__, document=self, **signal_kwargs + ) + result = await create_awaitable(self.save, *args, **kwargs) + await post_save.send_async( + self.__class__, document=self, **signal_kwargs + ) + return result async def async_reload(self, *fields, **kwargs): return await create_awaitable(self.reload, *fields, **kwargs) diff --git a/mongoengine_plus/aio/async_signals.py b/mongoengine_plus/aio/async_signals.py new file mode 100644 index 0000000..8b59d96 --- /dev/null +++ b/mongoengine_plus/aio/async_signals.py @@ -0,0 +1,6 @@ +from mongoengine.signals import Namespace + +async_signals = Namespace() + +pre_save = async_signals.signal("pre_save") +post_save = async_signals.signal("post_save") diff --git a/mongoengine_plus/version.py b/mongoengine_plus/version.py index 1a72d32..58d478a 100644 --- a/mongoengine_plus/version.py +++ b/mongoengine_plus/version.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.2.0' diff --git a/setup.py b/setup.py index abc9678..a339550 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'pymongo>=3.13.0,<4.0.0', 'pymongocrypt>=1.12.2,<2.0.0', 'boto3>=1.34.106,<2.0.0', + 'blinker>=1.9.0,<2.0.0', ], classifiers=[ 'Programming Language :: Python :: 3.9', diff --git a/tests/aio/test_async_signals.py b/tests/aio/test_async_signals.py new file mode 100644 index 0000000..29ee745 --- /dev/null +++ b/tests/aio/test_async_signals.py @@ -0,0 +1,43 @@ +import pytest +from mongoengine import StringField + +from mongoengine_plus.aio.async_document import AsyncDocument +from mongoengine_plus.aio.async_signals import post_save, pre_save +from mongoengine_plus.models.event_handlers import handler + + +@pytest.mark.asyncio +async def test_async_signal_handler_on_asyncdocument(): + pre_calls = [] + post_calls = [] + + @handler(pre_save) + async def my_async_pre_handler(cls, document, **kwargs): + pre_calls.append((document.name, getattr(document, "updated", False))) + + @handler(post_save) + async def my_async_post_handler(cls, document, **kwargs): + post_calls.append((document.name, getattr(document, "updated", False))) + + @my_async_pre_handler.apply + @my_async_post_handler.apply + class User(AsyncDocument): + name = StringField(required=True) + updated = StringField() + + user = User(name="Jane") + await user.async_save() + assert pre_calls[-1][0] == "Jane" + assert post_calls[-1][0] == "Jane" + + user.name = "John" + user.updated = "yes" + await user.async_save() + + # The handlers should have been called again with updated data + assert pre_calls[-1][0] == "John" + assert pre_calls[-1][1] == "yes" + assert post_calls[-1][0] == "John" + assert post_calls[-1][1] == "yes" + assert len(pre_calls) == 2 + assert len(post_calls) == 2