Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 34 additions & 41 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set.
:param entity: the cell entity on which to tabulate.
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
value_shape = self.value_shape
value_size = np.prod(value_shape, dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
# In almost all cases, we have
Expand All @@ -109,51 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
# basis functions, and the additional 3 are for
# dealing with transformations between physical
# and reference space).
index_shape = (self._element.space_dimension(),)
space_dimension = self._element.space_dimension()
if self.space_dimension() == space_dimension:
beta = self.get_indices()
index_shape = tuple(index.extent for index in beta)
else:
index_shape = (space_dimension,)
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
basis_indices = beta + zeta

for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
result[alpha] = gem.Failure(index_shape + value_shape, fiat_table)
continue

derivative = sum(alpha)
shp = (space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1)

exprs = []
for table in table_roll:
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
exprs.append(gem.Literal(table[0]))
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(table, 0.0)
exprs.append(gem.Literal(np.zeros(self.index_shape)))
else:
point_indices = ps.indices
point_shape = tuple(index.extent for index in point_indices)

exprs.append(gem.partial_indexed(
gem.Literal(table.reshape(point_shape + index_shape)),
point_indices
))
if self.value_shape:
# As above, this extent may be different from that
# advertised by the finat element.
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(self.value_shape)),
zeta),
beta + zeta
)
fiat_table = fiat_table.reshape(space_dimension, value_size, -1)

point_indices = ()
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
fiat_table = fiat_table[..., 0]
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(fiat_table, 0.0)
fiat_table = np.zeros(fiat_table.shape[:-1])
else:
expr, = exprs
result[alpha] = expr
point_indices = ps.indices

point_shape = tuple(index.extent for index in point_indices)
table_shape = index_shape + value_shape + point_shape
table_indices = basis_indices + point_indices

expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices)
expr = gem.ComponentTensor(expr, basis_indices)
result[alpha] = expr
return result

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
Expand Down
89 changes: 75 additions & 14 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __radd__(self, other):
def __sub__(self, other):
return componentwise(
Sum, self,
componentwise(Product, Literal(-1), as_gem(other)))
componentwise(Product, minus, as_gem(other)))

def __rsub__(self, other):
return as_gem(other).__sub__(self)
Expand Down Expand Up @@ -293,7 +293,7 @@ def is_equal(self, other):
return False
if self.shape != other.shape:
return False
return tuple(self.array.flat) == tuple(other.array.flat)
return numpy.array_equal(self.array, other.array)

def get_hash(self):
return hash((type(self), self.shape, tuple(self.array.flat)))
Expand Down Expand Up @@ -684,12 +684,45 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
return aggregate.array[multiindex]
# Simplify Literal and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate.array[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub

elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate.array[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
jj = aggregate.multiindex
ii = multiindex
if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
if not isinstance(C, ComponentTensor):
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
B = Indexed(C, ll)
jj = tuple(j for j in jj if j not in kk)
ii = tuple(rep[j] for j in jj)
if not ii:
return B

if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
ff = C.free_indices
if all((j in kk) and (j not in ff) for j in jj):
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
return Indexed(C, ll)

self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
Expand Down Expand Up @@ -835,6 +868,11 @@ def __new__(cls, expression, multiindex):
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

# Index folding
if isinstance(expression, Indexed):
if multiindex == expression.multiindex:
return expression.children[0]

self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
Expand All @@ -857,6 +895,25 @@ def __new__(cls, summand, multiindex):
if isinstance(summand, Zero):
return summand

# No indices case
multiindex = tuple(multiindex)
if not multiindex:
return summand

# Flatten nested sums
if isinstance(summand, IndexSum):
A, = summand.children
return IndexSum(A, summand.multiindex + multiindex)

# Factor out common factors
if isinstance(summand, Product):
a, b = summand.children
if all(i not in a.free_indices for i in multiindex):
return Product(a, IndexSum(b, multiindex))

if all(i not in b.free_indices for i in multiindex):
return Product(IndexSum(a, multiindex), b)

# Unroll singleton sums
unroll = tuple(index for index in multiindex if index.extent <= 1)
if unroll:
Expand All @@ -866,11 +923,6 @@ def __new__(cls, summand, multiindex):
multiindex = tuple(index for index in multiindex
if index not in unroll)

# No indices case
multiindex = tuple(multiindex)
if not multiindex:
return summand

self = super(IndexSum, cls).__new__(cls)
self.children = (summand,)
self.multiindex = multiindex
Expand All @@ -891,9 +943,17 @@ def __new__(cls, array):
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
e0 = array.flat[0]
child_shape = e0.shape
assert all(elem.shape == child_shape for elem in array.flat)

# Index folding
if child_shape == array.shape:
if all(isinstance(elem, Indexed) for elem in array.flat):
if all(elem.children == e0.children for elem in array.flat[1:]):
if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))):
return e0.children[0]

if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
Expand Down Expand Up @@ -1218,6 +1278,7 @@ def view(expression, *slices):

# Static one object for quicker constant folding
one = Literal(1)
minus = Literal(-1)


# Syntax sugar
Expand Down
Loading