diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index 8224cae965..e604629879 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -21,7 +21,6 @@ subsets as dace_subsets, transformation as dace_transformation, ) -from dace.cli import progress as dace_cliprogress from dace.sdfg import nodes as dace_nodes, utils as dace_sdutils from dace.transformation import ( dataflow as dace_dataflow, @@ -104,7 +103,6 @@ def gt_simplify( if "InlineSDFGs" not in skip: inline_res = gt_inline_nested_sdfg( sdfg=sdfg, - multistate=True, permissive=False, validate=False, validate_all=validate_all, @@ -257,11 +255,9 @@ def gt_simplify( def gt_inline_nested_sdfg( sdfg: dace.SDFG, - multistate: bool = True, permissive: bool = False, validate: bool = True, validate_all: bool = False, - progress: Optional[bool] = None, ) -> Optional[dict[str, int]]: """Perform inlining of nested SDFG into their parent SDFG. @@ -272,24 +268,64 @@ def gt_inline_nested_sdfg( Args: sdfg: The SDFG that should be processed, will be modified in place and returned. - multistate: Allow inlining of multistate nested SDFG, defaults to `True`. permissive: Be less strict on the accepted SDFGs. validate: Perform validation after the transformation has finished. validate_all: Performs extensive validation. + + Note: + - This function grantees a stable processing order, if the name of the nested + SDFGs and the name of the state they are located in, is stable. + - The `no_inline` attribute of the `NestedSDFG` flag only affects the inlining + of that specific node. The clearing transformations and the recursive + processing, i.e. inlining of NestedSDFGs inside the nested SDFG is still + performed. """ + # NOTE: DaCe has three(!) inliner. First `InlineMultistateSDFG`, that we employ, + # secondly `InlineSDFG`, which is only capable of inlining an SDFG with a single + # state and `InlineSDFGs` which combines the two. However, `InlineSDFG` has a + # bug and the processing order of `InlineSDFGs` is not stable. Thus GT4Py + # implements its own version. + + # Finding all nested SDFGs on this level. + nested_sdfgs_to_process: list[dace_nodes.NestedSDFG] = [] + for state in sdfg.states(): + nested_sdfgs_to_process.extend( + node for node in state.nodes() if isinstance(node, dace_nodes.NestedSDFG) + ) + + # If there are no SDFGs to process then we exit. + if len(nested_sdfgs_to_process) == 0: + return None + + # Now order them, such that we can process them in a stable way. + nested_sdfgs_to_process.sort(key=lambda nsdfg: (str(nsdfg.label), str(nsdfg.sdfg.parent.label))) + nb_preproccess_total = 0 nb_inlines_total = 0 - nsdfgs = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace_nodes.NestedSDFG)] - for nsdfg_node in dace_cliprogress.optional_progressbar( - reversed(nsdfgs), title="Inlining SDFGs", n=len(nsdfgs), progress=progress - ): - nsdfg: dace.SDFG = nsdfg_node.sdfg - parent_state = nsdfg.parent + + # Now we start inlining all the nested SDFGs. + # Before a nested SDFG is inlined the function first tires to inline all + # SDFGs that are nested inside it, i.e. they are processed in a stable + # DFS order. + for nsdfg_node in nested_sdfgs_to_process: + nested_sdfg: dace.SDFG = nsdfg_node.sdfg + parent_state = nested_sdfg.parent parent_sdfg = parent_state.sdfg parent_state_id = parent_state.block_id - # Clean the symbols and connectors of the nested SDFG. + # Recursive processing of nested SDFGs. + recursive_result = gt_inline_nested_sdfg( + sdfg=nsdfg_node.sdfg, + permissive=permissive, + validate=False, + validate_all=validate_all, + ) + if recursive_result is not None: + nb_preproccess_total += recursive_result.get("PruneSymbols|PruneConnectors", 0) + nb_inlines_total += recursive_result.get("InlineSDFGs", 0) + + # Now perform some cleaning on the nested SDFG. for xform in [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors]: candidate = {xform.nsdfg: nsdfg_node} cleaner = xform() @@ -305,6 +341,17 @@ def gt_inline_nested_sdfg( cleaner.apply(parent_state, parent_sdfg) nb_preproccess_total += 1 + # Inlining an SDFG is only possible if the nested SDFG node is at global scope. + if parent_state.scope_dict()[nsdfg_node] is not None: + continue + + # Check the `no_inline` flag. Note that it has to be checked here and not + # before to ensure that the node is recursively processed and the pruning + # transformations are applied. + if nsdfg_node.no_inline: + continue + + # Now perform the actual inlining. # NOTE: In [PR#2178](https://github.com/GridTools/gt4py/pull/2178) this function was # modified to be more efficient. It also changed the order in which the inlining # transformations of DaCe were applied. Instead of trying `InlineMultistateSDFG` @@ -312,7 +359,6 @@ def gt_inline_nested_sdfg( # [issue#2108](https://github.com/spcl/dace/issues/2108) which lead to the removals # of some writes. As a temporary solution we no longer use `InlineSDFG` but only # the multistate version. - # TODO(phimuell): As soon as the DaCe issue is resolved start using `InlineSDFG` again. multi_state_candidate = {dace_interstate.InlineMultistateSDFG.nested_sdfg: nsdfg_node} multi_state_inliner = dace_interstate.InlineMultistateSDFG() multi_state_inliner.setup_match(