Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7f2dd72
Boundary Quadrature element
pbrubeck Jan 9, 2025
3f871a6
Introduce MappedPointSet
pbrubeck Jan 10, 2025
9b26029
GEM: Syntax sugar
pbrubeck Jan 11, 2025
0a310e4
Merge branch 'master' into pbrubeck/finat-boundary-quadrature
pbrubeck Jan 16, 2025
24c6952
GEM: Simplify Indexed tensors
pbrubeck Jan 21, 2025
e8ff24b
Optimize zany-mapping matvec
pbrubeck Jan 21, 2025
783e0c4
WIP
pbrubeck Jan 21, 2025
27cf83d
Merge branch 'master' into pbrubeck/finat-boundary-quadrature
pbrubeck Jan 22, 2025
fdff8da
Fix for Delta(int, Index)
pbrubeck Jan 22, 2025
e77ed6f
Fix for Delta(int, Index)
pbrubeck Jan 22, 2025
07a685c
finat: Untangle value shape
pbrubeck Jan 26, 2025
eb35c89
finat: lazy basis transformation
pbrubeck Jan 26, 2025
6553375
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Jan 26, 2025
2acd16d
coffee: modernize
pbrubeck Jan 26, 2025
2cfa208
Merge branch 'pbrubeck/simplify-indexed' of github.com:firedrakeproje…
pbrubeck Jan 26, 2025
b6ec065
finat: untangle basis_transformation
pbrubeck Jan 27, 2025
6cb41db
Merge branch 'master' into pbrubeck/finat-boundary-quadrature
pbrubeck Feb 13, 2025
223669a
flake
pbrubeck Feb 13, 2025
f1d0422
Merge simplify_indexed
pbrubeck Mar 17, 2025
93dcb89
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Mar 17, 2025
3ee52fa
Merge branch 'pbrubeck/simplify-indexed' into pbrubeck/finat-boundary…
pbrubeck Mar 17, 2025
329dd46
FiniteElementBase: implement is_dg
pbrubeck Mar 17, 2025
83a8d7e
restore polynomial_set.py
pbrubeck Mar 17, 2025
2eadfb7
Revert "Merge simplify_indexed"
pbrubeck Mar 18, 2025
0ae1e19
fix Delta(int, Index)
pbrubeck Mar 18, 2025
b9af43e
Docs
pbrubeck Mar 24, 2025
b2e8563
Docs
pbrubeck Mar 26, 2025
b50031e
lint
pbrubeck Mar 27, 2025
be58cbf
lint
pbrubeck Mar 27, 2025
c52b024
lint
pbrubeck Mar 27, 2025
45a9f15
Apply suggestions from code review
pbrubeck Mar 27, 2025
f187e6a
Update finat/tensor_product.py
pbrubeck Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions finat/element_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def convert(element, **kwargs):
@convert.register(finat.ufl.FiniteElement)
def convert_finiteelement(element, **kwargs):
cell = as_fiat_cell(element.cell)
if element.family() == "Quadrature":
if element.family() in {"Quadrature", "Boundary Quadrature"}:
degree = element.degree()
scheme = element.quadrature_scheme()
scheme = element.quadrature_scheme() or "default"
if degree is None or scheme is None:
raise ValueError("Quadrature scheme and degree must be specified!")

return finat.make_quadrature_element(cell, degree, scheme), set()
codim = 1 if element.family() == "Boundary Quadrature" else 0
return finat.make_quadrature_element(cell, degree, scheme, codim), set()
lmbda = supported_elements[element.family()]
if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}:
lmbda = None
Expand Down
5 changes: 2 additions & 3 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
continue

derivative = sum(alpha)
table_roll = fiat_table.reshape(
space_dimension, value_size, len(ps.points)
).transpose(1, 2, 0)
shp = (space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1)

