Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@ def dual_transformation(self, Q, coordinate_mapping=None):

class MappedTabulation(Mapping):
"""A lazy tabulation dict that applies the basis transformation only
on the requested derivatives."""
on the requested derivatives.

def __init__(self, M, ref_tabulation):
:arg M: a gem.ListTensor with the basis transformation matrix.
:arg ref_tabulation: a dict of tabulations on the reference cell.
:kwarg indices: an optional list of restriction indices on the basis functions.
"""
def __init__(self, M, ref_tabulation, indices=None):
self.M = M
self.ref_tabulation = ref_tabulation
if indices is None:
indices = list(range(M.shape[0]))
self.indices = indices
# we expect M to be sparse with O(1) nonzeros per row
# for each row, get the column index of each nonzero entry
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
for i in range(M.shape[0])]
for i in indices]
self.csr = csr
self._tabulation_cache = {}

Expand All @@ -35,7 +42,7 @@ def matvec(self, table):
phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])]
# the sum approach is faster than calling numpy.dot or gem.IndexSum
exprs = [gem.ComponentTensor(gem.Sum(*(self.M.array[i, j] * phi[j] for j in js)), ii)
for i, js in enumerate(self.csr)]
for i, js in zip(self.indices, self.csr)]

result = gem.ListTensor(exprs)
result, = gem.optimise.unroll_indexsum((result,), lambda index: True)
Expand Down Expand Up @@ -64,6 +71,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
cite("Kirby2018zany")
cite("Kirby2019zany")
self.restriction_indices = None

@abstractmethod
def basis_transformation(self, coordinate_mapping):
Expand All @@ -75,7 +83,7 @@ def basis_transformation(self, coordinate_mapping):
def map_tabulation(self, ref_tabulation, coordinate_mapping):
assert coordinate_mapping is not None
M = self.basis_transformation(coordinate_mapping)
return MappedTabulation(M, ref_tabulation)
return MappedTabulation(M, ref_tabulation, indices=self.restriction_indices)

def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
result = super().basis_evaluation(order, ps, entity=entity)
Expand All @@ -89,9 +97,14 @@ def dual_transformation(self, Q, coordinate_mapping=None):
M = self.basis_transformation(coordinate_mapping)

M = M.array
if M.shape[0] != M.shape[1]:
M = M[:, :self.space_dimension()]
M_dual = gem.ListTensor(inverse(M.T))
if M.shape[1] > M.shape[0]:
M = M[:, :M.shape[0]]

M_dual = inverse(M.T)
if self.restriction_indices is not None:
indices = self.restriction_indices
M_dual = M_dual[numpy.ix_(indices, indices)]
M_dual = gem.ListTensor(M_dual)

key = None
return MappedTabulation(M_dual, {key: Q})[key]
Expand Down
48 changes: 43 additions & 5 deletions finat/restricted.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,45 @@
null_element = object()


class RestrictedPhysicallyMappedElement(PhysicallyMappedElement, FiatElement):
"""A restricted PhysicallyMappedElement.

:arg element: the finat.FiatElement to be restricted.
:arg indices: the indices of the degrees of freedom to be kept.
"""
def __init__(self, element, indices):
super().__init__(element._element)
# First sanitise the restriction indices
# Some finat elements, e.g. Bell, are the restriction of a FIAT element.
# Therefore, we need to compose the restrictions.
edofs = element.entity_dofs()
free_indices = set(chain.from_iterable(edofs[d][e] for d in edofs for e in edofs[d]))
indices = [i for i in indices if i in free_indices]
self.restriction_indices = indices

# Restrict the entity_dofs dict
rdofs = {d: {e: [indices.index(i) for i in edofs[d][e] if i in indices]
for e in edofs[d]} for d in edofs}
self.restriction_entity_dofs = rdofs

# Grab the basis transformation matrix from the parent element
if isinstance(element, PhysicallyMappedElement):
self.full_basis_transformation = element.basis_transformation
else:
self.full_basis_transformation = None

def basis_transformation(self, coordinate_mapping):
if self.full_basis_transformation is None:
raise NotImplementedError("basis_transformation not implemented.")
return self.full_basis_transformation(coordinate_mapping)

def space_dimension(self):
return len(self.restriction_indices)

def entity_dofs(self):
return self.restriction_entity_dofs


@singledispatch
def restrict(element, domain, take_closure):
"""Restrict an element to a given subentity.
Expand All @@ -39,16 +78,15 @@ def restrict_fiat(element, domain, take_closure):
return null_element

if element.space_dimension() == re.space_dimension():
# FIAT.RestrictedElement wipes out entity_permuations.
# FIAT.RestrictedElement wipes out entity_permutations.
# In case the restriction is trivial we return the original element
# to avoid reconstructing the space with an undesired permutation.
return element
return FiatElement(re)

if isinstance(element, PhysicallyMappedElement) and not (domain == "interior" and not take_closure):
return RestrictedPhysicallyMappedElement(element, re._indices)

@restrict.register(PhysicallyMappedElement)
def restrict_physically_mapped(element, domain, take_closure):
raise NotImplementedError("Can't restrict Physically Mapped things")
return FiatElement(re)


@restrict.register(finat.FlattenedDimensions)
Expand Down
6 changes: 1 addition & 5 deletions finat/ufl/restrictedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from finat.ufl.finiteelementbase import FiniteElementBase
from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement
from ufl.sobolevspace import L2

valid_restriction_domains = ("interior", "facet", "ridge", "face", "edge", "vertex", "reduced")

Expand Down Expand Up @@ -61,10 +60,7 @@ def __repr__(self):
@property
def sobolev_space(self):
"""Doc."""
if self._restriction_domain == "interior":
return L2
else:
return self._element.sobolev_space
return self._element.sobolev_space

def is_cellwise_constant(self):
"""Return whether the basis functions of this element is spatially constant over each cell."""
Expand Down
22 changes: 21 additions & 1 deletion test/finat/test_restriction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
from finat.point_set import PointSet
from finat.restricted import r_to_codim
from gem.interpreter import evaluate
from finat.physically_mapped import NeedsCoordinateMappingElement
from .conftest import MyMapping


def tabulate(element, ps):
tabulation, = element.basis_evaluation(0, ps).values()
coordinate_mapping = None
if isinstance(element, NeedsCoordinateMappingElement):
ref_el = element.cell
sd = ref_el.get_spatial_dimension()
phys_el = FIAT.reference_element.symmetric_simplex(sd)
coordinate_mapping = MyMapping(ref_el, phys_el)
tabulation, = element.basis_evaluation(0, ps, coordinate_mapping=coordinate_mapping).values()
result, = evaluate([tabulation])
# Singleton point
shape = (int(numpy.prod(element.index_shape)), ) + element.value_shape
Expand Down Expand Up @@ -142,3 +150,15 @@ def test_hdiv_restriction(hdiv_element, restriction, ps):

def test_hcurl_restriction(hcurl_element, restriction, ps):
run_restriction(hcurl_element, restriction, ps)


@pytest.fixture
def zany_element(cell):
if len(cell) == 1:
return finat.Walkington(cell[0])
else:
pytest.skip()


def test_zany_restriction(zany_element, restriction, ps):
run_restriction(zany_element, restriction, ps)