diff --git a/gem/gem.py b/gem/gem.py index b31fd950..02aeaaf5 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -293,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -684,8 +684,26 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + # This pattern corresponds to an index replacement rule jj -> ii applied to + # the innermost multiindex kk to produce ll. + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + ii = multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + ff = C.free_indices + if not any((j in ff) for j in jj): + # Only replace indices that are not present in C + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + aggregate = C + multiindex = ll + # All indices fixed - if all(isinstance(i, int) for i in multiindex): + if all(isinstance(i, Integral) for i in multiindex): if isinstance(aggregate, Constant): return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) elif isinstance(aggregate, ListTensor): @@ -835,6 +853,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -871,6 +894,11 @@ def __new__(cls, summand, multiindex): if not multiindex: return summand + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -891,15 +919,46 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Simplify [tensor[multiindex, j] for j in range(n)] -> partial_indexed(tensor, multiindex) + if all(isinstance(elem, Indexed) for elem in array.flat): + tensor = e0.children[0] + if all(elem.children[0] == tensor for elem in array.flat[1:]): + # Extract maximal subset of leading indices that is common over all array entries + multiindex = tuple(e0.multiindex) + for elem in array.flat[1:]: + while elem.multiindex[:len(multiindex)] != multiindex: + multiindex = multiindex[:-1] + if len(multiindex) == 0: + break + index_shape = tuple(i.extent if isinstance(i, Index) else 1 for i in multiindex) + if index_shape + array.shape + child_shape == tensor.shape: + if all(elem.multiindex[len(multiindex):] == idx for idx, elem in numpy.ndenumerate(array)): + return partial_indexed(tensor, multiindex) + + # Simplify [tensor[j, ...] for j in range(n)] -> tensor + if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) + for elem in array.flat): + tensor = e0.children[0].children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): + if all(elem.children[0].multiindex == idx + elem.multiindex + for idx, elem in numpy.ndenumerate(array)): + return tensor + + # Flatten nested ListTensors + if all(isinstance(elem, ListTensor) for elem in array.flat): + return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape)) + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) - for alpha in numpy.ndindex(array.shape): + for alpha, elem in numpy.ndenumerate(array): for beta in numpy.ndindex(child_shape): - direct_array[alpha + beta] = Indexed(array[alpha], beta) + direct_array[alpha + beta] = Indexed(elem, beta) array = direct_array # Constant folding diff --git a/gem/optimise.py b/gem/optimise.py index 289ccff7..9809cc2a 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -11,7 +11,7 @@ from gem.utils import groupby from gem.node import (Memoizer, MemoizerArg, reuse_if_untouched, reuse_if_untouched_arg, traversal) -from gem.gem import (Node, Failure, Identity, Literal, Zero, +from gem.gem import (Node, Failure, Identity, Constant, Literal, Zero, Product, Sum, Comparison, Conditional, Division, Index, VariableIndex, Indexed, FlexiblyIndexed, IndexSum, ComponentTensor, ListTensor, Delta, @@ -119,6 +119,7 @@ def replace_indices_delta(node, self, subst): def replace_indices_indexed(node, self, subst): multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children + if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules @@ -127,11 +128,27 @@ def replace_indices_indexed(node, self, subst): return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices - new_child = self(child, subst) - if multiindex == node.multiindex and new_child == child: + child = self(child, subst) + + # Remove fixed indices + if isinstance(child, (Constant, ListTensor)): + if all(isinstance(i, Integral) for i in multiindex): + # All indices fixed + sub = child.array[multiindex] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else sub + multiindex = () + + elif any(isinstance(i, Integral) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, Integral) else slice(None) for i in multiindex) + sub = child.array[slices] + child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub) + multiindex = tuple(i for i in multiindex if not isinstance(i, Integral)) + + if multiindex == node.multiindex and child == node.children[0]: return node else: - return Indexed(new_child, multiindex) + return Indexed(child, multiindex) @replace_indices.register(FlexiblyIndexed) @@ -177,7 +194,7 @@ def _constant_fold_zero(node, self): @_constant_fold_zero.register(Literal) def _constant_fold_zero_literal(node, self): - if (node.array == 0).all(): + if numpy.array_equal(node.array, 0): # All zeros, make symbolic zero return Zero(node.shape) else: @@ -663,8 +680,8 @@ def _(node, self): # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) - unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape))) + tensor = ComponentTensor(summand, unroll) + unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape))) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py new file mode 100644 index 00000000..fa224254 --- /dev/null +++ b/test/gem/test_simplify.py @@ -0,0 +1,86 @@ +import pytest +import gem +import numpy + + +@pytest.fixture +def A(): + a = gem.Variable("a", ()) + b = gem.Variable("b", ()) + c = gem.Variable("c", ()) + d = gem.Variable("d", ()) + array = [[a, b], [c, d]] + A = gem.ListTensor(array) + return A + + +@pytest.fixture +def X(): + return gem.Variable("X", (2, 2)) + + +def test_listtensor_from_indexed(X): + k = gem.Index() + elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])] + tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:])) + + assert isinstance(tensor, gem.ComponentTensor) + j = tensor.multiindex + expected = gem.partial_indexed(X, (k,)) + expected = gem.ComponentTensor(gem.Indexed(expected, j), j) + assert tensor == expected + + +def test_listtensor_from_fixed_indexed(A): + elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] + tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) + assert tensor == A + + +def test_listtensor_from_partial_indexed(A): + elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])] + tensor = gem.ListTensor(elems) + assert tensor == A + + +def test_nested_partial_indexed(A): + i, j = gem.indices(2) + B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,)) + assert B == gem.Indexed(A, (i, j)) + + +def test_componenttensor_from_indexed(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + assert A == gem.ComponentTensor(Aij, (i, j)) + + +def test_indexed_transpose(A): + i, j = gem.indices(2) + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + i, = gem.indices(1) + j = 1 + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + i, j = (0, 1) + ATij = gem.Indexed(A.T, (i, j)) + Aji = gem.Indexed(A, (j, i)) + assert ATij == Aji + + +def test_double_transpose(A): + assert A.T.T == A + + +def test_flatten_indexsum(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + + result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,)) + expected = gem.IndexSum(Aij, (i, j)) + assert result == expected