diff --git a/pdm.lock b/pdm.lock index 7c4ca53..c6816d0 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "test"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:4f5ed9e8f685d76eb0697a6fe96b8bc2a91790bca75dbe477816cf4ca61a32e4" +content_hash = "sha256:117c9d0d5cbb1cee84326ef33261c4f32df6ced8abecbaad6fd6816b380514eb" [[metadata.targets]] requires_python = ">=3.11,<3.13" @@ -86,6 +86,19 @@ files = [ {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +requires_python = ">=3.8" +summary = "Reusable constraint types to use with typing.Annotated" +dependencies = [ + "typing-extensions>=4.0.0; python_version < \"3.9\"", +] +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + [[package]] name = "anyio" version = "4.9.0" @@ -2560,6 +2573,70 @@ files = [ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] +[[package]] +name = "pydantic" +version = "2.12.5" +requires_python = ">=3.9" +summary = "Data validation using Python type hints" +dependencies = [ + "annotated-types>=0.6.0", + "pydantic-core==2.41.5", + "typing-extensions>=4.14.1", + "typing-inspection>=0.4.2", +] +files = [ + {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"}, + {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"}, +] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +requires_python = ">=3.9" +summary = "Core functionality for Pydantic validation and serialization" +dependencies = [ + "typing-extensions>=4.14.1", +] +files = [ + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"}, + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"}, + {file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"}, +] + [[package]] name = "pyflakes" version = "3.1.0" @@ -3437,12 +3514,25 @@ files = [ [[package]] name = "typing-extensions" -version = "4.13.2" -requires_python = ">=3.8" -summary = "Backported and Experimental Type Hints for Python 3.8+" +version = "4.15.0" +requires_python = ">=3.9" +summary = "Backported and Experimental Type Hints for Python 3.9+" +files = [ + {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, + {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +requires_python = ">=3.9" +summary = "Runtime typing introspection tools" +dependencies = [ + "typing-extensions>=4.12.0", +] files = [ - {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, - {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, + {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, + {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index ba6027a..f9e4600 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "line-profiler>=4.1.3", "scikit-image>=0.25.2", "numba>=0.61.2", + "pydantic>=2.0", ] requires-python = ">=3.11,<3.13" readme = "README.md" diff --git a/src/indeca/core/AR_kernel.py b/src/indeca/core/AR_kernel.py index 60ea07d..7ef33bb 100644 --- a/src/indeca/core/AR_kernel.py +++ b/src/indeca/core/AR_kernel.py @@ -9,7 +9,7 @@ from scipy.optimize import curve_fit from statsmodels.tsa.stattools import acovf -from indeca.core.deconv.deconv import construct_G, construct_R +from indeca.core.deconv import construct_G, construct_R from indeca.core.simulation import AR2tau, ar_pulse, solve_p, tau2AR diff --git a/src/indeca/core/deconv/__init__.py b/src/indeca/core/deconv/__init__.py index da46a0c..0753809 100644 --- a/src/indeca/core/deconv/__init__.py +++ b/src/indeca/core/deconv/__init__.py @@ -1,3 +1,4 @@ -from .deconv import DeconvBin, construct_R, construct_G, max_thres, sum_downsample +from .deconv import DeconvBin +from .utils import construct_R, construct_G, max_thres, sum_downsample __all__ = ["DeconvBin", "construct_R", "construct_G", "max_thres", "sum_downsample"] diff --git a/src/indeca/core/deconv/config.py b/src/indeca/core/deconv/config.py new file mode 100644 index 0000000..bd4aff9 --- /dev/null +++ b/src/indeca/core/deconv/config.py @@ -0,0 +1,91 @@ +"""Configuration for deconv module.""" + +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field, model_validator + + +class DeconvConfig(BaseModel): + """Configuration for DeconvBin.""" + + model_config = {"frozen": True} + + coef_len: int = Field( + 100, description="Length of the coefficient kernel (e.g. calcium response)." + ) + scale: float = Field(1.0, description="Global scaling factor.") + penal: str | None = Field("l1", description="Penalty type ('l1', 'l0', etc.).") + use_base: bool = Field(False, description="Whether to include a baseline term.") + upsamp: int = Field(1, description="Upsampling factor.") + norm: Literal["l1", "l2", "huber"] = Field( + "l2", description="Norm for data fidelity ('l2', 'l1', 'huber')." + ) + mixin: bool = Field( + False, description="Whether to use mixed-integer programming (boolean spikes)." + ) + backend: Literal["osqp", "cvxpy", "cuosqp"] = Field( + "osqp", + description="Solver backend ('osqp', 'cvxpy', 'cuosqp'). Note: emosqp requires codegen and is not supported.", + ) + free_kernel: bool = Field( + False, + description="If True, use convolution constraint instead of AR constraint. Only supported with OSQP backends.", + ) + nthres: int = Field(1000, description="Number of thresholds for thresholding step.") + err_weighting: Optional[str] = Field( + None, description="Error weighting method ('fft', 'corr', 'adaptive', or None)." + ) + wt_trunc_thres: float = Field( + 1e-2, description="Threshold for truncating error weights." + ) + masking_radius: Optional[int] = Field( + None, description="Radius for masking around spikes." + ) + pks_polish: bool = Field(True, description="Whether to polish peaks after solving.") + th_min: float = Field(0.0, description="Minimum threshold.") + th_max: float = Field(1.0, description="Maximum threshold.") + density_thres: Optional[float] = Field( + None, description="Max spike density threshold." + ) + ncons_thres: Union[int, Literal["auto"], None] = Field( + None, description="Max consecutive spikes threshold. If 'auto', upsamp + 1." + ) + min_rel_scl: Union[float, Literal["auto"], None] = Field( + "auto", description="Minimum relative scale. Use None to disable." + ) + + max_iter_l0: int = 30 + max_iter_penal: int = 500 + max_iter_scal: int = 50 + delta_l0: float = 1e-4 + delta_penal: float = 1e-4 + atol: float = 1e-3 + rtol: float = 1e-3 + Hlim: Optional[int] = 1e5 + + @model_validator(mode="before") + @classmethod + def resolve_auto_fields(cls, data): + # Resolve "auto" values before constructing the (frozen) model to avoid + # returning a new instance from an "after" validator (pydantic warns). + if not isinstance(data, dict): + return data + upsamp = data.get("upsamp", 1) + if data.get("min_rel_scl") == "auto": + data["min_rel_scl"] = 0.5 / upsamp + if data.get("ncons_thres") == "auto": + data["ncons_thres"] = upsamp + 1 + return data + + @model_validator(mode="after") + def validate_penal(self): + allowed = {None, "l0", "l1"} + if self.penal not in allowed: + raise ValueError(f"Unsupported penal type: {self.penal}") + return self + + @model_validator(mode="after") + def validate_compat(self): + if self.free_kernel and self.backend == "cvxpy": + raise ValueError("free_kernel=True is not supported with backend='cvxpy'") + return self diff --git a/src/indeca/core/deconv/deconv.py b/src/indeca/core/deconv/deconv.py index 40fa2fc..ea39772 100644 --- a/src/indeca/core/deconv/deconv.py +++ b/src/indeca/core/deconv/deconv.py @@ -1,139 +1,46 @@ +"""Main deconvolution module.""" + import itertools as itt +import math +import os import warnings -from typing import Tuple +from typing import Tuple, Any, Optional -import cvxpy as cp import numpy as np -import osqp -import pandas as pd import scipy.sparse as sps -import xarray as xr -from numba import njit -from scipy.ndimage import label +import pandas as pd from scipy.optimize import direct -from scipy.signal import ShortTimeFFT, find_peaks -from scipy.special import huber +from scipy.signal import find_peaks +from scipy.ndimage import label from indeca.utils.logging_config import get_module_logger from indeca.core.simulation import AR2tau, ar_pulse, exp_pulse, solve_p, tau2AR from indeca.utils.utils import scal_lstsq +from .config import DeconvConfig +from .solver import DeconvSolver, CVXPYSolver, OSQPSolver +from .utils import max_thres, max_consecutive, sum_downsample # Initialize logger for this module logger = get_module_logger("deconv") logger.info("Deconv module initialized") -try: - import cuosqp -except ImportError: - logger.warning("No GPU solver support") - -def construct_R(T: int, up_factor: int): - if up_factor > 1: - return sps.csc_matrix( - ( - np.ones(T * up_factor), - (np.repeat(np.arange(T), up_factor), np.arange(T * up_factor)), - ), - shape=(T, T * up_factor), - ) - else: - return sps.eye(T, format="csc") - - -def sum_downsample(a, factor): - return np.convolve(a, np.ones(factor), mode="full")[factor - 1 :: factor] - - -def construct_G(fac: np.ndarray, T: int, fromTau=False): - fac = np.array(fac) - assert fac.shape == (2,) - if fromTau: - fac = np.array(tau2AR(*fac)) - return sps.dia_matrix( - ( - np.tile(np.concatenate(([1], -fac)), (T, 1)).T, - -np.arange(len(fac) + 1), - ), - shape=(T, T), - ).tocsc() - - -def max_thres( - a: xr.DataArray, - nthres: int, - th_min=0.1, - th_max=0.9, - ds=None, - return_thres=False, - th_amplitude=False, - delta=1e-6, - reverse_thres=False, - nz_only: bool = False, -): - amax = a.max() - if reverse_thres: - thres = np.linspace(th_max, th_min, nthres) - else: - thres = np.linspace(th_min, th_max, nthres) - if th_amplitude: - S_ls = [np.floor_divide(a, (amax * th).clip(delta, None)) for th in thres] - else: - S_ls = [(a > (amax * th).clip(delta, None)) for th in thres] - if ds is not None: - S_ls = [sum_downsample(s, ds) for s in S_ls] - if nz_only: - Snz = [ss.sum() > 0 for ss in S_ls] - S_ls = [ss for ss, nz in zip(S_ls, Snz) if nz] - thres = [th for th, nz in zip(thres, Snz) if nz] - if return_thres: - return S_ls, thres - else: - return S_ls - - -@njit(nopython=True, nogil=True, cache=True) -def bin_convolve( - coef: np.ndarray, s: np.ndarray, nzidx_s: np.ndarray = None, s_len: int = None -): - coef_len = len(coef) - if s_len is None: - s_len = len(s) - out = np.zeros(s_len) - nzidx = np.where(s)[0] - if nzidx_s is not None: - nzidx = nzidx_s[nzidx].astype( - np.int64 - ) # astype to fix numpa issues on GPU on Windows - for i0 in nzidx: - i1 = min(i0 + coef_len, s_len) - clen = i1 - i0 - out[i0:i1] += coef[:clen] - return out - - -@njit(nopython=True, nogil=True, cache=True) -def max_consecutive(arr): - max_count = 0 - current_count = 0 - for value in arr: - if value: - current_count += 1 - max_count = max(max_count, current_count) - else: - current_count = 0 - return max_count +class DeconvBin: + """Deconvolution main class. + This class wraps the solver backends and provides high-level methods + for spike inference including thresholding, penalty optimization, + and scale estimation. + """ -class DeconvBin: def __init__( self, - y: np.array = None, + y: np.ndarray = None, y_len: int = None, - theta: np.array = None, - tau: np.array = None, - ps: np.array = None, - coef: np.array = None, + theta: np.ndarray = None, + tau: np.ndarray = None, + ps: np.ndarray = None, + coef: np.ndarray = None, coef_len: int = 100, scale: float = 1, penal: str = "l1", @@ -142,6 +49,7 @@ def __init__( norm: str = "l2", mixin: bool = False, backend: str = "osqp", + free_kernel: bool = False, nthres: int = 1000, err_weighting: str = None, wt_trunc_thres: float = 1e-2, @@ -163,15 +71,25 @@ def __init__( dashboard=None, dashboard_uid=None, ) -> None: - # book-keeping + # Handle y input if y is not None: self.y_len = len(y) + self.y = y else: assert y_len is not None self.y_len = y_len + self.y = np.zeros(y_len) + if coef_len is not None and coef_len > self.y_len: warnings.warn("Coefficient length longer than data") coef_len = self.y_len + + # Store tau/theta/ps + self.theta = None + self.tau = None + self.ps = None + + # Compute coefficients from theta or tau if theta is not None: self.theta = np.array(theta) if tau is None: @@ -207,164 +125,146 @@ def __init__( if coef is None: assert coef_len is not None coef = np.ones(coef_len * upsamp) + + # `coef_len` (config) is the *base* kernel length; the stored `coef` is + # already upsampled to length `coef_len * upsamp`. self.coef_len = len(coef) - self.T = self.y_len * upsamp - l0_penal = 0 - l1_penal = 0 - self.free_kernel = False - self.penal = penal - self.use_base = use_base - self.l0_penal = l0_penal - self.w_org = np.ones(self.T) - self.w = np.ones(self.T) - self.upsamp = upsamp - self.norm = norm - self.backend = backend - self.nthres = nthres - self.th_min = th_min - self.th_max = th_max - self.max_iter_l0 = max_iter_l0 - self.max_iter_penal = max_iter_penal - self.max_iter_scal = max_iter_scal - self.delta_l0 = delta_l0 - self.delta_penal = delta_penal - self.atol = atol - self.rtol = rtol - self.Hlim = Hlim + + # Create config (note: frozen after creation) + self.cfg = DeconvConfig( + coef_len=coef_len, + scale=scale, + penal=penal, + use_base=use_base, + upsamp=upsamp, + norm=norm, + mixin=mixin, + backend=backend, + free_kernel=free_kernel, + nthres=nthres, + err_weighting=err_weighting, + wt_trunc_thres=wt_trunc_thres, + masking_radius=masking_radius, + pks_polish=pks_polish, + th_min=th_min, + th_max=th_max, + density_thres=density_thres, + # IMPORTANT: preserve old semantics: None disables consecutive constraint, + # and "auto" enables the (upsamp + 1) default. + ncons_thres=ncons_thres, + min_rel_scl=min_rel_scl, + max_iter_l0=max_iter_l0, + max_iter_penal=max_iter_penal, + max_iter_scal=max_iter_scal, + delta_l0=delta_l0, + delta_penal=delta_penal, + atol=atol, + rtol=rtol, + Hlim=Hlim, + ) + + # Dashboard for visualization self.dashboard = dashboard self.dashboard_uid = dashboard_uid - self.nzidx_s = np.arange(self.T, dtype=np.int64) - self.nzidx_c = np.arange(self.T, dtype=np.int64) - self.x_cache = None - self.err_weighting = err_weighting - self.masking_r = masking_radius - self.pks_polish = pks_polish - self.err_wt = np.ones(self.y_len) - self.density_thres = density_thres - self.wt_trunc_thres = wt_trunc_thres - if ncons_thres == "auto": - self.ncons_thres = upsamp + 1 - else: - self.ncons_thres = ncons_thres - if min_rel_scl == "auto": - self.min_rel_scl = ( - 0.5 / self.upsamp - ) # 2 x upsamp number of spikes should be more than enough to explain highest peak - else: - self.min_rel_scl = min_rel_scl - if err_weighting == "fft": - self.stft = ShortTimeFFT(win=np.ones(self.coef_len), hop=1, fs=1) - self.yspec = self._get_stft_spec(y) - if y is not None: - self.huber_k = 0.5 * np.std(y) - self.err_total = self._res_err(y) - else: - self.huber_k = 0 - self.err_total = 0 - self._update_R() - # setup cvxpy - if self.backend == "cvxpy": - self.R = cp.Constant(self.R, name="R") - self.c = cp.Variable((self.T, 1), nonneg=True, name="c") - self.s = cp.Variable((self.T, 1), nonneg=True, name="s", boolean=mixin) - self.y = cp.Parameter(shape=(self.y_len, 1), name="y") - self.coef = cp.Parameter(value=coef, shape=self.coef_len, name="coef") - self.scale = cp.Parameter(value=scale, name="scale", nonneg=True) - self.l1_penal = cp.Parameter(value=l1_penal, name="l1_penal", nonneg=True) - self.l0_w = cp.Parameter( - shape=self.T, value=self.l0_penal * self.w, nonneg=True, name="w_l0" - ) # product of l0_penal * w! - if y is not None: - self.y.value = y.reshape((-1, 1)) - if coef is not None: - self.coef.value = coef - if use_base: - self.b = cp.Variable(nonneg=True, name="b") - else: - self.b = cp.Constant(value=0, name="b") - if norm == "l1": - self.err_term = cp.sum( - cp.abs(self.y - self.scale * self.R @ self.c - self.b) - ) - elif norm == "l2": - self.err_term = cp.sum_squares( - self.y - self.scale * self.R @ self.c - self.b - ) - elif norm == "huber": - self.err_term = cp.sum( - cp.huber(self.y - self.scale * self.R @ self.c - self.b) + + # Penalty tracking - solver tracks scale, we track penalty locally + self._l0_penal = 0.0 + self._l1_penal = 0.0 + + # Create solver + if self.cfg.backend == "cvxpy": + if self.cfg.free_kernel: + raise NotImplementedError( + "CVXPY backend does not support free_kernel mode" ) - obj = cp.Minimize( - self.err_term - + self.l0_w.T @ cp.abs(self.s) - + self.l1_penal * cp.sum(cp.abs(self.s)) + self.solver = CVXPYSolver( + self.cfg, + self.y_len, + y=self.y, + coef=coef, + theta=self.theta, + tau=self.tau, + ps=self.ps, + ) + elif self.cfg.backend in ["osqp", "cuosqp"]: + self.solver = OSQPSolver( + self.cfg, + self.y_len, + y=self.y, + coef=coef, + theta=self.theta, + tau=self.tau, + ps=self.ps, ) - if self.free_kernel: - dcv_cons = [ - self.c[:, 0] == cp.convolve(self.coef, self.s[:, 0])[: self.T] - ] - else: - self.theta = cp.Parameter( - value=self.theta, shape=self.theta.shape, name="theta" - ) - G_diag = sps.eye(self.T - 1) + sum( - [ - cp.diag(cp.promote(-self.theta[i], (self.T - i - 2,)), -i - 1) - for i in range(self.theta.shape[0]) - ] - ) # diag part of unshifted G - G = cp.bmat( - [ - [np.zeros((self.T - 1, 1)), G_diag], - [np.zeros((1, 1)), np.zeros((1, self.T - 1))], - ] - ) - dcv_cons = [self.s == G @ self.c] - edge_cons = [self.c[0, 0] == 0, self.s[-1, 0] == 0] - amp_cons = [self.s <= 1] - self.prob_free = cp.Problem(obj, dcv_cons + edge_cons) - self.prob = cp.Problem(obj, dcv_cons + edge_cons + amp_cons) - self._update_HG() # self.H and self.G not used for cvxpy problems - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - # book-keeping - if y is None: - self.y = np.ones(self.y_len) - else: - self.y = y - if coef is None: - self.coef = np.ones(self.coef_len) - else: - self.coef = coef - self.c = np.zeros(self.T * upsamp) - self.s = np.zeros(self.T * upsamp) - self.s_bin = None - self.b = 0 - self.l1_penal = l1_penal - self.scale = scale - self.H = None - self._update_Wt() - self._setup_prob_osqp() + else: + raise ValueError(f"Unknown backend: {self.cfg.backend}") + + # State + self.T = self.solver.T + self.s = np.zeros(self.T) + self.b = 0 + self.c_bin = None + self.s_bin = None + # NOTE: do not store an "err_total" here. `_res_err` expects a residual, + # not the raw `y`, and this value was misleading and unused. + + # Update dashboard with initial kernel if self.dashboard is not None: - self.dashboard.update( - h=self.coef.value if backend == "cvxpy" else self.coef, - uid=self.dashboard_uid, - ) - tr_exp, _, _ = exp_pulse( - self.tau[0], - self.tau[1], - p_d=self.ps[0], - p_r=self.ps[1], - nsamp=self.coef_len, - ) - theta = self.theta.value if self.backend == "cvxpy" else self.theta - tr_ar, _, _ = ar_pulse(theta[0], theta[1], nsamp=self.coef_len, shifted=True) - assert (~np.isnan(coef)).all() - assert np.isclose( - tr_exp, coef, atol=self.atol - ).all(), "exp time constant inconsistent" - assert np.isclose( - tr_ar, coef, atol=self.atol - ).all(), "ar coefficients inconsistent" + self.dashboard.update(h=coef, uid=self.dashboard_uid) + + # Validate coefficients + self.solver.validate_coefficients(atol=atol) + + @property + def scale(self) -> float: + """Current scale value (delegated to solver).""" + return self.solver.scale + + @property + def H(self): + """Convolution matrix H.""" + return self.solver.H + + @property + def R(self): + """Resampling matrix R.""" + return self.solver.R + + @property + def R_org(self): + """Original (full) resampling matrix.""" + return self.solver.R_org + + @property + def nzidx_s(self): + """Nonzero indices for s.""" + return self.solver.nzidx_s + + @property + def nzidx_c(self): + """Nonzero indices for c.""" + return self.solver.nzidx_c + + @property + def coef(self): + """Coefficient kernel.""" + return self.solver.coef + + @property + def err_wt(self): + """Error weighting vector.""" + return self.solver.err_wt + + @property + def wgt_len(self): + """Error weighting length.""" + return self.solver.wgt_len + + @err_wt.setter + def err_wt(self, value): + """Allow direct assignment (used by tests/demo code).""" + self.solver.err_wt = np.array(value) + self.solver.Wt = sps.diags(self.solver.err_wt) def update( self, @@ -380,178 +280,111 @@ def update( clear_weighting: bool = False, scale_coef: bool = False, ) -> None: - logger.debug( - f"Updating parameters - backend: {self.backend}, tau: {tau}, scale: {scale}, scale_mul: {scale_mul}, l0_penal: {l0_penal}, l1_penal: {l1_penal}" + """Update parameters.""" + logger.debug(f"Updating parameters - backend: {self.cfg.backend}") + + theta_new = None + if tau is not None: + theta_new = np.array(tau2AR(tau[0], tau[1])) + p = solve_p(tau[0], tau[1]) + coef_new, _, _ = exp_pulse( + tau[0], + tau[1], + p_d=p, + p_r=-p, + nsamp=self.cfg.coef_len * self.cfg.upsamp, + kn_len=self.cfg.coef_len * self.cfg.upsamp, + ) + coef = coef_new + self.tau = tau + self.theta = theta_new + self.ps = np.array([p, -p]) + + if coef is not None and scale_coef: + current_coef = ( + self.solver.coef if self.solver.coef is not None else np.ones_like(coef) + ) + scale_mul = scal_lstsq(coef, current_coef).item() + + if l0_penal is not None: + self._l0_penal = l0_penal + if l1_penal is not None: + self._l1_penal = l1_penal + + # Forward updates to solver (solver handles scale directly) + self.solver.update( + y=y, + coef=coef, + scale=scale, + scale_mul=scale_mul, + l1_penal=self._l1_penal if l1_penal is not None else None, + l0_penal=self._l0_penal if l0_penal is not None else None, + w=w, + theta=theta_new if theta_new is not None else self.theta, + update_weighting=update_weighting, + clear_weighting=clear_weighting, + scale_coef=scale_coef, ) - if self.backend == "cvxpy": - if y is not None: - self.y.value = y - if tau is not None: - theta_new = np.array(tau2AR(tau[0], tau[1])) - p = solve_p(tau[0], tau[1]) - coef, _, _ = exp_pulse( - tau[0], - tau[1], - p_d=p, - p_r=-p, - nsamp=self.coef_len, - kn_len=self.coef_len, - ) - self.coef.value = coef - self.theta.value = theta_new - self._update_HG() - self._update_wgt_len() - if coef is not None: - if scale_coef: - scale_mul = scal_lstsq(coef, self.coef).item() - self.coef.value = coef - self._update_HG() - self._update_wgt_len() - if scale is not None: - self.scale.value = scale - if scale_mul is not None: - self.scale.value = scale_mul * self.scale.value - if l1_penal is not None: - self.l1_penal.value = l1_penal - if l0_penal is not None: - self.l0_penal = l0_penal - if w is not None: - self._update_w(w) - if l0_penal is not None or w is not None: - self.l0_w.value = self.l0_penal * self.w - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - # update input params - if y is not None: - self.y = y - if tau is not None: - theta_new = np.array(tau2AR(tau[0], tau[1])) - p = solve_p(tau[0], tau[1]) - coef, _, _ = exp_pulse( - tau[0], - tau[1], - p_d=p, - p_r=-p, - nsamp=self.coef_len, - kn_len=self.coef_len, - ) - self.tau = tau - self.ps = np.array([p, -p]) - self.theta = theta_new - if coef is not None: - if scale_coef: - scale_mul = scal_lstsq(coef, self.coef).item() - self.coef = coef - if scale is not None: - self.scale = scale - if scale_mul is not None: - self.scale = scale_mul * self.scale - if l1_penal is not None: - self.l1_penal = l1_penal - if l0_penal is not None: - self.l0_penal = l0_penal - if w is not None: - self._update_w(w) - # update internal variables - updt_HG, updt_P, updt_A, updt_q0, updt_q, updt_bounds, setup_prob = [ - False - ] * 7 - if coef is not None: - self._update_HG() - self._update_wgt_len() - updt_HG = True - if self.err_weighting is not None and update_weighting: - self._update_Wt(clear=clear_weighting) - if self.err_weighting == "adaptive": - setup_prob = True - else: - updt_P = True - updt_q0 = True - updt_q = True - if self.norm == "huber": - if any((scale is not None, scale_mul is not None, updt_HG)): - self._update_A() - updt_A = True - if any( - (w is not None, l0_penal is not None, l1_penal is not None, updt_HG) - ): - self._update_q() - updt_q = True - if y is not None: - self._update_bounds() - updt_bounds = True + + if y is not None: + self.y = y + + def _pad_s(self, s: np.ndarray = None) -> np.ndarray: + """Pad sparse s to full length.""" + return self.solver._pad_s(s) + + def _pad_c(self, c: np.ndarray = None) -> np.ndarray: + """Pad sparse c to full length.""" + return self.solver._pad_c(c) + + def _reset_cache(self) -> None: + """Reset solver cache.""" + self.solver.reset_cache() + + def _reset_mask(self) -> None: + """Reset solver mask to full range.""" + self.solver.reset_mask() + + def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> None: + """Update mask based on current solution.""" + # CVXPY doesn't support masking + if self.cfg.backend == "cvxpy": + return + + if self.cfg.backend in ["osqp", "cuosqp"]: + if use_wt: + nzidx_s = np.where(self.R.T @ self.err_wt)[0] else: - if any((updt_HG, updt_A)): - A_before = self.A.copy() - self._update_A() - assert self.A.shape == A_before.shape - assert (self.A.nonzero()[0] == A_before.nonzero()[0]).all() - assert (self.A.nonzero()[1] == A_before.nonzero()[1]).all() - updt_A = True - if any((scale is not None, scale_mul is not None, updt_HG, updt_P)): - P_before = self.P.copy() - self._update_P() - assert self.P.shape == P_before.shape - assert (self.P.nonzero()[0] == P_before.nonzero()[0]).all() - assert (self.P.nonzero()[1] == P_before.nonzero()[1]).all() - updt_P = True - if any( - ( - scale is not None, - scale_mul is not None, - y is not None, - updt_HG, - updt_q0, - ) - ): - q0_before = self.q0.copy() - self._update_q0() - assert self.q0.shape == q0_before.shape - updt_q0 = True - if any( - ( - w is not None, - l0_penal is not None, - l1_penal is not None, - updt_q0, - updt_q, - ) - ): - q_before = self.q.copy() - self._update_q() - assert self.q.shape == q_before.shape - updt_q = True - # update prob - logger.debug(f"Updating optimization problem with {self.backend}") - if self.backend == "emosqp": - if updt_P: - self.prob_free.update_P(self.P.data, None, 0) - self.prob.update_P(self.P.data, None, 0) - if updt_q: - self.prob_free.update_lin_cost(self.q) - self.prob.update_lin_cost(self.q) - elif self.backend in ["osqp", "cuosqp"] and any( - (updt_P, updt_q, updt_A, updt_bounds, setup_prob) - ): - if setup_prob: - self._setup_prob_osqp() + if self.cfg.masking_radius is not None: + mask = np.zeros(self.T) + for nzidx in np.where(self._pad_s(self.s_bin) > 0)[0]: + start = max(nzidx - self.cfg.masking_radius, 0) + end = min(nzidx + self.cfg.masking_radius, self.T) + mask[start:end] = 1 + nzidx_s = np.where(mask)[0] else: - self.prob_free.update( - Px=self.P.copy().data if updt_P else None, - q=self.q.copy() if updt_q else None, - Ax=self.A.copy().data if updt_A else None, - l=self.lb.copy() if updt_bounds else None, - u=self.ub_inf.copy() if updt_bounds else None, - ) - self.prob.update( - Px=self.P.copy().data if updt_P else None, - q=self.q.copy() if updt_q else None, - Ax=self.A.copy().data if updt_A else None, - l=self.lb.copy() if updt_bounds else None, - u=self.ub.copy() if updt_bounds else None, - ) - logger.debug("Optimization problem updated") + self._reset_mask() + opt_s, _ = self.solve(amp_constraint) + nzidx_s = np.where(opt_s > self.cfg.delta_penal)[0] + + if len(nzidx_s) == 0: + logger.warning("Empty mask, resetting") + self._reset_mask() + return + + self.solver.set_mask(nzidx_s) + + # Verify mask is valid + if not self.cfg.free_kernel and len(self.nzidx_c) < self.T: + res = self.solver.prob.solve() + if res.info.status == "primal infeasible": + logger.warning("Mask caused primal infeasibility, resetting") + self._reset_mask() + else: + raise NotImplementedError("Masking not supported for cvxpy backend") def _cut_pks_labs(self, s, labs, pks): + """Cut peak labels at valleys between peaks.""" pk_labs = np.full_like(labs, -1) lb = 0 for ilab in range(labs.max() + 1): @@ -574,6 +407,8 @@ def _cut_pks_labs(self, s, labs, pks): def _merge_sparse_regs( self, s, regs, err_rtol=0, max_len=9, constraint_sum: bool = True ): + """Merge sparse regions to minimize error.""" + max_combos = 10_000 s_ret = s.copy() for r in range(regs.max() + 1): ridx = np.where(regs == r)[0] @@ -593,6 +428,9 @@ def _merge_sparse_regs( else: ns_vals = list(range(ns_min, rlen + 1)) for ns in ns_vals: + # Defensive guard: avoid combinatorial explosion. + if math.comb(rlen, ns) > max_combos: + continue for idxs in itt.combinations(ridx, ns): idxs = np.array(idxs) s_test = s_new.copy() @@ -600,28 +438,15 @@ def _merge_sparse_regs( err_after = self._compute_err(s=s_test[self.nzidx_s]) err_ls.append(err_after) idx_ls.append(idxs) - err_min_idx = np.argmin(err_ls) - err_min = err_ls[err_min_idx] - if err_min - err_before < err_rtol * err_before: - idx_min = idx_ls[err_min_idx] - s_new[idx_min] = rsum / len(idx_min) - s_ret = s_new - return s_ret - - def _pad_s(self, s=None): - if s is None: - s = self.s - s_ret = np.zeros(self.T) - s_ret[self.nzidx_s] = s + if len(err_ls) > 0: + err_min_idx = np.argmin(err_ls) + err_min = err_ls[err_min_idx] + if err_min - err_before < err_rtol * err_before: + idx_min = idx_ls[err_min_idx] + s_new[idx_min] = rsum / len(idx_min) + s_ret = s_new return s_ret - def _pad_c(self, c=None): - if c is None: - c = self.s - c_ret = np.zeros(self.T) - c_ret[self.nzidx_c] = c - return c_ret - def solve( self, amp_constraint: bool = True, @@ -630,24 +455,31 @@ def solve( pks_delta: float = 1e-5, pks_err_rtol: float = 10, pks_cut: bool = False, - ) -> np.ndarray: - if self.l0_penal == 0: - opt_s, opt_b = self._solve( - amp_constraint=amp_constraint, update_cache=update_cache - ) + ) -> Tuple[np.ndarray, float]: + """Solve main routine (l0 heuristic wrapper).""" + if self._l0_penal == 0: + opt_s, opt_b, _ = self.solver.solve(amp_constraint=amp_constraint) else: + # L0 heuristic via reweighted L1 metric_df = None - for i in range(self.max_iter_l0): - cur_s, cur_obj = self._solve(amp_constraint, return_obj=True) + for i in range(self.cfg.max_iter_l0): + cur_s, cur_b, _ = self.solver.solve(amp_constraint=amp_constraint) + # Compute objective explicitly since solver returns 0 + cur_obj = self._compute_err(s=cur_s, b=cur_b) + if metric_df is None: obj_best = np.inf obj_last = np.inf else: - obj_best = metric_df["obj"][1:].min() + obj_best = ( + metric_df["obj"][1:].min() if len(metric_df) > 1 else np.inf + ) obj_last = np.array(metric_df["obj"])[-1] - opt_s = np.where(cur_s > self.delta_l0, cur_s, 0) + + opt_s = np.where(cur_s > self.cfg.delta_l0, cur_s, 0) obj_gap = np.abs(cur_obj - obj_best) obj_delta = np.abs(cur_obj - obj_last) + cur_met = pd.DataFrame( [ { @@ -660,27 +492,30 @@ def solve( ] ) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - if any((obj_gap < self.rtol * np.obj_best, obj_delta < self.atol)): + + if any([obj_gap < self.cfg.rtol * obj_best, obj_delta < self.cfg.atol]): break else: - self.update( - w=np.clip( - np.ones(self.T) / (self.delta_l0 * np.ones(self.T) + opt_s), - 0, - 1e5, - ) - ) # clip to avoid numerical issues + w_new = np.clip( + np.ones(self.T) / (self.cfg.delta_l0 * np.ones(self.T) + opt_s), + 0, + 1e5, + ) + self.update(w=w_new) else: warnings.warn( - "l0 heuristic did not converge in {} iterations".format( - self.max_iter_l0 - ) + f"l0 heuristic did not converge in {self.cfg.max_iter_l0} iterations" ) + + opt_s, opt_b, _ = self.solver.solve(amp_constraint=amp_constraint) + self.b = opt_b + + # Peak polishing if pks_polish is None: pks_polish = amp_constraint - if pks_polish and self.backend != "cvxpy": - s_pad = self._pad_s(s=opt_s) + if pks_polish and self.cfg.backend != "cvxpy": + s_pad = self._pad_s(opt_s) if len(opt_s) == len(self.nzidx_s) else opt_s s_ft = np.where(s_pad > pks_delta, s_pad, 0) labs, _ = label(s_ft) labs = labs - 1 @@ -688,37 +523,125 @@ def solve( pks_idx, _ = find_peaks(s_ft) labs = self._cut_pks_labs(s=s_ft, labs=labs, pks=pks_idx) opt_s = self._merge_sparse_regs(s=s_ft, regs=labs, err_rtol=pks_err_rtol) - opt_s = opt_s[self.nzidx_s] + if len(opt_s) == self.T: + opt_s = opt_s[self.nzidx_s] + self.s = np.abs(opt_s) return self.s, self.b + def _compute_c(self, s: np.ndarray = None) -> np.ndarray: + """Compute c from s via convolution.""" + if s is not None: + return self.solver.convolve(s) + else: + return self.solver.convolve(self.s) + + def _res_err(self, r: np.ndarray) -> float: + """Compute residual error.""" + if self.err_wt is not None: + r = self.err_wt * r + if self.cfg.norm == "l1": + return np.sum(np.abs(r)) + elif self.cfg.norm == "l2": + return np.sum(r**2) + elif self.cfg.norm == "huber": + # True Huber loss: + # 0.5*r^2 if |r| <= k + # k*(|r| - 0.5*k) otherwise + k = float(self.solver.huber_k) + ar = np.abs(r) + quad = 0.5 * (r**2) + lin = k * (ar - 0.5 * k) + return float(np.sum(np.where(ar <= k, quad, lin))) + + def _compute_err( + self, + y_fit: np.ndarray = None, + b: float = None, + c: np.ndarray = None, + s: np.ndarray = None, + res: np.ndarray = None, + obj_crit: str = None, + ) -> float: + """Compute error/objective value.""" + y = np.array(self.y) + if res is not None: + y = y - res + if b is None: + b = self.b + y = y - b + + if y_fit is None: + if c is None: + c = self._compute_c(s) + R = self.R + if sps.issparse(c): + y_fit = np.array((R @ c * self.scale).todense()).squeeze() + else: + y_fit = np.array(R @ c * self.scale).squeeze() + + r = y - y_fit + err = self._res_err(r) + + if obj_crit in [None, "spk_diff"]: + return float(err) + else: + nspk = (s > 0).sum() if s is not None else (self.s > 0).sum() + if obj_crit == "mean_spk": + err_total = self._res_err(y - y.mean()) + return float((err - err_total) / max(nspk, 1)) + elif obj_crit in ["aic", "bic"]: + T = len(r) + mu = r.mean() + sigma = max(((r - mu) ** 2).sum() / T, 1e-10) + logL = -0.5 * ( + T * np.log(2 * np.pi * sigma) + 1 / sigma * ((r - mu) ** 2).sum() + ) + if obj_crit == "aic": + return float(2 * (nspk - logL)) + elif obj_crit == "bic": + return float(nspk * np.log(T) - 2 * logL) + return float(err) + def _max_thres(self, s, nz_only=True): + """Apply max thresholding to solution.""" S_ls, thres = max_thres( - s, - nthres=self.nthres, - th_min=self.th_min, - th_max=self.th_max, + np.array(s), + nthres=self.cfg.nthres, + th_min=self.cfg.th_min, + th_max=self.cfg.th_max, reverse_thres=True, return_thres=True, nz_only=nz_only, ) - if self.density_thres is not None: + # Ensure we return numpy arrays (not xarray.DataArray) + S_ls = [np.array(ss) for ss in S_ls] + + # Apply density threshold + if self.cfg.density_thres is not None: Sden = [ss.sum() / self.T for ss in S_ls] - S_ls = [ss for ss, den in zip(S_ls, Sden) if den < self.density_thres] - thres = [th for th, den in zip(thres, Sden) if den < self.density_thres] - if self.ncons_thres is not None: + S_ls = [ss for ss, den in zip(S_ls, Sden) if den < self.cfg.density_thres] + thres = [th for th, den in zip(thres, Sden) if den < self.cfg.density_thres] + + # Apply consecutive threshold + if self.cfg.ncons_thres is not None: S_pad = [self._pad_s(ss) for ss in S_ls] - Sncons = [max_consecutive(ss) for ss in S_pad] - if min(Sncons) < self.ncons_thres: + Sncons = [max_consecutive(np.array(ss)) for ss in S_pad] + if len(Sncons) > 0 and min(Sncons) < self.cfg.ncons_thres: S_ls = [ - ss for ss, ncons in zip(S_ls, Sncons) if ncons <= self.ncons_thres + ss + for ss, ncons in zip(S_ls, Sncons) + if ncons <= self.cfg.ncons_thres ] thres = [ - th for th, ncons in zip(thres, Sncons) if ncons <= self.ncons_thres + th + for th, ncons in zip(thres, Sncons) + if ncons <= self.cfg.ncons_thres ] - else: + elif len(S_ls) > 0: S_ls = [S_ls[0]] thres = [thres[0]] + return S_ls, thres def solve_thres( @@ -729,17 +652,21 @@ def solve_thres( return_intm: bool = False, pks_polish: bool = None, obj_crit: str = None, - ) -> Tuple[np.ndarray]: - if self.backend == "cvxpy": - y = np.array(self.y.value.squeeze()) - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - y = np.array(self.y) + ) -> Tuple[np.ndarray, ...]: + """Solve with thresholding.""" + y = np.array(self.y) opt_s, opt_b = self.solve(amp_constraint=amp_constraint, pks_polish=pks_polish) - R = self.R.value if self.backend == "cvxpy" else self.R + R = self.R + if ignore_res: - res = y - opt_b - self.scale * R @ self._compute_c(opt_s) + c = self._compute_c(opt_s) + if sps.issparse(c): + res = y - opt_b - self.scale * np.array((R @ c).todense()).squeeze() + else: + res = y - opt_b - self.scale * (R @ c).squeeze() else: res = np.zeros_like(y) + svals, thres = self._max_thres(opt_s) if not len(svals) > 0: if return_intm: @@ -751,31 +678,62 @@ def solve_thres( 0, np.inf, ) + cvals = [self._compute_c(s) for s in svals] - yfvals = [np.array((R @ c).todense()).squeeze() for c in cvals] + + def to_arr(m): + return ( + np.array(m.todense()).squeeze() + if sps.issparse(m) + else np.array(m).squeeze() + ) + + yfvals = [to_arr(R @ c) for c in cvals] + if scaling: scal_fit = [scal_lstsq(yf, y - res, fit_intercept=True) for yf in yfvals] scals = [sf[0] for sf in scal_fit] bs = [sf[1] for sf in scal_fit] - if self.min_rel_scl is not None: - scl_thres = np.max(y) * self.min_rel_scl - valid_idx = np.where(scals > scl_thres)[0] + + if self.cfg.min_rel_scl is not None: + scl_thres = np.max(y) * self.cfg.min_rel_scl + valid_idx = np.where(np.array(scals) > scl_thres)[0] if len(valid_idx) > 0: + # Optional legacy compatibility: reproduce old deconv behavior where + # scale filtering shrinks `scals/bs` but does NOT shrink `svals/yfvals`. + # This leads to zip() truncation later and can change which candidate is selected. + legacy_bug = ( + os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + ) scals = [scals[i] for i in valid_idx] bs = [bs[i] for i in valid_idx] + if not legacy_bug: + svals = [svals[i] for i in valid_idx] + cvals = [cvals[i] for i in valid_idx] + yfvals = [yfvals[i] for i in valid_idx] else: max_idx = np.argmax(scals) + legacy_bug = ( + os.environ.get("INDECA_LEGACY_SCALE_FILTER_BUG", "0") == "1" + ) scals = [scals[max_idx]] bs = [bs[max_idx]] + if not legacy_bug: + svals = [svals[max_idx]] + cvals = [cvals[max_idx]] + yfvals = [yfvals[max_idx]] else: scals = [self.scale] * len(yfvals) bs = [(y - res - scl * yf).mean() for scl, yf in zip(scals, yfvals)] + objs = [ self._compute_err(s=ss, y_fit=scl * yf, res=res, b=bb, obj_crit=obj_crit) for ss, scl, yf, bb in zip(svals, scals, yfvals, bs) ] + scals = np.array(scals).clip(0, None) objs = np.where(scals > 0, objs, np.inf) + if obj_crit == "spk_diff": err_null = self._compute_err( s=np.zeros_like(opt_s), res=res, b=opt_b, obj_crit=obj_crit @@ -784,17 +742,24 @@ def solve_thres( nspk = np.array([0] + [(ss > 0).sum() for ss in svals]) objs_diff = np.diff(objs_pad) nspk_diff = np.diff(nspk) + nspk_diff = np.where(nspk_diff == 0, 1, nspk_diff) # Avoid division by zero merr_diff = objs_diff / nspk_diff - avg_err = (objs_pad.min() - err_null) / nspk.max() - opt_idx = np.max(np.where(merr_diff < avg_err)) + avg_err = (objs_pad.min() - err_null) / max(nspk.max(), 1) + opt_idx = ( + int(np.max(np.where(merr_diff < avg_err)[0])) + if np.any(merr_diff < avg_err) + else 0 + ) objs = merr_diff else: - opt_idx = np.argmin(objs) + opt_idx = int(np.argmin(objs)) + s_bin = svals[opt_idx] self.s_bin = s_bin - assert len(s_bin) == len(self.nzidx_s) - self.c_bin = np.array(cvals[opt_idx].todense()).squeeze() + self.c_bin = to_arr(cvals[opt_idx]) self.b = bs[opt_idx] + self.solver.s_bin = s_bin # Update solver's s_bin for adaptive weighting + if return_intm: return ( self.s_bin, @@ -808,86 +773,89 @@ def solve_thres( def solve_penal( self, masking=True, scaling=True, return_intm=False, pks_polish=None - ) -> Tuple[np.ndarray]: - if self.penal is None: + ) -> Tuple[np.ndarray, ...]: + """Solve with penalty optimization via DIRECT.""" + if self.cfg.penal is None: opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( scaling=scaling, return_intm=return_intm, pks_polish=pks_polish ) opt_penal = 0 - elif self.penal in ["l0", "l1"]: - pn = "{}_penal".format(self.penal) - self.update(**{pn: 0}) - if masking: - self._reset_cache() - self._update_mask() - s_nopn, _, _, err_nopn, intm = self.solve_thres( - scaling=scaling, return_intm=True, pks_polish=pks_polish - ) - s_min = intm[0] - ymean = self.y.mean() - err_full = self._compute_err(s=np.zeros(len(self.nzidx_s)), b=ymean) - err_min = self._compute_err(s=s_min) - ub, ub_last = err_full, err_full - for _ in range(int(np.ceil(np.log2(ub)))): - self.update(**{pn: ub}) - s, b = self.solve(pks_polish=pks_polish) - cur_err = self._compute_err(s=s, b=b) - # DIRECT finds weird solutions with high penalty and baseline, - # so we want to eliminate those possibilities - if np.abs(cur_err - err_min) < 0.5 * np.abs(err_full - err_min): - ub = ub_last - break - else: - ub_last = ub - ub = ub / 2 - - def opt_fn(x): - self.update(**{pn: x.item()}) - _, _, _, obj = self.solve_thres(scaling=False, pks_polish=pks_polish) - if self.dashboard is not None: - self.dashboard.update( - uid=self.dashboard_uid, - penal_err={"penal": x.item(), "scale": self.scale, "err": obj}, - ) - if obj < err_full: - return obj - else: - return np.inf + if return_intm: + return opt_s, opt_c, opt_scl, opt_obj, opt_penal, None + return opt_s, opt_c, opt_scl, opt_obj, opt_penal + + pn = f"{self.cfg.penal}_penal" + self.update(**{pn: 0}) + + if masking: + self._reset_cache() + self._update_mask() + + s_nopn, _, _, err_nopn, intm = self.solve_thres( + scaling=scaling, return_intm=True, pks_polish=pks_polish + ) + s_min = intm[0] + ymean = self.y.mean() + err_full = self._compute_err(s=np.zeros(len(self.nzidx_s)), b=ymean) + err_min = self._compute_err(s=s_min) + + # Find upper bound for penalty + ub, ub_last = err_full, err_full + for _ in range(int(np.ceil(np.log2(ub + 1)))): + self.update(**{pn: ub}) + s, b = self.solve(pks_polish=pks_polish) + cur_err = self._compute_err(s=s, b=b) + if np.abs(cur_err - err_min) < 0.5 * np.abs(err_full - err_min): + ub = ub_last + break + else: + ub_last = ub + ub = ub / 2 + + def opt_fn(x): + self.update(**{pn: float(x)}) + _, _, _, obj = self.solve_thres(scaling=False, pks_polish=pks_polish) + if self.dashboard is not None: + self.dashboard.update( + uid=self.dashboard_uid, + penal_err={"penal": float(x), "scale": self.scale, "err": obj}, + ) + return obj if obj < err_full else np.inf + try: res = direct( opt_fn, - bounds=[(0, ub)], - maxfun=self.max_iter_penal, + bounds=[(0, max(ub, 1e-6))], + maxfun=self.cfg.max_iter_penal, locally_biased=False, vol_tol=1e-2, ) direct_pn = res.x if not res.success: logger.warning( - "could not find optimal penalty within {} iterations".format( - res.nfev - ) + f"Could not find optimal penalty within {res.nfev} iterations" ) opt_penal = 0 elif err_nopn <= opt_fn(direct_pn): - # DIRECT seem to mistakenly report high penalty when 0 penalty attains better error opt_penal = 0 else: - opt_penal = direct_pn.item() - self.update(**{pn: opt_penal}) - if return_intm: - opt_s, opt_c, opt_scl, opt_obj, intm = self.solve_thres( - scaling=scaling, return_intm=return_intm, pks_polish=pks_polish - ) - else: - opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( - scaling=scaling, return_intm=return_intm, pks_polish=pks_polish - ) - if opt_scl == 0: - logger.warning("could not find non-zero solution") + opt_penal = float(direct_pn) + except Exception as e: + logger.warning(f"DIRECT optimization failed: {e}") + opt_penal = 0 + + self.update(**{pn: opt_penal}) if return_intm: + opt_s, opt_c, opt_scl, opt_obj, intm = self.solve_thres( + scaling=scaling, return_intm=True, pks_polish=pks_polish + ) return opt_s, opt_c, opt_scl, opt_obj, opt_penal, intm else: + opt_s, opt_c, opt_scl, opt_obj = self.solve_thres( + scaling=scaling, pks_polish=pks_polish + ) + if opt_scl == 0: + logger.warning("Could not find non-zero solution") return opt_s, opt_c, opt_scl, opt_obj, opt_penal def solve_scale( @@ -898,32 +866,35 @@ def solve_scale( obj_crit: str = None, early_stop: bool = True, masking: bool = True, - ) -> Tuple[np.ndarray]: - if self.penal in ["l0", "l1"]: - pn = "{}_penal".format(self.penal) + ) -> Tuple[np.ndarray, ...]: + """Solve with iterative scale estimation.""" + if self.cfg.penal in ["l0", "l1"]: + pn = f"{self.cfg.penal}_penal" self.update(**{pn: 0}) + self._reset_cache() self._reset_mask() + if reset_scale: self.update(scale=1) s_free, _ = self.solve(amp_constraint=False) self.update(scale=np.ptp(s_free)) - else: - s_free = np.zeros(len(self.nzidx_s)) + metric_df = None - for i in range(self.max_iter_scal): + for i in range(self.cfg.max_iter_scal): if concur_penal: cur_s, cur_c, cur_scl, cur_obj_raw, cur_penal = self.solve_penal( scaling=i > 0, - pks_polish=self.pks_polish and (i > 1 or not reset_scale), + pks_polish=self.cfg.pks_polish and (i > 1 or not reset_scale), ) else: cur_penal = 0 cur_s, cur_c, cur_scl, cur_obj_raw = self.solve_thres( scaling=i > 0, - pks_polish=self.pks_polish and (i > 1 or not reset_scale), + pks_polish=self.cfg.pks_polish and (i > 1 or not reset_scale), obj_crit=obj_crit, ) + if self.dashboard is not None: pad_s = np.zeros(self.T) pad_s[self.nzidx_s] = cur_s @@ -933,6 +904,7 @@ def solve_scale( s=self.R_org @ pad_s, scale=cur_scl, ) + if metric_df is None: prev_scals = np.array([np.inf]) opt_obj = np.inf @@ -941,14 +913,16 @@ def solve_scale( last_scal = np.inf else: opt_idx = metric_df["obj"].idxmin() - opt_obj = metric_df.loc[opt_idx, "obj"].item() - opt_scal = metric_df.loc[opt_idx, "scale"].item() + opt_obj = metric_df.loc[opt_idx, "obj"] + opt_scal = metric_df.loc[opt_idx, "scale"] prev_scals = np.array(metric_df["scale"]) last_scal = prev_scals[-1] last_obj = np.array(metric_df["obj"])[-1] + y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) - cur_obj = (cur_obj_raw - err_tt) / err_tt + cur_obj = (cur_obj_raw - err_tt) / max(err_tt, 1e-10) + cur_met = pd.DataFrame( [ { @@ -963,47 +937,59 @@ def solve_scale( ] ) metric_df = pd.concat([metric_df, cur_met], ignore_index=True) - if self.err_weighting == "adaptive" and i <= 1: + + if self.cfg.err_weighting == "adaptive" and i <= 1: self.update(update_weighting=True) if masking and i >= 1: self._update_mask() + if any( - ( - np.abs(cur_scl - opt_scal) < self.rtol * opt_scal, - np.abs(cur_obj - opt_obj) < self.rtol * opt_obj, - np.abs(cur_scl - last_scal) < self.atol, - np.abs(cur_obj - last_obj) < self.atol * 1e-3, + [ + np.abs(cur_scl - opt_scal) < self.cfg.rtol * opt_scal, + np.abs(cur_obj - opt_obj) < self.cfg.rtol * opt_obj, + np.abs(cur_scl - last_scal) < self.cfg.atol, + np.abs(cur_obj - last_obj) < self.cfg.atol * 1e-3, early_stop and cur_obj > last_obj, - ) + ] ): break elif cur_scl == 0: - warnings.warn("exit with zero solution") + warnings.warn("Exit with zero solution") break - elif np.abs(cur_scl - prev_scals).min() < self.atol: + elif np.abs(cur_scl - prev_scals).min() < self.cfg.atol: self.update(scale=(cur_scl + last_scal) / 2) else: self.update(scale=cur_scl) else: - warnings.warn("max scale iterations reached") + warnings.warn("Max scale iterations reached") + + # Final solve with optimal scale opt_idx = metric_df["obj"].idxmin() self.update(update_weighting=True, clear_weighting=True) self._reset_cache() self._reset_mask() - self.update(scale=metric_df.loc[opt_idx, "scale"]) + self.update(scale=float(metric_df.loc[opt_idx, "scale"])) + cur_s, cur_c, cur_scl, cur_obj, cur_penal = self.solve_penal( - scaling=False, masking=False, pks_polish=self.pks_polish + scaling=False, masking=False, pks_polish=self.cfg.pks_polish ) - opt_s, opt_c = np.zeros(self.T), np.zeros(self.T) + + opt_s = np.zeros(self.T) + opt_c = np.zeros(self.T) opt_s[self.nzidx_s] = cur_s - opt_c[self.nzidx_c] = cur_c + opt_c[self.nzidx_c] = ( + cur_c if not sps.issparse(cur_c) else cur_c.toarray().squeeze() + ) nnz = int(opt_s.sum()) + self.update(update_weighting=True) y_wt = np.array(self.y * self.err_wt) err_tt = self._res_err(y_wt - y_wt.mean()) err_cur = self._compute_err(s=opt_s) - err_rel = (err_cur - err_tt) / err_tt + err_rel = (err_cur - err_tt) / max(err_tt, 1e-10) + self.update(update_weighting=True, clear_weighting=True) + if self.dashboard is not None: self.dashboard.update( uid=self.dashboard_uid, @@ -1011,556 +997,11 @@ def solve_scale( s=self.R_org @ opt_s, scale=cur_scl, ) + self._reset_cache() self._reset_mask() + if return_met: return opt_s, opt_c, cur_scl, cur_obj, err_rel, nnz, cur_penal, metric_df else: return opt_s, opt_c, cur_scl, cur_obj, err_rel, nnz, cur_penal - - def _setup_prob_osqp(self) -> None: - logger.debug("Setting up OSQP problem") - self._update_HG() - self._update_wgt_len() - self._update_P() - self._update_q0() - self._update_q() - self._update_A() - self._update_bounds() - if self.backend == "emosqp": - m = osqp.OSQP() - m.setup( - P=self.P, - q=self.q, - A=self.A, - l=self.lb, - u=self.ub_inf, - check_termination=25, - eps_abs=self.atol * 1e-4, - eps_rel=1e-8, - ) - m.codegen( - "osqp-codegen-prob_free", - parameters="matrices", - python_ext_name="emosqp_free", - force_rewrite=True, - ) - m.update(u=self.ub) - m.codegen( - "osqp-codegen-prob", - parameters="matrices", - python_ext_name="emosqp", - force_rewrite=True, - ) - import emosqp - import emosqp_free - - self.prob_free = emosqp_free - self.prob = emosqp - elif self.backend in ["osqp", "cuosqp"]: - if self.backend == "osqp": - self.prob_free = osqp.OSQP() - self.prob = osqp.OSQP() - elif self.backend == "cuosqp": - self.prob_free = cuosqp.OSQP() - self.prob = cuosqp.OSQP() - P_copy = self.P.copy() - q_copy = self.q.copy() - A_copy = self.A.copy() - lb_copy = self.lb.copy() - ub_inf_copy = self.ub_inf.copy() - self.prob_free.setup( - P=P_copy, - q=q_copy, - A=A_copy, - l=lb_copy, - u=ub_inf_copy, - verbose=False, - polish=True, - warm_start=False, - # adaptive_rho=False, - eps_abs=1e-6, - eps_rel=1e-6, - eps_prim_inf=1e-7, - eps_dual_inf=1e-7, - # max_iter=int(1e5) if self.backend == "osqp" else None, - # eps_prim_inf=1e-8, - ) - P_copy = self.P.copy() - q_copy = self.q.copy() - A_copy = self.A.copy() - lb_copy = self.lb.copy() - ub_copy = self.ub.copy() - self.prob.setup( - P=P_copy, - q=q_copy, - A=A_copy, - l=lb_copy, - u=ub_copy, - verbose=False, - polish=True, - warm_start=False, - # adaptive_rho=False, - eps_abs=1e-6, - eps_rel=1e-6, - eps_prim_inf=1e-7, - eps_dual_inf=1e-7, - # max_iter=int(1e5) if self.backend == "osqp" else None, - # eps_prim_inf=1e-8, - ) - logger.debug(f"{self.backend} setup completed successfully") - - def _solve( - self, - amp_constraint: bool = True, - return_obj: bool = False, - update_cache: bool = False, - ) -> np.ndarray: - if amp_constraint: - prob = self.prob - else: - prob = self.prob_free - # if self.backend in ["osqp", "emosqp", "cuosqp"] and self.x_cache is not None: - # prob.warm_start(x=self.x_cache) - res = prob.solve() - if self.backend == "cvxpy": - opt_s = self.s.value.squeeze() - opt_b = 0 - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - x = res[0] if self.backend == "emosqp" else res.x - if res.info.status not in ["solved", "solved inaccurate"]: - warnings.warn("Problem not solved. status: {}".format(res.info.status)) - # osqp mistakenly report primal infeasibility when using masks - # with high l1 penalty. manually set solution to zero in such cases - if res.info.status in [ - "primal infeasible", - "primal infeasible inaccurate", - ]: - x = np.zeros_like(x, dtype=float) - else: - x = x.astype(float) - # if update_cache: - # self.x_cache = x - # prob.warm_start(x=x) - if self.norm == "huber": - xlen = len(self.nzidx_s) if self.free_kernel else len(self.nzidx_c) - sol = x[:xlen] - else: - sol = x - opt_b = sol[0] - if self.free_kernel: - opt_s = sol[1:] - else: - opt_s = self.G @ sol[1:] - if return_obj: - if self.backend == "cvxpy": - opt_obj = res - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - opt_obj = self._compute_err() - return opt_s, opt_b, opt_obj - else: - return opt_s, opt_b - - def _compute_c(self, s: np.ndarray = None) -> np.ndarray: - if s is not None: - return self._convolve_s(s) - else: - if self.backend == "cvxpy": - return self.c.value.squeeze() - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - return self._convolve_s(self.s) - - def _convolve_s(self, s: np.ndarray) -> sps.csc_array: - if self.H is not None: - return self.H @ sps.csc_matrix(s.reshape(-1, 1)) - else: - if s.dtype == np.bool_: - return sps.csc_matrix( - bin_convolve( - self.coef, s=s, nzidx_s=self.nzidx_s, s_len=self.T - ).reshape(-1, 1) - ) - else: - return sps.csc_matrix( - np.convolve(self.coef, self._pad_s(s))[self.nzidx_c].reshape(-1, 1) - ) - - def _compute_err( - self, - y_fit: np.ndarray = None, - b: np.ndarray = None, - c: np.ndarray = None, - s: np.ndarray = None, - res: np.ndarray = None, - obj_crit: str = None, - ) -> float: - if self.backend == "cvxpy": - # TODO: add support - raise NotImplementedError - elif self.backend in ["osqp", "emosqp", "cuosqp"]: - y = np.array(self.y) - if res is not None: - y = y - res - if b is None: - b = self.b - y = y - b - if y_fit is None: - if c is None: - c = self._compute_c(s) - y_fit = np.array((self.R @ c * self.scale).todense()).squeeze() - r = y - y_fit - err = self._res_err(r) - if obj_crit in [None, "spk_diff"]: - return np.array(err).item() - else: - nspk = (s > 0).sum() - if obj_crit == "mean_spk": - err_total = self._res_err(y - y.mean()) - return np.array((err - err_total) / nspk).item() - elif obj_crit in ["aic", "bic"]: - noise_model = "normal" - T = len(r) - if noise_model == "normal": - mu = r.mean() - sigma = ((r - mu) ** 2).sum() / T - logL = -0.5 * ( - T * np.log(2 * np.pi * sigma) - + 1 / sigma * ((r - mu) ** 2).sum() - ) - elif noise_model == "lognormal": - ymin = y.min() - logy = np.log(y - ymin + 1) - logy_hat = np.log(y_fit - ymin + 1) - logr = logy - logy_hat - mu = np.mean(logr) - sigma = ((logr - mu) ** 2).sum() / T - logL = np.sum( - -logy - - (logr - mu) ** 2 / (2 * sigma) - - 0.5 * np.log(2 * np.pi * sigma) - ) - if obj_crit == "aic": - return np.array(2 * (nspk - logL)).item() - elif obj_crit == "bic": - return np.array(nspk * np.log(T) - 2 * logL).item() - else: - raise ValueError("invalid objective criterion: {}".format(obj_crit)) - - def _res_err(self, r: np.ndarray): - if self.err_wt is not None: - r = self.err_wt * r - if self.norm == "l1": - return np.sum(np.abs(r)) - elif self.norm == "l2": - return np.sum((r) ** 2) - elif self.norm == "huber": - err_hub = huber(self.huber_k, r) - err_qud = r**2 / 2 - return np.sum(np.where(r >= 0, err_hub, err_qud)) - - def _reset_cache(self) -> None: - self.x_cache = None - - def _reset_mask(self) -> None: - self.nzidx_s = np.arange(self.T) - self.nzidx_c = np.arange(self.T) - self._update_R() - self._update_w() - if self.backend in ["osqp", "emosqp", "cuosqp"]: - self._setup_prob_osqp() - - def _update_mask(self, use_wt: bool = False, amp_constraint: bool = True) -> None: - if self.backend in ["osqp", "emosqp", "cuosqp"]: - if use_wt: - nzidx_s = np.where(self.R.T @ self.err_wt)[0] - else: - if self.masking_r is not None: - mask = np.zeros(self.T) - for nzidx in np.where(self._pad_s(self.s_bin) > 0)[0]: - mask[ - max(nzidx - self.masking_r, 0) : min( - nzidx + self.masking_r, self.T - ) - ] = 1 - nzidx_s = np.where(mask)[0] - else: - self._reset_mask() - opt_s, _ = self.solve(amp_constraint) - nzidx_s = np.where(opt_s > self.delta_penal)[0] - if len(nzidx_s) == 0: - raise ValueError - self.nzidx_s = nzidx_s - self._update_R() - self._update_w() - self._setup_prob_osqp() - if not self.free_kernel and len(self.nzidx_c) < self.T: - res = self.prob.solve() - # osqp mistakenly report primal infeasible in some cases - # disable masking in such cases - # potentially related: https://github.com/osqp/osqp/issues/485 - if res.info.status == "primal infeasible": - self._reset_mask() - else: - # TODO: add support - raise NotImplementedError("masking not supported for cvxpy backend") - - def _update_w(self, w_new=None) -> None: - if w_new is not None: - self.w_org = w_new - self.w = self.w_org[self.nzidx_s] - - def _update_R(self) -> None: - self.R_org = construct_R(self.y_len, self.upsamp) - self.R = self.R_org[:, self.nzidx_c] - - def _update_Wt(self, clear=False) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if clear: - logger.debug("Clearing error weighting") - self.err_wt = np.ones(self.y_len) - elif self.err_weighting == "fft": - logger.debug("Updating error weighting with fft") - hspec = self._get_stft_spec(coef)[:, int(self.coef_len / 2)] - self.err_wt = ( - (hspec.reshape(-1, 1) * self.yspec).sum(axis=0) - / np.linalg.norm(hspec) - / np.linalg.norm(self.yspec, axis=0) - ) - elif self.err_weighting == "corr": - logger.debug("Updating error weighting with corr") - for i in range(self.y_len): - yseg = self.y[i : i + self.coef_len] - if len(yseg) <= 1: - continue - cseg = coef[: len(yseg)] - with np.errstate(all="ignore"): - self.err_wt[i] = np.corrcoef(yseg, cseg)[0, 1].clip(0, 1) - self.err_wt = np.nan_to_num(self.err_wt) - elif self.err_weighting == "adaptive": - if self.s_bin is not None: - self.err_wt = np.zeros(self.y_len) - s_bin_R = self.R @ self._pad_s(self.s_bin) - for nzidx in np.where(s_bin_R > 0)[0]: - self.err_wt[nzidx : nzidx + self.wgt_len] = 1 - else: - self.err_wt = np.ones(self.y_len) - self.Wt = sps.diags(self.err_wt) - - def _update_HG(self) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if self.Hlim is None or self.T * self.coef_len < self.Hlim: - self.H_org = sps.diags( - [np.repeat(coef[i], self.T - i) for i in range(len(coef))], - offsets=-np.arange(len(coef)), - format="csc", - ) - try: - H_shape, H_nnz = self.H.shape, self.H.nnz - except AttributeError: - H_shape, H_nnz = None, None - self.H = self.H_org[:, self.nzidx_s][self.nzidx_c, :] - logger.debug( - f"Updating H matrix - shape before: {H_shape}, shape new: {self.H.shape}, nnz before: {H_nnz}, nnz new: {self.H.nnz}" - ) - if not self.free_kernel: - theta = self.theta.value if self.backend == "cvxpy" else self.theta - G_diag = sps.diags( - [np.ones(self.T - 1)] - + [np.repeat(-theta[i], self.T - 2 - i) for i in range(theta.shape[0])], - offsets=np.arange(0, -theta.shape[0] - 1, -1), - format="csc", - ) - self.G_org = sps.bmat( - [[None, G_diag], [np.zeros((1, 1)), None]], format="csc" - ) - try: - G_shape, G_nnz = self.G.shape, self.G.nnz - except AttributeError: - G_shape, G_nnz = None, None - self.G = self.G_org[:, self.nzidx_c][self.nzidx_s, :] - logger.debug( - f"Updating G matrix - shape before: {G_shape}, shape new: {self.G.shape}, nnz before: {G_nnz}, nnz new: {self.G.nnz}" - ) - # assert np.isclose( - # np.linalg.pinv(self.H.todense()), self.G.todense(), atol=self.atol - # ).all() - - def _update_wgt_len(self) -> None: - coef = self.coef.value if self.backend == "cvxpy" else self.coef - if self.wt_trunc_thres is not None: - trunc_len = int( - np.around(np.where(coef > self.wt_trunc_thres)[0][-1] / self.upsamp) - ) - if trunc_len == 0: - trunc_len = int(np.around(np.where(coef > 0)[0][-1] / self.upsamp)) - self.wgt_len = max(min(self.coef_len, trunc_len), 1) - else: - self.wgt_len = self.coef_len - - def _get_stft_spec(self, x: np.ndarray) -> np.ndarray: - spec = np.abs(self.stft.stft(x)) ** 2 - t = self.stft.t(len(x)) - t_mask = np.logical_and(t >= 0, t < len(x)) - return spec[:, t_mask] - - def _get_M(self) -> sps.csc_matrix: - if self.free_kernel: - return sps.hstack( - [ - np.ones((self.R.shape[0], 1)), - self.scale * self.R @ self.H, - ], - format="csc", - ) - else: - return sps.hstack( - [np.ones((self.R.shape[0], 1)), self.scale * self.R], format="csc" - ) - - def _update_P(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - M = self._get_M() - P = M.T @ self.Wt.T @ self.Wt @ M - elif self.norm == "huber": - lc, ls, ly = len(self.nzidx_c), len(self.nzidx_s), self.y_len - if self.free_kernel: - P = sps.bmat( - [ - [sps.csc_matrix((ls, ls)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ] - ) - else: - P = sps.bmat( - [ - [sps.csc_matrix((lc, lc)), None, None], - [None, sps.csc_matrix((ly, ly)), None], - [None, None, sps.eye(ly, format="csc")], - ] - ) - try: - P_shape, P_nnz = self.P.shape, self.P.nnz - except AttributeError: - P_shape, P_nnz = None, None - logger.debug( - f"Updating P matrix - shape before: {P_shape}, shape new: {P.shape}, nnz before: {P_nnz}, nnz new: {P.nnz}" - ) - self.P = sps.triu(P).tocsc() - - def _update_q0(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - M = self._get_M() - self.q0 = -M.T @ self.Wt.T @ self.Wt @ self.y - elif self.norm == "huber": - ly = self.y_len - lx = len(self.nzidx_s) if self.free_kernel else len(self.nzidx_c) - self.q0 = ( - np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k - ) - - def _update_q(self) -> None: - if self.norm == "l1": - # TODO: add support - raise NotImplementedError( - "l1 norm not yet supported with backend {}".format(self.backend) - ) - elif self.norm == "l2": - if self.free_kernel: - ww = np.concatenate([np.zeros(1), self.w]) - qq = np.concatenate([np.zeros(1), np.ones_like(self.w)]) - self.q = self.q0 + self.l0_penal * ww + self.l1_penal * qq - else: - G_p = sps.hstack([np.zeros((self.G.shape[0], 1)), self.G], format="csc") - self.q = ( - self.q0 - + self.l0_penal * self.w @ G_p - + self.l1_penal * np.ones(self.G.shape[0]) @ G_p - ) - elif self.norm == "huber": - pad_k = np.zeros(self.y_len) - if self.free_kernel: - self.q = ( - self.q0 - + self.l0_penal * np.concatenate([self.w, pad_k, pad_k]) - + self.l1_penal - * np.concatenate([np.ones(len(self.nzidx_s)), pad_k, pad_k]) - ) - else: - self.q = ( - self.q0 - + self.l0_penal * np.concatenate([self.w @ self.G, pad_k, pad_k]) - + self.l1_penal - * np.concatenate([np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k]) - ) - - def _update_A(self) -> None: - if self.free_kernel: - Ax = sps.eye(len(self.nzidx_s), format="csc") - Ar = self.scale * self.R @ self.H - else: - Ax = sps.csc_matrix(self.G_org[:, self.nzidx_c]) - # record spike terms that requires constraint - self.nzidx_A = np.where((Ax != 0).sum(axis=1))[0] - Ax = Ax[self.nzidx_A, :] - Ar = self.scale * self.R - try: - A_shape, A_nnz = self.A.shape, self.A.nnz - except AttributeError: - A_shape, A_nnz = None, None - if self.norm == "huber": - e = sps.eye(self.y_len, format="csc") - self.A = sps.bmat( - [ - [Ax, None, None], - [None, e, None], - [None, None, -e], - [Ar, e, e], - ], - format="csc", - ) - else: - self.A = sps.bmat([[np.ones((1, 1)), None], [None, Ax]], format="csc") - logger.debug( - f"Updating A matrix - shape before: {A_shape}, shape new: {self.A.shape}, nnz before: {A_nnz}, nnz new: {self.A.nnz}" - ) - - def _update_bounds(self) -> None: - if self.norm == "huber": - xlen = len(self.nzidx_s) if self.free_kernel else self.T - self.lb = np.concatenate( - [np.zeros(xlen + self.y_len * 2), self.y - self.huber_k] - ) - self.ub = np.concatenate( - [np.ones(xlen), np.full(self.y_len * 2, np.inf), self.y - self.huber_k] - ) - self.ub_inf = np.concatenate( - [np.full(xlen + self.y_len * 2, np.inf), self.y - self.huber_k] - ) - else: - bb = np.clip(self.y.mean(), 0, None) if self.use_base else 0 - if self.free_kernel: - self.lb = np.zeros(len(self.nzidx_s) + 1) - self.ub = np.concatenate([np.full(1, bb), np.ones(len(self.nzidx_s))]) - self.ub_inf = np.concatenate( - [np.full(1, bb), np.full(len(self.nzidx_s), np.inf)] - ) - else: - ub_pad, ub_inf_pad = np.zeros(self.T), np.zeros(self.T) - ub_pad[self.nzidx_s] = 1 - ub_inf_pad[self.nzidx_s] = np.inf - self.lb = np.zeros(len(self.nzidx_A) + 1) - self.ub = np.concatenate([np.full(1, bb), ub_pad[self.nzidx_A]]) - self.ub_inf = np.concatenate([np.full(1, bb), ub_inf_pad[self.nzidx_A]]) - assert (self.ub >= self.lb).all() - assert (self.ub_inf >= self.lb).all() diff --git a/src/indeca/core/deconv/solver.py b/src/indeca/core/deconv/solver.py new file mode 100644 index 0000000..03279f1 --- /dev/null +++ b/src/indeca/core/deconv/solver.py @@ -0,0 +1,903 @@ +"""Solver implementations for deconv.""" + +from abc import ABC, abstractmethod +from typing import Any, Tuple, Optional +import warnings + +import cvxpy as cp +import numpy as np +import osqp +import scipy.sparse as sps +from scipy.special import huber +from scipy.signal import ShortTimeFFT + +from indeca.utils.logging_config import get_module_logger +from indeca.core.simulation import tau2AR, solve_p, exp_pulse, ar_pulse +from indeca.utils.utils import scal_lstsq +from .config import DeconvConfig +from .utils import construct_R, bin_convolve, get_stft_spec + +logger = get_module_logger("deconv_solver") + +# Try to import GPU solver +try: + import cuosqp + + HAS_CUOSQP = True +except ImportError: + HAS_CUOSQP = False + logger.debug("cuosqp not available") + + +class DeconvSolver(ABC): + """Abstract base class for deconvolution solvers.""" + + def __init__( + self, + config: DeconvConfig, + y_len: int, + y: np.ndarray | None = None, + coef: np.ndarray | None = None, + theta: np.ndarray | None = None, + tau: np.ndarray | None = None, + ps: np.ndarray | None = None, + ): + self.cfg = config + self.y_len = y_len + self.T = y_len * self.cfg.upsamp + self.y = y if y is not None else np.zeros(y_len) + self.coef = coef + self.coef_len = ( + len(coef) if coef is not None else config.coef_len * config.upsamp + ) + self.theta = theta + self.tau = tau + self.ps = ps + + # Scale tracking (mutable, since config is frozen) + self.scale = config.scale + + # Penalty tracking + self.l0_penal = 0.0 + self.l1_penal = 0.0 + + # Weight vectors + self.w_org = np.ones(self.T) + self.w = np.ones(self.T) + + # Masking indices + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + + # Matrices + self.R_org = construct_R(self.y_len, self.cfg.upsamp) + self.R = self.R_org + self.H = None + self.H_org = None + self.G = None + self.G_org = None + + # Cache + self.x_cache = None + self.s_bin = None # Binary spike solution from thresholding + + # Error weighting + self.err_wt = np.ones(self.y_len) + self.wgt_len = self.coef_len + self.Wt = sps.diags(self.err_wt) + + # Huber parameter + self.huber_k = 0.5 * np.std(self.y) if y is not None else 0 + + @abstractmethod + def update(self, **kwargs): + """Update solver parameters.""" + pass + + @abstractmethod + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve the optimization problem.""" + pass + + def reset_cache(self) -> None: + """Reset solution cache.""" + self.x_cache = None + + def reset_mask(self) -> None: + """Reset masks to full range.""" + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + self._update_R() + self._update_w() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """Set mask indices. Override in subclasses that don't support masking.""" + self.nzidx_s = nzidx_s + # Old behavior (from `old deconv.py`): masking is applied to spike indices (s) + # while keeping the calcium state indices (c) unmasked unless explicitly provided. + if nzidx_c is not None: + self.nzidx_c = nzidx_c + self._update_R() + self._update_w() + + def _update_R(self) -> None: + """Update R matrix based on mask.""" + self.R = self.R_org[:, self.nzidx_c] + + def _update_w(self, w_new: np.ndarray = None) -> None: + """Update weight vector.""" + if w_new is not None: + self.w_org = w_new + self.w = self.w_org[self.nzidx_s] + + def _pad_s(self, s: np.ndarray = None) -> np.ndarray: + """Pad sparse s to full length.""" + if s is None: + s = np.zeros(len(self.nzidx_s)) + s_ret = np.zeros(self.T) + s_ret[self.nzidx_s] = s + return s_ret + + def _pad_c(self, c: np.ndarray = None) -> np.ndarray: + """Pad sparse c to full length.""" + if c is None: + c = np.zeros(len(self.nzidx_c)) + c_ret = np.zeros(self.T) + c_ret[self.nzidx_c] = c + return c_ret + + def _update_HG(self) -> None: + """Update H (convolution) and G (AR inverse) matrices.""" + coef = self.coef + if coef is None: + return + + # H matrix: convolution matrix + # IMPORTANT: in free-kernel mode the optimization uses R @ H explicitly, + # so H must always be materialized (do not drop it based on Hlim). + if ( + self.cfg.free_kernel + or self.cfg.Hlim is None + or self.T * len(coef) < self.cfg.Hlim + ): + self.H_org = sps.diags( + [np.repeat(coef[i], self.T - i) for i in range(len(coef))], + offsets=-np.arange(len(coef)), + format="csc", + ) + self.H = self.H_org[:, self.nzidx_s][self.nzidx_c, :] + logger.debug(f"Updated H matrix - shape: {self.H.shape}, nnz: {self.H.nnz}") + else: + self.H = None + self.H_org = None + + # G matrix: AR inverse (only if theta provided and not free_kernel) + if not self.cfg.free_kernel and self.theta is not None: + theta = self.theta + G_diag = sps.diags( + [np.ones(self.T - 1)] + + [np.repeat(-theta[i], self.T - 2 - i) for i in range(theta.shape[0])], + offsets=np.arange(0, -theta.shape[0] - 1, -1), + format="csc", + ) + self.G_org = sps.bmat( + [[None, G_diag], [np.zeros((1, 1)), None]], format="csc" + ) + self.G = self.G_org[:, self.nzidx_c][self.nzidx_s, :] + logger.debug(f"Updated G matrix - shape: {self.G.shape}, nnz: {self.G.nnz}") + else: + self.G = None + self.G_org = None + + def _update_wgt_len(self) -> None: + """Update error weighting length based on coefficient truncation.""" + coef = self.coef + if coef is None: + return + if self.cfg.wt_trunc_thres is not None: + trunc_idx = np.where(coef > self.cfg.wt_trunc_thres)[0] + if len(trunc_idx) > 0: + trunc_len = int(np.around(trunc_idx[-1] / self.cfg.upsamp)) + else: + trunc_len = int(np.around(np.where(coef > 0)[0][-1] / self.cfg.upsamp)) + if trunc_len == 0: + trunc_len = 1 + self.wgt_len = max(min(self.coef_len, trunc_len), 1) + else: + self.wgt_len = self.coef_len + + def convolve(self, s: np.ndarray) -> sps.csc_matrix: + """Convolve signal s with kernel. Returns sparse column matrix.""" + if self.cfg.free_kernel: + assert ( + self.H is not None + ), "Invariant violated: free_kernel=True requires a materialized H matrix" + if self.H is not None: + # Check if s is masked length or full length + if len(s) == len(self.nzidx_s): + result = self.H @ sps.csc_matrix(s.reshape(-1, 1)) + elif len(s) == self.T: + result = self.H @ sps.csc_matrix(s[self.nzidx_s].reshape(-1, 1)) + else: + logger.warning( + f"Shape mismatch in convolve: s={len(s)}, nzidx_s={len(self.nzidx_s)}" + ) + result = sps.csc_matrix(np.zeros((len(self.nzidx_c), 1))) + return result + else: + # Use bin_convolve for efficiency when H is not stored + if s.dtype == np.bool_: + out = bin_convolve(self.coef, s, nzidx_s=self.nzidx_s, s_len=self.T) + else: + s_pad = self._pad_s(s) if len(s) == len(self.nzidx_s) else s + out = np.convolve(self.coef, s_pad)[: self.T] + return sps.csc_matrix(out[self.nzidx_c].reshape(-1, 1)) + + def validate_coefficients(self, atol: float = 1e-3) -> bool: + """Validate that AR and exponential coefficients are consistent.""" + if self.tau is None or self.ps is None or self.theta is None: + logger.debug("Skipping coefficient validation - missing tau/ps/theta") + return True + + try: + # Generate exponential pulse + tr_exp, _, _ = exp_pulse( + self.tau[0], + self.tau[1], + p_d=self.ps[0], + p_r=self.ps[1], + nsamp=self.coef_len, + ) + + # Generate AR pulse + theta = self.theta + tr_ar, _, _ = ar_pulse( + theta[0], theta[1], nsamp=self.coef_len, shifted=True + ) + + # Validate + if not (~np.isnan(self.coef)).all(): + logger.warning("Coefficient array contains NaN values") + return False + + if not np.isclose(tr_exp, self.coef[: len(tr_exp)], atol=atol).all(): + logger.warning("Exp time constant inconsistent with coefficients") + return False + + if not np.isclose(tr_ar, self.coef[: len(tr_ar)], atol=atol).all(): + logger.warning("AR coefficients inconsistent with coefficients") + return False + + logger.debug("Coefficient validation passed") + return True + except Exception as e: + logger.warning(f"Coefficient validation failed: {e}") + return False + + +class CVXPYSolver(DeconvSolver): + """CVXPY backend solver.""" + + def __init__(self, config: DeconvConfig, y_len: int, **kwargs): + super().__init__(config, y_len, **kwargs) + self._update_HG() + self._update_wgt_len() + self._setup_problem() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """CVXPY does not support masking - raise error.""" + if len(nzidx_s) != self.T: + raise NotImplementedError( + "CVXPY backend does not support masking. Use OSQP backend instead." + ) + super().set_mask(nzidx_s, nzidx_c) + + def reset_mask(self) -> None: + """CVXPY does not support masking - no-op since problem is already full.""" + # CVXPY builds the full problem once, no mask support + # Just reset indices without rebuilding + self.nzidx_s = np.arange(self.T) + self.nzidx_c = np.arange(self.T) + + def _setup_problem(self): + """Setup CVXPY optimization problem.""" + # NOTE: `free_kernel=True` is forbidden with CVXPY backend (see `DeconvConfig`). + self.cp_R = cp.Constant(self.R, name="R") + self.cp_c = cp.Variable((self.T, 1), nonneg=True, name="c") + self.cp_s = cp.Variable( + (self.T, 1), nonneg=True, name="s", boolean=self.cfg.mixin + ) + self.cp_y = cp.Parameter(shape=(self.y_len, 1), name="y") + self.cp_huber_k = cp.Parameter( + value=float(self.huber_k), nonneg=True, name="huber_k" + ) + + self.cp_scale = cp.Parameter(value=self.scale, name="scale", nonneg=True) + self.cp_l1_penal = cp.Parameter(value=0.0, name="l1_penal", nonneg=True) + self.cp_l0_w = cp.Parameter( + shape=self.T, value=np.zeros(self.T), nonneg=True, name="w_l0" + ) + + if self.y is not None: + self.cp_y.value = self.y.reshape((-1, 1)) + + if self.cfg.use_base: + self.cp_b = cp.Variable(nonneg=True, name="b") + else: + self.cp_b = cp.Constant(value=0, name="b") + + # Error term based on norm + term = self.cp_y - self.cp_scale * self.cp_R @ self.cp_c - self.cp_b + if self.cfg.norm == "l1": + self.err_term = cp.sum(cp.abs(term)) + elif self.cfg.norm == "l2": + self.err_term = cp.sum_squares(term) + elif self.cfg.norm == "huber": + # Keep huber parameter consistent with OSQP backend's `huber_k`. + self.err_term = cp.sum(cp.huber(term, M=self.cp_huber_k)) + + # Objective + obj_expr = ( + self.err_term + + self.cp_l0_w.T @ cp.abs(self.cp_s) + + self.cp_l1_penal * cp.sum(cp.abs(self.cp_s)) + ) + obj = cp.Minimize(obj_expr) + + # Constraints + # AR constraint via G matrix + self.cp_theta = cp.Parameter( + value=self.theta, shape=self.theta.shape, name="theta" + ) + G_diag = sps.eye(self.T - 1) + sum( + [ + cp.diag(cp.promote(-self.cp_theta[i], (self.T - i - 2,)), -i - 1) + for i in range(self.theta.shape[0]) + ] + ) + G = cp.bmat( + [ + [np.zeros((self.T - 1, 1)), G_diag], + [np.zeros((1, 1)), np.zeros((1, self.T - 1))], + ] + ) + dcv_cons = [self.cp_s == G @ self.cp_c] + + edge_cons = [self.cp_c[0, 0] == 0, self.cp_s[-1, 0] == 0] + amp_cons = [self.cp_s <= 1] + + self.prob_free = cp.Problem(obj, dcv_cons + edge_cons) + self.prob = cp.Problem(obj, dcv_cons + edge_cons + amp_cons) + + def update( + self, + y: np.ndarray = None, + coef: np.ndarray = None, + scale: float = None, + scale_mul: float = None, + l1_penal: float = None, + l0_penal: float = None, + w: np.ndarray = None, + theta: np.ndarray = None, + **kwargs, + ): + """Update CVXPY parameters.""" + if y is not None: + self.y = y + self.cp_y.value = y.reshape((-1, 1)) + # Keep huber_k consistent with OSQP backend / deconv objective. + self.huber_k = 0.5 * np.std(self.y) + self.cp_huber_k.value = float(self.huber_k) + if coef is not None: + self.coef = coef + self._update_HG() + self._update_wgt_len() + if scale is not None: + self.scale = scale + self.cp_scale.value = scale + if scale_mul is not None: + self.scale *= scale_mul + self.cp_scale.value = self.scale + if l1_penal is not None: + self.l1_penal = l1_penal + self.cp_l1_penal.value = l1_penal + if l0_penal is not None: + self.l0_penal = l0_penal + if w is not None: + self._update_w(w) + if l0_penal is not None or w is not None: + self.cp_l0_w.value = self.l0_penal * self.w + if theta is not None and hasattr(self, "cp_theta"): + self.theta = theta + self.cp_theta.value = theta + + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve CVXPY problem.""" + prob = self.prob if amp_constraint else self.prob_free + try: + res = prob.solve() + except cp.error.SolverError as e: + logger.warning(f"CVXPY SolverError: {e}") + res = np.inf + + opt_s = ( + self.cp_s.value.squeeze() + if self.cp_s.value is not None + else np.zeros(self.T) + ) + opt_b = 0 + if ( + self.cfg.use_base + and hasattr(self.cp_b, "value") + and self.cp_b.value is not None + ): + opt_b = float(self.cp_b.value) + + return opt_s, opt_b, res + + +class OSQPSolver(DeconvSolver): + """OSQP backend solver (also handles cuosqp for GPU).""" + + def __init__(self, config: DeconvConfig, y_len: int, **kwargs): + super().__init__(config, y_len, **kwargs) + + # Additional state for OSQP + self.prob = None + self.prob_free = None + self.P = None + self.q = None + self.q0 = None + self.A = None + self.lb = None + self.ub = None + self.ub_inf = None + self.nzidx_A = None + + # STFT for FFT weighting + if self.cfg.err_weighting == "fft": + self.stft = ShortTimeFFT(win=np.ones(self.coef_len), hop=1, fs=1) + self.yspec = get_stft_spec(self.y, self.stft) + + # Initialize matrices and problem + self._update_HG() + self._update_wgt_len() + self._update_Wt() + self._setup_prob_osqp() + + def reset_mask(self) -> None: + """Reset masks to full range and rebuild OSQP problems.""" + super().reset_mask() + self._update_HG() + self._setup_prob_osqp() + + def set_mask(self, nzidx_s: np.ndarray, nzidx_c: np.ndarray = None): + """Set mask and rebuild problem.""" + super().set_mask(nzidx_s, nzidx_c) + self._update_HG() + self._setup_prob_osqp() + + def _update_Wt(self, clear: bool = False) -> None: + """Update error weighting matrix.""" + coef = self.coef + if clear: + logger.debug("Clearing error weighting") + self.err_wt = np.ones(self.y_len) + elif self.cfg.err_weighting == "fft" and hasattr(self, "stft"): + logger.debug("Updating error weighting with fft") + hspec = get_stft_spec(coef, self.stft)[:, int(len(coef) / 2)] + self.err_wt = ( + (hspec.reshape(-1, 1) * self.yspec).sum(axis=0) + / np.linalg.norm(hspec) + / np.linalg.norm(self.yspec, axis=0) + ) + elif self.cfg.err_weighting == "corr": + logger.debug("Updating error weighting with corr") + self.err_wt = np.ones(self.y_len) + for i in range(self.y_len): + yseg = self.y[i : i + len(coef)] + if len(yseg) <= 1: + continue + cseg = coef[: len(yseg)] + with np.errstate(all="ignore"): + self.err_wt[i] = np.corrcoef(yseg, cseg)[0, 1].clip(0, 1) + self.err_wt = np.nan_to_num(self.err_wt) + elif self.cfg.err_weighting == "adaptive": + if self.s_bin is not None: + self.err_wt = np.zeros(self.y_len) + s_bin_R = self.R @ self._pad_s(self.s_bin) + for nzidx in np.where(s_bin_R > 0)[0]: + self.err_wt[nzidx : nzidx + self.wgt_len] = 1 + else: + self.err_wt = np.ones(self.y_len) + + self.Wt = sps.diags(self.err_wt) + + def _get_M(self) -> sps.csc_matrix: + """Get the combined model matrix M = [1, scale*R] or [1, scale*R@H].""" + if self.cfg.free_kernel: + return sps.hstack( + [np.ones((self.R.shape[0], 1)), self.scale * self.R @ self.H], + format="csc", + ) + else: + return sps.hstack( + [np.ones((self.R.shape[0], 1)), self.scale * self.R], + format="csc", + ) + + def _update_P(self) -> None: + """Update quadratic cost matrix P.""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + M = self._get_M() + P = M.T @ self.Wt.T @ self.Wt @ M + elif self.cfg.norm == "huber": + lc = len(self.nzidx_c) + ls = len(self.nzidx_s) + ly = self.y_len + if self.cfg.free_kernel: + P = sps.bmat( + [ + [sps.csc_matrix((ls + 1, ls + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ] + ) + else: + P = sps.bmat( + [ + [sps.csc_matrix((lc + 1, lc + 1)), None, None], + [None, sps.csc_matrix((ly, ly)), None], + [None, None, sps.eye(ly, format="csc")], + ] + ) + + self.P = sps.triu(P).tocsc() + logger.debug(f"Updated P matrix - shape: {self.P.shape}, nnz: {self.P.nnz}") + + def _update_q0(self) -> None: + """Update linear cost base q0.""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + M = self._get_M() + self.q0 = -M.T @ self.Wt.T @ self.Wt @ self.y + elif self.cfg.norm == "huber": + ly = self.y_len + lx = ( + len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + ) + self.q0 = ( + np.concatenate([np.zeros(lx), np.ones(ly), np.ones(ly)]) * self.huber_k + ) + + def _update_q(self) -> None: + """Update linear cost vector q (including penalties).""" + if self.cfg.norm == "l1": + raise NotImplementedError("l1 norm not yet supported with OSQP backend") + elif self.cfg.norm == "l2": + if self.cfg.free_kernel: + ww = np.concatenate([np.zeros(1), self.w]) + qq = np.concatenate([np.zeros(1), np.ones_like(self.w)]) + self.q = self.q0 + self.l0_penal * ww + self.l1_penal * qq + else: + G_p = sps.hstack([np.zeros((self.G.shape[0], 1)), self.G], format="csc") + self.q = ( + self.q0 + + self.l0_penal * self.w @ G_p + + self.l1_penal * np.ones(self.G.shape[0]) @ G_p + ) + elif self.cfg.norm == "huber": + pad_k = np.zeros(self.y_len) + if self.cfg.free_kernel: + self.q = ( + self.q0 + + self.l0_penal * np.concatenate([[0], self.w, pad_k, pad_k]) + + self.l1_penal + * np.concatenate([[0], np.ones(len(self.nzidx_s)), pad_k, pad_k]) + ) + else: + self.q = ( + self.q0 + + self.l0_penal + * np.concatenate([[0], self.w @ self.G, pad_k, pad_k]) + + self.l1_penal + * np.concatenate( + [[0], np.ones(self.G.shape[0]) @ self.G, pad_k, pad_k] + ) + ) + + def _update_A(self) -> None: + """Update constraint matrix A.""" + if self.cfg.free_kernel: + Ax = sps.eye(len(self.nzidx_s), format="csc") + Ar = self.scale * self.R @ self.H + else: + Ax = sps.csc_matrix(self.G_org[:, self.nzidx_c]) + # Record spike terms that require constraint + self.nzidx_A = np.where((Ax != 0).sum(axis=1))[0] + Ax = Ax[self.nzidx_A, :] + Ar = self.scale * self.R + + if self.cfg.norm == "huber": + e = sps.eye(self.y_len, format="csc") + self.A = sps.bmat( + [ + [sps.csc_matrix((Ax.shape[0], 1)), Ax, None, None], + [None, None, e, None], + [None, None, None, -e], + [np.ones((Ar.shape[0], 1)), Ar, e, e], + ], + format="csc", + ) + else: + self.A = sps.bmat([[np.ones((1, 1)), None], [None, Ax]], format="csc") + + logger.debug(f"Updated A matrix - shape: {self.A.shape}, nnz: {self.A.nnz}") + + def _update_bounds(self) -> None: + """Update constraint bounds.""" + if self.cfg.norm == "huber": + xlen = len(self.nzidx_s) if self.cfg.free_kernel else len(self.nzidx_A) + self.lb = np.concatenate( + [np.zeros(xlen + self.y_len * 2), self.y - self.huber_k] + ) + self.ub = np.concatenate( + [np.ones(xlen), np.full(self.y_len * 2, np.inf), self.y - self.huber_k] + ) + self.ub_inf = np.concatenate( + [np.full(xlen + self.y_len * 2, np.inf), self.y - self.huber_k] + ) + else: + bb = np.clip(self.y.mean(), 0, None) if self.cfg.use_base else 0 + if self.cfg.free_kernel: + self.lb = np.zeros(len(self.nzidx_s) + 1) + self.ub = np.concatenate([np.full(1, bb), np.ones(len(self.nzidx_s))]) + self.ub_inf = np.concatenate( + [np.full(1, bb), np.full(len(self.nzidx_s), np.inf)] + ) + else: + ub_pad = np.zeros(self.T) + ub_inf_pad = np.zeros(self.T) + ub_pad[self.nzidx_s] = 1 + ub_inf_pad[self.nzidx_s] = np.inf + self.lb = np.zeros(len(self.nzidx_A) + 1) + self.ub = np.concatenate([np.full(1, bb), ub_pad[self.nzidx_A]]) + self.ub_inf = np.concatenate([np.full(1, bb), ub_inf_pad[self.nzidx_A]]) + + assert (self.ub >= self.lb).all(), "Upper bounds must be >= lower bounds" + assert ( + self.ub_inf >= self.lb + ).all(), "Upper bounds (inf) must be >= lower bounds" + + def _setup_prob_osqp(self) -> None: + """Setup OSQP problem instances.""" + logger.debug("Setting up OSQP problem") + + self._update_P() + self._update_q0() + self._update_q() + self._update_A() + self._update_bounds() + + # Choose solver backend + if self.cfg.backend == "cuosqp": + if not HAS_CUOSQP: + logger.warning("cuosqp not available, falling back to osqp") + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + else: + self.prob = cuosqp.OSQP() + self.prob_free = cuosqp.OSQP() + elif self.cfg.backend == "emosqp": + # Stub: emosqp requires codegen, not supported in this refactor + logger.warning("emosqp requires codegen, using osqp instead") + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + else: + self.prob = osqp.OSQP() + self.prob_free = osqp.OSQP() + + # Setup constrained problem + self.prob.setup( + P=self.P.copy(), + q=self.q.copy(), + A=self.A.copy(), + l=self.lb.copy(), + u=self.ub.copy(), + verbose=False, + polish=True, + warm_start=False, + eps_abs=1e-6, + eps_rel=1e-6, + eps_prim_inf=1e-7, + eps_dual_inf=1e-7, + ) + + # Setup unconstrained (free) problem + self.prob_free.setup( + P=self.P.copy(), + q=self.q.copy(), + A=self.A.copy(), + l=self.lb.copy(), + u=self.ub_inf.copy(), + verbose=False, + polish=True, + warm_start=False, + eps_abs=1e-6, + eps_rel=1e-6, + eps_prim_inf=1e-7, + eps_dual_inf=1e-7, + ) + + logger.debug(f"{self.cfg.backend} setup completed successfully") + + def update( + self, + y: np.ndarray = None, + coef: np.ndarray = None, + tau: np.ndarray = None, + theta: np.ndarray = None, + scale: float = None, + scale_mul: float = None, + l1_penal: float = None, + l0_penal: float = None, + w: np.ndarray = None, + update_weighting: bool = False, + clear_weighting: bool = False, + scale_coef: bool = False, + **kwargs, + ): + """Update OSQP problem parameters.""" + logger.debug(f"Updating OSQP solver parameters") + + # Update input parameters + if y is not None: + self.y = y + # Match legacy behavior: huber_k tracks the current y (used in q0 and bounds) + self.huber_k = 0.5 * np.std(self.y) + if tau is not None: + theta_new = np.array(tau2AR(tau[0], tau[1])) + p = solve_p(tau[0], tau[1]) + coef_new, _, _ = exp_pulse( + tau[0], tau[1], p_d=p, p_r=-p, nsamp=self.coef_len, kn_len=self.coef_len + ) + self.tau = tau + self.ps = np.array([p, -p]) + self.theta = theta_new + coef = coef_new + if theta is not None: + self.theta = theta + if coef is not None: + if scale_coef and self.coef is not None: + scale_mul = scal_lstsq(coef, self.coef).item() + self.coef = coef + if scale is not None: + self.scale = scale + if scale_mul is not None: + self.scale *= scale_mul + if l1_penal is not None: + self.l1_penal = l1_penal + if l0_penal is not None: + self.l0_penal = l0_penal + if w is not None: + self._update_w(w) + + # Track what needs updating + updt_HG = coef is not None + updt_P = False + updt_q0 = False + updt_q = False + updt_A = False + updt_bounds = False + setup_prob = False + + if updt_HG: + self._update_HG() + self._update_wgt_len() + + if self.cfg.err_weighting is not None and update_weighting: + self._update_Wt(clear=clear_weighting) + if self.cfg.err_weighting == "adaptive": + setup_prob = True + else: + updt_P = True + updt_q0 = True + updt_q = True + + if self.cfg.norm == "huber": + # huber_k changes require recomputing q and bounds + if y is not None: + self._update_q0() + updt_q0 = True + if any([scale is not None, scale_mul is not None, updt_HG]): + self._update_A() + updt_A = True + if any( + [w is not None, l0_penal is not None, l1_penal is not None, updt_HG] + ): + self._update_q() + updt_q = True + if y is not None: + self._update_bounds() + updt_bounds = True + else: + if any([updt_HG, updt_A]): + self._update_A() + updt_A = True + if any([scale is not None, scale_mul is not None, updt_HG, updt_P]): + self._update_P() + updt_P = True + if any( + [ + scale is not None, + scale_mul is not None, + y is not None, + updt_HG, + updt_q0, + ] + ): + self._update_q0() + updt_q0 = True + if any( + [ + w is not None, + l0_penal is not None, + l1_penal is not None, + updt_q0, + updt_q, + ] + ): + self._update_q() + updt_q = True + + # Apply updates to OSQP - conservative approach: + # Only q can be updated in-place safely. For P, A, bounds, rebuild. + if setup_prob or any([updt_P, updt_A, updt_bounds]): + self._setup_prob_osqp() + elif updt_q: + self.prob.update(q=self.q) + self.prob_free.update(q=self.q) + + logger.debug("OSQP problem updated") + + def solve(self, amp_constraint: bool = True) -> Tuple[np.ndarray, float, Any]: + """Solve OSQP problem.""" + prob = self.prob if amp_constraint else self.prob_free + res = prob.solve() + + if res.info.status not in ["solved", "solved inaccurate"]: + logger.warning(f"OSQP not solved: {res.info.status}") + if res.info.status in ["primal infeasible", "primal infeasible inaccurate"]: + x = np.zeros(self.P.shape[0], dtype=float) + else: + x = ( + res.x.astype(float) + if res.x is not None + else np.zeros(self.P.shape[0], dtype=float) + ) + else: + x = res.x + + if self.cfg.norm == "huber": + xlen = ( + len(self.nzidx_s) + 1 if self.cfg.free_kernel else len(self.nzidx_c) + 1 + ) + sol = x[:xlen] + opt_b = sol[0] + if self.cfg.free_kernel: + opt_s = sol[1:] + else: + opt_s = self.G @ sol[1:] + else: + opt_b = x[0] + if self.cfg.free_kernel: + opt_s = x[1:] + else: + c_sol = x[1:] + opt_s = self.G @ c_sol + + # Return 0 for objective - caller should use _compute_err for correct objective + return opt_s, opt_b, 0 diff --git a/src/indeca/core/deconv/utils.py b/src/indeca/core/deconv/utils.py new file mode 100644 index 0000000..e806446 --- /dev/null +++ b/src/indeca/core/deconv/utils.py @@ -0,0 +1,121 @@ +"""Utility functions for deconv module.""" + +import numpy as np +import scipy.sparse as sps +from numba import njit +from scipy.signal import ShortTimeFFT +from indeca.core.simulation import tau2AR + + +def get_stft_spec(x: np.ndarray, stft: ShortTimeFFT) -> np.ndarray: + """Compute STFT spectrogram.""" + spec = np.abs(stft.stft(x)) ** 2 + t = stft.t(len(x)) + t_mask = np.logical_and(t >= 0, t < len(x)) + return spec[:, t_mask] + + +def construct_R(T: int, up_factor: int): + """Construct the resampling matrix R.""" + if up_factor > 1: + return sps.csc_matrix( + ( + np.ones(T * up_factor), + (np.repeat(np.arange(T), up_factor), np.arange(T * up_factor)), + ), + shape=(T, T * up_factor), + ) + else: + return sps.eye(T, format="csc") + + +def sum_downsample(a, factor): + """Sum downsample array a by factor.""" + return np.convolve(a, np.ones(factor), mode="full")[factor - 1 :: factor] + + +def construct_G(fac: np.ndarray, T: int, fromTau=False): + """Construct the generator matrix G.""" + # I think we should be able to remove fromTau argument since we don't use it anywhere. + fac = np.array(fac) + assert fac.shape == (2,) + if fromTau: + fac = np.array(tau2AR(*fac)) + return sps.dia_matrix( + ( + np.tile(np.concatenate(([1], -fac)), (T, 1)).T, + -np.arange(len(fac) + 1), + ), + shape=(T, T), + ).tocsc() + + +def max_thres( + a: np.ndarray, + nthres: int, + th_min=0.1, + th_max=0.9, + ds=None, + return_thres=False, + th_amplitude=False, + delta=1e-6, + reverse_thres=False, + nz_only: bool = False, +): + """Threshold array a with nthres levels.""" + # Accept any array-like; normalized to numpy. + a = np.asarray(a) + amax = a.max() + if reverse_thres: + thres = np.linspace(th_max, th_min, nthres) + else: + thres = np.linspace(th_min, th_max, nthres) + if th_amplitude: + S_ls = [np.floor_divide(a, (amax * th).clip(delta, None)) for th in thres] + else: + S_ls = [(a > (amax * th).clip(delta, None)) for th in thres] + if ds is not None: + S_ls = [sum_downsample(s, ds) for s in S_ls] + if nz_only: + Snz = [ss.sum() > 0 for ss in S_ls] + S_ls = [ss for ss, nz in zip(S_ls, Snz) if nz] + thres = [th for th, nz in zip(thres, Snz) if nz] + if return_thres: + return S_ls, thres + else: + return S_ls + + +@njit(nopython=True, nogil=True, cache=True) +def bin_convolve( + coef: np.ndarray, s: np.ndarray, nzidx_s: np.ndarray = None, s_len: int = None +): + """Binary convolution implemented in numba.""" + coef_len = len(coef) + if s_len is None: + s_len = len(s) + out = np.zeros(s_len) + nzidx = np.where(s)[0] + if nzidx_s is not None: + nzidx = nzidx_s[nzidx].astype( + np.int64 + ) # astype to fix numpa issues on GPU on Windows + for i0 in nzidx: + i1 = min(i0 + coef_len, s_len) + clen = i1 - i0 + out[i0:i1] += coef[:clen] + return out + + +@njit(nopython=True, nogil=True, cache=True) +def max_consecutive(arr): + """Find maximum consecutive ones.""" + max_count = 0 + current_count = 0 + for value in arr: + if value: + current_count += 1 + max_count = max(max_count, current_count) + else: + current_count = 0 + return max_count diff --git a/src/indeca/pipeline/pipeline.py b/src/indeca/pipeline/pipeline.py index 5c463d6..cfd4fd4 100644 --- a/src/indeca/pipeline/pipeline.py +++ b/src/indeca/pipeline/pipeline.py @@ -8,7 +8,7 @@ from indeca.core.AR_kernel import AR_upsamp_real, estimate_coefs, updateAR from indeca.dashboard.dashboard import Dashboard -from indeca.core.deconv.deconv import DeconvBin, construct_R +from indeca.core.deconv import DeconvBin, construct_R from indeca.utils.logging_config import get_module_logger from indeca.core.simulation import AR2tau, find_dhm, tau2AR from indeca.utils.utils import compute_dff diff --git a/tests/integration/test_deconv_masking.py b/tests/integration/test_deconv_masking.py index 5ec9cbd..ed55c57 100644 --- a/tests/integration/test_deconv_masking.py +++ b/tests/integration/test_deconv_masking.py @@ -18,10 +18,10 @@ def test_masking(self, taus, rand_seed, upsamp, eq_atol, test_fig_path_html): deconv, y, c, c_org, s, s_org, scale = fixt_deconv( taus=taus, rand_seed=rand_seed, upsamp=upsamp, deconv_kws={"Hlim": None} ) - s_nomsk, b_nomsk = deconv._solve(amp_constraint=False) + s_nomsk, b_nomsk = deconv.solve(amp_constraint=False, pks_polish=False) c_nomsk = deconv.H @ s_nomsk deconv._update_mask() - s_msk, b_msk = deconv._solve(amp_constraint=False) + s_msk, b_msk = deconv.solve(amp_constraint=False, pks_polish=False) c_msk = deconv.H @ s_msk s_msk = deconv._pad_s(s_msk) c_msk = deconv._pad_c(c_msk) diff --git a/tests/integration/test_deconv_solve.py b/tests/integration/test_deconv_solve.py index 1e50dea..47e799b 100644 --- a/tests/integration/test_deconv_solve.py +++ b/tests/integration/test_deconv_solve.py @@ -24,7 +24,7 @@ def test_solve(self, taus, rand_seed, backend, upsamp, eq_atol, test_fig_path_ht upsamp=upsamp, deconv_kws={"Hlim": None}, ) - R = deconv.R.value if backend == "cvxpy" else deconv.R + R = deconv.R s_solve, b_solve = deconv.solve(amp_constraint=False, pks_polish=True) c_solve = deconv.H @ s_solve c_solve_R = R @ c_solve diff --git a/tests/unit/test_deconv_G_matrix.py b/tests/unit/test_deconv_G_matrix.py new file mode 100644 index 0000000..08d3b7e --- /dev/null +++ b/tests/unit/test_deconv_G_matrix.py @@ -0,0 +1,57 @@ +import numpy as np + +from indeca.core.deconv import DeconvBin +from indeca.core.simulation import tau2AR + + +def test_G_matrix_matches_shifted_ar_difference_equation(): + """ + The legacy implementation constructs a (T x T) "G" operator such that: + - s = G @ c + - s[-1] == 0 (last row is zeros) + - for t >= 2: s[t] = c[t+1] - theta0*c[t] - theta1*c[t-1] + - for t == 1: s[1] = c[2] - theta0*c[1] + - for t == 0: s[0] = c[1] + + This test guards against accidental dimensional/shift regressions in `G_org`. + """ + T = 5 + theta = np.array(tau2AR(10.0, 3.0)) + + deconv = DeconvBin( + y_len=T, + theta=theta, + coef_len=3, + backend="osqp", + free_kernel=False, + use_base=False, + norm="l2", + penal=None, + ) + + # Use the full, unmasked operator for determinism. + deconv._reset_mask() + G = deconv.solver.G_org + assert G.shape == (T, T) + + # Build a c vector with the same boundary condition as the solver (c[0] == 0). + rng = np.random.default_rng(0) + c = rng.normal(size=T) + c[0] = 0.0 + + # `scipy.sparse_matrix @ np.ndarray` returns a dense np.ndarray (no `.todense()`). + s = np.asarray(G @ c.reshape(-1, 1)).squeeze() + + # Last element must be exactly 0 due to bottom row of zeros. + assert np.isclose(s[-1], 0.0) + + # Check the shifted AR-difference mapping on the first T-1 entries. + th0, th1 = float(theta[0]), float(theta[1]) + expected = np.zeros(T) + expected[0] = c[1] + expected[1] = c[2] - th0 * c[1] + expected[2] = c[3] - th0 * c[2] - th1 * c[1] + expected[3] = c[4] - th0 * c[3] - th1 * c[2] + expected[4] = 0.0 + + assert np.allclose(s, expected)