Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions albert/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from albert.base import TypeOrFilter
from albert.index import Index
from albert.tensor import Tensor
from albert.types import EvaluatorArrayDict, _AlgebraicJSON

T = TypeVar("T", bound=Base)
Expand Down Expand Up @@ -491,25 +490,19 @@ def evaluate(
Returns:
Evaluated node, as an array.
"""
# Find the scalar factor
factor = 1.0
if self.find(Scalar):
for scalar in self.search(Scalar):
factor *= scalar.evaluate(arrays, einsum)

# Get the arrays and indices
child: Tensor | Algebraic
index: Index
child_index_map: dict[Index, int] = {}
factor = 1.0
args: list[Any] = []
for child in self.search(lambda node: node is not self and not isinstance(node, Scalar)):
for index in child.external_indices:
if index not in child_index_map:
child_index_map[index] = len(child_index_map)
args.append(child.evaluate(arrays, einsum))
args.append(
tuple(child_index_map[index] for index in child.external_indices) # type: ignore
)
for child in self.children:
if isinstance(child, Scalar):
factor *= child.evaluate(arrays, einsum)
else:
for index in child.external_indices:
if index not in child_index_map:
child_index_map[index] = len(child_index_map)
args.append(child.evaluate(arrays, einsum))
args.append(tuple(child_index_map[index] for index in child.external_indices))

# Call the einsum function
output_indices = tuple(child_index_map[index] for index in self.external_indices)
Expand Down
2 changes: 1 addition & 1 deletion albert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def factory(cls: type[Base], *args: Any, **kwargs: Any) -> Base:
@property
def is_leaf(self) -> bool:
"""Get whether the object is a leaf in a tree."""
return self.children is None
return not bool(self.children)

@property
def children(self) -> tuple[Base, ...]:
Expand Down
File renamed without changes.