Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
de501c1
GEM: simplify indexed
pbrubeck Aug 8, 2025
592061e
Fixes for more complicated expressions
pbrubeck Aug 11, 2025
03e9ae6
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Aug 11, 2025
158b0ec
small change
pbrubeck Aug 11, 2025
d2a4584
More simplification
pbrubeck Aug 11, 2025
b49eb5c
Do not replace free indices
pbrubeck Aug 12, 2025
340ec40
Simplify IndexSum
pbrubeck Aug 12, 2025
a7fda02
Refactor IndexSum unrolling
pbrubeck Aug 14, 2025
eab0d90
Flatten nested ComponentTensors
pbrubeck Aug 14, 2025
41981ec
use numpy.array_equal
pbrubeck Aug 14, 2025
7f58028
Simplify ListTensor(ComponentTensor(Indexed(...)))
pbrubeck Aug 14, 2025
7077cf9
Some indices fixed
pbrubeck Aug 14, 2025
55bb313
style
pbrubeck Aug 14, 2025
10a8604
Add tests
pbrubeck Aug 26, 2025
b738ec6
More simplify
pbrubeck Aug 27, 2025
830a0e3
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Aug 27, 2025
5dfd17d
Fix up
pbrubeck Aug 29, 2025
3fe803d
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Oct 23, 2025
c0a6e02
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Dec 13, 2025
f9c5d2b
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Feb 2, 2026
8d0dc41
expand_fixedindices DAG traverser
pbrubeck Feb 2, 2026
40a7dc7
fix
pbrubeck Feb 2, 2026
2fea3ff
restore Indexed.__new__
pbrubeck Feb 2, 2026
ca7f2f0
Remove fixed indices within remove_componenttensors
pbrubeck Feb 2, 2026
ff2c664
fix
pbrubeck Feb 2, 2026
aab3a94
fixes
pbrubeck Feb 3, 2026
f5d7959
Merge branch 'main' into pbrubeck/simplify-indexed
pbrubeck Feb 4, 2026
04f0cf3
review comments
pbrubeck Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def is_equal(self, other):
return False
if self.shape != other.shape:
return False
return tuple(self.array.flat) == tuple(other.array.flat)
return numpy.array_equal(self.array, other.array)

def get_hash(self):
return hash((type(self), self.shape, tuple(self.array.flat)))
Expand Down Expand Up @@ -684,8 +684,26 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
# This pattern corresponds to an index replacement rule jj -> ii applied to
# the innermost multiindex kk to produce ll.
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
jj = aggregate.multiindex
ii = multiindex
if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
ff = C.free_indices
if not any((j in ff) for j in jj):
# Only replace indices that are not present in C
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
aggregate = C
multiindex = ll

# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if all(isinstance(i, Integral) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
Expand Down Expand Up @@ -835,6 +853,11 @@ def __new__(cls, expression, multiindex):
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

# Index folding
if isinstance(expression, Indexed):
if multiindex == expression.multiindex:
return expression.children[0]

self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
Expand Down Expand Up @@ -871,6 +894,11 @@ def __new__(cls, summand, multiindex):
if not multiindex:
return summand

# Flatten nested sums
if isinstance(summand, IndexSum):
A, = summand.children
return IndexSum(A, summand.multiindex + multiindex)

self = super(IndexSum, cls).__new__(cls)
self.children = (summand,)
self.multiindex = multiindex
Expand All @@ -891,15 +919,46 @@ def __new__(cls, array):
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
e0 = array.flat[0]
child_shape = e0.shape
assert all(elem.shape == child_shape for elem in array.flat)

# Simplify [tensor[multiindex, j] for j in range(n)] -> partial_indexed(tensor, multiindex)
if all(isinstance(elem, Indexed) for elem in array.flat):
tensor = e0.children[0]
if all(elem.children[0] == tensor for elem in array.flat[1:]):
# Extract maximal subset of leading indices that is common over all array entries
multiindex = tuple(e0.multiindex)
for elem in array.flat[1:]:
while elem.multiindex[:len(multiindex)] != multiindex:
multiindex = multiindex[:-1]
if len(multiindex) == 0:
break
index_shape = tuple(i.extent if isinstance(i, Index) else 1 for i in multiindex)
if index_shape + array.shape + child_shape == tensor.shape:
if all(elem.multiindex[len(multiindex):] == idx for idx, elem in numpy.ndenumerate(array)):
return partial_indexed(tensor, multiindex)

# Simplify [tensor[j, ...] for j in range(n)] -> tensor
if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed)
for elem in array.flat):
tensor = e0.children[0].children[0]
if array.shape + child_shape == tensor.shape:
if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]):
if all(elem.children[0].multiindex == idx + elem.multiindex
for idx, elem in numpy.ndenumerate(array)):
return tensor

# Flatten nested ListTensors
if all(isinstance(elem, ListTensor) for elem in array.flat):
return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape))

if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
for alpha in numpy.ndindex(array.shape):
for alpha, elem in numpy.ndenumerate(array):
for beta in numpy.ndindex(child_shape):
direct_array[alpha + beta] = Indexed(array[alpha], beta)
direct_array[alpha + beta] = Indexed(elem, beta)
array = direct_array

# Constant folding
Expand Down
31 changes: 24 additions & 7 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gem.utils import groupby
from gem.node import (Memoizer, MemoizerArg, reuse_if_untouched,
reuse_if_untouched_arg, traversal)
from gem.gem import (Node, Failure, Identity, Literal, Zero,
from gem.gem import (Node, Failure, Identity, Constant, Literal, Zero,
Product, Sum, Comparison, Conditional, Division,
Index, VariableIndex, Indexed, FlexiblyIndexed,
IndexSum, ComponentTensor, ListTensor, Delta,
Expand Down Expand Up @@ -119,6 +119,7 @@ def replace_indices_delta(node, self, subst):
def replace_indices_indexed(node, self, subst):
multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex)
child, = node.children

if isinstance(child, ComponentTensor):
# Indexing into ComponentTensor
# Inline ComponentTensor and augment the substitution rules
Expand All @@ -127,11 +128,27 @@ def replace_indices_indexed(node, self, subst):
return self(child.children[0], tuple(sorted(substitute.items())))
else:
# Replace indices
new_child = self(child, subst)
if multiindex == node.multiindex and new_child == child:
child = self(child, subst)

# Remove fixed indices
if isinstance(child, (Constant, ListTensor)):
if all(isinstance(i, Integral) for i in multiindex):
# All indices fixed
sub = child.array[multiindex]
child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub
multiindex = ()

elif any(isinstance(i, Integral) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, Integral) else slice(None) for i in multiindex)
sub = child.array[slices]
child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub)
multiindex = tuple(i for i in multiindex if not isinstance(i, Integral))

if multiindex == node.multiindex and child == node.children[0]:
return node
else:
return Indexed(new_child, multiindex)
return Indexed(child, multiindex)


@replace_indices.register(FlexiblyIndexed)
Expand Down Expand Up @@ -177,7 +194,7 @@ def _constant_fold_zero(node, self):

@_constant_fold_zero.register(Literal)
def _constant_fold_zero_literal(node, self):
if (node.array == 0).all():
if numpy.array_equal(node.array, 0):
# All zeros, make symbolic zero
return Zero(node.shape)
else:
Expand Down Expand Up @@ -663,8 +680,8 @@ def _(node, self):
# Unrolling
summand = self(node.children[0])
shape = tuple(index.extent for index in unroll)
unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha)
for alpha in numpy.ndindex(shape)))
tensor = ComponentTensor(summand, unroll)
unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape)))
return IndexSum(unrolled, tuple(index for index in node.multiindex
if index not in unroll))
else:
Expand Down
86 changes: 86 additions & 0 deletions test/gem/test_simplify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import gem
import numpy


@pytest.fixture
def A():
a = gem.Variable("a", ())
b = gem.Variable("b", ())
c = gem.Variable("c", ())
d = gem.Variable("d", ())
array = [[a, b], [c, d]]
A = gem.ListTensor(array)
return A


@pytest.fixture
def X():
return gem.Variable("X", (2, 2))


def test_listtensor_from_indexed(X):
k = gem.Index()
elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])]
tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:]))

assert isinstance(tensor, gem.ComponentTensor)
j = tensor.multiindex
expected = gem.partial_indexed(X, (k,))
expected = gem.ComponentTensor(gem.Indexed(expected, j), j)
assert tensor == expected


def test_listtensor_from_fixed_indexed(A):
elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)]
tensor = gem.ListTensor(numpy.reshape(elems, A.shape))
assert tensor == A


def test_listtensor_from_partial_indexed(A):
elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])]
tensor = gem.ListTensor(elems)
assert tensor == A


def test_nested_partial_indexed(A):
i, j = gem.indices(2)
B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,))
assert B == gem.Indexed(A, (i, j))


def test_componenttensor_from_indexed(A):
i, j = gem.indices(2)
Aij = gem.Indexed(A, (i, j))
assert A == gem.ComponentTensor(Aij, (i, j))


def test_indexed_transpose(A):
i, j = gem.indices(2)
ATij = gem.Indexed(A.T, (i, j))
Aji = gem.Indexed(A, (j, i))
assert ATij == Aji

i, = gem.indices(1)
j = 1
ATij = gem.Indexed(A.T, (i, j))
Aji = gem.Indexed(A, (j, i))
assert ATij == Aji

i, j = (0, 1)
ATij = gem.Indexed(A.T, (i, j))
Aji = gem.Indexed(A, (j, i))
assert ATij == Aji


def test_double_transpose(A):
assert A.T.T == A


def test_flatten_indexsum(A):
i, j = gem.indices(2)
Aij = gem.Indexed(A, (i, j))

result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,))
expected = gem.IndexSum(Aij, (i, j))
assert result == expected