From 1d920f5ac53eaca9cf2e79948a37600d20c46d97 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 1 Dec 2025 11:25:15 +0100 Subject: [PATCH 01/28] edit --- .../runners/dace/__init__.py | 2 - .../runners/dace/lowering/__init__.py | 21 ++++++++ .../dace/{ => lowering}/gtir_dataflow.py | 26 ++++----- .../dace/{ => lowering}/gtir_domain.py | 2 +- .../{ => lowering}/gtir_python_codegen.py | 0 .../dace/{ => lowering}/gtir_to_sdfg.py | 36 ++++++------- .../gtir_to_sdfg_concat_where.py | 6 +-- .../{ => lowering}/gtir_to_sdfg_primitives.py | 22 ++++---- .../dace/{ => lowering}/gtir_to_sdfg_scan.py | 2 +- .../dace/{ => lowering}/gtir_to_sdfg_types.py | 13 ++--- .../dace/{ => lowering}/gtir_to_sdfg_utils.py | 28 +++++++++- .../runners/dace/program.py | 28 +++++----- .../runners/dace/{utils.py => sdfg_args.py} | 53 ++++--------------- .../runners/dace/sdfg_callable.py | 9 ++-- .../dace/transformations/loop_blocking.py | 4 +- .../transformations/map_fusion_extended.py | 4 +- .../dace/transformations/map_orderer.py | 4 +- .../runners/dace/workflow/bindings.py | 12 ++--- .../runners/dace/workflow/translation.py | 25 ++++----- .../dace_tests/test_dace_domain.py | 12 ++--- .../dace_tests/test_dace_utils.py | 5 +- .../dace_tests/test_gtir_to_sdfg.py | 8 +-- .../test_auto_optimizer_hooks.py | 1 - .../transformation_tests/test_map_promoter.py | 6 +-- 24 files changed, 166 insertions(+), 163 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/lowering/__init__.py rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_dataflow.py (98%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_domain.py (98%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_python_codegen.py (100%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg.py (97%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg_concat_where.py (98%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg_primitives.py (97%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg_scan.py (99%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg_types.py (94%) rename src/gt4py/next/program_processors/runners/dace/{ => lowering}/gtir_to_sdfg_utils.py (84%) rename src/gt4py/next/program_processors/runners/dace/{utils.py => sdfg_args.py} (70%) diff --git a/src/gt4py/next/program_processors/runners/dace/__init__.py b/src/gt4py/next/program_processors/runners/dace/__init__.py index f3f672b651..0bb2c40dc3 100644 --- a/src/gt4py/next/program_processors/runners/dace/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/__init__.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.program_processors.runners.dace.gtir_to_sdfg import build_sdfg_from_gtir from gt4py.next.program_processors.runners.dace.sdfg_callable import get_sdfg_args from gt4py.next.program_processors.runners.dace.workflow.backend import ( make_dace_backend, @@ -21,7 +20,6 @@ __all__ = [ - "build_sdfg_from_gtir", "get_sdfg_args", "make_dace_backend", "run_dace_cpu", diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/__init__.py b/src/gt4py/next/program_processors/runners/dace/lowering/__init__.py new file mode 100644 index 0000000000..1962d2a7e6 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/lowering/__init__.py @@ -0,0 +1,21 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.program_processors.runners.dace.lowering.gtir_to_sdfg import build_sdfg_from_gtir +from gt4py.next.program_processors.runners.dace.lowering.gtir_to_sdfg_utils import ( + flatten_tuple_fields, + get_map_variable, +) + + +__all__ = [ + "build_sdfg_from_gtir", + "flatten_tuple_fields", + "get_map_variable", +] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index a075d36502..56b25103ee 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -34,12 +34,12 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_python_codegen, gtir_to_sdfg, gtir_to_sdfg_types, gtir_to_sdfg_utils, - utils as gtx_dace_utils, ) from gt4py.next.type_system import ( type_info as ti, @@ -346,7 +346,7 @@ def connect( def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: assert isinstance(node.type, ts.ScalarType) - dc_dtype = gtx_dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_args.as_dace_type(node.type) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 2 @@ -530,7 +530,7 @@ def _construct_tasklet_result( src_connector: str, use_array: bool = False, ) -> ValueExpr: - data_type = gtx_dace_utils.as_itir_type(dc_dtype) + data_type = gtx_dace_args.as_itir_type(dc_dtype) if use_array: # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) @@ -1008,7 +1008,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: # initially, the storage for the connectivty tables is created as transient; # when the tables are used, the storage is changed to non-transient, # as the corresponding arrays are supposed to be allocated by the SDFG caller - conn_data = gtx_dace_utils.connectivity_identifier(offset) + conn_data = gtx_dace_args.connectivity_identifier(offset) conn_desc = self.sdfg.arrays[conn_data] conn_desc.transient = False @@ -1124,7 +1124,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: raise ValueError(f"Unexpected argument type {type(src_arg)} in 'list_get' expression.") assert isinstance(src_arg.gt_dtype.element_type, ts.ScalarType) - assert src_desc.dtype == gtx_dace_utils.as_dace_type(src_arg.gt_dtype.element_type) + assert src_desc.dtype == gtx_dace_args.as_dace_type(src_arg.gt_dtype.element_type) dst, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, src_desc.dtype) dst_node = self.state.add_access(dst) @@ -1200,7 +1200,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: assert len(node.fun.args) == 1 # the operation to be mapped on the arguments assert isinstance(node.type.element_type, ts.ScalarType) - dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) + dc_dtype = gtx_dace_args.as_dace_type(node.type.element_type) input_connectors = [f"__arg{i}" for i in range(len(node.args))] output_connector = "__out" @@ -1273,7 +1273,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if conn_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. - conn_data = gtx_dace_utils.connectivity_identifier(offset_type.value) + conn_data = gtx_dace_args.connectivity_identifier(offset_type.value) conn_desc = self.sdfg.arrays[conn_data] conn_desc.transient = False @@ -1358,7 +1358,7 @@ def _make_reduce_with_skip_values( and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - connectivity = gtx_dace_utils.connectivity_identifier(offset_type.value) + connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) connectivity_node = self.state.add_access(connectivity) connectivity_desc = connectivity_node.desc(self.sdfg) connectivity_desc.transient = False @@ -1715,7 +1715,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = gtx_dace_utils.connectivity_identifier(offset) + offset_table = gtx_dace_args.connectivity_identifier(offset) self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) @@ -1783,13 +1783,13 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: # Therefore we handle `ListType` as a single-element array with shape (1,) # that will be accessed in a map expression on a local domain. assert isinstance(node.type.element_type, ts.ScalarType) - dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) + dc_dtype = gtx_dace_args.as_dace_type(node.type.element_type) # In order to ease the lowring of the parent expression on local dimension, # we represent the scalar value as a single-element 1D array. use_array = True else: assert isinstance(node.type, ts.ScalarType) - dc_dtype = gtx_dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_args.as_dace_type(node.type) use_array = False return self._construct_tasklet_result( @@ -1894,7 +1894,7 @@ def _visit_Lambda_impl( ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: - dc_dtype = gtx_dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_args.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) def visit_SymRef(self, node: gtir.SymRef) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_domain.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace/gtir_domain.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_domain.py index 303bb5d936..bfdf873d71 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_domain.py @@ -19,7 +19,7 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils -from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils +from gt4py.next.program_processors.runners.dace.lowering import gtir_to_sdfg_utils @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_python_codegen.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_python_codegen.py diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index d7f2bafba9..f3aa465b0f 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -29,13 +29,13 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_domain, gtir_to_sdfg_concat_where, gtir_to_sdfg_primitives, gtir_to_sdfg_types, gtir_to_sdfg_utils, - utils as gtx_dace_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -270,7 +270,7 @@ def map_nsdfg_field( ) # Same applies to the symbols used as field origin (the domain range start) outer_origin = [ - gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) + gtir_to_sdfg_utils.safe_replace_symbolic(val, symbol_mapping) for val in nsdfg_field.origin ] @@ -521,7 +521,7 @@ def make_field( "Fields with more than one local dimension are not supported." ) field_origin = tuple( - gtx_dace_utils.range_start_symbol(data_node.data, dim) for dim in field_type.dims + gtx_dace_args.range_start_symbol(data_node.data, dim) for dim in field_type.dims ) return gtir_to_sdfg_types.FieldopData(data_node, field_type, field_origin) @@ -671,8 +671,8 @@ def add_nested_sdfg( nsdfg_symbols_mapping |= arg.get_symbol_mapping(gt_symbol, outer_ctx.sdfg) connectivity_arrays = { - gtx_dace_utils.connectivity_identifier(offset) - for offset in gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) + gtx_dace_args.connectivity_identifier(offset) + for offset in gtx_dace_args.filter_connectivity_types(self.offset_provider_type) } inner_ctx_globals = [ @@ -751,28 +751,28 @@ def _make_array_shape_and_strides( Returns: Two lists of symbols, one for the shape and the other for the strides of the array. """ - neighbor_table_types = gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) + neighbor_table_types = gtx_dace_args.filter_connectivity_types(self.offset_provider_type) shape = [] for dim in dims: if dim.kind == gtx_common.DimensionKind.LOCAL: # for local dimension, the size is taken from the associated connectivity type shape.append(neighbor_table_types[dim.value].max_neighbors) - elif gtx_dace_utils.is_connectivity_identifier(name, self.offset_provider_type): + elif gtx_dace_args.is_connectivity_identifier(name, self.offset_provider_type): # we use symbolic size for the global dimension of a connectivity - shape.append(gtx_dace_utils.field_size_symbol(name, dim, neighbor_table_types)) + shape.append(gtx_dace_args.field_size_symbol(name, dim, neighbor_table_types)) else: # the size of global dimensions for a regular field is the symbolic # expression of domain range 'stop - start' shape.append( dace.symbolic.pystr_to_symbolic( "{} - {}".format( - gtx_dace_utils.range_stop_symbol(name, dim), - gtx_dace_utils.range_start_symbol(name, dim), + gtx_dace_args.range_stop_symbol(name, dim), + gtx_dace_args.range_start_symbol(name, dim), ) ) ) strides = [ - gtx_dace_utils.field_stride_symbol(name, dim, neighbor_table_types) for dim in dims + gtx_dace_args.field_stride_symbol(name, dim, neighbor_table_types) for dim in dims ] return shape, strides @@ -838,13 +838,13 @@ def _add_storage( transient=transient, ) if isinstance(gt_type.dtype, ts.ScalarType): - dc_dtype = gtx_dace_utils.as_dace_type(gt_type.dtype) + dc_dtype = gtx_dace_args.as_dace_type(gt_type.dtype) all_dims = gt_type.dims else: # for 'ts.ListType' use 'offset_type' as local dimension assert gt_type.dtype.offset_type is not None assert gt_type.dtype.offset_type.kind == gtx_common.DimensionKind.LOCAL assert isinstance(gt_type.dtype.element_type, ts.ScalarType) - dc_dtype = gtx_dace_utils.as_dace_type(gt_type.dtype.element_type) + dc_dtype = gtx_dace_args.as_dace_type(gt_type.dtype.element_type) all_dims = gtx_common.order_dimensions([*gt_type.dims, gt_type.dtype.offset_type]) # Use symbolic shape, which allows to invoke the program with fields of different size; @@ -854,7 +854,7 @@ def _add_storage( return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): - dc_dtype = gtx_dace_utils.as_dace_type(gt_type) + dc_dtype = gtx_dace_args.as_dace_type(gt_type) if symbolic_params is None or name in symbolic_params: sdfg.add_symbol(name, dc_dtype) else: @@ -951,7 +951,7 @@ def _add_sdfg_params( ) # add SDFG storage for connectivity tables - for offset, connectivity_type in gtx_dace_utils.filter_connectivity_types( + for offset, connectivity_type in gtx_dace_args.filter_connectivity_types( self.offset_provider_type ).items(): gt_type = ts.FieldType( @@ -966,7 +966,7 @@ def _add_sdfg_params( self._add_storage( sdfg=sdfg, symbolic_params=symbolic_params, - name=gtx_dace_utils.connectivity_identifier(offset), + name=gtx_dace_args.connectivity_identifier(offset), gt_type=gt_type, transient=True, ) @@ -1020,7 +1020,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: unused_connectivities = [ data for data, datadesc in nsdfg.arrays.items() - if gtx_dace_utils.is_connectivity_identifier(data, self.offset_provider_type) + if gtx_dace_args.is_connectivity_identifier(data, self.offset_provider_type) and datadesc.transient ] for data in unused_connectivities: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py index 602b8ba15c..bbc0b3c1c4 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py @@ -26,12 +26,12 @@ domain_utils, ir_makers as im, ) -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_domain, gtir_to_sdfg, gtir_to_sdfg_types, gtir_to_sdfg_utils, - utils as gtx_dace_utils, ) from gt4py.next.type_system import type_specifications as ts @@ -243,7 +243,7 @@ def translate_concat_where( assert output_dims == node.type.dims if isinstance(node.type.dtype, ts.ScalarType): - dtype = gtx_dace_utils.as_dace_type(node.type.dtype) + dtype = gtx_dace_args.as_dace_type(node.type.dtype) else: # TODO(edopao): Refactor allocation of fields with local dimension and enable this. raise NotImplementedError("'concat_where' with list output is not supported") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 1b65acd8d0..be0ff6130d 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -22,24 +22,24 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_dataflow, gtir_domain, gtir_python_codegen, gtir_to_sdfg, gtir_to_sdfg_types, gtir_to_sdfg_utils, - utils as gtx_dace_utils, ) -from gt4py.next.program_processors.runners.dace.gtir_to_sdfg_concat_where import ( +from gt4py.next.program_processors.runners.dace.lowering.gtir_to_sdfg_concat_where import ( translate_concat_where, ) -from gt4py.next.program_processors.runners.dace.gtir_to_sdfg_scan import translate_scan +from gt4py.next.program_processors.runners.dace.lowering.gtir_to_sdfg_scan import translate_scan from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: - from gt4py.next.program_processors.runners.dace import gtir_to_sdfg + from gt4py.next.program_processors.runners.dace.lowering import gtir_to_sdfg class PrimitiveTranslator(Protocol): @@ -311,7 +311,7 @@ def _construct_if_branch_output( out_type = true_br.gt_type if isinstance(out_type, ts.ScalarType): - dtype = gtx_dace_utils.as_dace_type(out_type) + dtype = gtx_dace_args.as_dace_type(out_type) out, _ = sdfg_builder.add_temp_scalar(ctx.sdfg, dtype) out_node = ctx.state.add_access(out) return gtir_to_sdfg_types.FieldopData(out_node, out_type, origin=()) @@ -322,12 +322,12 @@ def _construct_if_branch_output( assert dims == out_type.dims if isinstance(out_type.dtype, ts.ScalarType): - dtype = gtx_dace_utils.as_dace_type(out_type.dtype) + dtype = gtx_dace_args.as_dace_type(out_type.dtype) else: assert isinstance(out_type.dtype, ts.ListType) assert out_type.dtype.offset_type is not None assert isinstance(out_type.dtype.element_type, ts.ScalarType) - dtype = gtx_dace_utils.as_dace_type(out_type.dtype.element_type) + dtype = gtx_dace_args.as_dace_type(out_type.dtype.element_type) offset_provider_type = sdfg_builder.get_offset_provider_type( out_type.dtype.offset_type.value ) @@ -479,7 +479,7 @@ def translate_index( index_node = ctx.state.add_access(index_data) index_value = gtir_dataflow.ValueExpr( dc_node=index_node, - gt_dtype=gtx_dace_utils.as_itir_type(gtir_to_sdfg_types.INDEX_DTYPE), + gt_dtype=gtx_dace_args.as_itir_type(gtir_to_sdfg_types.INDEX_DTYPE), ) index_write_tasklet, connector_mapping = sdfg_builder.add_tasklet( name="index", @@ -557,7 +557,7 @@ def _get_symbolic_value( ) temp_name, _ = sdfg.add_scalar( temp_name or sdfg.temp_data_name(), - gtx_dace_utils.as_dace_type(scalar_type), + gtx_dace_args.as_dace_type(scalar_type), find_new_name=True, transient=True, ) @@ -703,7 +703,7 @@ def translate_scalar_expr( dace.Memlet(data=arg_node.data, subset="0"), ) # finally, create temporary for the result value - temp_name, _ = sdfg_builder.add_temp_scalar(ctx.sdfg, gtx_dace_utils.as_dace_type(node.type)) + temp_name, _ = sdfg_builder.add_temp_scalar(ctx.sdfg, gtx_dace_args.as_dace_type(node.type)) temp_node = ctx.state.add_access(temp_name) ctx.state.add_edge( tasklet_node, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index f3862c60b2..874a3d63d2 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -37,7 +37,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_dataflow, gtir_domain, gtir_to_sdfg, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py similarity index 94% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 3878b55c6c..c19d8a4011 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -19,11 +19,8 @@ from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common from gt4py.next.iterator import builtins as gtir_builtins -from gt4py.next.program_processors.runners.dace import ( - gtir_dataflow, - gtir_domain, - utils as gtx_dace_utils, -) +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace.lowering import gtir_dataflow, gtir_domain from gt4py.next.type_system import type_specifications as ts @@ -131,12 +128,12 @@ def get_symbol_mapping( symbol_mapping: dict[str, dace.symbolic.SymbolicType] = {} for dim, origin, size in zip(self.gt_type.dims, self.origin, globals_size, strict=True): symbol_mapping |= { - gtx_dace_utils.range_start_symbol(dataname, dim): origin, - gtx_dace_utils.range_stop_symbol(dataname, dim): (origin + size), + gtx_dace_args.range_start_symbol(dataname, dim): origin, + gtx_dace_args.range_stop_symbol(dataname, dim): (origin + size), } for dim, stride in zip(all_dims, outer_desc.strides, strict=True): symbol_mapping |= { - gtx_dace_utils.field_stride_symbol(dataname, dim): stride, + gtx_dace_args.field_stride_symbol(dataname, dim): stride, } return symbol_mapping diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_utils.py similarity index 84% rename from src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py rename to src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_utils.py index 1795fcbb43..9469240d24 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_utils.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Dict, Final, Optional, TypeVar +from typing import Dict, Final, Mapping, Optional, TypeVar import dace @@ -17,7 +17,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners.dace import gtir_python_codegen +from gt4py.next.program_processors.runners.dace.lowering import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts @@ -162,3 +162,27 @@ def get_symbolic(ir: gtir.Expr) -> dace.symbolic.SymbolicType: """ python_source = gtir_python_codegen.get_source(ir) return dace.symbolic.pystr_to_symbolic(python_source) + + +def safe_replace_symbolic( + val: dace.symbolic.SymbolicType, + symbol_mapping: Mapping[dace.symbolic.SymbolicType | str, dace.symbolic.SymbolicType | str], +) -> dace.symbolic.SymbolicType: + """ + Replace free symbols in a dace symbolic expression, using `safe_replace()` + in order to avoid clashes in case the new symbol value is also a free symbol + in the original exoression. + + Args: + val: The symbolic expression where to apply the replacement. + symbol_mapping: The mapping table for symbol replacement. + + Returns: + A new symbolic expression as result of symbol replacement. + """ + # The list `x` is needed because `subs()` returns a new object and can not handle + # replacement dicts of the form `{'x': 'y', 'y': 'x'}`. + # The utility `safe_replace()` will call `subs()` twice in case of such dicts. + x = [val] + dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m))) + return x[-1] diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 94740166df..8b5496989f 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -19,7 +19,7 @@ from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import extractors as extractors from gt4py.next.otf import arguments, recipes, toolchain -from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.type_system import type_specifications as ts @@ -157,7 +157,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ conn_id: conn for offset, conn in self.connectivities.items() if gtx_common.is_neighbor_table(conn) - and (conn_id := gtx_dace_utils.connectivity_identifier(offset)) + and (conn_id := gtx_dace_args.connectivity_identifier(offset)) in self.sdfg_closure_cache["arrays"] } @@ -171,27 +171,25 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Build the closure dictionary closure_dict: dict[str, dace.data.Array] = {} - connectivity_types = gtx_dace_utils.filter_connectivity_types( - gtx_common.offset_provider_to_type(self.connectivities) - ) + offset_provider_type = gtx_common.offset_provider_to_type(self.connectivities) for conn_id, conn in used_connectivities.items(): if conn_id not in self.connectivity_tables_data_descriptors: self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), shape=[ - gtx_dace_utils.field_size_symbol( - conn_id, conn.domain.dims[0], connectivity_types + gtx_dace_args.field_size_symbol( + conn_id, conn.domain.dims[0], offset_provider_type ), - gtx_dace_utils.field_size_symbol( - conn_id, conn.domain.dims[1], connectivity_types + gtx_dace_args.field_size_symbol( + conn_id, conn.domain.dims[1], offset_provider_type ), ], strides=[ - gtx_dace_utils.field_stride_symbol( - conn_id, conn.domain.dims[0], connectivity_types + gtx_dace_args.field_stride_symbol( + conn_id, conn.domain.dims[0], offset_provider_type ), - gtx_dace_utils.field_stride_symbol( - conn_id, conn.domain.dims[1], connectivity_types + gtx_dace_args.field_stride_symbol( + conn_id, conn.domain.dims[1], offset_provider_type ), ], storage=Program.connectivity_tables_data_descriptors["storage"], @@ -212,7 +210,7 @@ def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: li ): match dace_parsed_arg: case dace.data.Scalar(): - assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg) + assert dace_parsed_arg.dtype == gtx_dace_args.as_dace_type(gt4py_program_arg) case bool() | np.bool_(): assert isinstance(gt4py_program_arg, ts.ScalarType) assert gt4py_program_arg.kind == ts.ScalarKind.BOOL @@ -229,7 +227,7 @@ def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: li assert isinstance(gt4py_program_arg, ts.FieldType) assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg.dtype) + assert dace_parsed_arg.dtype == gtx_dace_args.as_dace_type(gt4py_program_arg.dtype) case dace.data.Structure() | dict() | collections.OrderedDict(): # offset provider pass diff --git a/src/gt4py/next/program_processors/runners/dace/utils.py b/src/gt4py/next/program_processors/runners/dace/sdfg_args.py similarity index 70% rename from src/gt4py/next/program_processors/runners/dace/utils.py rename to src/gt4py/next/program_processors/runners/dace/sdfg_args.py index cfc5290124..f33fbf8bb5 100644 --- a/src/gt4py/next/program_processors/runners/dace/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_args.py @@ -9,14 +9,14 @@ from __future__ import annotations import re -from typing import Final, Literal, Mapping, Union +from typing import Final, Literal import dace from gt4py.next import common as gtx_common from gt4py.next.iterator import builtins as gtir_builtins from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners.dace import gtir_python_codegen +from gt4py.next.program_processors.runners.dace.lowering import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts @@ -60,29 +60,21 @@ def connectivity_identifier(name: str) -> str: def is_connectivity_identifier( name: str, offset_provider_type: gtx_common.OffsetProviderType | None = None ) -> bool: - m = CONNECTIVITY_INDENTIFIER_RE.match(name) - if m is None: + if (m := CONNECTIVITY_INDENTIFIER_RE.match(name)) is None: return False - if offset_provider_type is None: + elif offset_provider_type is None: # If no offset provider type is provided, we assume there is a connectivity identifier # that matches the CONNECTIVITY_INDENTIFIER_RE. return True - return gtx_common.has_offset(offset_provider_type, m[1]) - - -def is_connectivity_symbol(name: str, offset_provider_type: gtx_common.OffsetProviderType) -> bool: - if (m_symbol := FIELD_SYMBOL_RE.match(name)) is None: - return False - if (m := CONNECTIVITY_INDENTIFIER_RE.match(m_symbol[1])) is None: - return False - return gtx_common.has_offset(offset_provider_type, m[1]) + else: + return gtx_common.has_offset(offset_provider_type, m[1]) def _field_symbol( field_name: str, dim: gtx_common.Dimension, sym: Literal["size", "stride"], - offset_provider_type: Mapping[str, gtx_common.NeighborConnectivityType] | None, + offset_provider_type: gtx_common.OffsetProviderType | None, ) -> dace.symbol: if (m := CONNECTIVITY_INDENTIFIER_RE.match(field_name)) is None: name = f"__{field_name}_{dim.value}_{sym}" @@ -91,6 +83,7 @@ def _field_symbol( assert m[1] in offset_provider_type offset = m[1] conn_type = offset_provider_type[offset] + assert isinstance(conn_type, gtx_common.NeighborConnectivityType) if dim == conn_type.source_dim: name = f"__{field_name}_source_{sym}" elif dim == conn_type.neighbor_dim: @@ -103,7 +96,7 @@ def _field_symbol( def field_size_symbol( field_name: str, dim: gtx_common.Dimension, - offset_provider_type: Mapping[str, gtx_common.NeighborConnectivityType], + offset_provider_type: gtx_common.OffsetProviderType, ) -> dace.symbol: return _field_symbol(field_name, dim, "size", offset_provider_type) @@ -111,7 +104,7 @@ def field_size_symbol( def field_stride_symbol( field_name: str, dim: gtx_common.Dimension, - offset_provider_type: Mapping[str, gtx_common.NeighborConnectivityType] | None = None, + offset_provider_type: gtx_common.OffsetProviderType | None = None, ) -> dace.symbol: return _field_symbol(field_name, dim, "stride", offset_provider_type) @@ -148,29 +141,3 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } - - -def safe_replace_symbolic( - val: dace.symbolic.SymbolicType, - symbol_mapping: Mapping[ - Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str] - ], -) -> dace.symbolic.SymbolicType: - """ - Replace free symbols in a dace symbolic expression, using `safe_replace()` - in order to avoid clashes in case the new symbol value is also a free symbol - in the original exoression. - - Args: - val: The symbolic expression where to apply the replacement. - symbol_mapping: The mapping table for symbol replacement. - - Returns: - A new symbolic expression as result of symbol replacement. - """ - # The list `x` is needed because `subs()` returns a new object and can not handle - # replacement dicts of the form `{'x': 'y', 'y': 'x'}`. - # The utility `safe_replace()` will call `subs()` twice in case of such dicts. - x = [val] - dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m))) - return x[-1] diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index 2cdd2325c0..735a78aab5 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -15,17 +15,16 @@ from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, field_utils - -from . import utils as gtx_dace_utils +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args def get_field_domain_symbols(name: str, domain: gtx_common.Domain) -> dict[str, int]: assert gtx_common.Domain.is_finite(domain) return { - gtx_dace_utils.range_start_symbol(name, dim).name: r.start + gtx_dace_args.range_start_symbol(name, dim).name: r.start for dim, r in zip(domain.dims, domain.ranges, strict=True) } | { - gtx_dace_utils.range_stop_symbol(name, dim).name: r.stop + gtx_dace_args.range_stop_symbol(name, dim).name: r.stop for dim, r in zip(domain.dims, domain.ranges, strict=True) } @@ -96,7 +95,7 @@ def get_sdfg_conn_args( """ connectivity_args = {} for offset, connectivity in offset_provider.items(): - name = gtx_dace_utils.connectivity_identifier(offset) + name = gtx_dace_args.connectivity_identifier(offset) if name in sdfg.arrays: assert gtx_common.is_neighbor_connectivity(connectivity) assert field_utils.verify_device_field_type( diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 9b284f2f5b..4601c29b52 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -25,7 +25,7 @@ from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils +from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering @dace_properties.make_properties @@ -89,7 +89,7 @@ def __init__( ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): - blocking_parameter = gtir_to_sdfg_utils.get_map_variable(blocking_parameter) + blocking_parameter = gtx_dace_lowering.get_map_variable(blocking_parameter) if blocking_parameter is not None: self.blocking_parameter = blocking_parameter if blocking_size is not None: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py index 0260e6aa8e..ed09312d4f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py @@ -26,8 +26,8 @@ from dace.transformation.passes import analysis as dace_analysis from gt4py.next.program_processors.runners.dace import ( + sdfg_args as gtx_dace_args, transformations as gtx_transformations, - utils as gtx_dace_utils, ) from gt4py.next.program_processors.runners.dace.transformations import ( map_fusion_utils as gtx_mfutils, @@ -698,7 +698,7 @@ def can_be_applied( edge.data.data for edge in second_map_subgraph.in_edges(second_map_node) ] if self.access_node.data in nested_map_input_edges_data and any( - gtx_dace_utils.is_connectivity_identifier(data) + gtx_dace_args.is_connectivity_identifier(data) for data in nested_map_input_edges_data ): return False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py index c018542c46..2f56671da8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py @@ -13,7 +13,7 @@ from dace.sdfg import nodes as dace_nodes from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils +from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering def gt_set_iteration_order( @@ -130,7 +130,7 @@ def __init__( self.unit_strides_dims = [ unit_strides_dim if isinstance(unit_strides_dim, str) - else gtir_to_sdfg_utils.get_map_variable(unit_strides_dim) + else gtx_dace_lowering.get_map_variable(unit_strides_dim) for unit_strides_dim in unit_strides_dims ] elif unit_strides_kind is not None: diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py index 91f0efdd66..6fc233145f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py @@ -14,7 +14,7 @@ from gt4py.eve import codegen from gt4py.next.otf import languages, stages -from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.type_system import type_specifications as ts @@ -140,12 +140,12 @@ def _parse_gt_param( # like 'range_stop - range_start', where 'range_start' and # 'range_stop' are the SDFG symbols for the domain range. arg_range = f"{arg}.domain.ranges[{i}]" - rstart = gtx_dace_utils.range_start_symbol(param_name, dim) - rstop = gtx_dace_utils.range_stop_symbol(param_name, dim) + rstart = gtx_dace_args.range_start_symbol(param_name, dim) + rstop = gtx_dace_args.range_stop_symbol(param_name, dim) for suffix, sdfg_range_symbol in [("start", rstart), ("stop", rstop)]: _parse_gt_param( param_name=sdfg_range_symbol.name, - param_type=gtx_dace_utils.as_itir_type(sdfg_range_symbol.dtype), + param_type=gtx_dace_args.as_itir_type(sdfg_range_symbol.dtype), arg=f"{arg_range}.{suffix}", code=code, sdfg_arglist=sdfg_arglist, @@ -160,13 +160,13 @@ def _parse_gt_param( f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].strides[{i}] == {arg_stride}" ) else: - sdfg_stride_symbol = gtx_dace_utils.field_stride_symbol(param_name, dim) + sdfg_stride_symbol = gtx_dace_args.field_stride_symbol(param_name, dim) assert array_stride == sdfg_stride_symbol # The strides of a global array are defined by a sequence # of SDFG symbols. _parse_gt_param( param_name=sdfg_stride_symbol.name, - param_type=gtx_dace_utils.as_itir_type(sdfg_stride_symbol.dtype), + param_type=gtx_dace_args.as_itir_type(sdfg_stride_symbol.dtype), arg=arg_stride, code=code, sdfg_arglist=sdfg_arglist, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index e5a556eb4f..e45132bc60 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -21,10 +21,9 @@ from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace import ( - gtir_to_sdfg, - gtir_to_sdfg_utils, + lowering as gtx_dace_lowering, + sdfg_args as gtx_dace_args, transformations as gtx_transformations, - utils as gtx_dace_utils, ) from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon from gt4py.next.type_system import type_specifications as ts @@ -52,16 +51,18 @@ def find_constant_symbols( raise NotImplementedError( f"Unsupported field with multiple horizontal dimensions '{p}'." ) - sdfg_stride_symbol = gtx_dace_utils.field_stride_symbol(str(p.id), dim) + sdfg_stride_symbol = gtx_dace_args.field_stride_symbol(str(p.id), dim) constant_symbols[sdfg_stride_symbol.name] = 1 # Same for connectivity tables, for which the first dimension is always horizontal - connectivity_types = gtx_dace_utils.filter_connectivity_types(offset_provider_type) - for offset, conn_type in connectivity_types.items(): - if (conn_id := gtx_dace_utils.connectivity_identifier(offset)) in sdfg.arrays: + for offset, conn_type in offset_provider_type.items(): + if ( + isinstance(conn_type, common.NeighborConnectivityType) + and (conn_id := gtx_dace_args.connectivity_identifier(offset)) in sdfg.arrays + ): assert not sdfg.arrays[conn_id].transient assert conn_type.source_dim.kind == common.DimensionKind.HORIZONTAL - sdfg_stride_symbol = gtx_dace_utils.field_stride_symbol( - conn_id, conn_type.source_dim, connectivity_types + sdfg_stride_symbol = gtx_dace_args.field_stride_symbol( + conn_id, conn_type.source_dim, offset_provider_type ) constant_symbols[sdfg_stride_symbol.name] = 1 @@ -71,7 +72,7 @@ def find_constant_symbols( if isinstance(p.type, ts.TupleType): psymbols = [ sym - for sym in gtir_to_sdfg_utils.flatten_tuple_fields(p.id, p.type) + for sym in gtx_dace_lowering.flatten_tuple_fields(p.id, p.type) if isinstance(sym.type, ts.FieldType) ] elif isinstance(p.type, ts.FieldType): @@ -85,7 +86,7 @@ def find_constant_symbols( continue # set all range start symbols to constant value 0 sdfg_origin_symbols = [ - gtx_dace_utils.range_start_symbol(str(psymbol.id), dim) + gtx_dace_args.range_start_symbol(str(psymbol.id), dim) for dim in psymbol.type.dims ] constant_symbols |= {sdfg_symbol.name: 0 for sdfg_symbol in sdfg_origin_symbols} @@ -374,7 +375,7 @@ def _generate_sdfg_without_configuring_dace( offset_provider_type = common.offset_provider_to_type(offset_provider) on_gpu = self.device_type != core_defs.DeviceType.CPU - sdfg = gtir_to_sdfg.build_sdfg_from_gtir(ir, offset_provider_type, column_axis) + sdfg = gtx_dace_lowering.build_sdfg_from_gtir(ir, offset_provider_type, column_axis) constant_symbols = find_constant_symbols( ir, sdfg, offset_provider_type, self.disable_field_origin_on_program_arguments diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 9c3301645c..cd773e3c6e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -13,10 +13,10 @@ dace = pytest.importorskip("dace") from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import ( +from gt4py.next.program_processors.runners.dace.lowering import ( gtir_domain as gtx_dace_domain, - utils as gtx_dace_utils, ) +from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -33,12 +33,12 @@ def test_symbolic_domain(): assert gtx_dace_domain.get_field_domain(domain) == [ gtx_dace_domain.FieldopDomainRange( Vertex, - dace.symbolic.SymExpr(gtx_dace_utils.range_start_symbol("arg", Vertex)), - dace.symbolic.SymExpr(gtx_dace_utils.range_stop_symbol("arg", Vertex)), + dace.symbolic.SymExpr(gtx_dace_args.range_start_symbol("arg", Vertex)), + dace.symbolic.SymExpr(gtx_dace_args.range_stop_symbol("arg", Vertex)), ), gtx_dace_domain.FieldopDomainRange( KDim, - dace.symbolic.SymExpr(gtx_dace_utils.range_start_symbol("arg", KDim)), - dace.symbolic.SymExpr(gtx_dace_utils.range_stop_symbol("arg", KDim)), + dace.symbolic.SymExpr(gtx_dace_args.range_start_symbol("arg", KDim)), + dace.symbolic.SymExpr(gtx_dace_args.range_stop_symbol("arg", KDim)), ), ] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py index eec68a6486..c2529f3be3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py @@ -11,11 +11,10 @@ import pytest dace = pytest.importorskip("dace") - -from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils +from gt4py.next.program_processors.runners.dace.lowering import gtir_to_sdfg_utils def test_safe_replace_symbolic(): - assert gtx_dace_utils.safe_replace_symbolic( + assert gtir_to_sdfg_utils.safe_replace_symbolic( dace.symbolic.pystr_to_symbolic("x*x + y"), symbol_mapping={"x": "y", "y": "x"} ) == dace.symbolic.pystr_to_symbolic("y*y + x") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 7b1280d753..0e776308fc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -38,7 +38,7 @@ ) dace = pytest.importorskip("dace") -dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace") +from gt4py.next.program_processors.runners.dace import lowering as dace_lowering @pytest.fixture @@ -121,16 +121,16 @@ def build_dace_sdfg( offset_provider: gtx_common.OffsetProvider, skip_domain_inference: bool = False, ) -> Callable[..., Any]: - """Wrapper of `dace_backend.build_sdfg_from_gtir()` to run domain inference. + """Wrapper of `dace_lowering.build_sdfg_from_gtir()` to run domain inference. - Before calling `dace_backend.build_sdfg_from_gtir()`, it will infer the domain + Before calling `dace_lowering.build_sdfg_from_gtir()`, it will infer the domain of the given `ir`, unless called with `skip_domain_inference=True`. """ if not skip_domain_inference: # run domain inference in order to add the domain annex information to the IR nodes ir = infer_domain.infer_program(ir, offset_provider=offset_provider) offset_provider_type = gtx_common.offset_provider_to_type(offset_provider) - return dace_backend.build_sdfg_from_gtir(ir, offset_provider_type, column_axis=KDim) + return dace_lowering.build_sdfg_from_gtir(ir, offset_provider_type, column_axis=KDim) def apply_margin_on_field_domain( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py index c7a34bca3f..3f686ae583 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py @@ -14,7 +14,6 @@ from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace import ( - gtir_to_sdfg_utils as gtx_sdfg_utils, transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py index 1fad74b8d7..435c37d629 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py @@ -14,7 +14,7 @@ from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace import ( - gtir_to_sdfg_utils as gtx_sdfg_utils, + lowering as gtx_dace_lowering, transformations as gtx_transformations, ) @@ -314,8 +314,8 @@ def _make_horizontal_promoter_sdfg( sdfg = dace.SDFG(util.unique_name("serial_map_promoter_tester")) state = sdfg.add_state(is_start_block=True) - h_idx = gtx_sdfg_utils.get_map_variable(gtx_common.Dimension("boden")) - v_idx = gtx_sdfg_utils.get_map_variable( + h_idx = gtx_dace_lowering.get_map_variable(gtx_common.Dimension("boden")) + v_idx = gtx_dace_lowering.get_map_variable( gtx_common.Dimension("K", gtx_common.DimensionKind.VERTICAL) ) From 42e1d2fa10b40b49f381fafbbd2159b99dd2bbd8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 12:38:57 +0100 Subject: [PATCH 02/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 422 +++++++++--------- .../dace/lowering/gtir_to_sdfg_primitives.py | 47 +- .../dace/lowering/gtir_to_sdfg_scan.py | 44 +- 3 files changed, 261 insertions(+), 252 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 07969bf1bc..7aa864f952 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -15,8 +15,21 @@ from __future__ import annotations import abc +import contextlib import dataclasses -from typing import Any, Dict, Iterable, List, Mapping, Optional, Protocol, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + Union, +) import dace from dace import subsets as dace_subsets @@ -26,7 +39,11 @@ from gt4py.eve import concepts from gt4py.next import common as gtx_common, config as gtx_config, utils as gtx_utils from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args @@ -176,8 +193,26 @@ class SubgraphContext: sdfg: dace.SDFG state: dace.SDFGState + scope_symbols: dict[str, ts.DataType] + + def remove_isolated_nodes(self) -> None: + flat_scope_symbols: list[gtir.Sym] = [] + for sym_name, sym_type in self.scope_symbols.items(): + if isinstance(sym_type, ts.TupleType): + flat_scope_symbols.extend( + gtir_to_sdfg_utils.flatten_tuple_fields(sym_name, sym_type) + ) + else: + flat_scope_symbols.append(im.sym(sym_name, sym_type)) + scope_symbols = set(str(sym.id) for sym in flat_scope_symbols) + isolated_nodes = [ + access_node + for access_node in self.state.data_nodes() + if access_node.data in scope_symbols and self.state.degree(access_node) == 0 + ] + self.state.remove_nodes_from(isolated_nodes) - def copy_field( + def copy_data( self, src: gtir_to_sdfg_types.FieldopData, domain: gtir_domain.FieldopDomain | None, @@ -309,18 +344,13 @@ def make_field( """ ... - @abc.abstractmethod - def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the SDFG.""" - ... - @abc.abstractmethod def is_column_axis(self, dim: gtx_common.Dimension) -> bool: """Check if the given dimension is the column axis.""" ... @abc.abstractmethod - def setup_nested_context( + def setup_nested_sdfg( self, expr: gtir.Lambda, sdfg_name: str, @@ -328,7 +358,7 @@ def setup_nested_context( params: Iterable[gtir.Sym], symbolic_inputs: set[str], capture_scope_symbols: bool, - ) -> tuple[SDFGBuilder, SubgraphContext]: + ) -> tuple[dace.SDFG, list[gtir.Sym]]: """ Create an nested SDFG context to lower a lambda expression, indipendent from the current context where the parent expression is being translated. @@ -349,10 +379,31 @@ def setup_nested_context( GTIR symbols defined in the parent scope. Returns: - A visitor object implementing the `SDFGBuilder` protocol. + The nested SDFG context. """ ... + @abc.abstractmethod + @contextlib.contextmanager + def setup_ctx( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + params: Iterable[gtir.Sym], + ) -> Generator[SubgraphContext, None, None]: + "Allocate scope symbols for given parameters in context." + ... + + @abc.abstractmethod + @contextlib.contextmanager + def setup_ctx_in_new_state( + self, + ctx: SubgraphContext, + state: dace.SDFGState, + ) -> Generator[SubgraphContext, Any, Any]: + "Allocate scope symbols of the given context inside a new state." + ... + @abc.abstractmethod def add_nested_sdfg( self, @@ -483,7 +534,6 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType column_axis: Optional[gtx_common.Dimension] - scope_symbols: dict[str, ts.DataType] uids: gtx_utils.IDGeneratorPool = dataclasses.field( init=False, repr=False, default_factory=lambda: gtx_utils.IDGeneratorPool() ) @@ -522,14 +572,11 @@ def make_field( ) return gtir_to_sdfg_types.FieldopData(data_node, field_type, field_origin) - def get_symbol_type(self, symbol_name: str) -> ts.DataType: - return self.scope_symbols[symbol_name] - def is_column_axis(self, dim: gtx_common.Dimension) -> bool: assert self.column_axis return dim == self.column_axis - def setup_nested_context( + def setup_nested_sdfg( self, expr: gtir.Lambda, sdfg_name: str, @@ -537,7 +584,7 @@ def setup_nested_context( params: Iterable[gtir.Sym], symbolic_inputs: set[str], capture_scope_symbols: bool, - ) -> tuple[SDFGBuilder, SubgraphContext]: + ) -> tuple[dace.SDFG, list[gtir.Sym]]: assert symbolic_inputs.issubset(str(p.id) for p in params) and all( isinstance(p.type, ts.ScalarType) for p in params if str(p.id) in symbolic_inputs ) @@ -549,8 +596,10 @@ def setup_nested_context( # then we add or override symbols defined as lambda parameters. if capture_scope_symbols: lambda_symbols = { - sym: self.scope_symbols[sym] - for sym in symbol_ref_utils.collect_symbol_refs(expr, self.scope_symbols.keys()) + sym: parent_ctx.scope_symbols[sym] + for sym in symbol_ref_utils.collect_symbol_refs( + expr, parent_ctx.scope_symbols.keys() + ) } else: lambda_symbols = {} @@ -588,13 +637,6 @@ def setup_nested_context( sdfg = dace.SDFG(name=self.unique_nsdfg_name(parent_ctx.sdfg, sdfg_name)) sdfg.debuginfo = gtir_to_sdfg_utils.debug_info(expr, default=parent_ctx.sdfg.debuginfo) - state = sdfg.add_state(f"{sdfg_name}_entry") - nested_ctx = SubgraphContext(sdfg, state) - nsdfg_builder = GTIRToSDFG( - offset_provider_type=self.offset_provider_type, - column_axis=self.column_axis, - scope_symbols=lambda_symbols, - ) # All GTIR-symbols accessed in domain expressions by the lambda need to be # represented as dace symbols. @@ -615,13 +657,13 @@ def setup_nested_context( # When they are accessed, the corresponding data descriptor (scalar or array) # will be turned into global by setting `transient=False`. In this way, we remove # all unused (and possibly undefined) input arguments. - nsdfg_builder._add_sdfg_params( + self._add_sdfg_params( sdfg, node_params=input_params, symbolic_params=(domain_symbols | mapped_symbols | symbolic_inputs), use_transient_storage=True, ) - return nsdfg_builder, nested_ctx + return sdfg, input_params def add_nested_sdfg( self, @@ -719,6 +761,34 @@ def add_nested_sdfg( return nsdfg_node, input_memlets + @contextlib.contextmanager + def setup_ctx( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + params: Iterable[gtir.Sym], + ) -> Generator[SubgraphContext, Any, Any]: + "Create access nodes to global data for given parameters in context." + scope_symbols = {str(p.id): p.type for p in params} + assert all(isinstance(sym_type, ts.DataType) for sym_type in scope_symbols.values()) + ctx = SubgraphContext(sdfg, state, scope_symbols) # type: ignore[arg-type] + try: + yield ctx + finally: + ctx.remove_isolated_nodes() + + @contextlib.contextmanager + def setup_ctx_in_new_state( + self, + ctx: SubgraphContext, + state: dace.SDFGState, + ) -> Generator[SubgraphContext, Any, Any]: + new_ctx = SubgraphContext(ctx.sdfg, state, ctx.scope_symbols.copy()) + try: + yield new_ctx + finally: + new_ctx.remove_isolated_nodes() + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: nsdfg_list = [ nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) @@ -861,19 +931,10 @@ def _add_storage( raise RuntimeError(f"Data type '{type(gt_type)}' not supported.") - def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: - """ - Add temporary storage (aka transient) for data containers used as GTIR temporaries. - - Assume all temporaries to be fields, therefore represented as dace arrays. - """ - raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") - def _visit_expression( self, node: gtir.Expr, - sdfg: dace.SDFG, - head_state: dace.SDFGState, + ctx: SubgraphContext, use_temp: bool = True, ) -> gtir_to_sdfg_types.FieldopResult: """ @@ -895,20 +956,17 @@ def _visit_expression( The nodes are organized in tree form, in case of tuples. """ - ctx = SubgraphContext(sdfg, head_state) result = self.visit(node, ctx=ctx) # sanity check: each statement should preserve the property of single exit state (aka head state), # i.e. eventually only introduce internal branches, and keep the same head state - sink_states = sdfg.sink_nodes() + sink_states = ctx.sdfg.sink_nodes() assert len(sink_states) == 1 - assert sink_states[0] == head_state + assert sink_states[0] == ctx.state if use_temp: # copy the full shape of global data to temporary storage return gtx_utils.tree_map( - lambda x: x - if x.dc_node.desc(ctx.sdfg).transient - else ctx.copy_field(x, domain=None) + lambda x: x if x.dc_node.desc(ctx.sdfg).transient else ctx.copy_data(x, domain=None) )(result) else: return result @@ -984,23 +1042,13 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: sdfg.debuginfo = gtir_to_sdfg_utils.debug_info(node) # start block of the stateful graph - entry_state = sdfg.add_state("program_entry", is_start_block=True) - - # declarations of temporaries result in transient array definitions in the SDFG - if node.declarations: - temp_symbols: dict[str, str] = {} - for decl in node.declarations: - temp_symbols |= self._add_storage_for_temporary(decl) - - # define symbols for shape and offsets of temporary arrays as interstate edge symbols - head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols) - else: - head_state = entry_state + head_state = sdfg.add_state("program_entry", is_start_block=True) # By passing `symbolic_arguments=None` all scalars are represented as dace symbols. # We do this to allow lowering of scalar expressions in let-statements, # that only depend on scalar parameters, as dace symbolic expressions # mapped to symbols on a nested SDFG. + assert all(isinstance(p.type, ts.DataType) for p in node.params) sdfg_arg_names = self._add_sdfg_params( sdfg, node.params, symbolic_params=None, use_transient_storage=False ) @@ -1010,7 +1058,8 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") head_state._debuginfo = gtir_to_sdfg_utils.debug_info(stmt, default=sdfg.debuginfo) - head_state = self.visit(stmt, sdfg=sdfg, state=head_state) + with self.setup_ctx(sdfg, head_state, node.params) as ctx: + head_state = self.visit(stmt, ctx=ctx) # remove unused connectivity tables (by design, arrays are marked as non-transient when they are used) for nsdfg in sdfg.all_sdfgs_recursive(): @@ -1030,12 +1079,9 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: # the arguments created by the translation process, must be passed as keyword arguments. sdfg.arg_names = sdfg_arg_names - sdfg.validate() return sdfg - def visit_SetAt( - self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dace.SDFGState: + def visit_SetAt(self, stmt: gtir.SetAt, ctx: SubgraphContext) -> dace.SDFGState: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. @@ -1045,77 +1091,48 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - # Visit the domain expression. - domain = gtir_domain.extract_target_domain(stmt.domain) - - # Visit the field operator expression. - source_tree = self._visit_expression(stmt.expr, sdfg, state) - - # The target expression could be a `SymRef` to an output field or a `make_tuple` - # expression in case the statement returns more than one field. - target_tree = self._visit_expression(stmt.target, sdfg, state, use_temp=False) - - expr_input_args = { - sym_id - for sym in eve.walk_values(stmt.expr).if_isinstance(gtir.SymRef) - if (sym_id := str(sym.id)) in sdfg.arrays - } - state_input_data = { - node.data - for node in state.data_nodes() - if node.data in expr_input_args and state.degree(node) != 0 - } - - # For inout argument, write the result in separate next state - # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X - # If this state is not used, we remove it before returning from the function. - target_state = sdfg.add_state_after(state, f"post_{state.label}") - def _visit_target( source: gtir_to_sdfg_types.FieldopData, target: gtir_to_sdfg_types.FieldopData, target_domain: domain_utils.SymbolicDomain, target_state: dace.SDFGState, ) -> None: - target_desc = sdfg.arrays[target.dc_node.data] - assert not target_desc.transient - assert source.gt_type == target.gt_type field_domain = gtir_domain.get_field_domain(target_domain) source_subset = _make_access_index_for_field(field_domain, source) target_subset = _make_access_index_for_field(field_domain, target) - if target.dc_node.data in state_input_data: - # create new access nodes in the target state - target_state.add_nedge( - target_state.add_access(source.dc_node.data), - target_state.add_access(target.dc_node.data), - dace.Memlet( - data=target.dc_node.data, subset=target_subset, other_subset=source_subset - ), - ) - # remove isolated access node - state.remove_node(target.dc_node) - else: - state.add_nedge( - source.dc_node, - target.dc_node, - dace.Memlet( - data=target.dc_node.data, subset=target_subset, other_subset=source_subset - ), - ) - - gtx_utils.tree_map( - lambda source, target, target_domain: _visit_target( - source, target, target_domain, target_state + target_state.add_nedge( + # create in the target state new access nodes to the field operator result + target_state.add_access(source.dc_node.data), + target.dc_node, + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) - )(source_tree, target_tree, domain) + if ctx.state.degree(source.dc_node) == 0: + ctx.state.remove_node(source.dc_node) - if target_state.is_empty(): - sdfg.remove_node(target_state) - return state - else: - return target_state + # Visit the domain expression. + domain = gtir_domain.extract_target_domain(stmt.domain) + + # Visit the field operator expression. + source_tree = self._visit_expression(stmt.expr, ctx) + + # In order to support inout argument, write the result in separate next state + # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X + target_state = ctx.sdfg.add_state_after(ctx.state, f"post_{ctx.state.label}") + with self.setup_ctx_in_new_state(ctx, target_state) as target_ctx: + # The target expression could be a `SymRef` to an output field or a `make_tuple` + # expression in case the statement returns more than one field. + target_tree = self._visit_expression(stmt.target, target_ctx, use_temp=False) + gtx_utils.tree_map( + lambda source, target, target_domain: _visit_target( + source, target, target_domain, target_state + ) + )(source_tree, target_tree, domain) + + return target_state def visit_FunCall( self, @@ -1200,7 +1217,7 @@ def visit_Lambda( lambda_arg_nodes, symbolic_args = flatten_tuple_args(args.items()) # lower let-statement lambda node as a nested SDFG - lambda_translator, lambda_ctx = self.setup_nested_context( + lambda_sdfg, lambda_params = self.setup_nested_sdfg( expr=node, sdfg_name="lambda", parent_ctx=ctx, @@ -1208,94 +1225,83 @@ def visit_Lambda( symbolic_inputs=set(symbolic_args.keys()), capture_scope_symbols=True, ) - - lambda_result = lambda_translator.visit(node.expr, ctx=lambda_ctx) - - # A let-lambda is allowed to capture GTIR-symbols from the outer scope, - # therefore we call `add_nested_sdfg()` with `capture_outer_data=True`. - nsdfg_node, input_memlets = self.add_nested_sdfg( - node=node, - inner_ctx=lambda_ctx, - outer_ctx=ctx, - symbolic_args=symbolic_args, - data_args=lambda_arg_nodes, - inner_result=lambda_result, - capture_outer_data=True, - ) - - # In this loop we call `pop()`, whenever an argument is connected to an input - # connector on the nested SDFG, so the corresponding node is removed from the dictionary. - for input_connector, memlet in input_memlets.items(): - if input_connector in lambda_arg_nodes: - arg_node = lambda_arg_nodes.pop(input_connector) - assert arg_node is not None - src_node = arg_node.dc_node - else: - src_node = ctx.state.add_access(memlet.data) - - ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - - # We can now safely remove all remaining arguments, because unused. Note - # that we only consider global access nodes, because the only goal of this - # cleanup is to remove isolated nodes. At this stage, temporary input nodes - # should not appear as isolated nodes, because they are supposed to contain - # the result of some argument expression. - if unused_access_nodes := [ - arg_node.dc_node - for arg_node in lambda_arg_nodes.values() - if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) - ]: - assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) - ctx.state.remove_nodes_from(unused_access_nodes) - - def construct_output_for_nested_sdfg( - inner_data: gtir_to_sdfg_types.FieldopData, - ) -> gtir_to_sdfg_types.FieldopData: - """ - This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, - available in the parent SDFG. - In order to achieve this, the data container inside the nested SDFG is marked as non-transient - (in other words, externally allocated - a requirement of the SDFG IR) and a new data container - is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` - but appropriatly remapped using the symbol mapping table. - For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped - to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. - The same happens to symbols available in the lambda context but not explicitly passed as lambda - arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. - """ - if not inner_data.dc_node.desc(lambda_ctx.sdfg).transient: - # This is inout data: some global associated to an input connector - # is also a sink node of the lambda dataflow. This can happen, for - # example, when the lambda constructs a tuple of some input fields. - # We copy this data to a new node, which we use as output. - nsdfg_node.remove_out_connector(inner_data.dc_node.data) - inner_data = lambda_ctx.copy_field(inner_data, domain=None) - nsdfg_node.add_out_connector(inner_data.dc_node.data) - elif lambda_ctx.state.degree(inner_data.dc_node) == 0: - # Isolated access node will make validation fail. - # Isolated access nodes can be found in the join-state of an if-expression. - lambda_ctx.state.remove_node(inner_data.dc_node) - # Transient data nodes only exist within the nested SDFG. In order to return some result data, - # the corresponding data container inside the nested SDFG has to be changed to non-transient, - # that is externally allocated, as required by the SDFG IR. An output edge will write the result - # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. - outer_data = ctx.map_nsdfg_field( - sdfg_builder=self, - nsdfg_field=inner_data, - nsdfg=lambda_ctx.sdfg, - symbol_mapping=nsdfg_node.symbol_mapping, - ) - ctx.state.add_edge( - nsdfg_node, - inner_data.dc_node.data, - outer_data.dc_node, - None, - ctx.sdfg.make_array_memlet(outer_data.dc_node.data), + lambda_state = lambda_sdfg.add_state("lambda") + + with self.setup_ctx(lambda_sdfg, lambda_state, lambda_params) as lambda_ctx: + lambda_result = self.visit(node.expr, ctx=lambda_ctx) + + # A let-lambda is allowed to capture GTIR-symbols from the outer scope, + # therefore we call `add_nested_sdfg()` with `capture_outer_data=True`. + nsdfg_node, input_memlets = self.add_nested_sdfg( + node=node, + inner_ctx=lambda_ctx, + outer_ctx=ctx, + symbolic_args=symbolic_args, + data_args=lambda_arg_nodes, + inner_result=lambda_result, + capture_outer_data=True, ) - return outer_data + # In this loop we call `pop()`, whenever an argument is connected to an input + # connector on the nested SDFG, so the corresponding node is removed from the dictionary. + for input_connector, memlet in input_memlets.items(): + if input_connector in lambda_arg_nodes: + arg_node = lambda_arg_nodes.pop(input_connector) + assert arg_node is not None + src_node = arg_node.dc_node + else: + src_node = ctx.state.add_access(memlet.data) + + ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + + def construct_output_for_nested_sdfg( + inner_data: gtir_to_sdfg_types.FieldopData, + ) -> gtir_to_sdfg_types.FieldopData: + """ + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. + """ + if not inner_data.dc_node.desc(lambda_sdfg).transient: + # This is inout data: some global associated to an input connector + # is also a sink node of the lambda dataflow. This can happen, for + # example, when the lambda constructs a tuple of some input fields. + # We copy this data to a new node, which we use as output. + nsdfg_node.remove_out_connector(inner_data.dc_node.data) + inner_data = lambda_ctx.copy_data(inner_data, domain=None) + nsdfg_node.add_out_connector(inner_data.dc_node.data) + elif lambda_ctx.state.degree(inner_data.dc_node) == 0: + # Isolated access node will make validation fail. + # Isolated access nodes can be found in the join-state of an if-expression. + lambda_ctx.state.remove_node(inner_data.dc_node) + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + outer_data = ctx.map_nsdfg_field( + sdfg_builder=self, + nsdfg_field=inner_data, + nsdfg=lambda_sdfg, + symbol_mapping=nsdfg_node.symbol_mapping, + ) + ctx.state.add_edge( + nsdfg_node, + inner_data.dc_node.data, + outer_data.dc_node, + None, + ctx.sdfg.make_array_memlet(outer_data.dc_node.data), + ) - return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) + return outer_data + + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) def visit_Literal( self, @@ -1334,6 +1340,8 @@ def build_sdfg_from_gtir( if ir.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + if ir.declarations: + raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) @@ -1344,9 +1352,9 @@ def build_sdfg_from_gtir( # Here we find new names for invalid symbols present in the IR. ir = gtir_to_sdfg_utils.replace_invalid_symbols(ir) - global_symbols = {str(p.id): p.type for p in ir.params if isinstance(p.type, ts.DataType)} - sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis, global_symbols) + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) + sdfg.validate() return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index be0ff6130d..766ea63fff 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -269,7 +269,7 @@ def translate_as_fieldop( # on the given domain. It copies a subset of the source field. arg = sdfg_builder.visit(node.args[0], ctx=ctx) assert isinstance(arg, gtir_to_sdfg_types.FieldopData) - return ctx.copy_field(arg, domain=field_domain) + return ctx.copy_data(arg, domain=field_domain) elif isinstance(fieldop_expr, gtir.Lambda): # Default case, handled below: the argument expression is a lambda function # representing the stencil operation to be computed over the field domain. @@ -426,34 +426,34 @@ def translate_if( # expect true branch as second argument true_state = ctx.sdfg.add_state(ctx.state.label + "_true_branch") - tbranch_ctx = gtir_to_sdfg.SubgraphContext(ctx.sdfg, true_state) ctx.sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=if_stmt)) ctx.sdfg.add_edge(true_state, ctx.state, dace.InterstateEdge()) # and false branch as third argument false_state = ctx.sdfg.add_state(ctx.state.label + "_false_branch") - fbranch_ctx = gtir_to_sdfg.SubgraphContext(ctx.sdfg, false_state) ctx.sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=f"not({if_stmt})")) ctx.sdfg.add_edge(false_state, ctx.state, dace.InterstateEdge()) - true_br_result = sdfg_builder.visit(true_expr, ctx=tbranch_ctx) - false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) - - node_output = gtx_utils.tree_map( - lambda domain, true_br, false_br: _construct_if_branch_output( - ctx, sdfg_builder, domain, true_br, false_br - ) - )( - node.annex.domain, - true_br_result, - false_br_result, - ) - gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( - true_br_result, node_output - ) - gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( - false_br_result, node_output - ) + with sdfg_builder.setup_ctx_in_new_state(ctx, true_state) as tbranch_ctx: + with sdfg_builder.setup_ctx_in_new_state(ctx, false_state) as fbranch_ctx: + true_br_result = sdfg_builder.visit(true_expr, ctx=tbranch_ctx) + false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) + + node_output = gtx_utils.tree_map( + lambda domain, true_br, false_br: _construct_if_branch_output( + ctx, sdfg_builder, domain, true_br, false_br + ) + )( + node.annex.domain, + true_br_result, + false_br_result, + ) + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( + true_br_result, node_output + ) + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( + false_br_result, node_output + ) return node_output @@ -659,7 +659,8 @@ def translate_scalar_expr( if isinstance(arg_expr, gtir.SymRef): try: # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined - sdfg_builder.get_symbol_type(arg_expr.id) + sym_name = str(arg_expr.id) + ctx.scope_symbols[sym_name] except KeyError: # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` @@ -726,7 +727,7 @@ def translate_symbol_ref( symbol_name = str(node.id) # we retrieve the type of the symbol in the GT4Py prgram - gt_symbol_type = sdfg_builder.get_symbol_type(symbol_name) + gt_symbol_type = ctx.scope_symbols[symbol_name] # Create new access node in current state. It is possible that multiple # access nodes are created in one state for the same data container. diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index 874a3d63d2..9a56ef5117 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -347,7 +347,7 @@ def _lower_lambda_to_nested_sdfg( ) ) # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. - lambda_translator, lambda_ctx = sdfg_builder.setup_nested_context( + lambda_sdfg, _ = sdfg_builder.setup_nested_sdfg( lambda_node, "scan", ctx, @@ -359,10 +359,10 @@ def _lower_lambda_to_nested_sdfg( # We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`. # This property is used by pattern matching in SDFG transformation framework # to skip those transformations that do not yet support control flow blocks. - lambda_ctx.sdfg.using_explicit_control_flow = True + lambda_sdfg.using_explicit_control_flow = True # We use the entry state for initialization of the scan carry variable. - init_state = lambda_ctx.state + init_state = lambda_sdfg.add_state("init") # use the vertical dimension in the domain as scan dimension scan_domain = next(r for r in field_domain if sdfg_builder.is_column_axis(r.dim)) @@ -410,27 +410,27 @@ def get_scan_output_shape( update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", inverted=False, ) - lambda_ctx.sdfg.add_node(scan_loop, ensure_unique_name=True) - lambda_ctx.sdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + lambda_sdfg.add_node(scan_loop, ensure_unique_name=True) + lambda_sdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) # Inside the loop region, create a 'compute' and an 'update' state. # The body of the 'compute' state implements the stencil expression for one vertical level. # The 'update' state writes the value computed by the stencil into the scan carry variable, # in order to make it available to the next vertical level. compute_state = scan_loop.add_state("scan_compute") - compute_ctx = gtir_to_sdfg.SubgraphContext(lambda_ctx.sdfg, compute_state) update_state = scan_loop.add_state_after(compute_state, "scan_update") - # inside the 'compute' state, visit the list of arguments to be passed to the stencil - stencil_args = [ - _parse_scan_fieldop_arg(im.ref(p.id), compute_ctx, lambda_translator, field_domain) - for p in lambda_node.params - ] - # stil inside the 'compute' state, generate the dataflow representing the stencil - # to be applied on the horizontal domain - lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( - compute_ctx.sdfg, compute_ctx.state, lambda_translator, lambda_node, stencil_args - ) + with sdfg_builder.setup_ctx(lambda_sdfg, compute_state, lambda_params) as compute_ctx: + # inside the 'compute' state, visit the list of arguments to be passed to the stencil + stencil_args = [ + _parse_scan_fieldop_arg(im.ref(p.id), compute_ctx, sdfg_builder, field_domain) + for p in lambda_node.params + ] + # stil inside the 'compute' state, generate the dataflow representing the stencil + # to be applied on the horizontal domain + lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( + compute_ctx.sdfg, compute_ctx.state, sdfg_builder, lambda_node, stencil_args + ) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region for edge in lambda_input_edges: @@ -451,15 +451,15 @@ def get_scan_output_shape( def init_scan_carry(sym: gtir.Sym) -> None: scan_carry_dataname = str(sym.id) - scan_carry_desc = lambda_ctx.sdfg.data(scan_carry_dataname) + scan_carry_desc = lambda_sdfg.data(scan_carry_dataname) input_scan_carry_dataname = _scan_input_name(scan_carry_dataname) input_scan_carry_desc = scan_carry_desc.clone() - lambda_ctx.sdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) + lambda_sdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) scan_carry_desc.transient = True init_state.add_nedge( init_state.add_access(input_scan_carry_dataname), init_state.add_access(scan_carry_dataname), - lambda_ctx.sdfg.make_array_memlet(input_scan_carry_dataname), + lambda_sdfg.make_array_memlet(input_scan_carry_dataname), ) if isinstance(scan_carry_input, tuple): @@ -484,7 +484,7 @@ def connect_scan_output( else: raise NotImplementedError("scan with list output is not supported.") scan_result_data = scan_result.dc_node.data - scan_result_desc = scan_result.dc_node.desc(lambda_ctx.sdfg) + scan_result_desc = scan_result.dc_node.desc(lambda_sdfg) scan_result_subset = dace_subsets.Range.from_array(scan_result_desc) # `sym` represents the global output data, that is the nested-SDFG output connector @@ -493,7 +493,7 @@ def connect_scan_output( # Note that we set `transient=True` because the lowering expects the dataflow # of nested SDDFG to write to some internal temporary nodes. These data elements # should be turned into globals by the caller and handled as output connections. - lambda_ctx.sdfg.add_array(output, scan_output_shape, scan_result_desc.dtype, transient=True) + lambda_sdfg.add_array(output, scan_output_shape, scan_result_desc.dtype, transient=True) output_node = compute_state.add_access(output) # in the 'compute' state, we write the current vertical level data to the output field @@ -534,7 +534,7 @@ def connect_scan_output( if compute_state.degree(state_node) == 0: compute_state.remove_node(state_node) - return lambda_ctx, lambda_output + return compute_ctx, lambda_output def _handle_dataflow_result_of_nested_sdfg( From 66ec368356b06cbb996a6583904225d2cd0b6863 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 12:50:53 +0100 Subject: [PATCH 03/28] fix --- .../runners/dace/lowering/gtir_to_sdfg_scan.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index 9a56ef5117..204ee4a39c 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -523,17 +523,6 @@ def connect_scan_output( lambda_result, lambda_result_shape, scan_carry_input ) - # Corner case where the scan computation, on one level, does not depend on - # the result from previous level. In this case, the state information from - # previous level is not used, therefore we could find isolated access nodes. - # In case of tuples, it might be that only some of the fields are used. - # In case of scalars, this is probably a misuse of scan in application code: - # it could have been represented as a pure field operator. - for arg in gtx_utils.flatten_nested_tuple((stencil_args[0],)): - state_node = arg.dc_node - if compute_state.degree(state_node) == 0: - compute_state.remove_node(state_node) - return compute_ctx, lambda_output From 076a2b8e6b7121339e87071a1bf46b36961003a6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 12:58:40 +0100 Subject: [PATCH 04/28] edit --- .../program_processors/runners/dace/lowering/gtir_to_sdfg.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 7aa864f952..505e4f85b9 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1277,10 +1277,7 @@ def construct_output_for_nested_sdfg( nsdfg_node.remove_out_connector(inner_data.dc_node.data) inner_data = lambda_ctx.copy_data(inner_data, domain=None) nsdfg_node.add_out_connector(inner_data.dc_node.data) - elif lambda_ctx.state.degree(inner_data.dc_node) == 0: - # Isolated access node will make validation fail. - # Isolated access nodes can be found in the join-state of an if-expression. - lambda_ctx.state.remove_node(inner_data.dc_node) + # Transient data nodes only exist within the nested SDFG. In order to return some result data, # the corresponding data container inside the nested SDFG has to be changed to non-transient, # that is externally allocated, as required by the SDFG IR. An output edge will write the result From 28673276e7ed9d0336df43be0712577bc91c5209 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 13:16:20 +0100 Subject: [PATCH 05/28] Use unique temp name across let-lambda scopes --- .../runners/dace/lowering/gtir_to_sdfg.py | 26 +++++++++++++------ .../dace/lowering/gtir_to_sdfg_primitives.py | 2 +- .../dace/lowering/gtir_to_sdfg_scan.py | 4 ++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 505e4f85b9..144faf7272 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -83,23 +83,27 @@ def unique_map_name(self, name: str) -> str: ... @abc.abstractmethod def unique_tasklet_name(self, name: str) -> str: ... + @abc.abstractmethod + def unique_temp_name(self) -> str: ... + def add_temp_array( self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass ) -> tuple[str, dace.data.Scalar]: """Add a temporary array to the SDFG.""" - return sdfg.add_temp_transient(shape, dtype) + temp_name = self.unique_temp_name() + return sdfg.add_transient(temp_name, shape, dtype) def add_temp_array_like( self, sdfg: dace.SDFG, datadesc: dace.data.Array ) -> tuple[str, dace.data.Scalar]: """Add a temporary array to the SDFG.""" - return sdfg.add_temp_transient_like(datadesc) + return sdfg.add_temp_transient_like(datadesc, name=self.unique_temp_name()) def add_temp_scalar( self, sdfg: dace.SDFG, dtype: dace.dtypes.typeclass ) -> tuple[str, dace.data.Scalar]: """Add a temporary scalar to the SDFG.""" - temp_name = sdfg.temp_data_name() + temp_name = self.unique_temp_name() return sdfg.add_scalar(temp_name, dtype, transient=True) def add_map( @@ -214,6 +218,7 @@ def remove_isolated_nodes(self) -> None: def copy_data( self, + sdfg_builder: SDFGBuilder, src: gtir_to_sdfg_types.FieldopData, domain: gtir_domain.FieldopDomain | None, ) -> gtir_to_sdfg_types.FieldopData: @@ -236,13 +241,13 @@ def copy_data( data_desc = src.dc_node.desc(self.sdfg) if isinstance(src.gt_type, ts.FieldType): if domain is None: - out, out_desc = self.sdfg.add_temp_transient_like(data_desc) + out, out_desc = sdfg_builder.add_temp_array_like(self.sdfg, data_desc) out_origin = list(src.origin) src_subset = ",".join(f"0:{size}" for size in data_desc.shape) else: out_dims, out_origin, out_shape = gtir_domain.get_field_layout(domain) assert out_dims == src.gt_type.dims - out, out_desc = self.sdfg.add_temp_transient(out_shape, data_desc.dtype) + out, out_desc = sdfg_builder.add_temp_array(self.sdfg, out_shape, data_desc.dtype) src_subset = ",".join( f"{dst_origin - src_origin}:{dst_origin - src_origin + size}" for dst_origin, src_origin, size in zip( @@ -252,7 +257,7 @@ def copy_data( else: assert domain is None assert isinstance(data_desc, dace.data.Scalar) - out, out_desc = self.sdfg.add_temp_transient_like(data_desc) + out, out_desc = sdfg_builder.add_temp_array_like(self.sdfg, data_desc) out_origin = [] src_subset = "0" @@ -801,6 +806,9 @@ def unique_map_name(self, name: str) -> str: def unique_tasklet_name(self, name: str) -> str: return f"{next(self.uids['tlet'])}_{name}" + def unique_temp_name(self) -> str: + return f"{next(self.uids['gtir_tmp'])}" + def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: @@ -966,7 +974,9 @@ def _visit_expression( if use_temp: # copy the full shape of global data to temporary storage return gtx_utils.tree_map( - lambda x: x if x.dc_node.desc(ctx.sdfg).transient else ctx.copy_data(x, domain=None) + lambda x: x + if x.dc_node.desc(ctx.sdfg).transient + else ctx.copy_data(self, x, domain=None) )(result) else: return result @@ -1275,7 +1285,7 @@ def construct_output_for_nested_sdfg( # example, when the lambda constructs a tuple of some input fields. # We copy this data to a new node, which we use as output. nsdfg_node.remove_out_connector(inner_data.dc_node.data) - inner_data = lambda_ctx.copy_data(inner_data, domain=None) + inner_data = lambda_ctx.copy_data(self, inner_data, domain=None) nsdfg_node.add_out_connector(inner_data.dc_node.data) # Transient data nodes only exist within the nested SDFG. In order to return some result data, diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 766ea63fff..0e4f7de2f6 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -269,7 +269,7 @@ def translate_as_fieldop( # on the given domain. It copies a subset of the source field. arg = sdfg_builder.visit(node.args[0], ctx=ctx) assert isinstance(arg, gtir_to_sdfg_types.FieldopData) - return ctx.copy_data(arg, domain=field_domain) + return ctx.copy_data(sdfg_builder, arg, domain=field_domain) elif isinstance(fieldop_expr, gtir.Lambda): # Default case, handled below: the argument expression is a lambda function # representing the stencil operation to be computed over the field domain. diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index 204ee4a39c..5a149fe065 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -527,6 +527,7 @@ def connect_scan_output( def _handle_dataflow_result_of_nested_sdfg( + sdfg_builder: gtir_to_sdfg.SDFGBuilder, nsdfg_node: dace.nodes.NestedSDFG, inner_ctx: gtir_to_sdfg.SubgraphContext, outer_ctx: gtir_to_sdfg.SubgraphContext, @@ -542,7 +543,7 @@ def _handle_dataflow_result_of_nested_sdfg( # The field is used outside the nested SDFG, therefore it needs to be copied # to a temporary array in the parent SDFG (outer context). inner_desc.transient = False - outer_dataname, outer_desc = outer_ctx.sdfg.add_temp_transient_like(inner_desc) + outer_dataname, outer_desc = sdfg_builder.add_temp_array_like(outer_ctx.sdfg, inner_desc) outer_node = outer_ctx.state.add_access(outer_dataname) outer_ctx.state.add_edge( nsdfg_node, @@ -684,6 +685,7 @@ def translate_scan( # results of a column slice for each point in the horizontal domain output_tree = gtx_utils.tree_map( lambda output_data, output_domain: _handle_dataflow_result_of_nested_sdfg( + sdfg_builder=sdfg_builder, nsdfg_node=nsdfg_node, inner_ctx=lambda_ctx, outer_ctx=ctx, From 4174aa731afeb803133f4182b9ad839ceba39b6f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 13:50:59 +0100 Subject: [PATCH 06/28] fix --- .../program_processors/runners/dace/lowering/gtir_to_sdfg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 144faf7272..d8b2898b2f 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1287,6 +1287,10 @@ def construct_output_for_nested_sdfg( nsdfg_node.remove_out_connector(inner_data.dc_node.data) inner_data = lambda_ctx.copy_data(self, inner_data, domain=None) nsdfg_node.add_out_connector(inner_data.dc_node.data) + elif lambda_ctx.state.degree(inner_data.dc_node) == 0: + # Isolated access node will make validation fail. + # Isolated access nodes can be found in the join-state of an if-expression. + lambda_ctx.state.remove_node(inner_data.dc_node) # Transient data nodes only exist within the nested SDFG. In order to return some result data, # the corresponding data container inside the nested SDFG has to be changed to non-transient, From f9e304c1cf57145d04a039dbfca3969125fa60ee Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 14:01:42 +0100 Subject: [PATCH 07/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index d8b2898b2f..82fdf531dd 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1264,6 +1264,19 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + # We can now safely remove all remaining arguments, because unused. Note + # that we only consider global access nodes, because the only goal of this + # cleanup is to remove isolated nodes. At this stage, temporary input nodes + # should not appear as isolated nodes, because they are supposed to contain + # the result of some argument expression. + if unused_access_nodes := [ + arg_node.dc_node + for arg_node in lambda_arg_nodes.values() + if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) + ]: + assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) + ctx.state.remove_nodes_from(unused_access_nodes) + def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: From 98c09081bc088e7537d8523bb9784c9da4215fd5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 14:10:12 +0100 Subject: [PATCH 08/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 82fdf531dd..14c9777675 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -975,7 +975,7 @@ def _visit_expression( if use_temp: # copy the full shape of global data to temporary storage return gtx_utils.tree_map( lambda x: x - if x.dc_node.desc(ctx.sdfg).transient + if x is None or x.dc_node.desc(ctx.sdfg).transient else ctx.copy_data(self, x, domain=None) )(result) else: @@ -1264,19 +1264,6 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - # We can now safely remove all remaining arguments, because unused. Note - # that we only consider global access nodes, because the only goal of this - # cleanup is to remove isolated nodes. At this stage, temporary input nodes - # should not appear as isolated nodes, because they are supposed to contain - # the result of some argument expression. - if unused_access_nodes := [ - arg_node.dc_node - for arg_node in lambda_arg_nodes.values() - if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) - ]: - assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) - ctx.state.remove_nodes_from(unused_access_nodes) - def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: From 854c22e2efcc1140b6f83a497d5d8debac42f816 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 16:49:36 +0100 Subject: [PATCH 09/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 105 +++++++++--------- .../dace/lowering/gtir_to_sdfg_scan.py | 2 +- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 14c9777675..169d5dcc82 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -208,11 +208,12 @@ def remove_isolated_nodes(self) -> None: ) else: flat_scope_symbols.append(im.sym(sym_name, sym_type)) - scope_symbols = set(str(sym.id) for sym in flat_scope_symbols) + data_names = set(str(sym.id) for sym in flat_scope_symbols) + assert len(data_names) == len(flat_scope_symbols) isolated_nodes = [ access_node for access_node in self.state.data_nodes() - if access_node.data in scope_symbols and self.state.degree(access_node) == 0 + if access_node.data in data_names and self.state.degree(access_node) == 0 ] self.state.remove_nodes_from(isolated_nodes) @@ -384,7 +385,9 @@ def setup_nested_sdfg( GTIR symbols defined in the parent scope. Returns: - The nested SDFG context. + - The new SDFG in which to lower a nested expression. + - The list of input parameters, which includes both the parameter list + of the given expression and symbols captured from the current context. """ ... @@ -1235,7 +1238,7 @@ def visit_Lambda( symbolic_inputs=set(symbolic_args.keys()), capture_scope_symbols=True, ) - lambda_state = lambda_sdfg.add_state("lambda") + lambda_state = lambda_sdfg.add_state("lambda", is_start_block=True) with self.setup_ctx(lambda_sdfg, lambda_state, lambda_params) as lambda_ctx: lambda_result = self.visit(node.expr, ctx=lambda_ctx) @@ -1264,55 +1267,55 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - def construct_output_for_nested_sdfg( - inner_data: gtir_to_sdfg_types.FieldopData, - ) -> gtir_to_sdfg_types.FieldopData: - """ - This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, - available in the parent SDFG. - In order to achieve this, the data container inside the nested SDFG is marked as non-transient - (in other words, externally allocated - a requirement of the SDFG IR) and a new data container - is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` - but appropriatly remapped using the symbol mapping table. - For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped - to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. - The same happens to symbols available in the lambda context but not explicitly passed as lambda - arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. - """ - if not inner_data.dc_node.desc(lambda_sdfg).transient: - # This is inout data: some global associated to an input connector - # is also a sink node of the lambda dataflow. This can happen, for - # example, when the lambda constructs a tuple of some input fields. - # We copy this data to a new node, which we use as output. - nsdfg_node.remove_out_connector(inner_data.dc_node.data) - inner_data = lambda_ctx.copy_data(self, inner_data, domain=None) - nsdfg_node.add_out_connector(inner_data.dc_node.data) - elif lambda_ctx.state.degree(inner_data.dc_node) == 0: - # Isolated access node will make validation fail. - # Isolated access nodes can be found in the join-state of an if-expression. - lambda_ctx.state.remove_node(inner_data.dc_node) - - # Transient data nodes only exist within the nested SDFG. In order to return some result data, - # the corresponding data container inside the nested SDFG has to be changed to non-transient, - # that is externally allocated, as required by the SDFG IR. An output edge will write the result - # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. - outer_data = ctx.map_nsdfg_field( - sdfg_builder=self, - nsdfg_field=inner_data, - nsdfg=lambda_sdfg, - symbol_mapping=nsdfg_node.symbol_mapping, - ) - ctx.state.add_edge( - nsdfg_node, - inner_data.dc_node.data, - outer_data.dc_node, - None, - ctx.sdfg.make_array_memlet(outer_data.dc_node.data), - ) + def construct_output_for_nested_sdfg( + inner_data: gtir_to_sdfg_types.FieldopData, + ) -> gtir_to_sdfg_types.FieldopData: + """ + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. + """ + if not inner_data.dc_node.desc(lambda_sdfg).transient: + # This is inout data: some global associated to an input connector + # is also a sink node of the lambda dataflow. This can happen, for + # example, when the lambda constructs a tuple of some input fields. + # We copy this data to a new node, which we use as output. + nsdfg_node.remove_out_connector(inner_data.dc_node.data) + inner_data = lambda_ctx.copy_data(self, inner_data, domain=None) + nsdfg_node.add_out_connector(inner_data.dc_node.data) + elif lambda_ctx.state.degree(inner_data.dc_node) == 0: + # Isolated access node will make validation fail. + # Isolated access nodes can be found in the join-state of an if-expression. + lambda_ctx.state.remove_node(inner_data.dc_node) + + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + outer_data = ctx.map_nsdfg_field( + sdfg_builder=self, + nsdfg_field=inner_data, + nsdfg=lambda_sdfg, + symbol_mapping=nsdfg_node.symbol_mapping, + ) + ctx.state.add_edge( + nsdfg_node, + inner_data.dc_node.data, + outer_data.dc_node, + None, + ctx.sdfg.make_array_memlet(outer_data.dc_node.data), + ) - return outer_data + return outer_data - return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) def visit_Literal( self, diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index 5a149fe065..8769c3fd57 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -362,7 +362,7 @@ def _lower_lambda_to_nested_sdfg( lambda_sdfg.using_explicit_control_flow = True # We use the entry state for initialization of the scan carry variable. - init_state = lambda_sdfg.add_state("init") + init_state = lambda_sdfg.add_state("init", is_start_block=True) # use the vertical dimension in the domain as scan dimension scan_domain = next(r for r in field_domain if sdfg_builder.is_column_axis(r.dim)) From 3a788a708a4e49f9d44ee823593ec5fcf9d7c144 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 17:25:18 +0100 Subject: [PATCH 10/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 8 +++----- .../dace/lowering/gtir_to_sdfg_primitives.py | 14 +++++--------- .../runners/dace/workflow/translation.py | 1 + 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 169d5dcc82..cbb929321a 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -399,7 +399,7 @@ def setup_ctx( state: dace.SDFGState, params: Iterable[gtir.Sym], ) -> Generator[SubgraphContext, None, None]: - "Allocate scope symbols for given parameters in context." + "Setup a lowering context with the given parameters in scope." ... @abc.abstractmethod @@ -409,7 +409,7 @@ def setup_ctx_in_new_state( ctx: SubgraphContext, state: dace.SDFGState, ) -> Generator[SubgraphContext, Any, Any]: - "Allocate scope symbols of the given context inside a new state." + "Setup same lowering context inside a new state, by copying all symbols in current scope." ... @abc.abstractmethod @@ -776,9 +776,8 @@ def setup_ctx( state: dace.SDFGState, params: Iterable[gtir.Sym], ) -> Generator[SubgraphContext, Any, Any]: - "Create access nodes to global data for given parameters in context." + assert all(isinstance(p.type, ts.DataType) for p in params) scope_symbols = {str(p.id): p.type for p in params} - assert all(isinstance(sym_type, ts.DataType) for sym_type in scope_symbols.values()) ctx = SubgraphContext(sdfg, state, scope_symbols) # type: ignore[arg-type] try: yield ctx @@ -1294,7 +1293,6 @@ def construct_output_for_nested_sdfg( # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression. lambda_ctx.state.remove_node(inner_data.dc_node) - # Transient data nodes only exist within the nested SDFG. In order to return some result data, # the corresponding data container inside the nested SDFG has to be changed to non-transient, # that is externally allocated, as required by the SDFG IR. An output edge will write the result diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 0e4f7de2f6..7b0534160e 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -655,16 +655,12 @@ def translate_scalar_expr( scalar_expr_args = [] for i, arg_expr in enumerate(node.args): - visit_expr = True if isinstance(arg_expr, gtir.SymRef): - try: - # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined - sym_name = str(arg_expr.id) - ctx.scope_symbols[sym_name] - except KeyError: - # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, - # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` - visit_expr = False + # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, + # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` + visit_expr = str(arg_expr.id) in ctx.scope_symbols + else: + visit_expr = True if visit_expr: # we visit the argument expression and obtain the access node to diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index e45132bc60..6edd93ae41 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -381,6 +381,7 @@ def _generate_sdfg_without_configuring_dace( ir, sdfg, offset_provider_type, self.disable_field_origin_on_program_arguments ) + return sdfg if self.auto_optimize: auto_optimize_args = {} if self.auto_optimize_args is None else self.auto_optimize_args From 859b43c80c3fe36d7c913b0ea67de78f55458283 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Dec 2025 17:34:58 +0100 Subject: [PATCH 11/28] undo extra change --- .../next/program_processors/runners/dace/workflow/translation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 6edd93ae41..e45132bc60 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -381,7 +381,6 @@ def _generate_sdfg_without_configuring_dace( ir, sdfg, offset_provider_type, self.disable_field_origin_on_program_arguments ) - return sdfg if self.auto_optimize: auto_optimize_args = {} if self.auto_optimize_args is None else self.auto_optimize_args From b3437c1adb424a4284b7401b428ffcc04d794a52 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Sat, 6 Dec 2025 10:27:17 +0100 Subject: [PATCH 12/28] undo extra change --- .../runners/dace/lowering/gtir_to_sdfg.py | 198 ++++++------------ .../dace/lowering/gtir_to_sdfg_primitives.py | 40 ++-- .../dace/lowering/gtir_to_sdfg_scan.py | 55 +++-- 3 files changed, 118 insertions(+), 175 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index cbb929321a..9e17883ff0 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -15,21 +15,8 @@ from __future__ import annotations import abc -import contextlib import dataclasses -from typing import ( - Any, - Dict, - Generator, - Iterable, - List, - Mapping, - Optional, - Protocol, - Sequence, - Tuple, - Union, -) +from typing import Any, Dict, Iterable, List, Mapping, Optional, Protocol, Sequence, Tuple, Union import dace from dace import subsets as dace_subsets @@ -39,11 +26,7 @@ from gt4py.eve import concepts from gt4py.next import common as gtx_common, config as gtx_config, utils as gtx_utils from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args @@ -199,24 +182,6 @@ class SubgraphContext: state: dace.SDFGState scope_symbols: dict[str, ts.DataType] - def remove_isolated_nodes(self) -> None: - flat_scope_symbols: list[gtir.Sym] = [] - for sym_name, sym_type in self.scope_symbols.items(): - if isinstance(sym_type, ts.TupleType): - flat_scope_symbols.extend( - gtir_to_sdfg_utils.flatten_tuple_fields(sym_name, sym_type) - ) - else: - flat_scope_symbols.append(im.sym(sym_name, sym_type)) - data_names = set(str(sym.id) for sym in flat_scope_symbols) - assert len(data_names) == len(flat_scope_symbols) - isolated_nodes = [ - access_node - for access_node in self.state.data_nodes() - if access_node.data in data_names and self.state.degree(access_node) == 0 - ] - self.state.remove_nodes_from(isolated_nodes) - def copy_data( self, sdfg_builder: SDFGBuilder, @@ -356,7 +321,7 @@ def is_column_axis(self, dim: gtx_common.Dimension) -> bool: ... @abc.abstractmethod - def setup_nested_sdfg( + def setup_nested_context( self, expr: gtir.Lambda, sdfg_name: str, @@ -364,7 +329,7 @@ def setup_nested_sdfg( params: Iterable[gtir.Sym], symbolic_inputs: set[str], capture_scope_symbols: bool, - ) -> tuple[dace.SDFG, list[gtir.Sym]]: + ) -> SubgraphContext: """ Create an nested SDFG context to lower a lambda expression, indipendent from the current context where the parent expression is being translated. @@ -385,33 +350,10 @@ def setup_nested_sdfg( GTIR symbols defined in the parent scope. Returns: - - The new SDFG in which to lower a nested expression. - - The list of input parameters, which includes both the parameter list - of the given expression and symbols captured from the current context. + The new SDFG context in which to lower a nested expression. """ ... - @abc.abstractmethod - @contextlib.contextmanager - def setup_ctx( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - params: Iterable[gtir.Sym], - ) -> Generator[SubgraphContext, None, None]: - "Setup a lowering context with the given parameters in scope." - ... - - @abc.abstractmethod - @contextlib.contextmanager - def setup_ctx_in_new_state( - self, - ctx: SubgraphContext, - state: dace.SDFGState, - ) -> Generator[SubgraphContext, Any, Any]: - "Setup same lowering context inside a new state, by copying all symbols in current scope." - ... - @abc.abstractmethod def add_nested_sdfg( self, @@ -584,7 +526,7 @@ def is_column_axis(self, dim: gtx_common.Dimension) -> bool: assert self.column_axis return dim == self.column_axis - def setup_nested_sdfg( + def setup_nested_context( self, expr: gtir.Lambda, sdfg_name: str, @@ -592,7 +534,7 @@ def setup_nested_sdfg( params: Iterable[gtir.Sym], symbolic_inputs: set[str], capture_scope_symbols: bool, - ) -> tuple[dace.SDFG, list[gtir.Sym]]: + ) -> SubgraphContext: assert symbolic_inputs.issubset(str(p.id) for p in params) and all( isinstance(p.type, ts.ScalarType) for p in params if str(p.id) in symbolic_inputs ) @@ -671,7 +613,9 @@ def setup_nested_sdfg( symbolic_params=(domain_symbols | mapped_symbols | symbolic_inputs), use_transient_storage=True, ) - return sdfg, input_params + + state = sdfg.add_state(f"{sdfg_name}_entry") + return SubgraphContext(sdfg, state, lambda_symbols) def add_nested_sdfg( self, @@ -769,33 +713,6 @@ def add_nested_sdfg( return nsdfg_node, input_memlets - @contextlib.contextmanager - def setup_ctx( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - params: Iterable[gtir.Sym], - ) -> Generator[SubgraphContext, Any, Any]: - assert all(isinstance(p.type, ts.DataType) for p in params) - scope_symbols = {str(p.id): p.type for p in params} - ctx = SubgraphContext(sdfg, state, scope_symbols) # type: ignore[arg-type] - try: - yield ctx - finally: - ctx.remove_isolated_nodes() - - @contextlib.contextmanager - def setup_ctx_in_new_state( - self, - ctx: SubgraphContext, - state: dace.SDFGState, - ) -> Generator[SubgraphContext, Any, Any]: - new_ctx = SubgraphContext(ctx.sdfg, state, ctx.scope_symbols.copy()) - try: - yield new_ctx - finally: - new_ctx.remove_isolated_nodes() - def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: nsdfg_list = [ nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) @@ -1060,18 +977,17 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: # We do this to allow lowering of scalar expressions in let-statements, # that only depend on scalar parameters, as dace symbolic expressions # mapped to symbols on a nested SDFG. - assert all(isinstance(p.type, ts.DataType) for p in node.params) sdfg_arg_names = self._add_sdfg_params( sdfg, node.params, symbolic_params=None, use_transient_storage=False ) # visit one statement at a time and expand the SDFG from the current head state + scope_symbols = {str(p.id): p.type for p in node.params if isinstance(p.type, ts.DataType)} for i, stmt in enumerate(node.body): # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") head_state._debuginfo = gtir_to_sdfg_utils.debug_info(stmt, default=sdfg.debuginfo) - with self.setup_ctx(sdfg, head_state, node.params) as ctx: - head_state = self.visit(stmt, ctx=ctx) + head_state = self.visit(stmt, ctx=SubgraphContext(sdfg, head_state, scope_symbols)) # remove unused connectivity tables (by design, arrays are marked as non-transient when they are used) for nsdfg in sdfg.all_sdfgs_recursive(): @@ -1134,15 +1050,18 @@ def _visit_target( # In order to support inout argument, write the result in separate next state # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X target_state = ctx.sdfg.add_state_after(ctx.state, f"post_{ctx.state.label}") - with self.setup_ctx_in_new_state(ctx, target_state) as target_ctx: - # The target expression could be a `SymRef` to an output field or a `make_tuple` - # expression in case the statement returns more than one field. - target_tree = self._visit_expression(stmt.target, target_ctx, use_temp=False) - gtx_utils.tree_map( - lambda source, target, target_domain: _visit_target( - source, target, target_domain, target_state - ) - )(source_tree, target_tree, domain) + # The target expression could be a `SymRef` to an output field or a `make_tuple` + # expression in case the statement returns more than one field. + target_tree = self._visit_expression( + stmt.target, + ctx=SubgraphContext(ctx.sdfg, target_state, ctx.scope_symbols), + use_temp=False, + ) + gtx_utils.tree_map( + lambda source, target, target_domain: _visit_target( + source, target, target_domain, target_state + ) + )(source_tree, target_tree, domain) return target_state @@ -1229,7 +1148,7 @@ def visit_Lambda( lambda_arg_nodes, symbolic_args = flatten_tuple_args(args.items()) # lower let-statement lambda node as a nested SDFG - lambda_sdfg, lambda_params = self.setup_nested_sdfg( + lambda_ctx = self.setup_nested_context( expr=node, sdfg_name="lambda", parent_ctx=ctx, @@ -1237,34 +1156,45 @@ def visit_Lambda( symbolic_inputs=set(symbolic_args.keys()), capture_scope_symbols=True, ) - lambda_state = lambda_sdfg.add_state("lambda", is_start_block=True) - - with self.setup_ctx(lambda_sdfg, lambda_state, lambda_params) as lambda_ctx: - lambda_result = self.visit(node.expr, ctx=lambda_ctx) - - # A let-lambda is allowed to capture GTIR-symbols from the outer scope, - # therefore we call `add_nested_sdfg()` with `capture_outer_data=True`. - nsdfg_node, input_memlets = self.add_nested_sdfg( - node=node, - inner_ctx=lambda_ctx, - outer_ctx=ctx, - symbolic_args=symbolic_args, - data_args=lambda_arg_nodes, - inner_result=lambda_result, - capture_outer_data=True, - ) - # In this loop we call `pop()`, whenever an argument is connected to an input - # connector on the nested SDFG, so the corresponding node is removed from the dictionary. - for input_connector, memlet in input_memlets.items(): - if input_connector in lambda_arg_nodes: - arg_node = lambda_arg_nodes.pop(input_connector) - assert arg_node is not None - src_node = arg_node.dc_node - else: - src_node = ctx.state.add_access(memlet.data) + lambda_result = self.visit(node.expr, ctx=lambda_ctx) + + # A let-lambda is allowed to capture GTIR-symbols from the outer scope, + # therefore we call `add_nested_sdfg()` with `capture_outer_data=True`. + nsdfg_node, input_memlets = self.add_nested_sdfg( + node=node, + inner_ctx=lambda_ctx, + outer_ctx=ctx, + symbolic_args=symbolic_args, + data_args=lambda_arg_nodes, + inner_result=lambda_result, + capture_outer_data=True, + ) - ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + # In this loop we call `pop()`, whenever an argument is connected to an input + # connector on the nested SDFG, so the corresponding node is removed from the dictionary. + for input_connector, memlet in input_memlets.items(): + if input_connector in lambda_arg_nodes: + arg_node = lambda_arg_nodes.pop(input_connector) + assert arg_node is not None + src_node = arg_node.dc_node + else: + src_node = ctx.state.add_access(memlet.data) + + ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + + # We can now safely remove all remaining arguments, because unused. Note + # that we only consider global access nodes, because the only goal of this + # cleanup is to remove isolated nodes. At this stage, temporary input nodes + # should not appear as isolated nodes, because they are supposed to contain + # the result of some argument expression. + if unused_access_nodes := [ + arg_node.dc_node + for arg_node in lambda_arg_nodes.values() + if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) + ]: + assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) + ctx.state.remove_nodes_from(unused_access_nodes) def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, @@ -1281,7 +1211,7 @@ def construct_output_for_nested_sdfg( The same happens to symbols available in the lambda context but not explicitly passed as lambda arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ - if not inner_data.dc_node.desc(lambda_sdfg).transient: + if not inner_data.dc_node.desc(lambda_ctx.sdfg).transient: # This is inout data: some global associated to an input connector # is also a sink node of the lambda dataflow. This can happen, for # example, when the lambda constructs a tuple of some input fields. @@ -1300,7 +1230,7 @@ def construct_output_for_nested_sdfg( outer_data = ctx.map_nsdfg_field( sdfg_builder=self, nsdfg_field=inner_data, - nsdfg=lambda_sdfg, + nsdfg=lambda_ctx.sdfg, symbol_mapping=nsdfg_node.symbol_mapping, ) ctx.state.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 7b0534160e..4cd5ac0396 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -426,34 +426,34 @@ def translate_if( # expect true branch as second argument true_state = ctx.sdfg.add_state(ctx.state.label + "_true_branch") + tbranch_ctx = gtir_to_sdfg.SubgraphContext(ctx.sdfg, true_state, ctx.scope_symbols) ctx.sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=if_stmt)) ctx.sdfg.add_edge(true_state, ctx.state, dace.InterstateEdge()) # and false branch as third argument false_state = ctx.sdfg.add_state(ctx.state.label + "_false_branch") + fbranch_ctx = gtir_to_sdfg.SubgraphContext(ctx.sdfg, false_state, ctx.scope_symbols) ctx.sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=f"not({if_stmt})")) ctx.sdfg.add_edge(false_state, ctx.state, dace.InterstateEdge()) - with sdfg_builder.setup_ctx_in_new_state(ctx, true_state) as tbranch_ctx: - with sdfg_builder.setup_ctx_in_new_state(ctx, false_state) as fbranch_ctx: - true_br_result = sdfg_builder.visit(true_expr, ctx=tbranch_ctx) - false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) - - node_output = gtx_utils.tree_map( - lambda domain, true_br, false_br: _construct_if_branch_output( - ctx, sdfg_builder, domain, true_br, false_br - ) - )( - node.annex.domain, - true_br_result, - false_br_result, - ) - gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( - true_br_result, node_output - ) - gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( - false_br_result, node_output - ) + true_br_result = sdfg_builder.visit(true_expr, ctx=tbranch_ctx) + false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) + + node_output = gtx_utils.tree_map( + lambda domain, true_br, false_br: _construct_if_branch_output( + ctx, sdfg_builder, domain, true_br, false_br + ) + )( + node.annex.domain, + true_br_result, + false_br_result, + ) + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( + true_br_result, node_output + ) + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( + false_br_result, node_output + ) return node_output diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index 8769c3fd57..dc60d352f9 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -347,7 +347,7 @@ def _lower_lambda_to_nested_sdfg( ) ) # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. - lambda_sdfg, _ = sdfg_builder.setup_nested_sdfg( + lambda_ctx = sdfg_builder.setup_nested_context( lambda_node, "scan", ctx, @@ -359,10 +359,10 @@ def _lower_lambda_to_nested_sdfg( # We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`. # This property is used by pattern matching in SDFG transformation framework # to skip those transformations that do not yet support control flow blocks. - lambda_sdfg.using_explicit_control_flow = True + lambda_ctx.sdfg.using_explicit_control_flow = True # We use the entry state for initialization of the scan carry variable. - init_state = lambda_sdfg.add_state("init", is_start_block=True) + init_state = lambda_ctx.state # use the vertical dimension in the domain as scan dimension scan_domain = next(r for r in field_domain if sdfg_builder.is_column_axis(r.dim)) @@ -410,27 +410,29 @@ def get_scan_output_shape( update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", inverted=False, ) - lambda_sdfg.add_node(scan_loop, ensure_unique_name=True) - lambda_sdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + lambda_ctx.sdfg.add_node(scan_loop, ensure_unique_name=True) + lambda_ctx.sdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) # Inside the loop region, create a 'compute' and an 'update' state. # The body of the 'compute' state implements the stencil expression for one vertical level. # The 'update' state writes the value computed by the stencil into the scan carry variable, # in order to make it available to the next vertical level. compute_state = scan_loop.add_state("scan_compute") + compute_ctx = gtir_to_sdfg.SubgraphContext( + lambda_ctx.sdfg, compute_state, lambda_ctx.scope_symbols + ) update_state = scan_loop.add_state_after(compute_state, "scan_update") - with sdfg_builder.setup_ctx(lambda_sdfg, compute_state, lambda_params) as compute_ctx: - # inside the 'compute' state, visit the list of arguments to be passed to the stencil - stencil_args = [ - _parse_scan_fieldop_arg(im.ref(p.id), compute_ctx, sdfg_builder, field_domain) - for p in lambda_node.params - ] - # stil inside the 'compute' state, generate the dataflow representing the stencil - # to be applied on the horizontal domain - lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( - compute_ctx.sdfg, compute_ctx.state, sdfg_builder, lambda_node, stencil_args - ) + # inside the 'compute' state, visit the list of arguments to be passed to the stencil + stencil_args = [ + _parse_scan_fieldop_arg(im.ref(p.id), compute_ctx, sdfg_builder, field_domain) + for p in lambda_node.params + ] + # stil inside the 'compute' state, generate the dataflow representing the stencil + # to be applied on the horizontal domain + lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( + compute_ctx.sdfg, compute_ctx.state, sdfg_builder, lambda_node, stencil_args + ) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region for edge in lambda_input_edges: @@ -451,15 +453,15 @@ def get_scan_output_shape( def init_scan_carry(sym: gtir.Sym) -> None: scan_carry_dataname = str(sym.id) - scan_carry_desc = lambda_sdfg.data(scan_carry_dataname) + scan_carry_desc = lambda_ctx.sdfg.data(scan_carry_dataname) input_scan_carry_dataname = _scan_input_name(scan_carry_dataname) input_scan_carry_desc = scan_carry_desc.clone() - lambda_sdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) + lambda_ctx.sdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) scan_carry_desc.transient = True init_state.add_nedge( init_state.add_access(input_scan_carry_dataname), init_state.add_access(scan_carry_dataname), - lambda_sdfg.make_array_memlet(input_scan_carry_dataname), + lambda_ctx.sdfg.make_array_memlet(input_scan_carry_dataname), ) if isinstance(scan_carry_input, tuple): @@ -484,7 +486,7 @@ def connect_scan_output( else: raise NotImplementedError("scan with list output is not supported.") scan_result_data = scan_result.dc_node.data - scan_result_desc = scan_result.dc_node.desc(lambda_sdfg) + scan_result_desc = scan_result.dc_node.desc(lambda_ctx.sdfg) scan_result_subset = dace_subsets.Range.from_array(scan_result_desc) # `sym` represents the global output data, that is the nested-SDFG output connector @@ -493,7 +495,7 @@ def connect_scan_output( # Note that we set `transient=True` because the lowering expects the dataflow # of nested SDDFG to write to some internal temporary nodes. These data elements # should be turned into globals by the caller and handled as output connections. - lambda_sdfg.add_array(output, scan_output_shape, scan_result_desc.dtype, transient=True) + lambda_ctx.sdfg.add_array(output, scan_output_shape, scan_result_desc.dtype, transient=True) output_node = compute_state.add_access(output) # in the 'compute' state, we write the current vertical level data to the output field @@ -523,6 +525,17 @@ def connect_scan_output( lambda_result, lambda_result_shape, scan_carry_input ) + # Corner case where the scan computation, on one level, does not depend on + # the result from previous level. In this case, the state information from + # previous level is not used, therefore we could find isolated access nodes. + # In case of tuples, it might be that only some of the fields are used. + # In case of scalars, this is probably a misuse of scan in application code: + # it could have been represented as a pure field operator. + for arg in gtx_utils.flatten_nested_tuple((stencil_args[0],)): + state_node = arg.dc_node + if compute_state.degree(state_node) == 0: + compute_state.remove_node(state_node) + return compute_ctx, lambda_output From 6114c1341245bafb9a42306d841abcecabcea4aa Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Sat, 6 Dec 2025 10:32:31 +0100 Subject: [PATCH 13/28] edit doc string --- .../runners/dace/lowering/gtir_to_sdfg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 9e17883ff0..3db132574b 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1012,11 +1012,13 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: def visit_SetAt(self, stmt: gtir.SetAt, ctx: SubgraphContext) -> dace.SDFGState: """Visits a `SetAt` statement expression and writes the local result to some external storage. - Each statement expression results in some sort of dataflow gragh writing to temporary storage. - The translation of `SetAt` ensures that the result is written back to the target external storage. + Each statement expression results in some sort of dataflow gragh writing to temporary + storage, inside the current state. The translation of `SetAt` ensures that the result + is written back to the target external storage, in a new state after the current one, + which will become the new head state. Returns: - The SDFG head state, eventually updated if the target write requires a new state. + The new SDFG head state, where the target fields are written. """ def _visit_target( From 2dbda35f0721d87a9ff3a369be2d2509c780d6c2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 8 Dec 2025 21:27:02 +0100 Subject: [PATCH 14/28] skip dace state fusion --- .../runners/dace/transformations/simplify.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index 8224cae965..f7f5bbfd11 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -126,7 +126,7 @@ def gt_simplify( validate=False, validate_all=validate_all, verbose=False, - skip=(skip | {"InlineSDFGs"}), + skip=(skip | {"FuseStates", "InlineSDFGs"}), ).apply_pass(sdfg, {}) if simplify_res is not None: @@ -134,10 +134,9 @@ def gt_simplify( result = result or {} result.update(simplify_res) - # Note that it is not nice that we run the state fusion twice, but to be fully - # effective there are some preparatory transformations that are run in DaCe - # simplify. So the GT4Py transformation is more like a clean up to handle - # the parts DaCe is not able to do. + # The DaCe state fusion pass was skipped above, and we rely on GT4Py state + # fusion, because of incorrect SDFG results (we observed that double buffer + # was removed, where needed). if "FuseStates" not in skip: fuse_state_res = sdfg.apply_transformations_repeated( [gtx_transformations.GT4PyStateFusion], From c7c6ba60e914c99d7c63cef06b247c0aa63b3c0f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 9 Dec 2025 12:35:10 +0100 Subject: [PATCH 15/28] Revert "skip dace state fusion" This reverts commit 2dbda35f0721d87a9ff3a369be2d2509c780d6c2. --- .../runners/dace/transformations/simplify.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index f7f5bbfd11..8224cae965 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -126,7 +126,7 @@ def gt_simplify( validate=False, validate_all=validate_all, verbose=False, - skip=(skip | {"FuseStates", "InlineSDFGs"}), + skip=(skip | {"InlineSDFGs"}), ).apply_pass(sdfg, {}) if simplify_res is not None: @@ -134,9 +134,10 @@ def gt_simplify( result = result or {} result.update(simplify_res) - # The DaCe state fusion pass was skipped above, and we rely on GT4Py state - # fusion, because of incorrect SDFG results (we observed that double buffer - # was removed, where needed). + # Note that it is not nice that we run the state fusion twice, but to be fully + # effective there are some preparatory transformations that are run in DaCe + # simplify. So the GT4Py transformation is more like a clean up to handle + # the parts DaCe is not able to do. if "FuseStates" not in skip: fuse_state_res = sdfg.apply_transformations_repeated( [gtx_transformations.GT4PyStateFusion], From 30d2436874fba9d66bb5f30a8f17569ea811f96f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 15 Dec 2025 10:21:36 +0100 Subject: [PATCH 16/28] cleanup arg nodes in lambda SDFG --- .../runners/dace/lowering/gtir_to_sdfg.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 91749d3760..b4aad18268 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1185,6 +1185,19 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + # We can now safely remove all remaining arguments, because unused. + if unused_access_nodes := [ + arg_node.dc_node for arg_node in lambda_arg_nodes.values() if arg_node is not None + ]: + # We expect only global access nodes, at this stage: temporary input nodes + # should not appear as isolated nodes because they are supposed to contain + # the result of some argument expression. + assert all( + ctx.state.degree(access_node) == 0 and not access_node.desc(ctx.sdfg).transient + for access_node in unused_access_nodes + ) + ctx.state.remove_nodes_from(unused_access_nodes) + def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: From 109fd3f2b0f32704a5f6aa3822ff4879adb08f76 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 15 Dec 2025 11:15:14 +0100 Subject: [PATCH 17/28] relax assert on unused args in lambda SDFG --- .../runners/dace/lowering/gtir_to_sdfg.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index b4aad18268..40efd20dcc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1189,13 +1189,7 @@ def visit_Lambda( if unused_access_nodes := [ arg_node.dc_node for arg_node in lambda_arg_nodes.values() if arg_node is not None ]: - # We expect only global access nodes, at this stage: temporary input nodes - # should not appear as isolated nodes because they are supposed to contain - # the result of some argument expression. - assert all( - ctx.state.degree(access_node) == 0 and not access_node.desc(ctx.sdfg).transient - for access_node in unused_access_nodes - ) + assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) ctx.state.remove_nodes_from(unused_access_nodes) def construct_output_for_nested_sdfg( From bac4ddcdd35fd90137c5ebd1c058237bf76dc7d8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 15 Dec 2025 13:28:32 +0100 Subject: [PATCH 18/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 40efd20dcc..3db132574b 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1185,9 +1185,15 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - # We can now safely remove all remaining arguments, because unused. + # We can now safely remove all remaining arguments, because unused. Note + # that we only consider global access nodes, because the only goal of this + # cleanup is to remove isolated nodes. At this stage, temporary input nodes + # should not appear as isolated nodes, because they are supposed to contain + # the result of some argument expression. if unused_access_nodes := [ - arg_node.dc_node for arg_node in lambda_arg_nodes.values() if arg_node is not None + arg_node.dc_node + for arg_node in lambda_arg_nodes.values() + if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) ]: assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) ctx.state.remove_nodes_from(unused_access_nodes) From 980afb22f2c247836c10d3058381d4818265f9a9 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 16 Dec 2025 22:21:40 +0100 Subject: [PATCH 19/28] edit --- .../runners/dace/lowering/gtir_to_sdfg.py | 7 ++++ .../dace_tests/test_gtir_to_sdfg.py | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index cab489f3f8..e915dc10d6 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1235,6 +1235,13 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + # We can now safely remove all remaining arguments, because unused. + if unused_access_nodes := [ + arg_node.dc_node for arg_node in lambda_arg_nodes.values() if arg_node is not None + ]: + assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) + ctx.state.remove_nodes_from(unused_access_nodes) + def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index cd19c98f5b..e935fca2c9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1647,6 +1647,38 @@ def test_gtir_let_lambda(): assert np.allclose(b, ref) +def test_gtir_let_lambda_unused_arg(): + testee = gtir.Program( + id="let_lambda_unused_arg", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + # Arg 'xᐞ1' is used inside the let-lambda, 'yᐞ1' is not. + expr=im.let(("xᐞ1", im.op_as_fieldop("multiplies")("x", 3.0)), ("yᐞ1", "y"))( + im.op_as_fieldop("multiplies")("xᐞ1", 2.0) + ), + domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "z", [IDim]), + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.random.rand(N) + + sdfg = build_dace_sdfg(testee, {}) + + sdfg(a, b, c, **FSYMBOLS) + assert np.allclose(c, a * 6.0) + + def test_gtir_let_lambda_scalar_expression(): domain_inner = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, "size_inner")}) domain_outer = im.get_field_domain( From 4e4fdea16685452dad09c5de4e5036cfebb7b216 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 17 Dec 2025 09:56:32 +0100 Subject: [PATCH 20/28] fix --- .../runners/dace/lowering/gtir_to_sdfg.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index e915dc10d6..40b04a96a6 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1236,8 +1236,13 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) # We can now safely remove all remaining arguments, because unused. + # Note that we only consider global access nodes, because the goal of this + # cleanup is to remove isolated nodes. At this stage, temporary data nodes + # are supposed to contain the result of some argument expression. if unused_access_nodes := [ - arg_node.dc_node for arg_node in lambda_arg_nodes.values() if arg_node is not None + arg_node.dc_node + for arg_node in lambda_arg_nodes.values() + if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) ]: assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) ctx.state.remove_nodes_from(unused_access_nodes) From d07ce80e7312356107152f9887d8b3393ec07ae8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 17 Dec 2025 10:03:21 +0100 Subject: [PATCH 21/28] undo extra change --- .../runners/dace/lowering/gtir_to_sdfg.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 3db132574b..91749d3760 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1185,19 +1185,6 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - # We can now safely remove all remaining arguments, because unused. Note - # that we only consider global access nodes, because the only goal of this - # cleanup is to remove isolated nodes. At this stage, temporary input nodes - # should not appear as isolated nodes, because they are supposed to contain - # the result of some argument expression. - if unused_access_nodes := [ - arg_node.dc_node - for arg_node in lambda_arg_nodes.values() - if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) - ]: - assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) - ctx.state.remove_nodes_from(unused_access_nodes) - def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: From 5589af7532cb6c23b6fd53d1635495e7819a648b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 12:08:31 +0100 Subject: [PATCH 22/28] better fix --- .../runners/dace/lowering/gtir_to_sdfg.py | 12 ------------ .../runners/dace/lowering/gtir_to_sdfg_primitives.py | 2 ++ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 40b04a96a6..cab489f3f8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1235,18 +1235,6 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) - # We can now safely remove all remaining arguments, because unused. - # Note that we only consider global access nodes, because the goal of this - # cleanup is to remove isolated nodes. At this stage, temporary data nodes - # are supposed to contain the result of some argument expression. - if unused_access_nodes := [ - arg_node.dc_node - for arg_node in lambda_arg_nodes.values() - if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) - ]: - assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) - ctx.state.remove_nodes_from(unused_access_nodes) - def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index be0ff6130d..e6dde691c8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -723,6 +723,8 @@ def translate_symbol_ref( ) -> gtir_to_sdfg_types.FieldopResult: """Generates the dataflow subgraph for a `ir.SymRef` node.""" assert isinstance(node, gtir.SymRef) + if getattr(node.annex, "domain", None) == infer_domain.DomainAccessDescriptor.NEVER: + return None symbol_name = str(node.id) # we retrieve the type of the symbol in the GT4Py prgram From 907eb48b94b42eb9fe19ec2bff5415cbf0f7c195 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 16:35:41 +0100 Subject: [PATCH 23/28] minor edit --- .../runners/dace/lowering/gtir_to_sdfg_primitives.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index e6dde691c8..d7a77a7348 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -618,8 +618,8 @@ def translate_tuple_get( index = int(node.args[0].value) data_nodes = sdfg_builder.visit(node.args[1], ctx=ctx) - if isinstance(data_nodes, gtir_to_sdfg_types.FieldopData): - raise ValueError(f"Invalid tuple expression {node}") + if not isinstance(data_nodes, tuple): + raise ValueError(f"Invalid tuple expression {node}.") # Now we remove the tuple fields that are not used, to avoid an SDFG validation # error because of isolated access nodes. unused_data_nodes = gtx_utils.flatten_nested_tuple( @@ -723,7 +723,10 @@ def translate_symbol_ref( ) -> gtir_to_sdfg_types.FieldopResult: """Generates the dataflow subgraph for a `ir.SymRef` node.""" assert isinstance(node, gtir.SymRef) - if getattr(node.annex, "domain", None) == infer_domain.DomainAccessDescriptor.NEVER: + + # If the symbol is not used, the domain is set to 'NEVER' in node annex. + node_domain = getattr(node.annex, "domain", infer_domain.DomainAccessDescriptor.UNKNOWN) + if node_domain == infer_domain.DomainAccessDescriptor.NEVER: return None symbol_name = str(node.id) From 1d6bce1720595a89e346eb94d5403bf657c4e5e4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Dec 2025 10:06:28 +0100 Subject: [PATCH 24/28] re-introduce cleanup stage --- .../runners/dace/lowering/gtir_to_sdfg.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index cab489f3f8..9cc317992c 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -1235,6 +1235,20 @@ def visit_Lambda( ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) + # We can now safely remove all remaining arguments, because unused. + # It can happen that a let-lambda argument is not used, in which case it + # should be `None` at this stage, but the symbol passed as argument is captured + # by the let-lambda expression and thus can be accessed directly inside the + # inner scope. The domain in the annex of this symbol is not `NEVER`, and + # an access node is created, so we find an isolated access node here. + if unused_access_nodes := [ + arg_node.dc_node + for arg_node in lambda_arg_nodes.values() + if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient) + ]: + assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes) + ctx.state.remove_nodes_from(unused_access_nodes) + def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, ) -> gtir_to_sdfg_types.FieldopData: From 0ced0f1b62e2b869bd42309605e8ce1ca43137a5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Dec 2025 12:59:19 +0100 Subject: [PATCH 25/28] apply review comments --- .../runners/dace/lowering/gtir_to_sdfg.py | 20 +++++++++++-------- .../dace/lowering/gtir_to_sdfg_primitives.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index b7322d767d..6fc1d83ae1 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -80,7 +80,8 @@ def add_temp_array_like( self, sdfg: dace.SDFG, datadesc: dace.data.Array ) -> tuple[str, dace.data.Scalar]: """Add a temporary array to the SDFG.""" - return sdfg.add_temp_transient_like(datadesc, name=self.unique_temp_name()) + temp_name = self.unique_temp_name() + return sdfg.add_temp_transient_like(datadesc, name=temp_name) def add_temp_scalar( self, sdfg: dace.SDFG, dtype: dace.dtypes.typeclass @@ -182,6 +183,9 @@ class SubgraphContext: state: dace.SDFGState scope_symbols: dict[str, ts.DataType] + def get_symbol_type(self, symbol_name: str) -> ts.DataType: + return self.scope_symbols[symbol_name] + def copy_data( self, sdfg_builder: SDFGBuilder, @@ -585,8 +589,8 @@ def setup_nested_context( gtir.Sym(id=name, type=lambda_symbols[name]) for name in sorted(lambda_symbols.keys()) ] - sdfg = dace.SDFG(name=self.unique_nsdfg_name(parent_ctx.sdfg, sdfg_name)) - sdfg.debuginfo = gtir_to_sdfg_utils.debug_info(expr, default=parent_ctx.sdfg.debuginfo) + nsdfg = dace.SDFG(name=self.unique_nsdfg_name(parent_ctx.sdfg, sdfg_name)) + nsdfg.debuginfo = gtir_to_sdfg_utils.debug_info(expr, default=parent_ctx.sdfg.debuginfo) # All GTIR-symbols accessed in domain expressions by the lambda need to be # represented as dace symbols. @@ -608,14 +612,14 @@ def setup_nested_context( # will be turned into global by setting `transient=False`. In this way, we remove # all unused (and possibly undefined) input arguments. self._add_sdfg_params( - sdfg, + nsdfg, node_params=input_params, symbolic_params=(domain_symbols | mapped_symbols | symbolic_inputs), use_transient_storage=True, ) - state = sdfg.add_state(f"{sdfg_name}_entry") - return SubgraphContext(sdfg, state, lambda_symbols) + nsdfg_entry_state = nsdfg.add_state(f"{sdfg_name}_entry", is_start_block=True) + return SubgraphContext(nsdfg, nsdfg_entry_state, lambda_symbols) def add_nested_sdfg( self, @@ -1018,7 +1022,7 @@ def visit_SetAt(self, stmt: gtir.SetAt, ctx: SubgraphContext) -> dace.SDFGState: which will become the new head state. Returns: - The new SDFG head state, where the target fields are written. + The new SDFG head state, where the writes to the AccessNodes of the field operator results are located. """ def _visit_target( @@ -1050,7 +1054,7 @@ def _visit_target( source_tree = self._visit_expression(stmt.expr, ctx) # In order to support inout argument, write the result in separate next state - # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X + # this is needed to avoid indeterministic behavior for expressions like: X, Y = X + 1, X target_state = ctx.sdfg.add_state_after(ctx.state, f"post_{ctx.state.label}") # The target expression could be a `SymRef` to an output field or a `make_tuple` # expression in case the statement returns more than one field. diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 1b7806e22f..7a780e41d8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -728,7 +728,7 @@ def translate_symbol_ref( symbol_name = str(node.id) # we retrieve the type of the symbol in the GT4Py prgram - gt_symbol_type = ctx.scope_symbols[symbol_name] + gt_symbol_type = ctx.get_symbol_type(symbol_name) # Create new access node in current state. It is possible that multiple # access nodes are created in one state for the same data container. From b3b6202f73578bfc3a0486558bf43ab31f2c5ea4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Dec 2025 17:14:03 +0100 Subject: [PATCH 26/28] edit doc string --- .../program_processors/runners/dace/lowering/gtir_to_sdfg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 6fc1d83ae1..c773d12aae 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -476,6 +476,10 @@ def flatten_tuple_args( class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. + A single instance of this visitor is used for the entire lowering, across all + levels of nested SDFGs. For each nested level, a new `SubgraphContext` is setup + with the data symbols in scope. + This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. Each statement is translated to a taskgraph inside a separate state. Statement states are chained From 9958b6a626512718bd9881c4b5cce2ee8998e9ee Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Dec 2025 18:00:11 +0100 Subject: [PATCH 27/28] re-introduce check for inout fields and remove target state if not used --- .../runners/dace/lowering/gtir_to_sdfg.py | 63 +++++++++++++++---- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index c773d12aae..42db9becaa 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -26,7 +26,11 @@ from gt4py.eve import concepts from gt4py.next import common as gtx_common, config as gtx_config, utils as gtx_utils from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args @@ -186,6 +190,19 @@ class SubgraphContext: def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.scope_symbols[symbol_name] + def input_data(self) -> set[str]: + flat_symbols = [] + for sym_name, sym_type in self.scope_symbols.items(): + if isinstance(sym_type, ts.TupleType): + flat_symbols.extend(gtir_to_sdfg_utils.flatten_tuple_fields(sym_name, sym_type)) + else: + flat_symbols.append(im.sym(sym_name, sym_type)) + + used_data = {access_node.data for access_node in self.state.data_nodes()} + input_data = {data_name for sym in flat_symbols if (data_name := str(sym.id)) in used_data} + assert not any(self.sdfg.arrays[data_name].transient for data_name in input_data) + return input_data + def copy_data( self, sdfg_builder: SDFGBuilder, @@ -1034,28 +1051,40 @@ def _visit_target( target: gtir_to_sdfg_types.FieldopData, target_domain: domain_utils.SymbolicDomain, target_state: dace.SDFGState, + is_target_inout: bool, ) -> None: + assert source.dc_node.desc(ctx.sdfg).transient assert source.gt_type == target.gt_type field_domain = gtir_domain.get_field_domain(target_domain) source_subset = _make_access_index_for_field(field_domain, source) target_subset = _make_access_index_for_field(field_domain, target) - target_state.add_nedge( - # create in the target state new access nodes to the field operator result - target_state.add_access(source.dc_node.data), - target.dc_node, - dace.Memlet( - data=target.dc_node.data, subset=target_subset, other_subset=source_subset - ), - ) - if ctx.state.degree(source.dc_node) == 0: - ctx.state.remove_node(source.dc_node) + if is_target_inout: + # write the field operator result in the separate target state + target_state.add_nedge( + target_state.add_access(source.dc_node.data), + target.dc_node, + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), + ) + else: + # do not use the target state, write the result inside the current context + ctx.state.add_nedge( + source.dc_node, + ctx.state.add_access(target.dc_node.data), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), + ) + target_state.remove_node(target.dc_node) # Visit the domain expression. domain = gtir_domain.extract_target_domain(stmt.domain) # Visit the field operator expression. source_tree = self._visit_expression(stmt.expr, ctx) + ctx_input_data = ctx.input_data() # In order to support inout argument, write the result in separate next state # this is needed to avoid indeterministic behavior for expressions like: X, Y = X + 1, X @@ -1069,11 +1098,19 @@ def _visit_target( ) gtx_utils.tree_map( lambda source, target, target_domain: _visit_target( - source, target, target_domain, target_state + source, + target, + target_domain, + target_state, + is_target_inout=(target.dc_node.data in ctx_input_data), ) )(source_tree, target_tree, domain) - return target_state + if target_state.is_empty(): + ctx.sdfg.remove_node(target_state) + return ctx.state + else: + return target_state def visit_FunCall( self, From 7e4547c22ed33ef515aa4acb5446c035a87a3940 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 8 Jan 2026 12:07:23 +0100 Subject: [PATCH 28/28] address review comments --- .../runners/dace/lowering/gtir_to_sdfg.py | 28 ++++++++++++------- .../dace/lowering/gtir_to_sdfg_primitives.py | 1 + 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 42db9becaa..cabd22aa5a 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -190,7 +190,13 @@ class SubgraphContext: def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.scope_symbols[symbol_name] - def input_data(self) -> set[str]: + def get_input_data(self) -> set[str]: + """Retrieve the names of arrays and scalars which are input data to this context. + + Note that only the data that is currently used in this context, in other words + accessed in this SDFG state, is included in the returned set. + """ + flat_symbols = [] for sym_name, sym_type in self.scope_symbols.items(): if isinstance(sym_type, ts.TupleType): @@ -494,8 +500,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. A single instance of this visitor is used for the entire lowering, across all - levels of nested SDFGs. For each nested level, a new `SubgraphContext` is setup - with the data symbols in scope. + levels of nested SDFGs. The level-specific information, including the data symbols + available in the lowering scope, is stored inside a `SubgraphContext` object + that can be accessed by the visitor methods. This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. @@ -1051,7 +1058,7 @@ def _visit_target( target: gtir_to_sdfg_types.FieldopData, target_domain: domain_utils.SymbolicDomain, target_state: dace.SDFGState, - is_target_inout: bool, + target_is_also_used_as_input: bool, ) -> None: assert source.dc_node.desc(ctx.sdfg).transient assert source.gt_type == target.gt_type @@ -1059,13 +1066,14 @@ def _visit_target( source_subset = _make_access_index_for_field(field_domain, source) target_subset = _make_access_index_for_field(field_domain, target) - if is_target_inout: - # write the field operator result in the separate target state + if target_is_also_used_as_input: + # write the field operator result in the separate target state, + # in order to ensure the correct dataflow for write after read target_state.add_nedge( target_state.add_access(source.dc_node.data), target.dc_node, dace.Memlet( - data=target.dc_node.data, subset=target_subset, other_subset=source_subset + data=source.dc_node.data, subset=source_subset, other_subset=target_subset ), ) else: @@ -1074,7 +1082,7 @@ def _visit_target( source.dc_node, ctx.state.add_access(target.dc_node.data), dace.Memlet( - data=target.dc_node.data, subset=target_subset, other_subset=source_subset + data=source.dc_node.data, subset=source_subset, other_subset=target_subset ), ) target_state.remove_node(target.dc_node) @@ -1084,7 +1092,7 @@ def _visit_target( # Visit the field operator expression. source_tree = self._visit_expression(stmt.expr, ctx) - ctx_input_data = ctx.input_data() + ctx_input_data = ctx.get_input_data() # In order to support inout argument, write the result in separate next state # this is needed to avoid indeterministic behavior for expressions like: X, Y = X + 1, X @@ -1102,7 +1110,7 @@ def _visit_target( target, target_domain, target_state, - is_target_inout=(target.dc_node.data in ctx_input_data), + target_is_also_used_as_input=(target.dc_node.data in ctx_input_data), ) )(source_tree, target_tree, domain) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 7a780e41d8..9fc1ebab4c 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -660,6 +660,7 @@ def translate_scalar_expr( # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` visit_expr = str(arg_expr.id) in ctx.scope_symbols else: + # any other kind of node should always be visited visit_expr = True if visit_expr: