diff --git a/albert/algebra.py b/albert/algebra.py index 1599c66..a3b9745 100644 --- a/albert/algebra.py +++ b/albert/algebra.py @@ -12,11 +12,11 @@ from albert.scalar import Scalar if TYPE_CHECKING: - from typing import Any, Iterable + from typing import Any, Callable, Iterable from albert.base import TypeOrFilter from albert.index import Index - from albert.types import _AlgebraicJSON + from albert.types import EvaluatorArrayDict, _AlgebraicJSON T = TypeVar("T", bound=Base) @@ -299,6 +299,31 @@ def delete( return self.factory(*children) return Scalar.factory(0.0) + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + return sum( + child.evaluate(arrays, einsum).transpose( + tuple(child.external_indices.index(index) for index in self.external_indices) + ) + for child in self.children + ) + @property def disjoint(self) -> bool: """Return whether the object is disjoint.""" @@ -447,6 +472,44 @@ def delete( return self.factory(*children) return Scalar.factory(0.0) + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + # Get the arrays and indices + child_index_map: dict[Index, int] = {} + factor = 1.0 + args: list[Any] = [] + for child in self.children: + if isinstance(child, Scalar): + factor *= child.evaluate(arrays, einsum) + else: + for index in child.external_indices: + if index not in child_index_map: + child_index_map[index] = len(child_index_map) + args.append(child.evaluate(arrays, einsum)) + args.append(tuple(child_index_map[index] for index in child.external_indices)) + + # Call the einsum function + output_indices = tuple(child_index_map[index] for index in self.external_indices) + result = einsum(*args, output_indices) + + return result * factor + @property def disjoint(self) -> bool: """Return whether the object is disjoint.""" diff --git a/albert/base.py b/albert/base.py index 273ea27..b7f31f7 100644 --- a/albert/base.py +++ b/albert/base.py @@ -15,7 +15,7 @@ from albert.index import Index from albert.symmetry import Permutation - from albert.types import SerialisedField + from albert.types import EvaluatorArrayDict, SerialisedField T = TypeVar("T", bound="Base") TypeOrFilter = Optional[type[T] | tuple[type[T], ...] | Callable[["Base"], bool]] @@ -274,6 +274,27 @@ def delete( """ pass + @abstractmethod + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + pass + @property def external_indices(self) -> tuple[Index, ...]: """Get the external indices (those that are not summed over).""" diff --git a/albert/expression.py b/albert/expression.py index d65fc9d..ed82f00 100644 --- a/albert/expression.py +++ b/albert/expression.py @@ -8,10 +8,10 @@ from albert.tensor import Tensor if TYPE_CHECKING: - from typing import Iterable + from typing import Any, Callable, Iterable from albert.index import Index - from albert.types import SerialisedField, _ExpressionJSON + from albert.types import EvaluatorArrayDict, SerialisedField, _ExpressionJSON class Expression(Serialisable): @@ -71,6 +71,26 @@ def copy(self) -> Expression: """ return Expression(self._lhs.copy(), self._rhs.copy()) + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + return self.rhs.evaluate(arrays, einsum) + def as_json(self) -> _ExpressionJSON: """Return a JSON representation of the object. diff --git a/albert/scalar.py b/albert/scalar.py index 6bab5ac..9f5a722 100644 --- a/albert/scalar.py +++ b/albert/scalar.py @@ -7,11 +7,11 @@ from albert.base import _INTERN_TABLE, Base, _matches_filter if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Callable, Optional from albert.base import TypeOrFilter from albert.index import Index - from albert.types import _ScalarJSON + from albert.types import EvaluatorArrayDict, _ScalarJSON T = TypeVar("T", bound=Base) @@ -79,6 +79,26 @@ def delete( """ return Scalar.factory(0.0) if _matches_filter(self, type_filter) else self + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + return self.value + @property def value(self) -> float: """Get the value of the scalar.""" diff --git a/albert/tensor.py b/albert/tensor.py index 2ede2a9..2bd8084 100644 --- a/albert/tensor.py +++ b/albert/tensor.py @@ -10,11 +10,11 @@ from albert.scalar import Scalar if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Callable, Optional from albert.base import TypeOrFilter from albert.symmetry import Permutation, Symmetry - from albert.types import _TensorJSON + from albert.types import EvaluatorArrayDict, _TensorJSON T = TypeVar("T", bound=Base) @@ -99,6 +99,36 @@ def delete( """ return Scalar.factory(0.0) if _matches_filter(self, type_filter) else self + def evaluate( + self, + arrays: EvaluatorArrayDict, + einsum: Callable[..., Any], + ) -> Any: + """Evaluate the node numerically. + + Args: + arrays: Mapping to provide numerical arrays for tensors. The mapping must be in one of + the following formats: + 1. ``{tensor_name: { (space1, space2, ...): array, ... }, ...}`` + 2. ``{tensor_name: { "space1space2...": array, ...}, ...}`` + 3. ``{tensor_name: array, ...}`` (only for tensors with no indices) + einsum: Function to perform tensor contraction. + + Returns: + Evaluated node, as an array. + """ + spaces = tuple(index.space for index in self.indices) + if all(space is None for space in spaces): + return arrays[self.name] + elif any(space is None for space in spaces): + raise ValueError("Cannot evaluate tensor with some indices missing spaces") + spaces = cast(tuple[str, ...], spaces) + if spaces in arrays[self.name]: + return arrays[self.name][spaces] + elif "".join(spaces) in arrays[self.name]: + return arrays[self.name]["".join(spaces)] + raise ValueError(f"No tensor {self.name} with spaces {spaces} provided") + @property def indices(self) -> tuple[Index, ...]: """Get the indices of the object.""" diff --git a/albert/types.pyi b/albert/types.pyi index 0bbe0f9..893a30e 100644 --- a/albert/types.pyi +++ b/albert/types.pyi @@ -2,6 +2,9 @@ from __future__ import annotations from typing import Any, Hashable, Protocol, Optional, TypedDict +EvaluatorArrayDict = dict[str, dict[tuple[str, ...] | str, Any]] | dict[str, Any] + + class Comparable(Protocol): """Protocol for comparable objects."""