From 61a6cac5f42b1024fb601c34a9cec5fe04e927d6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:00:21 -0600 Subject: [PATCH 1/7] CHERRY-PICK: SingleGridPytatoArrayContext Co-authored-by: Matthew Smith --- examples/simple-dg.py | 30 ++-- meshmode/array_context.py | 283 ++++++++++++++++++++++++++++++++++++++ meshmode/pytato_utils.py | 62 +++++++++ requirements.txt | 2 +- 4 files changed, 359 insertions(+), 18 deletions(-) create mode 100644 meshmode/pytato_utils.py diff --git a/examples/simple-dg.py b/examples/simple-dg.py index 964e0a12d..84bea31a4 100644 --- a/examples/simple-dg.py +++ b/examples/simple-dg.py @@ -33,7 +33,7 @@ from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from meshmode.dof_array import DOFArray, flat_norm from meshmode.array_context import (PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext) + SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext) from arraycontext import ( ArrayContainer, map_array_container, @@ -455,11 +455,10 @@ def main(lazy=False): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) - actx_outer = PyOpenCLArrayContext(queue, force_device_scalars=True) if lazy: - actx_rhs = PytatoPyOpenCLArrayContext(queue) + actx = PytatoPyOpenCLArrayContext(queue) else: - actx_rhs = actx_outer + actx = PyOpenCLArrayContext(queue, force_device_scalars=True) nel_1d = 16 from meshmode.mesh.generation import generate_regular_rect_mesh @@ -475,37 +474,34 @@ def main(lazy=False): logger.info("%d elements", mesh.nelements) - discr = DGDiscretization(actx_outer, mesh, order=order) + discr = DGDiscretization(actx, mesh, order=order) fields = WaveState( - u=bump(actx_outer, discr), - v=make_obj_array([discr.zeros(actx_outer) for i in range(discr.dim)]), + u=bump(actx, discr), + v=make_obj_array([discr.zeros(actx) for i in range(discr.dim)]), ) from meshmode.discretization.visualization import make_visualizer - vis = make_visualizer(actx_outer, discr.volume_discr) + vis = make_visualizer(actx, discr.volume_discr) def rhs(t, q): - return wave_operator(actx_rhs, discr, c=1, q=q) + return wave_operator(actx, discr, c=1, q=q) - compiled_rhs = actx_rhs.compile(rhs) - - def rhs_wrapper(t, q): - r = compiled_rhs(t, actx_rhs.thaw(actx_outer.freeze(q))) - return actx_outer.thaw(actx_rhs.freeze(r)) + compiled_rhs = actx.compile(rhs) t = np.float64(0) t_final = 3 istep = 0 while t < t_final: - fields = rk4_step(fields, t, dt, rhs_wrapper) + fields = actx.thaw(actx.freeze(fields,)) + fields = rk4_step(fields, t, dt, compiled_rhs) if istep % 10 == 0: # FIXME: Maybe an integral function to go with the # DOFArray would be nice? assert len(fields.u) == 1 logger.info("[%05d] t %.5e / %.5e norm %.5e", - istep, t, t_final, actx_outer.to_numpy(flat_norm(fields.u, 2))) + istep, t, t_final, actx.to_numpy(flat_norm(fields.u, 2))) vis.write_vtk_file("fld-wave-min-%04d.vtu" % istep, [ ("q", fields), ]) @@ -513,7 +509,7 @@ def rhs_wrapper(t, q): t += dt istep += 1 - assert flat_norm(fields.u, 2) < 100 + assert actx.to_numpy(flat_norm(fields.u, 2)) < 100 if __name__ == "__main__": diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 259ceabc3..071071dc4 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -26,6 +26,8 @@ """ import sys +import logging + from warnings import warn from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase @@ -33,6 +35,9 @@ _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory) +from loopy.translation_unit import for_each_kernel + +logger = logging.getLogger(__name__) def thaw(actx, ary): @@ -345,4 +350,282 @@ def _import_names(): # }}} +@for_each_kernel +def _single_grid_work_group_transform(kernel, cl_device): + import loopy as lp + from meshmode.transform_metadata import (ConcurrentElementInameTag, + ConcurrentDOFInameTag) + + splayed_inames = set() + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + for insn in kernel.instructions: + if insn.within_inames in splayed_inames: + continue + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + + splayed_inames.add(insn.within_inames) + continue + + for iname in insn.within_inames: + if kernel.iname_tags_of_type(iname, + ConcurrentElementInameTag): + assert bigger_loop is None + bigger_loop = iname + elif kernel.iname_tags_of_type(iname, + ConcurrentDOFInameTag): + assert smaller_loop is None + smaller_loop = iname + else: + pass + + if bigger_loop or smaller_loop: + assert (bigger_loop is not None + and smaller_loop is not None) + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=kernel.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + splayed_inames.add(insn.within_inames) + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + + +def _alias_global_temporaries(t_unit): + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + """ + from loopy.kernel.data import AddressSpace + from loopy.kernel import KernelState + from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, + CallKernel, ReturnFromKernel, Barrier) + from loopy.schedule.tools import get_return_from_kernel_mapping + from pytools import UniqueNameGenerator + from collections import defaultdict + + kernel = t_unit.default_entrypoint + assert kernel.state == KernelState.LINEARIZED + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + temp_to_live_interval_start = {} + temp_to_live_interval_end = {} + return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) + + for sched_idx, sched_item in enumerate(kernel.linearization): + if isinstance(sched_item, RunInstruction): + for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = sched_idx + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] + elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, + ReturnFromKernel, Barrier)): + # no variables are accessed within these schedule items => do + # nothing. + pass + else: + raise NotImplementedError(type(sched_item)) + + vng = UniqueNameGenerator() + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage = defaultdict(set) + + sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] + sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] + + for tv, just_alive_idx in temp_to_live_interval_start.items(): + sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) + + for tv, just_dead_idx in temp_to_live_interval_end.items(): + sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) + + new_tvs = {} + + for sched_idx, _ in enumerate(kernel.linearization): + just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] + to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + for tv_name in sorted(to_be_allocated_temps): + assert len(to_be_allocated_temps) <= 1 + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_msh_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + assert name in new_tvs + + kernel = kernel.copy(temporary_variables=new_tvs) + + return t_unit.with_kernel(kernel) + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +def deduplicate_data_wrappers(dag): + import pytato as pt + data_wrapper_cache = {} + data_wrappers_encountered = 0 + + def cached_data_wrapper_if_present(ary): + nonlocal data_wrappers_encountered + + if isinstance(ary, pt.DataWrapper): + + data_wrappers_encountered += 1 + cache_key = (ary.data.base_data.int_ptr, ary.data.offset, + ary.shape, ary.data.strides) + try: + result = data_wrapper_cache[cache_key] + except KeyError: + result = ary + data_wrapper_cache[cache_key] = result + + return result + else: + return ary + + dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) + + if data_wrappers_encountered: + logger.info("data wrapper de-duplication: " + "%d encountered, %d kept, %d eliminated", + data_wrappers_encountered, + len(data_wrapper_cache), + data_wrappers_encountered - len(data_wrapper_cache)) + + return dag + + +class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): + """ + A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL + kernel so that the work + """ + def transform_loopy_program(self, t_unit): + import loopy as lp + + t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) + t_unit = lp.set_options(t_unit, "insert_gbarriers") + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + + return t_unit + + def _get_fake_numpy_namespace(self): + from meshmode.pytato_utils import ( + EagerReduceComputingPytatoFakeNumpyNamespace) + return EagerReduceComputingPytatoFakeNumpyNamespace(self) + + def transform_dag(self, dag): + import pytato as pt + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_vec(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) + + # }}} + + # {{{ materialize all einsums + + def materialize_einsums(ary: pt.Array) -> pt.Array: + if isinstance(ary, pt.Einsum): + return ary.tagged(pt.tags.ImplStored()) + + return ary + + dag = pt.transform.map_and_copy(dag, materialize_einsums) + + # }}} + + dag = pt.transform.materialize_with_mpms(dag) + dag = deduplicate_data_wrappers(dag) + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + # vim: foldmethod=marker diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py new file mode 100644 index 000000000..b0724b1df --- /dev/null +++ b/meshmode/pytato_utils.py @@ -0,0 +1,62 @@ +from functools import partial, reduce +from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext import rec_map_reduce_array_container +import pyopencl.array as cl_array + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +class EagerReduceComputingPytatoFakeNumpyNamespace(PytatoFakeNumpyNamespace): + """ + A Numpy-namespace that computes the reductions eagerly whenever possible. + """ + def sum(self, a, axis=None, dtype=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + + def _pt_sum(ary): + return cl_array.sum(self._array_context.freeze(ary), + dtype=dtype, + queue=self._array_context.queue) + + return self._array_context.thaw(rec_map_reduce_array_container(sum, + _pt_sum, + a)) + else: + return super().sum(a, axis=axis, dtype=dtype) + + def min(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + lambda ary: cl_array.min(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().min(a, axis=axis) + + def max(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + lambda ary: cl_array.max(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().max(a, axis=axis) + +# vim: fdm=marker diff --git a/requirements.txt b/requirements.txt index 7ad3c51b2..35a924099 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ git+https://github.com/inducer/pytato.git#egg=pytato git+https://github.com/inducer/pymbolic.git#egg=pymbolic # also depends on pymbolic, so should come after it -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/kaushikcfd/loopy.git#egg=loopy # depends on loopy, so should come after it. git+https://github.com/inducer/arraycontext.git#egg=arraycontext From 38e5b0e3edc348db11edd692c75a04a3ac2412e6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:03:32 -0600 Subject: [PATCH 2/7] Adds more pytato utils: Axes types unification --- meshmode/pytato_utils.py | 575 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 574 insertions(+), 1 deletion(-) diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py index b0724b1df..986c65fff 100644 --- a/meshmode/pytato_utils.py +++ b/meshmode/pytato_utils.py @@ -1,7 +1,23 @@ +import pyopencl.array as cl_array +import kanren +import pytato as pt +import unification +import logging + from functools import partial, reduce from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace from arraycontext import rec_map_reduce_array_container -import pyopencl.array as cl_array +from meshmode.transform_metadata import DiscretizationEntityAxisTag +from pytato.loopy import LoopyCall +from pytato.array import EinsumElementwiseAxis, EinsumReductionAxis +from pytato.transform import ArrayOrNames +from arraycontext import ArrayContainer +from arraycontext.container.traversal import rec_map_array_container +from typing import Set, Mapping, Tuple, Union +logger = logging.getLogger(__name__) + + +MAX_UNIFY_RETRIES = 50 # used by unify_discretization_entity_tags def _can_be_eagerly_computed(ary) -> bool: @@ -59,4 +75,561 @@ def max(self, a, axis=None): else: return super().max(a, axis=axis) + +# {{{ solve for discretization metadata for arrays' axes + +class DiscretizationEntityConstraintCollector(pt.transform.Mapper): + """ + .. warning:: + + Instances of this mapper type store state that are only for visiting a + single DAG. Using a single instance for collecting the constraints on + multiple DAGs is undefined behavior. + """ + def __init__(self): + super().__init__() + self._visited_ids: Set[int] = set() + + # axis_to_var: mapping from (array, iaxis) to the kanren variable to be + # used for unification. + self.axis_to_tag_var: Mapping[Tuple[pt.Array, int], + unification.variable.Var] = {} + self.variables_to_solve: Set[unification.variable.Var] = set() + self.constraints = [] + + # type-ignore reason: CachedWalkMapper.rec's type does not match + # WalkMapper.rec's type + def rec(self, expr: ArrayOrNames) -> None: # type: ignore + if id(expr) in self._visited_ids: + return + + # type-ignore reason: super().rec expects either 'Array' or + # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' + super().rec(expr) # type: ignore + self._visited_ids.add(id(expr)) + + def get_kanren_var_for_axis_tag(self, + expr: pt.Array, + iaxis: int + ) -> unification.variable.Var: + key = (expr, iaxis) + + if key not in self.axis_to_tag_var: + self.axis_to_tag_var[key] = kanren.var() + + return self.axis_to_tag_var[key] + + def _record_all_axes_to_be_solved_if_impl_stored(self, expr): + if expr.tags_of_type(pt.tags.ImplStored): + for iaxis in range(expr.ndim): + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, + iaxis)) + + def _record_all_axes_to_be_solved(self, expr): + for iaxis in range(expr.ndim): + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, + iaxis)) + + def record_constraint(self, lhs, rhs): + self.constraints.append((lhs, rhs)) + + def record_eq_constraints_from_tags(self, expr: pt.Array) -> None: + for iaxis, axis in enumerate(expr.axes): + if axis.tags_of_type(DiscretizationEntityAxisTag): + discr_tag, = axis.tags_of_type(DiscretizationEntityAxisTag) + axis_var = self.get_kanren_var_for_axis_tag(expr, iaxis) + self.record_constraint(axis_var, discr_tag) + + def _map_input_base(self, expr: pt.InputArgumentBase + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for dim in expr.shape: + if isinstance(dim, pt.Array): + self.rec(dim) + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + + def map_index_lambda(self, expr: pt.IndexLambda) -> None: + from pytato.utils import are_shape_components_equal + from pytato.raising import index_lambda_to_high_level_op + from pytato.raising import (BinaryOp, FullOp, WhereOp, + BroadcastOp, C99CallOp, ReduceOp) + + # {{{ record constraints for expr and its subexprs. + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for dim in expr.shape: + if isinstance(dim, pt.Array): + self.rec(dim) + + for bnd in expr.bindings.values(): + self.rec(bnd) + + # }}} + + hlo = index_lambda_to_high_level_op(expr) + + if isinstance(hlo, BinaryOp): + subexprs = (hlo.x1, hlo.x2) + elif isinstance(hlo, WhereOp): + subexprs = (hlo.condition, hlo.then, hlo.else_) + elif isinstance(hlo, FullOp): + # A full-op does not impose any constraints + subexprs = () + elif isinstance(hlo, BroadcastOp): + subexprs = (hlo.x,) + elif isinstance(hlo, C99CallOp): + subexprs = hlo.args + elif isinstance(hlo, ReduceOp): + # {{{ ReduceOp doesn't quite involve broadcasting + + i_out_axis = 0 + for i_in_axis in range(hlo.x.ndim): + if i_in_axis not in hlo.axes: + in_tag_var = self.get_kanren_var_for_axis_tag(hlo.x, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(in_tag_var, out_tag_var) + i_out_axis += 1 + + assert i_out_axis == expr.ndim + + # }}} + + for axis in hlo.axes: + self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(hlo.x, + axis)) + return + + else: + raise NotImplementedError(type(hlo)) + + for subexpr in subexprs: + if isinstance(subexpr, pt.Array): + for i_in_axis, i_out_axis in zip( + range(subexpr.ndim), + range(expr.ndim-subexpr.ndim, expr.ndim)): + in_dim = subexpr.shape[i_in_axis] + out_dim = expr.shape[i_out_axis] + if are_shape_components_equal(in_dim, out_dim): + in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + # broadcasted axes, cannot belong to the same + # discretization entity. + assert are_shape_components_equal(in_dim, 1) + + def map_stack(self, expr: pt.Stack) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + # TODO; I think the axis corresponding to 'axis' need not be solved. + for ary in expr.arrays: + self.rec(ary) + + for iaxis in range(expr.ndim): + for ary in expr.arrays: + if iaxis < expr.axis: + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + + self.record_constraint(in_tag_var, out_tag_var) + elif iaxis == expr.axis: + pass + elif iaxis > expr.axis: + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis-1) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + raise AssertionError + + def map_concatenate(self, expr: pt.Concatenate) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + # TODO; I think the axis corresponding to 'axis' need not be solved. + for ary in expr.arrays: + self.rec(ary) + + for ary in expr.arrays: + assert ary.ndim == expr.ndim + for iaxis in range(expr.ndim): + if iaxis != expr.axis: + # non-concatenated axes share the dimensions. + in_tag_var = self.get_kanren_var_for_axis_tag(ary, + iaxis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + iaxis) + self.record_constraint(in_tag_var, out_tag_var) + + def map_axis_permutation(self, expr: pt.AxisPermutation + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + + assert expr.ndim == expr.array.ndim + + for out_axis in range(expr.ndim): + in_axis = expr.axis_permutation[out_axis] + out_tag = self.get_kanren_var_for_axis_tag(expr, out_axis) + in_tag = self.get_kanren_var_for_axis_tag(expr, in_axis) + self.record_constraint(out_tag, in_tag) + + def map_basic_index(self, expr: pt.IndexBase) -> None: + from pytato.array import NormalizedSlice + from pytato.utils import are_shape_components_equal + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + + i_out_axis = 0 + + assert len(expr.indices) == expr.array.ndim + + for i_in_axis, idx in enumerate(expr.indices): + if isinstance(idx, int): + pass + else: + assert isinstance(idx, NormalizedSlice) + if (idx.step == 1 + and are_shape_components_equal(idx.start, 0) + and are_shape_components_equal(idx.stop, + expr.array.shape[i_in_axis])): + + i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, + i_in_axis) + i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(i_in_axis_tag, i_out_axis_tag) + + i_out_axis += 1 + + def map_contiguous_advanced_index(self, + expr: pt.AdvancedIndexInContiguousAxes + ) -> None: + from pytato.array import NormalizedSlice + from pytato.utils import (partition, get_shape_after_broadcasting, + are_shapes_equal, are_shape_components_equal) + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + for idx in expr.indices: + if isinstance(idx, pt.Array): + self.rec(idx) + + i_adv_indices, i_basic_indices = partition( + lambda idx: isinstance(expr.indices[idx], NormalizedSlice), + range(len(expr.indices))) + npre_advanced_basic_indices = len([i_idx + for i_idx in i_basic_indices + if i_idx < i_adv_indices[0]]) + npost_advanced_basic_indices = len([i_idx + for i_idx in i_basic_indices + if i_idx > i_adv_indices[-1]]) + + indirection_arrays = [expr.indices[i_idx] for i_idx in i_adv_indices] + assert are_shapes_equal( + get_shape_after_broadcasting(indirection_arrays), + expr.shape[ + npre_advanced_basic_indices:expr.ndim-npost_advanced_basic_indices]) + + for subexpr in indirection_arrays: + if isinstance(subexpr, pt.Array): + for i_in_axis, i_out_axis in zip( + range(subexpr.ndim), + range(expr.ndim-subexpr.ndim+npre_advanced_basic_indices, + expr.ndim-npost_advanced_basic_indices)): + in_dim = subexpr.shape[i_in_axis] + out_dim = expr.shape[i_out_axis] + if are_shape_components_equal(in_dim, out_dim): + in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, + i_in_axis) + out_tag_var = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + + self.record_constraint(in_tag_var, out_tag_var) + else: + # broadcasted axes, cannot belong to the same + # discretization entity. + assert are_shape_components_equal(in_dim, 1) + + def map_non_contiguous_advanced_index(self, + expr: pt.AdvancedIndexInNoncontiguousAxes + ) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + for idx in expr.indices: + if isinstance(idx, pt.Array): + self.rec(idx) + + def map_reshape(self, expr: pt.Reshape) -> None: + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + self.rec(expr.array) + # we can add constraints to reshape that only include new axes in its + # reshape. + # Other reshapes do not 'conserve' the types in our type-system. + # Well *what if*. Let's just say this type inference fails for + # non-trivial 'reshapes'. So, what are the 'trivial' reshapes? + # trivial reshapes: + # (x1, x2, ... xn) -> ((1,)*, x1, (1,)*, x2, (1,)*, x3, (1,)*, ..., xn, 1*) + # given all(x1!=1, x2!=1, x3!=1, .. xn!= 1) + if ((1 not in (expr.array.shape)) # leads to ambiguous newaxis + and (set(expr.shape) <= (set(expr.array.shape) | {1}))): + i_in_axis = 0 + for i_out_axis, dim in enumerate(expr.shape): + if dim != 1: + assert dim == expr.array.shape[i_in_axis] + i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, + i_in_axis) + i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, + i_out_axis) + self.record_constraint(i_in_axis_tag, i_out_axis_tag) + i_in_axis += 1 + else: + # print(f"Skipping: {expr.array.shape} -> {expr.shape}") + # Wacky reshape => bail. + return + + def map_einsum(self, expr: pt.Einsum) -> None: + + self.record_eq_constraints_from_tags(expr) + self._record_all_axes_to_be_solved_if_impl_stored(expr) + + for arg in expr.args: + self.rec(arg) + + descr_to_tag = {} + for iaxis in range(expr.ndim): + descr_to_tag[EinsumElementwiseAxis(iaxis)] = ( + self.get_kanren_var_for_axis_tag(expr, iaxis)) + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + # if an einsum is stored => every argument's axes must + # also be inferred, even those that are getting reduced. + for iarg_axis, descr in enumerate(access_descrs): + in_tag_var = self.get_kanren_var_for_axis_tag(arg, + iarg_axis) + + if descr in descr_to_tag: + self.record_constraint(descr_to_tag[descr], in_tag_var) + else: + descr_to_tag[descr] = in_tag_var + + if isinstance(descr, EinsumReductionAxis): + self.variables_to_solve.add(in_tag_var) + + def map_dict_of_named_arrays(self, expr: pt.DictOfNamedArrays + ) -> None: + for _, subexpr in sorted(expr._data.items()): + self.rec(subexpr) + self._record_all_axes_to_be_solved(subexpr) + + def map_loopy_call(self, expr: LoopyCall) -> None: + for _, subexpr in sorted(expr.bindings.items()): + if isinstance(subexpr, pt.Array): + if not isinstance(subexpr, pt.InputArgumentBase): + self._record_all_axes_to_be_solved(subexpr) + self.rec(subexpr) + + # there's really no good way to propagate the metadata in this case. + # One *could* raise the loopy kernel instruction expressions to + # high level ops, but that's really involved and probably not worth it. + + def map_named_array(self, expr: pt.NamedArray) -> None: + self.record_eq_constraints_from_tags(expr) + self.rec(expr._container) + + def map_distributed_send_ref_holder(self, + expr: pt.DistributedSendRefHolder + ) -> None: + self.record_eq_constraints_from_tags(expr) + self.rec(expr.passthrough_data) + for idim in range(expr.ndim): + assert (expr.passthrough_data.shape[idim] + == expr.shape[idim]) + self.record_constraint( + self.get_kanren_var_for_axis_tag(expr.passthrough_data, + idim), + self.get_kanren_var_for_axis_tag(expr, idim) + ) + + def map_distributed_recv(self, + expr: pt.DistributedRecv) -> None: + self.record_eq_constraints_from_tags(expr) + + +def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] + ) -> ArrayOrNames: + if not isinstance(expr, (pt.Array, pt.DictOfNamedArrays)): + return rec_map_array_container(unify_discretization_entity_tags, + expr) + + from collections import defaultdict + discr_unification_helper = DiscretizationEntityConstraintCollector() + discr_unification_helper(expr) + tag_var_to_axis = {} + variables_to_solve = [] + + for (axis, var) in discr_unification_helper.axis_to_tag_var.items(): + tag_var_to_axis[var] = axis + if var in discr_unification_helper.variables_to_solve: + variables_to_solve.append(var) + + lhs = [cnstrnt[0] for cnstrnt in discr_unification_helper.constraints] + rhs = [cnstrnt[1] for cnstrnt in discr_unification_helper.constraints] + assert len(lhs) == len(rhs) + solutions = {} + + for i_retry in range(MAX_UNIFY_RETRIES): + old_solutions = solutions.copy() + solutions = unification.unify(lhs, rhs, + {l_expr: r_expr + for l_expr, r_expr in solutions.items() + if isinstance(r_expr, + DiscretizationEntityAxisTag)}) + if solutions == old_solutions: + logger.info(f"Unification converged after {i_retry} iterations.") + break + else: + logger.warn(f"Could not converge after {MAX_UNIFY_RETRIES} iterations.") + + # Ideally it might be better to enable this, but that would be too + # restrictive as not all computation graphs result in DOFArray ouptuts + # if not (frozenset(variables_to_solve) <= frozenset(solutions)): + # raise RuntimeError("Unification failed.") + + # ary_to_axes_tags: mapping from array to a mapping from iaxis to the + # solved tag. + ary_to_axes_tags = defaultdict(dict) + for var in solutions: + ary, axis = tag_var_to_axis[var] + if isinstance(solutions[var], DiscretizationEntityAxisTag): + ary_to_axes_tags[ary][axis] = solutions[var] + if var in variables_to_solve and ( + not isinstance(solutions[var], DiscretizationEntityAxisTag)): + raise RuntimeError(f"Could not solve for {var}.") + + def attach_tags(expr: ArrayOrNames) -> ArrayOrNames: + if not isinstance(expr, pt.Array): + return expr + + for iaxis, solved_tag in ary_to_axes_tags[expr].items(): + if expr.axes[iaxis].tags_of_type(DiscretizationEntityAxisTag): + discr_tag, = (expr + .axes[iaxis] + .tags_of_type(DiscretizationEntityAxisTag)) + assert discr_tag == solved_tag + else: + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + actual_tag = discr_unification_helper.axis_to_tag_var[(expr, + iaxis)] + assert actual_tag in discr_unification_helper.variables_to_solve + assert actual_tag in variables_to_solve + raise ValueError(f"In {expr!r}, axis={iaxis}'s type cannot be " + "inferred.") + expr = expr.with_tagged_axis(iaxis, solved_tag) + + if isinstance(expr, pt.Einsum): + redn_descr_to_entity_type = {} + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + for iaxis, access_descr in enumerate(access_descrs): + if isinstance(access_descr, EinsumReductionAxis): + redn_descr_to_entity_type[access_descr] = ( + ary_to_axes_tags[arg][iaxis]) + + if (frozenset(redn_descr_to_entity_type) + != frozenset(expr.redn_descr_to_redn_dim)): + raise ValueError + + for redn_descr, solved_tag in redn_descr_to_entity_type.items(): + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + raise ValueError(f"In {expr!r}, redn_descr={redn_descr}'s" + " type cannot be inferred.") + expr = expr.with_tagged_redn_dim(redn_descr, solved_tag) + + if isinstance(expr, pt.IndexLambda): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + hlo = index_lambda_to_high_level_op(expr) + if isinstance(hlo, ReduceOp): + for iaxis in hlo.axes: + solved_tag = ary_to_axes_tags[hlo.x][iaxis] + if not isinstance(solved_tag, DiscretizationEntityAxisTag): + raise ValueError(f"In {expr!r}, redn_descr={iaxis}'s" + " type cannot be inferred.") + + expr = expr.with_tagged_redn_dim(iaxis, solved_tag) + + return expr + + return pt.transform.map_and_copy(expr, attach_tags) + +# }}} + + +class UnInferredStoredArrayCatcher(pt.transform.CachedWalkMapper): + """ + Raises a :class:`ValueError` if a stored array has axes without a + :class:`DiscretizationEntityAxisTag` tagged to it. + """ + def post_visit(self, expr: ArrayOrNames) -> None: + if (isinstance(expr, pt.Array) + and expr.tags_of_type(pt.tags.ImplStored)): + if any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in expr.axes): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + if (isinstance(expr, pt.IndexLambda) + and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for redn_dim in expr.reduction_dims.values())): + raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") + + if (isinstance(expr, pt.Einsum) + and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for redn_dim in expr.redn_descr_to_redn_dim.values())): + raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") + + if isinstance(expr, pt.DictOfNamedArrays): + if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in subexpr.axes) + for subexpr in expr._data.values()): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + from pytato.loopy import LoopyCall + + if isinstance(expr, LoopyCall): + if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 + for axis in subexpr.axes) + for subexpr in expr.bindings.values() + if (isinstance(subexpr, pt.Array) + and not isinstance(subexpr, pt.InputArgumentBase) + and subexpr.ndim != 0)): + raise ValueError(f"{expr!r} doesn't have all its axes inferred.") + + +def are_all_stored_arrays_inferred(expr: ArrayOrNames): + UnInferredStoredArrayCatcher()(expr) + # vim: fdm=marker From 6283a35d13323c3624b7133d1c70a41d9bd3491b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:05:08 -0600 Subject: [PATCH 3/7] Implements FusionContractorArrayContext * Performs Loop Fusion * Performs Array contraction * Splits kernels at the granularity of fused einsums * Transforms those einsums using recorded values from github.com/kaushikcfd/feinsum Co-authored-by: Matthias Diener Co-authored-by: Matthew Smith --- meshmode/array_context.py | 1286 +++++++++++++++++++++++++++++++++++-- meshmode/pytato_utils.py | 555 +--------------- 2 files changed, 1257 insertions(+), 584 deletions(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 071071dc4..cfee0b977 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -27,8 +27,10 @@ import sys import logging +import numpy as np from warnings import warn +from typing import Union, FrozenSet, Tuple, Any from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase from arraycontext.pytest import ( @@ -37,6 +39,21 @@ register_pytest_array_context_factory) from loopy.translation_unit import for_each_kernel +from loopy.tools import memoize_on_disk +from pytools import ProcessLogger, memoize_on_first_arg +from pytools.tag import UniqueTag, tag_dataclass + +from meshmode.transform_metadata import (DiscretizationElementAxisTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationDimAxisTag, + DiscretizationTopologicalDimAxisTag, + DiscretizationAmbientDimAxisTag, + DiscretizationFlattenedDOFAxisTag, + DiscretizationEntityAxisTag) +from dataclasses import dataclass + +from pyrsistent import pmap logger = logging.getLogger(__name__) @@ -431,63 +448,60 @@ def _alias_global_temporaries(t_unit): intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. """ from loopy.kernel.data import AddressSpace - from loopy.kernel import KernelState - from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, - CallKernel, ReturnFromKernel, Barrier) - from loopy.schedule.tools import get_return_from_kernel_mapping - from pytools import UniqueNameGenerator from collections import defaultdict kernel = t_unit.default_entrypoint - assert kernel.state == KernelState.LINEARIZED + toposorted_iels = _get_element_loop_topo_sorted_order(kernel) + iel_order = {iel: i + for i, iel in enumerate(toposorted_iels)} + temp_vars = frozenset(tv.name for tv in kernel.temporary_variables.values() if tv.address_space == AddressSpace.GLOBAL) - temp_to_live_interval_start = {} - temp_to_live_interval_end = {} - return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) - - for sched_idx, sched_item in enumerate(kernel.linearization): - if isinstance(sched_item, RunInstruction): - for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() - & temp_vars): - if var not in temp_to_live_interval_start: - assert var not in temp_to_live_interval_end - temp_to_live_interval_start[var] = sched_idx - assert var in temp_to_live_interval_start - temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] - elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, - ReturnFromKernel, Barrier)): - # no variables are accessed within these schedule items => do - # nothing. - pass - else: - raise NotImplementedError(type(sched_item)) + temp_var_to_iels = {tv: set() for tv in temp_vars} + all_iels = { + iel + for iel in kernel.all_inames() + if kernel.inames[iel].tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))} + + if not all_iels: + # no element loops => return the t_unit as is. + return t_unit - vng = UniqueNameGenerator() + for insn in kernel.instructions: + iel, = insn.within_inames & all_iels + + for tv in insn.dependency_names() & temp_vars: + temp_var_to_iels[tv].add(iel) + + temp_to_iel_start = {tv: min(iels, + key=lambda x: iel_order[x], + default=toposorted_iels[-1] + ) + for tv, iels in temp_var_to_iels.items()} + temp_to_iel_end = {tv: max(iels, + key=lambda x: iel_order[x], + default=toposorted_iels[0] + ) + for tv, iels in temp_var_to_iels.items()} + + iel_to_temps_to_allocate = {iel: set() for iel in all_iels} + iel_to_temps_to_free = {iel: set() for iel in all_iels} + for tv in temp_vars: + allocate_iel, free_iel = temp_to_iel_start[tv], temp_to_iel_end[tv] + iel_to_temps_to_allocate[allocate_iel].add(tv) + iel_to_temps_to_free[free_iel].add(tv) + + vng = kernel.get_var_name_generator() # a mapping from shape to the available base storages from temp variables # that were dead. shape_to_available_base_storage = defaultdict(set) - sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] - sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] - - for tv, just_alive_idx in temp_to_live_interval_start.items(): - sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) - - for tv, just_dead_idx in temp_to_live_interval_end.items(): - sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) - new_tvs = {} - for sched_idx, _ in enumerate(kernel.linearization): - just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] - to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] - for tv_name in sorted(just_dead_temps): - tv = new_tvs[tv_name] - assert tv.base_storage is not None - assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] - shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + for iel in toposorted_iels: + to_be_allocated_temps = iel_to_temps_to_allocate[iel] for tv_name in sorted(to_be_allocated_temps): assert len(to_be_allocated_temps) <= 1 @@ -502,14 +516,37 @@ def _alias_global_temporaries(t_unit): new_tvs[tv.name] = tv.copy(base_storage=base_storage) + just_dead_temps = iel_to_temps_to_free[iel] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + for name, tv in kernel.temporary_variables.items(): if tv.address_space != AddressSpace.GLOBAL: new_tvs[name] = tv else: - assert name in new_tvs + # FIXME: Need tighter assertion condition (this doesn't work when + # zero-size arrays are present) + # assert name in new_tvs + pass kernel = kernel.copy(temporary_variables=new_tvs) + old_tmp_mem_requirement = sum( + tv.nbytes + for tv in kernel.temporary_variables.values()) + + new_tmp_mem_requirement = sum( + {tv.base_storage: tv.nbytes + for tv in kernel.temporary_variables.values()}.values()) + + logger.info( + "[_alias_global_temporaries]: Reduced memory requirement from " + f"{old_tmp_mem_requirement*1e-6:.1f}MB to" + f" {new_tmp_mem_requirement*1e-6:.1f}MB.") + return t_unit.with_kernel(kernel) @@ -565,8 +602,6 @@ def transform_loopy_program(self, t_unit): t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) t_unit = lp.set_options(t_unit, "insert_gbarriers") - t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) - t_unit = _alias_global_temporaries(t_unit) return t_unit @@ -628,4 +663,1161 @@ def untag_loopy_call_results(expr): return dag + +def get_temps_not_to_contract(knl): + from functools import reduce + wmap = knl.writer_map() + rmap = knl.reader_map() + + temps_not_to_contract = set() + for tv in knl.temporary_variables: + if len(wmap.get(tv, set())) == 1: + writer_id, = wmap[tv] + writer_loop_nest = knl.id_to_insn[writer_id].within_inames + insns_in_writer_loop_nest = reduce(frozenset.union, + (knl.iname_to_insns()[iname] + for iname in writer_loop_nest), + frozenset()) + if ( + (not (rmap.get(tv, frozenset()) + <= insns_in_writer_loop_nest)) + or len(knl.id_to_insn[writer_id].reduction_inames()) != 0 + or any((len(knl.id_to_insn[reader_id].reduction_inames()) != 0) + for reader_id in rmap.get(tv, frozenset()))): + temps_not_to_contract.add(tv) + else: + temps_not_to_contract.add(tv) + return temps_not_to_contract + + # Better way to query it... + # import loopy as lp + # from kanren.constraints import neq as kanren_neq + # + # tempo = lp.relations.get_tempo(knl) + # producero = lp.relations.get_producero(knl) + # consumero = lp.relations.get_consumero(knl) + # withino = lp.relations.get_withino(knl) + # reduce_insno = lp.relations.get_reduce_insno(knl) + # + # # temp_k: temporary variable that cannot be contracted + # temp_k = kanren.var() + # producer_insn_k = kanren.var() + # producer_loops_k = kanren.var() + # consumer_insn_k = kanren.var() + # consumer_loops_k = kanren.var() + + # temps_not_to_contract = kanren.run(0, + # temp_k, + # tempo(temp_k), + # producero(producer_insn_k, + # temp_k), + # consumero(consumer_insn_k, + # temp_k), + # withino(producer_insn_k, + # producer_loops_k), + # withino(consumer_insn_k, + # consumer_loops_k), + # kanren.lany( + # kanren_neq( + # producer_loops_k, + # consumer_loops_k), + # reduce_insno(consumer_insn_k)), + # results_filter=frozenset) + # return temps_not_to_contract + + +def _is_iel_loop_part_of_global_dof_loops(iel: str, knl) -> bool: + insn, = knl.iname_to_insns()[iel] + return any(iname + for iname in knl.id_to_insn[insn].within_inames + if knl.iname_tags_of_type(iname, DiscretizationDOFAxisTag)) + + +def _discr_entity_sort_key(discr_tag: DiscretizationEntityAxisTag + ) -> Tuple[Any, ...]: + + return type(discr_tag).__name__ + + +# {{{ define FEMEinsumTag + +@dataclass(frozen=True) +class EinsumIndex: + discr_entity: DiscretizationEntityAxisTag + length: int + + @classmethod + def from_iname(cls, iname, kernel): + discr_entity, = kernel.filter_iname_tags_by_type( + iname, DiscretizationEntityAxisTag) + length = kernel.get_constant_iname_length(iname) + return cls(discr_entity, length) + + +@dataclass(frozen=True) +class FreeEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class SummationEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class FEMEinsumTag(UniqueTag): + indices: Tuple[Tuple[EinsumIndex, ...], ...] + + +class NotAnFEMEinsumError(ValueError): + """ + pass + """ + +# }}} + + +@memoize_on_first_arg +def _get_redn_iname_to_insns(kernel): + from immutables import Map + redn_iname_to_insns = {iname: set() + for iname in kernel.all_inames()} + + for insn in kernel.instructions: + for redn_iname in insn.reduction_inames(): + redn_iname_to_insns[redn_iname].add(insn.id) + + return Map({k: frozenset(v) + for k, v in redn_iname_to_insns.items()}) + + +def _do_inames_belong_to_different_einsum_types(iname1, iname2, kernel): + if kernel.iname_to_insns()[iname1]: + assert (len(kernel.iname_to_insns()[iname1]) + == len(kernel.iname_to_insns()[iname2]) + == 1) + insn1, = kernel.iname_to_insns()[iname1] + insn2, = kernel.iname_to_insns()[iname2] + else: + redn_iname_to_insns = _get_redn_iname_to_insns(kernel) + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + insn1, = redn_iname_to_insns[iname1] + insn2, = redn_iname_to_insns[iname2] + + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + + var1_name, = kernel.id_to_insn[insn1].assignee_var_names() + var2_name, = kernel.id_to_insn[insn2].assignee_var_names() + var1 = kernel.get_var_descriptor(var1_name) + var2 = kernel.get_var_descriptor(var2_name) + + ensm1, = var1.tags_of_type(FEMEinsumTag) + ensm2, = var2.tags_of_type(FEMEinsumTag) + + return ensm1 != ensm2 + + +def _fuse_loops_over_a_discr_entity(knl, + mesh_entity, + fused_loop_prefix, + should_fuse_redn_loops, + orig_knl): + import loopy as lp + import kanren + from functools import reduce, partial + taggedo = lp.relations.get_taggedo_of_type(orig_knl, mesh_entity) + + redn_loops = reduce(frozenset.union, + (insn.reduction_inames() + for insn in orig_knl.instructions), + frozenset()) + + non_redn_loops = reduce(frozenset.union, + (insn.within_inames + for insn in orig_knl.instructions), + frozenset()) + + # tag_k: tag of type 'mesh_entity' + tag_k = kanren.var() + tags = kanren.run(0, + tag_k, + taggedo(kanren.var(), tag_k), + results_filter=frozenset) + for itag, tag in enumerate( + sorted(tags, key=lambda x: _discr_entity_sort_key(x))): + # iname_k: iname tagged with 'tag' + iname_k = kanren.var() + inames = kanren.run(0, + iname_k, + taggedo(iname_k, tag), + results_filter=frozenset) + inames = frozenset(inames) + if should_fuse_redn_loops: + inames = inames & redn_loops + else: + inames = inames & non_redn_loops + + length_to_inames = {} + for iname in inames: + length = knl.get_constant_iname_length(iname) + length_to_inames.setdefault(length, set()).add(iname) + + for i, (_, inames_to_fuse) in enumerate( + sorted(length_to_inames.items())): + + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, inames_to_fuse, + prefix=f"{fused_loop_prefix}_{itag}_{i}_", + force_infusible=partial( + _do_inames_belong_to_different_einsum_types, + kernel=orig_knl), + )) + knl = lp.tag_inames(knl, {f"{fused_loop_prefix}_{itag}_*": tag}) + + return knl + + +@memoize_on_disk +def fuse_same_discretization_entity_loops(knl): + # maintain an 'orig_knl' to keep the original iname and tags before + # transforming it. + orig_knl = knl + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationElementAxisTag, + "iel", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + False, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + True, + orig_knl) + + return knl + + +@memoize_on_disk +def contract_arrays(knl, callables_table): + import loopy as lp + from loopy.transform.precompute import precompute_for_single_kernel + + temps_not_to_contract = get_temps_not_to_contract(knl) + all_temps = frozenset(knl.temporary_variables) + + logger.info("Array Contraction: Contracting " + f"{len(all_temps-frozenset(temps_not_to_contract))} temps") + + wmap = knl.writer_map() + + for temp in sorted(all_temps - frozenset(temps_not_to_contract)): + writer_id, = wmap[temp] + rmap = knl.reader_map() + ensm_tag, = knl.id_to_insn[writer_id].tags_of_type(EinsumTag) + + knl = lp.assignment_to_subst(knl, temp, + remove_newly_unused_inames=False) + if temp not in rmap: + # no one was reading 'temp' i.e. dead code got eliminated :) + assert f"{temp}_subst" not in knl.substitutions + continue + knl = precompute_for_single_kernel( + knl, callables_table, f"{temp}_subst", + sweep_inames=(), + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id=f"_mm_contract_{temp}", + ) + + knl = lp.map_instructions(knl, + f"id:_mm_contract_{temp}", + lambda x: x.tagged(ensm_tag)) + + return lp.remove_unused_inames(knl) + + +def _get_group_size_for_dof_array_loop(nunit_dofs): + """ + Returns the OpenCL workgroup size for a loop iterating over the global DOFs + of a discretization with *nunit_dofs* per cell. + """ + if nunit_dofs == {6}: + return 16, 6 + elif nunit_dofs == {10}: + return 16, 10 + elif nunit_dofs == {20}: + return 16, 10 + elif nunit_dofs == {1}: + return 32, 1 + elif nunit_dofs == {2}: + return 32, 2 + elif nunit_dofs == {4}: + return 16, 4 + elif nunit_dofs == {3}: + return 32, 3 + elif nunit_dofs == {35}: + return 9, 7 + elif nunit_dofs == {15}: + return 8, 8 + elif nunit_dofs == {7}: + return 9, 7 + else: + # /!\ not ideal performance-wise but better than raising. + return 8, 4 + + +def _get_iel_to_idofs(kernel): + iel_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))) + } + idof_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDOFAxisTag)) + } + iface_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationFaceAxisTag)) + } + idim_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDimAxisTag)) + } + + iel_to_idofs = {iel: set() for iel in iel_inames} + + for insn in kernel.instructions: + if (len(insn.within_inames) == 1 + and (insn.within_inames) <= iel_inames): + iel, = insn.within_inames + if all(kernel.id_to_insn[el_insn].within_inames == insn.within_inames + for el_insn in kernel.iname_to_insns()[iel]): + # the iel here doesn't interfere with any idof i.e. we + # support parallelizing such loops. + pass + else: + raise NotImplementedError(f"The loop {insn.within_inames}" + " does not appear as a singly nested" + " loop.") + elif ((len(insn.within_inames) == 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1)): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((iel in kernel.id_to_insn[dof_insn].within_inames) + for dof_insn in kernel.iname_to_insns()[idof]): + pass + else: + raise NotImplementedError("The loop " + f"'{insn.within_inames}' has the idof-loop" + " that's not nested within the iel-loop.") + elif ((len(insn.within_inames) > 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1) + and (len(insn.within_inames & (idim_inames | iface_inames)) + == (len(insn.within_inames) - 2))): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((all({iel, idof} <= kernel.id_to_insn[non_iel_insn].within_inames + for non_iel_insn in kernel.iname_to_insns()[non_iel_iname])) + for non_iel_iname in insn.within_inames - {iel}): + iel_to_idofs[iel].add(idof) + else: + raise NotImplementedError("Could not fit into " + " loop nest pattern.") + else: + raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" + " into known set of loop-nest patterns.") + + return pmap({iel: frozenset(idofs) + for iel, idofs in iel_to_idofs.items()}) + + +def _get_iel_loop_from_insn(insn, knl): + iel, = {iname + for iname in insn.within_inames + if knl.inames[iname].tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))} + return iel + + +def _get_element_loop_topo_sorted_order(knl): + from loopy import MultiAssignmentBase + dag = {iel: set() + for iel in knl.all_inames() + if knl.inames[iel].tags_of_type(DiscretizationElementAxisTag)} + + for insn in knl.instructions: + if isinstance(insn, MultiAssignmentBase): + succ_iel = _get_iel_loop_from_insn(insn, knl) + for dep_id in insn.depends_on: + pred_iel = _get_iel_loop_from_insn(knl.id_to_insn[dep_id], knl) + if pred_iel != succ_iel: + dag[pred_iel].add(succ_iel) + + from pytools.graph import compute_topological_order + return compute_topological_order(dag, key=lambda x: x) + + +@tag_dataclass +class EinsumTag(UniqueTag): + orig_loop_nest: FrozenSet[str] + + +def _prepare_kernel_for_parallelization(kernel): + discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel", + DiscretizationDOFAxisTag: "idof", + DiscretizationDimAxisTag: "idim", + DiscretizationAmbientDimAxisTag: "idim", + DiscretizationTopologicalDimAxisTag: "idim", + DiscretizationFlattenedDOFAxisTag: "imsh_nodes", + DiscretizationFaceAxisTag: "iface"} + import loopy as lp + from loopy.match import ObjTagged + + # A mapping from inames that the instruction accesss to + # the instructions ids within that iname. + ensm_buckets = {} + vng = kernel.get_var_name_generator() + + for insn in kernel.instructions: + inames = insn.within_inames | insn.reduction_inames() + ensm_buckets.setdefault(tuple(sorted(inames)), set()).add(insn.id) + + # FIXME: Dependency violation is a big concern here + # Waiting on the loopy feature: https://github.com/inducer/loopy/issues/550 + + for ieinsm, (loop_nest, insns) in enumerate(sorted(ensm_buckets.items())): + new_insns = [insn.tagged(EinsumTag(frozenset(loop_nest))) + if insn.id in insns + else insn + for insn in kernel.instructions] + kernel = kernel.copy(instructions=new_insns) + + new_inames = [] + for iname in loop_nest: + discr_tag, = kernel.iname_tags_of_type(iname, + DiscretizationEntityAxisTag) + new_iname = vng(f"{discr_tag_to_prefix[type(discr_tag)]}_ensm{ieinsm}") + new_inames.append(new_iname) + + kernel = lp.duplicate_inames( + kernel, + loop_nest, + within=ObjTagged(EinsumTag(frozenset(loop_nest))), + new_inames=new_inames, + tags={iname: kernel.inames[iname].tags + for iname in loop_nest}) + + return kernel + + +def _get_elementwise_einsum(t_unit, einsum_tag): + import loopy as lp + import feinsum as fnsm + from loopy.match import ObjTagged + from pymbolic.primitives import Variable, Subscript + + kernel = t_unit.default_entrypoint + + assert isinstance(einsum_tag, EinsumTag) + insn_match = ObjTagged(einsum_tag) + + global_vars = ({tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == lp.AddressSpace.GLOBAL} + | set(kernel.arg_dict.keys())) + insns = [insn + for insn in kernel.instructions + if insn_match(kernel, insn)] + idx_tuples = set() + + for insn in insns: + assert len(insn.assignees) == 1 + if isinstance(insn.assignee, Variable): + if insn.assignee.name in global_vars: + raise NotImplementedError(insn) + else: + assert (kernel.temporary_variables[insn.assignee.name].address_space + == lp.AddressSpace.PRIVATE) + elif isinstance(insn.assignee, Subscript): + assert insn.assignee_name in global_vars + idx_tuples.add(tuple(idx.name + for idx in insn.assignee.index_tuple)) + else: + raise NotImplementedError(insn) + + if len(idx_tuples) != 1: + raise NotImplementedError("Multiple einsums in the same loop nest =>" + " not allowed.") + idx_tuple, = idx_tuples + subscript = "{lhs}, {lhs}->{lhs}".format( + lhs="".join(chr(97+i) + for i in range(len(idx_tuple)))) + arg_shape = tuple(np.inf + if kernel.iname_tags_of_type(idx, DiscretizationElementAxisTag) + else kernel.get_constant_iname_length(idx) + for idx in idx_tuple) + return fnsm.einsum(subscript, + fnsm.array(arg_shape, "float64"), + fnsm.array(arg_shape, "float64")) + + +def _combine_einsum_domains(knl): + import islpy as isl + from functools import reduce + + new_domains = [] + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in knl.instructions), + frozenset()) + + for tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + insns = [insn + for insn in knl.instructions + if tag in insn.tags] + inames = reduce(frozenset.union, + ((insn.within_inames | insn.reduction_inames()) + for insn in insns), + frozenset()) + domain = knl.get_inames_domain(frozenset(inames)) + new_domains.append(domain.project_out_except(sorted(inames), + [isl.dim_type.set])) + + return knl.copy(domains=new_domains) + + +class FusionContractorArrayContext( + SingleGridWorkBalancingPytatoArrayContext): + + def transform_dag(self, dag): + import pytato as pt + + # {{{ Remove FEMEinsumTags that might have been propagated + + # TODO: Is this too hacky? + + def remove_fem_einsum_tags(expr): + if isinstance(expr, pt.Array): + try: + fem_ensm_tag = next(iter(expr.tags_of_type(FEMEinsumTag))) + except StopIteration: + return expr + else: + assert isinstance(expr, pt.InputArgumentBase) + return expr.without_tags(fem_ensm_tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_fem_einsum_tags) + + # }}} + + # {{{ CSE + + with ProcessLogger(logger, "transform_dag.mpms_materialization"): + dag = pt.transform.materialize_with_mpms(dag) + + def mark_materialized_nodes_as_cse( + ary: Union[pt.Array, + pt.AbstractResultWithNamedArrays]) -> pt.Array: + if isinstance(ary, pt.AbstractResultWithNamedArrays): + return ary + + if ary.tags_of_type(pt.tags.ImplStored): + return ary.tagged(pt.tags.PrefixNamed("cse")) + else: + return ary + + with ProcessLogger(logger, "transform_dag.naming_cse"): + dag = pt.transform.map_and_copy(dag, mark_materialized_nodes_as_cse) + + # }}} + + # {{{ indirect addressing are non-negative + + indirection_maps = set() + + class _IndirectionMapRecorder(pt.transform.CachedWalkMapper): + def post_visit(self, expr): + if isinstance(expr, pt.IndexBase): + for idx in expr.indices: + if isinstance(idx, pt.Array): + indirection_maps.add(idx) + + _IndirectionMapRecorder()(dag) + + def tag_indices_as_non_negative(ary): + if ary in indirection_maps: + return ary.tagged(pt.tags.AssumeNonNegative()) + else: + return ary + + with ProcessLogger(logger, "transform_dag.tag_indices_as_non_negative"): + dag = pt.transform.map_and_copy(dag, tag_indices_as_non_negative) + + # }}} + + with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"): + dag = pt.transform.deduplicate_data_wrappers(dag) + + # {{{ get rid of copies for different views of a cl-array + + def eliminate_reshapes_of_data_wrappers(ary): + if (isinstance(ary, pt.Reshape) + and isinstance(ary.array, pt.DataWrapper)): + return pt.make_data_wrapper(ary.array.data.reshape(ary.shape), + tags=ary.tags, + axes=ary.axes) + else: + return ary + + dag = pt.transform.map_and_copy(dag, + eliminate_reshapes_of_data_wrappers) + + # }}} + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_input_and_output(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return (pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass")))) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_face_mass_ins_and_outs"): + dag = pt.transform.map_and_copy(dag, + materialize_face_mass_input_and_output) + + # }}} + + # {{{ materialize inverse mass inputs + + def materialize_inverse_mass_inputs(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ei,ij,ej->ei")): + arg1, arg2, arg3 = expr.args + if not arg3.tags_of_type(pt.tags.PrefixNamed): + arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp")) + if not arg3.tags_of_type(pt.tags.ImplStored): + arg3 = arg3.tagged(pt.tags.ImplStored()) + + return pt.Einsum(expr.access_descriptors, + (arg1, arg2, arg3), + expr.axes, + expr.redn_axis_to_redn_descr, + expr.index_to_access_descr, + expr.tags) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_inverse_mass_inputs) + + # }}} + + # {{{ materialize all einsums + + def materialize_all_einsums_or_reduces(expr): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + if isinstance(expr, pt.Einsum): + return expr.tagged(pt.tags.ImplStored()) + elif (isinstance(expr, pt.IndexLambda) + and isinstance(index_lambda_to_high_level_op(expr), ReduceOp)): + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_all_einsums_or_reduces"): + dag = pt.transform.map_and_copy(dag, materialize_all_einsums_or_reduces) + + # }}} + + # {{{ infer axis types + + from meshmode.pytato_utils import unify_discretization_entity_tags + + with ProcessLogger(logger, "transform_dag.infer_axes_tags"): + dag = unify_discretization_entity_tags(dag) + + # }}} + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + # {{{ remove broadcasts from einsums: help feinsum + + ensm_arg_rewrite_cache = {} + + def _get_rid_of_broadcasts_from_einsum(expr): + # Helpful for matching against the available expressions + # in feinsum. + + from pytato.utils import (are_shape_components_equal, + are_shapes_equal) + if isinstance(expr, pt.Einsum): + from pytato.array import EinsumElementwiseAxis + idx_to_len = expr._access_descr_to_axis_len() + new_access_descriptors = [] + new_args = [] + inp_gatherer = pt.transform.InputGatherer() + access_descr_to_axes = dict(expr.redn_axis_to_redn_descr) + for iax, axis in enumerate(expr.axes): + access_descr_to_axes[EinsumElementwiseAxis(iax)] = axis + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + new_shape = [] + new_access_descrs = [] + new_axes = [] + for iaxis, (access_descr, axis_len) in enumerate( + zip(access_descrs, + arg.shape)): + if not are_shape_components_equal(axis_len, + idx_to_len[access_descr]): + assert are_shape_components_equal(axis_len, 1) + if any(isinstance(inp, pt.Placeholder) + for inp in inp_gatherer(arg)): + # do not get rid of broadcasts from parameteric + # data. + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + new_axes.append(arg.axes[iaxis]) + else: + new_axes.append(arg.axes[iaxis]) + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + + if not are_shapes_equal(new_shape, arg.shape): + assert len(new_axes) == len(new_shape) + arg_to_freeze = (arg.reshape(new_shape) + .copy(axes=tuple( + access_descr_to_axes[acc_descr] + for acc_descr in new_access_descrs))) + + try: + new_arg = ensm_arg_rewrite_cache[arg_to_freeze] + except KeyError: + new_arg = self.thaw(self.freeze(arg_to_freeze)) + ensm_arg_rewrite_cache[arg_to_freeze] = new_arg + + arg = new_arg + + assert arg.ndim == len(new_access_descrs) + new_args.append(arg) + new_access_descriptors.append(tuple(new_access_descrs)) + + return pt.Einsum(tuple(new_access_descriptors), + tuple(new_args), + tags=expr.tags, + axes=expr.axes, + redn_axis_to_redn_descr=(expr + .redn_axis_to_redn_descr), + index_to_access_descr=expr.index_to_access_descr) + else: + return expr + + dag = pt.transform.map_and_copy(dag, _get_rid_of_broadcasts_from_einsum) + + # }}} + + # {{{ remove any PartID tags + + from pytato.distributed import PartIDTag + + def remove_part_id_tags(expr): + if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): + tag, = expr.tags_of_type(PartIDTag) + return expr.without_tags(tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_part_id_tags) + + # }}} + + # {{{ attach FEMEinsumTag tags + + dag_outputs = frozenset(dag._data.values()) + + def add_fem_einsum_tags(expr): + if isinstance(expr, pt.Einsum): + from pytato.array import (EinsumElementwiseAxis, + EinsumReductionAxis) + assert expr.tags_of_type(pt.tags.ImplStored) + ensm_indices = [] + for arg, access_descrs in zip(expr.args, + expr.access_descriptors): + arg_indices = [] + for iaxis, access_descr in enumerate(access_descrs): + try: + discr_tag = next( + iter(arg + .axes[iaxis] + .tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + if isinstance(access_descr, EinsumElementwiseAxis): + arg_indices.append(FreeEinsumIndex(discr_tag, + arg.shape[iaxis])) + elif isinstance(access_descr, EinsumReductionAxis): + arg_indices.append(SummationEinsumIndex( + discr_tag, + arg.shape[iaxis])) + else: + raise NotImplementedError(access_descr) + ensm_indices.append(tuple(arg_indices)) + + return expr.tagged(FEMEinsumTag(tuple(ensm_indices))) + elif (isinstance(expr, pt.Array) + and (expr.tags_of_type(pt.tags.ImplStored) + or expr in dag_outputs)): + if (isinstance(expr, pt.IndexLambda) + and expr.var_to_reduction_descr + and expr.shape == ()): + raise NotImplementedError("all-reduce expressions not" + " supported") + else: + discr_tags = [] + for axis in expr.axes: + try: + discr_tag = next( + iter(axis.tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + discr_tags.append(discr_tag) + + fem_ensm_tag = FEMEinsumTag( + (tuple(FreeEinsumIndex(discr_tag, dim) + for dim, discr_tag in zip(expr.shape, + discr_tags)),) * 2 + ) + + return expr.tagged(fem_ensm_tag) + + else: + return expr + + try: + dag = pt.transform.map_and_copy(dag, add_fem_einsum_tags) + except NotAnFEMEinsumError: + pass + + # }}} + + # {{{ untag outputs tagged from being tagged ImplStored + + def _untag_impl_stored(expr): + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.without_tags(pt.tags.ImplStored(), + verify_existence=False) + + dag = pt.make_dict_of_named_arrays({ + name: _untag_impl_stored(named_ary.expr) + for name, named_ary in dag.items()}) + + # }}} + + return dag + + def transform_loopy_program(self, t_unit): + import loopy as lp + from functools import reduce + from arraycontext.impl.pytato.compile import FromArrayContextCompile + + original_t_unit = t_unit + + # from loopy.transform.instruction import simplify_indices + # t_unit = simplify_indices(t_unit) + + knl = t_unit.default_entrypoint + + logger.info(f"Transforming kernel with {len(knl.instructions)} statements.") + + # {{{ fallback: if the inames are not inferred which mesh entity they + # iterate over. + + for iname in knl.all_inames(): + if not knl.iname_tags_of_type(iname, DiscretizationEntityAxisTag): + warn(f"[{knl.name}]: Falling back to a slower transformation" + " strategy as some loops are uninferred which mesh entity" + " they belong to.", + stacklevel=2) + + return super().transform_loopy_program(original_t_unit) + + for insn in knl.instructions: + for assignee in insn.assignee_var_names(): + var = knl.get_var_descriptor(assignee) + if not var.tags_of_type(FEMEinsumTag): + warn(f"[{knl.name}]: Falling back to a slower transformation" + " strategy as some instructions couldn't be inferred as" + " einsums", + stacklevel=2) + + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ hardcode offset to 0 (sorry humanity) + + knl = knl.copy(args=[arg.copy(offset=0) + for arg in knl.args]) + + # }}} + + # {{{ loop fusion + + with ProcessLogger(logger, "Loop Fusion"): + knl = fuse_same_discretization_entity_loops(knl) + + # }}} + + # {{{ align kernels for fused einsums + + knl = _prepare_kernel_for_parallelization(knl) + knl = _combine_einsum_domains(knl) + + # }}} + + # {{{ array contraction + + with ProcessLogger(logger, "Array Contraction"): + knl = contract_arrays(knl, t_unit.callables_table) + + # }}} + + # {{{ Stats Collection (Disabled) + + if 0: + with ProcessLogger(logger, "Counting Kernel Ops"): + from loopy.kernel.array import ArrayBase + from pytools import product + knl = knl.copy( + silenced_warnings=(knl.silenced_warnings + + ["insn_count_subgroups_upper_bound", + "summing_if_branches_ops"])) + + t_unit = t_unit.with_kernel(knl) + + op_map = lp.get_op_map(t_unit, subgroup_size=32) + + c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + f32_ops = ((op_map.filter_by(dtype=[np.float32], + kernel_name=knl.name) + .eval_and_sum({})) + + (2 * c64_ops["add"] + + 6 * c64_ops["mul"] + + (6 + 3 + 2) * c64_ops["div"])) + f64_ops = ((op_map.filter_by(dtype=[np.float64], + kernel_name="_pt_kernel") + .eval_and_sum({})) + + (2 * c128_ops["add"] + + 6 * c128_ops["mul"] + + (6 + 3 + 2) * c128_ops["div"])) + + # {{{ footprint gathering + + nfootprint_bytes = 0 + + for ary in knl.args: + if (isinstance(ary, ArrayBase) + and ary.address_space == lp.AddressSpace.GLOBAL): + nfootprint_bytes += (product(ary.shape) + * ary.dtype.itemsize) + + for ary in knl.temporary_variables.values(): + if ary.address_space == lp.AddressSpace.GLOBAL: + # global temps would be written once and read once + nfootprint_bytes += (2 * product(ary.shape) + * ary.dtype.itemsize) + + # }}} + + if f32_ops: + logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}") + if f64_ops: + logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}") + logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}") + + # }}} + + # {{{ check whether we can parallelize the kernel + + try: + iel_to_idofs = _get_iel_to_idofs(knl) + except NotImplementedError as err: + if knl.tags_of_type(FromArrayContextCompile): + raise err + else: + warn(f"[{knl.name}]: FusionContractorArrayContext." + "transform_loopy_program not broad enough (yet)." + " Falling back to a possibly slower" + " transformation strategy.") + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ insert barriers between consecutive iel-loops + + toposorted_iels = _get_element_loop_topo_sorted_order(knl) + + for iel_pred, iel_succ in zip(toposorted_iels[:-1], + toposorted_iels[1:]): + knl = lp.add_barrier(knl, + insn_before=f"iname:{iel_pred}", + insn_after=f"iname:{iel_succ}") + + # }}} + + t_unit = _alias_global_temporaries(t_unit) + + # {{{ Parallelization strategy: Use feinsum + + t_unit = t_unit.with_kernel(knl) + del knl + + if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile): + # FIXME: Enable this branch, WIP for now and hence disabled it. + from loopy.match import ObjTagged + import feinsum as fnsm + from meshmode.feinsum_transformations import FEINSUM_TO_TRANSFORMS + + assert all(insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions + if isinstance(insn, lp.MultiAssignmentBase) + ) + + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions), + frozenset()) + for ensm_tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + if reduce(frozenset.union, + (insn.reduction_inames() + for insn in (t_unit.default_entrypoint.instructions) + if ensm_tag in insn.tags), + frozenset()): + fused_einsum = fnsm.match_einsum(t_unit, ObjTagged(ensm_tag)) + else: + # elementwise loop + fused_einsum = _get_elementwise_einsum(t_unit, ensm_tag) + + try: + fnsm_transform = FEINSUM_TO_TRANSFORMS[ + fnsm.normalize_einsum(fused_einsum)] + except KeyError: + fnsm.query(fused_einsum, + self.queue.context, + err_if_no_results=True) + 1/0 + + t_unit = fnsm_transform(t_unit, + insn_match=ObjTagged(ensm_tag)) + else: + knl = t_unit.default_entrypoint + for iel, idofs in sorted(iel_to_idofs.items()): + if idofs: + nunit_dofs = {knl.get_constant_iname_length(idof) + for idof in idofs} + idof, = idofs + + l_one_size, l_zero_size = _get_group_size_for_dof_array_loop( + nunit_dofs) + + knl = lp.split_iname(knl, iel, l_one_size, + inner_tag="l.1", outer_tag="g.0") + knl = lp.split_iname(knl, idof, l_zero_size, + inner_tag="l.0", outer_tag="unr") + else: + knl = lp.split_iname(knl, iel, 32, + outer_tag="g.0", inner_tag="l.0") + + t_unit = t_unit.with_kernel(knl) + + # }}} + + return t_unit + # vim: foldmethod=marker diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py index 986c65fff..046981678 100644 --- a/meshmode/pytato_utils.py +++ b/meshmode/pytato_utils.py @@ -1,19 +1,17 @@ import pyopencl.array as cl_array -import kanren import pytato as pt -import unification import logging from functools import partial, reduce from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace from arraycontext import rec_map_reduce_array_container from meshmode.transform_metadata import DiscretizationEntityAxisTag -from pytato.loopy import LoopyCall -from pytato.array import EinsumElementwiseAxis, EinsumReductionAxis from pytato.transform import ArrayOrNames +from pytato.transform.metadata import ( + AxesTagsEquationCollector as BaseAxesTagsEquationCollector) from arraycontext import ArrayContainer from arraycontext.container.traversal import rec_map_array_container -from typing import Set, Mapping, Tuple, Union +from typing import Union logger = logging.getLogger(__name__) @@ -78,404 +76,28 @@ def max(self, a, axis=None): # {{{ solve for discretization metadata for arrays' axes -class DiscretizationEntityConstraintCollector(pt.transform.Mapper): - """ - .. warning:: - - Instances of this mapper type store state that are only for visiting a - single DAG. Using a single instance for collecting the constraints on - multiple DAGs is undefined behavior. - """ - def __init__(self): - super().__init__() - self._visited_ids: Set[int] = set() - - # axis_to_var: mapping from (array, iaxis) to the kanren variable to be - # used for unification. - self.axis_to_tag_var: Mapping[Tuple[pt.Array, int], - unification.variable.Var] = {} - self.variables_to_solve: Set[unification.variable.Var] = set() - self.constraints = [] - - # type-ignore reason: CachedWalkMapper.rec's type does not match - # WalkMapper.rec's type - def rec(self, expr: ArrayOrNames) -> None: # type: ignore - if id(expr) in self._visited_ids: - return - - # type-ignore reason: super().rec expects either 'Array' or - # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' - super().rec(expr) # type: ignore - self._visited_ids.add(id(expr)) - - def get_kanren_var_for_axis_tag(self, - expr: pt.Array, - iaxis: int - ) -> unification.variable.Var: - key = (expr, iaxis) - - if key not in self.axis_to_tag_var: - self.axis_to_tag_var[key] = kanren.var() - - return self.axis_to_tag_var[key] - - def _record_all_axes_to_be_solved_if_impl_stored(self, expr): - if expr.tags_of_type(pt.tags.ImplStored): - for iaxis in range(expr.ndim): - self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, - iaxis)) - - def _record_all_axes_to_be_solved(self, expr): - for iaxis in range(expr.ndim): - self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(expr, - iaxis)) - - def record_constraint(self, lhs, rhs): - self.constraints.append((lhs, rhs)) - - def record_eq_constraints_from_tags(self, expr: pt.Array) -> None: - for iaxis, axis in enumerate(expr.axes): - if axis.tags_of_type(DiscretizationEntityAxisTag): - discr_tag, = axis.tags_of_type(DiscretizationEntityAxisTag) - axis_var = self.get_kanren_var_for_axis_tag(expr, iaxis) - self.record_constraint(axis_var, discr_tag) - - def _map_input_base(self, expr: pt.InputArgumentBase - ) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - - for dim in expr.shape: - if isinstance(dim, pt.Array): - self.rec(dim) - - map_placeholder = _map_input_base - map_data_wrapper = _map_input_base - map_size_param = _map_input_base - - def map_index_lambda(self, expr: pt.IndexLambda) -> None: - from pytato.utils import are_shape_components_equal - from pytato.raising import index_lambda_to_high_level_op - from pytato.raising import (BinaryOp, FullOp, WhereOp, - BroadcastOp, C99CallOp, ReduceOp) - - # {{{ record constraints for expr and its subexprs. - - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - - for dim in expr.shape: - if isinstance(dim, pt.Array): - self.rec(dim) - - for bnd in expr.bindings.values(): - self.rec(bnd) - - # }}} - - hlo = index_lambda_to_high_level_op(expr) - - if isinstance(hlo, BinaryOp): - subexprs = (hlo.x1, hlo.x2) - elif isinstance(hlo, WhereOp): - subexprs = (hlo.condition, hlo.then, hlo.else_) - elif isinstance(hlo, FullOp): - # A full-op does not impose any constraints - subexprs = () - elif isinstance(hlo, BroadcastOp): - subexprs = (hlo.x,) - elif isinstance(hlo, C99CallOp): - subexprs = hlo.args - elif isinstance(hlo, ReduceOp): - # {{{ ReduceOp doesn't quite involve broadcasting - - i_out_axis = 0 - for i_in_axis in range(hlo.x.ndim): - if i_in_axis not in hlo.axes: - in_tag_var = self.get_kanren_var_for_axis_tag(hlo.x, - i_in_axis) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - i_out_axis) - self.record_constraint(in_tag_var, out_tag_var) - i_out_axis += 1 - - assert i_out_axis == expr.ndim - - # }}} - - for axis in hlo.axes: - self.variables_to_solve.add(self.get_kanren_var_for_axis_tag(hlo.x, - axis)) - return - - else: - raise NotImplementedError(type(hlo)) - - for subexpr in subexprs: - if isinstance(subexpr, pt.Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim, expr.ndim)): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, - i_in_axis) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - i_out_axis) - - self.record_constraint(in_tag_var, out_tag_var) - else: - # broadcasted axes, cannot belong to the same - # discretization entity. - assert are_shape_components_equal(in_dim, 1) - - def map_stack(self, expr: pt.Stack) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - # TODO; I think the axis corresponding to 'axis' need not be solved. - for ary in expr.arrays: - self.rec(ary) - - for iaxis in range(expr.ndim): - for ary in expr.arrays: - if iaxis < expr.axis: - in_tag_var = self.get_kanren_var_for_axis_tag(ary, - iaxis) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - iaxis) - - self.record_constraint(in_tag_var, out_tag_var) - elif iaxis == expr.axis: - pass - elif iaxis > expr.axis: - in_tag_var = self.get_kanren_var_for_axis_tag(ary, - iaxis-1) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - iaxis) - - self.record_constraint(in_tag_var, out_tag_var) - else: - raise AssertionError - - def map_concatenate(self, expr: pt.Concatenate) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - # TODO; I think the axis corresponding to 'axis' need not be solved. - for ary in expr.arrays: - self.rec(ary) - - for ary in expr.arrays: - assert ary.ndim == expr.ndim - for iaxis in range(expr.ndim): - if iaxis != expr.axis: - # non-concatenated axes share the dimensions. - in_tag_var = self.get_kanren_var_for_axis_tag(ary, - iaxis) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - iaxis) - self.record_constraint(in_tag_var, out_tag_var) - - def map_axis_permutation(self, expr: pt.AxisPermutation - ) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - self.rec(expr.array) - - assert expr.ndim == expr.array.ndim - - for out_axis in range(expr.ndim): - in_axis = expr.axis_permutation[out_axis] - out_tag = self.get_kanren_var_for_axis_tag(expr, out_axis) - in_tag = self.get_kanren_var_for_axis_tag(expr, in_axis) - self.record_constraint(out_tag, in_tag) - - def map_basic_index(self, expr: pt.IndexBase) -> None: - from pytato.array import NormalizedSlice - from pytato.utils import are_shape_components_equal - - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - self.rec(expr.array) - - i_out_axis = 0 - - assert len(expr.indices) == expr.array.ndim - - for i_in_axis, idx in enumerate(expr.indices): - if isinstance(idx, int): - pass - else: - assert isinstance(idx, NormalizedSlice) - if (idx.step == 1 - and are_shape_components_equal(idx.start, 0) - and are_shape_components_equal(idx.stop, - expr.array.shape[i_in_axis])): - - i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, - i_in_axis) - i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, - i_out_axis) - self.record_constraint(i_in_axis_tag, i_out_axis_tag) - - i_out_axis += 1 - - def map_contiguous_advanced_index(self, - expr: pt.AdvancedIndexInContiguousAxes - ) -> None: - from pytato.array import NormalizedSlice - from pytato.utils import (partition, get_shape_after_broadcasting, - are_shapes_equal, are_shape_components_equal) - - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - self.rec(expr.array) - for idx in expr.indices: - if isinstance(idx, pt.Array): - self.rec(idx) - - i_adv_indices, i_basic_indices = partition( - lambda idx: isinstance(expr.indices[idx], NormalizedSlice), - range(len(expr.indices))) - npre_advanced_basic_indices = len([i_idx - for i_idx in i_basic_indices - if i_idx < i_adv_indices[0]]) - npost_advanced_basic_indices = len([i_idx - for i_idx in i_basic_indices - if i_idx > i_adv_indices[-1]]) - - indirection_arrays = [expr.indices[i_idx] for i_idx in i_adv_indices] - assert are_shapes_equal( - get_shape_after_broadcasting(indirection_arrays), - expr.shape[ - npre_advanced_basic_indices:expr.ndim-npost_advanced_basic_indices]) - - for subexpr in indirection_arrays: - if isinstance(subexpr, pt.Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim+npre_advanced_basic_indices, - expr.ndim-npost_advanced_basic_indices)): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - in_tag_var = self.get_kanren_var_for_axis_tag(subexpr, - i_in_axis) - out_tag_var = self.get_kanren_var_for_axis_tag(expr, - i_out_axis) - - self.record_constraint(in_tag_var, out_tag_var) - else: - # broadcasted axes, cannot belong to the same - # discretization entity. - assert are_shape_components_equal(in_dim, 1) - - def map_non_contiguous_advanced_index(self, - expr: pt.AdvancedIndexInNoncontiguousAxes - ) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - self.rec(expr.array) - for idx in expr.indices: - if isinstance(idx, pt.Array): - self.rec(idx) - +class AxesTagsEquationCollector(BaseAxesTagsEquationCollector): def map_reshape(self, expr: pt.Reshape) -> None: - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - self.rec(expr.array) - # we can add constraints to reshape that only include new axes in its - # reshape. - # Other reshapes do not 'conserve' the types in our type-system. - # Well *what if*. Let's just say this type inference fails for - # non-trivial 'reshapes'. So, what are the 'trivial' reshapes? - # trivial reshapes: - # (x1, x2, ... xn) -> ((1,)*, x1, (1,)*, x2, (1,)*, x3, (1,)*, ..., xn, 1*) - # given all(x1!=1, x2!=1, x3!=1, .. xn!= 1) - if ((1 not in (expr.array.shape)) # leads to ambiguous newaxis + super().map_reshape(expr) + + if (expr.size > 0 + and (1 not in (expr.array.shape)) # leads to ambiguous newaxis and (set(expr.shape) <= (set(expr.array.shape) | {1}))): i_in_axis = 0 for i_out_axis, dim in enumerate(expr.shape): if dim != 1: assert dim == expr.array.shape[i_in_axis] - i_in_axis_tag = self.get_kanren_var_for_axis_tag(expr.array, - i_in_axis) - i_out_axis_tag = self.get_kanren_var_for_axis_tag(expr, - i_out_axis) - self.record_constraint(i_in_axis_tag, i_out_axis_tag) + self.record_equation( + self.get_var_for_axis(expr.array, + i_in_axis), + self.get_var_for_axis(expr, + i_out_axis) + ) i_in_axis += 1 else: # print(f"Skipping: {expr.array.shape} -> {expr.shape}") # Wacky reshape => bail. - return - - def map_einsum(self, expr: pt.Einsum) -> None: - - self.record_eq_constraints_from_tags(expr) - self._record_all_axes_to_be_solved_if_impl_stored(expr) - - for arg in expr.args: - self.rec(arg) - - descr_to_tag = {} - for iaxis in range(expr.ndim): - descr_to_tag[EinsumElementwiseAxis(iaxis)] = ( - self.get_kanren_var_for_axis_tag(expr, iaxis)) - - for access_descrs, arg in zip(expr.access_descriptors, - expr.args): - # if an einsum is stored => every argument's axes must - # also be inferred, even those that are getting reduced. - for iarg_axis, descr in enumerate(access_descrs): - in_tag_var = self.get_kanren_var_for_axis_tag(arg, - iarg_axis) - - if descr in descr_to_tag: - self.record_constraint(descr_to_tag[descr], in_tag_var) - else: - descr_to_tag[descr] = in_tag_var - - if isinstance(descr, EinsumReductionAxis): - self.variables_to_solve.add(in_tag_var) - - def map_dict_of_named_arrays(self, expr: pt.DictOfNamedArrays - ) -> None: - for _, subexpr in sorted(expr._data.items()): - self.rec(subexpr) - self._record_all_axes_to_be_solved(subexpr) - - def map_loopy_call(self, expr: LoopyCall) -> None: - for _, subexpr in sorted(expr.bindings.items()): - if isinstance(subexpr, pt.Array): - if not isinstance(subexpr, pt.InputArgumentBase): - self._record_all_axes_to_be_solved(subexpr) - self.rec(subexpr) - - # there's really no good way to propagate the metadata in this case. - # One *could* raise the loopy kernel instruction expressions to - # high level ops, but that's really involved and probably not worth it. - - def map_named_array(self, expr: pt.NamedArray) -> None: - self.record_eq_constraints_from_tags(expr) - self.rec(expr._container) - - def map_distributed_send_ref_holder(self, - expr: pt.DistributedSendRefHolder - ) -> None: - self.record_eq_constraints_from_tags(expr) - self.rec(expr.passthrough_data) - for idim in range(expr.ndim): - assert (expr.passthrough_data.shape[idim] - == expr.shape[idim]) - self.record_constraint( - self.get_kanren_var_for_axis_tag(expr.passthrough_data, - idim), - self.get_kanren_var_for_axis_tag(expr, idim) - ) - - def map_distributed_recv(self, - expr: pt.DistributedRecv) -> None: - self.record_eq_constraints_from_tags(expr) + pass def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] @@ -484,152 +106,11 @@ def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] return rec_map_array_container(unify_discretization_entity_tags, expr) - from collections import defaultdict - discr_unification_helper = DiscretizationEntityConstraintCollector() - discr_unification_helper(expr) - tag_var_to_axis = {} - variables_to_solve = [] - - for (axis, var) in discr_unification_helper.axis_to_tag_var.items(): - tag_var_to_axis[var] = axis - if var in discr_unification_helper.variables_to_solve: - variables_to_solve.append(var) - - lhs = [cnstrnt[0] for cnstrnt in discr_unification_helper.constraints] - rhs = [cnstrnt[1] for cnstrnt in discr_unification_helper.constraints] - assert len(lhs) == len(rhs) - solutions = {} - - for i_retry in range(MAX_UNIFY_RETRIES): - old_solutions = solutions.copy() - solutions = unification.unify(lhs, rhs, - {l_expr: r_expr - for l_expr, r_expr in solutions.items() - if isinstance(r_expr, - DiscretizationEntityAxisTag)}) - if solutions == old_solutions: - logger.info(f"Unification converged after {i_retry} iterations.") - break - else: - logger.warn(f"Could not converge after {MAX_UNIFY_RETRIES} iterations.") - - # Ideally it might be better to enable this, but that would be too - # restrictive as not all computation graphs result in DOFArray ouptuts - # if not (frozenset(variables_to_solve) <= frozenset(solutions)): - # raise RuntimeError("Unification failed.") - - # ary_to_axes_tags: mapping from array to a mapping from iaxis to the - # solved tag. - ary_to_axes_tags = defaultdict(dict) - for var in solutions: - ary, axis = tag_var_to_axis[var] - if isinstance(solutions[var], DiscretizationEntityAxisTag): - ary_to_axes_tags[ary][axis] = solutions[var] - if var in variables_to_solve and ( - not isinstance(solutions[var], DiscretizationEntityAxisTag)): - raise RuntimeError(f"Could not solve for {var}.") - - def attach_tags(expr: ArrayOrNames) -> ArrayOrNames: - if not isinstance(expr, pt.Array): - return expr - - for iaxis, solved_tag in ary_to_axes_tags[expr].items(): - if expr.axes[iaxis].tags_of_type(DiscretizationEntityAxisTag): - discr_tag, = (expr - .axes[iaxis] - .tags_of_type(DiscretizationEntityAxisTag)) - assert discr_tag == solved_tag - else: - if not isinstance(solved_tag, DiscretizationEntityAxisTag): - actual_tag = discr_unification_helper.axis_to_tag_var[(expr, - iaxis)] - assert actual_tag in discr_unification_helper.variables_to_solve - assert actual_tag in variables_to_solve - raise ValueError(f"In {expr!r}, axis={iaxis}'s type cannot be " - "inferred.") - expr = expr.with_tagged_axis(iaxis, solved_tag) - - if isinstance(expr, pt.Einsum): - redn_descr_to_entity_type = {} - for access_descrs, arg in zip(expr.access_descriptors, - expr.args): - for iaxis, access_descr in enumerate(access_descrs): - if isinstance(access_descr, EinsumReductionAxis): - redn_descr_to_entity_type[access_descr] = ( - ary_to_axes_tags[arg][iaxis]) - - if (frozenset(redn_descr_to_entity_type) - != frozenset(expr.redn_descr_to_redn_dim)): - raise ValueError - - for redn_descr, solved_tag in redn_descr_to_entity_type.items(): - if not isinstance(solved_tag, DiscretizationEntityAxisTag): - raise ValueError(f"In {expr!r}, redn_descr={redn_descr}'s" - " type cannot be inferred.") - expr = expr.with_tagged_redn_dim(redn_descr, solved_tag) - - if isinstance(expr, pt.IndexLambda): - from pytato.raising import (index_lambda_to_high_level_op, - ReduceOp) - - hlo = index_lambda_to_high_level_op(expr) - if isinstance(hlo, ReduceOp): - for iaxis in hlo.axes: - solved_tag = ary_to_axes_tags[hlo.x][iaxis] - if not isinstance(solved_tag, DiscretizationEntityAxisTag): - raise ValueError(f"In {expr!r}, redn_descr={iaxis}'s" - " type cannot be inferred.") - - expr = expr.with_tagged_redn_dim(iaxis, solved_tag) - - return expr - - return pt.transform.map_and_copy(expr, attach_tags) + return pt.unify_axes_tags(expr, + tag_t=DiscretizationEntityAxisTag, + equations_collector_t=AxesTagsEquationCollector) # }}} -class UnInferredStoredArrayCatcher(pt.transform.CachedWalkMapper): - """ - Raises a :class:`ValueError` if a stored array has axes without a - :class:`DiscretizationEntityAxisTag` tagged to it. - """ - def post_visit(self, expr: ArrayOrNames) -> None: - if (isinstance(expr, pt.Array) - and expr.tags_of_type(pt.tags.ImplStored)): - if any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 - for axis in expr.axes): - raise ValueError(f"{expr!r} doesn't have all its axes inferred.") - - if (isinstance(expr, pt.IndexLambda) - and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 - for redn_dim in expr.reduction_dims.values())): - raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") - - if (isinstance(expr, pt.Einsum) - and any(len(redn_dim.tags_of_type(DiscretizationEntityAxisTag)) != 1 - for redn_dim in expr.redn_descr_to_redn_dim.values())): - raise ValueError(f"{expr!r} doesn't have all its redn axes inferred.") - - if isinstance(expr, pt.DictOfNamedArrays): - if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 - for axis in subexpr.axes) - for subexpr in expr._data.values()): - raise ValueError(f"{expr!r} doesn't have all its axes inferred.") - - from pytato.loopy import LoopyCall - - if isinstance(expr, LoopyCall): - if any(any(len(axis.tags_of_type(DiscretizationEntityAxisTag)) != 1 - for axis in subexpr.axes) - for subexpr in expr.bindings.values() - if (isinstance(subexpr, pt.Array) - and not isinstance(subexpr, pt.InputArgumentBase) - and subexpr.ndim != 0)): - raise ValueError(f"{expr!r} doesn't have all its axes inferred.") - - -def are_all_stored_arrays_inferred(expr: ArrayOrNames): - UnInferredStoredArrayCatcher()(expr) - # vim: fdm=marker From 46c7302f69795d66307696d42ff88d6ffb1736a8 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 13:25:57 -0600 Subject: [PATCH 4/7] adds feinsum, kanren to deps --- requirements.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35a924099..e99abecc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/kaushikcfd/pytato.git#egg=pytato # required by pytential, which is in turn needed for some tests git+https://github.com/inducer/pymbolic.git#egg=pymbolic @@ -27,3 +27,8 @@ git+https://github.com/inducer/pymetis.git#egg=pymetis # for examples/tp-lagrange-stl.py numpy-stl + + +# for FusionContractorActx transforms +git+https://github.com/kaushikcfd/feinsum.git#egg=feinsum +git+https://github.com/pythological/kanren.git#egg=miniKanren From faaeff15f68dbde690c026bca1c63f849ab67d43 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 15 Nov 2022 12:38:57 -0800 Subject: [PATCH 5/7] add missing get_cache_key to _IndirectionMapRecorder --- meshmode/array_context.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index cfee0b977..73708a538 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1279,6 +1279,10 @@ def mark_materialized_nodes_as_cse( indirection_maps = set() class _IndirectionMapRecorder(pt.transform.CachedWalkMapper): + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr) -> int: # type: ignore[override] + return id(expr) + def post_visit(self, expr): if isinstance(expr, pt.IndexBase): for idx in expr.indices: From 5c33900e1b60eb18ba0e5c25ea9e4a139090365f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 3 Nov 2022 11:16:22 -0700 Subject: [PATCH 6/7] rearrange Einsum args --- meshmode/array_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 73708a538..9b06b1974 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1360,10 +1360,10 @@ def materialize_inverse_mass_inputs(expr): return pt.Einsum(expr.access_descriptors, (arg1, arg2, arg3), - expr.axes, expr.redn_axis_to_redn_descr, expr.index_to_access_descr, - expr.tags) + axes=expr.axes, + tags=expr.tags) else: return expr From 431dbe68de518efb5f17759bd4027ca8c81a6c35 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 19 Jan 2023 08:50:22 -0800 Subject: [PATCH 7/7] make compatible with new distributed DAG partitioner --- meshmode/array_context.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 9b06b1974..def07608b 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1493,16 +1493,21 @@ def _get_rid_of_broadcasts_from_einsum(expr): # {{{ remove any PartID tags - from pytato.distributed import PartIDTag + # FIXME: Remove after https://github.com/inducer/pytato/pull/393 goes in + try: + from pytato.distributed import PartIDTag - def remove_part_id_tags(expr): - if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): - tag, = expr.tags_of_type(PartIDTag) - return expr.without_tags(tag) - else: - return expr + def remove_part_id_tags(expr): + if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): + tag, = expr.tags_of_type(PartIDTag) + return expr.without_tags(tag) + else: + return expr + except ImportError: + remove_part_id_tags = None - dag = pt.transform.map_and_copy(dag, remove_part_id_tags) + if remove_part_id_tags is not None: + dag = pt.transform.map_and_copy(dag, remove_part_id_tags) # }}}