Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions labtech/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from .types import ResultT, Storage, Task, TaskT


def _format_class_name(cls):
"""Format a qualified class name in format that can be safely
represented in a filename."""
# Nested classes may have periods that need to be replaced.
return cls.__qualname__.replace('.', '-')


class NullCache(Cache):
"""Cache that never stores results in the storage provider."""

Expand Down Expand Up @@ -63,7 +70,7 @@ def cache_key(self, task: Task) -> str:
# hashes, and security concerns with sha1 are not relevant to
# our use case.
hashed = hashlib.sha1(serialized_str).hexdigest()
return f'{self.KEY_PREFIX}{task.__class__.__qualname__}__{hashed}'
return f'{self.KEY_PREFIX}{_format_class_name(task.__class__)}__{hashed}'

def is_cached(self, storage: Storage, task: Task) -> bool:
return storage.exists(task.cache_key)
Expand Down Expand Up @@ -91,7 +98,7 @@ def save(self, storage: Storage, task: Task[ResultT], task_result: TaskResult[Re
self.save_result(storage, task, task_result.value)

def load_metadata(self, storage: Storage, task_type: type[Task], key: str) -> dict[str, Any]:
if not key.startswith(f'{self.KEY_PREFIX}{task_type.__qualname__}'):
if not key.startswith(f'{self.KEY_PREFIX}{_format_class_name(task_type)}'):
raise TaskNotFound
with storage.file_handle(key, self.METADATA_FILENAME, mode='r') as metadata_file:
metadata = json.load(metadata_file)
Expand Down
16 changes: 12 additions & 4 deletions labtech/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,17 @@ def deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum:
return enum_cls[name]

def serialize_class(self, cls: type) -> jsonable:
return f'{cls.__module__}.{cls.__qualname__}'
# Handle nested classes by splitting class nesting path by ">".
return f'{cls.__module__}.{cls.__qualname__.replace(".", ">")}'

def deserialize_class(self, serialized_class: jsonable) -> type:
cls_module, cls_name = cast('str', serialized_class).rsplit('.', 1)
module = __import__(cls_module, fromlist=[cls_name])
return getattr(module, cls_name)
cls_module, cls_qualname = cast('str', serialized_class).rsplit('.', 1)
cls_name_parts = cls_qualname.split('>')
module = __import__(cls_module, fromlist=[cls_name_parts[0]])

cls = getattr(module, cls_name_parts[0])
for part in cls_name_parts[1:]:
# Navigate to nested class.
cls = getattr(cls, part)

return cls
19 changes: 11 additions & 8 deletions tests/integration/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from labtech.types import Task


@labtech.task(cache=None)
class ClassifierTask:
n_estimators: int
class Classifiers:

def run(self) -> dict:
return {'n_estimators': self.n_estimators}
# Testing a nested class
@labtech.task(cache=None)
class ClassifierTask:
n_estimators: int

def run(self) -> dict:
return {'n_estimators': self.n_estimators}


class ExperimentTask(Protocol):
Expand All @@ -33,7 +36,7 @@ def run(self) -> dict:

@labtech.task
class ClassifierExperiment(ExperimentTask):
classifier_task: ClassifierTask
classifier_task: Classifiers.ClassifierTask
dataset_key: str

def filter_context(self, context: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -97,7 +100,7 @@ class Evaluation(TypedDict):
def basic_evaluation(context: dict[str, Any]) -> Evaluation:
"""Evaluation of a standard setup of multiple levels of dependency."""
classifier_tasks = [
ClassifierTask(
Classifiers.ClassifierTask(
n_estimators=n_estimators,
)
for n_estimators in range(1, 3)
Expand Down Expand Up @@ -142,7 +145,7 @@ def repeated_dependency_evaluation(context: dict[str, Any]) -> Evaluation:
dependency task."""
def experiment_factory():
return ClassifierExperiment(
classifier_task=ClassifierTask(n_estimators=2),
classifier_task=Classifiers.ClassifierTask(n_estimators=2),
dataset_key=list(context['DATASETS'].keys())[0],
)
experiment = experiment_factory()
Expand Down