diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index a5d61b3b5..f6ac0e685 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - space_dimension = self._element.space_dimension() - value_size = np.prod(self._element.value_shape(), dtype=int) + value_shape = self.value_shape + value_size = np.prod(value_shape, dtype=int) fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have @@ -109,51 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - index_shape = (self._element.space_dimension(),) + space_dimension = self._element.space_dimension() + if self.space_dimension() == space_dimension: + beta = self.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + basis_indices = beta + zeta + for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) + result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue derivative = sum(alpha) - shp = (space_dimension, value_size, *ps.points.shape[:-1]) - table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) - - exprs = [] - for table in table_roll: - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - exprs.append(gem.Literal(table[0])) - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(table, 0.0) - exprs.append(gem.Literal(np.zeros(self.index_shape))) - else: - point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - - exprs.append(gem.partial_indexed( - gem.Literal(table.reshape(point_shape + index_shape)), - point_indices - )) - if self.value_shape: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(self.value_shape)), - zeta), - beta + zeta - ) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + + point_indices = () + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs - result[alpha] = expr + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = basis_indices + point_indices + + expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices) + expr = gem.ComponentTensor(expr, basis_indices) + result[alpha] = expr return result def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None): diff --git a/gem/gem.py b/gem/gem.py index 4bae48e89..8c7349dd6 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -293,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -684,12 +684,45 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + ii = multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + B = Indexed(C, ll) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if not ii: + return B + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -835,6 +868,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -857,6 +895,25 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # No indices case + multiindex = tuple(multiindex) + if not multiindex: + return summand + + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + + # Factor out common factors + if isinstance(summand, Product): + a, b = summand.children + if all(i not in a.free_indices for i in multiindex): + return Product(a, IndexSum(b, multiindex)) + + if all(i not in b.free_indices for i in multiindex): + return Product(IndexSum(a, multiindex), b) + # Unroll singleton sums unroll = tuple(index for index in multiindex if index.extent <= 1) if unroll: @@ -866,11 +923,6 @@ def __new__(cls, summand, multiindex): multiindex = tuple(index for index in multiindex if index not in unroll) - # No indices case - multiindex = tuple(multiindex) - if not multiindex: - return summand - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -891,9 +943,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) @@ -1218,6 +1278,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar