Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 92 additions & 36 deletions src/opera_utils/disp/_rebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from typing import TypeVar

import dask.array
import numpy as np
import pandas as pd
import xarray as xr

from opera_utils._helpers import flatten

from ._utils import _clamp_chunk_dict

logger = logging.getLogger("opera_utils")
Expand All @@ -17,6 +21,8 @@
"create_rebased_displacement",
]

T = TypeVar("T")


class NaNPolicy(str, Enum):
"""Policy for handling NaN values in rebase_timeseries."""
Expand Down Expand Up @@ -76,15 +82,26 @@ def create_rebased_displacement(
}
process_chunks = _clamp_chunk_dict(process_chunks, da_displacement.shape)

# Make the map_blocks-compatible function to accumulate the displacement
def process_block(arr: xr.DataArray) -> xr.DataArray:
out = rebase_timeseries(
arr.to_numpy(), reference_datetimes, nan_policy=nan_policy
return xr.DataArray(
rebase_timeseries(
arr.to_numpy(),
reference_dates=reference_datetimes,
secondary_dates=da_displacement.time,
nan_policy=nan_policy,
),
coords=arr.coords,
dims=arr.dims,
)
return xr.DataArray(out, coords=arr.coords, dims=arr.dims)

d_chunked = da_displacement.chunk(process_chunks)
template = xr.DataArray(
data=dask.array.empty_like(d_chunked),
coords=da_displacement.coords,
)

# Process the dataset in blocks
rebased_da = da_displacement.chunk(process_chunks).map_blocks(process_block)
rebased_da = d_chunked.map_blocks(process_block, template=template)

if add_reference_time:
# Add initial reference epoch of zeros, and rechunk
Expand All @@ -101,6 +118,7 @@ def process_block(arr: xr.DataArray) -> xr.DataArray:
def rebase_timeseries(
raw_data: np.ndarray,
reference_dates: Sequence[datetime],
secondary_dates: Sequence[datetime],
nan_policy: str | NaNPolicy = NaNPolicy.propagate,
) -> np.ndarray:
"""Adjust for moving reference dates to create a continuous time series.
Expand All @@ -126,7 +144,9 @@ def rebase_timeseries(
3D array of displacement values with moving reference dates
shape = (time, rows, cols)
reference_dates : Sequence[datetime]
Reference dates for each time step
Reference date for each time step
secondary_dates : Sequence[datetime]
Secondary date for each time step
nan_policy : choices = ["propagate", "omit"]
Whether to propagate or omit (zero out) NaNs in the data.
By default "propagate", which means any ministack, or any "reference crossover"
Expand All @@ -140,33 +160,69 @@ def rebase_timeseries(
Continuous displacement time series with consistent reference date

"""
if len(set(reference_dates)) == 1:
return raw_data.copy()

shape2d = raw_data.shape[1:]
cumulative_offset = np.zeros(shape2d, dtype=np.float32)
previous_displacement = np.zeros(shape2d, dtype=np.float32)

# Set initial reference date
current_reference_date = reference_dates[0]

output = np.zeros_like(raw_data)
# Process each time step
for cur_ref_date, current_displacement, out_layer in zip(
reference_dates, raw_data, output
):
# Check for shift in temporal reference date
if cur_ref_date != current_reference_date:
# When reference date changes, accumulate the previous displacement
if nan_policy == NaNPolicy.omit:
np.nan_to_num(previous_displacement, copy=False)
cumulative_offset += previous_displacement
current_reference_date = cur_ref_date

# Store current displacement for next iteration
previous_displacement = current_displacement.copy()

# Add cumulative offset to get consistent reference
out_layer[:] = current_displacement + cumulative_offset

return output
A = get_incidence_matrix(list(zip(reference_dates, secondary_dates)))
pA = np.linalg.pinv(A)

pixels = np.nan_to_num(raw_data.reshape(raw_data.shape[0], -1))
pixels_rebased = pA @ pixels
out_stack = pixels_rebased.reshape(raw_data.shape)
if nan_policy == NaNPolicy.omit:
return out_stack

nan_mask = np.isnan(raw_data)
# Cumulatively sum the mask to find where we should hide after the fact
nans_propagated = np.cumsum(nan_mask.astype(int), axis=0)
out_stack[nans_propagated > 0] = np.nan
return out_stack


def get_incidence_matrix(
ifg_pairs: Sequence[tuple[T, T]],
sar_idxs: Sequence[T] | None = None,
delete_first_date_column: bool = True,
) -> np.ndarray:
"""Build the indicator matrix from a list of ifg pairs (index 1, index 2).

Parameters
----------
ifg_pairs : Sequence[tuple[T, T]]
List of ifg pairs represented as tuples of (day 1, day 2)
Can be ints, datetimes, etc.
sar_idxs : Sequence[T], optional
If provided, used as the total set of indexes which `ifg_pairs`
were formed from.
Otherwise, created from the unique entries in `ifg_pairs`.
Only provide if there are some dates which are not present in `ifg_pairs`.
delete_first_date_column : bool
If True, removes the first column of the matrix to make it full column rank.
Size will be `n_sar_dates - 1` columns.
Otherwise, the matrix will have `n_sar_dates`, but rank `n_sar_dates - 1`.

Returns
-------
A : np.array 2D
The incident-like matrix for the system: A*phi = dphi
Each row corresponds to an ifg, each column to a SAR date.
The value will be -1 on the early (reference) ifgs, +1 on later (secondary)
since the ifg phase = (later - earlier)
Shape: (n_ifgs, n_sar_dates - 1)

"""
if sar_idxs is None:
sar_idxs = sorted(set(flatten(ifg_pairs)))

M = len(ifg_pairs)
col_iter = sar_idxs[1:] if delete_first_date_column else sar_idxs
N = len(col_iter)
A = np.zeros((M, N))

# Create a dictionary mapping sar dates to matrix columns
# We take the first SAR acquisition to be time 0, leave out of matrix
date_to_col = {date: i for i, date in enumerate(col_iter)}
for i, (early, later) in enumerate(ifg_pairs):
if early in date_to_col:
A[i, date_to_col[early]] = -1
if later in date_to_col:
A[i, date_to_col[later]] = +1

return A
Loading