From dc4ddc9d4e900b0806f5923d619657109c516463 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 4 Apr 2022 17:02:41 -0500 Subject: [PATCH 1/6] CHERRY-PICK: Attach discretization entity tags --- meshmode/array_context.py | 38 +++++++++++ meshmode/discretization/__init__.py | 26 ++++++-- meshmode/discretization/connection/direct.py | 23 ++++++- meshmode/discretization/connection/face.py | 7 +- meshmode/discretization/connection/modal.py | 7 +- .../discretization/connection/same_mesh.py | 11 ++-- meshmode/discretization/visualization.py | 10 ++- meshmode/transform_metadata.py | 66 ++++++++++++++++++- 8 files changed, 169 insertions(+), 19 deletions(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 12b6e3f48..6f48f0377 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -27,12 +27,14 @@ import sys from warnings import warn +from typing import Mapping, Sequence, Union from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase from arraycontext.pytest import ( _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory) +from pytools.tag import Tag def thaw(actx, ary): @@ -235,6 +237,27 @@ def transform_loopy_program(self, t_unit): # {{{ pytato pyopencl array context subclass class PytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContextBase): + def transform_dag(self, dag): + dag = super().transform_dag(dag) + + # {{{ /!\ Remove tags from NamedArrays + # See + + import pytato as pt + + def untag_loopy_call_results(expr): + if isinstance(expr, pt.NamedArray): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + def transform_loopy_program(self, t_unit): # FIXME: Do not parallelize for now. return t_unit @@ -326,4 +349,19 @@ def _import_names(): # }}} +# {{{ tagging helpers + +def tag_axes(ary, actx, dim_to_tags: Mapping[int, Union[Sequence[Tag], Tag]] + ): + """ + Return a copy of *ary* with the axes in *dim_to_tags* tagged with their + corresponding tags. + """ + for iaxis, tags in dim_to_tags.items(): + ary = actx.tag_axis(iaxis, tags, ary) + + return ary + +# }}} + # vim: foldmethod=marker diff --git a/meshmode/discretization/__init__.py b/meshmode/discretization/__init__.py index 6ce8850a2..b11fefa9e 100644 --- a/meshmode/discretization/__init__.py +++ b/meshmode/discretization/__init__.py @@ -40,13 +40,16 @@ from pytools import memoize_in, memoize_method, keyed_memoize_in from pytools.obj_array import make_obj_array from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + FirstAxisIsElementsTag, DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) # underscored because it shouldn't be imported from here. from meshmode.dof_array import DOFArray as _DOFArray from meshmode.mesh import ( Mesh as _Mesh, MeshElementGroup as _MeshElementGroup) +from meshmode.array_context import tag_axes __doc__ = """ @@ -542,9 +545,13 @@ def _new_array(self, actx, creation_func, dtype=None): else: dtype = np.dtype(dtype) - return _DOFArray(actx, tuple( - creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype) - for grp in self.groups)) + result = _DOFArray(actx, + tuple(creation_func(shape=(grp.nelements, + grp.nunit_dofs), + dtype=dtype) + for grp in self.groups)) + return tag_axes(result, actx, {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) def empty(self, actx: ArrayContext, dtype: Optional[np.dtype] = None) -> _DOFArray: @@ -644,6 +651,8 @@ def nodes(self, cached: bool = True) -> np.ndarray: def resample_mesh_nodes(grp, iaxis): # TODO: would be nice to have the mesh use an array context already nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis]) + nodes = tag_axes(nodes, actx, {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) grp_unit_nodes = grp.unit_nodes.reshape(-1) meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1) @@ -654,7 +663,10 @@ def resample_mesh_nodes(grp, iaxis): return nodes return actx.einsum("ij,ej->ei", - actx.from_numpy(grp.from_mesh_interp_matrix()), + actx.tag_axis( + 0, + DiscretizationDOFAxisTag(), + actx.from_numpy(grp.from_mesh_interp_matrix())), nodes, tagged=(FirstAxisIsElementsTag(),)) @@ -714,7 +726,9 @@ def get_mat(grp, gref_axes): return _DOFArray(actx, tuple( actx.einsum("ij,ej->ei", - get_mat(grp, ref_axes), + actx.tag_axis(0, + DiscretizationDOFAxisTag(), + get_mat(grp, ref_axes)), vec[igrp], tagged=(FirstAxisIsElementsTag(),)) for igrp, grp in enumerate(discr.groups))) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index c5dcd108a..97eff8495 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -29,7 +29,8 @@ import loopy as lp from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + DiscretizationElementAxisTag, DiscretizationDOFAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( ArrayContext, NotAnArrayContainerError, @@ -39,6 +40,7 @@ from meshmode.discretization import Discretization, ElementGroupBase from meshmode.dof_array import DOFArray +from meshmode.array_context import tag_axes from dataclasses import dataclass @@ -370,7 +372,11 @@ def _resample_matrix(self, actx: ArrayContext, to_group_index: int, from_grp_basis_fcts, ibatch.result_unit_nodes, from_grp.unit_nodes) - return actx.freeze(actx.from_numpy(result)) + thawed_ary = actx.from_numpy(result) + + # freeze, attach metadata + return actx.freeze(tag_axes(thawed_ary, actx, + {1: DiscretizationDOFAxisTag()})) # }}} @@ -698,6 +704,13 @@ def group_pick_knl(is_surjective: bool): grp_ary_contrib, 0) + # attach metadata + grp_ary_contrib = tag_axes( + grp_ary_contrib, + actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) + group_array_contributions.append(grp_ary_contrib) else: for fgpd in group_pick_info: @@ -774,6 +787,12 @@ def group_pick_knl(is_surjective: bool): self.to_discr.groups[i_tgrp].nunit_dofs) )["result"] + # attach metadata + batch_result = tag_axes(batch_result, + actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) + group_array_contributions.append(batch_result) if group_array_contributions: diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py index 28a03a599..34f65d893 100644 --- a/meshmode/discretization/connection/face.py +++ b/meshmode/discretization/connection/face.py @@ -21,6 +21,7 @@ """ from dataclasses import dataclass +from meshmode.transform_metadata import DiscretizationElementAxisTag import numpy as np import modepy as mp @@ -445,8 +446,10 @@ def make_face_to_all_faces_embedding(actx, faces_connection, all_faces_discr, assert all_faces_grp.nelements == nfaces * vol_grp.nelements to_element_indices = actx.freeze( - vol_grp.nelements*iface - + actx.thaw(src_batch.from_element_indices)) + actx.tag_axis(0, + DiscretizationElementAxisTag(), + vol_grp.nelements*iface + + actx.thaw(src_batch.from_element_indices))) batches.append( InterpolationBatch( diff --git a/meshmode/discretization/connection/modal.py b/meshmode/discretization/connection/modal.py index c0178563d..596c72033 100644 --- a/meshmode/discretization/connection/modal.py +++ b/meshmode/discretization/connection/modal.py @@ -29,7 +29,8 @@ from arraycontext import ( NotAnArrayContainerError, serialize_container, deserialize_container) -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag) from meshmode.discretization import InterpolatoryElementGroupBase from meshmode.discretization.poly_element import QuadratureSimplexElementGroup from meshmode.discretization.connection.direct import DiscretizationConnection @@ -163,7 +164,9 @@ def vandermonde_inverse(grp): return actx.from_numpy(vdm_inv) return actx.einsum("ij,ej->ei", - vandermonde_inverse(grp), + actx.tag_axis(0, + DiscretizationDOFAxisTag(), + vandermonde_inverse(grp)), ary, tagged=(FirstAxisIsElementsTag(),)) diff --git a/meshmode/discretization/connection/same_mesh.py b/meshmode/discretization/connection/same_mesh.py index 638bcb598..ba79d056a 100644 --- a/meshmode/discretization/connection/same_mesh.py +++ b/meshmode/discretization/connection/same_mesh.py @@ -21,6 +21,7 @@ """ import numpy as np +from meshmode.transform_metadata import DiscretizationElementAxisTag # {{{ same-mesh constructor @@ -42,10 +43,12 @@ def make_same_mesh_connection(actx, to_discr, from_discr): groups = [] for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)): all_elements = actx.freeze( - actx.from_numpy( - np.arange( - fgrp.nelements, - dtype=np.intp))) + actx.tag_axis(0, + DiscretizationElementAxisTag(), + actx.from_numpy( + np.arange( + fgrp.nelements, + dtype=np.intp)))) ibatch = InterpolationBatch( from_group_index=igrp, from_element_indices=all_elements, diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 406abbca8..f47fc9afb 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -33,6 +33,7 @@ from pytools.obj_array import make_obj_array from arraycontext import thaw, flatten from meshmode.dof_array import DOFArray +from meshmode.transform_metadata import DiscretizationMeshNodesAxisTag from modepy.shapes import Shape, Simplex, Hypercube @@ -144,7 +145,9 @@ def _resample_to_numpy(conn, vis_discr, vec, *, stack=False, by_group=False): from meshmode.dof_array import check_dofarray_against_discr check_dofarray_against_discr(vis_discr, vec) - return actx.to_numpy(flatten(vec, actx)) + return actx.to_numpy(actx.tag_axis(0, + DiscretizationMeshNodesAxisTag(), + flatten(vec, actx))) else: raise TypeError(f"unsupported array type: {type(vec).__name__}") @@ -550,7 +553,10 @@ def copy_with_same_connectivity(self, actx, discr, skip_tests=False): def _vis_nodes_numpy(self): actx = self.vis_discr._setup_actx return np.array([ - actx.to_numpy(flatten(thaw(ary, actx), actx)) + actx.to_numpy(actx.tag_axis( + 0, + DiscretizationMeshNodesAxisTag(), + flatten(thaw(ary, actx), actx))) for ary in self.vis_discr.nodes() ]) diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index 981685c60..651d58707 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -2,6 +2,12 @@ .. autoclass:: FirstAxisIsElementsTag .. autoclass:: ConcurrentElementInameTag .. autoclass:: ConcurrentDOFInameTag +.. autoclass:: DiscretizationEntityAxisTag +.. autoclass:: DiscretizationElementAxisTag +.. autoclass:: DiscretizationFaceAxisTag +.. autoclass:: DiscretizationDOFAxisTag +.. autoclass:: DiscretizationPhysicalDimAxisTag +.. autoclass:: DiscretizationRefDimAxisTag """ __copyright__ = """ @@ -28,7 +34,7 @@ THE SOFTWARE. """ -from pytools.tag import Tag +from pytools.tag import Tag, tag_dataclass, UniqueTag class FirstAxisIsElementsTag(Tag): @@ -57,3 +63,61 @@ class ConcurrentDOFInameTag(Tag): computations for all DOFs within each element may be performed concurrently. """ + + +class DiscretizationEntityAxisTag(UniqueTag): + """ + A tag applicable to an array's axis to describe which discretization entity + the axis indexes over. + """ + + +@tag_dataclass +class DiscretizationElementAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's elements. + """ + + +@tag_dataclass +class DiscretizationFaceAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's facets. + """ + + +@tag_dataclass +class DiscretizationDOFAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's DoFs. + """ + + +@tag_dataclass +class DiscretizationMeshNodesAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's DoFs. + """ + + +@tag_dataclass +class DiscretizationDimAxisTag(DiscretizationEntityAxisTag): + pass + + +class DiscretizationPhysicalDimAxisTag(DiscretizationDimAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's reference coordinate dimensions. + """ + + +class DiscretizationRefDimAxisTag(DiscretizationDimAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + the discretization's physical coordinate dimensions. + """ From 563f30d83b66cf21ca6b7fe16e20476cb0ffa4fa Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:00:21 -0600 Subject: [PATCH 2/6] CHERRY-PICK: SingleGridPytatoArrayContext --- examples/simple-dg.py | 30 ++-- meshmode/array_context.py | 284 ++++++++++++++++++++++++++++++++++++++ meshmode/pytato_utils.py | 62 +++++++++ requirements.txt | 2 +- 4 files changed, 360 insertions(+), 18 deletions(-) create mode 100644 meshmode/pytato_utils.py diff --git a/examples/simple-dg.py b/examples/simple-dg.py index 8a41712c9..4599474fe 100644 --- a/examples/simple-dg.py +++ b/examples/simple-dg.py @@ -33,7 +33,7 @@ from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from meshmode.dof_array import DOFArray, flat_norm from meshmode.array_context import (PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext) + SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext) from arraycontext import ( freeze, thaw, ArrayContainer, @@ -456,11 +456,10 @@ def main(lazy=False): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) - actx_outer = PyOpenCLArrayContext(queue, force_device_scalars=True) if lazy: - actx_rhs = PytatoPyOpenCLArrayContext(queue) + actx = PytatoPyOpenCLArrayContext(queue) else: - actx_rhs = actx_outer + actx = PyOpenCLArrayContext(queue, force_device_scalars=True) nel_1d = 16 from meshmode.mesh.generation import generate_regular_rect_mesh @@ -476,37 +475,34 @@ def main(lazy=False): logger.info("%d elements", mesh.nelements) - discr = DGDiscretization(actx_outer, mesh, order=order) + discr = DGDiscretization(actx, mesh, order=order) fields = WaveState( - u=bump(actx_outer, discr), - v=make_obj_array([discr.zeros(actx_outer) for i in range(discr.dim)]), + u=bump(actx, discr), + v=make_obj_array([discr.zeros(actx) for i in range(discr.dim)]), ) from meshmode.discretization.visualization import make_visualizer - vis = make_visualizer(actx_outer, discr.volume_discr) + vis = make_visualizer(actx, discr.volume_discr) def rhs(t, q): - return wave_operator(actx_rhs, discr, c=1, q=q) + return wave_operator(actx, discr, c=1, q=q) - compiled_rhs = actx_rhs.compile(rhs) - - def rhs_wrapper(t, q): - r = compiled_rhs(t, thaw(freeze(q, actx_outer), actx_rhs)) - return thaw(freeze(r, actx_rhs), actx_outer) + compiled_rhs = actx.compile(rhs) t = np.float64(0) t_final = 3 istep = 0 while t < t_final: - fields = rk4_step(fields, t, dt, rhs_wrapper) + fields = thaw(freeze(fields, actx), actx) + fields = rk4_step(fields, t, dt, compiled_rhs) if istep % 10 == 0: # FIXME: Maybe an integral function to go with the # DOFArray would be nice? assert len(fields.u) == 1 logger.info("[%05d] t %.5e / %.5e norm %.5e", - istep, t, t_final, actx_outer.to_numpy(flat_norm(fields.u, 2))) + istep, t, t_final, actx.to_numpy(flat_norm(fields.u, 2))) vis.write_vtk_file("fld-wave-min-%04d.vtu" % istep, [ ("q", fields), ]) @@ -514,7 +510,7 @@ def rhs_wrapper(t, q): t += dt istep += 1 - assert flat_norm(fields.u, 2) < 100 + assert actx.to_numpy(flat_norm(fields.u, 2)) < 100 if __name__ == "__main__": diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 6f48f0377..fb2386fd9 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -26,6 +26,8 @@ """ import sys +import logging + from warnings import warn from typing import Mapping, Sequence, Union from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase @@ -35,6 +37,9 @@ _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory) from pytools.tag import Tag +from loopy.translation_unit import for_each_kernel + +logger = logging.getLogger(__name__) def thaw(actx, ary): @@ -364,4 +369,283 @@ def tag_axes(ary, actx, dim_to_tags: Mapping[int, Union[Sequence[Tag], Tag]] # }}} + +@for_each_kernel +def _single_grid_work_group_transform(kernel, cl_device): + import loopy as lp + from meshmode.transform_metadata import (ConcurrentElementInameTag, + ConcurrentDOFInameTag) + + splayed_inames = set() + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + for insn in kernel.instructions: + if insn.within_inames in splayed_inames: + continue + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + + splayed_inames.add(insn.within_inames) + continue + + for iname in insn.within_inames: + if kernel.iname_tags_of_type(iname, + ConcurrentElementInameTag): + assert bigger_loop is None + bigger_loop = iname + elif kernel.iname_tags_of_type(iname, + ConcurrentDOFInameTag): + assert smaller_loop is None + smaller_loop = iname + else: + pass + + if bigger_loop or smaller_loop: + assert (bigger_loop is not None + and smaller_loop is not None) + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=kernel.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + splayed_inames.add(insn.within_inames) + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + + +def _alias_global_temporaries(t_unit): + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + """ + from loopy.kernel.data import AddressSpace + from loopy.kernel import KernelState + from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, + CallKernel, ReturnFromKernel, Barrier) + from loopy.schedule.tools import get_return_from_kernel_mapping + from pytools import UniqueNameGenerator + from collections import defaultdict + + kernel = t_unit.default_entrypoint + assert kernel.state == KernelState.LINEARIZED + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + temp_to_live_interval_start = {} + temp_to_live_interval_end = {} + return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) + + for sched_idx, sched_item in enumerate(kernel.linearization): + if isinstance(sched_item, RunInstruction): + for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = sched_idx + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] + elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, + ReturnFromKernel, Barrier)): + # no variables are accessed within these schedule items => do + # nothing. + pass + else: + raise NotImplementedError(type(sched_item)) + + vng = UniqueNameGenerator() + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage = defaultdict(set) + + sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] + sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] + + for tv, just_alive_idx in temp_to_live_interval_start.items(): + sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) + + for tv, just_dead_idx in temp_to_live_interval_end.items(): + sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) + + new_tvs = {} + + for sched_idx, _ in enumerate(kernel.linearization): + just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] + to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + for tv_name in sorted(to_be_allocated_temps): + assert len(to_be_allocated_temps) <= 1 + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_msh_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + assert name in new_tvs + + kernel = kernel.copy(temporary_variables=new_tvs) + + return t_unit.with_kernel(kernel) + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +def deduplicate_data_wrappers(dag): + import pytato as pt + data_wrapper_cache = {} + data_wrappers_encountered = 0 + + def cached_data_wrapper_if_present(ary): + nonlocal data_wrappers_encountered + + if isinstance(ary, pt.DataWrapper): + + data_wrappers_encountered += 1 + cache_key = (ary.data.base_data.int_ptr, ary.data.offset, + ary.shape, ary.data.strides) + try: + result = data_wrapper_cache[cache_key] + except KeyError: + result = ary + data_wrapper_cache[cache_key] = result + + return result + else: + return ary + + dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) + + if data_wrappers_encountered: + logger.info("data wrapper de-duplication: " + "%d encountered, %d kept, %d eliminated", + data_wrappers_encountered, + len(data_wrapper_cache), + data_wrappers_encountered - len(data_wrapper_cache)) + + return dag + + +class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): + """ + A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL + kernel so that the work + """ + def transform_loopy_program(self, t_unit): + import loopy as lp + + t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) + t_unit = lp.set_options(t_unit, "insert_gbarriers") + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + + return t_unit + + def _get_fake_numpy_namespace(self): + from meshmode.pytato_utils import ( + EagerReduceComputingPytatoFakeNumpyNamespace) + return EagerReduceComputingPytatoFakeNumpyNamespace(self) + + def transform_dag(self, dag): + import pytato as pt + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_vec(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) + + # }}} + + # {{{ materialize all einsums + + def materialize_einsums(ary: pt.Array) -> pt.Array: + if isinstance(ary, pt.Einsum): + return ary.tagged(pt.tags.ImplStored()) + + return ary + + dag = pt.transform.map_and_copy(dag, materialize_einsums) + + # }}} + + dag = pt.transform.materialize_with_mpms(dag) + dag = deduplicate_data_wrappers(dag) + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + # vim: foldmethod=marker diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py new file mode 100644 index 000000000..960decd4e --- /dev/null +++ b/meshmode/pytato_utils.py @@ -0,0 +1,62 @@ +from functools import partial, reduce +from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext import rec_map_reduce_array_container +import pyopencl.array as cl_array + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +class EagerReduceComputingPytatoFakeNumpyNamespace(PytatoFakeNumpyNamespace): + """ + A Numpy-namespace that computes the reductions eagerly whenever possible. + """ + def sum(self, a, axis=None, dtype=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + + def _pt_sum(ary): + return cl_array.sum(self._array_context.freeze(ary), + dtype=dtype, + queue=self._array_context.queue) + + return self._array_context.thaw(rec_map_reduce_array_container(sum, + _pt_sum, + a)) + else: + return super().sum(a, axis=axis, dtype=dtype) + + def min(self, a, axis=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + lambda ary: cl_array.min(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().min(a, axis=axis) + + def max(self, a, axis=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + lambda ary: cl_array.max(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().max(a, axis=axis) + +# vim: fdm=marker diff --git a/requirements.txt b/requirements.txt index 7ad3c51b2..35a924099 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ git+https://github.com/inducer/pytato.git#egg=pytato git+https://github.com/inducer/pymbolic.git#egg=pymbolic # also depends on pymbolic, so should come after it -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/kaushikcfd/loopy.git#egg=loopy # depends on loopy, so should come after it. git+https://github.com/inducer/arraycontext.git#egg=arraycontext From 4c7889cebe3b9cb0c3514ee6faf2680b6ffcef88 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:03:32 -0600 Subject: [PATCH 3/6] Adds more pytato utils: Axes types unification --- meshmode/pytato_utils.py | 575 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 574 insertions(+), 1 deletion(-) diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py index 960decd4e..a7a272afa 100644 --- a/meshmode/pytato_utils.py +++ b/meshmode/pytato_utils.py @@ -1,7 +1,23 @@ +import pyopencl.array as cl_array +import kanren +import pytato as pt +import unification +import logging + from functools import partial, reduce from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace from arraycontext import rec_map_reduce_array_container -import pyopencl.array as cl_array +from meshmode.transform_metadata import DiscretizationEntityAxisTag +from pytato.loopy import LoopyCall +from pytato.array import EinsumElementwiseAxis, EinsumReductionAxis +from pytato.transform import ArrayOrNames +from arraycontext import ArrayContainer +from arraycontext.container.traversal import rec_map_array_container +from typing import Set, Mapping, Tuple, Union +logger = logging.getLogger(__name__) + + +MAX_UNIFY_RETRIES = 50 # used by unify_discretization_entity_tags def _can_be_eagerly_computed(ary) -> bool: @@ -59,4 +75,561 @@ def max(self, a, axis=None): else: return super().max(a, axis=axis) + +# {{{ solve for discretization metadata for arrays' axes + +class DiscretizationEntityConstraintCollector(pt.transform.Mapper): + """ + .. warning:: + + Instances of this mapper type store state that are only for visiting a + single DAG. Using a single instance for collecting the constraints on + multiple DAGs is undefined behavior. + """ + def __init__(self): + super().__init__() + self._visited_ids: Set[int] = set() + + # axis_to_var: mapping from (array, iaxis) to the kanren variable to be + # used for unification. + self.axis_to_tag_var: Mapping[Tuple[pt.Array, int], + unification.variable.Var] = {} + self.variables_to_solve: Set[unification.variable.Var] = set() + self.constraints = [] + + # type-ignore reason: CachedWalkMapper.rec's type does not match + # WalkMapper.rec's type + def rec(self, expr: ArrayOrNames) -> None: # type: ignore + if id(expr) in self._visited_ids: + return + + # type-ignore reason: super().rec expects either 'Array' or + # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' + super().rec(expr) # type: ignore + self._visited_ids.add(id(expr)) + + def get_kanren_var_for_axis_tag(self, + expr: pt.Array, + iaxis: int + ) -> unification.variable.Var: + key = (expr, iaxis) + + if key not in self.axis_to_tag_var: + self.axis_to_tag_var[key] = kanren.var() + + return self.axis_to_tag_var[key] + + def _record_all_axes_to_be_solved_if_impl_stored(self, expr): + if expr.tags_of_type(pt.tags.ImplStored): + for iaxis in range(expr.ndim): + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, + iaxis)) + + def _record_all_axes_to_be_solved(self, expr): + for iaxis in range(expr.ndim): + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, + iaxis)) + + def record_constraint(self, lhs, rhs): + self.constraints.append((lhs, rhs)) + + def record_eq_constraints_from_tags(self, expr: pt.Array) -> None: + for iaxis, axis in enumerate(expr.axes): + if axis.tags_of_type(DiscretizationEntityAxisTag): + discr_tag, = axis.tags_of_type(DiscretizationEntityAxisTag) + axis_var = self.get_kanren_var_for_axis_tag(expr, iaxis) + self.record_constraint(axis_var, discr_tag) + + def _map_input_base(self, expr: pt.InputArgumentBase + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for dim in expr.shape: + if isinstance(dim, pt.Array): + self.rec(dim) + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + + def map_index_lambda(self, expr: pt.IndexLambda) -> None: + from pytato.utils import are_shape_components_equal + from pytato.raising import index_lambda_to_high_level_op + from pytato.raising import (BinaryOp, FullOp, WhereOp, + BroadcastOp, C99CallOp, ReduceOp) + + # {{{ record constraints for expr and its subexprs. + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for dim in expr.shape: + if isinstance(dim, pt.Array): + self.rec(dim) + + for bnd in expr.bindings.values(): + self.rec(bnd) + + # }}} + + hlo = index_lambda_to_high_level_op(expr) + + if isinstance(hlo, BinaryOp): + subexprs = (hlo.x1, hlo.x2) + elif isinstance(hlo, WhereOp): + subexprs = (hlo.condition, hlo.then, hlo.else_) + elif isinstance(hlo, FullOp): + # A full-op does not impose any constraints + subexprs = () + elif isinstance(hlo, BroadcastOp): + subexprs = (hlo.x,) + elif isinstance(hlo, C99CallOp): + subexprs = hlo.args + elif isinstance(hlo, ReduceOp): + # {{{ ReduceOp doesn't quite involve broadcasting + + i_out_axis = 0 + for i_in_axis in range(hlo.x.ndim): + if i_in_axis not in hlo.axes: + in_tag_var = self.get_kanren_var_for_axis_tag(hlo.x, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(in_tag_var, out_tag_var) + i_out_axis += 1 + + assert i_out_axis == expr.ndim + + # }}} + + for axis in hlo.axes: + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(hlo.x, + axis)) + return + + else: + raise NotImplementedError(type(hlo)) + + for subexpr in subexprs: + if isinstance(subexpr, pt.Array): + for i_in_axis, i_out_axis in zip( + range(subexpr.ndim), + range(expr.ndim-subexpr.ndim, expr.ndim)): + in_dim = subexpr.shape[i_in_axis] + out_dim = expr.shape[i_out_axis] + if are_shape_components_equal(in_dim, out_dim): + in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + # broadcasted axes, cannot belong to the same + # discretization entity. + assert are_shape_components_equal(in_dim, 1) + + def map_stack(self, expr: pt.Stack) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + # TODO; I think the axis corresponding to 'axis' need not be solved. + for ary in expr.arrays: + self.rec(ary) + + for iaxis in range(expr.ndim): + for ary in expr.arrays: + if iaxis < expr.axis: + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + + self.record_constraint(in_tag_var, out_tag_var) + elif iaxis == expr.axis: + pass + elif iaxis > expr.axis: + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis-1) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + raise AssertionError + + def map_concatenate(self, expr: pt.Concatenate) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + # TODO; I think the axis corresponding to 'axis' need not be solved. + for ary in expr.arrays: + self.rec(ary) + + for ary in expr.arrays: + assert ary.ndim == expr.ndim + for iaxis in range(expr.ndim): + if iaxis != expr.axis: + # non-concatenated axes share the dimensions. + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + self.record_constraint(in_tag_var, out_tag_var) + + def map_axis_permutation(self, expr: pt.AxisPermutation + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + + assert expr.ndim == expr.array.ndim + + for out_axis in range(expr.ndim): + in_axis = expr.axis_permutation[out_axis] + out_tag = self.get_kanren_var_for_axis_tag(expr, out_axis) + in_tag = self.get_kanren_var_for_axis_tag(expr, in_axis) + self.record_constraint(out_tag, in_tag) + + def map_basic_index(self, expr: pt.IndexBase) -> None: + from pytato.array import NormalizedSlice + from pytato.utils import are_shape_components_equal + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + + i_out_axis = 0 + + assert len(expr.indices) == expr.array.ndim + + for i_in_axis, idx in enumerate(expr.indices): + if isinstance(idx, int): + pass + else: + assert isinstance(idx, NormalizedSlice) + if (idx.step == 1 + and are_shape_components_equal(idx.start, 0) + and are_shape_components_equal(idx.stop, + expr.array.shape[i_in_axis])): + + i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, + i_in_axis) + i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(i_in_axis_tag, i_out_axis_tag) + + i_out_axis += 1 + + def map_contiguous_advanced_index(self, + expr: pt.AdvancedIndexInContiguousAxes + ) -> None: + from pytato.array import NormalizedSlice + from pytato.utils import (partition, get_shape_after_broadcasting, + are_shapes_equal, are_shape_components_equal) + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + for idx in expr.indices: + if isinstance(idx, pt.Array): + self.rec(idx) + + i_adv_indices, i_basic_indices = partition( + lambda idx: isinstance(expr.indices[idx], NormalizedSlice), + range(len(expr.indices))) + npre_advanced_basic_indices = len([i_idx + for i_idx in i_basic_indices + if i_idx < i_adv_indices[0]]) + npost_advanced_basic_indices = len([i_idx + for i_idx in i_basic_indices + if i_idx > i_adv_indices[-1]]) + + indirection_arrays = [expr.indices[i_idx] for i_idx in i_adv_indices] + assert are_shapes_equal( + get_shape_after_broadcasting(indirection_arrays), + expr.shape[ + npre_advanced_basic_indices:expr.ndim-npost_advanced_basic_indices]) + + for subexpr in indirection_arrays: + if isinstance(subexpr, pt.Array): + for i_in_axis, i_out_axis in zip( + range(subexpr.ndim), + range(expr.ndim-subexpr.ndim+npre_advanced_basic_indices, + expr.ndim-npost_advanced_basic_indices)): + in_dim = subexpr.shape[i_in_axis] + out_dim = expr.shape[i_out_axis] + if are_shape_components_equal(in_dim, out_dim): + in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + # broadcasted axes, cannot belong to the same + # discretization entity. + assert are_shape_components_equal(in_dim, 1) + + def map_non_contiguous_advanced_index(self, + expr: pt.AdvancedIndexInNoncontiguousAxes + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + for idx in expr.indices: + if isinstance(idx, pt.Array): + self.rec(idx) + + def map_reshape(self, expr: pt.Reshape) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + # we can add constraints to reshape that only include new axes in its + # reshape. + # Other reshapes do not 'conserve' the types in our type-system. + # Well *what if*. Let's just say this type inference fails for + # non-trivial 'reshapes'. So, what are the 'trivial' reshapes? + # trivial reshapes: + # (x1, x2, ... xn) -> ((1,)*, x1, (1,)*, x2, (1,)*, x3, (1,)*, ..., xn, 1*) + # given all(x1!=1, x2!=1, x3!=1, .. xn!= 1) + if ((1 not in (expr.array.shape)) # leads to ambiguous newaxis + and (set(expr.shape) <= (set(expr.array.shape) | {1}))): + i_in_axis = 0 + for i_out_axis, dim in enumerate(expr.shape): + if dim != 1: + assert dim == expr.array.shape[i_in_axis] + i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, + i_in_axis) + i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(i_in_axis_tag, i_out_axis_tag) + i_in_axis += 1 + else: + # print(f"Skipping: {expr.array.shape} -> {expr.shape}") + # Wacky reshape => bail. + return + + def map_einsum(self, expr: pt.Einsum) -> None: + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for arg in expr.args: + self.rec(arg) + + descr_to_tag = {} + for iaxis in range(expr.ndim): + descr_to_tag[EinsumElementwiseAxis(iaxis)] = ( + self.get_kanren_var_for_axis_tag(expr, iaxis)) + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + # if an einsum is stored => every argument's axes must + # also be inferred, even those that are getting reduced. + for iarg_axis, descr in enumerate(access_descrs): + in_tag_var = self.get_kanren_var_for_axis_tag(arg, + iarg_axis) + + if descr in descr_to_tag: + self.record_constraint(descr_to_tag[descr], in_tag_var) + else: + descr_to_tag[descr] = in_tag_var + + if isinstance(descr, EinsumReductionAxis): + self.variables_to_solve.add(in_tag_var) + + def map_dict_of_named_arrays(self, expr: pt.DictOfNamedArrays + ) -> None: + for _, subexpr in sorted(expr._data.items()): + self.rec(subexpr) + self._record_all_axes_to_be_solved(subexpr) + + def map_loopy_call(self, expr: LoopyCall) -> None: + for _, subexpr in sorted(expr.bindings.items()): + if isinstance(subexpr, pt.Array): + if not isinstance(subexpr, pt.InputArgumentBase): + self._record_all_axes_to_be_solved(subexpr) + self.rec(subexpr) + + # there's really no good way to propagate the metadata in this case. + # One *could* raise the loopy kernel instruction expressions to + # high level ops, but that's really involved and probably not worth it. + + def map_named_array(self, expr: pt.NamedArray) -> None: + self.record_eq_constraints_from_tags(expr) + self.rec(expr._container) + + def map_distributed_send_ref_holder(self, + expr: pt.DistributedSendRefHolder + ) -> None: + self.record_eq_constraints_from_tags(expr) + self.rec(expr.passthrough_data) + for idim in range(expr.ndim): + assert (expr.passthrough_data.shape[idim] + == expr.shape[idim]) + self.record_constraint( + self.get_kanren_var_for_axis_tag(expr.passthrough_data, + idim), + self.get_kanren_var_for_axis_tag(expr, idim) + ) + + def map_distributed_recv(self, + expr: pt.DistributedRecv) -> None: + self.record_eq_constraints_from_tags(expr) + + +def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] + ) -> ArrayOrNames: + if not isinstance(expr, (pt.Array, pt.DictOfNamedArrays)): + return rec_map_array_container(unify_discretization_entity_tags, + expr) + + from collections import defaultdict + discr_unification_helper = DiscretizationEntityConstraintCollector() + discr_unification_helper(expr) + tag_var_to_axis = {} + variables_to_solve = [] + + for (axis, var) in discr_unification_helper.axis_to_tag_var.items(): + tag_var_to_axis[var] = axis + if var in discr_unification_helper.variables_to_solve: + variables_to_solve.append(var) + + lhs = [cnstrnt[0] for cnstrnt in discr_unification_helper.constraints] + rhs = [cnstrnt[1] for cnstrnt in discr_unification_helper.constraints] + assert len(lhs) == len(rhs) + solutions = {} + + for i_retry in range(MAX_UNIFY_RETRIES): + old_solutions = solutions.copy() + solutions = unification.unify(lhs, rhs, + {l_expr: r_expr + for l_expr, r_expr in solutions.items() + if isinstance(r_expr, + DiscretizationEntityAxisTag)}) + if solutions == old_solutions: + logger.info(f"Unification converged after {i_retry} iterations.") + break + else: + logger.warn(f"Could not converge after {MAX_UNIFY_RETRIES} iterations.") + + # Ideally it might be better to enable this, but that would be too + # restrictive as not all computation graphs result in DOFArray ouptuts + # if not (frozenset(variables_to_solve) <= frozenset(solutions)): + # raise RuntimeError("Unification failed.") + + # ary_to_axes_tags: mapping from array to a mapping from iaxis to the + # solved tag. + ary_to_axes_tags = defaultdict(dict) + for var in solutions: + ary, axis = tag_var_to_axis[var] + if isinstance(solutions[var], DiscretizationEntityAxisTag): + ary_to_axes_tags[ary][axis] = solutions[var] + if var in variables_to_solve and ( + not isinstance(solutions[var], DiscretizationEntityAxisTag)): + raise RuntimeError(f"Could not solve for {var}.") + + def attach_tags(expr: ArrayOrNames) -> ArrayOrNames: + if not isinstance(expr, pt.Array): + return expr + + for iaxis, solved_tag in ary_to_axes_tags[expr].items(): + if expr.axes[iaxis].tags_of_type(DiscretizationEntityAxisTag): + discr_tag, = (expr + .axes[iaxis] + .tags_of_type(DiscretizationEntityAxisTag)) + assert discr_tag == solved_tag + else: + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + actual_tag = discr_unification_helper.axis_to_tag_var[(expr, + iaxis)] + assert actual_tag in discr_unification_helper.variables_to_solve + assert actual_tag in variables_to_solve + raise ValueError(f"In {expr!r}, axis={iaxis}'s type cannot be " + "inferred.") + expr = expr.with_tagged_axis(iaxis, solved_tag) + + if isinstance(expr, pt.Einsum): + redn_descr_to_entity_type = {} + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + for iaxis, access_descr in enumerate(access_descrs): + if isinstance(access_descr, EinsumReductionAxis): + redn_descr_to_entity_type[access_descr] = ( + ary_to_axes_tags[arg][iaxis]) + + if (frozenset(redn_descr_to_entity_type) + != frozenset(expr.redn_descr_to_redn_dim)): + raise ValueError + + for redn_descr, solved_tag in redn_descr_to_entity_type.items(): + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + raise ValueError(f"In {expr!r}, redn_descr={redn_descr}'s" + " type cannot be inferred.") + expr = expr.with_tagged_redn_dim(redn_descr, solved_tag) + + if isinstance(expr, pt.IndexLambda): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + hlo = index_lambda_to_high_level_op(expr) + if isinstance(hlo, ReduceOp): + for iaxis in hlo.axes: + solved_tag = ary_to_axes_tags[hlo.x][iaxis] + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + raise ValueError(f"In {expr!r}, redn_descr={iaxis}'s" + " type cannot be inferred.") + + expr = expr.with_tagged_redn_dim(iaxis, solved_tag) + + return expr + + return pt.transform.map_and_copy(expr, attach_tags) + +# }}} + + +class UnInferredStoredArrayCatcher(pt.transform.CachedWalkMapper): + """ + Raises a :class:`ValueError` if a stored array has axes without a + :class:`DiscretizationEntityAxisTag` tagged to it. + """ + def post_visit(self, expr: ArrayOrNames) -> None: + if (isinstance(expr, pt.Array) + and expr.tags_of_type(pt.tags.ImplStored)): + if any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in expr.axes): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + if (isinstance(expr, pt.IndexLambda) + and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for redn_dim in expr.reduction_dims.values())): + raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") + + if (isinstance(expr, pt.Einsum) + and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for redn_dim in expr.redn_descr_to_redn_dim.values())): + raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") + + if isinstance(expr, pt.DictOfNamedArrays): + if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in subexpr.axes) + for subexpr in expr._data.values()): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + from pytato.loopy import LoopyCall + + if isinstance(expr, LoopyCall): + if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in subexpr.axes) + for subexpr in expr.bindings.values() + if (isinstance(subexpr, pt.Array) + and not isinstance(subexpr, pt.InputArgumentBase) + and subexpr.ndim != 0)): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + +def are_all_stored_arrays_inferred(expr: ArrayOrNames): + UnInferredStoredArrayCatcher()(expr) + # vim: fdm=marker From d213ff61254648cf132c493617a04f1e96c02f81 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:05:08 -0600 Subject: [PATCH 4/6] Implements FusionContractorArrayContext * Performs Loop Fusion * Performs Array contraction * Splits kernels at the granularity of fused einsums * Transforms those einsums using recorded values from github.com/kaushikcfd/feinsum --- meshmode/array_context.py | 991 +++++++++++++++++++++++++++++++++++++- 1 file changed, 990 insertions(+), 1 deletion(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index fb2386fd9..4d0d224ff 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -27,9 +27,10 @@ import sys import logging +import numpy as np from warnings import warn -from typing import Mapping, Sequence, Union +from typing import Mapping, Sequence, Union, FrozenSet, Tuple, Any from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase from arraycontext.pytest import ( @@ -39,6 +40,20 @@ from pytools.tag import Tag from loopy.translation_unit import for_each_kernel +from loopy.tools import memoize_on_disk +from pytools import ProcessLogger +from pytools.tag import UniqueTag, tag_dataclass + +from meshmode.transform_metadata import (DiscretizationElementAxisTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationDimAxisTag, + DiscretizationPhysicalDimAxisTag, + DiscretizationRefDimAxisTag, + DiscretizationMeshNodesAxisTag, + DiscretizationEntityAxisTag) + +from pyrsistent import pmap logger = logging.getLogger(__name__) @@ -648,4 +663,978 @@ def untag_loopy_call_results(expr): return dag + +def get_temps_not_to_contract(knl): + from functools import reduce + wmap = knl.writer_map() + rmap = knl.reader_map() + + temps_not_to_contract = set() + for tv in knl.temporary_variables: + if len(wmap[tv]) == 1: + writer_id, = wmap[tv] + writer_loop_nest = knl.id_to_insn[writer_id].within_inames + insns_in_writer_loop_nest = reduce(frozenset.union, + (knl.iname_to_insns()[iname] + for iname in writer_loop_nest), + frozenset()) + if ( + (not (rmap.get(tv, frozenset()) + <= insns_in_writer_loop_nest)) + or len(knl.id_to_insn[writer_id].reduction_inames()) != 0 + or any((len(knl.id_to_insn[reader_id].reduction_inames()) != 0) + for reader_id in rmap.get(tv, frozenset()))): + temps_not_to_contract.add(tv) + else: + temps_not_to_contract.add(tv) + return temps_not_to_contract + + # Better way to query it... + # import loopy as lp + # from kanren.constraints import neq as kanren_neq + # + # tempo = lp.relations.get_tempo(knl) + # producero = lp.relations.get_producero(knl) + # consumero = lp.relations.get_consumero(knl) + # withino = lp.relations.get_withino(knl) + # reduce_insno = lp.relations.get_reduce_insno(knl) + # + # # temp_k: temporary variable that cannot be contracted + # temp_k = kanren.var() + # producer_insn_k = kanren.var() + # producer_loops_k = kanren.var() + # consumer_insn_k = kanren.var() + # consumer_loops_k = kanren.var() + + # temps_not_to_contract = kanren.run(0, + # temp_k, + # tempo(temp_k), + # producero(producer_insn_k, + # temp_k), + # consumero(consumer_insn_k, + # temp_k), + # withino(producer_insn_k, + # producer_loops_k), + # withino(consumer_insn_k, + # consumer_loops_k), + # kanren.lany( + # kanren_neq( + # producer_loops_k, + # consumer_loops_k), + # reduce_insno(consumer_insn_k)), + # results_filter=frozenset) + # return temps_not_to_contract + + +def _is_iel_loop_part_of_global_dof_loops(iel: str, knl) -> bool: + insn, = knl.iname_to_insns()[iel] + return any(iname + for iname in knl.id_to_insn[insn].within_inames + if knl.iname_tags_of_type(iname, DiscretizationDOFAxisTag)) + + +def _discr_entity_sort_key(discr_tag: DiscretizationEntityAxisTag + ) -> Tuple[Any, ...]: + from dataclasses import fields + key = [type(discr_tag).__name__] + + for field in fields(discr_tag): + key.append(getattr(discr_tag, field.name)) + + return tuple(key) + + +def _fuse_loops_over_a_discr_entity(knl, + mesh_entity, + fused_loop_prefix, + should_fuse_redn_loops, + orig_knl): + import loopy as lp + import kanren + from functools import reduce + taggedo = lp.relations.get_taggedo_of_type(orig_knl, mesh_entity) + + redn_loops = reduce(frozenset.union, + (insn.reduction_inames() + for insn in orig_knl.instructions), + frozenset()) + + # tag_k: tag of type 'mesh_entity' + tag_k = kanren.var() + tags = kanren.run(0, + tag_k, + taggedo(kanren.var(), tag_k), + results_filter=frozenset) + for itag, tag in enumerate( + sorted(tags, key=lambda x: _discr_entity_sort_key(x))): + # iname_k: iname tagged with 'tag' + iname_k = kanren.var() + inames = kanren.run(0, + iname_k, + taggedo(iname_k, tag), + results_filter=frozenset) + inames = frozenset(inames) + if should_fuse_redn_loops: + inames = inames & redn_loops + else: + inames = inames - redn_loops + + length_to_inames = {} + for iname in inames: + length = knl.get_constant_iname_length(iname) + length_to_inames.setdefault(length, set()).add(iname) + + for i, (_, inames_to_fuse) in enumerate( + sorted(length_to_inames.items())): + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, inames_to_fuse, prefix=f"{fused_loop_prefix}_{itag}_{i}_")) + knl = lp.tag_inames(knl, {f"{fused_loop_prefix}_{itag}_*": tag}) + + return knl + + +@memoize_on_disk +def fuse_same_discretization_entity_loops(knl): + # maintain an 'orig_knl' to keep the original iname and tags before + # transforming it. + orig_knl = knl + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationElementAxisTag, + "iel", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + False, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + True, + orig_knl) + + return knl + + +@memoize_on_disk +def contract_arrays(knl, callables_table): + import loopy as lp + from loopy.transform.precompute import precompute_for_single_kernel + + temps_not_to_contract = get_temps_not_to_contract(knl) + all_temps = frozenset(knl.temporary_variables) + + logger.info("Array Contraction: Contracting " + f"{len(all_temps-frozenset(temps_not_to_contract))} temps") + + wmap = knl.writer_map() + + for temp in sorted(all_temps - frozenset(temps_not_to_contract)): + writer_id, = wmap[temp] + rmap = knl.reader_map() + ensm_tag, = knl.id_to_insn[writer_id].tags_of_type(EinsumTag) + + knl = lp.assignment_to_subst(knl, temp, + remove_newly_unused_inames=False) + if temp not in rmap: + # no one was reading 'temp' i.e. dead code got eliminated :) + assert f"{temp}_subst" not in knl.substitutions + continue + knl = precompute_for_single_kernel( + knl, callables_table, f"{temp}_subst", + sweep_inames=(), + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id=f"_mm_contract_{temp}", + ) + + knl = lp.map_instructions(knl, + f"id:_mm_contract_{temp}", + lambda x: x.tagged(ensm_tag)) + + return lp.remove_unused_inames(knl) + + +def _get_group_size_for_dof_array_loop(nunit_dofs): + """ + Returns the OpenCL workgroup size for a loop iterating over the global DOFs + of a discretization with *nunit_dofs* per cell. + """ + if nunit_dofs == {6}: + return 16, 6 + elif nunit_dofs == {10}: + return 16, 10 + elif nunit_dofs == {20}: + return 16, 10 + elif nunit_dofs == {1}: + return 32, 1 + elif nunit_dofs == {2}: + return 32, 2 + elif nunit_dofs == {4}: + return 16, 4 + elif nunit_dofs == {3}: + return 32, 3 + elif nunit_dofs == {35}: + return 9, 7 + elif nunit_dofs == {15}: + return 8, 8 + else: + raise NotImplementedError(nunit_dofs) + + +def _get_iel_to_idofs(kernel): + iel_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type((DiscretizationElementAxisTag, + DiscretizationMeshNodesAxisTag))) + } + idof_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDOFAxisTag)) + } + iface_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationFaceAxisTag)) + } + idim_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDimAxisTag)) + } + + iel_to_idofs = {iel: set() for iel in iel_inames} + + for insn in kernel.instructions: + if (len(insn.within_inames) == 1 + and (insn.within_inames) <= iel_inames): + iel, = insn.within_inames + if all(kernel.id_to_insn[el_insn].within_inames == insn.within_inames + for el_insn in kernel.iname_to_insns()[iel]): + # the iel here doesn't interfere with any idof i.e. we + # support parallelizing such loops. + pass + else: + raise NotImplementedError(f"The loop {insn.within_inames}" + " does not appear as a singly nested" + " loop.") + elif ((len(insn.within_inames) == 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1)): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((iel in kernel.id_to_insn[dof_insn].within_inames) + for dof_insn in kernel.iname_to_insns()[idof]): + pass + else: + raise NotImplementedError("The loop " + f"'{insn.within_inames}' has the idof-loop" + " that's not nested within the iel-loop.") + elif ((len(insn.within_inames) > 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1) + and (len(insn.within_inames & (idim_inames | iface_inames)) + == (len(insn.within_inames) - 2))): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((all({iel, idof} <= kernel.id_to_insn[non_iel_insn].within_inames + for non_iel_insn in kernel.iname_to_insns()[non_iel_iname])) + for non_iel_iname in insn.within_inames - {iel}): + iel_to_idofs[iel].add(idof) + else: + raise NotImplementedError("Could not fit into " + " loop nest pattern.") + else: + raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" + " into known set of loop-nest patterns.") + + return pmap({iel: frozenset(idofs) + for iel, idofs in iel_to_idofs.items()}) + + +def _get_iel_loop_from_insn(insn, knl): + iel, = {iname + for iname in insn.within_inames + if knl.inames[iname].tags_of_type((DiscretizationElementAxisTag, + DiscretizationMeshNodesAxisTag))} + return iel + + +def _get_element_loop_topo_sorted_order(knl): + dag = {iel: set() + for iel in knl.all_inames() + if knl.inames[iel].tags_of_type(DiscretizationElementAxisTag)} + + for insn in knl.instructions: + succ_iel = _get_iel_loop_from_insn(insn, knl) + for dep_id in insn.depends_on: + pred_iel = _get_iel_loop_from_insn(knl.id_to_insn[dep_id], knl) + if pred_iel != succ_iel: + dag[pred_iel].add(succ_iel) + + from pytools.graph import compute_topological_order + return compute_topological_order(dag, key=lambda x: x) + + +@tag_dataclass +class EinsumTag(UniqueTag): + orig_loop_nest: FrozenSet[str] + + +def _prepare_kernel_for_parallelization(kernel): + discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel", + DiscretizationDOFAxisTag: "idof", + DiscretizationDimAxisTag: "idim", + DiscretizationPhysicalDimAxisTag: "idim", + DiscretizationRefDimAxisTag: "idim", + DiscretizationMeshNodesAxisTag: "imsh_nodes", + DiscretizationFaceAxisTag: "iface"} + import loopy as lp + from loopy.match import ObjTagged + + # A mapping from inames that the instruction accesss to + # the instructions ids within that iname. + ensm_buckets = {} + vng = kernel.get_var_name_generator() + + for insn in kernel.instructions: + inames = insn.within_inames | insn.reduction_inames() + ensm_buckets.setdefault(tuple(sorted(inames)), set()).add(insn.id) + + # FIXME: Dependency violation is a big concern here + # Waiting on the loopy feature: https://github.com/inducer/loopy/issues/550 + + for ieinsm, (loop_nest, insns) in enumerate(sorted(ensm_buckets.items())): + new_insns = [insn.tagged(EinsumTag(frozenset(loop_nest))) + if insn.id in insns + else insn + for insn in kernel.instructions] + kernel = kernel.copy(instructions=new_insns) + + new_inames = [] + for iname in loop_nest: + discr_tag, = kernel.iname_tags_of_type(iname, + DiscretizationEntityAxisTag) + new_iname = vng(f"{discr_tag_to_prefix[type(discr_tag)]}_ensm{ieinsm}") + new_inames.append(new_iname) + + kernel = lp.duplicate_inames( + kernel, + loop_nest, + within=ObjTagged(EinsumTag(frozenset(loop_nest))), + new_inames=new_inames, + tags=kernel.iname_to_tags) + + return kernel + + +def _get_elementwise_einsum(t_unit, einsum_tag): + import loopy as lp + import feinsum as fnsm + from loopy.match import ObjTagged + from pymbolic.primitives import Variable, Subscript + + kernel = t_unit.default_entrypoint + + assert isinstance(einsum_tag, EinsumTag) + insn_match = ObjTagged(einsum_tag) + + global_vars = ({tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == lp.AddressSpace.GLOBAL} + | set(kernel.arg_dict.keys())) + insns = [insn + for insn in kernel.instructions + if insn_match(kernel, insn)] + idx_tuples = set() + + for insn in insns: + assert len(insn.assignees) == 1 + if isinstance(insn.assignee, Variable): + if insn.assignee.name in global_vars: + raise NotImplementedError(insn) + else: + assert (kernel.temporary_variables[insn.assignee.name].address_space + == lp.AddressSpace.PRIVATE) + elif isinstance(insn.assignee, Subscript): + assert insn.assignee_name in global_vars + idx_tuples.add(tuple(idx.name + for idx in insn.assignee.index_tuple)) + else: + raise NotImplementedError(insn) + + if len(idx_tuples) != 1: + raise NotImplementedError("Multiple einsums in the same loop nest =>" + " not allowed.") + idx_tuple, = idx_tuples + subscript = "{lhs}, {lhs}->{lhs}".format( + lhs="".join(chr(97+i) + for i in range(len(idx_tuple)))) + arg_shape = tuple(np.inf + if kernel.iname_tags_of_type(idx, DiscretizationElementAxisTag) + else kernel.get_constant_iname_length(idx) + for idx in idx_tuple) + return fnsm.einsum(subscript, + fnsm.array(arg_shape, "float64"), + fnsm.array(arg_shape, "float64")) + + +def _combine_einsum_domains(knl): + import islpy as isl + from functools import reduce + + new_domains = [] + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in knl.instructions), + frozenset()) + + for tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + insns = [insn + for insn in knl.instructions + if tag in insn.tags] + inames = reduce(frozenset.union, + ((insn.within_inames | insn.reduction_inames()) + for insn in insns), + frozenset()) + domain = knl.get_inames_domain(frozenset(inames)) + new_domains.append(domain.project_out_except(sorted(inames), + [isl.dim_type.set])) + + return knl.copy(domains=new_domains) + + +class FusionContractorArrayContext( + SingleGridWorkBalancingPytatoArrayContext): + + def transform_dag(self, dag): + import pytato as pt + + # {{{ CSE + + with ProcessLogger(logger, "transform_dag.mpms_materialization"): + dag = pt.transform.materialize_with_mpms(dag) + + def mark_materialized_nodes_as_cse( + ary: Union[pt.Array, + pt.AbstractResultWithNamedArrays]) -> pt.Array: + if isinstance(ary, pt.AbstractResultWithNamedArrays): + return ary + + if ary.tags_of_type(pt.tags.ImplStored): + return ary.tagged(pt.tags.PrefixNamed("cse")) + else: + return ary + + with ProcessLogger(logger, "transform_dag.naming_cse"): + dag = pt.transform.map_and_copy(dag, mark_materialized_nodes_as_cse) + + # }}} + + # {{{ indirect addressing are non-negative + + indirection_maps = set() + + class _IndirectionMapRecorder(pt.transform.CachedWalkMapper): + def post_visit(self, expr): + if isinstance(expr, pt.IndexBase): + for idx in expr.indices: + if isinstance(idx, pt.Array): + indirection_maps.add(idx) + + _IndirectionMapRecorder()(dag) + + def tag_indices_as_non_negative(ary): + if ary in indirection_maps: + return ary.tagged(pt.tags.AssumeNonNegative()) + else: + return ary + + with ProcessLogger(logger, "transform_dag.tag_indices_as_non_negative"): + dag = pt.transform.map_and_copy(dag, tag_indices_as_non_negative) + + # }}} + + dag = pt.transform.deduplicate_data_wrappers(dag) + + # {{{ get rid of copies for different views of a cl-array + + def eliminate_reshapes_of_data_wrappers(ary): + if (isinstance(ary, pt.Reshape) + and isinstance(ary.array, pt.DataWrapper)): + return pt.make_data_wrapper(ary.array.data.reshape(ary.shape), + tags=ary.tags, + axes=ary.axes) + else: + return ary + + dag = pt.transform.map_and_copy(dag, + eliminate_reshapes_of_data_wrappers) + + # }}} + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_input_and_output(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return (pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass")))) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_face_mass_ins_and_outs"): + dag = pt.transform.map_and_copy(dag, + materialize_face_mass_input_and_output) + + # }}} + + # {{{ materialize inverse mass inputs + + def materialize_inverse_mass_inputs(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ei,ij,ej->ei")): + arg1, arg2, arg3 = expr.args + if not arg3.tags_of_type(pt.tags.PrefixNamed): + arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp")) + if not arg3.tags_of_type(pt.tags.ImplStored): + arg3 = arg3.tagged(pt.tags.ImplStored()) + + return pt.Einsum(expr.access_descriptors, + (arg1, arg2, arg3), + expr.axes, + expr.redn_descr_to_redn_dim, + expr.index_to_access_descr, + expr.tags) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_inverse_mass_inputs) + + # }}} + + # {{{ materialize all einsums + + def materialize_all_einsums_or_reduces(expr): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + if isinstance(expr, pt.Einsum): + return expr.tagged(pt.tags.ImplStored()) + elif (isinstance(expr, pt.IndexLambda) + and isinstance(index_lambda_to_high_level_op(expr), ReduceOp)): + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_all_einsums_or_reduces"): + dag = pt.transform.map_and_copy(dag, materialize_all_einsums_or_reduces) + + # }}} + + # {{{ infer axis types + + from meshmode.pytato_utils import unify_discretization_entity_tags + + with ProcessLogger(logger, "transform_dag.infer_axes_tags"): + dag = unify_discretization_entity_tags(dag) + + # }}} + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + # {{{ remove broadcasts from einsums: help feinsum + + ensm_arg_rewrite_cache = {} + + def _get_rid_of_broadcasts_from_einsum(expr): + # Helpful for matching against the available expressions + # in feinsum. + + from pytato.utils import (are_shape_components_equal, + are_shapes_equal) + if isinstance(expr, pt.Einsum): + from pytato.array import EinsumElementwiseAxis + idx_to_len = expr._access_descr_to_axis_len() + new_access_descriptors = [] + new_args = [] + inp_gatherer = pt.transform.InputGatherer() + access_descr_to_axes = dict(expr.redn_descr_to_redn_dim) + for iax, axis in enumerate(expr.axes): + access_descr_to_axes[EinsumElementwiseAxis(iax)] = axis + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + new_shape = [] + new_access_descrs = [] + new_axes = [] + for iaxis, (access_descr, axis_len) in enumerate( + zip(access_descrs, + arg.shape)): + if not are_shape_components_equal(axis_len, + idx_to_len[access_descr]): + assert are_shape_components_equal(axis_len, 1) + if any(isinstance(inp, pt.Placeholder) + for inp in inp_gatherer(arg)): + # do not get rid of broadcasts from parameteric + # data. + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + new_axes.append(arg.axes[iaxis]) + else: + new_axes.append(arg.axes[iaxis]) + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + + if not are_shapes_equal(new_shape, arg.shape): + assert len(new_axes) == len(new_shape) + arg_to_freeze = (arg.reshape(new_shape) + .copy(axes=tuple( + access_descr_to_axes[acc_descr] + for acc_descr in new_access_descrs))) + + try: + new_arg = ensm_arg_rewrite_cache[arg_to_freeze] + except KeyError: + new_arg = self.thaw(self.freeze(arg_to_freeze)) + ensm_arg_rewrite_cache[arg_to_freeze] = new_arg + + arg = new_arg + + assert arg.ndim == len(new_access_descrs) + new_args.append(arg) + new_access_descriptors.append(tuple(new_access_descrs)) + + return pt.Einsum(tuple(new_access_descriptors), + tuple(new_args), + tags=expr.tags, + axes=expr.axes, + redn_descr_to_redn_dim=expr.redn_descr_to_redn_dim, + index_to_access_descr=expr.index_to_access_descr) + else: + return expr + + dag = pt.transform.map_and_copy(dag, _get_rid_of_broadcasts_from_einsum) + + # }}} + + # {{{ get rid of 0-long arrays + + def replace_zero_size_arrays_with_zeros(expr): + if isinstance(expr, pt.Array) and expr.size == 0: + from pytato.array import _get_default_axes + return pt.IndexLambda(0, expr.shape, expr.dtype, + {}, + _get_default_axes(expr.ndim), + pmap(), + frozenset()) + else: + return expr + + dag = pt.transform.map_and_copy(dag, + replace_zero_size_arrays_with_zeros) + + # }}} + + # {{{ remove any PartID tags + + from pytato.distributed import PartIDTag + + def remove_part_id_tags(expr): + if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): + tag, = expr.tags_of_type(PartIDTag) + return expr.without_tags(tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_part_id_tags) + + # }}} + + # {{{ untag outputs tagged from being tagged ImplStored + + def _untag_impl_stored(expr): + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.without_tags(pt.tags.ImplStored(), + verify_existence=False) + + dag = pt.make_dict_of_named_arrays({ + name: _untag_impl_stored(named_ary.expr) + for name, named_ary in dag.items()}) + + # }}} + + return dag + + def transform_loopy_program(self, t_unit): + import loopy as lp + from functools import reduce + from arraycontext.impl.pytato.compile import FromArrayContextCompile + + original_t_unit = t_unit + + # from loopy.transform.instruction import simplify_indices + # t_unit = simplify_indices(t_unit) + + knl = t_unit.default_entrypoint + + # {{{ fallback: if the inames are not inferred which mesh entity they + # iterate over. + + for iname in knl.all_inames(): + if not knl.iname_tags_of_type(iname, DiscretizationEntityAxisTag): + warn("Falling back to a slower transformation strategy as some" + " loops are uninferred which mesh entity they belong to.", + stacklevel=2) + + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ hardcode offset to 0 (sorry humanity) + + knl = knl.copy(args=[arg.copy(offset=0) + for arg in knl.args]) + + # }}} + + # {{{ loop fusion + + with ProcessLogger(logger, "Loop Fusion"): + knl = fuse_same_discretization_entity_loops(knl) + + # }}} + + # {{{ align kernels for fused einsums + + knl = _prepare_kernel_for_parallelization(knl) + knl = _combine_einsum_domains(knl) + + # }}} + + # {{{ array contraction + + with ProcessLogger(logger, "Array Contraction"): + knl = contract_arrays(knl, t_unit.callables_table) + + # }}} + + # {{{ Stats Collection (Disabled) + + if 0: + with ProcessLogger(logger, "Counting Kernel Ops"): + from loopy.kernel.array import ArrayBase + from pytools import product + knl = knl.copy( + silenced_warnings=(knl.silenced_warnings + + ["insn_count_subgroups_upper_bound", + "summing_if_branches_ops"])) + + t_unit = t_unit.with_kernel(knl) + + op_map = lp.get_op_map(t_unit, subgroup_size=32) + + c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + f32_ops = ((op_map.filter_by(dtype=[np.float32], + kernel_name=knl.name) + .eval_and_sum({})) + + (2 * c64_ops["add"] + + 6 * c64_ops["mul"] + + (6 + 3 + 2) * c64_ops["div"])) + f64_ops = ((op_map.filter_by(dtype=[np.float64], + kernel_name="_pt_kernel") + .eval_and_sum({})) + + (2 * c128_ops["add"] + + 6 * c128_ops["mul"] + + (6 + 3 + 2) * c128_ops["div"])) + + # {{{ footprint gathering + + nfootprint_bytes = 0 + + for ary in knl.args: + if (isinstance(ary, ArrayBase) + and ary.address_space == lp.AddressSpace.GLOBAL): + nfootprint_bytes += (product(ary.shape) + * ary.dtype.itemsize) + + for ary in knl.temporary_variables.values(): + if ary.address_space == lp.AddressSpace.GLOBAL: + # global temps would be written once and read once + nfootprint_bytes += (2 * product(ary.shape) + * ary.dtype.itemsize) + + # }}} + + if f32_ops: + logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}") + if f64_ops: + logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}") + logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}") + + # }}} + + # {{{ check whether we can parallelize the kernel + + try: + iel_to_idofs = _get_iel_to_idofs(knl) + except NotImplementedError as err: + if knl.tags_of_type(FromArrayContextCompile): + raise err + else: + warn("FusionContractorArrayContext.transform_loopy_program not" + " broad enough (yet). Falling back to a possibly slower" + " transformation strategy.") + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ insert barriers between consecutive iel-loops + + toposorted_iels = _get_element_loop_topo_sorted_order(knl) + + for iel_pred, iel_succ in zip(toposorted_iels[:-1], + toposorted_iels[1:]): + knl = lp.add_barrier(knl, + insn_before=f"iname:{iel_pred}", + insn_after=f"iname:{iel_succ}") + + # }}} + + # {{{ Parallelization strategy: Use feinsum + + t_unit = t_unit.with_kernel(knl) + del knl + + if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile): + # FIXME: Enable this branch, WIP for now and hence disabled it. + from loopy.match import ObjTagged + import feinsum as fnsm + from meshmode.feinsum_transformations import FEINSUM_TO_TRANSFORMS + + assert all(insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions + if isinstance(insn, lp.MultiAssignmentBase) + ) + + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions), + frozenset()) + for ensm_tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + if reduce(frozenset.union, + (insn.reduction_inames() + for insn in (t_unit.default_entrypoint.instructions) + if ensm_tag in insn.tags), + frozenset()): + fused_einsum = fnsm.match_einsum(t_unit, ObjTagged(ensm_tag)) + else: + # elementwise loop + fused_einsum = _get_elementwise_einsum(t_unit, ensm_tag) + + try: + fnsm_transform = FEINSUM_TO_TRANSFORMS[ + fnsm.normalize_einsum(fused_einsum)] + except KeyError: + fnsm.query(fused_einsum, + self.queue.context, + err_if_no_results=True) + 1/0 + + t_unit = fnsm_transform(t_unit, + insn_match=ObjTagged(ensm_tag)) + else: + knl = t_unit.default_entrypoint + for iel, idofs in sorted(iel_to_idofs.items()): + if idofs: + nunit_dofs = {knl.get_constant_iname_length(idof) + for idof in idofs} + idof, = idofs + + l_one_size, l_zero_size = _get_group_size_for_dof_array_loop( + nunit_dofs) + + knl = lp.split_iname(knl, iel, l_one_size, + inner_tag="l.1", outer_tag="g.0") + knl = lp.split_iname(knl, idof, l_zero_size, + inner_tag="l.0", outer_tag="unr") + else: + knl = lp.split_iname(knl, iel, 32, + outer_tag="g.0", inner_tag="l.0") + + t_unit = t_unit.with_kernel(knl) + + # }}} + + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + + return t_unit + # vim: foldmethod=marker From 1467d5a917f11a657117bb1c711580201c446525 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:25:57 -0600 Subject: [PATCH 5/6] adds feinsum, kanren to deps --- requirements.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35a924099..e99abecc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/kaushikcfd/pytato.git#egg=pytato # required by pytential, which is in turn needed for some tests git+https://github.com/inducer/pymbolic.git#egg=pymbolic @@ -27,3 +27,8 @@ git+https://github.com/inducer/pymetis.git#egg=pymetis # for examples/tp-lagrange-stl.py numpy-stl + + +# for FusionContractorActx transforms +git+https://github.com/kaushikcfd/feinsum.git#egg=feinsum +git+https://github.com/pythological/kanren.git#egg=miniKanren From 23200ad1635490ea50bab9753fcf147be636ca01 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 22 Apr 2022 13:01:05 -0500 Subject: [PATCH 6/6] time dedup call --- meshmode/array_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 4d0d224ff..6f9e0a595 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1184,7 +1184,8 @@ def tag_indices_as_non_negative(ary): # }}} - dag = pt.transform.deduplicate_data_wrappers(dag) + with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"): + dag = pt.transform.deduplicate_data_wrappers(dag) # {{{ get rid of copies for different views of a cl-array