From de501c10b9adee96ec3468f62b972a16dde97085 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 Aug 2025 10:29:01 +0100 Subject: [PATCH 01/22] GEM: simplify indexed --- gem/gem.py | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 974556754..f468747bd 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -685,12 +685,31 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -836,6 +855,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -892,9 +916,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) From 592061ea4f4e3416b3729f658cd6b2380de494a4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:30:05 +0100 Subject: [PATCH 02/22] Fixes for more complicated expressions --- gem/gem.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index f468747bd..79bcd2e79 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -274,6 +274,7 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -702,14 +703,29 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, ComponentTensor): B, = aggregate.children jj = aggregate.multiindex + ii = multiindex + # Avoid recursion and just attempt to simplify some common patterns + # as the result of this method is not cached. if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): - ii = tuple(multiindex) + if isinstance(C, ListTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - return Indexed(C, ll) + B = Indexed(C, ll) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if len(ii) == 0: + return B + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -722,6 +738,7 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) + self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From 158b0ec918a2100f496d9b59bbcce5235cb24c24 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:41:42 +0100 Subject: [PATCH 03/22] small change --- gem/gem.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 111f4bb3b..cec66af68 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,7 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -294,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -737,7 +736,6 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) - self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From d2a45843f703db6e8ceff49e43541812785d6b93 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 17:50:56 +0100 Subject: [PATCH 04/22] More simplification --- gem/gem.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index cec66af68..9eb00f4ac 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,6 +273,9 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + if numpy.allclose(array, 0, 1e-14): + return Zero(array.shape) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -690,6 +693,7 @@ def __new__(cls, aggregate, multiindex): # All indices fixed sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): # Some indices fixed slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) @@ -719,12 +723,20 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): - if all(j in kk for j in jj): - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + if isinstance(C, ComponentTensor): + if (all(isinstance(i, Index) for i in ii) + or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): + return Indexed(C, ll) + else: return Indexed(C, ll) + if len(ii) < len(multiindex): + aggregate = ComponentTensor(B, jj) + multiindex = ii + self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex From b49eb5cfd8bec337efbb323277df7f74ca4cfff5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 12:19:35 +0100 Subject: [PATCH 05/22] Do not replace free indices --- gem/gem.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9eb00f4ac..2ee2fb87f 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -273,9 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - if numpy.allclose(array, 0, 1e-14): - return Zero(array.shape) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -706,36 +703,27 @@ def __new__(cls, aggregate, multiindex): B, = aggregate.children jj = aggregate.multiindex ii = multiindex - # Avoid recursion and just attempt to simplify some common patterns - # as the result of this method is not cached. + if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if isinstance(C, ListTensor): + if not isinstance(C, ComponentTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) B = Indexed(C, ll) jj = tuple(j for j in jj if j not in kk) ii = tuple(rep[j] for j in jj) - if len(ii) == 0: + if not ii: return B if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - if isinstance(C, ComponentTensor): - if (all(isinstance(i, Index) for i in ii) - or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): - return Indexed(C, ll) - else: - return Indexed(C, ll) - - if len(ii) < len(multiindex): - aggregate = ComponentTensor(B, jj) - multiindex = ii + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -1277,6 +1265,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From 340ec400181d3fd1ae8469961d940488ee456ef0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 22:22:28 +0100 Subject: [PATCH 06/22] Simplify IndexSum --- gem/gem.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2ee2fb87f..23e91ceec 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -896,6 +896,25 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # No indices case + multiindex = tuple(multiindex) + if not multiindex: + return summand + + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + + # Factor out common factors + if isinstance(summand, Product): + a, b = summand.children + if all(i not in a.free_indices for i in multiindex): + return Product(a, IndexSum(b, multiindex)) + + if all(i not in b.free_indices for i in multiindex): + return Product(IndexSum(a, multiindex), b) + # Unroll singleton sums unroll = tuple(index for index in multiindex if index.extent <= 1) if unroll: @@ -905,11 +924,6 @@ def __new__(cls, summand, multiindex): multiindex = tuple(index for index in multiindex if index not in unroll) - # No indices case - multiindex = tuple(multiindex) - if not multiindex: - return summand - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex From a7fda022d2d0470cefcd66cd65d209cc5297853b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 10:43:05 +0100 Subject: [PATCH 07/22] Refactor IndexSum unrolling --- gem/optimise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gem/optimise.py b/gem/optimise.py index 3c3c9bed7..69fbb8ce6 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -657,8 +657,8 @@ def _(node, self): # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) - unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape))) + tensor = ComponentTensor(summand, unroll) + unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape))) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: From eab0d90f9947ef1bace5ebc6e2906e6b2707f0d3 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 10:46:08 +0100 Subject: [PATCH 08/22] Flatten nested ComponentTensors --- gem/gem.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 23e91ceec..a2c8fe7f5 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -874,6 +874,11 @@ def __new__(cls, expression, multiindex): if multiindex == expression.multiindex: return expression.children[0] + # Flatten nested ComponentTensors + if isinstance(expression, ComponentTensor): + A, = expression.children + return ComponentTensor(A, expression.multiindex + multiindex) + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex From 41981ec30234a8d05cb296931a01553caf54042b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 16:32:01 +0100 Subject: [PATCH 09/22] use numpy.array_equal --- gem/optimise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/optimise.py b/gem/optimise.py index 69fbb8ce6..1aa2016b8 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -177,7 +177,7 @@ def _constant_fold_zero(node, self): @_constant_fold_zero.register(Literal) def _constant_fold_zero_literal(node, self): - if (node.array == 0).all(): + if numpy.array_equal(node.array, 0): # All zeros, make symbolic zero return Zero(node.shape) else: From 7f580284a69e8cf32e2de3856c3415ad5059a5f7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 17:36:52 +0100 Subject: [PATCH 10/22] Simplify ListTensor(ComponentTensor(Indexed(...))) --- gem/gem.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index a2c8fe7f5..1eb591a4e 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,13 +691,6 @@ def __new__(cls, aggregate, multiindex): sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate.array[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children @@ -953,12 +946,23 @@ def __new__(cls, array): child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) - # Index folding - if child_shape == array.shape: - if all(isinstance(elem, Indexed) for elem in array.flat): - if all(elem.children == e0.children for elem in array.flat[1:]): + # Simplify [v[j] for j in range(n)] -> v + if all(isinstance(elem, Indexed) for elem in array.flat): + tensor = e0.children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0] == tensor for elem in array.flat[1:]): if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): - return e0.children[0] + return tensor + + # Simplify [v[j, :] for j in range(n)] -> v + if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) + for elem in array.flat): + tensor = e0.children[0].children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): + if all(elem.children[0].multiindex == idx + elem.multiindex + for idx, elem in zip(numpy.ndindex(array.shape), array.flat)): + return tensor if child_shape: # Destroy structure From 7077cf9514cb971e67e27cf2c299ee1b04d47534 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 17:56:27 +0100 Subject: [PATCH 11/22] Some indices fixed --- gem/gem.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 1eb591a4e..2963161ea 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,6 +691,13 @@ def __new__(cls, aggregate, multiindex): sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children From 55bb3136bc70305e9d4a022e465c00c57c362a05 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 23:39:22 +0100 Subject: [PATCH 12/22] style --- gem/gem.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2963161ea..1b94f09cc 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -901,6 +901,15 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # Unroll singleton sums + unroll = tuple(index for index in multiindex if index.extent <= 1) + if unroll: + assert numpy.prod([index.extent for index in unroll]) == 1 + summand = Indexed(ComponentTensor(summand, unroll), + (0,) * len(unroll)) + multiindex = tuple(index for index in multiindex + if index not in unroll) + # No indices case multiindex = tuple(multiindex) if not multiindex: @@ -920,15 +929,6 @@ def __new__(cls, summand, multiindex): if all(i not in b.free_indices for i in multiindex): return Product(IndexSum(a, multiindex), b) - # Unroll singleton sums - unroll = tuple(index for index in multiindex if index.extent <= 1) - if unroll: - assert numpy.prod([index.extent for index in unroll]) == 1 - summand = Indexed(ComponentTensor(summand, unroll), - (0,) * len(unroll)) - multiindex = tuple(index for index in multiindex - if index not in unroll) - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -958,7 +958,7 @@ def __new__(cls, array): tensor = e0.children[0] if array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): - if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + if all(elem.multiindex == idx for idx, elem in numpy.ndenumerate(array)): return tensor # Simplify [v[j, :] for j in range(n)] -> v @@ -968,15 +968,19 @@ def __new__(cls, array): if array.shape + child_shape == tensor.shape: if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): if all(elem.children[0].multiindex == idx + elem.multiindex - for idx, elem in zip(numpy.ndindex(array.shape), array.flat)): + for idx, elem in numpy.ndenumerate(array)): return tensor + # Flatten nested ListTensors + if all(isinstance(elem, ListTensor) for elem in array.flat): + return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape)) + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) - for alpha in numpy.ndindex(array.shape): + for alpha, elem in numpy.ndenumerate(array): for beta in numpy.ndindex(child_shape): - direct_array[alpha + beta] = Indexed(array[alpha], beta) + direct_array[alpha + beta] = Indexed(elem, beta) array = direct_array # Constant folding From 10a8604ea94b92d6eb87b502279dc14f978f2ae2 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 26 Aug 2025 17:52:12 +0100 Subject: [PATCH 13/22] Add tests --- gem/gem.py | 17 +------------- test/gem/test_simplify.py | 47 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 16 deletions(-) create mode 100644 test/gem/test_simplify.py diff --git a/gem/gem.py b/gem/gem.py index 1b94f09cc..2638effe4 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, minus, as_gem(other))) + componentwise(Product, Literal(-1), as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -874,11 +874,6 @@ def __new__(cls, expression, multiindex): if multiindex == expression.multiindex: return expression.children[0] - # Flatten nested ComponentTensors - if isinstance(expression, ComponentTensor): - A, = expression.children - return ComponentTensor(A, expression.multiindex + multiindex) - self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -920,15 +915,6 @@ def __new__(cls, summand, multiindex): A, = summand.children return IndexSum(A, summand.multiindex + multiindex) - # Factor out common factors - if isinstance(summand, Product): - a, b = summand.children - if all(i not in a.free_indices for i in multiindex): - return Product(a, IndexSum(b, multiindex)) - - if all(i not in b.free_indices for i in multiindex): - return Product(IndexSum(a, multiindex), b) - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -1299,7 +1285,6 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) -minus = Literal(-1) # Syntax sugar diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py new file mode 100644 index 000000000..0658a7f56 --- /dev/null +++ b/test/gem/test_simplify.py @@ -0,0 +1,47 @@ +import pytest +import gem +import numpy + + +@pytest.fixture +def A(): + a = gem.Variable("a", ()) + b = gem.Variable("b", ()) + c = gem.Variable("c", ()) + d = gem.Variable("d", ()) + array = [[a, b], [c, d]] + A = gem.ListTensor(array) + return A + + +def test_listtensor_from_indexed(A): + elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] + tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) + assert tensor == A + + +def test_listtensor_from_partial_indexed(A): + elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])] + tensor = gem.ListTensor(elems) + assert tensor == A + + +def test_nested_partial_indexed(A): + i, j = gem.indices(2) + B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,)) + assert B == gem.Indexed(A, (i, j)) + + +def test_componenttensor_from_indexed(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + assert A == gem.ComponentTensor(Aij, (i, j)) + + +def test_flatten_indexsum(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + + result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,)) + expected = gem.IndexSum(Aij, (i, j)) + assert result == expected From b738ec64f0eb8e7ac9197fd255b27b4bc25722c6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 11:34:03 +0100 Subject: [PATCH 14/22] More simplify --- gem/gem.py | 12 +++++++----- test/gem/test_simplify.py | 19 ++++++++++++++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2638effe4..c10198a7e 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -939,15 +939,17 @@ def __new__(cls, array): child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) - # Simplify [v[j] for j in range(n)] -> v + # Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex) if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] - if array.shape + child_shape == tensor.shape: + multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) + index_shape = tuple(i.extent for i in multiindex) + if index_shape + array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): - if all(elem.multiindex == idx for idx, elem in numpy.ndenumerate(array)): - return tensor + if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): + return partial_indexed(tensor, multiindex) - # Simplify [v[j, :] for j in range(n)] -> v + # Simplify [v[j, ...] for j in range(n)] -> v if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) for elem in array.flat): tensor = e0.children[0].children[0] diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py index 0658a7f56..9d7b52a49 100644 --- a/test/gem/test_simplify.py +++ b/test/gem/test_simplify.py @@ -14,7 +14,24 @@ def A(): return A -def test_listtensor_from_indexed(A): +@pytest.fixture +def X(): + return gem.Variable("X", (2, 2)) + + +def test_listtensor_from_indexed(X): + k = gem.Index() + elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])] + tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:])) + + assert isinstance(tensor, gem.ComponentTensor) + j = tensor.multiindex + expected = gem.partial_indexed(X, (k,)) + expected = gem.ComponentTensor(gem.Indexed(expected, j), j) + assert tensor == expected + + +def test_listtensor_from_fixed_indexed(A): elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) assert tensor == A From 5dfd17de4df9e3685b094877b3f31dee9563a2c0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 29 Aug 2025 14:10:30 +0100 Subject: [PATCH 15/22] Fix up --- gem/gem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/gem.py b/gem/gem.py index 463a9111b..8601c1c20 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -943,7 +943,7 @@ def __new__(cls, array): if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) - index_shape = tuple(i.extent for i in multiindex) + index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) if index_shape + array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): From 8d0dc4198c8f191b21d9562bcba50899b10081f1 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 2 Feb 2026 12:20:56 +0000 Subject: [PATCH 16/22] expand_fixedindices DAG traverser --- gem/gem.py | 16 ---------------- gem/optimise.py | 44 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 8601c1c20..9d68b6993 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,13 +691,6 @@ def __new__(cls, aggregate, multiindex): sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate.array[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children @@ -716,15 +709,6 @@ def __new__(cls, aggregate, multiindex): if not ii: return B - if isinstance(B, Indexed): - C, = B.children - kk = B.multiindex - ff = C.free_indices - if all((j in kk) and (j not in ff) for j in jj): - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) - return Indexed(C, ll) - self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex diff --git a/gem/optimise.py b/gem/optimise.py index ad91e1f5e..c9d74782b 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -11,7 +11,7 @@ from gem.utils import groupby from gem.node import (Memoizer, MemoizerArg, reuse_if_untouched, reuse_if_untouched_arg, traversal) -from gem.gem import (Node, Failure, Identity, Literal, Zero, +from gem.gem import (Node, Failure, Identity, Constant, Literal, Zero, Product, Sum, Comparison, Conditional, Division, Index, VariableIndex, Indexed, FlexiblyIndexed, IndexSum, ComponentTensor, ListTensor, Delta, @@ -80,6 +80,47 @@ def replace_division(expressions): return list(map(mapper, expressions)) +@singledispatch +def _expand_fixedindices(node, self): + """Simplify Indexed(ListTensor(A), FixedIndex) -> A[FixedIndex] + + :param node: root of expression + :param self: function for recursive calls + """ + raise AssertionError("cannot handle type %s" % type(node)) + + +_expand_fixedindices.register(Node)(reuse_if_untouched) + + +@_expand_fixedindices.register(Indexed) +def expand_fixedindices_indexed(node, self): + aggregate, = node.children + multiindex = node.multiindex + + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return self(Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))) + + return reuse_if_untouched(node, self) + + +def expand_fixedindices(expressions): + """Expands indices in multi-root expression DAG.""" + mapper = Memoizer(_expand_fixedindices) + return list(map(mapper, expressions)) + + @singledispatch def replace_indices(node, self, subst): """Replace free indices in a GEM expression. @@ -163,6 +204,7 @@ def filtered_replace_indices(node, self, subst): def remove_componenttensors(expressions): """Removes all ComponentTensors in multi-root expression DAG.""" + expressions = expand_fixedindices(expressions) mapper = MemoizerArg(filtered_replace_indices) return [mapper(expression, ()) for expression in expressions] From 40a7dc78e2169a964d196dcb49b3aba7c8c2dd55 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 2 Feb 2026 12:58:08 +0000 Subject: [PATCH 17/22] fix --- gem/gem.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9d68b6993..8a11e60a5 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -684,12 +684,12 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # Simplify Literal and ListTensor - if isinstance(aggregate, (Constant, ListTensor)): - if all(isinstance(i, int) for i in multiindex): - # All indices fixed - sub = aggregate.array[multiindex] - return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + # All indices fixed + if all(isinstance(i, int) for i in multiindex): + if isinstance(aggregate, Constant): + return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) + elif isinstance(aggregate, ListTensor): + return aggregate.array[multiindex] # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): From 2fea3ff66de651f299286c8cabc2e8ce970800ce Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 2 Feb 2026 13:19:30 +0000 Subject: [PATCH 18/22] restore Indexed.__new__ --- gem/gem.py | 18 ------------------ gem/optimise.py | 2 +- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 8a11e60a5..906984437 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,24 +691,6 @@ def __new__(cls, aggregate, multiindex): elif isinstance(aggregate, ListTensor): return aggregate.array[multiindex] - # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) - if isinstance(aggregate, ComponentTensor): - B, = aggregate.children - jj = aggregate.multiindex - ii = multiindex - - if isinstance(B, Indexed): - C, = B.children - kk = B.multiindex - if not isinstance(C, ComponentTensor): - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) - B = Indexed(C, ll) - jj = tuple(j for j in jj if j not in kk) - ii = tuple(rep[j] for j in jj) - if not ii: - return B - self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex diff --git a/gem/optimise.py b/gem/optimise.py index c9d74782b..721ebdce9 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -110,7 +110,7 @@ def expand_fixedindices_indexed(node, self): slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) sub = aggregate.array[slices] sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return self(Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) return reuse_if_untouched(node, self) From ca7f2f098216e2f113a350505cd38c988e7b4df5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 2 Feb 2026 16:34:04 +0000 Subject: [PATCH 19/22] Remove fixed indices within remove_componenttensors --- gem/optimise.py | 58 ++++++++++++++----------------------------------- 1 file changed, 16 insertions(+), 42 deletions(-) diff --git a/gem/optimise.py b/gem/optimise.py index 721ebdce9..bcc162fc1 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -80,47 +80,6 @@ def replace_division(expressions): return list(map(mapper, expressions)) -@singledispatch -def _expand_fixedindices(node, self): - """Simplify Indexed(ListTensor(A), FixedIndex) -> A[FixedIndex] - - :param node: root of expression - :param self: function for recursive calls - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -_expand_fixedindices.register(Node)(reuse_if_untouched) - - -@_expand_fixedindices.register(Indexed) -def expand_fixedindices_indexed(node, self): - aggregate, = node.children - multiindex = node.multiindex - - # Simplify Literal and ListTensor - if isinstance(aggregate, (Constant, ListTensor)): - if all(isinstance(i, int) for i in multiindex): - # All indices fixed - sub = aggregate.array[multiindex] - return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate.array[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - - return reuse_if_untouched(node, self) - - -def expand_fixedindices(expressions): - """Expands indices in multi-root expression DAG.""" - mapper = Memoizer(_expand_fixedindices) - return list(map(mapper, expressions)) - - @singledispatch def replace_indices(node, self, subst): """Replace free indices in a GEM expression. @@ -160,6 +119,22 @@ def replace_indices_delta(node, self, subst): def replace_indices_indexed(node, self, subst): multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children + + # Remove fixed indices + if isinstance(child, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = child.array[multiindex] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub + multiindex = () + + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = child.array[slices] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub) + multiindex = tuple(i for i in multiindex if not isinstance(i, int)) + if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules @@ -204,7 +179,6 @@ def filtered_replace_indices(node, self, subst): def remove_componenttensors(expressions): """Removes all ComponentTensors in multi-root expression DAG.""" - expressions = expand_fixedindices(expressions) mapper = MemoizerArg(filtered_replace_indices) return [mapper(expression, ()) for expression in expressions] From ff2c66472e2567638acd49a2c3da884fc681d197 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 2 Feb 2026 22:17:05 +0000 Subject: [PATCH 20/22] fix --- gem/optimise.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/gem/optimise.py b/gem/optimise.py index bcc162fc1..129f0afff 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -120,21 +120,6 @@ def replace_indices_indexed(node, self, subst): multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children - # Remove fixed indices - if isinstance(child, (Constant, ListTensor)): - if all(isinstance(i, int) for i in multiindex): - # All indices fixed - sub = child.array[multiindex] - child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub - multiindex = () - - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = child.array[slices] - child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub) - multiindex = tuple(i for i in multiindex if not isinstance(i, int)) - if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules @@ -143,11 +128,27 @@ def replace_indices_indexed(node, self, subst): return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices - new_child = self(child, subst) - if multiindex == node.multiindex and new_child == child: + child = self(child, subst) + + # Remove fixed indices + if isinstance(child, (Constant, ListTensor)): + if all(isinstance(i, Integral) for i in multiindex): + # All indices fixed + sub = child.array[multiindex] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub + multiindex = tuple() + + elif any(isinstance(i, Integral) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, Integral) else slice(None) for i in multiindex) + sub = child.array[slices] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub) + multiindex = tuple(i for i in multiindex if not isinstance(i, Integral)) + + if multiindex == node.multiindex and child == node.children[0]: return node else: - return Indexed(new_child, multiindex) + return Indexed(child, multiindex) @replace_indices.register(FlexiblyIndexed) From aab3a943161f9b1fd4fd62fac8a9fe7312cafe18 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 3 Feb 2026 10:59:07 +0000 Subject: [PATCH 21/22] fixes --- gem/gem.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 906984437..19bd7a44a 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -684,6 +684,22 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + ii = multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if not ii: + return Indexed(C, ll) + # All indices fixed if all(isinstance(i, int) for i in multiindex): if isinstance(aggregate, Constant): @@ -905,17 +921,22 @@ def __new__(cls, array): child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) - # Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex) + # Simplify [tensor[multiindex, j] for j in range(n)] -> partial_indexed(tensor, multiindex) if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] - multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) - index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) - if index_shape + array.shape + child_shape == tensor.shape: - if all(elem.children[0] == tensor for elem in array.flat[1:]): - if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): + if all(elem.children[0] == tensor for elem in array.flat[1:]): + multiindex = tuple(e0.multiindex) + for elem in array.flat[1:]: + while elem.multiindex[:len(multiindex)] != multiindex: + multiindex = multiindex[:-1] + if len(multiindex) == 0: + break + index_shape = tuple(i.extent if isinstance(i, Index) else 1 for i in multiindex) + if index_shape + array.shape + child_shape == tensor.shape: + if all(elem.multiindex[len(multiindex):] == idx for idx, elem in numpy.ndenumerate(array)): return partial_indexed(tensor, multiindex) - # Simplify [v[j, ...] for j in range(n)] -> v + # Simplify [tensor[j, ...] for j in range(n)] -> tensor if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) for elem in array.flat): tensor = e0.children[0].children[0] From 04f0cf3fc86d76762fc3dce18cfcfdc63bd20120 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 4 Feb 2026 13:31:16 +0000 Subject: [PATCH 22/22] review comments --- gem/gem.py | 15 +++++++++------ gem/optimise.py | 2 +- test/gem/test_simplify.py | 22 ++++++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 19bd7a44a..02aeaaf55 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -685,6 +685,8 @@ def __new__(cls, aggregate, multiindex): return Zero(dtype=aggregate.dtype) # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + # This pattern corresponds to an index replacement rule jj -> ii applied to + # the innermost multiindex kk to produce ll. if isinstance(aggregate, ComponentTensor): B, = aggregate.children jj = aggregate.multiindex @@ -692,16 +694,16 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if not isinstance(C, ComponentTensor): + ff = C.free_indices + if not any((j in ff) for j in jj): + # Only replace indices that are not present in C rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - jj = tuple(j for j in jj if j not in kk) - ii = tuple(rep[j] for j in jj) - if not ii: - return Indexed(C, ll) + aggregate = C + multiindex = ll # All indices fixed - if all(isinstance(i, int) for i in multiindex): + if all(isinstance(i, Integral) for i in multiindex): if isinstance(aggregate, Constant): return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) elif isinstance(aggregate, ListTensor): @@ -925,6 +927,7 @@ def __new__(cls, array): if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] if all(elem.children[0] == tensor for elem in array.flat[1:]): + # Extract maximal subset of leading indices that is common over all array entries multiindex = tuple(e0.multiindex) for elem in array.flat[1:]: while elem.multiindex[:len(multiindex)] != multiindex: diff --git a/gem/optimise.py b/gem/optimise.py index 129f0afff..9809cc2aa 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -136,7 +136,7 @@ def replace_indices_indexed(node, self, subst): # All indices fixed sub = child.array[multiindex] child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub - multiindex = tuple() + multiindex = () elif any(isinstance(i, Integral) for i in multiindex): # Some indices fixed diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py index 9d7b52a49..fa2242546 100644 --- a/test/gem/test_simplify.py +++ b/test/gem/test_simplify.py @@ -55,6 +55,28 @@ def test_componenttensor_from_indexed(A): assert A == gem.ComponentTensor(Aij, (i, j)) +def test_indexed_transpose(A): + i, j = gem.indices(2) + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + i, = gem.indices(1) + j = 1 + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + i, j = (0, 1) + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + +def test_double_transpose(A): + assert A.T.T == A + + def test_flatten_indexsum(A): i, j = gem.indices(2) Aij = gem.Indexed(A, (i, j))