Skip to content

Commit d8bf784

Browse files
Make align() lazy
1 parent 7ce021b commit d8bf784

2 files changed

Lines changed: 122 additions & 95 deletions

File tree

insardev/insardev/BatchCore.py

Lines changed: 121 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ def _dissolve_pol_for_dask(da_current, das_others, wrap, extend, weight):
116116
w_other = (1.0 - weight) / n_others if n_others > 0 else 0.0
117117

118118
# Reindex das_others to match da_current coordinates
119-
# Use interp for floating point coordinate matching instead of reindex
119+
# Grids are consistent with exactly matched coordinates in overlap areas
120120
das_reindexed = []
121121
for d in das_others:
122-
# Use interp with nearest to avoid issues with floating point coordinate matching
123-
das_reindexed.append(d.interp(y=ys, x=xs, method='nearest', kwargs={'fill_value': np.nan}))
122+
das_reindexed.append(d.reindex(y=ys, x=xs, fill_value=np.nan))
124123

125124
current_vals = da_current.values
126125
current_valid = np.isfinite(current_vals)
@@ -129,12 +128,12 @@ def _dissolve_pol_for_dask(da_current, das_others, wrap, extend, weight):
129128
warnings.simplefilter('ignore', RuntimeWarning)
130129

131130
if wrap:
132-
weighted_sum = np.where(current_valid, np.exp(1j * current_vals) * w_current, 0.0)
131+
weighted_sum = np.where(current_valid, np.exp(1j * current_vals).astype(np.complex64) * w_current, np.complex64(0))
133132
weight_sum = np.where(current_valid, w_current, 0.0)
134133
for d in das_reindexed:
135134
vals = d.values
136135
valid = np.isfinite(vals)
137-
weighted_sum += np.where(valid, np.exp(1j * vals) * w_other, 0.0)
136+
weighted_sum += np.where(valid, np.exp(1j * vals).astype(np.complex64) * w_other, np.complex64(0))
138137
weight_sum += np.where(valid, w_other, 0.0)
139138
valid_weights = weight_sum > 0
140139
normalized = np.divide(weighted_sum, weight_sum, out=np.zeros_like(weighted_sum), where=valid_weights)
@@ -164,6 +163,40 @@ def _dissolve_pol_3d_for_dask(da_slice, das_others_slice, wrap, extend, weight):
164163
return _dissolve_pol_for_dask(da_slice, das_others_slice, wrap, extend, weight)[np.newaxis, ...]
165164

166165

