Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
57e610c
GEM tabulations
pbrubeck Aug 14, 2025
e078883
Remove sympy from point_evaluation
pbrubeck Aug 20, 2025
2f7fb1c
Reuse code for tabulation at known and unknown points
pbrubeck Aug 20, 2025
2c54e72
GEM barycentric interpolation
pbrubeck Aug 21, 2025
26113d8
Fix cellwise constant case
pbrubeck Aug 21, 2025
7243809
add tests
pbrubeck Aug 21, 2025
fdefbbd
Fix transpose
pbrubeck Aug 21, 2025
9770f83
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Aug 21, 2025
6868df5
Fixup
pbrubeck Aug 22, 2025
4d528b0
Node.__neg__
pbrubeck Aug 26, 2025
572af82
Do not use barycentric interpolation at unknown points
pbrubeck Aug 26, 2025
d82ba64
comments
pbrubeck Aug 27, 2025
d743cb8
Grab Variables from ps.expression
pbrubeck Aug 27, 2025
accabdb
Fix up
pbrubeck Aug 27, 2025
3225043
Evaluate FlexiblyIndexed
pbrubeck Aug 27, 2025
8ab7002
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Aug 28, 2025
4d00815
replace indices symbolically
pbrubeck Aug 28, 2025
6cfe146
docstring
pbrubeck Aug 28, 2025
259787d
Restore gem/interpreter.py
pbrubeck Aug 28, 2025
1044442
Implement Node.__pow__
pbrubeck Aug 28, 2025
5b91496
style
pbrubeck Aug 28, 2025
394e8e0
Fix cellwise constant for non-simplices
pbrubeck Aug 28, 2025
fd0ae42
Implement Literal.__bool__
pbrubeck Aug 29, 2025
08d410f
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 15, 2025
36575b5
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 16, 2025
e7ccbb0
Update FIAT/barycentric_interpolation.py
pbrubeck Oct 18, 2025
b290e56
Update FIAT/expansions.py
pbrubeck Oct 18, 2025
8868ea1
import partial
pbrubeck Oct 18, 2025
a76df13
Update gem/gem.py
pbrubeck Oct 18, 2025
2b5eb2a
simplify point_evaluation
pbrubeck Oct 23, 2025
a935cf2
rounding
pbrubeck Oct 23, 2025
26d1308
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 23, 2025
72eb7a1
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Dec 15, 2025
48f1ef8
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Jan 30, 2026
e849774
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Jan 30, 2026
b7af1e1
Enforce constant DG0 in gem tabulation
pbrubeck Jan 30, 2026
fdc58d0
dtype
pbrubeck Jan 30, 2026
b500793
dtype
pbrubeck Jan 30, 2026
1d8b2d1
dtype
pbrubeck Jan 30, 2026
e64f161
dtype
pbrubeck Jan 31, 2026
c61b8e7
merge conflict
pbrubeck Feb 4, 2026
a78b8a4
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 6, 2026
2d3277f
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 6, 2026
21656e0
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 11, 2026
d3bdf9d
Drop unrelated changes
pbrubeck Feb 11, 2026
166ad2c
restore barycentric interpolation
pbrubeck Feb 11, 2026
4293cbe
restore expansions
pbrubeck Feb 12, 2026
37896ac
bump tolerance
pbrubeck Feb 14, 2026
2e5737a
Merge branch 'pbrubeck/gem-tabulation' of github.com:firedrakeproject…
pbrubeck Feb 17, 2026
ac70b8f
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 17, 2026
537f1fa
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 18, 2026
5c231cc
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions FIAT/barycentric_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,24 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
"""
if pts.dtype == object:
from sympy import simplify
sp_simplify = numpy.vectorize(simplify)
# Do not use barycentric interpolation at unknown points
phi = numpy.add.outer(-nodes, pts.flatten())
phis = [wi * numpy.prod(phi[:i], axis=0) * numpy.prod(phi[i+1:], axis=0) for i, wi in enumerate(wts)]
phi = numpy.asarray(phis)
else:
sp_simplify = lambda x: x
phi = numpy.add.outer(-nodes, pts.flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.reciprocal(phi, out=phi)
numpy.multiply(phi, wts[:, None], out=phi)
numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi)
phi[phi != phi] = 1.0
phi = phi.reshape(-1, *pts.shape[:-1])
# Use the second barycentric interpolation formula
phi = numpy.add.outer(-nodes, pts.flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.reciprocal(phi, out=phi)
numpy.multiply(phi, wts[:, None], out=phi)
numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi)
# Replace nan with one
phi[phi != phi] = 1.0

phi = sp_simplify(phi)
phi = phi.reshape(-1, *pts.shape[:-1])
results = {(0,): phi}
for r in range(1, order+1):
phi = sp_simplify(numpy.dot(dmat, phi))
phi = numpy.dot(dmat, phi)
results[(r,)] = phi
return results

Expand Down
24 changes: 14 additions & 10 deletions FIAT/expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,19 @@ def dubiner_recurrence(dim, n, order, ref_pts, Jinv, scale, variant=None):
scale = -scale

num_members = math.comb(n + dim, dim)
results = tuple([None] * num_members for i in range(order+1))
phi, dphi, ddphi = results + (None,) * (2-order)

outer = lambda x, y: x[:, None, ...] * y[None, ...]
sym_outer = lambda x, y: outer(x, y) + outer(y, x)

pad_dim = dim + 2
dX = pad_jacobian(Jinv, pad_dim)
phi[0] = sum((ref_pts[i] - ref_pts[i] for i in range(dim)), scale)
if dphi is not None:
dphi[0] = (phi[0] - phi[0]) * dX[0]
if ddphi is not None:
ddphi[0] = outer(dphi[0], dX[0])

phi0 = numpy.array([sum((ref_pts[i] - ref_pts[i] for i in range(dim)), 0.0)])
results = [numpy.zeros((num_members,) + (dim,)*k + phi0.shape[1:], dtype=phi0.dtype)
for k in range(order+1)]

phi, dphi, ddphi = results + [None] * (2-order)
phi[0] += scale
if dim == 0 or n == 0:
return results
if dim > 3 or dim < 0:
Expand Down Expand Up @@ -183,7 +183,7 @@ def C0_basis(dim, n, tabulations):
# Recover facet bubbles
for phi in tabulations:
icur = 0
phi[icur] *= -1
phi[icur] *= -1.0
for inext in range(1, dim+1):
phi[icur] -= phi[inext]
if dim == 2:
Expand Down Expand Up @@ -723,11 +723,15 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12):
:kwarg tol: the absolute tolerance.
:returns: a list of (weighted) characteristic functions for each subcell.
"""
from sympy import Piecewise
import gem
sd = ref_el.get_spatial_dimension()
top = ref_el.get_topology()
# assert singleton point
pt = pt.reshape((sd,))
if isinstance(pt[0], gem.Node):
import gem as backend
else:
import sympy as backend

