diff --git a/meshmode/discretization/connection/refinement.py b/meshmode/discretization/connection/refinement.py index 1124e5ad9..b4c6f6c1e 100644 --- a/meshmode/discretization/connection/refinement.py +++ b/meshmode/discretization/connection/refinement.py @@ -69,8 +69,8 @@ def _build_interpolation_batches_for_group( """ from meshmode.discretization.connection.direct import InterpolationBatch - num_children = len(record.tesselation.children) \ - if record.tesselation else 0 + num_children = len(record.el_tess_info.children) \ + if record.el_tess_info else 0 from_bins = [[] for i in range(1 + num_children)] to_bins = [[] for i in range(1 + num_children)] for elt_idx, refinement_result in enumerate(record.element_mapping): @@ -87,9 +87,11 @@ def _build_interpolation_batches_for_group( to_bin.append(child_idx) fine_unit_nodes = fine_discr_group.unit_nodes + fine_meg = fine_discr_group.mesh_el_group + from meshmode.mesh.refinement.utils import map_unit_nodes_to_children mapped_unit_nodes = map_unit_nodes_to_children( - fine_unit_nodes, record.tesselation) + fine_meg, fine_unit_nodes, record.el_tess_info) from itertools import chain for from_bin, to_bin, unit_nodes in zip( diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py index ea298b185..bd400aa0c 100644 --- a/meshmode/mesh/__init__.py +++ b/meshmode/mesh/__init__.py @@ -914,10 +914,8 @@ def __eq__(self, other): and self.groups == other.groups and self.vertex_id_dtype == other.vertex_id_dtype and self.element_id_dtype == other.element_id_dtype - and (self._nodal_adjacency - == other._nodal_adjacency) - and (self._facial_adjacency_groups - == other._facial_adjacency_groups) + and self._nodal_adjacency == other._nodal_adjacency + and self._facial_adjacency_groups == other._facial_adjacency_groups and self.boundary_tags == other.boundary_tags and self.is_conforming == other.is_conforming) diff --git a/meshmode/mesh/refinement/__init__.py b/meshmode/mesh/refinement/__init__.py index be101b487..c3d258d20 100644 --- a/meshmode/mesh/refinement/__init__.py +++ b/meshmode/mesh/refinement/__init__.py @@ -20,13 +20,14 @@ THE SOFTWARE. """ +import itertools +from functools import partial import numpy as np -import itertools -from pytools import RecordWithoutPickling -from meshmode.mesh.refinement.no_adjacency import ( # noqa: F401 - RefinerWithoutAdjacency) +from meshmode.mesh.refinement.no_adjacency import RefinerWithoutAdjacency +from meshmode.mesh.refinement.tesselate import \ + ElementTesselationInfo, GroupRefinementRecord import logging logger = logging.getLogger(__name__) @@ -37,6 +38,43 @@ .. autofunction :: refine_uniformly """ +__all__ = [ + "Refiner", "RefinerWithoutAdjacency", "refine_uniformly" +] + + +# {{{ deprecated + +class SimplexResampler: + @staticmethod + def get_vertex_pair_to_midpoint_order(dim): + nmidpoints = dim * (dim + 1) // 2 + return dict(zip( + ((i, j) for j in range(dim + 1) for i in range(j)), + range(nmidpoints) + )) + + @staticmethod + def get_midpoints(group, el_tess_info, elements): + import meshmode.mesh.refinement.tesselate as tess + return tess.get_group_midpoints(group, el_tess_info, elements) + + @staticmethod + def get_tesselated_nodes(group, el_tess_info, elements): + import meshmode.mesh.refinement.tesselate as tess + return tess.get_group_tesselated_nodes(group, el_tess_info, elements) + + +def tesselate_simplex(dim): + import modepy as mp + shape = mp.Simplex(dim) + space = mp.space_for_shape(shape, 2) + + node_tuples = mp.node_tuples_for_space(space) + return node_tuples, mp.submesh_for_shape(shape, node_tuples) + +# }}} + class TreeRayNode: """Describes a ray as a tree, this class represents each node in this tree @@ -65,21 +103,6 @@ def __init__(self, left_vertex, right_vertex, adjacent_elements=[]): self.adjacent_add_diff = [] -class _Tesselation(RecordWithoutPickling): - - def __init__(self, children, ref_vertices): - RecordWithoutPickling.__init__(self, - children=children, - ref_vertices=ref_vertices,) - - -class _GroupRefinementRecord(RecordWithoutPickling): - - def __init__(self, tesselation, element_mapping): - RecordWithoutPickling.__init__(self, - tesselation=tesselation, element_mapping=element_mapping) - - class Refiner: """An older that mostly succeeds at preserving adjacency across non-conformal refinement. @@ -100,20 +123,22 @@ class Refiner: # {{{ constructor def __init__(self, mesh): + from warnings import warn + warn("Refiner is deprecated and will be removed in 2022.", + DeprecationWarning, stacklevel=2) + if mesh.is_conforming is not True: raise ValueError("Refiner can only be used with meshes that are known " "to be conforming. If you would like to refine non-conforming " "meshes and do not need adjacency information, consider " "using RefinerWithoutAdjacency.") - from meshmode.mesh.refinement.tesselate import \ - tesselateseg, tesselatetet, tesselatetri self.lazy = False self.seen_tuple = {} self.group_refinement_records = [] - seg_node_tuples, seg_result = tesselateseg() - tri_node_tuples, tri_result = tesselatetri() - tet_node_tuples, tet_result = tesselatetet() + seg_node_tuples, seg_result = tesselate_simplex(1) + tri_node_tuples, tri_result = tesselate_simplex(2) + tet_node_tuples, tet_result = tesselate_simplex(3) #quadrilateral_node_tuples = [ #print tri_result, tet_result self.simplex_node_tuples = [ @@ -597,11 +622,16 @@ def check_adjacent_elements(groups, new_hanging_vertex_elements, nelements_in_gr nelements_in_grp = grp.nelements del self.group_refinement_records[:] + from meshmode.mesh import SimplexElementGroup for grp_idx, grp in enumerate(self.last_mesh.groups): + if not isinstance(grp, SimplexElementGroup): + raise TypeError("refinement not supported for groups of type " + f"'{type(grp).__name__}'") + iel_base = grp.element_nr_base # List of lists mapping element number to new element number(s). element_mapping = [] - tesselation = None + el_tess_info = None # {{{ get midpoint coordinates for vertices @@ -613,19 +643,17 @@ def check_adjacent_elements(groups, new_hanging_vertex_elements, nelements_in_gr if len(grp.vertex_indices[iel_grp]) == grp.dim + 1: midpoints_to_find.append(iel_grp) if not resampler: - from meshmode.mesh.refinement.resampler import ( - SimplexResampler) resampler = SimplexResampler() - tesselation = _Tesselation( - self.simplex_result[grp.dim], - self.simplex_node_tuples[grp.dim]) + el_tess_info = ElementTesselationInfo( + children=self.simplex_result[grp.dim], + ref_vertices=self.simplex_node_tuples[grp.dim]) else: raise NotImplementedError("unimplemented: midpoint finding" "for non simplex elements") if midpoints_to_find: midpoints = resampler.get_midpoints( - grp, tesselation, midpoints_to_find) + grp, el_tess_info, midpoints_to_find) midpoint_order = resampler.get_vertex_pair_to_midpoint_order(grp.dim) del midpoints_to_find @@ -703,7 +731,10 @@ def check_adjacent_elements(groups, new_hanging_vertex_elements, nelements_in_gr # if len(cur_list[len(cur_list)-1]) self.group_refinement_records.append( - _GroupRefinementRecord(tesselation, element_mapping)) + GroupRefinementRecord( + el_tess_info=el_tess_info, + element_mapping=element_mapping) + ) #clear connectivity data for grp in self.last_mesh.groups: @@ -760,10 +791,9 @@ def check_adjacent_elements(groups, new_hanging_vertex_elements, nelements_in_gr if to_resample: # if simplex if is_simplex: - from meshmode.mesh.refinement.resampler import SimplexResampler resampler = SimplexResampler() new_nodes = resampler.get_tesselated_nodes( - prev_group, refinement_record.tesselation, to_resample) + prev_group, refinement_record.el_tess_info, to_resample) else: raise NotImplementedError( "unimplemented: node resampling for non simplex elements") @@ -948,9 +978,13 @@ def generate_nodal_adjacency(self, nelements, nvertices, groups): def refine_uniformly(mesh, iterations, with_adjacency=False): if with_adjacency: - refiner = Refiner(mesh) - else: - refiner = RefinerWithoutAdjacency(mesh) + # For conforming meshes, even RefinerWithoutAdjacency will reconstruct + # adjacency from vertex identity. + + if not mesh.is_conforming: + raise ValueError("mesh must be conforming if adjacency is desired") + + refiner = RefinerWithoutAdjacency(mesh) for _ in range(iterations): refiner.refine_uniformly() diff --git a/meshmode/mesh/refinement/no_adjacency.py b/meshmode/mesh/refinement/no_adjacency.py index adf1fc2f9..cc9caa37e 100644 --- a/meshmode/mesh/refinement/no_adjacency.py +++ b/meshmode/mesh/refinement/no_adjacency.py @@ -24,37 +24,12 @@ THE SOFTWARE. """ - import numpy as np -from pytools import RecordWithoutPickling - -from pytools import memoize_method - import logging logger = logging.getLogger(__name__) -class _TesselationInfo(RecordWithoutPickling): - - def __init__(self, children, ref_vertices, orig_vertex_indices, - midpoint_indices, midpoint_vertex_pairs, resampler): - RecordWithoutPickling.__init__(self, - children=children, - ref_vertices=ref_vertices, - orig_vertex_indices=orig_vertex_indices, - midpoint_indices=midpoint_indices, - midpoint_vertex_pairs=midpoint_vertex_pairs, - resampler=resampler) - - -class _GroupRefinementRecord(RecordWithoutPickling): - - def __init__(self, tesselation, element_mapping): - RecordWithoutPickling.__init__(self, - tesselation=tesselation, element_mapping=element_mapping) - - class RefinerWithoutAdjacency: """A refiner that may be applied to non-conforming :class:`meshmode.mesh.Mesh` instances. It does not generate adjacency @@ -81,61 +56,6 @@ def __init__(self, mesh): self.group_refinement_records = None self.global_vertex_pair_to_midpoint = {} - # {{{ build tesselation info - - @memoize_method - def _get_bisection_tesselation_info(self, group_type, dim): - from meshmode.mesh import SimplexElementGroup - if issubclass(group_type, SimplexElementGroup): - from meshmode.mesh.refinement.tesselate import \ - tesselate_simplex_bisection, add_tuples, halve_tuple - ref_vertices, children = tesselate_simplex_bisection(dim) - - orig_vertex_tuples = [(0,) * dim] + [ - (0,) * i + (2,) + (0,) * (dim-i-1) - for i in range(dim)] - node_dict = { - ituple: idx - for idx, ituple in enumerate(ref_vertices)} - orig_vertex_indices = [node_dict[vt] for vt in orig_vertex_tuples] - - from meshmode.mesh.refinement.resampler import SimplexResampler - resampler = SimplexResampler() - vertex_pair_to_midpoint_order = \ - resampler.get_vertex_pair_to_midpoint_order(dim) - - midpoint_idx_to_vertex_pair = {} - for vpair, mpoint_idx in vertex_pair_to_midpoint_order.items(): - midpoint_idx_to_vertex_pair[mpoint_idx] = vpair - - midpoint_vertex_pairs = [ - midpoint_idx_to_vertex_pair[i] - for i in range(len(midpoint_idx_to_vertex_pair))] - - midpoint_indices = [ - node_dict[ - halve_tuple( - add_tuples( - orig_vertex_tuples[v1], - orig_vertex_tuples[v2]))] - for v1, v2 in midpoint_vertex_pairs] - - return _TesselationInfo( - ref_vertices=ref_vertices, - children=np.array(children), - orig_vertex_indices=np.array(orig_vertex_indices), - midpoint_indices=np.array(midpoint_indices), - midpoint_vertex_pairs=midpoint_vertex_pairs, - resampler=resampler, - ) - - else: - raise NotImplementedError( - "bisection for elements groups of type %s" - % group_type.__name__) - - # }}} - def refine_uniformly(self): flags = np.ones(self._current_mesh.nelements, dtype=bool) return self.refine(flags) @@ -163,9 +83,13 @@ def refine(self, refine_flags): if perform_vertex_updates: inew_vertex = mesh.nvertices + from meshmode.mesh.refinement.tesselate import ( + get_group_tesselation_info, + get_group_midpoints, + get_group_tesselated_nodes) + for igrp, group in enumerate(mesh.groups): - bisection_info = self._get_bisection_tesselation_info( - type(group), group.dim) + el_tess_info = get_group_tesselation_info(group) # {{{ compute counts and index arrays @@ -173,7 +97,7 @@ def refine(self, refine_flags): group.element_nr_base: group.element_nr_base+group.nelements] - nchildren = len(bisection_info.children) + nchildren = len(el_tess_info.children) nchild_elements = np.ones(group.nelements, dtype=mesh.element_id_dtype) nchild_elements[grp_flags] = nchildren @@ -189,9 +113,10 @@ def refine(self, refine_flags): # }}} + from meshmode.mesh.refinement.tesselate import GroupRefinementRecord group_refinement_records.append( - _GroupRefinementRecord( - tesselation=bisection_info, + GroupRefinementRecord( + el_tess_info=el_tess_info, element_mapping=[ list(range( child_el_indices[iel], @@ -201,8 +126,8 @@ def refine(self, refine_flags): # {{{ get new vertices together if perform_vertex_updates: - midpoints = bisection_info.resampler.get_midpoints( - group, bisection_info, refining_el_old_indices) + midpoints = get_group_midpoints( + group, el_tess_info, refining_el_old_indices) new_vertex_indices = np.empty( (new_nelements, group.vertex_indices.shape[1]), @@ -216,17 +141,17 @@ def refine(self, refine_flags): for old_iel in refining_el_old_indices: new_iel_base = child_el_indices[old_iel] - refining_vertices = np.empty(len(bisection_info.ref_vertices), + refining_vertices = np.empty(len(el_tess_info.ref_vertices), dtype=mesh.vertex_id_dtype) refining_vertices.fill(-17) # carry over old vertices - refining_vertices[bisection_info.orig_vertex_indices] = \ + refining_vertices[el_tess_info.orig_vertex_indices] = \ group.vertex_indices[old_iel] for imidpoint, (iref_midpoint, (v1, v2)) in enumerate(zip( - bisection_info.midpoint_indices, - bisection_info.midpoint_vertex_pairs)): + el_tess_info.midpoint_indices, + el_tess_info.midpoint_vertex_pairs)): global_v1 = group.vertex_indices[old_iel, v1] global_v2 = group.vertex_indices[old_iel, v2] @@ -250,7 +175,7 @@ def refine(self, refine_flags): assert (refining_vertices >= 0).all() new_vertex_indices[new_iel_base:new_iel_base+nchildren] = \ - refining_vertices[bisection_info.children] + refining_vertices[el_tess_info.children] assert (new_vertex_indices >= 0).all() else: @@ -269,8 +194,8 @@ def refine(self, refine_flags): # copy over unchanged nodes new_nodes[:, unrefined_el_new_indices] = group.nodes[:, ~grp_flags] - tesselated_nodes = bisection_info.resampler.get_tesselated_nodes( - group, bisection_info, refining_el_old_indices) + tesselated_nodes = get_group_tesselated_nodes( + group, el_tess_info, refining_el_old_indices) for old_iel in refining_el_old_indices: new_iel_base = child_el_indices[old_iel] diff --git a/meshmode/mesh/refinement/resampler.py b/meshmode/mesh/refinement/resampler.py deleted file mode 100644 index fd6ae9396..000000000 --- a/meshmode/mesh/refinement/resampler.py +++ /dev/null @@ -1,137 +0,0 @@ -__copyright__ = "Copyright (C) 2016 Matt Wala" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - - -import numpy as np -import modepy as mp - -import logging -logger = logging.getLogger(__name__) - - -# {{{ resampling simplex points for refinement - -# NOTE: Class internal to refiner: do not make documentation public. -class SimplexResampler: - """ - Resampling of points on simplex elements for refinement. - - Most methods take a ``tesselation`` parameter. - The tesselation should follow the format of - :func:`meshmode.mesh.tesselate.tesselatetri()` or - :func:`meshmode.mesh.tesselate.tesselatetet()`. - """ - - @staticmethod - def get_vertex_pair_to_midpoint_order(dim): - """ - :arg dim: Dimension of the element - - :return: A :class:`dict` mapping the vertex pair :math:`(v1, v2)` (with - :math:`v1 < v2`) to the number of the midpoint in the tesselation - ordering (the numbering is restricted to the midpoints, so there - are no gaps in the numbering) - """ - nmidpoints = dim * (dim + 1) // 2 - return dict(zip( - ((i, j) for j in range(dim + 1) for i in range(j)), - range(nmidpoints) - )) - - @staticmethod - def get_midpoints(group, tesselation, elements): - """ - Compute the midpoints of the vertices of the specified elements. - - :arg group: An instance of :class:`meshmode.mesh.SimplexElementGroup` - :arg tesselation: With attributes `ref_vertices`, `children` - :arg elements: A list of (group-relative) element numbers - - :return: A :class:`dict` mapping element numbers to midpoint - coordinates, with each value in the map having shape - ``(ambient_dim, nmidpoints)``. The ordering of the midpoints - follows their ordering in the tesselation (see also - :meth:`SimplexResampler.get_vertex_pair_to_midpoint_order`) - """ - if group.vertex_indices is not None: - assert len(group.vertex_indices[0]) == group.dim + 1 - - # Get midpoints, converted to unit coordinates. - midpoints = -1 + np.array([vertex for vertex in - tesselation.ref_vertices if 1 in vertex], dtype=float) - - resamp_mat = mp.resampling_matrix( - mp.simplex_best_available_basis(group.dim, group.order), - midpoints.T, - group.unit_nodes) - - resamp_midpoints = np.einsum("mu,deu->edm", - resamp_mat, - group.nodes[:, elements]) - - return dict(zip(elements, resamp_midpoints)) - - @staticmethod - def get_tesselated_nodes(group, tesselation, elements): - """ - Compute the nodes of the child elements according to the tesselation. - - :arg group: An instance of :class:`meshmode.mesh.SimplexElementGroup` - :arg tesselation: With attributes `ref_vertices`, `children` - :arg elements: A list of (group-relative) element numbers - - :return: A :class:`dict` mapping element numbers to node - coordinates, with each value in the map having shape - ``(ambient_dim, nchildren, nunit_nodes)``. - The ordering of the child nodes follows the ordering - of ``tesselation.children.`` - """ - if group.vertex_indices is not None: - assert len(group.vertex_indices[0]) == group.dim + 1 - - from meshmode.mesh.refinement.utils import map_unit_nodes_to_children - - # Get child unit node coordinates. - child_unit_nodes = np.hstack(list( - map_unit_nodes_to_children(group.unit_nodes, tesselation))) - - resamp_mat = mp.resampling_matrix( - mp.simplex_best_available_basis(group.dim, group.order), - child_unit_nodes, - group.unit_nodes) - - resamp_unit_nodes = np.einsum("cu,deu->edc", - resamp_mat, - group.nodes[:, elements]) - - ambient_dim = len(group.nodes) - nunit_nodes = len(group.unit_nodes[0]) - - return {elem: - resamp_unit_nodes[ielem].reshape( - (ambient_dim, -1, nunit_nodes)) - for ielem, elem in enumerate(elements)} - -# }}} - - -# vim: foldmethod=marker diff --git a/meshmode/mesh/refinement/tesselate.py b/meshmode/mesh/refinement/tesselate.py index 211a87f08..13b50ee33 100644 --- a/meshmode/mesh/refinement/tesselate.py +++ b/meshmode/mesh/refinement/tesselate.py @@ -1,6 +1,7 @@ __copyright__ = """ Copyright (C) 2018 Andreas Kloeckner Copyright (C) 2014-6 Shivam Gupta +Copyright (C) 2020 Alexandru Fikl """ __license__ = """ @@ -23,105 +24,250 @@ THE SOFTWARE. """ +from dataclasses import dataclass +from functools import singledispatch -from pytools import generate_nonnegative_integer_tuples_summing_to_at_most \ - as gnitstam +import numpy as np +import modepy as mp +from meshmode.mesh import MeshElementGroup, _ModepyElementGroup -def add_tuples(a, b): - return tuple(ac+bc for ac, bc in zip(a, b)) +import logging +logger = logging.getLogger(__name__) +from typing import List, Tuple, Optional -def halve_tuple(a): - def halve(x): - d, r = divmod(x, 2) + +# {{{ interface + +@dataclass(frozen=True) +class ElementTesselationInfo: + """Describes how one element is split into multiple child elements. + + .. attribute:: children + + An array of shape ``(nchildren, nvertices)`` containing the vertices + of each child element the reference element was split into. + + .. attribute:: ref_vertices + + A list of tuples (similar to :func:`modepy.node_tuples_for_space`) + for the reference element containing midpoints. This is equivalent + to a second-order equidistant element. + + .. attribute:: orig_vertex_indices + + Indices into :attr:`ref_vertices` that select only the vertices, i.e. + without the midpoints. + + .. attribute:: midpoint_indices + + Indices into :attr:`ref_vertices` that select only the midpoints, i.e. + without :attr:`orig_vertex_indices`. + + .. attribute:: midpoint_vertex_pairs + + A list of tuples ``(v1, v2)`` of indices into :attr:`orig_vertex_indices` + that give for each midpoint the two vertices on the same line. + """ + + children: np.ndarray + ref_vertices: List[Tuple[int, ...]] + + orig_vertex_indices: Optional[np.ndarray] = None + midpoint_indices: Optional[np.ndarray] = None + midpoint_vertex_pairs: Optional[List[Tuple[int, int]]] = None + + +@dataclass(frozen=True) +class GroupRefinementRecord: + """ + .. attribute:: el_tess_info + + An instance of :class:`ElementTesselationInfo` that describes the + tesselation of a single element into multiple child elements. + + .. attribute:: element_mapping + + A mapping from the original elements to the refined child elements. + """ + + el_tess_info: ElementTesselationInfo + # FIXME: This should really be a CSR data structure. + element_mapping: List[List[int]] + + +@singledispatch +def get_group_midpoints(meg: MeshElementGroup, el_tess_info, elements): + """Compute the midpoints of the vertices of the specified elements. + + :arg group: an instance of :class:`meshmode.mesh.MeshElementGroup`. + :arg el_tess_info: a :class:`ElementTesselationInfo`. + :arg elements: a list of (group-relative) element numbers. + + :return: A :class:`dict` mapping element numbers to midpoint + coordinates, with each value in the map having shape + ``(ambient_dim, nmidpoints)``. The ordering of the midpoints + follows their ordering in the tesselation. + """ + raise NotImplementedError(type(meg).__name__) + + +@singledispatch +def get_group_tesselated_nodes(meg: MeshElementGroup, el_tess_info, elements): + """Compute the nodes of the child elements according to the tesselation. + + :arg group: An instance of :class:`meshmode.mesh.MeshElementGroup`. + :arg el_tess_info: a :class:`ElementTesselationInfo`. + :arg elements: A list of (group-relative) element numbers. + + :return: A :class:`dict` mapping element numbers to node + coordinates, with each value in the map having shape + ``(ambient_dim, nchildren, nunit_nodes)``. + The ordering of the child nodes follows the ordering + of ``el_tess_info.children.`` + """ + raise NotImplementedError(type(meg).__name__) + + +@singledispatch +def get_group_tesselation_info(meg: MeshElementGroup): + """ + :returns: a :class:`ElementTesselationInfo` for the element group *meg*. + """ + raise NotImplementedError(type(meg).__name__) + +# }}} + + +# {{{ helpers + +def _midpoint_tuples(a, b): + def midpoint(x, y): + d, r = divmod(x + y, 2) if r: raise ValueError("%s is not evenly divisible by two" % x) + return d - return tuple(halve(ac) for ac in a) - - -def tesselateseg(): - node_tuples = [(0,), (1,), (2,)] - result = [(0, 1), (1, 2)] - return node_tuples, result - - -def tesselatetri(): - result = [] - - node_tuples = list(gnitstam(2, 2)) - node_dict = { - ituple: idx - for idx, ituple in enumerate(node_tuples)} - - def try_add_tri(current, d1, d2, d3): - try: - result.append(( - node_dict[add_tuples(current, d1)], - node_dict[add_tuples(current, d2)], - node_dict[add_tuples(current, d3)], - )) - except KeyError: - pass - - if len(result) > 0: - return [node_tuples, result] - for current in node_tuples: - # this is a tesselation of a square into two triangles. - # subtriangles that fall outside of the master tet are simply not added. - - # positively oriented - try_add_tri(current, (0, 0), (1, 0), (0, 1)) - try_add_tri(current, (1, 0), (1, 1), (0, 1)) - return node_tuples, result - - -def tesselatetet(): - node_tuples = list(gnitstam(2, 3)) - - node_dict = { - ituple: idx - for idx, ituple in enumerate(node_tuples)} - - def try_add_tet(current, d1, d2, d3, d4): - try: - result.append(( - node_dict[add_tuples(current, d1)], - node_dict[add_tuples(current, d2)], - node_dict[add_tuples(current, d3)], - node_dict[add_tuples(current, d4)], - )) - except KeyError: - pass - - result = [] - - if len(result) > 0: - return [node_tuples, result] - for current in node_tuples: - # this is a tesselation of a cube into six tets. - # subtets that fall outside of the master tet are simply not added. - - # positively oriented - try_add_tet(current, (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1)) - try_add_tet(current, (1, 0, 1), (1, 0, 0), (0, 0, 1), (0, 1, 0)) - try_add_tet(current, (1, 0, 1), (0, 1, 1), (0, 1, 0), (0, 0, 1)) - - try_add_tet(current, (1, 0, 0), (0, 1, 0), (1, 0, 1), (1, 1, 0)) - try_add_tet(current, (0, 1, 1), (0, 1, 0), (1, 1, 0), (1, 0, 1)) - try_add_tet(current, (0, 1, 1), (1, 1, 1), (1, 0, 1), (1, 1, 0)) - - return node_tuples, result - - -def tesselate_simplex_bisection(dim): - if dim == 1: - return tesselateseg() - elif dim == 2: - return tesselatetri() - elif dim == 3: - return tesselatetet() - else: - raise ValueError("cannot tesselate %d-simplex" % dim) + return tuple(midpoint(ai, bi) for ai, bi in zip(a, b)) + + +def _get_ref_midpoints(shape, ref_vertices): + r"""The reference element is considered to be, e.g. for a 2-simplex:: + + F + | \ + | \ + D----E + | /| \ + | / | \ + A----B----C + + where the midpoints are ``(B, E, D)``. The same applies to other shapes + and higher dimensions. + + :arg ref_vertices: a :class:`list` of node index :class:`tuple`\ s + on :math:`[0, 2]^d`. + """ + + from pytools import add_tuples + space = mp.space_for_shape(shape, 1) + orig_vertices = [ + add_tuples(vt, vt) for vt in mp.node_tuples_for_space(space) + ] + return [rv for rv in ref_vertices if rv not in orig_vertices] + +# }}} + + +# {{{ modepy.shape tesselation and resampling + +@get_group_midpoints.register(_ModepyElementGroup) +def _(meg: _ModepyElementGroup, el_tess_info, elements): + shape = meg._modepy_shape + space = meg._modepy_space + + # get midpoints in reference coordinates + midpoints = -1 + np.array(_get_ref_midpoints(shape, el_tess_info.ref_vertices)) + + # resample midpoints to ambient coordinates + resampling_mat = mp.resampling_matrix( + mp.basis_for_space(space, shape).functions, + midpoints.T, + meg.unit_nodes) + + resampled_midpoints = np.einsum("mu,deu->edm", + resampling_mat, meg.nodes[:, elements]) + + return dict(zip(elements, resampled_midpoints)) + + +@get_group_tesselated_nodes.register(_ModepyElementGroup) +def _(meg: _ModepyElementGroup, el_tess_info, elements): + shape = meg._modepy_shape + space = meg._modepy_space + + # get child unit node coordinates. + from meshmode.mesh.refinement.utils import map_unit_nodes_to_children + child_unit_nodes = np.hstack(list( + map_unit_nodes_to_children(meg, meg.unit_nodes, el_tess_info) + )) + + # resample child nodes to ambient coordinates + resampling_mat = mp.resampling_matrix( + mp.basis_for_space(space, shape).functions, + child_unit_nodes, + meg.unit_nodes) + + resampled_unit_nodes = np.einsum("cu,deu->edc", + resampling_mat, meg.nodes[:, elements]) + + ambient_dim = len(meg.nodes) + nunit_nodes = len(meg.unit_nodes[0]) + + return { + el: resampled_unit_nodes[iel].reshape((ambient_dim, -1, nunit_nodes)) + for iel, el in enumerate(elements) + } + + +@get_group_tesselation_info.register(_ModepyElementGroup) +def _(meg: _ModepyElementGroup): + shape = meg._modepy_shape + space = type(meg._modepy_space)(meg.dim, 2) + + ref_vertices = mp.node_tuples_for_space(space) + ref_vertices_to_index = {rv: i for i, rv in enumerate(ref_vertices)} + + from pytools import add_tuples + space = type(meg._modepy_space)(meg.dim, 1) + orig_vertices = tuple([ + add_tuples(vt, vt) for vt in mp.node_tuples_for_space(space) + ]) + orig_vertex_indices = [ref_vertices_to_index[vt] for vt in orig_vertices] + + midpoints = _get_ref_midpoints(shape, ref_vertices) + midpoint_indices = [ref_vertices_to_index[mp] for mp in midpoints] + + midpoint_to_vertex_pairs = { + midpoint: (i, j) + for i, ivt in enumerate(orig_vertices) + for j, jvt in enumerate(orig_vertices) + for midpoint in [_midpoint_tuples(ivt, jvt)] + if i < j and midpoint in midpoints + } + # ensure order matches the one in midpoint_indices + midpoint_vertex_pairs = [midpoint_to_vertex_pairs[m] for m in midpoints] + + return ElementTesselationInfo( + ref_vertices=ref_vertices, + children=np.array(mp.submesh_for_shape(shape, ref_vertices)), + orig_vertex_indices=np.array(orig_vertex_indices), + midpoint_indices=np.array(midpoint_indices), + midpoint_vertex_pairs=midpoint_vertex_pairs, + ) + +# }}} + +# vim: foldmethod=marker diff --git a/meshmode/mesh/refinement/utils.py b/meshmode/mesh/refinement/utils.py index be64eb07c..4ccde721f 100644 --- a/meshmode/mesh/refinement/utils.py +++ b/meshmode/mesh/refinement/utils.py @@ -20,48 +20,68 @@ THE SOFTWARE. """ +from functools import singledispatch import numpy as np +from meshmode.mesh import ( + MeshElementGroup, + SimplexElementGroup, + TensorProductElementGroup) + import logging logger = logging.getLogger(__name__) -# {{{ map unit nodes to children - +# {{{ map child unit nodes -def map_unit_nodes_to_children(unit_nodes, tesselation): +@singledispatch +def map_unit_nodes_to_children(meg: MeshElementGroup, + unit_nodes, el_tess_info) -> np.ndarray: + """ + :arg unit_nodes: an :class:`~numpy.ndarray` of unit nodes on the + element type described by *meg*. + :arg el_tess_info: a + :class:`~meshmode.mesh.refinement.tesselate.ElementTesselationInfo`. + :returns: an :class:`~numpy.ndarray` of mapped unit nodes for each + child in the tesselation. """ - Given a collection of unit nodes, return the coordinates of the - unit nodes mapped onto each of the children of the reference - element. + raise NotImplementedError(type(meg).__name__) - The tesselation should follow the format of - :func:`meshmode.mesh.tesselate.tesselatetri()` or - :func:`meshmode.mesh.tesselate.tesselatetet()`. - `unit_nodes` should be relative to the unit simplex coordinates in - :module:`modepy`. +@map_unit_nodes_to_children.register(SimplexElementGroup) +def _(meg: SimplexElementGroup, unit_nodes, el_tess_info): + ref_vertices = np.array(el_tess_info.ref_vertices, dtype=np.float).T + assert len(unit_nodes.shape) == 2 - :arg unit_nodes: shaped `(dim, nunit_nodes)` - :arg tesselation: With attributes `ref_vertices`, `children` - """ - ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float) + for child in el_tess_info.children: + origin = ref_vertices[:, child[0]].reshape(-1, 1) + basis = ref_vertices[:, child[1:]] - origin + + # mapped nodes are on [0, 2], so we subtract 1 to get it to [-1, 1] + yield basis.dot((unit_nodes + 1.0) / 2.0) + origin - 1.0 + +@map_unit_nodes_to_children.register(TensorProductElementGroup) +def _(meg: TensorProductElementGroup, unit_nodes, el_tess_info): + ref_vertices = np.array(el_tess_info.ref_vertices, dtype=np.float).T assert len(unit_nodes.shape) == 2 - for child_element in tesselation.children: - center = np.vstack(ref_vertices[child_element[0]]) - # Scale by 1/2 since sides in the tesselation have length 2. - aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2 - # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices. - # Hence the translation by +/- 1. - yield aff_mat.dot(unit_nodes + 1) + center - 1 + # NOTE: nodes indices in the unit hypercube that form the `e_i` basis + basis_indices = 2**np.arange(meg.dim) + + for child in el_tess_info.children: + child_arr = np.array(child) + origin = ref_vertices[:, child_arr[0]].reshape(-1, 1) + basis = ref_vertices[:, child_arr[basis_indices]] - origin + + # mapped nodes are on [0, 2], so we subtract 1 to get it to [-1, 1] + yield basis.dot((unit_nodes + 1.0) / 2.0) + origin - 1.0 # }}} -# {{{ test nodal adjacency against geometry +# {{{ test nodal adjacency against geometry def is_symmetric(relation, debug=False): for a, other_list in enumerate(relation): diff --git a/test/test_refinement.py b/test/test_refinement.py index afbda13ca..c88f7a148 100644 --- a/test/test_refinement.py +++ b/test/test_refinement.py @@ -26,20 +26,24 @@ import numpy as np import pytest -from meshmode.array_context import ( # noqa +from meshmode import _acf # noqa: F401 +from meshmode.array_context import ( # noqa: F401 pytest_generate_tests_for_pyopencl_array_context as pytest_generate_tests) from meshmode.dof_array import thaw -from meshmode.mesh.generation import ( # noqa +from meshmode.mesh.generation import ( # noqa: F401 generate_icosahedron, generate_box_mesh, make_curve_mesh, ellipse) from meshmode.mesh.refinement.utils import check_nodal_adj_against_geometry from meshmode.mesh.refinement import Refiner, RefinerWithoutAdjacency +from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup from meshmode.discretization.poly_element import ( InterpolatoryQuadratureSimplexGroupFactory, PolynomialWarpAndBlendGroupFactory, PolynomialEquidistantSimplexGroupFactory, + LegendreGaussLobattoTensorProductGroupFactory, + GaussLegendreTensorProductGroupFactory, ) logger = logging.getLogger(__name__) @@ -155,14 +159,17 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations): check_nodal_adj_against_geometry(mesh) -@pytest.mark.parametrize("refiner_cls", [ - Refiner, - RefinerWithoutAdjacency - ]) -@pytest.mark.parametrize("group_factory", [ - InterpolatoryQuadratureSimplexGroupFactory, - PolynomialWarpAndBlendGroupFactory, - PolynomialEquidistantSimplexGroupFactory +@pytest.mark.parametrize(("refiner_cls", "group_factory"), [ + (Refiner, InterpolatoryQuadratureSimplexGroupFactory), + (Refiner, PolynomialWarpAndBlendGroupFactory), + (Refiner, PolynomialEquidistantSimplexGroupFactory), + + (RefinerWithoutAdjacency, InterpolatoryQuadratureSimplexGroupFactory), + (RefinerWithoutAdjacency, PolynomialWarpAndBlendGroupFactory), + (RefinerWithoutAdjacency, PolynomialEquidistantSimplexGroupFactory), + + (RefinerWithoutAdjacency, LegendreGaussLobattoTensorProductGroupFactory), + (RefinerWithoutAdjacency, GaussLegendreTensorProductGroupFactory), ]) @pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [ ("circle", 1, [20, 30, 40]), @@ -181,14 +188,19 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations): def test_refinement_connection( actx_factory, refiner_cls, group_factory, mesh_name, dim, mesh_pars, mesh_order, refine_flags, visualize=False): + group_cls = group_factory.mesh_group_class + if issubclass(group_cls, TensorProductElementGroup): + if mesh_name in ["circle", "blob"]: + pytest.skip("mesh does not have tensor product support") + from random import seed seed(13) - # Discretization order - order = 5 - actx = actx_factory() + # discretization order + order = 5 + from meshmode.discretization import Discretization from meshmode.discretization.connection import ( make_refinement_connection, check_connection) @@ -213,7 +225,8 @@ def test_refinement_connection( h = float(mesh_par) elif mesh_name == "warp": from meshmode.mesh.generation import generate_warped_rect_mesh - mesh = generate_warped_rect_mesh(dim, order=mesh_order, n=mesh_par) + mesh = generate_warped_rect_mesh(dim, order=mesh_order, n=mesh_par, + group_cls=group_cls) h = 1/mesh_par else: raise ValueError("mesh_name not recognized") @@ -289,12 +302,17 @@ def f(x): or eoc_rec.max_error() < 1e-14) -@pytest.mark.parametrize("with_adjacency", [True, False]) -def test_uniform_refinement(with_adjacency): +@pytest.mark.parametrize(("group_cls", "with_adjacency"), [ + (SimplexElementGroup, True), + (SimplexElementGroup, False), + (TensorProductElementGroup, False) + ]) +def test_uniform_refinement(group_cls, with_adjacency): make_mesh = partial(generate_box_mesh, ( np.linspace(0.0, 1.0, 2), np.linspace(0.0, 1.0, 3), - np.linspace(0.0, 1.0, 2)), order=4) + np.linspace(0.0, 1.0, 2)), + order=4, group_cls=group_cls) mesh = make_mesh() from meshmode.mesh.refinement import refine_uniformly