Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
31e767b
switch to gt4py main
edopao Oct 1, 2025
2f4692b
Merge branch 'main' into blueline_integration
edopao Oct 1, 2025
b4bd3e4
Enable custom backends for blueline
havogt Oct 1, 2025
055051d
improve typing
havogt Oct 1, 2025
559106e
fix diffusion
havogt Oct 1, 2025
03e6a46
cleanup and fix allocator/backend
havogt Oct 1, 2025
3b237a5
dace default, gtfn for vertically implicit
havogt Oct 1, 2025
6c6f9df
from measurement
havogt Oct 2, 2025
33fe48c
customize one
havogt Oct 3, 2025
4a3ad8f
fix forwarding
havogt Oct 3, 2025
cde438c
cleanup
havogt Oct 8, 2025
575cb1c
Merge remote-tracking branch 'upstream/main' into add_some_backend_cu…
havogt Oct 8, 2025
cede653
fix log message
havogt Oct 9, 2025
fe21f91
overlap experiment
havogt Oct 9, 2025
3880d1a
swap exchange<->wait
havogt Oct 9, 2025
43cc58b
fix comment
havogt Oct 9, 2025
9554aad
run exchange async
havogt Oct 9, 2025
8085549
Use experimental GHEX async scheduling
havogt Nov 10, 2025
51099e9
Merge remote-tracking branch 'upstream/main' into exchange_overlap
havogt Nov 11, 2025
621517c
cleanup
havogt Nov 11, 2025
6a93895
Merge remote-tracking branch 'upstream/main' into async_ghex
havogt Nov 11, 2025
e94560b
Merge branch 'async_ghex' into exchange_overlap_mpi
havogt Nov 11, 2025
64ba54c
fix domain construction
havogt Nov 11, 2025
8b6ba50
fix domain
havogt Nov 11, 2025
6cc3f53
cache half-fields
havogt Nov 11, 2025
b78106c
try set_sync_marker
havogt Nov 11, 2025
ca23bb4
fix return
havogt Nov 11, 2025
4356e7a
Revert "try set_sync_marker"
havogt Nov 11, 2025
293ab4e
Merge remote-tracking branch 'upstream/main' into exchange_overlap
havogt Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def __init__(
offset_provider=self._grid.connectivities,
)