# The distance to the nearest cell is equal to the distance to the parent cell
best = ref_el.get_parent().distance_to_point_l1(pt, rescale=True)
Expand All @@ -739,7 +743,7 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12):
for cell in sorted(top[sd]):
# Bin points based on l1 distance
pt_near_cell = ref_el.distance_to_point_l1(pt, entity=(sd, cell), rescale=True) < tol
masks.append(Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True)))
masks.append(backend.Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True)))
if unique:
otherwise.append((0.0, pt_near_cell))
# If the point is on a facet, divide the characteristic function by the facet multiplicity
Expand Down
6 changes: 3 additions & 3 deletions finat/argyris.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _edge_transform(V, vorder, eorder, fiat_cell, coordinate_mapping, avg=False)
V[s, v1id] = P1 * Bnt
V[s, v0id] = P0 * Bnt
if k > 0:
V[s, s + eorder] = -1 * Bnt
V[s, s + eorder] = -Bnt


class Argyris(PhysicallyMappedElement, ScalarFiatElement):
Expand Down Expand Up @@ -167,7 +167,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v1id] = 15/8 * Bnt
V[s, v0id] = -1 * V[s, v1id]
V[s, v0id] = -V[s, v1id]

# vertex derivatives
for i in range(sd):
Expand All @@ -178,7 +178,7 @@ def basis_transformation(self, coordinate_mapping):
tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]]
for i in range(len(tau)):
V[s, v1id+3+i] = 1/32 * Bnt * tau[i]
V[s, v0id+3+i] = -1 * V[s, v1id+3+i]
V[s, v0id+3+i] = -V[s, v1id+3+i]

