Skip to content

Commit 3819545

Browse files
Fix batch to_dataframe() function. Add batch merge().
1 parent d8bf784 commit 3819545

2 files changed

Lines changed: 137 additions & 75 deletions

File tree

insardev/insardev/BatchCore.py

Lines changed: 136 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,6 +3019,34 @@ def rename_vars(self, **kw):
30193019
def rename(self, **kw):
30203020
return type(self)({k: ds.rename(**kw) for k, ds in self.items()})
30213021

3022+
def merge(self, other: 'BatchCore') -> 'Batch':
3023+
"""Merge variables from another Batch into this one (per burst xr.merge).
3024+
3025+
Both Batches must share the same burst keys and compatible coordinates.
3026+
Use rename() first to avoid variable name conflicts.
3027+
Always returns Batch (plain real-valued) since mixed types should not
3028+
support specialized operations (e.g., wrapped phase arithmetic).
3029+
3030+
Parameters
3031+
----------
3032+
other : BatchCore
3033+
Batch with additional variables to merge.
3034+
3035+
Returns
3036+
-------
3037+
Batch
3038+
Merged batch containing variables from both.
3039+
3040+
Examples
3041+
--------
3042+
>>> corr_mean = mcorr.mean('pair').rename(VV='VV_cor')
3043+
>>> combined = corr_mean.merge(rmse_mm.rename(VV='VV_rmse_mm'))
3044+
>>> df = combined.to_dataframe()
3045+
"""
3046+
import xarray as xr
3047+
from .Batch import Batch
3048+
return Batch({k: xr.merge([ds, other[k]]) for k, ds in self.items() if k in other})
3049+
30223050
def reindex(self, **kw):
30233051
return type(self)({k: ds.reindex(**kw) for k, ds in self.items()})
30243052

@@ -3049,6 +3077,9 @@ def _agg(self, name: str, dim=None, **kwargs):
30493077
out[key] = fn(dim=dim, **kwargs)
30503078
else:
30513079
out[key] = fn(**kwargs)
3080+
# Preserve attrs (xarray aggregations drop them by default)
3081+
if hasattr(obj, 'attrs') and hasattr(out[key], 'attrs'):
3082+
out[key].attrs = obj.attrs
30523083

30533084
# filter out collapsed dimensions
30543085
sample = next(iter(out.values()), None)
@@ -3150,7 +3181,7 @@ def rmse(self, solution, weight=None):
31503181
else:
31513182
rmse_val = np.sqrt(err_sq.mean('pair'))
31523183
rmse_vars[var] = rmse_val.astype('float32')
3153-
out[key] = xr.Dataset(rmse_vars)
3184+
out[key] = xr.Dataset(rmse_vars, attrs=self[key].attrs)
31543185

31553186
return Batch(out)
31563187