exprs = []
for table in table_roll:
Expand Down
3 changes: 3 additions & 0 deletions finat/finiteelementbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def entity_closure_dofs(self):
element.'''
return self._entity_closure_dofs

def is_dg(self):
return self.entity_dofs() == self.entity_closure_dofs()

@cached_property
def _entity_support_dofs(self):
esd = {}
Expand Down
63 changes: 58 additions & 5 deletions finat/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def points(self):
@property
def dimension(self):
"""Point dimension."""
_, dim = self.points.shape
return dim
return self.points.shape[-1]

@property
@abc.abstractmethod
Expand Down Expand Up @@ -129,8 +128,7 @@ def points(self):

@cached_property
def indices(self):
N, _ = self._points_expr.shape
return (gem.Index(extent=N),)
return tuple(gem.Index(extent=N) for N in self._points_expr.shape[:-1])

@cached_property
def expression(self):
Expand Down Expand Up @@ -159,7 +157,7 @@ def points(self):

@cached_property
def indices(self):
return (gem.Index(extent=len(self.points)),)
return tuple(gem.Index(extent=N) for N in self.points.shape[:-1])

@cached_property
def expression(self):
Expand Down Expand Up @@ -224,3 +222,58 @@ def almost_equal(self, other, tolerance=1e-12):
len(self.factors) == len(other.factors) and \
all(s.almost_equal(o, tolerance=tolerance)
for s, o in zip(self.factors, other.factors))


class FacetPointSet(AbstractPointSet):
"""A point set on facets.

A FacetPointSet is constructed by mapping a lower-dimensional PointSet
onto all facets of a higher-dimensional cell.

Parameters
----------
cell : FIAT.Cell
The cell.
ps : PointSet
A lower-dimensional point set.

"""
def __init__(self, cell, ps):
self.cell = cell
self.ps = ps

def __repr__(self):
return f"{type(self).__name__}({self.ps!r})"

@cached_property
def entities(self):
"""The list of all cell entites matching the reference point set dimension."""
to_int = lambda x: sum(x) if isinstance(x, tuple) else x
top = self.cell.topology
return [(dim, entity)
for dim in sorted(top)
for entity in sorted(top[dim])
if to_int(dim) == self.ps.dimension]

@cached_property
def points(self):
"""The array with the reference points mapped onto each facet."""
ref_pts = self.ps.points
pts = [self.cell.get_entity_transform(dim, entity)(ref_pts)
for dim, entity in self.entities]
return numpy.concatenate(pts)

@cached_property
def indices(self):
"""An Index tuple of the facet index and the reference point indices."""
return (gem.Index(extent=len(self.entities)), *self.ps.indices)

@cached_property
def expression(self):
raise NotImplementedError("Symbolic point expression is not yet implemented for FacetPointSet.")

def almost_equal(self, other, tolerance=1e-12):
"""Approximate numerical equality of point sets"""
return type(self) is type(other) and \
self.cell == other.cell and \
self.ps.almost_equal(other.ps, tolerance=tolerance)
4 changes: 2 additions & 2 deletions finat/quadrature.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
from abc import ABCMeta, abstractmethod
from functools import cached_property, reduce
from functools import cached_property

import gem
import numpy
Expand Down Expand Up @@ -163,4 +163,4 @@ def point_set(self):

@cached_property
def weight_expression(self):
return reduce(gem.Product, (q.weight_expression for q in self.factors))
return gem.Product(*(q.weight_expression for q in self.factors))
66 changes: 45 additions & 21 deletions finat/quadrature_element.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from finat.point_set import UnknownPointSet
from functools import reduce
from finat.point_set import UnknownPointSet, FacetPointSet

import numpy

Expand All @@ -13,7 +12,7 @@
from finat.quadrature import make_quadrature, AbstractQuadratureRule


def make_quadrature_element(fiat_ref_cell, degree, scheme="default"):
def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0):
"""Construct a :class:`QuadratureElement` from a given a reference
element, degree and scheme.

Expand All @@ -23,9 +22,16 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default"):
integrate exactly.
:param scheme: The quadrature scheme to use - e.g. "default",
"canonical" or "KMV".
:param codim: The codimension of the quadrature scheme.
:returns: The appropriate :class:`QuadratureElement`
"""
rule = make_quadrature(fiat_ref_cell, degree, scheme=scheme)
if codim > 0:
sd = fiat_ref_cell.get_spatial_dimension()
rule_ref_cell = fiat_ref_cell.construct_subelement(sd - codim)
else:
rule_ref_cell = fiat_ref_cell

rule = make_quadrature(rule_ref_cell, degree, scheme=scheme)
return QuadratureElement(fiat_ref_cell, rule)


Expand All @@ -42,8 +48,6 @@ def __init__(self, fiat_ref_cell, rule):
self.cell = fiat_ref_cell
if not isinstance(rule, AbstractQuadratureRule):
raise TypeError("rule is not an AbstractQuadratureRule")
if fiat_ref_cell.get_spatial_dimension() != rule.point_set.dimension:
raise ValueError("Cell dimension does not match rule's point set dimension")
self._rule = rule

@cached_property
Expand All @@ -64,10 +68,18 @@ def formdegree(self):

@cached_property
def _entity_dofs(self):
# Inspired by ffc/quadratureelement.py
top = self.cell.get_topology()
entity_dofs = {dim: {entity: [] for entity in entities}
for dim, entities in self.cell.get_topology().items()}
entity_dofs[self.cell.get_dimension()] = {0: list(range(self.space_dimension()))}
for dim, entities in top.items()}
ps = self._rule.point_set
num_pts = len(ps.points)
to_int = lambda x: sum(x) if isinstance(x, tuple) else x
cur = 0
for dim in sorted(top):
if to_int(dim) == ps.dimension:
for entity in sorted(top[dim]):
entity_dofs[dim][entity].extend(range(cur, cur + num_pts))
cur += num_pts
return entity_dofs

def entity_dofs(self):
Expand All @@ -76,9 +88,15 @@ def entity_dofs(self):
def space_dimension(self):
return numpy.prod(self.index_shape, dtype=int)

@cached_property
def _point_set(self):
ps = self._rule.point_set
sd = self.cell.get_spatial_dimension()
return ps if ps.dimension == sd else FacetPointSet(self.cell, ps)

@property
def index_shape(self):
ps = self._rule.point_set
ps = self._point_set
return tuple(index.extent for index in ps.indices)

@property
Expand All @@ -87,7 +105,7 @@ def value_shape(self):

@cached_property
def fiat_equivalent(self):
ps = self._rule.point_set
ps = self._point_set
if isinstance(ps, UnknownPointSet):
raise ValueError("A quadrature element with rule with runtime points has no fiat equivalent!")
weights = getattr(self._rule, 'weights', None)
Expand All @@ -107,8 +125,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set object.
:param entity: the cell entity on which to tabulate.
'''
if entity is not None and entity != (self.cell.get_dimension(), 0):
raise ValueError('QuadratureElement does not "tabulate" on subentities.')
rule_dim = self._rule.point_set.dimension
if entity is None:
entity = (rule_dim, 0)
entity_dim, entity_id = entity
if entity_dim != rule_dim:
raise ValueError(f"Cannot tabulate QuadratureElement of dimension {rule_dim}"
f" on subentities of dimension {entity_dim}.")

