From 2b894f5a7d2ba917406c8bc260f2e9d8f94a19ae Mon Sep 17 00:00:00 2001 From: Nchinda Nchinda Date: Tue, 8 Jun 2021 22:43:09 -0400 Subject: [PATCH] Improve Motor GenericReference field --- tests/test_fields.py | 2 +- tests/test_marshmallow.py | 3 ++- umongo/fields.py | 2 +- umongo/frameworks/motor_asyncio.py | 6 +++++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index 6e81fa1f..e7337dfd 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -772,7 +772,7 @@ class ToRef2(Document): @self.instance.register class MyDoc(Document): - gref = fields.GenericReferenceField(attribute='in_mongo_gref', allow_none=True) + gref = fields.GenericReferenceField(attribute='in_mongo_gref', reference_cls=Reference, allow_none=True) MySchema = MyDoc.Schema diff --git a/tests/test_marshmallow.py b/tests/test_marshmallow.py index c3748b61..83d32b66 100644 --- a/tests/test_marshmallow.py +++ b/tests/test_marshmallow.py @@ -1,5 +1,6 @@ """Test marshmallow-related features""" import datetime as dt +from umongo.data_objects import Reference import pytest @@ -317,7 +318,7 @@ def test_marshmallow_bonus_fields(self): class Doc(Document): id = fields.ObjectIdField(attribute='_id') ref = fields.ReferenceField('Doc') - gen_ref = fields.GenericReferenceField() + gen_ref = fields.GenericReferenceField(reference_cls=Reference) for name, field_cls in ( ('id', ma_bonus_fields.ObjectId), diff --git a/umongo/fields.py b/umongo/fields.py index bebda3d4..c4730de6 100644 --- a/umongo/fields.py +++ b/umongo/fields.py @@ -374,7 +374,7 @@ def _deserialize_from_mongo(self, value): class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference): - def __init__(self, *args, reference_cls=Reference, **kwargs): + def __init__(self, *args, reference_cls=None, **kwargs): super().__init__(*args, **kwargs) self.reference_cls = reference_cls self._document_implementation_cls = DocumentImplementation diff --git a/umongo/frameworks/motor_asyncio.py b/umongo/frameworks/motor_asyncio.py index 30c82709..dba01835 100644 --- a/umongo/frameworks/motor_asyncio.py +++ b/umongo/frameworks/motor_asyncio.py @@ -14,7 +14,7 @@ from ..document import DocumentImplementation from ..data_objects import Reference from ..exceptions import NotCreatedError, UpdateError, DeleteError, NoneReferenceError -from ..fields import ReferenceField, ListField, DictField, EmbeddedField +from ..fields import ReferenceField, GenericReferenceField, ListField, DictField, EmbeddedField from ..query_mapper import map_query from .tools import cook_find_filter, remove_cls_field_from_embedded_docs @@ -431,6 +431,10 @@ def _patch_field(self, field): if isinstance(field, ReferenceField): field.io_validate.append(_reference_io_validate) field.reference_cls = MotorAsyncIOReference + if isinstance(field, GenericReferenceField): + field.io_validate.append(_reference_io_validate) + if field.reference_cls is None: + field.reference_cls = MotorAsyncIOReference if isinstance(field, EmbeddedField): field.io_validate_recursive = _embedded_document_io_validate