@@ -116,11 +116,10 @@ def _dissolve_pol_for_dask(da_current, das_others, wrap, extend, weight):
116116 w_other = (1.0 - weight ) / n_others if n_others > 0 else 0.0
117117
118118 # Reindex das_others to match da_current coordinates
119- # Use interp for floating point coordinate matching instead of reindex
119+ # Grids are consistent with exactly matched coordinates in overlap areas
120120 das_reindexed = []
121121 for d in das_others :
122- # Use interp with nearest to avoid issues with floating point coordinate matching
123- das_reindexed .append (d .interp (y = ys , x = xs , method = 'nearest' , kwargs = {'fill_value' : np .nan }))
122+ das_reindexed .append (d .reindex (y = ys , x = xs , fill_value = np .nan ))
124123
125124 current_vals = da_current .values
126125 current_valid = np .isfinite (current_vals )
@@ -129,12 +128,12 @@ def _dissolve_pol_for_dask(da_current, das_others, wrap, extend, weight):
129128 warnings .simplefilter ('ignore' , RuntimeWarning )
130129
131130 if wrap :
132- weighted_sum = np .where (current_valid , np .exp (1j * current_vals ) * w_current , 0.0 )
131+ weighted_sum = np .where (current_valid , np .exp (1j * current_vals ). astype ( np . complex64 ) * w_current , np . complex64 ( 0 ) )
133132 weight_sum = np .where (current_valid , w_current , 0.0 )
134133 for d in das_reindexed :
135134 vals = d .values
136135 valid = np .isfinite (vals )
137- weighted_sum += np .where (valid , np .exp (1j * vals ) * w_other , 0.0 )
136+ weighted_sum += np .where (valid , np .exp (1j * vals ). astype ( np . complex64 ) * w_other , np . complex64 ( 0 ) )
138137 weight_sum += np .where (valid , w_other , 0.0 )
139138 valid_weights = weight_sum > 0
140139 normalized = np .divide (weighted_sum , weight_sum , out = np .zeros_like (weighted_sum ), where = valid_weights )
@@ -164,6 +163,40 @@ def _dissolve_pol_3d_for_dask(da_slice, das_others_slice, wrap, extend, weight):
164163 return _dissolve_pol_for_dask (da_slice , das_others_slice , wrap , extend , weight )[np .newaxis , ...]
165164
166165
166+ def _dissolve_raw_for_dask (current_arr , current_y , current_x ,
167+ others_arrs , others_ys , others_xs ,
168+ wrap , extend , weight ):
169+ """
170+ Dissolve using raw numpy arrays + coordinates.
171+
172+ Receives raw arrays (dask resolves them to numpy before calling) and
173+ numpy coordinate arrays. Reconstructs minimal xarray DataArrays for
174+ the interp-based coordinate matching, then delegates to _dissolve_pol_for_dask.
175+
176+ For 3D arrays (pair, y, x), iterates over first dim.
177+ """
178+ import xarray as xr
179+
180+ if current_arr .ndim > 2 :
181+ n_stack = current_arr .shape [0 ]
182+ slices = []
183+ for i in range (n_stack ):
184+ da_c = xr .DataArray (current_arr [i ], dims = ['y' , 'x' ],
185+ coords = {'y' : current_y , 'x' : current_x })
186+ das_o = [xr .DataArray (arr [i ], dims = ['y' , 'x' ],
187+ coords = {'y' : y , 'x' : x })
188+ for arr , y , x in zip (others_arrs , others_ys , others_xs )]
189+ slices .append (_dissolve_pol_for_dask (da_c , das_o , wrap , extend , weight ))
190+ return np .stack (slices , axis = 0 )
191+ else :
192+ da_c = xr .DataArray (current_arr , dims = ['y' , 'x' ],
193+ coords = {'y' : current_y , 'x' : current_x })
194+ das_o = [xr .DataArray (arr , dims = ['y' , 'x' ],
195+ coords = {'y' : y , 'x' : x })
196+ for arr , y , x in zip (others_arrs , others_ys , others_xs )]
197+ return _dissolve_pol_for_dask (da_c , das_o , wrap , extend , weight )
198+
199+
167200def _apply_gaussian_for_dask (block , weight_block , sigmas , threshold , device , pixel_sizes , out_dtype ):
168201 """
169202 Module-level function for gaussian blockwise operation (DEPRECATED - use _apply_gaussian_2d_for_dask).
@@ -5686,59 +5719,60 @@ def process_phase_diff(diff_np, x_coords, id1, id2, pair_idx):
56865719 if degree == 1 and cross_subswath_skipped > 0 :
56875720 print (f' (skipped { cross_subswath_skipped } same-path cross-subswath pairs for ramp estimation)' , flush = True )
56885721
5689- # Build delayed overlap statistics using to_delayed pattern.
5690- # Each overlap is pre-selected via .sel() so only overlap chunks
5691- # enter the dask graph — NOT the full burst pipeline.
5722+ # Pass raw burst data arrays (not pre-computed diffs) to a single
5723+ # delayed task. This creates N_bursts graph dependencies instead of
5724+ # N_overlaps*3 layers from xarray diff operations, keeping the graph
5725+ # minimal for downstream dissolve().
56925726 import dask .array as _da
56935727
5694- delayed_stats = []
5695- for id1 , id2 in all_overlap_pairs :
5696- e1 , e2 = extents [id1 ], extents [id2 ]
5697- # Overlap bounding box
5698- y_min , y_max = max (e1 [0 ], e2 [0 ]), min (e1 [1 ], e2 [1 ])
5699- x_min , x_max = max (e1 [2 ], e2 [2 ]), min (e1 [3 ], e2 [3 ])
5700- y_slice = slice (y_max , y_min ) if _y_descending else slice (y_min , y_max )
5701- x_slice = slice (x_min , x_max )
5702-
5703- # Select only the overlap region — restricts dask graph to overlap chunks
5704- i1 = self [id1 ][polarization ].sel (y = y_slice , x = x_slice )
5705- i2 = self [id2 ][polarization ].sel (y = y_slice , x = x_slice )
5706-
5707- for pair_idx in range (n_pairs ):
5708- i1_p = i1 .isel (pair = pair_idx ) if has_pair_dim else i1
5709- i2_p = i2 .isel (pair = pair_idx ) if has_pair_dim else i2
5710- # Lazy difference — xarray aligns to common overlap coordinates
5711- diff = i2_p - i1_p
5712- x_coords = diff .coords ['x' ].values
5713- # Convert to delayed numpy via to_delayed
5714- diff_delayed = diff .data .rechunk (- 1 , - 1 ).to_delayed ().ravel ()[0 ]
5715- # Delayed numpy processing
5716- stat = dask .delayed (process_phase_diff )(
5717- diff_delayed , x_coords , id1 , id2 , pair_idx
5718- )
5719- delayed_stats .append (stat )
5728+ # Collect burst data + coordinates (coordinates are numpy, not dask)
5729+ burst_data = [self [bid ][polarization ].data for bid in ids ]
5730+ burst_y = [self [bid ][polarization ].y .values for bid in ids ]
5731+ burst_x = [self [bid ][polarization ].x .values for bid in ids ]
57205732
57215733 if debug :
5722- print (f'Building lazy graph for { len (delayed_stats )} overlap statistics ...' , flush = True )
5734+ print (f'Building lazy graph for { len (all_overlap_pairs )} overlap pairs, { len ( ids ) } bursts ...' , flush = True )
57235735
5724- # Solve function — runs inside dask.delayed, receives concrete stats.
5725- # Captures only small metadata (ids, id_to_idx, x_centers, etc.).
5726- def _fit_solve (stats_list ):
5736+ # Single delayed task: receives resolved burst numpy arrays,
5737+ # computes overlaps + diffs internally, then solves.
5738+ def _fit_all (* burst_data_arrays ):
5739+ import xarray as xr
57275740 from scipy import sparse as _sparse
57285741 from scipy .sparse .linalg import lsqr as _lsqr
57295742 from scipy .sparse .csgraph import connected_components as _cc
57305743
5731- valid_stats = [s for s in stats_list if s is not None ]
5732-
5744+ # Compute overlap diffs and process statistics
57335745 _pbp = {p : [] for p in range (n_pairs )}
5734- for st in valid_stats :
5735- _id1 , _id2 , _pidx , _off , _rv , _xc , _nu = st
5736- _w = np .sqrt (_nu )
5737- if degree == 0 :
5738- _pbp [_pidx ].append ((_id1 , _id2 , _off , _w ))
5739- else :
5740- if _rv is not None :
5741- _pbp [_pidx ].append ((_id1 , _id2 , _off , _rv , _xc , _w ))
5746+ for id1 , id2 in all_overlap_pairs :
5747+ i1_idx = id_to_idx [id1 ]
5748+ i2_idx = id_to_idx [id2 ]
5749+ d1 = np .asarray (burst_data_arrays [i1_idx ])
5750+ d2 = np .asarray (burst_data_arrays [i2_idx ])
5751+
5752+ for pair_idx in range (n_pairs ):
5753+ d1_p = d1 [pair_idx ] if has_pair_dim else d1
5754+ d2_p = d2 [pair_idx ] if has_pair_dim else d2
5755+
5756+ # Build xarray DataArrays for coordinate-aware overlap
5757+ da1 = xr .DataArray (d1_p , dims = ['y' , 'x' ],
5758+ coords = {'y' : burst_y [i1_idx ],
5759+ 'x' : burst_x [i1_idx ]})
5760+ da2 = xr .DataArray (d2_p , dims = ['y' , 'x' ],
5761+ coords = {'y' : burst_y [i2_idx ],
5762+ 'x' : burst_x [i2_idx ]})
5763+ diff = da2 - da1
5764+ stat = process_phase_diff (diff .values ,
5765+ diff .coords ['x' ].values ,
5766+ id1 , id2 , pair_idx )
5767+ if stat is None :
5768+ continue
5769+ _id1s , _id2s , _pidxs , _off , _rv , _xcent , _nu = stat
5770+ _w = np .sqrt (_nu )
5771+ if degree == 0 :
5772+ _pbp [_pidxs ].append ((_id1s , _id2s , _off , _w ))
5773+ else :
5774+ if _rv is not None :
5775+ _pbp [_pidxs ].append ((_id1s , _id2s , _off , _rv , _xcent , _w ))
57425776
57435777 def _solve_one (pidx ):
57445778 pairs = _pbp [pidx ]
@@ -5855,8 +5889,9 @@ def _solve_one(pidx):
58555889
58565890 return {'offsets' : offsets , 'residuals' : residuals }
58575891
5858- # Delayed solve — fully lazy, NO dask.compute()
5859- solve_result = dask .delayed (_fit_solve )(delayed_stats )
5892+ # Single delayed call — dask resolves burst data arrays before calling.
5893+ # Graph has ~N_bursts layers (not ~N_overlaps*3 from xarray diffs).
5894+ solve_result = dask .delayed (_fit_all , pure = True )(* burst_data )
58605895
58615896 # Extract per-burst dask 0-d arrays from delayed solve result
58625897 offsets_part = solve_result ['offsets' ]
@@ -6155,9 +6190,11 @@ def dissolve(self, extend: bool = False, weight: float = None, debug: bool = Fal
61556190 total_overlaps = sum (len (v ) for v in overlapping_map .values ())
61566191 print (f'dissolve: STRtree found { total_overlaps } burst overlaps' , flush = True )
61576192
6158- # Build output - per burst, replace pol variables with lazy arrays
6159- # Note: dissolve functions are defined at module level (_dissolve_pol_for_dask, _dissolve_pol_3d_for_dask)
6160- # to avoid dask serialization issues with nested function closures in distributed environments
6193+ # Build output — one dask.delayed task per burst per pol.
6194+ # Pass raw dask arrays (not xarray DataArrays) to avoid expensive
6195+ # xarray __dask_graph__() calls during dask.delayed graph construction.
6196+ # The _dissolve_raw_for_dask function receives numpy arrays (dask resolves
6197+ # them) and reconstructs minimal xarray DataArrays for coord matching.
61616198 output = {}
61626199 for burst_idx , bid in enumerate (burst_ids ):
61636200 overlapping_indices = overlapping_map [burst_idx ]
@@ -6169,52 +6206,42 @@ def dissolve(self, extend: bool = False, weight: float = None, debug: bool = Fal
61696206
61706207 ds_others = [self [burst_ids [idx ]] for idx in overlapping_indices ]
61716208
6172- # Copy dataset and replace each pol with lazy dissolved version
61736209 new_ds = ds_current .copy ()
61746210 for pol in polarizations :
61756211 da_current = ds_current [pol ]
61766212 das_others = [ds [pol ] for ds in ds_others ]
61776213
6178- # Check if 3D (has stack dimension like 'pair')
6179- if len (da_current .dims ) > 2 :
6180- # Ensure first dim chunked as 1 for per-slice processing
6181- if hasattr (da_current .data , 'chunks' ) and da_current .data .chunks [0 ][0 ] != 1 :
6182- da_current = da_current .chunk ({da_current .dims [0 ]: 1 })
6183- das_others = [d .chunk ({d .dims [0 ]: 1 }) for d in das_others ]
6184- stackvar = da_current .dims [0 ]
6185- n_stack = da_current .sizes [stackvar ]
6186- shape_2d = da_current .shape [1 :]
6187-
6188- # Create separate delayed array for each stack element for parallelization
6189- # Use module-level function to avoid dask serialization issues
6190- delayed_slices = []
6191- for i in range (n_stack ):
6192- da_slice = da_current .isel ({stackvar : i })
6193- das_others_slice = tuple (d .isel ({stackvar : i }) for d in das_others )
6194- # Create 3D delayed array with shape (1, y, x) and chunks (1, -1, -1)
6195- # Use pure=True for deterministic behavior in distributed environments
6196- delayed_slice = da .from_delayed (
6197- dask .delayed (_dissolve_pol_3d_for_dask , pure = True )(
6198- da_slice , das_others_slice , wrap , extend , weight
6199- ),
6200- shape = (1 ,) + shape_2d ,
6201- dtype = da_current .dtype
6202- )
6203- delayed_slices .append (delayed_slice )
6204-
6205- # Concatenate along axis 0 - each slice is already (1, y, x)
6206- delayed_array = da .concatenate (delayed_slices , axis = 0 )
6207- else :
6208- # 2D case - single delayed array
6209- # Use module-level function to avoid dask serialization issues
6210- # Use pure=True for deterministic behavior in distributed environments
6211- delayed_array = da .from_delayed (
6212- dask .delayed (_dissolve_pol_for_dask , pure = True )(
6213- da_current , tuple (das_others ), wrap , extend , weight
6214- ),
6215- shape = da_current .shape ,
6216- dtype = da_current .dtype
6217- )
6214+ # Extract raw arrays and numpy coordinates.
6215+ # Raw dask arrays have O(1) __dask_graph__() (direct attribute),
6216+ # vs xarray DataArrays which create temp Dataset each call.
6217+ current_arr = da_current .data
6218+ current_y = da_current .y .values
6219+ current_x = da_current .x .values
6220+
6221+ if not isinstance (current_arr , da .Array ):
6222+ current_arr = da .from_array (current_arr , chunks = current_arr .shape )
6223+
6224+ others_arrs = []
6225+ others_ys = []
6226+ others_xs = []
6227+ for d in das_others :
6228+ arr = d .data
6229+ if not isinstance (arr , da .Array ):
6230+ arr = da .from_array (arr , chunks = arr .shape )
6231+ others_arrs .append (arr )
6232+ others_ys .append (d .y .values )
6233+ others_xs .append (d .x .values )
6234+
6235+ delayed_result = dask .delayed (_dissolve_raw_for_dask , pure = True )(
6236+ current_arr , current_y , current_x ,
6237+ others_arrs , others_ys , others_xs ,
6238+ wrap , extend , weight
6239+ )
6240+ delayed_array = da .from_delayed (
6241+ delayed_result ,
6242+ shape = da_current .shape ,
6243+ dtype = da_current .dtype
6244+ )
62186245 new_ds [pol ] = da_current .copy (data = delayed_array )
62196246
62206247 output [bid ] = new_ds
0 commit comments