diff --git a/pynumdiff/total_variation_regularization.py b/pynumdiff/total_variation_regularization.py index a3f406e..1935d0f 100644 --- a/pynumdiff/total_variation_regularization.py +++ b/pynumdiff/total_variation_regularization.py @@ -52,7 +52,49 @@ def iterative_velocity(x, dt, params=None, options=None, num_iterations=None, ga return x_hat, dxdt_hat +#N-d case: +def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None, axis=0): + """ + Generalized total variation regularized derivatives (cvxpy). Supports multidimensionality by differentiating along + 'axis', independently for each vector obtained by fixing all other indices. + + :param np.array[float] x: data to differentiate + :param float dt: step size + :param int order: 1, 2, or 3, the derivative to regularize + :param float gamma: regularization parameter + :param float huberM: Huber loss parameter, in units of scaled median absolute deviation of input data. + :math:`M = \\infty` reduces to :math:`\\ell_2` loss squared on first, fidelity cost term, and + :math:`M = 0` reduces to :math:`\\ell_1` loss, which seeks sparse residuals. + :param str solver: Solver to use. Solver options include: 'MOSEK', 'CVXOPT', 'CLARABEL', 'ECOS'. + If not given, fall back to CVXPY's default. + :return: - **x_hat** (np.array) -- estimated (smoothed) x + - **dxdt_hat** (np.array) -- estimated derivative of x + """ + + x0 = np.moveaxis(x, axis, 0) + + # end quick if it's just 1d case + if x0.ndim == 1: + x_hat0, dxdt0 = tvrdiff(x0, dt, order, gamma, huberM, solver) + return x_hat0, dxdt0 + + x_hat0 = np.empty_like(x0, dtype=float) + dxdt0 = np.empty_like(x0, dtype=float) + rest = x0.shape[1:] + print(rest) + + # had to loop in python:( + for i in np.ndindex(rest): + slice = (slice(None),) + i + x_hat0[slice], dxdt0[slice] = tvrdiff(x0[slice], dt, order, gamma, huberM, solver) + + x_hat = np.moveaxis(x_hat0, 0, axis) + dxdt_hat = np.moveaxis(dxdt0, 0, axis) + + return x_hat, dxdt_hat + +# 1-d case: def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None): """Generalized total variation regularized derivatives. Use convex optimization (cvxpy) to solve for a total variation regularized derivative. Other convex-solver-based methods in this module call this function. @@ -70,6 +112,7 @@ def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None): :return: - **x_hat** (np.array) -- estimated (smoothed) x - **dxdt_hat** (np.array) -- estimated derivative of x """ + # Normalize for numerical consistency with convex solver mu = np.mean(x) sigma = median_abs_deviation(x, scale='normal') # robust alternative to std()