166+
def _dissolve_raw_for_dask(current_arr, current_y, current_x,
167+
others_arrs, others_ys, others_xs,
168+
wrap, extend, weight):
169+
"""
170+
Dissolve using raw numpy arrays + coordinates.
171+
172+
Receives raw arrays (dask resolves them to numpy before calling) and
173+
numpy coordinate arrays. Reconstructs minimal xarray DataArrays for
174+
the interp-based coordinate matching, then delegates to _dissolve_pol_for_dask.
175+
176+
For 3D arrays (pair, y, x), iterates over first dim.
177+
"""
178+
import xarray as xr
179+
180+
if current_arr.ndim > 2:
181+
n_stack = current_arr.shape[0]
182+
slices = []
183+
for i in range(n_stack):
184+
da_c = xr.DataArray(current_arr[i], dims=['y', 'x'],
185+
coords={'y': current_y, 'x': current_x})
186+
das_o = [xr.DataArray(arr[i], dims=['y', 'x'],
187+
coords={'y': y, 'x': x})
188+
for arr, y, x in zip(others_arrs, others_ys, others_xs)]
189+
slices.append(_dissolve_pol_for_dask(da_c, das_o, wrap, extend, weight))
190+
return np.stack(slices, axis=0)
191+
else:
192+
da_c = xr.DataArray(current_arr, dims=['y', 'x'],
193+
coords={'y': current_y, 'x': current_x})
194+
das_o = [xr.DataArray(arr, dims=['y', 'x'],
195+
coords={'y': y, 'x': x})
196+
for arr, y, x in zip(others_arrs, others_ys, others_xs)]
197+
return _dissolve_pol_for_dask(da_c, das_o, wrap, extend, weight)
198+
199+
167200
def _apply_gaussian_for_dask(block, weight_block, sigmas, threshold, device, pixel_sizes, out_dtype):
168201
"""
169202
Module-level function for gaussian blockwise operation (DEPRECATED - use _apply_gaussian_2d_for_dask).
@@ -5686,59 +5719,60 @@ def process_phase_diff(diff_np, x_coords, id1, id2, pair_idx):
56865719
if degree == 1 and cross_subswath_skipped > 0:
56875720
print(f' (skipped {cross_subswath_skipped} same-path cross-subswath pairs for ramp estimation)', flush=True)
56885721

5689-
# Build delayed overlap statistics using to_delayed pattern.
5690-
# Each overlap is pre-selected via .sel() so only overlap chunks
5691-
# enter the dask graph — NOT the full burst pipeline.
5722+
# Pass raw burst data arrays (not pre-computed diffs) to a single
5723+
# delayed task. This creates N_bursts graph dependencies instead of
5724+
# N_overlaps*3 layers from xarray diff operations, keeping the graph
5725+
# minimal for downstream dissolve().
56925726
import dask.array as _da
56935727

5694-
delayed_stats = []
5695-
for id1, id2 in all_overlap_pairs:
5696-
e1, e2 = extents[id1], extents[id2]
5697-
# Overlap bounding box
5698-
y_min, y_max = max(e1[0], e2[0]), min(e1[1], e2[1])
5699-
x_min, x_max = max(e1[2], e2[2]), min(e1[3], e2[3])
5700-
y_slice = slice(y_max, y_min) if _y_descending else slice(y_min, y_max)
5701-
x_slice = slice(x_min, x_max)
5702-
5703-
# Select only the overlap region — restricts dask graph to overlap chunks
5704-
i1 = self[id1][polarization].sel(y=y_slice, x=x_slice)
5705-
i2 = self[id2][polarization].sel(y=y_slice, x=x_slice)
5706-
5707-
for pair_idx in range(n_pairs):
5708-
i1_p = i1.isel(pair=pair_idx) if has_pair_dim else i1
5709-
i2_p = i2.isel(pair=pair_idx) if has_pair_dim else i2
5710-
# Lazy difference — xarray aligns to common overlap coordinates
5711-
diff = i2_p - i1_p
5712-
x_coords = diff.coords['x'].values
5713-
# Convert to delayed numpy via to_delayed
5714-
diff_delayed = diff.data.rechunk(-1, -1).to_delayed().ravel()[0]
5715-
# Delayed numpy processing
5716-
stat = dask.delayed(process_phase_diff)(
5717-
diff_delayed, x_coords, id1, id2, pair_idx
5718-
)
5719-
delayed_stats.append(stat)
5728+
# Collect burst data + coordinates (coordinates are numpy, not dask)
5729+
burst_data = [self[bid][polarization].data for bid in ids]
5730+
burst_y = [self[bid][polarization].y.values for bid in ids]
5731+
burst_x = [self[bid][polarization].x.values for bid in ids]
57205732

57215733
if debug:
5722-
print(f'Building lazy graph for {len(delayed_stats)} overlap statistics...', flush=True)
5734+
print(f'Building lazy graph for {len(all_overlap_pairs)} overlap pairs, {len(ids)} bursts...', flush=True)
57235735

5724-
# Solve function — runs inside dask.delayed, receives concrete stats.
5725-
# Captures only small metadata (ids, id_to_idx, x_centers, etc.).
5726-
def _fit_solve(stats_list):
5736+
# Single delayed task: receives resolved burst numpy arrays,
5737+
# computes overlaps + diffs internally, then solves.
5738+
def _fit_all(*burst_data_arrays):
5739+
import xarray as xr
57275740
from scipy import sparse as _sparse
57285741
from scipy.sparse.linalg import lsqr as _lsqr
57295742
from scipy.sparse.csgraph import connected_components as _cc
57305743

5731-
valid_stats = [s for s in stats_list if s is not None]
5732-
5744+
# Compute overlap diffs and process statistics
57335745
_pbp = {p: [] for p in range(n_pairs)}
5734-
for st in valid_stats:
5735-
_id1, _id2, _pidx, _off, _rv, _xc, _nu = st
5736-
_w = np.sqrt(_nu)
5737-
if degree == 0:
5738-
_pbp[_pidx].append((_id1, _id2, _off, _w))
5739-
else:
5740-
if _rv is not None:
5741-
_pbp[_pidx].append((_id1, _id2, _off, _rv, _xc, _w))
5746+
for id1, id2 in all_overlap_pairs:
5747+
i1_idx = id_to_idx[id1]
5748+
i2_idx = id_to_idx[id2]
5749+
d1 = np.asarray(burst_data_arrays[i1_idx])
5750+
d2 = np.asarray(burst_data_arrays[i2_idx])
5751+
5752+
for pair_idx in range(n_pairs):
5753+
d1_p = d1[pair_idx] if has_pair_dim else d1
5754+
d2_p = d2[pair_idx] if has_pair_dim else d2
5755+
5756+
# Build xarray DataArrays for coordinate-aware overlap
5757+
da1 = xr.DataArray(d1_p, dims=['y', 'x'],
5758+
coords={'y': burst_y[i1_idx],
5759+
'x': burst_x[i1_idx]})
5760+
da2 = xr.DataArray(d2_p, dims=['y', 'x'],
5761+
coords={'y': burst_y[i2_idx],
5762+
'x': burst_x[i2_idx]})
5763+
diff = da2 - da1
5764+
stat = process_phase_diff(diff.values,
5765+
diff.coords['x'].values,
5766+
id1, id2, pair_idx)
5767+
if stat is None:
5768+
continue
5769+
_id1s, _id2s, _pidxs, _off, _rv, _xcent, _nu = stat
5770+
_w = np.sqrt(_nu)
5771+
if degree == 0:
5772+
_pbp[_pidxs].append((_id1s, _id2s, _off, _w))
5773+
else:
5774+
if _rv is not None:
5775+
_pbp[_pidxs].append((_id1s, _id2s, _off, _rv, _xcent, _w))
57425776

57435777
def _solve_one(pidx):
57445778
pairs = _pbp[pidx]
@@ -5855,8 +5889,9 @@ def _solve_one(pidx):
58555889

58565890
return {'offsets': offsets, 'residuals': residuals}
58575891

5858-
# Delayed solve — fully lazy, NO dask.compute()
5859-
solve_result = dask.delayed(_fit_solve)(delayed_stats)
5892+
# Single delayed call — dask resolves burst data arrays before calling.
5893+
# Graph has ~N_bursts layers (not ~N_overlaps*3 from xarray diffs).
5894+
solve_result = dask.delayed(_fit_all, pure=True)(*burst_data)
58605895

58615896
# Extract per-burst dask 0-d arrays from delayed solve result
58625897
offsets_part = solve_result['offsets']
@@ -6155,9 +6190,11 @@ def dissolve(self, extend: bool = False, weight: float = None, debug: bool = Fal
61556190
total_overlaps = sum(len(v) for v in overlapping_map.values())
61566191
print(f'dissolve: STRtree found {total_overlaps} burst overlaps', flush=True)
61576192

6158-
# Build output - per burst, replace pol variables with lazy arrays
6159-
# Note: dissolve functions are defined at module level (_dissolve_pol_for_dask, _dissolve_pol_3d_for_dask)
6160-
# to avoid dask serialization issues with nested function closures in distributed environments
6193+
# Build output — one dask.delayed task per burst per pol.
6194+
# Pass raw dask arrays (not xarray DataArrays) to avoid expensive
6195+
# xarray __dask_graph__() calls during dask.delayed graph construction.
6196+
# The _dissolve_raw_for_dask function receives numpy arrays (dask resolves
6197+
# them) and reconstructs minimal xarray DataArrays for coord matching.
61616198
output = {}
61626199
for burst_idx, bid in enumerate(burst_ids):
61636200
overlapping_indices = overlapping_map[burst_idx]
@@ -6169,52 +6206,42 @@ def dissolve(self, extend: bool = False, weight: float = None, debug: bool = Fal
61696206

61706207
ds_others = [self[burst_ids[idx]] for idx in overlapping_indices]
61716208

6172-
# Copy dataset and replace each pol with lazy dissolved version
61736209
new_ds = ds_current.copy()
61746210
for pol in polarizations:
61756211
da_current = ds_current[pol]
61766212
das_others = [ds[pol] for ds in ds_others]
61776213

6178-
# Check if 3D (has stack dimension like 'pair')
6179-
if len(da_current.dims) > 2:
6180-
# Ensure first dim chunked as 1 for per-slice processing
6181-
if hasattr(da_current.data, 'chunks') and da_current.data.chunks[0][0] != 1:
6182-
da_current = da_current.chunk({da_current.dims[0]: 1})
6183-
das_others = [d.chunk({d.dims[0]: 1}) for d in das_others]
6184-
stackvar = da_current.dims[0]
6185-
n_stack = da_current.sizes[stackvar]
6186-
shape_2d = da_current.shape[1:]
6187-
6188-
# Create separate delayed array for each stack element for parallelization
6189-
# Use module-level function to avoid dask serialization issues
6190-
delayed_slices = []
6191-
for i in range(n_stack):
6192-
da_slice = da_current.isel({stackvar: i})
6193-
das_others_slice = tuple(d.isel({stackvar: i}) for d in das_others)
6194-
# Create 3D delayed array with shape (1, y, x) and chunks (1, -1, -1)
6195-
# Use pure=True for deterministic behavior in distributed environments
6196-
delayed_slice = da.from_delayed(
6197-
dask.delayed(_dissolve_pol_3d_for_dask, pure=True)(
6198-
da_slice, das_others_slice, wrap, extend, weight
6199-
),
6200-
shape=(1,) + shape_2d,
6201-
dtype=da_current.dtype
6202-
)
6203-
delayed_slices.append(delayed_slice)
6204-
6205-
# Concatenate along axis 0 - each slice is already (1, y, x)
6206-
delayed_array = da.concatenate(delayed_slices, axis=0)
6207-
else:
6208-
# 2D case - single delayed array
6209-
# Use module-level function to avoid dask serialization issues
6210-
# Use pure=True for deterministic behavior in distributed environments
6211-
delayed_array = da.from_delayed(
6212-
dask.delayed(_dissolve_pol_for_dask, pure=True)(
6213-
da_current, tuple(das_others), wrap, extend, weight
6214-
),
6215-
shape=da_current.shape,
6216-
dtype=da_current.dtype
6217-
)
6214+
# Extract raw arrays and numpy coordinates.
6215+
# Raw dask arrays have O(1) __dask_graph__() (direct attribute),
6216+
# vs xarray DataArrays which create temp Dataset each call.
6217+
current_arr = da_current.data
6218+
current_y = da_current.y.values
6219+
current_x = da_current.x.values
6220+
6221+
if not isinstance(current_arr, da.Array):
6222+
current_arr = da.from_array(current_arr, chunks=current_arr.shape)
6223+
6224+
others_arrs = []
6225+
others_ys = []
6226+
others_xs = []
6227+
for d in das_others:
6228+
arr = d.data
6229+
if not isinstance(arr, da.Array):
6230+
arr = da.from_array(arr, chunks=arr.shape)
6231+
others_arrs.append(arr)
6232+
others_ys.append(d.y.values)
6233+
others_xs.append(d.x.values)
6234+
6235+
delayed_result = dask.delayed(_dissolve_raw_for_dask, pure=True)(
6236+
current_arr, current_y, current_x,
6237+
others_arrs, others_ys, others_xs,
6238+
wrap, extend, weight
6239+
)
6240+
delayed_array = da.from_delayed(
6241+
delayed_result,
6242+
shape=da_current.shape,
6243+
dtype=da_current.dtype
6244+
)
62186245
new_ds[pol] = da_current.copy(data=delayed_array)
62196246

62206247
output[bid] = new_ds

insardev/insardev/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# See the LICENSE file in the insardev directory for license terms.
99
# Professional use requires an active per-seat subscription at: https://patreon.com/pechnikov
1010
# ----------------------------------------------------------------------------
11-
__version__ = '2026.3.8.post1'
11+
__version__ = '2026.3.8.post2'
1212

1313
# processing functions
1414
from .Stack import Stack

0 commit comments

Comments
 (0)