From e8e95aa06672a4fe96dedde32def31a2bcbbe020 Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Fri, 29 Aug 2025 18:19:05 -0400 Subject: [PATCH] Use `pinv` to solve for the cumulative displacement --- src/opera_utils/disp/_rebase.py | 128 +++++++++++++++++++++++--------- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/src/opera_utils/disp/_rebase.py b/src/opera_utils/disp/_rebase.py index da913e5b..3c2f7e66 100644 --- a/src/opera_utils/disp/_rebase.py +++ b/src/opera_utils/disp/_rebase.py @@ -4,11 +4,15 @@ from collections.abc import Sequence from datetime import datetime from enum import Enum +from typing import TypeVar +import dask.array import numpy as np import pandas as pd import xarray as xr +from opera_utils._helpers import flatten + from ._utils import _clamp_chunk_dict logger = logging.getLogger("opera_utils") @@ -17,6 +21,8 @@ "create_rebased_displacement", ] +T = TypeVar("T") + class NaNPolicy(str, Enum): """Policy for handling NaN values in rebase_timeseries.""" @@ -76,15 +82,26 @@ def create_rebased_displacement( } process_chunks = _clamp_chunk_dict(process_chunks, da_displacement.shape) - # Make the map_blocks-compatible function to accumulate the displacement def process_block(arr: xr.DataArray) -> xr.DataArray: - out = rebase_timeseries( - arr.to_numpy(), reference_datetimes, nan_policy=nan_policy + return xr.DataArray( + rebase_timeseries( + arr.to_numpy(), + reference_dates=reference_datetimes, + secondary_dates=da_displacement.time, + nan_policy=nan_policy, + ), + coords=arr.coords, + dims=arr.dims, ) - return xr.DataArray(out, coords=arr.coords, dims=arr.dims) + + d_chunked = da_displacement.chunk(process_chunks) + template = xr.DataArray( + data=dask.array.empty_like(d_chunked), + coords=da_displacement.coords, + ) # Process the dataset in blocks - rebased_da = da_displacement.chunk(process_chunks).map_blocks(process_block) + rebased_da = d_chunked.map_blocks(process_block, template=template) if add_reference_time: # Add initial reference epoch of zeros, and rechunk @@ -101,6 +118,7 @@ def process_block(arr: xr.DataArray) -> xr.DataArray: def rebase_timeseries( raw_data: np.ndarray, reference_dates: Sequence[datetime], + secondary_dates: Sequence[datetime], nan_policy: str | NaNPolicy = NaNPolicy.propagate, ) -> np.ndarray: """Adjust for moving reference dates to create a continuous time series. @@ -126,7 +144,9 @@ def rebase_timeseries( 3D array of displacement values with moving reference dates shape = (time, rows, cols) reference_dates : Sequence[datetime] - Reference dates for each time step + Reference date for each time step + secondary_dates : Sequence[datetime] + Secondary date for each time step nan_policy : choices = ["propagate", "omit"] Whether to propagate or omit (zero out) NaNs in the data. By default "propagate", which means any ministack, or any "reference crossover" @@ -140,33 +160,69 @@ def rebase_timeseries( Continuous displacement time series with consistent reference date """ - if len(set(reference_dates)) == 1: - return raw_data.copy() - - shape2d = raw_data.shape[1:] - cumulative_offset = np.zeros(shape2d, dtype=np.float32) - previous_displacement = np.zeros(shape2d, dtype=np.float32) - - # Set initial reference date - current_reference_date = reference_dates[0] - - output = np.zeros_like(raw_data) - # Process each time step - for cur_ref_date, current_displacement, out_layer in zip( - reference_dates, raw_data, output - ): - # Check for shift in temporal reference date - if cur_ref_date != current_reference_date: - # When reference date changes, accumulate the previous displacement - if nan_policy == NaNPolicy.omit: - np.nan_to_num(previous_displacement, copy=False) - cumulative_offset += previous_displacement - current_reference_date = cur_ref_date - - # Store current displacement for next iteration - previous_displacement = current_displacement.copy() - - # Add cumulative offset to get consistent reference - out_layer[:] = current_displacement + cumulative_offset - - return output + A = get_incidence_matrix(list(zip(reference_dates, secondary_dates))) + pA = np.linalg.pinv(A) + + pixels = np.nan_to_num(raw_data.reshape(raw_data.shape[0], -1)) + pixels_rebased = pA @ pixels + out_stack = pixels_rebased.reshape(raw_data.shape) + if nan_policy == NaNPolicy.omit: + return out_stack + + nan_mask = np.isnan(raw_data) + # Cumulatively sum the mask to find where we should hide after the fact + nans_propagated = np.cumsum(nan_mask.astype(int), axis=0) + out_stack[nans_propagated > 0] = np.nan + return out_stack + + +def get_incidence_matrix( + ifg_pairs: Sequence[tuple[T, T]], + sar_idxs: Sequence[T] | None = None, + delete_first_date_column: bool = True, +) -> np.ndarray: + """Build the indicator matrix from a list of ifg pairs (index 1, index 2). + + Parameters + ---------- + ifg_pairs : Sequence[tuple[T, T]] + List of ifg pairs represented as tuples of (day 1, day 2) + Can be ints, datetimes, etc. + sar_idxs : Sequence[T], optional + If provided, used as the total set of indexes which `ifg_pairs` + were formed from. + Otherwise, created from the unique entries in `ifg_pairs`. + Only provide if there are some dates which are not present in `ifg_pairs`. + delete_first_date_column : bool + If True, removes the first column of the matrix to make it full column rank. + Size will be `n_sar_dates - 1` columns. + Otherwise, the matrix will have `n_sar_dates`, but rank `n_sar_dates - 1`. + + Returns + ------- + A : np.array 2D + The incident-like matrix for the system: A*phi = dphi + Each row corresponds to an ifg, each column to a SAR date. + The value will be -1 on the early (reference) ifgs, +1 on later (secondary) + since the ifg phase = (later - earlier) + Shape: (n_ifgs, n_sar_dates - 1) + + """ + if sar_idxs is None: + sar_idxs = sorted(set(flatten(ifg_pairs))) + + M = len(ifg_pairs) + col_iter = sar_idxs[1:] if delete_first_date_column else sar_idxs + N = len(col_iter) + A = np.zeros((M, N)) + + # Create a dictionary mapping sar dates to matrix columns + # We take the first SAR acquisition to be time 0, leave out of matrix + date_to_col = {date: i for i, date in enumerate(col_iter)} + for i, (early, later) in enumerate(ifg_pairs): + if early in date_to_col: + A[i, date_to_col[early]] = -1 + if later in date_to_col: + A[i, date_to_col[later]] = +1 + + return A