From 4d2819fc6699cdb99ee26b6c21658b84df916bc7 Mon Sep 17 00:00:00 2001 From: phildong <35715936+phildong@users.noreply.github.com> Date: Wed, 31 Dec 2025 18:28:22 -0500 Subject: [PATCH 1/2] initial AI-generated draft of docstrings --- src/indeca/AR_kernel.py | 616 +++++++++++++++++++++++++++-- src/indeca/dashboard.py | 119 +++++- src/indeca/deconv.py | 342 +++++++++++++++-- src/indeca/pipeline.py | 196 ++++++++-- src/indeca/simulation.py | 810 ++++++++++++++++++++++++++++++++++++--- src/indeca/utils.py | 149 ++++++- 6 files changed, 2067 insertions(+), 165 deletions(-) diff --git a/src/indeca/AR_kernel.py b/src/indeca/AR_kernel.py index 94429ef..5f95c35 100644 --- a/src/indeca/AR_kernel.py +++ b/src/indeca/AR_kernel.py @@ -1,9 +1,26 @@ +""" +AR kernel estimation and manipulation utilities. + +This module provides functions for estimating and refining autoregressive (AR) +kernel parameters from calcium imaging data. It implements the kernel update +step of the InDeCa algorithm, which uses inferred spike trains to iteratively +estimate interpretable bi-exponential calcium dynamics. + +Key functionality: +- Solve for AR coefficients from data +- Fit sum-of-exponentials models to impulse responses +- Estimate noise levels from power spectral density +- Convert between AR coefficients and time constants +""" + import warnings +from typing import Literal, Optional, Tuple, Union import cvxpy as cp import numpy as np import pandas as pd import scipy.sparse as sps +from numpy.typing import NDArray from scipy.integrate import cumulative_trapezoid from scipy.linalg import lstsq, toeplitz from scipy.optimize import curve_fit @@ -13,13 +30,67 @@ from indeca.simulation import AR2tau, ar_pulse, solve_p, tau2AR -def convolve_g(s, g): +def convolve_g(s: NDArray, g: NDArray) -> NDArray: + """ + Convolve a signal with the inverse AR filter (i.e., the kernel H = G⁻¹). + + Constructs the AR matrix G from coefficients g, computes its inverse G⁻¹, + and applies it to the input signal. Since H = G⁻¹, this effectively + convolves with the impulse response kernel derived from the AR coefficients. + + For spike input s, this produces calcium c = G⁻¹s = Hs. + + Parameters + ---------- + s : NDArray + Input signal of shape (T,). + g : NDArray + AR coefficients of shape (p,), typically (2,) for AR(2). + + Returns + ------- + NDArray + Output signal G⁻¹s of shape (T,). + + See Also + -------- + convolve_h : Convolve with a general kernel h directly. + construct_G : Build the AR relationship matrix G. + + Notes + ----- + The AR relationship is Gc = s. This function computes c = G⁻¹s, + which is equivalent to convolving s with the kernel h where Hs = c. + """ G = construct_G(g, len(s)) Gi = sps.linalg.inv(G) return np.array(Gi @ s.reshape((-1, 1))).squeeze() -def convolve_h(s, h): +def convolve_h(s: NDArray, h: NDArray) -> NDArray: + """ + Convolve a signal with a general kernel using matrix multiplication. + + Builds the full convolution matrix H from kernel h and applies it to + the input signal s. The result is equivalent to np.convolve(h, s) + but uses explicit matrix construction. + + Parameters + ---------- + s : NDArray + Input signal of shape (T,). + h : NDArray + Convolution kernel (impulse response) of shape (T,). + + Returns + ------- + NDArray + Convolved output Hs of shape (T,). + + See Also + -------- + convolve_g : Convolve with the inverse AR filter G⁻¹. + """ T = len(s) H0 = h.reshape((-1, 1)) H1n = [ @@ -30,7 +101,39 @@ def convolve_h(s, h): return np.real(np.array(H @ s.reshape((-1, 1))).squeeze()) -def solve_g(y, s, norm="l1", masking=False): +def solve_g( + y: NDArray, s: NDArray, norm: str = "l1", masking: bool = False +) -> Tuple[float, float]: + """ + Solve for AR(2) coefficients that best relate signals y and s. + + Finds AR coefficients (θ₁, θ₂) that minimize the reconstruction + error ||G @ y - s|| subject to stability constraints, where G is + the AR matrix parameterized by θ₁ and θ₂. + + Parameters + ---------- + y : NDArray + First input signal of shape (T,). + s : NDArray + Second input signal of shape (T,). + norm : str, default="l1" + Error norm to minimize: "l1" or "l2". + masking : bool, default=False + If True, only consider time points where s > 0. + + Returns + ------- + theta_1 : float + First AR coefficient (θ₁), constrained to be non-negative. + theta_2 : float + Second AR coefficient (θ₂), constrained to be non-positive. + + Notes + ----- + The constraints θ₁ ≥ 0 and θ₂ ≤ 0 ensure the AR process has + appropriate decay characteristics (real, positive time constants). + """ T = len(s) theta_1, theta_2 = cp.Variable(), cp.Variable() G = ( @@ -55,7 +158,41 @@ def solve_g(y, s, norm="l1", masking=False): return theta_1.value, theta_2.value -def fit_sumexp(y, N, x=None, use_l1=False): +def fit_sumexp( + y: NDArray, N: int, x: Optional[NDArray] = None, use_l1: bool = False +) -> Tuple[NDArray, NDArray, NDArray]: + """ + Fit a sum of N exponentials to data using the Prony-like method. + + Uses cumulative integration and eigenvalue decomposition to extract + exponential rates and amplitudes from the input signal. + + Parameters + ---------- + y : NDArray + Input signal of shape (T,). + N : int + Number of exponential terms to fit. + x : NDArray, optional + Time points of shape (T,). If None, uses np.arange(T). + use_l1 : bool, default=False + If True, use L1 norm for fitting (more robust to outliers). + + Returns + ------- + lams : NDArray + Exponential rates (λ) of shape (N,). Sorted in descending order. + The exponential form is exp(λ * t), so λ < 0 for decay. + ps : NDArray + Amplitude coefficients of shape (N,). + y_fit : NDArray + Fitted signal of shape (T,). + + References + ---------- + .. [1] http://arxiv.org/abs/physics/0305019 + .. [2] https://github.juangburgos.com/FitSumExponentials/lab/index.html + """ # ref: http://arxiv.org/abs/physics/0305019 # ref: https://github.juangburgos.com/FitSumExponentials/lab/index.html T = len(y) @@ -86,7 +223,28 @@ def fit_sumexp(y, N, x=None, use_l1=False): return lams, ps, y_fit -def fit_sumexp_split(y): +def fit_sumexp_split(y: NDArray) -> Tuple[NDArray, NDArray, NDArray]: + """ + Fit sum of exponentials by splitting at the peak. + + Separately fits single exponentials to the rising and decaying + portions of the signal. Useful for signals with distinct rise + and decay phases. + + Parameters + ---------- + y : NDArray + Input signal of shape (T,). + + Returns + ------- + lams : NDArray + Exponential rates [λ_decay, λ_rise] of shape (2,). + ps : NDArray + Amplitude coefficients [p_decay, p_rise] of shape (2,). + y_fit : NDArray + Fitted signal of shape (T,). + """ T = len(y) x = np.arange(T) idx_split = np.argmax(y) @@ -99,7 +257,50 @@ def fit_sumexp_split(y): ) -def fit_sumexp_gd(y, x=None, y_weight=None, fit_amp=True, interp_factor=100): +def fit_sumexp_gd( + y: NDArray, + x: Optional[NDArray] = None, + y_weight: Optional[NDArray] = None, + fit_amp: Union[bool, str] = True, + interp_factor: int = 100, +) -> Tuple[NDArray, NDArray, float, NDArray]: + """ + Fit bi-exponential using gradient descent (curve_fit). + + Uses nonlinear least squares to fit a bi-exponential kernel to the + input signal. Initial guesses are estimated from the signal shape. + + Parameters + ---------- + y : NDArray + Input signal of shape (T,). + x : NDArray, optional + Time points of shape (T,). If None, uses np.arange(T). + y_weight : NDArray, optional + Per-point weights for fitting (inverse variance). + fit_amp : bool or str, default=True + How to handle amplitude: + - True: Fit time constants, compute amplitude from AR normalization + - False: Fit time constants only, use p = [1, -1] + - "scale": Fit time constants and an overall scaling factor + interp_factor : int, default=100 + Interpolation factor for initial guess estimation. + + Returns + ------- + lams : NDArray + Exponential rates [-1/τ_d, -1/τ_r] of shape (2,). + p : NDArray + Amplitude coefficients of shape (2,). + scal : float + Overall scaling factor (1.0 if fit_amp != "scale"). + y_fit : NDArray + Fitted signal of shape (T,). + + Warnings + -------- + Issues a warning if τ_d ≤ τ_r (decay faster than rise), and swaps them. + """ T = len(y) if x is None: x = np.arange(T) @@ -179,7 +380,43 @@ def fit_sumexp_gd(y, x=None, y_weight=None, fit_amp=True, interp_factor=100): ) -def fit_sumexp_iter(y, max_iters=50, atol=1e-3, **kwargs): +def fit_sumexp_iter( + y: NDArray, max_iters: int = 50, atol: float = 1e-3, **kwargs +) -> Tuple[NDArray, float, float, NDArray, pd.DataFrame]: + """ + Iteratively fit bi-exponential with amplitude refinement. + + Alternates between fitting time constants and updating the amplitude + normalization until convergence. + + Parameters + ---------- + y : NDArray + Input signal of shape (T,). + max_iters : int, default=50 + Maximum number of iterations. + atol : float, default=1e-3 + Absolute tolerance for amplitude convergence. + **kwargs + Additional arguments passed to fit_sumexp_gd. + + Returns + ------- + lams : NDArray + Final exponential rates of shape (2,). + p : float + Final amplitude normalization factor. + scal : float + Overall scaling factor. + y_fit : NDArray + Fitted signal of shape (T,). + coef_df : pd.DataFrame + Iteration history with columns: i_iter, p, tau_d, tau_r. + + Warnings + -------- + Issues a warning if max_iters is reached without convergence. + """ _, _, scal, y_fit = fit_sumexp_gd(y, fit_amp="scale") y_norm = y / scal p = 1 @@ -210,7 +447,29 @@ def fit_sumexp_iter(y, max_iters=50, atol=1e-3, **kwargs): return lams, p, scal, y_fit, coef_df -def lst_l1(A, b): +def lst_l1(A: NDArray, b: NDArray) -> NDArray: + """ + Solve least squares with L1 norm using convex optimization. + + Minimizes ||b - A @ x||_1 using CVXPY. + + Parameters + ---------- + A : NDArray + Design matrix of shape (m, n). + b : NDArray + Target vector of shape (m,). + + Returns + ------- + NDArray + Solution vector of shape (n,). + + Raises + ------ + AssertionError + If the optimization does not reach optimal status. + """ x = cp.Variable(A.shape[1]) obj = cp.Minimize(cp.norm(b - A @ x, 1)) prob = cp.Problem(obj) @@ -220,16 +479,56 @@ def lst_l1(A, b): def solve_h( - y, - s, - scal, - err_wt=None, - h_len=60, - norm="l2", - smth_penalty=0, - ignore_len=0, - up_factor=1, -): + y: NDArray, + s: NDArray, + scal: NDArray, + err_wt: Optional[NDArray] = None, + h_len: int = 60, + norm: str = "l2", + smth_penalty: float = 0, + ignore_len: int = 0, + up_factor: int = 1, +) -> NDArray: + """ + Solve for the convolution kernel h given observed data and spikes. + + Estimates an unconstrained kernel h that minimizes reconstruction error: + ||y - scale * R @ (h * s) - b|| + + where R is the downsampling matrix and b is a baseline offset. + + Parameters + ---------- + y : NDArray + Observed signal. Shape (T,) for single unit or (n_cells, T) + for multiple units with shared kernel. + s : NDArray + Input signal (e.g., spike trains). Shape matches y. + scal : NDArray + Amplitude scaling factors. Shape (1,) or (n_cells, 1). + err_wt : NDArray, optional + Per-timepoint error weights of shape matching y. + h_len : int, default=60 + Length of the kernel to estimate. + norm : str, default="l2" + Error norm: "l1" or "l2". + smth_penalty : float, default=0 + L1 penalty on kernel differences (smoothness regularization). + ignore_len : int, default=0 + Number of initial kernel samples to exclude from smoothness penalty. + up_factor : int, default=1 + Temporal upsampling factor. + + Returns + ------- + NDArray + Estimated kernel h, zero-padded to length T. + + Notes + ----- + Uses CLARABEL solver for convex optimization. The baseline offset b + is constrained to be non-negative. + """ y, s = y.squeeze(), s.squeeze() assert y.ndim == s.ndim multi_unit = y.ndim > 1 @@ -270,16 +569,58 @@ def solve_h( def solve_fit_h( - y, - s, - scal, - N=2, - s_len=60, - norm="l1", - tol=1e-3, + y: NDArray, + s: NDArray, + scal: NDArray, + N: int = 2, + s_len: int = 60, + norm: str = "l1", + tol: float = 1e-3, max_iters: int = 30, - verbose=False, -): + verbose: bool = False, +) -> Tuple[NDArray, NDArray, NDArray, NDArray, pd.DataFrame, pd.DataFrame]: + """ + Iteratively solve for kernel with smoothing to ensure real exponentials. + + Uses binary search on smoothing penalty to find the minimum regularization + that produces a kernel with real (not complex) exponential rates. + + Parameters + ---------- + y : NDArray + Observed signal. + s : NDArray + Input signal (e.g., spike trains). + scal : NDArray + Amplitude scaling factors. + N : int, default=2 + Number of exponential terms to fit. + s_len : int, default=60 + Kernel length. + norm : str, default="l1" + Error norm for solve_h. + tol : float, default=1e-3 + Tolerance for smoothing penalty binary search. + max_iters : int, default=30 + Maximum number of iterations. + verbose : bool, default=False + If True, print iteration progress. + + Returns + ------- + lams : NDArray + Exponential rates of shape (N,). + ps : NDArray + Amplitude coefficients of shape (N,). + h : NDArray + Estimated kernel. + h_fit : NDArray + Fitted exponential kernel. + metric_df : pd.DataFrame + Iteration metrics with columns: iter, smth_penal, isreal. + h_df : pd.DataFrame + Kernel history with columns: iter, smth_penal, h, h_fit, frame. + """ metric_df = None h_df = None smth_penal = 0 @@ -329,7 +670,54 @@ def solve_fit_h( return lams, ps, h, h_fit, metric_df, h_df -def solve_fit_h_num(y, s, scal, err_wt=None, N=2, h_len=60, norm="l2", up_factor=1): +def solve_fit_h_num( + y: NDArray, + s: NDArray, + scal: NDArray, + err_wt: Optional[NDArray] = None, + N: int = 2, + h_len: int = 60, + norm: str = "l2", + up_factor: int = 1, +) -> Tuple[NDArray, NDArray, float, NDArray, NDArray]: + """ + Solve for kernel and fit bi-exponential numerically. + + Combines solve_h and fit_sumexp_gd to estimate a kernel and fit + a bi-exponential model to it. + + Parameters + ---------- + y : NDArray + Observed signal. + s : NDArray + Input signal (e.g., spike trains). + scal : NDArray + Amplitude scaling factors. + err_wt : NDArray, optional + Per-timepoint error weights. + N : int, default=2 + Number of exponential terms. + h_len : int, default=60 + Kernel length. + norm : str, default="l2" + Error norm. + up_factor : int, default=1 + Temporal upsampling factor. + + Returns + ------- + lams : NDArray + Exponential rates of shape (N,). + p : NDArray + Amplitude coefficients of shape (N,). + scal : float + Fitted scaling factor. + h : NDArray + Estimated kernel. + h_fit_pad : NDArray + Fitted kernel, zero-padded to match h length. + """ if y.ndim == 1: ylen = len(y) else: @@ -352,8 +740,61 @@ def solve_fit_h_num(y, s, scal, err_wt=None, N=2, h_len=60, norm="l2", up_factor def updateAR( - y, s, scal, err_wt=None, N=2, h_len=60, norm="l2", up_factor=1, pre_agg=True -): + y: NDArray, + s: NDArray, + scal: NDArray, + err_wt: Optional[NDArray] = None, + N: int = 2, + h_len: int = 60, + norm: str = "l2", + up_factor: int = 1, + pre_agg: bool = True, +) -> Tuple[NDArray, NDArray, float, NDArray, NDArray]: + """ + Update AR parameters from data and inferred spikes. + + Main kernel update function for InDeCa. Estimates time constants + from the relationship between observed fluorescence and spike trains. + + Parameters + ---------- + y : NDArray + Observed fluorescence of shape (n_cells, T) or (T,). + s : NDArray + Spike trains matching y shape. + scal : NDArray + Amplitude scaling factors. + err_wt : NDArray, optional + Per-timepoint error weights. + N : int, default=2 + Number of exponential terms (typically 2 for rise and decay). + h_len : int, default=60 + Kernel length in frames. + norm : str, default="l2" + Error norm for kernel estimation. + up_factor : int, default=1 + Temporal upsampling factor. + pre_agg : bool, default=True + If True, aggregate spikes before fitting (more efficient). + + Returns + ------- + taus : NDArray + Time constants [τ_d, τ_r] of shape (2,), in original time units. + ps : NDArray + Amplitude coefficients of shape (2,). + ar_scal : float + Scaling factor from fit. + h : NDArray + Estimated kernel (zero-padded). + h_fit : NDArray + Fitted bi-exponential kernel (zero-padded). + + Notes + ----- + This implements the kernel update step from the InDeCa algorithm, + using the inferred spikes to estimate a denoised calcium dynamics kernel. + """ if not pre_agg: lams, ps, ar_scal, h, h_fit = solve_fit_h_num( y, s, scal, err_wt=err_wt, N=N, h_len=h_len, norm=norm, up_factor=up_factor @@ -381,7 +822,45 @@ def updateAR( ) -def solve_g_cons(y, s, lam_tol=1e-6, lam_start=1, max_iter=30): +def solve_g_cons( + y: NDArray, + s: NDArray, + lam_tol: float = 1e-6, + lam_start: float = 1, + max_iter: int = 30, +) -> Tuple[float, float]: + """ + Fit AR coefficients with constraint for real exponentials. + + Uses iterative penalty adjustment to find AR coefficients that + correspond to real (not complex) time constants. + + Parameters + ---------- + y : NDArray + First input signal of shape (T,). + s : NDArray + Second input signal of shape (T,). + lam_tol : float, default=1e-6 + Tolerance for penalty convergence. + lam_start : float, default=1 + Initial penalty value. + max_iter : int, default=30 + Maximum number of iterations. + + Returns + ------- + th1 : float + First AR coefficient. + th2 : float + Second AR coefficient. + + Notes + ----- + The characteristic equation θ₁² + 4θ₂ < 0 indicates complex roots + (oscillatory response). This function iteratively adjusts the penalty + to find coefficients on the boundary of the real/complex region. + """ T = len(s) i_iter = 0 lam = lam_start @@ -423,8 +902,45 @@ def solve_g_cons(y, s, lam_tol=1e-6, lam_start=1, max_iter=30): def estimate_coefs( - y: np.ndarray, p: int, noise_freq: tuple, use_smooth: bool, add_lag: int -): + y: NDArray, + p: int, + noise_freq: Optional[Tuple[float, float]], + use_smooth: bool, + add_lag: int, +) -> Tuple[NDArray, float]: + """ + Estimate AR coefficients from noisy data. + + Uses Yule-Walker equations with optional smoothing and noise estimation + to fit AR coefficients to the input signal. + + Parameters + ---------- + y : NDArray + Input fluorescence signal of shape (T,). + p : int + Order of the AR process. + noise_freq : tuple of float, optional + Frequency range (low, high) as fraction of Nyquist for noise estimation. + If None, assumes zero noise. + use_smooth : bool + If True, low-pass filter the signal before AR estimation. + add_lag : int + Additional lags to include in the Yule-Walker estimation. + + Returns + ------- + g : NDArray + AR coefficients of shape (p,). + tn : float + Estimated noise level. + + See Also + -------- + get_ar_coef : Core AR coefficient estimation. + filt_fft : FFT-based filtering. + noise_fft : Noise estimation from PSD. + """ if noise_freq is None: tn = 0 else: @@ -551,7 +1067,39 @@ def get_ar_coef( return g -def AR_upsamp_real(theta, upsamp: int = 1, fit_nsamp: int = 1000): +def AR_upsamp_real( + theta: Tuple[float, float], upsamp: int = 1, fit_nsamp: int = 1000 +) -> Tuple[Tuple[float, float], NDArray, NDArray]: + """ + Compute upsampled AR parameters ensuring real exponentials. + + Converts AR coefficients to time constants, scales for upsampling, + and converts back. Ensures the result corresponds to real (not complex) + bi-exponential dynamics. + + Parameters + ---------- + theta : tuple of float + AR coefficients (θ₁, θ₂) at original sampling rate. + upsamp : int, default=1 + Upsampling factor. + fit_nsamp : int, default=1000 + Number of samples to use for impulse response fitting. + + Returns + ------- + theta_up : tuple of float + Upsampled AR coefficients (θ₁', θ₂'). + tau_up : NDArray + Upsampled time constants [τ_d, τ_r] of shape (2,). + p_up : NDArray + Amplitude coefficients [p, -p] of shape (2,). + + Raises + ------ + AssertionError + If τ_d ≤ τ_r or if amplitude is invalid (NaN, inf, or ≤ 0). + """ tr = ar_pulse(*theta, nsamp=fit_nsamp, shifted=True)[0] lams, cur_p, scl, tr_fit = fit_sumexp_gd(tr, fit_amp=True) tau = -1 / lams diff --git a/src/indeca/dashboard.py b/src/indeca/dashboard.py index 68f7830..2bfb297 100644 --- a/src/indeca/dashboard.py +++ b/src/indeca/dashboard.py @@ -1,19 +1,86 @@ +""" +Interactive dashboard for real-time visualization of InDeCa optimization. + +This module provides a web-based dashboard for monitoring the InDeCa +algorithm during execution. It displays: +- Per-cell fluorescence traces, calcium fits, and spike trains +- Estimated kernels and their bi-exponential fits +- Iteration metrics (error, scale, time constants) +- Penalty search heatmaps + +The dashboard uses Panel and Plotly for interactive visualization and +can be accessed via a web browser during algorithm execution. +""" + +from typing import Any, Dict, Optional, Union + import numpy as np +from numpy.typing import NDArray import panel as pn import plotly.graph_objects as go from plotly.subplots import make_subplots class Dashboard: + """ + Interactive web dashboard for monitoring InDeCa optimization. + + Provides real-time visualization of the deconvolution process including + per-cell traces, kernels, and iteration metrics. Runs as a threaded + web server accessible via browser. + + Parameters + ---------- + Y : NDArray, optional + Input fluorescence data of shape (n_cells, n_timepoints). + Either Y or both ncell and T must be provided. + ncell : int, optional + Number of cells (required if Y not provided). + T : int, optional + Number of time points (required if Y not provided). + max_iters : int, default=20 + Maximum number of iterations to store. + kn_len : int, default=60 + Length of kernel for display. + port : int, default=54321 + Port number for the web server. + + Attributes + ---------- + it_vars : dict + Dictionary storing iteration data with keys: + - 'c': Calcium traces (max_iters, ncell, T) + - 's': Spike trains (max_iters, ncell, T) + - 'h': Kernels (max_iters, ncell, kn_len) + - 'h_fit': Fitted kernels (max_iters, ncell, kn_len) + - 'scale': Scaling factors (max_iters, ncell) + - 'tau_d': Decay time constants (max_iters, ncell) + - 'tau_r': Rise time constants (max_iters, ncell) + - 'err': Errors (max_iters, ncell) + - 'penal_err': Penalty search data + it_update : int + Current iteration being updated. + it_view : int + Current iteration being viewed. + + Examples + -------- + >>> dashboard = Dashboard(Y=fluorescence_data, kn_len=60) + >>> # Access at http://localhost:54321 + >>> dashboard.update(uid=0, s=spike_train, c=calcium_trace) + >>> dashboard.set_iter(1) + >>> dashboard.stop() + """ + def __init__( self, - Y: np.ndarray = None, - ncell: int = None, - T: int = None, + Y: Optional[NDArray] = None, + ncell: Optional[int] = None, + T: Optional[int] = None, max_iters: int = 20, kn_len: int = 60, port: int = 54321, - ): + ) -> None: super().__init__() self.title = "Dashboard" if Y is None: @@ -28,7 +95,7 @@ def __init__( self.max_iters = max_iters self.it_update = 0 self.it_view = 0 - self.it_vars = { + self.it_vars: Dict[str, NDArray] = { "c": np.full((max_iters, ncell, T), np.nan), "s": np.full((max_iters, ncell, T), np.nan), "h": np.full((max_iters, ncell, kn_len), np.nan), @@ -246,13 +313,43 @@ def _refresh_it_view(self): self._refresh_cells_fig(u) self._refresh_err_penal_fit(u) - def set_iter(self, it: int): + def set_iter(self, it: int) -> None: + """ + Set the current iteration for display and update. + + If the view is tracking the update iteration, it will advance + to show the new iteration. + + Parameters + ---------- + it : int + Iteration number to set. + """ if self.it_update == self.it_view: self.it_view = it self._refresh_it_view() self.it_update = it - def update(self, uid: int = None, **kwargs): + def update(self, uid: Optional[int] = None, **kwargs: Any) -> None: + """ + Update dashboard data for one or more cells. + + Parameters + ---------- + uid : int, optional + Cell ID to update. If None, updates all cells. + **kwargs : Any + Data to update. Supported keys: + - c : NDArray - Calcium trace + - s : NDArray - Spike train + - h : NDArray - Kernel + - h_fit : NDArray - Fitted kernel + - scale : float - Scaling factor + - tau_d : float - Decay time constant + - tau_r : float - Rise time constant + - err : float - Error value + - penal_err : dict - Penalty search data with keys 'penal', 'scale', 'err' + """ if uid is None: uids = np.arange(self.ncell) else: @@ -279,5 +376,11 @@ def update(self, uid: int = None, **kwargs): for v in ["penal", "scale", "err"]: self.it_vars[vname][self.it_update, u][v].append(dat[v]) - def stop(self): + def stop(self) -> None: + """ + Stop the dashboard web server. + + Should be called when the algorithm completes to cleanly + shut down the threaded server. + """ self.sv.stop() diff --git a/src/indeca/deconv.py b/src/indeca/deconv.py index f0b2421..9d8c703 100644 --- a/src/indeca/deconv.py +++ b/src/indeca/deconv.py @@ -1,6 +1,46 @@ +""" +Deconvolution solver for calcium imaging spike inference. + +This module implements the core DeconvBin class for solving the spike inference +problem from calcium imaging data. It formulates deconvolution as a convex +optimization problem and uses binary pursuit to recover sparse spike trains. + +The forward model is: + y = scale * R @ H @ s + b + noise + +Where: + - y: Observed fluorescence trace (y_len,) + - s: Spike train to infer (T,), where T = y_len * upsamp + - H: Convolution matrix encoding calcium dynamics (T × T) + - R: Downsampling matrix (y_len × T) + - scale: Amplitude scaling factor + - b: Baseline offset + +Key features: + - Multiple optimization backends: CVXPY, OSQP, EMOSQP, CUOSQP (GPU) + - L0/L1 sparsity penalties with automatic tuning via DIRECT algorithm + - Temporal upsampling for sub-frame spike timing resolution + - Adaptive error weighting (FFT-based, correlation-based) + - Amplitude constraints for physiologically realistic spikes + +Classes +------- +DeconvBin + Main solver class for binary pursuit deconvolution. + +Functions +--------- +construct_R + Build temporal downsampling matrix. +construct_G + Build AR relationship matrix from coefficients. +max_thres + Generate threshold series for binary spike detection. +""" + import itertools as itt import warnings -from typing import Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import cvxpy as cp import numpy as np @@ -9,6 +49,7 @@ import scipy.sparse as sps import xarray as xr from numba import njit +from numpy.typing import NDArray from scipy.ndimage import label from scipy.optimize import direct from scipy.signal import ShortTimeFFT, find_peaks @@ -28,7 +69,27 @@ logger.warning("No GPU solver support") -def construct_R(T: int, up_factor: int): +def construct_R(T: int, up_factor: int) -> sps.csc_matrix: + """ + Construct a temporal downsampling/upsampling matrix. + + Creates a sparse matrix R that relates upsampled spike times to + observed frame times. When up_factor > 1, R sums groups of up_factor + consecutive time points into single frames. + + Parameters + ---------- + T : int + Number of observed time frames. + up_factor : int + Temporal upsampling factor. The spike train has T * up_factor points. + + Returns + ------- + sps.csc_matrix + Sparse matrix of shape (T, T * up_factor) in CSC format. + When up_factor=1, returns identity matrix. + """ if up_factor > 1: return sps.csc_matrix( ( @@ -41,11 +102,49 @@ def construct_R(T: int, up_factor: int): return sps.eye(T, format="csc") -def sum_downsample(a, factor): +def sum_downsample(a: NDArray, factor: int) -> NDArray: + """ + Downsample an array by summing consecutive groups. + + Parameters + ---------- + a : NDArray + Input array to downsample. + factor : int + Downsampling factor. + + Returns + ------- + NDArray + Downsampled array where each element is the sum of `factor` + consecutive elements from the input. + """ return np.convolve(a, np.ones(factor), mode="full")[factor - 1 :: factor] -def construct_G(fac: np.ndarray, T: int, fromTau=False): +def construct_G(fac: NDArray, T: int, fromTau: bool = False) -> sps.csc_matrix: + """ + Construct the AR relationship matrix G. + + Builds a sparse lower-triangular matrix G that encodes the AR(2) + relationship: s = G @ c, where s is the spike train and c is calcium. + + Parameters + ---------- + fac : NDArray + AR(2) coefficients of shape (2,), or time constants if fromTau=True. + T : int + Number of time points. + fromTau : bool, default=False + If True, interpret `fac` as time constants [τ_d, τ_r] and convert + to AR coefficients. + + Returns + ------- + sps.csc_matrix + Sparse AR matrix of shape (T, T) in CSC format. + Structure: G[i,i] = 1, G[i,i-1] = -θ₁, G[i,i-2] = -θ₂ + """ fac = np.array(fac) assert fac.shape == (2,) if fromTau: @@ -60,17 +159,53 @@ def construct_G(fac: np.ndarray, T: int, fromTau=False): def max_thres( - a: xr.DataArray, + a: Union[NDArray, xr.DataArray], nthres: int, - th_min=0.1, - th_max=0.9, - ds=None, - return_thres=False, - th_amplitude=False, - delta=1e-6, - reverse_thres=False, + th_min: float = 0.1, + th_max: float = 0.9, + ds: Optional[int] = None, + return_thres: bool = False, + th_amplitude: bool = False, + delta: float = 1e-6, + reverse_thres: bool = False, nz_only: bool = False, -): +) -> Union[List[NDArray], Tuple[List[NDArray], List[float]]]: + """ + Generate a series of thresholded spike trains for binary pursuit. + + Creates multiple candidate spike trains by applying different amplitude + thresholds to a continuous spike estimate. + + Parameters + ---------- + a : NDArray or xr.DataArray + Continuous spike estimate to threshold. + nthres : int + Number of threshold levels to generate. + th_min : float, default=0.1 + Minimum threshold as fraction of maximum value. + th_max : float, default=0.9 + Maximum threshold as fraction of maximum value. + ds : int, optional + Downsampling factor to apply after thresholding. + return_thres : bool, default=False + If True, also return the threshold values. + th_amplitude : bool, default=False + If True, divide by threshold instead of binary comparison. + delta : float, default=1e-6 + Minimum value to avoid division by zero. + reverse_thres : bool, default=False + If True, start from high threshold and go to low. + nz_only : bool, default=False + If True, exclude threshold levels that produce all-zero outputs. + + Returns + ------- + S_ls : list of NDArray + List of thresholded spike trains. + thres : list of float, optional + Threshold values, returned if return_thres=True. + """ amax = a.max() if reverse_thres: thres = np.linspace(th_max, th_min, nthres) @@ -94,8 +229,33 @@ def max_thres( @njit(nopython=True, nogil=True, cache=True) def bin_convolve( - coef: np.ndarray, s: np.ndarray, nzidx_s: np.ndarray = None, s_len: int = None -): + coef: NDArray, + s: NDArray, + nzidx_s: Optional[NDArray] = None, + s_len: Optional[int] = None, +) -> NDArray: + """ + Fast convolution optimized for sparse binary spike trains. + + Computes the convolution of a kernel with a sparse spike train by + only processing non-zero spike locations. JIT-compiled with Numba. + + Parameters + ---------- + coef : NDArray + Convolution kernel of shape (coef_len,). + s : NDArray + Spike train (potentially sparse). + nzidx_s : NDArray, optional + Indices mapping positions in s to full time array. + s_len : int, optional + Output length. If None, uses len(s). + + Returns + ------- + NDArray + Convolution result of shape (s_len,). + """ coef_len = len(coef) if s_len is None: s_len = len(s) @@ -111,7 +271,22 @@ def bin_convolve( @njit(nopython=True, nogil=True, cache=True) -def max_consecutive(arr): +def max_consecutive(arr: NDArray) -> int: + """ + Find the maximum number of consecutive True values in a boolean array. + + JIT-compiled with Numba for performance. + + Parameters + ---------- + arr : NDArray + Boolean array. + + Returns + ------- + int + Length of the longest consecutive True sequence. + """ max_count = 0 current_count = 0 for value in arr: @@ -124,32 +299,137 @@ def max_consecutive(arr): class DeconvBin: + """ + Binary pursuit deconvolution solver for calcium imaging spike inference. + + Infers spike trains from calcium fluorescence traces by solving a + convex optimization problem with sparsity constraints. Supports multiple + optimization backends and various regularization strategies. + + The forward model is: + y = scale * R @ H @ s + b + + Where y is the observed fluorescence, s is the spike train, H is the + convolution matrix, R is the downsampling matrix, and b is baseline. + + Parameters + ---------- + y : NDArray, optional + Observed fluorescence trace. Either y or y_len must be provided. + y_len : int, optional + Length of fluorescence trace (if y not provided). + theta : NDArray, optional + AR(2) coefficients [θ₁, θ₂]. Either theta or tau must be provided. + tau : NDArray, optional + Time constants [τ_d, τ_r] in frames. + ps : NDArray, optional + Exponential amplitudes [p_d, p_r]. Required if tau is provided. + coef : NDArray, optional + Impulse response kernel. Computed from theta/tau if not provided. + coef_len : int, default=100 + Length of impulse response kernel. + scale : float, default=1 + Initial amplitude scaling factor. + penal : str, default="l1" + Sparsity penalty type: "l0", "l1", or None. + use_base : bool, default=False + If True, include baseline offset in optimization. + upsamp : int, default=1 + Temporal upsampling factor for sub-frame resolution. + norm : str, default="l2" + Error norm: "l1", "l2", or "huber". + mixin : bool, default=False + If True, use mixed-integer formulation (CVXPY only). + backend : str, default="osqp" + Optimization backend: "cvxpy", "osqp", "emosqp", or "cuosqp" (GPU). + nthres : int, default=1000 + Number of threshold levels for binary pursuit. + err_weighting : str, optional + Error weighting scheme: "fft", "corr", "adaptive", or None. + wt_trunc_thres : float, optional + Threshold for truncating error weights. + masking_radius : int, optional + Radius for search region masking. + pks_polish : bool, default=True + If True, refine spike locations after optimization. + th_min : float, default=0 + Minimum threshold as fraction of maximum. + th_max : float, default=1 + Maximum threshold as fraction of maximum. + density_thres : float, optional + Maximum allowed spike density. + ncons_thres : int or "auto", optional + Maximum consecutive spikes allowed. + min_rel_scl : float or "auto", optional + Minimum relative scale for valid solutions. + max_iter_l0 : int, default=30 + Maximum iterations for L0 penalty optimization. + max_iter_penal : int, default=500 + Maximum iterations for penalty search (DIRECT). + max_iter_scal : int, default=50 + Maximum iterations for scale optimization. + delta_l0 : float, default=1e-4 + Threshold for L0 penalty soft-thresholding. + delta_penal : float, default=1e-4 + Tolerance for penalty search. + atol : float, default=1e-3 + Absolute tolerance for convergence. + rtol : float, default=1e-3 + Relative tolerance for convergence. + Hlim : int, default=1e5 + Maximum size for dense H matrix before switching to sparse. + dashboard : Dashboard, optional + Dashboard object for real-time visualization. + dashboard_uid : any, optional + Unique identifier for this solver in the dashboard. + + Attributes + ---------- + s : NDArray + Inferred spike train of shape (T,) or (len(nzidx_s),). + c : NDArray + Inferred calcium trace of shape (T,) or (len(nzidx_c),). + b : float + Inferred baseline offset. + H : sparse matrix + Convolution matrix. + G : sparse matrix + AR relationship matrix. + R : sparse matrix + Downsampling matrix. + + Examples + -------- + >>> solver = DeconvBin(y=fluorescence, tau=[10, 2], ps=[1, -1]) + >>> s, c, scale, err = solver.solve_scale() + """ + def __init__( self, - y: np.array = None, - y_len: int = None, - theta: np.array = None, - tau: np.array = None, - ps: np.array = None, - coef: np.array = None, + y: Optional[NDArray] = None, + y_len: Optional[int] = None, + theta: Optional[NDArray] = None, + tau: Optional[NDArray] = None, + ps: Optional[NDArray] = None, + coef: Optional[NDArray] = None, coef_len: int = 100, scale: float = 1, - penal: str = "l1", + penal: Optional[str] = "l1", use_base: bool = False, upsamp: int = 1, norm: str = "l2", mixin: bool = False, backend: str = "osqp", nthres: int = 1000, - err_weighting: str = None, - wt_trunc_thres: float = None, - masking_radius: int = None, + err_weighting: Optional[str] = None, + wt_trunc_thres: Optional[float] = None, + masking_radius: Optional[int] = None, pks_polish: bool = True, th_min: float = 0, th_max: float = 1, - density_thres: float = None, - ncons_thres: int = None, - min_rel_scl: float = None, + density_thres: Optional[float] = None, + ncons_thres: Optional[Union[int, str]] = None, + min_rel_scl: Optional[Union[float, str]] = None, max_iter_l0: int = 30, max_iter_penal: int = 500, max_iter_scal: int = 50, @@ -158,8 +438,8 @@ def __init__( atol: float = 1e-3, rtol: float = 1e-3, Hlim: int = 1e5, - dashboard=None, - dashboard_uid=None, + dashboard: Optional[Any] = None, + dashboard_uid: Optional[Any] = None, ) -> None: # book-keeping if y is not None: diff --git a/src/indeca/pipeline.py b/src/indeca/pipeline.py index 74c4854..f1689f2 100644 --- a/src/indeca/pipeline.py +++ b/src/indeca/pipeline.py @@ -1,8 +1,22 @@ +""" +Main processing pipeline for InDeCa spike inference. + +This module provides the main entry point for running the InDeCa (Interpretable +Deconvolution for Calcium Imaging) algorithm. It implements an iterative +binary pursuit pipeline that alternates between: +1. Spike inference (deconvolution) using the current kernel estimate +2. Kernel estimation using the inferred spikes + +The algorithm converges when spike patterns stabilize or error criteria are met. +""" + import warnings +from typing import Any, Optional, Tuple, Union import numpy as np import pandas as pd from line_profiler import profile +from numpy.typing import NDArray from scipy.signal import find_peaks, medfilt from tqdm.auto import tqdm, trange @@ -20,53 +34,153 @@ @profile def pipeline_bin( - Y, - up_factor=1, - p=2, - tau_init=None, - return_iter=False, - max_iters=50, - n_best=3, - use_rel_err=True, - err_atol=1e-4, - err_rtol=5e-2, - est_noise_freq=None, - est_use_smooth=False, - est_add_lag=20, - est_nevt=10, - med_wnd=None, - dff=True, - deconv_nthres=1000, - deconv_norm="l2", - deconv_atol=1e-3, - deconv_penal=None, - deconv_backend="osqp", - deconv_err_weighting=None, - deconv_use_base=True, - deconv_reset_scl=True, - deconv_masking_radius=None, - deconv_pks_polish=None, - deconv_ncons_thres=None, - deconv_min_rel_scl=None, - ar_use_all=True, - ar_kn_len=100, - ar_norm="l2", - ar_prop_best=None, - da_client=None, - spawn_dashboard=True, -): - """Binary pursuit pipeline for spike inference. + Y: NDArray, + up_factor: int = 1, + p: int = 2, + tau_init: Optional[NDArray] = None, + return_iter: bool = False, + max_iters: int = 50, + n_best: int = 3, + use_rel_err: bool = True, + err_atol: float = 1e-4, + err_rtol: float = 5e-2, + est_noise_freq: Optional[Tuple[float, float]] = None, + est_use_smooth: bool = False, + est_add_lag: int = 20, + est_nevt: Optional[int] = 10, + med_wnd: Optional[Union[int, str]] = None, + dff: bool = True, + deconv_nthres: int = 1000, + deconv_norm: str = "l2", + deconv_atol: float = 1e-3, + deconv_penal: Optional[str] = None, + deconv_backend: str = "osqp", + deconv_err_weighting: Optional[str] = None, + deconv_use_base: bool = True, + deconv_reset_scl: bool = True, + deconv_masking_radius: Optional[int] = None, + deconv_pks_polish: Optional[bool] = None, + deconv_ncons_thres: Optional[Union[int, str]] = None, + deconv_min_rel_scl: Optional[Union[float, str]] = None, + ar_use_all: bool = True, + ar_kn_len: int = 100, + ar_norm: str = "l2", + ar_prop_best: Optional[float] = None, + da_client: Optional[Any] = None, + spawn_dashboard: bool = True, +) -> Union[ + Tuple[NDArray, NDArray, pd.DataFrame], + Tuple[NDArray, NDArray, pd.DataFrame, list, list, list, list], +]: + """ + Binary pursuit pipeline for calcium imaging spike inference. + + Implements the InDeCa algorithm for inferring spike trains from calcium + fluorescence traces. The algorithm iteratively refines both spike estimates + and calcium dynamics parameters until convergence. Parameters ---------- - Y : array-like - Input fluorescence trace - ... + Y : NDArray + Input fluorescence traces of shape (n_cells, n_timepoints). + up_factor : int, default=1 + Temporal upsampling factor for sub-frame spike resolution. + p : int, default=2 + Order of the AR process (typically 2 for bi-exponential). + tau_init : NDArray, optional + Initial time constants [τ_d, τ_r]. If None, estimated from data. + return_iter : bool, default=False + If True, return intermediate results from all iterations. + max_iters : int, default=50 + Maximum number of iterations. + n_best : int, default=3 + Number of best previous solutions to combine for kernel update. + use_rel_err : bool, default=True + If True, use relative error as optimization objective. + err_atol : float, default=1e-4 + Absolute error tolerance for convergence. + err_rtol : float, default=5e-2 + Relative error tolerance for convergence. + est_noise_freq : tuple of float, optional + Frequency range for noise estimation during initialization. + est_use_smooth : bool, default=False + If True, smooth data before initial AR estimation. + est_add_lag : int, default=20 + Additional lags for Yule-Walker AR estimation. + est_nevt : int, optional, default=10 + Number of top events to use for kernel update. None uses all. + med_wnd : int or "auto", optional + Median filter window size for preprocessing. "auto" uses ar_kn_len. + dff : bool, default=True + If True, compute ΔF/F₀ preprocessing. + deconv_nthres : int, default=1000 + Number of threshold levels for binary pursuit. + deconv_norm : str, default="l2" + Error norm for deconvolution: "l1", "l2", or "huber". + deconv_atol : float, default=1e-3 + Absolute tolerance for deconvolution solver. + deconv_penal : str, optional + Sparsity penalty type: "l0", "l1", or None. + deconv_backend : str, default="osqp" + Optimization backend: "cvxpy", "osqp", "emosqp", or "cuosqp". + deconv_err_weighting : str, optional + Error weighting scheme: "fft", "corr", "adaptive", or None. + deconv_use_base : bool, default=True + If True, estimate baseline offset. + deconv_reset_scl : bool, default=True + If True, reset scale each iteration. + deconv_masking_radius : int, optional + Radius for search region masking. + deconv_pks_polish : bool, optional + If True, refine spike locations after optimization. + deconv_ncons_thres : int or "auto", optional + Maximum consecutive spikes allowed. + deconv_min_rel_scl : float or "auto", optional + Minimum relative scale for valid solutions. + ar_use_all : bool, default=True + If True, use all cells for shared kernel estimation. + ar_kn_len : int, default=100 + Kernel length in frames. + ar_norm : str, default="l2" + Error norm for kernel estimation. + ar_prop_best : float, optional + Proportion of best cells to use for kernel update. + da_client : distributed.Client, optional + Dask client for distributed computation. + spawn_dashboard : bool, default=True + If True, create interactive visualization dashboard. Returns ------- - dict - Dictionary containing results of the pipeline + opt_C : NDArray + Optimal calcium traces of shape (n_cells, T * up_factor). + opt_S : NDArray + Optimal spike trains of shape (n_cells, T * up_factor). + metric_df : pd.DataFrame + Iteration metrics including errors, time constants, and scales. + + If return_iter=True, also returns: + C_ls : list of NDArray + Calcium traces from each iteration. + S_ls : list of NDArray + Spike trains from each iteration. + h_ls : list of NDArray + Kernels from each iteration. + h_fit_ls : list of NDArray + Fitted kernels from each iteration. + + Notes + ----- + The algorithm converges when any of these conditions are met: + - Absolute error change < err_atol + - Relative error change < err_rtol * best_error + - Spike pattern unchanged from previous iteration + - Solution trapped in local optimum + + Examples + -------- + >>> C, S, metrics = pipeline_bin(fluorescence_data, up_factor=4, ar_kn_len=60) + >>> # S contains inferred spike trains at 4x temporal resolution """ logger.info("Starting binary pursuit pipeline") # 0. housekeeping diff --git a/src/indeca/simulation.py b/src/indeca/simulation.py index 2a056b0..6399409 100644 --- a/src/indeca/simulation.py +++ b/src/indeca/simulation.py @@ -1,5 +1,18 @@ -# %% import and definitions +""" +Simulation utilities for generating synthetic calcium imaging data. + +This module provides functions for simulating calcium imaging data including: +- Spike train generation using Markov chain models +- Calcium dynamics using autoregressive (AR) or bi-exponential kernels +- Spatial footprint generation for simulated neurons +- Full video simulation with motion artifacts and background signals + +The simulations are used for algorithm validation and benchmarking, allowing +comparison of inferred spikes against known ground truth. +""" + import warnings +from typing import Optional, Tuple, Union import dask.array as darr import numpy as np @@ -7,6 +20,8 @@ import sparse import xarray as xr from numpy import random +from numpy.random import Generator +from numpy.typing import NDArray from scipy.ndimage import gaussian_filter1d from scipy.optimize import root_scalar from scipy.stats import multivariate_normal @@ -19,9 +34,39 @@ def gauss_cell( sz_mean: float, sz_sigma: float, sz_min: float, - cent=None, - norm=True, -): + cent: Optional[NDArray] = None, + norm: bool = True, +) -> NDArray: + """ + Generate 2D Gaussian spatial footprints for simulated neurons. + + Creates spatial footprints by placing 2D Gaussian distributions at specified + or random locations. Each neuron's footprint is independently sized based on + the provided size distribution parameters. + + Parameters + ---------- + height : int + Height of the spatial field in pixels. + width : int + Width of the spatial field in pixels. + sz_mean : float + Mean size (variance) of the Gaussian footprints. + sz_sigma : float + Standard deviation of the size distribution. + sz_min : float + Minimum allowed size (variance) for footprints. + cent : NDArray, optional + Centroids of shape (n_cells, 2) specifying [row, col] positions. + If None, centroids are randomly generated. + norm : bool, default=True + If True, normalize each footprint to [0, 1] range. + + Returns + ------- + NDArray + Spatial footprints of shape (n_cells, height, width). + """ # generate centroid if cent is None: cent = np.atleast_2d([random.randint(height), random.randint(width)]) @@ -44,8 +89,37 @@ def gauss_cell( return A -# @nb.jit(nopython=True, nogil=True, cache=True) -def apply_arcoef(s: np.ndarray, g: np.ndarray, shifted: bool = False): +def apply_arcoef(s: NDArray, g: NDArray, shifted: bool = False) -> NDArray: + """ + Apply AR(2) coefficients to a spike train to generate calcium dynamics. + + Implements the autoregressive relationship: + c[t] = s[t] + g[0] * c[t-1] + g[1] * c[t-2] + + This models calcium indicator dynamics where calcium concentration at each + time point depends on the current spike and previous calcium values. + + Parameters + ---------- + s : NDArray + Spike train of shape (n_timepoints,). Can be binary (0/1) or continuous. + g : NDArray + AR(2) coefficients of shape (2,), where g[0] = γ₁ and g[1] = γ₂. + These determine the decay characteristics of the calcium response. + shifted : bool, default=False + If True, use spike from previous time point (s[t-1]) instead of s[t]. + This models a delay between spike and calcium response. + + Returns + ------- + NDArray + Calcium trace of shape (n_timepoints,). + + See Also + -------- + tau2AR : Convert time constants to AR coefficients. + apply_exp : Apply bi-exponential kernel via convolution. + """ c = np.zeros(len(s), dtype=float) for i in range(len(s)): if shifted: @@ -65,14 +139,56 @@ def apply_arcoef(s: np.ndarray, g: np.ndarray, shifted: bool = False): def apply_exp( - s: np.ndarray, + s: NDArray, tau_d: float, tau_r: float, p_d: float = 1, p_r: float = -1, - kn_len: int = None, - trunc_thres: float = None, -): + kn_len: Optional[int] = None, + trunc_thres: Optional[float] = None, +) -> NDArray: + """ + Apply bi-exponential kernel to a spike train via convolution. + + Convolves the spike train with a bi-exponential kernel of the form: + h(t) = p_d * exp(-t/τ_d) + p_r * exp(-t/τ_r) + + This models calcium indicator dynamics with distinct rise and decay phases. + + Parameters + ---------- + s : NDArray + Spike train of shape (n_timepoints,). + tau_d : float + Decay time constant in frames. Must be positive and > tau_r. + tau_r : float + Rise time constant in frames. Must be positive and < tau_d. + p_d : float, default=1 + Amplitude coefficient for decay component. + p_r : float, default=-1 + Amplitude coefficient for rise component. Typically negative to create + the characteristic rising phase. + kn_len : int, optional + Length of the kernel. If None, uses len(s). + trunc_thres : float, optional + Truncate kernel when amplitude falls below this threshold. + Improves computational efficiency for long traces. + + Returns + ------- + NDArray + Calcium trace of shape (n_timepoints,). + + Raises + ------ + ValueError + If tau_d is not positive. + + See Also + -------- + apply_arcoef : Apply AR coefficients directly. + tau2AR : Convert time constants to AR coefficients. + """ if kn_len is None: kn_len = len(s) t = np.arange(kn_len).astype(float) @@ -94,13 +210,49 @@ def apply_exp( def ar_trace( frame: int, - P: np.ndarray, - g: np.ndarray = None, - tau_d: float = None, - tau_r: float = None, + P: NDArray, + g: Optional[NDArray] = None, + tau_d: Optional[float] = None, + tau_r: Optional[float] = None, shifted: bool = False, - rng=None, -): + rng: Optional[Generator] = None, +) -> Tuple[NDArray, NDArray]: + """ + Generate a calcium trace with Markovian spike train using AR dynamics. + + Generates a spike train using a 2-state Markov chain and applies AR(2) + dynamics to produce a calcium trace. + + Parameters + ---------- + frame : int + Number of time frames to simulate. + P : NDArray + Markov transition matrix of shape (2, 2). P[i, j] is the probability + of transitioning from state i to state j. + g : NDArray, optional + AR(2) coefficients of shape (2,). If None, computed from tau_d and tau_r. + tau_d : float, optional + Decay time constant. Required if g is None. + tau_r : float, optional + Rise time constant. Required if g is None. + shifted : bool, default=False + If True, apply one-frame delay between spike and calcium response. + rng : Generator, optional + NumPy random generator for reproducibility. + + Returns + ------- + C : NDArray + Calcium trace of shape (frame,). + S : NDArray + Binary spike train of shape (frame,). + + See Also + -------- + exp_trace : Generate trace using bi-exponential convolution. + markov_fire : Generate Markovian spike train. + """ if g is None: g = np.array(tau2AR(tau_d, tau_r)) S = markov_fire(frame, P, rng=rng).astype(float) @@ -108,7 +260,39 @@ def ar_trace( return C, S -def exp_trace(frame: int, P: np.ndarray, tau_d: float, tau_r: float, trunc_thres=1e-6): +def exp_trace( + frame: int, P: NDArray, tau_d: float, tau_r: float, trunc_thres: float = 1e-6 +) -> Tuple[NDArray, NDArray]: + """ + Generate a calcium trace with Markovian spike train using bi-exponential kernel. + + Uses a 2-state Markov model to generate bursty spike trains, then convolves + with a bi-exponential kernel to produce realistic calcium dynamics. + + Parameters + ---------- + frame : int + Number of time frames to simulate. + P : NDArray + Markov transition matrix of shape (2, 2). + tau_d : float + Decay time constant in frames. + tau_r : float + Rise time constant in frames. + trunc_thres : float, default=1e-6 + Truncate kernel when amplitude falls below this threshold. + + Returns + ------- + C : NDArray + Calcium trace of shape (frame,). + S : NDArray + Binary spike train of shape (frame,). + + See Also + -------- + ar_trace : Generate trace using AR dynamics. + """ # uses a 2 state markov model to generate more 'bursty' spike trains S = markov_fire(frame, P).astype(float) t = np.arange(0, frame) @@ -121,7 +305,39 @@ def exp_trace(frame: int, P: np.ndarray, tau_d: float, tau_r: float, trunc_thres return C, S -def markov_fire(frame: int, P: np.ndarray, rng=None): +def markov_fire( + frame: int, P: NDArray, rng: Optional[Generator] = None +) -> NDArray: + """ + Generate a binary spike train using a 2-state Markov chain. + + Simulates neural firing as a two-state process where the transition + probabilities determine burst characteristics. Ensures at least one + spike is generated. + + Parameters + ---------- + frame : int + Number of time frames to simulate. + P : NDArray + Markov transition matrix of shape (2, 2). P[0, 1] controls the + probability of starting a spike from quiescence, and P[1, 1] + controls burst continuation probability. + + rng : Generator, optional + NumPy random generator for reproducibility. If None, creates a + new default generator. + + Returns + ------- + NDArray + Binary spike train of shape (frame,) with dtype int. + + Raises + ------ + AssertionError + If P is not shape (2, 2) or rows don't sum to 1. + """ if rng is None: rng = np.random.default_rng() # makes sure markov probabilities are correct shape @@ -140,15 +356,45 @@ def markov_fire(frame: int, P: np.ndarray, rng=None): def random_walk( - n_stp, + n_stp: int, stp_var: float = 1, constrain_factor: float = 0, - ndim=1, - norm=False, - integer=True, - nn=False, - smooth_var=None, -): + ndim: int = 1, + norm: bool = False, + integer: bool = True, + nn: bool = False, + smooth_var: Optional[float] = None, +) -> NDArray: + """ + Generate a random walk with optional constraints and smoothing. + + Used for simulating motion artifacts and background signal fluctuations. + + Parameters + ---------- + n_stp : int + Number of time steps. + stp_var : float, default=1 + Variance of step sizes (standard deviation of Gaussian steps). + constrain_factor : float, default=0 + Mean-reversion strength. If > 0, steps are biased toward origin. + Higher values produce more constrained walks. + ndim : int, default=1 + Number of dimensions for the walk. + norm : bool, default=False + If True, normalize output to [0, 1] range per dimension. + integer : bool, default=True + If True, round walk values to integers. + nn : bool, default=False + If True, clip negative values to zero (non-negative). + smooth_var : float, optional + If provided, apply Gaussian smoothing with this sigma. + + Returns + ------- + NDArray + Random walk of shape (n_stp, ndim). + """ if constrain_factor > 0: walk = np.zeros(shape=(n_stp, ndim)) for i in range(n_stp): @@ -179,13 +425,46 @@ def random_walk( def simulate_traces( num_cells: int, length_in_sec: float, - tmp_P: np.ndarray, + tmp_P: NDArray, tmp_tau_d: float, tmp_tau_r: float, approx_fps: float = 30, - spike_sampling_rate=500, + spike_sampling_rate: int = 500, noise: float = 0.01, -): +) -> pd.DataFrame: + """ + Simulate calcium traces for multiple cells with configurable parameters. + + Parameters + ---------- + num_cells : int + Number of cells to simulate. + length_in_sec : float + Duration of simulation in seconds. + tmp_P : NDArray + Markov transition matrix of shape (2, 2) for spike generation. + tmp_tau_d : float + Decay time constant in seconds. + tmp_tau_r : float + Rise time constant in seconds. + approx_fps : float, default=30 + Approximate frames per second for the output. + spike_sampling_rate : int, default=500 + Internal sampling rate for spike generation in Hz. + noise : float, default=0.01 + Standard deviation of additive Gaussian noise. + + Returns + ------- + pd.DataFrame + DataFrame with columns: C_true, S_true, C, S, C_noisy, fps, + upsample_factor, spike_sampling_rate. + + Notes + ----- + This function is marked for future integration with exp_trace and the + rest of the simulation pipeline. + """ # TODO: make this compatible with exp_trace and incorporate this with rest # of the simulation pipeline upsample_factor = np.round(spike_sampling_rate / approx_fps).astype(int) @@ -220,7 +499,7 @@ def simulate_data( sz_mean: float, sz_sigma: float, sz_min: float, - tmp_P: np.ndarray, + tmp_P: NDArray, tmp_tau_d: float, tmp_tau_r: float, post_offset: float, @@ -231,11 +510,88 @@ def simulate_data( bg_smth_var: float, mo_stp_var: float, mo_cons_fac: float = 1, - cent=None, - zero_thres=1e-8, - chk_size=1000, + cent: Optional[NDArray] = None, + zero_thres: float = 1e-8, + chk_size: int = 1000, upsample: int = 1, -): +) -> Union[ + Tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray], + Tuple[ + xr.DataArray, + xr.DataArray, + xr.DataArray, + xr.DataArray, + xr.DataArray, + xr.DataArray, + xr.DataArray, + ], +]: + """ + Simulate complete calcium imaging data including video, spatial footprints, and signals. + + Generates synthetic calcium imaging data with realistic characteristics including + spatial footprints, temporal dynamics, background fluctuations, motion artifacts, + and noise. + + Parameters + ---------- + ncell : int + Number of cells to simulate. + dims : dict + Dictionary with keys 'frame', 'height', 'width' specifying video dimensions. + sig_scale : float + Signal amplitude scaling factor. + sz_mean : float + Mean size of cell spatial footprints. + sz_sigma : float + Standard deviation of cell sizes. + sz_min : float + Minimum cell size. + tmp_P : NDArray + Markov transition matrix of shape (2, 2) for spike generation. + tmp_tau_d : float + Decay time constant in frames. + tmp_tau_r : float + Rise time constant in frames. + post_offset : float + Baseline offset added to video. + post_gain : float + Gain factor applied to video (for converting to uint8). + bg_nsrc : int + Number of background signal sources. + bg_tmp_var : float + Temporal variance of background signals. + bg_cons_fac : float + Constraint factor for background temporal dynamics. + bg_smth_var : float + Smoothing variance for background signals. + mo_stp_var : float + Step variance for motion simulation. + mo_cons_fac : float, default=1 + Constraint factor for motion. + cent : NDArray, optional + Predefined cell centroids of shape (ncell, 2). + zero_thres : float, default=1e-8 + Threshold below which spatial footprint values are set to zero. + chk_size : int, default=1000 + Chunk size for Dask arrays. + upsample : int, default=1 + Temporal upsampling factor for higher resolution spike timing. + + Returns + ------- + tuple + If upsample == 1: (Y, A, C, S, shifts) + If upsample > 1: (Y, A, C, S, C_true, S_true, shifts) + + Where: + - Y: Video data (frame, height, width) + - A: Spatial footprints (unit_id, height, width) + - C: Calcium traces (frame, unit_id) + - S: Spike trains (frame, unit_id) + - C_true, S_true: High-resolution versions when upsampled + - shifts: Motion shifts (frame, shift_dim) + """ ff, hh, ww = ( dims["frame"], dims["height"], @@ -405,7 +761,27 @@ def simulate_data( return Y, A, C, S, shifts -def generate_data(dpath, save_Y=False, **kwargs): +def generate_data(dpath: str, save_Y: bool = False, **kwargs) -> xr.Dataset: + """ + Generate and save simulated calcium imaging data to a NetCDF file. + + Wrapper around simulate_data that saves the results to disk. + + Parameters + ---------- + dpath : str + Path to save the NetCDF file. + save_Y : bool, default=False + If True, include the video data Y in the saved dataset. + Video data can be large, so it's excluded by default. + **kwargs + Additional arguments passed to simulate_data. + + Returns + ------- + xr.Dataset + The merged dataset containing all simulation outputs. + """ dat_vars = simulate_data(**kwargs) if not save_Y: dat_vars = dat_vars[1:] @@ -414,7 +790,49 @@ def generate_data(dpath, save_Y=False, **kwargs): return ds -def computeY(A, C, A_bg, C_bg, shifts, sig_scale, noise_scale, post_offset, post_gain): +def computeY( + A: NDArray, + C: NDArray, + A_bg: NDArray, + C_bg: NDArray, + shifts: NDArray, + sig_scale: float, + noise_scale: float, + post_offset: float, + post_gain: float, +) -> NDArray: + """ + Compute fluorescence video from spatial and temporal components. + + Combines cell signals, background signals, motion shifts, and noise to + generate a realistic calcium imaging video. Used as a Dask blockwise function. + + Parameters + ---------- + A : NDArray + Cell spatial footprints of shape (n_cells, height, width). + C : NDArray + Cell temporal signals of shape (n_frames, n_cells). + A_bg : NDArray + Background spatial footprints of shape (n_bg, height, width). + C_bg : NDArray + Background temporal signals of shape (n_frames, n_bg). + shifts : NDArray + Motion shifts of shape (n_frames, 2) for [height, width] shifts. + sig_scale : float + Signal amplitude scaling factor. + noise_scale : float + Standard deviation of additive Gaussian noise. + post_offset : float + Baseline offset added after scaling. + post_gain : float + Gain factor for converting to uint8 range. + + Returns + ------- + NDArray + Video data of shape (n_frames, height, width) with dtype uint8. + """ A, C, A_bg, C_bg, shifts = A[0], C[0], A_bg[0], C_bg[0], shifts[0] Y = sparse.tensordot(C, A, axes=1) Y *= sig_scale @@ -432,7 +850,43 @@ def computeY(A, C, A_bg, C_bg, shifts, sig_scale, noise_scale, post_offset, post return Y.astype(np.uint8) -def tau2AR(tau_d, tau_r, p=1, return_scl=False): +def tau2AR( + tau_d: float, tau_r: float, p: float = 1, return_scl: bool = False +) -> Union[Tuple[float, float], Tuple[float, float, float]]: + """ + Convert bi-exponential time constants to AR(2) coefficients. + + Transforms decay and rise time constants (τ_d, τ_r) into autoregressive + coefficients (θ₁, θ₂) that produce equivalent dynamics. + + The relationship is: + z₁ = exp(-1/τ_d), z₂ = exp(-1/τ_r) + θ₁ = z₁ + z₂, θ₂ = -z₁ * z₂ + + Parameters + ---------- + tau_d : float + Decay time constant in frames. + tau_r : float + Rise time constant in frames. + p : float, default=1 + Amplitude scaling factor for the bi-exponential. + return_scl : bool, default=False + If True, also return the scaling factor. + + Returns + ------- + theta0 : float + First AR coefficient (γ₁). + theta1 : float + Second AR coefficient (γ₂). + scl : float, optional + Scaling factor, returned if return_scl=True. + + See Also + -------- + AR2tau : Inverse conversion from AR coefficients to time constants. + """ z1, z2 = np.exp(-1 / tau_d), np.exp(-1 / tau_r) theta0, theta1 = np.real(z1 + z2), np.real(-z1 * z2) if theta1 == 0: @@ -447,7 +901,43 @@ def tau2AR(tau_d, tau_r, p=1, return_scl=False): return theta0, theta1 -def AR2tau(theta1, theta2, solve_amp: bool = False): +def AR2tau( + theta1: float, theta2: float, solve_amp: bool = False +) -> Union[Tuple[float, float], Tuple[float, float, float]]: + """ + Convert AR(2) coefficients to bi-exponential time constants. + + Inverse of tau2AR. Finds the roots of the characteristic polynomial + and converts to time constants. + + Parameters + ---------- + theta1 : float + First AR coefficient (γ₁). + theta2 : float + Second AR coefficient (γ₂). + solve_amp : bool, default=False + If True, also compute and return the amplitude scaling factor. + + Returns + ------- + tau_d : float + Decay time constant in frames. May be complex if AR process is oscillatory. + tau_r : float + Rise time constant in frames. May be complex if AR process is oscillatory. + p : float, optional + Amplitude scaling factor, returned if solve_amp=True. + + Notes + ----- + If the AR coefficients correspond to an oscillatory (underdamped) system, + the returned time constants will be complex numbers. + + See Also + -------- + tau2AR : Forward conversion from time constants to AR coefficients. + AR2exp : Full conversion including amplitude coefficients. + """ rts = np.roots([1, -theta1, -theta2]) z1, z2 = rts if np.imag(z1) == 0 and np.isclose(z1, 0) and z1 < 0: @@ -462,13 +952,67 @@ def AR2tau(theta1, theta2, solve_amp: bool = False): return tau_d, tau_r -def solve_p(tau_d, tau_r): +def solve_p(tau_d: float, tau_r: float) -> float: + """ + Compute amplitude scaling factor for bi-exponential kernel. + + Calculates the scaling factor p such that the bi-exponential kernel + h(t) = p * (exp(-t/τ_d) - exp(-t/τ_r)) integrates properly with the + AR representation. + + Parameters + ---------- + tau_d : float + Decay time constant in frames. + tau_r : float + Rise time constant in frames. + + Returns + ------- + float + Amplitude scaling factor. + + Raises + ------ + AssertionError + If the result is NaN or infinite. + """ p = 1 / (np.exp(-1 / tau_d) - np.exp(-1 / tau_r)) assert not (np.isnan(p) or np.isinf(p)) return p -def AR2exp(theta1, theta2): +def AR2exp( + theta1: float, theta2: float +) -> Tuple[bool, NDArray, NDArray]: + """ + Convert AR(2) coefficients to exponential representation with coefficients. + + Determines whether the AR process corresponds to real bi-exponential + dynamics or complex (oscillatory) dynamics, and returns the appropriate + parameters. + + Parameters + ---------- + theta1 : float + First AR coefficient (γ₁). + theta2 : float + Second AR coefficient (γ₂). + + Returns + ------- + is_biexp : bool + True if the dynamics are real bi-exponential, False if oscillatory. + tconst : NDArray + Time constants of shape (2,). If is_biexp=True, contains [τ_d, τ_r]. + If is_biexp=False, contains [a, b] for exp(at) * (cos(bt) + sin(bt)). + coef : NDArray + Amplitude coefficients of shape (2,) for the exponential terms. + + See Also + -------- + eval_exp : Evaluate the exponential representation at given times. + """ tau_d, tau_r = AR2tau(theta1, theta2) if np.imag(tau_d) == 0 and np.imag(tau_r) == 0: # real exponentials L = np.array([[1, 1], [np.exp(-1 / tau_d), np.exp(-1 / tau_r)]]) @@ -483,28 +1027,140 @@ def AR2exp(theta1, theta2): return False, np.array([a, b]), coef -def generate_pulse(nsamp): +def generate_pulse(nsamp: int) -> Tuple[NDArray, NDArray]: + """ + Generate a unit impulse (delta function) for kernel analysis. + + Parameters + ---------- + nsamp : int + Number of samples. + + Returns + ------- + pulse : NDArray + Impulse signal of shape (nsamp,) with pulse[0]=1 and zeros elsewhere. + t : NDArray + Time indices of shape (nsamp,). + """ t = np.arange(nsamp).astype(float) pulse = np.zeros_like(t) pulse[0] = 1 return pulse, t -def ar_pulse(theta1, theta2, nsamp, shifted=False): +def ar_pulse( + theta1: float, theta2: float, nsamp: int, shifted: bool = False +) -> Tuple[NDArray, NDArray, NDArray]: + """ + Compute the impulse response of an AR(2) process. + + Parameters + ---------- + theta1 : float + First AR coefficient (γ₁). + theta2 : float + Second AR coefficient (γ₂). + nsamp : int + Number of samples for the response. + shifted : bool, default=False + If True, apply one-sample delay. + + Returns + ------- + ar : NDArray + Impulse response of shape (nsamp,). + t : NDArray + Time indices of shape (nsamp,). + pulse : NDArray + Input impulse of shape (nsamp,). + + See Also + -------- + exp_pulse : Impulse response using bi-exponential convolution. + """ pulse, t = generate_pulse(nsamp) ar = apply_arcoef(pulse, np.array([theta1, theta2]), shifted=shifted) return ar, t, pulse def exp_pulse( - tau_d, tau_r, nsamp, p_d=1, p_r=-1, kn_len: int = None, trunc_thres: float = None -): + tau_d: float, + tau_r: float, + nsamp: int, + p_d: float = 1, + p_r: float = -1, + kn_len: Optional[int] = None, + trunc_thres: Optional[float] = None, +) -> Tuple[NDArray, NDArray, NDArray]: + """ + Compute the impulse response using bi-exponential convolution. + + Parameters + ---------- + tau_d : float + Decay time constant in frames. + tau_r : float + Rise time constant in frames. + nsamp : int + Number of samples for the response. + p_d : float, default=1 + Decay amplitude coefficient. + p_r : float, default=-1 + Rise amplitude coefficient. + kn_len : int, optional + Kernel length for convolution. + trunc_thres : float, optional + Threshold for kernel truncation. + + Returns + ------- + exp : NDArray + Impulse response of shape (nsamp,). + t : NDArray + Time indices of shape (nsamp,). + pulse : NDArray + Input impulse of shape (nsamp,). + + See Also + -------- + ar_pulse : Impulse response using AR coefficients. + """ pulse, t = generate_pulse(nsamp) exp = apply_exp(pulse, tau_d, tau_r, p_d, p_r, kn_len, trunc_thres) return exp, t, pulse -def eval_exp(t, is_biexp, tconst, coefs): +def eval_exp( + t: NDArray, is_biexp: bool, tconst: NDArray, coefs: NDArray +) -> NDArray: + """ + Evaluate exponential response at given time points. + + Computes the value of an exponential kernel (either bi-exponential or + oscillatory) at specified times. + + Parameters + ---------- + t : NDArray + Time points at which to evaluate. + is_biexp : bool + If True, use bi-exponential form. If False, use oscillatory form. + tconst : NDArray + Time constants of shape (2,). For bi-exponential: [τ_d, τ_r]. + For oscillatory: [a, b] where response is exp(at) * (c1*cos(bt) + c2*sin(bt)). + coefs : NDArray + Amplitude coefficients of shape (2,). + + Returns + ------- + NDArray + Evaluated response values at each time point. + + See Also + -------- + AR2exp : Convert AR coefficients to exponential parameters. + """ if is_biexp: tau_d, tau_r = tconst c1, c2 = coefs @@ -518,7 +1174,47 @@ def eval_exp(t, is_biexp, tconst, coefs): return np.exp(a * t) * (c1 * np.cos(b * t) + c2 * np.sin(b * t)) -def find_dhm(is_biexp, tconst, coefs, verbose=False): +def find_dhm( + is_biexp: bool, tconst: NDArray, coefs: NDArray, verbose: bool = False +) -> Tuple[Tuple[float, float], float]: + """ + Find Distance to Half Maximum (DHM) metrics for a calcium kernel. + + Computes temporal metrics that characterize kernel dynamics: + - DHM_r: Time to rise from baseline to half-maximum + - DHM_d: Time to decay from peak to half-maximum + + These metrics are robust to oscillatory tails and provide interpretable + measures of calcium indicator dynamics. + + Parameters + ---------- + is_biexp : bool + If True, kernel is bi-exponential. If False, it's oscillatory. + tconst : NDArray + Time constants of shape (2,). See eval_exp for interpretation. + coefs : NDArray + Amplitude coefficients of shape (2,). + verbose : bool, default=False + If True, print intermediate values for debugging. + + Returns + ------- + dhm : Tuple[float, float] + (DHM_r, DHM_d) - rise and decay half-max times. + t_peak : float + Time of peak amplitude. + + Raises + ------ + AssertionError + If root finding does not converge. + + Notes + ----- + DHM metrics are computed based on the first threshold-crossing in each + direction, making them robust to oscillatory behavior in the kernel tail. + """ if is_biexp: tau_d, tau_r = tconst c1, c2 = coefs @@ -567,7 +1263,29 @@ def find_dhm(is_biexp, tconst, coefs, verbose=False): return (rt0.root, rt1.root), t_hat -def shift_frame(fm, sh, fill=np.nan): +def shift_frame( + fm: NDArray, sh: NDArray, fill: float = np.nan +) -> NDArray: + """ + Shift a frame by integer offsets and fill edges. + + Applies circular shift (roll) to a frame and fills the vacated edges + with a specified value. Used for simulating motion artifacts. + + Parameters + ---------- + fm : NDArray + Frame to shift, can be 2D or 3D. + sh : NDArray + Shift values for each dimension, shape (ndim,). + fill : float, default=np.nan + Value to fill vacated edges. + + Returns + ------- + NDArray + Shifted frame with same shape as input. + """ if np.isnan(fm).all(): return fm sh = np.around(sh).astype(int) diff --git a/src/indeca/utils.py b/src/indeca/utils.py index f7565d6..e5240e8 100644 --- a/src/indeca/utils.py +++ b/src/indeca/utils.py @@ -1,10 +1,40 @@ +""" +Utility functions for signal processing and array manipulation. + +This module provides helper functions for normalization, scaling, least squares +fitting, and fluorescence signal preprocessing used throughout the InDeCa package. +""" + import itertools as itt +from typing import Generator, Iterable, Tuple import numpy as np import pandas as pd +from numpy.typing import ArrayLike, NDArray + + +def norm(a: ArrayLike) -> NDArray: + """ + Normalize an array to the range [0, 1] using min-max normalization. + + Handles the case where all values are equal by returning zeros. + + Parameters + ---------- + a : ArrayLike + Input array to normalize. Can contain NaN values. + Returns + ------- + NDArray + Normalized array with values in [0, 1]. If all input values are + equal, returns an array of zeros with the same shape. -def norm(a): + Examples + -------- + >>> norm(np.array([1, 2, 3, 4, 5])) + array([0. , 0.25, 0.5 , 0.75, 1. ]) + """ amin, amax = np.nanmin(a), np.nanmax(a) diff = amax - amin if diff > 0: @@ -13,7 +43,37 @@ def norm(a): return a - amin -def scal_lstsq(a, b, fit_intercept=False): +def scal_lstsq( + a: NDArray, b: NDArray, fit_intercept: bool = False +) -> NDArray: + """ + Solve a least squares scaling problem to find coefficients. + + Finds coefficients that minimize ||a @ coef - b||_2. + + Parameters + ---------- + a : NDArray + Design matrix of shape (n_samples,) or (n_samples, n_features). + If 1D, will be reshaped to (n_samples, 1). + b : NDArray + Target vector of shape (n_samples,) or (n_samples, 1). + fit_intercept : bool, default=False + If True, adds a column of ones to ``a`` to fit an intercept term. + + Returns + ------- + NDArray + Solution coefficients. Shape is (n_features,) or (n_features + 1,) + if ``fit_intercept=True``, where the last element is the intercept. + + Examples + -------- + >>> a = np.array([1, 2, 3, 4]) + >>> b = np.array([2, 4, 6, 8]) + >>> scal_lstsq(a, b) + array([2.]) + """ if a.ndim == 1: a = a.reshape((-1, 1)) if fit_intercept: @@ -21,7 +81,32 @@ def scal_lstsq(a, b, fit_intercept=False): return np.linalg.lstsq(a, b.squeeze(), rcond=None)[0] -def scal_like(src: np.ndarray, tgt: np.ndarray, zero_center=True): +def scal_like(src: NDArray, tgt: NDArray, zero_center: bool = True) -> NDArray: + """ + Scale the source array to match the range of the target array. + + Parameters + ---------- + src : NDArray + Source array to be scaled. + tgt : NDArray + Target array whose range is used for scaling. + zero_center : bool, default=True + If True, scales only by the range ratio (preserves zero). + If False, also shifts to match the target's minimum. + + Returns + ------- + NDArray + Scaled array with the same shape as ``src``. + + Examples + -------- + >>> src = np.array([0, 1, 2]) + >>> tgt = np.array([0, 10, 20]) + >>> scal_like(src, tgt, zero_center=True) + array([ 0., 10., 20.]) + """ smin, smax = np.nanmin(src), np.nanmax(src) tmin, tmax = np.nanmin(tgt), np.nanmax(tgt) if zero_center: @@ -30,11 +115,65 @@ def scal_like(src: np.ndarray, tgt: np.ndarray, zero_center=True): return (src - smin) / (smax - smin) * (tmax - tmin) + tmin -def enumerated_product(*args): +def enumerated_product( + *args: Iterable, +) -> Generator[Tuple[Tuple[int, ...], Tuple], None, None]: + """ + Generate the Cartesian product of iterables with their indices. + + Yields tuples containing both the indices and values from the + Cartesian product of the input iterables. + + Parameters + ---------- + *args : Iterable + Variable number of iterables to compute the product over. + + Yields + ------ + Tuple[Tuple[int, ...], Tuple] + A tuple of (indices, values) where indices is a tuple of integer + indices into each input iterable, and values is a tuple of the + corresponding elements. + + Examples + -------- + >>> list(enumerated_product(['a', 'b'], [1, 2])) + [((0, 0), ('a', 1)), ((0, 1), ('a', 2)), ((1, 0), ('b', 1)), ((1, 1), ('b', 2))] + """ yield from zip(itt.product(*(range(len(x)) for x in args)), itt.product(*args)) -def compute_dff(s, window_size=100, q=0.10): +def compute_dff( + s: ArrayLike, window_size: int = 100, q: float = 0.10 +) -> NDArray: + """ + Compute ΔF/F₀ (change in fluorescence) for a calcium signal. + + Estimates baseline fluorescence F₀ using a rolling quantile and computes + the difference from the signal. This is a common preprocessing step for + calcium imaging data to remove baseline drift. + + Parameters + ---------- + s : ArrayLike + Raw fluorescence signal, shape (n_timepoints,). + window_size : int, default=100 + Size of the rolling window for baseline estimation in frames. + q : float, default=0.10 + Quantile to use for baseline estimation (0 to 1). Lower values + are more robust to transient calcium events. + + Returns + ------- + NDArray + Baseline-subtracted fluorescence signal (F - F₀), same shape as input. + + Notes + ----- + This implementation returns F - F₀ rather than (F - F₀) / F₀ to avoid + division by small baseline values which can amplify noise. + """ s = pd.Series(s).astype(float) f0 = s.rolling(window=window_size, min_periods=1).quantile(q) # dff = (s - f0) / f0 From 9b351677abe222d413613f235220e0eb394dc010 Mon Sep 17 00:00:00 2001 From: phildong Date: Wed, 31 Dec 2025 23:31:27 +0000 Subject: [PATCH 2/2] style: format code with black --- src/indeca/simulation.py | 16 ++++------------ src/indeca/utils.py | 8 ++------ 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/indeca/simulation.py b/src/indeca/simulation.py index 6399409..266058d 100644 --- a/src/indeca/simulation.py +++ b/src/indeca/simulation.py @@ -305,9 +305,7 @@ def exp_trace( return C, S -def markov_fire( - frame: int, P: NDArray, rng: Optional[Generator] = None -) -> NDArray: +def markov_fire(frame: int, P: NDArray, rng: Optional[Generator] = None) -> NDArray: """ Generate a binary spike train using a 2-state Markov chain. @@ -982,9 +980,7 @@ def solve_p(tau_d: float, tau_r: float) -> float: return p -def AR2exp( - theta1: float, theta2: float -) -> Tuple[bool, NDArray, NDArray]: +def AR2exp(theta1: float, theta2: float) -> Tuple[bool, NDArray, NDArray]: """ Convert AR(2) coefficients to exponential representation with coefficients. @@ -1131,9 +1127,7 @@ def exp_pulse( return exp, t, pulse -def eval_exp( - t: NDArray, is_biexp: bool, tconst: NDArray, coefs: NDArray -) -> NDArray: +def eval_exp(t: NDArray, is_biexp: bool, tconst: NDArray, coefs: NDArray) -> NDArray: """ Evaluate exponential response at given time points. @@ -1263,9 +1257,7 @@ def find_dhm( return (rt0.root, rt1.root), t_hat -def shift_frame( - fm: NDArray, sh: NDArray, fill: float = np.nan -) -> NDArray: +def shift_frame(fm: NDArray, sh: NDArray, fill: float = np.nan) -> NDArray: """ Shift a frame by integer offsets and fill edges. diff --git a/src/indeca/utils.py b/src/indeca/utils.py index e5240e8..68a414c 100644 --- a/src/indeca/utils.py +++ b/src/indeca/utils.py @@ -43,9 +43,7 @@ def norm(a: ArrayLike) -> NDArray: return a - amin -def scal_lstsq( - a: NDArray, b: NDArray, fit_intercept: bool = False -) -> NDArray: +def scal_lstsq(a: NDArray, b: NDArray, fit_intercept: bool = False) -> NDArray: """ Solve a least squares scaling problem to find coefficients. @@ -144,9 +142,7 @@ def enumerated_product( yield from zip(itt.product(*(range(len(x)) for x in args)), itt.product(*args)) -def compute_dff( - s: ArrayLike, window_size: int = 100, q: float = 0.10 -) -> NDArray: +def compute_dff(s: ArrayLike, window_size: int = 100, q: float = 0.10) -> NDArray: """ Compute ΔF/F₀ (change in fluorescence) for a calcium signal.