diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index 25862bf2f..45b0287e3 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -8,7 +8,6 @@ import numpy from FIAT import reference_element, expansions, polynomial_set -from FIAT.functional import index_iterator def get_lagrange_points(nodes): @@ -94,33 +93,28 @@ def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None): class LagrangePolynomialSet(polynomial_set.PolynomialSet): - def __init__(self, ref_el, pts, shape=tuple()): + def __init__(self, ref_el, pts, shape=()): if ref_el.get_shape() != reference_element.LINE: raise ValueError("Invalid reference element type.") expansion_set = LagrangeLineExpansionSet(ref_el, pts) degree = expansion_set.degree - if shape == tuple(): - num_components = 1 - else: - flat_shape = numpy.ravel(shape) - num_components = numpy.prod(flat_shape) + num_components = numpy.prod(shape, dtype=int) num_exp_functions = expansion_set.get_num_members(degree) num_members = num_components * num_exp_functions embedded_degree = degree # set up coefficients - if shape == tuple(): + if shape == (): coeffs = numpy.eye(num_members, dtype="d") else: coeffs_shape = (num_members, *shape, num_exp_functions) coeffs = numpy.zeros(coeffs_shape, "d") - # use functional's index_iterator function - cur_bf = 0 - for idx in index_iterator(shape): - for exp_bf in range(num_exp_functions): - cur_idx = (cur_bf, *idx, exp_bf) - coeffs[cur_idx] = 1.0 - cur_bf += 1 + cur = 0 + exp_bf = range(num_exp_functions) + for idx in numpy.ndindex(shape): + cur_bf = range(cur, cur+num_exp_functions) + coeffs[(cur_bf, *idx, exp_bf)] = 1.0 + cur += num_exp_functions super().__init__(ref_el, degree, embedded_degree, expansion_set, coeffs) diff --git a/FIAT/functional.py b/FIAT/functional.py index 18de29942..26400b011 100644 --- a/FIAT/functional.py +++ b/FIAT/functional.py @@ -18,13 +18,6 @@ from FIAT import polynomial_set, jacobi, quadrature_schemes -def index_iterator(shp): - """Constructs a generator iterating over all indices in - shp in generalized column-major order So if shp = (2,2), then we - construct the sequence (0,0),(0,1),(1,0),(1,1)""" - return numpy.ndindex(shp) - - class Functional(object): r"""Abstract class representing a linear functional. All FIAT functionals are discrete in the sense that @@ -384,7 +377,7 @@ def __init__(self, ref_el, Q, f_at_qpts, nm=None): self.f_at_qpts = f_at_qpts qpts, qwts = Q.get_points(), Q.get_weights() weights = numpy.transpose(numpy.multiply(f_at_qpts, qwts), (-1,) + tuple(range(len(shp)))) - alphas = list(index_iterator(shp)) + alphas = list(numpy.ndindex(shp)) pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)} Functional.__init__(self, ref_el, shp, pt_dict, {}, nm or "FrobeniusIntegralMoment") @@ -494,7 +487,7 @@ def __init__(self, ref_el, Q, f_at_qpts): weights = numpy.multiply(f_at_qpts, Q.get_weights()).T alphas = tuple(map(tuple, numpy.eye(sd, dtype=int))) - dpt_dict = {tuple(pt): [(wt[i], alphas[j], (i, j)) for i, j in index_iterator(shp)] + dpt_dict = {tuple(pt): [(wt[i], alphas[j], (i, j)) for i, j in numpy.ndindex(shp)] for pt, wt in zip(points, weights)} super().__init__(ref_el, tuple(), {}, dpt_dict, "IntegralMomentOfDivergence") @@ -655,7 +648,7 @@ def __init__(self, ref_el, v, w, pt): wvT = numpy.outer(w, v) shp = wvT.shape - pt_dict = {tuple(pt): [(wvT[idx], idx) for idx in index_iterator(shp)]} + pt_dict = {tuple(pt): [(wvT[idx], idx) for idx in numpy.ndindex(shp)]} super().__init__(ref_el, shp, pt_dict, {}, "PointwiseInnerProductEval") diff --git a/FIAT/polynomial_set.py b/FIAT/polynomial_set.py index 534c18642..48fd9418c 100644 --- a/FIAT/polynomial_set.py +++ b/FIAT/polynomial_set.py @@ -18,7 +18,6 @@ import numpy from itertools import chain from FIAT import expansions -from FIAT.functional import index_iterator def mis(m, n): @@ -113,25 +112,21 @@ class ONPolynomialSet(PolynomialSet): identity matrix of coefficients. Can be used to specify ON bases for vector- and tensor-valued sets as well. """ - def __init__(self, ref_el, degree, shape=tuple(), **kwargs): + def __init__(self, ref_el, degree, shape=(), **kwargs): expansion_set = expansions.ExpansionSet(ref_el, **kwargs) - if shape == tuple(): - num_components = 1 - else: - flat_shape = numpy.ravel(shape) - num_components = numpy.prod(flat_shape) + num_components = numpy.prod(shape, dtype=int) num_exp_functions = expansion_set.get_num_members(degree) num_members = num_components * num_exp_functions embedded_degree = degree # set up coefficients - if shape == tuple(): + if shape == (): coeffs = numpy.eye(num_members) else: coeffs = numpy.zeros((num_members, *shape, num_exp_functions)) cur = 0 exp_bf = range(num_exp_functions) - for idx in index_iterator(shape): + for idx in numpy.ndindex(shape): cur_bf = range(cur, cur+num_exp_functions) coeffs[(cur_bf, *idx, exp_bf)] = 1.0 cur += num_exp_functions @@ -243,7 +238,7 @@ def __init__(self, ref_el, degree, size=None, **kwargs): coeffs = numpy.zeros((num_members, *shape, num_exp_functions)) cur = 0 exp_bf = range(num_exp_functions) - for i, j in index_iterator(shape): + for i, j in numpy.ndindex(shape): if i > j: continue cur_bf = range(cur, cur+num_exp_functions) @@ -275,7 +270,7 @@ def __init__(self, ref_el, degree, size=None, **kwargs): coeffs = numpy.zeros((num_members, *shape, num_exp_functions)) cur = 0 exp_bf = range(num_exp_functions) - for i, j in index_iterator(shape): + for i, j in numpy.ndindex(shape): if i == size-1 and j == size-1: continue cur_bf = range(cur, cur+num_exp_functions) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 3ce923844..54cce6d4e 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from collections.abc import Mapping import gem import numpy @@ -260,6 +261,46 @@ class NeedsCoordinateMappingElement(metaclass=ABCMeta): pass +class MappedTabulation(Mapping): + """A lazy tabulation dict that applies the basis transformation only + on the requested derivatives.""" + + def __init__(self, M, ref_tabulation): + self.M = M + self.ref_tabulation = ref_tabulation + # we expect M to be sparse with O(1) nonzeros per row + # for each row, get the column index of each nonzero entry + csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + for i in range(M.shape[0])] + self.csr = csr + self._tabulation_cache = {} + + def matvec(self, table): + # basis recombination using hand-rolled sparse-dense matrix multiplication + ii = gem.indices(len(table.shape)-1) + phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])] + # the sum approach is faster than calling numpy.dot or gem.IndexSum + exprs = [gem.ComponentTensor(gem.Sum(*(self.M.array[i, j] * phi[j] for j in js)), ii) + for i, js in enumerate(self.csr)] + + val = gem.ListTensor(exprs) + # val = self.M @ table + return gem.optimise.aggressive_unroll(val) + + def __getitem__(self, alpha): + try: + return self._tabulation_cache[alpha] + except KeyError: + result = self.matvec(self.ref_tabulation[alpha]) + return self._tabulation_cache.setdefault(alpha, result) + + def __iter__(self): + return iter(self.ref_tabulation) + + def __len__(self): + return len(self.ref_tabulation) + + class PhysicallyMappedElement(NeedsCoordinateMappingElement): """A mixin that applies a "physical" transformation to tabulated basis functions.""" @@ -277,24 +318,10 @@ def basis_transformation(self, coordinate_mapping): :arg coordinate_mapping: Object providing physical geometry.""" pass - def map_tabulation(self, tabulation, coordinate_mapping): + def map_tabulation(self, ref_tabulation, coordinate_mapping): assert coordinate_mapping is not None - M = self.basis_transformation(coordinate_mapping) - # we expect M to be sparse with O(1) nonzeros per row - # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] - for i in range(M.shape[0])] - - def matvec(table): - # basis recombination using hand-rolled sparse-dense matrix multiplication - table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])] - # the sum approach is faster than calling numpy.dot or gem.IndexSum - expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)] - val = gem.ListTensor(expressions) - return gem.optimise.aggressive_unroll(val) - - return {alpha: matvec(tabulation[alpha]) for alpha in tabulation} + return MappedTabulation(M, ref_tabulation) def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): result = super().basis_evaluation(order, ps, entity=entity) diff --git a/gem/coffee.py b/gem/coffee.py index f766a4890..12c4d0fcf 100644 --- a/gem/coffee.py +++ b/gem/coffee.py @@ -4,8 +4,7 @@ This file is NOT for code generation as a COFFEE AST. """ -from collections import OrderedDict -import itertools +from itertools import chain, repeat import logging import numpy @@ -58,10 +57,10 @@ def find_optimal_atomics(monomials, linear_indices): :returns: list of atomic GEM expressions """ - atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) + atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials))) def cost(solution): - extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) + extent = sum(map(index_extent, solution, repeat(linear_indices))) # Prefer shorter solutions, but larger extents return (len(solution), -extent) diff --git a/gem/gem.py b/gem/gem.py index 974556754..b31fd950d 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -15,7 +15,7 @@ """ from abc import ABCMeta -from itertools import chain +from itertools import chain, repeat from functools import reduce from operator import attrgetter from numbers import Integral, Number @@ -55,8 +55,8 @@ def __call__(self, *args, **kwargs): # Set free_indices if not set already if not hasattr(obj, 'free_indices'): - obj.free_indices = unique(chain(*[c.free_indices - for c in obj.children])) + obj.free_indices = unique(chain.from_iterable(c.free_indices + for c in obj.children)) # Set dtype if not set already. if not hasattr(obj, 'dtype'): obj.dtype = obj.inherit_dtype_from_children(obj.children) @@ -118,9 +118,8 @@ def __matmul__(self, other): raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") *i, k = indices(len(self.shape)) _, *j = indices(len(other.shape)) - expr = Product(Indexed(self, tuple(i) + (k, )), - Indexed(other, (k, ) + tuple(j))) - return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) + expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) + return ComponentTensor(IndexSum(expr, (k, )), (*i, *j)) def __rmatmul__(self, other): return as_gem(other).__matmul__(self) @@ -341,7 +340,7 @@ def __new__(cls, *args): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Sum, cls).__new__(cls) self.children = a, b @@ -370,7 +369,7 @@ def __new__(cls, *args): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Product, cls).__new__(cls) self.children = a, b @@ -394,7 +393,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b])) + return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b))) self = super(Division, cls).__new__(cls) self.children = a, b @@ -407,7 +406,7 @@ class FloorDiv(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children([a, b]) + dtype = Node.inherit_dtype_from_children((a, b)) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})") # Constant folding @@ -430,7 +429,7 @@ class Remainder(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - dtype = Node.inherit_dtype_from_children([a, b]) + dtype = Node.inherit_dtype_from_children((a, b)) if dtype != uint_type: raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})") # Constant folding @@ -453,7 +452,7 @@ class Power(Scalar): def __new__(cls, base, exponent): assert not base.shape assert not exponent.shape - dtype = Node.inherit_dtype_from_children([base, exponent]) + dtype = Node.inherit_dtype_from_children((base, exponent)) # Constant folding if isinstance(base, Zero): @@ -569,7 +568,7 @@ def __new__(cls, condition, then, else_): self = super(Conditional, cls).__new__(cls) self.children = condition, then, else_ self.shape = then.shape - self.dtype = Node.inherit_dtype_from_children([then, else_]) + self.dtype = Node.inherit_dtype_from_children((then, else_)) return self @@ -932,7 +931,7 @@ def is_equal(self, other): """Common subexpression eliminating equality predicate.""" if type(self) is not type(other): return False - if (self.array == other.array).all(): + if numpy.array_equal(self.array, other.array): self.array = other.array return True return False @@ -973,7 +972,7 @@ class Delta(Scalar, Terminal): def __new__(cls, i, j, dtype=None): if isinstance(i, tuple) and isinstance(j, tuple): # Handle multiindices - return Product(*map(Delta, i, j)) + return Product(*map(Delta, i, j, repeat(dtype))) assert isinstance(i, IndexBase) assert isinstance(j, IndexBase) @@ -985,26 +984,18 @@ def __new__(cls, i, j, dtype=None): if isinstance(i, Integral) and isinstance(j, Integral): return one if i == j else Zero() - if isinstance(i, Integral): - return Indexed(Literal(numpy.eye(j.extent)[i]), (j,)) - - if isinstance(j, Integral): - return Indexed(Literal(numpy.eye(i.extent)[j]), (i,)) - self = super(Delta, cls).__new__(cls) self.i = i self.j = j # Set up free indices - free_indices = [] - for index in (i, j): - if isinstance(index, Index): - free_indices.append(index) - elif isinstance(index, VariableIndex): - raise NotImplementedError("Can not make Delta with VariableIndex") + free_indices = [index for index in (i, j) if isinstance(index, Index)] self.free_indices = tuple(unique(free_indices)) self._dtype = dtype return self + def reconstruct(self, *args): + return Delta(*args, dtype=self.dtype) + class Inverse(Node): """The inverse of a square matrix.""" diff --git a/gem/node.py b/gem/node.py index 5d9c5bf04..190fe6d40 100644 --- a/gem/node.py +++ b/gem/node.py @@ -3,6 +3,7 @@ import collections import gem +from itertools import repeat class Node(object): @@ -36,14 +37,18 @@ def _cons_args(self, children): Internally used utility function. """ - front_args = [getattr(self, name) for name in self.__front__] - back_args = [getattr(self, name) for name in self.__back__] + front_args = (getattr(self, name) for name in self.__front__) + back_args = (getattr(self, name) for name in self.__back__) - return tuple(front_args) + tuple(children) + tuple(back_args) + return (*front_args, *children, *back_args) + + @property + def _arguments(self): + return self._cons_args(self.children) def __reduce__(self): # Gold version: - return type(self), self._cons_args(self.children) + return type(self), self._arguments def reconstruct(self, *args): """Reconstructs the node with new children from @@ -54,8 +59,8 @@ def reconstruct(self, *args): return type(self)(*self._cons_args(args)) def __repr__(self): - cons_args = self._cons_args(self.children) - return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args))) + repr_args = ', '.join(map(repr, self._arguments)) + return f"{type(self).__name__}({repr_args})" def __eq__(self, other): """Provides equality testing with quick positive and negative @@ -87,9 +92,7 @@ def is_equal(self, other): """ if type(self) is not type(other): return False - self_consargs = self._cons_args(self.children) - other_consargs = other._cons_args(other.children) - return self_consargs == other_consargs + return self._arguments == other._arguments def get_hash(self): """Hash function. @@ -97,7 +100,7 @@ def get_hash(self): This is the method to potentially override in derived classes, not :meth:`__hash__`. """ - return hash((type(self),) + self._cons_args(self.children)) + return hash((type(self), *self._arguments)) def _make_traversal_children(node): @@ -235,8 +238,7 @@ def __call__(self, node): return self.cache[node] except KeyError: result = self.function(node, self) - self.cache[node] = result - return result + return self.cache.setdefault(node, result) class MemoizerArg(object): @@ -259,14 +261,13 @@ def __call__(self, node, arg): return self.cache[cache_key] except KeyError: result = self.function(node, self, arg) - self.cache[cache_key] = result - return result + return self.cache.setdefault(cache_key, result) def reuse_if_untouched(node, self): """Reuse if untouched recipe""" - new_children = list(map(self, node.children)) - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children)) + if new_children == node.children: return node else: return node.reconstruct(*new_children) @@ -274,8 +275,8 @@ def reuse_if_untouched(node, self): def reuse_if_untouched_arg(node, self, arg): """Reuse if touched recipe propagating an extra argument""" - new_children = [self(child, arg) for child in node.children] - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children, repeat(arg))) + if new_children == node.children: return node else: return node.reconstruct(*new_children) diff --git a/gem/optimise.py b/gem/optimise.py index c56ba7fbb..289ccff7d 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -117,20 +117,18 @@ def replace_indices_delta(node, self, subst): @replace_indices.register(Indexed) def replace_indices_indexed(node, self, subst): + multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children - substitute = dict(subst) - multiindex = [] - for i in node.multiindex: - multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules + substitute = dict(subst) substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) - if new_child == child and multiindex == node.multiindex: + if multiindex == node.multiindex and new_child == child: return node else: return Indexed(new_child, multiindex) @@ -138,9 +136,6 @@ def replace_indices_indexed(node, self, subst): @replace_indices.register(FlexiblyIndexed) def replace_indices_flexiblyindexed(node, self, subst): - child, = node.children - assert not child.free_indices - dim2idxs = tuple( ( offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), @@ -149,6 +144,8 @@ def replace_indices_flexiblyindexed(node, self, subst): for offset, idxs in node.dim2idxs ) + child, = node.children + assert not child.free_indices if dim2idxs == node.dim2idxs: return node else: @@ -192,7 +189,7 @@ def _constant_fold_zero_listtensor(node, self): new_children = list(map(self, node.children)) if all(isinstance(nc, Zero) for nc in new_children): return Zero(node.shape) - elif all(nc == c for nc, c in zip(new_children, node.children)): + elif new_children == node.children: return node else: return node.reconstruct(*new_children) @@ -211,7 +208,7 @@ def constant_fold_zero(exprs): otherwise Literal `0`s would be reintroduced. """ mapper = Memoizer(_constant_fold_zero) - return [mapper(e) for e in exprs] + return list(map(mapper, exprs)) def _select_expression(expressions, index): @@ -249,6 +246,12 @@ def child(expression): children = remove_componenttensors([Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) + if types == {Delta}: + if all(e.i == k and e.j == expr.j for k, e in enumerate(expressions)): + return expr.reconstruct(index, expr.j) + elif all(e.j == k and e.i == expr.i for k, e in enumerate(expressions)): + return expr.reconstruct(expr.i, index) + if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: @@ -256,9 +259,9 @@ def child(expression): assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 - return expr.reconstruct(*[_select_expression(nth_children, index) - for nth_children in zip(*[e.children - for e in expressions])]) + return expr.reconstruct(*(_select_expression(nth_children, index) + for nth_children in zip(*(e.children + for e in expressions)))) raise NotImplementedError("No rule for factorising expressions of this kind.") diff --git a/test/finat/test_zany_mapping.py b/test/finat/test_zany_mapping.py index 54c66d827..ea638ce9d 100644 --- a/test/finat/test_zany_mapping.py +++ b/test/finat/test_zany_mapping.py @@ -3,6 +3,7 @@ import numpy as np import pytest from gem.interpreter import evaluate +from finat.physically_mapped import PhysicallyMappedElement def make_unisolvent_points(element, interior=False): @@ -65,11 +66,11 @@ def check_zany_mapping(element, ref_to_phys, *args, **kwargs): # Zany map the results num_bfs = phys_element.space_dimension() num_dofs = finat_element.space_dimension() - try: + if isinstance(finat_element, PhysicallyMappedElement): Mgem = finat_element.basis_transformation(ref_to_phys) M = evaluate([Mgem])[0].arr ref_vals_zany = np.tensordot(M, ref_vals_piola, (-1, 0)) - except AttributeError: + else: M = np.eye(num_dofs, num_bfs) ref_vals_zany = ref_vals_piola