@@ -3655,99 +3686,130 @@ def to_dataframe(self,
36553686
if not self:
36563687
return pd.DataFrame()
36573688

3658-
# Detect CRS from data if auto
3659-
if crs is not None and isinstance(crs, str) and crs == 'auto':
3660-
sample = next(iter(self.values()))
3661-
crs = sample.attrs.get('crs', 4326)
3662-
3663-
# Detect polarizations
3689+
# Detect native CRS from data
36643690
sample = next(iter(self.values()))
3665-
polarizations = [pol for pol in ['VV', 'VH', 'HH', 'HV'] if pol in sample.data_vars]
3666-
3667-
# Detect dimension: 'date' for BatchComplex, 'pair' for others
3668-
dim = 'date' if 'date' in sample.dims else 'pair'
3691+
native_crs = self.crs
3692+
if native_crs is None:
3693+
raise ValueError('Batch has no CRS. Check the processing pipeline that produced this Batch.')
3694+
if crs is not None and isinstance(crs, str) and crs == 'auto':
3695+
crs = native_crs
3696+
3697+
# Detect spatial data variables (skip 1D/0D vars like converted attributes)
3698+
spatial_vars = [v for v in sample.data_vars if sample[v].ndim >= 2]
3699+
ndims = {sample[v].ndim for v in spatial_vars}
3700+
if len(ndims) > 1:
3701+
raise ValueError(f'Mixed 2D and 3D variables not supported: {{{", ".join(f"{v}: {sample[v].ndim}D" for v in spatial_vars)}}}')
3702+
3703+
# Detect dimension: 'date' for BatchComplex, 'pair' for others, None for spatial-only
3704+
if 'date' in sample.dims:
3705+
dim = 'date'
3706+
elif 'pair' in sample.dims:
3707+
dim = 'pair'
3708+
else:
3709+
dim = None
36693710

36703711
# Define the attribute order matching Stack.to_dataframe
3671-
# Order: fullBurstID, burst, startTime, polarization, flightDirection, pathNumber, subswath, mission, beamModeType, BPR, geometry
3672-
attr_order = ['fullBurstID', 'burst', 'startTime', 'polarization', 'flightDirection',
3673-
'pathNumber', 'subswath', 'mission', 'beamModeType', 'BPR', 'geometry']
3712+
attr_order = ['fullBurstID', 'burst', 'startTime', 'polarization', 'flightDirection',
3713+
'pathNumber', 'subswath', 'mission', 'beamModeType', 'BPR']
36743714

3675-
# Make attributes dataframe from data
3676-
processed_attrs = []
3715+
# Spatial-only data (e.g., RMSE, elevation): one row per pixel with data values
3716+
if dim is None:
3717+
frames = []
3718+
for key, ds in self.items():
3719+
# Get spatial data variables (skip 1D/0D vars)
3720+
spatial_vars = [v for v in ds.data_vars if ds[v].ndim >= 2]
3721+
if not spatial_vars:
3722+
continue
3723+
df_burst = ds[spatial_vars].to_dataframe().reset_index()
3724+
# Drop all-NaN rows
3725+
df_burst = df_burst.dropna(subset=spatial_vars, how='all')
3726+
# Add burst metadata from attrs
3727+
for attr_name in attr_order:
3728+
if attr_name in ds.attrs:
3729+
value = ds.attrs[attr_name]
3730+
if attr_name == 'startTime':
3731+
value = pd.Timestamp(value)
3732+
df_burst[attr_name] = value
3733+
frames.append(df_burst)
3734+
3735+
if not frames:
3736+
return pd.DataFrame()
3737+
df = pd.concat(frames, ignore_index=True)
3738+
3739+
# Create Point geometry in data's native CRS
3740+
df['geometry'] = gpd.points_from_xy(df['x'], df['y'])
3741+
df = gpd.GeoDataFrame(df, crs=native_crs)
3742+
3743+
# Reorder: burst metadata first, then data, then geometry
3744+
meta_cols = [c for c in attr_order if c in df.columns]
3745+
data_cols = ['y', 'x'] + spatial_vars
3746+
ordered = meta_cols + data_cols + ['geometry']
3747+
df = df[[c for c in ordered if c in df.columns]]
3748+
3749+
if 'fullBurstID' in df.columns and 'burst' in df.columns:
3750+
df = df.sort_values(by=['fullBurstID', 'burst']).set_index(['fullBurstID', 'burst'])
3751+
3752+
if crs is not None and crs != native_crs:
3753+
df = df.to_crs(crs)
3754+
return df
3755+
3756+
# Date/pair-based data: one row per pixel per date/pair with data values
3757+
spatial_vars = [v for v in sample.data_vars
3758+
if 'y' in sample[v].dims and 'x' in sample[v].dims]
3759+
frames = []
36773760
for key, ds in self.items():
36783761
for idx in range(ds.dims[dim]):
3679-
processed_attr = {}
3680-
3681-
# Get ref/rep for pair dimension
3762+
# Extract 2D slice for this date/pair
3763+
ds_slice = ds[spatial_vars].isel({dim: idx})
3764+
df_slice = ds_slice.to_dataframe().reset_index()
3765+
# Drop all-NaN rows
3766+
df_slice = df_slice.dropna(subset=spatial_vars, how='all')
3767+
3768+
# Add date/pair info
36823769
if dim == 'pair':
36833770
if 'ref' in ds.coords:
3684-
processed_attr['ref'] = pd.Timestamp(ds['ref'].values[idx])
3771+
df_slice['ref'] = pd.Timestamp(ds['ref'].values[idx])
36853772
if 'rep' in ds.coords:
3686-
processed_attr['rep'] = pd.Timestamp(ds['rep'].values[idx])
3687-
else:
3688-
processed_attr['date'] = pd.Timestamp(ds[dim].values[idx])
3689-
3690-
# Extract attributes from ds.attrs
3773+
df_slice['rep'] = pd.Timestamp(ds['rep'].values[idx])
3774+
elif dim == 'date':
3775+
df_slice['date'] = pd.Timestamp(ds[dim].values[idx])
3776+
3777+
# Add burst metadata from attrs
36913778
for attr_name in attr_order:
3692-
if attr_name in ds.attrs and attr_name not in processed_attr:
3779+
if attr_name in ds.attrs:
36933780
value = ds.attrs[attr_name]
3694-
if attr_name == 'geometry' and isinstance(value, str):
3695-
processed_attr[attr_name] = wkt.loads(value)
3696-
elif attr_name == 'startTime':
3697-
processed_attr[attr_name] = pd.Timestamp(value)
3698-
else:
3699-
processed_attr[attr_name] = value
3700-
3701-
processed_attrs.append(processed_attr)
3781+
if attr_name == 'startTime':
3782+
value = pd.Timestamp(value)
3783+
df_slice[attr_name] = value
37023784

3703-
if not processed_attrs:
3704-
return pd.DataFrame()
3785+
frames.append(df_slice)
37053786

3706-
# Check if we have geometry column for GeoDataFrame
3707-
has_geometry = 'geometry' in processed_attrs[0]
3708-
3709-
if has_geometry:
3710-
df = gpd.GeoDataFrame(processed_attrs, crs=4326)
3711-
else:
3712-
df = pd.DataFrame(processed_attrs)
3787+
if not frames:
3788+
return pd.DataFrame()
3789+
df = pd.concat(frames, ignore_index=True)
37133790

3714-
# Add polarization info if not already present
3715-
if 'polarization' not in df.columns and polarizations:
3716-
df['polarization'] = ','.join(map(str, polarizations))
3791+
# Create Point geometry in data's native CRS
3792+
df['geometry'] = gpd.points_from_xy(df['x'], df['y'])
3793+
df = gpd.GeoDataFrame(df, crs=native_crs)
37173794

3718-
# Round BPR for readability
3719-
if 'BPR' in df.columns:
3720-
df['BPR'] = df['BPR'].round(1)
3795+
# Reorder columns: burst metadata, date/pair info, coordinates, data, geometry
3796+
if dim == 'pair':
3797+
time_cols = ['ref', 'rep']
3798+
else:
3799+
time_cols = ['date']
3800+
meta_cols = [c for c in attr_order if c in df.columns]
3801+
time_cols = [c for c in time_cols if c in df.columns]
3802+
data_cols = ['y', 'x'] + spatial_vars
3803+
ordered = meta_cols + time_cols + data_cols + ['geometry']
3804+
df = df[[c for c in ordered if c in df.columns]]
37213805

3722-
# Reorder columns to match Stack.to_dataframe format
3723-
# For pair data: fullBurstID, burst (index), then ref, rep, then rest
37243806
if 'fullBurstID' in df.columns and 'burst' in df.columns:
3725-
# Build column order
3726-
if dim == 'pair':
3727-
# ref, rep first after index, then startTime, polarization, etc.
3728-
first_cols = ['fullBurstID', 'burst', 'ref', 'rep']
3729-
else:
3730-
first_cols = ['fullBurstID', 'burst', 'date']
3731-
3732-
# Rest of columns in attr_order, excluding index columns and ref/rep/date
3733-
other_cols = [c for c in attr_order if c not in first_cols and c in df.columns]
3734-
3735-
# Reorder
3736-
ordered_cols = [c for c in first_cols if c in df.columns] + other_cols
3737-
df = df[ordered_cols]
3738-
3739-
# Sort and set index
37403807
df = df.sort_values(by=['fullBurstID', 'burst']).set_index(['fullBurstID', 'burst'])
3741-
3742-
# Move geometry to end if present
3743-
if has_geometry and 'geometry' in df.columns:
3744-
df = df.loc[:, df.columns.drop("geometry").tolist() + ["geometry"]]
37453808

3746-
# Convert CRS if requested and we have a GeoDataFrame
3747-
if has_geometry and crs is not None:
3748-
return df.to_crs(crs)
3809+
if crs is not None and crs != native_crs:
3810+
df = df.to_crs(crs)
37493811
return df
3750-
3812+
37513813
@property
37523814
def spacing(self) -> tuple[float, float]:
37533815
"""Return the (y, x) grid spacing."""

insardev/insardev/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# See the LICENSE file in the insardev directory for license terms.
99
# Professional use requires an active per-seat subscription at: https://patreon.com/pechnikov
1010
# ----------------------------------------------------------------------------
11-
__version__ = '2026.3.8.post2'
11+
__version__ = '2026.3.8.post3'
1212

1313
# processing functions
1414
from .Stack import Stack

0 commit comments

Comments
 (0)