Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions albert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import importlib
from abc import ABC, abstractmethod
from enum import Enum
from functools import cache
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast

from albert.hashing import InternTable
Expand Down Expand Up @@ -463,14 +465,13 @@ def _hashable_fields(self) -> Iterable[SerialisedField]:
yield from self._children

@classmethod
@abstractmethod
def from_json(cls, data: Any) -> Base:
"""Return an object loaded from a JSON representation.

Returns:
Object loaded from JSON representation.
"""
pass
return _get_class(data["_module"], data["_type"]).from_json(data)

@abstractmethod
def __repr__(self) -> str:
Expand Down Expand Up @@ -560,3 +561,17 @@ def __neg__(self) -> Base:


_INTERN_TABLE = InternTable[Base]()


@cache
def _get_class(module: str, type_: str) -> type[Base]:
"""Get a class from JSON data for deserialisation.

Args:
module: Module path of the class.
type_: Type name of the class.

Returns:
Class object.
"""
return cast(type[Base], getattr(importlib.import_module(module), type_))
33 changes: 13 additions & 20 deletions albert/types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,56 @@ class SerialisedField(Comparable, Hashable, Protocol):
pass


class _ScalarJSON(TypedDict):
"""Type for JSON representation of a scalar."""
class _BaseJSON(TypedDict):
"""Base type for JSON representation."""

_type: str
_module: str


class _ScalarJSON(_BaseJSON):
"""Type for JSON representation of a scalar."""

value: float


class _IndexJSON(TypedDict):
class _IndexJSON(_BaseJSON):
"""Type for JSON representation of an index."""

_type: str
_module: str
name: str
spin: Optional[str]
space: Optional[str]


class _PermutationJSON(TypedDict):
class _PermutationJSON(_BaseJSON):
"""Type for JSON representation of a permutation."""

_type: str
_module: str
permutation: tuple[int, ...]
sign: int


class _SymmetryJSON(TypedDict):
class _SymmetryJSON(_BaseJSON):
"""Type for JSON representation of a symmetry group."""

_type: str
_module: str
permutations: tuple[_PermutationJSON, ...]


class _TensorJSON(TypedDict):
class _TensorJSON(_BaseJSON):
"""Type for JSON representation of a tensor."""

_type: str
_module: str
indices: tuple[_IndexJSON, ...]
name: str
symmetry: Optional[_SymmetryJSON]


class _AlgebraicJSON(TypedDict):
class _AlgebraicJSON(_BaseJSON):
"""Type for JSON representation of an algebraic operation."""

_type: str
_module: str
children: tuple[_AlgebraicJSON | _TensorJSON, ...]


class _ExpressionJSON(TypedDict):
class _ExpressionJSON(_BaseJSON):
"""Type for JSON representation of an expression."""

_type: str
_module: str
lhs: _TensorJSON
rhs: _TensorJSON | _AlgebraicJSON