Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/dataclay/backend/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from threadpoolctl import threadpool_limits

from dataclay import utils
from dataclay.config import set_runtime, settings
from dataclay.config import LEGACY_DEPS, set_runtime, settings
from dataclay.event_loop import dc_to_thread_io
from dataclay.exceptions import (
DataClayException,
Expand All @@ -19,6 +19,7 @@
ObjectWithWrongBackendIdError,
)
from dataclay.lock_manager import lock_manager
from ..metadata.kvdata import ObjectMetadata
from dataclay.runtime import BackendRuntime
from dataclay.utils.serialization import dcdumps, dcloads, recursive_dcloads
from dataclay.utils.telemetry import trace
Expand Down Expand Up @@ -72,22 +73,32 @@ async def is_ready(self, timeout: Optional[float] = None, pause: float = 0.5):
async def register_objects(self, serialized_objects: Iterable[bytes], make_replica: bool):
logger.debug("Receiving (%d) objects to register", len(serialized_objects))
for object_bytes in serialized_objects:
state, getstate = await dcloads(object_bytes)
instance = await self.runtime.get_object_by_id(state["_dc_meta"].id)
metadata_dict, dc_properties, getstate = await dcloads(object_bytes)

if LEGACY_DEPS:
dc_meta = ObjectMetadata.parse_obj(metadata_dict)
else:
dc_meta = ObjectMetadata.model_validate(metadata_dict)

instance = await self.runtime.get_object_by_id(dc_meta.id)

if instance._dc_is_local:
assert instance._dc_is_replica
if make_replica:
logger.warning("Replica already exists with id=%s", instance._dc_meta.id)
continue

state = {"_dc_meta": dc_meta}

async with lock_manager.get_lock(instance._dc_meta.id).writer_lock:
# Update object state and flags
state["_dc_is_loaded"] = True
state["_dc_is_local"] = True
vars(instance).update(state)
if getstate:
instance.__setstate__(getstate)
else:
vars(instance).update(dc_properties)
self.runtime.data_manager.add_hard_reference(instance)

if make_replica:
Expand Down
8 changes: 3 additions & 5 deletions src/dataclay/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,18 @@ async def load_object(self, instance: DataClayObject):
path = f"{settings.storage_path}/{object_id}"
# TODO: Is it necessary dc_to_thread_cpu? Should be blocking
# to avoid bugs with parallel loads?
state, getstate = await dc_to_thread_cpu(pickle.load, open(path, "rb"))
metadata_dict, dc_properties, getstate = await dc_to_thread_cpu(pickle.load, open(path, "rb"))
self.dataclay_stored_objects.dec()
except Exception as e:
raise ObjectNotFound(object_id) from e

# Delete outdated metadata (SSOT stored in Redis)
del state["_dc_meta"]
vars(instance).update(state)

# NOTE: We need to set _dc_is_loaded before calling __setstate__
# to avoid infinite recursion
instance._dc_is_loaded = True
if getstate is not None:
instance.__setstate__(getstate)
else:
vars(instance).update(dc_properties)

self.add_hard_reference(instance)
logger.debug("(%s) Loaded '%s'", object_id, instance.__class__.__name__)
Expand Down
12 changes: 8 additions & 4 deletions src/dataclay/dataclay_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,16 @@ def _dc_properties(self) -> dict[str, Any]:
@property
def _dc_state(self) -> tuple[dict, Any]:
"""Returns the object state"""
state = {"_dc_meta": self._dc_meta}

if LEGACY_DEPS:
metadata_dict = self._dc_meta.dict()
else:
metadata_dict = self._dc_meta.model_dump()

if hasattr(self, "__getstate__") and hasattr(self, "__setstate__"):
return state, self.__getstate__()
return metadata_dict, None, self.__getstate__()
else:
state.update(self._dc_properties)
return state, None
return metadata_dict, self._dc_properties, None

@property
def _dc_all_backend_ids(self) -> set[UUID]:
Expand Down
18 changes: 13 additions & 5 deletions src/dataclay/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from uuid import UUID

from dataclay import utils
from dataclay.config import get_runtime
from dataclay.config import LEGACY_DEPS, get_runtime
from dataclay.dataclay_object import DataClayObject
from dataclay.event_loop import dc_to_thread_cpu, get_dc_event_loop
from dataclay.metadata.kvdata import ObjectMetadata

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,22 +152,29 @@ async def recursive_dcloads(object_binary, unserialized_objects: dict[UUID, Data
unserialized_objects = {}

# Use dc_to_thread_cpu to avoid blocking the event loop in `get_by_id_sync`
object_dict, state = await dc_to_thread_cpu(
object_metadata_dict, dc_properties, state = await dc_to_thread_cpu(
RecursiveDataClayObjectUnpickler(io.BytesIO(object_binary), unserialized_objects).load
)

object_id = object_dict["_dc_meta"].id
if LEGACY_DEPS:
dc_meta = ObjectMetadata.parse_obj(object_metadata_dict)
else:
dc_meta = ObjectMetadata.model_validate(object_metadata_dict)

object_id = dc_meta.id
try:
# In case it was already unserialized by a reference
proxy_object = unserialized_objects[object_id]
except KeyError:
cls: type[DataClayObject] = utils.get_class_by_name(object_dict["_dc_meta"].class_name)
cls: type[DataClayObject] = utils.get_class_by_name(dc_meta.class_name)
proxy_object = cls.new_proxy_object()
unserialized_objects[object_id] = proxy_object

vars(proxy_object).update(object_dict)
vars(proxy_object)["_dc_meta"] = dc_meta
if state is not None:
proxy_object.__setstate__(state)
else:
vars(proxy_object).update(dc_properties)
return proxy_object


Expand Down
Loading