# Patch up conditioning
h = coordinate_mapping.cell_size()
Expand Down
4 changes: 2 additions & 2 deletions finat/bell.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v1id] = 1/21 * Bnt
V[s, v0id] = -1 * V[s, v1id]
V[s, v0id] = -V[s, v1id]

# vertex derivatives
for i in range(sd):
Expand All @@ -55,7 +55,7 @@ def basis_transformation(self, coordinate_mapping):
tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]]
for i in range(len(tau)):
V[s, v1id+3+i] = 1/252 * Bnt * tau[i]
V[s, v0id+3+i] = -1 * V[s, v1id+3+i]
V[s, v0id+3+i] = -V[s, v1id+3+i]

# Patch up conditioning
h = coordinate_mapping.cell_size()
Expand Down
145 changes: 56 additions & 89 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import FIAT
import gem
import numpy as np
import sympy as sp
from gem.utils import cached_property

from finat.finiteelementbase import FiniteElementBase
from finat.point_set import PointSet
from finat.sympy2gem import sympy2gem
from finat.point_set import PointSet, PointSingleton


class FiatElement(FiniteElementBase):
Expand Down Expand Up @@ -67,62 +65,61 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set.
:param entity: the cell entity on which to tabulate.
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
fiat_element = self._element
fiat_result = fiat_element.tabulate(order, ps.points, entity)
# In almost all cases, we have
# self.space_dimension() == self._element.space_dimension()
# But for Bell, FIAT reports 21 basis functions,
# but FInAT only 18 (because there are actually 18
# basis functions, and the additional 3 are for
# dealing with transformations between physical
# and reference space).
index_shape = (self._element.space_dimension(),)
value_shape = self.value_shape
space_dimension = fiat_element.space_dimension()
if self.space_dimension() == space_dimension:
beta = self.get_indices()
index_shape = tuple(index.extent for index in beta)
else:
index_shape = (space_dimension,)
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
basis_indices = beta + zeta

result = {}
for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
result[alpha] = gem.Failure(index_shape + value_shape, fiat_table)
continue

point_indices = ()
replace_indices = ()
derivative = sum(alpha)
shp = (space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1)

exprs = []
for table in table_roll:
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
exprs.append(gem.Literal(table[0]))
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(table, 0.0)
exprs.append(gem.Literal(np.zeros(self.index_shape)))
if derivative == self.degree and self.complex.is_simplex():
# Ensure a cellwise constant tabulation
if fiat_table.dtype == object:
replace_indices = tuple((i, 0) for i in ps.expression.free_indices)
else:
point_indices = ps.indices
point_shape = tuple(index.extent for index in point_indices)

exprs.append(gem.partial_indexed(
gem.Literal(table.reshape(point_shape + index_shape)),
point_indices
))
if self.value_shape:
# As above, this extent may be different from that
# advertised by the finat element.
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(self.value_shape)),
zeta),
beta + zeta
)
fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1)
assert np.allclose(fiat_table, fiat_table[..., 0, None])
fiat_table = fiat_table[..., 0]
elif derivative > self.degree:
# Ensure a zero tabulation
if fiat_table.dtype != object:
assert np.allclose(fiat_table, 0.0)
fiat_table = np.zeros(index_shape + value_shape)
else:
expr, = exprs
result[alpha] = expr
point_indices = ps.indices

point_shape = tuple(i.extent for i in point_indices)
fiat_table = fiat_table.reshape(index_shape + value_shape + point_shape)
gem_table = gem.as_gem(fiat_table)
expr = gem.Indexed(gem_table, basis_indices + point_indices)
expr = gem.ComponentTensor(expr, basis_indices)
if replace_indices:
expr, = gem.optimise.remove_componenttensors((expr,), subst=replace_indices)
result[alpha] = expr
return result

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
Expand All @@ -147,7 +144,20 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non
esd = self.cell.construct_subelement(entity_dim).get_spatial_dimension()
assert isinstance(refcoords, gem.Node) and refcoords.shape == (esd,)

return point_evaluation(self._element, order, refcoords, (entity_dim, entity_i))
# Coordinates on the reference entity (GEM)
Xi = tuple(gem.Indexed(refcoords, i) for i in np.ndindex(refcoords.shape))
ps = PointSingleton(Xi)
result = self.basis_evaluation(order, ps, entity=entity,
coordinate_mapping=coordinate_mapping)

# Apply symbolic simplification
vals = result.values()
vals = map(gem.optimise.ffc_rounding, vals, [1E-13]*len(vals))
vals = gem.optimise.constant_fold_zero(vals)
vals = map(gem.optimise.aggressive_unroll, vals)
vals = gem.optimise.remove_componenttensors(vals)
result = dict(zip(result.keys(), vals))
return result

@cached_property
def _dual_basis(self):
Expand Down Expand Up @@ -260,49 +270,6 @@ def mapping(self):
return result


def point_evaluation(fiat_element, order, refcoords, entity):
# Coordinates on the reference entity (SymPy)
esd, = refcoords.shape
Xi = sp.symbols('X Y Z')[:esd]

space_dimension = fiat_element.space_dimension()
value_size = np.prod(fiat_element.value_shape(), dtype=int)
fiat_result = fiat_element.tabulate(order, [Xi], entity)
result = {}
for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure((space_dimension,) + fiat_element.value_shape(), fiat_table)
continue

# Convert SymPy expression to GEM
mapper = gem.node.Memoizer(sympy2gem)
mapper.bindings = {s: gem.Indexed(refcoords, (i,))
for i, s in enumerate(Xi)}
gem_table = np.vectorize(mapper)(fiat_table)

table_roll = gem_table.reshape(space_dimension, value_size).transpose()

exprs = []
for table in table_roll:
exprs.append(gem.ListTensor(table.reshape(space_dimension)))
if fiat_element.value_shape():
beta = (gem.Index(extent=space_dimension),)
zeta = tuple(gem.Index(extent=d)
for d in fiat_element.value_shape())
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(fiat_element.value_shape())),
zeta),
beta + zeta
)
else:
expr, = exprs
result[alpha] = expr
return result


class Regge(FiatElement): # symmetric matrix valued
def __init__(self, cell, degree, **kwargs):
super().__init__(FIAT.Regge(cell, degree, **kwargs))
Expand Down
2 changes: 1 addition & 1 deletion finat/hct.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v0id] = 1/5 * Bnt
V[s, v1id] = -1 * V[s, v0id]
V[s, v1id] = -V[s, v0id]

# vertex derivatives
for i in range(sd):
Expand Down
4 changes: 0 additions & 4 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
result = super().basis_evaluation(order, ps, entity=entity)
return self.map_tabulation(result, coordinate_mapping)

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
result = super().point_evaluation(order, refcoords, entity=entity)
return self.map_tabulation(result, coordinate_mapping)

def dual_transformation(self, Q, coordinate_mapping=None):
M = self.basis_transformation(coordinate_mapping)

Expand Down
2 changes: 1 addition & 1 deletion finat/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def points(self):

@cached_property
def expression(self):
return gem.Literal(self.point)
return gem.as_gem(self.point)


class UnknownPointsArray():
Expand Down
14 changes: 1 addition & 13 deletions finat/sympy2gem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import singledispatch, reduce

import numpy
import sympy
try:
import symengine
Expand Down Expand Up @@ -130,18 +129,7 @@ def sympy2gem_le(node, self):
@sympy2gem.register(sympy.Piecewise)
@sympy2gem.register(symengine.Piecewise)
def sympy2gem_conditional(node, self):
expr = None
pieces = []
for v, c in node.args:
if isinstance(c, (bool, numpy.bool, sympy.logic.boolalg.BooleanTrue)) and c:
expr = self(v)
break
pieces.append((v, c))
if expr is None:
expr = gem.Literal(float("nan"))
for v, c in reversed(pieces):
expr = gem.Conditional(self(c), self(v), expr)
return expr
return gem.Piecewise(*[(self(v), self(c)) for v, c in node.args])


@sympy2gem.register(sympy.ITE)
Expand Down
Loading