diff --git a/albert/base.py b/albert/base.py index b7f31f7..f426143 100644 --- a/albert/base.py +++ b/albert/base.py @@ -2,8 +2,10 @@ from __future__ import annotations +import importlib from abc import ABC, abstractmethod from enum import Enum +from functools import cache from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast from albert.hashing import InternTable @@ -463,14 +465,13 @@ def _hashable_fields(self) -> Iterable[SerialisedField]: yield from self._children @classmethod - @abstractmethod def from_json(cls, data: Any) -> Base: """Return an object loaded from a JSON representation. Returns: Object loaded from JSON representation. """ - pass + return _get_class(data["_module"], data["_type"]).from_json(data) @abstractmethod def __repr__(self) -> str: @@ -560,3 +561,17 @@ def __neg__(self) -> Base: _INTERN_TABLE = InternTable[Base]() + + +@cache +def _get_class(module: str, type_: str) -> type[Base]: + """Get a class from JSON data for deserialisation. + + Args: + module: Module path of the class. + type_: Type name of the class. + + Returns: + Class object. + """ + return cast(type[Base], getattr(importlib.import_module(module), type_)) diff --git a/albert/types.pyi b/albert/types.pyi index 893a30e..2b1a162 100644 --- a/albert/types.pyi +++ b/albert/types.pyi @@ -23,63 +23,56 @@ class SerialisedField(Comparable, Hashable, Protocol): pass -class _ScalarJSON(TypedDict): - """Type for JSON representation of a scalar.""" +class _BaseJSON(TypedDict): + """Base type for JSON representation.""" _type: str _module: str + + +class _ScalarJSON(_BaseJSON): + """Type for JSON representation of a scalar.""" + value: float -class _IndexJSON(TypedDict): +class _IndexJSON(_BaseJSON): """Type for JSON representation of an index.""" - _type: str - _module: str name: str spin: Optional[str] space: Optional[str] -class _PermutationJSON(TypedDict): +class _PermutationJSON(_BaseJSON): """Type for JSON representation of a permutation.""" - _type: str - _module: str permutation: tuple[int, ...] sign: int -class _SymmetryJSON(TypedDict): +class _SymmetryJSON(_BaseJSON): """Type for JSON representation of a symmetry group.""" - _type: str - _module: str permutations: tuple[_PermutationJSON, ...] -class _TensorJSON(TypedDict): +class _TensorJSON(_BaseJSON): """Type for JSON representation of a tensor.""" - _type: str - _module: str indices: tuple[_IndexJSON, ...] name: str symmetry: Optional[_SymmetryJSON] -class _AlgebraicJSON(TypedDict): +class _AlgebraicJSON(_BaseJSON): """Type for JSON representation of an algebraic operation.""" - _type: str - _module: str children: tuple[_AlgebraicJSON | _TensorJSON, ...] -class _ExpressionJSON(TypedDict): +class _ExpressionJSON(_BaseJSON): """Type for JSON representation of an expression.""" - _type: str - _module: str lhs: _TensorJSON rhs: _TensorJSON | _AlgebraicJSON