self._apply_divergence_damping_and_update_vn = setup_program(
self._apply_divergence_damping_and_update_vn_first_half = setup_program(
backend=backend,
program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn,
constant_args={
Expand All @@ -517,6 +517,35 @@ def __init__(
},
vertical_sizes={
"vertical_start": gtx.int32(0),
"vertical_end": gtx.int32(self._grid.num_levels // 2),
},
offset_provider=self._grid.connectivities,
)
self._apply_divergence_damping_and_update_vn_second_half = setup_program(
backend=backend,
program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn,
constant_args={
"horizontal_mask_for_3d_divdamp": self._metric_state_nonhydro.horizontal_mask_for_3d_divdamp,
"scaling_factor_for_3d_divdamp": self._metric_state_nonhydro.scaling_factor_for_3d_divdamp,
"inv_dual_edge_length": self._edge_geometry.inverse_dual_edge_lengths,
"nudgecoeff_e": self._interpolation_state.nudgecoeff_e,
"geofac_grdiv": self._interpolation_state.geofac_grdiv,
"advection_explicit_weight_parameter": self._params.advection_explicit_weight_parameter,
"advection_implicit_weight_parameter": self._params.advection_implicit_weight_parameter,
"iau_wgt_dyn": self._config.iau_wgt_dyn,
"is_iau_active": self._config.is_iau_active,
"limited_area": self._grid.limited_area,
},
variants={
"apply_2nd_order_divergence_damping": [False, True],
"apply_4th_order_divergence_damping": [False, True],
},
horizontal_sizes={
"horizontal_start": gtx.int32(self._start_edge_nudging_level_2),
"horizontal_end": self._end_edge_local,
},
vertical_sizes={
"vertical_start": gtx.int32(self._grid.num_levels // 2),
"vertical_end": gtx.int32(self._grid.num_levels),
},
offset_provider=self._grid.connectivities,
Expand Down Expand Up @@ -547,26 +576,51 @@ def __init__(
offset_provider=self._grid.connectivities,
)

self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection = setup_program(
backend=backend,
program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection,
constant_args={
"e_flx_avg": self._interpolation_state.e_flx_avg,
"ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e,
},
variants={
"at_first_substep": [False, True],
"prepare_advection": [False, True],
},
horizontal_sizes={
"horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5),
"horizontal_end": self._end_edge_halo_level_2,
},
vertical_sizes={
"vertical_start": gtx.int32(0),
"vertical_end": gtx.int32(self._grid.num_levels),
},
offset_provider=self._grid.connectivities,
self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half = (
setup_program(
backend=backend,
program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection,
constant_args={
"e_flx_avg": self._interpolation_state.e_flx_avg,
"ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e,
},
variants={
"at_first_substep": [False, True],
"prepare_advection": [False, True],
},
horizontal_sizes={
"horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5),
"horizontal_end": self._end_edge_halo_level_2,
},
vertical_sizes={
"vertical_start": gtx.int32(0),
"vertical_end": gtx.int32(self._grid.num_levels // 2),
},
offset_provider=self._grid.connectivities,
)
)
self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half = (
setup_program(
backend=backend,
program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection,
constant_args={
"e_flx_avg": self._interpolation_state.e_flx_avg,
"ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e,
},
variants={
"at_first_substep": [False, True],
"prepare_advection": [False, True],
},
horizontal_sizes={
"horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5),
"horizontal_end": self._end_edge_halo_level_2,
},
vertical_sizes={
"vertical_start": gtx.int32(self._grid.num_levels // 2),
"vertical_end": gtx.int32(self._grid.num_levels),
},
offset_provider=self._grid.connectivities,
)
)

self._vertically_implicit_solver_at_predictor_step = setup_program(
Expand Down Expand Up @@ -819,6 +873,8 @@ def __init__(

self.p_test_run = False

self._first_half_cache = {}
self._second_half_cache = {}
self._dtime_previous_substep: float = 0.0
"""
Dynamic substep length of previous substep in order to track if rayleigh damping coefficients need to be
Expand Down Expand Up @@ -1268,6 +1324,38 @@ def run_predictor_step(
log.debug("exchanging prognostic field 'w'")
self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w)

def get_first_half_vn(self, vn: gtx.Field):
try:
return self._first_half_cache[vn.__gt_buffer_info__.hash_key]
except KeyError:
self._first_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field(
vn.ndarray[:, : self._grid.num_levels // 2],
domain=gtx_common.Domain(
dims=vn.domain.dims,
ranges=(
vn.domain.ranges[0],
gtx_common.UnitRange(0, self._grid.num_levels // 2),
),
),
)
return self._first_half_cache[vn.__gt_buffer_info__.hash_key]

def get_second_half_vn(self, vn: gtx.Field):
try:
return self._second_half_cache[vn.__gt_buffer_info__.hash_key]
except KeyError:
self._second_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field(
vn.ndarray[:, self._grid.num_levels // 2 :],
domain=gtx_common.Domain(
dims=vn.domain.dims,
ranges=(
vn.domain.ranges[0],
gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels),
),
),
)
return self._second_half_cache[vn.__gt_buffer_info__.hash_key]

def run_corrector_step(
self,
diagnostic_state_nh: dycore_states.DiagnosticStateNonHydro,
Expand Down Expand Up @@ -1341,7 +1429,32 @@ def run_corrector_step(
)
)

self._apply_divergence_damping_and_update_vn(
# EXCHANGE OVERLAP EXPERIMENT START
self._apply_divergence_damping_and_update_vn_first_half(
horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence,
next_vn=prognostic_states.next.vn,
current_vn=prognostic_states.current.vn,
dwdz_at_cells_on_model_levels=z_fields.dwdz_at_cells_on_model_levels,
predictor_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.predictor,
corrector_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.corrector,
normal_wind_tendency_due_to_slow_physics_process=diagnostic_state_nh.normal_wind_tendency_due_to_slow_physics_process,
normal_wind_iau_increment=diagnostic_state_nh.normal_wind_iau_increment,
theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels,
horizontal_pressure_gradient=z_fields.horizontal_pressure_gradient,
reduced_fourth_order_divdamp_coeff_at_nest_boundary=self.reduced_fourth_order_divdamp_coeff_at_nest_boundary,
fourth_order_divdamp_scaling_coeff=self.fourth_order_divdamp_scaling_coeff,
second_order_divdamp_scaling_coeff=second_order_divdamp_scaling_coeff,
dtime=dtime,
apply_2nd_order_divergence_damping=apply_2nd_order_divergence_damping,
apply_4th_order_divergence_damping=apply_4th_order_divergence_damping,
)

log.debug("exchanging prognostic field 'vn' first half")

first_half_vn = self.get_first_half_vn(prognostic_states.next.vn)
first_half_exchange = self._exchange.exchange(dims.EdgeDim, first_half_vn)

self._apply_divergence_damping_and_update_vn_second_half(
horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence,
next_vn=prognostic_states.next.vn,
current_vn=prognostic_states.current.vn,
Expand All @@ -1360,10 +1473,27 @@ def run_corrector_step(
apply_4th_order_divergence_damping=apply_4th_order_divergence_damping,
)

log.debug("exchanging prognostic field 'vn'")
self._exchange.exchange_and_wait(dims.EdgeDim, (prognostic_states.next.vn))
log.debug("exchanging prognostic field 'vn' second half")
# TODO(havogt): this wait could be after the next exchange starts, but we need to duplicate the ghex communication object
first_half_exchange.wait()
second_half_vn = self.get_second_half_vn(prognostic_states.next.vn)
second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn)
self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half(
spatially_averaged_vn=self.z_vn_avg,
mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels,
theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels,
substep_and_spatially_averaged_vn=prep_adv.vn_traj,
substep_averaged_mass_flux=prep_adv.mass_flx_me,
vn=prognostic_states.next.vn,
rho_at_edges_on_model_levels=z_fields.rho_at_edges_on_model_levels,
theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels,
prepare_advection=lprep_adv,
at_first_substep=at_first_substep,
r_nsubsteps=r_nsubsteps,
)

self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection(
second_half_exchange.wait()
self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half(
spatially_averaged_vn=self.z_vn_avg,
mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels,
theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels,
Expand All @@ -1376,6 +1506,7 @@ def run_corrector_step(
at_first_substep=at_first_substep,
r_nsubsteps=r_nsubsteps,
)
# EXCHANGE OVERLAP EXPERIMENT END

self._vertically_implicit_solver_at_corrector_step(
next_w=prognostic_states.next.w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,8 @@ def exchange(self, dim: gtx.Dimension, *fields: gtx.Field) -> MultiNodeResult:
the granule context where fields otherwise have length nproma.
"""
applied_patterns = [self._get_applied_pattern(dim, f) for f in fields]
assert hasattr(fields[0], "array_ns")
if hasattr(fields[0].array_ns, "cuda"):
# TODO(havogt): this is a workaround as ghex does not know that it should synchronize
# the GPU before the exchange. This is necessary to ensure that all data is ready for the exchange.
fields[0].array_ns.cuda.runtime.deviceSynchronize()
# With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream,
# otherwise we need an explicit device synchronize here.
handle = self._comm.exchange(applied_patterns)
log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.")
return MultiNodeResult(handle, applied_patterns)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ url = "https://test.pypi.org/simple/"

[tool.uv.sources]
dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_05"}
# ghex = {git = "https://github.com/ghex-org/GHEX.git", branch = "master"}
ghex = {git = "https://github.com/msimberg/GHEX.git", branch = "async-mpi"}
# gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"}
# gt4py = {index = "test.pypi"}
icon4py-atmosphere-advection = {workspace = true}
Expand Down
7 changes: 3 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.