From 26d685b2b7091f7201f70f8c52d83056eb44f33f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 09:26:34 +0100 Subject: [PATCH 01/28] Modified versions. --- pyproject.toml | 5 +++-- uv.lock | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 029a046517..618bbf31f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -354,8 +354,9 @@ name = "test.pypi" url = "https://test.pypi.org/simple/" [tool.uv.sources] -dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_26"} -ghex = {git = "https://github.com/msimberg/GHEX.git", branch = "async-mpi"} +dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_12_11"} +# ghex = {git = "https://github.com/msimberg/GHEX.git", branch = "async-mpi"} +ghex = {git = "https://github.com/philip-paul-mueller/GHEX/", branch = "phimuell__async-mpi-2"} # gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} # gt4py = {index = "test.pypi"} icon4py-atmosphere-advection = {workspace = true} diff --git a/uv.lock b/uv.lock index 05f868c07f..7378a2e14a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -175,8 +175,8 @@ dependencies = [ { name = "pathspec" }, { name = "platformdirs" }, { name = "pytokens" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/8c/ad/33adf4708633d047950ff2dfdea2e215d84ac50ef95aff14a614e4b6e9b2/black-25.11.0.tar.gz", hash = "sha256:9a323ac32f5dc75ce7470501b887250be5005a01602e931a15e45593f70f6e08", size = 655669, upload-time = "2025-11-10T01:53:50.558Z" } wheels = [ @@ -568,7 +568,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121, upload-time = "2023-08-17T17:29:11.868Z" } wheels = [ @@ -912,8 +912,8 @@ wheels = [ [[package]] name = "dace" -version = "2025.11.26" -source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_26#3eec6f6dae18ac90e2f967ab8098505bd972b92f" } +version = "2025.12.11" +source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_12_11#ab9eaef558b4961058c98930be6a4026597ea6c9" } dependencies = [ { name = "aenum" }, { name = "astunparse" }, @@ -923,7 +923,7 @@ dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "ply" }, - { name = "pyreadline", marker = "sys_platform == 'win32'" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, { name = "pyyaml" }, { name = "sympy" }, ] @@ -1359,7 +1359,7 @@ wheels = [ [[package]] name = "ghex" version = "0.4.1" -source = { git = "https://github.com/msimberg/GHEX.git?branch=async-mpi#6d896166994cedbcfc50da1873239a5edb212e3f" } +source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#0f1e15dc81682baf5c1be9b2e082f85a9c7a136b" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, @@ -1882,9 +1882,9 @@ requires-dist = [ { name = "cftime", marker = "extra == 'io'", specifier = ">=1.6.3" }, { name = "cupy-cuda11x", marker = "extra == 'cuda11'", specifier = ">=13.0" }, { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=13.0" }, - { name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_26" }, + { name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_12_11" }, { name = "datashader", marker = "extra == 'io'", specifier = ">=0.16.1" }, - { name = "ghex", marker = "extra == 'distributed'", git = "https://github.com/msimberg/GHEX.git?branch=async-mpi" }, + { name = "ghex", marker = "extra == 'distributed'", git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2" }, { name = "gt4py", specifier = "==1.1.2" }, { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'" }, { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'" }, @@ -1951,7 +1951,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "devtools", specifier = ">=0.12" }, - { name = "gt4py", specifier = "==1.1.0" }, + { name = "gt4py", specifier = "==1.1.2" }, { name = "icon4py-atmosphere-diffusion", editable = "model/atmosphere/diffusion" }, { name = "icon4py-atmosphere-dycore", editable = "model/atmosphere/dycore" }, { name = "icon4py-common", editable = "model/common" }, @@ -3973,7 +3973,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, { name = "setuptools" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/4f/a4/00a9ac1b555294710d4a68d2ce8dfdf39d72aa4d769a7395d05218d88a42/setuptools_scm-8.1.0.tar.gz", hash = "sha256:42dea1b65771cba93b7a515d65a65d8246e560768a66b9106a592c8e7f26c8a7", size = 76465, upload-time = "2024-05-06T15:07:56.934Z" } wheels = [ @@ -4598,7 +4598,7 @@ version = "3.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-7-icon4py-cuda11' and extra == 'extra-7-icon4py-cuda12')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/5c/9b/941647e9e3616b5da7bbc4601ed9920f44a886704100fa8151406c07c149/versioningit-3.1.2.tar.gz", hash = "sha256:4db83ed99f56b07d83940bee3445ca46ca120d13b6b304cdb5fb44e5aa4edec0", size = 213047, upload-time = "2024-07-20T12:41:07.927Z" } wheels = [ From 6518ce97d6b0a07a8bfd559f521176be672bb1df Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 13:04:20 +0100 Subject: [PATCH 02/28] Made some addaptions towards the asynchronous exchange. --- .../model/common/decomposition/definitions.py | 131 ++++++++++++++---- .../common/decomposition/mpi_decomposition.py | 49 +++++-- 2 files changed, 143 insertions(+), 37 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index c23f1ad111..b38782640a 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -34,6 +34,18 @@ class ProcessProperties(Protocol): comm_size: int +class NoStream: + """Used in `exchange_and_wait()`, `exchange()` to indicate that no stream + synchronization is requested. + + If an instance of this type is given as `stream` argument , then the normal + `exchange()` and `wait()` functions are using. Otherwise the + `schedule_exchange()` and `schedule_wait()` are used. + """ + + pass + + @dataclasses.dataclass(frozen=True, init=False) class SingleNodeProcessProperties(ProcessProperties): comm: None @@ -149,24 +161,81 @@ def global_index( class ExchangeResult(Protocol): - def wait(self) -> None: ... + def wait(self, stream: Any | type[NoStream] | None = NoStream) -> None: + """Perform a wait. + + The function will wait for the communication to have finished and then + perform the unpacking. The condition under which the function returns + depends on `stream` if it is `NoStream`, the default, the function will + wait until all unpacking has completed. Otherwise, the function will + return after all unpacking has been scheduled. Thus, all further work + submitted to `stream` will wait until the unpacking has finished. + """ + ... - def is_ready(self) -> bool: ... + def is_ready(self) -> bool: + """Check if communication has been finished.""" + ... @runtime_checkable class ExchangeRuntime(Protocol): @overload - def exchange(self, dim: gtx.Dimension, *fields: gtx.Field) -> ExchangeResult: ... + def exchange( + self, + dim: gtx.Dimension, + *buffers: data_alloc.NDArray, + stream: Any | type[NoStream] | None = NoStream, + ) -> ExchangeResult: ... @overload - def exchange(self, dim: gtx.Dimension, *buffers: data_alloc.NDArray) -> ExchangeResult: ... + def exchange( + self, dim: gtx.Dimension, *fields: gtx.Field, stream: Any | type[NoStream] | None = NoStream + ) -> ExchangeResult: + """Perform halo exchanges. + + The exact behaviour depends on if the optional argument `stream` is supplied. + If it is given and set to `NoStream` then packing will start immediately. + However, if it is given and a GPU stream, then packing will wait until all + work on `stream` has been done. + """ + ... @overload - def exchange_and_wait(self, dim: gtx.Dimension, *fields: gtx.Field) -> None: ... + def exchange_and_wait( + self, + dim: gtx.Dimension, + *buffers: data_alloc.NDArray, + stream: Any | type[NoStream] | None = NoStream, + ) -> None: ... @overload - def exchange_and_wait(self, dim: gtx.Dimension, *buffers: data_alloc.NDArray) -> None: ... + def exchange_and_wait( + self, dim: gtx.Dimension, *fields: gtx.Field, stream: Any | type[NoStream] | None = NoStream + ) -> None: + """Exchange and wait in one go.""" + ... + + def __call__( + self, + *args: Any, + dim: gtx.Dimension, + wait: bool = True, + stream: Any | type[NoStream] | None = NoStream, + ) -> None | ExchangeResult: + """Perform a halo exchange operation. + + Args: + args: The fields to be exchanged. + + Keyword Args: + dim: The dimension along which the exchange is performed. + wait: If True, the operation will block until the exchange is completed (default: True). + stream: Use the asynchronous exchange mode using stream `stream`. If `NoStream` is + passed, the default, then the normal exchange function are used. For more + information refer to the GHEX documentation. + """ + ... def get_size(self) -> int: ... @@ -179,12 +248,18 @@ def __str__(self) -> str: @dataclasses.dataclass class SingleNodeExchange: def exchange( - self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray + self, + dim: gtx.Dimension, + *fields: gtx.Field | data_alloc.NDArray, + stream: Any | type[NoStream] | None = NoStream, ) -> ExchangeResult: return SingleNodeResult() def exchange_and_wait( - self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray + self, + dim: gtx.Dimension, + *fields: gtx.Field | data_alloc.NDArray, + stream: Any | type[NoStream] | None = NoStream, ) -> None: return None @@ -194,20 +269,16 @@ def my_rank(self) -> int: def get_size(self) -> int: return 1 - def __call__(self, *args: Any, dim: gtx.Dimension, wait: bool = True) -> ExchangeResult | None: # type: ignore[return] # return statment in else condition - """Perform a halo exchange operation. - - Args: - args: The fields to be exchanged. - - Keyword Args: - dim: The dimension along which the exchange is performed. - wait: If True, the operation will block until the exchange is completed (default: True). - """ - - res = self.exchange(dim, *args) + def __call__( # type: ignore[return] # return statment in else condition + self, + *args: Any, + dim: gtx.Dimension, + wait: bool = True, + stream: Any | type[NoStream] | None = NoStream, + ) -> ExchangeResult | None: + res = self.exchange(dim, *args, stream=stream) if wait: - res.wait() + res.wait(stream=stream) else: return res @@ -234,8 +305,15 @@ def dace__sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: class HaloExchangeWaitRuntime(Protocol): """Protocol for halo exchange wait.""" - def __call__(self, communication_handle: ExchangeResult) -> None: - """Wait on the communication handle.""" + def __call__( + self, communication_handle: ExchangeResult, stream: Any | type[NoStream] | None = NoStream + ) -> None: + """Wait on the communication handle. + + If `stream` is given then perform a scheduled wait. This means that the when + the function returns the unpacking has not necessarily finished. However, + every work that is submitted to `stream` will wait for the unpacking. + """ ... def __sdfg__(self, *args: Any, **kwargs: dict[str, Any]) -> dace.sdfg.sdfg.SDFG: @@ -255,7 +333,10 @@ def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: class HaloExchangeWait: exchange_object: SingleNodeExchange # maintain the same interface with the MPI counterpart - def __call__(self, communication_handle: SingleNodeResult) -> None: + def __call__( + self, communication_handle: SingleNodeResult, stream: Any | type[NoStream] | None = NoStream + ) -> None: + # Stream is ignored. communication_handle.wait() # Implementation of DaCe SDFGConvertible interface @@ -288,7 +369,7 @@ def create_single_node_halo_exchange_wait(runtime: SingleNodeExchange) -> HaloEx class SingleNodeResult: - def wait(self) -> None: + def wait(self, stream: Any | type[NoStream] | None = NoStream) -> None: pass def is_ready(self) -> bool: diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index cd18d6259d..b0f55da379 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -242,7 +242,10 @@ def _get_applied_pattern(self, dim: gtx.Dimension, f: gtx.Field | data_alloc.NDA return self._patterns[dim](self._make_field_descriptor(dim, f)) def exchange( - self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray + self, + dim: gtx.Dimension, + *fields: gtx.Field | data_alloc.NDArray, + stream: Any | type[definitions.NoStream] | None = definitions.NoStream, ) -> MultiNodeResult: """ Exchange method that slices the fields based on the dimension and then performs halo exchange. @@ -254,18 +257,30 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, # otherwise we need an explicit device synchronize here. - handle = self._comm.exchange(applied_patterns) + if stream is definitions.NoStream: + handle = self._comm.exchange(*applied_patterns) + else: + handle = self._comm.schedule_exchange(*applied_patterns, stream=stream) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) def exchange_and_wait( - self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray + self, + dim: gtx.Dimension, + *fields: gtx.Field | data_alloc.NDArray, + stream: Any | type[definitions.NoStream] | None = definitions.NoStream, ) -> None: - res = self.exchange(dim, *fields) - res.wait() + res = self.exchange(dim, *fields, stream=stream) + res.wait(stream=stream) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' done.") - def __call__(self, *args: Any, dim: gtx.Dimension, wait: bool = True) -> MultiNodeResult | None: # type: ignore[return] # return statment in else condition + def __call__( # type: ignore[return] # return statment in else condition + self, + *args: Any, + dim: gtx.Dimension, + wait: bool = True, + stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + ) -> MultiNodeResult | None: """Perform a halo exchange operation. Args: @@ -274,13 +289,16 @@ def __call__(self, *args: Any, dim: gtx.Dimension, wait: bool = True) -> MultiNo Keyword Args: dim: The dimension along which the exchange is performed. wait: If True, the operation will block until the exchange is completed (default: True). + stream: Use the asynchronous exchange mode using stream `stream`. If `NoStream` is + passed, the default, then the normal exchange function are used. For more + information refer to the GHEX documentation. """ if dim is None: raise ValueError("Need to define a dimension.") - res = self.exchange(dim, *args) + res = self.exchange(dim, *args, stream=stream) if wait: - res.wait() + res.wait(stream=stream) else: return res @@ -334,9 +352,13 @@ class HaloExchangeWait: buffer_name: ClassVar[str] = "communication_handle" # DaCe-related - def __call__(self, communication_handle: MultiNodeResult) -> None: + def __call__( + self, + communication_handle: MultiNodeResult, + stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + ) -> None: """Wait on the communication handle.""" - communication_handle.wait() + communication_handle.wait(stream=stream) # Implementation of DaCe SDFGConvertible interface def dace__sdfg__( @@ -407,8 +429,11 @@ class MultiNodeResult: handle: Any pattern_refs: Any - def wait(self) -> None: - self.handle.wait() + def wait(self, stream: Any | type[definitions.NoStream] | None = definitions.NoStream) -> None: + if stream is definitions.NoStream: + self.handle.wait() + else: + self.handle.schedule_wait(stream=stream) del self.pattern_refs def is_ready(self) -> bool: From ec7fca2d79d31d49b34fbd3c029cdfec29622cfc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 13:33:46 +0100 Subject: [PATCH 03/28] More uniformity. --- .../src/icon4py/model/common/decomposition/definitions.py | 3 +-- .../icon4py/model/common/decomposition/mpi_decomposition.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index b38782640a..98e8107816 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -336,8 +336,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: SingleNodeResult, stream: Any | type[NoStream] | None = NoStream ) -> None: - # Stream is ignored. - communication_handle.wait() + communication_handle.wait(stream=stream) # Implementation of DaCe SDFGConvertible interface def dace__sdfg__( diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index b0f55da379..ecfe9f4eaf 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -257,6 +257,7 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, # otherwise we need an explicit device synchronize here. + # TODO(phimuell): Switch to passing the list, when bindings have catch up. if stream is definitions.NoStream: handle = self._comm.exchange(*applied_patterns) else: From 1f5e9e66117efab976a30951f743181cadba8dbd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 15:40:16 +0100 Subject: [PATCH 04/28] Updated ghex version. --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 7378a2e14a..c1cbb39d90 100644 --- a/uv.lock +++ b/uv.lock @@ -1359,7 +1359,7 @@ wheels = [ [[package]] name = "ghex" version = "0.4.1" -source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#0f1e15dc81682baf5c1be9b2e082f85a9c7a136b" } +source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#7d47080c299707f5312fe7525646905212cf97a8" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, From e69cb82fc0482e8894b5be06266fcee36983ffbb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 15:44:46 +0100 Subject: [PATCH 05/28] Fixed at least that issue. --- .../icon4py/model/common/decomposition/mpi_decomposition.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index ecfe9f4eaf..5b8e6b55f0 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -257,11 +257,10 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, # otherwise we need an explicit device synchronize here. - # TODO(phimuell): Switch to passing the list, when bindings have catch up. if stream is definitions.NoStream: - handle = self._comm.exchange(*applied_patterns) + handle = self._comm.exchange(applied_patterns) else: - handle = self._comm.schedule_exchange(*applied_patterns, stream=stream) + handle = self._comm.schedule_exchange(applied_patterns, stream=stream) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) From f60a1f87873a6212df9cfa789972d9b7be571366 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Dec 2025 16:19:47 +0100 Subject: [PATCH 06/28] Made the components aware of async stuff. **NOTE:** This commit still follows the old nomoclature, where `None` means default stream. Most likely this will change such that `None` means "not using `schedule_*()` functions and another sigelton is used for it. --- .../model/atmosphere/advection/advection.py | 4 ++-- .../model/atmosphere/diffusion/diffusion.py | 18 ++++++++++++++---- .../model/atmosphere/dycore/solve_nonhydro.py | 15 +++++++++++---- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py index 3e03c2201a..02b6b02799 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py @@ -271,7 +271,7 @@ def run( log.debug("advection run - start") log.debug("communication of prep_adv cell field: mass_flx_ic - start") - self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic) + self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic, stream=None) log.debug("communication of prep_adv cell field: mass_flx_ic - end") # reintegrate density for conservation of mass @@ -364,7 +364,7 @@ def run( # exchange updated tracer values, originally happens only if iforcing /= inwp log.debug("communication of advection cell field: p_tracer_new - start") - self._exchange.exchange_and_wait(dims.CellDim, p_tracer_new) + self._exchange.exchange_and_wait(dims.CellDim, p_tracer_new, stream=None) log.debug("communication of advection cell field: p_tracer_new - end") # finalize step diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 4e87d3de17..9258784a44 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -386,7 +386,8 @@ def __init__( self._cell_params = cell_params self.halo_exchange_wait = decomposition.create_halo_exchange_wait( - self._exchange + self._exchange, + stream=None, ) # wait on a communication handle self.rd_o_cvd: float = constants.GAS_CONSTANT_DRY_AIR / ( constants.CPD - constants.GAS_CONSTANT_DRY_AIR @@ -736,6 +737,7 @@ def _sync_cell_fields(self, prognostic_state): prognostic_state.w, prognostic_state.theta_v, prognostic_state.exner, + stream=None, ) log.debug("communication of prognostic cell fields: theta, w, exner - done") @@ -772,12 +774,14 @@ def _do_diffusion_step( log.debug("rbf interpolation 1: end") # 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert + # TODO(phimuell, muellch): Is asynchronous mode okay here. log.debug("communication rbf extrapolation of vn - start") self._exchange( self.u_vert, self.v_vert, dim=dims.VertexDim, wait=True, + stream=None, ) log.debug("communication rbf extrapolation of vn - end") @@ -817,7 +821,8 @@ def _do_diffusion_step( # TODO(halungge): move this up and do asynchronous exchange if self.config.type_vn_diffu > 1: log.debug("communication rbf extrapolation of z_nable2_e - start") - self._exchange(self.z_nabla2_e, dim=dims.EdgeDim, wait=True) + # TODO(phimuell, muellch): Is asynchronous mode okay here. + self._exchange(self.z_nabla2_e, dim=dims.EdgeDim, wait=True, stream=None) log.debug("communication rbf extrapolation of z_nable2_e - end") log.debug("2nd rbf interpolation: start") @@ -827,12 +832,14 @@ def _do_diffusion_step( log.debug("2nd rbf interpolation: end") # 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields) + # TODO(phimuell, muellch): Is asynchronous mode okay here. log.debug("communication rbf extrapolation of z_nable2_e - start") self._exchange( self.u_vert, self.v_vert, dim=dims.VertexDim, wait=True, + stream=None, ) log.debug("communication rbf extrapolation of z_nable2_e - end") @@ -848,7 +855,9 @@ def _do_diffusion_step( log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): end") log.debug("communication of prognistic.vn : start") - handle_edge_comm = self._exchange(prognostic_state.vn, dim=dims.EdgeDim, wait=False) + handle_edge_comm = self._exchange( + prognostic_state.vn, dim=dims.EdgeDim, wait=False, stream=None + ) log.debug( "running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start" @@ -894,7 +903,8 @@ def _do_diffusion_step( log.debug("running stencil 13 to 16 apply_diffusion_to_theta_and_exner: end") self.halo_exchange_wait( - handle_edge_comm + handle_edge_comm, + stream=None, ) # need to do this here, since we currently only use 1 communication object. log.debug("communication of prognogistic.vn - end") diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 8b939ed6d6..4d85a736e5 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1190,7 +1190,10 @@ def run_predictor_step( log.debug("exchanging prognostic field 'vn' and local field 'rho_at_edges_on_model_levels'") self._exchange.exchange_and_wait( - dims.EdgeDim, prognostic_states.next.vn, z_fields.rho_at_edges_on_model_levels + dims.EdgeDim, + prognostic_states.next.vn, + z_fields.rho_at_edges_on_model_levels, + stream=None, ) self._compute_horizontal_velocity_quantities_and_fluxes( @@ -1262,11 +1265,14 @@ def run_predictor_step( "exchanging prognostic field 'w' and local field 'dwdz_at_cells_on_model_levels'" ) self._exchange.exchange_and_wait( - dims.CellDim, prognostic_states.next.w, z_fields.dwdz_at_cells_on_model_levels + dims.CellDim, + prognostic_states.next.w, + z_fields.dwdz_at_cells_on_model_levels, + stream=None, ) else: log.debug("exchanging prognostic field 'w'") - self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w) + self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w, stream=None) def run_corrector_step( self, @@ -1361,7 +1367,7 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn'") - self._exchange.exchange_and_wait(dims.EdgeDim, (prognostic_states.next.vn)) + self._exchange.exchange_and_wait(dims.EdgeDim, prognostic_states.next.vn, stream=None) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection( spatially_averaged_vn=self.z_vn_avg, @@ -1433,4 +1439,5 @@ def run_corrector_step( prognostic_states.next.rho, prognostic_states.next.exner, prognostic_states.next.w, + stream=None, ) From e11da41c09cf5487baf0e0bdb7889014a4f830d4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 07:30:35 +0100 Subject: [PATCH 07/28] Fixed some stray `stream` argument. --- .../src/icon4py/model/atmosphere/diffusion/diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 9258784a44..46d7bf79a9 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -387,7 +387,6 @@ def __init__( self.halo_exchange_wait = decomposition.create_halo_exchange_wait( self._exchange, - stream=None, ) # wait on a communication handle self.rd_o_cvd: float = constants.GAS_CONSTANT_DRY_AIR / ( constants.CPD - constants.GAS_CONSTANT_DRY_AIR From 6636cea001fa67a975b640766d53e38eb48549e6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 07:36:55 +0100 Subject: [PATCH 08/28] Updated the annotations and the meaning of `stream`. - There are now two protocols that describes how to extract the underlying address. They are probably at the wrong location. - `stream=None` no longer means "default stream" but is not equivalent to "do not use scheduled version". - To indicate the default stream the singelton `DefaultStream` is used. The `cupy.cuda.Stream.null` singelton was not used, because it would require that `cupy` is present. - However, use the default stream is still the default behaviour. --- .../model/common/decomposition/definitions.py | 127 +++++++++++------- .../common/decomposition/mpi_decomposition.py | 47 +++---- 2 files changed, 105 insertions(+), 69 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index 98e8107816..22693d4a9c 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -13,7 +13,7 @@ import logging from collections.abc import Sequence from enum import Enum -from typing import Any, Literal, Protocol, overload, runtime_checkable +from typing import Any, Literal, Protocol, TypeAlias, overload, runtime_checkable import dace # type: ignore[import-untyped] import gt4py.next as gtx @@ -26,24 +26,49 @@ log = logging.getLogger(__name__) +# TODO(reviewer): I am pretty sure that the protocols I added should go +# somewhere else, but I have no plan where. -class ProcessProperties(Protocol): - comm: Any - rank: int - comm_name: str - comm_size: int + +class DefaultStream: + """Used in `exchange_and_wait()`, `exchange()` to indicate that synchronization + with the default stream is requested. + """ -class NoStream: - """Used in `exchange_and_wait()`, `exchange()` to indicate that no stream - synchronization is requested. +class CupyLikeStream(Protocol): + """The type follows the CuPy convention of a stream. - If an instance of this type is given as `stream` argument , then the normal - `exchange()` and `wait()` functions are using. Otherwise the - `schedule_exchange()` and `schedule_wait()` are used. + This means they have an attribute `ptr` that returns the address of the + underlying GPU stream. + See: https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html#cupy-cuda-stream """ - pass + @property + def ptr(self) -> int: ... + + +class CudaStreamProtocolLike(Protocol): + """The type follows the CUDA stream protocol. + + This means it provides a method called `__cuda_stream__()` returning a pair of + integers. The first is the protocol version and the second value is the + address of the stream. + See: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol + """ + + def __cuda_stream__(self) -> tuple[int, int]: ... + + +#: Types that are supported as streams. +StreamLike: TypeAlias = type[DefaultStream] | CupyLikeStream | CudaStreamProtocolLike + + +class ProcessProperties(Protocol): + comm: Any + rank: int + comm_name: str + comm_size: int @dataclasses.dataclass(frozen=True, init=False) @@ -161,15 +186,18 @@ def global_index( class ExchangeResult(Protocol): - def wait(self, stream: Any | type[NoStream] | None = NoStream) -> None: - """Perform a wait. - - The function will wait for the communication to have finished and then - perform the unpacking. The condition under which the function returns - depends on `stream` if it is `NoStream`, the default, the function will - wait until all unpacking has completed. Otherwise, the function will - return after all unpacking has been scheduled. Thus, all further work - submitted to `stream` will wait until the unpacking has finished. + def wait( + self, + stream: StreamLike | None = DefaultStream, + ) -> None: + """Performs a wait. + + The function will wait until the communication has ended and then start the + unpacking of the data. If `stream` is `None` then the function will wait + until the unpacking has completed. If it is a CUDA stream or the + `DefaultStream` singleton, then the function will return after the unpacking + has been scheduled. It will add synchronizations, such that all work + that will be submitted to `stream` will wait until the unpacking has finished. """ ... @@ -185,19 +213,24 @@ def exchange( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> ExchangeResult: ... @overload def exchange( - self, dim: gtx.Dimension, *fields: gtx.Field, stream: Any | type[NoStream] | None = NoStream + self, + dim: gtx.Dimension, + *fields: gtx.Field, + stream: StreamLike | None = DefaultStream, ) -> ExchangeResult: """Perform halo exchanges. - The exact behaviour depends on if the optional argument `stream` is supplied. - If it is given and set to `NoStream` then packing will start immediately. - However, if it is given and a GPU stream, then packing will wait until all - work on `stream` has been done. + The exact behaviour depends on the optional argument `stream` is supplied. + If it is a GPU stream, by default it is the default stream, then the + exchange will wait until the work previously submitted to the stream has + concluded. If it is `None` then the exchange will start immediately. + + It is important that `wait` is called on the returned handle. """ ... @@ -206,12 +239,15 @@ def exchange_and_wait( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> None: ... @overload def exchange_and_wait( - self, dim: gtx.Dimension, *fields: gtx.Field, stream: Any | type[NoStream] | None = NoStream + self, + dim: gtx.Dimension, + *fields: gtx.Field, + stream: StreamLike | None = DefaultStream, ) -> None: """Exchange and wait in one go.""" ... @@ -221,7 +257,7 @@ def __call__( *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> None | ExchangeResult: """Perform a halo exchange operation. @@ -231,9 +267,7 @@ def __call__( Keyword Args: dim: The dimension along which the exchange is performed. wait: If True, the operation will block until the exchange is completed (default: True). - stream: Use the asynchronous exchange mode using stream `stream`. If `NoStream` is - passed, the default, then the normal exchange function are used. For more - information refer to the GHEX documentation. + stream: How stream synchronization works, see `self.exchange()` for more. """ ... @@ -251,7 +285,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> ExchangeResult: return SingleNodeResult() @@ -259,7 +293,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> None: return None @@ -274,7 +308,7 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: Any | type[NoStream] | None = NoStream, + stream: StreamLike | None = DefaultStream, ) -> ExchangeResult | None: res = self.exchange(dim, *args, stream=stream) if wait: @@ -284,6 +318,7 @@ def __call__( # type: ignore[return] # return statment in else condition # Implementation of DaCe SDFGConvertible interface # For more see [dace repo]/dace/frontend/python/common.py#[class SDFGConvertible] + # TODO(phimuell): Add the `stream` keyword as well. def dace__sdfg__( self, *args: Any, dim: gtx.Dimension, wait: bool = True ) -> dace.sdfg.sdfg.SDFG: @@ -306,14 +341,11 @@ class HaloExchangeWaitRuntime(Protocol): """Protocol for halo exchange wait.""" def __call__( - self, communication_handle: ExchangeResult, stream: Any | type[NoStream] | None = NoStream + self, + communication_handle: ExchangeResult, + stream: StreamLike | None = DefaultStream, ) -> None: - """Wait on the communication handle. - - If `stream` is given then perform a scheduled wait. This means that the when - the function returns the unpacking has not necessarily finished. However, - every work that is submitted to `stream` will wait for the unpacking. - """ + """Calls `wait()` on the provided communication handle, `stream` is forwarded.""" ... def __sdfg__(self, *args: Any, **kwargs: dict[str, Any]) -> dace.sdfg.sdfg.SDFG: @@ -334,11 +366,14 @@ class HaloExchangeWait: exchange_object: SingleNodeExchange # maintain the same interface with the MPI counterpart def __call__( - self, communication_handle: SingleNodeResult, stream: Any | type[NoStream] | None = NoStream + self, + communication_handle: SingleNodeResult, + stream: StreamLike | None = DefaultStream, ) -> None: communication_handle.wait(stream=stream) # Implementation of DaCe SDFGConvertible interface + # TODO(phimuell): Add `stream` argument. def dace__sdfg__( self, *args: Any, dim: gtx.Dimension, wait: bool = True ) -> dace.sdfg.sdfg.SDFG: @@ -368,7 +403,7 @@ def create_single_node_halo_exchange_wait(runtime: SingleNodeExchange) -> HaloEx class SingleNodeResult: - def wait(self, stream: Any | type[NoStream] | None = NoStream) -> None: + def wait(self, stream: StreamLike | None = DefaultStream) -> None: pass def is_ready(self) -> bool: diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 5b8e6b55f0..ed30a7ddbb 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -245,7 +245,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + stream: definitions.StreamLike | None = definitions.DefaultStream, ) -> MultiNodeResult: """ Exchange method that slices the fields based on the dimension and then performs halo exchange. @@ -255,12 +255,16 @@ def exchange( ), f"first dimension must be one of ({dims.MAIN_HORIZONTAL_DIMENSIONS.values()})" applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, - # otherwise we need an explicit device synchronize here. - if stream is definitions.NoStream: + + if stream is None: + # Normal exchange. handle = self._comm.exchange(applied_patterns) else: - handle = self._comm.schedule_exchange(applied_patterns, stream=stream) + # Stream given, perform a scheduled exchange.. + # NOTE: GHEX interprets `None` as default stream. + handle = self._comm.schedule_exchange( + applied_patterns, stream=(None if stream is definitions.DefaultStream else stream) + ) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) @@ -268,7 +272,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + stream: definitions.StreamLike | None = definitions.DefaultStream, ) -> None: res = self.exchange(dim, *fields, stream=stream) res.wait(stream=stream) @@ -279,20 +283,8 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + stream: definitions.StreamLike | None = definitions.DefaultStream, ) -> MultiNodeResult | None: - """Perform a halo exchange operation. - - Args: - args: The fields to be exchanged. - - Keyword Args: - dim: The dimension along which the exchange is performed. - wait: If True, the operation will block until the exchange is completed (default: True). - stream: Use the asynchronous exchange mode using stream `stream`. If `NoStream` is - passed, the default, then the normal exchange function are used. For more - information refer to the GHEX documentation. - """ if dim is None: raise ValueError("Need to define a dimension.") @@ -355,7 +347,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: MultiNodeResult, - stream: Any | type[definitions.NoStream] | None = definitions.NoStream, + stream: definitions.StreamLike | None = definitions.DefaultStream, ) -> None: """Wait on the communication handle.""" communication_handle.wait(stream=stream) @@ -429,11 +421,20 @@ class MultiNodeResult: handle: Any pattern_refs: Any - def wait(self, stream: Any | type[definitions.NoStream] | None = definitions.NoStream) -> None: - if stream is definitions.NoStream: + def wait( + self, + stream: definitions.StreamLike | None = definitions.DefaultStream, + ) -> None: + if stream is None: + # No stream given, perform full blocking wait. self.handle.wait() else: - self.handle.schedule_wait(stream=stream) + # Stream given, perform a scheduled wait. + # NOTE: GHEX interprets `None` as default stream. + self.handle.schedule_wait( + stream=(None if stream is definitions.DefaultStream else stream) + ) + # TODO(reviewer, phimuell): Is it safe to delete that here, even in the scheduled mode? del self.pattern_refs def is_ready(self) -> bool: From 383f9590cdab8b03f1ec69fcbd249cc15119be5f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 08:10:27 +0100 Subject: [PATCH 09/28] Realized that the strams are disabled. --- .../model/atmosphere/advection/advection.py | 12 +++++++++-- .../model/atmosphere/diffusion/diffusion.py | 20 +++++++++++++------ .../model/atmosphere/dycore/solve_nonhydro.py | 18 ++++++++++++----- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py index 02b6b02799..e3a90e7189 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py @@ -271,7 +271,11 @@ def run( log.debug("advection run - start") log.debug("communication of prep_adv cell field: mass_flx_ic - start") - self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic, stream=None) + self._exchange.exchange_and_wait( + dims.CellDim, + prep_adv.mass_flx_ic, + stream=decomposition.DefaultStream, + ) log.debug("communication of prep_adv cell field: mass_flx_ic - end") # reintegrate density for conservation of mass @@ -364,7 +368,11 @@ def run( # exchange updated tracer values, originally happens only if iforcing /= inwp log.debug("communication of advection cell field: p_tracer_new - start") - self._exchange.exchange_and_wait(dims.CellDim, p_tracer_new, stream=None) + self._exchange.exchange_and_wait( + dims.CellDim, + p_tracer_new, + stream=decomposition.DefaultStream, + ) log.debug("communication of advection cell field: p_tracer_new - end") # finalize step diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 46d7bf79a9..4ec6367e0f 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -736,7 +736,7 @@ def _sync_cell_fields(self, prognostic_state): prognostic_state.w, prognostic_state.theta_v, prognostic_state.exner, - stream=None, + stream=decomposition.DefaultStream, ) log.debug("communication of prognostic cell fields: theta, w, exner - done") @@ -780,7 +780,7 @@ def _do_diffusion_step( self.v_vert, dim=dims.VertexDim, wait=True, - stream=None, + stream=decomposition.DefaultStream, ) log.debug("communication rbf extrapolation of vn - end") @@ -821,7 +821,12 @@ def _do_diffusion_step( if self.config.type_vn_diffu > 1: log.debug("communication rbf extrapolation of z_nable2_e - start") # TODO(phimuell, muellch): Is asynchronous mode okay here. - self._exchange(self.z_nabla2_e, dim=dims.EdgeDim, wait=True, stream=None) + self._exchange( + self.z_nabla2_e, + dim=dims.EdgeDim, + wait=True, + stream=decomposition.DefaultStream, + ) log.debug("communication rbf extrapolation of z_nable2_e - end") log.debug("2nd rbf interpolation: start") @@ -838,7 +843,7 @@ def _do_diffusion_step( self.v_vert, dim=dims.VertexDim, wait=True, - stream=None, + stream=decomposition.DefaultStream, ) log.debug("communication rbf extrapolation of z_nable2_e - end") @@ -855,7 +860,10 @@ def _do_diffusion_step( log.debug("communication of prognistic.vn : start") handle_edge_comm = self._exchange( - prognostic_state.vn, dim=dims.EdgeDim, wait=False, stream=None + prognostic_state.vn, + dim=dims.EdgeDim, + wait=False, + stream=decomposition.DefaultStream, ) log.debug( @@ -903,7 +911,7 @@ def _do_diffusion_step( self.halo_exchange_wait( handle_edge_comm, - stream=None, + stream=decomposition.DefaultStream, ) # need to do this here, since we currently only use 1 communication object. log.debug("communication of prognogistic.vn - end") diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 4d85a736e5..5ce64b908e 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -1193,7 +1193,7 @@ def run_predictor_step( dims.EdgeDim, prognostic_states.next.vn, z_fields.rho_at_edges_on_model_levels, - stream=None, + stream=decomposition.DefaultStream, ) self._compute_horizontal_velocity_quantities_and_fluxes( @@ -1268,11 +1268,15 @@ def run_predictor_step( dims.CellDim, prognostic_states.next.w, z_fields.dwdz_at_cells_on_model_levels, - stream=None, + stream=decomposition.DefaultStream, ) else: log.debug("exchanging prognostic field 'w'") - self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w, stream=None) + self._exchange.exchange_and_wait( + dims.CellDim, + prognostic_states.next.w, + stream=decomposition.DefaultStream, + ) def run_corrector_step( self, @@ -1367,7 +1371,11 @@ def run_corrector_step( ) log.debug("exchanging prognostic field 'vn'") - self._exchange.exchange_and_wait(dims.EdgeDim, prognostic_states.next.vn, stream=None) + self._exchange.exchange_and_wait( + dims.EdgeDim, + prognostic_states.next.vn, + stream=decomposition.DefaultStream, + ) self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection( spatially_averaged_vn=self.z_vn_avg, @@ -1439,5 +1447,5 @@ def run_corrector_step( prognostic_states.next.rho, prognostic_states.next.exner, prognostic_states.next.w, - stream=None, + stream=decomposition.DefaultStream, ) From 41322f745a1c7656f47943635d45751a20e2c497 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 08:49:31 +0100 Subject: [PATCH 10/28] Let's see if that help, but it is strange that it takes longer now. --- ci/dace.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/dace.yml b/ci/dace.yml index f876048fd1..5bfbb01278 100644 --- a/ci/dace.yml +++ b/ci/dace.yml @@ -39,7 +39,7 @@ test_model_stencils_aarch64: - when: on_success variables: NUM_PROCESSES: 8 - SLURM_TIMELIMIT: '00:45:00' + SLURM_TIMELIMIT: '01:30:00' # test_model_datatests_x86_64: # extends: [.test_model_datatests, .test_template_x86_64] test_model_datatests_aarch64: From ae6db3963b25fb31c68177f11c9bf61d935f18f4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 15:09:44 +0100 Subject: [PATCH 11/28] Updated ghex version. --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index c1cbb39d90..382bc2eef0 100644 --- a/uv.lock +++ b/uv.lock @@ -1359,7 +1359,7 @@ wheels = [ [[package]] name = "ghex" version = "0.4.1" -source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#7d47080c299707f5312fe7525646905212cf97a8" } +source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#5fec2d116c123d459fb002b713b21f8914a3f290" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, From 815fc460ce88dbaab885785ec445332283c33237 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 15:54:43 +0100 Subject: [PATCH 12/28] This should fix the argument names, but I do not understant why it is not working. --- .../model/common/decomposition/mpi_decomposition.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index ed30a7ddbb..02cab7e82e 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -262,8 +262,10 @@ def exchange( else: # Stream given, perform a scheduled exchange.. # NOTE: GHEX interprets `None` as default stream. + # TODO(phimuell): Fix named arguments in GHEX. handle = self._comm.schedule_exchange( - applied_patterns, stream=(None if stream is definitions.DefaultStream else stream) + None if stream is definitions.DefaultStream else stream, + applied_patterns, ) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) @@ -431,9 +433,8 @@ def wait( else: # Stream given, perform a scheduled wait. # NOTE: GHEX interprets `None` as default stream. - self.handle.schedule_wait( - stream=(None if stream is definitions.DefaultStream else stream) - ) + # TODO(phimuell): Fixing named arguments in GHEX. + self.handle.schedule_wait(None if stream is definitions.DefaultStream else stream) # TODO(reviewer, phimuell): Is it safe to delete that here, even in the scheduled mode? del self.pattern_refs From 296c7b1a3a858499c80c5eb384f1de20cf8c959f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 16:03:26 +0100 Subject: [PATCH 13/28] Renambled named arguments. --- .../model/common/decomposition/mpi_decomposition.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 02cab7e82e..792313ffcc 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -264,8 +264,8 @@ def exchange( # NOTE: GHEX interprets `None` as default stream. # TODO(phimuell): Fix named arguments in GHEX. handle = self._comm.schedule_exchange( - None if stream is definitions.DefaultStream else stream, - applied_patterns, + patterns=applied_patterns, + stream=(None if stream is definitions.DefaultStream else stream), ) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) @@ -434,7 +434,9 @@ def wait( # Stream given, perform a scheduled wait. # NOTE: GHEX interprets `None` as default stream. # TODO(phimuell): Fixing named arguments in GHEX. - self.handle.schedule_wait(None if stream is definitions.DefaultStream else stream) + self.handle.schedule_wait( + stream=(None if stream is definitions.DefaultStream else stream), + ) # TODO(reviewer, phimuell): Is it safe to delete that here, even in the scheduled mode? del self.pattern_refs From 310a5e5bd65f90467b6616af220b0143427486fb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Dec 2025 16:06:25 +0100 Subject: [PATCH 14/28] There it is not \(yet\) possible to use the name. --- .../src/icon4py/model/common/decomposition/mpi_decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 792313ffcc..643540f672 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -435,7 +435,7 @@ def wait( # NOTE: GHEX interprets `None` as default stream. # TODO(phimuell): Fixing named arguments in GHEX. self.handle.schedule_wait( - stream=(None if stream is definitions.DefaultStream else stream), + None if stream is definitions.DefaultStream else stream, ) # TODO(reviewer, phimuell): Is it safe to delete that here, even in the scheduled mode? del self.pattern_refs From 9a8975ff8c419cb404d0c4cffa1b4383cc22ac83 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Dec 2025 11:30:37 +0100 Subject: [PATCH 15/28] Made it mandatory to pass the stream. --- .../model/common/decomposition/definitions.py | 31 ++++++++++--------- .../common/decomposition/mpi_decomposition.py | 16 +++++----- .../icon4py/model/common/states/factory.py | 3 +- .../mpi_tests/test_mpi_decomposition.py | 2 +- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index 22693d4a9c..8a0297123f 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -188,7 +188,7 @@ def global_index( class ExchangeResult(Protocol): def wait( self, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: """Performs a wait. @@ -213,7 +213,7 @@ def exchange( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> ExchangeResult: ... @overload @@ -221,14 +221,15 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> ExchangeResult: """Perform halo exchanges. The exact behaviour depends on the optional argument `stream` is supplied. - If it is a GPU stream, by default it is the default stream, then the - exchange will wait until the work previously submitted to the stream has - concluded. If it is `None` then the exchange will start immediately. + If it is a GPU stream then the exchange will wait until the work previously + submitted to the stream has concluded. To select the default stream use the + special value `DefaultStream`. + If it is `None` then the exchange will start immediately. It is important that `wait` is called on the returned handle. """ @@ -239,7 +240,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: ... @overload @@ -247,7 +248,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: """Exchange and wait in one go.""" ... @@ -257,7 +258,7 @@ def __call__( *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None | ExchangeResult: """Perform a halo exchange operation. @@ -285,7 +286,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> ExchangeResult: return SingleNodeResult() @@ -293,7 +294,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: return None @@ -308,7 +309,7 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> ExchangeResult | None: res = self.exchange(dim, *args, stream=stream) if wait: @@ -343,7 +344,7 @@ class HaloExchangeWaitRuntime(Protocol): def __call__( self, communication_handle: ExchangeResult, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: """Calls `wait()` on the provided communication handle, `stream` is forwarded.""" ... @@ -368,7 +369,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: SingleNodeResult, - stream: StreamLike | None = DefaultStream, + stream: StreamLike | None, ) -> None: communication_handle.wait(stream=stream) @@ -403,7 +404,7 @@ def create_single_node_halo_exchange_wait(runtime: SingleNodeExchange) -> HaloEx class SingleNodeResult: - def wait(self, stream: StreamLike | None = DefaultStream) -> None: + def wait(self, stream: StreamLike | None) -> None: pass def is_ready(self) -> bool: diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 643540f672..263e7ca912 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -245,7 +245,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: definitions.StreamLike | None = definitions.DefaultStream, + stream: definitions.StreamLike | None, ) -> MultiNodeResult: """ Exchange method that slices the fields based on the dimension and then performs halo exchange. @@ -261,7 +261,8 @@ def exchange( handle = self._comm.exchange(applied_patterns) else: # Stream given, perform a scheduled exchange.. - # NOTE: GHEX interprets `None` as default stream. + # NOTE: GHEX interprets `None` as default stream. Furthermore, if no + # GPU is present, passing `None` is mandatory. # TODO(phimuell): Fix named arguments in GHEX. handle = self._comm.schedule_exchange( patterns=applied_patterns, @@ -274,7 +275,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: definitions.StreamLike | None = definitions.DefaultStream, + stream: definitions.StreamLike | None, ) -> None: res = self.exchange(dim, *fields, stream=stream) res.wait(stream=stream) @@ -285,7 +286,7 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: definitions.StreamLike | None = definitions.DefaultStream, + stream: definitions.StreamLike | None, ) -> MultiNodeResult | None: if dim is None: raise ValueError("Need to define a dimension.") @@ -349,7 +350,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: MultiNodeResult, - stream: definitions.StreamLike | None = definitions.DefaultStream, + stream: definitions.StreamLike | None, ) -> None: """Wait on the communication handle.""" communication_handle.wait(stream=stream) @@ -425,14 +426,15 @@ class MultiNodeResult: def wait( self, - stream: definitions.StreamLike | None = definitions.DefaultStream, + stream: definitions.StreamLike | None, ) -> None: if stream is None: # No stream given, perform full blocking wait. self.handle.wait() else: # Stream given, perform a scheduled wait. - # NOTE: GHEX interprets `None` as default stream. + # NOTE: GHEX interprets `None` as default stream. Furthermore, if no + # GPU is present, passing `None` is mandatory. # TODO(phimuell): Fixing named arguments in GHEX. self.handle.schedule_wait( None if stream is definitions.DefaultStream else stream, diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 626c6950a1..3a65f51bef 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -112,7 +112,8 @@ def exchange( first_dim in dims.MAIN_HORIZONTAL_DIMENSIONS.values() ), f"1st dimension {first_dim} needs to be one of (CellDim, EdgeDim, VertexDim) for exchange" with as_exchangeable_field(field) as buffer: - exchange.exchange_and_wait(first_dim, buffer) + # Synchronous exchange. + exchange.exchange_and_wait(first_dim, buffer, stream=None) log.debug(f"exchanged buffer for {name}") diff --git a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py index b7a695ce82..582a011aed 100644 --- a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py +++ b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py @@ -272,7 +272,7 @@ def test_exchange_on_dummy_data( dimension, definitions.DecompositionInfo.EntryType.OWNED ) assert np.all(input_field.asnumpy() == number) - exchange.exchange_and_wait(dimension, input_field) + exchange.exchange_and_wait(dimension, input_field, stream=None) result = input_field.asnumpy() print(f"rank={processor_props.rank} - num of halo points ={halo_points.shape}") print( From a985a5988064ea2bdf7788e047ff87298269e6f1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Dec 2025 12:01:35 +0100 Subject: [PATCH 16/28] Undo something. --- ci/dace.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/dace.yml b/ci/dace.yml index 5bfbb01278..f876048fd1 100644 --- a/ci/dace.yml +++ b/ci/dace.yml @@ -39,7 +39,7 @@ test_model_stencils_aarch64: - when: on_success variables: NUM_PROCESSES: 8 - SLURM_TIMELIMIT: '01:30:00' + SLURM_TIMELIMIT: '00:45:00' # test_model_datatests_x86_64: # extends: [.test_model_datatests, .test_template_x86_64] test_model_datatests_aarch64: From e04e057893ab2c5675a44ea5666082dbe8e82d72 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Dec 2025 12:25:22 +0100 Subject: [PATCH 17/28] Updated GHEX. --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 382bc2eef0..890cb8ea13 100644 --- a/uv.lock +++ b/uv.lock @@ -1359,7 +1359,7 @@ wheels = [ [[package]] name = "ghex" version = "0.4.1" -source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#5fec2d116c123d459fb002b713b21f8914a3f290" } +source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#473ebd0959027f77fd35c2f13f7a4e113197d818" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, From 7fc5154643ae01bb911467669fe544de2ee51b6f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Dec 2025 12:36:28 +0100 Subject: [PATCH 18/28] Forgot one. --- .../src/icon4py/model/common/interpolation/rbf_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py index 9af69817a8..ce0fe6a2d4 100644 --- a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py +++ b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py @@ -355,7 +355,7 @@ def index_offset(f): rbf_vec_coeff[j][horizontal_start:] /= array_ns.sum( nxnx[j] * rbf_vec_coeff[j][horizontal_start:], axis=1 )[:, array_ns.newaxis] - exchange(*rbf_vec_coeff) + exchange(*rbf_vec_coeff, stream=None) return rbf_vec_coeff From 9a68a54bc089071deeff2cf6e84d9fa2b8039d54 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Dec 2025 12:54:42 +0100 Subject: [PATCH 19/28] Let's hope that is enough. --- .../model/atmosphere/advection/advection.py | 4 +- .../advection/advection_horizontal.py | 4 +- .../interpolation/interpolation_factory.py | 44 ++++++++++++++----- .../model/common/metrics/metrics_factory.py | 20 ++++++--- .../mpi_tests/test_mpi_decomposition.py | 2 +- 5 files changed, 55 insertions(+), 19 deletions(-) diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py index e3a90e7189..5256f54fe5 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py @@ -188,7 +188,9 @@ def run( log.debug("advection run - start") log.debug("communication of prep_adv cell field: mass_flx_ic - start") - self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic) + self._exchange.exchange_and_wait( + dims.CellDim, prep_adv.mass_flx_ic, stream=decomposition.DefaultStream + ) log.debug("communication of prep_adv cell field: mass_flx_ic - end") log.debug("running stencil copy_cell_kdim_field - start") diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py index fae19d65bf..992b13b74f 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py @@ -139,7 +139,9 @@ def apply_flux_limiter( ) log.debug("communication of advection cell field: r_m - start") - self._exchange.exchange_and_wait(dims.CellDim, self._r_m) + self._exchange.exchange_and_wait( + dims.CellDim, self._r_m, stream=decomposition.DefaultStream + ) log.debug("communication of advection cell field: r_m - end") # limit outward fluxes diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py index 79cd0d7823..8355160076 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py @@ -155,7 +155,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_geofac_n2s, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), fields=(attrs.GEOFAC_N2S,), domain=(dims.CellDim, dims.C2E2CODim), @@ -175,7 +177,9 @@ def _register_computed_fields(self) -> None: geofac_grdiv = factory.NumpyDataProvider( func=functools.partial( interpolation_fields.compute_geofac_grdiv, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), array_ns=self._xp, ), fields=(attrs.GEOFAC_GRDIV,), @@ -199,7 +203,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_mass_conserving_bilinear_cell_average_weight, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), fields=(attrs.C_BLN_AVG,), domain=(dims.CellDim, dims.C2E2CODim), @@ -226,7 +232,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_c_lin_e, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), fields=(attrs.C_LIN_E,), domain=(dims.EdgeDim, dims.E2CDim), @@ -247,7 +255,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_geofac_grg, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y), domain=(dims.CellDim, dims.C2E2CODim), @@ -271,7 +281,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_e_flx_avg, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), fields=(attrs.E_FLX_AVG,), domain=(dims.EdgeDim, dims.E2C2EODim), @@ -319,7 +331,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), fields=(attrs.POS_ON_TPLANE_E_X, attrs.POS_ON_TPLANE_E_Y), domain=(dims.EdgeDim, dims.E2CDim), @@ -348,7 +362,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_cells_aw_verts, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.VertexDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.VertexDim, stream=None + ), ), fields=(attrs.CELL_AW_VERTS,), domain=(dims.VertexDim, dims.V2CDim), @@ -375,7 +391,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( rbf.compute_rbf_interpolation_coeffs_cell, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), fields=(attrs.RBF_VEC_COEFF_C1, attrs.RBF_VEC_COEFF_C2), domain=(dims.CellDim, dims.C2E2C2EDim), @@ -407,7 +425,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( rbf.compute_rbf_interpolation_coeffs_edge, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), fields=(attrs.RBF_VEC_COEFF_E,), domain=(dims.EdgeDim, dims.E2C2EDim), @@ -438,7 +458,9 @@ def _register_computed_fields(self) -> None: func=functools.partial( rbf.compute_rbf_interpolation_coeffs_vertex, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.VertexDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.VertexDim, stream=None + ), ), fields=(attrs.RBF_VEC_COEFF_V1, attrs.RBF_VEC_COEFF_V2), domain=(dims.VertexDim, dims.V2EDim), diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index 9d7e9a7a01..a761eff89d 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -147,7 +147,9 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen func=functools.partial( v_grid.compute_vertical_coordinate, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), fields=(attrs.CELL_HEIGHT_ON_HALF_LEVEL,), domain=(dims.CellDim, dims.KHalfDim), @@ -636,7 +638,9 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen max_flat_index_provider = factory.NumpyDataProvider( func=functools.partial( mf.compute_flat_max_idx, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), array_ns=self._xp, ), deps={ @@ -748,7 +752,9 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen func=functools.partial( compute_zdiff_gradp_dsl.compute_zdiff_gradp_dsl, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), deps={ "z_mc": attrs.Z_MC, @@ -808,7 +814,9 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen func=functools.partial( weight_factors.compute_wgtfacq_e_dsl, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.EdgeDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + ), ), deps={ "z_ifc": attrs.CELL_HEIGHT_ON_HALF_LEVEL, @@ -870,7 +878,9 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen func=functools.partial( compute_diffusion_metrics.compute_max_nbhgt_array_ns, array_ns=self._xp, - exchange=functools.partial(self._exchange.exchange_and_wait, dims.CellDim), + exchange=functools.partial( + self._exchange.exchange_and_wait, dims.CellDim, stream=None + ), ), deps={ "z_mc": attrs.Z_MC, diff --git a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py index 582a011aed..74f7ed0de5 100644 --- a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py +++ b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py @@ -321,6 +321,6 @@ def test_halo_exchange_for_sparse_field( f"{processor_props.rank}/{processor_props.comm_size}: size of computed field {result.asnumpy().shape}" ) - exchange.exchange_and_wait(dims.CellDim, result) + exchange.exchange_and_wait(dims.CellDim, result, stream=None) assert test_helpers.dallclose(result.asnumpy(), field_ref.asnumpy()) From 0d3a4001620920c539568a0c289a1d223a23fcdf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 24 Dec 2025 10:12:24 +0100 Subject: [PATCH 20/28] Added warnings. --- .../common/decomposition/mpi_decomposition.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 263e7ca912..e21f001c68 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -10,6 +10,7 @@ import functools import logging +import warnings from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Final, Union @@ -256,8 +257,14 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - if stream is None: - # Normal exchange. + if stream is None or (not ghex.__config__["gpu"]): + if stream is not None: + warnings.warn( + "Requested 'scheduled exchange' mode in GHEX, which is only available" + " if GHEX was compiled with GPU support, but it was not." + " Falling back to normal exchange.", + stacklevel=0, + ) handle = self._comm.exchange(applied_patterns) else: # Stream given, perform a scheduled exchange.. @@ -429,6 +436,13 @@ def wait( stream: definitions.StreamLike | None, ) -> None: if stream is None: + if stream is not None: + warnings.warn( + "Requested 'scheduled wait' mode in GHEX, which is only available" + " if GHEX was compiled with GPU support, but it was not." + " Falling back to normal exchange.", + stacklevel=0, + ) # No stream given, perform full blocking wait. self.handle.wait() else: From 6f4c6da0c0791a92da9ce04fb33125e137520a0f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 24 Dec 2025 10:13:32 +0100 Subject: [PATCH 21/28] Updated GHEX. --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 890cb8ea13..ebc4cf0146 100644 --- a/uv.lock +++ b/uv.lock @@ -1359,7 +1359,7 @@ wheels = [ [[package]] name = "ghex" version = "0.4.1" -source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#473ebd0959027f77fd35c2f13f7a4e113197d818" } +source = { git = "https://github.com/philip-paul-mueller/GHEX/?branch=phimuell__async-mpi-2#f974c7b51280cfe958c9bf46a1669cbba8e5bc2b" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, From aae7b06586e77ab360a6823e7df262f01788c61c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 24 Dec 2025 10:37:55 +0100 Subject: [PATCH 22/28] Starnge things are going on. --- .../model/common/decomposition/mpi_decomposition.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index e21f001c68..7bf1bde85f 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -257,7 +257,11 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - if stream is None or (not ghex.__config__["gpu"]): + # NOTE: We should actually also test if `ghex.__config__["gpu"]` is `False` + # and then use the normal `exchange` call, because in that case + # `scheduled_exchange()` is not present. However, for some reason GHEX is + # compiled with `-DGHEX_USE_GPU=OFF` by `uv`. + if stream is None: # or (not ghex.__config__["gpu"]): if stream is not None: warnings.warn( "Requested 'scheduled exchange' mode in GHEX, which is only available" @@ -436,7 +440,11 @@ def wait( stream: definitions.StreamLike | None, ) -> None: if stream is None: - if stream is not None: + # NOTE: We should actually also test if `ghex.__config__["gpu"]` is `False` + # and then use the normal `exchange` call, because in that case + # `scheduled_wait()` is not present. However, for some reason GHEX is + # compiled with `-DGHEX_USE_GPU=OFF` by `uv`. + if stream is not None: # or (not ghex.__config__["gpu"]): warnings.warn( "Requested 'scheduled wait' mode in GHEX, which is only available" " if GHEX was compiled with GPU support, but it was not." From d6dbc8c106225946930e861b8aea4339a20dd5dd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 24 Dec 2025 11:21:55 +0100 Subject: [PATCH 23/28] The problem was me installing GHEX wrong. --- .../common/decomposition/mpi_decomposition.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 7bf1bde85f..68f2400ac7 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -257,11 +257,7 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - # NOTE: We should actually also test if `ghex.__config__["gpu"]` is `False` - # and then use the normal `exchange` call, because in that case - # `scheduled_exchange()` is not present. However, for some reason GHEX is - # compiled with `-DGHEX_USE_GPU=OFF` by `uv`. - if stream is None: # or (not ghex.__config__["gpu"]): + if stream is None or (not ghex.__config__["gpu"]): if stream is not None: warnings.warn( "Requested 'scheduled exchange' mode in GHEX, which is only available" @@ -439,19 +435,15 @@ def wait( self, stream: definitions.StreamLike | None, ) -> None: - if stream is None: - # NOTE: We should actually also test if `ghex.__config__["gpu"]` is `False` - # and then use the normal `exchange` call, because in that case - # `scheduled_wait()` is not present. However, for some reason GHEX is - # compiled with `-DGHEX_USE_GPU=OFF` by `uv`. - if stream is not None: # or (not ghex.__config__["gpu"]): + if stream is None or (not ghex.__config__["gpu"]): + # No stream given, perform full blocking wait. + if stream is not None: warnings.warn( "Requested 'scheduled wait' mode in GHEX, which is only available" " if GHEX was compiled with GPU support, but it was not." " Falling back to normal exchange.", stacklevel=0, ) - # No stream given, perform full blocking wait. self.handle.wait() else: # Stream given, perform a scheduled wait. From c8d58f84bbfb80da0bc147ff18ebe485130a0221 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 1 Jan 2026 08:56:39 +0100 Subject: [PATCH 24/28] The no streaming case is no longer `None` but has a named constant. --- .../model/common/decomposition/definitions.py | 80 +++++++++++++------ .../common/decomposition/mpi_decomposition.py | 15 ++-- .../interpolation/interpolation_factory.py | 26 +++--- .../common/interpolation/rbf_interpolation.py | 3 +- .../model/common/metrics/metrics_factory.py | 10 +-- .../icon4py/model/common/states/factory.py | 2 +- .../mpi_tests/test_mpi_decomposition.py | 4 +- 7 files changed, 86 insertions(+), 54 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index 8a0297123f..1feb1cd566 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -32,7 +32,15 @@ class DefaultStream: """Used in `exchange_and_wait()`, `exchange()` to indicate that synchronization - with the default stream is requested. + with the default stream is requested, see there for more information. If there + is no GPU, or the data is stored on the host, then the behaviour falls back to + `NoStreaming`, see there for more. + """ + + +class NoStreaming: + """Used in `exchange_and_wait()`, `exchange()` to indicate that no streaming + support is requested, see there for more information. """ @@ -188,16 +196,25 @@ def global_index( class ExchangeResult(Protocol): def wait( self, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: - """Performs a wait. + """Wait on the halo exchange. The function will wait until the communication has ended and then start the - unpacking of the data. If `stream` is `None` then the function will wait - until the unpacking has completed. If it is a CUDA stream or the - `DefaultStream` singleton, then the function will return after the unpacking - has been scheduled. It will add synchronizations, such that all work - that will be submitted to `stream` will wait until the unpacking has finished. + unpacking of the data. Depending on `stream` the behaviour when the function + returns are different. If it is a CUDA stream, see `StreamLike`, then the + function will return as soon as the unpacking has been scheduled. Furthermore, + the unpacking will synchronize with `stream`, i.e. all work that is submitted + to `stream`, after this function returns will not start before the unpacking + has finished. If `stream` is the special constant `NoStreaming`, then the + function will only return once the unpacking has finished. + + Note: + - To select the default stream the special constant `DefaultStream` can be used. + - If there is no GPU, then using `DefaultStream` is the same as `NoStreaming`. + - If `stream` is used then "scheduling exchange" in GHEX are used. + - For data located on the host the behaviour is always the same as if + specifying `NoStreaming`. """ ... @@ -213,7 +230,7 @@ def exchange( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> ExchangeResult: ... @overload @@ -221,17 +238,28 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> ExchangeResult: """Perform halo exchanges. - The exact behaviour depends on the optional argument `stream` is supplied. - If it is a GPU stream then the exchange will wait until the work previously - submitted to the stream has concluded. To select the default stream use the - special value `DefaultStream`. - If it is `None` then the exchange will start immediately. - - It is important that `wait` is called on the returned handle. + The function packs the data and transmit it to the neighboring nodes, on the + returned handle a user must call `wait()`, to complete the process, see + `ExchangeResult.wait()` for more. + The function will only return once the data has been send. + + The exact behaviour depends on the `stream` argument. If `stream` is the + constant `NoStreaming` then the function will start to pack the data + immediately, this means the caller must make sure that the computation has + been completed. + If it is a CUDA stream, see `StreamLike` or the constant `DefaultStream`, + then the packing will wait until all work, that has been submitted to `stream` + before this function was called, has been completed. + + Note: + - If there is no GPU then specifying `DefaultStream` is the same as + `NoStreaming`. + - For data located on the host the behaviour is always the same as if + specifying `NoStreaming`. """ ... @@ -240,7 +268,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *buffers: data_alloc.NDArray, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: ... @overload @@ -248,7 +276,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: """Exchange and wait in one go.""" ... @@ -258,7 +286,7 @@ def __call__( *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None | ExchangeResult: """Perform a halo exchange operation. @@ -286,7 +314,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> ExchangeResult: return SingleNodeResult() @@ -294,7 +322,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: return None @@ -309,7 +337,7 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> ExchangeResult | None: res = self.exchange(dim, *args, stream=stream) if wait: @@ -344,7 +372,7 @@ class HaloExchangeWaitRuntime(Protocol): def __call__( self, communication_handle: ExchangeResult, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: """Calls `wait()` on the provided communication handle, `stream` is forwarded.""" ... @@ -369,7 +397,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: SingleNodeResult, - stream: StreamLike | None, + stream: StreamLike | type[NoStreaming], ) -> None: communication_handle.wait(stream=stream) @@ -404,7 +432,7 @@ def create_single_node_halo_exchange_wait(runtime: SingleNodeExchange) -> HaloEx class SingleNodeResult: - def wait(self, stream: StreamLike | None) -> None: + def wait(self, stream: StreamLike | type[NoStreaming]) -> None: pass def is_ready(self) -> bool: diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 68f2400ac7..ff8ab11d74 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -246,7 +246,7 @@ def exchange( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: definitions.StreamLike | None, + stream: definitions.StreamLike | type[definitions.NoStreaming], ) -> MultiNodeResult: """ Exchange method that slices the fields based on the dimension and then performs halo exchange. @@ -257,8 +257,8 @@ def exchange( applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - if stream is None or (not ghex.__config__["gpu"]): - if stream is not None: + if stream is definitions.NoStreaming or (not ghex.__config__["gpu"]): + if stream is not definitions.NoStreaming: warnings.warn( "Requested 'scheduled exchange' mode in GHEX, which is only available" " if GHEX was compiled with GPU support, but it was not." @@ -270,7 +270,6 @@ def exchange( # Stream given, perform a scheduled exchange.. # NOTE: GHEX interprets `None` as default stream. Furthermore, if no # GPU is present, passing `None` is mandatory. - # TODO(phimuell): Fix named arguments in GHEX. handle = self._comm.schedule_exchange( patterns=applied_patterns, stream=(None if stream is definitions.DefaultStream else stream), @@ -282,7 +281,7 @@ def exchange_and_wait( self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray, - stream: definitions.StreamLike | None, + stream: definitions.StreamLike | type[definitions.NoStreaming], ) -> None: res = self.exchange(dim, *fields, stream=stream) res.wait(stream=stream) @@ -293,7 +292,7 @@ def __call__( # type: ignore[return] # return statment in else condition *args: Any, dim: gtx.Dimension, wait: bool = True, - stream: definitions.StreamLike | None, + stream: definitions.StreamLike | type[definitions.NoStreaming], ) -> MultiNodeResult | None: if dim is None: raise ValueError("Need to define a dimension.") @@ -357,7 +356,7 @@ class HaloExchangeWait: def __call__( self, communication_handle: MultiNodeResult, - stream: definitions.StreamLike | None, + stream: definitions.StreamLike | type[definitions.NoStreaming], ) -> None: """Wait on the communication handle.""" communication_handle.wait(stream=stream) @@ -433,7 +432,7 @@ class MultiNodeResult: def wait( self, - stream: definitions.StreamLike | None, + stream: definitions.StreamLike | type[definitions.NoStreaming], ) -> None: if stream is None or (not ghex.__config__["gpu"]): # No stream given, perform full blocking wait. diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py index 8355160076..97740c8139 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py @@ -156,7 +156,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_geofac_n2s, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), fields=(attrs.GEOFAC_N2S,), @@ -178,7 +178,7 @@ def _register_computed_fields(self) -> None: func=functools.partial( interpolation_fields.compute_geofac_grdiv, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), array_ns=self._xp, ), @@ -204,7 +204,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_mass_conserving_bilinear_cell_average_weight, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), fields=(attrs.C_BLN_AVG,), @@ -233,7 +233,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_c_lin_e, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), fields=(attrs.C_LIN_E,), @@ -256,7 +256,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_geofac_grg, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y), @@ -282,7 +282,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_e_flx_avg, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), fields=(attrs.E_FLX_AVG,), @@ -332,7 +332,7 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), fields=(attrs.POS_ON_TPLANE_E_X, attrs.POS_ON_TPLANE_E_Y), @@ -363,7 +363,9 @@ def _register_computed_fields(self) -> None: interpolation_fields.compute_cells_aw_verts, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.VertexDim, stream=None + self._exchange.exchange_and_wait, + dims.VertexDim, + stream=decomposition.NoStreaming, ), ), fields=(attrs.CELL_AW_VERTS,), @@ -392,7 +394,7 @@ def _register_computed_fields(self) -> None: rbf.compute_rbf_interpolation_coeffs_cell, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), fields=(attrs.RBF_VEC_COEFF_C1, attrs.RBF_VEC_COEFF_C2), @@ -426,7 +428,7 @@ def _register_computed_fields(self) -> None: rbf.compute_rbf_interpolation_coeffs_edge, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), fields=(attrs.RBF_VEC_COEFF_E,), @@ -459,7 +461,9 @@ def _register_computed_fields(self) -> None: rbf.compute_rbf_interpolation_coeffs_vertex, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.VertexDim, stream=None + self._exchange.exchange_and_wait, + dims.VertexDim, + stream=decomposition.NoStreaming, ), ), fields=(attrs.RBF_VEC_COEFF_V1, attrs.RBF_VEC_COEFF_V2), diff --git a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py index ce0fe6a2d4..8758cde161 100644 --- a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py +++ b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py @@ -16,6 +16,7 @@ import scipy.linalg as sla from icon4py.model.common import dimension as dims, type_alias as ta +from icon4py.model.common.decomposition import definitions as decomposition from icon4py.model.common.grid import base as base_grid from icon4py.model.common.utils import data_allocation as data_alloc @@ -355,7 +356,7 @@ def index_offset(f): rbf_vec_coeff[j][horizontal_start:] /= array_ns.sum( nxnx[j] * rbf_vec_coeff[j][horizontal_start:], axis=1 )[:, array_ns.newaxis] - exchange(*rbf_vec_coeff, stream=None) + exchange(*rbf_vec_coeff, stream=decomposition.NoStreaming) return rbf_vec_coeff diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index a761eff89d..823ff03263 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -148,7 +148,7 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen v_grid.compute_vertical_coordinate, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), fields=(attrs.CELL_HEIGHT_ON_HALF_LEVEL,), @@ -639,7 +639,7 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen func=functools.partial( mf.compute_flat_max_idx, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), array_ns=self._xp, ), @@ -753,7 +753,7 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen compute_zdiff_gradp_dsl.compute_zdiff_gradp_dsl, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), deps={ @@ -815,7 +815,7 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen weight_factors.compute_wgtfacq_e_dsl, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.EdgeDim, stream=None + self._exchange.exchange_and_wait, dims.EdgeDim, stream=decomposition.NoStreaming ), ), deps={ @@ -879,7 +879,7 @@ def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statemen compute_diffusion_metrics.compute_max_nbhgt_array_ns, array_ns=self._xp, exchange=functools.partial( - self._exchange.exchange_and_wait, dims.CellDim, stream=None + self._exchange.exchange_and_wait, dims.CellDim, stream=decomposition.NoStreaming ), ), deps={ diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 3a65f51bef..bb34d1eade 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -113,7 +113,7 @@ def exchange( ), f"1st dimension {first_dim} needs to be one of (CellDim, EdgeDim, VertexDim) for exchange" with as_exchangeable_field(field) as buffer: # Synchronous exchange. - exchange.exchange_and_wait(first_dim, buffer, stream=None) + exchange.exchange_and_wait(first_dim, buffer, stream=decomposition.NoStreaming) log.debug(f"exchanged buffer for {name}") diff --git a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py index 74f7ed0de5..a0d34b20ff 100644 --- a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py +++ b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py @@ -272,7 +272,7 @@ def test_exchange_on_dummy_data( dimension, definitions.DecompositionInfo.EntryType.OWNED ) assert np.all(input_field.asnumpy() == number) - exchange.exchange_and_wait(dimension, input_field, stream=None) + exchange.exchange_and_wait(dimension, input_field, stream=definitions.NoStreaming) result = input_field.asnumpy() print(f"rank={processor_props.rank} - num of halo points ={halo_points.shape}") print( @@ -321,6 +321,6 @@ def test_halo_exchange_for_sparse_field( f"{processor_props.rank}/{processor_props.comm_size}: size of computed field {result.asnumpy().shape}" ) - exchange.exchange_and_wait(dims.CellDim, result, stream=None) + exchange.exchange_and_wait(dims.CellDim, result, stream=definitions.NoStreaming) assert test_helpers.dallclose(result.asnumpy(), field_ref.asnumpy()) From 89a94aa4a073ef23a81a9aef23eae785a7407dbb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 1 Jan 2026 09:00:18 +0100 Subject: [PATCH 25/28] Forgot to add a warning. --- .../icon4py/model/common/decomposition/mpi_decomposition.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index ff8ab11d74..1582ed8161 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -267,6 +267,12 @@ def exchange( ) handle = self._comm.exchange(applied_patterns) else: + if stream is None: + warnings.warn( + "Passed `None` as `stream` argument. This is discouraged but allowed" + " `stream` is interpreted as `DefaultStream`.", + stacklevel=0, + ) # Stream given, perform a scheduled exchange.. # NOTE: GHEX interprets `None` as default stream. Furthermore, if no # GPU is present, passing `None` is mandatory. From 0b8307b8031293ad33ae98d7a7e72bbbeedfcff9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 1 Jan 2026 14:13:33 +0100 Subject: [PATCH 26/28] I am not fully like it, but for this it seems appropriate. --- model/common/tests/common/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/model/common/tests/common/utils.py b/model/common/tests/common/utils.py index f3d689155d..9d14bf5fba 100644 --- a/model/common/tests/common/utils.py +++ b/model/common/tests/common/utils.py @@ -6,9 +6,14 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Any from icon4py.model.common.utils import data_allocation as data_alloc -def dummy_exchange(*field: data_alloc.NDArray) -> None: - return None +def dummy_exchange(*field: data_alloc.NDArray, **kwargs: Any) -> None: + # The real exchange function takes a `stream` argument, for the scheduled + # exchange. We have to ignore it as we never do an exchange. + # TODO(phimuell): Is this the best way? + assert len(kwargs) <= 1 + assert len(kwargs) == 1 and "stream" in kwargs From 0e2fbde936fbd369efcecb5b5c3047d2234064f4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 2 Jan 2026 12:20:35 +0100 Subject: [PATCH 27/28] Forgot that. --- model/common/tests/common/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/common/tests/common/utils.py b/model/common/tests/common/utils.py index 9d14bf5fba..2d866d6bf6 100644 --- a/model/common/tests/common/utils.py +++ b/model/common/tests/common/utils.py @@ -15,5 +15,5 @@ def dummy_exchange(*field: data_alloc.NDArray, **kwargs: Any) -> None: # The real exchange function takes a `stream` argument, for the scheduled # exchange. We have to ignore it as we never do an exchange. # TODO(phimuell): Is this the best way? - assert len(kwargs) <= 1 - assert len(kwargs) == 1 and "stream" in kwargs + assert len(kwargs) == 0 or (len(kwargs) == 1 and "stream" in kwargs) + return None From a570aee6ef4dc2b157d258ef56e9916369ee84cd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 2 Jan 2026 12:27:56 +0100 Subject: [PATCH 28/28] Removed useless return. --- model/common/tests/common/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model/common/tests/common/utils.py b/model/common/tests/common/utils.py index 2d866d6bf6..da9736b982 100644 --- a/model/common/tests/common/utils.py +++ b/model/common/tests/common/utils.py @@ -16,4 +16,3 @@ def dummy_exchange(*field: data_alloc.NDArray, **kwargs: Any) -> None: # exchange. We have to ignore it as we never do an exchange. # TODO(phimuell): Is this the best way? assert len(kwargs) == 0 or (len(kwargs) == 1 and "stream" in kwargs) - return None