diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 658455e9..df0d598a 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -17,15 +17,22 @@ def dual_transformation(self, Q, coordinate_mapping=None): class MappedTabulation(Mapping): """A lazy tabulation dict that applies the basis transformation only - on the requested derivatives.""" + on the requested derivatives. - def __init__(self, M, ref_tabulation): + :arg M: a gem.ListTensor with the basis transformation matrix. + :arg ref_tabulation: a dict of tabulations on the reference cell. + :kwarg indices: an optional list of restriction indices on the basis functions. + """ + def __init__(self, M, ref_tabulation, indices=None): self.M = M self.ref_tabulation = ref_tabulation + if indices is None: + indices = list(range(M.shape[0])) + self.indices = indices # we expect M to be sparse with O(1) nonzeros per row # for each row, get the column index of each nonzero entry csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] - for i in range(M.shape[0])] + for i in indices] self.csr = csr self._tabulation_cache = {} @@ -35,7 +42,7 @@ def matvec(self, table): phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])] # the sum approach is faster than calling numpy.dot or gem.IndexSum exprs = [gem.ComponentTensor(gem.Sum(*(self.M.array[i, j] * phi[j] for j in js)), ii) - for i, js in enumerate(self.csr)] + for i, js in zip(self.indices, self.csr)] result = gem.ListTensor(exprs) result, = gem.optimise.unroll_indexsum((result,), lambda index: True) @@ -64,6 +71,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) cite("Kirby2018zany") cite("Kirby2019zany") + self.restriction_indices = None @abstractmethod def basis_transformation(self, coordinate_mapping): @@ -75,7 +83,7 @@ def basis_transformation(self, coordinate_mapping): def map_tabulation(self, ref_tabulation, coordinate_mapping): assert coordinate_mapping is not None M = self.basis_transformation(coordinate_mapping) - return MappedTabulation(M, ref_tabulation) + return MappedTabulation(M, ref_tabulation, indices=self.restriction_indices) def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): result = super().basis_evaluation(order, ps, entity=entity) @@ -89,9 +97,14 @@ def dual_transformation(self, Q, coordinate_mapping=None): M = self.basis_transformation(coordinate_mapping) M = M.array - if M.shape[0] != M.shape[1]: - M = M[:, :self.space_dimension()] - M_dual = gem.ListTensor(inverse(M.T)) + if M.shape[1] > M.shape[0]: + M = M[:, :M.shape[0]] + + M_dual = inverse(M.T) + if self.restriction_indices is not None: + indices = self.restriction_indices + M_dual = M_dual[numpy.ix_(indices, indices)] + M_dual = gem.ListTensor(M_dual) key = None return MappedTabulation(M_dual, {key: Q})[key] diff --git a/finat/restricted.py b/finat/restricted.py index a85e0d39..b5144be9 100644 --- a/finat/restricted.py +++ b/finat/restricted.py @@ -13,6 +13,45 @@ null_element = object() +class RestrictedPhysicallyMappedElement(PhysicallyMappedElement, FiatElement): + """A restricted PhysicallyMappedElement. + + :arg element: the finat.FiatElement to be restricted. + :arg indices: the indices of the degrees of freedom to be kept. + """ + def __init__(self, element, indices): + super().__init__(element._element) + # First sanitise the restriction indices + # Some finat elements, e.g. Bell, are the restriction of a FIAT element. + # Therefore, we need to compose the restrictions. + edofs = element.entity_dofs() + free_indices = set(chain.from_iterable(edofs[d][e] for d in edofs for e in edofs[d])) + indices = [i for i in indices if i in free_indices] + self.restriction_indices = indices + + # Restrict the entity_dofs dict + rdofs = {d: {e: [indices.index(i) for i in edofs[d][e] if i in indices] + for e in edofs[d]} for d in edofs} + self.restriction_entity_dofs = rdofs + + # Grab the basis transformation matrix from the parent element + if isinstance(element, PhysicallyMappedElement): + self.full_basis_transformation = element.basis_transformation + else: + self.full_basis_transformation = None + + def basis_transformation(self, coordinate_mapping): + if self.full_basis_transformation is None: + raise NotImplementedError("basis_transformation not implemented.") + return self.full_basis_transformation(coordinate_mapping) + + def space_dimension(self): + return len(self.restriction_indices) + + def entity_dofs(self): + return self.restriction_entity_dofs + + @singledispatch def restrict(element, domain, take_closure): """Restrict an element to a given subentity. @@ -39,16 +78,15 @@ def restrict_fiat(element, domain, take_closure): return null_element if element.space_dimension() == re.space_dimension(): - # FIAT.RestrictedElement wipes out entity_permuations. + # FIAT.RestrictedElement wipes out entity_permutations. # In case the restriction is trivial we return the original element # to avoid reconstructing the space with an undesired permutation. return element - return FiatElement(re) + if isinstance(element, PhysicallyMappedElement) and not (domain == "interior" and not take_closure): + return RestrictedPhysicallyMappedElement(element, re._indices) -@restrict.register(PhysicallyMappedElement) -def restrict_physically_mapped(element, domain, take_closure): - raise NotImplementedError("Can't restrict Physically Mapped things") + return FiatElement(re) @restrict.register(finat.FlattenedDimensions) diff --git a/finat/ufl/restrictedelement.py b/finat/ufl/restrictedelement.py index 921acd02..225bd33c 100644 --- a/finat/ufl/restrictedelement.py +++ b/finat/ufl/restrictedelement.py @@ -13,7 +13,6 @@ from finat.ufl.finiteelementbase import FiniteElementBase from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement -from ufl.sobolevspace import L2 valid_restriction_domains = ("interior", "facet", "ridge", "face", "edge", "vertex", "reduced") @@ -61,10 +60,7 @@ def __repr__(self): @property def sobolev_space(self): """Doc.""" - if self._restriction_domain == "interior": - return L2 - else: - return self._element.sobolev_space + return self._element.sobolev_space def is_cellwise_constant(self): """Return whether the basis functions of this element is spatially constant over each cell.""" diff --git a/test/finat/test_restriction.py b/test/finat/test_restriction.py index 443292a0..8ba91d7a 100644 --- a/test/finat/test_restriction.py +++ b/test/finat/test_restriction.py @@ -5,10 +5,18 @@ from finat.point_set import PointSet from finat.restricted import r_to_codim from gem.interpreter import evaluate +from finat.physically_mapped import NeedsCoordinateMappingElement +from .conftest import MyMapping def tabulate(element, ps): - tabulation, = element.basis_evaluation(0, ps).values() + coordinate_mapping = None + if isinstance(element, NeedsCoordinateMappingElement): + ref_el = element.cell + sd = ref_el.get_spatial_dimension() + phys_el = FIAT.reference_element.symmetric_simplex(sd) + coordinate_mapping = MyMapping(ref_el, phys_el) + tabulation, = element.basis_evaluation(0, ps, coordinate_mapping=coordinate_mapping).values() result, = evaluate([tabulation]) # Singleton point shape = (int(numpy.prod(element.index_shape)), ) + element.value_shape @@ -142,3 +150,15 @@ def test_hdiv_restriction(hdiv_element, restriction, ps): def test_hcurl_restriction(hcurl_element, restriction, ps): run_restriction(hcurl_element, restriction, ps) + + +@pytest.fixture +def zany_element(cell): + if len(cell) == 1: + return finat.Walkington(cell[0]) + else: + pytest.skip() + + +def test_zany_restriction(zany_element, restriction, ps): + run_restriction(zany_element, restriction, ps)