From 8586a7cdc7ad8177e37321ab16695945f3db96aa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 Aug 2025 10:13:26 +0100 Subject: [PATCH 1/9] Cleanup FiatElement.basis_evaluation --- finat/fiat_elements.py | 75 +++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index a5d61b3b5..f6ac0e685 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - space_dimension = self._element.space_dimension() - value_size = np.prod(self._element.value_shape(), dtype=int) + value_shape = self.value_shape + value_size = np.prod(value_shape, dtype=int) fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have @@ -109,51 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - index_shape = (self._element.space_dimension(),) + space_dimension = self._element.space_dimension() + if self.space_dimension() == space_dimension: + beta = self.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + basis_indices = beta + zeta + for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) + result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue derivative = sum(alpha) - shp = (space_dimension, value_size, *ps.points.shape[:-1]) - table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) - - exprs = [] - for table in table_roll: - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - exprs.append(gem.Literal(table[0])) - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(table, 0.0) - exprs.append(gem.Literal(np.zeros(self.index_shape))) - else: - point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - - exprs.append(gem.partial_indexed( - gem.Literal(table.reshape(point_shape + index_shape)), - point_indices - )) - if self.value_shape: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(self.value_shape)), - zeta), - beta + zeta - ) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + + point_indices = () + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs - result[alpha] = expr + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = basis_indices + point_indices + + expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices) + expr = gem.ComponentTensor(expr, basis_indices) + result[alpha] = expr return result def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None): From de501c10b9adee96ec3468f62b972a16dde97085 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 Aug 2025 10:29:01 +0100 Subject: [PATCH 2/9] GEM: simplify indexed --- gem/gem.py | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 974556754..f468747bd 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -685,12 +685,31 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -836,6 +855,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -892,9 +916,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) From 592061ea4f4e3416b3729f658cd6b2380de494a4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:30:05 +0100 Subject: [PATCH 3/9] Fixes for more complicated expressions --- gem/gem.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index f468747bd..79bcd2e79 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -274,6 +274,7 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -702,14 +703,29 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, ComponentTensor): B, = aggregate.children jj = aggregate.multiindex + ii = multiindex + # Avoid recursion and just attempt to simplify some common patterns + # as the result of this method is not cached. if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): - ii = tuple(multiindex) + if isinstance(C, ListTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - return Indexed(C, ll) + B = Indexed(C, ll) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if len(ii) == 0: + return B + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -722,6 +738,7 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) + self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From 158b0ec918a2100f496d9b59bbcce5235cb24c24 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:41:42 +0100 Subject: [PATCH 4/9] small change --- gem/gem.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 111f4bb3b..cec66af68 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,7 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -294,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -737,7 +736,6 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) - self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From d2a45843f703db6e8ceff49e43541812785d6b93 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 17:50:56 +0100 Subject: [PATCH 5/9] More simplification --- gem/gem.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index cec66af68..9eb00f4ac 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,6 +273,9 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + if numpy.allclose(array, 0, 1e-14): + return Zero(array.shape) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -690,6 +693,7 @@ def __new__(cls, aggregate, multiindex): # All indices fixed sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): # Some indices fixed slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) @@ -719,12 +723,20 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): - if all(j in kk for j in jj): - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + if isinstance(C, ComponentTensor): + if (all(isinstance(i, Index) for i in ii) + or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): + return Indexed(C, ll) + else: return Indexed(C, ll) + if len(ii) < len(multiindex): + aggregate = ComponentTensor(B, jj) + multiindex = ii + self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex From 19823731b41a00c66c1857f654cbce84dc65ee6c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 12:19:35 +0100 Subject: [PATCH 6/9] Do not replace free indices --- gem/gem.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9eb00f4ac..3ec4cb1a1 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -706,8 +706,6 @@ def __new__(cls, aggregate, multiindex): B, = aggregate.children jj = aggregate.multiindex ii = multiindex - # Avoid recursion and just attempt to simplify some common patterns - # as the result of this method is not cached. if isinstance(B, Indexed): C, = B.children kk = B.multiindex @@ -723,15 +721,11 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - if isinstance(C, ComponentTensor): - if (all(isinstance(i, Index) for i in ii) - or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): - return Indexed(C, ll) - else: - return Indexed(C, ll) + return Indexed(C, ll) if len(ii) < len(multiindex): aggregate = ComponentTensor(B, jj) @@ -1277,6 +1271,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From 4ab1c0d1a5ebd652a7005ed7bacc86be2de1ebb4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 12:19:35 +0100 Subject: [PATCH 7/9] Do not replace free indices --- gem/gem.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9eb00f4ac..8a56a6ef9 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -273,9 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - if numpy.allclose(array, 0, 1e-14): - return Zero(array.shape) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -706,8 +703,6 @@ def __new__(cls, aggregate, multiindex): B, = aggregate.children jj = aggregate.multiindex ii = multiindex - # Avoid recursion and just attempt to simplify some common patterns - # as the result of this method is not cached. if isinstance(B, Indexed): C, = B.children kk = B.multiindex @@ -723,15 +718,11 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - if isinstance(C, ComponentTensor): - if (all(isinstance(i, Index) for i in ii) - or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): - return Indexed(C, ll) - else: - return Indexed(C, ll) + return Indexed(C, ll) if len(ii) < len(multiindex): aggregate = ComponentTensor(B, jj) @@ -1277,6 +1268,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From b49eb5cfd8bec337efbb323277df7f74ca4cfff5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 12:19:35 +0100 Subject: [PATCH 8/9] Do not replace free indices --- gem/gem.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9eb00f4ac..2ee2fb87f 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -273,9 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - if numpy.allclose(array, 0, 1e-14): - return Zero(array.shape) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -706,36 +703,27 @@ def __new__(cls, aggregate, multiindex): B, = aggregate.children jj = aggregate.multiindex ii = multiindex - # Avoid recursion and just attempt to simplify some common patterns - # as the result of this method is not cached. + if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if isinstance(C, ListTensor): + if not isinstance(C, ComponentTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) B = Indexed(C, ll) jj = tuple(j for j in jj if j not in kk) ii = tuple(rep[j] for j in jj) - if len(ii) == 0: + if not ii: return B if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - if isinstance(C, ComponentTensor): - if (all(isinstance(i, Index) for i in ii) - or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): - return Indexed(C, ll) - else: - return Indexed(C, ll) - - if len(ii) < len(multiindex): - aggregate = ComponentTensor(B, jj) - multiindex = ii + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -1277,6 +1265,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From 340ec400181d3fd1ae8469961d940488ee456ef0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 22:22:28 +0100 Subject: [PATCH 9/9] Simplify IndexSum --- gem/gem.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2ee2fb87f..23e91ceec 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -896,6 +896,25 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # No indices case + multiindex = tuple(multiindex) + if not multiindex: + return summand + + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + + # Factor out common factors + if isinstance(summand, Product): + a, b = summand.children + if all(i not in a.free_indices for i in multiindex): + return Product(a, IndexSum(b, multiindex)) + + if all(i not in b.free_indices for i in multiindex): + return Product(IndexSum(a, multiindex), b) + # Unroll singleton sums unroll = tuple(index for index in multiindex if index.extent <= 1) if unroll: @@ -905,11 +924,6 @@ def __new__(cls, summand, multiindex): multiindex = tuple(index for index in multiindex if index not in unroll) - # No indices case - multiindex = tuple(multiindex) - if not multiindex: - return summand - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex