Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions albert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
}

ALLOW_NON_EINSTEIN_NOTATION = 0
INFER_ALGEBRA_SYMMETRIES = 0
87 changes: 68 additions & 19 deletions albert/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
import itertools
from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING

from albert import ALLOW_NON_EINSTEIN_NOTATION
from albert import ALLOW_NON_EINSTEIN_NOTATION, INFER_ALGEBRA_SYMMETRIES
from albert.base import _INTERN_TABLE, Base, _matches_filter
from albert.scalar import Scalar
from albert.symmetry import infer_symmetry_add, infer_symmetry_mul

if TYPE_CHECKING:
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Optional

from albert.base import TypeOrFilter
from albert.index import Index
from albert.symmetry import Symmetry
from albert.types import EvaluatorArrayDict, _AlgebraicJSON

T = TypeVar("T", bound=Base)
# T = TypeVar("T", bound=Base)


def _check_indices(children: Iterable[Base]) -> dict[Index, int]:
Expand Down Expand Up @@ -66,20 +68,39 @@ class Algebraic(Base):
children: Children to operate on.
"""

__slots__ = ("_hash", "_children")
__slots__ = ("_hash", "_children", "_symmetry")

_children: tuple[Base, ...]

def __init__(self, *children: Base):
def __init__(self, *children: Base, symmetry: Optional[Symmetry] = None):
"""Initialise the addition."""
self._hash = None
self._children = children
self._symmetry = symmetry

def copy(self, *children: Base) -> Algebraic:
"""Return a copy of the object with optionally updated attributes."""
def copy(self, *children: Base, symmetry: Optional[Symmetry] = None) -> Algebraic:
"""Return a copy of the object with optionally updated attributes.

Args:
children: New children.
symmetry: New symmetry.

Returns:
Copy of the object.

Note:
Since ``albert`` objects are immutable, the copy may be an interned object. If this
method is called without any arguments, the original object may be returned.
"""
if not children:
children = self.children
return self.__class__(*children)
if symmetry is None:
symmetry = self._symmetry
factory_copy = self.factory(*children, symmetry=symmetry)
if isinstance(factory_copy, self.__class__):
# factory method may return a different class, only return it if not
return factory_copy
return self.__class__(*children, symmetry=symmetry)

def map_indices(self, mapping: dict[Index, Index]) -> Algebraic:
"""Return a copy of the object with the indices mapped according to some dictionary.
Expand Down Expand Up @@ -211,27 +232,41 @@ class Add(Algebraic):
children: Children to add.
"""

__slots__ = ("_hash", "_children", "_internal_indices", "_external_indices")
__slots__ = ("_hash", "_children", "_symmetry", "_internal_indices", "_external_indices")

_score = 2

def __init__(self, *children: Base):
def __init__(self, *children: Base, symmetry: Optional[Symmetry] = None):
"""Initialise the addition."""
if len(set(tuple(sorted(child.external_indices)) for child in children)) > 1:
raise ValueError("External indices in additions must be equal.")
super().__init__(*children)
super().__init__(*children, symmetry=symmetry)

# Precompute indices
self._external_indices = children[0].external_indices
self._internal_indices = ()

# Try to infer symmetry if not provided
if (
symmetry is None
and all(child.symmetry is not None for child in children)
and INFER_ALGEBRA_SYMMETRIES
):
self._symmetry = infer_symmetry_add(self)

@classmethod
def factory(cls: type[Add], *children: Base, cls_scalar: type[Scalar] | None = None) -> Base:
def factory(
cls: type[Add],
*children: Base,
symmetry: Optional[Symmetry] = None,
cls_scalar: type[Scalar] | None = None,
) -> Base:
"""Factory method to create a new object.

Args:
cls: The class of the addition to create.
children: The children of the addition.
symmetry: Symmetry of the addition.
cls_scalar: Class to use for scalars.

Returns:
Expand Down Expand Up @@ -265,7 +300,7 @@ def factory(cls: type[Add], *children: Base, cls_scalar: type[Scalar] | None = N
other.append(child)

# Build a key for interning
key = (cls, value, tuple(other)) # Commutative but not canonical
key = (cls, value, tuple(other), symmetry) # Commutative but not canonical

def create() -> Base:
if not other:
Expand Down Expand Up @@ -383,26 +418,40 @@ class Mul(Algebraic):
children: Children to multiply
"""

__slots__ = ("_hash", "_children", "_internal_indices", "_external_indices")
__slots__ = ("_hash", "_children", "_symmetry", "_internal_indices", "_external_indices")

_score = 3

def __init__(self, *children: Base):
def __init__(self, *children: Base, symmetry: Optional[Symmetry] = None):
"""Initialise the multiplication."""
super().__init__(*children)
super().__init__(*children, symmetry=symmetry)

# Precompute indices
counts = _check_indices(children)
self._external_indices = tuple(index for index, count in counts.items() if count == 1)
self._internal_indices = tuple(index for index, count in counts.items() if count > 1)

# Try to infer symmetry if not provided
if (
self._symmetry is None
and all(child.symmetry is not None for child in children)
and INFER_ALGEBRA_SYMMETRIES
):
self._symmetry = infer_symmetry_mul(self)

@classmethod
def factory(cls: type[Mul], *children: Base, cls_scalar: type[Scalar] | None = None) -> Base:
def factory(
cls: type[Mul],
*children: Base,
symmetry: Optional[Symmetry] = None,
cls_scalar: type[Scalar] | None = None,
) -> Base:
"""Factory method to create a new object.

Args:
cls: The class of the multiplication to create.
children: The children of the multiplication.
symmetry: Symmetry of the multiplication.
cls_scalar: Class to use for scalars.

Returns:
Expand Down Expand Up @@ -440,7 +489,7 @@ def factory(cls: type[Mul], *children: Base, cls_scalar: type[Scalar] | None = N
return cls_scalar.factory(0.0)

# Build a key for interning
key = (cls, value, tuple(other)) # Commutative but not canonical
key = (cls, value, tuple(other), symmetry) # Commutative but not canonical

def create() -> Base:
if not other:
Expand Down
8 changes: 7 additions & 1 deletion albert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import Self

from albert.index import Index
from albert.symmetry import Permutation
from albert.symmetry import Permutation, Symmetry
from albert.types import EvaluatorArrayDict, SerialisedField

T = TypeVar("T", bound="Base")
Expand Down Expand Up @@ -146,6 +146,7 @@ class Base(Serialisable):

_score: int
_children: Optional[tuple[Base, ...]]
_symmetry: Optional[Symmetry] = None
_penalties: tuple[Callable[[Base], int], ...] = (_sign_penalty,)
_internal_indices: tuple[Index, ...]
_external_indices: tuple[Index, ...]
Expand Down Expand Up @@ -175,6 +176,11 @@ def children(self) -> tuple[Base, ...]:
"""Get the children of the node."""
return self._children or ()

@property
def symmetry(self) -> Optional[Symmetry]:
"""Get the symmetry of the object."""
return self._symmetry

def _search(
self,
level: int,
Expand Down
4 changes: 2 additions & 2 deletions albert/opt/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,8 @@ def _optimise_biclique(

# Initialise intermediate tensors
n = len(intermediates)
left = Tensor(*partition.indices_left, name=intermediate_format.format(n))
right = Tensor(*partition.indices_right, name=intermediate_format.format(n + 1))
left = Tensor.factory(*partition.indices_left, name=intermediate_format.format(n))
right = Tensor.factory(*partition.indices_right, name=intermediate_format.format(n + 1))

# Build the intermediate expressions
intermediates[left], intermediates[right] = build_intermediates(
Expand Down
19 changes: 16 additions & 3 deletions albert/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, TypeVar, cast

from albert.base import _INTERN_TABLE, Base, _matches_filter
from albert.symmetry import fully_symmetric_group

if TYPE_CHECKING:
from typing import Any, Callable, Optional
Expand All @@ -25,15 +26,23 @@ class Scalar(Base):
value: Value of the scalar.
"""

__slots__ = ("_value", "_hash", "_children", "_internal_indices", "_external_indices")
__slots__ = (
"_value",
"_hash",
"_children",
"_symmetry",
"_internal_indices",
"_external_indices",
)

_score = 1

def __init__(self, value: float = 0.0):
def __init__(self, value: float = 0.0) -> None:
"""Initialise the tensor."""
self._value = value
self._hash = None
self._children = None
self._symmetry = fully_symmetric_group(0)
self._internal_indices = ()
self._external_indices = ()

Expand Down Expand Up @@ -117,10 +126,14 @@ def copy(self, value: Optional[float] = None) -> Scalar:

Returns:
Copy of the object.

Note:
Since ``albert`` objects are immutable, the copy may be an interned object. If this
method is called without any arguments, the original object may be returned.
"""
if value is None:
value = self.value
return Scalar(value)
return self.__class__.factory(value)

def map_indices(self, mapping: dict[Index, Index]) -> Scalar:
"""Return a copy of the object with the indices mapped according to some dictionary.
Expand Down
72 changes: 71 additions & 1 deletion albert/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from __future__ import annotations

import itertools
from math import prod
from typing import TYPE_CHECKING

from albert.base import Serialisable

if TYPE_CHECKING:
from typing import Iterable

from albert.algebra import Add, Mul
from albert.base import Base
from albert.types import SerialisedField, _PermutationJSON, _SymmetryJSON

Expand Down Expand Up @@ -210,7 +212,7 @@ def fully_symmetric_group(n: int) -> Symmetry:
Returns:
Symmetry group.
"""
return Symmetry(*[Permutation(perm, 1) for perm in itertools.permutations(range(n))])
return Symmetry(*sorted(Permutation(perm, 1) for perm in itertools.permutations(range(n))))


def fully_antisymmetric_group(n: int) -> Symmetry:
Expand Down Expand Up @@ -244,3 +246,71 @@ def _permutations(seq: list[int]) -> list[list[int]]:
]

return Symmetry(*permutations)


def infer_symmetry_add(add: Add) -> Symmetry:
"""Infer the symmetry of an addition from its children.

Args:
add: Addition node.

Returns:
Inferred symmetry.
"""
perms: set[Permutation] = set()
for child in add.children:
if child.symmetry is None:
raise ValueError("All children must have symmetry defined to infer symmetry.")
perms.update(child.symmetry.permutations)
return Symmetry(*sorted(perms))


def infer_symmetry_mul(mul: Mul) -> Symmetry:
"""Infer the symmetry of a multiplication from its children.

Args:
mul: Multiplication node.

Returns:
Inferred symmetry.
"""
from albert.scalar import Scalar

# Get all permutations
children = tuple(filter(lambda node: not isinstance(node, Scalar), mul.children))
perms: list[list[Permutation]] = []
for child in children:
if child.symmetry is None:
raise ValueError("All children must have symmetry defined to infer symmetry.")
perms.append(list(child.symmetry.permutations))

# Loop over permutations of each child
result_permutations: set[Permutation] = set()
for permutations in itertools.product(*perms):
index_maps = [
dict(zip(child.external_indices, (child.external_indices[p] for p in perm.permutation)))
for child, perm in zip(children, permutations)
]
sign = prod(perm.sign for perm in permutations)

# Skip if any internal and external indices are permuted in any child
if any(
(src in mul.external_indices) != (dst in mul.external_indices)
for index_map in index_maps
for src, dst in index_map.items()
):
continue

# Find the permutation of the external indices
perm: list[int] = []
for idx in mul.external_indices:
for index_map in index_maps:
if idx in index_map:
idx = index_map[idx]
break
perm.append(mul.external_indices.index(idx))

# Add the resulting permutation
result_permutations.add(Permutation(tuple(perm), sign))

return Symmetry(*sorted(result_permutations))
Loading