From 2c4700d7b79a5ce9f984e08da33b738ba872d4ea Mon Sep 17 00:00:00 2001 From: daniil <6158938+richiq@users.noreply.github.com> Date: Mon, 3 May 2021 11:47:16 +0300 Subject: [PATCH] Fix #349 --- umongo/fields.py | 2 +- umongo/frameworks/motor_asyncio.py | 3 ++- umongo/frameworks/pymongo.py | 3 ++- umongo/frameworks/txmongo.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/umongo/fields.py b/umongo/fields.py index bebda3d4..26b1e88d 100644 --- a/umongo/fields.py +++ b/umongo/fields.py @@ -310,7 +310,7 @@ class ObjectIdField(BaseField, ma_bonus_fields.ObjectId): class ReferenceField(BaseField, ma_bonus_fields.Reference): - def __init__(self, document, *args, reference_cls=Reference, **kwargs): + def __init__(self, document, *args, reference_cls=None, **kwargs): """ :param document: Can be a :class:`umongo.embedded_document.DocumentTemplate`, another instance's :class:`umongo.embedded_document.DocumentImplementation` or diff --git a/umongo/frameworks/motor_asyncio.py b/umongo/frameworks/motor_asyncio.py index 30c82709..68fb1c53 100644 --- a/umongo/frameworks/motor_asyncio.py +++ b/umongo/frameworks/motor_asyncio.py @@ -430,7 +430,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = MotorAsyncIOReference + if field.reference_cls is None: + field.reference_cls = MotorAsyncIOReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate diff --git a/umongo/frameworks/pymongo.py b/umongo/frameworks/pymongo.py index 88159501..dc620d20 100644 --- a/umongo/frameworks/pymongo.py +++ b/umongo/frameworks/pymongo.py @@ -354,7 +354,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = PyMongoReference + if field.reference_cls is None: + field.reference_cls = PyMongoReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate diff --git a/umongo/frameworks/txmongo.py b/umongo/frameworks/txmongo.py index 9bbc074c..8e0b1c62 100644 --- a/umongo/frameworks/txmongo.py +++ b/umongo/frameworks/txmongo.py @@ -367,7 +367,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = TxMongoReference + if field.reference_cls is None: + field.reference_cls = TxMongoReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate