From 8586a7cdc7ad8177e37321ab16695945f3db96aa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 Aug 2025 10:13:26 +0100 Subject: [PATCH] Cleanup FiatElement.basis_evaluation --- finat/fiat_elements.py | 75 +++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index a5d61b3b5..f6ac0e685 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - space_dimension = self._element.space_dimension() - value_size = np.prod(self._element.value_shape(), dtype=int) + value_shape = self.value_shape + value_size = np.prod(value_shape, dtype=int) fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have @@ -109,51 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - index_shape = (self._element.space_dimension(),) + space_dimension = self._element.space_dimension() + if self.space_dimension() == space_dimension: + beta = self.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + basis_indices = beta + zeta + for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) + result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue derivative = sum(alpha) - shp = (space_dimension, value_size, *ps.points.shape[:-1]) - table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) - - exprs = [] - for table in table_roll: - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - exprs.append(gem.Literal(table[0])) - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(table, 0.0) - exprs.append(gem.Literal(np.zeros(self.index_shape))) - else: - point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - - exprs.append(gem.partial_indexed( - gem.Literal(table.reshape(point_shape + index_shape)), - point_indices - )) - if self.value_shape: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(self.value_shape)), - zeta), - beta + zeta - ) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + + point_indices = () + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs - result[alpha] = expr + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = basis_indices + point_indices + + expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices) + expr = gem.ComponentTensor(expr, basis_indices) + result[alpha] = expr return result def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):