diff --git a/src/dataclay/backend/api.py b/src/dataclay/backend/api.py index a696f5e6..3b8041f1 100644 --- a/src/dataclay/backend/api.py +++ b/src/dataclay/backend/api.py @@ -10,7 +10,7 @@ from threadpoolctl import threadpool_limits from dataclay import utils -from dataclay.config import set_runtime, settings +from dataclay.config import LEGACY_DEPS, set_runtime, settings from dataclay.event_loop import dc_to_thread_io from dataclay.exceptions import ( DataClayException, @@ -19,6 +19,7 @@ ObjectWithWrongBackendIdError, ) from dataclay.lock_manager import lock_manager +from ..metadata.kvdata import ObjectMetadata from dataclay.runtime import BackendRuntime from dataclay.utils.serialization import dcdumps, dcloads, recursive_dcloads from dataclay.utils.telemetry import trace @@ -72,8 +73,14 @@ async def is_ready(self, timeout: Optional[float] = None, pause: float = 0.5): async def register_objects(self, serialized_objects: Iterable[bytes], make_replica: bool): logger.debug("Receiving (%d) objects to register", len(serialized_objects)) for object_bytes in serialized_objects: - state, getstate = await dcloads(object_bytes) - instance = await self.runtime.get_object_by_id(state["_dc_meta"].id) + metadata_dict, dc_properties, getstate = await dcloads(object_bytes) + + if LEGACY_DEPS: + dc_meta = ObjectMetadata.parse_obj(metadata_dict) + else: + dc_meta = ObjectMetadata.model_validate(metadata_dict) + + instance = await self.runtime.get_object_by_id(dc_meta.id) if instance._dc_is_local: assert instance._dc_is_replica @@ -81,6 +88,8 @@ async def register_objects(self, serialized_objects: Iterable[bytes], make_repli logger.warning("Replica already exists with id=%s", instance._dc_meta.id) continue + state = {"_dc_meta": dc_meta} + async with lock_manager.get_lock(instance._dc_meta.id).writer_lock: # Update object state and flags state["_dc_is_loaded"] = True @@ -88,6 +97,8 @@ async def register_objects(self, serialized_objects: Iterable[bytes], make_repli vars(instance).update(state) if getstate: instance.__setstate__(getstate) + else: + vars(instance).update(dc_properties) self.runtime.data_manager.add_hard_reference(instance) if make_replica: diff --git a/src/dataclay/data_manager.py b/src/dataclay/data_manager.py index 8a7ab10a..feec1b9e 100644 --- a/src/dataclay/data_manager.py +++ b/src/dataclay/data_manager.py @@ -141,20 +141,18 @@ async def load_object(self, instance: DataClayObject): path = f"{settings.storage_path}/{object_id}" # TODO: Is it necessary dc_to_thread_cpu? Should be blocking # to avoid bugs with parallel loads? - state, getstate = await dc_to_thread_cpu(pickle.load, open(path, "rb")) + metadata_dict, dc_properties, getstate = await dc_to_thread_cpu(pickle.load, open(path, "rb")) self.dataclay_stored_objects.dec() except Exception as e: raise ObjectNotFound(object_id) from e - # Delete outdated metadata (SSOT stored in Redis) - del state["_dc_meta"] - vars(instance).update(state) - # NOTE: We need to set _dc_is_loaded before calling __setstate__ # to avoid infinite recursion instance._dc_is_loaded = True if getstate is not None: instance.__setstate__(getstate) + else: + vars(instance).update(dc_properties) self.add_hard_reference(instance) logger.debug("(%s) Loaded '%s'", object_id, instance.__class__.__name__) diff --git a/src/dataclay/dataclay_object.py b/src/dataclay/dataclay_object.py index 03a8ec4a..bb75cb44 100644 --- a/src/dataclay/dataclay_object.py +++ b/src/dataclay/dataclay_object.py @@ -313,12 +313,16 @@ def _dc_properties(self) -> dict[str, Any]: @property def _dc_state(self) -> tuple[dict, Any]: """Returns the object state""" - state = {"_dc_meta": self._dc_meta} + + if LEGACY_DEPS: + metadata_dict = self._dc_meta.dict() + else: + metadata_dict = self._dc_meta.model_dump() + if hasattr(self, "__getstate__") and hasattr(self, "__setstate__"): - return state, self.__getstate__() + return metadata_dict, None, self.__getstate__() else: - state.update(self._dc_properties) - return state, None + return metadata_dict, self._dc_properties, None @property def _dc_all_backend_ids(self) -> set[UUID]: diff --git a/src/dataclay/utils/serialization.py b/src/dataclay/utils/serialization.py index 68c3e57e..6b8058f0 100644 --- a/src/dataclay/utils/serialization.py +++ b/src/dataclay/utils/serialization.py @@ -10,9 +10,10 @@ from uuid import UUID from dataclay import utils -from dataclay.config import get_runtime +from dataclay.config import LEGACY_DEPS, get_runtime from dataclay.dataclay_object import DataClayObject from dataclay.event_loop import dc_to_thread_cpu, get_dc_event_loop +from dataclay.metadata.kvdata import ObjectMetadata logger = logging.getLogger(__name__) @@ -151,22 +152,29 @@ async def recursive_dcloads(object_binary, unserialized_objects: dict[UUID, Data unserialized_objects = {} # Use dc_to_thread_cpu to avoid blocking the event loop in `get_by_id_sync` - object_dict, state = await dc_to_thread_cpu( + object_metadata_dict, dc_properties, state = await dc_to_thread_cpu( RecursiveDataClayObjectUnpickler(io.BytesIO(object_binary), unserialized_objects).load ) - object_id = object_dict["_dc_meta"].id + if LEGACY_DEPS: + dc_meta = ObjectMetadata.parse_obj(object_metadata_dict) + else: + dc_meta = ObjectMetadata.model_validate(object_metadata_dict) + + object_id = dc_meta.id try: # In case it was already unserialized by a reference proxy_object = unserialized_objects[object_id] except KeyError: - cls: type[DataClayObject] = utils.get_class_by_name(object_dict["_dc_meta"].class_name) + cls: type[DataClayObject] = utils.get_class_by_name(dc_meta.class_name) proxy_object = cls.new_proxy_object() unserialized_objects[object_id] = proxy_object - vars(proxy_object).update(object_dict) + vars(proxy_object)["_dc_meta"] = dc_meta if state is not None: proxy_object.__setstate__(state) + else: + vars(proxy_object).update(dc_properties) return proxy_object