diff --git a/umongo/fields.py b/umongo/fields.py index bebda3d..26b1e88 100644 --- a/umongo/fields.py +++ b/umongo/fields.py @@ -310,7 +310,7 @@ class ObjectIdField(BaseField, ma_bonus_fields.ObjectId): class ReferenceField(BaseField, ma_bonus_fields.Reference): - def __init__(self, document, *args, reference_cls=Reference, **kwargs): + def __init__(self, document, *args, reference_cls=None, **kwargs): """ :param document: Can be a :class:`umongo.embedded_document.DocumentTemplate`, another instance's :class:`umongo.embedded_document.DocumentImplementation` or diff --git a/umongo/frameworks/motor_asyncio.py b/umongo/frameworks/motor_asyncio.py index 30c8270..68fb1c5 100644 --- a/umongo/frameworks/motor_asyncio.py +++ b/umongo/frameworks/motor_asyncio.py @@ -430,7 +430,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = MotorAsyncIOReference + if field.reference_cls is None: + field.reference_cls = MotorAsyncIOReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate diff --git a/umongo/frameworks/pymongo.py b/umongo/frameworks/pymongo.py index 8815950..dc620d2 100644 --- a/umongo/frameworks/pymongo.py +++ b/umongo/frameworks/pymongo.py @@ -354,7 +354,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = PyMongoReference + if field.reference_cls is None: + field.reference_cls = PyMongoReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate diff --git a/umongo/frameworks/txmongo.py b/umongo/frameworks/txmongo.py index 9bbc074..8e0b1c6 100644 --- a/umongo/frameworks/txmongo.py +++ b/umongo/frameworks/txmongo.py @@ -367,7 +367,8 @@ def _patch_field(self, field): field.io_validate_recursive = _dict_io_validate if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) - field.reference_cls = TxMongoReference + if field.reference_cls is None: + field.reference_cls = TxMongoReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate