Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 34 additions & 41 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set.
:param entity: the cell entity on which to tabulate.
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
value_shape = self.value_shape
value_size = np.prod(value_shape, dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
# In almost all cases, we have
Expand All @@ -109,51 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
# basis functions, and the additional 3 are for
# dealing with transformations between physical
# and reference space).
index_shape = (self._element.space_dimension(),)
space_dimension = self._element.space_dimension()
if self.space_dimension() == space_dimension:
beta = self.get_indices()
index_shape = tuple(index.extent for index in beta)
else:
index_shape = (space_dimension,)
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
basis_indices = beta + zeta

for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
result[alpha] = gem.Failure(index_shape + value_shape, fiat_table)
continue

derivative = sum(alpha)
shp = (space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1)

exprs = []
for table in table_roll:
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
exprs.append(gem.Literal(table[0]))
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(table, 0.0)
exprs.append(gem.Literal(np.zeros(self.index_shape)))
else:
point_indices = ps.indices
point_shape = tuple(index.extent for index in point_indices)

exprs.append(gem.partial_indexed(
gem.Literal(table.reshape(point_shape + index_shape)),
point_indices
))
if self.value_shape:
# As above, this extent may be different from that
# advertised by the finat element.
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(self.value_shape)),
zeta),
beta + zeta
)
fiat_table = fiat_table.reshape(space_dimension, value_size, -1)

point_indices = ()
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
fiat_table = fiat_table[..., 0]
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(fiat_table, 0.0)
fiat_table = np.zeros(fiat_table.shape[:-1])
else:
expr, = exprs
result[alpha] = expr
point_indices = ps.indices

point_shape = tuple(index.extent for index in point_indices)
table_shape = index_shape + value_shape + point_shape
table_indices = basis_indices + point_indices

expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices)
expr = gem.ComponentTensor(expr, basis_indices)
result[alpha] = expr
return result

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
Expand Down
Loading