From 6a1948294deb894bb6a706edf87a9280b2a3496e Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Fri, 12 Dec 2025 23:16:13 -0800 Subject: [PATCH 1/6] Initial refactor attempt --- pdm.lock | 102 +- pyproject.toml | 1 + src/indeca/core/AR_kernel.py | 2 +- src/indeca/core/deconv/__init__.py | 3 +- src/indeca/core/deconv/config.py | 90 ++ src/indeca/core/deconv/deconv.py | 1749 +++++++++------------------- src/indeca/core/deconv/solver.py | 840 +++++++++++++ src/indeca/core/deconv/utils.py | 121 ++ tests/unit/test_deconv_G_matrix.py | 59 + 9 files changed, 1766 insertions(+), 1201 deletions(-) create mode 100644 src/indeca/core/deconv/config.py create mode 100644 src/indeca/core/deconv/solver.py create mode 100644 src/indeca/core/deconv/utils.py create mode 100644 tests/unit/test_deconv_G_matrix.py diff --git a/pdm.lock b/pdm.lock index 7c4ca53..c6816d0 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "test"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:4f5ed9e8f685d76eb0697a6fe96b8bc2a91790bca75dbe477816cf4ca61a32e4" +content_hash = "sha256:117c9d0d5cbb1cee84326ef33261c4f32df6ced8abecbaad6fd6816b380514eb" [[metadata.targets]] requires_python = ">=3.11,<3.13" @@ -86,6 +86,19 @@ files = [ {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +requires_python = ">=3.8" +summary = "Reusable constraint types to use with typing.Annotated" +dependencies = [ + "typing-extensions>=4.0.0; python_version < \"3.9\"", +] +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + [[package]] name = "anyio" version = "4.9.0" @@ -2560,6 +2573,70 @@ files = [ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] +[[package]] +name = "pydantic" +version = "2.12.5" +requires_python = ">=3.9" +summary = "Data validation using Python type hints" +dependencies = [ + "annotated-types>=0.6.0", + "pydantic-core==2.41.5", + "typing-extensions>=4.14.1", + "typing-inspection>=0.4.2", +] +files = [ + {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"}, + {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"}, +] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +requires_python = ">=3.9" +summary = "Core functionality for Pydantic validation and serialization" +dependencies = [ + "typing-extensions>=4.14.1", +] +files = [ + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"}, + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"}, + {file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"}, +] + [[package]] name = "pyflakes" version = "3.1.0" @@ -3437,12 +3514,25 @@ files = [ [[package]] name = "typing-extensions" -version = "4.13.2" -requires_python = ">=3.8" -summary = "Backported and Experimental Type Hints for Python 3.8+" +version = "4.15.0" +requires_python = ">=3.9" +summary = "Backported and Experimental Type Hints for Python 3.9+" +files = [ + {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, + {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +requires_python = ">=3.9" +summary = "Runtime typing introspection tools" +dependencies = [ + "typing-extensions>=4.12.0", +] files = [ - {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, - {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, + {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, + {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index ba6027a..f9e4600 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "line-profiler>=4.1.3", "scikit-image>=0.25.2", "numba>=0.61.2", + "pydantic>=2.0", ] requires-python = ">=3.11,<3.13" readme = "README.md" diff --git a/src/indeca/core/AR_kernel.py b/src/indeca/core/AR_kernel.py index 60ea07d..7ef33bb 100644 --- a/src/indeca/core/AR_kernel.py +++ b/src/indeca/core/AR_kernel.py @@ -9,7 +9,7 @@ from scipy.optimize import curve_fit from statsmodels.tsa.stattools import acovf -from indeca.core.deconv.deconv import construct_G, construct_R +from indeca.core.deconv import construct_G, construct_R from indeca.core.simulation import AR2tau, ar_pulse, solve_p, tau2AR diff --git a/src/indeca/core/deconv/__init__.py b/src/indeca/core/deconv/__init__.py index da46a0c..0753809 100644 --- a/src/indeca/core/deconv/__init__.py +++ b/src/indeca/core/deconv/__init__.py @@ -1,3 +1,4 @@ -from .deconv import DeconvBin, construct_R, construct_G, max_thres, sum_downsample +from .deconv import DeconvBin +from .utils import construct_R, construct_G, max_thres, sum_downsample __all__ = ["DeconvBin", "construct_R", "construct_G", "max_thres", "sum_downsample"] diff --git a/src/indeca/core/deconv/config.py b/src/indeca/core/deconv/config.py new file mode 100644 index 0000000..d0d8eaa --- /dev/null +++ b/src/indeca/core/deconv/config.py @@ -0,0 +1,90 @@ +"""Configuration for deconv module.""" + +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field, model_validator + + +class DeconvConfig(BaseModel): + """Configuration for DeconvBin.""" + + model_config = {"frozen": True} + + coef_len: int = Field( + 100, description="Length of the coefficient kernel (e.g. calcium response)." + ) + scale: float = Field(1.0, description="Global scaling factor.") + penal: str | None = Field("l1", description="Penalty type ('l1', 'l0', etc.).") + use_base: bool = Field(False, description="Whether to include a baseline term.") + upsamp: int = Field(1, description="Upsampling factor.") + norm: Literal["l1", "l2", "huber"] = Field( + "l2", description="Norm for data fidelity ('l2', 'l1', 'huber')." + ) + mixin: bool = Field( + False, description="Whether to use mixed-integer programming (boolean spikes)." + ) + backend: Literal["osqp", "cvxpy", "cuosqp"] = Field( + "osqp", description="Solver backend ('osqp', 'cvxpy', 'cuosqp'). Note: emosqp requires codegen and is not supported." + ) + free_kernel: bool = Field( + False, description="If True, use convolution constraint instead of AR constraint. Only supported with OSQP backends." + ) + nthres: int = Field(1000, description="Number of thresholds for thresholding step.") + err_weighting: Optional[str] = Field( + None, description="Error weighting method ('fft', 'corr', 'adaptive', or None)." + ) + wt_trunc_thres: float = Field( + 1e-2, description="Threshold for truncating error weights." + ) + masking_radius: Optional[int] = Field( + None, description="Radius for masking around spikes." + ) + pks_polish: bool = Field(True, description="Whether to polish peaks after solving.") + th_min: float = Field(0.0, description="Minimum threshold.") + th_max: float = Field(1.0, description="Maximum threshold.") + density_thres: Optional[float] = Field( + None, description="Max spike density threshold." + ) + ncons_thres: Union[int, Literal["auto"], None] = Field( + None, description="Max consecutive spikes threshold. If 'auto', upsamp + 1." + ) + min_rel_scl: Union[float, Literal["auto"], None] = Field( + "auto", description="Minimum relative scale. Use None to disable." + ) + + max_iter_l0: int = 30 + max_iter_penal: int = 500 + max_iter_scal: int = 50 + delta_l0: float = 1e-4 + delta_penal: float = 1e-4 + atol: float = 1e-3 + rtol: float = 1e-3 + Hlim: int = 1e5 + + @model_validator(mode="before") + @classmethod + def resolve_auto_fields(cls, data): + # Resolve "auto" values before constructing the (frozen) model to avoid + # returning a new instance from an "after" validator (pydantic warns). + if not isinstance(data, dict): + return data + upsamp = data.get("upsamp", 1) + if data.get("min_rel_scl") == "auto": + data["min_rel_scl"] = 0.5 / upsamp + if data.get("ncons_thres") == "auto": + data["ncons_thres"] = upsamp + 1 + return data + + + @model_validator(mode="after") + def validate_penal(self): + allowed = {None, "l0", "l1"} + if self.penal not in allowed: + raise ValueError(f"Unsupported penal type: {self.penal}") + return self + + @model_validator(mode="after") + def validate_compat(self): + if self.free_kernel and self.backend == "cvxpy": + raise ValueError("free_kernel=True is not supported with backend='cvxpy'") + return self diff --git a/src/indeca/core/deconv/deconv.py b/src/indeca/core/deconv/deconv.py index 40fa2fc..fb04d43 100644 --- a/src/indeca/core/deconv/deconv.py +++ b/src/indeca/core/deconv/deconv.py @@ -1,139 +1,45 @@ +"""Main deconvolution module.""" + import itertools as itt +import math import warnings -from typing import Tuple +from typing import Tuple, Any, Optional -import cvxpy as cp import numpy as np -import osqp -import pandas as pd import scipy.sparse as sps -import xarray as xr -from numba import njit -from scipy.ndimage import label +import pandas as pd from scipy.optimize import direct -from scipy.signal import ShortTimeFFT, find_peaks -from scipy.special import huber +from scipy.signal import find_peaks +from scipy.ndimage import label from indeca.utils.logging_config import get_module_logger from indeca.core.simulation import AR2tau, ar_pulse, exp_pulse, solve_p, tau2AR from indeca.utils.utils import scal_lstsq +from .config import DeconvConfig +from .solver import DeconvSolver, CVXPYSolver, OSQPSolver +from .utils import max_thres, max_consecutive, sum_downsample # Initialize logger for this module logger = get_module_logger("deconv") logger.info("Deconv module initialized") -try: - import cuosqp -except ImportError: - logger.warning("No GPU solver support") - - -def construct_R(T: int, up_factor: int): - if up_factor > 1: - return sps.csc_matrix( - ( - np.ones(T * up_factor), - (np.repeat(np.arange(T), up_factor), np.arange(T * up_factor)), - ), - shape=(T, T * up_factor), - ) - else: - return sps.eye(T, format="csc") - - -def sum_downsample(a, factor): - return np.convolve(a, np.ones(factor), mode="full")[factor - 1 :: factor] - - -def construct_G(fac: np.ndarray, T: int, fromTau=False): - fac = np.array(fac) - assert fac.shape == (2,) - if fromTau: - fac = np.array(tau2AR(*fac)) - return sps.dia_matrix( - ( - np.tile(np.concatenate(([1], -fac)), (T, 1)).T, - -np.arange(len(fac) + 1), - ), - shape=(T, T), - ).tocsc() - - -def max_thres( - a: xr.DataArray, - nthres: int, - th_min=0.1, - th_max=0.9, - ds=None, - return_thres=False, - th_amplitude=False, - delta=1e-6, - reverse_thres=False, - nz_only: bool = False, -): - amax = a.max() - if reverse_thres: - thres = np.linspace(th_max, th_min, nthres) - else: - thres = np.linspace(th_min, th_max, nthres) - if th_amplitude: - S_ls = [np.floor_divide(a, (amax * th).clip(delta, None)) for th in thres] - else: - S_ls = [(a > (amax * th).clip(delta, None)) for th in thres] - if ds is not None: - S_ls = [sum_downsample(s, ds) for s in S_ls] - if nz_only: - Snz = [ss.sum() > 0 for ss in S_ls] - S_ls = [ss for ss, nz in zip(S_ls, Snz) if nz] - thres = [th for th, nz in zip(thres, Snz) if nz] - if return_thres: - return S_ls, thres - else: - return S_ls - - -@njit(nopython=True, nogil=True, cache=True) -def bin_convolve( - coef: np.ndarray, s: np.ndarray, nzidx_s: np.ndarray = None, s_len: int = None -): - coef_len = len(coef) - if s_len is None: - s_len = len(s) - out = np.zeros(s_len) - nzidx = np.where(s)[0] - if nzidx_s is not None: - nzidx = nzidx_s[nzidx].astype( - np.int64 - ) # astype to fix numpa issues on GPU on Windows - for i0 in nzidx: - i1 = min(i0 + coef_len, s_len) - clen = i1 - i0 - out[i0:i1] += coef[:clen] - return out - - -@njit(nopython=True, nogil=True, cache=True) -def max_consecutive(arr): - max_count = 0 - current_count = 0 - for value in arr: - if value: - current_count += 1 - max_count = max(max_count, current_count) - else: - current_count = 0 - return max_count - class DeconvBin: + """Deconvolution main class. + + This class wraps the solver backends and provides high-level methods + for spike inference including thresholding, penalty optimization, + and scale estimation. + """ + def __init__( self, - y: np.array = None, + y: np.ndarray = None, y_len: int = None, - theta: np.array = None, - tau: np.array = None, - ps: np.array = None, - coef: np.array = None, + theta: np.ndarray = None, + tau: np.ndarray = None, + ps: np.ndarray = None, + coef: np.ndarray = None, coef_len: int = 100, scale: float = 1, penal: str = "l1", @@ -142,6 +48,7 @@ def __init__( norm: str = "l2", mixin: bool = False, backend: str = "osqp", + free_kernel: bool = False, nthres: int = 1000, err_weighting: str = None, wt_trunc_thres: float = 1e-2, @@ -163,15 +70,25 @@ def __init__( dashboard=None, dashboard_uid=None, ) -> None: - # book-keeping + # Handle y input if y is not None: self.y_len = len(y) + self.y = y else: assert y_len is not None self.y_len = y_len + self.y = np.zeros(y_len) + if coef_len is not None and coef_len > self.y_len: warnings.warn("Coefficient length longer than data") coef_len = self.y_len + + # Store tau/theta/ps + self.theta = None + self.tau = None + self.ps = None + + # Compute coefficients from theta or tau if theta is not None: self.theta = np.array(theta) if tau is None: @@ -179,27 +96,19 @@ def __init__( self.tau = np.array([tau_d, tau_r]) self.ps = np.array([p, -p]) coef, _, _ = exp_pulse( - tau_d, - tau_r, - p_d=p, - p_r=-p, + tau_d, tau_r, p_d=p, p_r=-p, nsamp=coef_len * upsamp, kn_len=coef_len * upsamp, trunc_thres=atol, ) if tau is not None: - assert ( - ps is not None - ), "exp coefficients must be provided together with time constants." + assert ps is not None, "exp coefficients must be provided together with time constants." if theta is None: self.theta = np.array(tau2AR(tau[0], tau[1])) self.tau = np.array(tau) self.ps = ps coef, _, _ = exp_pulse( - tau[0], - tau[1], - p_d=ps[0], - p_r=ps[1], + tau[0], tau[1], p_d=ps[0], p_r=ps[1], nsamp=coef_len * upsamp, kn_len=coef_len * upsamp, trunc_thres=atol, @@ -207,164 +116,131 @@ def __init__( if coef is None: assert coef_len is not None coef = np.ones(coef_len * upsamp) + + # `coef_len` (config) is the *base* kernel length; the stored `coef` is + # already upsampled to length `coef_len * upsamp`. self.coef_len = len(coef) - self.T = self.y_len * upsamp - l0_penal = 0 - l1_penal = 0 - self.free_kernel = False - self.penal = penal - self.use_base = use_base - self.l0_penal = l0_penal - self.w_org = np.ones(self.T) - self.w = np.ones(self.T) - self.upsamp = upsamp - self.norm = norm - self.backend = backend - self.nthres = nthres - self.th_min = th_min - self.th_max = th_max - self.max_iter_l0 = max_iter_l0 - self.max_iter_penal = max_iter_penal - self.max_iter_scal = max_iter_scal - self.delta_l0 = delta_l0 - self.delta_penal = delta_penal - self.atol = atol - self.rtol = rtol - self.Hlim = Hlim + + # Create config (note: frozen after creation) + self.cfg = DeconvConfig( + coef_len=coef_len, + scale=scale, + penal=penal, + use_base=use_base, + upsamp=upsamp, + norm=norm, + mixin=mixin, + backend=backend, + free_kernel=free_kernel, + nthres=nthres, + err_weighting=err_weighting, + wt_trunc_thres=wt_trunc_thres, + masking_radius=masking_radius, + pks_polish=pks_polish, + th_min=th_min, + th_max=th_max, + density_thres=density_thres, + # IMPORTANT: preserve old semantics: None disables consecutive constraint, + # and "auto" enables the (upsamp + 1) default. + ncons_thres=ncons_thres, + min_rel_scl=min_rel_scl, + max_iter_l0=max_iter_l0, + max_iter_penal=max_iter_penal, + max_iter_scal=max_iter_scal, + delta_l0=delta_l0, + delta_penal=delta_penal, + atol=atol, + rtol=rtol, + Hlim=Hlim, + ) + + # Dashboard for visualization self.dashboard = dashboard self.dashboard_uid = dashboard_uid - self.nzidx_s = np.arange(self.T, dtype=np.int64) - self.nzidx_c = np.arange(self.T, dtype=np.int64) - self.x_cache = None - self.err_weighting = err_weighting - self.masking_r = masking_radius - self.pks_polish = pks_polish - self.err_wt = np.ones(self.y_len) - self.density_thres = density_thres - self.wt_trunc_thres = wt_trunc_thres - if ncons_thres == "auto": - self.ncons_thres = upsamp + 1 - else: - self.ncons_thres = ncons_thres - if min_rel_scl == "auto": - self.min_rel_scl = ( - 0.5 / self.upsamp - ) # 2 x upsamp number of spikes should be more than enough to explain highest peak - else: - self.min_rel_scl = min_rel_scl - if err_weighting == "fft": - self.stft = ShortTimeFFT(win=np.ones(self.coef_len), hop=1, fs=1) - self.yspec = self._get_stft_spec(y) - if y is not None: - self.huber_k = 0.5 * np.std(y) - self.err_total = self._res_err(y) - else: - self.huber_k = 0 - self.err_total = 0 - self._update_R() - # setup cvxpy - if self.backend == "cvxpy": - self.R = cp.Constant(self.R, name="R") - self.c = cp.Variable((self.T, 1), nonneg=True, name="c") - self.s = cp.Variable((self.T, 1), nonneg=True, name="s", boolean=mixin) - self.y = cp.Parameter(shape=(self.y_len, 1), name="y") - self.coef = cp.Parameter(value=coef, shape=self.coef_len, name="coef") - self.scale = cp.Parameter(value=scale, name="scale", nonneg=True) - self.l1_penal = cp.Parameter(value=l1_penal, name="l1_penal", nonneg=True) - self.l0_w = cp.Parameter( - shape=self.T, value=self.l0_penal * self.w, nonneg=True, name="w_l0" - ) # product of l0_penal * w! - if y is not None: - self.y.value = y.reshape((-1, 1)) - if coef is not None: - self.coef.value = coef - if use_base: - self.b = cp.Variable(nonneg=True, name="b") - else: - self.b = cp.Constant(value=0, name="b") - if norm == "l1": - self.err_term = cp.sum( - cp.abs(self.y - self.scale * self.R @ self.c - self.b) - ) - elif norm == "l2": - self.err_term = cp.sum_squares( - self.y - self.scale * self.R @ self.c - self.b - ) - elif norm == "huber": - self.err_term = cp.sum( - cp.huber(self.y - self.scale * self.R @ self.c - self.b) - ) - obj = cp.Minimize( - self.err_term - + self.l0_w.T @ cp.abs(self.s) - + self.l1_penal * cp.sum(cp.abs(self.s)) + + # Penalty tracking - solver tracks scale, we track penalty locally + self._l0_penal = 0.0 + self._l1_penal = 0.0 + + # Create solver + if self.cfg.backend == "cvxpy": + if self.cfg.free_kernel: + raise NotImplementedError("CVXPY backend does not support free_kernel mode") + self.solver = CVXPYSolver( + self.cfg, self.y_len, + y=self.y, coef=coef, theta=self.theta, + tau=self.tau, ps=self.ps + ) + elif self.cfg.backend in ["osqp", "cuosqp"]: + self.solver = OSQPSolver( + self.cfg, self.y_len, + y=self.y, coef=coef, theta=self.theta, + tau=self.tau, ps=self.ps ) - if self.free_kernel: - dcv_cons = [ - self.c[:, 0] == cp.convolve(self.coef, self.s[:, 0])[: self.T] - ] - else: - self.theta = cp.Parameter( - value=self.theta, shape=self.theta.shape, name="theta" - ) - G_diag = sps.eye(self.T - 1) + sum( - [ - cp.diag(cp.promote(-self.theta[i], (self.T - i - 2,)), -i - 1) - for i in range(self.theta.shape[0]) - ] - ) # diag part of unshifted G - G = cp.bmat( - [ - [np.zeros((self.T - 1, 1)), G_diag], - [np.zeros((1, 1)), np.zeros((1, self.T - 1))], - ] - ) - dcv_cons = [self.s == G @ self.c] - edge_cons = [self.c[0, 0] == 0, self.s[-1, 0] == 0] - amp_cons = [self.s <= 1] - self.prob_free = cp.Problem(obj, dcv_cons + edge_cons) - self.prob = cp.Problem(obj, dcv_cons + edge_cons + amp_cons) - self._update_HG() # self.H and self.G not used for cvxpy problems - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - # book-keeping - if y is None: - self.y = np.ones(self.y_len) - else: - self.y = y - if coef is None: - self.coef = np.ones(self.coef_len) - else: - self.coef = coef - self.c = np.zeros(self.T * upsamp) - self.s = np.zeros(self.T * upsamp) - self.s_bin = None - self.b = 0 - self.l1_penal = l1_penal - self.scale = scale - self.H = None - self._update_Wt() - self._setup_prob_osqp() + else: + raise ValueError(f"Unknown backend: {self.cfg.backend}") + + # State + self.T = self.solver.T + self.s = np.zeros(self.T) + self.b = 0 + self.c_bin = None + self.s_bin = None + # NOTE: do not store an "err_total" here. `_res_err` expects a residual, + # not the raw `y`, and this value was misleading and unused. + + # Update dashboard with initial kernel if self.dashboard is not None: - self.dashboard.update( - h=self.coef.value if backend == "cvxpy" else self.coef, - uid=self.dashboard_uid, - ) - tr_exp, _, _ = exp_pulse( - self.tau[0], - self.tau[1], - p_d=self.ps[0], - p_r=self.ps[1], - nsamp=self.coef_len, - ) - theta = self.theta.value if self.backend == "cvxpy" else self.theta - tr_ar, _, _ = ar_pulse(theta[0], theta[1], nsamp=self.coef_len, shifted=True) - assert (~np.isnan(coef)).all() - assert np.isclose( - tr_exp, coef, atol=self.atol - ).all(), "exp time constant inconsistent" - assert np.isclose( - tr_ar, coef, atol=self.atol - ).all(), "ar coefficients inconsistent" + self.dashboard.update(h=coef, uid=self.dashboard_uid) + + # Validate coefficients + self.solver.validate_coefficients(atol=atol) + + @property + def scale(self) -> float: + """Current scale value (delegated to solver).""" + return self.solver.scale + + @property + def H(self): + """Convolution matrix H.""" + return self.solver.H + + @property + def R(self): + """Resampling matrix R.""" + return self.solver.R + + @property + def R_org(self): + """Original (full) resampling matrix.""" + return self.solver.R_org + + @property + def nzidx_s(self): + """Nonzero indices for s.""" + return self.solver.nzidx_s + + @property + def nzidx_c(self): + """Nonzero indices for c.""" + return self.solver.nzidx_c + + @property + def coef(self): + """Coefficient kernel.""" + return self.solver.coef + + @property + def err_wt(self): + """Error weighting vector.""" + return self.solver.err_wt + + @err_wt.setter + def err_wt(self, value): + """Allow direct assignment (used by tests/demo code).""" + self.solver.err_wt = np.array(value) + self.solver.Wt = sps.diags(self.solver.err_wt) def update( self, @@ -380,178 +256,105 @@ def update( clear_weighting: bool = False, scale_coef: bool = False, ) -> None: - logger.debug( - f"Updating parameters - backend: {self.backend}, tau: {tau}, scale: {scale}, scale_mul: {scale_mul}, l0_penal: {l0_penal}, l1_penal: {l1_penal}" + """Update parameters.""" + logger.debug(f"Updating parameters - backend: {self.cfg.backend}") + + theta_new = None + if tau is not None: + theta_new = np.array(tau2AR(tau[0], tau[1])) + p = solve_p(tau[0], tau[1]) + coef_new, _, _ = exp_pulse( + tau[0], tau[1], p_d=p, p_r=-p, + nsamp=self.cfg.coef_len * self.cfg.upsamp, + kn_len=self.cfg.coef_len * self.cfg.upsamp + ) + coef = coef_new + self.tau = tau + self.theta = theta_new + self.ps = np.array([p, -p]) + + if coef is not None and scale_coef: + current_coef = self.solver.coef if self.solver.coef is not None else np.ones_like(coef) + scale_mul = scal_lstsq(coef, current_coef).item() + + if l0_penal is not None: + self._l0_penal = l0_penal + if l1_penal is not None: + self._l1_penal = l1_penal + + # Forward updates to solver (solver handles scale directly) + self.solver.update( + y=y, coef=coef, + scale=scale, + scale_mul=scale_mul, + l1_penal=self._l1_penal if l1_penal is not None else None, + l0_penal=self._l0_penal if l0_penal is not None else None, + w=w, + theta=theta_new if theta_new is not None else self.theta, + update_weighting=update_weighting, + clear_weighting=clear_weighting, + scale_coef=scale_coef, ) - if self.backend == "cvxpy": - if y is not None: - self.y.value = y - if tau is not None: - theta_new = np.array(tau2AR(tau[0], tau[1])) - p = solve_p(tau[0], tau[1]) - coef, _, _ = exp_pulse( - tau[0], - tau[1], - p_d=p, - p_r=-p, - nsamp=self.coef_len, - kn_len=self.coef_len, - ) - self.coef.value = coef - self.theta.value = theta_new - self._update_HG() - self._update_wgt_len() - if coef is not None: - if scale_coef: - scale_mul = scal_lstsq(coef, self.coef).item() - self.coef.value = coef - self._update_HG() - self._update_wgt_len() - if scale is not None: - self.scale.value = scale - if scale_mul is not None: - self.scale.value = scale_mul * self.scale.value - if l1_penal is not None: - self.l1_penal.value = l1_penal - if l0_penal is not None: - self.l0_penal = l0_penal - if w is not None: - self._update_w(w) - if l0_penal is not None or w is not None: - self.l0_w.value = self.l0_penal * self.w - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - # update input params - if y is not None: - self.y = y - if tau is not None: - theta_new = np.array(tau2AR(tau[0], tau[1])) - p = solve_p(tau[0], tau[1]) - coef, _, _ = exp_pulse( - tau[0], - tau[1], - p_d=p, - p_r=-p, - nsamp=self.coef_len, - kn_len=self.coef_len, - ) - self.tau = tau - self.ps = np.array([p, -p]) - self.theta = theta_new - if coef is not None: - if scale_coef: - scale_mul = scal_lstsq(coef, self.coef).item() - self.coef = coef - if scale is not None: - self.scale = scale - if scale_mul is not None: - self.scale = scale_mul * self.scale - if l1_penal is not None: - self.l1_penal = l1_penal - if l0_penal is not None: - self.l0_penal = l0_penal - if w is not None: - self._update_w(w) - # update internal variables - updt_HG, updt_P, updt_A, updt_q0, updt_q, updt_bounds, setup_prob = [ - False - ] * 7 - if coef is not None: - self._update_HG() - self._update_wgt_len() - updt_HG = True - if self.err_weighting is not None and update_weighting: - self._update_Wt(clear=clear_weighting) - if self.err_weighting == "adaptive": - setup_prob = True - else: - updt_P = True - updt_q0 = True - updt_q = True - if self.norm == "huber": - if any((scale is not None, scale_mul is not None, updt_HG)): - self._update_A() - updt_A = True - if any( - (w is not None, l0_penal is not None, l1_penal is not None, updt_HG) - ): - self._update_q() - updt_q = True - if y is not None: - self._update_bounds() - updt_bounds = True + + if y is not None: + self.y = y + + def _pad_s(self, s: np.ndarray = None) -> np.ndarray: + """Pad sparse s to full length.""" + return self.solver._pad_s(s) + + def _pad_c(self, c: np.ndarray = None) -> np.ndarray: + """Pad sparse c to full length.""" + return self.solver._pad_c(c) + + def _reset_cache(self) -> None: + """Reset solver cache.""" + self.solver.reset_cache() + + def _reset_mask(self) -> None: + """Reset solver mask to full range.""" + self.solver.reset_mask() + + def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> None: + """Update mask based on current solution.""" + # CVXPY doesn't support masking + if self.cfg.backend == "cvxpy": + return + + if self.cfg.backend in ["osqp", "cuosqp"]: + if use_wt: + nzidx_s = np.where(self.R.T @ self.err_wt)[0] else: - if any((updt_HG, updt_A)): - A_before = self.A.copy() - self._update_A() - assert self.A.shape == A_before.shape - assert (self.A.nonzero()[0] == A_before.nonzero()[0]).all() - assert (self.A.nonzero()[1] == A_before.nonzero()[1]).all() - updt_A = True - if any((scale is not None, scale_mul is not None, updt_HG, updt_P)): - P_before = self.P.copy() - self._update_P() - assert self.P.shape == P_before.shape - assert (self.P.nonzero()[0] == P_before.nonzero()[0]).all() - assert (self.P.nonzero()[1] == P_before.nonzero()[1]).all() - updt_P = True - if any( - ( - scale is not None, - scale_mul is not None, - y is not None, - updt_HG, - updt_q0, - ) - ): - q0_before = self.q0.copy() - self._update_q0() - assert self.q0.shape == q0_before.shape - updt_q0 = True - if any( - ( - w is not None, - l0_penal is not None, - l1_penal is not None, - updt_q0, - updt_q, - ) - ): - q_before = self.q.copy() - self._update_q() - assert self.q.shape == q_before.shape - updt_q = True - # update prob - logger.debug(f"Updating optimization problem with {self.backend}") - if self.backend == "emosqp": - if updt_P: - self.prob_free.update_P(self.P.data, None, 0) - self.prob.update_P(self.P.data, None, 0) - if updt_q: - self.prob_free.update_lin_cost(self.q) - self.prob.update_lin_cost(self.q) - elif self.backend in ["osqp", "cuosqp"] and any( - (updt_P, updt_q, updt_A, updt_bounds, setup_prob) - ): - if setup_prob: - self._setup_prob_osqp() + if self.cfg.masking_radius is not None: + mask = np.zeros(self.T) + for nzidx in np.where(self._pad_s(self.s_bin) > 0)[0]: + start = max(nzidx - self.cfg.masking_radius, 0) + end = min(nzidx + self.cfg.masking_radius, self.T) + mask[start:end] = 1 + nzidx_s = np.where(mask)[0] else: - self.prob_free.update( - Px=self.P.copy().data if updt_P else None, - q=self.q.copy() if updt_q else None, - Ax=self.A.copy().data if updt_A else None, - l=self.lb.copy() if updt_bounds else None, - u=self.ub_inf.copy() if updt_bounds else None, - ) - self.prob.update( - Px=self.P.copy().data if updt_P else None, - q=self.q.copy() if updt_q else None, - Ax=self.A.copy().data if updt_A else None, - l=self.lb.copy() if updt_bounds else None, - u=self.ub.copy() if updt_bounds else None, - ) - logger.debug("Optimization problem updated") + self._reset_mask() + opt_s, _ = self.solve(amp_constraint) + nzidx_s = np.where(opt_s > self.cfg.delta_penal)[0] + + if len(nzidx_s) == 0: + logger.warning("Empty mask, resetting") + self._reset_mask() + return + + self.solver.set_mask(nzidx_s) + + # Verify mask is valid + if not self.cfg.free_kernel and len(self.nzidx_c) < self.T: + res = self.solver.prob.solve() + if res.info.status == "primal infeasible": + logger.warning("Mask caused primal infeasibility, resetting") + self._reset_mask() + else: + raise NotImplementedError("Masking not supported for cvxpy backend") def _cut_pks_labs(self, s, labs, pks): + """Cut peak labels at valleys between peaks.""" pk_labs = np.full_like(labs, -1) lb = 0 for ilab in range(labs.max() + 1): @@ -564,7 +367,7 @@ def _cut_pks_labs(self, s, labs, pks): pk_labs[p_start:p_stop] = lb lb += 1 p_start = p_stop - pk_labs[p_stop : lb_idxs[-1] + 1] = lb + pk_labs[p_stop:lb_idxs[-1] + 1] = lb lb += 1 else: pk_labs[lb_idxs] = lb @@ -574,6 +377,8 @@ def _cut_pks_labs(self, s, labs, pks): def _merge_sparse_regs( self, s, regs, err_rtol=0, max_len=9, constraint_sum: bool = True ): + """Merge sparse regions to minimize error.""" + max_combos = 10_000 s_ret = s.copy() for r in range(regs.max() + 1): ridx = np.where(regs == r)[0] @@ -593,6 +398,9 @@ def _merge_sparse_regs( else: ns_vals = list(range(ns_min, rlen + 1)) for ns in ns_vals: + # Defensive guard: avoid combinatorial explosion. + if math.comb(rlen, ns) > max_combos: + continue for idxs in itt.combinations(ridx, ns): idxs = np.array(idxs) s_test = s_new.copy() @@ -600,28 +408,15 @@ def _merge_sparse_regs( err_after = self._compute_err(s=s_test[self.nzidx_s]) err_ls.append(err_after) idx_ls.append(idxs) - err_min_idx = np.argmin(err_ls) - err_min = err_ls[err_min_idx] - if err_min - err_before < err_rtol * err_before: - idx_min = idx_ls[err_min_idx] - s_new[idx_min] = rsum / len(idx_min) - s_ret = s_new - return s_ret - - def _pad_s(self, s=None): - if s is None: - s = self.s - s_ret = np.zeros(self.T) - s_ret[self.nzidx_s] = s + if len(err_ls) > 0: + err_min_idx = np.argmin(err_ls) + err_min = err_ls[err_min_idx] + if err_min - err_before < err_rtol * err_before: + idx_min = idx_ls[err_min_idx] + s_new[idx_min] = rsum / len(idx_min) + s_ret = s_new return s_ret - def _pad_c(self, c=None): - if c is None: - c = self.s - c_ret = np.zeros(self.T) - c_ret[self.nzidx_c] = c - return c_ret - def solve( self, amp_constraint: bool = True, @@ -630,57 +425,61 @@ def solve( pks_delta: float = 1e-5, pks_err_rtol: float = 10, pks_cut: bool = False, - ) -> np.ndarray: - if self.l0_penal == 0: - opt_s, opt_b = self._solve( - amp_constraint=amp_constraint, update_cache=update_cache - ) + ) -> Tuple[np.ndarray, float]: + """Solve main routine (l0 heuristic wrapper).""" + if self._l0_penal == 0: + opt_s, opt_b, _ = self.solver.solve(amp_constraint=amp_constraint) else: + # L0 heuristic via reweighted L1 metric_df = None - for i in range(self.max_iter_l0): - cur_s, cur_obj = self._solve(amp_constraint, return_obj=True) + for i in range(self.cfg.max_iter_l0): + cur_s, cur_b, _ = self.solver.solve(amp_constraint=amp_constraint) + # Compute objective explicitly since solver returns 0 + cur_obj = self._compute_err(s=cur_s, b=cur_b) + if metric_df is None: obj_best = np.inf obj_last = np.inf else: - obj_best = metric_df["obj"][1:].min() + obj_best = metric_df["obj"][1:].min() if len(metric_df) > 1 else np.inf obj_last = np.array(metric_df["obj"])[-1] - opt_s = np.where(cur_s > self.delta_l0, cur_s, 0) + + opt_s = np.where(cur_s > self.cfg.delta_l0, cur_s, 0) obj_gap = np.abs(cur_obj - obj_best) obj_delta = np.abs(cur_obj - obj_last) - cur_met = pd.DataFrame( - [ - { - "iter": i, - "obj": cur_obj, - "nnz": (opt_s > 0).sum(), - "obj_gap": obj_gap, - "obj_delta": obj_delta, - } - ] - ) + + cur_met = pd.DataFrame([{ + "iter": i, + "obj": cur_obj, + "nnz": (opt_s > 0).sum(), + "obj_gap": obj_gap, + "obj_delta": obj_delta, + }]) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - if any((obj_gap < self.rtol * np.obj_best, obj_delta < self.atol)): + + if any([ + obj_gap < self.cfg.rtol * obj_best, + obj_delta < self.cfg.atol + ]): break else: - self.update( - w=np.clip( - np.ones(self.T) / (self.delta_l0 * np.ones(self.T) + opt_s), - 0, - 1e5, - ) - ) # clip to avoid numerical issues - else: - warnings.warn( - "l0 heuristic did not converge in {} iterations".format( - self.max_iter_l0 + w_new = np.clip( + np.ones(self.T) / (self.cfg.delta_l0 * np.ones(self.T) + opt_s), + 0, 1e5 ) - ) + self.update(w=w_new) + else: + warnings.warn(f"l0 heuristic did not converge in {self.cfg.max_iter_l0} iterations") + + opt_s, opt_b, _ = self.solver.solve(amp_constraint=amp_constraint) + self.b = opt_b + + # Peak polishing if pks_polish is None: pks_polish = amp_constraint - if pks_polish and self.backend != "cvxpy": - s_pad = self._pad_s(s=opt_s) + if pks_polish and self.cfg.backend != "cvxpy": + s_pad = self._pad_s(opt_s) if len(opt_s) == len(self.nzidx_s) else opt_s s_ft = np.where(s_pad > pks_delta, s_pad, 0) labs, _ = label(s_ft) labs = labs - 1 @@ -688,37 +487,115 @@ def solve( pks_idx, _ = find_peaks(s_ft) labs = self._cut_pks_labs(s=s_ft, labs=labs, pks=pks_idx) opt_s = self._merge_sparse_regs(s=s_ft, regs=labs, err_rtol=pks_err_rtol) - opt_s = opt_s[self.nzidx_s] + if len(opt_s) == self.T: + opt_s = opt_s[self.nzidx_s] + self.s = np.abs(opt_s) return self.s, self.b + def _compute_c(self, s: np.ndarray = None) -> np.ndarray: + """Compute c from s via convolution.""" + if s is not None: + return self.solver.convolve(s) + else: + return self.solver.convolve(self.s) + + def _res_err(self, r: np.ndarray) -> float: + """Compute residual error.""" + if self.err_wt is not None: + r = self.err_wt * r + if self.cfg.norm == "l1": + return np.sum(np.abs(r)) + elif self.cfg.norm == "l2": + return np.sum(r ** 2) + elif self.cfg.norm == "huber": + # True Huber loss: + # 0.5*r^2 if |r| <= k + # k*(|r| - 0.5*k) otherwise + k = float(self.solver.huber_k) + ar = np.abs(r) + quad = 0.5 * (r ** 2) + lin = k * (ar - 0.5 * k) + return float(np.sum(np.where(ar <= k, quad, lin))) + + def _compute_err( + self, + y_fit: np.ndarray = None, + b: float = None, + c: np.ndarray = None, + s: np.ndarray = None, + res: np.ndarray = None, + obj_crit: str = None, + ) -> float: + """Compute error/objective value.""" + y = np.array(self.y) + if res is not None: + y = y - res + if b is None: + b = self.b + y = y - b + + if y_fit is None: + if c is None: + c = self._compute_c(s) + R = self.R + if sps.issparse(c): + y_fit = np.array((R @ c * self.scale).todense()).squeeze() + else: + y_fit = np.array(R @ c * self.scale).squeeze() + + r = y - y_fit + err = self._res_err(r) + + if obj_crit in [None, "spk_diff"]: + return float(err) + else: + nspk = (s > 0).sum() if s is not None else (self.s > 0).sum() + if obj_crit == "mean_spk": + err_total = self._res_err(y - y.mean()) + return float((err - err_total) / max(nspk, 1)) + elif obj_crit in ["aic", "bic"]: + T = len(r) + mu = r.mean() + sigma = max(((r - mu) ** 2).sum() / T, 1e-10) + logL = -0.5 * (T * np.log(2 * np.pi * sigma) + 1 / sigma * ((r - mu) ** 2).sum()) + if obj_crit == "aic": + return float(2 * (nspk - logL)) + elif obj_crit == "bic": + return float(nspk * np.log(T) - 2 * logL) + return float(err) + def _max_thres(self, s, nz_only=True): + """Apply max thresholding to solution.""" S_ls, thres = max_thres( - s, - nthres=self.nthres, - th_min=self.th_min, - th_max=self.th_max, + np.array(s), + nthres=self.cfg.nthres, + th_min=self.cfg.th_min, + th_max=self.cfg.th_max, reverse_thres=True, return_thres=True, nz_only=nz_only, ) - if self.density_thres is not None: + # Ensure we return numpy arrays (not xarray.DataArray) + S_ls = [np.array(ss) for ss in S_ls] + + # Apply density threshold + if self.cfg.density_thres is not None: Sden = [ss.sum() / self.T for ss in S_ls] - S_ls = [ss for ss, den in zip(S_ls, Sden) if den < self.density_thres] - thres = [th for th, den in zip(thres, Sden) if den < self.density_thres] - if self.ncons_thres is not None: + S_ls = [ss for ss, den in zip(S_ls, Sden) if den < self.cfg.density_thres] + thres = [th for th, den in zip(thres, Sden) if den < self.cfg.density_thres] + + # Apply consecutive threshold + if self.cfg.ncons_thres is not None: S_pad = [self._pad_s(ss) for ss in S_ls] - Sncons = [max_consecutive(ss) for ss in S_pad] - if min(Sncons) < self.ncons_thres: - S_ls = [ - ss for ss, ncons in zip(S_ls, Sncons) if ncons <= self.ncons_thres - ] - thres = [ - th for th, ncons in zip(thres, Sncons) if ncons <= self.ncons_thres - ] - else: + Sncons = [max_consecutive(np.array(ss)) for ss in S_pad] + if len(Sncons) > 0 and min(Sncons) < self.cfg.ncons_thres: + S_ls = [ss for ss, ncons in zip(S_ls, Sncons) if ncons <= self.cfg.ncons_thres] + thres = [th for th, ncons in zip(thres, Sncons) if ncons <= self.cfg.ncons_thres] + elif len(S_ls) > 0: S_ls = [S_ls[0]] thres = [thres[0]] + return S_ls, thres def solve_thres( @@ -729,17 +606,21 @@ def solve_thres( return_intm: bool = False, pks_polish: bool = None, obj_crit: str = None, - ) -> Tuple[np.ndarray]: - if self.backend == "cvxpy": - y = np.array(self.y.value.squeeze()) - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - y = np.array(self.y) + ) -> Tuple[np.ndarray, ...]: + """Solve with thresholding.""" + y = np.array(self.y) opt_s, opt_b = self.solve(amp_constraint=amp_constraint, pks_polish=pks_polish) - R = self.R.value if self.backend == "cvxpy" else self.R + R = self.R + if ignore_res: - res = y - opt_b - self.scale * R @ self._compute_c(opt_s) + c = self._compute_c(opt_s) + if sps.issparse(c): + res = y - opt_b - self.scale * np.array((R @ c).todense()).squeeze() + else: + res = y - opt_b - self.scale * (R @ c).squeeze() else: res = np.zeros_like(y) + svals, thres = self._max_thres(opt_s) if not len(svals) > 0: if return_intm: @@ -751,31 +632,47 @@ def solve_thres( 0, np.inf, ) + cvals = [self._compute_c(s) for s in svals] - yfvals = [np.array((R @ c).todense()).squeeze() for c in cvals] + + def to_arr(m): + return np.array(m.todense()).squeeze() if sps.issparse(m) else np.array(m).squeeze() + + yfvals = [to_arr(R @ c) for c in cvals] + if scaling: scal_fit = [scal_lstsq(yf, y - res, fit_intercept=True) for yf in yfvals] scals = [sf[0] for sf in scal_fit] bs = [sf[1] for sf in scal_fit] - if self.min_rel_scl is not None: - scl_thres = np.max(y) * self.min_rel_scl - valid_idx = np.where(scals > scl_thres)[0] + + if self.cfg.min_rel_scl is not None: + scl_thres = np.max(y) * self.cfg.min_rel_scl + valid_idx = np.where(np.array(scals) > scl_thres)[0] if len(valid_idx) > 0: scals = [scals[i] for i in valid_idx] bs = [bs[i] for i in valid_idx] + svals = [svals[i] for i in valid_idx] + cvals = [cvals[i] for i in valid_idx] + yfvals = [yfvals[i] for i in valid_idx] else: max_idx = np.argmax(scals) scals = [scals[max_idx]] bs = [bs[max_idx]] + svals = [svals[max_idx]] + cvals = [cvals[max_idx]] + yfvals = [yfvals[max_idx]] else: scals = [self.scale] * len(yfvals) bs = [(y - res - scl * yf).mean() for scl, yf in zip(scals, yfvals)] + objs = [ self._compute_err(s=ss, y_fit=scl * yf, res=res, b=bb, obj_crit=obj_crit) for ss, scl, yf, bb in zip(svals, scals, yfvals, bs) ] + scals = np.array(scals).clip(0, None) objs = np.where(scals > 0, objs, np.inf) + if obj_crit == "spk_diff": err_null = self._compute_err( s=np.zeros_like(opt_s), res=res, b=opt_b, obj_crit=obj_crit @@ -784,23 +681,23 @@ def solve_thres( nspk = np.array([0] + [(ss > 0).sum() for ss in svals]) objs_diff = np.diff(objs_pad) nspk_diff = np.diff(nspk) + nspk_diff = np.where(nspk_diff == 0, 1, nspk_diff) # Avoid division by zero merr_diff = objs_diff / nspk_diff - avg_err = (objs_pad.min() - err_null) / nspk.max() - opt_idx = np.max(np.where(merr_diff < avg_err)) + avg_err = (objs_pad.min() - err_null) / max(nspk.max(), 1) + opt_idx = int(np.max(np.where(merr_diff < avg_err)[0])) if np.any(merr_diff < avg_err) else 0 objs = merr_diff else: - opt_idx = np.argmin(objs) + opt_idx = int(np.argmin(objs)) + s_bin = svals[opt_idx] self.s_bin = s_bin - assert len(s_bin) == len(self.nzidx_s) - self.c_bin = np.array(cvals[opt_idx].todense()).squeeze() + self.c_bin = to_arr(cvals[opt_idx]) self.b = bs[opt_idx] + self.solver.s_bin = s_bin # Update solver's s_bin for adaptive weighting + if return_intm: return ( - self.s_bin, - self.c_bin, - scals[opt_idx], - objs[opt_idx], + self.s_bin, self.c_bin, scals[opt_idx], objs[opt_idx], (opt_s, thres, svals, cvals, yfvals, scals, objs, opt_idx), ) else: @@ -808,86 +705,87 @@ def solve_thres( def solve_penal( self, masking=True, scaling=True, return_intm=False, pks_polish=None - ) -> Tuple[np.ndarray]: - if self.penal is None: + ) -> Tuple[np.ndarray, ...]: + """Solve with penalty optimization via DIRECT.""" + if self.cfg.penal is None: opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( scaling=scaling, return_intm=return_intm, pks_polish=pks_polish ) opt_penal = 0 - elif self.penal in ["l0", "l1"]: - pn = "{}_penal".format(self.penal) - self.update(**{pn: 0}) - if masking: - self._reset_cache() - self._update_mask() - s_nopn, _, _, err_nopn, intm = self.solve_thres( - scaling=scaling, return_intm=True, pks_polish=pks_polish - ) - s_min = intm[0] - ymean = self.y.mean() - err_full = self._compute_err(s=np.zeros(len(self.nzidx_s)), b=ymean) - err_min = self._compute_err(s=s_min) - ub, ub_last = err_full, err_full - for _ in range(int(np.ceil(np.log2(ub)))): - self.update(**{pn: ub}) - s, b = self.solve(pks_polish=pks_polish) - cur_err = self._compute_err(s=s, b=b) - # DIRECT finds weird solutions with high penalty and baseline, - # so we want to eliminate those possibilities - if np.abs(cur_err - err_min) < 0.5 * np.abs(err_full - err_min): - ub = ub_last - break - else: - ub_last = ub - ub = ub / 2 - - def opt_fn(x): - self.update(**{pn: x.item()}) - _, _, _, obj = self.solve_thres(scaling=False, pks_polish=pks_polish) - if self.dashboard is not None: - self.dashboard.update( - uid=self.dashboard_uid, - penal_err={"penal": x.item(), "scale": self.scale, "err": obj}, - ) - if obj < err_full: - return obj - else: - return np.inf - + if return_intm: + return opt_s, opt_c, opt_scl, opt_obj, opt_penal, None + return opt_s, opt_c, opt_scl, opt_obj, opt_penal + + pn = f"{self.cfg.penal}_penal" + self.update(**{pn: 0}) + + if masking: + self._reset_cache() + self._update_mask() + + s_nopn, _, _, err_nopn, intm = self.solve_thres( + scaling=scaling, return_intm=True, pks_polish=pks_polish + ) + s_min = intm[0] + ymean = self.y.mean() + err_full = self._compute_err(s=np.zeros(len(self.nzidx_s)), b=ymean) + err_min = self._compute_err(s=s_min) + + # Find upper bound for penalty + ub, ub_last = err_full, err_full + for _ in range(int(np.ceil(np.log2(ub + 1)))): + self.update(**{pn: ub}) + s, b = self.solve(pks_polish=pks_polish) + cur_err = self._compute_err(s=s, b=b) + if np.abs(cur_err - err_min) < 0.5 * np.abs(err_full - err_min): + ub = ub_last + break + else: + ub_last = ub + ub = ub / 2 + + def opt_fn(x): + self.update(**{pn: float(x)}) + _, _, _, obj = self.solve_thres(scaling=False, pks_polish=pks_polish) + if self.dashboard is not None: + self.dashboard.update( + uid=self.dashboard_uid, + penal_err={"penal": float(x), "scale": self.scale, "err": obj} + ) + return obj if obj < err_full else np.inf + + try: res = direct( opt_fn, - bounds=[(0, ub)], - maxfun=self.max_iter_penal, + bounds=[(0, max(ub, 1e-6))], + maxfun=self.cfg.max_iter_penal, locally_biased=False, vol_tol=1e-2, ) direct_pn = res.x if not res.success: - logger.warning( - "could not find optimal penalty within {} iterations".format( - res.nfev - ) - ) + logger.warning(f"Could not find optimal penalty within {res.nfev} iterations") opt_penal = 0 elif err_nopn <= opt_fn(direct_pn): - # DIRECT seem to mistakenly report high penalty when 0 penalty attains better error opt_penal = 0 else: - opt_penal = direct_pn.item() - self.update(**{pn: opt_penal}) - if return_intm: - opt_s, opt_c, opt_scl, opt_obj, intm = self.solve_thres( - scaling=scaling, return_intm=return_intm, pks_polish=pks_polish - ) - else: - opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( - scaling=scaling, return_intm=return_intm, pks_polish=pks_polish - ) - if opt_scl == 0: - logger.warning("could not find non-zero solution") + opt_penal = float(direct_pn) + except Exception as e: + logger.warning(f"DIRECT optimization failed: {e}") + opt_penal = 0 + + self.update(**{pn: opt_penal}) if return_intm: + opt_s, opt_c, opt_scl, opt_obj, intm = self.solve_thres( + scaling=scaling, return_intm=True, pks_polish=pks_polish + ) return opt_s, opt_c, opt_scl, opt_obj, opt_penal, intm else: + opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( + scaling=scaling, pks_polish=pks_polish + ) + if opt_scl == 0: + logger.warning("Could not find non-zero solution") return opt_s, opt_c, opt_scl, opt_obj, opt_penal def solve_scale( @@ -898,32 +796,35 @@ def solve_scale( obj_crit: str = None, early_stop: bool = True, masking: bool = True, - ) -> Tuple[np.ndarray]: - if self.penal in ["l0", "l1"]: - pn = "{}_penal".format(self.penal) + ) -> Tuple[np.ndarray, ...]: + """Solve with iterative scale estimation.""" + if self.cfg.penal in ["l0", "l1"]: + pn = f"{self.cfg.penal}_penal" self.update(**{pn: 0}) + self._reset_cache() self._reset_mask() + if reset_scale: self.update(scale=1) s_free, _ = self.solve(amp_constraint=False) self.update(scale=np.ptp(s_free)) - else: - s_free = np.zeros(len(self.nzidx_s)) + metric_df = None - for i in range(self.max_iter_scal): + for i in range(self.cfg.max_iter_scal): if concur_penal: cur_s, cur_c, cur_scl, cur_obj_raw, cur_penal = self.solve_penal( scaling=i > 0, - pks_polish=self.pks_polish and (i > 1 or not reset_scale), + pks_polish=self.cfg.pks_polish and (i > 1 or not reset_scale), ) else: cur_penal = 0 cur_s, cur_c, cur_scl, cur_obj_raw = self.solve_thres( scaling=i > 0, - pks_polish=self.pks_polish and (i > 1 or not reset_scale), + pks_polish=self.cfg.pks_polish and (i > 1 or not reset_scale), obj_crit=obj_crit, ) + if self.dashboard is not None: pad_s = np.zeros(self.T) pad_s[self.nzidx_s] = cur_s @@ -933,6 +834,7 @@ def solve_scale( s=self.R_org @ pad_s, scale=cur_scl, ) + if metric_df is None: prev_scals = np.array([np.inf]) opt_obj = np.inf @@ -941,69 +843,75 @@ def solve_scale( last_scal = np.inf else: opt_idx = metric_df["obj"].idxmin() - opt_obj = metric_df.loc[opt_idx, "obj"].item() - opt_scal = metric_df.loc[opt_idx, "scale"].item() + opt_obj = metric_df.loc[opt_idx, "obj"] + opt_scal = metric_df.loc[opt_idx, "scale"] prev_scals = np.array(metric_df["scale"]) last_scal = prev_scals[-1] last_obj = np.array(metric_df["obj"])[-1] + y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) - cur_obj = (cur_obj_raw - err_tt) / err_tt - cur_met = pd.DataFrame( - [ - { - "iter": i, - "scale": cur_scl, - "obj_raw": cur_obj_raw, - "obj": cur_obj, - "penal": cur_penal, - "nnz": (cur_s > 0).sum(), - "density": (cur_s > 0).sum() / self.T, - } - ] - ) + cur_obj = (cur_obj_raw - err_tt) / max(err_tt, 1e-10) + + cur_met = pd.DataFrame([{ + "iter": i, + "scale": cur_scl, + "obj_raw": cur_obj_raw, + "obj": cur_obj, + "penal": cur_penal, + "nnz": (cur_s > 0).sum(), + "density": (cur_s > 0).sum() / self.T, + }]) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - if self.err_weighting == "adaptive" and i <= 1: + + if self.cfg.err_weighting == "adaptive" and i <= 1: self.update(update_weighting=True) if masking and i >= 1: self._update_mask() - if any( - ( - np.abs(cur_scl - opt_scal) < self.rtol * opt_scal, - np.abs(cur_obj - opt_obj) < self.rtol * opt_obj, - np.abs(cur_scl - last_scal) < self.atol, - np.abs(cur_obj - last_obj) < self.atol * 1e-3, - early_stop and cur_obj > last_obj, - ) - ): + + if any([ + np.abs(cur_scl - opt_scal) < self.cfg.rtol * opt_scal, + np.abs(cur_obj - opt_obj) < self.cfg.rtol * opt_obj, + np.abs(cur_scl - last_scal) < self.cfg.atol, + np.abs(cur_obj - last_obj) < self.cfg.atol * 1e-3, + early_stop and cur_obj > last_obj, + ]): break elif cur_scl == 0: - warnings.warn("exit with zero solution") + warnings.warn("Exit with zero solution") break - elif np.abs(cur_scl - prev_scals).min() < self.atol: + elif np.abs(cur_scl - prev_scals).min() < self.cfg.atol: self.update(scale=(cur_scl + last_scal) / 2) else: self.update(scale=cur_scl) else: - warnings.warn("max scale iterations reached") + warnings.warn("Max scale iterations reached") + + # Final solve with optimal scale opt_idx = metric_df["obj"].idxmin() self.update(update_weighting=True, clear_weighting=True) self._reset_cache() self._reset_mask() - self.update(scale=metric_df.loc[opt_idx, "scale"]) + self.update(scale=float(metric_df.loc[opt_idx, "scale"])) + cur_s, cur_c, cur_scl, cur_obj, cur_penal = self.solve_penal( - scaling=False, masking=False, pks_polish=self.pks_polish + scaling=False, masking=False, pks_polish=self.cfg.pks_polish ) - opt_s, opt_c = np.zeros(self.T), np.zeros(self.T) + + opt_s = np.zeros(self.T) + opt_c = np.zeros(self.T) opt_s[self.nzidx_s] = cur_s - opt_c[self.nzidx_c] = cur_c + opt_c[self.nzidx_c] = cur_c if not sps.issparse(cur_c) else cur_c.toarray().squeeze() nnz = int(opt_s.sum()) + self.update(update_weighting=True) y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) err_cur = self._compute_err(s=opt_s) - err_rel = (err_cur - err_tt) / err_tt + err_rel = (err_cur - err_tt) / max(err_tt, 1e-10) + self.update(update_weighting=True, clear_weighting=True) + if self.dashboard is not None: self.dashboard.update( uid=self.dashboard_uid, @@ -1011,556 +919,11 @@ def solve_scale( s=self.R_org @ opt_s, scale=cur_scl, ) + self._reset_cache() self._reset_mask() + if return_met: return opt_s, opt_c, cur_scl, cur_obj, err_rel, nnz, cur_penal, metric_df else: return opt_s, opt_c, cur_scl, cur_obj, err_rel, nnz, cur_penal - - def _setup_prob_osqp(self) -> None: - logger.debug("Setting up OSQP problem") - self._update_HG() - self._update_wgt_len() - self._update_P() - self._update_q0() - self._update_q() - self._update_A() - self._update_bounds() - if self.backend == "emosqp": - m = osqp.OSQP() - m.setup( - P=self.P, - q=self.q, - A=self.A, - l=self.lb, - u=self.ub_inf, - check_termination=25, - eps_abs=self.atol * 1e-4, - eps_rel=1e-8, - ) - m.codegen( - "osqp-codegen-prob_free", - parameters="matrices", - python_ext_name="emosqp_free", - force_rewrite=True, - ) - m.update(u=self.ub) - m.codegen( - "osqp-codegen-prob", - parameters="matrices", - python_ext_name="emosqp", - force_rewrite=True, - ) - import emosqp - import emosqp_free - - self.prob_free = emosqp_free - self.prob = emosqp - elif self.backend in ["osqp", "cuosqp"]: - if self.backend == "osqp": - self.prob_free = osqp.OSQP() - self.prob = osqp.OSQP() - elif self.backend == "cuosqp": - self.prob_free = cuosqp.OSQP() - self.prob = cuosqp.OSQP() - P_copy = self.P.copy() - q_copy = self.q.copy() - A_copy = self.A.copy() - lb_copy = self.lb.copy() - ub_inf_copy = self.ub_inf.copy() - self.prob_free.setup( - P=P_copy, - q=q_copy, - A=A_copy, - l=lb_copy, - u=ub_inf_copy, - verbose=False, - polish=True, - warm_start=False, - # adaptive_rho=False, - eps_abs=1e-6, - eps_rel=1e-6, - eps_prim_inf=1e-7, - eps_dual_inf=1e-7, - # max_iter=int(1e5) if self.backend == "osqp" else None, - # eps_prim_inf=1e-8, - ) - P_copy = self.P.copy() - q_copy = self.q.copy() - A_copy = self.A.copy() - lb_copy = self.lb.copy() - ub_copy = self.ub.copy() - self.prob.setup( - P=P_copy, - q=q_copy, - A=A_copy, - l=lb_copy, - u=ub_copy, - verbose=False, - polish=True, - warm_start=False, - # adaptive_rho=False, - eps_abs=1e-6, - eps_rel=1e-6, - eps_prim_inf=1e-7, - eps_dual_inf=1e-7, - # max_iter=int(1e5) if self.backend == "osqp" else None, - # eps_prim_inf=1e-8, - ) - logger.debug(f"{self.backend} setup completed successfully") - - def _solve( - self, - amp_constraint: bool = True, - return_obj: bool = False, - update_cache: bool = False, - ) -> np.ndarray: - if amp_constraint: - prob = self.prob - else: - prob = self.prob_free - # if self.backend in ["osqp", "emosqp", "cuosqp"] and self.x_cache is not None: - # prob.warm_start(x=self.x_cache) - res = prob.solve() - if self.backend == "cvxpy": - opt_s = self.s.value.squeeze() - opt_b = 0 - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - x = res[0] if self.backend == "emosqp" else res.x - if res.info.status not in ["solved", "solved inaccurate"]: - warnings.warn("Problem not solved. status: {}".format(res.info.status)) - # osqp mistakenly report primal infeasibility when using masks - # with high l1 penalty. manually set solution to zero in such cases - if res.info.status in [ - "primal infeasible", - "primal infeasible inaccurate", - ]: - x = np.zeros_like(x, dtype=float) - else: - x = x.astype(float) - # if update_cache: - # self.x_cache = x - # prob.warm_start(x=x) - if self.norm == "huber": - xlen = len(self.nzidx_s) if self.free_kernel else len(self.nzidx_c) - sol = x[:xlen] - else: - sol = x - opt_b = sol[0] - if self.free_kernel: - opt_s = sol[1:] - else: - opt_s = self.G @ sol[1:] - if return_obj: - if self.backend == "cvxpy": - opt_obj = res - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - opt_obj = self._compute_err() - return opt_s, opt_b, opt_obj - else: - return opt_s, opt_b - - def _compute_c(self, s: np.ndarray = None) -> np.ndarray: - if s is not None: - return self._convolve_s(s) - else: - if self.backend == "cvxpy": - return self.c.value.squeeze() - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - return self._convolve_s(self.s) - - def _convolve_s(self, s: np.ndarray) -> sps.csc_array: - if self.H is not None: - return self.H @ sps.csc_matrix(s.reshape(-1, 1)) - else: - if s.dtype == np.bool_: - return sps.csc_matrix( - bin_convolve( - self.coef, s=s, nzidx_s=self.nzidx_s, s_len=self.T - ).reshape(-1, 1) - ) - else: - return sps.csc_matrix( - np.convolve(self.coef, self._pad_s(s))[self.nzidx_c].reshape(-1, 1) - ) - - def _compute_err( - self, - y_fit: np.ndarray = None, - b: np.ndarray = None, - c: np.ndarray = None, - s: np.ndarray = None, - res: np.ndarray = None, - obj_crit: str = None, - ) -> float: - if self.backend == "cvxpy": - # TODO: add support - raise NotImplementedError - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - y = np.array(self.y) - if res is not None: - y = y - res - if b is None: - b = self.b - y = y - b - if y_fit is None: - if c is None: - c = self._compute_c(s) - y_fit = np.array((self.R @ c * self.scale).todense()).squeeze() - r = y - y_fit - err = self._res_err(r) - if obj_crit in [None, "spk_diff"]: - return np.array(err).item() - else: - nspk = (s > 0).sum() - if obj_crit == "mean_spk": - err_total = self._res_err(y - y.mean()) - return np.array((err - err_total) / nspk).item() - elif obj_crit in ["aic", "bic"]: - noise_model = "normal" - T = len(r) - if noise_model == "normal": - mu = r.mean() - sigma = ((r - mu) ** 2).sum() / T - logL = -0.5 * ( - T * np.log(2 * np.pi * sigma) - + 1 / sigma * ((r - mu) ** 2).sum() - ) - elif noise_model == "lognormal": - ymin = y.min() - logy = np.log(y - ymin + 1) - logy_hat = np.log(y_fit - ymin + 1) - logr = logy - logy_hat - mu = np.mean(logr) - sigma = ((logr - mu) ** 2).sum() / T - logL = np.sum( - -logy - - (logr - mu) ** 2 / (2 * sigma) - - 0.5 * np.log(2 * np.pi * sigma) - ) - if obj_crit == "aic": - return np.array(2 * (nspk - logL)).item() - elif obj_crit == "bic": - return np.array(nspk * np.log(T) - 2 * logL).item() - else: - raise ValueError("invalid objective criterion: {}".format(obj_crit)) - - def _res_err(self, r: np.ndarray): - if self.err_wt is not None: - r = self.err_wt * r - if self.norm == "l1": - return np.sum(np.abs(r)) - elif self.norm == "l2": - return np.sum((r) ** 2) - elif self.norm == "huber": - err_hub = huber(self.huber_k, r) - err_qud = r**2 / 2 - return np.sum(np.where(r >= 0, err_hub, err_qud)) - - def _reset_cache(self) -> None: - self.x_cache = None - - def _reset_mask(self) -> None: - self.nzidx_s = np.arange(self.T) - self.nzidx_c = np.arange(self.T) - self._update_R() - self._update_w() - if self.backend in ["osqp", "emosqp", "cuosqp"]: - self._setup_prob_osqp() - - def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> None: - if self.backend in ["osqp", "emosqp", "cuosqp"]: - if use_wt: - nzidx_s = np.where(self.R.T @ self.err_wt)[0] - else: - if self.masking_r is not None: - mask = np.zeros(self.T) - for nzidx in np.where(self._pad_s(self.s_bin) > 0)[0]: - mask[ - max(nzidx - self.masking_r, 0) : min( - nzidx + self.masking_r, self.T - ) - ] = 1 - nzidx_s = np.where(mask)[0] - else: - self._reset_mask() - opt_s, _ = self.solve(amp_constraint) - nzidx_s = np.where(opt_s > self.delta_penal)[0] - if len(nzidx_s) == 0: - raise ValueError - self.nzidx_s = nzidx_s - self._update_R() - self._update_w() - self._setup_prob_osqp() - if not self.free_kernel and len(self.nzidx_c) < self.T: - res = self.prob.solve() - # osqp mistakenly report primal infeasible in some cases - # disable masking in such cases - # potentially related: https://github.com/osqp/osqp/issues/485 - if res.info.status == "primal infeasible": - self._reset_mask() - else: - # TODO: add support - raise NotImplementedError("masking not supported for cvxpy backend") - - def _update_w(self, w_new=None) -> None: - if w_new is not None: - self.w_org = w_new - self.w = self.w_org[self.nzidx_s] - - def _update_R(self) -> None: - self.R_org = construct_R(self.y_len, self.upsamp) - self.R = self.R_org[:, self.nzidx_c] - - def _update_Wt(self, clear=False) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if clear: - logger.debug("Clearing error weighting") - self.err_wt = np.ones(self.y_len) - elif self.err_weighting == "fft": - logger.debug("Updating error weighting with fft") - hspec = self._get_stft_spec(coef)[:, int(self.coef_len / 2)] - self.err_wt = ( - (hspec.reshape(-1, 1) * self.yspec).sum(axis=0) - / np.linalg.norm(hspec) - / np.linalg.norm(self.yspec, axis=0) - ) - elif self.err_weighting == "corr": - logger.debug("Updating error weighting with corr") - for i in range(self.y_len): - yseg = self.y[i : i + self.coef_len] - if len(yseg) <= 1: - continue - cseg = coef[: len(yseg)] - with np.errstate(all="ignore"): - self.err_wt[i] = np.corrcoef(yseg, cseg)[0, 1].clip(0, 1) - self.err_wt = np.nan_to_num(self.err_wt) - elif self.err_weighting == "adaptive": - if self.s_bin is not None: - self.err_wt = np.zeros(self.y_len) - s_bin_R = self.R @ self._pad_s(self.s_bin) - for nzidx in np.where(s_bin_R > 0)[0]: - self.err_wt[nzidx : nzidx + self.wgt_len] = 1 - else: - self.err_wt = np.ones(self.y_len) - self.Wt = sps.diags(self.err_wt) - - def _update_HG(self) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if self.Hlim is None or self.T * self.coef_len < self.Hlim: - self.H_org = sps.diags( - [np.repeat(coef[i], self.T - i) for i in range(len(coef))], - offsets=-np.arange(len(coef)), - format="csc", - ) - try: - H_shape, H_nnz = self.H.shape, self.H.nnz - except AttributeError: - H_shape, H_nnz = None, None - self.H = self.H_org[:, self.nzidx_s][self.nzidx_c, :] - logger.debug( - f"Updating H matrix - shape before: {H_shape}, shape new: {self.H.shape}, nnz before: {H_nnz}, nnz new: {self.H.nnz}" - ) - if not self.free_kernel: - theta = self.theta.value if self.backend == "cvxpy" else self.theta - G_diag = sps.diags( - [np.ones(self.T - 1)] - + [np.repeat(-theta[i], self.T - 2 - i) for i in range(theta.shape[0])], - offsets=np.arange(0, -theta.shape[0] - 1, -1), - format="csc", - ) - self.G_org = sps.bmat( - [[None, G_diag], [np.zeros((1, 1)), None]], format="csc" - ) - try: - G_shape, G_nnz = self.G.shape, self.G.nnz - except AttributeError: - G_shape, G_nnz = None, None - self.G = self.G_org[:, self.nzidx_c][self.nzidx_s, :] - logger.debug( - f"Updating G matrix - shape before: {G_shape}, shape new: {self.G.shape}, nnz before: {G_nnz}, nnz new: {self.G.nnz}" - ) - # assert np.isclose( - # np.linalg.pinv(self.H.todense()), self.G.todense(), atol=self.atol - # ).all() - - def _update_wgt_len(self) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if self.wt_trunc_thres is not None: - trunc_len = int( - np.around(np.where(coef > self.wt_trunc_thres)[0][-1] / self.upsamp) - ) - if trunc_len == 0: - trunc_len = int(np.around(np.where(coef > 0)[0][-1] / self.upsamp)) - self.wgt_len = max(min(self.coef_len, trunc_len), 1) - else: - self.wgt_len = self.coef_len - - def _get_stft_spec(self, x: np.ndarray) -> np.ndarray: - spec = np.abs(self.stft.stft(x)) ** 2 - t = self.stft.t(len(x)) - t_mask = np.logical_and(t >= 0, t < len(x)) - return spec[:, t_mask] - - def _get_M(self) -> sps.csc_matrix: - if self.free_kernel: - return sps.hstack( - [ - np.ones((self.R.shape[0], 1)), - self.scale * self.R @ self.H, - ], - format="csc", - ) - else: - return sps.hstack( - [np.ones((self.R.shape[0], 1)), self.scale * self.R], format="csc" - ) - - def _update_P(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - M = self._get_M() - P = M.T @ self.Wt.T @ self.Wt @ M - elif self.norm == "huber": - lc, ls, ly = len(self.nzidx_c), len(self.nzidx_s), self.y_len - if self.free_kernel: - P = sps.bmat( - [ - [sps.csc_matrix((ls, ls)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ] - ) - else: - P = sps.bmat( - [ - [sps.csc_matrix((lc, lc)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ] - ) - try: - P_shape, P_nnz = self.P.shape, self.P.nnz - except AttributeError: - P_shape, P_nnz = None, None - logger.debug( - f"Updating P matrix - shape before: {P_shape}, shape new: {P.shape}, nnz before: {P_nnz}, nnz new: {P.nnz}" - ) - self.P = sps.triu(P).tocsc() - - def _update_q0(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - M = self._get_M() - self.q0 = -M.T @ self.Wt.T @ self.Wt @ self.y - elif self.norm == "huber": - ly = self.y_len - lx = len(self.nzidx_s) if self.free_kernel else len(self.nzidx_c) - self.q0 = ( - np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k - ) - - def _update_q(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - if self.free_kernel: - ww = np.concatenate([np.zeros(1), self.w]) - qq = np.concatenate([np.zeros(1), np.ones_like(self.w)]) - self.q = self.q0 + self.l0_penal * ww + self.l1_penal * qq - else: - G_p = sps.hstack([np.zeros((self.G.shape[0], 1)), self.G], format="csc") - self.q = ( - self.q0 - + self.l0_penal * self.w @ G_p - + self.l1_penal * np.ones(self.G.shape[0]) @ G_p - ) - elif self.norm == "huber": - pad_k = np.zeros(self.y_len) - if self.free_kernel: - self.q = ( - self.q0 - + self.l0_penal * np.concatenate([self.w, pad_k, pad_k]) - + self.l1_penal - * np.concatenate([np.ones(len(self.nzidx_s)), pad_k, pad_k]) - ) - else: - self.q = ( - self.q0 - + self.l0_penal * np.concatenate([self.w @ self.G, pad_k, pad_k]) - + self.l1_penal - * np.concatenate([np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k]) - ) - - def _update_A(self) -> None: - if self.free_kernel: - Ax = sps.eye(len(self.nzidx_s), format="csc") - Ar = self.scale * self.R @ self.H - else: - Ax = sps.csc_matrix(self.G_org[:, self.nzidx_c]) - # record spike terms that requires constraint - self.nzidx_A = np.where((Ax != 0).sum(axis=1))[0] - Ax = Ax[self.nzidx_A, :] - Ar = self.scale * self.R - try: - A_shape, A_nnz = self.A.shape, self.A.nnz - except AttributeError: - A_shape, A_nnz = None, None - if self.norm == "huber": - e = sps.eye(self.y_len, format="csc") - self.A = sps.bmat( - [ - [Ax, None, None], - [None, e, None], - [None, None, -e], - [Ar, e, e], - ], - format="csc", - ) - else: - self.A = sps.bmat([[np.ones((1, 1)), None], [None, Ax]], format="csc") - logger.debug( - f"Updating A matrix - shape before: {A_shape}, shape new: {self.A.shape}, nnz before: {A_nnz}, nnz new: {self.A.nnz}" - ) - - def _update_bounds(self) -> None: - if self.norm == "huber": - xlen = len(self.nzidx_s) if self.free_kernel else self.T - self.lb = np.concatenate( - [np.zeros(xlen + self.y_len * 2), self.y - self.huber_k] - ) - self.ub = np.concatenate( - [np.ones(xlen), np.full(self.y_len * 2, np.inf), self.y - self.huber_k] - ) - self.ub_inf = np.concatenate( - [np.full(xlen + self.y_len * 2, np.inf), self.y - self.huber_k] - ) - else: - bb = np.clip(self.y.mean(), 0, None) if self.use_base else 0 - if self.free_kernel: - self.lb = np.zeros(len(self.nzidx_s) + 1) - self.ub = np.concatenate([np.full(1, bb), np.ones(len(self.nzidx_s))]) - self.ub_inf = np.concatenate( - [np.full(1, bb), np.full(len(self.nzidx_s), np.inf)] - ) - else: - ub_pad, ub_inf_pad = np.zeros(self.T), np.zeros(self.T) - ub_pad[self.nzidx_s] = 1 - ub_inf_pad[self.nzidx_s] = np.inf - self.lb = np.zeros(len(self.nzidx_A) + 1) - self.ub = np.concatenate([np.full(1, bb), ub_pad[self.nzidx_A]]) - self.ub_inf = np.concatenate([np.full(1, bb), ub_inf_pad[self.nzidx_A]]) - assert (self.ub >= self.lb).all() - assert (self.ub_inf >= self.lb).all() diff --git a/src/indeca/core/deconv/solver.py b/src/indeca/core/deconv/solver.py new file mode 100644 index 0000000..12f2845 --- /dev/null +++ b/src/indeca/core/deconv/solver.py @@ -0,0 +1,840 @@ +"""Solver implementations for deconv.""" + +from abc import ABC, abstractmethod +from typing import Any, Tuple, Optional +import warnings + +import cvxpy as cp +import numpy as np +import osqp +import scipy.sparse as sps +from scipy.special import huber +from scipy.signal import ShortTimeFFT + +from indeca.utils.logging_config import get_module_logger +from indeca.core.simulation import tau2AR, solve_p, exp_pulse, ar_pulse +from indeca.utils.utils import scal_lstsq +from .config import DeconvConfig +from .utils import construct_R, bin_convolve, get_stft_spec + +logger = get_module_logger("deconv_solver") + +# Try to import GPU solver +try: + import cuosqp + HAS_CUOSQP = True +except ImportError: + HAS_CUOSQP = False + logger.debug("cuosqp not available") + + +class DeconvSolver(ABC): + """Abstract base class for deconvolution solvers.""" + + def __init__( + self, + config: DeconvConfig, + y_len: int, + y: np.ndarray | None = None, + coef: np.ndarray | None = None, + theta: np.ndarray | None = None, + tau: np.ndarray | None = None, + ps: np.ndarray | None = None, + ): + self.cfg = config + self.y_len = y_len + self.T = y_len * self.cfg.upsamp + self.y = y if y is not None else np.zeros(y_len) + self.coef = coef + self.coef_len = len(coef) if coef is not None else config.coef_len * config.upsamp + self.theta = theta + self.tau = tau + self.ps = ps + + # Scale tracking (mutable, since config is frozen) + self.scale = config.scale + + # Penalty tracking + self.l0_penal = 0.0 + self.l1_penal = 0.0 + + # Weight vectors + self.w_org = np.ones(self.T) + self.w = np.ones(self.T) + + # Masking indices + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + + # Matrices + self.R_org = construct_R(self.y_len, self.cfg.upsamp) + self.R = self.R_org + self.H = None + self.H_org = None + self.G = None + self.G_org = None + + # Cache + self.x_cache = None + self.s_bin = None # Binary spike solution from thresholding + + # Error weighting + self.err_wt = np.ones(self.y_len) + self.wgt_len = self.coef_len + self.Wt = sps.diags(self.err_wt) + + # Huber parameter + self.huber_k = 0.5 * np.std(self.y) if y is not None else 0 + + @abstractmethod + def update(self, **kwargs): + """Update solver parameters.""" + pass + + @abstractmethod + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve the optimization problem.""" + pass + + def reset_cache(self) -> None: + """Reset solution cache.""" + self.x_cache = None + + def reset_mask(self) -> None: + """Reset masks to full range.""" + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + self._update_R() + self._update_w() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """Set mask indices. Override in subclasses that don't support masking.""" + self.nzidx_s = nzidx_s + # Old behavior (from `old deconv.py`): masking is applied to spike indices (s) + # while keeping the calcium state indices (c) unmasked unless explicitly provided. + if nzidx_c is not None: + self.nzidx_c = nzidx_c + self._update_R() + self._update_w() + + def _update_R(self) -> None: + """Update R matrix based on mask.""" + self.R = self.R_org[:, self.nzidx_c] + + def _update_w(self, w_new: np.ndarray = None) -> None: + """Update weight vector.""" + if w_new is not None: + self.w_org = w_new + self.w = self.w_org[self.nzidx_s] + + def _pad_s(self, s: np.ndarray = None) -> np.ndarray: + """Pad sparse s to full length.""" + if s is None: + s = np.zeros(len(self.nzidx_s)) + s_ret = np.zeros(self.T) + s_ret[self.nzidx_s] = s + return s_ret + + def _pad_c(self, c: np.ndarray = None) -> np.ndarray: + """Pad sparse c to full length.""" + if c is None: + c = np.zeros(len(self.nzidx_c)) + c_ret = np.zeros(self.T) + c_ret[self.nzidx_c] = c + return c_ret + + def _update_HG(self) -> None: + """Update H (convolution) and G (AR inverse) matrices.""" + coef = self.coef + if coef is None: + return + + # H matrix: convolution matrix + # IMPORTANT: in free-kernel mode the optimization uses R @ H explicitly, + # so H must always be materialized (do not drop it based on Hlim). + if self.cfg.free_kernel or self.cfg.Hlim is None or self.T * len(coef) < self.cfg.Hlim: + self.H_org = sps.diags( + [np.repeat(coef[i], self.T - i) for i in range(len(coef))], + offsets=-np.arange(len(coef)), + format="csc", + ) + self.H = self.H_org[:, self.nzidx_s][self.nzidx_c, :] + logger.debug(f"Updated H matrix - shape: {self.H.shape}, nnz: {self.H.nnz}") + else: + self.H = None + self.H_org = None + + # G matrix: AR inverse (only if theta provided and not free_kernel) + if not self.cfg.free_kernel and self.theta is not None: + theta = self.theta + G_diag = sps.diags( + [np.ones(self.T - 1)] + + [np.repeat(-theta[i], self.T - 2 - i) for i in range(theta.shape[0])], + offsets=np.arange(0, -theta.shape[0] - 1, -1), + format="csc", + ) + self.G_org = sps.bmat( + [[None, G_diag], [np.zeros((1, 1)), None]], format="csc" + ) + self.G = self.G_org[:, self.nzidx_c][self.nzidx_s, :] + logger.debug(f"Updated G matrix - shape: {self.G.shape}, nnz: {self.G.nnz}") + else: + self.G = None + self.G_org = None + + def _update_wgt_len(self) -> None: + """Update error weighting length based on coefficient truncation.""" + coef = self.coef + if coef is None: + return + if self.cfg.wt_trunc_thres is not None: + trunc_idx = np.where(coef > self.cfg.wt_trunc_thres)[0] + if len(trunc_idx) > 0: + trunc_len = int(np.around(trunc_idx[-1] / self.cfg.upsamp)) + else: + trunc_len = int(np.around(np.where(coef > 0)[0][-1] / self.cfg.upsamp)) + if trunc_len == 0: + trunc_len = 1 + self.wgt_len = max(min(self.coef_len, trunc_len), 1) + else: + self.wgt_len = self.coef_len + + def convolve(self, s: np.ndarray) -> sps.csc_matrix: + """Convolve signal s with kernel. Returns sparse column matrix.""" + if self.cfg.free_kernel: + assert ( + self.H is not None + ), "Invariant violated: free_kernel=True requires a materialized H matrix" + if self.H is not None: + # Check if s is masked length or full length + if len(s) == len(self.nzidx_s): + result = self.H @ sps.csc_matrix(s.reshape(-1, 1)) + elif len(s) == self.T: + result = self.H @ sps.csc_matrix(s[self.nzidx_s].reshape(-1, 1)) + else: + logger.warning(f"Shape mismatch in convolve: s={len(s)}, nzidx_s={len(self.nzidx_s)}") + result = sps.csc_matrix(np.zeros((len(self.nzidx_c), 1))) + return result + else: + # Use bin_convolve for efficiency when H is not stored + if s.dtype == np.bool_: + out = bin_convolve(self.coef, s, nzidx_s=self.nzidx_s, s_len=self.T) + else: + s_pad = self._pad_s(s) if len(s) == len(self.nzidx_s) else s + out = np.convolve(self.coef, s_pad)[:self.T] + return sps.csc_matrix(out[self.nzidx_c].reshape(-1, 1)) + + def validate_coefficients(self, atol: float = 1e-3) -> bool: + """Validate that AR and exponential coefficients are consistent.""" + if self.tau is None or self.ps is None or self.theta is None: + logger.debug("Skipping coefficient validation - missing tau/ps/theta") + return True + + try: + # Generate exponential pulse + tr_exp, _, _ = exp_pulse( + self.tau[0], self.tau[1], + p_d=self.ps[0], p_r=self.ps[1], + nsamp=self.coef_len, + ) + + # Generate AR pulse + theta = self.theta + tr_ar, _, _ = ar_pulse(theta[0], theta[1], nsamp=self.coef_len, shifted=True) + + # Validate + if not (~np.isnan(self.coef)).all(): + logger.warning("Coefficient array contains NaN values") + return False + + if not np.isclose(tr_exp, self.coef[:len(tr_exp)], atol=atol).all(): + logger.warning("Exp time constant inconsistent with coefficients") + return False + + if not np.isclose(tr_ar, self.coef[:len(tr_ar)], atol=atol).all(): + logger.warning("AR coefficients inconsistent with coefficients") + return False + + logger.debug("Coefficient validation passed") + return True + except Exception as e: + logger.warning(f"Coefficient validation failed: {e}") + return False + + +class CVXPYSolver(DeconvSolver): + """CVXPY backend solver.""" + + def __init__(self, config: DeconvConfig, y_len: int, **kwargs): + super().__init__(config, y_len, **kwargs) + self._update_HG() + self._update_wgt_len() + self._setup_problem() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """CVXPY does not support masking - raise error.""" + if len(nzidx_s) != self.T: + raise NotImplementedError( + "CVXPY backend does not support masking. Use OSQP backend instead." + ) + super().set_mask(nzidx_s, nzidx_c) + + def reset_mask(self) -> None: + """CVXPY does not support masking - no-op since problem is already full.""" + # CVXPY builds the full problem once, no mask support + # Just reset indices without rebuilding + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + + def _setup_problem(self): + """Setup CVXPY optimization problem.""" + # NOTE: `free_kernel=True` is forbidden with CVXPY backend (see `DeconvConfig`). + self.cp_R = cp.Constant(self.R, name="R") + self.cp_c = cp.Variable((self.T, 1), nonneg=True, name="c") + self.cp_s = cp.Variable((self.T, 1), nonneg=True, name="s", boolean=self.cfg.mixin) + self.cp_y = cp.Parameter(shape=(self.y_len, 1), name="y") + self.cp_huber_k = cp.Parameter(value=float(self.huber_k), nonneg=True, name="huber_k") + + self.cp_scale = cp.Parameter(value=self.scale, name="scale", nonneg=True) + self.cp_l1_penal = cp.Parameter(value=0.0, name="l1_penal", nonneg=True) + self.cp_l0_w = cp.Parameter(shape=self.T, value=np.zeros(self.T), nonneg=True, name="w_l0") + + if self.y is not None: + self.cp_y.value = self.y.reshape((-1, 1)) + + if self.cfg.use_base: + self.cp_b = cp.Variable(nonneg=True, name="b") + else: + self.cp_b = cp.Constant(value=0, name="b") + + # Error term based on norm + term = self.cp_y - self.cp_scale * self.cp_R @ self.cp_c - self.cp_b + if self.cfg.norm == "l1": + self.err_term = cp.sum(cp.abs(term)) + elif self.cfg.norm == "l2": + self.err_term = cp.sum_squares(term) + elif self.cfg.norm == "huber": + # Keep huber parameter consistent with OSQP backend's `huber_k`. + self.err_term = cp.sum(cp.huber(term, M=self.cp_huber_k)) + + # Objective + obj_expr = ( + self.err_term + + self.cp_l0_w.T @ cp.abs(self.cp_s) + + self.cp_l1_penal * cp.sum(cp.abs(self.cp_s)) + ) + obj = cp.Minimize(obj_expr) + + # Constraints + # AR constraint via G matrix + self.cp_theta = cp.Parameter( + value=self.theta, shape=self.theta.shape, name="theta" + ) + G_diag = sps.eye(self.T - 1) + sum( + [ + cp.diag(cp.promote(-self.cp_theta[i], (self.T - i - 2,)), -i - 1) + for i in range(self.theta.shape[0]) + ] + ) + G = cp.bmat( + [ + [np.zeros((self.T - 1, 1)), G_diag], + [np.zeros((1, 1)), np.zeros((1, self.T - 1))], + ] + ) + dcv_cons = [self.cp_s == G @ self.cp_c] + + edge_cons = [self.cp_c[0, 0] == 0, self.cp_s[-1, 0] == 0] + amp_cons = [self.cp_s <= 1] + + self.prob_free = cp.Problem(obj, dcv_cons + edge_cons) + self.prob = cp.Problem(obj, dcv_cons + edge_cons + amp_cons) + + def update( + self, + y: np.ndarray = None, + coef: np.ndarray = None, + scale: float = None, + scale_mul: float = None, + l1_penal: float = None, + l0_penal: float = None, + w: np.ndarray = None, + theta: np.ndarray = None, + **kwargs + ): + """Update CVXPY parameters.""" + if y is not None: + self.y = y + self.cp_y.value = y.reshape((-1, 1)) + # Keep huber_k consistent with OSQP backend / deconv objective. + self.huber_k = 0.5 * np.std(self.y) + self.cp_huber_k.value = float(self.huber_k) + if coef is not None: + self.coef = coef + self._update_HG() + self._update_wgt_len() + if scale is not None: + self.scale = scale + self.cp_scale.value = scale + if scale_mul is not None: + self.scale *= scale_mul + self.cp_scale.value = self.scale + if l1_penal is not None: + self.l1_penal = l1_penal + self.cp_l1_penal.value = l1_penal + if l0_penal is not None: + self.l0_penal = l0_penal + if w is not None: + self._update_w(w) + if l0_penal is not None or w is not None: + self.cp_l0_w.value = self.l0_penal * self.w + if theta is not None and hasattr(self, 'cp_theta'): + self.theta = theta + self.cp_theta.value = theta + + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve CVXPY problem.""" + prob = self.prob if amp_constraint else self.prob_free + try: + res = prob.solve() + except cp.error.SolverError as e: + logger.warning(f"CVXPY SolverError: {e}") + res = np.inf + + opt_s = self.cp_s.value.squeeze() if self.cp_s.value is not None else np.zeros(self.T) + opt_b = 0 + if self.cfg.use_base and hasattr(self.cp_b, 'value') and self.cp_b.value is not None: + opt_b = float(self.cp_b.value) + + return opt_s, opt_b, res + + +class OSQPSolver(DeconvSolver): + """OSQP backend solver (also handles cuosqp for GPU).""" + + def __init__(self, config: DeconvConfig, y_len: int, **kwargs): + super().__init__(config, y_len, **kwargs) + + # Additional state for OSQP + self.prob = None + self.prob_free = None + self.P = None + self.q = None + self.q0 = None + self.A = None + self.lb = None + self.ub = None + self.ub_inf = None + self.nzidx_A = None + + # STFT for FFT weighting + if self.cfg.err_weighting == "fft": + self.stft = ShortTimeFFT(win=np.ones(self.coef_len), hop=1, fs=1) + self.yspec = get_stft_spec(self.y, self.stft) + + # Initialize matrices and problem + self._update_HG() + self._update_wgt_len() + self._update_Wt() + self._setup_prob_osqp() + + def reset_mask(self) -> None: + """Reset masks to full range and rebuild OSQP problems.""" + super().reset_mask() + self._update_HG() + self._setup_prob_osqp() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """Set mask and rebuild problem.""" + super().set_mask(nzidx_s, nzidx_c) + self._update_HG() + self._setup_prob_osqp() + + def _update_Wt(self, clear: bool = False) -> None: + """Update error weighting matrix.""" + coef = self.coef + if clear: + logger.debug("Clearing error weighting") + self.err_wt = np.ones(self.y_len) + elif self.cfg.err_weighting == "fft" and hasattr(self, 'stft'): + logger.debug("Updating error weighting with fft") + hspec = get_stft_spec(coef, self.stft)[:, int(len(coef) / 2)] + self.err_wt = ( + (hspec.reshape(-1, 1) * self.yspec).sum(axis=0) + / np.linalg.norm(hspec) + / np.linalg.norm(self.yspec, axis=0) + ) + elif self.cfg.err_weighting == "corr": + logger.debug("Updating error weighting with corr") + self.err_wt = np.ones(self.y_len) + for i in range(self.y_len): + yseg = self.y[i : i + len(coef)] + if len(yseg) <= 1: + continue + cseg = coef[:len(yseg)] + with np.errstate(all="ignore"): + self.err_wt[i] = np.corrcoef(yseg, cseg)[0, 1].clip(0, 1) + self.err_wt = np.nan_to_num(self.err_wt) + elif self.cfg.err_weighting == "adaptive": + if self.s_bin is not None: + self.err_wt = np.zeros(self.y_len) + s_bin_R = self.R @ self._pad_s(self.s_bin) + for nzidx in np.where(s_bin_R > 0)[0]: + self.err_wt[nzidx : nzidx + self.wgt_len] = 1 + else: + self.err_wt = np.ones(self.y_len) + + self.Wt = sps.diags(self.err_wt) + + def _get_M(self) -> sps.csc_matrix: + """Get the combined model matrix M = [1, scale*R] or [1, scale*R@H].""" + if self.cfg.free_kernel: + return sps.hstack( + [np.ones((self.R.shape[0], 1)), self.scale * self.R @ self.H], + format="csc", + ) + else: + return sps.hstack( + [np.ones((self.R.shape[0], 1)), self.scale * self.R], + format="csc", + ) + + def _update_P(self) -> None: + """Update quadratic cost matrix P.""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + M = self._get_M() + P = M.T @ self.Wt.T @ self.Wt @ M + elif self.cfg.norm == "huber": + lc = len(self.nzidx_c) + ls = len(self.nzidx_s) + ly = self.y_len + if self.cfg.free_kernel: + P = sps.bmat([ + [sps.csc_matrix((ls + 1, ls + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ]) + else: + P = sps.bmat([ + [sps.csc_matrix((lc + 1, lc + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ]) + + self.P = sps.triu(P).tocsc() + logger.debug(f"Updated P matrix - shape: {self.P.shape}, nnz: {self.P.nnz}") + + def _update_q0(self) -> None: + """Update linear cost base q0.""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + M = self._get_M() + self.q0 = -M.T @ self.Wt.T @ self.Wt @ self.y + elif self.cfg.norm == "huber": + ly = self.y_len + lx = len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + self.q0 = np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k + + def _update_q(self) -> None: + """Update linear cost vector q (including penalties).""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + if self.cfg.free_kernel: + ww = np.concatenate([np.zeros(1), self.w]) + qq = np.concatenate([np.zeros(1), np.ones_like(self.w)]) + self.q = self.q0 + self.l0_penal * ww + self.l1_penal * qq + else: + G_p = sps.hstack([np.zeros((self.G.shape[0], 1)), self.G], format="csc") + self.q = ( + self.q0 + + self.l0_penal * self.w @ G_p + + self.l1_penal * np.ones(self.G.shape[0]) @ G_p + ) + elif self.cfg.norm == "huber": + pad_k = np.zeros(self.y_len) + if self.cfg.free_kernel: + self.q = ( + self.q0 + + self.l0_penal * np.concatenate([[0], self.w, pad_k, pad_k]) + + self.l1_penal * np.concatenate([[0], np.ones(len(self.nzidx_s)), pad_k, pad_k]) + ) + else: + self.q = ( + self.q0 + + self.l0_penal * np.concatenate([[0], self.w @ self.G, pad_k, pad_k]) + + self.l1_penal * np.concatenate([[0], np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k]) + ) + + def _update_A(self) -> None: + """Update constraint matrix A.""" + if self.cfg.free_kernel: + Ax = sps.eye(len(self.nzidx_s), format="csc") + Ar = self.scale * self.R @ self.H + else: + Ax = sps.csc_matrix(self.G_org[:, self.nzidx_c]) + # Record spike terms that require constraint + self.nzidx_A = np.where((Ax != 0).sum(axis=1))[0] + Ax = Ax[self.nzidx_A, :] + Ar = self.scale * self.R + + if self.cfg.norm == "huber": + e = sps.eye(self.y_len, format="csc") + self.A = sps.bmat([ + [sps.csc_matrix((Ax.shape[0], 1)), Ax, None, None], + [None, None, e, None], + [None, None, None, -e], + [np.ones((Ar.shape[0], 1)), Ar, e, e], + ], format="csc") + else: + self.A = sps.bmat([ + [np.ones((1, 1)), None], + [None, Ax] + ], format="csc") + + logger.debug(f"Updated A matrix - shape: {self.A.shape}, nnz: {self.A.nnz}") + + def _update_bounds(self) -> None: + """Update constraint bounds.""" + if self.cfg.norm == "huber": + xlen = len(self.nzidx_s) if self.cfg.free_kernel else len(self.nzidx_A) + self.lb = np.concatenate([ + np.zeros(xlen + self.y_len * 2), + self.y - self.huber_k + ]) + self.ub = np.concatenate([ + np.ones(xlen), + np.full(self.y_len * 2, np.inf), + self.y - self.huber_k + ]) + self.ub_inf = np.concatenate([ + np.full(xlen + self.y_len * 2, np.inf), + self.y - self.huber_k + ]) + else: + bb = np.clip(self.y.mean(), 0, None) if self.cfg.use_base else 0 + if self.cfg.free_kernel: + self.lb = np.zeros(len(self.nzidx_s) + 1) + self.ub = np.concatenate([np.full(1, bb), np.ones(len(self.nzidx_s))]) + self.ub_inf = np.concatenate([np.full(1, bb), np.full(len(self.nzidx_s), np.inf)]) + else: + ub_pad = np.zeros(self.T) + ub_inf_pad = np.zeros(self.T) + ub_pad[self.nzidx_s] = 1 + ub_inf_pad[self.nzidx_s] = np.inf + self.lb = np.zeros(len(self.nzidx_A) + 1) + self.ub = np.concatenate([np.full(1, bb), ub_pad[self.nzidx_A]]) + self.ub_inf = np.concatenate([np.full(1, bb), ub_inf_pad[self.nzidx_A]]) + + assert (self.ub >= self.lb).all(), "Upper bounds must be >= lower bounds" + assert (self.ub_inf >= self.lb).all(), "Upper bounds (inf) must be >= lower bounds" + + def _setup_prob_osqp(self) -> None: + """Setup OSQP problem instances.""" + logger.debug("Setting up OSQP problem") + + self._update_P() + self._update_q0() + self._update_q() + self._update_A() + self._update_bounds() + + # Choose solver backend + if self.cfg.backend == "cuosqp": + if not HAS_CUOSQP: + logger.warning("cuosqp not available, falling back to osqp") + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + else: + self.prob = cuosqp.OSQP() + self.prob_free = cuosqp.OSQP() + elif self.cfg.backend == "emosqp": + # Stub: emosqp requires codegen, not supported in this refactor + logger.warning("emosqp requires codegen, using osqp instead") + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + else: + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + + # Setup constrained problem + self.prob.setup( + P=self.P.copy(), + q=self.q.copy(), + A=self.A.copy(), + l=self.lb.copy(), + u=self.ub.copy(), + verbose=False, + polish=True, + warm_start=False, + eps_abs=1e-6, + eps_rel=1e-6, + eps_prim_inf=1e-7, + eps_dual_inf=1e-7, + ) + + # Setup unconstrained (free) problem + self.prob_free.setup( + P=self.P.copy(), + q=self.q.copy(), + A=self.A.copy(), + l=self.lb.copy(), + u=self.ub_inf.copy(), + verbose=False, + polish=True, + warm_start=False, + eps_abs=1e-6, + eps_rel=1e-6, + eps_prim_inf=1e-7, + eps_dual_inf=1e-7, + ) + + logger.debug(f"{self.cfg.backend} setup completed successfully") + + def update( + self, + y: np.ndarray = None, + coef: np.ndarray = None, + tau: np.ndarray = None, + theta: np.ndarray = None, + scale: float = None, + scale_mul: float = None, + l1_penal: float = None, + l0_penal: float = None, + w: np.ndarray = None, + update_weighting: bool = False, + clear_weighting: bool = False, + scale_coef: bool = False, + **kwargs + ): + """Update OSQP problem parameters.""" + logger.debug(f"Updating OSQP solver parameters") + + # Update input parameters + if y is not None: + self.y = y + # Match legacy behavior: huber_k tracks the current y (used in q0 and bounds) + self.huber_k = 0.5 * np.std(self.y) + if tau is not None: + theta_new = np.array(tau2AR(tau[0], tau[1])) + p = solve_p(tau[0], tau[1]) + coef_new, _, _ = exp_pulse( + tau[0], tau[1], p_d=p, p_r=-p, + nsamp=self.coef_len, kn_len=self.coef_len + ) + self.tau = tau + self.ps = np.array([p, -p]) + self.theta = theta_new + coef = coef_new + if theta is not None: + self.theta = theta + if coef is not None: + if scale_coef and self.coef is not None: + scale_mul = scal_lstsq(coef, self.coef).item() + self.coef = coef + if scale is not None: + self.scale = scale + if scale_mul is not None: + self.scale *= scale_mul + if l1_penal is not None: + self.l1_penal = l1_penal + if l0_penal is not None: + self.l0_penal = l0_penal + if w is not None: + self._update_w(w) + + # Track what needs updating + updt_HG = coef is not None + updt_P = False + updt_q0 = False + updt_q = False + updt_A = False + updt_bounds = False + setup_prob = False + + if updt_HG: + self._update_HG() + self._update_wgt_len() + + if self.cfg.err_weighting is not None and update_weighting: + self._update_Wt(clear=clear_weighting) + if self.cfg.err_weighting == "adaptive": + setup_prob = True + else: + updt_P = True + updt_q0 = True + updt_q = True + + if self.cfg.norm == "huber": + # huber_k changes require recomputing q and bounds + if y is not None: + self._update_q0() + updt_q0 = True + if any([scale is not None, scale_mul is not None, updt_HG]): + self._update_A() + updt_A = True + if any([w is not None, l0_penal is not None, l1_penal is not None, updt_HG]): + self._update_q() + updt_q = True + if y is not None: + self._update_bounds() + updt_bounds = True + else: + if any([updt_HG, updt_A]): + self._update_A() + updt_A = True + if any([scale is not None, scale_mul is not None, updt_HG, updt_P]): + self._update_P() + updt_P = True + if any([scale is not None, scale_mul is not None, y is not None, updt_HG, updt_q0]): + self._update_q0() + updt_q0 = True + if any([w is not None, l0_penal is not None, l1_penal is not None, updt_q0, updt_q]): + self._update_q() + updt_q = True + + # Apply updates to OSQP - conservative approach: + # Only q can be updated in-place safely. For P, A, bounds, rebuild. + if setup_prob or any([updt_P, updt_A, updt_bounds]): + self._setup_prob_osqp() + elif updt_q: + self.prob.update(q=self.q) + self.prob_free.update(q=self.q) + + logger.debug("OSQP problem updated") + + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve OSQP problem.""" + prob = self.prob if amp_constraint else self.prob_free + res = prob.solve() + + if res.info.status not in ["solved", "solved inaccurate"]: + logger.warning(f"OSQP not solved: {res.info.status}") + if res.info.status in ["primal infeasible", "primal infeasible inaccurate"]: + x = np.zeros(self.P.shape[0], dtype=float) + else: + x = res.x.astype(float) if res.x is not None else np.zeros(self.P.shape[0], dtype=float) + else: + x = res.x + + if self.cfg.norm == "huber": + xlen = len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + sol = x[:xlen] + opt_b = sol[0] + if self.cfg.free_kernel: + opt_s = sol[1:] + else: + opt_s = self.G @ sol[1:] + else: + opt_b = x[0] + if self.cfg.free_kernel: + opt_s = x[1:] + else: + c_sol = x[1:] + opt_s = self.G @ c_sol + + # Return 0 for objective - caller should use _compute_err for correct objective + return opt_s, opt_b, 0 diff --git a/src/indeca/core/deconv/utils.py b/src/indeca/core/deconv/utils.py new file mode 100644 index 0000000..e806446 --- /dev/null +++ b/src/indeca/core/deconv/utils.py @@ -0,0 +1,121 @@ +"""Utility functions for deconv module.""" + +import numpy as np +import scipy.sparse as sps +from numba import njit +from scipy.signal import ShortTimeFFT +from indeca.core.simulation import tau2AR + + +def get_stft_spec(x: np.ndarray, stft: ShortTimeFFT) -> np.ndarray: + """Compute STFT spectrogram.""" + spec = np.abs(stft.stft(x)) ** 2 + t = stft.t(len(x)) + t_mask = np.logical_and(t >= 0, t < len(x)) + return spec[:, t_mask] + + +def construct_R(T: int, up_factor: int): + """Construct the resampling matrix R.""" + if up_factor > 1: + return sps.csc_matrix( + ( + np.ones(T * up_factor), + (np.repeat(np.arange(T), up_factor), np.arange(T * up_factor)), + ), + shape=(T, T * up_factor), + ) + else: + return sps.eye(T, format="csc") + + +def sum_downsample(a, factor): + """Sum downsample array a by factor.""" + return np.convolve(a, np.ones(factor), mode="full")[factor - 1 :: factor] + + +def construct_G(fac: np.ndarray, T: int, fromTau=False): + """Construct the generator matrix G.""" + # I think we should be able to remove fromTau argument since we don't use it anywhere. + fac = np.array(fac) + assert fac.shape == (2,) + if fromTau: + fac = np.array(tau2AR(*fac)) + return sps.dia_matrix( + ( + np.tile(np.concatenate(([1], -fac)), (T, 1)).T, + -np.arange(len(fac) + 1), + ), + shape=(T, T), + ).tocsc() + + +def max_thres( + a: np.ndarray, + nthres: int, + th_min=0.1, + th_max=0.9, + ds=None, + return_thres=False, + th_amplitude=False, + delta=1e-6, + reverse_thres=False, + nz_only: bool = False, +): + """Threshold array a with nthres levels.""" + # Accept any array-like; normalized to numpy. + a = np.asarray(a) + amax = a.max() + if reverse_thres: + thres = np.linspace(th_max, th_min, nthres) + else: + thres = np.linspace(th_min, th_max, nthres) + if th_amplitude: + S_ls = [np.floor_divide(a, (amax * th).clip(delta, None)) for th in thres] + else: + S_ls = [(a > (amax * th).clip(delta, None)) for th in thres] + if ds is not None: + S_ls = [sum_downsample(s, ds) for s in S_ls] + if nz_only: + Snz = [ss.sum() > 0 for ss in S_ls] + S_ls = [ss for ss, nz in zip(S_ls, Snz) if nz] + thres = [th for th, nz in zip(thres, Snz) if nz] + if return_thres: + return S_ls, thres + else: + return S_ls + + +@njit(nopython=True, nogil=True, cache=True) +def bin_convolve( + coef: np.ndarray, s: np.ndarray, nzidx_s: np.ndarray = None, s_len: int = None +): + """Binary convolution implemented in numba.""" + coef_len = len(coef) + if s_len is None: + s_len = len(s) + out = np.zeros(s_len) + nzidx = np.where(s)[0] + if nzidx_s is not None: + nzidx = nzidx_s[nzidx].astype( + np.int64 + ) # astype to fix numpa issues on GPU on Windows + for i0 in nzidx: + i1 = min(i0 + coef_len, s_len) + clen = i1 - i0 + out[i0:i1] += coef[:clen] + return out + + +@njit(nopython=True, nogil=True, cache=True) +def max_consecutive(arr): + """Find maximum consecutive ones.""" + max_count = 0 + current_count = 0 + for value in arr: + if value: + current_count += 1 + max_count = max(max_count, current_count) + else: + current_count = 0 + return max_count diff --git a/tests/unit/test_deconv_G_matrix.py b/tests/unit/test_deconv_G_matrix.py new file mode 100644 index 0000000..b64d5af --- /dev/null +++ b/tests/unit/test_deconv_G_matrix.py @@ -0,0 +1,59 @@ +import numpy as np + +from indeca.core.deconv import DeconvBin +from indeca.core.simulation import tau2AR + + +def test_G_matrix_matches_shifted_ar_difference_equation(): + """ + The legacy implementation constructs a (T x T) "G" operator such that: + - s = G @ c + - s[-1] == 0 (last row is zeros) + - for t >= 2: s[t] = c[t+1] - theta0*c[t] - theta1*c[t-1] + - for t == 1: s[1] = c[2] - theta0*c[1] + - for t == 0: s[0] = c[1] + + This test guards against accidental dimensional/shift regressions in `G_org`. + """ + T = 5 + theta = np.array(tau2AR(10.0, 3.0)) + + deconv = DeconvBin( + y_len=T, + theta=theta, + coef_len=3, + backend="osqp", + free_kernel=False, + use_base=False, + norm="l2", + penal=None, + ) + + # Use the full, unmasked operator for determinism. + deconv._reset_mask() + G = deconv.solver.G_org + assert G.shape == (T, T) + + # Build a c vector with the same boundary condition as the solver (c[0] == 0). + rng = np.random.default_rng(0) + c = rng.normal(size=T) + c[0] = 0.0 + + # `scipy.sparse_matrix @ np.ndarray` returns a dense np.ndarray (no `.todense()`). + s = np.asarray(G @ c.reshape(-1, 1)).squeeze() + + # Last element must be exactly 0 due to bottom row of zeros. + assert np.isclose(s[-1], 0.0) + + # Check the shifted AR-difference mapping on the first T-1 entries. + th0, th1 = float(theta[0]), float(theta[1]) + expected = np.zeros(T) + expected[0] = c[1] + expected[1] = c[2] - th0 * c[1] + expected[2] = c[3] - th0 * c[2] - th1 * c[1] + expected[3] = c[4] - th0 * c[3] - th1 * c[2] + expected[4] = 0.0 + + assert np.allclose(s, expected) + + From f8e9119bd0d345fed3b89d6d916abcae928d94e1 Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Fri, 12 Dec 2025 23:47:30 -0800 Subject: [PATCH 2/6] small fixes based on tests errors --- src/indeca/core/deconv/config.py | 2 +- src/indeca/core/deconv/deconv.py | 5 +++++ src/indeca/pipeline/pipeline.py | 2 +- tests/integration/test_deconv_masking.py | 4 ++-- tests/integration/test_deconv_solve.py | 2 +- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/indeca/core/deconv/config.py b/src/indeca/core/deconv/config.py index d0d8eaa..4041613 100644 --- a/src/indeca/core/deconv/config.py +++ b/src/indeca/core/deconv/config.py @@ -59,7 +59,7 @@ class DeconvConfig(BaseModel): delta_penal: float = 1e-4 atol: float = 1e-3 rtol: float = 1e-3 - Hlim: int = 1e5 + Hlim: Optional[int] = 1e5 @model_validator(mode="before") @classmethod diff --git a/src/indeca/core/deconv/deconv.py b/src/indeca/core/deconv/deconv.py index fb04d43..d616b83 100644 --- a/src/indeca/core/deconv/deconv.py +++ b/src/indeca/core/deconv/deconv.py @@ -236,6 +236,11 @@ def err_wt(self): """Error weighting vector.""" return self.solver.err_wt + @property + def wgt_len(self): + """Error weighting length.""" + return self.solver.wgt_len + @err_wt.setter def err_wt(self, value): """Allow direct assignment (used by tests/demo code).""" diff --git a/src/indeca/pipeline/pipeline.py b/src/indeca/pipeline/pipeline.py index 5c463d6..cfd4fd4 100644 --- a/src/indeca/pipeline/pipeline.py +++ b/src/indeca/pipeline/pipeline.py @@ -8,7 +8,7 @@ from indeca.core.AR_kernel import AR_upsamp_real, estimate_coefs, updateAR from indeca.dashboard.dashboard import Dashboard -from indeca.core.deconv.deconv import DeconvBin, construct_R +from indeca.core.deconv import DeconvBin, construct_R from indeca.utils.logging_config import get_module_logger from indeca.core.simulation import AR2tau, find_dhm, tau2AR from indeca.utils.utils import compute_dff diff --git a/tests/integration/test_deconv_masking.py b/tests/integration/test_deconv_masking.py index 5ec9cbd..ed55c57 100644 --- a/tests/integration/test_deconv_masking.py +++ b/tests/integration/test_deconv_masking.py @@ -18,10 +18,10 @@ def test_masking(self, taus, rand_seed, upsamp, eq_atol, test_fig_path_html): deconv, y, c, c_org, s, s_org, scale = fixt_deconv( taus=taus, rand_seed=rand_seed, upsamp=upsamp, deconv_kws={"Hlim": None} ) - s_nomsk, b_nomsk = deconv._solve(amp_constraint=False) + s_nomsk, b_nomsk = deconv.solve(amp_constraint=False, pks_polish=False) c_nomsk = deconv.H @ s_nomsk deconv._update_mask() - s_msk, b_msk = deconv._solve(amp_constraint=False) + s_msk, b_msk = deconv.solve(amp_constraint=False, pks_polish=False) c_msk = deconv.H @ s_msk s_msk = deconv._pad_s(s_msk) c_msk = deconv._pad_c(c_msk) diff --git a/tests/integration/test_deconv_solve.py b/tests/integration/test_deconv_solve.py index 1e50dea..47e799b 100644 --- a/tests/integration/test_deconv_solve.py +++ b/tests/integration/test_deconv_solve.py @@ -24,7 +24,7 @@ def test_solve(self, taus, rand_seed, backend, upsamp, eq_atol, test_fig_path_ht upsamp=upsamp, deconv_kws={"Hlim": None}, ) - R = deconv.R.value if backend == "cvxpy" else deconv.R + R = deconv.R s_solve, b_solve = deconv.solve(amp_constraint=False, pks_polish=True) c_solve = deconv.H @ s_solve c_solve_R = R @ c_solve From d654b1bd152688b4c8f3d6262dca1c5883ed1cfe Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Sat, 13 Dec 2025 09:49:08 -0800 Subject: [PATCH 3/6] add option to include legacy bug for comparing to old deconv code --- src/indeca/core/deconv/deconv.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/indeca/core/deconv/deconv.py b/src/indeca/core/deconv/deconv.py index d616b83..81d44e1 100644 --- a/src/indeca/core/deconv/deconv.py +++ b/src/indeca/core/deconv/deconv.py @@ -2,6 +2,7 @@ import itertools as itt import math +import os import warnings from typing import Tuple, Any, Optional @@ -654,18 +655,25 @@ def to_arr(m): scl_thres = np.max(y) * self.cfg.min_rel_scl valid_idx = np.where(np.array(scals) > scl_thres)[0] if len(valid_idx) > 0: + # Optional legacy compatibility: reproduce old deconv behavior where + # scale filtering shrinks `scals/bs` but does NOT shrink `svals/yfvals`. + # This leads to zip() truncation later and can change which candidate is selected. + legacy_bug = os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" scals = [scals[i] for i in valid_idx] bs = [bs[i] for i in valid_idx] - svals = [svals[i] for i in valid_idx] - cvals = [cvals[i] for i in valid_idx] - yfvals = [yfvals[i] for i in valid_idx] + if not legacy_bug: + svals = [svals[i] for i in valid_idx] + cvals = [cvals[i] for i in valid_idx] + yfvals = [yfvals[i] for i in valid_idx] else: max_idx = np.argmax(scals) + legacy_bug = os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" scals = [scals[max_idx]] bs = [bs[max_idx]] - svals = [svals[max_idx]] - cvals = [cvals[max_idx]] - yfvals = [yfvals[max_idx]] + if not legacy_bug: + svals = [svals[max_idx]] + cvals = [cvals[max_idx]] + yfvals = [yfvals[max_idx]] else: scals = [self.scale] * len(yfvals) bs = [(y - res - scl * yf).mean() for scl, yf in zip(scals, yfvals)] From 474b6e03829a7448fdf724ae1b427bde0fd364be Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Sat, 13 Dec 2025 09:50:17 -0800 Subject: [PATCH 4/6] black formatting --- src/indeca/core/deconv/config.py | 7 +- src/indeca/core/deconv/deconv.py | 315 +++++++++++++++++------------ src/indeca/core/deconv/solver.py | 295 ++++++++++++++++----------- tests/unit/test_deconv_G_matrix.py | 2 - 4 files changed, 373 insertions(+), 246 deletions(-) diff --git a/src/indeca/core/deconv/config.py b/src/indeca/core/deconv/config.py index 4041613..bd4aff9 100644 --- a/src/indeca/core/deconv/config.py +++ b/src/indeca/core/deconv/config.py @@ -24,10 +24,12 @@ class DeconvConfig(BaseModel): False, description="Whether to use mixed-integer programming (boolean spikes)." ) backend: Literal["osqp", "cvxpy", "cuosqp"] = Field( - "osqp", description="Solver backend ('osqp', 'cvxpy', 'cuosqp'). Note: emosqp requires codegen and is not supported." + "osqp", + description="Solver backend ('osqp', 'cvxpy', 'cuosqp'). Note: emosqp requires codegen and is not supported.", ) free_kernel: bool = Field( - False, description="If True, use convolution constraint instead of AR constraint. Only supported with OSQP backends." + False, + description="If True, use convolution constraint instead of AR constraint. Only supported with OSQP backends.", ) nthres: int = Field(1000, description="Number of thresholds for thresholding step.") err_weighting: Optional[str] = Field( @@ -75,7 +77,6 @@ def resolve_auto_fields(cls, data): data["ncons_thres"] = upsamp + 1 return data - @model_validator(mode="after") def validate_penal(self): allowed = {None, "l0", "l1"} diff --git a/src/indeca/core/deconv/deconv.py b/src/indeca/core/deconv/deconv.py index 81d44e1..ea39772 100644 --- a/src/indeca/core/deconv/deconv.py +++ b/src/indeca/core/deconv/deconv.py @@ -27,7 +27,7 @@ class DeconvBin: """Deconvolution main class. - + This class wraps the solver backends and provides high-level methods for spike inference including thresholding, penalty optimization, and scale estimation. @@ -88,7 +88,7 @@ def __init__( self.theta = None self.tau = None self.ps = None - + # Compute coefficients from theta or tau if theta is not None: self.theta = np.array(theta) @@ -97,19 +97,27 @@ def __init__( self.tau = np.array([tau_d, tau_r]) self.ps = np.array([p, -p]) coef, _, _ = exp_pulse( - tau_d, tau_r, p_d=p, p_r=-p, + tau_d, + tau_r, + p_d=p, + p_r=-p, nsamp=coef_len * upsamp, kn_len=coef_len * upsamp, trunc_thres=atol, ) if tau is not None: - assert ps is not None, "exp coefficients must be provided together with time constants." + assert ( + ps is not None + ), "exp coefficients must be provided together with time constants." if theta is None: self.theta = np.array(tau2AR(tau[0], tau[1])) self.tau = np.array(tau) self.ps = ps coef, _, _ = exp_pulse( - tau[0], tau[1], p_d=ps[0], p_r=ps[1], + tau[0], + tau[1], + p_d=ps[0], + p_r=ps[1], nsamp=coef_len * upsamp, kn_len=coef_len * upsamp, trunc_thres=atol, @@ -121,7 +129,7 @@ def __init__( # `coef_len` (config) is the *base* kernel length; the stored `coef` is # already upsampled to length `coef_len * upsamp`. self.coef_len = len(coef) - + # Create config (note: frozen after creation) self.cfg = DeconvConfig( coef_len=coef_len, @@ -158,29 +166,39 @@ def __init__( # Dashboard for visualization self.dashboard = dashboard self.dashboard_uid = dashboard_uid - + # Penalty tracking - solver tracks scale, we track penalty locally self._l0_penal = 0.0 self._l1_penal = 0.0 - + # Create solver if self.cfg.backend == "cvxpy": if self.cfg.free_kernel: - raise NotImplementedError("CVXPY backend does not support free_kernel mode") + raise NotImplementedError( + "CVXPY backend does not support free_kernel mode" + ) self.solver = CVXPYSolver( - self.cfg, self.y_len, - y=self.y, coef=coef, theta=self.theta, - tau=self.tau, ps=self.ps + self.cfg, + self.y_len, + y=self.y, + coef=coef, + theta=self.theta, + tau=self.tau, + ps=self.ps, ) elif self.cfg.backend in ["osqp", "cuosqp"]: self.solver = OSQPSolver( - self.cfg, self.y_len, - y=self.y, coef=coef, theta=self.theta, - tau=self.tau, ps=self.ps + self.cfg, + self.y_len, + y=self.y, + coef=coef, + theta=self.theta, + tau=self.tau, + ps=self.ps, ) else: raise ValueError(f"Unknown backend: {self.cfg.backend}") - + # State self.T = self.solver.T self.s = np.zeros(self.T) @@ -189,11 +207,11 @@ def __init__( self.s_bin = None # NOTE: do not store an "err_total" here. `_res_err` expects a residual, # not the raw `y`, and this value was misleading and unused. - + # Update dashboard with initial kernel if self.dashboard is not None: self.dashboard.update(h=coef, uid=self.dashboard_uid) - + # Validate coefficients self.solver.validate_coefficients(atol=atol) @@ -264,33 +282,39 @@ def update( ) -> None: """Update parameters.""" logger.debug(f"Updating parameters - backend: {self.cfg.backend}") - + theta_new = None if tau is not None: theta_new = np.array(tau2AR(tau[0], tau[1])) p = solve_p(tau[0], tau[1]) coef_new, _, _ = exp_pulse( - tau[0], tau[1], p_d=p, p_r=-p, + tau[0], + tau[1], + p_d=p, + p_r=-p, nsamp=self.cfg.coef_len * self.cfg.upsamp, - kn_len=self.cfg.coef_len * self.cfg.upsamp + kn_len=self.cfg.coef_len * self.cfg.upsamp, ) coef = coef_new self.tau = tau self.theta = theta_new self.ps = np.array([p, -p]) - + if coef is not None and scale_coef: - current_coef = self.solver.coef if self.solver.coef is not None else np.ones_like(coef) + current_coef = ( + self.solver.coef if self.solver.coef is not None else np.ones_like(coef) + ) scale_mul = scal_lstsq(coef, current_coef).item() - + if l0_penal is not None: self._l0_penal = l0_penal if l1_penal is not None: self._l1_penal = l1_penal - + # Forward updates to solver (solver handles scale directly) self.solver.update( - y=y, coef=coef, + y=y, + coef=coef, scale=scale, scale_mul=scale_mul, l1_penal=self._l1_penal if l1_penal is not None else None, @@ -301,7 +325,7 @@ def update( clear_weighting=clear_weighting, scale_coef=scale_coef, ) - + if y is not None: self.y = y @@ -326,7 +350,7 @@ def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> Non # CVXPY doesn't support masking if self.cfg.backend == "cvxpy": return - + if self.cfg.backend in ["osqp", "cuosqp"]: if use_wt: nzidx_s = np.where(self.R.T @ self.err_wt)[0] @@ -342,14 +366,14 @@ def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> Non self._reset_mask() opt_s, _ = self.solve(amp_constraint) nzidx_s = np.where(opt_s > self.cfg.delta_penal)[0] - + if len(nzidx_s) == 0: logger.warning("Empty mask, resetting") self._reset_mask() return - + self.solver.set_mask(nzidx_s) - + # Verify mask is valid if not self.cfg.free_kernel and len(self.nzidx_c) < self.T: res = self.solver.prob.solve() @@ -373,7 +397,7 @@ def _cut_pks_labs(self, s, labs, pks): pk_labs[p_start:p_stop] = lb lb += 1 p_start = p_stop - pk_labs[p_stop:lb_idxs[-1] + 1] = lb + pk_labs[p_stop : lb_idxs[-1] + 1] = lb lb += 1 else: pk_labs[lb_idxs] = lb @@ -442,45 +466,51 @@ def solve( cur_s, cur_b, _ = self.solver.solve(amp_constraint=amp_constraint) # Compute objective explicitly since solver returns 0 cur_obj = self._compute_err(s=cur_s, b=cur_b) - + if metric_df is None: obj_best = np.inf obj_last = np.inf else: - obj_best = metric_df["obj"][1:].min() if len(metric_df) > 1 else np.inf + obj_best = ( + metric_df["obj"][1:].min() if len(metric_df) > 1 else np.inf + ) obj_last = np.array(metric_df["obj"])[-1] - + opt_s = np.where(cur_s > self.cfg.delta_l0, cur_s, 0) obj_gap = np.abs(cur_obj - obj_best) obj_delta = np.abs(cur_obj - obj_last) - - cur_met = pd.DataFrame([{ - "iter": i, - "obj": cur_obj, - "nnz": (opt_s > 0).sum(), - "obj_gap": obj_gap, - "obj_delta": obj_delta, - }]) + + cur_met = pd.DataFrame( + [ + { + "iter": i, + "obj": cur_obj, + "nnz": (opt_s > 0).sum(), + "obj_gap": obj_gap, + "obj_delta": obj_delta, + } + ] + ) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - - if any([ - obj_gap < self.cfg.rtol * obj_best, - obj_delta < self.cfg.atol - ]): + + if any([obj_gap < self.cfg.rtol * obj_best, obj_delta < self.cfg.atol]): break else: w_new = np.clip( np.ones(self.T) / (self.cfg.delta_l0 * np.ones(self.T) + opt_s), - 0, 1e5 + 0, + 1e5, ) self.update(w=w_new) else: - warnings.warn(f"l0 heuristic did not converge in {self.cfg.max_iter_l0} iterations") - + warnings.warn( + f"l0 heuristic did not converge in {self.cfg.max_iter_l0} iterations" + ) + opt_s, opt_b, _ = self.solver.solve(amp_constraint=amp_constraint) - + self.b = opt_b - + # Peak polishing if pks_polish is None: pks_polish = amp_constraint @@ -495,7 +525,7 @@ def solve( opt_s = self._merge_sparse_regs(s=s_ft, regs=labs, err_rtol=pks_err_rtol) if len(opt_s) == self.T: opt_s = opt_s[self.nzidx_s] - + self.s = np.abs(opt_s) return self.s, self.b @@ -513,14 +543,14 @@ def _res_err(self, r: np.ndarray) -> float: if self.cfg.norm == "l1": return np.sum(np.abs(r)) elif self.cfg.norm == "l2": - return np.sum(r ** 2) + return np.sum(r**2) elif self.cfg.norm == "huber": # True Huber loss: # 0.5*r^2 if |r| <= k # k*(|r| - 0.5*k) otherwise k = float(self.solver.huber_k) ar = np.abs(r) - quad = 0.5 * (r ** 2) + quad = 0.5 * (r**2) lin = k * (ar - 0.5 * k) return float(np.sum(np.where(ar <= k, quad, lin))) @@ -540,7 +570,7 @@ def _compute_err( if b is None: b = self.b y = y - b - + if y_fit is None: if c is None: c = self._compute_c(s) @@ -549,10 +579,10 @@ def _compute_err( y_fit = np.array((R @ c * self.scale).todense()).squeeze() else: y_fit = np.array(R @ c * self.scale).squeeze() - + r = y - y_fit err = self._res_err(r) - + if obj_crit in [None, "spk_diff"]: return float(err) else: @@ -564,7 +594,9 @@ def _compute_err( T = len(r) mu = r.mean() sigma = max(((r - mu) ** 2).sum() / T, 1e-10) - logL = -0.5 * (T * np.log(2 * np.pi * sigma) + 1 / sigma * ((r - mu) ** 2).sum()) + logL = -0.5 * ( + T * np.log(2 * np.pi * sigma) + 1 / sigma * ((r - mu) ** 2).sum() + ) if obj_crit == "aic": return float(2 * (nspk - logL)) elif obj_crit == "bic": @@ -584,24 +616,32 @@ def _max_thres(self, s, nz_only=True): ) # Ensure we return numpy arrays (not xarray.DataArray) S_ls = [np.array(ss) for ss in S_ls] - + # Apply density threshold if self.cfg.density_thres is not None: Sden = [ss.sum() / self.T for ss in S_ls] S_ls = [ss for ss, den in zip(S_ls, Sden) if den < self.cfg.density_thres] thres = [th for th, den in zip(thres, Sden) if den < self.cfg.density_thres] - + # Apply consecutive threshold if self.cfg.ncons_thres is not None: S_pad = [self._pad_s(ss) for ss in S_ls] Sncons = [max_consecutive(np.array(ss)) for ss in S_pad] if len(Sncons) > 0 and min(Sncons) < self.cfg.ncons_thres: - S_ls = [ss for ss, ncons in zip(S_ls, Sncons) if ncons <= self.cfg.ncons_thres] - thres = [th for th, ncons in zip(thres, Sncons) if ncons <= self.cfg.ncons_thres] + S_ls = [ + ss + for ss, ncons in zip(S_ls, Sncons) + if ncons <= self.cfg.ncons_thres + ] + thres = [ + th + for th, ncons in zip(thres, Sncons) + if ncons <= self.cfg.ncons_thres + ] elif len(S_ls) > 0: S_ls = [S_ls[0]] thres = [thres[0]] - + return S_ls, thres def solve_thres( @@ -617,7 +657,7 @@ def solve_thres( y = np.array(self.y) opt_s, opt_b = self.solve(amp_constraint=amp_constraint, pks_polish=pks_polish) R = self.R - + if ignore_res: c = self._compute_c(opt_s) if sps.issparse(c): @@ -626,7 +666,7 @@ def solve_thres( res = y - opt_b - self.scale * (R @ c).squeeze() else: res = np.zeros_like(y) - + svals, thres = self._max_thres(opt_s) if not len(svals) > 0: if return_intm: @@ -638,19 +678,23 @@ def solve_thres( 0, np.inf, ) - + cvals = [self._compute_c(s) for s in svals] - + def to_arr(m): - return np.array(m.todense()).squeeze() if sps.issparse(m) else np.array(m).squeeze() - + return ( + np.array(m.todense()).squeeze() + if sps.issparse(m) + else np.array(m).squeeze() + ) + yfvals = [to_arr(R @ c) for c in cvals] - + if scaling: scal_fit = [scal_lstsq(yf, y - res, fit_intercept=True) for yf in yfvals] scals = [sf[0] for sf in scal_fit] bs = [sf[1] for sf in scal_fit] - + if self.cfg.min_rel_scl is not None: scl_thres = np.max(y) * self.cfg.min_rel_scl valid_idx = np.where(np.array(scals) > scl_thres)[0] @@ -658,7 +702,9 @@ def to_arr(m): # Optional legacy compatibility: reproduce old deconv behavior where # scale filtering shrinks `scals/bs` but does NOT shrink `svals/yfvals`. # This leads to zip() truncation later and can change which candidate is selected. - legacy_bug = os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + legacy_bug = ( + os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + ) scals = [scals[i] for i in valid_idx] bs = [bs[i] for i in valid_idx] if not legacy_bug: @@ -667,7 +713,9 @@ def to_arr(m): yfvals = [yfvals[i] for i in valid_idx] else: max_idx = np.argmax(scals) - legacy_bug = os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + legacy_bug = ( + os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + ) scals = [scals[max_idx]] bs = [bs[max_idx]] if not legacy_bug: @@ -677,15 +725,15 @@ def to_arr(m): else: scals = [self.scale] * len(yfvals) bs = [(y - res - scl * yf).mean() for scl, yf in zip(scals, yfvals)] - + objs = [ self._compute_err(s=ss, y_fit=scl * yf, res=res, b=bb, obj_crit=obj_crit) for ss, scl, yf, bb in zip(svals, scals, yfvals, bs) ] - + scals = np.array(scals).clip(0, None) objs = np.where(scals > 0, objs, np.inf) - + if obj_crit == "spk_diff": err_null = self._compute_err( s=np.zeros_like(opt_s), res=res, b=opt_b, obj_crit=obj_crit @@ -697,20 +745,27 @@ def to_arr(m): nspk_diff = np.where(nspk_diff == 0, 1, nspk_diff) # Avoid division by zero merr_diff = objs_diff / nspk_diff avg_err = (objs_pad.min() - err_null) / max(nspk.max(), 1) - opt_idx = int(np.max(np.where(merr_diff < avg_err)[0])) if np.any(merr_diff < avg_err) else 0 + opt_idx = ( + int(np.max(np.where(merr_diff < avg_err)[0])) + if np.any(merr_diff < avg_err) + else 0 + ) objs = merr_diff else: opt_idx = int(np.argmin(objs)) - + s_bin = svals[opt_idx] self.s_bin = s_bin self.c_bin = to_arr(cvals[opt_idx]) self.b = bs[opt_idx] self.solver.s_bin = s_bin # Update solver's s_bin for adaptive weighting - + if return_intm: return ( - self.s_bin, self.c_bin, scals[opt_idx], objs[opt_idx], + self.s_bin, + self.c_bin, + scals[opt_idx], + objs[opt_idx], (opt_s, thres, svals, cvals, yfvals, scals, objs, opt_idx), ) else: @@ -728,14 +783,14 @@ def solve_penal( if return_intm: return opt_s, opt_c, opt_scl, opt_obj, opt_penal, None return opt_s, opt_c, opt_scl, opt_obj, opt_penal - + pn = f"{self.cfg.penal}_penal" self.update(**{pn: 0}) - + if masking: self._reset_cache() self._update_mask() - + s_nopn, _, _, err_nopn, intm = self.solve_thres( scaling=scaling, return_intm=True, pks_polish=pks_polish ) @@ -743,7 +798,7 @@ def solve_penal( ymean = self.y.mean() err_full = self._compute_err(s=np.zeros(len(self.nzidx_s)), b=ymean) err_min = self._compute_err(s=s_min) - + # Find upper bound for penalty ub, ub_last = err_full, err_full for _ in range(int(np.ceil(np.log2(ub + 1)))): @@ -756,17 +811,17 @@ def solve_penal( else: ub_last = ub ub = ub / 2 - + def opt_fn(x): self.update(**{pn: float(x)}) _, _, _, obj = self.solve_thres(scaling=False, pks_polish=pks_polish) if self.dashboard is not None: self.dashboard.update( uid=self.dashboard_uid, - penal_err={"penal": float(x), "scale": self.scale, "err": obj} + penal_err={"penal": float(x), "scale": self.scale, "err": obj}, ) return obj if obj < err_full else np.inf - + try: res = direct( opt_fn, @@ -777,7 +832,9 @@ def opt_fn(x): ) direct_pn = res.x if not res.success: - logger.warning(f"Could not find optimal penalty within {res.nfev} iterations") + logger.warning( + f"Could not find optimal penalty within {res.nfev} iterations" + ) opt_penal = 0 elif err_nopn <= opt_fn(direct_pn): opt_penal = 0 @@ -786,7 +843,7 @@ def opt_fn(x): except Exception as e: logger.warning(f"DIRECT optimization failed: {e}") opt_penal = 0 - + self.update(**{pn: opt_penal}) if return_intm: opt_s, opt_c, opt_scl, opt_obj, intm = self.solve_thres( @@ -814,15 +871,15 @@ def solve_scale( if self.cfg.penal in ["l0", "l1"]: pn = f"{self.cfg.penal}_penal" self.update(**{pn: 0}) - + self._reset_cache() self._reset_mask() - + if reset_scale: self.update(scale=1) s_free, _ = self.solve(amp_constraint=False) self.update(scale=np.ptp(s_free)) - + metric_df = None for i in range(self.cfg.max_iter_scal): if concur_penal: @@ -837,7 +894,7 @@ def solve_scale( pks_polish=self.cfg.pks_polish and (i > 1 or not reset_scale), obj_crit=obj_crit, ) - + if self.dashboard is not None: pad_s = np.zeros(self.T) pad_s[self.nzidx_s] = cur_s @@ -847,7 +904,7 @@ def solve_scale( s=self.R_org @ pad_s, scale=cur_scl, ) - + if metric_df is None: prev_scals = np.array([np.inf]) opt_obj = np.inf @@ -861,34 +918,40 @@ def solve_scale( prev_scals = np.array(metric_df["scale"]) last_scal = prev_scals[-1] last_obj = np.array(metric_df["obj"])[-1] - + y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) cur_obj = (cur_obj_raw - err_tt) / max(err_tt, 1e-10) - - cur_met = pd.DataFrame([{ - "iter": i, - "scale": cur_scl, - "obj_raw": cur_obj_raw, - "obj": cur_obj, - "penal": cur_penal, - "nnz": (cur_s > 0).sum(), - "density": (cur_s > 0).sum() / self.T, - }]) + + cur_met = pd.DataFrame( + [ + { + "iter": i, + "scale": cur_scl, + "obj_raw": cur_obj_raw, + "obj": cur_obj, + "penal": cur_penal, + "nnz": (cur_s > 0).sum(), + "density": (cur_s > 0).sum() / self.T, + } + ] + ) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - + if self.cfg.err_weighting == "adaptive" and i <= 1: self.update(update_weighting=True) if masking and i >= 1: self._update_mask() - - if any([ - np.abs(cur_scl - opt_scal) < self.cfg.rtol * opt_scal, - np.abs(cur_obj - opt_obj) < self.cfg.rtol * opt_obj, - np.abs(cur_scl - last_scal) < self.cfg.atol, - np.abs(cur_obj - last_obj) < self.cfg.atol * 1e-3, - early_stop and cur_obj > last_obj, - ]): + + if any( + [ + np.abs(cur_scl - opt_scal) < self.cfg.rtol * opt_scal, + np.abs(cur_obj - opt_obj) < self.cfg.rtol * opt_obj, + np.abs(cur_scl - last_scal) < self.cfg.atol, + np.abs(cur_obj - last_obj) < self.cfg.atol * 1e-3, + early_stop and cur_obj > last_obj, + ] + ): break elif cur_scl == 0: warnings.warn("Exit with zero solution") @@ -899,32 +962,34 @@ def solve_scale( self.update(scale=cur_scl) else: warnings.warn("Max scale iterations reached") - + # Final solve with optimal scale opt_idx = metric_df["obj"].idxmin() self.update(update_weighting=True, clear_weighting=True) self._reset_cache() self._reset_mask() self.update(scale=float(metric_df.loc[opt_idx, "scale"])) - + cur_s, cur_c, cur_scl, cur_obj, cur_penal = self.solve_penal( scaling=False, masking=False, pks_polish=self.cfg.pks_polish ) - + opt_s = np.zeros(self.T) opt_c = np.zeros(self.T) opt_s[self.nzidx_s] = cur_s - opt_c[self.nzidx_c] = cur_c if not sps.issparse(cur_c) else cur_c.toarray().squeeze() + opt_c[self.nzidx_c] = ( + cur_c if not sps.issparse(cur_c) else cur_c.toarray().squeeze() + ) nnz = int(opt_s.sum()) - + self.update(update_weighting=True) y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) err_cur = self._compute_err(s=opt_s) err_rel = (err_cur - err_tt) / max(err_tt, 1e-10) - + self.update(update_weighting=True, clear_weighting=True) - + if self.dashboard is not None: self.dashboard.update( uid=self.dashboard_uid, @@ -932,10 +997,10 @@ def solve_scale( s=self.R_org @ opt_s, scale=cur_scl, ) - + self._reset_cache() self._reset_mask() - + if return_met: return opt_s, opt_c, cur_scl, cur_obj, err_rel, nnz, cur_penal, metric_df else: diff --git a/src/indeca/core/deconv/solver.py b/src/indeca/core/deconv/solver.py index 12f2845..03279f1 100644 --- a/src/indeca/core/deconv/solver.py +++ b/src/indeca/core/deconv/solver.py @@ -22,6 +22,7 @@ # Try to import GPU solver try: import cuosqp + HAS_CUOSQP = True except ImportError: HAS_CUOSQP = False @@ -46,26 +47,28 @@ def __init__( self.T = y_len * self.cfg.upsamp self.y = y if y is not None else np.zeros(y_len) self.coef = coef - self.coef_len = len(coef) if coef is not None else config.coef_len * config.upsamp + self.coef_len = ( + len(coef) if coef is not None else config.coef_len * config.upsamp + ) self.theta = theta self.tau = tau self.ps = ps - + # Scale tracking (mutable, since config is frozen) self.scale = config.scale - + # Penalty tracking self.l0_penal = 0.0 self.l1_penal = 0.0 - + # Weight vectors self.w_org = np.ones(self.T) self.w = np.ones(self.T) - + # Masking indices self.nzidx_s = np.arange(self.T) self.nzidx_c = np.arange(self.T) - + # Matrices self.R_org = construct_R(self.y_len, self.cfg.upsamp) self.R = self.R_org @@ -73,16 +76,16 @@ def __init__( self.H_org = None self.G = None self.G_org = None - + # Cache self.x_cache = None self.s_bin = None # Binary spike solution from thresholding - + # Error weighting self.err_wt = np.ones(self.y_len) self.wgt_len = self.coef_len self.Wt = sps.diags(self.err_wt) - + # Huber parameter self.huber_k = 0.5 * np.std(self.y) if y is not None else 0 @@ -148,11 +151,15 @@ def _update_HG(self) -> None: coef = self.coef if coef is None: return - + # H matrix: convolution matrix # IMPORTANT: in free-kernel mode the optimization uses R @ H explicitly, # so H must always be materialized (do not drop it based on Hlim). - if self.cfg.free_kernel or self.cfg.Hlim is None or self.T * len(coef) < self.cfg.Hlim: + if ( + self.cfg.free_kernel + or self.cfg.Hlim is None + or self.T * len(coef) < self.cfg.Hlim + ): self.H_org = sps.diags( [np.repeat(coef[i], self.T - i) for i in range(len(coef))], offsets=-np.arange(len(coef)), @@ -212,7 +219,9 @@ def convolve(self, s: np.ndarray) -> sps.csc_matrix: elif len(s) == self.T: result = self.H @ sps.csc_matrix(s[self.nzidx_s].reshape(-1, 1)) else: - logger.warning(f"Shape mismatch in convolve: s={len(s)}, nzidx_s={len(self.nzidx_s)}") + logger.warning( + f"Shape mismatch in convolve: s={len(s)}, nzidx_s={len(self.nzidx_s)}" + ) result = sps.csc_matrix(np.zeros((len(self.nzidx_c), 1))) return result else: @@ -221,7 +230,7 @@ def convolve(self, s: np.ndarray) -> sps.csc_matrix: out = bin_convolve(self.coef, s, nzidx_s=self.nzidx_s, s_len=self.T) else: s_pad = self._pad_s(s) if len(s) == len(self.nzidx_s) else s - out = np.convolve(self.coef, s_pad)[:self.T] + out = np.convolve(self.coef, s_pad)[: self.T] return sps.csc_matrix(out[self.nzidx_c].reshape(-1, 1)) def validate_coefficients(self, atol: float = 1e-3) -> bool: @@ -229,32 +238,36 @@ def validate_coefficients(self, atol: float = 1e-3) -> bool: if self.tau is None or self.ps is None or self.theta is None: logger.debug("Skipping coefficient validation - missing tau/ps/theta") return True - + try: # Generate exponential pulse tr_exp, _, _ = exp_pulse( - self.tau[0], self.tau[1], - p_d=self.ps[0], p_r=self.ps[1], + self.tau[0], + self.tau[1], + p_d=self.ps[0], + p_r=self.ps[1], nsamp=self.coef_len, ) - + # Generate AR pulse theta = self.theta - tr_ar, _, _ = ar_pulse(theta[0], theta[1], nsamp=self.coef_len, shifted=True) - + tr_ar, _, _ = ar_pulse( + theta[0], theta[1], nsamp=self.coef_len, shifted=True + ) + # Validate if not (~np.isnan(self.coef)).all(): logger.warning("Coefficient array contains NaN values") return False - - if not np.isclose(tr_exp, self.coef[:len(tr_exp)], atol=atol).all(): + + if not np.isclose(tr_exp, self.coef[: len(tr_exp)], atol=atol).all(): logger.warning("Exp time constant inconsistent with coefficients") return False - - if not np.isclose(tr_ar, self.coef[:len(tr_ar)], atol=atol).all(): + + if not np.isclose(tr_ar, self.coef[: len(tr_ar)], atol=atol).all(): logger.warning("AR coefficients inconsistent with coefficients") return False - + logger.debug("Coefficient validation passed") return True except Exception as e: @@ -291,22 +304,28 @@ def _setup_problem(self): # NOTE: `free_kernel=True` is forbidden with CVXPY backend (see `DeconvConfig`). self.cp_R = cp.Constant(self.R, name="R") self.cp_c = cp.Variable((self.T, 1), nonneg=True, name="c") - self.cp_s = cp.Variable((self.T, 1), nonneg=True, name="s", boolean=self.cfg.mixin) + self.cp_s = cp.Variable( + (self.T, 1), nonneg=True, name="s", boolean=self.cfg.mixin + ) self.cp_y = cp.Parameter(shape=(self.y_len, 1), name="y") - self.cp_huber_k = cp.Parameter(value=float(self.huber_k), nonneg=True, name="huber_k") - + self.cp_huber_k = cp.Parameter( + value=float(self.huber_k), nonneg=True, name="huber_k" + ) + self.cp_scale = cp.Parameter(value=self.scale, name="scale", nonneg=True) self.cp_l1_penal = cp.Parameter(value=0.0, name="l1_penal", nonneg=True) - self.cp_l0_w = cp.Parameter(shape=self.T, value=np.zeros(self.T), nonneg=True, name="w_l0") - + self.cp_l0_w = cp.Parameter( + shape=self.T, value=np.zeros(self.T), nonneg=True, name="w_l0" + ) + if self.y is not None: self.cp_y.value = self.y.reshape((-1, 1)) - + if self.cfg.use_base: self.cp_b = cp.Variable(nonneg=True, name="b") else: self.cp_b = cp.Constant(value=0, name="b") - + # Error term based on norm term = self.cp_y - self.cp_scale * self.cp_R @ self.cp_c - self.cp_b if self.cfg.norm == "l1": @@ -316,15 +335,15 @@ def _setup_problem(self): elif self.cfg.norm == "huber": # Keep huber parameter consistent with OSQP backend's `huber_k`. self.err_term = cp.sum(cp.huber(term, M=self.cp_huber_k)) - + # Objective obj_expr = ( - self.err_term - + self.cp_l0_w.T @ cp.abs(self.cp_s) + self.err_term + + self.cp_l0_w.T @ cp.abs(self.cp_s) + self.cp_l1_penal * cp.sum(cp.abs(self.cp_s)) ) obj = cp.Minimize(obj_expr) - + # Constraints # AR constraint via G matrix self.cp_theta = cp.Parameter( @@ -343,10 +362,10 @@ def _setup_problem(self): ] ) dcv_cons = [self.cp_s == G @ self.cp_c] - + edge_cons = [self.cp_c[0, 0] == 0, self.cp_s[-1, 0] == 0] amp_cons = [self.cp_s <= 1] - + self.prob_free = cp.Problem(obj, dcv_cons + edge_cons) self.prob = cp.Problem(obj, dcv_cons + edge_cons + amp_cons) @@ -360,7 +379,7 @@ def update( l0_penal: float = None, w: np.ndarray = None, theta: np.ndarray = None, - **kwargs + **kwargs, ): """Update CVXPY parameters.""" if y is not None: @@ -388,7 +407,7 @@ def update( self._update_w(w) if l0_penal is not None or w is not None: self.cp_l0_w.value = self.l0_penal * self.w - if theta is not None and hasattr(self, 'cp_theta'): + if theta is not None and hasattr(self, "cp_theta"): self.theta = theta self.cp_theta.value = theta @@ -400,12 +419,20 @@ def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: except cp.error.SolverError as e: logger.warning(f"CVXPY SolverError: {e}") res = np.inf - - opt_s = self.cp_s.value.squeeze() if self.cp_s.value is not None else np.zeros(self.T) + + opt_s = ( + self.cp_s.value.squeeze() + if self.cp_s.value is not None + else np.zeros(self.T) + ) opt_b = 0 - if self.cfg.use_base and hasattr(self.cp_b, 'value') and self.cp_b.value is not None: + if ( + self.cfg.use_base + and hasattr(self.cp_b, "value") + and self.cp_b.value is not None + ): opt_b = float(self.cp_b.value) - + return opt_s, opt_b, res @@ -414,7 +441,7 @@ class OSQPSolver(DeconvSolver): def __init__(self, config: DeconvConfig, y_len: int, **kwargs): super().__init__(config, y_len, **kwargs) - + # Additional state for OSQP self.prob = None self.prob_free = None @@ -426,12 +453,12 @@ def __init__(self, config: DeconvConfig, y_len: int, **kwargs): self.ub = None self.ub_inf = None self.nzidx_A = None - + # STFT for FFT weighting if self.cfg.err_weighting == "fft": self.stft = ShortTimeFFT(win=np.ones(self.coef_len), hop=1, fs=1) self.yspec = get_stft_spec(self.y, self.stft) - + # Initialize matrices and problem self._update_HG() self._update_wgt_len() @@ -456,7 +483,7 @@ def _update_Wt(self, clear: bool = False) -> None: if clear: logger.debug("Clearing error weighting") self.err_wt = np.ones(self.y_len) - elif self.cfg.err_weighting == "fft" and hasattr(self, 'stft'): + elif self.cfg.err_weighting == "fft" and hasattr(self, "stft"): logger.debug("Updating error weighting with fft") hspec = get_stft_spec(coef, self.stft)[:, int(len(coef) / 2)] self.err_wt = ( @@ -471,7 +498,7 @@ def _update_Wt(self, clear: bool = False) -> None: yseg = self.y[i : i + len(coef)] if len(yseg) <= 1: continue - cseg = coef[:len(yseg)] + cseg = coef[: len(yseg)] with np.errstate(all="ignore"): self.err_wt[i] = np.corrcoef(yseg, cseg)[0, 1].clip(0, 1) self.err_wt = np.nan_to_num(self.err_wt) @@ -483,7 +510,7 @@ def _update_Wt(self, clear: bool = False) -> None: self.err_wt[nzidx : nzidx + self.wgt_len] = 1 else: self.err_wt = np.ones(self.y_len) - + self.Wt = sps.diags(self.err_wt) def _get_M(self) -> sps.csc_matrix: @@ -511,18 +538,22 @@ def _update_P(self) -> None: ls = len(self.nzidx_s) ly = self.y_len if self.cfg.free_kernel: - P = sps.bmat([ - [sps.csc_matrix((ls + 1, ls + 1)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ]) + P = sps.bmat( + [ + [sps.csc_matrix((ls + 1, ls + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ] + ) else: - P = sps.bmat([ - [sps.csc_matrix((lc + 1, lc + 1)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ]) - + P = sps.bmat( + [ + [sps.csc_matrix((lc + 1, lc + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ] + ) + self.P = sps.triu(P).tocsc() logger.debug(f"Updated P matrix - shape: {self.P.shape}, nnz: {self.P.nnz}") @@ -535,8 +566,12 @@ def _update_q0(self) -> None: self.q0 = -M.T @ self.Wt.T @ self.Wt @ self.y elif self.cfg.norm == "huber": ly = self.y_len - lx = len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 - self.q0 = np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k + lx = ( + len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + ) + self.q0 = ( + np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k + ) def _update_q(self) -> None: """Update linear cost vector q (including penalties).""" @@ -560,13 +595,18 @@ def _update_q(self) -> None: self.q = ( self.q0 + self.l0_penal * np.concatenate([[0], self.w, pad_k, pad_k]) - + self.l1_penal * np.concatenate([[0], np.ones(len(self.nzidx_s)), pad_k, pad_k]) + + self.l1_penal + * np.concatenate([[0], np.ones(len(self.nzidx_s)), pad_k, pad_k]) ) else: self.q = ( self.q0 - + self.l0_penal * np.concatenate([[0], self.w @ self.G, pad_k, pad_k]) - + self.l1_penal * np.concatenate([[0], np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k]) + + self.l0_penal + * np.concatenate([[0], self.w @ self.G, pad_k, pad_k]) + + self.l1_penal + * np.concatenate( + [[0], np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k] + ) ) def _update_A(self) -> None: @@ -580,46 +620,44 @@ def _update_A(self) -> None: self.nzidx_A = np.where((Ax != 0).sum(axis=1))[0] Ax = Ax[self.nzidx_A, :] Ar = self.scale * self.R - + if self.cfg.norm == "huber": e = sps.eye(self.y_len, format="csc") - self.A = sps.bmat([ - [sps.csc_matrix((Ax.shape[0], 1)), Ax, None, None], - [None, None, e, None], - [None, None, None, -e], - [np.ones((Ar.shape[0], 1)), Ar, e, e], - ], format="csc") + self.A = sps.bmat( + [ + [sps.csc_matrix((Ax.shape[0], 1)), Ax, None, None], + [None, None, e, None], + [None, None, None, -e], + [np.ones((Ar.shape[0], 1)), Ar, e, e], + ], + format="csc", + ) else: - self.A = sps.bmat([ - [np.ones((1, 1)), None], - [None, Ax] - ], format="csc") - + self.A = sps.bmat([[np.ones((1, 1)), None], [None, Ax]], format="csc") + logger.debug(f"Updated A matrix - shape: {self.A.shape}, nnz: {self.A.nnz}") def _update_bounds(self) -> None: """Update constraint bounds.""" if self.cfg.norm == "huber": xlen = len(self.nzidx_s) if self.cfg.free_kernel else len(self.nzidx_A) - self.lb = np.concatenate([ - np.zeros(xlen + self.y_len * 2), - self.y - self.huber_k - ]) - self.ub = np.concatenate([ - np.ones(xlen), - np.full(self.y_len * 2, np.inf), - self.y - self.huber_k - ]) - self.ub_inf = np.concatenate([ - np.full(xlen + self.y_len * 2, np.inf), - self.y - self.huber_k - ]) + self.lb = np.concatenate( + [np.zeros(xlen + self.y_len * 2), self.y - self.huber_k] + ) + self.ub = np.concatenate( + [np.ones(xlen), np.full(self.y_len * 2, np.inf), self.y - self.huber_k] + ) + self.ub_inf = np.concatenate( + [np.full(xlen + self.y_len * 2, np.inf), self.y - self.huber_k] + ) else: bb = np.clip(self.y.mean(), 0, None) if self.cfg.use_base else 0 if self.cfg.free_kernel: self.lb = np.zeros(len(self.nzidx_s) + 1) self.ub = np.concatenate([np.full(1, bb), np.ones(len(self.nzidx_s))]) - self.ub_inf = np.concatenate([np.full(1, bb), np.full(len(self.nzidx_s), np.inf)]) + self.ub_inf = np.concatenate( + [np.full(1, bb), np.full(len(self.nzidx_s), np.inf)] + ) else: ub_pad = np.zeros(self.T) ub_inf_pad = np.zeros(self.T) @@ -628,20 +666,22 @@ def _update_bounds(self) -> None: self.lb = np.zeros(len(self.nzidx_A) + 1) self.ub = np.concatenate([np.full(1, bb), ub_pad[self.nzidx_A]]) self.ub_inf = np.concatenate([np.full(1, bb), ub_inf_pad[self.nzidx_A]]) - + assert (self.ub >= self.lb).all(), "Upper bounds must be >= lower bounds" - assert (self.ub_inf >= self.lb).all(), "Upper bounds (inf) must be >= lower bounds" + assert ( + self.ub_inf >= self.lb + ).all(), "Upper bounds (inf) must be >= lower bounds" def _setup_prob_osqp(self) -> None: """Setup OSQP problem instances.""" logger.debug("Setting up OSQP problem") - + self._update_P() self._update_q0() self._update_q() self._update_A() self._update_bounds() - + # Choose solver backend if self.cfg.backend == "cuosqp": if not HAS_CUOSQP: @@ -659,7 +699,7 @@ def _setup_prob_osqp(self) -> None: else: self.prob = osqp.OSQP() self.prob_free = osqp.OSQP() - + # Setup constrained problem self.prob.setup( P=self.P.copy(), @@ -675,7 +715,7 @@ def _setup_prob_osqp(self) -> None: eps_prim_inf=1e-7, eps_dual_inf=1e-7, ) - + # Setup unconstrained (free) problem self.prob_free.setup( P=self.P.copy(), @@ -691,7 +731,7 @@ def _setup_prob_osqp(self) -> None: eps_prim_inf=1e-7, eps_dual_inf=1e-7, ) - + logger.debug(f"{self.cfg.backend} setup completed successfully") def update( @@ -708,11 +748,11 @@ def update( update_weighting: bool = False, clear_weighting: bool = False, scale_coef: bool = False, - **kwargs + **kwargs, ): """Update OSQP problem parameters.""" logger.debug(f"Updating OSQP solver parameters") - + # Update input parameters if y is not None: self.y = y @@ -722,8 +762,7 @@ def update( theta_new = np.array(tau2AR(tau[0], tau[1])) p = solve_p(tau[0], tau[1]) coef_new, _, _ = exp_pulse( - tau[0], tau[1], p_d=p, p_r=-p, - nsamp=self.coef_len, kn_len=self.coef_len + tau[0], tau[1], p_d=p, p_r=-p, nsamp=self.coef_len, kn_len=self.coef_len ) self.tau = tau self.ps = np.array([p, -p]) @@ -745,7 +784,7 @@ def update( self.l0_penal = l0_penal if w is not None: self._update_w(w) - + # Track what needs updating updt_HG = coef is not None updt_P = False @@ -754,11 +793,11 @@ def update( updt_A = False updt_bounds = False setup_prob = False - + if updt_HG: self._update_HG() self._update_wgt_len() - + if self.cfg.err_weighting is not None and update_weighting: self._update_Wt(clear=clear_weighting) if self.cfg.err_weighting == "adaptive": @@ -767,7 +806,7 @@ def update( updt_P = True updt_q0 = True updt_q = True - + if self.cfg.norm == "huber": # huber_k changes require recomputing q and bounds if y is not None: @@ -776,7 +815,9 @@ def update( if any([scale is not None, scale_mul is not None, updt_HG]): self._update_A() updt_A = True - if any([w is not None, l0_penal is not None, l1_penal is not None, updt_HG]): + if any( + [w is not None, l0_penal is not None, l1_penal is not None, updt_HG] + ): self._update_q() updt_q = True if y is not None: @@ -789,13 +830,29 @@ def update( if any([scale is not None, scale_mul is not None, updt_HG, updt_P]): self._update_P() updt_P = True - if any([scale is not None, scale_mul is not None, y is not None, updt_HG, updt_q0]): + if any( + [ + scale is not None, + scale_mul is not None, + y is not None, + updt_HG, + updt_q0, + ] + ): self._update_q0() updt_q0 = True - if any([w is not None, l0_penal is not None, l1_penal is not None, updt_q0, updt_q]): + if any( + [ + w is not None, + l0_penal is not None, + l1_penal is not None, + updt_q0, + updt_q, + ] + ): self._update_q() updt_q = True - + # Apply updates to OSQP - conservative approach: # Only q can be updated in-place safely. For P, A, bounds, rebuild. if setup_prob or any([updt_P, updt_A, updt_bounds]): @@ -803,25 +860,31 @@ def update( elif updt_q: self.prob.update(q=self.q) self.prob_free.update(q=self.q) - + logger.debug("OSQP problem updated") def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: """Solve OSQP problem.""" prob = self.prob if amp_constraint else self.prob_free res = prob.solve() - + if res.info.status not in ["solved", "solved inaccurate"]: logger.warning(f"OSQP not solved: {res.info.status}") if res.info.status in ["primal infeasible", "primal infeasible inaccurate"]: x = np.zeros(self.P.shape[0], dtype=float) else: - x = res.x.astype(float) if res.x is not None else np.zeros(self.P.shape[0], dtype=float) + x = ( + res.x.astype(float) + if res.x is not None + else np.zeros(self.P.shape[0], dtype=float) + ) else: x = res.x - + if self.cfg.norm == "huber": - xlen = len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + xlen = ( + len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + ) sol = x[:xlen] opt_b = sol[0] if self.cfg.free_kernel: @@ -835,6 +898,6 @@ def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: else: c_sol = x[1:] opt_s = self.G @ c_sol - + # Return 0 for objective - caller should use _compute_err for correct objective return opt_s, opt_b, 0 diff --git a/tests/unit/test_deconv_G_matrix.py b/tests/unit/test_deconv_G_matrix.py index b64d5af..08d3b7e 100644 --- a/tests/unit/test_deconv_G_matrix.py +++ b/tests/unit/test_deconv_G_matrix.py @@ -55,5 +55,3 @@ def test_G_matrix_matches_shifted_ar_difference_equation(): expected[4] = 0.0 assert np.allclose(s, expected) - - From f51f9d876c97b1d8f1d4abfc214327b7f640b60d Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Sat, 13 Dec 2025 16:47:20 -0800 Subject: [PATCH 5/6] refactor pipeline --- src/indeca/pipeline/__init__.py | 70 ++- src/indeca/pipeline/ar_update.py | 306 +++++++++++++ src/indeca/pipeline/binary_pursuit.py | 364 +++++++++++++++ src/indeca/pipeline/config.py | 281 ++++++++++++ src/indeca/pipeline/convergence.py | 112 +++++ src/indeca/pipeline/init.py | 220 +++++++++ src/indeca/pipeline/iteration.py | 79 ++++ src/indeca/pipeline/metrics.py | 132 ++++++ src/indeca/pipeline/pipeline.py | 618 ++++++++------------------ src/indeca/pipeline/preprocess.py | 57 +++ src/indeca/pipeline/types.py | 154 +++++++ 11 files changed, 1969 insertions(+), 424 deletions(-) create mode 100644 src/indeca/pipeline/ar_update.py create mode 100644 src/indeca/pipeline/binary_pursuit.py create mode 100644 src/indeca/pipeline/config.py create mode 100644 src/indeca/pipeline/convergence.py create mode 100644 src/indeca/pipeline/init.py create mode 100644 src/indeca/pipeline/iteration.py create mode 100644 src/indeca/pipeline/metrics.py create mode 100644 src/indeca/pipeline/preprocess.py create mode 100644 src/indeca/pipeline/types.py diff --git a/src/indeca/pipeline/__init__.py b/src/indeca/pipeline/__init__.py index 487e269..9af47cf 100644 --- a/src/indeca/pipeline/__init__.py +++ b/src/indeca/pipeline/__init__.py @@ -1,3 +1,69 @@ -from .pipeline import pipeline_bin +"""Binary pursuit deconvolution pipeline. -__all__ = ["pipeline_bin"] +This package provides the binary pursuit pipeline for spike inference +from calcium imaging traces. + +Usage (recommended, config-based API):: + + from indeca.pipeline import pipeline_bin, DeconvPipelineConfig + + config = DeconvPipelineConfig( + up_factor=2, + convergence=ConvergenceConfig(max_iters=20), + ) + opt_C, opt_S, metrics = pipeline_bin(Y, config=config) + +Usage (legacy, deprecated):: + + from indeca.pipeline import pipeline_bin + + opt_C, opt_S, metrics = pipeline_bin(Y, up_factor=2, max_iters=20) + +""" + +# New config-based API (recommended) +from .binary_pursuit import pipeline_bin as pipeline_bin_new +from .config import ( + ARUpdateConfig, + ConvergenceConfig, + DeconvPipelineConfig, + DeconvStageConfig, + InitConfig, + PreprocessConfig, +) + +# Legacy API (deprecated, for backward compatibility) +from .pipeline import pipeline_bin, pipeline_bin_legacy + +# Type definitions +from .types import ( + ARParams, + ARUpdateResult, + ConvergenceResult, + DeconvStepResult, + IterationState, + PipelineResult, +) + +__all__ = [ + # Main entry point (legacy for backward compat, emits deprecation warning) + "pipeline_bin", + # New config-based entry point + "pipeline_bin_new", + # Legacy explicit name + "pipeline_bin_legacy", + # Configuration classes + "DeconvPipelineConfig", + "PreprocessConfig", + "InitConfig", + "DeconvStageConfig", + "ARUpdateConfig", + "ConvergenceConfig", + # Type definitions + "ARParams", + "DeconvStepResult", + "ARUpdateResult", + "ConvergenceResult", + "IterationState", + "PipelineResult", +] diff --git a/src/indeca/pipeline/ar_update.py b/src/indeca/pipeline/ar_update.py new file mode 100644 index 0000000..ef35136 --- /dev/null +++ b/src/indeca/pipeline/ar_update.py @@ -0,0 +1,306 @@ +"""AR parameter update functions. + +Handles spike selection, AR estimation, and parameter propagation. +""" + +from typing import Any, List, Optional, Tuple + +import numpy as np +import pandas as pd +from scipy.signal import find_peaks + +from indeca.core.AR_kernel import updateAR +from indeca.core.deconv import construct_R + +from .types import ARUpdateResult + + +def select_best_spikes( + S_ls: List[np.ndarray], + scal_ls: List[np.ndarray], + err_rel: np.ndarray, + metric_df: pd.DataFrame, + *, + n_best: Optional[int], + i_iter: int, + tau_init: Optional[Tuple[float, float]], +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]: + """Select best spikes based on n_best iterations. + + Parameters + ---------- + S_ls : list of np.ndarray + Spike trains from all iterations + scal_ls : list of np.ndarray + Scale factors from all iterations + err_rel : np.ndarray + Relative errors from current iteration + metric_df : pd.DataFrame + Accumulated metrics + n_best : int or None + Number of best iterations to use + i_iter : int + Current iteration index + tau_init : tuple or None + Initial tau values (affects metric selection) + + Returns + ------- + S_best : np.ndarray + Best spike trains, shape (ncell, T * up_factor) + scal_best : np.ndarray + Best scale factors, shape (ncell,) + err_wt : np.ndarray + Error weights (negative err_rel), shape (ncell,) + metric_df : pd.DataFrame + Updated metric DataFrame with best_idx column + """ + S = S_ls[-1] # Current iteration spikes + scale = scal_ls[-1] # Current iteration scales + + metric_df = metric_df.set_index(["iter", "cell"]) + + if n_best is not None and i_iter >= n_best: + ncell = S.shape[0] + S_best = np.empty_like(S) + scal_best = np.empty_like(scale) + err_wt = np.empty_like(err_rel) + + if tau_init is not None: + metric_best = metric_df + else: + metric_best = metric_df.loc[1:, :] + + for icell, cell_met in metric_best.groupby("cell", sort=True): + cell_met = cell_met.reset_index().sort_values("obj", ascending=True) + cur_idx = np.array(cell_met["iter"][:n_best]) + metric_df.loc[(i_iter, icell), "best_idx"] = ",".join( + cur_idx.astype(str) + ) + S_best[icell, :] = np.sum( + np.stack([S_ls[i][icell, :] for i in cur_idx], axis=0), axis=0 + ) > (n_best / 2) + scal_best[icell] = np.mean([scal_ls[i][icell] for i in cur_idx]) + err_wt[icell] = -np.mean( + [metric_df.loc[(i, icell), "err_rel"] for i in cur_idx] + ) + else: + S_best = S + scal_best = scale + err_wt = -err_rel + + metric_df = metric_df.reset_index() + return S_best, scal_best, err_wt, metric_df + + +def make_S_ar( + S_best: np.ndarray, + *, + est_nevt: Optional[int], + T: int, + up_factor: int, + ar_kn_len: int, +) -> np.ndarray: + """Create spike train for AR estimation with optional peak masking. + + Parameters + ---------- + S_best : np.ndarray + Best spike trains, shape (ncell, T * up_factor) + est_nevt : int or None + Number of top events to use. None uses all spikes. + T : int + Original trace length + up_factor : int + Upsampling factor + ar_kn_len : int + AR kernel length + + Returns + ------- + np.ndarray + Spike train for AR estimation, shape (ncell, T * up_factor) + """ + if est_nevt is not None: + S_ar = [] + R = construct_R(T, up_factor) + + for s in S_best: + Rs = R @ s + s_pks, pk_prop = find_peaks( + Rs, height=1, distance=ar_kn_len * up_factor + ) + pk_ht = pk_prop["peak_heights"] + top_idx = s_pks[np.argsort(pk_ht)[-est_nevt:]] + mask = np.zeros_like(Rs, dtype=bool) + mask[top_idx] = True + Rs_ma = Rs * mask + s_ma = np.zeros_like(s) + s_ma[::up_factor] = Rs_ma + S_ar.append(s_ma) + + S_ar = np.stack(S_ar, axis=0) + else: + S_ar = S_best + + return S_ar + + +def update_ar_parameters( + Y: np.ndarray, + S_ar: np.ndarray, + scal_best: np.ndarray, + err_wt: np.ndarray, + *, + ar_use_all: bool, + ar_kn_len: int, + ar_norm: str, + ar_prop_best: Optional[float], + up_factor: int, + p: int, + ncell: int, + dashboard: Any, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Update AR parameters based on current spike estimates. + + Parameters + ---------- + Y : np.ndarray + Input traces, shape (ncell, T) + S_ar : np.ndarray + Spike trains for AR estimation, shape (ncell, T * up_factor) + scal_best : np.ndarray + Best scale factors, shape (ncell,) + err_wt : np.ndarray + Error weights, shape (ncell,) + ar_use_all : bool + Whether to use all cells for shared AR update + ar_kn_len : int + AR kernel length + ar_norm : str + Norm for AR fitting + ar_prop_best : float or None + Proportion of best cells to use + up_factor : int + Upsampling factor + p : int + AR model order + ncell : int + Number of cells + dashboard : Dashboard or None + Dashboard instance + + Returns + ------- + tau : np.ndarray + Updated time constants, shape (ncell, 2) + ps : np.ndarray + Updated peak coefficients + h : np.ndarray + Impulse response + h_fit : np.ndarray + Fitted impulse response + """ + if ar_use_all: + # Shared AR update across cells + if ar_prop_best is not None: + ar_nbest = max(int(np.round(ar_prop_best * ncell)), 1) + ar_best_idx = np.argsort(err_wt)[-ar_nbest:] + else: + ar_best_idx = slice(None) + + cur_tau, ps, ar_scal, h, h_fit = updateAR( + Y[ar_best_idx], + S_ar[ar_best_idx], + scal_best[ar_best_idx], + N=p, + h_len=ar_kn_len * up_factor, + norm=ar_norm, + up_factor=up_factor, + ) + + if dashboard is not None: + dashboard.update( + h=h[: ar_kn_len * up_factor], + h_fit=h_fit[: ar_kn_len * up_factor], + ) + + tau = np.tile(cur_tau, (ncell, 1)) + else: + # Per-cell AR update + tau = np.empty((ncell, p)) + + # NOTE: Original pipeline only retained the last cell's ps/h/h_fit + # when ar_use_all=False. We preserve this behavior explicitly. + ps = None + h = None + h_fit = None + + for icell, (y, s) in enumerate(zip(Y, S_ar)): + cur_tau, cur_ps, ar_scal, cur_h, cur_h_fit = updateAR( + y, + s, + scal_best[icell], + N=p, + h_len=ar_kn_len, + norm=ar_norm, + up_factor=up_factor, + ) + + if dashboard is not None: + dashboard.update(uid=icell, h=cur_h, h_fit=cur_h_fit) + + tau[icell, :] = cur_tau + + # Overwrite on each iteration; only last cell's values are kept + ps = cur_ps + h = cur_h + h_fit = cur_h_fit + + return tau, ps, h, h_fit + + +def propagate_ar_update( + deconvolvers: List[Any], + tau: np.ndarray, + scal_best: np.ndarray, + *, + ar_use_all: bool, + da_client: Any, +) -> None: + """Propagate AR parameter updates to deconvolvers. + + Parameters + ---------- + deconvolvers : list + List of DeconvBin instances + tau : np.ndarray + Updated time constants, shape (ncell, 2) + scal_best : np.ndarray + Best scale factors, shape (ncell,) + ar_use_all : bool + Whether using shared AR (affects which tau to use) + da_client : Client or None + Dask client for distributed execution + """ + if ar_use_all: + # All cells share the same tau (use tau[0]) + cur_tau = tau[0] + for idx, d in enumerate(deconvolvers): + if da_client is not None: + da_client.submit( + lambda dd: dd.update(tau=cur_tau, scale=scal_best[idx]), d + ) + else: + d.update(tau=cur_tau, scale=scal_best[idx]) + else: + # Per-cell tau + for idx, d in enumerate(deconvolvers): + if da_client is not None: + da_client.submit( + lambda dd: dd.update(tau=tau[idx], scale=scal_best[idx]), + deconvolvers[idx], + ) + else: + d.update(tau=tau[idx], scale=scal_best[idx]) + diff --git a/src/indeca/pipeline/binary_pursuit.py b/src/indeca/pipeline/binary_pursuit.py new file mode 100644 index 0000000..d7f052c --- /dev/null +++ b/src/indeca/pipeline/binary_pursuit.py @@ -0,0 +1,364 @@ +"""Binary pursuit deconvolution pipeline. + +This module contains the main pipeline_bin function that orchestrates +the entire deconvolution process in a readable, top-down manner. +""" + +from typing import Any, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from line_profiler import profile +from tqdm.auto import trange + +from indeca.core.simulation import tau2AR +from indeca.dashboard.dashboard import Dashboard +from indeca.utils.logging_config import get_module_logger + +from .ar_update import ( + make_S_ar, + propagate_ar_update, + select_best_spikes, + update_ar_parameters, +) +from .config import DeconvPipelineConfig +from .convergence import check_convergence +from .init import initialize_ar_params, initialize_deconvolvers +from .iteration import run_deconv_step +from .metrics import append_metrics, make_cur_metric, update_dashboard +from .preprocess import preprocess_traces +from .types import IterationState + +logger = get_module_logger("pipeline") + + +@profile +def pipeline_bin( + Y: np.ndarray, + *, + config: DeconvPipelineConfig, + da_client: Any = None, + spawn_dashboard: bool = True, + return_iter: bool = False, +) -> Union[ + Tuple[np.ndarray, np.ndarray, pd.DataFrame], + Tuple[np.ndarray, np.ndarray, pd.DataFrame, list, list, list, list], +]: + """Binary pursuit pipeline for spike inference. + + This is the main entry point for the deconvolution pipeline. + It orchestrates preprocessing, initialization, iterative deconvolution, + AR updates, and convergence checking. + + Parameters + ---------- + Y : np.ndarray + Input fluorescence traces, shape (ncell, T) + config : DeconvPipelineConfig + Pipeline configuration + da_client : Client or None + Dask client for distributed execution. None for local execution. + spawn_dashboard : bool + Whether to spawn a real-time dashboard + return_iter : bool + Whether to return per-iteration results + + Returns + ------- + opt_C : np.ndarray + Optimal calcium traces, shape (ncell, T * up_factor) + opt_S : np.ndarray + Optimal spike trains, shape (ncell, T * up_factor) + metric_df : pd.DataFrame + Per-iteration metrics + C_ls : list (only if return_iter=True) + Calcium traces per iteration + S_ls : list (only if return_iter=True) + Spike trains per iteration + h_ls : list (only if return_iter=True) + Impulse responses per iteration + h_fit_ls : list (only if return_iter=True) + Fitted impulse responses per iteration + """ + logger.info("Starting binary pursuit pipeline") + + # Unpack config + up_factor = config.up_factor + p = config.p + preprocess_cfg = config.preprocess + init_cfg = config.init + deconv_cfg = config.deconv + ar_cfg = config.ar_update + conv_cfg = config.convergence + + # 0. Housekeeping + ncell, T = Y.shape + logger.debug( + f"Pipeline parameters: " + f"up_factor={up_factor}, p={p}, max_iters={conv_cfg.max_iters}, " + f"n_best={conv_cfg.n_best}, backend={deconv_cfg.backend}, " + f"ar_use_all={ar_cfg.use_all}, ar_kn_len={ar_cfg.kn_len}, " + f"{ncell} cells with {T} timepoints" + ) + + # 1. Preprocessing + Y = preprocess_traces( + Y, + med_wnd=preprocess_cfg.med_wnd, + dff=preprocess_cfg.dff, + ar_kn_len=ar_cfg.kn_len, + ) + + # 2. Dashboard setup + if spawn_dashboard: + if da_client is not None: + logger.debug("Using Dask client for distributed computation") + dashboard = da_client.submit( + Dashboard, Y=Y, kn_len=ar_cfg.kn_len, actor=True + ).result() + else: + logger.debug("Running in single-machine mode") + dashboard = Dashboard(Y=Y, kn_len=ar_cfg.kn_len) + else: + dashboard = None + + # 3. Initialize AR parameters + ar_params = initialize_ar_params( + Y, + tau_init=init_cfg.tau_init, + p=p, + up_factor=up_factor, + ar_kn_len=ar_cfg.kn_len, + est_noise_freq=init_cfg.est_noise_freq, + est_use_smooth=init_cfg.est_use_smooth, + est_add_lag=init_cfg.est_add_lag, + ) + theta = ar_params.theta + tau = ar_params.tau + + # 4. Initialize deconvolvers + dcv = initialize_deconvolvers( + Y, + ar_params, + ar_kn_len=ar_cfg.kn_len, + up_factor=up_factor, + nthres=deconv_cfg.nthres, + norm=deconv_cfg.norm, + penal=deconv_cfg.penal, + use_base=deconv_cfg.use_base, + err_weighting=deconv_cfg.err_weighting, + masking_radius=deconv_cfg.masking_radius, + pks_polish=deconv_cfg.pks_polish, + ncons_thres=deconv_cfg.ncons_thres, + min_rel_scl=deconv_cfg.min_rel_scl, + atol=deconv_cfg.atol, + backend=deconv_cfg.backend, + dashboard=dashboard, + da_client=da_client, + ) + + # 5. Initialize iteration state + state = IterationState.empty(T, up_factor) + scale = np.empty(ncell) + + # 6. Main iteration loop + for i_iter in trange(conv_cfg.max_iters, desc="iteration"): + logger.info(f"Starting iteration {i_iter}/{conv_cfg.max_iters}") + + # 6.1 Deconvolution step + deconv_result = run_deconv_step( + Y, + dcv, + i_iter=i_iter, + reset_scale=deconv_cfg.reset_scale, + da_client=da_client, + ) + scale = deconv_result.scale + + logger.debug( + f"Iteration {i_iter} stats - " + f"Mean error: {deconv_result.err.mean():.4f}, " + f"Mean scale: {scale.mean():.4f}" + ) + + # 6.2 Update metrics + cur_metric = make_cur_metric( + i_iter=i_iter, + ncell=ncell, + theta=theta, + tau=tau, + scale=scale, + deconv_result=deconv_result, + deconvolvers=dcv, + use_rel_err=conv_cfg.use_rel_err, + ) + update_dashboard(dashboard, cur_metric, i_iter, conv_cfg.max_iters) + state.metric_df = append_metrics(state.metric_df, cur_metric) + + # 6.3 Save iteration results + state.C_ls.append(deconv_result.C) + state.S_ls.append(deconv_result.S) + state.scal_ls.append(scale) + + # Handle h_ls / h_fit_ls (not available on first iteration) + if i_iter == 0: + state.h_ls.append(np.full(T * up_factor, np.nan)) + state.h_fit_ls.append(np.full(T * up_factor, np.nan)) + else: + state.h_ls.append(h) + state.h_fit_ls.append(h_fit) + + # 6.4 Select best spikes for AR update + S_best, scal_best, err_wt, state.metric_df = select_best_spikes( + state.S_ls, + state.scal_ls, + deconv_result.err_rel, + state.metric_df, + n_best=conv_cfg.n_best, + i_iter=i_iter, + tau_init=init_cfg.tau_init, + ) + + # 6.5 Create spike train for AR estimation + S_ar = make_S_ar( + S_best, + est_nevt=init_cfg.est_nevt, + T=T, + up_factor=up_factor, + ar_kn_len=ar_cfg.kn_len, + ) + + # 6.6 Update AR parameters + tau, ps, h, h_fit = update_ar_parameters( + Y, + S_ar, + scal_best, + err_wt, + ar_use_all=ar_cfg.use_all, + ar_kn_len=ar_cfg.kn_len, + ar_norm=ar_cfg.norm, + ar_prop_best=ar_cfg.prop_best, + up_factor=up_factor, + p=p, + ncell=ncell, + dashboard=dashboard, + ) + + # Update theta to match the new tau values + # (required for correct metric reporting in make_cur_metric) + theta = np.array([tau2AR(t[0], t[1]) for t in tau]) + + if ar_cfg.use_all: + logger.debug( + f"Updating AR parameters for all cells: tau={tau[0]}" + ) + else: + logger.debug(f"Updated AR parameters per-cell") + + # 6.7 Propagate AR update to deconvolvers + propagate_ar_update( + dcv, + tau, + scal_best, + ar_use_all=ar_cfg.use_all, + da_client=da_client, + ) + + # 6.8 Check convergence + conv_result = check_convergence( + state.metric_df, + cur_metric, + deconv_result.S, + state.S_ls, + i_iter=i_iter, + err_atol=conv_cfg.err_atol, + err_rtol=conv_cfg.err_rtol, + ) + + if conv_result.converged: + if "trapped" in conv_result.reason.lower(): + logger.warning(conv_result.reason) + else: + logger.info(conv_result.reason) + break + else: + logger.warning("Max iteration reached without convergence") + + # 7. Compute final results + opt_C, opt_S = _finalize_results( + state, ncell, T, up_factor, ar_cfg.use_all + ) + + # 8. Cleanup + if dashboard is not None: + dashboard.stop() + + logger.info("Pipeline completed successfully") + + if return_iter: + return ( + opt_C, + opt_S, + state.metric_df, + state.C_ls, + state.S_ls, + state.h_ls, + state.h_fit_ls, + ) + else: + return opt_C, opt_S, state.metric_df + + +def _finalize_results( + state: IterationState, + ncell: int, + T: int, + up_factor: int, + ar_use_all: bool, +) -> Tuple[np.ndarray, np.ndarray]: + """Compute final optimal results from iteration history. + + Parameters + ---------- + state : IterationState + Accumulated iteration state + ncell : int + Number of cells + T : int + Original trace length + up_factor : int + Upsampling factor + ar_use_all : bool + Whether using shared AR + + Returns + ------- + opt_C : np.ndarray + Optimal calcium traces + opt_S : np.ndarray + Optimal spike trains + """ + metric_df = state.metric_df + C_ls = state.C_ls + S_ls = state.S_ls + + opt_C = np.empty((ncell, T * up_factor)) + opt_S = np.empty((ncell, T * up_factor)) + + # mobj = metric_df.groupby("iter")["obj"].median() + # opt_idx_all = mobj.idxmin() + # NOTE: Original pipeline always selected the last iteration (-1), + # regardless of metric-based selection. We preserve that behavior here. + # (The metric-based selection logic was present but unused in the original.) + opt_idx = -1 + + for icell in range(ncell): + opt_C[icell, :] = C_ls[opt_idx][icell, :] + opt_S[icell, :] = S_ls[opt_idx][icell, :] + + # Append optimal to lists (matching original behavior) + C_ls.append(opt_C) + S_ls.append(opt_S) + + return opt_C, opt_S + diff --git a/src/indeca/pipeline/config.py b/src/indeca/pipeline/config.py new file mode 100644 index 0000000..4eb0abe --- /dev/null +++ b/src/indeca/pipeline/config.py @@ -0,0 +1,281 @@ +"""Pydantic v2 configuration models for the binary pursuit pipeline. + +These configs make the pipeline self-documenting, enable validation, +and allow easy CLI / config-file usage in the future. +""" + +from typing import Literal, Optional, Tuple, Union + +from pydantic import BaseModel, Field + + +class PreprocessConfig(BaseModel): + """Configuration for preprocessing traces.""" + + model_config = {"frozen": True} + + med_wnd: Optional[Union[int, Literal["auto"]]] = Field( + None, + description="Window size for median filtering. Use 'auto' to set to ar_kn_len, or None to skip.", + ) + dff: bool = Field( + True, + description="Whether to compute dF/F normalization.", + ) + + +class InitConfig(BaseModel): + """Configuration for AR parameter initialization.""" + + model_config = {"frozen": True} + + tau_init: Optional[Tuple[float, float]] = Field( + None, + description="Initial tau values (tau_d, tau_r). If None, estimate from data.", + ) + est_noise_freq: Optional[float] = Field( + None, + description="Frequency for noise estimation. None uses default.", + ) + est_use_smooth: bool = Field( + False, + description="Whether to use smoothing during AR estimation.", + ) + est_add_lag: int = Field( + 20, + description="Additional lag samples for AR estimation.", + ) + est_nevt: Optional[int] = Field( + 10, + description="Number of top spike events for AR update. None uses all spikes.", + ) + + +class DeconvStageConfig(BaseModel): + """Configuration for the deconvolution stage.""" + + model_config = {"frozen": True} + + nthres: int = Field( + 1000, + description="Number of thresholds for thresholding step.", + ) + norm: Literal["l1", "l2", "huber"] = Field( + "l2", + description="Norm for data fidelity.", + ) + penal: Optional[Literal["l0", "l1"]] = Field( + None, + description="Penalty type for sparsity.", + ) + backend: Literal["osqp", "cvxpy", "cuosqp"] = Field( + "osqp", + description="Solver backend.", + ) + err_weighting: Optional[Literal["fft", "corr", "adaptive"]] = Field( + None, + description="Error weighting method.", + ) + use_base: bool = Field( + True, + description="Whether to include a baseline term.", + ) + reset_scale: bool = Field( + True, + description="Whether to reset scale at each iteration.", + ) + masking_radius: Optional[int] = Field( + None, + description="Radius for masking around spikes.", + ) + pks_polish: bool = Field( + True, + description="Whether to polish peaks after solving.", + ) + ncons_thres: Optional[Union[int, Literal["auto"]]] = Field( + None, + description="Max consecutive spikes threshold. 'auto' = upsamp + 1.", + ) + min_rel_scl: Optional[Union[float, Literal["auto"]]] = Field( + None, + description="Minimum relative scale. 'auto' = 0.5 / upsamp.", + ) + atol: float = Field( + 1e-3, + description="Absolute tolerance for solver.", + ) + + +class ARUpdateConfig(BaseModel): + """Configuration for AR parameter updates.""" + + model_config = {"frozen": True} + + use_all: bool = Field( + True, + description="Whether to use all cells for AR update (shared tau).", + ) + kn_len: int = Field( + 100, + description="Kernel length for AR fitting.", + ) + norm: Literal["l1", "l2"] = Field( + "l2", + description="Norm for AR fitting.", + ) + prop_best: Optional[float] = Field( + None, + description="Proportion of best cells to use for AR update. None uses all.", + ) + + +class ConvergenceConfig(BaseModel): + """Configuration for convergence criteria.""" + + model_config = {"frozen": True} + + max_iters: int = Field( + 50, + description="Maximum number of iterations.", + ) + err_atol: float = Field( + 1e-4, + description="Absolute error tolerance for convergence.", + ) + err_rtol: float = Field( + 5e-2, + description="Relative error tolerance for convergence.", + ) + use_rel_err: bool = Field( + True, + description="Whether to use relative error for objective.", + ) + n_best: Optional[int] = Field( + 3, + description="Number of best iterations to average for spike selection.", + ) + + +class DeconvPipelineConfig(BaseModel): + """Main configuration for the binary pursuit deconvolution pipeline. + + This is the top-level config that composes all sub-configs. + """ + + model_config = {"frozen": True} + + # Core parameters + up_factor: int = Field( + 1, + description="Upsampling factor for spike times.", + ) + p: int = Field( + 2, + description="Order of AR model (typically 2 for calcium imaging).", + ) + + # Sub-configs + preprocess: PreprocessConfig = Field( + default_factory=PreprocessConfig, + description="Preprocessing configuration.", + ) + init: InitConfig = Field( + default_factory=InitConfig, + description="Initialization configuration.", + ) + deconv: DeconvStageConfig = Field( + default_factory=DeconvStageConfig, + description="Deconvolution stage configuration.", + ) + ar_update: ARUpdateConfig = Field( + default_factory=ARUpdateConfig, + description="AR update configuration.", + ) + convergence: ConvergenceConfig = Field( + default_factory=ConvergenceConfig, + description="Convergence configuration.", + ) + + @classmethod + def from_legacy_kwargs( + cls, + *, + up_factor: int = 1, + p: int = 2, + tau_init: Optional[Tuple[float, float]] = None, + max_iters: int = 50, + n_best: Optional[int] = 3, + use_rel_err: bool = True, + err_atol: float = 1e-4, + err_rtol: float = 5e-2, + est_noise_freq: Optional[float] = None, + est_use_smooth: bool = False, + est_add_lag: int = 20, + est_nevt: Optional[int] = 10, + med_wnd: Optional[Union[int, Literal["auto"]]] = None, + dff: bool = True, + deconv_nthres: int = 1000, + deconv_norm: Literal["l1", "l2", "huber"] = "l2", + deconv_atol: float = 1e-3, + deconv_penal: Optional[Literal["l0", "l1"]] = None, + deconv_backend: Literal["osqp", "cvxpy", "cuosqp"] = "osqp", + deconv_err_weighting: Optional[Literal["fft", "corr", "adaptive"]] = None, + deconv_use_base: bool = True, + deconv_reset_scl: bool = True, + deconv_masking_radius: Optional[int] = None, + deconv_pks_polish: bool = True, + deconv_ncons_thres: Optional[Union[int, Literal["auto"]]] = None, + deconv_min_rel_scl: Optional[Union[float, Literal["auto"]]] = None, + ar_use_all: bool = True, + ar_kn_len: int = 100, + ar_norm: Literal["l1", "l2"] = "l2", + ar_prop_best: Optional[float] = None, + ) -> "DeconvPipelineConfig": + """Create a config from legacy keyword arguments. + + This factory method enables backward compatibility with the old + flat-kwargs API. + """ + return cls( + up_factor=up_factor, + p=p, + preprocess=PreprocessConfig( + med_wnd=med_wnd, + dff=dff, + ), + init=InitConfig( + tau_init=tau_init, + est_noise_freq=est_noise_freq, + est_use_smooth=est_use_smooth, + est_add_lag=est_add_lag, + est_nevt=est_nevt, + ), + deconv=DeconvStageConfig( + nthres=deconv_nthres, + norm=deconv_norm, + penal=deconv_penal, + backend=deconv_backend, + err_weighting=deconv_err_weighting, + use_base=deconv_use_base, + reset_scale=deconv_reset_scl, + masking_radius=deconv_masking_radius, + pks_polish=deconv_pks_polish, + ncons_thres=deconv_ncons_thres, + min_rel_scl=deconv_min_rel_scl, + atol=deconv_atol, + ), + ar_update=ARUpdateConfig( + use_all=ar_use_all, + kn_len=ar_kn_len, + norm=ar_norm, + prop_best=ar_prop_best, + ), + convergence=ConvergenceConfig( + max_iters=max_iters, + err_atol=err_atol, + err_rtol=err_rtol, + use_rel_err=use_rel_err, + n_best=n_best, + ), + ) + diff --git a/src/indeca/pipeline/convergence.py b/src/indeca/pipeline/convergence.py new file mode 100644 index 0000000..ace1ca5 --- /dev/null +++ b/src/indeca/pipeline/convergence.py @@ -0,0 +1,112 @@ +"""Convergence checking functions. + +Handles all convergence and trapping detection logic. +""" + +from typing import List + +import numpy as np +import pandas as pd + +from .types import ConvergenceResult + + +def check_convergence( + metric_df: pd.DataFrame, + cur_metric: pd.DataFrame, + S: np.ndarray, + S_ls: List[np.ndarray], + *, + i_iter: int, + err_atol: float, + err_rtol: float, +) -> ConvergenceResult: + """Check if the pipeline has converged or is trapped. + + Checks multiple convergence criteria: + 1. Absolute error tolerance + 2. Relative error tolerance + 3. Spike pattern stabilization + 4. Trapped in local optimum (error) + 5. Trapped in local optimum (spike pattern) + + Parameters + ---------- + metric_df : pd.DataFrame + Accumulated metrics from previous iterations + cur_metric : pd.DataFrame + Metrics from current iteration + S : np.ndarray + Current spike trains, shape (ncell, T * up_factor) + S_ls : list of np.ndarray + Spike trains from all iterations + i_iter : int + Current iteration index + err_atol : float + Absolute error tolerance + err_rtol : float + Relative error tolerance + + Returns + ------- + ConvergenceResult + Result indicating if converged and why + """ + # Need at least one previous iteration + metric_prev = metric_df[metric_df["iter"] < i_iter].dropna( + subset=["obj", "scale"] + ) + metric_last = metric_df[metric_df["iter"] == i_iter - 1].dropna( + subset=["obj", "scale"] + ) + + if len(metric_prev) == 0: + return ConvergenceResult(converged=False, reason="") + + err_cur = cur_metric.set_index("cell")["obj"] + err_last = metric_last.set_index("cell")["obj"] + err_best = metric_prev.groupby("cell")["obj"].min() + ncell = S.shape[0] + + # Check 1: Converged by absolute error + if (np.abs(err_cur - err_last) < err_atol).all(): + return ConvergenceResult( + converged=True, reason="Converged: absolute error tolerance reached" + ) + + # Check 2: Converged by relative error + if (np.abs(err_cur - err_last) < err_rtol * err_best).all(): + return ConvergenceResult( + converged=True, reason="Converged: relative error tolerance reached" + ) + + # Check 3: Converged by spike pattern stabilization + T_up = S.shape[1] + S_best = np.empty((ncell, T_up)) + for uid, udf in metric_prev.groupby("cell"): + best_iter = udf.set_index("iter")["obj"].idxmin() + S_best[uid, :] = S_ls[best_iter][uid, :] + + if np.abs(S - S_best).sum() < 1: + return ConvergenceResult( + converged=True, reason="Converged: spike pattern stabilized" + ) + + # Check 4: Trapped by error (current error very close to some past error) + err_all = metric_prev.pivot(columns="iter", index="cell", values="obj") + diff_all = np.abs(err_cur.values.reshape((-1, 1)) - err_all.values) + if (diff_all.min(axis=1) < err_atol).all(): + return ConvergenceResult( + converged=True, reason="Solution trapped in local optimal err" + ) + + # Check 5: Trapped by spike pattern (current pattern matches > 1 past pattern) + if len(S_ls) > 1: + diff_all = np.array([np.abs(S - prev_s).sum() for prev_s in S_ls[:-1]]) + if (diff_all < 1).sum() > 1: + return ConvergenceResult( + converged=True, reason="Solution trapped in local optimal s" + ) + + return ConvergenceResult(converged=False, reason="") + diff --git a/src/indeca/pipeline/init.py b/src/indeca/pipeline/init.py new file mode 100644 index 0000000..55c9b25 --- /dev/null +++ b/src/indeca/pipeline/init.py @@ -0,0 +1,220 @@ +"""Initialization functions for the binary pursuit pipeline. + +Handles AR parameter estimation and DeconvBin instance creation. +""" + +from typing import List, Optional, Tuple, Any + +import numpy as np + +from indeca.core.AR_kernel import AR_upsamp_real, estimate_coefs +from indeca.core.deconv import DeconvBin +from indeca.core.simulation import AR2tau, tau2AR + +from .types import ARParams + + +def initialize_ar_params( + Y: np.ndarray, + *, + tau_init: Optional[Tuple[float, float]], + p: int, + up_factor: int, + ar_kn_len: int, + est_noise_freq: Optional[float], + est_use_smooth: bool, + est_add_lag: int, +) -> ARParams: + """Initialize AR model parameters. + + If tau_init is provided, uses those values for all cells. + Otherwise, estimates AR parameters from the data for each cell. + + Parameters + ---------- + Y : np.ndarray + Input traces, shape (ncell, T) + tau_init : tuple or None + Initial (tau_d, tau_r) values. If None, estimate from data. + p : int + AR model order (typically 2) + up_factor : int + Upsampling factor + ar_kn_len : int + AR kernel length for fitting + est_noise_freq : float or None + Noise frequency for estimation + est_use_smooth : bool + Whether to use smoothing during estimation + est_add_lag : int + Additional lag for estimation + + Returns + ------- + ARParams + Initialized AR parameters (theta, tau, ps) + """ + ncell = Y.shape[0] + + if tau_init is not None: + # Use provided tau values for all cells + theta = tau2AR(tau_init[0], tau_init[1]) + _, _, pp = AR2tau(theta[0], theta[1], solve_amp=True) + ps = np.array([pp, -pp]) + + theta = np.tile(tau2AR(tau_init[0], tau_init[1]), (ncell, 1)) + tau = np.tile(tau_init, (ncell, 1)) + ps = np.tile(ps, (ncell, 1)) + else: + # Estimate AR parameters from data + theta = np.empty((ncell, p)) + tau = np.empty((ncell, p)) + ps = np.empty((ncell, p)) + + for icell, y in enumerate(Y): + cur_theta, _ = estimate_coefs( + y, + p=p, + noise_freq=est_noise_freq, + use_smooth=est_use_smooth, + add_lag=est_add_lag, + ) + cur_theta, cur_tau, cur_p = AR_upsamp_real( + cur_theta, upsamp=up_factor, fit_nsamp=ar_kn_len + ) + tau[icell, :] = cur_tau + theta[icell, :] = cur_theta + ps[icell, :] = cur_p + + return ARParams(theta=theta, tau=tau, ps=ps) + + +def initialize_deconvolvers( + Y: np.ndarray, + ar_params: ARParams, + *, + ar_kn_len: int, + up_factor: int, + nthres: int, + norm: str, + penal: Optional[str], + use_base: bool, + err_weighting: Optional[str], + masking_radius: Optional[int], + pks_polish: bool, + ncons_thres: Optional[int], + min_rel_scl: Optional[float], + atol: float, + backend: str, + dashboard: Any, + da_client: Any, +) -> List[Any]: + """Create DeconvBin instances for all cells. + + Parameters + ---------- + Y : np.ndarray + Input traces, shape (ncell, T) + ar_params : ARParams + Initialized AR parameters + ar_kn_len : int + AR kernel length + up_factor : int + Upsampling factor + nthres : int + Number of thresholds + norm : str + Norm type ("l1", "l2", "huber") + penal : str or None + Penalty type + use_base : bool + Whether to use baseline + err_weighting : str or None + Error weighting method + masking_radius : int or None + Masking radius + pks_polish : bool + Whether to polish peaks + ncons_thres : int or None + Consecutive spikes threshold + min_rel_scl : float or None + Minimum relative scale + atol : float + Absolute tolerance + backend : str + Solver backend + dashboard : Dashboard or None + Dashboard instance + da_client : Client or None + Dask client for distributed execution + + Returns + ------- + list + List of DeconvBin instances (or futures if using Dask) + """ + theta = ar_params.theta + tau = ar_params.tau + ps = ar_params.ps + + if da_client is not None: + # Distributed execution + dcv = [ + da_client.submit( + lambda yy, th, tau_i, ps_i: DeconvBin( + y=yy, + theta=th, + tau=tau_i, + ps=ps_i, + coef_len=ar_kn_len, + upsamp=up_factor, + nthres=nthres, + norm=norm, + penal=penal, + use_base=use_base, + err_weighting=err_weighting, + masking_radius=masking_radius, + pks_polish=pks_polish, + ncons_thres=ncons_thres, + min_rel_scl=min_rel_scl, + atol=atol, + backend=backend, + dashboard=dashboard, + dashboard_uid=i, + ), + y, + theta[i], + tau[i], + ps[i], + ) + for i, y in enumerate(Y) + ] + else: + # Local execution + dcv = [ + DeconvBin( + y=y, + theta=theta[i], + tau=tau[i], + ps=ps[i], + coef_len=ar_kn_len, + upsamp=up_factor, + nthres=nthres, + norm=norm, + penal=penal, + use_base=use_base, + err_weighting=err_weighting, + masking_radius=masking_radius, + pks_polish=pks_polish, + ncons_thres=ncons_thres, + min_rel_scl=min_rel_scl, + atol=atol, + backend=backend, + dashboard=dashboard, + dashboard_uid=i, + ) + for i, y in enumerate(Y) + ] + + return dcv + diff --git a/src/indeca/pipeline/iteration.py b/src/indeca/pipeline/iteration.py new file mode 100644 index 0000000..79e434a --- /dev/null +++ b/src/indeca/pipeline/iteration.py @@ -0,0 +1,79 @@ +"""Per-iteration deconvolution step functions. + +Handles running solve_scale on all cells and collecting results. +""" + +from typing import Any, List + +import numpy as np +from tqdm.auto import tqdm + +from .types import DeconvStepResult + + +def run_deconv_step( + Y: np.ndarray, + deconvolvers: List[Any], + *, + i_iter: int, + reset_scale: bool, + da_client: Any, +) -> DeconvStepResult: + """Run one deconvolution iteration across all cells. + + Parameters + ---------- + Y : np.ndarray + Input traces, shape (ncell, T) + deconvolvers : list + List of DeconvBin instances (or futures) + i_iter : int + Current iteration index + reset_scale : bool + Whether to reset scale this iteration + da_client : Client or None + Dask client for distributed execution + + Returns + ------- + DeconvStepResult + Results from this deconvolution step + """ + res = [] + + for icell, _ in tqdm( + enumerate(Y), total=Y.shape[0], desc="deconv", leave=False + ): + if da_client is not None: + r = da_client.submit( + lambda d: d.solve_scale(reset_scale=i_iter <= 1 or reset_scale), + deconvolvers[icell], + ) + else: + r = deconvolvers[icell].solve_scale( + reset_scale=i_iter <= 1 or reset_scale + ) + res.append(r) + + if da_client is not None: + res = da_client.gather(res) + + # Unpack results + S = np.stack([r[0].squeeze() for r in res], axis=0, dtype=float) + C = np.stack([r[1].squeeze() for r in res], axis=0) + scale = np.array([r[2] for r in res]) + err = np.array([r[3] for r in res]) + err_rel = np.array([r[4] for r in res]) + nnz = np.array([r[5] for r in res]) + penal = np.array([r[6] for r in res]) + + return DeconvStepResult( + S=S, + C=C, + scale=scale, + err=err, + err_rel=err_rel, + nnz=nnz, + penal=penal, + ) + diff --git a/src/indeca/pipeline/metrics.py b/src/indeca/pipeline/metrics.py new file mode 100644 index 0000000..ed65249 --- /dev/null +++ b/src/indeca/pipeline/metrics.py @@ -0,0 +1,132 @@ +"""Metrics construction and update functions. + +Handles building the per-iteration metrics DataFrame. +""" + +from typing import Any, List + +import numpy as np +import pandas as pd + +from indeca.core.simulation import find_dhm + +from .types import DeconvStepResult + + +def make_cur_metric( + i_iter: int, + ncell: int, + theta: np.ndarray, + tau: np.ndarray, + scale: np.ndarray, + deconv_result: DeconvStepResult, + deconvolvers: List[Any], + use_rel_err: bool, +) -> pd.DataFrame: + """Construct the metrics DataFrame for the current iteration. + + Parameters + ---------- + i_iter : int + Current iteration index + ncell : int + Number of cells + theta : np.ndarray + AR coefficients, shape (ncell, p) + tau : np.ndarray + Time constants, shape (ncell, 2) + scale : np.ndarray + Scale factors, shape (ncell,) + deconv_result : DeconvStepResult + Results from deconvolution step + deconvolvers : list + List of DeconvBin instances + use_rel_err : bool + Whether to use relative error for objective + + Returns + ------- + pd.DataFrame + Metrics for the current iteration + """ + # Compute half-max durations + dhm = np.stack( + [ + np.array(find_dhm(True, (t0, t1), (s, -s))[0], dtype=float) + for t0, t1, s in zip(tau.T[0], tau.T[1], scale) + ], + axis=0, + ) + + cur_metric = pd.DataFrame( + { + "iter": i_iter, + "cell": np.arange(ncell), + "g0": theta.T[0], + "g1": theta.T[1], + "tau_d": tau.T[0], + "tau_r": tau.T[1], + "dhm0": dhm.T[0], + "dhm1": dhm.T[1], + "err": deconv_result.err, + "err_rel": deconv_result.err_rel, + "scale": scale, + "penal": deconv_result.penal, + "nnz": deconv_result.nnz, + "obj": deconv_result.err_rel if use_rel_err else deconv_result.err, + "wgt_len": [d.wgt_len for d in deconvolvers], + } + ) + + return cur_metric + + +def append_metrics( + metric_df: pd.DataFrame, + cur_metric: pd.DataFrame, +) -> pd.DataFrame: + """Append current iteration metrics to the accumulated DataFrame. + + Parameters + ---------- + metric_df : pd.DataFrame + Accumulated metrics from previous iterations + cur_metric : pd.DataFrame + Metrics from current iteration + + Returns + ------- + pd.DataFrame + Updated metrics DataFrame + """ + return pd.concat([metric_df, cur_metric], ignore_index=True) + + +def update_dashboard( + dashboard: Any, + cur_metric: pd.DataFrame, + i_iter: int, + max_iters: int, +) -> None: + """Update the dashboard with current iteration metrics. + + Parameters + ---------- + dashboard : Dashboard or None + Dashboard instance + cur_metric : pd.DataFrame + Current iteration metrics + i_iter : int + Current iteration index + max_iters : int + Maximum number of iterations + """ + if dashboard is not None: + dashboard.update( + tau_d=cur_metric["tau_d"].squeeze(), + tau_r=cur_metric["tau_r"].squeeze(), + err=cur_metric["obj"].squeeze(), + scale=cur_metric["scale"].squeeze(), + ) + dashboard.set_iter(min(i_iter + 1, max_iters - 1)) + diff --git a/src/indeca/pipeline/pipeline.py b/src/indeca/pipeline/pipeline.py index cfd4fd4..446578f 100644 --- a/src/indeca/pipeline/pipeline.py +++ b/src/indeca/pipeline/pipeline.py @@ -1,446 +1,220 @@ +"""Legacy interface for the binary pursuit pipeline. + +This module provides backward compatibility with the old flat-kwargs API. +New code should use the config-based API from binary_pursuit.py. + +.. deprecated:: + Use `pipeline_bin` from `indeca.pipeline.binary_pursuit` with + `DeconvPipelineConfig` instead. +""" + import warnings +from typing import Literal, Optional, Tuple, Union import numpy as np import pandas as pd from line_profiler import profile -from scipy.signal import find_peaks, medfilt -from tqdm.auto import tqdm, trange -from indeca.core.AR_kernel import AR_upsamp_real, estimate_coefs, updateAR -from indeca.dashboard.dashboard import Dashboard -from indeca.core.deconv import DeconvBin, construct_R from indeca.utils.logging_config import get_module_logger -from indeca.core.simulation import AR2tau, find_dhm, tau2AR -from indeca.utils.utils import compute_dff -# Initialize logger for this module +from .binary_pursuit import pipeline_bin as _pipeline_bin_new +from .config import DeconvPipelineConfig + logger = get_module_logger("pipeline") -logger.info("Pipeline module initialized") # Test message on import +logger.info("Pipeline module initialized") @profile def pipeline_bin( - Y, - up_factor=1, - p=2, - tau_init=None, - return_iter=False, - max_iters=50, - n_best=3, - use_rel_err=True, - err_atol=1e-4, - err_rtol=5e-2, - est_noise_freq=None, - est_use_smooth=False, - est_add_lag=20, - est_nevt=10, - med_wnd=None, - dff=True, - deconv_nthres=1000, - deconv_norm="l2", - deconv_atol=1e-3, - deconv_penal=None, - deconv_backend="osqp", - deconv_err_weighting=None, - deconv_use_base=True, - deconv_reset_scl=True, - deconv_masking_radius=None, - deconv_pks_polish=None, - deconv_ncons_thres=None, - deconv_min_rel_scl=None, - ar_use_all=True, - ar_kn_len=100, - ar_norm="l2", - ar_prop_best=None, + Y: np.ndarray, + up_factor: int = 1, + p: int = 2, + tau_init: Optional[Tuple[float, float]] = None, + return_iter: bool = False, + max_iters: int = 50, + n_best: Optional[int] = 3, + use_rel_err: bool = True, + err_atol: float = 1e-4, + err_rtol: float = 5e-2, + est_noise_freq: Optional[float] = None, + est_use_smooth: bool = False, + est_add_lag: int = 20, + est_nevt: Optional[int] = 10, + med_wnd: Optional[Union[int, Literal["auto"]]] = None, + dff: bool = True, + deconv_nthres: int = 1000, + deconv_norm: Literal["l1", "l2", "huber"] = "l2", + deconv_atol: float = 1e-3, + deconv_penal: Optional[Literal["l0", "l1"]] = None, + deconv_backend: Literal["osqp", "cvxpy", "cuosqp"] = "osqp", + deconv_err_weighting: Optional[Literal["fft", "corr", "adaptive"]] = None, + deconv_use_base: bool = True, + deconv_reset_scl: bool = True, + deconv_masking_radius: Optional[int] = None, + deconv_pks_polish: bool = True, + deconv_ncons_thres: Optional[Union[int, Literal["auto"]]] = None, + deconv_min_rel_scl: Optional[Union[float, Literal["auto"]]] = None, + ar_use_all: bool = True, + ar_kn_len: int = 100, + ar_norm: Literal["l1", "l2"] = "l2", + ar_prop_best: Optional[float] = None, da_client=None, - spawn_dashboard=True, -): - """Binary pursuit pipeline for spike inference. + spawn_dashboard: bool = True, +) -> Union[ + Tuple[np.ndarray, np.ndarray, pd.DataFrame], + Tuple[np.ndarray, np.ndarray, pd.DataFrame, list, list, list, list], +]: + """Binary pursuit pipeline for spike inference (legacy interface). + + .. deprecated:: + This function signature is deprecated. Use the config-based API: + + >>> from indeca.pipeline import pipeline_bin, DeconvPipelineConfig + >>> config = DeconvPipelineConfig(up_factor=2, ...) + >>> opt_C, opt_S, metrics = pipeline_bin(Y, config=config) Parameters ---------- Y : array-like - Input fluorescence trace - ... + Input fluorescence trace, shape (ncell, T) + up_factor : int + Upsampling factor for spike times + p : int + AR model order + tau_init : tuple or None + Initial (tau_d, tau_r) values. If None, estimate from data. + return_iter : bool + Whether to return per-iteration results + max_iters : int + Maximum number of iterations + n_best : int or None + Number of best iterations to average for spike selection + use_rel_err : bool + Whether to use relative error for objective + err_atol : float + Absolute error tolerance for convergence + err_rtol : float + Relative error tolerance for convergence + est_noise_freq : float or None + Frequency for noise estimation + est_use_smooth : bool + Whether to use smoothing during AR estimation + est_add_lag : int + Additional lag samples for AR estimation + est_nevt : int or None + Number of top spike events for AR update + med_wnd : int, "auto", or None + Window size for median filtering + dff : bool + Whether to compute dF/F normalization + deconv_nthres : int + Number of thresholds for thresholding step + deconv_norm : str + Norm for data fidelity + deconv_atol : float + Absolute tolerance for solver + deconv_penal : str or None + Penalty type for sparsity + deconv_backend : str + Solver backend + deconv_err_weighting : str or None + Error weighting method + deconv_use_base : bool + Whether to include a baseline term + deconv_reset_scl : bool + Whether to reset scale at each iteration + deconv_masking_radius : int or None + Radius for masking around spikes + deconv_pks_polish : bool + Whether to polish peaks after solving + deconv_ncons_thres : int, "auto", or None + Max consecutive spikes threshold + deconv_min_rel_scl : float, "auto", or None + Minimum relative scale + ar_use_all : bool + Whether to use all cells for AR update (shared tau) + ar_kn_len : int + Kernel length for AR fitting + ar_norm : str + Norm for AR fitting + ar_prop_best : float or None + Proportion of best cells to use for AR update + da_client : Client or None + Dask client for distributed execution + spawn_dashboard : bool + Whether to spawn a real-time dashboard Returns ------- - dict - Dictionary containing results of the pipeline + opt_C : np.ndarray + Optimal calcium traces + opt_S : np.ndarray + Optimal spike trains + metric_df : pd.DataFrame + Per-iteration metrics + C_ls : list (only if return_iter=True) + Calcium traces per iteration + S_ls : list (only if return_iter=True) + Spike trains per iteration + h_ls : list (only if return_iter=True) + Impulse responses per iteration + h_fit_ls : list (only if return_iter=True) + Fitted impulse responses per iteration """ - logger.info("Starting binary pursuit pipeline") - # 0. housekeeping - ncell, T = Y.shape - logger.debug( - "Pipeline parameters: " - f"up_factor={up_factor}, p={p}, max_iters={max_iters}, " - f"n_best={n_best}, deconv_backend={deconv_backend}, " - f"ar_use_all={ar_use_all}, ar_kn_len={ar_kn_len}" - f"{ncell} cells with {T} timepoints" + # Emit deprecation warning + warnings.warn( + "The flat-kwargs signature of pipeline_bin() is deprecated. " + "Use the config-based API instead:\n" + " from indeca.pipeline import pipeline_bin, DeconvPipelineConfig\n" + " config = DeconvPipelineConfig.from_legacy_kwargs(...)\n" + " result = pipeline_bin(Y, config=config)", + DeprecationWarning, + stacklevel=2, ) - if med_wnd is not None: - if med_wnd == "auto": - med_wnd = ar_kn_len - for iy, y in enumerate(Y): - Y[iy, :] = y - medfilt(y, med_wnd * 2 + 1) - if dff: - for iy, y in enumerate(Y): - Y[iy, :] = compute_dff(y, window_size=ar_kn_len * 5, q=0.2) - if spawn_dashboard: - if da_client is not None: - logger.debug("Using Dask client for distributed computation") - dashboard = da_client.submit( - Dashboard, Y=Y, kn_len=ar_kn_len, actor=True - ).result() - else: - logger.debug("Running in single-machine mode") - dashboard = Dashboard(Y=Y, kn_len=ar_kn_len) - else: - dashboard = None - # 1. estimate initial guess at convolution kernel - if tau_init is not None: - logger.debug(f"Using provided tau_init: {tau_init}") - theta = tau2AR(tau_init[0], tau_init[1]) - _, _, pp = AR2tau(theta[0], theta[1], solve_amp=True) - ps = np.array([pp, -pp]) - theta = np.tile(tau2AR(tau_init[0], tau_init[1]), (ncell, 1)) - tau = np.tile(tau_init, (ncell, 1)) - ps = np.tile(ps, (ncell, 1)) - else: - logger.debug("Computing initial tau values") - theta = np.empty((ncell, p)) - tau = np.empty((ncell, p)) - ps = np.empty((ncell, p)) - for icell, y in enumerate(Y): - cur_theta, _ = estimate_coefs( - y, - p=p, - noise_freq=est_noise_freq, - use_smooth=est_use_smooth, - add_lag=est_add_lag, - ) - cur_theta, cur_tau, cur_p = AR_upsamp_real( - cur_theta, upsamp=up_factor, fit_nsamp=ar_kn_len - ) - tau[icell, :] = cur_tau - theta[icell, :] = cur_theta - ps[icell, :] = cur_p - scale = np.empty(ncell) - # 2. iteration loop - C_ls = [] - S_ls = [] - scal_ls = [] - h_ls = [] - h_fit_ls = [] - metric_df = pd.DataFrame( - columns=[ - "iter", - "cell", - "g0", - "g1", - "tau_d", - "tau_r", - "err", - "err_rel", - "nnz", - "scale", - "best_idx", - "obj", - "wgt_len", - ] + + # Build config from legacy kwargs + config = DeconvPipelineConfig.from_legacy_kwargs( + up_factor=up_factor, + p=p, + tau_init=tau_init, + max_iters=max_iters, + n_best=n_best, + use_rel_err=use_rel_err, + err_atol=err_atol, + err_rtol=err_rtol, + est_noise_freq=est_noise_freq, + est_use_smooth=est_use_smooth, + est_add_lag=est_add_lag, + est_nevt=est_nevt, + med_wnd=med_wnd, + dff=dff, + deconv_nthres=deconv_nthres, + deconv_norm=deconv_norm, + deconv_atol=deconv_atol, + deconv_penal=deconv_penal, + deconv_backend=deconv_backend, + deconv_err_weighting=deconv_err_weighting, + deconv_use_base=deconv_use_base, + deconv_reset_scl=deconv_reset_scl, + deconv_masking_radius=deconv_masking_radius, + deconv_pks_polish=deconv_pks_polish, + deconv_ncons_thres=deconv_ncons_thres, + deconv_min_rel_scl=deconv_min_rel_scl, + ar_use_all=ar_use_all, + ar_kn_len=ar_kn_len, + ar_norm=ar_norm, + ar_prop_best=ar_prop_best, + ) + + # Delegate to new implementation + return _pipeline_bin_new( + Y, + config=config, + da_client=da_client, + spawn_dashboard=spawn_dashboard, + return_iter=return_iter, ) - if da_client is not None: - dcv = [ - da_client.submit( - lambda yy, th, tau, ps: DeconvBin( - y=yy, - theta=th, - tau=tau, - ps=ps, - coef_len=ar_kn_len, - upsamp=up_factor, - nthres=deconv_nthres, - norm=deconv_norm, - penal=deconv_penal, - use_base=deconv_use_base, - err_weighting=deconv_err_weighting, - masking_radius=deconv_masking_radius, - pks_polish=deconv_pks_polish, - ncons_thres=deconv_ncons_thres, - min_rel_scl=deconv_min_rel_scl, - atol=deconv_atol, - backend=deconv_backend, - dashboard=dashboard, - dashboard_uid=i, - ), - y, - theta[i], - tau[i], - ps[i], - ) - for i, y in enumerate(Y) - ] - else: - dcv = [ - DeconvBin( - y=y, - theta=theta[i], - tau=tau[i], - ps=ps[i], - coef_len=ar_kn_len, - upsamp=up_factor, - nthres=deconv_nthres, - norm=deconv_norm, - penal=deconv_penal, - use_base=deconv_use_base, - err_weighting=deconv_err_weighting, - masking_radius=deconv_masking_radius, - pks_polish=deconv_pks_polish, - ncons_thres=deconv_ncons_thres, - min_rel_scl=deconv_min_rel_scl, - atol=deconv_atol, - backend=deconv_backend, - dashboard=dashboard, - dashboard_uid=i, - ) - for i, y in enumerate(Y) - ] - for i_iter in trange(max_iters, desc="iteration"): - logger.info(f"Starting iteration {i_iter}/{max_iters}") - # 2.1 deconvolution - res = [] - for icell, y in tqdm( - enumerate(Y), total=Y.shape[0], desc="deconv", leave=False - ): - if da_client is not None: - r = da_client.submit( - lambda d: d.solve_scale( - reset_scale=i_iter <= 1 or deconv_reset_scl - ), - dcv[icell], - ) - else: - r = dcv[icell].solve_scale(reset_scale=i_iter <= 1 or deconv_reset_scl) - res.append(r) - if da_client is not None: - res = da_client.gather(res) - S = np.stack([r[0].squeeze() for r in res], axis=0, dtype=float) - C = np.stack([r[1].squeeze() for r in res], axis=0) - scale = np.array([r[2] for r in res]) - err = np.array([r[3] for r in res]) - err_rel = np.array([r[4] for r in res]) - nnz = np.array([r[5] for r in res]) - penal = np.array([r[6] for r in res]) - logger.debug( - f"Iteration {i_iter} stats - Mean error: {err.mean():.4f}, Mean scale: {scale.mean():.4f}" - ) - # 2.2 save iteration results - dhm = np.stack( - [ - find_dhm(True, (t0, t1), (s, -s))[0] - for t0, t1, s in zip(tau.T[0], tau.T[1], scale) - ] - ) - cur_metric = pd.DataFrame( - { - "iter": i_iter, - "cell": np.arange(ncell), - "g0": theta.T[0], - "g1": theta.T[1], - "tau_d": tau.T[0], - "tau_r": tau.T[1], - "dhm0": dhm.T[0], - "dhm1": dhm.T[1], - "err": err, - "err_rel": err_rel, - "scale": scale, - "penal": penal, - "nnz": nnz, - "obj": err_rel if use_rel_err else err, - "wgt_len": [d.wgt_len for d in dcv], - } - ) - if dashboard is not None: - dashboard.update( - tau_d=cur_metric["tau_d"].squeeze(), - tau_r=cur_metric["tau_r"].squeeze(), - err=cur_metric["obj"].squeeze(), - scale=cur_metric["scale"].squeeze(), - ) - dashboard.set_iter(min(i_iter + 1, max_iters - 1)) - metric_df = pd.concat([metric_df, cur_metric], ignore_index=True) - C_ls.append(C) - S_ls.append(S) - scal_ls.append(scale) - try: - h_ls.append(h) - h_fit_ls.append(h_fit) - except UnboundLocalError: - h_ls.append(np.full(T * up_factor, np.nan)) - h_fit_ls.append(np.full(T * up_factor, np.nan)) - # 2.3 update AR - metric_df = metric_df.set_index(["iter", "cell"]) - if n_best is not None and i_iter >= n_best: - S_best = np.empty_like(S) - scal_best = np.empty_like(scale) - err_wt = np.empty_like(err_rel) - if tau_init is not None: - metric_best = metric_df - else: - metric_best = metric_df.loc[1:, :] - for icell, cell_met in metric_best.groupby("cell", sort=True): - cell_met = cell_met.reset_index().sort_values("obj", ascending=True) - cur_idx = np.array(cell_met["iter"][:n_best]) - metric_df.loc[(i_iter, icell), "best_idx"] = ",".join( - cur_idx.astype(str) - ) - S_best[icell, :] = np.sum( - np.stack([S_ls[i][icell, :] for i in cur_idx], axis=0), axis=0 - ) > (n_best / 2) - scal_best[icell] = np.mean([scal_ls[i][icell] for i in cur_idx]) - err_wt[icell] = -np.mean( - [metric_df.loc[(i, icell), "err_rel"] for i in cur_idx] - ) - else: - S_best = S - scal_best = scale - err_wt = -err_rel - metric_df = metric_df.reset_index() - if est_nevt is not None: - S_ar = [] - R = construct_R(T, up_factor) - for s in S_best: - Rs = R @ s - s_pks, pk_prop = find_peaks( - Rs, height=1, distance=ar_kn_len * up_factor - ) - pk_ht = pk_prop["peak_heights"] - top_idx = s_pks[np.argsort(pk_ht)[-est_nevt:]] - mask = np.zeros_like(Rs, dtype=bool) - mask[top_idx] = True - Rs_ma = Rs * mask - s_ma = np.zeros_like(s) - s_ma[::up_factor] = Rs_ma - S_ar.append(s_ma) - S_ar = np.stack(S_ar, axis=0) - else: - S_ar = S_best - if ar_use_all: - if ar_prop_best is not None: - ar_nbest = max(int(np.round(ar_prop_best * ncell)), 1) - ar_best_idx = np.argsort(err_wt)[-ar_nbest:] - else: - ar_best_idx = slice(None) - cur_tau, ps, ar_scal, h, h_fit = updateAR( - Y[ar_best_idx], - S_ar[ar_best_idx], - scal_best[ar_best_idx], - N=p, - h_len=ar_kn_len * up_factor, - norm=ar_norm, - up_factor=up_factor, - ) - if dashboard is not None: - dashboard.update( - h=h[: ar_kn_len * up_factor], h_fit=h_fit[: ar_kn_len * up_factor] - ) - tau = np.tile(cur_tau, (ncell, 1)) - for idx, d in enumerate(dcv): - if da_client is not None: - da_client.submit( - lambda dd: dd.update(tau=cur_tau, scale=scal_best[idx]), d - ) - else: - d.update(tau=cur_tau, scale=scal_best[idx]) - logger.debug( - f"Updating AR parameters for all cells: tau:{tau}, ar_scal: {ar_scal}" - ) - else: - theta = np.empty((ncell, p)) - tau = np.empty((ncell, p)) - for icell, (y, s) in enumerate(zip(Y, S_ar)): - cur_tau, ps, ar_scal, h, h_fit = updateAR( - y, - s, - scal_best[icell], - N=p, - h_len=ar_kn_len, - norm=ar_norm, - up_factor=up_factor, - ) - if dashboard is not None: - dashboard.update(uid=icell, h=h, h_fit=h_fit) - tau[icell, :] = cur_tau - if da_client is not None: - da_client.submit( - lambda dd: dd.update(tau=cur_tau, scale=scal_best[icell]), - dcv[icell], - ) - else: - dcv[icell].update(tau=cur_tau, scale=scal_best[icell]) - logger.debug( - f"Updating AR parameters for cell {icell}: tau:{tau}, ar_scal: {ar_scal}" - ) - # 2.4 check convergence - metric_prev = metric_df[metric_df["iter"] < i_iter].dropna( - subset=["obj", "scale"] - ) - metric_last = metric_df[metric_df["iter"] == i_iter - 1].dropna( - subset=["obj", "scale"] - ) - if len(metric_prev) > 0: - err_cur = cur_metric.set_index("cell")["obj"] - err_last = metric_last.set_index("cell")["obj"] - err_best = metric_prev.groupby("cell")["obj"].min() - # converged by err - if (np.abs(err_cur - err_last) < err_atol).all(): - logger.info("Converged: absolute error tolerance reached") - break - # converged by relative err - if (np.abs(err_cur - err_last) < err_rtol * err_best).all(): - logger.info("Converged: relative error tolerance reached") - break - # converged by s - S_best = np.empty((ncell, T * up_factor)) - for uid, udf in metric_prev.groupby("cell"): - best_iter = udf.set_index("iter")["obj"].idxmin() - S_best[uid, :] = S_ls[best_iter][uid, :] - if np.abs(S - S_best).sum() < 1: - logger.info("Converged: spike pattern stabilized") - break - # trapped - err_all = metric_prev.pivot(columns="iter", index="cell", values="obj") - diff_all = np.abs(err_cur.values.reshape((-1, 1)) - err_all.values) - if (diff_all.min(axis=1) < err_atol).all(): - logger.warning("Solution trapped in local optimal err") - break - # trapped by s - diff_all = np.array([np.abs(S - prev_s).sum() for prev_s in S_ls[:-1]]) - if (diff_all < 1).sum() > 1: - logger.warning("Solution trapped in local optimal s") - break - else: - logger.warning("Max iteration reached without convergence") - # Compute final results - opt_C, opt_S = np.empty((ncell, T * up_factor)), np.empty((ncell, T * up_factor)) - mobj = metric_df.groupby("iter")["obj"].median() - opt_idx_all = mobj.idxmin() - for icell in range(ncell): - if ar_use_all: - opt_idx = opt_idx_all - else: - opt_idx = metric_df.loc[ - metric_df[metric_df["cell"] == icell]["obj"].idxmin(), "iter" - ] - opt_idx = -1 - opt_C[icell, :] = C_ls[opt_idx][icell, :] - opt_S[icell, :] = S_ls[opt_idx][icell, :] - C_ls.append(opt_C) - S_ls.append(opt_S) - if dashboard is not None: - dashboard.stop() - logger.info("Pipeline completed successfully") - if return_iter: - return opt_C, opt_S, metric_df, C_ls, S_ls, h_ls, h_fit_ls - else: - return opt_C, opt_S, metric_df + + +# Keep legacy name available for explicit imports +pipeline_bin_legacy = pipeline_bin diff --git a/src/indeca/pipeline/preprocess.py b/src/indeca/pipeline/preprocess.py new file mode 100644 index 0000000..f118dd5 --- /dev/null +++ b/src/indeca/pipeline/preprocess.py @@ -0,0 +1,57 @@ +"""Preprocessing functions for the binary pursuit pipeline. + +These are pure functions that transform input traces before deconvolution. +""" + +from typing import Optional, Union, Literal + +import numpy as np +from scipy.signal import medfilt + +from indeca.utils.utils import compute_dff + + +def preprocess_traces( + Y: np.ndarray, + *, + med_wnd: Optional[Union[int, Literal["auto"]]] = None, + dff: bool = True, + ar_kn_len: int = 100, +) -> np.ndarray: + """Preprocess fluorescence traces. + + This function applies median filtering and/or dF/F normalization + to the input traces. The input array is modified in place for + efficiency (matching the original pipeline behavior). + + Parameters + ---------- + Y : np.ndarray + Input fluorescence traces, shape (ncell, T) + med_wnd : int, "auto", or None + Window size for median filtering. If "auto", uses ar_kn_len. + If None, skips median filtering. + dff : bool + Whether to apply dF/F normalization. + ar_kn_len : int + AR kernel length, used for window sizing. + + Returns + ------- + np.ndarray + Preprocessed traces, shape (ncell, T). + Note: This may be the same array as Y (modified in place). + """ + # Median filtering + if med_wnd is not None: + actual_wnd = ar_kn_len if med_wnd == "auto" else med_wnd + for iy, y in enumerate(Y): + Y[iy, :] = y - medfilt(y, actual_wnd * 2 + 1) + + # dF/F normalization + if dff: + for iy, y in enumerate(Y): + Y[iy, :] = compute_dff(y, window_size=ar_kn_len * 5, q=0.2) + + return Y + diff --git a/src/indeca/pipeline/types.py b/src/indeca/pipeline/types.py new file mode 100644 index 0000000..7e6bab6 --- /dev/null +++ b/src/indeca/pipeline/types.py @@ -0,0 +1,154 @@ +"""Shared type definitions for the binary pursuit pipeline. + +These types define the data structures passed between pipeline steps, +making the data flow explicit and typed. +""" + +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +import pandas as pd + + +@dataclass +class ARParams: + """AR model parameters for all cells. + + Attributes: + theta: AR coefficients, shape (ncell, p) + tau: Time constants (tau_d, tau_r), shape (ncell, 2) + ps: Peak coefficients, shape (ncell, p) + """ + + theta: np.ndarray + tau: np.ndarray + ps: np.ndarray + + +@dataclass +class DeconvStepResult: + """Result of a single deconvolution step. + + Attributes: + S: Spike train, shape (ncell, T * up_factor) + C: Calcium trace, shape (ncell, T * up_factor) + scale: Scale factors, shape (ncell,) + err: Absolute errors, shape (ncell,) + err_rel: Relative errors, shape (ncell,) + nnz: Non-zero counts, shape (ncell,) + penal: Penalty values, shape (ncell,) + """ + + S: np.ndarray + C: np.ndarray + scale: np.ndarray + err: np.ndarray + err_rel: np.ndarray + nnz: np.ndarray + penal: np.ndarray + + +@dataclass +class ARUpdateResult: + """Result of AR parameter update step. + + Attributes: + tau: Updated time constants, shape (ncell, 2) or (1, 2) if use_all + ps: Updated peak coefficients + ar_scal: AR scale factor + h: Estimated impulse response + h_fit: Fitted impulse response + """ + + tau: np.ndarray + ps: np.ndarray + ar_scal: float + h: np.ndarray + h_fit: np.ndarray + + +@dataclass +class ConvergenceResult: + """Result of convergence check. + + Attributes: + converged: Whether convergence criteria are met + reason: Human-readable reason for convergence/non-convergence + """ + + converged: bool + reason: str + + +@dataclass +class IterationState: + """State accumulated across iterations. + + Attributes: + C_ls: List of calcium traces per iteration + S_ls: List of spike trains per iteration + scal_ls: List of scale factors per iteration + h_ls: List of impulse responses per iteration + h_fit_ls: List of fitted impulse responses per iteration + metric_df: DataFrame with per-iteration metrics + """ + + C_ls: List[np.ndarray] + S_ls: List[np.ndarray] + scal_ls: List[np.ndarray] + h_ls: List[np.ndarray] + h_fit_ls: List[np.ndarray] + metric_df: pd.DataFrame + + @classmethod + def empty(cls, T: int, up_factor: int) -> "IterationState": + """Create an empty iteration state.""" + return cls( + C_ls=[], + S_ls=[], + scal_ls=[], + h_ls=[], + h_fit_ls=[], + metric_df=pd.DataFrame( + columns=[ + "iter", + "cell", + "g0", + "g1", + "tau_d", + "tau_r", + "err", + "err_rel", + "nnz", + "scale", + "best_idx", + "obj", + "wgt_len", + ] + ), + ) + + +@dataclass +class PipelineResult: + """Final result of the pipeline. + + Attributes: + opt_C: Optimal calcium traces, shape (ncell, T * up_factor) + opt_S: Optimal spike trains, shape (ncell, T * up_factor) + metric_df: DataFrame with all iteration metrics + C_ls: List of calcium traces per iteration (if return_iter=True) + S_ls: List of spike trains per iteration (if return_iter=True) + h_ls: List of impulse responses per iteration (if return_iter=True) + h_fit_ls: List of fitted impulse responses per iteration (if return_iter=True) + """ + + opt_C: np.ndarray + opt_S: np.ndarray + metric_df: pd.DataFrame + C_ls: Optional[List[np.ndarray]] = None + S_ls: Optional[List[np.ndarray]] = None + h_ls: Optional[List[np.ndarray]] = None + h_fit_ls: Optional[List[np.ndarray]] = None + From 85463b6e4eaaaca72d65de2dce8af94c28775518 Mon Sep 17 00:00:00 2001 From: Daniel Aharoni Date: Sat, 13 Dec 2025 17:04:07 -0800 Subject: [PATCH 6/6] black format --- src/indeca/pipeline/ar_update.py | 11 +++-------- src/indeca/pipeline/binary_pursuit.py | 9 ++------- src/indeca/pipeline/config.py | 1 - src/indeca/pipeline/convergence.py | 5 +---- src/indeca/pipeline/init.py | 1 - src/indeca/pipeline/iteration.py | 9 ++------- src/indeca/pipeline/metrics.py | 1 - src/indeca/pipeline/preprocess.py | 1 - src/indeca/pipeline/types.py | 1 - 9 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/indeca/pipeline/ar_update.py b/src/indeca/pipeline/ar_update.py index ef35136..e281946 100644 --- a/src/indeca/pipeline/ar_update.py +++ b/src/indeca/pipeline/ar_update.py @@ -74,9 +74,7 @@ def select_best_spikes( for icell, cell_met in metric_best.groupby("cell", sort=True): cell_met = cell_met.reset_index().sort_values("obj", ascending=True) cur_idx = np.array(cell_met["iter"][:n_best]) - metric_df.loc[(i_iter, icell), "best_idx"] = ",".join( - cur_idx.astype(str) - ) + metric_df.loc[(i_iter, icell), "best_idx"] = ",".join(cur_idx.astype(str)) S_best[icell, :] = np.sum( np.stack([S_ls[i][icell, :] for i in cur_idx], axis=0), axis=0 ) > (n_best / 2) @@ -127,9 +125,7 @@ def make_S_ar( for s in S_best: Rs = R @ s - s_pks, pk_prop = find_peaks( - Rs, height=1, distance=ar_kn_len * up_factor - ) + s_pks, pk_prop = find_peaks(Rs, height=1, distance=ar_kn_len * up_factor) pk_ht = pk_prop["peak_heights"] top_idx = s_pks[np.argsort(pk_ht)[-est_nevt:]] mask = np.zeros_like(Rs, dtype=bool) @@ -229,7 +225,7 @@ def update_ar_parameters( else: # Per-cell AR update tau = np.empty((ncell, p)) - + # NOTE: Original pipeline only retained the last cell's ps/h/h_fit # when ar_use_all=False. We preserve this behavior explicitly. ps = None @@ -303,4 +299,3 @@ def propagate_ar_update( ) else: d.update(tau=tau[idx], scale=scal_best[idx]) - diff --git a/src/indeca/pipeline/binary_pursuit.py b/src/indeca/pipeline/binary_pursuit.py index d7f052c..902e2b7 100644 --- a/src/indeca/pipeline/binary_pursuit.py +++ b/src/indeca/pipeline/binary_pursuit.py @@ -249,9 +249,7 @@ def pipeline_bin( theta = np.array([tau2AR(t[0], t[1]) for t in tau]) if ar_cfg.use_all: - logger.debug( - f"Updating AR parameters for all cells: tau={tau[0]}" - ) + logger.debug(f"Updating AR parameters for all cells: tau={tau[0]}") else: logger.debug(f"Updated AR parameters per-cell") @@ -285,9 +283,7 @@ def pipeline_bin( logger.warning("Max iteration reached without convergence") # 7. Compute final results - opt_C, opt_S = _finalize_results( - state, ncell, T, up_factor, ar_cfg.use_all - ) + opt_C, opt_S = _finalize_results(state, ncell, T, up_factor, ar_cfg.use_all) # 8. Cleanup if dashboard is not None: @@ -361,4 +357,3 @@ def _finalize_results( S_ls.append(opt_S) return opt_C, opt_S - diff --git a/src/indeca/pipeline/config.py b/src/indeca/pipeline/config.py index 4eb0abe..f7682dd 100644 --- a/src/indeca/pipeline/config.py +++ b/src/indeca/pipeline/config.py @@ -278,4 +278,3 @@ def from_legacy_kwargs( n_best=n_best, ), ) - diff --git a/src/indeca/pipeline/convergence.py b/src/indeca/pipeline/convergence.py index ace1ca5..75f79ad 100644 --- a/src/indeca/pipeline/convergence.py +++ b/src/indeca/pipeline/convergence.py @@ -53,9 +53,7 @@ def check_convergence( Result indicating if converged and why """ # Need at least one previous iteration - metric_prev = metric_df[metric_df["iter"] < i_iter].dropna( - subset=["obj", "scale"] - ) + metric_prev = metric_df[metric_df["iter"] < i_iter].dropna(subset=["obj", "scale"]) metric_last = metric_df[metric_df["iter"] == i_iter - 1].dropna( subset=["obj", "scale"] ) @@ -109,4 +107,3 @@ def check_convergence( ) return ConvergenceResult(converged=False, reason="") - diff --git a/src/indeca/pipeline/init.py b/src/indeca/pipeline/init.py index 55c9b25..abfb06a 100644 --- a/src/indeca/pipeline/init.py +++ b/src/indeca/pipeline/init.py @@ -217,4 +217,3 @@ def initialize_deconvolvers( ] return dcv - diff --git a/src/indeca/pipeline/iteration.py b/src/indeca/pipeline/iteration.py index 79e434a..78d52e0 100644 --- a/src/indeca/pipeline/iteration.py +++ b/src/indeca/pipeline/iteration.py @@ -41,18 +41,14 @@ def run_deconv_step( """ res = [] - for icell, _ in tqdm( - enumerate(Y), total=Y.shape[0], desc="deconv", leave=False - ): + for icell, _ in tqdm(enumerate(Y), total=Y.shape[0], desc="deconv", leave=False): if da_client is not None: r = da_client.submit( lambda d: d.solve_scale(reset_scale=i_iter <= 1 or reset_scale), deconvolvers[icell], ) else: - r = deconvolvers[icell].solve_scale( - reset_scale=i_iter <= 1 or reset_scale - ) + r = deconvolvers[icell].solve_scale(reset_scale=i_iter <= 1 or reset_scale) res.append(r) if da_client is not None: @@ -76,4 +72,3 @@ def run_deconv_step( nnz=nnz, penal=penal, ) - diff --git a/src/indeca/pipeline/metrics.py b/src/indeca/pipeline/metrics.py index ed65249..f022abd 100644 --- a/src/indeca/pipeline/metrics.py +++ b/src/indeca/pipeline/metrics.py @@ -129,4 +129,3 @@ def update_dashboard( scale=cur_metric["scale"].squeeze(), ) dashboard.set_iter(min(i_iter + 1, max_iters - 1)) - diff --git a/src/indeca/pipeline/preprocess.py b/src/indeca/pipeline/preprocess.py index f118dd5..a088a5c 100644 --- a/src/indeca/pipeline/preprocess.py +++ b/src/indeca/pipeline/preprocess.py @@ -54,4 +54,3 @@ def preprocess_traces( Y[iy, :] = compute_dff(y, window_size=ar_kn_len * 5, q=0.2) return Y - diff --git a/src/indeca/pipeline/types.py b/src/indeca/pipeline/types.py index 7e6bab6..a3e07ec 100644 --- a/src/indeca/pipeline/types.py +++ b/src/indeca/pipeline/types.py @@ -151,4 +151,3 @@ class PipelineResult: S_ls: Optional[List[np.ndarray]] = None h_ls: Optional[List[np.ndarray]] = None h_fit_ls: Optional[List[np.ndarray]] = None -