From 7b124d840cdfda769dcb8d37406923edd29fb014 Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sat, 31 May 2025 11:32:28 +1200 Subject: [PATCH] Add support for nested classes as task types. --- labtech/cache.py | 11 +++++++++-- labtech/serialization.py | 16 ++++++++++++---- tests/integration/test_e2e.py | 19 +++++++++++-------- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/labtech/cache.py b/labtech/cache.py index 103a599..91accce 100644 --- a/labtech/cache.py +++ b/labtech/cache.py @@ -19,6 +19,13 @@ from .types import ResultT, Storage, Task, TaskT +def _format_class_name(cls): + """Format a qualified class name in format that can be safely + represented in a filename.""" + # Nested classes may have periods that need to be replaced. + return cls.__qualname__.replace('.', '-') + + class NullCache(Cache): """Cache that never stores results in the storage provider.""" @@ -63,7 +70,7 @@ def cache_key(self, task: Task) -> str: # hashes, and security concerns with sha1 are not relevant to # our use case. hashed = hashlib.sha1(serialized_str).hexdigest() - return f'{self.KEY_PREFIX}{task.__class__.__qualname__}__{hashed}' + return f'{self.KEY_PREFIX}{_format_class_name(task.__class__)}__{hashed}' def is_cached(self, storage: Storage, task: Task) -> bool: return storage.exists(task.cache_key) @@ -91,7 +98,7 @@ def save(self, storage: Storage, task: Task[ResultT], task_result: TaskResult[Re self.save_result(storage, task, task_result.value) def load_metadata(self, storage: Storage, task_type: type[Task], key: str) -> dict[str, Any]: - if not key.startswith(f'{self.KEY_PREFIX}{task_type.__qualname__}'): + if not key.startswith(f'{self.KEY_PREFIX}{_format_class_name(task_type)}'): raise TaskNotFound with storage.file_handle(key, self.METADATA_FILENAME, mode='r') as metadata_file: metadata = json.load(metadata_file) diff --git a/labtech/serialization.py b/labtech/serialization.py index e18d91f..bdb94bd 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -122,9 +122,17 @@ def deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: return enum_cls[name] def serialize_class(self, cls: type) -> jsonable: - return f'{cls.__module__}.{cls.__qualname__}' + # Handle nested classes by splitting class nesting path by ">". + return f'{cls.__module__}.{cls.__qualname__.replace(".", ">")}' def deserialize_class(self, serialized_class: jsonable) -> type: - cls_module, cls_name = cast('str', serialized_class).rsplit('.', 1) - module = __import__(cls_module, fromlist=[cls_name]) - return getattr(module, cls_name) + cls_module, cls_qualname = cast('str', serialized_class).rsplit('.', 1) + cls_name_parts = cls_qualname.split('>') + module = __import__(cls_module, fromlist=[cls_name_parts[0]]) + + cls = getattr(module, cls_name_parts[0]) + for part in cls_name_parts[1:]: + # Navigate to nested class. + cls = getattr(cls, part) + + return cls diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index 04ffe9a..24a972a 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -16,12 +16,15 @@ from labtech.types import Task -@labtech.task(cache=None) -class ClassifierTask: - n_estimators: int +class Classifiers: - def run(self) -> dict: - return {'n_estimators': self.n_estimators} + # Testing a nested class + @labtech.task(cache=None) + class ClassifierTask: + n_estimators: int + + def run(self) -> dict: + return {'n_estimators': self.n_estimators} class ExperimentTask(Protocol): @@ -33,7 +36,7 @@ def run(self) -> dict: @labtech.task class ClassifierExperiment(ExperimentTask): - classifier_task: ClassifierTask + classifier_task: Classifiers.ClassifierTask dataset_key: str def filter_context(self, context: dict[str, Any]) -> dict[str, Any]: @@ -97,7 +100,7 @@ class Evaluation(TypedDict): def basic_evaluation(context: dict[str, Any]) -> Evaluation: """Evaluation of a standard setup of multiple levels of dependency.""" classifier_tasks = [ - ClassifierTask( + Classifiers.ClassifierTask( n_estimators=n_estimators, ) for n_estimators in range(1, 3) @@ -142,7 +145,7 @@ def repeated_dependency_evaluation(context: dict[str, Any]) -> Evaluation: dependency task.""" def experiment_factory(): return ClassifierExperiment( - classifier_task=ClassifierTask(n_estimators=2), + classifier_task=Classifiers.ClassifierTask(n_estimators=2), dataset_key=list(context['DATASETS'].keys())[0], ) experiment = experiment_factory()