if order:
raise ValueError("Derivatives are not defined on a QuadratureElement.")
Expand All @@ -117,24 +140,25 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
raise ValueError("Mismatch of quadrature points!")

# Return an outer product of identity matrices
multiindex = self.get_indices()
product = reduce(gem.Product, [gem.Delta(q, r)
for q, r in zip(ps.indices, multiindex)])
basis_indices = self.get_indices()
point_indices = ps.indices
if len(basis_indices) > len(point_indices):
point_indices = (entity_id, *point_indices)
delta = gem.Delta(point_indices, basis_indices)

dim = self.cell.get_spatial_dimension()
return {(0,) * dim: gem.ComponentTensor(product, multiindex)}
sd = self.cell.get_spatial_dimension()
return {(0,) * sd: gem.ComponentTensor(delta, basis_indices)}

def point_evaluation(self, order, refcoords, entity=None):
raise NotImplementedError("QuadratureElement cannot do point evaluation!")

@property
def dual_basis(self):
ps = self._rule.point_set
ps = self._point_set
multiindex = self.get_indices()
# Evaluation matrix is just an outer product of identity
# matrices, evaluation points are just the quadrature points.
Q = reduce(gem.Product, (gem.Delta(q, r)
for q, r in zip(ps.indices, multiindex)))
Q = gem.Delta(ps.indices, multiindex)
Q = gem.ComponentTensor(Q, multiindex)
return Q, ps

