From 31e767b14e1da6a721c8b08050a02d4152f28ff1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 1 Oct 2025 13:22:22 +0200 Subject: [PATCH 01/23] switch to gt4py main --- pyproject.toml | 2 +- uv.lock | 46 ++++++++++++++++------------------------------ 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 062b3dbda1..ca06875e5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -319,7 +319,7 @@ url = "https://test.pypi.org/simple/" [tool.uv.sources] dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_09_25"} # ghex = {git = "https://github.com/ghex-org/GHEX.git", branch = "master"} -# gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} +gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} # gt4py = {index = "test.pypi"} icon4py-atmosphere-advection = {workspace = true} icon4py-atmosphere-diffusion = {workspace = true} diff --git a/uv.lock b/uv.lock index eb739d8a97..0fe3b5aa9c 100644 --- a/uv.lock +++ b/uv.lock @@ -1078,15 +1078,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a", size = 119418, upload-time = "2024-09-29T00:03:19.344Z" }, ] -[[package]] -name = "diskcache" -version = "5.6.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, -] - [[package]] name = "distlib" version = "0.3.9" @@ -1428,8 +1419,8 @@ wheels = [ [[package]] name = "gt4py" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } +version = "1.0.9.post15+bb7826e9" +source = { git = "https://github.com/GridTools/gt4py?branch=main#bb7826e9b3fdf87a2617798eeb7b239a6f57a37c" } dependencies = [ { name = "attrs" }, { name = "black" }, @@ -1440,7 +1431,6 @@ dependencies = [ { name = "cytoolz" }, { name = "deepdiff" }, { name = "devtools" }, - { name = "diskcache" }, { name = "factory-boy" }, { name = "filelock" }, { name = "frozendict" }, @@ -1460,10 +1450,6 @@ dependencies = [ { name = "versioningit" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a9/1c/2577d3b2380dc3e5451432a96de730ce4fdf4b602f63b9b989d0373f9ed4/gt4py-1.0.9.tar.gz", hash = "sha256:8b7d1eab14b1d093d1db943de8d8a759e9b979464892533d31c9ff9d6abc53ca", size = 724634, upload-time = "2025-09-12T12:30:50.244Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/52/bc/e49d6dfc6169ea10dc10ed723b281aa841d7c644297c95a427455317638a/gt4py-1.0.9-py3-none-any.whl", hash = "sha256:1ef45657dd470e77bbe0f5cc9bf3c17493efc0df498ee74897069b2cbf6ac9cb", size = 925459, upload-time = "2025-09-12T12:30:48.731Z" }, -] [package.optional-dependencies] cuda11 = [ @@ -1805,7 +1791,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", editable = "model/common" }, { name = "packaging", specifier = ">=20.0" }, ] @@ -1822,7 +1808,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", editable = "model/common" }, { name = "packaging", specifier = ">=20.0" }, ] @@ -1839,7 +1825,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", editable = "model/common" }, { name = "packaging", specifier = ">=20.0" }, ] @@ -1856,7 +1842,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", editable = "model/common" }, { name = "packaging", specifier = ">=20.0" }, ] @@ -1873,7 +1859,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", extras = ["io"], editable = "model/common" }, { name = "packaging", specifier = ">=20.0" }, ] @@ -1943,10 +1929,10 @@ requires-dist = [ { name = "dace", marker = "extra == 'dace'", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_09_25" }, { name = "datashader", marker = "extra == 'io'", specifier = ">=0.16.1" }, { name = "ghex", marker = "extra == 'distributed'", specifier = ">=0.3.0" }, - { name = "gt4py", specifier = "==1.0.9" }, - { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'" }, - { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'" }, - { name = "gt4py", extras = ["next"], marker = "extra == 'dace'" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, + { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'", git = "https://github.com/GridTools/gt4py?branch=main" }, + { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'", git = "https://github.com/GridTools/gt4py?branch=main" }, + { name = "gt4py", extras = ["next"], marker = "extra == 'dace'", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "holoviews", marker = "extra == 'io'", specifier = ">=1.16.0" }, { name = "icon4py-common", extras = ["dace", "distributed", "io"], marker = "extra == 'all'", editable = "model/common" }, { name = "mpi4py", marker = "extra == 'distributed'", specifier = ">=3.1.5" }, @@ -1982,7 +1968,7 @@ dependencies = [ requires-dist = [ { name = "click", specifier = ">=8.0.1" }, { name = "devtools", specifier = ">=0.12" }, - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-atmosphere-diffusion", editable = "model/atmosphere/diffusion" }, { name = "icon4py-atmosphere-dycore", editable = "model/atmosphere/dycore" }, { name = "icon4py-common", editable = "model/common" }, @@ -2010,7 +1996,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "filelock", specifier = ">=3.18.0" }, - { name = "gt4py", specifier = "==1.0.9" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-common", extras = ["io"], editable = "model/common" }, { name = "numpy", specifier = ">=1.23.3" }, { name = "packaging", specifier = ">=20.0" }, @@ -2056,9 +2042,9 @@ requires-dist = [ { name = "cupy-cuda11x", marker = "extra == 'cuda11'", specifier = ">=13.0" }, { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=13.0" }, { name = "fprettify", specifier = ">=0.3.7" }, - { name = "gt4py", specifier = "==1.0.9" }, - { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'" }, - { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'" }, + { name = "gt4py", git = "https://github.com/GridTools/gt4py?branch=main" }, + { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'", git = "https://github.com/GridTools/gt4py?branch=main" }, + { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'", git = "https://github.com/GridTools/gt4py?branch=main" }, { name = "icon4py-atmosphere-advection", editable = "model/atmosphere/advection" }, { name = "icon4py-atmosphere-diffusion", editable = "model/atmosphere/diffusion" }, { name = "icon4py-atmosphere-dycore", editable = "model/atmosphere/dycore" }, From b4bd3e4de202be82635630c03cdaee6747a4b6dc Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 1 Oct 2025 21:39:04 +0200 Subject: [PATCH 02/23] Enable custom backends for blueline --- .../model/atmosphere/advection/advection.py | 2 +- .../advection/advection_horizontal.py | 14 +-- .../advection/advection_vertical.py | 24 ++-- .../integration_tests/test_advection.py | 2 +- .../advection/tests/advection/utils.py | 10 +- .../model/atmosphere/diffusion/diffusion.py | 65 +++++------ .../test_benchmark_diffusion.py | 18 +-- .../integration_tests/test_diffusion.py | 4 +- .../integration_tests/test_diffusion_utils.py | 16 +-- .../model/atmosphere/dycore/solve_nonhydro.py | 105 +++++++++--------- .../atmosphere/dycore/velocity_advection.py | 22 ++-- .../integration_tests/test_solve_nonhydro.py | 34 +++--- .../test_velocity_advection.py | 6 +- .../mpi_tests/test_parallel_solve_nonhydro.py | 2 +- .../dycore/stencil_tests/test_dycore_utils.py | 18 +-- model/atmosphere/dycore/tests/dycore/utils.py | 6 +- .../microphysics/saturation_adjustment.py | 10 +- .../single_moment_six_class_gscp_graupel.py | 26 ++--- .../test_saturation_adjustment.py | 6 +- ...st_single_moment_six_class_gscp_graupel.py | 14 +-- .../model/common/metrics/metrics_factory.py | 4 +- .../icon4py/model/common/model_backends.py | 13 +++ .../icon4py/model/common/states/factory.py | 4 +- .../model/common/utils/data_allocation.py | 30 ++--- .../test_diagnostic_calculations.py | 30 ++--- .../tests/common/grid/unit_tests/test_base.py | 2 +- .../common/grid/unit_tests/test_geometry.py | 6 +- .../common/grid/unit_tests/test_vertical.py | 2 +- .../unit_tests/test_compute_nudgecoeffs.py | 2 +- .../common/math/unit_tests/test_helpers.py | 18 +-- .../math/unit_tests/test_smagorinsky.py | 4 +- .../test_compute_diffusion_metrics.py | 20 ++-- .../unit_tests/test_compute_weight_factors.py | 7 +- .../test_compute_zdiff_gradp_dsl.py | 4 +- .../metrics/unit_tests/test_metric_fields.py | 46 ++++---- .../unit_tests/test_reference_atmosphere.py | 35 ++++-- .../model/driver/initialization_utils.py | 20 ++-- .../icon4py/model/driver/testcases/gauss3d.py | 4 +- .../testcases/jablonowski_williamson.py | 24 ++-- .../icon4py/model/driver/testcases/utils.py | 60 +++++----- .../driver/integration_tests/test_icon4py.py | 8 +- .../src/icon4py/model/testing/grid_utils.py | 2 +- .../icon4py/tools/py2fgen/wrappers/common.py | 36 ++++-- .../py2fgen/wrappers/diffusion_wrapper.py | 14 ++- .../tools/py2fgen/wrappers/dycore_wrapper.py | 14 ++- .../tools/py2fgen/wrappers/grid_wrapper.py | 5 +- 46 files changed, 436 insertions(+), 382 deletions(-) diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py index 520a6e1ed4..3e03c2201a 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py @@ -249,7 +249,7 @@ def __init__( # density fields #: intermediate density times cell thickness, includes either the horizontal or vertical advective density increment [kg/m^2] self._rhodz_ast2 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) # stencils diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py index 77d64ce8c5..fae19d65bf 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py @@ -94,7 +94,7 @@ def __init__( # limiter fields self._r_m = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) # stencils @@ -207,13 +207,13 @@ def __init__( # reconstruction fields self._p_coeff_1 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._p_coeff_2 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._p_coeff_3 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) # stencils @@ -464,13 +464,13 @@ def __init__( # backtrajectory fields self._z_real_vt = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=self._backend ) self._p_distv_bary_1 = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=self._backend ) self._p_distv_bary_2 = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=self._backend ) # stencils diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py index b7c72ad723..d62d759e5a 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py @@ -187,7 +187,7 @@ def __init__(self, grid: icon_grid.IconGrid, backend: gtx_typing.Backend | None) # fields self._l_limit = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) # stencils @@ -255,10 +255,10 @@ def __init__(self, grid: icon_grid.IconGrid, backend: gtx_typing.Backend | None) # fields self._k_field = data_alloc.index_field( - self._grid, dims.KDim, extend={dims.KDim: 1}, dtype=gtx.int32, backend=self._backend + self._grid, dims.KDim, extend={dims.KDim: 1}, dtype=gtx.int32, allocator=self._backend ) # TODO(dastrm): should be KHalfDim self._l_limit = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=gtx.int32, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=gtx.int32, allocator=self._backend ) # stencils @@ -538,7 +538,7 @@ def __init__( grid=self._grid, extend={dims.KDim: 1}, dtype=gtx.int32, - backend=self._backend, + allocator=self._backend, ) # TODO(dastrm): should be KHalfDim # stencils @@ -675,28 +675,28 @@ def __init__( # fields self._k_field = data_alloc.index_field( - self._grid, dims.KDim, extend={dims.KDim: 1}, dtype=gtx.int32, backend=self._backend + self._grid, dims.KDim, extend={dims.KDim: 1}, dtype=gtx.int32, allocator=self._backend ) # TODO(dastrm): should be KHalfDim self._z_cfl = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=self._backend ) # TODO(dastrm): should be KHalfDim self._z_slope = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._z_face = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=self._backend ) # TODO(dastrm): should be KHalfDim self._z_face_up = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._z_face_low = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._z_delta_q = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) self._z_a1 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=self._backend ) # stencils diff --git a/model/atmosphere/advection/tests/advection/integration_tests/test_advection.py b/model/atmosphere/advection/tests/advection/integration_tests/test_advection.py index 0d5b9d96c7..7c5cef064d 100644 --- a/model/atmosphere/advection/tests/advection/integration_tests/test_advection.py +++ b/model/atmosphere/advection/tests/advection/integration_tests/test_advection.py @@ -149,7 +149,7 @@ def test_advection_run_single_step( ) prep_adv = construct_prep_adv(advection_init_savepoint) p_tracer_now = advection_init_savepoint.tracer(ntracer) - p_tracer_new = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + p_tracer_new = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) dtime = advection_init_savepoint.get_metadata("dtime").get("dtime") log_serialized(diagnostic_state, prep_adv, p_tracer_now, dtime) diff --git a/model/atmosphere/advection/tests/advection/utils.py b/model/atmosphere/advection/tests/advection/utils.py index 16cb6a0358..9e43a5ba4e 100644 --- a/model/atmosphere/advection/tests/advection/utils.py +++ b/model/atmosphere/advection/tests/advection/utils.py @@ -59,7 +59,7 @@ def construct_least_squares_state( def construct_metric_state( icon_grid, savepoint: sb.MetricSavepoint, backend: gtx_typing.Backend | None ) -> advection_states.AdvectionMetricState: - constant_f = data_alloc.constant_field(icon_grid, 1.0, dims.KDim, backend=backend) + constant_f = data_alloc.constant_field(icon_grid, 1.0, dims.KDim, allocator=backend) ddqz_z_full_np = np.reciprocal(savepoint.inv_ddqz_z_full().asnumpy()) return advection_states.AdvectionMetricState( deepatmo_divh=constant_f, @@ -79,9 +79,9 @@ def construct_diagnostic_init_state( airmass_now=savepoint.airmass_now(), airmass_new=savepoint.airmass_new(), grf_tend_tracer=savepoint.grf_tend_tracer(ntracer), - hfl_tracer=data_alloc.zero_field(icon_grid, dims.EdgeDim, dims.KDim, backend=backend), + hfl_tracer=data_alloc.zero_field(icon_grid, dims.EdgeDim, dims.KDim, allocator=backend), vfl_tracer=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), ) @@ -93,8 +93,8 @@ def construct_diagnostic_exit_state( backend: gtx_typing.Backend | None, ) -> advection_states.AdvectionDiagnosticState: return advection_states.AdvectionDiagnosticState( - airmass_now=data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend), - airmass_new=data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend), + airmass_now=data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend), + airmass_new=data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend), grf_tend_tracer=data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim), hfl_tracer=savepoint.hfl_tracer(ntracer), vfl_tracer=savepoint.vfl_tracer(ntracer), diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 5501e70328..7fb1fc98d8 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -16,6 +16,7 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing +from gt4py.next import allocators as gtx_allocators import icon4py.model.common.grid.states as grid_states import icon4py.model.common.states.prognostic_state as prognostics @@ -370,7 +371,7 @@ def __init__( orchestration: bool = False, exchange: decomposition.ExchangeRuntime | None = None, ): - self._backend = backend + self._allocator = model_backends.get_allocator(backend) self._orchestration = orchestration self._exchange = exchange or decomposition.SingleNodeExchange() self.config = config @@ -420,7 +421,7 @@ def __init__( ) self.calculate_nabla2_and_smag_coefficients_for_vn = setup_program( - backend=self._backend, + backend=backend, program=calculate_nabla2_and_smag_coefficients_for_vn, constant_args={ "tangent_orientation": self._edge_params.tangent_orientation, @@ -440,7 +441,7 @@ def __init__( ) self.calculate_diagnostic_quantities_for_turbulence = setup_program( - backend=self._backend, + backend=backend, program=calculate_diagnostic_quantities_for_turbulence, constant_args={ "e_bln_c_s": self._interpolation_state.e_bln_c_s, @@ -455,7 +456,7 @@ def __init__( offset_provider=self._grid.connectivities, ) self.apply_diffusion_to_vn = setup_program( - backend=self._backend, + backend=backend, program=apply_diffusion_to_vn, constant_args={ "primal_normal_vert_v1": self._edge_params.primal_normal_vert[0], @@ -477,7 +478,7 @@ def __init__( offset_provider=self._grid.connectivities, ) self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = setup_program( - backend=self._backend, + backend=backend, program=apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence, constant_args={ "geofac_n2s": self._interpolation_state.geofac_n2s, @@ -505,7 +506,7 @@ def __init__( offset_provider=self._grid.connectivities, ) self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = setup_program( - backend=self._backend, + backend=backend, program=calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools, constant_args={ "theta_ref_mc": self._metric_state.theta_ref_mc, @@ -523,7 +524,7 @@ def __init__( offset_provider=self._grid.connectivities, ) self.apply_diffusion_to_theta_and_exner = setup_program( - backend=self._backend, + backend=backend, program=apply_diffusion_to_theta_and_exner, constant_args={ "geofac_div": self._interpolation_state.geofac_div, @@ -548,19 +549,19 @@ def __init__( }, offset_provider=self._grid.connectivities, ) - self.copy_field = setup_program(backend=self._backend, program=copy_field) - self.scale_k = setup_program(backend=self._backend, program=scale_k) + self.copy_field = setup_program(backend=backend, program=copy_field) + self.scale_k = setup_program(backend=backend, program=scale_k) self.setup_fields_for_initial_step = setup_program( - backend=self._backend, program=setup_fields_for_initial_step + backend=backend, program=setup_fields_for_initial_step ) self.init_diffusion_local_fields_for_regular_timestep = setup_program( - backend=self._backend, + backend=backend, program=init_diffusion_local_fields_for_regular_timestep, offset_provider={"Koff": dims.KDim}, ) - self._allocate_temporary_fields() + self._allocate_local_fields() self.init_diffusion_local_fields_for_regular_timestep( params.K4, @@ -574,7 +575,7 @@ def __init__( offset_provider={"Koff": dims.KDim}, ) setup_program( - backend=self._backend, + backend=backend, program=diffusion_utils.init_nabla2_factor_in_upper_damping_zone, constant_args={ "physical_heights": self._vertical_grid.interface_physical_height, @@ -595,42 +596,44 @@ def __init__( # but this requires some changes in gt4py domain inference. self.compile_time_connectivities = self._grid.connectivities - def _allocate_temporary_fields(self): - self.diff_multfac_vn = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) - self.diff_multfac_n2w = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) - self.smag_limit = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) - self.enh_smag_fac = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) + def _allocate_local_fields( + self, allocator: gtx_allocators.FieldBufferAllocationUtil | None = None + ): + self.diff_multfac_vn = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) + self.diff_multfac_n2w = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) + self.smag_limit = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) + self.enh_smag_fac = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) self.u_vert = data_alloc.zero_field( - self._grid, dims.VertexDim, dims.KDim, backend=self._backend + self._grid, dims.VertexDim, dims.KDim, allocator=allocator ) self.v_vert = data_alloc.zero_field( - self._grid, dims.VertexDim, dims.KDim, backend=self._backend + self._grid, dims.VertexDim, dims.KDim, allocator=allocator ) self.kh_smag_e = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=allocator ) self.kh_smag_ec = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=allocator ) self.z_nabla2_e = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, allocator=allocator ) - self.diff_multfac_smag = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) + self.diff_multfac_smag = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) # TODO(halungge): this is KHalfDim self.vertical_index = data_alloc.index_field( - self._grid, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.KDim, extend={dims.KDim: 1}, allocator=allocator ) self.horizontal_cell_index = data_alloc.index_field( - self._grid, dims.CellDim, backend=self._backend + self._grid, dims.CellDim, allocator=allocator ) self.horizontal_edge_index = data_alloc.index_field( - self._grid, dims.EdgeDim, backend=self._backend + self._grid, dims.EdgeDim, allocator=allocator ) self.w_tmp = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=allocator ) self.theta_v_tmp = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, backend=self._backend + self._grid, dims.CellDim, dims.KDim, allocator=allocator ) def _determine_horizontal_domains(self): @@ -685,8 +688,8 @@ def initial_run( This run uses special values for diff_multfac_vn, smag_limit and smag_offset """ - diff_multfac_vn = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) - smag_limit = data_alloc.zero_field(self._grid, dims.KDim, backend=self._backend) + diff_multfac_vn = data_alloc.zero_field(self._grid, dims.KDim, allocator=self._allocator) + smag_limit = data_alloc.zero_field(self._grid, dims.KDim, allocator=self._allocator) self.setup_fields_for_initial_step( self._params.K4, diff --git a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_benchmark_diffusion.py b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_benchmark_diffusion.py index 8ecd3b7692..5241ae59ca 100644 --- a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_benchmark_diffusion.py +++ b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_benchmark_diffusion.py @@ -190,20 +190,20 @@ def test_run_diffusion_benchmark( ) # initialization of the diagnostic and prognostic state diagnostic_state = diffusion_states.DiffusionDiagnosticState( - hdef_ic=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), - div_ic=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), - dwdx=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), - dwdy=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), + hdef_ic=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), + div_ic=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), + dwdx=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), + dwdy=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), ) prognostic_state = prognostics.PrognosticState( w=data_alloc.random_field( - mesh, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, low=0.0, backend=backend + mesh, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, low=0.0, allocator=backend ), - vn=data_alloc.random_field(mesh, dims.EdgeDim, dims.KDim, backend=backend), - exner=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), - theta_v=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), - rho=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, backend=backend), + vn=data_alloc.random_field(mesh, dims.EdgeDim, dims.KDim, allocator=backend), + exner=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), + theta_v=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), + rho=data_alloc.random_field(mesh, dims.CellDim, dims.KDim, allocator=backend), ) diffusion_granule = diffusion.Diffusion( diff --git a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion.py b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion.py index 709bbc1e87..630c560ff5 100644 --- a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion.py +++ b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion.py @@ -671,8 +671,8 @@ def test_verify_special_diffusion_inital_step_values_against_initial_savepoint( expected_smag_limit = savepoint.smag_limit() exptected_smag_offset = savepoint.smag_offset() - diff_multfac_vn = data_alloc.zero_field(icon_grid, dims.KDim, backend=backend) - smag_limit = data_alloc.zero_field(icon_grid, dims.KDim, backend=backend) + diff_multfac_vn = data_alloc.zero_field(icon_grid, dims.KDim, allocator=backend) + smag_limit = data_alloc.zero_field(icon_grid, dims.KDim, allocator=backend) diffusion_utils.setup_fields_for_initial_step.with_backend(backend)( params.K4, config.hdiff_efdt_ratio, diff --git a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion_utils.py b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion_utils.py index 939a77251d..c080b868ad 100644 --- a/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion_utils.py +++ b/model/atmosphere/diffusion/tests/diffusion/integration_tests/test_diffusion_utils.py @@ -27,8 +27,8 @@ def initial_diff_multfac_vn_numpy(shape, k4, hdiff_efdt_ratio): def test_scale_k(backend): grid = simple_grid.simple_grid(backend=backend) - field = data_alloc.random_field(grid, dims.KDim, backend=backend) - scaled_field = data_alloc.zero_field(grid, dims.KDim, backend=backend) + field = data_alloc.random_field(grid, dims.KDim, allocator=backend) + scaled_field = data_alloc.zero_field(grid, dims.KDim, allocator=backend) factor = 2.0 diffusion_utils.scale_k.with_backend(backend)(field, factor, scaled_field, offset_provider={}) assert np.allclose(factor * field.asnumpy(), scaled_field.asnumpy()) @@ -36,8 +36,8 @@ def test_scale_k(backend): def test_diff_multfac_vn_and_smag_limit_for_initial_step(backend): grid = simple_grid.simple_grid(backend=backend) - diff_multfac_vn_init = data_alloc.zero_field(grid, dims.KDim, backend=backend) - smag_limit_init = data_alloc.zero_field(grid, dims.KDim, backend=backend) + diff_multfac_vn_init = data_alloc.zero_field(grid, dims.KDim, allocator=backend) + smag_limit_init = data_alloc.zero_field(grid, dims.KDim, allocator=backend) k4 = 1.0 efdt_ratio = 24.0 shape = diff_multfac_vn_init.asnumpy().shape @@ -57,8 +57,8 @@ def test_diff_multfac_vn_and_smag_limit_for_initial_step(backend): def test_diff_multfac_vn_smag_limit_for_time_step_with_const_value(backend): grid = simple_grid.simple_grid(backend=backend) - diff_multfac_vn = data_alloc.zero_field(grid, dims.KDim, backend=backend) - smag_limit = data_alloc.zero_field(grid, dims.KDim, backend=backend) + diff_multfac_vn = data_alloc.zero_field(grid, dims.KDim, allocator=backend) + smag_limit = data_alloc.zero_field(grid, dims.KDim, allocator=backend) k4 = 1.0 substeps = 5.0 efdt_ratio = 24.0 @@ -80,8 +80,8 @@ def test_diff_multfac_vn_smag_limit_for_time_step_with_const_value(backend): def test_diff_multfac_vn_smag_limit_for_loop_run_with_k4_substeps(backend): grid = simple_grid.simple_grid(backend=backend) - diff_multfac_vn = data_alloc.zero_field(grid, dims.KDim, backend=backend) - smag_limit = data_alloc.zero_field(grid, dims.KDim, backend=backend) + diff_multfac_vn = data_alloc.zero_field(grid, dims.KDim, allocator=backend) + smag_limit = data_alloc.zero_field(grid, dims.KDim, allocator=backend) k4 = 0.003 substeps = 1.0 diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 1e2f089c9e..2e1a80a0e4 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -116,25 +116,25 @@ def allocate( ): return IntermediateFields( horizontal_pressure_gradient=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), rho_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), theta_v_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), horizontal_gradient_of_normal_wind_divergence=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), dwdz_at_cells_on_model_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, backend=backend + grid, dims.CellDim, dims.KDim, allocator=backend ), horizontal_kinetic_energy_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), tangential_wind_on_half_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), ) @@ -365,7 +365,6 @@ def __init__( exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(), ): self._exchange = exchange - self._backend = backend self._grid = grid self._config = config @@ -378,7 +377,7 @@ def __init__( self._determine_local_domains() self._compute_theta_and_exner = setup_program( - backend=self._backend, + backend=backend, program=compute_theta_and_exner, constant_args={ "bdy_halo_c": self._metric_state_nonhydro.bdy_halo_c, @@ -396,7 +395,7 @@ def __init__( ) self._compute_exner_from_rhotheta = setup_program( - backend=self._backend, + backend=backend, program=compute_exner_from_rhotheta, constant_args={ "rd_o_cvd": constants.RD_O_CVD, @@ -413,7 +412,7 @@ def __init__( ) self._update_theta_v = setup_program( - backend=self._backend, + backend=backend, program=update_theta_v, constant_args={ "mask_prog_halo_c": self._metric_state_nonhydro.mask_prog_halo_c, @@ -429,7 +428,7 @@ def __init__( ) self._compute_hydrostatic_correction_term = setup_program( - backend=self._backend, + backend=backend, program=compute_hydrostatic_correction_term, constant_args={ "ikoffset": self._metric_state_nonhydro.vertoffset_gradp, @@ -450,7 +449,7 @@ def __init__( ) self._compute_theta_rho_face_values_and_pressure_gradient_and_update_vn = setup_program( - backend=self._backend, + backend=backend, program=compute_edge_diagnostics_for_dycore_and_update_vn.compute_theta_rho_face_values_and_pressure_gradient_and_update_vn, constant_args={ "reference_rho_at_edges_on_model_levels": self._metric_state_nonhydro.reference_rho_at_edges_on_model_levels, @@ -493,7 +492,7 @@ def __init__( ) self._apply_divergence_damping_and_update_vn = setup_program( - backend=self._backend, + backend=backend, program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn, constant_args={ "horizontal_mask_for_3d_divdamp": self._metric_state_nonhydro.horizontal_mask_for_3d_divdamp, @@ -523,7 +522,7 @@ def __init__( ) self._compute_horizontal_velocity_quantities_and_fluxes = setup_program( - backend=self._backend, + backend=backend, program=compute_horizontal_velocity_quantities_and_fluxes, constant_args={ "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, @@ -548,7 +547,7 @@ def __init__( ) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection = setup_program( - backend=self._backend, + backend=backend, program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, constant_args={ "e_flx_avg": self._interpolation_state.e_flx_avg, @@ -570,7 +569,7 @@ def __init__( ) self._vertically_implicit_solver_at_predictor_step = setup_program( - backend=self._backend, + backend=backend, program=vertically_implicit_dycore_solver.vertically_implicit_solver_at_predictor_step, constant_args={ "geofac_div": self._interpolation_state.geofac_div, @@ -607,7 +606,7 @@ def __init__( ) self._vertically_implicit_solver_at_corrector_step = setup_program( - backend=self._backend, + backend=backend, program=vertically_implicit_dycore_solver.vertically_implicit_solver_at_corrector_step, constant_args={ "exner_w_explicit_weight_parameter": self._metric_state_nonhydro.exner_w_explicit_weight_parameter, @@ -640,7 +639,7 @@ def __init__( ) self._compute_dwdz_for_divergence_damping = setup_program( - backend=self._backend, + backend=backend, program=compute_dwdz_for_divergence_damping, constant_args={ "inv_ddqz_z_full": self._metric_state_nonhydro.inv_ddqz_z_full, @@ -657,7 +656,7 @@ def __init__( ) self._init_cell_kdim_field_with_zero_wp = setup_program( - backend=self._backend, + backend=backend, program=init_cell_kdim_field_with_zero_wp, horizontal_sizes={ "horizontal_start": self._start_cell_lateral_boundary, @@ -669,7 +668,7 @@ def __init__( }, ) self._update_mass_flux_weighted = setup_program( - backend=self._backend, + backend=backend, program=update_mass_flux_weighted, constant_args={ "vwind_expl_wgt": self._metric_state_nonhydro.exner_w_explicit_weight_parameter, @@ -685,7 +684,7 @@ def __init__( }, ) self._calculate_divdamp_fields = setup_program( - backend=self._backend, + backend=backend, program=dycore_utils.calculate_divdamp_fields, constant_args={ "divdamp_order": gtx.int32(self._config.divdamp_order), @@ -695,7 +694,7 @@ def __init__( }, ) self._compute_rayleigh_damping_factor = setup_program( - backend=self._backend, + backend=backend, program=dycore_utils.compute_rayleigh_damping_factor, constant_args={ "rayleigh_w": self._metric_state_nonhydro.rayleigh_w, @@ -703,7 +702,7 @@ def __init__( ) self._compute_perturbed_quantities_and_interpolation = setup_program( - backend=self._backend, + backend=backend, program=compute_cell_diagnostics_for_dycore.compute_perturbed_quantities_and_interpolation, constant_args={ "reference_rho_at_cells_on_model_levels": self._metric_state_nonhydro.reference_rho_at_cells_on_model_levels, @@ -742,7 +741,7 @@ def __init__( ) self._interpolate_rho_theta_v_to_half_levels_and_compute_pressure_buoyancy_acceleration = setup_program( - backend=self._backend, + backend=backend, program=compute_cell_diagnostics_for_dycore.interpolate_rho_theta_v_to_half_levels_and_compute_pressure_buoyancy_acceleration, constant_args={ "reference_theta_at_cells_on_model_levels": self._metric_state_nonhydro.reference_theta_at_cells_on_model_levels, @@ -764,7 +763,7 @@ def __init__( offset_provider=self._grid.connectivities, ) self._stencils_61_62 = setup_program( - backend=self._backend, + backend=backend, program=nhsolve_stencils.stencils_61_62, horizontal_sizes={ "horizontal_start": self._start_cell_lateral_boundary, @@ -776,7 +775,7 @@ def __init__( }, ) self._en_smag_fac_for_zero_nshift = setup_program( - backend=self._backend, + backend=backend, program=smagorinsky.en_smag_fac_for_zero_nshift, constant_args={ "vect_a": self._vertical_params.interface_physical_height, @@ -792,7 +791,7 @@ def __init__( offset_provider={"Koff": dims.KDim}, ) self._init_test_fields = setup_program( - backend=self._backend, + backend=backend, program=nhsolve_stencils.init_test_fields, horizontal_sizes={ "edges_start": self._start_edge_lateral_boundary, @@ -813,9 +812,9 @@ def __init__( vertical_params, edge_geometry, owner_mask, - backend=self._backend, + backend=backend, ) - self._allocate_local_fields() + self._allocate_local_fields(model_backends.get_allocator(backend)) self._en_smag_fac_for_zero_nshift( enh_smag_fac=self.interpolated_fourth_order_divdamp_factor, @@ -823,14 +822,14 @@ def __init__( self.p_test_run = True - def _allocate_local_fields(self): + def _allocate_local_fields(self, allocator): self.temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, extend={dims.KDim: 1}, - backend=self._backend, + allocator=allocator, ) """ Declared as z_exner_ex_pr in ICON. @@ -841,14 +840,14 @@ def _allocate_local_fields(self): dims.KDim, dtype=ta.vpfloat, extend={dims.KDim: 1}, - backend=self._backend, + allocator=allocator, ) """ Declared as z_exner_ic in ICON. """ self.ddz_of_temporal_extrapolation_of_perturbed_exner_on_model_levels = ( data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) ) """ @@ -860,14 +859,14 @@ def _allocate_local_fields(self): dims.KDim, dtype=ta.vpfloat, extend={dims.KDim: 1}, - backend=self._backend, + allocator=allocator, ) """ Declared as z_theta_v_pr_ic in ICON. """ self.pressure_buoyancy_acceleration_at_cells_on_half_levels = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) """ Declared as z_th_ddz_exner_c in ICON. theta' dpi0/dz + theta (1 - eta_impl) dpi'/dz. @@ -876,85 +875,83 @@ def _allocate_local_fields(self): term for updating w, and w at model top/bottom is diagnosed. """ self.perturbed_rho_at_cells_on_model_levels = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) """ Declared as z_rth_pr_1 in ICON. """ self.perturbed_theta_v_at_cells_on_model_levels = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) """ Declared as z_rth_pr_2 in ICON. """ self.d2dz2_of_temporal_extrapolation_of_perturbed_exner_on_model_levels = ( data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) ) """ Declared as z_dexner_dz_c_2 in ICON. """ self.z_vn_avg = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) self.theta_v_flux_at_edges_on_model_levels = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) """ Declared as z_theta_v_fl_e in ICON. """ self.z_rho_v = data_alloc.zero_field( - self._grid, dims.VertexDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.VertexDim, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) self.z_theta_v_v = data_alloc.zero_field( - self._grid, dims.VertexDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.VertexDim, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) self.k_field = data_alloc.index_field( - self._grid, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.KDim, extend={dims.KDim: 1}, allocator=allocator ) self._contravariant_correction_at_edges_on_model_levels = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) """ Declared as z_w_concorr_me in ICON. vn dz/dn + vt dz/dt, z is topography height """ self.hydrostatic_correction_on_lowest_level = data_alloc.zero_field( - self._grid, dims.EdgeDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.EdgeDim, dtype=ta.vpfloat, allocator=allocator ) self.hydrostatic_correction = data_alloc.zero_field( - self._grid, dims.EdgeDim, dims.KDim, dtype=ta.vpfloat, backend=self._backend + self._grid, dims.EdgeDim, dims.KDim, dtype=ta.vpfloat, allocator=allocator ) """ Declared as z_hydro_corr in ICON. Used for computation of horizontal pressure gradient over steep slope. """ self.rayleigh_damping_factor = data_alloc.zero_field( - self._grid, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) """ Declared as z_raylfac in ICON. """ self.interpolated_fourth_order_divdamp_factor = data_alloc.zero_field( - self._grid, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) """ Declared as enh_divdamp_fac in ICON. """ self.reduced_fourth_order_divdamp_coeff_at_nest_boundary = data_alloc.zero_field( - self._grid, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) """ Declared as bdy_divdamp in ICON. """ self.fourth_order_divdamp_scaling_coeff = data_alloc.zero_field( - self._grid, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.KDim, dtype=ta.wpfloat, allocator=allocator ) """ Declared as scal_divdamp in ICON. """ - self.intermediate_fields = IntermediateFields.allocate( - grid=self._grid, backend=self._backend - ) + self.intermediate_fields = IntermediateFields.allocate(grid=self._grid, allocator=allocator) def _determine_local_domains(self): vertex_domain = h_grid.domain(dims.VertexDim) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py index 0009168c06..a664c1034d 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py @@ -11,6 +11,7 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing +from gt4py.next import allocators as gtx_allocators # TODO(havogt): expose in gtx_typing from icon4py.model.atmosphere.dycore import dycore_states from icon4py.model.atmosphere.dycore.stencils.compute_advection_in_horizontal_momentum_equation import ( @@ -60,15 +61,14 @@ def __init__( self.vertical_params = vertical_params self.edge_params = edge_params self.c_owner_mask = owner_mask - self._backend = backend self.cfl_w_limit: float = 0.65 self.scalfac_exdiff: float = 0.05 - self._allocate_local_fields() + self._allocate_local_fields(model_backends.get_allocator(backend)) self._determine_local_domains() self._compute_derived_horizontal_winds_and_ke_and_contravariant_correction = setup_program( - backend=self._backend, + backend=backend, program=compute_derived_horizontal_winds_and_ke_and_contravariant_correction, constant_args={ "rbf_vec_coeff_e": self.interpolation_state.rbf_vec_coeff_e, @@ -98,7 +98,7 @@ def __init__( # TODO(nfarabullini): add `skip_compute_predictor_vertical_advection` to `variants` once possible self._compute_contravariant_correction_and_advection_in_vertical_momentum_equation = setup_program( - backend=self._backend, + backend=backend, program=compute_contravariant_correction_and_advection_in_vertical_momentum_equation, constant_args={ "coeff1_dwdz": self.metric_state.coeff1_dwdz, @@ -123,7 +123,7 @@ def __init__( ) self._compute_advection_in_vertical_momentum_equation = setup_program( - backend=self._backend, + backend=backend, program=compute_advection_in_vertical_momentum_equation, constant_args={ "coeff1_dwdz": self.metric_state.coeff1_dwdz, @@ -150,7 +150,7 @@ def __init__( ) self._compute_advection_in_horizontal_momentum_equation = setup_program( - backend=self._backend, + backend=backend, program=compute_advection_in_horizontal_momentum_equation, constant_args={ "e_bln_c_s": self.interpolation_state.e_bln_c_s, @@ -179,23 +179,25 @@ def __init__( offset_provider=self.grid.connectivities, ) - def _allocate_local_fields(self): + def _allocate_local_fields( + self, allocator: gtx_allocators.FieldBufferAllocationUtil | None = None + ): self._horizontal_advection_of_w_at_edges_on_half_levels = data_alloc.zero_field( - self.grid, dims.EdgeDim, dims.KDim, backend=self._backend, dtype=ta.vpfloat + self.grid, dims.EdgeDim, dims.KDim, allocator=allocator, dtype=ta.vpfloat ) """ Declared as z_v_grad_w in ICON. vn dw/dn + vt dw/dt. NOTE THAT IT ONLY HAS nlev LEVELS because w[nlevp1-1] is diagnostic. """ self._contravariant_corrected_w_at_cells_on_model_levels = data_alloc.zero_field( - self.grid, dims.CellDim, dims.KDim, backend=self._backend, dtype=ta.vpfloat + self.grid, dims.CellDim, dims.KDim, allocator=allocator, dtype=ta.vpfloat ) """ Declared as z_w_con_c_full in ICON. w - (vn dz/dn + vt dz/dt), z is topography height """ self.vertical_cfl = data_alloc.zero_field( - self.grid, dims.CellDim, dims.KDim, backend=self._backend, dtype=ta.vpfloat + self.grid, dims.CellDim, dims.KDim, allocator=allocator, dtype=ta.vpfloat ) def _determine_local_domains(self): diff --git a/model/atmosphere/dycore/tests/dycore/integration_tests/test_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore/integration_tests/test_solve_nonhydro.py index de864464f7..507101684a 100644 --- a/model/atmosphere/dycore/tests/dycore/integration_tests/test_solve_nonhydro.py +++ b/model/atmosphere/dycore/tests/dycore/integration_tests/test_solve_nonhydro.py @@ -53,17 +53,17 @@ def test_validate_divdamp_fields_against_savepoint_values( interpolated_fourth_order_divdamp_factor = data_alloc.zero_field( icon_grid, dims.KDim, - backend=backend, + allocator=backend, ) fourth_order_divdamp_scaling_coeff = data_alloc.zero_field( icon_grid, dims.KDim, - backend=backend, + allocator=backend, ) reduced_fourth_order_divdamp_coeff_at_nest_boundary = data_alloc.zero_field( icon_grid, dims.KDim, - backend=backend, + allocator=backend, ) smagorinsky.en_smag_fac_for_zero_nshift.with_backend(backend)( grid_savepoint.vct_a(), @@ -526,7 +526,7 @@ def test_nonhydro_corrector_step( mass_flx_me=init_savepoint.mass_flx_me(), dynamical_vertical_mass_flux_at_cells_on_half_levels=init_savepoint.mass_flx_ic(), dynamical_vertical_volumetric_flux_at_cells_on_half_levels=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), ) @@ -736,7 +736,7 @@ def test_run_solve_nonhydro_single_step( mass_flx_me=sp.mass_flx_me(), dynamical_vertical_mass_flux_at_cells_on_half_levels=sp.mass_flx_ic(), dynamical_vertical_volumetric_flux_at_cells_on_half_levels=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), ) @@ -864,7 +864,7 @@ def test_run_solve_nonhydro_multi_step( mass_flx_me=sp.mass_flx_me(), dynamical_vertical_mass_flux_at_cells_on_half_levels=sp.mass_flx_ic(), dynamical_vertical_volumetric_flux_at_cells_on_half_levels=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), ) @@ -1060,28 +1060,28 @@ def test_compute_perturbed_quantities_and_interpolation( # local fields perturbed_rho_at_cells_on_model_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) perturbed_theta_v_at_cells_on_model_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) perturbed_theta_v_at_cells_on_half_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) pressure_buoyancy_acceleration_at_cells_on_half_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) exner_at_cells_on_half_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) ddz_of_temporal_extrapolation_of_perturbed_exner_on_model_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) d2dz2_of_temporal_extrapolation_of_perturbed_exner_on_model_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) limited_area = icon_grid.limited_area @@ -1279,10 +1279,10 @@ def test_interpolate_rho_theta_v_to_half_levels_and_compute_pressure_buoyancy_ac rhotheta_implicit_weight_parameter = sp_init.wgt_nnew_rth() perturbed_theta_v_at_cells_on_half_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) pressure_buoyancy_acceleration_at_cells_on_half_levels = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) cell_domain = h_grid.domain(dims.CellDim) @@ -1437,7 +1437,7 @@ def test_compute_theta_rho_face_values_and_pressure_gradient_and_update_vn( perturbed_rho_at_cells_on_model_levels = sp_stencil_init.z_rth_pr(0) perturbed_theta_v_at_cells_on_model_levels = sp_stencil_init.z_rth_pr(1) hydrostatic_correction = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, allocator=backend ) temporal_extrapolation_of_perturbed_exner = sp_stencil_init.z_exner_ex_pr() ddz_of_temporal_extrapolation_of_perturbed_exner_on_model_levels = ( diff --git a/model/atmosphere/dycore/tests/dycore/integration_tests/test_velocity_advection.py b/model/atmosphere/dycore/tests/dycore/integration_tests/test_velocity_advection.py index d2f3e8e141..e957d556fd 100644 --- a/model/atmosphere/dycore/tests/dycore/integration_tests/test_velocity_advection.py +++ b/model/atmosphere/dycore/tests/dycore/integration_tests/test_velocity_advection.py @@ -480,7 +480,7 @@ def test_compute_derived_horizontal_winds_and_ke_and_contravariant_correction( vn_on_half_levels = savepoint_velocity_init.vn_ie() horizontal_kinetic_energy_at_edges_on_model_levels = savepoint_velocity_init.z_kin_hor_e() horizontal_advection_of_w_at_edges_on_half_levels = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, allocator=backend ) vn = savepoint_velocity_init.vn() w = savepoint_velocity_init.w() @@ -622,7 +622,7 @@ def test_compute_contravariant_correction_and_advection_in_vertical_momentum_equ vertical_wind_advective_tendency = savepoint_velocity_init.ddt_w_adv_pc(istep_init - 1) contravariant_corrected_w_at_cells_on_model_levels = savepoint_velocity_init.z_w_con_c_full() vertical_cfl = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=backend ) skip_compute_predictor_vertical_advection = savepoint_velocity_init.lvn_only() @@ -774,7 +774,7 @@ def test_compute_advection_in_vertical_momentum_equation( vertical_wind_advective_tendency = savepoint_velocity_init.ddt_w_adv_pc(istep_init - 1) contravariant_corrected_w_at_cells_on_model_levels = savepoint_velocity_init.z_w_con_c_full() vertical_cfl = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.vpfloat, allocator=backend ) coeff1_dwdz = metrics_savepoint.coeff1_dwdz() diff --git a/model/atmosphere/dycore/tests/dycore/mpi_tests/test_parallel_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore/mpi_tests/test_parallel_solve_nonhydro.py index 0bffc0aecc..7e29d9e0c3 100644 --- a/model/atmosphere/dycore/tests/dycore/mpi_tests/test_parallel_solve_nonhydro.py +++ b/model/atmosphere/dycore/tests/dycore/mpi_tests/test_parallel_solve_nonhydro.py @@ -101,7 +101,7 @@ def test_run_solve_nonhydro_single_step( mass_flx_me=sp.mass_flx_me(), dynamical_vertical_mass_flux_at_cells_on_half_levels=sp.mass_flx_ic(), dynamical_vertical_volumetric_flux_at_cells_on_half_levels=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), ) diff --git a/model/atmosphere/dycore/tests/dycore/stencil_tests/test_dycore_utils.py b/model/atmosphere/dycore/tests/dycore/stencil_tests/test_dycore_utils.py index 048108ef93..803e5de30d 100644 --- a/model/atmosphere/dycore/tests/dycore/stencil_tests/test_dycore_utils.py +++ b/model/atmosphere/dycore/tests/dycore/stencil_tests/test_dycore_utils.py @@ -48,9 +48,9 @@ def test_calculate_fourth_order_divdamp_scaling_coeff_order_24( mean_cell_area = 1000.0 grid = simple_grid.simple_grid(backend=backend) interpolated_fourth_order_divdamp_factor = data_alloc.random_field( - grid, dims.KDim, backend=backend + grid, dims.KDim, allocator=backend ) - out = data_alloc.random_field(grid, dims.KDim, backend=backend) + out = data_alloc.random_field(grid, dims.KDim, allocator=backend) dycore_utils._calculate_fourth_order_divdamp_scaling_coeff.with_backend(backend)( interpolated_fourth_order_divdamp_factor=interpolated_fourth_order_divdamp_factor, @@ -77,9 +77,9 @@ def test_calculate_fourth_order_divdamp_scaling_coeff_any_order( mean_cell_area = 1000.0 grid = simple_grid.simple_grid(backend=backend) interpolated_fourth_order_divdamp_factor = data_alloc.random_field( - grid, dims.KDim, backend=backend + grid, dims.KDim, allocator=backend ) - out = data_alloc.random_field(grid, dims.KDim, backend=backend) + out = data_alloc.random_field(grid, dims.KDim, allocator=backend) dycore_utils._calculate_fourth_order_divdamp_scaling_coeff.with_backend(backend)( interpolated_fourth_order_divdamp_factor=interpolated_fourth_order_divdamp_factor, @@ -97,8 +97,8 @@ def test_calculate_reduced_fourth_order_divdamp_coeff_at_nest_boundary( backend: gtx_typing.Backend, ) -> None: grid = simple_grid.simple_grid(backend=backend) - fourth_order_divdamp_scaling_coeff = data_alloc.random_field(grid, dims.KDim, backend=backend) - out = data_alloc.zero_field(grid, dims.KDim, backend=backend) + fourth_order_divdamp_scaling_coeff = data_alloc.random_field(grid, dims.KDim, allocator=backend) + out = data_alloc.zero_field(grid, dims.KDim, allocator=backend) coeff = 0.3 dycore_utils._calculate_reduced_fourth_order_divdamp_coeff_at_nest_boundary.with_backend( backend @@ -113,10 +113,10 @@ def test_calculate_reduced_fourth_order_divdamp_coeff_at_nest_boundary( def test_calculate_divdamp_fields(backend: gtx_typing.Backend) -> None: grid = simple_grid.simple_grid(backend=backend) - divdamp_field = data_alloc.random_field(grid, dims.KDim, backend=backend) - fourth_order_divdamp_scaling_coeff = data_alloc.zero_field(grid, dims.KDim, backend=backend) + divdamp_field = data_alloc.random_field(grid, dims.KDim, allocator=backend) + fourth_order_divdamp_scaling_coeff = data_alloc.zero_field(grid, dims.KDim, allocator=backend) reduced_fourth_order_divdamp_coeff_at_nest_boundary = data_alloc.zero_field( - grid, dims.KDim, backend=backend + grid, dims.KDim, allocator=backend ) divdamp_order = gtx.int32(24) mean_cell_area = 1000.0 diff --git a/model/atmosphere/dycore/tests/dycore/utils.py b/model/atmosphere/dycore/tests/dycore/utils.py index 3ad48aad6e..438af72e2c 100644 --- a/model/atmosphere/dycore/tests/dycore/utils.py +++ b/model/atmosphere/dycore/tests/dycore/utils.py @@ -122,11 +122,11 @@ def construct_diagnostics( tangential_wind=init_savepoint.vt(), vn_on_half_levels=init_savepoint.vn_ie(), contravariant_correction_at_cells_on_half_levels=init_savepoint.w_concorr_c(), - rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), normal_wind_iau_increment=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), - exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), exner_dynamical_increment=init_savepoint.exner_dyn_incr(), ) diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py index e9a2c0551b..a0d9b2fb45 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py @@ -123,23 +123,23 @@ def output_properties(self) -> dict[str, model.FieldMetaData]: def _allocate_local_variables(self): #: it was originally named as tworkold in ICON. Old temperature before iteration. self._temperature1 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) #: it was originally named as twork in ICON. New temperature before iteration. self._temperature2 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) #: A mask that indicates whether the grid cell is subsaturated or not. self._subsaturated_mask = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=bool, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._backend ) #: A mask that indicates whether next Newton iteration is required. self._newton_iteration_mask = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=bool, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._backend ) #: latent heat vaporization / dry air heat capacity at constant volume self._lwdocvd = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) def _initialize_gt4py_programs(self): diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py index 44f2b35bd2..a62413314a 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py @@ -219,43 +219,43 @@ def _initialize_configurable_parameters(self): def _initialize_local_fields(self): self.rhoqrv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.rhoqsv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.rhoqgv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.rhoqiv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.vnew_r = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.vnew_s = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.vnew_g = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.vnew_i = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.rain_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.snow_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.graupel_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.ice_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) self.total_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend ) def _determine_horizontal_domains(self): diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_saturation_adjustment.py b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_saturation_adjustment.py index 0170e6364e..12af8065ae 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_saturation_adjustment.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_saturation_adjustment.py @@ -70,10 +70,10 @@ def test_saturation_adjustement( vct_b=grid_savepoint.vct_b(), ) temperature_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) - qv_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - qc_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + qv_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + qc_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) metric_state = satad.MetricStateSaturationAdjustment( ddqz_z_full=metrics_savepoint.ddqz_z_full() diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py index 9c0577905f..a41c893126 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py @@ -124,25 +124,25 @@ def test_graupel( qnc = entry_savepoint.qnc() temperature_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qv_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qc_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qr_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qi_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qs_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) qg_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) graupel_microphysics.run( diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index 1ebde3e79c..c1894e2072 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -97,9 +97,9 @@ def __init__( } k_index = data_alloc.index_field( - self._grid, dims.KDim, extend={dims.KDim: 1}, backend=self._backend + self._grid, dims.KDim, extend={dims.KDim: 1}, allocator=self._backend ) - e_lev = data_alloc.index_field(self._grid, dims.EdgeDim, backend=self._backend) + e_lev = data_alloc.index_field(self._grid, dims.EdgeDim, allocator=self._backend) e_owner_mask = gtx.as_field( (dims.EdgeDim,), self._decomposition_info.owner_mask(dims.EdgeDim) ) diff --git a/model/common/src/icon4py/model/common/model_backends.py b/model/common/src/icon4py/model/common/model_backends.py index bd2102be7c..c0365d48c9 100644 --- a/model/common/src/icon4py/model/common/model_backends.py +++ b/model/common/src/icon4py/model/common/model_backends.py @@ -9,6 +9,7 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing +from gt4py.next import allocators as gtx_allocators, backend as gtx_backend from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory from icon4py.model.common import dimension as dims @@ -39,6 +40,18 @@ def is_backend_descriptor( return False +def get_allocator( + backend: gtx_typing.Backend | DeviceType | BackendDescriptor | None, +) -> gtx_typing.Backend | None: + if backend is None or isinstance(backend, gtx_backend.Backend): + return backend + if is_backend_descriptor(backend): + backend = backend["device"] + if isinstance(backend, DeviceType): + return gtx_allocators.device_allocators[backend] + raise ValueError(f"Cannot get allocator from {backend}") + + try: from gt4py.next.program_processors.runners.dace import make_dace_backend diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index f9e9a3c3f4..b24a816d66 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -331,7 +331,7 @@ def _compute(self, factory: FieldSource, grid_provider: GridProvider) -> None: log.debug(f"transferring dependencies to compute backend: {self._dependencies.keys()}") deps = { - k: data_alloc.as_field(factory.get(v), backend=compute_backend) + k: data_alloc.as_field(factory.get(v), allocator=compute_backend) for k, v in self._dependencies.items() } @@ -343,7 +343,7 @@ def _compute(self, factory: FieldSource, grid_provider: GridProvider) -> None: f"transferring result {k} to target backend: " f"{data_alloc.backend_name(factory.backend)}" ) - self._fields[k] = data_alloc.as_field(v, backend=factory.backend) + self._fields[k] = data_alloc.as_field(v, allocator=factory.backend) def _unravel_output_fields(self): out_fields = tuple(self._fields.values()) diff --git a/model/common/src/icon4py/model/common/utils/data_allocation.py b/model/common/src/icon4py/model/common/utils/data_allocation.py index 1eda6d5e7a..fd3ac89d94 100644 --- a/model/common/src/icon4py/model/common/utils/data_allocation.py +++ b/model/common/src/icon4py/model/common/utils/data_allocation.py @@ -70,12 +70,12 @@ def import_array_ns(allocator: gtx_allocators.FieldBufferAllocationUtil | None) def as_field( field: gtx.Field, - backend: gtx_typing.Backend | None = None, + allocator: gtx_typing.Backend | None = None, embedded_on_host: bool = False, ) -> gtx.Field: """Convenience function to transfer an existing Field to a given backend.""" data = field.asnumpy() if embedded_on_host else field.ndarray - return gtx.as_field(field.domain, data=data, allocator=backend) + return gtx.as_field(field.domain, data=data, allocator=allocator) def random_field( @@ -85,14 +85,14 @@ def random_field( high: float = 1.0, dtype: npt.DTypeLike | None = None, extend: dict[gtx.Dimension, int] | None = None, - backend=None, + allocator=None, ) -> gtx.Field: arr = np.random.default_rng().uniform( low=low, high=high, size=_shape(grid, *dims, extend=extend) ) if dtype: arr = arr.astype(dtype) - return gtx.as_field(dims, arr, allocator=backend) + return gtx.as_field(dims, arr, allocator=allocator) def random_sign( @@ -100,13 +100,13 @@ def random_sign( *dims, dtype: npt.DTypeLike | None = None, extend: dict[gtx.Dimension, int] | None = None, - backend=None, + allocator=None, ) -> gtx.Field: """Generate a random field with values -1 or 1.""" arr = np.random.default_rng().choice([-1, 1], size=_shape(grid, *dims, extend=extend)) if dtype: arr = arr.astype(dtype) - return gtx.as_field(dims, arr, allocator=backend) + return gtx.as_field(dims, arr, allocator=allocator) def random_mask( @@ -114,7 +114,7 @@ def random_mask( *dims: gtx.Dimension, dtype: npt.DTypeLike | None = None, extend: dict[gtx.Dimension, int] | None = None, - backend: gtx_typing.Backend | None = None, + allocator: gtx_typing.Backend | None = None, ) -> gtx.Field: rng = np.random.default_rng() shape = _shape(grid, *dims, extend=extend) @@ -125,7 +125,7 @@ def random_mask( arr = np.reshape(arr, newshape=shape) if dtype: arr = arr.astype(dtype) - return gtx.as_field(dims, arr, allocator=backend) + return gtx.as_field(dims, arr, allocator=allocator) def zero_field( @@ -133,10 +133,10 @@ def zero_field( *dims: gtx.Dimension, dtype=ta.wpfloat, extend: dict[gtx.Dimension, int] | None = None, - backend=None, + allocator=None, ) -> gtx.Field: field_domain = {dim: (0, stop) for dim, stop in zip(dims, _shape(grid, *dims, extend=extend))} - return gtx.constructors.zeros(field_domain, dtype=dtype, allocator=backend) + return gtx.constructors.zeros(field_domain, dtype=dtype, allocator=allocator) def constant_field( @@ -144,12 +144,12 @@ def constant_field( value: float, *dims: gtx.Dimension, dtype=ta.wpfloat, - backend=None, + allocator=None, ) -> gtx.Field: return gtx.as_field( dims, value * np.ones(shape=tuple(map(lambda x: grid.size[x], dims)), dtype=dtype), - allocator=backend, + allocator=allocator, ) @@ -167,8 +167,8 @@ def index_field( dim: gtx.Dimension, extend: dict[gtx.Dimension, int] | None = None, dtype=gtx.int32, - backend: gtx_typing.Backend | None = None, + allocator: gtx_typing.Backend | None = None, ) -> gtx.Field: - xp = import_array_ns(backend) + xp = import_array_ns(allocator) shapex = _shape(grid, dim, extend=extend)[0] - return gtx.as_field((dim,), xp.arange(shapex, dtype=dtype), allocator=backend) + return gtx.as_field((dim,), xp.arange(shapex, dtype=dtype), allocator=allocator) diff --git a/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py b/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py index aa836cd0c7..c021830fb3 100644 --- a/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py +++ b/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py @@ -58,18 +58,18 @@ def test_diagnose_temperature( theta_v = initial_prognostic_savepoint.theta_v_now() temperature = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend ) virtual_temperature = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend ) - qv = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - qc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - qr = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - qi = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - qs = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - qg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) + qv = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + qc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + qr = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + qi = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + qs = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + qg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) diagnose_temperature.diagnose_virtual_temperature_and_temperature.with_backend(backend)( qv=qv, @@ -117,8 +117,8 @@ def test_diagnose_meridional_and_zonal_winds( u_ref = diagnostics_reference_savepoint.zonal_wind().asnumpy() v_ref = diagnostics_reference_savepoint.meridional_wind().asnumpy() - u = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) - v = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend) + u = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) + v = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) cell_domain = h_grid.domain(dims.CellDim) cell_end_lateral_boundary_level_2 = icon_grid.end_index( @@ -169,7 +169,7 @@ def test_diagnose_surface_pressure( ddqz_z_full = metrics_savepoint.ddqz_z_full() surface_pressure = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}, allocator=backend ) cell_domain = h_grid.domain(dims.CellDim) @@ -210,12 +210,12 @@ def test_diagnose_pressure( pressure_ref = diagnostics_reference_savepoint.pressure().asnumpy() pressure = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=float, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend ) cell_domain = h_grid.domain(dims.CellDim) pressure_ifc = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}, allocator=backend ) pressure_ifc.ndarray[:, -1] = surface_pressure.ndarray @@ -274,9 +274,9 @@ def test_diagnostic_update_after_saturation_adjustement( vct_b=grid_savepoint.vct_b(), ) virtual_temperature_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ) - exner_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + exner_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) tracer_state = tracers.TracerState( qv=satad_exit.qv(), diff --git a/model/common/tests/common/grid/unit_tests/test_base.py b/model/common/tests/common/grid/unit_tests/test_base.py index a5576653ae..6a289340b0 100644 --- a/model/common/tests/common/grid/unit_tests/test_base.py +++ b/model/common/tests/common/grid/unit_tests/test_base.py @@ -20,7 +20,7 @@ def test_replace_skip_values(backend: gtx_typing.Backend) -> None: domain = (dims.CellDim, dims.C2E2CDim) xp = data_alloc.import_array_ns(backend) neighbor_table = data_alloc.random_field( - grid, *domain, low=0, high=grid.num_cells, dtype=gtx.int32, backend=backend + grid, *domain, low=0, high=grid.num_cells, dtype=gtx.int32, allocator=backend ).ndarray neighbor_table[0, 1:] = gridfile.GridFile.INVALID_INDEX # type: ignore[index] # NDArrayObject Protocol doesn't support this diff --git a/model/common/tests/common/grid/unit_tests/test_geometry.py b/model/common/tests/common/grid/unit_tests/test_geometry.py index f5e3cc8d06..ce0e40b70d 100644 --- a/model/common/tests/common/grid/unit_tests/test_geometry.py +++ b/model/common/tests/common/grid/unit_tests/test_geometry.py @@ -350,7 +350,7 @@ def test_cartesian_centers_edge( assert y.ndarray.shape == (grid.num_edges,) assert z.ndarray.shape == (grid.num_edges,) # those are coordinates on the unit sphere: hence norm should be 1 - norm = data_alloc.zero_field(grid, dims.EdgeDim, dtype=x.dtype, backend=backend) + norm = data_alloc.zero_field(grid, dims.EdgeDim, dtype=x.dtype, allocator=backend) math_helpers.norm2_on_edges(x, z, y, out=norm, offset_provider={}) assert test_utils.dallclose(norm.asnumpy(), 1.0) @@ -367,7 +367,7 @@ def test_cartesian_centers_cell( assert y.ndarray.shape == (grid.num_cells,) assert z.ndarray.shape == (grid.num_cells,) # those are coordinates on the unit sphere: hence norm should be 1 - norm = data_alloc.zero_field(grid, dims.CellDim, dtype=x.dtype, backend=backend) + norm = data_alloc.zero_field(grid, dims.CellDim, dtype=x.dtype, allocator=backend) math_helpers.norm2_on_cells(x, z, y, out=norm, offset_provider={}) assert test_utils.dallclose(norm.asnumpy(), 1.0) @@ -382,7 +382,7 @@ def test_vertex(backend: gtx_typing.Backend, experiment: definitions.Experiment) assert y.ndarray.shape == (grid.num_vertices,) assert z.ndarray.shape == (grid.num_vertices,) # those are coordinates on the unit sphere: hence norm should be 1 - norm = data_alloc.zero_field(grid, dims.VertexDim, dtype=x.dtype, backend=backend) + norm = data_alloc.zero_field(grid, dims.VertexDim, dtype=x.dtype, allocator=backend) math_helpers.norm2_on_vertices(x, z, y, out=norm, offset_provider={}) assert test_utils.dallclose(norm.asnumpy(), 1.0) diff --git a/model/common/tests/common/grid/unit_tests/test_vertical.py b/model/common/tests/common/grid/unit_tests/test_vertical.py index 5d45ab3098..9d3fd7b7e3 100644 --- a/model/common/tests/common/grid/unit_tests/test_vertical.py +++ b/model/common/tests/common/grid/unit_tests/test_vertical.py @@ -402,7 +402,7 @@ def test_compute_vertical_coordinate( topography = topography_savepoint.topo_c() elif experiment == definitions.Experiments.EXCLAIM_APE: topography = data_alloc.zero_field( - icon_grid, dims.CellDim, backend=backend, dtype=ta.wpfloat + icon_grid, dims.CellDim, allocator=backend, dtype=ta.wpfloat ) geofac_n2s = interpolation_savepoint.geofac_n2s() diff --git a/model/common/tests/common/interpolation/unit_tests/test_compute_nudgecoeffs.py b/model/common/tests/common/interpolation/unit_tests/test_compute_nudgecoeffs.py index d8b29d7d66..b1617ebba1 100644 --- a/model/common/tests/common/interpolation/unit_tests/test_compute_nudgecoeffs.py +++ b/model/common/tests/common/interpolation/unit_tests/test_compute_nudgecoeffs.py @@ -44,7 +44,7 @@ def test_compute_nudgecoeffs_e( icon_grid: base_grid.Grid, backend: gtx_typing.Backend, ) -> None: - nudgecoeff_e = data_alloc.zero_field(icon_grid, dims.EdgeDim, dtype=wpfloat, backend=backend) + nudgecoeff_e = data_alloc.zero_field(icon_grid, dims.EdgeDim, dtype=wpfloat, allocator=backend) nudgecoeff_e_ref = interpolation_savepoint.nudgecoeff_e() refin_ctrl = grid_savepoint.refin_ctrl(dims.EdgeDim) grf_nudge_start_e = refinement.get_nudging_refinement_value(dims.EdgeDim) diff --git a/model/common/tests/common/math/unit_tests/test_helpers.py b/model/common/tests/common/math/unit_tests/test_helpers.py index 344c6e1813..001f9aacb2 100644 --- a/model/common/tests/common/math/unit_tests/test_helpers.py +++ b/model/common/tests/common/math/unit_tests/test_helpers.py @@ -23,15 +23,15 @@ def test_cross_product(backend): mesh = simple.simple_grid(backend=backend) - x1 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - y1 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - z1 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - x2 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - y2 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - z2 = data_alloc.random_field(mesh, dims.EdgeDim, backend=backend) - x = data_alloc.zero_field(mesh, dims.EdgeDim, backend=backend) - y = data_alloc.zero_field(mesh, dims.EdgeDim, backend=backend) - z = data_alloc.zero_field(mesh, dims.EdgeDim, backend=backend) + x1 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + y1 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + z1 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + x2 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + y2 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + z2 = data_alloc.random_field(mesh, dims.EdgeDim, allocator=backend) + x = data_alloc.zero_field(mesh, dims.EdgeDim, allocator=backend) + y = data_alloc.zero_field(mesh, dims.EdgeDim, allocator=backend) + z = data_alloc.zero_field(mesh, dims.EdgeDim, allocator=backend) helpers.cross_product_on_edges.with_backend(backend)( x1, x2, y1, y2, z1, z2, out=(x, y, z), offset_provider={} diff --git a/model/common/tests/common/math/unit_tests/test_smagorinsky.py b/model/common/tests/common/math/unit_tests/test_smagorinsky.py index 51858cbe60..b4b8f8485a 100644 --- a/model/common/tests/common/math/unit_tests/test_smagorinsky.py +++ b/model/common/tests/common/math/unit_tests/test_smagorinsky.py @@ -17,9 +17,9 @@ def test_init_enh_smag_fac(backend, grid): - enh_smag_fac = data_alloc.zero_field(grid, dims.KDim, backend=backend) + enh_smag_fac = data_alloc.zero_field(grid, dims.KDim, allocator=backend) a_vec = data_alloc.random_field( - grid, dims.KDim, low=1.0, high=10.0, extend={dims.KDim: 1}, backend=backend + grid, dims.KDim, low=1.0, high=10.0, extend={dims.KDim: 1}, allocator=backend ) fac = (0.67, 0.5, 1.3, 0.8) z = (0.1, 0.2, 0.3, 0.4) diff --git a/model/common/tests/common/metrics/unit_tests/test_compute_diffusion_metrics.py b/model/common/tests/common/metrics/unit_tests/test_compute_diffusion_metrics.py index af9c0e4eae..57330f033f 100644 --- a/model/common/tests/common/metrics/unit_tests/test_compute_diffusion_metrics.py +++ b/model/common/tests/common/metrics/unit_tests/test_compute_diffusion_metrics.py @@ -60,11 +60,11 @@ def test_compute_diffusion_mask_and_coeff( if experiment == definitions.Experiments.EXCLAIM_APE: pytest.skip(f"Fields not computed for {experiment}") - maxslp_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxhgtd_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxslp = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxhgtd = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - max_nbhgt = data_alloc.zero_field(icon_grid, dims.CellDim, backend=backend) + maxslp_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxhgtd_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxslp = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxhgtd = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + max_nbhgt = data_alloc.zero_field(icon_grid, dims.CellDim, allocator=backend) c2e2c = icon_grid.get_connectivity(dims.C2E2C).asnumpy() c_bln_avg = interpolation_savepoint.c_bln_avg() @@ -146,11 +146,11 @@ def test_compute_diffusion_intcoef_and_vertoffset( if experiment == definitions.Experiments.EXCLAIM_APE: pytest.skip(f"Fields not computed for {experiment}") - maxslp_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxhgtd_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxslp = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - maxhgtd = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - max_nbhgt = data_alloc.zero_field(icon_grid, dims.CellDim, backend=backend) + maxslp_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxhgtd_avg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxslp = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + maxhgtd = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + max_nbhgt = data_alloc.zero_field(icon_grid, dims.CellDim, allocator=backend) c2e2c = icon_grid.get_connectivity(dims.C2E2C).asnumpy() c_bln_avg = interpolation_savepoint.c_bln_avg() diff --git a/model/common/tests/common/metrics/unit_tests/test_compute_weight_factors.py b/model/common/tests/common/metrics/unit_tests/test_compute_weight_factors.py index 235ab596de..3ccf085268 100644 --- a/model/common/tests/common/metrics/unit_tests/test_compute_weight_factors.py +++ b/model/common/tests/common/metrics/unit_tests/test_compute_weight_factors.py @@ -44,7 +44,12 @@ def test_compute_wgtfac_c( backend: gtx_typing.Backend | None, ) -> None: wgtfac_c = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, extend={dims.KDim: 1}, backend=backend + icon_grid, + dims.CellDim, + dims.KDim, + dtype=ta.wpfloat, + extend={dims.KDim: 1}, + allocator=backend, ) wgtfac_c_ref = metrics_savepoint.wgtfac_c() z_ifc = metrics_savepoint.z_ifc() diff --git a/model/common/tests/common/metrics/unit_tests/test_compute_zdiff_gradp_dsl.py b/model/common/tests/common/metrics/unit_tests/test_compute_zdiff_gradp_dsl.py index d94cc5f7bc..0fb3f99c2f 100644 --- a/model/common/tests/common/metrics/unit_tests/test_compute_zdiff_gradp_dsl.py +++ b/model/common/tests/common/metrics/unit_tests/test_compute_zdiff_gradp_dsl.py @@ -55,9 +55,9 @@ def test_compute_zdiff_gradp_dsl( z_ifc = metrics_savepoint.z_ifc() z_ifc_ground_level = z_ifc.ndarray[:, icon_grid.num_levels] z_mc = metrics_savepoint.z_mc() - k_lev = data_alloc.index_field(icon_grid, dims.KDim, dtype=gtx.int32, backend=backend) + k_lev = data_alloc.index_field(icon_grid, dims.KDim, dtype=gtx.int32, allocator=backend) flat_idx = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, dtype=gtx.int32, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, dtype=gtx.int32, allocator=backend ) edge_domain = h_grid.domain(dims.EdgeDim) horizontal_start_edge = icon_grid.start_index(edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2)) diff --git a/model/common/tests/common/metrics/unit_tests/test_metric_fields.py b/model/common/tests/common/metrics/unit_tests/test_metric_fields.py index 8842452e06..5c662900c8 100644 --- a/model/common/tests/common/metrics/unit_tests/test_metric_fields.py +++ b/model/common/tests/common/metrics/unit_tests/test_metric_fields.py @@ -57,7 +57,7 @@ def test_compute_ddq_z_half( nlevp1 = icon_grid.num_levels + 1 z_mc = metrics_savepoint.z_mc() ddqz_z_half = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) mf.compute_ddqz_z_half.with_backend(backend=backend)( @@ -84,8 +84,8 @@ def test_compute_ddqz_z_full_and_inverse( ) -> None: z_ifc = metrics_savepoint.z_ifc() inv_ddqz_full_ref = metrics_savepoint.inv_ddqz_z_full() - ddqz_z_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - inv_ddqz_z_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + ddqz_z_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + inv_ddqz_z_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) mf.compute_ddqz_z_full_and_inverse.with_backend(backend)( z_ifc=z_ifc, @@ -110,7 +110,7 @@ def test_compute_scaling_factor_for_3d_divdamp( backend: gtx_typing.Backend, ) -> None: scalfac_dd3d_ref = metrics_savepoint.scalfac_dd3d() - scaling_factor_for_3d_divdamp = data_alloc.zero_field(icon_grid, dims.KDim, backend=backend) + scaling_factor_for_3d_divdamp = data_alloc.zero_field(icon_grid, dims.KDim, allocator=backend) divdamp_trans_start = 12500.0 divdamp_trans_end = 17500.0 divdamp_type = 3 @@ -143,7 +143,7 @@ def test_compute_rayleigh_w( rayleigh_w_ref = metrics_savepoint.rayleigh_w() vct_a_1 = grid_savepoint.vct_a().asnumpy()[0] rayleigh_w_full = data_alloc.zero_field( - icon_grid, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) rayleigh_type = 2 rayleigh_coeff = 0.1 if experiment == definitions.Experiments.EXCLAIM_APE else 5.0 @@ -172,8 +172,8 @@ def test_compute_coeff_dwdz( coeff1_dwdz_ref = metrics_savepoint.coeff1_dwdz() coeff2_dwdz_ref = metrics_savepoint.coeff2_dwdz() - coeff1_dwdz_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - coeff2_dwdz_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + coeff1_dwdz_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + coeff2_dwdz_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) ddqz_z_full = gtx.as_field( (dims.CellDim, dims.KDim), 1 / metrics_savepoint.inv_ddqz_z_full().asnumpy(), @@ -202,7 +202,7 @@ def test_compute_exner_w_explicit_weight_parameter( icon_grid: base_grid.Grid, metrics_savepoint: sb.MetricSavepoint, backend: gtx_typing.Backend ) -> None: exner_w_explicit_weight_parameter_full = data_alloc.zero_field( - icon_grid, dims.CellDim, backend=backend + icon_grid, dims.CellDim, allocator=backend ) vwind_expl_wgt_ref = metrics_savepoint.vwind_expl_wgt() exner_w_implicit_weight_parameter = metrics_savepoint.vwind_impl_wgt() @@ -232,7 +232,7 @@ def test_compute_exner_exfac( ) -> None: horizontal_start = icon_grid.start_index(cell_domain(horizontal.Zone.LATERAL_BOUNDARY_LEVEL_2)) exner_expol = 0.333 if experiment == definitions.Experiments.MCH_CH_R04B09 else 0.3333333333333 - exner_exfac = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + exner_exfac = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) exner_exfac_ref = metrics_savepoint.exner_exfac() mf.compute_exner_exfac.with_backend(backend)( ddxn_z_full=metrics_savepoint.ddxn_z_full(), @@ -265,10 +265,10 @@ def test_compute_exner_w_implicit_weight_parameter( tangent_orientation = grid_savepoint.tangent_orientation() inv_primal_edge_length = grid_savepoint.inverse_primal_edge_lengths() z_ddxn_z_half_e = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) z_ddxt_z_half_e = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) horizontal_start = icon_grid.start_index(edge_domain(horizontal.Zone.LATERAL_BOUNDARY_LEVEL_2)) @@ -342,7 +342,7 @@ def test_compute_wgtfac_e( backend: gtx_typing.Backend, ) -> None: wgtfac_e = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) wgtfac_e_ref = metrics_savepoint.wgtfac_e() mf.compute_wgtfac_e.with_backend(backend)( @@ -378,16 +378,16 @@ def test_compute_pressure_gradient_downward_extrapolation_mask_distance( c_lin_e = interpolation_savepoint.c_lin_e() topography = gtx.as_field((dims.CellDim,), z_ifc.ndarray[:, nlev], allocator=backend) # type: ignore[arg-type] # TODO(havogt): needs fix in GT4Py - k = data_alloc.index_field(icon_grid, dim=dims.KDim, extend={dims.KDim: 1}, backend=backend) - edges = data_alloc.index_field(icon_grid, dim=dims.EdgeDim, backend=backend) + k = data_alloc.index_field(icon_grid, dim=dims.KDim, extend={dims.KDim: 1}, allocator=backend) + edges = data_alloc.index_field(icon_grid, dim=dims.EdgeDim, allocator=backend) flat_idx = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, dtype=gtx.int32, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, dtype=gtx.int32, allocator=backend ) edge_mask = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, dtype=bool, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, dtype=bool, allocator=backend ) - ex_distance = data_alloc.zero_field(icon_grid, dims.EdgeDim, dims.KDim, backend=backend) + ex_distance = data_alloc.zero_field(icon_grid, dims.EdgeDim, dims.KDim, allocator=backend) start_edge_nudging = icon_grid.end_index(edge_domain(horizontal.Zone.NUDGING)) start_edge_nudging_2 = icon_grid.start_index(edge_domain(horizontal.Zone.NUDGING_LEVEL_2)) @@ -451,7 +451,7 @@ def test_compute_mask_prog_halo_c( backend: gtx_typing.Backend, ) -> None: mask_prog_halo_c_full = data_alloc.zero_field( - icon_grid, dims.CellDim, dtype=bool, backend=backend + icon_grid, dims.CellDim, dtype=bool, allocator=backend ) c_refin_ctrl = grid_savepoint.refin_ctrl(dims.CellDim) mask_prog_halo_c_ref = metrics_savepoint.mask_prog_halo_c() @@ -476,7 +476,7 @@ def test_compute_bdy_halo_c( grid_savepoint: sb.IconGridSavepoint, backend: gtx_typing.Backend, ) -> None: - bdy_halo_c_full = data_alloc.zero_field(icon_grid, dims.CellDim, dtype=bool, backend=backend) + bdy_halo_c_full = data_alloc.zero_field(icon_grid, dims.CellDim, dtype=bool, allocator=backend) c_refin_ctrl = grid_savepoint.refin_ctrl(dims.CellDim) bdy_halo_c_ref = metrics_savepoint.bdy_halo_c() horizontal_start = icon_grid.start_index(cell_domain(horizontal.Zone.HALO)) @@ -501,7 +501,9 @@ def test_compute_horizontal_mask_for_3d_divdamp( grid_savepoint: sb.IconGridSavepoint, backend: gtx_typing.Backend, ) -> None: - horizontal_mask_for_3d_divdamp = data_alloc.zero_field(icon_grid, dims.EdgeDim, backend=backend) + horizontal_mask_for_3d_divdamp = data_alloc.zero_field( + icon_grid, dims.EdgeDim, allocator=backend + ) e_refin_ctrl = grid_savepoint.refin_ctrl(dims.EdgeDim) horizontal_start = icon_grid.start_index(edge_domain(horizontal.Zone.LATERAL_BOUNDARY_LEVEL_2)) hmask_dd3d_ref = metrics_savepoint.hmask_dd3d() @@ -527,8 +529,8 @@ def test_compute_theta_exner_ref_mc( icon_grid: base_grid.Grid, backend: gtx_typing.Backend, ) -> None: - exner_ref_mc_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - theta_ref_mc_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + exner_ref_mc_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + theta_ref_mc_full = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) t0sl_bg = constants.SEA_LEVEL_TEMPERATURE del_t_bg = constants.DELTA_TEMPERATURE h_scal_bg = constants.HEIGHT_SCALE_FOR_REFERENCE_ATMOSPHERE diff --git a/model/common/tests/common/metrics/unit_tests/test_reference_atmosphere.py b/model/common/tests/common/metrics/unit_tests/test_reference_atmosphere.py index c655ebc942..871bd78b23 100644 --- a/model/common/tests/common/metrics/unit_tests/test_reference_atmosphere.py +++ b/model/common/tests/common/metrics/unit_tests/test_reference_atmosphere.py @@ -60,13 +60,13 @@ def test_compute_reference_atmosphere_fields_on_full_level_masspoints( z_mc = metrics_savepoint.z_mc() exner_ref_mc = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) rho_ref_mc = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) theta_ref_mc = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) compute_reference_atmosphere_cell_fields.with_backend(backend)( z_height=z_mc, @@ -103,13 +103,28 @@ def test_compute_reference_atmosphere_on_half_level_mass_points( z_ifc = metrics_savepoint.z_ifc() exner_ref_ic = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, dtype=ta.wpfloat, backend=backend + icon_grid, + dims.CellDim, + dims.KDim, + extend={dims.KDim: 1}, + dtype=ta.wpfloat, + allocator=backend, ) rho_ref_ic = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, dtype=ta.wpfloat, backend=backend + icon_grid, + dims.CellDim, + dims.KDim, + extend={dims.KDim: 1}, + dtype=ta.wpfloat, + allocator=backend, ) theta_ref_ic = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, dtype=ta.wpfloat, backend=backend + icon_grid, + dims.CellDim, + dims.KDim, + extend={dims.KDim: 1}, + dtype=ta.wpfloat, + allocator=backend, ) compute_reference_atmosphere_cell_fields.with_backend(backend=backend)( z_height=z_ifc, @@ -143,7 +158,7 @@ def test_compute_d_exner_dz_ref_ic( theta_ref_ic = metrics_savepoint.theta_ref_ic() d_exner_dz_ref_ic_ref = metrics_savepoint.d_exner_dz_ref_ic() d_exner_dz_ref_ic = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + icon_grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ) compute_d_exner_dz_ref_ic.with_backend(backend)( theta_ref_ic=theta_ref_ic, @@ -171,7 +186,7 @@ def test_compute_reference_atmosphere_on_full_level_edge_fields( z_mc = metrics_savepoint.z_mc() z_me = data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) horizontal_start = icon_grid.start_index( horizontal.domain(dims.EdgeDim)(horizontal.Zone.LATERAL_BOUNDARY_LEVEL_2) @@ -220,8 +235,8 @@ def test_compute_d2dexdz2_fac_mc( d2dexdz2_fac1_mc_ref = metrics_savepoint.d2dexdz2_fac1_mc() d2dexdz2_fac2_mc_ref = metrics_savepoint.d2dexdz2_fac2_mc() - d2dexdz2_fac1_mc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) - d2dexdz2_fac2_mc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, backend=backend) + d2dexdz2_fac1_mc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) + d2dexdz2_fac2_mc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) compute_d2dexdz2_fac_mc.with_backend(backend=backend)( theta_ref_mc=metrics_savepoint.theta_ref_mc(), diff --git a/model/driver/src/icon4py/model/driver/initialization_utils.py b/model/driver/src/icon4py/model/driver/initialization_utils.py index 11403c3956..e24f1302ee 100644 --- a/model/driver/src/icon4py/model/driver/initialization_utils.py +++ b/model/driver/src/icon4py/model/driver/initialization_utils.py @@ -152,11 +152,11 @@ def model_initialization_serialbox( tangential_wind=velocity_init_savepoint.vt(), vn_on_half_levels=velocity_init_savepoint.vn_ie(), contravariant_correction_at_cells_on_half_levels=velocity_init_savepoint.w_concorr_c(), - rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), normal_wind_iau_increment=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), - exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), exner_dynamical_increment=solve_nonhydro_init_savepoint.exner_dyn_incr(), ) @@ -165,34 +165,34 @@ def model_initialization_serialbox( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), pressure_ifc=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), temperature=data_alloc.zero_field( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), virtual_temperature=data_alloc.zero_field( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), u=data_alloc.zero_field( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), v=data_alloc.zero_field( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), ) @@ -212,7 +212,7 @@ def model_initialization_serialbox( grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), ) diff --git a/model/driver/src/icon4py/model/driver/testcases/gauss3d.py b/model/driver/src/icon4py/model/driver/testcases/gauss3d.py index 97eb0d5a30..368aa6f588 100644 --- a/model/driver/src/icon4py/model/driver/testcases/gauss3d.py +++ b/model/driver/src/icon4py/model/driver/testcases/gauss3d.py @@ -162,7 +162,7 @@ def model_initialization_gauss3d( # noqa: PLR0915 [too-many-statements] log.info("Hydrostatic adjustment computation completed.") eta_v = gtx.as_field((dims.CellDim, dims.KDim), eta_v_ndarray, allocator=backend) - eta_v_e = data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend) + eta_v_e = data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend) cell_2_edge_interpolation.cell_2_edge_interpolation.with_backend(backend)( eta_v, cell_2_edge_coeff, @@ -220,7 +220,7 @@ def model_initialization_gauss3d( # noqa: PLR0915 [too-many-statements] ) log.info("U, V computation completed.") - perturbed_exner = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend) + perturbed_exner = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend) testcases_utils.compute_perturbed_exner.with_backend(backend)( exner, data_provider.from_metrics_savepoint().exner_ref_mc(), diff --git a/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py b/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py index 2797e68edb..26a012d906 100644 --- a/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py +++ b/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py @@ -71,25 +71,25 @@ def model_initialization_jabw( # noqa: PLR0915 [too-many-statements] xp = data_alloc.import_array_ns(backend) wgtfac_c = data_alloc.as_field( - data_provider.from_metrics_savepoint().wgtfac_c(), backend=backend + data_provider.from_metrics_savepoint().wgtfac_c(), allocator=backend ).ndarray ddqz_z_half = data_alloc.as_field( - data_provider.from_metrics_savepoint().ddqz_z_half(), backend=backend + data_provider.from_metrics_savepoint().ddqz_z_half(), allocator=backend ).ndarray theta_ref_mc = data_alloc.as_field( - data_provider.from_metrics_savepoint().theta_ref_mc(), backend=backend + data_provider.from_metrics_savepoint().theta_ref_mc(), allocator=backend ).ndarray theta_ref_ic = data_alloc.as_field( - data_provider.from_metrics_savepoint().theta_ref_ic(), backend=backend + data_provider.from_metrics_savepoint().theta_ref_ic(), allocator=backend ).ndarray exner_ref_mc = data_alloc.as_field( - data_provider.from_metrics_savepoint().exner_ref_mc(), backend=backend + data_provider.from_metrics_savepoint().exner_ref_mc(), allocator=backend ).ndarray d_exner_dz_ref_ic = data_alloc.as_field( - data_provider.from_metrics_savepoint().d_exner_dz_ref_ic(), backend=backend + data_provider.from_metrics_savepoint().d_exner_dz_ref_ic(), allocator=backend ).ndarray geopot = data_alloc.as_field( - data_provider.from_metrics_savepoint().geopot(), backend=backend + data_provider.from_metrics_savepoint().geopot(), allocator=backend ).ndarray cell_lat = cell_param.cell_center_lat.ndarray @@ -98,13 +98,13 @@ def model_initialization_jabw( # noqa: PLR0915 [too-many-statements] primal_normal_x = edge_param.primal_normal[0].ndarray cell_2_edge_coeff = data_alloc.as_field( - data_provider.from_interpolation_savepoint().c_lin_e(), backend=backend + data_provider.from_interpolation_savepoint().c_lin_e(), allocator=backend ) rbf_vec_coeff_c1 = data_alloc.as_field( - data_provider.from_interpolation_savepoint().rbf_vec_coeff_c1(), backend=backend + data_provider.from_interpolation_savepoint().rbf_vec_coeff_c1(), allocator=backend ) rbf_vec_coeff_c2 = data_alloc.as_field( - data_provider.from_interpolation_savepoint().rbf_vec_coeff_c2(), backend=backend + data_provider.from_interpolation_savepoint().rbf_vec_coeff_c2(), allocator=backend ) num_cells = grid.num_cells @@ -219,7 +219,7 @@ def model_initialization_jabw( # noqa: PLR0915 [too-many-statements] log.info("Newton iteration completed!") eta_v = gtx.as_field((dims.CellDim, dims.KDim), eta_v_ndarray, allocator=backend) - eta_v_e = data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend) + eta_v_e = data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend) cell_2_edge_interpolation.cell_2_edge_interpolation.with_backend(backend)( eta_v, cell_2_edge_coeff, @@ -308,7 +308,7 @@ def model_initialization_jabw( # noqa: PLR0915 [too-many-statements] log.info("U, V computation completed.") - perturbed_exner = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend) + perturbed_exner = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend) testcases_utils.compute_perturbed_exner.with_backend(backend)( exner, data_provider.from_metrics_savepoint().exner_ref_mc(), diff --git a/model/driver/src/icon4py/model/driver/testcases/utils.py b/model/driver/src/icon4py/model/driver/testcases/utils.py index df5b450fd9..2492cce7da 100644 --- a/model/driver/src/icon4py/model/driver/testcases/utils.py +++ b/model/driver/src/icon4py/model/driver/testcases/utils.py @@ -222,16 +222,16 @@ def initialize_diffusion_diagnostic_state( ) -> diffusion_states.DiffusionDiagnosticState: return diffusion_states.DiffusionDiagnosticState( hdef_ic=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), div_ic=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), dwdx=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), dwdy=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), ) @@ -242,57 +242,57 @@ def initialize_solve_nonhydro_diagnostic_state( backend: gtx_typing.Backend | None, ) -> dycore_states.DiagnosticStateNonHydro: normal_wind_advective_tendency = common_utils.PredictorCorrectorPair( - data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), - data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), + data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), + data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), ) vertical_wind_advective_tendency = common_utils.PredictorCorrectorPair( data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), ) return dycore_states.DiagnosticStateNonHydro( max_vertical_cfl=0.0, theta_v_at_cells_on_half_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), perturbed_exner_at_cells_on_model_levels=perturbed_exner_at_cells_on_model_levels, rho_at_cells_on_half_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), exner_tendency_due_to_slow_physics=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, backend=backend + grid, dims.CellDim, dims.KDim, allocator=backend ), - grf_tend_rho=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), - grf_tend_thv=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + grf_tend_rho=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), + grf_tend_thv=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), grf_tend_w=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), mass_flux_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), normal_wind_tendency_due_to_slow_physics_process=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), - grf_tend_vn=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), + grf_tend_vn=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), normal_wind_advective_tendency=normal_wind_advective_tendency, vertical_wind_advective_tendency=vertical_wind_advective_tendency, - tangential_wind=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), + tangential_wind=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), vn_on_half_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.EdgeDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), contravariant_correction_at_cells_on_half_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), - rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + rho_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), normal_wind_iau_increment=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, backend=backend + grid, dims.EdgeDim, dims.KDim, allocator=backend ), - exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend), + exner_iau_increment=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend), exner_dynamical_increment=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, backend=backend + grid, dims.CellDim, dims.KDim, allocator=backend ), ) @@ -301,13 +301,13 @@ def initialize_prep_advection( grid: icon_grid.IconGrid, backend: gtx_typing.Backend | None ) -> dycore_states.PrepAdvection: return dycore_states.PrepAdvection( - vn_traj=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), - mass_flx_me=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, backend=backend), + vn_traj=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), + mass_flx_me=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=backend), dynamical_vertical_mass_flux_at_cells_on_half_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), dynamical_vertical_volumetric_flux_at_cells_on_half_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, backend=backend + grid, dims.CellDim, dims.KDim, extend={dims.KDim: 1}, allocator=backend ), ) @@ -359,8 +359,8 @@ def create_gt4py_field_for_prognostic_and_diagnostic_variables( rho_next = gtx.as_field((dims.CellDim, dims.KDim), rho_ndarray, allocator=backend) theta_v_next = gtx.as_field((dims.CellDim, dims.KDim), theta_v_ndarray, allocator=backend) - u = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend) - v = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, backend=backend) + u = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend) + v = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=backend) return ( vn, diff --git a/model/driver/tests/driver/integration_tests/test_icon4py.py b/model/driver/tests/driver/integration_tests/test_icon4py.py index fccd3ccd4c..6dad55cfcc 100644 --- a/model/driver/tests/driver/integration_tests/test_icon4py.py +++ b/model/driver/tests/driver/integration_tests/test_icon4py.py @@ -257,7 +257,7 @@ def test_run_timeloop_single_step( icon_grid, dims.CellDim, dims.KDim, - backend=backend, + allocator=backend, ), ) @@ -284,13 +284,13 @@ def test_run_timeloop_single_step( vn_on_half_levels=sp_v.vn_ie(), contravariant_correction_at_cells_on_half_levels=sp_v.w_concorr_c(), rho_iau_increment=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), # sp.rho_incr(), normal_wind_iau_increment=data_alloc.zero_field( - icon_grid, dims.EdgeDim, dims.KDim, backend=backend + icon_grid, dims.EdgeDim, dims.KDim, allocator=backend ), # sp.vn_incr(), exner_iau_increment=data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, backend=backend + icon_grid, dims.CellDim, dims.KDim, allocator=backend ), # sp.exner_incr(), exner_dynamical_increment=sp.exner_dyn_incr(), ) diff --git a/model/testing/src/icon4py/model/testing/grid_utils.py b/model/testing/src/icon4py/model/testing/grid_utils.py index d3ffc5779f..c2bc46a68c 100644 --- a/model/testing/src/icon4py/model/testing/grid_utils.py +++ b/model/testing/src/icon4py/model/testing/grid_utils.py @@ -109,7 +109,7 @@ def construct_decomposition_info( xp = data_alloc.array_ns(on_gpu) def _add_dimension(dim: gtx.Dimension) -> None: - indices = data_alloc.index_field(grid, dim, backend=backend) + indices = data_alloc.index_field(grid, dim, allocator=backend) owner_mask = xp.ones((grid.size[dim],), dtype=bool) decomposition_info.with_dimension(dim, indices.ndarray, owner_mask) diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/common.py b/tools/src/icon4py/tools/py2fgen/wrappers/common.py index 3e199d659e..04c1f6c18f 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/common.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/common.py @@ -19,6 +19,7 @@ import numpy as np from gt4py import eve from gt4py._core import definitions as gt4py_definitions +from gt4py.next import allocators as gtx_allocators from icon4py.model.common import dimension as dims, model_backends from icon4py.model.common.decomposition import definitions, mpi_decomposition @@ -67,15 +68,26 @@ class BackendIntEnum(eve.IntEnum): } -def select_backend(selector: BackendIntEnum, on_gpu: bool) -> gtx_typing.Backend: - default_cpu = BackendIntEnum._GTFN_CPU - default_gpu = BackendIntEnum._GTFN_GPU - if selector == BackendIntEnum.DEFAULT: - selector = BackendIntEnum.DEFAULT_GPU if on_gpu else BackendIntEnum.DEFAULT_CPU +def select_backend( + selector: BackendIntEnum, on_gpu: bool +) -> gtx_typing.Backend | model_backends.DeviceType: if selector == BackendIntEnum.DEFAULT_CPU: - selector = default_cpu - elif selector == BackendIntEnum.DEFAULT_GPU: - selector = default_gpu + if on_gpu: + raise ValueError( + f"Inconsistent backend selection: {selector.name} and on_gpu={on_gpu}." + ) + return model_backends.CPU + if selector == BackendIntEnum.DEFAULT_GPU: + if not on_gpu: + raise ValueError( + f"Inconsistent backend selection: {selector.name} and on_gpu={on_gpu}." + ) + assert isinstance(model_backends.GPU, model_backends.DeviceType) + return model_backends.GPU + if selector == BackendIntEnum.DEFAULT: + device_type = model_backends.GPU if on_gpu else model_backends.CPU + assert isinstance(device_type, model_backends.DeviceType) + return device_type if selector not in ( BackendIntEnum._GTFN_CPU, @@ -96,7 +108,7 @@ def select_backend(selector: BackendIntEnum, on_gpu: bool) -> gtx_typing.Backend def cached_dummy_field_factory( - allocator: gtx_typing.Backend, + allocator: gtx_allocators.FieldBufferAllocationUtil | None, ) -> Callable[[str, gtx.Domain, gt4py_definitions.DType], gtx.Field]: # curried to exclude non-hashable backend from cache @functools.lru_cache(maxsize=20) @@ -154,7 +166,7 @@ def construct_icon_grid( vertical_size: int, limited_area: bool, mean_cell_area: gtx.float64, # type:ignore[name-defined] # TODO(): fix type hint - backend: gtx_typing.Backend, + allocator: gtx_allocators.FieldBufferAllocationUtil | None, ) -> icon.IconGrid: log.debug("Constructing ICON Grid in Python...") log.debug("num_cells:%s", num_cells) @@ -164,7 +176,7 @@ def construct_icon_grid( log.debug("Offsetting Fortran connectivitity arrays by 1") - xp = data_alloc.import_array_ns(backend) + xp = data_alloc.import_array_ns(allocator) start_indices = { # TODO(halungge): ICON Fortran has 0 values in these arrays in some places possibly where they don't use them. # We should investigate where we access these values. @@ -228,7 +240,7 @@ def construct_icon_grid( return icon.icon_grid( id_=grid_id, - allocator=backend, + allocator=allocator, config=config, neighbor_tables=neighbor_tables, start_index=start_index, diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py index c0209dd461..53dd2f9f2f 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py @@ -36,7 +36,7 @@ DiffusionInterpolationState, DiffusionMetricState, ) -from icon4py.model.common import dimension as dims, field_type_aliases as fa +from icon4py.model.common import dimension as dims, field_type_aliases as fa, model_backends from icon4py.model.common.grid.vertical import VerticalGrid, VerticalGridConfig from icon4py.model.common.states.prognostic_state import PrognosticState from icon4py.model.common.type_alias import wpfloat @@ -50,7 +50,7 @@ @dataclasses.dataclass class DiffusionGranule: diffusion: Diffusion - backend: gtx_typing.Backend + backend: gtx_typing.Backend | model_backends.DeviceType dummy_field_factory: Callable profiler: cProfile.Profile = dataclasses.field(default_factory=cProfile.Profile) @@ -117,10 +117,10 @@ def diffusion_init( actual_backend = wrapper_common.select_backend( wrapper_common.BackendIntEnum(backend), on_gpu=on_gpu ) - logger.info(f"{on_gpu=}") - logger.info( - f"Using Backend {wrapper_common.BackendIntEnum(backend).name} ({actual_backend.name})" + backend_name = ( + actual_backend.name if hasattr(actual_backend, "name") else actual_backend.__name__ ) + logger.info(f"Using Backend {backend_name} with on_gpu={on_gpu}") # Diffusion parameters config = DiffusionConfig( @@ -215,7 +215,9 @@ def diffusion_init( exchange=grid_wrapper.grid_state.exchange_runtime, ), backend=actual_backend, - dummy_field_factory=wrapper_common.cached_dummy_field_factory(actual_backend), + dummy_field_factory=wrapper_common.cached_dummy_field_factory( + model_backends.get_allocator(actual_backend) + ), ) diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py index 8d9ae6927e..af0433ff71 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py @@ -29,7 +29,7 @@ from gt4py.next.type_system import type_specifications as ts from icon4py.model.atmosphere.dycore import dycore_states, solve_nonhydro -from icon4py.model.common import dimension as dims, utils as common_utils +from icon4py.model.common import dimension as dims, model_backends, utils as common_utils from icon4py.model.common.grid.vertical import VerticalGrid, VerticalGridConfig from icon4py.model.common.states.prognostic_state import PrognosticState from icon4py.tools import py2fgen @@ -43,7 +43,7 @@ @dataclasses.dataclass class SolveNonhydroGranule: solve_nh: solve_nonhydro.SolveNonhydro - backend: gtx_typing.Backend + backend: gtx_typing.Backend | model_backends.DeviceType dummy_field_factory: Callable profiler: cProfile.Profile = dataclasses.field(default_factory=cProfile.Profile) @@ -153,10 +153,10 @@ def solve_nh_init( actual_backend = wrapper_common.select_backend( wrapper_common.BackendIntEnum(backend), on_gpu=on_gpu ) - logger.info(f"{on_gpu=}") - logger.info( - f"Using Backend {wrapper_common.BackendIntEnum(backend).name} ({actual_backend.name})" + backend_name = ( + actual_backend.name if hasattr(actual_backend, "name") else actual_backend.__name__ ) + logger.info(f"Using Backend {backend_name} with on_gpu={on_gpu}") config = solve_nonhydro.NonHydrostaticConfig( itime_scheme=itime_scheme, @@ -269,7 +269,9 @@ def solve_nh_init( exchange=grid_wrapper.grid_state.exchange_runtime, ), backend=actual_backend, - dummy_field_factory=wrapper_common.cached_dummy_field_factory(actual_backend), + dummy_field_factory=wrapper_common.cached_dummy_field_factory( + model_backends.get_allocator(actual_backend) + ), ) diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/grid_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/grid_wrapper.py index ffa8e8b7b4..312ddad55d 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/grid_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/grid_wrapper.py @@ -14,7 +14,7 @@ from gt4py.next.type_system import type_specifications as ts import icon4py.model.common.grid.states as grid_states -from icon4py.model.common import dimension as dims, field_type_aliases as fa +from icon4py.model.common import dimension as dims, field_type_aliases as fa, model_backends from icon4py.model.common.decomposition import definitions as decomposition_defs from icon4py.model.common.grid import icon as icon_grid from icon4py.model.common.type_alias import wpfloat @@ -114,6 +114,7 @@ def grid_init( actual_backend = wrapper_common.select_backend( wrapper_common.BackendIntEnum(backend), on_gpu=on_gpu ) + allocator = model_backends.get_allocator(actual_backend) grid = wrapper_common.construct_icon_grid( cell_starts=cell_starts, cell_ends=cell_ends, @@ -137,7 +138,7 @@ def grid_init( vertical_size=vertical_size, limited_area=limited_area, mean_cell_area=mean_cell_area, - backend=actual_backend, + allocator=allocator, ) # Edge geometry edge_params = grid_states.EdgeParams( From 055051daf9dc082bee9baadc3cf655b9ba5ce6b7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 1 Oct 2025 21:55:34 +0200 Subject: [PATCH 03/23] improve typing --- .../model/common/utils/data_allocation.py | 20 +++++++++---------- .../model/common/utils/device_utils.py | 16 ++++++++++----- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/model/common/src/icon4py/model/common/utils/data_allocation.py b/model/common/src/icon4py/model/common/utils/data_allocation.py index fd3ac89d94..375c5b3471 100644 --- a/model/common/src/icon4py/model/common/utils/data_allocation.py +++ b/model/common/src/icon4py/model/common/utils/data_allocation.py @@ -70,7 +70,7 @@ def import_array_ns(allocator: gtx_allocators.FieldBufferAllocationUtil | None) def as_field( field: gtx.Field, - allocator: gtx_typing.Backend | None = None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, embedded_on_host: bool = False, ) -> gtx.Field: """Convenience function to transfer an existing Field to a given backend.""" @@ -79,13 +79,13 @@ def as_field( def random_field( - grid, - *dims, + grid: grid_base.Grid, + *dims: gtx.Dimension, low: float = -1.0, high: float = 1.0, dtype: npt.DTypeLike | None = None, extend: dict[gtx.Dimension, int] | None = None, - allocator=None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, ) -> gtx.Field: arr = np.random.default_rng().uniform( low=low, high=high, size=_shape(grid, *dims, extend=extend) @@ -96,11 +96,11 @@ def random_field( def random_sign( - grid, - *dims, + grid: grid_base.Grid, + *dims: gtx.Dimension, dtype: npt.DTypeLike | None = None, extend: dict[gtx.Dimension, int] | None = None, - allocator=None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, ) -> gtx.Field: """Generate a random field with values -1 or 1.""" arr = np.random.default_rng().choice([-1, 1], size=_shape(grid, *dims, extend=extend)) @@ -133,7 +133,7 @@ def zero_field( *dims: gtx.Dimension, dtype=ta.wpfloat, extend: dict[gtx.Dimension, int] | None = None, - allocator=None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, ) -> gtx.Field: field_domain = {dim: (0, stop) for dim, stop in zip(dims, _shape(grid, *dims, extend=extend))} return gtx.constructors.zeros(field_domain, dtype=dtype, allocator=allocator) @@ -144,7 +144,7 @@ def constant_field( value: float, *dims: gtx.Dimension, dtype=ta.wpfloat, - allocator=None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, ) -> gtx.Field: return gtx.as_field( dims, @@ -167,7 +167,7 @@ def index_field( dim: gtx.Dimension, extend: dict[gtx.Dimension, int] | None = None, dtype=gtx.int32, - allocator: gtx_typing.Backend | None = None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None = None, ) -> gtx.Field: xp = import_array_ns(allocator) shapex = _shape(grid, dim, extend=extend)[0] diff --git a/model/common/src/icon4py/model/common/utils/device_utils.py b/model/common/src/icon4py/model/common/utils/device_utils.py index f131bdde0c..30b08083f5 100644 --- a/model/common/src/icon4py/model/common/utils/device_utils.py +++ b/model/common/src/icon4py/model/common/utils/device_utils.py @@ -8,7 +8,7 @@ import functools from collections.abc import Callable -from typing import Any +from typing import Any, ParamSpec, TypeVar import gt4py.next as gtx import gt4py.next.allocators as gtx_allocators @@ -21,9 +21,9 @@ cp = None -def is_cupy_device( - allocator: gtx_allocators.FieldBufferAllocationUtil | None, -) -> bool: +def is_cupy_device(allocator: gtx_allocators.FieldBufferAllocationUtil | None) -> bool: + if allocator is None: + return False return gtx_allocators.is_field_allocation_tool_for(allocator, gtx.CUPY_DEVICE_TYPE) @@ -37,7 +37,13 @@ def sync(backend: gtx_typing.Backend | None = None) -> None: cp.cuda.runtime.deviceSynchronize() -def synchronized_function(func: Callable[..., Any], *, backend: gtx_typing.Backend | None): +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def synchronized_function( + func: Callable[_P, _R], *, backend: gtx_typing.Backend | None +) -> Callable[_P, _R]: """ Wraps a function and synchronizes after execution """ From 559106e5333ed867873fc8ff31891b9c9f3226f1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 1 Oct 2025 22:02:28 +0200 Subject: [PATCH 04/23] fix diffusion --- .../src/icon4py/model/atmosphere/diffusion/diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 7fb1fc98d8..b1574641cd 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -406,7 +406,7 @@ def __init__( self._determine_horizontal_domains() self.mo_intp_rbf_rbf_vec_interpol_vertex = setup_program( - backend=self._backend, + backend=backend, program=mo_intp_rbf_rbf_vec_interpol_vertex, constant_args={ "ptr_coeff_1": self._interpolation_state.rbf_coeff_1, @@ -911,7 +911,7 @@ def _do_diffusion_step( def orchestration_uid(self) -> str: """Unique id based on the runtime state of the Diffusion object. It is used for caching in DaCe Orchestration.""" members_to_disregard = [ - "_backend", + "_allocator", "_exchange", "_grid", "compile_time_connectivities", From 03e6a46448efca66f0b9e2e2d8a8fa301ad6d537 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 1 Oct 2025 22:13:50 +0200 Subject: [PATCH 05/23] cleanup and fix allocator/backend --- .../model/atmosphere/diffusion/diffusion.py | 6 ++---- .../model/atmosphere/dycore/solve_nonhydro.py | 19 ++++++++++--------- .../atmosphere/dycore/velocity_advection.py | 4 +--- .../py2fgen/wrappers/diffusion_wrapper.py | 3 --- .../tools/py2fgen/wrappers/dycore_wrapper.py | 3 --- 5 files changed, 13 insertions(+), 22 deletions(-) diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index b1574641cd..2f4fba2226 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -561,7 +561,7 @@ def __init__( offset_provider={"Koff": dims.KDim}, ) - self._allocate_local_fields() + self._allocate_local_fields(model_backends.get_allocator(backend)) self.init_diffusion_local_fields_for_regular_timestep( params.K4, @@ -596,9 +596,7 @@ def __init__( # but this requires some changes in gt4py domain inference. self.compile_time_connectivities = self._grid.connectivities - def _allocate_local_fields( - self, allocator: gtx_allocators.FieldBufferAllocationUtil | None = None - ): + def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self.diff_multfac_vn = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) self.diff_multfac_n2w = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) self.smag_limit = data_alloc.zero_field(self._grid, dims.KDim, allocator=allocator) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 2e1a80a0e4..cd38f57dfe 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -13,6 +13,7 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing +from gt4py.next import allocators as gtx_allocators import icon4py.model.atmosphere.dycore.solve_nonhydro_stencils as nhsolve_stencils import icon4py.model.common.grid.states as grid_states @@ -112,29 +113,29 @@ class IntermediateFields: def allocate( cls, grid: grid_def.Grid, - backend: gtx_typing.Backend | None = None, + allocator: gtx_allocators.FieldBufferAllocationUtil | None, ): return IntermediateFields( horizontal_pressure_gradient=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), rho_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), theta_v_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), horizontal_gradient_of_normal_wind_divergence=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), dwdz_at_cells_on_model_levels=data_alloc.zero_field( - grid, dims.CellDim, dims.KDim, allocator=backend + grid, dims.CellDim, dims.KDim, allocator=allocator ), horizontal_kinetic_energy_at_edges_on_model_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), tangential_wind_on_half_levels=data_alloc.zero_field( - grid, dims.EdgeDim, dims.KDim, allocator=backend + grid, dims.EdgeDim, dims.KDim, allocator=allocator ), ) @@ -822,7 +823,7 @@ def __init__( self.p_test_run = True - def _allocate_local_fields(self, allocator): + def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self.temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( self._grid, dims.CellDim, diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py index a664c1034d..126f386d33 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py @@ -179,9 +179,7 @@ def __init__( offset_provider=self.grid.connectivities, ) - def _allocate_local_fields( - self, allocator: gtx_allocators.FieldBufferAllocationUtil | None = None - ): + def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self._horizontal_advection_of_w_at_edges_on_half_levels = data_alloc.zero_field( self.grid, dims.EdgeDim, dims.KDim, allocator=allocator, dtype=ta.vpfloat ) diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py index 53dd2f9f2f..0419be7e92 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py @@ -22,7 +22,6 @@ from collections.abc import Callable import gt4py.next as gtx -import gt4py.next.typing as gtx_typing import numpy as np from icon4py.model.atmosphere.diffusion.diffusion import ( @@ -50,7 +49,6 @@ @dataclasses.dataclass class DiffusionGranule: diffusion: Diffusion - backend: gtx_typing.Backend | model_backends.DeviceType dummy_field_factory: Callable profiler: cProfile.Profile = dataclasses.field(default_factory=cProfile.Profile) @@ -214,7 +212,6 @@ def diffusion_init( backend=actual_backend, exchange=grid_wrapper.grid_state.exchange_runtime, ), - backend=actual_backend, dummy_field_factory=wrapper_common.cached_dummy_field_factory( model_backends.get_allocator(actual_backend) ), diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py index af0433ff71..cddfb2ea98 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py @@ -23,7 +23,6 @@ from typing import Annotated, TypeAlias import gt4py.next as gtx -import gt4py.next.typing as gtx_typing import numpy as np from gt4py.next import config as gtx_config, metrics as gtx_metrics from gt4py.next.type_system import type_specifications as ts @@ -43,7 +42,6 @@ @dataclasses.dataclass class SolveNonhydroGranule: solve_nh: solve_nonhydro.SolveNonhydro - backend: gtx_typing.Backend | model_backends.DeviceType dummy_field_factory: Callable profiler: cProfile.Profile = dataclasses.field(default_factory=cProfile.Profile) @@ -268,7 +266,6 @@ def solve_nh_init( backend=actual_backend, exchange=grid_wrapper.grid_state.exchange_runtime, ), - backend=actual_backend, dummy_field_factory=wrapper_common.cached_dummy_field_factory( model_backends.get_allocator(actual_backend) ), From 3b237a56af1b0936c2977ffe76388128d243a9ef Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 1 Oct 2025 22:39:17 +0200 Subject: [PATCH 06/23] dace default, gtfn for vertically implicit --- .../src/icon4py/model/common/model_options.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index dac4d116c1..09c09bba92 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause import functools +import logging import typing import gt4py.next as gtx @@ -14,18 +15,29 @@ from icon4py.model.common import model_backends +log = logging.getLogger(__name__) + + def dict_values_to_list(d: dict[str, typing.Any]) -> dict[str, list]: return {k: [v] for k, v in d.items()} +def get_options( + program_name: str, **backend_description: typing.Any +) -> model_backends.BackendDescriptor: + if program_name.startswith("vertically_implicit"): + backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend + return backend_description + + def customize_backend( + program_name: str, backend: model_backends.DeviceType | model_backends.BackendDescriptor, ) -> gtx_typing.Backend: if isinstance(backend, model_backends.DeviceType): backend = {"device": backend} - # TODO(havogt): implement the lookup function as below - # options = get_options(program_name, arch, **backend) # noqa: ERA001 - backend_func = backend.get("backend_factory", model_backends.make_custom_gtfn_backend) + backend = get_options(program_name, **backend) + backend_func = backend.get("backend_factory", model_backends.make_custom_dace_backend) device = backend.get("device", model_backends.DeviceType.CPU) custom_backend = backend_func( device=device, @@ -65,7 +77,10 @@ def setup_program( offset_provider = {} if offset_provider is None else offset_provider if isinstance(backend, gtx.DeviceType) or model_backends.is_backend_descriptor(backend): - backend = customize_backend(backend) + backend = customize_backend(program.__name__, backend) + + backend_name = backend.name if backend is not None else "embedded" + log.info(f"Configured '{backend_name}' backend for {program.__name__}.") bound_static_args = {k: v for k, v in constant_args.items() if gtx.is_scalar_type(v)} static_args_program = program.with_backend(backend) From 6c6f9df73fcc5676f26b65f78c959c455d052c43 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 2 Oct 2025 07:57:19 +0200 Subject: [PATCH 07/23] from measurement --- .../src/icon4py/model/common/model_options.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index 09c09bba92..cf1250d029 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -22,10 +22,34 @@ def dict_values_to_list(d: dict[str, typing.Any]) -> dict[str, list]: return {k: [v] for k, v in d.items()} +gtfn_programs = { + "mo_intp_rbf_rbf_vec_interpol_vertex", + "calculate_diagnostic_quantities_for_turbulence", + "apply_diffusion_to_vn", + "apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence", + "calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools", + "apply_diffusion_to_theta_and_exner", + "compute_advection_in_horizontal_momentum_equation", + "compute_rayleigh_damping_factor", + "compute_perturbed_quantities_and_interpolation", + "compute_hydrostatic_correction_term", + "vertically_implicit_solver_at_predictor_step", + "stencils_61_62", + "compute_dwdz_for_divergence_damping", + "calculate_divdamp_fields", + "compute_averaged_vn_and_fluxes_and_prepare_tracer_advection", + "vertically_implicit_solver_at_corrector_step", + "init_cell_kdim_field_with_zero_wp", + "update_mass_flux_weighted", + "compute_theta_and_exner", + "compute_exner_from_rhotheta", +} + + def get_options( program_name: str, **backend_description: typing.Any ) -> model_backends.BackendDescriptor: - if program_name.startswith("vertically_implicit"): + if program_name in gtfn_programs: backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend return backend_description From 33fe48c2810df4d939c56cfe2f6f4376119528a0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 3 Oct 2025 11:52:08 +0200 Subject: [PATCH 08/23] customize one --- model/common/src/icon4py/model/common/model_backends.py | 9 +++++++-- model/common/src/icon4py/model/common/model_options.py | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/model_backends.py b/model/common/src/icon4py/model/common/model_backends.py index c0365d48c9..fd1031df34 100644 --- a/model/common/src/icon4py/model/common/model_backends.py +++ b/model/common/src/icon4py/model/common/model_backends.py @@ -21,7 +21,6 @@ "embedded": None, "roundtrip": gtx.itir_python, "gtfn_cpu": gtx.gtfn_cpu, - "gtfn_gpu": gtx.gtfn_gpu, } # DeviceType should always be imported from here, as we might replace it by an ICON4Py internal implementation @@ -110,10 +109,16 @@ def make_custom_dace_backend(device: str, **options) -> gtx_typing.Backend: raise NotImplementedError("Depends on dace module, which is not installed.") -def make_custom_gtfn_backend(device: DeviceType, cached: bool = True, **_) -> gtx_typing.Backend: +def make_custom_gtfn_backend( + device: DeviceType, cached: bool = True, fuse_all_fieldops=False, **_ +) -> gtx_typing.Backend: on_gpu = device == GPU return GTFNBackendFactory( gpu=on_gpu, cached=cached, + fuse_all_fieldops=fuse_all_fieldops, otf_workflow__cached_translation=cached, ) + + +BACKENDS["gtfn_gpu"] = make_custom_gtfn_backend(device=GPU) diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index cf1250d029..636d5e786c 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -51,6 +51,9 @@ def get_options( ) -> model_backends.BackendDescriptor: if program_name in gtfn_programs: backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend + if program_name == "compute_theta_rho_face_values_and_pressure_gradient_and_update_vn": + backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend + backend_description["fuse_all_fieldops"] = True return backend_description From 4a3ad8f36cd3fbdd30fdfdbc3f9876004b450383 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 3 Oct 2025 12:55:34 +0200 Subject: [PATCH 09/23] fix forwarding --- .../src/icon4py/model/common/model_options.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index 636d5e786c..cf2e20bb7d 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -64,11 +64,11 @@ def customize_backend( if isinstance(backend, model_backends.DeviceType): backend = {"device": backend} backend = get_options(program_name, **backend) - backend_func = backend.get("backend_factory", model_backends.make_custom_dace_backend) - device = backend.get("device", model_backends.DeviceType.CPU) - custom_backend = backend_func( - device=device, - ) + backend_factory = backend.pop("backend_factory", model_backends.make_custom_dace_backend) + backend["device"] = backend.get("device", model_backends.DeviceType.CPU) # set default device + custom_backend = backend_factory(**backend) + backend_name = custom_backend.name if custom_backend is not None else "embedded" + log.info(f"Using custom backend '{backend_name}' for '{program_name}' with options: {backend}.") return custom_backend @@ -105,9 +105,9 @@ def setup_program( if isinstance(backend, gtx.DeviceType) or model_backends.is_backend_descriptor(backend): backend = customize_backend(program.__name__, backend) - - backend_name = backend.name if backend is not None else "embedded" - log.info(f"Configured '{backend_name}' backend for {program.__name__}.") + else: + backend_name = backend.name if backend is not None else "embedded" + log.info(f"Using non-custom backend '{backend_name}' for '{program.__name__}'.") bound_static_args = {k: v for k, v in constant_args.items() if gtx.is_scalar_type(v)} static_args_program = program.with_backend(backend) From cde438cab6b49cf44a23ded79e410fbed4ad34be Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 8 Oct 2025 13:57:34 +0200 Subject: [PATCH 10/23] cleanup --- .../icon4py/model/common/model_backends.py | 5 +- .../src/icon4py/model/common/model_options.py | 82 +++++++++---------- .../model/testing/test_model_options.py | 16 ++-- .../icon4py/tools/py2fgen/wrappers/common.py | 59 ++++++------- 4 files changed, 76 insertions(+), 86 deletions(-) diff --git a/model/common/src/icon4py/model/common/model_backends.py b/model/common/src/icon4py/model/common/model_backends.py index fd1031df34..126a4379d6 100644 --- a/model/common/src/icon4py/model/common/model_backends.py +++ b/model/common/src/icon4py/model/common/model_backends.py @@ -109,14 +109,11 @@ def make_custom_dace_backend(device: str, **options) -> gtx_typing.Backend: raise NotImplementedError("Depends on dace module, which is not installed.") -def make_custom_gtfn_backend( - device: DeviceType, cached: bool = True, fuse_all_fieldops=False, **_ -) -> gtx_typing.Backend: +def make_custom_gtfn_backend(device: DeviceType, cached: bool = True, **_) -> gtx_typing.Backend: on_gpu = device == GPU return GTFNBackendFactory( gpu=on_gpu, cached=cached, - fuse_all_fieldops=fuse_all_fieldops, otf_workflow__cached_translation=cached, ) diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index cf2e20bb7d..029ccfef5c 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import functools import logging -import typing +from collections.abc import Callable +from typing import Any import gt4py.next as gtx import gt4py.next.typing as gtx_typing @@ -18,57 +19,54 @@ log = logging.getLogger(__name__) -def dict_values_to_list(d: dict[str, typing.Any]) -> dict[str, list]: +def dict_values_to_list(d: dict[str, Any]) -> dict[str, list]: return {k: [v] for k, v in d.items()} -gtfn_programs = { - "mo_intp_rbf_rbf_vec_interpol_vertex", - "calculate_diagnostic_quantities_for_turbulence", - "apply_diffusion_to_vn", - "apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence", - "calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools", - "apply_diffusion_to_theta_and_exner", - "compute_advection_in_horizontal_momentum_equation", - "compute_rayleigh_damping_factor", - "compute_perturbed_quantities_and_interpolation", - "compute_hydrostatic_correction_term", - "vertically_implicit_solver_at_predictor_step", - "stencils_61_62", - "compute_dwdz_for_divergence_damping", - "calculate_divdamp_fields", - "compute_averaged_vn_and_fluxes_and_prepare_tracer_advection", - "vertically_implicit_solver_at_corrector_step", - "init_cell_kdim_field_with_zero_wp", - "update_mass_flux_weighted", - "compute_theta_and_exner", - "compute_exner_from_rhotheta", -} - - -def get_options( - program_name: str, **backend_description: typing.Any +def get_dace_options( + program_name: str, **backend_descriptor: Any ) -> model_backends.BackendDescriptor: - if program_name in gtfn_programs: - backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend - if program_name == "compute_theta_rho_face_values_and_pressure_gradient_and_update_vn": - backend_description["backend_factory"] = model_backends.make_custom_gtfn_backend - backend_description["fuse_all_fieldops"] = True - return backend_description + return backend_descriptor + + +def get_gtfn_options( + program_name: str, **backend_descriptor: Any +) -> model_backends.BackendDescriptor: + return backend_descriptor + + +def get_options(program_name: str, **backend_descriptor: Any) -> model_backends.BackendDescriptor: + if "backend_factory" not in backend_descriptor: + # here we could set a backend_factory per program + backend_descriptor["backend_factory"] = model_backends.make_custom_dace_backend + if backend_descriptor["backend_factory"] == model_backends.make_custom_dace_backend: + backend_descriptor = get_dace_options(program_name, **backend_descriptor) + if backend_descriptor["backend_factory"] == model_backends.make_custom_gtfn_backend: + backend_descriptor = get_gtfn_options(program_name, **backend_descriptor) + + return backend_descriptor def customize_backend( program_name: str, backend: model_backends.DeviceType | model_backends.BackendDescriptor, ) -> gtx_typing.Backend: - if isinstance(backend, model_backends.DeviceType): - backend = {"device": backend} - backend = get_options(program_name, **backend) - backend_factory = backend.pop("backend_factory", model_backends.make_custom_dace_backend) - backend["device"] = backend.get("device", model_backends.DeviceType.CPU) # set default device - custom_backend = backend_factory(**backend) + backend_descriptor = ( + {"device": backend} if isinstance(backend, model_backends.DeviceType) else backend + ) + + backend_descriptor = get_options(program_name, **backend_descriptor) + backend_descriptor["device"] = backend_descriptor.get( + "device", model_backends.DeviceType.CPU + ) # set default device + backend_factory = backend_descriptor.pop( + "backend_factory", model_backends.make_custom_dace_backend + ) + custom_backend = backend_factory(**backend_descriptor) backend_name = custom_backend.name if custom_backend is not None else "embedded" - log.info(f"Using custom backend '{backend_name}' for '{program_name}' with options: {backend}.") + log.info( + f"Using custom backend '{backend_name}' for '{program_name}' with options: {backend_descriptor}." + ) return custom_backend @@ -83,7 +81,7 @@ def setup_program( horizontal_sizes: dict[str, gtx.int32] | None = None, vertical_sizes: dict[str, gtx.int32] | None = None, offset_provider: gtx_typing.OffsetProvider | None = None, -) -> typing.Callable[..., None]: +) -> Callable[..., None]: """ This function processes arguments to the GT4Py program. It - binds arguments that don't change during model run ('constant_args', 'horizontal_sizes', "vertical_sizes'); diff --git a/model/testing/src/icon4py/model/testing/test_model_options.py b/model/testing/src/icon4py/model/testing/test_model_options.py index 3e585fa671..33a3c268d8 100644 --- a/model/testing/src/icon4py/model/testing/test_model_options.py +++ b/model/testing/src/icon4py/model/testing/test_model_options.py @@ -38,7 +38,7 @@ def test_custom_backend_options(backend_factory: typing.Callable, expected_backe "backend_factory": backend_factory, "device": model_backends.CPU, } - backend = customize_backend(backend_options) + backend = customize_backend("foo", backend_options) backend_name = expected_backend + "_cpu" # TODO(havogt): test should be improved to work without string comparison assert repr(model_backends.BACKENDS[backend_name]) == repr(backend) @@ -46,8 +46,8 @@ def test_custom_backend_options(backend_factory: typing.Callable, expected_backe def test_custom_backend_device() -> None: device = model_backends.CPU - backend = customize_backend(device) - default_backend = "gtfn_cpu" + backend = customize_backend("foo", device) + default_backend = "dace_cpu" # TODO(havogt): test should be improved to work without string comparison assert repr(model_backends.BACKENDS[default_backend]) == repr(backend) @@ -55,10 +55,10 @@ def test_custom_backend_device() -> None: @pytest.mark.parametrize( "backend", [ - model_backends.BACKENDS["gtfn_cpu"], + model_backends.BACKENDS["dace_cpu"], model_backends.CPU, - {"backend_factory": model_backends.make_custom_gtfn_backend, "device": model_backends.CPU}, - {"backend_factory": model_backends.make_custom_gtfn_backend}, + {"backend_factory": model_backends.make_custom_dace_backend, "device": model_backends.CPU}, + {"backend_factory": model_backends.make_custom_dace_backend}, {"device": model_backends.CPU}, ], ) @@ -69,7 +69,7 @@ def test_setup_program_defaults( | None, ) -> None: partial_program = setup_program(backend=backend, program=program_return_field) - backend = model_backends.BACKENDS["gtfn_cpu"] + backend = model_backends.BACKENDS["dace_cpu"] expected_partial = functools.partial( program_return_field.with_backend(backend).compile( enable_jit=False, @@ -94,7 +94,7 @@ def test_setup_program_defaults( "dace_gpu", ), ({"backend_factory": model_backends.make_custom_dace_backend}, "dace_cpu"), - ({"device": model_backends.GPU}, "gtfn_gpu"), + ({"device": model_backends.GPU}, "dace_gpu"), ], ) def test_setup_program_specify_inputs( diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/common.py b/tools/src/icon4py/tools/py2fgen/wrappers/common.py index 04c1f6c18f..41469973d9 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/common.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/common.py @@ -49,8 +49,8 @@ class BackendIntEnum(eve.IntEnum): DEFAULT = 0 - DEFAULT_CPU = 1 - DEFAULT_GPU = 2 + DACE = 1 + GTFN = 2 _GTFN_CPU = 11 _GTFN_GPU = 12 _DACE_CPU = 21 @@ -70,41 +70,36 @@ class BackendIntEnum(eve.IntEnum): def select_backend( selector: BackendIntEnum, on_gpu: bool -) -> gtx_typing.Backend | model_backends.DeviceType: - if selector == BackendIntEnum.DEFAULT_CPU: - if on_gpu: - raise ValueError( - f"Inconsistent backend selection: {selector.name} and on_gpu={on_gpu}." - ) - return model_backends.CPU - if selector == BackendIntEnum.DEFAULT_GPU: - if not on_gpu: - raise ValueError( - f"Inconsistent backend selection: {selector.name} and on_gpu={on_gpu}." - ) - assert isinstance(model_backends.GPU, model_backends.DeviceType) - return model_backends.GPU - if selector == BackendIntEnum.DEFAULT: - device_type = model_backends.GPU if on_gpu else model_backends.CPU - assert isinstance(device_type, model_backends.DeviceType) - return device_type - - if selector not in ( +) -> gtx_typing.Backend | model_backends.BackendDescriptor: + if selector in ( BackendIntEnum._GTFN_CPU, BackendIntEnum._GTFN_GPU, BackendIntEnum._DACE_CPU, BackendIntEnum._DACE_GPU, ): - raise ValueError(f"Invalid backend selector: {selector.name}") - if on_gpu and selector in (BackendIntEnum._DACE_CPU, BackendIntEnum._GTFN_CPU): - raise ValueError(f"Inconsistent backend selection: {selector.name} and on_gpu=True") - if not on_gpu and selector in (BackendIntEnum._DACE_GPU, BackendIntEnum._GTFN_GPU): - raise ValueError(f"Inconsistent backend selection: {selector.name} and on_gpu=False") - - backend = _BACKEND_MAP.get(selector) - assert backend is not None - - return backend + # Concrete non-customizable backends. + # TODO(havogt): consider removing + if on_gpu and selector in (BackendIntEnum._DACE_CPU, BackendIntEnum._GTFN_CPU): + raise ValueError(f"Inconsistent backend selection: {selector.name} and on_gpu=True") + if not on_gpu and selector in (BackendIntEnum._DACE_GPU, BackendIntEnum._GTFN_GPU): + raise ValueError(f"Inconsistent backend selection: {selector.name} and on_gpu=False") + + backend = _BACKEND_MAP.get(selector) + assert backend is not None + return backend + + backend_descriptor: model_backends.BackendDescriptor = {} + backend_descriptor["device"] = model_backends.GPU if on_gpu else model_backends.CPU + if selector == BackendIntEnum.DEFAULT: + return backend_descriptor + if selector == BackendIntEnum.DACE: + backend_descriptor["backend_factory"] = model_backends.make_custom_dace_backend + return backend_descriptor + if selector == BackendIntEnum.GTFN: + backend_descriptor["backend_factory"] = model_backends.make_custom_gtfn_backend + return backend_descriptor + + raise ValueError(f"Invalid backend selector: {selector}") def cached_dummy_field_factory( From cede6532879ea3ef0b49a533c2f25fbc5e88afb3 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 9 Oct 2025 08:50:01 +0200 Subject: [PATCH 11/23] fix log message --- tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py | 4 +--- tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py index 0419be7e92..b81f966c70 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/diffusion_wrapper.py @@ -115,9 +115,7 @@ def diffusion_init( actual_backend = wrapper_common.select_backend( wrapper_common.BackendIntEnum(backend), on_gpu=on_gpu ) - backend_name = ( - actual_backend.name if hasattr(actual_backend, "name") else actual_backend.__name__ - ) + backend_name = actual_backend.name if hasattr(actual_backend, "name") else actual_backend logger.info(f"Using Backend {backend_name} with on_gpu={on_gpu}") # Diffusion parameters diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py index cddfb2ea98..559fc40a5c 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/dycore_wrapper.py @@ -151,9 +151,7 @@ def solve_nh_init( actual_backend = wrapper_common.select_backend( wrapper_common.BackendIntEnum(backend), on_gpu=on_gpu ) - backend_name = ( - actual_backend.name if hasattr(actual_backend, "name") else actual_backend.__name__ - ) + backend_name = actual_backend.name if hasattr(actual_backend, "name") else actual_backend logger.info(f"Using Backend {backend_name} with on_gpu={on_gpu}") config = solve_nonhydro.NonHydrostaticConfig( From fe21f9104884f23bc61334acae17363515098318 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 9 Oct 2025 21:28:48 +0200 Subject: [PATCH 12/23] overlap experiment --- .../model/atmosphere/dycore/solve_nonhydro.py | 147 +++++++++++++++--- 1 file changed, 123 insertions(+), 24 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 932466e146..bec76e6a06 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -492,7 +492,7 @@ def __init__( offset_provider=self._grid.connectivities, ) - self._apply_divergence_damping_and_update_vn = setup_program( + self._apply_divergence_damping_and_update_vn_first_half = setup_program( backend=backend, program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn, constant_args={ @@ -517,6 +517,35 @@ def __init__( }, vertical_sizes={ "vertical_start": gtx.int32(0), + "vertical_end": gtx.int32(self._grid.num_levels // 2), + }, + offset_provider=self._grid.connectivities, + ) + self._apply_divergence_damping_and_update_vn_second_half = setup_program( + backend=backend, + program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn, + constant_args={ + "horizontal_mask_for_3d_divdamp": self._metric_state_nonhydro.horizontal_mask_for_3d_divdamp, + "scaling_factor_for_3d_divdamp": self._metric_state_nonhydro.scaling_factor_for_3d_divdamp, + "inv_dual_edge_length": self._edge_geometry.inverse_dual_edge_lengths, + "nudgecoeff_e": self._interpolation_state.nudgecoeff_e, + "geofac_grdiv": self._interpolation_state.geofac_grdiv, + "advection_explicit_weight_parameter": self._params.advection_explicit_weight_parameter, + "advection_implicit_weight_parameter": self._params.advection_implicit_weight_parameter, + "iau_wgt_dyn": self._config.iau_wgt_dyn, + "is_iau_active": self._config.is_iau_active, + "limited_area": self._grid.limited_area, + }, + variants={ + "apply_2nd_order_divergence_damping": [False, True], + "apply_4th_order_divergence_damping": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_nudging_level_2), + "horizontal_end": self._end_edge_local, + }, + vertical_sizes={ + "vertical_start": gtx.int32(self._grid.num_levels // 2), "vertical_end": gtx.int32(self._grid.num_levels), }, offset_provider=self._grid.connectivities, @@ -547,26 +576,51 @@ def __init__( offset_provider=self._grid.connectivities, ) - self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection = setup_program( - backend=backend, - program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, - constant_args={ - "e_flx_avg": self._interpolation_state.e_flx_avg, - "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, - }, - variants={ - "at_first_substep": [False, True], - "prepare_advection": [False, True], - }, - horizontal_sizes={ - "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), - "horizontal_end": self._end_edge_halo_level_2, - }, - vertical_sizes={ - "vertical_start": gtx.int32(0), - "vertical_end": gtx.int32(self._grid.num_levels), - }, - offset_provider=self._grid.connectivities, + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half = ( + setup_program( + backend=backend, + program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, + constant_args={ + "e_flx_avg": self._interpolation_state.e_flx_avg, + "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, + }, + variants={ + "at_first_substep": [False, True], + "prepare_advection": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), + "horizontal_end": self._end_edge_halo_level_2, + }, + vertical_sizes={ + "vertical_start": gtx.int32(0), + "vertical_end": gtx.int32(self._grid.num_levels // 2), + }, + offset_provider=self._grid.connectivities, + ) + ) + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half = ( + setup_program( + backend=backend, + program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, + constant_args={ + "e_flx_avg": self._interpolation_state.e_flx_avg, + "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, + }, + variants={ + "at_first_substep": [False, True], + "prepare_advection": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), + "horizontal_end": self._end_edge_halo_level_2, + }, + vertical_sizes={ + "vertical_start": gtx.int32(self._grid.num_levels // 2), + "vertical_end": gtx.int32(self._grid.num_levels), + }, + offset_provider=self._grid.connectivities, + ) ) self._vertically_implicit_solver_at_predictor_step = setup_program( @@ -1336,7 +1390,33 @@ def run_corrector_step( ) ) - self._apply_divergence_damping_and_update_vn( + # HALO OVERLAP EXPERIMENT START + self._apply_divergence_damping_and_update_vn_first_half( + horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, + next_vn=prognostic_states.next.vn, + current_vn=prognostic_states.current.vn, + dwdz_at_cells_on_model_levels=z_fields.dwdz_at_cells_on_model_levels, + predictor_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.predictor, + corrector_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.corrector, + normal_wind_tendency_due_to_slow_physics_process=diagnostic_state_nh.normal_wind_tendency_due_to_slow_physics_process, + normal_wind_iau_increment=diagnostic_state_nh.normal_wind_iau_increment, + theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels, + horizontal_pressure_gradient=z_fields.horizontal_pressure_gradient, + reduced_fourth_order_divdamp_coeff_at_nest_boundary=self.reduced_fourth_order_divdamp_coeff_at_nest_boundary, + fourth_order_divdamp_scaling_coeff=self.fourth_order_divdamp_scaling_coeff, + second_order_divdamp_scaling_coeff=second_order_divdamp_scaling_coeff, + dtime=dtime, + apply_2nd_order_divergence_damping=apply_2nd_order_divergence_damping, + apply_4th_order_divergence_damping=apply_4th_order_divergence_damping, + ) + + log.debug("exchanging prognostic field 'vn'") + # this exchange should wait for `_apply_divergence_damping_and_update_vn_first_half` + first_half_exchange = self._exchange.exchange( + dims.EdgeDim, (prognostic_states.next.vn[:, : self._grid.num_levels // 2]) + ) + + self._apply_divergence_damping_and_update_vn_second_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, next_vn=prognostic_states.next.vn, current_vn=prognostic_states.current.vn, @@ -1356,9 +1436,27 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn'") - self._exchange.exchange_and_wait(dims.EdgeDim, (prognostic_states.next.vn)) + second_half_exchange = self._exchange.exchange( + dims.EdgeDim, (prognostic_states.next.vn[:, self._grid.num_levels // 2 :]) + ) + + first_half_exchange.wait() + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( + spatially_averaged_vn=self.z_vn_avg, + mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, + theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels, + substep_and_spatially_averaged_vn=prep_adv.vn_traj, + substep_averaged_mass_flux=prep_adv.mass_flx_me, + vn=prognostic_states.next.vn, + rho_at_edges_on_model_levels=z_fields.rho_at_edges_on_model_levels, + theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels, + prepare_advection=lprep_adv, + at_first_substep=at_first_substep, + r_nsubsteps=r_nsubsteps, + ) - self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection( + second_half_exchange.wait() + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels, @@ -1371,6 +1469,7 @@ def run_corrector_step( at_first_substep=at_first_substep, r_nsubsteps=r_nsubsteps, ) + # HALO OVERLAP EXPERIMENT END self._vertically_implicit_solver_at_corrector_step( next_w=prognostic_states.next.w, From 3880d1a3010fbc1b20224d821eadd93e88fae6da Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 9 Oct 2025 21:46:15 +0200 Subject: [PATCH 13/23] swap exchange<->wait --- .../icon4py/model/atmosphere/dycore/solve_nonhydro.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index bec76e6a06..4f429211fc 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1410,8 +1410,10 @@ def run_corrector_step( apply_4th_order_divergence_damping=apply_4th_order_divergence_damping, ) - log.debug("exchanging prognostic field 'vn'") - # this exchange should wait for `_apply_divergence_damping_and_update_vn_first_half` + log.debug("exchanging prognostic field 'vn' first half") + # - this exchange should sync to `_apply_divergence_damping_and_update_vn_first_half` + # - the exchange should probably run fully asynchronously + # - to force MPI to make progress we could put a wait() in a Python future and resolve the future where we currently have the wait first_half_exchange = self._exchange.exchange( dims.EdgeDim, (prognostic_states.next.vn[:, : self._grid.num_levels // 2]) ) @@ -1435,12 +1437,13 @@ def run_corrector_step( apply_4th_order_divergence_damping=apply_4th_order_divergence_damping, ) - log.debug("exchanging prognostic field 'vn'") + log.debug("exchanging prognostic field 'vn' second half") + # TODO(havogt): this wait could be after the next exchange starts, but ghex doesn't like it: "earlier exchange operation was not finished" + first_half_exchange.wait() second_half_exchange = self._exchange.exchange( dims.EdgeDim, (prognostic_states.next.vn[:, self._grid.num_levels // 2 :]) ) - first_half_exchange.wait() self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, From 43cc58bf76c2f0a5609b81d8ef2a9622af17cb02 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 9 Oct 2025 21:49:44 +0200 Subject: [PATCH 14/23] fix comment --- .../src/icon4py/model/atmosphere/dycore/solve_nonhydro.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 4f429211fc..1235436328 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: ERA001, B008 +import concurrent.futures import dataclasses import logging from typing import Final @@ -68,6 +69,8 @@ log = logging.getLogger(__name__) +_async_exchange_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) + @dataclasses.dataclass class IntermediateFields: @@ -1390,7 +1393,7 @@ def run_corrector_step( ) ) - # HALO OVERLAP EXPERIMENT START + # EXCHANGE OVERLAP EXPERIMENT START self._apply_divergence_damping_and_update_vn_first_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, next_vn=prognostic_states.next.vn, @@ -1472,7 +1475,7 @@ def run_corrector_step( at_first_substep=at_first_substep, r_nsubsteps=r_nsubsteps, ) - # HALO OVERLAP EXPERIMENT END + # EXCHANGE OVERLAP EXPERIMENT END self._vertically_implicit_solver_at_corrector_step( next_w=prognostic_states.next.w, From 9554aad021ea07a3b991ceeaf49eddd722372265 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 9 Oct 2025 21:53:55 +0200 Subject: [PATCH 15/23] run exchange async --- .../model/atmosphere/dycore/solve_nonhydro.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 1235436328..a8bae54ecf 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1414,11 +1414,10 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn' first half") - # - this exchange should sync to `_apply_divergence_damping_and_update_vn_first_half` - # - the exchange should probably run fully asynchronously - # - to force MPI to make progress we could put a wait() in a Python future and resolve the future where we currently have the wait - first_half_exchange = self._exchange.exchange( - dims.EdgeDim, (prognostic_states.next.vn[:, : self._grid.num_levels // 2]) + first_half_exchange = _async_exchange_pool.submit( + self._exchange.exchange_and_wait, + dims.EdgeDim, + (prognostic_states.next.vn[:, : self._grid.num_levels // 2]), ) self._apply_divergence_damping_and_update_vn_second_half( @@ -1442,9 +1441,11 @@ def run_corrector_step( log.debug("exchanging prognostic field 'vn' second half") # TODO(havogt): this wait could be after the next exchange starts, but ghex doesn't like it: "earlier exchange operation was not finished" - first_half_exchange.wait() - second_half_exchange = self._exchange.exchange( - dims.EdgeDim, (prognostic_states.next.vn[:, self._grid.num_levels // 2 :]) + first_half_exchange.result() + second_half_exchange = _async_exchange_pool.submit( + self._exchange.exchange_and_wait, + dims.EdgeDim, + (prognostic_states.next.vn[:, self._grid.num_levels // 2 :]), ) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( @@ -1461,7 +1462,7 @@ def run_corrector_step( r_nsubsteps=r_nsubsteps, ) - second_half_exchange.wait() + second_half_exchange.result() self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, From 8085549e14f8d9e0ee685b5e8517935f035de792 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 10 Nov 2025 10:14:48 +0100 Subject: [PATCH 16/23] Use experimental GHEX async scheduling --- .../model/common/decomposition/mpi_decomposition.py | 6 ++---- pyproject.toml | 2 +- uv.lock | 7 +++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 7fd04744bd..70df2f0501 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -237,10 +237,8 @@ def exchange(self, dim: gtx.Dimension, *fields: Sequence[gtx.Field]): ) for f in sliced_fields ] - if hasattr(fields[0].array_ns, "cuda"): - # TODO(havogt): this is a workaround as ghex does not know that it should synchronize - # the GPU before the exchange. This is necessary to ensure that all data is ready for the exchange. - fields[0].array_ns.cuda.runtime.deviceSynchronize() + # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, + # otherwise we need an explicit device synchronize here. handle = self._comm.exchange(applied_patterns) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) diff --git a/pyproject.toml b/pyproject.toml index a27e715e18..4dd9ea4f8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -317,7 +317,7 @@ url = "https://test.pypi.org/simple/" [tool.uv.sources] dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_05"} -# ghex = {git = "https://github.com/ghex-org/GHEX.git", branch = "master"} +ghex = {git = "https://github.com/msimberg/GHEX.git", branch = "async-mpi"} # gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} # gt4py = {index = "test.pypi"} icon4py-atmosphere-advection = {workspace = true} diff --git a/uv.lock b/uv.lock index 3a31e4cf7c..9eeeff3fb2 100644 --- a/uv.lock +++ b/uv.lock @@ -1352,13 +1352,12 @@ wheels = [ [[package]] name = "ghex" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } +version = "0.4.1" +source = { git = "https://github.com/msimberg/GHEX.git?branch=async-mpi#6d896166994cedbcfc50da1873239a5edb212e3f" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d4/4f/d6217b2afcecff78620c8d3df315b3a354820447ad48962889fe029a3b2c/ghex-0.4.0.tar.gz", hash = "sha256:65135fee88a0bea16bbcc6a48fda9065850db7af4340726c0ea804affed04890", size = 8309041, upload-time = "2024-12-18T14:40:05.407Z" } [[package]] name = "gitdb" @@ -1876,7 +1875,7 @@ requires-dist = [ { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=13.0" }, { name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_05" }, { name = "datashader", marker = "extra == 'io'", specifier = ">=0.16.1" }, - { name = "ghex", marker = "extra == 'distributed'", specifier = ">=0.3.0" }, + { name = "ghex", marker = "extra == 'distributed'", git = "https://github.com/msimberg/GHEX.git?branch=async-mpi" }, { name = "gt4py", specifier = "==1.1.0" }, { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'" }, { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'" }, From 621517cbe1f57e718f4d5c6f4cb51fb211736269 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 20:52:13 +0100 Subject: [PATCH 17/23] cleanup --- .../model/atmosphere/dycore/solve_nonhydro.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 991d60e30b..4cf6ba39fd 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -7,14 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: ERA001, B008 -import concurrent.futures import dataclasses import logging from typing import Final import gt4py.next as gtx import gt4py.next.typing as gtx_typing -from gt4py.next import allocators as gtx_allocators +from gt4py.next import allocators as gtx_allocators, common as gtx_common import icon4py.model.atmosphere.dycore.solve_nonhydro_stencils as nhsolve_stencils import icon4py.model.common.grid.states as grid_states @@ -69,8 +68,6 @@ log = logging.getLogger(__name__) -_async_exchange_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) - @dataclasses.dataclass class IntermediateFields: @@ -1411,11 +1408,18 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn' first half") - first_half_exchange = _async_exchange_pool.submit( - self._exchange.exchange_and_wait, - dims.EdgeDim, - (prognostic_states.next.vn[:, : self._grid.num_levels // 2]), + + first_half_vn = gtx_common._field( + prognostic_states.next.vn.ndarray[:, : self._grid.num_levels // 2], + domain=gtx_common.Domain( + prognostic_states.next.vn.domain.dims, + ( + prognostic_states.next.vn.domain.ranges[0], + self._grid.num_levels // 2, + ), + ), ) + first_half_exchange = self._exchange.exchange(dims.EdgeDim, first_half_vn) self._apply_divergence_damping_and_update_vn_second_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, @@ -1437,14 +1441,19 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn' second half") - # TODO(havogt): this wait could be after the next exchange starts, but ghex doesn't like it: "earlier exchange operation was not finished" - first_half_exchange.result() - second_half_exchange = _async_exchange_pool.submit( - self._exchange.exchange_and_wait, - dims.EdgeDim, - (prognostic_states.next.vn[:, self._grid.num_levels // 2 :]), + # TODO(havogt): this wait could be after the next exchange starts, but we need to duplicate the ghex communication object + first_half_exchange.wait() + second_half_vn = gtx_common._field( + prognostic_states.next.vn.ndarray[:, : self._grid.num_levels // 2], + domain=gtx_common.Domain( + prognostic_states.next.vn.domain.dims, + ( + prognostic_states.next.vn.domain.ranges[0], + self._grid.num_levels - self._grid.num_levels // 2, + ), + ), ) - + second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, @@ -1459,7 +1468,7 @@ def run_corrector_step( r_nsubsteps=r_nsubsteps, ) - second_half_exchange.result() + second_half_exchange.wait() self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, From 64ba54c7d2534155dad1aa66ab78a3e94d270bd8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 22:16:04 +0100 Subject: [PATCH 18/23] fix domain construction --- .../model/atmosphere/dycore/solve_nonhydro.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 4cf6ba39fd..c95b6e4cf1 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1412,10 +1412,10 @@ def run_corrector_step( first_half_vn = gtx_common._field( prognostic_states.next.vn.ndarray[:, : self._grid.num_levels // 2], domain=gtx_common.Domain( - prognostic_states.next.vn.domain.dims, - ( + dims=prognostic_states.next.vn.domain.dims, + ranges=( prognostic_states.next.vn.domain.ranges[0], - self._grid.num_levels // 2, + gtx_common.UnitRange(0, self._grid.num_levels // 2), ), ), ) @@ -1444,13 +1444,10 @@ def run_corrector_step( # TODO(havogt): this wait could be after the next exchange starts, but we need to duplicate the ghex communication object first_half_exchange.wait() second_half_vn = gtx_common._field( - prognostic_states.next.vn.ndarray[:, : self._grid.num_levels // 2], - domain=gtx_common.Domain( - prognostic_states.next.vn.domain.dims, - ( - prognostic_states.next.vn.domain.ranges[0], - self._grid.num_levels - self._grid.num_levels // 2, - ), + prognostic_states.next.vn.ndarray[:, self._grid.num_levels // 2 :], + ranges=( + prognostic_states.next.vn.domain.ranges[0], + gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), ), ) second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn) From 8b6ba5046c19696cbf06c0616ba7c6801b541d27 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 22:20:47 +0100 Subject: [PATCH 19/23] fix domain --- .../icon4py/model/atmosphere/dycore/solve_nonhydro.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index c95b6e4cf1..e3a3792b34 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1445,9 +1445,12 @@ def run_corrector_step( first_half_exchange.wait() second_half_vn = gtx_common._field( prognostic_states.next.vn.ndarray[:, self._grid.num_levels // 2 :], - ranges=( - prognostic_states.next.vn.domain.ranges[0], - gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), + domain=gtx_common.Domain( + dims=prognostic_states.next.vn.domain.dims, + ranges=( + prognostic_states.next.vn.domain.ranges[0], + gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), + ), ), ) second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn) From 6cc3f53d510df60890e2ca83fc92e546f13458a4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 23:00:54 +0100 Subject: [PATCH 20/23] cache half-fields --- .../model/atmosphere/dycore/solve_nonhydro.py | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index e3a3792b34..1d15a24198 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -873,6 +873,9 @@ def __init__( self.p_test_run = False + self._first_half_cache = {} + self._second_half_cache = {} + def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self.temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( self._grid, @@ -1309,6 +1312,36 @@ def run_predictor_step( log.debug("exchanging prognostic field 'w'") self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w) + def get_first_half_vn(self, vn: gtx.Field): + try: + return self._first_half_cache[vn.__gt_buffer_info__.hash_key] + except KeyError: + self._first_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field( + vn.ndarray[:, : self._grid.num_levels // 2], + domain=gtx_common.Domain( + dims=vn.domain.dims, + ranges=( + vn.domain.ranges[0], + gtx_common.UnitRange(0, self._grid.num_levels // 2), + ), + ), + ) + + def get_second_half_vn(self, vn: gtx.Field): + try: + return self._second_half_cache[vn.__gt_buffer_info__.hash_key] + except KeyError: + self._second_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field( + vn.ndarray[:, self._grid.num_levels // 2 :], + domain=gtx_common.Domain( + dims=vn.domain.dims, + ranges=( + vn.domain.ranges[0], + gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), + ), + ), + ) + def run_corrector_step( self, diagnostic_state_nh: dycore_states.DiagnosticStateNonHydro, @@ -1409,16 +1442,7 @@ def run_corrector_step( log.debug("exchanging prognostic field 'vn' first half") - first_half_vn = gtx_common._field( - prognostic_states.next.vn.ndarray[:, : self._grid.num_levels // 2], - domain=gtx_common.Domain( - dims=prognostic_states.next.vn.domain.dims, - ranges=( - prognostic_states.next.vn.domain.ranges[0], - gtx_common.UnitRange(0, self._grid.num_levels // 2), - ), - ), - ) + first_half_vn = self.get_first_half_vn(prognostic_states.next.vn) first_half_exchange = self._exchange.exchange(dims.EdgeDim, first_half_vn) self._apply_divergence_damping_and_update_vn_second_half( @@ -1443,16 +1467,7 @@ def run_corrector_step( log.debug("exchanging prognostic field 'vn' second half") # TODO(havogt): this wait could be after the next exchange starts, but we need to duplicate the ghex communication object first_half_exchange.wait() - second_half_vn = gtx_common._field( - prognostic_states.next.vn.ndarray[:, self._grid.num_levels // 2 :], - domain=gtx_common.Domain( - dims=prognostic_states.next.vn.domain.dims, - ranges=( - prognostic_states.next.vn.domain.ranges[0], - gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), - ), - ), - ) + second_half_vn = self.get_second_half_vn(prognostic_states.next.vn) second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( spatially_averaged_vn=self.z_vn_avg, From b78106cd205468b1250428fe70029ecab9fdecb4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 23:09:43 +0100 Subject: [PATCH 21/23] try set_sync_marker --- .../src/icon4py/model/atmosphere/dycore/solve_nonhydro.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 1d15a24198..171b5c351d 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -13,6 +13,7 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing +import viztracer # type: ignore[import-not-found] from gt4py.next import allocators as gtx_allocators, common as gtx_common import icon4py.model.atmosphere.dycore.solve_nonhydro_stencils as nhsolve_stencils @@ -875,6 +876,7 @@ def __init__( self._first_half_cache = {} self._second_half_cache = {} + self._counter = 0 def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self.temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( @@ -1421,6 +1423,9 @@ def run_corrector_step( ) # EXCHANGE OVERLAP EXPERIMENT START + if self._counter == 13: + viztracer.get_tracer().set_sync_marker() + self._counter += 1 self._apply_divergence_damping_and_update_vn_first_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, next_vn=prognostic_states.next.vn, From ca23bb4c1976a53887c18d2fb60c4aad170d7957 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 23:14:52 +0100 Subject: [PATCH 22/23] fix return --- .../src/icon4py/model/atmosphere/dycore/solve_nonhydro.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 171b5c351d..52a5d029fe 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1328,6 +1328,7 @@ def get_first_half_vn(self, vn: gtx.Field): ), ), ) + return self._first_half_cache[vn.__gt_buffer_info__.hash_key] def get_second_half_vn(self, vn: gtx.Field): try: @@ -1343,6 +1344,7 @@ def get_second_half_vn(self, vn: gtx.Field): ), ), ) + return self._second_half_cache[vn.__gt_buffer_info__.hash_key] def run_corrector_step( self, From 4356e7addbd787d83cebcba3f60d7fcf53130c9d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 11 Nov 2025 23:21:02 +0100 Subject: [PATCH 23/23] Revert "try set_sync_marker" This reverts commit b78106cd205468b1250428fe70029ecab9fdecb4. --- .../src/icon4py/model/atmosphere/dycore/solve_nonhydro.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 52a5d029fe..b3c2a47355 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -13,7 +13,6 @@ import gt4py.next as gtx import gt4py.next.typing as gtx_typing -import viztracer # type: ignore[import-not-found] from gt4py.next import allocators as gtx_allocators, common as gtx_common import icon4py.model.atmosphere.dycore.solve_nonhydro_stencils as nhsolve_stencils @@ -876,7 +875,6 @@ def __init__( self._first_half_cache = {} self._second_half_cache = {} - self._counter = 0 def _allocate_local_fields(self, allocator: gtx_allocators.FieldBufferAllocationUtil | None): self.temporal_extrapolation_of_perturbed_exner = data_alloc.zero_field( @@ -1425,9 +1423,6 @@ def run_corrector_step( ) # EXCHANGE OVERLAP EXPERIMENT START - if self._counter == 13: - viztracer.get_tracer().set_sync_marker() - self._counter += 1 self._apply_divergence_damping_and_update_vn_first_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, next_vn=prognostic_states.next.vn,