From 7f2dd72954980cf24a85849683d3c9096217881a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 9 Jan 2025 20:36:51 +0000 Subject: [PATCH 01/24] Boundary Quadrature element --- FIAT/polynomial_set.py | 2 +- finat/element_factory.py | 7 ++-- finat/fiat_elements.py | 7 ++-- finat/point_set.py | 7 ++-- finat/quadrature_element.py | 67 ++++++++++++++++++++++++++++--------- 5 files changed, 62 insertions(+), 28 deletions(-) diff --git a/FIAT/polynomial_set.py b/FIAT/polynomial_set.py index 08e62b8d8..be97e43df 100644 --- a/FIAT/polynomial_set.py +++ b/FIAT/polynomial_set.py @@ -69,7 +69,7 @@ def tabulate_new(self, pts): def tabulate(self, pts, jet_order=0): """Returns the values of the polynomial set.""" base_vals = self.expansion_set._tabulate(self.embedded_degree, pts, order=jet_order) - result = {alpha: numpy.dot(self.coeffs, base_vals[alpha]) for alpha in base_vals} + result = {alpha: numpy.tensordot(self.coeffs, base_vals[alpha], (-1, 0)) for alpha in base_vals} return result def get_expansion_set(self): diff --git a/finat/element_factory.py b/finat/element_factory.py index 48db428d8..b95043b25 100644 --- a/finat/element_factory.py +++ b/finat/element_factory.py @@ -149,13 +149,14 @@ def convert(element, **kwargs): @convert.register(finat.ufl.FiniteElement) def convert_finiteelement(element, **kwargs): cell = as_fiat_cell(element.cell) - if element.family() == "Quadrature": + if element.family() in {"Quadrature", "Boundary Quadrature"}: degree = element.degree() - scheme = element.quadrature_scheme() + scheme = element.quadrature_scheme() or "default" if degree is None or scheme is None: raise ValueError("Quadrature scheme and degree must be specified!") - return finat.make_quadrature_element(cell, degree, scheme), set() + codim = 1 if element.family() == "Boundary Quadrature" else 0 + return finat.make_quadrature_element(cell, degree, scheme, codim), set() lmbda = supported_elements[element.family()] if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}: lmbda = None diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 0203a3a7f..6a1a2fc18 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -100,7 +100,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): ''' space_dimension = self._element.space_dimension() value_size = np.prod(self._element.value_shape(), dtype=int) - fiat_result = self._element.tabulate(order, ps.points, entity) + fiat_result = self._element.tabulate(order, ps.points.reshape(-1, ps.points.shape[-1]), entity) result = {} # In almost all cases, we have # self.space_dimension() == self._element.space_dimension() @@ -116,9 +116,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): continue derivative = sum(alpha) - table_roll = fiat_table.reshape( - space_dimension, value_size, len(ps.points) - ).transpose(1, 2, 0) + table = fiat_table.reshape(space_dimension, value_size, *ps.points.shape[:-1]) + table_roll = np.moveaxis(table, 0, -1) exprs = [] for table in table_roll: diff --git a/finat/point_set.py b/finat/point_set.py index 1497308c7..82424c04b 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -24,8 +24,7 @@ def points(self): @property def dimension(self): """Point dimension.""" - _, dim = self.points.shape - return dim + return self.points.shape[-1] @abstractproperty def indices(self): @@ -130,7 +129,7 @@ def __init__(self, points): :arg points: A vector of N points of shape (N, D) where D is the dimension of each point.""" points = numpy.asarray(points) - assert len(points.shape) == 2 + assert len(points.shape) > 1 self.points = points @cached_property @@ -139,7 +138,7 @@ def points(self): @cached_property def indices(self): - return (gem.Index(extent=len(self.points)),) + return tuple(gem.Index(extent=e) for e in self.points.shape[:-1]) @cached_property def expression(self): diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index 3f17ec399..a133ef642 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -1,4 +1,4 @@ -from finat.point_set import UnknownPointSet +from finat.point_set import UnknownPointSet, PointSet from functools import reduce import numpy @@ -13,7 +13,7 @@ from finat.quadrature import make_quadrature, AbstractQuadratureRule -def make_quadrature_element(fiat_ref_cell, degree, scheme="default"): +def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0): """Construct a :class:`QuadratureElement` from a given a reference element, degree and scheme. @@ -23,9 +23,16 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default"): integrate exactly. :param scheme: The quadrature scheme to use - e.g. "default", "canonical" or "KMV". + :param codim: The codimension of the quadrature scheme. :returns: The appropriate :class:`QuadratureElement` """ - rule = make_quadrature(fiat_ref_cell, degree, scheme=scheme) + if codim: + sd = fiat_ref_cell.get_spatial_dimension() + rule_ref_cell = fiat_ref_cell.construct_subcomplex(sd - codim) + else: + rule_ref_cell = fiat_ref_cell + + rule = make_quadrature(rule_ref_cell, degree, scheme=scheme) return QuadratureElement(fiat_ref_cell, rule) @@ -42,8 +49,6 @@ def __init__(self, fiat_ref_cell, rule): self.cell = fiat_ref_cell if not isinstance(rule, AbstractQuadratureRule): raise TypeError("rule is not an AbstractQuadratureRule") - if fiat_ref_cell.get_spatial_dimension() != rule.point_set.dimension: - raise ValueError("Cell dimension does not match rule's point set dimension") self._rule = rule @cached_property @@ -64,10 +69,16 @@ def formdegree(self): @cached_property def _entity_dofs(self): - # Inspired by ffc/quadratureelement.py + top = self.cell.get_topology() entity_dofs = {dim: {entity: [] for entity in entities} - for dim, entities in self.cell.get_topology().items()} - entity_dofs[self.cell.get_dimension()] = {0: list(range(self.space_dimension()))} + for dim, entities in top.items()} + ps = self._rule.point_set + dim = ps.dimension + num_pts = len(ps.points) + cur = 0 + for entity in sorted(top[dim]): + entity_dofs[dim][entity] = list(range(cur, cur + num_pts)) + cur += num_pts return entity_dofs def entity_dofs(self): @@ -76,9 +87,22 @@ def entity_dofs(self): def space_dimension(self): return numpy.prod(self.index_shape, dtype=int) + @cached_property + def _point_set(self): + ps = self._rule.point_set + sd = self.cell.get_spatial_dimension() + dim = ps.dimension + if dim != sd: + # Tile the quadrature rule on each subentity + entity_ids = self.entity_dofs() + pts = [self.cell.get_entity_transform(dim, entity)(ps.points) + for entity in entity_ids[dim]] + ps = PointSet(numpy.stack(pts, axis=0)) + return ps + @property def index_shape(self): - ps = self._rule.point_set + ps = self._point_set return tuple(index.extent for index in ps.indices) @property @@ -87,7 +111,7 @@ def value_shape(self): @cached_property def fiat_equivalent(self): - ps = self._rule.point_set + ps = self._point_set if isinstance(ps, UnknownPointSet): raise ValueError("A quadrature element with rule with runtime points has no fiat equivalent!") weights = getattr(self._rule, 'weights', None) @@ -107,8 +131,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set object. :param entity: the cell entity on which to tabulate. ''' - if entity is not None and entity != (self.cell.get_dimension(), 0): - raise ValueError('QuadratureElement does not "tabulate" on subentities.') + rule_dim = self._rule.point_set.dimension + if entity is None: + entity = (rule_dim, 0) + entity_dim, entity_id = entity + if entity_dim != rule_dim: + raise ValueError(f"Cannot tabulate QuadratureElement of dimension {rule_dim}" + f" on subentities of dimension {entity_dim}.") if order: raise ValueError("Derivatives are not defined on a QuadratureElement.") @@ -119,17 +148,23 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # Return an outer product of identity matrices multiindex = self.get_indices() product = reduce(gem.Product, [gem.Delta(q, r) - for q, r in zip(ps.indices, multiindex)]) + for q, r in zip(ps.indices, multiindex[-len(ps.indices):])]) - dim = self.cell.get_spatial_dimension() - return {(0,) * dim: gem.ComponentTensor(product, multiindex)} + sd = self.cell.get_spatial_dimension() + if sd != ps.dimension: + data = numpy.zeros(self.index_shape[:-1], dtype=object) + data[...] = gem.Zero() + data[entity_id] = gem.Literal(1) + product = gem.Product(product, gem.Indexed(gem.ListTensor(data), multiindex[:1])) + + return {(0,) * sd: gem.ComponentTensor(product, multiindex)} def point_evaluation(self, order, refcoords, entity=None): raise NotImplementedError("QuadratureElement cannot do point evaluation!") @property def dual_basis(self): - ps = self._rule.point_set + ps = self._point_set multiindex = self.get_indices() # Evaluation matrix is just an outer product of identity # matrices, evaluation points are just the quadrature points. From 3f871a6a5137060ff7dfa9cc4ae60448539ff4cc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 10 Jan 2025 19:05:30 +0000 Subject: [PATCH 02/24] Introduce MappedPointSet --- finat/fiat_elements.py | 2 +- finat/point_set.py | 59 ++++++++++++++++++++++++++++++++++--- finat/quadrature_element.py | 22 ++++---------- finat/tensor_product.py | 11 ++++++- gem/gem.py | 10 +++++++ 5 files changed, 82 insertions(+), 22 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 6a1a2fc18..11f49842c 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -100,7 +100,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): ''' space_dimension = self._element.space_dimension() value_size = np.prod(self._element.value_shape(), dtype=int) - fiat_result = self._element.tabulate(order, ps.points.reshape(-1, ps.points.shape[-1]), entity) + fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have # self.space_dimension() == self._element.space_dimension() diff --git a/finat/point_set.py b/finat/point_set.py index 82424c04b..444f77de0 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -5,6 +5,7 @@ import gem from gem.utils import cached_property +from FIAT.reference_element import make_affine_mapping class AbstractPointSet(metaclass=ABCMeta): @@ -111,8 +112,7 @@ def points(self): @cached_property def indices(self): - N, _ = self._points_expr.shape - return (gem.Index(extent=N),) + return tuple(gem.Index(extent=N) for N in self._points_expr.shape[:-1]) @cached_property def expression(self): @@ -129,7 +129,7 @@ def __init__(self, points): :arg points: A vector of N points of shape (N, D) where D is the dimension of each point.""" points = numpy.asarray(points) - assert len(points.shape) > 1 + assert len(points.shape) == 2 self.points = points @cached_property @@ -138,7 +138,7 @@ def points(self): @cached_property def indices(self): - return tuple(gem.Index(extent=e) for e in self.points.shape[:-1]) + return tuple(gem.Index(extent=N) for N in self.points.shape[:-1]) @cached_property def expression(self): @@ -200,3 +200,54 @@ def almost_equal(self, other, tolerance=1e-12): len(self.factors) == len(other.factors) and \ all(s.almost_equal(o, tolerance=tolerance) for s, o in zip(self.factors, other.factors)) + + +class MappedPointSet(AbstractPointSet): + + def __init__(self, cell, ps): + self.cell = cell + self.ps = ps + + @cached_property + def transforms(self): + top = self.cell.topology + dim = self.ps.dimension + sd = self.cell.get_spatial_dimension() + A = numpy.zeros((len(top[dim]), sd, dim)) + b = numpy.zeros((len(top[dim]), sd)) + ref_verts = self.cell.construct_subelement(dim).vertices + for entity in sorted(top[dim]): + verts = self.cell.get_vertices_of_subcomplex(top[dim][entity]) + A[entity], b[entity] = make_affine_mapping(ref_verts, verts) + return A, b + + @cached_property + def points(self): + x = self.ps.points + A, b = self.transforms + pts = [numpy.add(numpy.dot(x, A[entity].T), b[entity]) + for entity in range(len(A))] + return numpy.concatenate(pts) + + @cached_property + def indices(self): + num_facets = len(self.cell.topology[self.ps.dimension]) + return (gem.Index(extent=num_facets), *self.ps.indices) + + @cached_property + def expression(self): + A, b = self.transforms + x = self.ps.expression + i, *p = self.indices + j, k = (gem.Index(extent=e) for e in A.shape[1:]) + + xpk = gem.Indexed(x, (*p, k)) + Aijk = gem.Indexed(gem.Literal(A), (i, j, k)) + bij = gem.Indexed(gem.Literal(b), (i, j)) + return gem.Sum(gem.IndexSum(Aijk, xpk, (k,)), bij) + + def almost_equal(self, other, tolerance=1e-12): + """Approximate numerical equality of point sets""" + return type(self) is type(other) and \ + self.cell == other.cell and \ + self.ps.almost_equal(other.ps, tolerance=tolerance) diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index a133ef642..6db7b3c44 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -1,4 +1,4 @@ -from finat.point_set import UnknownPointSet, PointSet +from finat.point_set import UnknownPointSet, MappedPointSet from functools import reduce import numpy @@ -91,14 +91,7 @@ def space_dimension(self): def _point_set(self): ps = self._rule.point_set sd = self.cell.get_spatial_dimension() - dim = ps.dimension - if dim != sd: - # Tile the quadrature rule on each subentity - entity_ids = self.entity_dofs() - pts = [self.cell.get_entity_transform(dim, entity)(ps.points) - for entity in entity_ids[dim]] - ps = PointSet(numpy.stack(pts, axis=0)) - return ps + return ps if ps.dimension == sd else MappedPointSet(self.cell, ps) @property def index_shape(self): @@ -147,16 +140,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # Return an outer product of identity matrices multiindex = self.get_indices() + fid = ps.indices + if len(multiindex) > len(fid): + fid = (entity_id, *fid) product = reduce(gem.Product, [gem.Delta(q, r) - for q, r in zip(ps.indices, multiindex[-len(ps.indices):])]) + for q, r in zip(fid, multiindex)]) sd = self.cell.get_spatial_dimension() - if sd != ps.dimension: - data = numpy.zeros(self.index_shape[:-1], dtype=object) - data[...] = gem.Zero() - data[entity_id] = gem.Literal(1) - product = gem.Product(product, gem.Indexed(gem.ListTensor(data), multiindex[:1])) - return {(0,) * sd: gem.ComponentTensor(product, multiindex)} def point_evaluation(self, order, refcoords, entity=None): diff --git a/finat/tensor_product.py b/finat/tensor_product.py index f0fd58477..891c44ba1 100644 --- a/finat/tensor_product.py +++ b/finat/tensor_product.py @@ -13,7 +13,7 @@ from gem.utils import cached_property from finat.finiteelementbase import FiniteElementBase -from finat.point_set import PointSingleton, PointSet, TensorPointSet +from finat.point_set import PointSingleton, PointSet, TensorPointSet, MappedPointSet class TensorProductElement(FiniteElementBase): @@ -138,6 +138,15 @@ def _merge_evaluations(self, factor_results): return result def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): + if isinstance(ps, MappedPointSet): + top = self.cell.topology + evals = [self.basis_evaluation(order, ps.ps, entity=(dim, entity), + coordinate_mapping=coordinate_mapping) + for dim in sorted(top) + for entity in sorted(top[dim]) + if sum(dim) == ps.ps.dimension] + return {key: gem.ListTensor([e[key] for e in evals]) for key in evals[0]} + entities = self._factor_entity(entity) entity_dim, _ = zip(*entities) diff --git a/gem/gem.py b/gem/gem.py index 8369b6f75..f0c681b6f 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -971,6 +971,16 @@ def __new__(cls, i, j, dtype=None): if isinstance(i, int) and isinstance(j, int): return one if i == j else Zero() + if isinstance(i, int): + expr = numpy.full((j.extent), Zero(), dtype=object) + expr[i] = one + return Indexed(ListTensor(expr), (j,)) + + if isinstance(j, int): + expr = numpy.full((i.extent), Zero(), dtype=object) + expr[j] = one + return Indexed(ListTensor(expr), (i,)) + self = super(Delta, cls).__new__(cls) self.i = i self.j = j From 9b2602939be0b8815bb1e2f47c91851bf6fa440e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 11 Jan 2025 20:07:56 +0000 Subject: [PATCH 03/24] GEM: Syntax sugar --- finat/fiat_elements.py | 4 +-- finat/point_set.py | 37 ++++++++------------------- finat/quadrature.py | 3 +-- finat/quadrature_element.py | 19 +++++++------- finat/sympy2gem.py | 4 +-- finat/tensor_product.py | 49 +++++++++++++----------------------- finat/tensorfiniteelement.py | 7 ++---- gem/gem.py | 18 +++++++++++-- gem/optimise.py | 12 ++++----- 9 files changed, 65 insertions(+), 88 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 11f49842c..1a6c8b558 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -116,8 +116,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): continue derivative = sum(alpha) - table = fiat_table.reshape(space_dimension, value_size, *ps.points.shape[:-1]) - table_roll = np.moveaxis(table, 0, -1) + shp = (space_dimension, value_size, *ps.points.shape[:-1]) + table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) exprs = [] for table in table_roll: diff --git a/finat/point_set.py b/finat/point_set.py index 444f77de0..c88093239 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -5,7 +5,6 @@ import gem from gem.utils import cached_property -from FIAT.reference_element import make_affine_mapping class AbstractPointSet(metaclass=ABCMeta): @@ -209,42 +208,28 @@ def __init__(self, cell, ps): self.ps = ps @cached_property - def transforms(self): + def entities(self): + to_int = lambda x: sum(x) if isinstance(x, tuple) else x top = self.cell.topology - dim = self.ps.dimension - sd = self.cell.get_spatial_dimension() - A = numpy.zeros((len(top[dim]), sd, dim)) - b = numpy.zeros((len(top[dim]), sd)) - ref_verts = self.cell.construct_subelement(dim).vertices - for entity in sorted(top[dim]): - verts = self.cell.get_vertices_of_subcomplex(top[dim][entity]) - A[entity], b[entity] = make_affine_mapping(ref_verts, verts) - return A, b + return [(dim, entity) + for dim in sorted(top) + for entity in sorted(top[dim]) + if to_int(dim) == self.ps.dimension] @cached_property def points(self): - x = self.ps.points - A, b = self.transforms - pts = [numpy.add(numpy.dot(x, A[entity].T), b[entity]) - for entity in range(len(A))] + ref_pts = self.ps.points + pts = [self.cell.get_entity_transform(dim, entity)(ref_pts) + for dim, entity in self.entities] return numpy.concatenate(pts) @cached_property def indices(self): - num_facets = len(self.cell.topology[self.ps.dimension]) - return (gem.Index(extent=num_facets), *self.ps.indices) + return (gem.Index(extent=len(self.entities)), *self.ps.indices) @cached_property def expression(self): - A, b = self.transforms - x = self.ps.expression - i, *p = self.indices - j, k = (gem.Index(extent=e) for e in A.shape[1:]) - - xpk = gem.Indexed(x, (*p, k)) - Aijk = gem.Indexed(gem.Literal(A), (i, j, k)) - bij = gem.Indexed(gem.Literal(b), (i, j)) - return gem.Sum(gem.IndexSum(Aijk, xpk, (k,)), bij) + raise NotImplementedError("Should not use MappedPointSet like this") def almost_equal(self, other, tolerance=1e-12): """Approximate numerical equality of point sets""" diff --git a/finat/quadrature.py b/finat/quadrature.py index ec6def127..f8aa33682 100644 --- a/finat/quadrature.py +++ b/finat/quadrature.py @@ -1,5 +1,4 @@ from abc import ABCMeta, abstractproperty -from functools import reduce import gem import numpy @@ -137,4 +136,4 @@ def point_set(self): @cached_property def weight_expression(self): - return reduce(gem.Product, (q.weight_expression for q in self.factors)) + return gem.Product(*(q.weight_expression for q in self.factors)) diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index 6db7b3c44..4e5c570ca 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -1,5 +1,4 @@ from finat.point_set import UnknownPointSet, MappedPointSet -from functools import reduce import numpy @@ -28,7 +27,7 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0): """ if codim: sd = fiat_ref_cell.get_spatial_dimension() - rule_ref_cell = fiat_ref_cell.construct_subcomplex(sd - codim) + rule_ref_cell = fiat_ref_cell.construct_subelement(sd - codim) else: rule_ref_cell = fiat_ref_cell @@ -73,12 +72,14 @@ def _entity_dofs(self): entity_dofs = {dim: {entity: [] for entity in entities} for dim, entities in top.items()} ps = self._rule.point_set - dim = ps.dimension num_pts = len(ps.points) + to_int = lambda x: sum(x) if isinstance(x, tuple) else x cur = 0 - for entity in sorted(top[dim]): - entity_dofs[dim][entity] = list(range(cur, cur + num_pts)) - cur += num_pts + for dim in sorted(top): + if to_int(dim) == ps.dimension: + for entity in sorted(top[dim]): + entity_dofs[dim][entity].extend(range(cur, cur + num_pts)) + cur += num_pts return entity_dofs def entity_dofs(self): @@ -143,8 +144,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): fid = ps.indices if len(multiindex) > len(fid): fid = (entity_id, *fid) - product = reduce(gem.Product, [gem.Delta(q, r) - for q, r in zip(fid, multiindex)]) + product = gem.Delta(fid, multiindex) sd = self.cell.get_spatial_dimension() return {(0,) * sd: gem.ComponentTensor(product, multiindex)} @@ -158,8 +158,7 @@ def dual_basis(self): multiindex = self.get_indices() # Evaluation matrix is just an outer product of identity # matrices, evaluation points are just the quadrature points. - Q = reduce(gem.Product, (gem.Delta(q, r) - for q, r in zip(ps.indices, multiindex))) + Q = gem.Delta(ps.indices, multiindex) Q = gem.ComponentTensor(Q, multiindex) return Q, ps diff --git a/finat/sympy2gem.py b/finat/sympy2gem.py index 29add8760..9d613a307 100644 --- a/finat/sympy2gem.py +++ b/finat/sympy2gem.py @@ -27,13 +27,13 @@ def sympy2gem_expr(node, self): @sympy2gem.register(sympy.Add) @sympy2gem.register(symengine.Add) def sympy2gem_add(node, self): - return reduce(gem.Sum, map(self, node.args)) + return gem.Sum(*map(self, node.args)) @sympy2gem.register(sympy.Mul) @sympy2gem.register(symengine.Mul) def sympy2gem_mul(node, self): - return reduce(gem.Product, map(self, node.args)) + return gem.Product(*map(self, node.args)) @sympy2gem.register(sympy.Pow) diff --git a/finat/tensor_product.py b/finat/tensor_product.py index 891c44ba1..46365f366 100644 --- a/finat/tensor_product.py +++ b/finat/tensor_product.py @@ -1,4 +1,3 @@ -from functools import reduce from itertools import chain, product from operator import methodcaller @@ -13,7 +12,7 @@ from gem.utils import cached_property from finat.finiteelementbase import FiniteElementBase -from finat.point_set import PointSingleton, PointSet, TensorPointSet, MappedPointSet +from finat.point_set import PointSingleton, PointSet, TensorPointSet class TensorProductElement(FiniteElementBase): @@ -32,11 +31,11 @@ def __init__(self, factors): @cached_property def cell(self): - return TensorProductCell(*[fe.cell for fe in self.factors]) + return TensorProductCell(*(fe.cell for fe in self.factors)) @cached_property def complex(self): - return TensorProductCell(*[fe.complex for fe in self.factors]) + return TensorProductCell(*(fe.complex for fe in self.factors)) @property def degree(self): @@ -69,7 +68,7 @@ def space_dimension(self): @property def index_shape(self): - return tuple(chain(*[fe.index_shape for fe in self.factors])) + return tuple(chain.from_iterable(fe.index_shape for fe in self.factors)) @property def value_shape(self): @@ -117,36 +116,23 @@ def _merge_evaluations(self, factor_results): # multiindex describing the value shape of the subelement. zetas = [fe.get_value_indices() for fe in self.factors] + multiindex = tuple(chain(*alphas, *zetas)) result = {} for derivative in range(order + 1): for Delta in mis(dimension, derivative): # Split the multiindex for the subelements deltas = [Delta[s] for s in dim_slices] - # GEM scalars (can have free indices) for collecting - # the contributions from the subelements. - scalars = [] - for fr, delta, alpha, zeta in zip(factor_results, deltas, alphas, zetas): - # Turn basis shape to free indices, select the - # right derivative entry, and collect the result. - scalars.append(gem.Indexed(fr[delta], alpha + zeta)) - # Multiply the values from the subelements and wrap up - # non-point indices into shape. - result[Delta] = gem.ComponentTensor( - reduce(gem.Product, scalars), - tuple(chain(*(alphas + zetas))) - ) + # Multiply the values from the subelements + # Turn basis shape to free indices, select the + # right derivative entry. + scalar = gem.Product(*(gem.Indexed(fr[delta], alpha + zeta) + for fr, delta, alpha, zeta + in zip(factor_results, deltas, alphas, zetas))) + # Wrap up non-point indices into shape. + result[Delta] = gem.ComponentTensor(scalar, multiindex) return result def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): - if isinstance(ps, MappedPointSet): - top = self.cell.topology - evals = [self.basis_evaluation(order, ps.ps, entity=(dim, entity), - coordinate_mapping=coordinate_mapping) - for dim in sorted(top) - for entity in sorted(top[dim]) - if sum(dim) == ps.ps.dimension] - return {key: gem.ListTensor([e[key] for e in evals]) for key in evals[0]} - entities = self._factor_entity(entity) entity_dim, _ = zip(*entities) @@ -188,12 +174,11 @@ def dual_basis(self): # Naming as _merge_evaluations above alphas = [factor.get_indices() for factor in self.factors] zetas = [factor.get_value_indices() for factor in self.factors] - # Index the factors by so that we can reshape into index-shape - # followed by value-shape - qis = [q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas)] Q = gem.ComponentTensor( - reduce(gem.Product, qis), - tuple(chain(*(alphas + zetas))) + # Index the factors by so that we can reshape into index-shape + # followed by value-shape + gem.Product(*(q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas))), + tuple(chain(*alphas, *zetas)) ) return Q, ps diff --git a/finat/tensorfiniteelement.py b/finat/tensorfiniteelement.py index c0a8aa91e..b8c50e158 100644 --- a/finat/tensorfiniteelement.py +++ b/finat/tensorfiniteelement.py @@ -1,4 +1,3 @@ -from functools import reduce from itertools import chain import numpy @@ -134,8 +133,7 @@ def _tensorise(self, scalar_evaluation): tensor_vi = tuple(gem.Index(extent=d) for d in self._shape) # Couple new basis function and value indices - deltas = reduce(gem.Product, (gem.Delta(j, k) - for j, k in zip(tensor_i, tensor_vi))) + deltas = gem.Delta(tensor_i, tensor_vi) if self._transpose: index_ordering = tensor_i + scalar_i + tensor_vi + scalar_vi @@ -163,8 +161,7 @@ def dual_basis(self): tensor_i = tuple(gem.Index(extent=d) for d in self._shape) tensor_vi = tuple(gem.Index(extent=d) for d in self._shape) # Couple new basis function and value indices - deltas = reduce(gem.Product, (gem.Delta(j, k) - for j, k in zip(tensor_i, tensor_vi))) + deltas = gem.Delta(tensor_i, tensor_vi) if self._transpose: index_ordering = tensor_i + scalar_i + tensor_vi + scalar_vi else: diff --git a/gem/gem.py b/gem/gem.py index f0c681b6f..f8f65ea96 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -16,6 +16,7 @@ from abc import ABCMeta from itertools import chain +from functools import reduce from operator import attrgetter from numbers import Integral, Number @@ -324,7 +325,12 @@ def __init__(self, name, shape, dtype=None): class Sum(Scalar): __slots__ = ('children',) - def __new__(cls, a, b): + def __new__(cls, *args): + try: + a, b = args + except ValueError: + # Handle more than two arguments + return reduce(Sum, args) assert not a.shape assert not b.shape @@ -345,7 +351,12 @@ def __new__(cls, a, b): class Product(Scalar): __slots__ = ('children',) - def __new__(cls, a, b): + def __new__(cls, *args): + try: + a, b = args + except ValueError: + # Handle more than two arguments + return reduce(Product, args) assert not a.shape assert not b.shape @@ -960,6 +971,9 @@ class Delta(Scalar, Terminal): __back__ = ('dtype',) def __new__(cls, i, j, dtype=None): + if isinstance(i, tuple) and isinstance(j, tuple): + # Handle multiindices + return Product(*map(Delta, i, j)) assert isinstance(i, IndexBase) assert isinstance(j, IndexBase) diff --git a/gem/optimise.py b/gem/optimise.py index 7d6c8ecd6..6f83271e1 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -2,7 +2,7 @@ expressions.""" from collections import OrderedDict, defaultdict -from functools import singledispatch, partial, reduce +from functools import singledispatch, partial from itertools import combinations, permutations, zip_longest from numbers import Integral @@ -374,7 +374,7 @@ def sum_factorise(sum_indices, factors): # Form groups by free indices groups = groupby(factors, key=lambda f: f.free_indices) - groups = [reduce(Product, terms) for _, terms in groups] + groups = [Product(*terms) for _, terms in groups] # Sum factorisation expression = None @@ -414,7 +414,7 @@ def sum_factorise(sum_indices, factors): def make_sum(summands): """Constructs an operation-minimal sum of GEM expressions.""" groups = groupby(summands, key=lambda f: f.free_indices) - summands = [reduce(Sum, terms) for _, terms in groups] + summands = [Sum(*terms) for _, terms in groups] result, flops = associate(Sum, summands) return result @@ -660,10 +660,8 @@ def _(node, self): # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) - unrolled = reduce(Sum, - (Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape)), - Zero()) + unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha) + for alpha in numpy.ndindex(shape))) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: From 24c6952b2ebe71a554b140ad09519aab5f10b741 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 21 Jan 2025 15:03:27 +0000 Subject: [PATCH 04/24] GEM: Simplify Indexed tensors --- gem/gem.py | 57 ++++++++++++++++++++++++++++++++++++++++--------- gem/node.py | 35 +++++++++++++++--------------- gem/optimise.py | 16 ++++++-------- 3 files changed, 70 insertions(+), 38 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 8369b6f75..01bf2c8d1 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -117,9 +117,8 @@ def __matmul__(self, other): raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") *i, k = indices(len(self.shape)) _, *j = indices(len(other.shape)) - expr = Product(Indexed(self, tuple(i) + (k, )), - Indexed(other, (k, ) + tuple(j))) - return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) + expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) + return ComponentTensor(IndexSum(expr, (k, )), (*i, *j)) def __rmatmul__(self, other): return as_gem(other).__matmul__(self) @@ -307,6 +306,9 @@ def value(self): def shape(self): return self.array.shape + def __getitem__(self, i): + return self.array[i] + class Variable(Terminal): """Symbolic variable tensor""" @@ -674,12 +676,31 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) + + # Simplify Constant and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -825,6 +846,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -881,9 +907,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) @@ -911,6 +945,9 @@ def shape(self): def __reduce__(self): return type(self), (self.array,) + def __getitem__(self, i): + return self.array[i] + def reconstruct(self, *args): return ListTensor(asarray(args).reshape(self.array.shape)) diff --git a/gem/node.py b/gem/node.py index 5d9c5bf04..71f814638 100644 --- a/gem/node.py +++ b/gem/node.py @@ -3,6 +3,7 @@ import collections import gem +from itertools import repeat class Node(object): @@ -36,14 +37,17 @@ def _cons_args(self, children): Internally used utility function. """ - front_args = [getattr(self, name) for name in self.__front__] - back_args = [getattr(self, name) for name in self.__back__] + front_args = (getattr(self, name) for name in self.__front__) + back_args = (getattr(self, name) for name in self.__back__) - return tuple(front_args) + tuple(children) + tuple(back_args) + return (*front_args, *children, *back_args) + + def _arguments(self): + return self._cons_args(self.children) def __reduce__(self): # Gold version: - return type(self), self._cons_args(self.children) + return type(self), self._arguments() def reconstruct(self, *args): """Reconstructs the node with new children from @@ -54,8 +58,7 @@ def reconstruct(self, *args): return type(self)(*self._cons_args(args)) def __repr__(self): - cons_args = self._cons_args(self.children) - return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args))) + return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, self._arguments()))) def __eq__(self, other): """Provides equality testing with quick positive and negative @@ -87,9 +90,7 @@ def is_equal(self, other): """ if type(self) is not type(other): return False - self_consargs = self._cons_args(self.children) - other_consargs = other._cons_args(other.children) - return self_consargs == other_consargs + return self._arguments() == other._arguments() def get_hash(self): """Hash function. @@ -97,7 +98,7 @@ def get_hash(self): This is the method to potentially override in derived classes, not :meth:`__hash__`. """ - return hash((type(self),) + self._cons_args(self.children)) + return hash((type(self), *self._arguments())) def _make_traversal_children(node): @@ -235,8 +236,7 @@ def __call__(self, node): return self.cache[node] except KeyError: result = self.function(node, self) - self.cache[node] = result - return result + return self.cache.setdefault(node, result) class MemoizerArg(object): @@ -259,14 +259,13 @@ def __call__(self, node, arg): return self.cache[cache_key] except KeyError: result = self.function(node, self, arg) - self.cache[cache_key] = result - return result + return self.cache.setdefault(cache_key, result) def reuse_if_untouched(node, self): """Reuse if untouched recipe""" - new_children = list(map(self, node.children)) - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children)) + if new_children == node.children: return node else: return node.reconstruct(*new_children) @@ -274,8 +273,8 @@ def reuse_if_untouched(node, self): def reuse_if_untouched_arg(node, self, arg): """Reuse if touched recipe propagating an extra argument""" - new_children = [self(child, arg) for child in node.children] - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children, repeat(arg))) + if new_children == node.children: return node else: return node.reconstruct(*new_children) diff --git a/gem/optimise.py b/gem/optimise.py index 7d6c8ecd6..6206d360e 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -101,8 +101,7 @@ def _replace_indices_atomic(i, self, subst): new_expr = self(i.expression, subst) return i if new_expr == i.expression else VariableIndex(new_expr) else: - substitute = dict(subst) - return substitute.get(i, i) + return dict(subst).get(i, i) @replace_indices.register(Delta) @@ -117,20 +116,18 @@ def replace_indices_delta(node, self, subst): @replace_indices.register(Indexed) def replace_indices_indexed(node, self, subst): + multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children - substitute = dict(subst) - multiindex = [] - for i in node.multiindex: - multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules + substitute = dict(subst) substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) - if new_child == child and multiindex == node.multiindex: + if multiindex == node.multiindex and new_child == child: return node else: return Indexed(new_child, multiindex) @@ -138,9 +135,6 @@ def replace_indices_indexed(node, self, subst): @replace_indices.register(FlexiblyIndexed) def replace_indices_flexiblyindexed(node, self, subst): - child, = node.children - assert not child.free_indices - dim2idxs = tuple( ( offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), @@ -149,6 +143,8 @@ def replace_indices_flexiblyindexed(node, self, subst): for offset, idxs in node.dim2idxs ) + child, = node.children + assert not child.free_indices if dim2idxs == node.dim2idxs: return node else: From e8ff24baee0711c9fe6cc286e3dd82c854996d39 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 21 Jan 2025 17:56:37 +0000 Subject: [PATCH 05/24] Optimize zany-mapping matvec --- finat/physically_mapped.py | 9 ++++++--- gem/gem.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 4b6c60896..80abc7758 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -270,15 +270,18 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): M = self.basis_transformation(coordinate_mapping) # we expect M to be sparse with O(1) nonzeros per row # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + csr = [[j for j in range(M.shape[1]) if not isinstance(M[i, j], gem.Zero)] for i in range(M.shape[0])] def matvec(table): # basis recombination using hand-rolled sparse-dense matrix multiplication - table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])] + ii = gem.indices(len(table.shape)-1) + phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])] # the sum approach is faster than calling numpy.dot or gem.IndexSum - expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)] + expressions = [gem.ComponentTensor(sum(M[i, j] * phi[j] for j in js), ii) + for i, js in enumerate(csr)] val = gem.ListTensor(expressions) + # val = M @ table return gem.optimise.aggressive_unroll(val) result = super().basis_evaluation(order, ps, entity=entity) diff --git a/gem/gem.py b/gem/gem.py index 01bf2c8d1..8371803ad 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -689,7 +689,7 @@ def __new__(cls, aggregate, multiindex): ll = tuple(rep.get(k, k) for k in kk) return Indexed(C, ll) - # Simplify Constant and ListTensor + # Simplify Literal and ListTensor if isinstance(aggregate, (Constant, ListTensor)): if all(isinstance(i, int) for i in multiindex): # All indices fixed From 783e0c472fac1bcb0a5d4c5c2e947b21102dc661 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 21 Jan 2025 23:52:30 +0000 Subject: [PATCH 06/24] WIP --- finat/physically_mapped.py | 9 +++--- gem/gem.py | 52 +++++++++++++++------------------ gem/optimise.py | 10 +++---- test/finat/test_zany_mapping.py | 5 ++-- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 80abc7758..1e9248dd8 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -270,7 +270,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): M = self.basis_transformation(coordinate_mapping) # we expect M to be sparse with O(1) nonzeros per row # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M[i, j], gem.Zero)] + csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] for i in range(M.shape[0])] def matvec(table): @@ -278,9 +278,10 @@ def matvec(table): ii = gem.indices(len(table.shape)-1) phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])] # the sum approach is faster than calling numpy.dot or gem.IndexSum - expressions = [gem.ComponentTensor(sum(M[i, j] * phi[j] for j in js), ii) - for i, js in enumerate(csr)] - val = gem.ListTensor(expressions) + exprs = [gem.ComponentTensor(sum(M.array[i, j] * phi[j] for j in js), ii) + for i, js in enumerate(csr)] + + val = gem.ListTensor(exprs) # val = M @ table return gem.optimise.aggressive_unroll(val) diff --git a/gem/gem.py b/gem/gem.py index 8371803ad..6ca10419b 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -54,8 +54,8 @@ def __call__(self, *args, **kwargs): # Set free_indices if not set already if not hasattr(obj, 'free_indices'): - obj.free_indices = unique(chain(*[c.free_indices - for c in obj.children])) + obj.free_indices = unique(chain.from_iterable(c.free_indices + for c in obj.children)) # Set dtype if not set already. if not hasattr(obj, 'dtype'): obj.dtype = obj.inherit_dtype_from_children(obj.children) @@ -306,9 +306,6 @@ def value(self): def shape(self): return self.array.shape - def __getitem__(self, i): - return self.array[i] - class Variable(Terminal): """Symbolic variable tensor""" @@ -337,7 +334,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Sum, cls).__new__(cls) self.children = a, b @@ -361,7 +358,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Product, cls).__new__(cls) self.children = a, b @@ -385,7 +382,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Division, cls).__new__(cls) self.children = a, b @@ -398,7 +395,7 @@ class FloorDiv(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children([a, b]) + dtype = Node.inherit_dtype_from_children((a, b)) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})") # Constant folding @@ -421,7 +418,7 @@ class Remainder(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children([a, b]) + dtype = Node.inherit_dtype_from_children((a, b)) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})") # Constant folding @@ -444,7 +441,7 @@ class Power(Scalar): def __new__(cls, base, exponent): assert not base.shape assert not exponent.shape - dtype = Node.inherit_dtype_from_children([base, exponent]) + dtype = Node.inherit_dtype_from_children((base, exponent)) # Constant folding if isinstance(base, Zero): @@ -560,7 +557,7 @@ def __new__(cls, condition, then, else_): self = super(Conditional, cls).__new__(cls) self.children = condition, then, else_ self.shape = then.shape - self.dtype = Node.inherit_dtype_from_children([then, else_]) + self.dtype = Node.inherit_dtype_from_children((then, else_)) return self @@ -676,6 +673,19 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children @@ -689,19 +699,6 @@ def __new__(cls, aggregate, multiindex): ll = tuple(rep.get(k, k) for k in kk) return Indexed(C, ll) - # Simplify Literal and ListTensor - if isinstance(aggregate, (Constant, ListTensor)): - if all(isinstance(i, int) for i in multiindex): - # All indices fixed - sub = aggregate[multiindex] - return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex @@ -945,9 +942,6 @@ def shape(self): def __reduce__(self): return type(self), (self.array,) - def __getitem__(self, i): - return self.array[i] - def reconstruct(self, *args): return ListTensor(asarray(args).reshape(self.array.shape)) @@ -958,7 +952,7 @@ def is_equal(self, other): """Common subexpression eliminating equality predicate.""" if type(self) is not type(other): return False - if (self.array == other.array).all(): + if numpy.array_equal(self.array, other.array): self.array = other.array return True return False diff --git a/gem/optimise.py b/gem/optimise.py index 6206d360e..857e40f01 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -188,7 +188,7 @@ def _constant_fold_zero_listtensor(node, self): new_children = list(map(self, node.children)) if all(isinstance(nc, Zero) for nc in new_children): return Zero(node.shape) - elif all(nc == c for nc, c in zip(new_children, node.children)): + elif new_children == node.children: return node else: return node.reconstruct(*new_children) @@ -207,7 +207,7 @@ def constant_fold_zero(exprs): otherwise Literal `0`s would be reintroduced. """ mapper = Memoizer(_constant_fold_zero) - return [mapper(e) for e in exprs] + return list(map(mapper, exprs)) def _select_expression(expressions, index): @@ -252,9 +252,9 @@ def child(expression): assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 - return expr.reconstruct(*[_select_expression(nth_children, index) - for nth_children in zip(*[e.children - for e in expressions])]) + return expr.reconstruct(*(_select_expression(nth_children, index) + for nth_children in zip(*(e.children + for e in expressions)))) raise NotImplementedError("No rule for factorising expressions of this kind.") diff --git a/test/finat/test_zany_mapping.py b/test/finat/test_zany_mapping.py index d5136e651..6406dbe9b 100644 --- a/test/finat/test_zany_mapping.py +++ b/test/finat/test_zany_mapping.py @@ -3,6 +3,7 @@ import numpy as np import pytest from gem.interpreter import evaluate +from finat.physically_mapped import PhysicallyMappedElement def make_unisolvent_points(element, interior=False): @@ -65,11 +66,11 @@ def check_zany_mapping(element, ref_to_phys, *args, **kwargs): # Zany map the results num_bfs = phys_element.space_dimension() num_dofs = finat_element.space_dimension() - try: + if isinstance(finat_element, PhysicallyMappedElement): Mgem = finat_element.basis_transformation(ref_to_phys) M = evaluate([Mgem])[0].arr ref_vals_zany = np.tensordot(M, ref_vals_piola, (-1, 0)) - except AttributeError: + else: M = np.eye(num_dofs, num_bfs) ref_vals_zany = ref_vals_piola From fdff8da506edd668609491b4df6dfa90f6fcd26e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 22 Jan 2025 16:47:52 +0000 Subject: [PATCH 07/24] Fix for Delta(int, Index) --- gem/gem.py | 12 +----------- gem/optimise.py | 2 +- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index f8f65ea96..782c1e8fd 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -982,19 +982,9 @@ def __new__(cls, i, j, dtype=None): return one # Fixed indices - if isinstance(i, int) and isinstance(j, int): + if isinstance(i, Integral) and isinstance(j, Integral): return one if i == j else Zero() - if isinstance(i, int): - expr = numpy.full((j.extent), Zero(), dtype=object) - expr[i] = one - return Indexed(ListTensor(expr), (j,)) - - if isinstance(j, int): - expr = numpy.full((i.extent), Zero(), dtype=object) - expr[j] = one - return Indexed(ListTensor(expr), (i,)) - self = super(Delta, cls).__new__(cls) self.i = i self.j = j diff --git a/gem/optimise.py b/gem/optimise.py index 6f83271e1..c56ba7fbb 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -623,7 +623,7 @@ def _replace_delta_delta(node, self): return Indexed(Identity(size), (i, j)) else: def expression(index): - if isinstance(index, int): + if isinstance(index, Integral): return Literal(index) elif isinstance(index, VariableIndex): return index.expression From e77ed6f1ff5a1138ec95304685a35db34f45bb6c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 22 Jan 2025 17:43:47 +0000 Subject: [PATCH 08/24] Fix for Delta(int, Index) --- gem/gem.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 782c1e8fd..cf223dcb0 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -985,6 +985,12 @@ def __new__(cls, i, j, dtype=None): if isinstance(i, Integral) and isinstance(j, Integral): return one if i == j else Zero() + if isinstance(i, Integral): + return Indexed(Identity(j.extent), (i, j)) + + if isinstance(j, Integral): + return Indexed(Identity(i.extent), (i, j)) + self = super(Delta, cls).__new__(cls) self.i = i self.j = j From 07a685c2b073309223a68e7233e77a09a4c1ee7a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 26 Jan 2025 15:01:27 +0000 Subject: [PATCH 09/24] finat: Untangle value shape --- finat/fiat_elements.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 0efea5e00..05bf6fc42 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -132,29 +132,25 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): else: point_indices = ps.indices point_shape = tuple(index.extent for index in point_indices) - exprs.append(gem.partial_indexed( gem.Literal(table.reshape(point_shape + index_shape)), point_indices )) + if self.value_shape: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) + exprs = np.reshape(exprs, self.value_shape) + if self.space_dimension() == space_dimension: + beta = self.get_indices() + else: + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) zeta = self.get_value_indices() - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(self.value_shape)), - zeta), - beta + zeta - ) + expr = gem.ComponentTensor(gem.Indexed(gem.ListTensor(exprs), + zeta + beta), beta + zeta) else: expr, = exprs - result[alpha] = expr + result[alpha] = expr return result def point_evaluation(self, order, refcoords, entity=None): From eb35c8956cf1df11b1dada07c311d4ac89cc1e3c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 26 Jan 2025 15:01:53 +0000 Subject: [PATCH 10/24] finat: lazy basis transformation --- finat/physically_mapped.py | 64 +++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 1e9248dd8..d001543f4 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from collections.abc import Mapping import gem import numpy @@ -247,6 +248,45 @@ class NeedsCoordinateMappingElement(metaclass=ABCMeta): pass +class MappedTabulation(Mapping): + """A lazy tabulation dict that applies the basis transformation only + on the requested derivatives.""" + + def __init__(self, M, ref_tabulation): + self.M = M + self.ref_tabulation = ref_tabulation + # we expect M to be sparse with O(1) nonzeros per row + # for each row, get the column index of each nonzero entry + csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + for i in range(M.shape[0])] + self.csr = csr + self._tabulation_cache = {} + + def matvec(self, table): + # basis recombination using hand-rolled sparse-dense matrix multiplication + ii = gem.indices(len(table.shape)-1) + phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])] + # the sum approach is faster than calling numpy.dot or gem.IndexSum + exprs = [gem.ComponentTensor(sum(self.M.array[i, j] * phi[j] for j in js), ii) + for i, js in enumerate(self.csr)] + val = gem.ListTensor(exprs) + # val = self.M @ table + return gem.optimise.aggressive_unroll(val) + + def __getitem__(self, alpha): + try: + return self._tabulation_cache[alpha] + except KeyError: + result = self.matvec(self.ref_tabulation[alpha]) + return self._tabulation_cache.setdefault(alpha, result) + + def __iter__(self): + return iter(self.ref_tabulation) + + def __len__(self): + return len(self.ref_tabulation) + + class PhysicallyMappedElement(NeedsCoordinateMappingElement): """A mixin that applies a "physical" transformation to tabulated basis functions.""" @@ -267,28 +307,10 @@ def basis_transformation(self, coordinate_mapping): def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): assert coordinate_mapping is not None - M = self.basis_transformation(coordinate_mapping) - # we expect M to be sparse with O(1) nonzeros per row - # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] - for i in range(M.shape[0])] - - def matvec(table): - # basis recombination using hand-rolled sparse-dense matrix multiplication - ii = gem.indices(len(table.shape)-1) - phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])] - # the sum approach is faster than calling numpy.dot or gem.IndexSum - exprs = [gem.ComponentTensor(sum(M.array[i, j] * phi[j] for j in js), ii) - for i, js in enumerate(csr)] - - val = gem.ListTensor(exprs) - # val = M @ table - return gem.optimise.aggressive_unroll(val) + ref_tabulation = super().basis_evaluation(order, ps, entity=entity) - result = super().basis_evaluation(order, ps, entity=entity) - - return {alpha: matvec(table) - for alpha, table in result.items()} + M = self.basis_transformation(coordinate_mapping) + return MappedTabulation(M, ref_tabulation) def point_evaluation(self, order, refcoords, entity=None): raise NotImplementedError("TODO: not yet thought about it") From 2acd16de96629fcc41de273170bc355db452c8df Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 26 Jan 2025 17:02:51 +0000 Subject: [PATCH 11/24] coffee: modernize --- finat/fiat_elements.py | 2 ++ gem/coffee.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 05bf6fc42..43dea5b68 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -142,6 +142,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): if self.space_dimension() == space_dimension: beta = self.get_indices() else: + # As above, this extent may be different from that + # advertised by the finat element. beta = tuple(gem.Index(extent=i) for i in index_shape) assert len(beta) == len(self.get_indices()) diff --git a/gem/coffee.py b/gem/coffee.py index f766a4890..12c4d0fcf 100644 --- a/gem/coffee.py +++ b/gem/coffee.py @@ -4,8 +4,7 @@ This file is NOT for code generation as a COFFEE AST. """ -from collections import OrderedDict -import itertools +from itertools import chain, repeat import logging import numpy @@ -58,10 +57,10 @@ def find_optimal_atomics(monomials, linear_indices): :returns: list of atomic GEM expressions """ - atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) + atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials))) def cost(solution): - extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) + extent = sum(map(index_extent, solution, repeat(linear_indices))) # Prefer shorter solutions, but larger extents return (len(solution), -extent) From b6ec0658000e1ad8db2d55c7b5f08a7284677119 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 27 Jan 2025 15:57:23 +0000 Subject: [PATCH 12/24] finat: untangle basis_transformation --- finat/fiat_elements.py | 72 +++++++++++++++++--------------------- finat/physically_mapped.py | 4 ++- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 43dea5b68..19b99192d 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - space_dimension = self._element.space_dimension() - value_size = np.prod(self._element.value_shape(), dtype=int) + value_shape = self.value_shape + value_size = np.prod(value_shape, dtype=int) fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have @@ -109,49 +109,43 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - index_shape = (self._element.space_dimension(),) + space_dimension = self._element.space_dimension() + if self.space_dimension() == space_dimension: + beta = self.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + result_indices = beta + zeta + for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) + result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue derivative = sum(alpha) - table_roll = fiat_table.reshape( - space_dimension, value_size, len(ps.points) - ).transpose(1, 2, 0) - - exprs = [] - for table in table_roll: - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - exprs.append(gem.Literal(table[0])) - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(table, 0.0) - exprs.append(gem.Literal(np.zeros(self.index_shape))) - else: - point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - exprs.append(gem.partial_indexed( - gem.Literal(table.reshape(point_shape + index_shape)), - point_indices - )) - - if self.value_shape: - exprs = np.reshape(exprs, self.value_shape) - if self.space_dimension() == space_dimension: - beta = self.get_indices() - else: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - expr = gem.ComponentTensor(gem.Indexed(gem.ListTensor(exprs), - zeta + beta), beta + zeta) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + + point_indices = () + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = beta + zeta + point_indices + + expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices) + expr = gem.ComponentTensor(expr, result_indices) result[alpha] = expr return result diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index d001543f4..f08234508 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping +from functools import reduce import gem import numpy @@ -267,8 +268,9 @@ def matvec(self, table): ii = gem.indices(len(table.shape)-1) phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])] # the sum approach is faster than calling numpy.dot or gem.IndexSum - exprs = [gem.ComponentTensor(sum(self.M.array[i, j] * phi[j] for j in js), ii) + exprs = [gem.ComponentTensor(reduce(gem.Sum, (self.M.array[i, j] * phi[j] for j in js)), ii) for i, js in enumerate(self.csr)] + val = gem.ListTensor(exprs) # val = self.M @ table return gem.optimise.aggressive_unroll(val) From 223669a53afff05e2a75be4cd9aa826777494eff Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 13 Feb 2025 17:17:51 +0000 Subject: [PATCH 13/24] flake --- finat/quadrature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finat/quadrature.py b/finat/quadrature.py index 66c7dad73..c182a8a9a 100644 --- a/finat/quadrature.py +++ b/finat/quadrature.py @@ -1,6 +1,6 @@ import hashlib from abc import ABCMeta, abstractmethod -from functools import cached_property, reduce +from functools import cached_property import gem import numpy From 329dd465e8ddb9d3f211cb3c199cd9e843616132 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 17 Mar 2025 12:10:38 +0000 Subject: [PATCH 14/24] FiniteElementBase: implement is_dg --- finat/finiteelementbase.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finat/finiteelementbase.py b/finat/finiteelementbase.py index 64a6399d2..b7953510d 100644 --- a/finat/finiteelementbase.py +++ b/finat/finiteelementbase.py @@ -77,6 +77,9 @@ def entity_closure_dofs(self): element.''' return self._entity_closure_dofs + def is_dg(self): + return self.entity_dofs() == self.entity_closure_dofs() + @cached_property def _entity_support_dofs(self): esd = {} From 83a8d7ee87b388559af2a4ef0cc63378ae52f685 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 17 Mar 2025 14:25:00 +0000 Subject: [PATCH 15/24] restore polynomial_set.py --- FIAT/polynomial_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FIAT/polynomial_set.py b/FIAT/polynomial_set.py index 6c0efe147..534c18642 100644 --- a/FIAT/polynomial_set.py +++ b/FIAT/polynomial_set.py @@ -69,7 +69,7 @@ def tabulate_new(self, pts): def tabulate(self, pts, jet_order=0): """Returns the values of the polynomial set.""" base_vals = self.expansion_set._tabulate(self.embedded_degree, pts, order=jet_order) - result = {alpha: numpy.tensordot(self.coeffs, base_vals[alpha], (-1, 0)) for alpha in base_vals} + result = {alpha: numpy.dot(self.coeffs, base_vals[alpha]) for alpha in base_vals} return result def get_expansion_set(self): From 2eadfb714b8b53f20f6f9edf3e6c26ade485b8eb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 18 Mar 2025 13:49:14 +0000 Subject: [PATCH 16/24] Revert "Merge simplify_indexed" This reverts commit f1d0422ace320bcb3d7e9150029ad600a5ad6370, reversing changes made to 223669a53afff05e2a75be4cd9aa826777494eff. --- finat/fiat_elements.py | 75 ++++++++++++++++++--------------- finat/physically_mapped.py | 62 ++++++++------------------- finat/point_set.py | 5 +-- gem/coffee.py | 7 +-- gem/gem.py | 71 +++++++++---------------------- gem/node.py | 35 +++++++-------- gem/optimise.py | 26 +++++++----- test/finat/test_zany_mapping.py | 5 +-- 8 files changed, 118 insertions(+), 168 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 19b99192d..24f4a3d7d 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - value_shape = self.value_shape - value_size = np.prod(value_shape, dtype=int) + space_dimension = self._element.space_dimension() + value_size = np.prod(self._element.value_shape(), dtype=int) fiat_result = self._element.tabulate(order, ps.points, entity) result = {} # In almost all cases, we have @@ -109,44 +109,51 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - space_dimension = self._element.space_dimension() - if self.space_dimension() == space_dimension: - beta = self.get_indices() - index_shape = tuple(index.extent for index in beta) - else: - index_shape = (space_dimension,) - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - result_indices = beta + zeta - + index_shape = (self._element.space_dimension(),) for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) + result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) continue derivative = sum(alpha) - fiat_table = fiat_table.reshape(space_dimension, value_size, -1) - - point_indices = () - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - fiat_table = fiat_table[..., 0] - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(fiat_table, 0.0) - fiat_table = np.zeros(fiat_table.shape[:-1]) + shp = (space_dimension, value_size, *ps.points.shape[:-1]) + table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) + + exprs = [] + for table in table_roll: + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + exprs.append(gem.Literal(table[0])) + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(table, 0.0) + exprs.append(gem.Literal(np.zeros(self.index_shape))) + else: + point_indices = ps.indices + point_shape = tuple(index.extent for index in point_indices) + + exprs.append(gem.partial_indexed( + gem.Literal(table.reshape(point_shape + index_shape)), + point_indices + )) + if self.value_shape: + # As above, this extent may be different from that + # advertised by the finat element. + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + result[alpha] = gem.ComponentTensor( + gem.Indexed( + gem.ListTensor(np.array( + [gem.Indexed(expr, beta) for expr in exprs] + ).reshape(self.value_shape)), + zeta), + beta + zeta + ) else: - point_indices = ps.indices - - point_shape = tuple(index.extent for index in point_indices) - table_shape = index_shape + value_shape + point_shape - table_indices = beta + zeta + point_indices - - expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices) - expr = gem.ComponentTensor(expr, result_indices) - result[alpha] = expr + expr, = exprs + result[alpha] = expr return result def point_evaluation(self, order, refcoords, entity=None): diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index f08234508..4b6c60896 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -1,6 +1,4 @@ from abc import ABCMeta, abstractmethod -from collections.abc import Mapping -from functools import reduce import gem import numpy @@ -249,46 +247,6 @@ class NeedsCoordinateMappingElement(metaclass=ABCMeta): pass -class MappedTabulation(Mapping): - """A lazy tabulation dict that applies the basis transformation only - on the requested derivatives.""" - - def __init__(self, M, ref_tabulation): - self.M = M - self.ref_tabulation = ref_tabulation - # we expect M to be sparse with O(1) nonzeros per row - # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] - for i in range(M.shape[0])] - self.csr = csr - self._tabulation_cache = {} - - def matvec(self, table): - # basis recombination using hand-rolled sparse-dense matrix multiplication - ii = gem.indices(len(table.shape)-1) - phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])] - # the sum approach is faster than calling numpy.dot or gem.IndexSum - exprs = [gem.ComponentTensor(reduce(gem.Sum, (self.M.array[i, j] * phi[j] for j in js)), ii) - for i, js in enumerate(self.csr)] - - val = gem.ListTensor(exprs) - # val = self.M @ table - return gem.optimise.aggressive_unroll(val) - - def __getitem__(self, alpha): - try: - return self._tabulation_cache[alpha] - except KeyError: - result = self.matvec(self.ref_tabulation[alpha]) - return self._tabulation_cache.setdefault(alpha, result) - - def __iter__(self): - return iter(self.ref_tabulation) - - def __len__(self): - return len(self.ref_tabulation) - - class PhysicallyMappedElement(NeedsCoordinateMappingElement): """A mixin that applies a "physical" transformation to tabulated basis functions.""" @@ -309,10 +267,24 @@ def basis_transformation(self, coordinate_mapping): def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): assert coordinate_mapping is not None - ref_tabulation = super().basis_evaluation(order, ps, entity=entity) - M = self.basis_transformation(coordinate_mapping) - return MappedTabulation(M, ref_tabulation) + # we expect M to be sparse with O(1) nonzeros per row + # for each row, get the column index of each nonzero entry + csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + for i in range(M.shape[0])] + + def matvec(table): + # basis recombination using hand-rolled sparse-dense matrix multiplication + table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])] + # the sum approach is faster than calling numpy.dot or gem.IndexSum + expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)] + val = gem.ListTensor(expressions) + return gem.optimise.aggressive_unroll(val) + + result = super().basis_evaluation(order, ps, entity=entity) + + return {alpha: matvec(table) + for alpha, table in result.items()} def point_evaluation(self, order, refcoords, entity=None): raise NotImplementedError("TODO: not yet thought about it") diff --git a/finat/point_set.py b/finat/point_set.py index 3e6737d2b..40ea34de1 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -230,9 +230,6 @@ def __init__(self, cell, ps): self.cell = cell self.ps = ps - def __repr__(self): - return f"{type(self).__name__}({self.ps!r})" - @cached_property def entities(self): to_int = lambda x: sum(x) if isinstance(x, tuple) else x @@ -255,7 +252,7 @@ def indices(self): @cached_property def expression(self): - raise NotImplementedError("The expression for MappedPointSet is not implemented yet.") + raise NotImplementedError("Should not use MappedPointSet like this") def almost_equal(self, other, tolerance=1e-12): """Approximate numerical equality of point sets""" diff --git a/gem/coffee.py b/gem/coffee.py index 12c4d0fcf..f766a4890 100644 --- a/gem/coffee.py +++ b/gem/coffee.py @@ -4,7 +4,8 @@ This file is NOT for code generation as a COFFEE AST. """ -from itertools import chain, repeat +from collections import OrderedDict +import itertools import logging import numpy @@ -57,10 +58,10 @@ def find_optimal_atomics(monomials, linear_indices): :returns: list of atomic GEM expressions """ - atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials))) + atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) def cost(solution): - extent = sum(map(index_extent, solution, repeat(linear_indices))) + extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) # Prefer shorter solutions, but larger extents return (len(solution), -extent) diff --git a/gem/gem.py b/gem/gem.py index b50f78be9..cf223dcb0 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -55,8 +55,8 @@ def __call__(self, *args, **kwargs): # Set free_indices if not set already if not hasattr(obj, 'free_indices'): - obj.free_indices = unique(chain.from_iterable(c.free_indices - for c in obj.children)) + obj.free_indices = unique(chain(*[c.free_indices + for c in obj.children])) # Set dtype if not set already. if not hasattr(obj, 'dtype'): obj.dtype = obj.inherit_dtype_from_children(obj.children) @@ -118,8 +118,9 @@ def __matmul__(self, other): raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") *i, k = indices(len(self.shape)) _, *j = indices(len(other.shape)) - expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) - return ComponentTensor(IndexSum(expr, (k, )), (*i, *j)) + expr = Product(Indexed(self, tuple(i) + (k, )), + Indexed(other, (k, ) + tuple(j))) + return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) def __rmatmul__(self, other): return as_gem(other).__matmul__(self) @@ -340,7 +341,7 @@ def __new__(cls, *args): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b))) + return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Sum, cls).__new__(cls) self.children = a, b @@ -369,7 +370,7 @@ def __new__(cls, *args): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b))) + return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Product, cls).__new__(cls) self.children = a, b @@ -393,7 +394,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b))) + return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Division, cls).__new__(cls) self.children = a, b @@ -406,7 +407,7 @@ class FloorDiv(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children((a, b)) + dtype = Node.inherit_dtype_from_children([a, b]) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})") # Constant folding @@ -429,7 +430,7 @@ class Remainder(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children((a, b)) + dtype = Node.inherit_dtype_from_children([a, b]) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})") # Constant folding @@ -452,7 +453,7 @@ class Power(Scalar): def __new__(cls, base, exponent): assert not base.shape assert not exponent.shape - dtype = Node.inherit_dtype_from_children((base, exponent)) + dtype = Node.inherit_dtype_from_children([base, exponent]) # Constant folding if isinstance(base, Zero): @@ -568,7 +569,7 @@ def __new__(cls, condition, then, else_): self = super(Conditional, cls).__new__(cls) self.children = condition, then, else_ self.shape = then.shape - self.dtype = Node.inherit_dtype_from_children((then, else_)) + self.dtype = Node.inherit_dtype_from_children([then, else_]) return self @@ -684,31 +685,12 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # Simplify Literal and ListTensor - if isinstance(aggregate, (Constant, ListTensor)): - if all(isinstance(i, int) for i in multiindex): - # All indices fixed - sub = aggregate.array[multiindex] - return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate.array[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - - # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) - if isinstance(aggregate, ComponentTensor): - B, = aggregate.children - jj = aggregate.multiindex - if isinstance(B, Indexed): - C, = B.children - kk = B.multiindex - if all(j in kk for j in jj): - ii = tuple(multiindex) - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) - return Indexed(C, ll) + # All indices fixed + if all(isinstance(i, int) for i in multiindex): + if isinstance(aggregate, Constant): + return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) + elif isinstance(aggregate, ListTensor): + return aggregate.array[multiindex] self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -854,11 +836,6 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) - # Index folding - if isinstance(expression, Indexed): - if multiindex == expression.multiindex: - return expression.children[0] - self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -915,17 +892,9 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - e0 = array.flat[0] - child_shape = e0.shape + child_shape = array.flat[0].shape assert all(elem.shape == child_shape for elem in array.flat) - # Index folding - if child_shape == array.shape: - if all(isinstance(elem, Indexed) for elem in array.flat): - if all(elem.children == e0.children for elem in array.flat[1:]): - if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): - return e0.children[0] - if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) @@ -963,7 +932,7 @@ def is_equal(self, other): """Common subexpression eliminating equality predicate.""" if type(self) is not type(other): return False - if numpy.array_equal(self.array, other.array): + if (self.array == other.array).all(): self.array = other.array return True return False diff --git a/gem/node.py b/gem/node.py index 71f814638..5d9c5bf04 100644 --- a/gem/node.py +++ b/gem/node.py @@ -3,7 +3,6 @@ import collections import gem -from itertools import repeat class Node(object): @@ -37,17 +36,14 @@ def _cons_args(self, children): Internally used utility function. """ - front_args = (getattr(self, name) for name in self.__front__) - back_args = (getattr(self, name) for name in self.__back__) + front_args = [getattr(self, name) for name in self.__front__] + back_args = [getattr(self, name) for name in self.__back__] - return (*front_args, *children, *back_args) - - def _arguments(self): - return self._cons_args(self.children) + return tuple(front_args) + tuple(children) + tuple(back_args) def __reduce__(self): # Gold version: - return type(self), self._arguments() + return type(self), self._cons_args(self.children) def reconstruct(self, *args): """Reconstructs the node with new children from @@ -58,7 +54,8 @@ def reconstruct(self, *args): return type(self)(*self._cons_args(args)) def __repr__(self): - return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, self._arguments()))) + cons_args = self._cons_args(self.children) + return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args))) def __eq__(self, other): """Provides equality testing with quick positive and negative @@ -90,7 +87,9 @@ def is_equal(self, other): """ if type(self) is not type(other): return False - return self._arguments() == other._arguments() + self_consargs = self._cons_args(self.children) + other_consargs = other._cons_args(other.children) + return self_consargs == other_consargs def get_hash(self): """Hash function. @@ -98,7 +97,7 @@ def get_hash(self): This is the method to potentially override in derived classes, not :meth:`__hash__`. """ - return hash((type(self), *self._arguments())) + return hash((type(self),) + self._cons_args(self.children)) def _make_traversal_children(node): @@ -236,7 +235,8 @@ def __call__(self, node): return self.cache[node] except KeyError: result = self.function(node, self) - return self.cache.setdefault(node, result) + self.cache[node] = result + return result class MemoizerArg(object): @@ -259,13 +259,14 @@ def __call__(self, node, arg): return self.cache[cache_key] except KeyError: result = self.function(node, self, arg) - return self.cache.setdefault(cache_key, result) + self.cache[cache_key] = result + return result def reuse_if_untouched(node, self): """Reuse if untouched recipe""" - new_children = tuple(map(self, node.children)) - if new_children == node.children: + new_children = list(map(self, node.children)) + if all(nc == c for nc, c in zip(new_children, node.children)): return node else: return node.reconstruct(*new_children) @@ -273,8 +274,8 @@ def reuse_if_untouched(node, self): def reuse_if_untouched_arg(node, self, arg): """Reuse if touched recipe propagating an extra argument""" - new_children = tuple(map(self, node.children, repeat(arg))) - if new_children == node.children: + new_children = [self(child, arg) for child in node.children] + if all(nc == c for nc, c in zip(new_children, node.children)): return node else: return node.reconstruct(*new_children) diff --git a/gem/optimise.py b/gem/optimise.py index ecbb671d4..c56ba7fbb 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -101,7 +101,8 @@ def _replace_indices_atomic(i, self, subst): new_expr = self(i.expression, subst) return i if new_expr == i.expression else VariableIndex(new_expr) else: - return dict(subst).get(i, i) + substitute = dict(subst) + return substitute.get(i, i) @replace_indices.register(Delta) @@ -116,18 +117,20 @@ def replace_indices_delta(node, self, subst): @replace_indices.register(Indexed) def replace_indices_indexed(node, self, subst): - multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children + substitute = dict(subst) + multiindex = [] + for i in node.multiindex: + multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules - substitute = dict(subst) substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) - if multiindex == node.multiindex and new_child == child: + if new_child == child and multiindex == node.multiindex: return node else: return Indexed(new_child, multiindex) @@ -135,6 +138,9 @@ def replace_indices_indexed(node, self, subst): @replace_indices.register(FlexiblyIndexed) def replace_indices_flexiblyindexed(node, self, subst): + child, = node.children + assert not child.free_indices + dim2idxs = tuple( ( offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), @@ -143,8 +149,6 @@ def replace_indices_flexiblyindexed(node, self, subst): for offset, idxs in node.dim2idxs ) - child, = node.children - assert not child.free_indices if dim2idxs == node.dim2idxs: return node else: @@ -188,7 +192,7 @@ def _constant_fold_zero_listtensor(node, self): new_children = list(map(self, node.children)) if all(isinstance(nc, Zero) for nc in new_children): return Zero(node.shape) - elif new_children == node.children: + elif all(nc == c for nc, c in zip(new_children, node.children)): return node else: return node.reconstruct(*new_children) @@ -207,7 +211,7 @@ def constant_fold_zero(exprs): otherwise Literal `0`s would be reintroduced. """ mapper = Memoizer(_constant_fold_zero) - return list(map(mapper, exprs)) + return [mapper(e) for e in exprs] def _select_expression(expressions, index): @@ -252,9 +256,9 @@ def child(expression): assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 - return expr.reconstruct(*(_select_expression(nth_children, index) - for nth_children in zip(*(e.children - for e in expressions)))) + return expr.reconstruct(*[_select_expression(nth_children, index) + for nth_children in zip(*[e.children + for e in expressions])]) raise NotImplementedError("No rule for factorising expressions of this kind.") diff --git a/test/finat/test_zany_mapping.py b/test/finat/test_zany_mapping.py index 6406dbe9b..d5136e651 100644 --- a/test/finat/test_zany_mapping.py +++ b/test/finat/test_zany_mapping.py @@ -3,7 +3,6 @@ import numpy as np import pytest from gem.interpreter import evaluate -from finat.physically_mapped import PhysicallyMappedElement def make_unisolvent_points(element, interior=False): @@ -66,11 +65,11 @@ def check_zany_mapping(element, ref_to_phys, *args, **kwargs): # Zany map the results num_bfs = phys_element.space_dimension() num_dofs = finat_element.space_dimension() - if isinstance(finat_element, PhysicallyMappedElement): + try: Mgem = finat_element.basis_transformation(ref_to_phys) M = evaluate([Mgem])[0].arr ref_vals_zany = np.tensordot(M, ref_vals_piola, (-1, 0)) - else: + except AttributeError: M = np.eye(num_dofs, num_bfs) ref_vals_zany = ref_vals_piola From 0ae1e196ddf84f747a4f88e0314957782f2ff511 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 18 Mar 2025 15:09:22 +0000 Subject: [PATCH 17/24] fix Delta(int, Index) --- finat/point_set.py | 3 +++ gem/gem.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/finat/point_set.py b/finat/point_set.py index 40ea34de1..07397461d 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -230,6 +230,9 @@ def __init__(self, cell, ps): self.cell = cell self.ps = ps + def __repr__(self): + return f"{type(self).__name__}({self.ps!r})" + @cached_property def entities(self): to_int = lambda x: sum(x) if isinstance(x, tuple) else x diff --git a/gem/gem.py b/gem/gem.py index cf223dcb0..5e653e7b1 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -986,10 +986,10 @@ def __new__(cls, i, j, dtype=None): return one if i == j else Zero() if isinstance(i, Integral): - return Indexed(Identity(j.extent), (i, j)) + return Indexed(Literal(numpy.eye(j.extent)[i]), (j,)) if isinstance(j, Integral): - return Indexed(Identity(i.extent), (i, j)) + return Indexed(Literal(numpy.eye(i.extent)[j]), (i,)) self = super(Delta, cls).__new__(cls) self.i = i From b9af43ec114dc7893591f7a80ce52cf31e70df5d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 24 Mar 2025 10:44:30 +0000 Subject: [PATCH 18/24] Docs --- finat/point_set.py | 12 ++++++++++-- finat/quadrature_element.py | 16 ++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/finat/point_set.py b/finat/point_set.py index 07397461d..76254a0da 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -224,8 +224,13 @@ def almost_equal(self, other, tolerance=1e-12): for s, o in zip(self.factors, other.factors)) -class MappedPointSet(AbstractPointSet): +class FacetPointSet(AbstractPointSet): + """A point set constructed by mapping a lower-dimensional point set onto all facets + of a higher-dimensional cell. + :arg cell: The FIAT.Cell.` + :arg ps: The reference PointSet. + """ def __init__(self, cell, ps): self.cell = cell self.ps = ps @@ -235,6 +240,7 @@ def __repr__(self): @cached_property def entities(self): + """The list of all cell entites matching the reference point set dimension.""" to_int = lambda x: sum(x) if isinstance(x, tuple) else x top = self.cell.topology return [(dim, entity) @@ -244,6 +250,7 @@ def entities(self): @cached_property def points(self): + """The array with the reference points mapped onto each facet.""" ref_pts = self.ps.points pts = [self.cell.get_entity_transform(dim, entity)(ref_pts) for dim, entity in self.entities] @@ -251,11 +258,12 @@ def points(self): @cached_property def indices(self): + """An Index tuple of the facet index and the reference point indices.""" return (gem.Index(extent=len(self.entities)), *self.ps.indices) @cached_property def expression(self): - raise NotImplementedError("Should not use MappedPointSet like this") + raise NotImplementedError("Symbolic point expression is not yet implemented for FacetPointSet.") def almost_equal(self, other, tolerance=1e-12): """Approximate numerical equality of point sets""" diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index 4e5c570ca..eedbe67b6 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -1,4 +1,4 @@ -from finat.point_set import UnknownPointSet, MappedPointSet +from finat.point_set import UnknownPointSet, FacetPointSet import numpy @@ -92,7 +92,7 @@ def space_dimension(self): def _point_set(self): ps = self._rule.point_set sd = self.cell.get_spatial_dimension() - return ps if ps.dimension == sd else MappedPointSet(self.cell, ps) + return ps if ps.dimension == sd else FacetPointSet(self.cell, ps) @property def index_shape(self): @@ -140,14 +140,14 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): raise ValueError("Mismatch of quadrature points!") # Return an outer product of identity matrices - multiindex = self.get_indices() - fid = ps.indices - if len(multiindex) > len(fid): - fid = (entity_id, *fid) - product = gem.Delta(fid, multiindex) + basis_indices = self.get_indices() + point_indices = ps.indices + if len(basis_indices) > len(point_indices): + point_indices = (entity_id, *point_indices) + delta = gem.Delta(point_indices, basis_indices) sd = self.cell.get_spatial_dimension() - return {(0,) * sd: gem.ComponentTensor(product, multiindex)} + return {(0,) * sd: gem.ComponentTensor(delta, basis_indices)} def point_evaluation(self, order, refcoords, entity=None): raise NotImplementedError("QuadratureElement cannot do point evaluation!") From b2e85639a12f0ec89a74468eb8174b7045e24ac5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 26 Mar 2025 15:13:35 +0000 Subject: [PATCH 19/24] Docs --- finat/point_set.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/finat/point_set.py b/finat/point_set.py index 76254a0da..5e7758f8e 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -225,11 +225,17 @@ def almost_equal(self, other, tolerance=1e-12): class FacetPointSet(AbstractPointSet): - """A point set constructed by mapping a lower-dimensional point set onto all facets - of a higher-dimensional cell. + """A point set on facets. - :arg cell: The FIAT.Cell.` - :arg ps: The reference PointSet. + A FacetPointSet is constructed by mapping a lower-dimensional PointSet + onto all facets of a higher-dimensional cell. + + Parameters + ---------- + cell : FIAT.Cell + The cell. + ps : PointSet + A lower-dimensional point set. """ def __init__(self, cell, ps): self.cell = cell From b50031e1ff25189973b19e9344c7c33ed99f3bea Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 27 Mar 2025 07:57:58 +0000 Subject: [PATCH 20/24] lint --- finat/point_set.py | 1 + 1 file changed, 1 insertion(+) diff --git a/finat/point_set.py b/finat/point_set.py index 5e7758f8e..ea14392bb 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -236,6 +236,7 @@ class FacetPointSet(AbstractPointSet): The cell. ps : PointSet A lower-dimensional point set. + """ def __init__(self, cell, ps): self.cell = cell From be58cbf6d1278ca08463536f4fcbd5f3f3e8a18d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 27 Mar 2025 08:16:33 +0000 Subject: [PATCH 21/24] lint --- finat/point_set.py | 1 - 1 file changed, 1 deletion(-) diff --git a/finat/point_set.py b/finat/point_set.py index ea14392bb..5e7758f8e 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -236,7 +236,6 @@ class FacetPointSet(AbstractPointSet): The cell. ps : PointSet A lower-dimensional point set. - """ def __init__(self, cell, ps): self.cell = cell From c52b02403d3e83ad0a1402d487cc68ce765d0ddc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 27 Mar 2025 08:35:11 +0000 Subject: [PATCH 22/24] lint --- finat/point_set.py | 1 + 1 file changed, 1 insertion(+) diff --git a/finat/point_set.py b/finat/point_set.py index 5e7758f8e..9e91faae4 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -236,6 +236,7 @@ class FacetPointSet(AbstractPointSet): The cell. ps : PointSet A lower-dimensional point set. + """ def __init__(self, cell, ps): self.cell = cell From 45a9f15d37d410beb9dcfa2d877ced2a34e6c0e4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 27 Mar 2025 17:08:36 +0000 Subject: [PATCH 23/24] Apply suggestions from code review Co-authored-by: India Marsden <37078108+indiamai@users.noreply.github.com> --- finat/quadrature_element.py | 2 +- finat/tensor_product.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/finat/quadrature_element.py b/finat/quadrature_element.py index eedbe67b6..c7b784688 100644 --- a/finat/quadrature_element.py +++ b/finat/quadrature_element.py @@ -25,7 +25,7 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0): :param codim: The codimension of the quadrature scheme. :returns: The appropriate :class:`QuadratureElement` """ - if codim: + if codim > 0: sd = fiat_ref_cell.get_spatial_dimension() rule_ref_cell = fiat_ref_cell.construct_subelement(sd - codim) else: diff --git a/finat/tensor_product.py b/finat/tensor_product.py index 46365f366..0d49a7c3c 100644 --- a/finat/tensor_product.py +++ b/finat/tensor_product.py @@ -176,7 +176,7 @@ def dual_basis(self): zetas = [factor.get_value_indices() for factor in self.factors] Q = gem.ComponentTensor( # Index the factors by so that we can reshape into index-shape - # followed by value-shape + # into index-shape followed by value-shape gem.Product(*(q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas))), tuple(chain(*alphas, *zetas)) ) From f187e6a8d73695755b6fd2c32d2243fa839f35f6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 27 Mar 2025 17:10:24 +0000 Subject: [PATCH 24/24] Update finat/tensor_product.py --- finat/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finat/tensor_product.py b/finat/tensor_product.py index 0d49a7c3c..959213d0e 100644 --- a/finat/tensor_product.py +++ b/finat/tensor_product.py @@ -175,7 +175,7 @@ def dual_basis(self): alphas = [factor.get_indices() for factor in self.factors] zetas = [factor.get_value_indices() for factor in self.factors] Q = gem.ComponentTensor( - # Index the factors by so that we can reshape into index-shape + # Index the factors by basis function and component so that we can reshape # into index-shape followed by value-shape gem.Product(*(q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas))), tuple(chain(*alphas, *zetas))