Expand Down
4 changes: 2 additions & 2 deletions finat/sympy2gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def sympy2gem_expr(node, self):
@sympy2gem.register(sympy.Add)
@sympy2gem.register(symengine.Add)
def sympy2gem_add(node, self):
return reduce(gem.Sum, map(self, node.args))
return gem.Sum(*map(self, node.args))


@sympy2gem.register(sympy.Mul)
@sympy2gem.register(symengine.Mul)
def sympy2gem_mul(node, self):
return reduce(gem.Product, map(self, node.args))
return gem.Product(*map(self, node.args))


@sympy2gem.register(sympy.Pow)
Expand Down
38 changes: 16 additions & 22 deletions finat/tensor_product.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import reduce
from itertools import chain, product
from operator import methodcaller

Expand Down Expand Up @@ -32,11 +31,11 @@ def __init__(self, factors):

@cached_property
def cell(self):
return TensorProductCell(*[fe.cell for fe in self.factors])
return TensorProductCell(*(fe.cell for fe in self.factors))

@cached_property
def complex(self):
return TensorProductCell(*[fe.complex for fe in self.factors])
return TensorProductCell(*(fe.complex for fe in self.factors))

@property
def degree(self):
Expand Down Expand Up @@ -69,7 +68,7 @@ def space_dimension(self):

@property
def index_shape(self):
return tuple(chain(*[fe.index_shape for fe in self.factors]))
return tuple(chain.from_iterable(fe.index_shape for fe in self.factors))

@property
def value_shape(self):
Expand Down Expand Up @@ -117,24 +116,20 @@ def _merge_evaluations(self, factor_results):
# multiindex describing the value shape of the subelement.
zetas = [fe.get_value_indices() for fe in self.factors]

multiindex = tuple(chain(*alphas, *zetas))
result = {}
for derivative in range(order + 1):
for Delta in mis(dimension, derivative):
# Split the multiindex for the subelements
deltas = [Delta[s] for s in dim_slices]
# GEM scalars (can have free indices) for collecting
# the contributions from the subelements.
scalars = []
for fr, delta, alpha, zeta in zip(factor_results, deltas, alphas, zetas):
# Turn basis shape to free indices, select the
# right derivative entry, and collect the result.
scalars.append(gem.Indexed(fr[delta], alpha + zeta))
# Multiply the values from the subelements and wrap up
# non-point indices into shape.
result[Delta] = gem.ComponentTensor(
reduce(gem.Product, scalars),
tuple(chain(*(alphas + zetas)))
)
# Multiply the values from the subelements
# Turn basis shape to free indices, select the
# right derivative entry.
scalar = gem.Product(*(gem.Indexed(fr[delta], alpha + zeta)
for fr, delta, alpha, zeta
in zip(factor_results, deltas, alphas, zetas)))
# Wrap up non-point indices into shape.
result[Delta] = gem.ComponentTensor(scalar, multiindex)
return result

def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
Expand Down Expand Up @@ -179,12 +174,11 @@ def dual_basis(self):
# Naming as _merge_evaluations above
alphas = [factor.get_indices() for factor in self.factors]
zetas = [factor.get_value_indices() for factor in self.factors]
# Index the factors by so that we can reshape into index-shape
# followed by value-shape
qis = [q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas)]
Q = gem.ComponentTensor(
reduce(gem.Product, qis),
tuple(chain(*(alphas + zetas)))
# Index the factors by basis function and component so that we can reshape
# into index-shape followed by value-shape
gem.Product(*(q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas))),
tuple(chain(*alphas, *zetas))
)
return Q, ps

Expand Down
Loading