From 35ef7554ddd16ccab72e77a23381e636a893414e Mon Sep 17 00:00:00 2001 From: Rob Kirby Date: Thu, 27 Feb 2020 11:18:56 -0600 Subject: [PATCH 1/7] Add interface to trimmed serendipity --- tsfc/fiatinterface.py | 2 ++ tsfc/finatinterface.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tsfc/fiatinterface.py b/tsfc/fiatinterface.py index c35171f4..21998cff 100644 --- a/tsfc/fiatinterface.py +++ b/tsfc/fiatinterface.py @@ -62,6 +62,8 @@ "NCF": None, "DPC": FIAT.DPC, "S": FIAT.Serendipity, + "SminusFace": FIAT.TrimmedSerendipityFace, + "SminusEdge": FIAT.TrimmedSerendipityEdge, "DPC L2": FIAT.DPC, "Discontinuous Lagrange L2": FIAT.DiscontinuousLagrange, "Gauss-Legendre L2": FIAT.GaussLegendre, diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index 84c9fddf..10643c47 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -67,6 +67,8 @@ "Real": finat.DiscontinuousLagrange, "DPC": finat.DPC, "S": finat.Serendipity, + "SminusFace": finat.TrimmedSerendipityFace, + "SminusEdge": finat.TrimmedSerendipityEdge, "DPC L2": finat.DPC, "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, "Gauss-Legendre L2": finat.GaussLegendre, From 7c17027ef5e5b51ff79296ac6f09d65b4cd05a5f Mon Sep 17 00:00:00 2001 From: Rob Kirby Date: Thu, 27 Feb 2020 12:43:04 -0600 Subject: [PATCH 2/7] Fix names --- tsfc/fiatinterface.py | 4 ++-- tsfc/finatinterface.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tsfc/fiatinterface.py b/tsfc/fiatinterface.py index 21998cff..909ece42 100644 --- a/tsfc/fiatinterface.py +++ b/tsfc/fiatinterface.py @@ -62,8 +62,8 @@ "NCF": None, "DPC": FIAT.DPC, "S": FIAT.Serendipity, - "SminusFace": FIAT.TrimmedSerendipityFace, - "SminusEdge": FIAT.TrimmedSerendipityEdge, + "SminusF": FIAT.TrimmedSerendipityFace, + "SminusE": FIAT.TrimmedSerendipityEdge, "DPC L2": FIAT.DPC, "Discontinuous Lagrange L2": FIAT.DiscontinuousLagrange, "Gauss-Legendre L2": FIAT.GaussLegendre, diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index 10643c47..a84fe662 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -67,8 +67,8 @@ "Real": finat.DiscontinuousLagrange, "DPC": finat.DPC, "S": finat.Serendipity, - "SminusFace": finat.TrimmedSerendipityFace, - "SminusEdge": finat.TrimmedSerendipityEdge, + "SminusF": finat.TrimmedSerendipityFace, + "SminusE": finat.TrimmedSerendipityEdge, "DPC L2": finat.DPC, "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, "Gauss-Legendre L2": finat.GaussLegendre, From 3a353342367edda476ba527367a8c32a5c44c64d Mon Sep 17 00:00:00 2001 From: Justincrum Date: Tue, 14 Apr 2020 11:27:05 -0700 Subject: [PATCH 3/7] Plumbing for SminusDiv.py --- tsfc/fiatinterface.py | 1 + tsfc/finatinterface.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tsfc/fiatinterface.py b/tsfc/fiatinterface.py index 909ece42..993bb1bf 100644 --- a/tsfc/fiatinterface.py +++ b/tsfc/fiatinterface.py @@ -63,6 +63,7 @@ "DPC": FIAT.DPC, "S": FIAT.Serendipity, "SminusF": FIAT.TrimmedSerendipityFace, + "SminusDiv": FIAT.TrimmedSerendipityDiv, "SminusE": FIAT.TrimmedSerendipityEdge, "DPC L2": FIAT.DPC, "Discontinuous Lagrange L2": FIAT.DiscontinuousLagrange, diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index a84fe662..b4c0e749 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -68,6 +68,7 @@ "DPC": finat.DPC, "S": finat.Serendipity, "SminusF": finat.TrimmedSerendipityFace, + "SminusDiv": finat.TrimmedSerendipityDiv, "SminusE": finat.TrimmedSerendipityEdge, "DPC L2": finat.DPC, "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, From 7f0a9c334f22fb67ab1f14813ee879d7c21fd81b Mon Sep 17 00:00:00 2001 From: Justincrum Date: Mon, 24 Aug 2020 12:42:01 -0700 Subject: [PATCH 4/7] Fixing an update for the init file. --- tsfc/__init__.py | 3 +- tsfc/driver.py | 138 ++++++++++++++++++++++++++++++++++------------ tsfc/ufl_utils.py | 94 +++++++++++++++++++++++++++++-- 3 files changed, 193 insertions(+), 42 deletions(-) diff --git a/tsfc/__init__.py b/tsfc/__init__.py index e69abac5..967457bd 100644 --- a/tsfc/__init__.py +++ b/tsfc/__init__.py @@ -1,4 +1,5 @@ -from tsfc.driver import compile_form, compile_expression_at_points # noqa: F401 +#from tsfc.driver import compile_form, compile_expression_at_points # noqa: F401 +from tsfc.driver import compile_form, compile_expression_dual_evaluation #noqa F401 from tsfc.parameters import default_parameters # noqa: F401 try: diff --git a/tsfc/driver.py b/tsfc/driver.py index 7297d19c..0fd79b4e 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -6,7 +6,7 @@ from functools import reduce from itertools import chain -from numpy import asarray +from numpy import asarray, allclose import ufl from ufl.algorithms import extract_arguments, extract_coefficients @@ -18,15 +18,18 @@ import gem import gem.impero_utils as impero_utils +import FIAT from FIAT.reference_element import TensorProductCell +from FIAT.functional import PointEvaluation from finat.point_set import PointSet -from finat.quadrature import AbstractQuadratureRule, make_quadrature +from finat.quadrature import AbstractQuadratureRule, make_quadrature, QuadratureRule from tsfc import fem, ufl_utils from tsfc.fiatinterface import as_fiat_cell from tsfc.logging import logger from tsfc.parameters import default_parameters, is_complex +from tsfc.ufl_utils import apply_mapping # To handle big forms. The various transformations might need a deeper stack sys.setrecursionlimit(3000) @@ -47,14 +50,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr assert isinstance(form, Form) # Determine whether in complex mode: - # complex nodes would break the refactoriser. complex_mode = parameters and is_complex(parameters.get("scalar_type")) - if complex_mode: - logger.warning("Disabling whole expression optimisations" - " in GEM for supporting complex mode.") - parameters = parameters.copy() - parameters["mode"] = 'vanilla' - fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode) logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time) @@ -95,6 +91,10 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co # Delayed import, loopy is a runtime dependency import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy interface = firedrake_interface_loopy.KernelBuilder + if coffee: + scalar_type = parameters["scalar_type_c"] + else: + scalar_type = parameters["scalar_type"] # Remove these here, they're handled below. if parameters.get("quadrature_degree") in ["auto", "default", None, -1, "-1"]: @@ -120,7 +120,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co domain_numbering = form_data.original_form.domain_numbering() builder = interface(integral_type, integral_data.subdomain_id, domain_numbering[integral_data.domain], - parameters["scalar_type"], + scalar_type, diagonal=diagonal) argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() for arg in arguments) @@ -151,11 +151,11 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co kernel_cfg = dict(interface=builder, ufl_cell=cell, integral_type=integral_type, - precision=parameters["precision"], integration_dim=integration_dim, entity_ids=entity_ids, argument_multiindices=argument_multiindices, - index_cache=index_cache) + index_cache=index_cache, + scalar_type=parameters["scalar_type"]) mode_irs = collections.OrderedDict() for integral in integral_data.integrals: @@ -264,24 +264,28 @@ def name_multiindex(multiindex, name): for multiindex, name in zip(argument_multiindices, ['j', 'k']): name_multiindex(multiindex, name) - return builder.construct_kernel(kernel_name, impero_c, parameters["precision"], index_names, quad_rule) + return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule) -def compile_expression_at_points(expression, points, coordinates, interface=None, - parameters=None, coffee=True): - """Compiles a UFL expression to be evaluated at compile-time known - reference points. Useful for interpolating UFL expressions onto - function spaces with only point evaluation nodes. +def compile_expression_dual_evaluation(expression, to_element, coordinates, *, + domain=None, interface=None, + parameters=None, coffee=False): + """Compile a UFL expression to be evaluated against a compile-time known reference element's dual basis. + + Useful for interpolating UFL expressions into e.g. N1curl spaces. :arg expression: UFL expression - :arg points: reference coordinates of the evaluation points + :arg to_element: A FIAT FiniteElement for the target space :arg coordinates: the coordinate function + :arg domain: optional UFL domain the expression is defined on (useful when expression contains no domain). :arg interface: backend module for the kernel interface :arg parameters: parameters object :arg coffee: compile coffee kernel instead of loopy kernel """ import coffee.base as ast import loopy as lp + if any(len(dual.deriv_dict) != 0 for dual in to_element.dual_basis()): + raise NotImplementedError("Can only interpolate onto dual basis functionals without derivative evaluation, sorry!") if parameters is None: parameters = default_parameters() @@ -293,6 +297,13 @@ def compile_expression_at_points(expression, points, coordinates, interface=None # Determine whether in complex mode complex_mode = is_complex(parameters["scalar_type"]) + # Find out which mapping to apply + try: + mapping, = set(to_element.mapping()) + except ValueError: + raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") + expression = apply_mapping(expression, mapping, domain) + # Apply UFL preprocessing expression = ufl_utils.preprocess_expression(expression, complex_mode=complex_mode) @@ -329,22 +340,75 @@ def compile_expression_at_points(expression, points, coordinates, interface=None expression = ufl_utils.split_coefficients(expression, builder.coefficient_split) # Translate to GEM - point_set = PointSet(points) - config = dict(interface=builder, - ufl_cell=coordinates.ufl_domain().ufl_cell(), - precision=parameters["precision"], - point_set=point_set, - argument_multiindices=argument_multiindices) - ir, = fem.compile_ufl(expression, point_sum=False, **config) - - # Deal with non-scalar expressions - value_shape = ir.shape - tensor_indices = tuple(gem.Index() for s in value_shape) - if value_shape: - ir = gem.Indexed(ir, tensor_indices) + kernel_cfg = dict(interface=builder, + ufl_cell=coordinates.ufl_domain().ufl_cell(), + argument_multiindices=argument_multiindices, + index_cache={}, + scalar_type=parameters["scalar_type"]) + + if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()): + # This is an optimisation for point-evaluation nodes which + # should go away once FInAT offers the interface properly + qpoints = [] + # Everything is just a point evaluation. + for dual in to_element.dual_basis(): + ptdict = dual.get_point_dict() + qpoint, = ptdict.keys() + (qweight, component), = ptdict[qpoint] + assert allclose(qweight, 1.0) + assert component == () + qpoints.append(qpoint) + point_set = PointSet(qpoints) + config = kernel_cfg.copy() + config.update(point_set=point_set) + + # Allow interpolation onto QuadratureElements to refer to the quadrature + # rule they represent + if isinstance(to_element, FIAT.QuadratureElement): + assert allclose(asarray(qpoints), asarray(to_element._points)) + quad_rule = QuadratureRule(point_set, to_element._weights) + config["quadrature_rule"] = quad_rule + + expr, = fem.compile_ufl(expression, **config, point_sum=False) + shape_indices = tuple(gem.Index() for _ in expr.shape) + basis_indices = point_set.indices + ir = gem.Indexed(expr, shape_indices) + else: + # This is general code but is more unrolled than necssary. + dual_expressions = [] # one for each functional + broadcast_shape = len(expression.ufl_shape) - len(to_element.value_shape()) + shape_indices = tuple(gem.Index() for _ in expression.ufl_shape[:broadcast_shape]) + expr_cache = {} # Sharing of evaluation of the expression at points + for dual in to_element.dual_basis(): + pts = tuple(sorted(dual.get_point_dict().keys())) + try: + expr, point_set = expr_cache[pts] + except KeyError: + point_set = PointSet(pts) + config = kernel_cfg.copy() + config.update(point_set=point_set) + expr, = fem.compile_ufl(expression, **config, point_sum=False) + expr = gem.partial_indexed(expr, shape_indices) + expr_cache[pts] = expr, point_set + weights = collections.defaultdict(list) + for p in pts: + for (w, cmp) in dual.get_point_dict()[p]: + weights[cmp].append(w) + qexprs = gem.Zero() + for cmp in sorted(weights): + qweights = gem.Literal(weights[cmp]) + qexpr = gem.Indexed(expr, cmp) + qexpr = gem.index_sum(gem.Indexed(qweights, point_set.indices)*qexpr, + point_set.indices) + qexprs = gem.Sum(qexprs, qexpr) + assert qexprs.shape == () + assert set(qexprs.free_indices) == set(chain(shape_indices, *argument_multiindices)) + dual_expressions.append(qexprs) + basis_indices = (gem.Index(), ) + ir = gem.Indexed(gem.ListTensor(dual_expressions), basis_indices) # Build kernel body - return_indices = point_set.indices + tensor_indices + tuple(chain(*argument_multiindices)) + return_indices = basis_indices + shape_indices + tuple(chain(*argument_multiindices)) return_shape = tuple(i.extent for i in return_indices) return_var = gem.Variable('A', return_shape) if coffee: @@ -353,14 +417,16 @@ def compile_expression_at_points(expression, points, coordinates, interface=None return_arg = lp.GlobalArg("A", dtype=parameters["scalar_type"], shape=return_shape) return_expr = gem.Indexed(return_var, return_indices) + + # TODO: one should apply some GEM optimisations as in assembly, + # but we don't for now. ir, = impero_utils.preprocess_gem([ir]) impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices) - point_index, = point_set.indices - + index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices)) # Handle kernel interface requirements builder.register_requirements([ir]) # Build kernel tuple - return builder.construct_kernel(return_arg, impero_c, parameters["precision"], {point_index: 'p'}) + return builder.construct_kernel(return_arg, impero_c, index_names) def lower_integral_type(fiat_cell, integral_type): diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 888ac028..634b9d64 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -18,9 +18,10 @@ from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction from ufl.geometry import QuadratureWeight +from ufl.geometry import Jacobian, JacobianDeterminant, JacobianInverse from ufl.classes import (Abs, Argument, CellOrientation, Coefficient, ComponentTensor, Expr, FloatValue, Division, - MixedElement, MultiIndex, Product, + Indexed, MixedElement, MultiIndex, Product, ScalarValue, Sqrt, Zero, CellVolume, FacetArea) from gem.node import MemoizerArg @@ -272,8 +273,11 @@ def _simplify_abs_expr(o, self, in_abs): @_simplify_abs.register(Sqrt) def _simplify_abs_sqrt(o, self, in_abs): - # Square root is always non-negative - return ufl_reuse_if_untouched(o, self(o.ufl_operands[0], False)) + result = ufl_reuse_if_untouched(o, self(o.ufl_operands[0], False)) + if self.complex_mode and in_abs: + return Abs(result) + else: + return result @_simplify_abs.register(ScalarValue) @@ -325,8 +329,88 @@ def _simplify_abs_abs(o, self, in_abs): return self(o.ufl_operands[0], True) -def simplify_abs(expression): +def simplify_abs(expression, complex_mode): """Simplify absolute values in a UFL expression. Its primary purpose is to "neutralise" CellOrientation nodes that are surrounded by absolute values and thus not at all necessary.""" - return MemoizerArg(_simplify_abs)(expression, False) + mapper = MemoizerArg(_simplify_abs) + mapper.complex_mode = complex_mode + return mapper(expression, False) + + +def apply_mapping(expression, mapping, domain): + """ + This applies the appropriate transformation to the + given expression for interpolation to a specific + element, according to the manner in which it maps + from the reference cell. + + The following is borrowed from the UFC documentation: + + Let g be a field defined on a physical domain T with physical + coordinates x. Let T_0 be a reference domain with coordinates + X. Assume that F: T_0 -> T such that + + x = F(X) + + Let J be the Jacobian of F, i.e J = dx/dX and let K denote the + inverse of the Jacobian K = J^{-1}. Then we (currently) have the + following four types of mappings: + + 'affine' mapping for g: + + G(X) = g(x) + + For vector fields g: + + 'contravariant piola' mapping for g: + + G(X) = det(J) K g(x) i.e G_i(X) = det(J) K_ij g_j(x) + + 'covariant piola' mapping for g: + + G(X) = J^T g(x) i.e G_i(X) = J^T_ij g(x) = J_ji g_j(x) + + 'double covariant piola' mapping for g: + + G(X) = J^T g(x) J i.e. G_il(X) = J_ji g_jk(x) J_kl + + 'double contravariant piola' mapping for g: + + G(X) = det(J)^2 K g(x) K^T i.e. G_il(X)=(detJ)^2 K_ij g_jk K_lk + + If 'contravariant piola' or 'covariant piola' are applied to a + matrix-valued function, the appropriate mappings are applied row-by-row. + + :arg expression: UFL expression + :arg mapping: a string indicating the mapping to apply + """ + + mesh = expression.ufl_domain() + if mesh is None: + mesh = domain + if domain is not None and mesh != domain: + raise NotImplementedError("Multiple domains not supported") + rank = len(expression.ufl_shape) + if mapping == "affine": + return expression + elif mapping == "covariant piola": + J = Jacobian(mesh) + *i, j, k = indices(len(expression.ufl_shape) + 1) + expression = Indexed(expression, MultiIndex((*i, k))) + return as_tensor(J.T[j, k] * expression, (*i, j)) + elif mapping == "contravariant piola": + K = JacobianInverse(mesh) + detJ = JacobianDeterminant(mesh) + *i, j, k = indices(len(expression.ufl_shape) + 1) + expression = Indexed(expression, MultiIndex((*i, k))) + return as_tensor(detJ * K[j, k] * expression, (*i, j)) + elif mapping == "double covariant piola" and rank == 2: + J = Jacobian(mesh) + return J.T * expression * J + elif mapping == "double contravariant piola" and rank == 2: + K = JacobianInverse(mesh) + detJ = JacobianDeterminant(mesh) + return (detJ)**2 * K * expression * K.T + else: + raise NotImplementedError("Don't know how to handle mapping type %s for expression of rank %d" % (mapping, rank)) From 4530fbd7ca2eec9de15f73eff58dc11fc8bb54cb Mon Sep 17 00:00:00 2001 From: Justincrum Date: Tue, 8 Sep 2020 13:24:03 -0700 Subject: [PATCH 5/7] Changes to get SminusCurl implemented. --- tsfc/fiatinterface.py | 1 + tsfc/finatinterface.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tsfc/fiatinterface.py b/tsfc/fiatinterface.py index 3935ffab..b8b0b8d2 100644 --- a/tsfc/fiatinterface.py +++ b/tsfc/fiatinterface.py @@ -65,6 +65,7 @@ "SminusF": FIAT.TrimmedSerendipityFace, "SminusDiv": FIAT.TrimmedSerendipityDiv, "SminusE": FIAT.TrimmedSerendipityEdge, + "SminusCurl": FIAT.TrimmedSerendipityCurl, "DPC L2": FIAT.DPC, "Discontinuous Lagrange L2": FIAT.DiscontinuousLagrange, "Gauss-Legendre L2": FIAT.GaussLegendre, diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index facccf8a..5429a1bc 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -70,6 +70,7 @@ "SminusF": finat.TrimmedSerendipityFace, "SminusDiv": finat.TrimmedSerendipityDiv, "SminusE": finat.TrimmedSerendipityEdge, + "SminusCurl": finat.TrimmedSerendipityCurl, "DPC L2": finat.DPC, "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, "Gauss-Legendre L2": finat.GaussLegendre, From 701225015f12b87a006fb3bd16fccc4b094a38ba Mon Sep 17 00:00:00 2001 From: Justincrum Date: Fri, 11 Sep 2020 11:14:44 -0700 Subject: [PATCH 6/7] More fixes from over the summer. --- tsfc/finatinterface.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index 5429a1bc..8932904c 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -31,7 +31,7 @@ from tsfc.fiatinterface import as_fiat_cell -__all__ = ("create_element", "supported_elements", "as_fiat_cell") +__all__ = ("create_element", "supported_elements", "as_fiat_cell", "create_base_element") supported_elements = { @@ -308,3 +308,14 @@ def _create_element(ufl_element, **kwargs): # Forward result return finat_element, deps + + +def create_base_element(ufl_element, **kwargs): + """Create a "scalar" base FInAT element given a UFL element. + Takes a UFL element and an unspecified set of parameter options, + and returns the converted element. + """ + finat_element = create_element(ufl_element, **kwargs) + if isinstance(finat_element, finat.TensorFiniteElement): + finat_element = finat_element.base_element + return finat_element From 6cdfc30cea375e186ed0674679d8d202cae7f75f Mon Sep 17 00:00:00 2001 From: Johnny Vogels <35307256+jmv2009@users.noreply.github.com> Date: Sat, 22 May 2021 17:26:18 +0200 Subject: [PATCH 7/7] Delete fiatinterface.py File is not present in master; Should not need to be added --- tsfc/fiatinterface.py | 257 ------------------------------------------ 1 file changed, 257 deletions(-) delete mode 100644 tsfc/fiatinterface.py diff --git a/tsfc/fiatinterface.py b/tsfc/fiatinterface.py deleted file mode 100644 index b8b0b8d2..00000000 --- a/tsfc/fiatinterface.py +++ /dev/null @@ -1,257 +0,0 @@ -# -*- coding: utf-8 -*- -# -# This file was modified from FFC -# (http://bitbucket.org/fenics-project/ffc), copyright notice -# reproduced below. -# -# Copyright (C) 2009-2013 Kristian B. Oelgaard and Anders Logg -# -# This file is part of FFC. -# -# FFC is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# FFC is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with FFC. If not, see . - -from functools import singledispatch, partial -import weakref - -import FIAT -from FIAT.tensor_product import FlattenedDimensions - -import ufl - - -__all__ = ("create_element", "supported_elements", "as_fiat_cell") - - -supported_elements = { - # These all map directly to FIAT elements - "Bernstein": FIAT.Bernstein, - "Brezzi-Douglas-Marini": FIAT.BrezziDouglasMarini, - "Brezzi-Douglas-Fortin-Marini": FIAT.BrezziDouglasFortinMarini, - "Bubble": FIAT.Bubble, - "FacetBubble": FIAT.FacetBubble, - "Crouzeix-Raviart": FIAT.CrouzeixRaviart, - "Discontinuous Lagrange": FIAT.DiscontinuousLagrange, - "Discontinuous Taylor": FIAT.DiscontinuousTaylor, - "Discontinuous Raviart-Thomas": FIAT.DiscontinuousRaviartThomas, - "Gauss-Lobatto-Legendre": FIAT.GaussLobattoLegendre, - "Gauss-Legendre": FIAT.GaussLegendre, - "Lagrange": FIAT.Lagrange, - "Nedelec 1st kind H(curl)": FIAT.Nedelec, - "Nedelec 2nd kind H(curl)": FIAT.NedelecSecondKind, - "Raviart-Thomas": FIAT.RaviartThomas, - "HDiv Trace": FIAT.HDivTrace, - "Regge": FIAT.Regge, - "Hellan-Herrmann-Johnson": FIAT.HellanHerrmannJohnson, - # These require special treatment below - "DQ": None, - "Q": None, - "RTCE": None, - "RTCF": None, - "NCE": None, - "NCF": None, - "DPC": FIAT.DPC, - "S": FIAT.Serendipity, - "SminusF": FIAT.TrimmedSerendipityFace, - "SminusDiv": FIAT.TrimmedSerendipityDiv, - "SminusE": FIAT.TrimmedSerendipityEdge, - "SminusCurl": FIAT.TrimmedSerendipityCurl, - "DPC L2": FIAT.DPC, - "Discontinuous Lagrange L2": FIAT.DiscontinuousLagrange, - "Gauss-Legendre L2": FIAT.GaussLegendre, - "DQ L2": None, -} -"""A :class:`.dict` mapping UFL element family names to their -FIAT-equivalent constructors. If the value is ``None``, the UFL -element is supported, but must be handled specially because it doesn't -have a direct FIAT equivalent.""" - - -def as_fiat_cell(cell): - """Convert a ufl cell to a FIAT cell. - - :arg cell: the :class:`ufl.Cell` to convert.""" - if not isinstance(cell, ufl.AbstractCell): - raise ValueError("Expecting a UFL Cell") - return FIAT.ufc_cell(cell) - - -@singledispatch -def convert(element, vector_is_mixed): - """Handler for converting UFL elements to FIAT elements. - - :arg element: The UFL element to convert. - :arg vector_is_mixed: Should Vector and Tensor elements be treated - as Mixed? If ``False``, then just look at the sub-element. - - Do not use this function directly, instead call - :func:`create_element`.""" - if element.family() in supported_elements: - raise ValueError("Element %s supported, but no handler provided" % element) - raise ValueError("Unsupported element type %s" % type(element)) - - -# Base finite elements first -@convert.register(ufl.FiniteElement) -def convert_finiteelement(element, vector_is_mixed): - if element.family() == "Real": - # Real element is just DG0 - cell = element.cell() - return create_element(ufl.FiniteElement("DG", cell, 0), vector_is_mixed) - cell = as_fiat_cell(element.cell()) - if element.family() == "Quadrature": - degree = element.degree() - scheme = element.quadrature_scheme() - if degree is None or scheme is None: - raise ValueError("Quadrature scheme and degree must be specified!") - - quad_rule = FIAT.create_quadrature(cell, degree, scheme) - return FIAT.QuadratureElement(cell, quad_rule.get_points(), weights=quad_rule.get_weights()) - lmbda = supported_elements[element.family()] - if lmbda is None: - if element.cell().cellname() == "quadrilateral": - # Handle quadrilateral short names like RTCF and RTCE. - element = element.reconstruct(cell=quadrilateral_tpc) - elif element.cell().cellname() == "hexahedron": - # Handle hexahedron short names like NCF and NCE. - element = element.reconstruct(cell=hexahedron_tpc) - else: - raise ValueError("%s is supported, but handled incorrectly" % - element.family()) - return FlattenedDimensions(create_element(element, vector_is_mixed)) - - kind = element.variant() - if kind is None: - kind = 'spectral' if element.cell().cellname() == 'interval' else 'equispaced' # default variant - - if element.family() == "Lagrange": - if kind == 'equispaced': - lmbda = FIAT.Lagrange - elif kind == 'spectral' and element.cell().cellname() == 'interval': - lmbda = FIAT.GaussLobattoLegendre - else: - raise ValueError("Variant %r not supported on %s" % (kind, element.cell())) - elif element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)", - "Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)"}: - lmbda = partial(lmbda, variant=element.variant()) - elif element.family() in {"Discontinuous Lagrange", "Discontinuous Lagrange L2"}: - if kind == 'equispaced': - lmbda = FIAT.DiscontinuousLagrange - elif kind == 'spectral' and element.cell().cellname() == 'interval': - lmbda = FIAT.GaussLegendre - else: - raise ValueError("Variant %r not supported on %s" % (kind, element.cell())) - return lmbda(cell, element.degree()) - - -# Element modifiers -@convert.register(ufl.RestrictedElement) -def convert_restrictedelement(element, vector_is_mixed): - return FIAT.RestrictedElement(create_element(element.sub_element(), vector_is_mixed), - restriction_domain=element.restriction_domain()) - - -@convert.register(ufl.EnrichedElement) -def convert_enrichedelement(element, vector_is_mixed): - return FIAT.EnrichedElement(*(create_element(e, vector_is_mixed) - for e in element._elements)) - - -@convert.register(ufl.NodalEnrichedElement) -def convert_nodalenrichedelement(element, vector_is_mixed): - return FIAT.NodalEnrichedElement(*(create_element(e, vector_is_mixed) - for e in element._elements)) - - -@convert.register(ufl.BrokenElement) -def convert_brokenelement(element, vector_is_mixed): - return FIAT.DiscontinuousElement(create_element(element._element, vector_is_mixed)) - - -# Now for the TPE-specific stuff -@convert.register(ufl.TensorProductElement) -def convert_tensorproductelement(element, vector_is_mixed): - cell = element.cell() - if type(cell) is not ufl.TensorProductCell: - raise ValueError("TPE not on TPC?") - A, B = element.sub_elements() - return FIAT.TensorProductElement(create_element(A, vector_is_mixed), - create_element(B, vector_is_mixed)) - - -@convert.register(ufl.HDivElement) -def convert_hdivelement(element, vector_is_mixed): - return FIAT.Hdiv(create_element(element._element, vector_is_mixed)) - - -@convert.register(ufl.HCurlElement) -def convert_hcurlelement(element, vector_is_mixed): - return FIAT.Hcurl(create_element(element._element, vector_is_mixed)) - - -# Finally the MixedElement case -@convert.register(ufl.MixedElement) -def convert_mixedelement(element, vector_is_mixed): - # If we're just trying to get the scalar part of a vector element? - if not vector_is_mixed: - assert isinstance(element, (ufl.VectorElement, - ufl.TensorElement)) - return create_element(element.sub_elements()[0], vector_is_mixed) - - elements = [] - - def rec(eles): - for ele in eles: - if isinstance(ele, ufl.MixedElement): - rec(ele.sub_elements()) - else: - elements.append(ele) - - rec(element.sub_elements()) - fiat_elements = map(partial(create_element, vector_is_mixed=vector_is_mixed), - elements) - return FIAT.MixedElement(fiat_elements) - - -hexahedron_tpc = ufl.TensorProductCell(ufl.quadrilateral, ufl.interval) -quadrilateral_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval) -_cache = weakref.WeakKeyDictionary() - - -def create_element(element, vector_is_mixed=True): - """Create a FIAT element (suitable for tabulating with) given a UFL element. - - :arg element: The UFL element to create a FIAT element from. - - :arg vector_is_mixed: indicate whether VectorElement (or - TensorElement) should be treated as a MixedElement. Maybe - useful if you want a FIAT element that tells you how many - "nodes" the finite element has. - """ - try: - cache = _cache[element] - except KeyError: - _cache[element] = {} - cache = _cache[element] - - try: - return cache[vector_is_mixed] - except KeyError: - pass - - if element.cell() is None: - raise ValueError("Don't know how to build element when cell is not given") - - fiat_element = convert(element, vector_is_mixed) - cache[vector_is_mixed] = fiat_element - return fiat_element