From 354243de7a1cd395fd3d119ccd4b539566257ec2 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Mar 2026 13:59:28 +0100 Subject: [PATCH 01/16] wip: first shot lsmr in torch + tests. --- pyfixest/estimation/torch/__init__.py | 0 pyfixest/estimation/torch/demean_torch_.py | 318 ++++++++++++ pyfixest/estimation/torch/lsmr_torch.py | 280 +++++++++++ tests/test_torch_lsmr.py | 535 +++++++++++++++++++++ 4 files changed, 1133 insertions(+) create mode 100644 pyfixest/estimation/torch/__init__.py create mode 100644 pyfixest/estimation/torch/demean_torch_.py create mode 100644 pyfixest/estimation/torch/lsmr_torch.py create mode 100644 tests/test_torch_lsmr.py diff --git a/pyfixest/estimation/torch/__init__.py b/pyfixest/estimation/torch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py new file mode 100644 index 000000000..30fc8f5a0 --- /dev/null +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -0,0 +1,318 @@ +""" +FWL demeaning via LSMR in pure PyTorch. + +Builds a sparse dummy matrix D directly from integer-encoded fixed effects +(no formulaic/pandas dependency), then solves D @ theta = x via LSMR per column. +Diagonal preconditioning normalizes column norms when group sizes vary. + +Sparse format is chosen per device: CSR on CUDA/CPU (cuSPARSE / native), +COO on MPS (Metal does not support sparse CSR). +""" + +from __future__ import annotations + +import warnings + +import numpy as np +import torch +from numpy.typing import NDArray + +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch + + +def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: + """Auto-detect best available device: CUDA > MPS > CPU. + + MPS does not support float64, so we fall back to CPU when float64 is needed. + When MPS is available but dtype is float64, a hint is issued to use float32. + """ + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + if dtype != torch.float64: + return torch.device("mps") + warnings.warn( + "MPS GPU is available but requires float32. " + "Pass `dtype=torch.float32` to `demean_torch` for GPU acceleration. " + "Falling back to CPU.", + UserWarning, + stacklevel=3, + ) + return torch.device("cpu") + warnings.warn( + "No GPU available — torch demeaning will run on CPU, which is slower " + "than the scipy backend. Consider using `demean_scipy` instead.", + UserWarning, + stacklevel=3, + ) + return torch.device("cpu") + + +def _build_sparse_dummy( + flist: NDArray[np.uint64], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Build sparse dummy matrix D from integer-encoded FE array. + + For n_factors factors, D has shape (N, total_groups) where + total_groups = sum of unique groups per factor. Columns are stacked: + [factor0_groups | factor1_groups | ...]. + + No reference-level dropping — LSMR finds the min-norm solution, + which gives correct residuals for rank-deficient systems. + + Returns COO on MPS (Metal has no CSR kernels), CSR otherwise. + + Parameters + ---------- + flist : np.ndarray, shape (N, n_factors), dtype uint64 + Integer-encoded fixed effects. Must be contiguous 0-based integers + per factor (i.e., values in [0, n_groups_j) with no gaps). + device : torch.device + Target device. + dtype : torch.dtype + Value dtype (e.g. torch.float64). + + Returns + ------- + D : torch.Tensor + Sparse tensor of shape (N, total_groups). COO on MPS, CSR otherwise. + + Raises + ------ + ValueError + If any factor has non-contiguous group IDs (gaps in the integer encoding). + """ + N, n_factors = flist.shape + + row_indices = [] + col_indices = [] + col_offset = 0 + + for j in range(n_factors): + col_j = flist[:, j] + unique_vals = np.unique(col_j) + n_unique_j = len(unique_vals) + n_groups_j = int(unique_vals[-1]) + 1 + + if n_groups_j != n_unique_j: + raise ValueError( + f"Factor {j} has non-contiguous group IDs: " + f"max ID is {n_groups_j - 1} but only {n_unique_j} unique values found. " + f"Re-encode to contiguous 0-based integers." + ) + + rows = np.arange(N, dtype=np.int64) + cols = col_j.astype(np.int64) + col_offset + + row_indices.append(rows) + col_indices.append(cols) + col_offset += n_groups_j + + all_rows = np.concatenate(row_indices) + all_cols = np.concatenate(col_indices) + + indices = torch.tensor( + np.stack([all_rows, all_cols]), dtype=torch.long, device=device + ) + values = torch.ones(len(all_rows), dtype=dtype, device=device) + + D_coo = torch.sparse_coo_tensor(indices, values, size=(N, col_offset)) + + if device.type == "mps": + return D_coo.coalesce() + return D_coo.to_sparse_csr() + + +def _scale_sparse_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse matrix by a dense vector (row-wise multiply). + + Dispatches to the appropriate implementation based on sparse layout: + CSR on CUDA/CPU, COO on MPS. + """ + if D.layout == torch.sparse_csr: + return _scale_csr_rows(D, scale) + return _scale_coo_rows(D, scale) + + +def _scale_csr_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse CSR matrix. + + Operates directly on CSR internal arrays, avoiding COO format roundtrip. + Uses repeat_interleave to expand per-row scales to per-nonzero scales. + """ + crow = D.crow_indices() + col = D.col_indices() + row_counts = crow[1:] - crow[:-1] + val = D.values() * torch.repeat_interleave(scale, row_counts) + + return torch.sparse_csr_tensor( + crow, col, val, D.shape, dtype=D.dtype, device=D.device + ) + + +def _scale_coo_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse COO matrix. + + Indexes into the scale vector using COO row indices to expand + per-row scales to per-nonzero scales. + """ + d_indices = D.indices() + new_values = D.values() * scale[d_indices[0]] + return torch.sparse_coo_tensor( + d_indices, new_values, D.shape, device=D.device + ).coalesce() + + +def demean_torch( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, + dtype: torch.dtype = torch.float64, +) -> tuple[NDArray[np.float64], bool]: + """ + Demean x by projecting out fixed effects via FWL + LSMR. + + Parameters + ---------- + x : np.ndarray, shape (N,) or (N, K) + Variables to demean. + flist : np.ndarray, shape (N, n_factors), dtype uint64 + Integer-encoded fixed effects. + weights : np.ndarray, shape (N,) + Observation weights (1.0 for equal weighting). + tol : float + Convergence tolerance for LSMR (used as both atol and btol). + maxiter : int + Maximum LSMR iterations per column. + dtype : torch.dtype + Tensor dtype. Use ``torch.float32`` to enable MPS (Apple GPU) + acceleration. Default ``torch.float64`` for full precision. + + Returns + ------- + x_demeaned : np.ndarray + Residuals after projecting out fixed effects. Same shape as input x. + success : bool + True if LSMR converged for all columns. + """ + if flist is None: + raise ValueError("flist cannot be None") + if weights is None: + weights = np.ones(x.shape[0], dtype=np.float64) + + # Track original shape to restore on output + was_1d = x.ndim == 1 + x_2d = x[:, None] if was_1d else x + weights_1d = weights.ravel() + + device = _get_device(dtype) + + # Move to torch — use from_numpy for zero-copy when staying on CPU + x_c = x_2d if x_2d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(x_2d) + w_c = weights_1d if weights_1d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(weights_1d) + x_t = torch.from_numpy(x_c).to(dtype=dtype, device=device) + w_t = torch.from_numpy(w_c).to(dtype=dtype, device=device) + + # Ensure flist is 2D + flist_2d = flist if flist.ndim == 2 else flist[:, None] + + # Build sparse dummy matrix (unweighted) + D_unweighted = _build_sparse_dummy(flist_2d, device, dtype) + _, D_cols = D_unweighted.shape + K = x_t.shape[1] + + # Apply sqrt-weights + sqrt_w = torch.sqrt(w_t) + x_w = x_t * sqrt_w[:, None] + + # Weight the dummy matrix: D_weighted = diag(sqrt_w) @ D_unweighted + D_weighted = _scale_sparse_rows(D_unweighted, sqrt_w) + + # Diagonal preconditioning: M_inv = 1 / sqrt(D_unweighted^T @ w) + # D_unweighted^T @ w gives the sum of weights per group + group_weights = D_unweighted.t() @ w_t + + # Guard against zero-weight groups (would produce inf in M_inv) + zero_weight_mask = group_weights <= 0.0 + if zero_weight_mask.any(): + bad_groups = zero_weight_mask.nonzero(as_tuple=False).squeeze(-1).tolist() + raise ValueError( + f"Fixed effect groups {bad_groups} have zero total weight. " + f"Check your weights or fixed effect encoding." + ) + + M_inv = 1.0 / torch.sqrt(group_weights) + + # Build preconditioned operator once (not per column) + A_precond = _PreconditionedSparse(D_weighted, M_inv) + + # Solve for each column + theta = torch.zeros(D_cols, K, dtype=dtype, device=device) + success = True + + for k in range(K): + z, istop, itn, normr, normar, normA, condA, normx = lsmr_torch( + A_precond, + x_w[:, k], + damp=0.0, + atol=tol, + btol=tol, + maxiter=maxiter, + ) + + # Recover theta from preconditioned solution: theta = M_inv * z + theta[:, k] = M_inv * z + success = success and (istop in (1, 2, 3)) + + # Compute residuals: x_demeaned = x - D_unweighted @ theta + x_demeaned = x_t - D_unweighted @ theta + + result = x_demeaned.cpu().numpy() + if was_1d: + result = result[:, 0] + + return result, success + + +class _PreconditionedSparse: + """ + Wraps a sparse matrix D and diagonal preconditioner M_inv + to present A_precond = D @ diag(M_inv) for LSMR. + + This avoids forming the preconditioned matrix explicitly — + just element-wise multiply before/after matvec. + + The transpose view is cached and returned by `.t()`, so LSMR's + repeated `A.t().mv(u)` calls don't allocate a new object each time. + """ + + def __init__( + self, D: torch.Tensor, M_inv: torch.Tensor, *, _transposed: bool = False + ): + m, n = D.shape + self.shape = (n, m) if _transposed else (m, n) + self._D = D + self._M_inv = M_inv + self._transposed = _transposed + self._T: _PreconditionedSparse | None = None + + def mv(self, v: torch.Tensor) -> torch.Tensor: + if self._transposed: + # Compute M_inv * (D^T @ u) + return self._M_inv * (self._D.t() @ v) + # Compute D @ (M_inv * v) + return self._D @ (self._M_inv * v) + + def t(self) -> _PreconditionedSparse: + """Return cached transpose view.""" + if self._T is None: + self._T = _PreconditionedSparse( + self._D, self._M_inv, _transposed=not self._transposed + ) + self._T._T = self # cross-link so .t().t() returns self + return self._T diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py new file mode 100644 index 000000000..a4172eb6f --- /dev/null +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -0,0 +1,280 @@ +""" +Pure PyTorch implementation of the LSMR algorithm. + +Ported from SciPy's `scipy.sparse.linalg.lsmr` (Fong & Saunders, 2011). +All vector operations use torch tensors (staying on-device for GPU), +while scalar Givens rotations use Python `math` to avoid autograd overhead. + +Reference: + D. C.-L. Fong and M. A. Saunders, + "LSMR: An iterative algorithm for sparse least-squares problems", + SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. +""" + +from __future__ import annotations + +import math + +import torch + + +def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: + """ + Stable Givens rotation (SymOrtho). + + Given scalars a and b, compute c, s, r such that: + [ c s ] [ a ] = [ r ] + [-s c ] [ b ] [ 0 ] + + This is the same algorithm as SciPy's `_sym_ortho` from LSQR, + using pure Python math for speed on scalar values. + """ + if b == 0.0: + # math.copysign(1, 0) = 1.0 but np.sign(0) = 0.0; match SciPy's behavior + c = 0.0 if a == 0.0 else math.copysign(1.0, a) + return c, 0.0, abs(a) + elif a == 0.0: + return 0.0, math.copysign(1.0, b), abs(b) + elif abs(b) > abs(a): + tau = a / b + s = math.copysign(1.0, b) / math.sqrt(1.0 + tau * tau) + c = s * tau + r = b / s + else: + tau = b / a + c = math.copysign(1.0, a) / math.sqrt(1.0 + tau * tau) + s = c * tau + r = a / c + return c, s, r + + +def _matvec(A, v: torch.Tensor) -> torch.Tensor: + """A @ v — works for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" + if isinstance(A, torch.Tensor): + return A @ v + return A.mv(v) + + +def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: + """A^T @ u — works for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" + if isinstance(A, torch.Tensor): + return A.t() @ u + return A.t().mv(u) + + +def lsmr_torch( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR iterative solver for sparse least-squares problems, in pure PyTorch. + + Solves ``min ||b - Ax||_2`` (or the damped variant) where A is a sparse + CSR (COO) tensor and b is a dense vector. All vector ops stay on the tensor's + device (CPU/CUDA/MPS). + + Parameters + ---------- + A : torch.Tensor + Sparse CSR tensor of shape (m, n). + b : torch.Tensor + Dense vector of shape (m,). + damp : float + Damping factor for regularized least-squares. + atol, btol : float + Stopping tolerances (see SciPy LSMR docs). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + + Returns + ------- + x : torch.Tensor + Solution vector of shape (n,). + istop : int + Reason for stopping (0-7, same codes as SciPy LSMR). + itn : int + Number of iterations used. + normr : float + ``||b - Ax||`` + normar : float + ``||A^T(b - Ax)||`` + normA : float + Estimate of Frobenius norm of A. + condA : float + Estimate of condition number of A. + normx : float + ``||x||`` + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b).item() + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb + + if beta > 0: + u = u * (1.0 / beta) + v = _rmatvec(A, u) + alpha = torch.linalg.norm(v).item() + else: + v = torch.zeros(n, device=device, dtype=dtype) + alpha = 0.0 + + if alpha > 0: + v = v * (1.0 / alpha) + + # --- Scalar state for iteration --- + itn = 0 + zetabar = alpha * beta + alphabar = alpha + rho = 1.0 + rhobar = 1.0 + cbar = 1.0 + sbar = 0.0 + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # Estimation of ||r|| + betadd = beta + betad = 0.0 + rhodold = 1.0 + tautildeold = 0.0 + thetatilde = 0.0 + zeta = 0.0 + d = 0.0 + + # Estimation of ||A|| and cond(A) + normA2 = alpha * alpha + maxrbar = 0.0 + minrbar = 1e100 + normA = math.sqrt(normA2) + condA = 1.0 + normx = 0.0 + + # Stopping + istop = 0 + ctol = 1.0 / conlim if conlim > 0 else 0.0 + normr = beta + normar = alpha * beta + + if normar == 0.0: + return x, istop, itn, normr, normar, normA, condA, normx + + if normb == 0.0: + x.zero_() + return x, istop, itn, normr, normar, normA, condA, normx + + # --- Main iteration loop --- + while itn < maxiter: + itn += 1 + + # Bidiagonalization step: get next beta, u, alpha, v + u = _matvec(A, v) - alpha * u + beta = torch.linalg.norm(u).item() + + if beta > 0: + u *= 1.0 / beta + v = _rmatvec(A, u) - beta * v + alpha = torch.linalg.norm(v).item() + if alpha > 0: + v *= 1.0 / alpha + + # Construct rotation Qhat_{k,2k+1} + chat, shat, alphahat = _sym_ortho(alphabar, damp) + + # Use plane rotation Q_i to turn B_i to R_i + rhoold = rho + c, s, rho = _sym_ortho(alphahat, beta) + thetanew = s * alpha + alphabar = c * alpha + + # Use plane rotation Qbar_i to turn R_i^T to R_i^bar + rhobarold = rhobar + zetaold = zeta + thetabar = sbar * rho + rhotemp = cbar * rho + cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew) + zeta = cbar * zetabar + zetabar = -sbar * zetabar + + # Update h, hbar, x (vector ops — stay on device) + hbar = h + hbar * (-(thetabar * rho) / (rhoold * rhobarold)) + x = x + (zeta / (rho * rhobar)) * hbar + h = v + h * (-(thetanew / rho)) + + # Estimate ||r|| + betaacute = chat * betadd + betacheck = -shat * betadd + + betahat = c * betaacute + betadd = -s * betaacute + + thetatildeold = thetatilde + ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar) + thetatilde = stildeold * rhobar + rhodold = ctildeold * rhobar + betad = -stildeold * betad + ctildeold * betahat + + tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold + taud = (zeta - thetatilde * tautildeold) / rhodold + d = d + betacheck * betacheck + normr = math.sqrt(d + (betad - taud) ** 2 + betadd * betadd) + + # Estimate ||A|| + normA2 = normA2 + beta * beta + normA = math.sqrt(normA2) + normA2 = normA2 + alpha * alpha + + # Estimate cond(A) + maxrbar = max(maxrbar, rhobarold) + if itn > 1: + minrbar = min(minrbar, rhobarold) + condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp) + + # Convergence tests + normar = abs(zetabar) + normx = torch.linalg.norm(x).item() + + test1 = normr / normb + if (normA * normr) != 0: + test2 = normar / (normA * normr) + else: + test2 = float("inf") + test3 = 1.0 / condA + t1 = test1 / (1.0 + normA * normx / normb) + rtol = btol + atol * normA * normx / normb + + if itn >= maxiter: + istop = 7 + if 1.0 + test3 <= 1.0: + istop = 6 + if 1.0 + test2 <= 1.0: + istop = 5 + if 1.0 + t1 <= 1.0: + istop = 4 + if test3 <= ctol: + istop = 3 + if test2 <= atol: + istop = 2 + if test1 <= rtol: + istop = 1 + + if istop > 0: + break + + return x, istop, itn, normr, normar, normA, condA, normx diff --git a/tests/test_torch_lsmr.py b/tests/test_torch_lsmr.py new file mode 100644 index 000000000..e4e94ee46 --- /dev/null +++ b/tests/test_torch_lsmr.py @@ -0,0 +1,535 @@ +""" +Tests for standalone PyTorch LSMR solver and torch-based FWL demeaning. + +Three test levels: +1. Bare LSMR: verify the solver on known linear systems +2. Demeaning vs pyhdfe: compare demean_torch against pyhdfe reference +3. Demeaning vs demean_scipy: compare against SciPy LSMR-based backend +""" + +import numpy as np +import pyhdfe +import pytest + +torch = pytest.importorskip("torch") + +from pyfixest.estimation.cupy.demean_cupy_ import demean_scipy # noqa: E402 +from pyfixest.estimation.torch.demean_torch_ import demean_torch # noqa: E402 +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def demean_data(): + """Shared data fixture for all demeaning tests (matches test_demean.py pattern).""" + rng = np.random.default_rng(929291) + N = 1000 + M = 10 + x = rng.normal(0, 1, M * N).reshape((N, M)) + f1 = rng.choice(list(range(M)), N).reshape((N, 1)) + f2 = rng.choice(list(range(M)), N).reshape((N, 1)) + flist = np.concatenate((f1, f2), axis=1).astype(np.uint64) + # Weights drawn from the *same* advanced RNG (not a fresh seed) + weights = rng.uniform(0, 1, N) + return x, flist, weights + + +# --------------------------------------------------------------------------- +# Level 1: Bare LSMR tests +# --------------------------------------------------------------------------- + + +class TestLSMR: + """Unit tests for the pure-torch LSMR solver.""" + + def test_overdetermined_known_solution(self): + """Overdetermined system with exact solution: LSMR should recover it.""" + A_dense = torch.tensor( + [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], dtype=torch.float64 + ) + x_true = torch.tensor([1.0, -1.0], dtype=torch.float64) + b = A_dense @ x_true + + A_sparse = A_dense.to_sparse_csr() + x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + + assert istop in (1, 2), f"LSMR did not converge, istop={istop}" + assert torch.allclose(x_sol, x_true, atol=1e-10), ( + f"Solution mismatch: {x_sol} vs {x_true}" + ) + assert normr < 1e-10, f"Residual too large: {normr}" + + def test_underdetermined_min_norm(self): + """Underdetermined system (m < n): LSMR should find min-norm solution.""" + torch.manual_seed(42) + A_dense = torch.randn(2, 4, dtype=torch.float64) + b = torch.randn(2, dtype=torch.float64) + + A_sparse = A_dense.to_sparse_csr() + x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + + residual = torch.norm(A_dense @ x_sol - b).item() + assert residual < 1e-8, f"Residual too large: {residual}" + + x_lstsq = torch.linalg.lstsq(A_dense, b).solution + assert torch.allclose(x_sol, x_lstsq, atol=1e-6), ( + f"Not min-norm: ||x_lsmr||={torch.norm(x_sol):.6f}, " + f"||x_lstsq||={torch.norm(x_lstsq):.6f}" + ) + + def test_larger_sparse_system(self): + """Larger sparse system to exercise the iteration loop.""" + rng = np.random.default_rng(123) + m, n = 200, 50 + density = 0.1 + + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + indices = torch.tensor(np.stack([rows, cols]), dtype=torch.long) + values = torch.tensor(vals, dtype=torch.float64) + A_sparse = torch.sparse_coo_tensor(indices, values, (m, n)).to_sparse_csr() + + x_true = torch.tensor(rng.standard_normal(n), dtype=torch.float64) + A_dense = A_sparse.to_dense() + b = A_dense @ x_true + + x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + + assert istop in (1, 2, 3), f"LSMR did not converge, istop={istop}" + residual = torch.norm(A_dense @ x_sol - b).item() + assert residual < 1e-5, f"Residual too large: {residual}" + + def test_zero_rhs(self): + """b = 0 should give x = 0.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.zeros(3, dtype=torch.float64) + + x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + + assert torch.allclose(x_sol, torch.zeros(3, dtype=torch.float64), atol=1e-12) + assert istop == 0 + + def test_damped_regularization(self): + """Damping should shrink the solution toward zero.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.ones(3, dtype=torch.float64) + + x_undamped, *_ = lsmr_torch(A_sparse, b, damp=0.0) + x_damped, *_ = lsmr_torch(A_sparse, b, damp=10.0) + + assert torch.norm(x_damped) < torch.norm(x_undamped), ( + "Damped solution should have smaller norm" + ) + + def test_maxiter_exhaustion_returns_istop_7(self): + """Forcing maxiter=2 on an ill-conditioned system must return istop=7.""" + # Use a system that genuinely needs many iterations + # (identity converges in 1 step, so we use a harder matrix) + torch.manual_seed(99) + A_dense = torch.randn(20, 10, dtype=torch.float64) + b = torch.randn(20, dtype=torch.float64) + A_sparse = A_dense.to_sparse_csr() + + _, istop, itn, *_ = lsmr_torch(A_sparse, b, maxiter=2, atol=1e-15, btol=1e-15) + + assert istop == 7, f"Expected istop=7 (maxiter hit), got {istop}" + assert itn == 2 + + def test_full_return_tuple(self): + """Verify all 8 return values are present and sensible.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.ones(3, dtype=torch.float64) + + x, istop, itn, normr, normar, normA, condA, normx = lsmr_torch(A_sparse, b) + + assert isinstance(normA, float) and normA > 0 + assert isinstance(condA, float) and condA >= 1.0 + assert isinstance(normx, float) and normx > 0 + + +# --------------------------------------------------------------------------- +# Level 2: Demeaning vs pyhdfe +# --------------------------------------------------------------------------- + + +class TestDemeanVsPyhdfe: + """Compare demean_torch against pyhdfe reference.""" + + def test_unweighted(self, demean_data): + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="demean_torch vs pyhdfe mismatch (unweighted)", + ) + + def test_weighted(self, demean_data): + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge (weighted)" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="demean_torch vs pyhdfe mismatch (weighted)", + ) + + +# --------------------------------------------------------------------------- +# Level 3: Demeaning vs demean_scipy +# --------------------------------------------------------------------------- + + +class TestDemeanVsScipy: + """Compare demean_torch against demean_scipy (both use LSMR).""" + + def test_unweighted_vs_scipy(self, demean_data): + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + res_scipy, success_scipy = demean_scipy(x, flist, weights, tol=1e-10) + res_torch, success_torch = demean_torch(x, flist, weights, tol=1e-10) + + assert success_scipy, "demean_scipy did not converge" + assert success_torch, "demean_torch did not converge" + + # Different D construction (formulaic drops reference level, we don't) + # so tolerances are slightly looser than torch-vs-pyhdfe. + # Both produce valid demeaned residuals, but the underlying theta + # coefficients differ because the systems have different null spaces. + np.testing.assert_allclose( + res_torch, + res_scipy, + rtol=1e-5, + atol=1e-7, + err_msg="demean_torch vs demean_scipy mismatch (unweighted)", + ) + + def test_weighted_vs_scipy(self, demean_data): + x, flist, weights = demean_data + + res_scipy, success_scipy = demean_scipy(x, flist, weights, tol=1e-10) + res_torch, success_torch = demean_torch(x, flist, weights, tol=1e-10) + + assert success_scipy, "demean_scipy did not converge" + assert success_torch, "demean_torch did not converge" + + np.testing.assert_allclose( + res_torch, + res_scipy, + rtol=1e-5, + atol=1e-7, + err_msg="demean_torch vs demean_scipy mismatch (weighted)", + ) + + +# --------------------------------------------------------------------------- +# Level 4: Edge cases for demean_torch +# --------------------------------------------------------------------------- + + +class TestDemeanEdgeCases: + """Edge case tests for demean_torch.""" + + def test_1d_x_input(self): + """1D x input should return 1D output with same shape.""" + rng = np.random.default_rng(42) + N = 100 + x_1d = rng.normal(0, 1, N) + flist = rng.choice(5, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + res, success = demean_torch(x_1d, flist, weights, tol=1e-10) + assert success + assert res.ndim == 1, f"Expected 1D output, got shape {res.shape}" + assert res.shape == x_1d.shape + + def test_single_fe_factor(self): + """Single fixed effect factor should work and match pyhdfe.""" + rng = np.random.default_rng(42) + N = 200 + x = rng.normal(0, 1, (N, 3)) + flist = rng.choice(10, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_three_fe_factors(self): + """Three FE factors should work and match pyhdfe.""" + rng = np.random.default_rng(42) + N = 500 + x = rng.normal(0, 1, (N, 2)) + f1 = rng.choice(8, N).reshape(N, 1) + f2 = rng.choice(6, N).reshape(N, 1) + f3 = rng.choice(4, N).reshape(N, 1) + flist = np.concatenate([f1, f2, f3], axis=1).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_non_contiguous_group_ids_raises(self): + """Non-contiguous group IDs (gaps) should raise ValueError.""" + N = 50 + x = np.ones((N, 1)) + # Groups [0, 5, 10] — non-contiguous, max+1=11 but only 3 unique + flist = np.array([0, 5, 10] * (N // 3) + [0] * (N % 3), dtype=np.uint64).reshape(N, 1) + weights = np.ones(N) + + with pytest.raises(ValueError, match="non-contiguous group IDs"): + demean_torch(x, flist, weights) + + def test_zero_weight_group_raises(self): + """If any group has zero total weight, should raise ValueError.""" + N = 50 + x = np.ones((N, 1)) + # All observations in group 0 get zero weight + flist = np.zeros(N, dtype=np.uint64).reshape(N, 1) + flist[N // 2 :, 0] = 1 + weights = np.ones(N) + weights[: N // 2] = 0.0 # group 0 has zero total weight + + with pytest.raises(ValueError, match="zero total weight"): + demean_torch(x, flist, weights) + + def test_flist_none_raises(self): + """flist=None should raise ValueError.""" + with pytest.raises(ValueError, match="flist cannot be None"): + demean_torch(np.ones((10, 2)), flist=None, weights=np.ones(10)) + + +# --------------------------------------------------------------------------- +# Level 5: High-dimensional stress tests +# --------------------------------------------------------------------------- + + +class TestHighDimensional: + """Stress tests with larger N, many groups, and unbalanced designs.""" + + def test_many_groups(self): + """N=10K, 200 groups per factor — exercises LSMR on a larger system.""" + rng = np.random.default_rng(7777) + N = 10_000 + G = 200 + x = rng.normal(0, 1, (N, 3)) + f1 = rng.choice(G, N).reshape(N, 1) + f2 = rng.choice(G, N).reshape(N, 1) + flist = np.concatenate([f1, f2], axis=1).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge on high-D problem" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_unbalanced_groups(self): + """Highly unbalanced groups: one large group, many tiny groups. + + This stresses the diagonal preconditioner because group_weights + vary by orders of magnitude. + """ + rng = np.random.default_rng(8888) + N = 5_000 + # Group 0 gets 80% of observations, groups 1-49 share the rest + groups = np.zeros(N, dtype=np.uint64) + small_start = int(N * 0.8) + groups[small_start:] = rng.choice(49, N - small_start).astype(np.uint64) + 1 + flist = groups.reshape(N, 1) + + x = rng.normal(0, 1, (N, 2)) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge on unbalanced groups" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_many_columns(self): + """Many dependent variables (K=50) — checks per-column loop scales.""" + rng = np.random.default_rng(9999) + N = 2_000 + K = 50 + x = rng.normal(0, 1, (N, K)) + flist = rng.choice(20, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_weighted_high_d(self): + """Weighted demeaning with many groups — weighted preconditioner stress test.""" + rng = np.random.default_rng(1234) + N = 10_000 + G = 100 + x = rng.normal(0, 1, (N, 5)) + flist = rng.choice(G, (N, 2)).astype(np.uint64) + # Highly skewed weights (log-normal) + weights = rng.lognormal(0, 2, N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge with skewed weights" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + +# --------------------------------------------------------------------------- +# Level 6: float32 / MPS tests +# --------------------------------------------------------------------------- + +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +class TestFloat32: + """Tests for float32 dtype path (COO on MPS, COO-converted-to-CSR on CPU).""" + + def test_cpu_float32_vs_pyhdfe(self, demean_data): + """float32 on CPU should match pyhdfe within single-precision tolerance.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (f32) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (f32 CPU) vs pyhdfe mismatch", + ) + + def test_cpu_float32_weighted(self, demean_data): + """Weighted float32 on CPU should match pyhdfe.""" + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (f32 weighted) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (f32 CPU weighted) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_vs_pyhdfe(self, demean_data): + """float32 on MPS should match pyhdfe within single-precision tolerance.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_weighted(self, demean_data): + """Weighted float32 on MPS should match pyhdfe.""" + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32 weighted) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32 weighted) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_high_d(self): + """MPS float32 with many groups — exercises COO path at scale.""" + rng = np.random.default_rng(5555) + N = 10_000 + G = 100 + x = rng.normal(0, 1, (N, 3)) + flist = rng.choice(G, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32 high-D) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32 high-D) vs pyhdfe mismatch", + ) From 1e2dd3bb553a93ba63206d7fc64feacdef1807ea Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Mar 2026 18:54:08 +0100 Subject: [PATCH 02/16] wip: new benchmark integration. awainting local review. --- benchmarks/benchmarks.py | 51 +++++++- benchmarks/config.json | 86 +++++++++++++ pyfixest/estimation/internals/backends.py | 53 ++++++++ pyfixest/estimation/internals/literals.py | 14 ++- pyfixest/estimation/torch/__init__.py | 15 +++ pyfixest/estimation/torch/demean_torch_.py | 140 ++++++++++++++++----- 6 files changed, 326 insertions(+), 33 deletions(-) create mode 100644 benchmarks/config.json diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 3700ef434..eb65530a3 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -16,6 +16,26 @@ import pandas as pd +# Optional JAX availability detection +try: + import jax # noqa: F401 + + HAS_JAX = True +except ImportError: + HAS_JAX = False + +# Optional torch availability detection +try: + import torch + + HAS_TORCH = True + HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + HAS_CUDA = torch.cuda.is_available() +except ImportError: + HAS_TORCH = False + HAS_MPS = False + HAS_CUDA = False + # ============================================================================= # Estimator functions (run in main process for JIT caching) # ============================================================================= @@ -220,6 +240,17 @@ def get_estimators( False, "pyfixest_feols", ), + ] + if HAS_JAX: + estimators.append(("pyfixest.feols (jax)", "jax", run_pyfixest_feols, False, "pyfixest_feols")) + if HAS_TORCH: + estimators.append(("pyfixest.feols (torch_cpu)", "torch_cpu", run_pyfixest_feols, False, "pyfixest_feols")) + if HAS_MPS: + estimators.append(("pyfixest.feols (torch_mps)", "torch_mps", run_pyfixest_feols, False, "pyfixest_feols")) + if HAS_CUDA: + estimators.append(("pyfixest.feols (torch_cuda)", "torch_cuda", run_pyfixest_feols, False, "pyfixest_feols")) + estimators.append(("pyfixest.feols (torch_cuda32)", "torch_cuda32", run_pyfixest_feols, False, "pyfixest_feols")) + estimators += [ ( "linearmodels.AbsorbingLS", "absorbingls", @@ -263,6 +294,15 @@ def get_estimators( "pyfixest_fepois", ), ] + if HAS_JAX: + estimators.append(("pyfixest.fepois (jax)", "jax", run_pyfixest_fepois, False, "pyfixest_fepois")) + if HAS_TORCH: + estimators.append(("pyfixest.fepois (torch_cpu)", "torch_cpu", run_pyfixest_fepois, False, "pyfixest_fepois")) + if HAS_MPS: + estimators.append(("pyfixest.fepois (torch_mps)", "torch_mps", run_pyfixest_fepois, False, "pyfixest_fepois")) + if HAS_CUDA: + estimators.append(("pyfixest.fepois (torch_cuda)", "torch_cuda", run_pyfixest_fepois, False, "pyfixest_fepois")) + estimators.append(("pyfixest.fepois (torch_cuda32)", "torch_cuda32", run_pyfixest_fepois, False, "pyfixest_fepois")) formulas = { 2: "negbin_y ~ x1 | indiv_id + year", 3: "negbin_y ~ x1 | indiv_id + year + firm_id", @@ -291,6 +331,15 @@ def get_estimators( "pyfixest_feglm_logit", ), ] + if HAS_JAX: + estimators.append(("pyfixest.feglm_logit (jax)", "jax", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) + if HAS_TORCH: + estimators.append(("pyfixest.feglm_logit (torch_cpu)", "torch_cpu", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) + if HAS_MPS: + estimators.append(("pyfixest.feglm_logit (torch_mps)", "torch_mps", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) + if HAS_CUDA: + estimators.append(("pyfixest.feglm_logit (torch_cuda)", "torch_cuda", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) + estimators.append(("pyfixest.feglm_logit (torch_cuda32)", "torch_cuda32", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) formulas = { 2: "binary_y ~ x1 | indiv_id + year", 3: "binary_y ~ x1 | indiv_id + year + firm_id", @@ -420,7 +469,7 @@ def run_benchmark( print(f"{elapsed:.3f}s") else: # Run in main process - if backend_or_func in ("scipy", "numba", "rust"): + if backend_or_func in ("scipy", "numba", "rust", "jax", "torch_cpu", "torch_mps", "torch_cuda", "torch_cuda32"): elapsed = func(data, formula, backend_or_func) else: elapsed = func(data, formula) diff --git a/benchmarks/config.json b/benchmarks/config.json new file mode 100644 index 000000000..6f974cea1 --- /dev/null +++ b/benchmarks/config.json @@ -0,0 +1,86 @@ +{ + "timeout_secs": { + "python": 60 + }, + "python_timeout_estimators": ["linearmodels.AbsorbingLS", "statsmodels.OLS"], + "iterations": { + "n_iters": 3, + "burn_in": 1 + }, + "formulas": { + "ols": { + "2": { + "python": "y ~ x1 | indiv_id + year" + }, + "3": { + "python": "y ~ x1 | indiv_id + year + firm_id" + } + }, + "poisson": { + "2": { + "python": "negbin_y ~ x1 | indiv_id + year" + }, + "3": { + "python": "negbin_y ~ x1 | indiv_id + year + firm_id" + } + }, + "logit": { + "2": { + "python": "binary_y ~ x1 | indiv_id + year" + }, + "3": { + "python": "binary_y ~ x1 | indiv_id + year + firm_id" + } + } + }, + "datasets": [ + { "name": "simple_1k", "n": 1000, "type": "simple" }, + { "name": "difficult_1k", "n": 1000, "type": "difficult" }, + { "name": "simple_10k", "n": 10000, "type": "simple" }, + { "name": "difficult_10k", "n": 10000, "type": "difficult" }, + { "name": "simple_100k", "n": 100000, "type": "simple" }, + { "name": "difficult_100k", "n": 100000, "type": "difficult" }, + { "name": "simple_500k", "n": 500000, "type": "simple" }, + { "name": "difficult_500k", "n": 500000, "type": "difficult" }, + { "name": "simple_1m", "n": 1000000, "type": "simple" }, + { "name": "difficult_1m", "n": 1000000, "type": "difficult" } + ], + "datasets_by_type": { + "ols": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m" + ], + "poisson": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m" + ], + "logit": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m" + ] + } +} diff --git a/pyfixest/estimation/internals/backends.py b/pyfixest/estimation/internals/backends.py index 2f6b35bde..acda43614 100644 --- a/pyfixest/estimation/internals/backends.py +++ b/pyfixest/estimation/internals/backends.py @@ -47,6 +47,29 @@ crv1_meat_loop_cupy = crv1_meat_loop_nb count_fixef_fully_nested_all_cupy = count_fixef_fully_nested_all_nb +# Try to import Torch functions, fall back to numba if not available +try: + from pyfixest.estimation.torch.demean_torch_ import ( + demean_torch, + demean_torch_cpu, + demean_torch_cuda, + demean_torch_cuda32, + demean_torch_mps, + ) + + TORCH_AVAILABLE = True +except ImportError: + demean_torch = demean_nb + demean_torch_cpu = demean_nb + demean_torch_mps = demean_nb + demean_torch_cuda = demean_nb + demean_torch_cuda32 = demean_nb + TORCH_AVAILABLE = False + +find_collinear_variables_torch = find_collinear_variables_nb +crv1_meat_loop_torch = crv1_meat_loop_nb +count_fixef_fully_nested_all_torch = count_fixef_fully_nested_all_nb + BACKENDS = { "numba": { "demean": demean_nb, @@ -96,4 +119,34 @@ "crv1_meat": crv1_meat_loop_cupy, "nonnested": count_fixef_fully_nested_all_cupy, }, + "torch": { + "demean": demean_torch, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + }, + "torch_cpu": { + "demean": demean_torch_cpu, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + }, + "torch_mps": { + "demean": demean_torch_mps, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + }, + "torch_cuda": { + "demean": demean_torch_cuda, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + }, + "torch_cuda32": { + "demean": demean_torch_cuda32, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + }, } diff --git a/pyfixest/estimation/internals/literals.py b/pyfixest/estimation/internals/literals.py index ab8e7e067..84be563f6 100644 --- a/pyfixest/estimation/internals/literals.py +++ b/pyfixest/estimation/internals/literals.py @@ -12,7 +12,19 @@ "jax", ] DemeanerBackendOptions = Literal[ - "numba", "jax", "rust", "rust-cg", "cupy", "cupy32", "cupy64", "scipy" + "numba", + "jax", + "rust", + "rust-cg", + "cupy", + "cupy32", + "cupy64", + "scipy", + "torch", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", ] PredictionErrorOptions = Literal["prediction"] QuantregMethodOptions = Literal["fn", "pfn"] diff --git a/pyfixest/estimation/torch/__init__.py b/pyfixest/estimation/torch/__init__.py index e69de29bb..c89341727 100644 --- a/pyfixest/estimation/torch/__init__.py +++ b/pyfixest/estimation/torch/__init__.py @@ -0,0 +1,15 @@ +from pyfixest.estimation.torch.demean_torch_ import ( + demean_torch, + demean_torch_cpu, + demean_torch_cuda, + demean_torch_cuda32, + demean_torch_mps, +) + +__all__ = [ + "demean_torch", + "demean_torch_cpu", + "demean_torch_cuda", + "demean_torch_cuda32", + "demean_torch_mps", +] diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 30fc8f5a0..54e036b49 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -166,39 +166,20 @@ def _scale_coo_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ).coalesce() -def demean_torch( +def _demean_torch_on_device( x: NDArray[np.float64], - flist: NDArray[np.uint64] | None = None, - weights: NDArray[np.float64] | None = None, - tol: float = 1e-8, - maxiter: int = 100_000, - dtype: torch.dtype = torch.float64, + flist: NDArray[np.uint64] | None, + weights: NDArray[np.float64] | None, + tol: float, + maxiter: int, + device: torch.device, + dtype: torch.dtype, ) -> tuple[NDArray[np.float64], bool]: """ - Demean x by projecting out fixed effects via FWL + LSMR. + Core demeaning implementation for a specific device and dtype. - Parameters - ---------- - x : np.ndarray, shape (N,) or (N, K) - Variables to demean. - flist : np.ndarray, shape (N, n_factors), dtype uint64 - Integer-encoded fixed effects. - weights : np.ndarray, shape (N,) - Observation weights (1.0 for equal weighting). - tol : float - Convergence tolerance for LSMR (used as both atol and btol). - maxiter : int - Maximum LSMR iterations per column. - dtype : torch.dtype - Tensor dtype. Use ``torch.float32`` to enable MPS (Apple GPU) - acceleration. Default ``torch.float64`` for full precision. - - Returns - ------- - x_demeaned : np.ndarray - Residuals after projecting out fixed effects. Same shape as input x. - success : bool - True if LSMR converged for all columns. + This is the shared workhorse called by all public wrappers. + See `demean_torch` for full parameter documentation. """ if flist is None: raise ValueError("flist cannot be None") @@ -210,8 +191,6 @@ def demean_torch( x_2d = x[:, None] if was_1d else x weights_1d = weights.ravel() - device = _get_device(dtype) - # Move to torch — use from_numpy for zero-copy when staying on CPU x_c = x_2d if x_2d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(x_2d) w_c = weights_1d if weights_1d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(weights_1d) @@ -279,6 +258,105 @@ def demean_torch( return result, success +def demean_torch( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, + dtype: torch.dtype = torch.float64, +) -> tuple[NDArray[np.float64], bool]: + """ + Demean x by projecting out fixed effects via FWL + LSMR. + + Auto-detects the best available device (CUDA > MPS > CPU). + For explicit device control, use the device-specific variants: + `demean_torch_cpu`, `demean_torch_mps`, `demean_torch_cuda`, + `demean_torch_cuda32`. + + Parameters + ---------- + x : np.ndarray, shape (N,) or (N, K) + Variables to demean. + flist : np.ndarray, shape (N, n_factors), dtype uint64 + Integer-encoded fixed effects. + weights : np.ndarray, shape (N,) + Observation weights (1.0 for equal weighting). + tol : float + Convergence tolerance for LSMR (used as both atol and btol). + maxiter : int + Maximum LSMR iterations per column. + dtype : torch.dtype + Tensor dtype. Use ``torch.float32`` to enable MPS (Apple GPU) + acceleration. Default ``torch.float64`` for full precision. + + Returns + ------- + x_demeaned : np.ndarray + Residuals after projecting out fixed effects. Same shape as input x. + success : bool + True if LSMR converged for all columns. + """ + device = _get_device(dtype) + return _demean_torch_on_device(x, flist, weights, tol, maxiter, device, dtype) + + +def demean_torch_cpu( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, +) -> tuple[NDArray[np.float64], bool]: + """Torch demeaner on CPU, float64.""" + return _demean_torch_on_device( + x, flist, weights, tol, maxiter, + device=torch.device("cpu"), dtype=torch.float64, + ) + + +def demean_torch_mps( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, +) -> tuple[NDArray[np.float64], bool]: + """Torch demeaner on MPS (Apple GPU), float32.""" + return _demean_torch_on_device( + x, flist, weights, tol, maxiter, + device=torch.device("mps"), dtype=torch.float32, + ) + + +def demean_torch_cuda( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, +) -> tuple[NDArray[np.float64], bool]: + """Torch demeaner on CUDA GPU, float64.""" + return _demean_torch_on_device( + x, flist, weights, tol, maxiter, + device=torch.device("cuda"), dtype=torch.float64, + ) + + +def demean_torch_cuda32( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, +) -> tuple[NDArray[np.float64], bool]: + """Torch demeaner on CUDA GPU, float32.""" + return _demean_torch_on_device( + x, flist, weights, tol, maxiter, + device=torch.device("cuda"), dtype=torch.float32, + ) + + class _PreconditionedSparse: """ Wraps a sparse matrix D and diagonal preconditioner M_inv From f0e8daacd02ae112cc29f32d7d589e9ec1d52380 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Mar 2026 21:28:40 +0100 Subject: [PATCH 03/16] refactoring --- benchmarks/benchmarks.py | 58 +++++++------- pyfixest/estimation/internals/backends.py | 43 ++++------- pyfixest/estimation/torch/demean_torch_.py | 90 ++++++++++------------ 3 files changed, 83 insertions(+), 108 deletions(-) diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index eb65530a3..f77eeff7a 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -36,6 +36,32 @@ HAS_MPS = False HAS_CUDA = False +# Backends that accept a backend= argument when called through pyfixest runners +_PYFIXEST_BACKENDS = {"scipy", "numba", "rust", "jax", "torch_cpu", "torch_mps", "torch_cuda", "torch_cuda32"} + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _append_optional_backends(estimators, label_prefix, runner_func, func_name): + """Append JAX + torch backend estimators based on runtime availability.""" + optional = [] + if HAS_JAX: + optional.append(("jax", "jax")) + if HAS_TORCH: + optional.append(("torch_cpu", "torch_cpu")) + if HAS_MPS: + optional.append(("torch_mps", "torch_mps")) + if HAS_CUDA: + optional.append(("torch_cuda", "torch_cuda")) + optional.append(("torch_cuda32", "torch_cuda32")) + for suffix, backend in optional: + estimators.append( + (f"{label_prefix} ({suffix})", backend, runner_func, False, func_name) + ) + + # ============================================================================= # Estimator functions (run in main process for JIT caching) # ============================================================================= @@ -241,15 +267,7 @@ def get_estimators( "pyfixest_feols", ), ] - if HAS_JAX: - estimators.append(("pyfixest.feols (jax)", "jax", run_pyfixest_feols, False, "pyfixest_feols")) - if HAS_TORCH: - estimators.append(("pyfixest.feols (torch_cpu)", "torch_cpu", run_pyfixest_feols, False, "pyfixest_feols")) - if HAS_MPS: - estimators.append(("pyfixest.feols (torch_mps)", "torch_mps", run_pyfixest_feols, False, "pyfixest_feols")) - if HAS_CUDA: - estimators.append(("pyfixest.feols (torch_cuda)", "torch_cuda", run_pyfixest_feols, False, "pyfixest_feols")) - estimators.append(("pyfixest.feols (torch_cuda32)", "torch_cuda32", run_pyfixest_feols, False, "pyfixest_feols")) + _append_optional_backends(estimators, "pyfixest.feols", run_pyfixest_feols, "pyfixest_feols") estimators += [ ( "linearmodels.AbsorbingLS", @@ -294,15 +312,7 @@ def get_estimators( "pyfixest_fepois", ), ] - if HAS_JAX: - estimators.append(("pyfixest.fepois (jax)", "jax", run_pyfixest_fepois, False, "pyfixest_fepois")) - if HAS_TORCH: - estimators.append(("pyfixest.fepois (torch_cpu)", "torch_cpu", run_pyfixest_fepois, False, "pyfixest_fepois")) - if HAS_MPS: - estimators.append(("pyfixest.fepois (torch_mps)", "torch_mps", run_pyfixest_fepois, False, "pyfixest_fepois")) - if HAS_CUDA: - estimators.append(("pyfixest.fepois (torch_cuda)", "torch_cuda", run_pyfixest_fepois, False, "pyfixest_fepois")) - estimators.append(("pyfixest.fepois (torch_cuda32)", "torch_cuda32", run_pyfixest_fepois, False, "pyfixest_fepois")) + _append_optional_backends(estimators, "pyfixest.fepois", run_pyfixest_fepois, "pyfixest_fepois") formulas = { 2: "negbin_y ~ x1 | indiv_id + year", 3: "negbin_y ~ x1 | indiv_id + year + firm_id", @@ -331,15 +341,7 @@ def get_estimators( "pyfixest_feglm_logit", ), ] - if HAS_JAX: - estimators.append(("pyfixest.feglm_logit (jax)", "jax", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) - if HAS_TORCH: - estimators.append(("pyfixest.feglm_logit (torch_cpu)", "torch_cpu", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) - if HAS_MPS: - estimators.append(("pyfixest.feglm_logit (torch_mps)", "torch_mps", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) - if HAS_CUDA: - estimators.append(("pyfixest.feglm_logit (torch_cuda)", "torch_cuda", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) - estimators.append(("pyfixest.feglm_logit (torch_cuda32)", "torch_cuda32", run_pyfixest_feglm_logit, False, "pyfixest_feglm_logit")) + _append_optional_backends(estimators, "pyfixest.feglm_logit", run_pyfixest_feglm_logit, "pyfixest_feglm_logit") formulas = { 2: "binary_y ~ x1 | indiv_id + year", 3: "binary_y ~ x1 | indiv_id + year + firm_id", @@ -469,7 +471,7 @@ def run_benchmark( print(f"{elapsed:.3f}s") else: # Run in main process - if backend_or_func in ("scipy", "numba", "rust", "jax", "torch_cpu", "torch_mps", "torch_cuda", "torch_cuda32"): + if backend_or_func in _PYFIXEST_BACKENDS: elapsed = func(data, formula, backend_or_func) else: elapsed = func(data, formula) diff --git a/pyfixest/estimation/internals/backends.py b/pyfixest/estimation/internals/backends.py index acda43614..40f698eed 100644 --- a/pyfixest/estimation/internals/backends.py +++ b/pyfixest/estimation/internals/backends.py @@ -119,34 +119,19 @@ "crv1_meat": crv1_meat_loop_cupy, "nonnested": count_fixef_fully_nested_all_cupy, }, - "torch": { - "demean": demean_torch, - "collinear": find_collinear_variables_torch, - "crv1_meat": crv1_meat_loop_torch, - "nonnested": count_fixef_fully_nested_all_torch, - }, - "torch_cpu": { - "demean": demean_torch_cpu, - "collinear": find_collinear_variables_torch, - "crv1_meat": crv1_meat_loop_torch, - "nonnested": count_fixef_fully_nested_all_torch, - }, - "torch_mps": { - "demean": demean_torch_mps, - "collinear": find_collinear_variables_torch, - "crv1_meat": crv1_meat_loop_torch, - "nonnested": count_fixef_fully_nested_all_torch, - }, - "torch_cuda": { - "demean": demean_torch_cuda, - "collinear": find_collinear_variables_torch, - "crv1_meat": crv1_meat_loop_torch, - "nonnested": count_fixef_fully_nested_all_torch, - }, - "torch_cuda32": { - "demean": demean_torch_cuda32, - "collinear": find_collinear_variables_torch, - "crv1_meat": crv1_meat_loop_torch, - "nonnested": count_fixef_fully_nested_all_torch, + **{ + name: { + "demean": demean_fn, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + } + for name, demean_fn in [ + ("torch", demean_torch), + ("torch_cpu", demean_torch_cpu), + ("torch_mps", demean_torch_mps), + ("torch_cuda", demean_torch_cuda), + ("torch_cuda32", demean_torch_cuda32), + ] }, } diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 54e036b49..fb201d835 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -186,6 +186,20 @@ def _demean_torch_on_device( if weights is None: weights = np.ones(x.shape[0], dtype=np.float64) + return _demean_torch_on_device_impl(x, flist, weights, tol, maxiter, device, dtype) + + +@torch.no_grad() +def _demean_torch_on_device_impl( + x: NDArray[np.float64], + flist: NDArray[np.uint64], + weights: NDArray[np.float64], + tol: float, + maxiter: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple[NDArray[np.float64], bool]: + """Inner implementation wrapped in torch.no_grad() to skip autograd overhead.""" # Track original shape to restore on output was_1d = x.ndim == 1 x_2d = x[:, None] if was_1d else x @@ -301,60 +315,34 @@ def demean_torch( return _demean_torch_on_device(x, flist, weights, tol, maxiter, device, dtype) -def demean_torch_cpu( - x: NDArray[np.float64], - flist: NDArray[np.uint64] | None = None, - weights: NDArray[np.float64] | None = None, - tol: float = 1e-8, - maxiter: int = 100_000, -) -> tuple[NDArray[np.float64], bool]: - """Torch demeaner on CPU, float64.""" - return _demean_torch_on_device( - x, flist, weights, tol, maxiter, - device=torch.device("cpu"), dtype=torch.float64, - ) - - -def demean_torch_mps( - x: NDArray[np.float64], - flist: NDArray[np.uint64] | None = None, - weights: NDArray[np.float64] | None = None, - tol: float = 1e-8, - maxiter: int = 100_000, -) -> tuple[NDArray[np.float64], bool]: - """Torch demeaner on MPS (Apple GPU), float32.""" - return _demean_torch_on_device( - x, flist, weights, tol, maxiter, - device=torch.device("mps"), dtype=torch.float32, - ) - +def _make_demean_variant( + device_str: str, + dtype: torch.dtype, + doc: str, +): + """Factory for device-specific demean wrappers.""" + + def _demean( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, + ) -> tuple[NDArray[np.float64], bool]: + return _demean_torch_on_device( + x, flist, weights, tol, maxiter, + device=torch.device(device_str), dtype=dtype, + ) -def demean_torch_cuda( - x: NDArray[np.float64], - flist: NDArray[np.uint64] | None = None, - weights: NDArray[np.float64] | None = None, - tol: float = 1e-8, - maxiter: int = 100_000, -) -> tuple[NDArray[np.float64], bool]: - """Torch demeaner on CUDA GPU, float64.""" - return _demean_torch_on_device( - x, flist, weights, tol, maxiter, - device=torch.device("cuda"), dtype=torch.float64, - ) + _demean.__doc__ = doc + _demean.__qualname__ = f"demean_torch_{device_str}" + return _demean -def demean_torch_cuda32( - x: NDArray[np.float64], - flist: NDArray[np.uint64] | None = None, - weights: NDArray[np.float64] | None = None, - tol: float = 1e-8, - maxiter: int = 100_000, -) -> tuple[NDArray[np.float64], bool]: - """Torch demeaner on CUDA GPU, float32.""" - return _demean_torch_on_device( - x, flist, weights, tol, maxiter, - device=torch.device("cuda"), dtype=torch.float32, - ) +demean_torch_cpu = _make_demean_variant("cpu", torch.float64, "Torch demeaner on CPU, float64.") +demean_torch_mps = _make_demean_variant("mps", torch.float32, "Torch demeaner on MPS (Apple GPU), float32.") +demean_torch_cuda = _make_demean_variant("cuda", torch.float64, "Torch demeaner on CUDA GPU, float64.") +demean_torch_cuda32 = _make_demean_variant("cuda", torch.float32, "Torch demeaner on CUDA GPU, float32.") class _PreconditionedSparse: From 455a7c157b9bbdbbc19d54364d16617124fffd54 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Mar 2026 21:34:26 +0100 Subject: [PATCH 04/16] ruff formatting --- benchmarks/benchmarks.py | 26 +++++++++++++--- pyfixest/estimation/torch/demean_torch_.py | 35 ++++++++++++++++------ pyfixest/estimation/torch/lsmr_torch.py | 7 ++--- tests/test_torch_lsmr.py | 21 ++++++++----- 4 files changed, 63 insertions(+), 26 deletions(-) diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index f77eeff7a..ae5091f7b 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -37,7 +37,16 @@ HAS_CUDA = False # Backends that accept a backend= argument when called through pyfixest runners -_PYFIXEST_BACKENDS = {"scipy", "numba", "rust", "jax", "torch_cpu", "torch_mps", "torch_cuda", "torch_cuda32"} +_PYFIXEST_BACKENDS = { + "scipy", + "numba", + "rust", + "jax", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", +} # ============================================================================= # Helpers @@ -267,7 +276,9 @@ def get_estimators( "pyfixest_feols", ), ] - _append_optional_backends(estimators, "pyfixest.feols", run_pyfixest_feols, "pyfixest_feols") + _append_optional_backends( + estimators, "pyfixest.feols", run_pyfixest_feols, "pyfixest_feols" + ) estimators += [ ( "linearmodels.AbsorbingLS", @@ -312,7 +323,9 @@ def get_estimators( "pyfixest_fepois", ), ] - _append_optional_backends(estimators, "pyfixest.fepois", run_pyfixest_fepois, "pyfixest_fepois") + _append_optional_backends( + estimators, "pyfixest.fepois", run_pyfixest_fepois, "pyfixest_fepois" + ) formulas = { 2: "negbin_y ~ x1 | indiv_id + year", 3: "negbin_y ~ x1 | indiv_id + year + firm_id", @@ -341,7 +354,12 @@ def get_estimators( "pyfixest_feglm_logit", ), ] - _append_optional_backends(estimators, "pyfixest.feglm_logit", run_pyfixest_feglm_logit, "pyfixest_feglm_logit") + _append_optional_backends( + estimators, + "pyfixest.feglm_logit", + run_pyfixest_feglm_logit, + "pyfixest_feglm_logit", + ) formulas = { 2: "binary_y ~ x1 | indiv_id + year", 3: "binary_y ~ x1 | indiv_id + year + firm_id", diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index fb201d835..19a9e50be 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -207,7 +207,11 @@ def _demean_torch_on_device_impl( # Move to torch — use from_numpy for zero-copy when staying on CPU x_c = x_2d if x_2d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(x_2d) - w_c = weights_1d if weights_1d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(weights_1d) + w_c = ( + weights_1d + if weights_1d.flags["C_CONTIGUOUS"] + else np.ascontiguousarray(weights_1d) + ) x_t = torch.from_numpy(x_c).to(dtype=dtype, device=device) w_t = torch.from_numpy(w_c).to(dtype=dtype, device=device) @@ -249,7 +253,7 @@ def _demean_torch_on_device_impl( success = True for k in range(K): - z, istop, itn, normr, normar, normA, condA, normx = lsmr_torch( + z, istop, _itn, _normr, _normar, _normA, _condA, _normx = lsmr_torch( A_precond, x_w[:, k], damp=0.0, @@ -320,7 +324,7 @@ def _make_demean_variant( dtype: torch.dtype, doc: str, ): - """Factory for device-specific demean wrappers.""" + """Create a device-specific demean wrapper.""" def _demean( x: NDArray[np.float64], @@ -330,8 +334,13 @@ def _demean( maxiter: int = 100_000, ) -> tuple[NDArray[np.float64], bool]: return _demean_torch_on_device( - x, flist, weights, tol, maxiter, - device=torch.device(device_str), dtype=dtype, + x, + flist, + weights, + tol, + maxiter, + device=torch.device(device_str), + dtype=dtype, ) _demean.__doc__ = doc @@ -339,10 +348,18 @@ def _demean( return _demean -demean_torch_cpu = _make_demean_variant("cpu", torch.float64, "Torch demeaner on CPU, float64.") -demean_torch_mps = _make_demean_variant("mps", torch.float32, "Torch demeaner on MPS (Apple GPU), float32.") -demean_torch_cuda = _make_demean_variant("cuda", torch.float64, "Torch demeaner on CUDA GPU, float64.") -demean_torch_cuda32 = _make_demean_variant("cuda", torch.float32, "Torch demeaner on CUDA GPU, float32.") +demean_torch_cpu = _make_demean_variant( + "cpu", torch.float64, "Torch demeaner on CPU, float64." +) +demean_torch_mps = _make_demean_variant( + "mps", torch.float32, "Torch demeaner on MPS (Apple GPU), float32." +) +demean_torch_cuda = _make_demean_variant( + "cuda", torch.float64, "Torch demeaner on CUDA GPU, float64." +) +demean_torch_cuda32 = _make_demean_variant( + "cuda", torch.float32, "Torch demeaner on CUDA GPU, float32." +) class _PreconditionedSparse: diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index a4172eb6f..1771f8be4 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -49,7 +49,7 @@ def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: def _matvec(A, v: torch.Tensor) -> torch.Tensor: - """A @ v — works for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" + """Compute A @ v for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" if isinstance(A, torch.Tensor): return A @ v return A.mv(v) @@ -251,10 +251,7 @@ def lsmr_torch( normx = torch.linalg.norm(x).item() test1 = normr / normb - if (normA * normr) != 0: - test2 = normar / (normA * normr) - else: - test2 = float("inf") + test2 = normar / (normA * normr) if normA * normr != 0 else float("inf") test3 = 1.0 / condA t1 = test1 / (1.0 + normA * normx / normb) rtol = btol + atol * normA * normx / normb diff --git a/tests/test_torch_lsmr.py b/tests/test_torch_lsmr.py index e4e94ee46..f1aa796c9 100644 --- a/tests/test_torch_lsmr.py +++ b/tests/test_torch_lsmr.py @@ -17,7 +17,6 @@ from pyfixest.estimation.torch.demean_torch_ import demean_torch # noqa: E402 from pyfixest.estimation.torch.lsmr_torch import lsmr_torch # noqa: E402 - # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @@ -55,7 +54,7 @@ def test_overdetermined_known_solution(self): b = A_dense @ x_true A_sparse = A_dense.to_sparse_csr() - x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + x_sol, istop, _itn, normr, _normar, *_ = lsmr_torch(A_sparse, b) assert istop in (1, 2), f"LSMR did not converge, istop={istop}" assert torch.allclose(x_sol, x_true, atol=1e-10), ( @@ -70,7 +69,7 @@ def test_underdetermined_min_norm(self): b = torch.randn(2, dtype=torch.float64) A_sparse = A_dense.to_sparse_csr() - x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + x_sol, _istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) residual = torch.norm(A_dense @ x_sol - b).item() assert residual < 1e-8, f"Residual too large: {residual}" @@ -100,18 +99,18 @@ def test_larger_sparse_system(self): A_dense = A_sparse.to_dense() b = A_dense @ x_true - x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + x_sol, istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) assert istop in (1, 2, 3), f"LSMR did not converge, istop={istop}" residual = torch.norm(A_dense @ x_sol - b).item() assert residual < 1e-5, f"Residual too large: {residual}" def test_zero_rhs(self): - """b = 0 should give x = 0.""" + """B = 0 should give x = 0.""" A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() b = torch.zeros(3, dtype=torch.float64) - x_sol, istop, itn, normr, normar, *_ = lsmr_torch(A_sparse, b) + x_sol, istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) assert torch.allclose(x_sol, torch.zeros(3, dtype=torch.float64), atol=1e-12) assert istop == 0 @@ -147,7 +146,7 @@ def test_full_return_tuple(self): A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() b = torch.ones(3, dtype=torch.float64) - x, istop, itn, normr, normar, normA, condA, normx = lsmr_torch(A_sparse, b) + _x, _istop, _itn, _normr, _normar, normA, condA, normx = lsmr_torch(A_sparse, b) assert isinstance(normA, float) and normA > 0 assert isinstance(condA, float) and condA >= 1.0 @@ -163,6 +162,7 @@ class TestDemeanVsPyhdfe: """Compare demean_torch against pyhdfe reference.""" def test_unweighted(self, demean_data): + """Verify unweighted demeaning matches pyhdfe.""" x, flist, _ = demean_data N = x.shape[0] weights = np.ones(N) @@ -181,6 +181,7 @@ def test_unweighted(self, demean_data): ) def test_weighted(self, demean_data): + """Verify weighted demeaning matches pyhdfe.""" x, flist, weights = demean_data N = x.shape[0] @@ -207,6 +208,7 @@ class TestDemeanVsScipy: """Compare demean_torch against demean_scipy (both use LSMR).""" def test_unweighted_vs_scipy(self, demean_data): + """Verify unweighted demeaning matches scipy LSMR.""" x, flist, _ = demean_data N = x.shape[0] weights = np.ones(N) @@ -230,6 +232,7 @@ def test_unweighted_vs_scipy(self, demean_data): ) def test_weighted_vs_scipy(self, demean_data): + """Verify weighted demeaning matches scipy LSMR.""" x, flist, weights = demean_data res_scipy, success_scipy = demean_scipy(x, flist, weights, tol=1e-10) @@ -306,7 +309,9 @@ def test_non_contiguous_group_ids_raises(self): N = 50 x = np.ones((N, 1)) # Groups [0, 5, 10] — non-contiguous, max+1=11 but only 3 unique - flist = np.array([0, 5, 10] * (N // 3) + [0] * (N % 3), dtype=np.uint64).reshape(N, 1) + flist = np.array( + [0, 5, 10] * (N // 3) + [0] * (N % 3), dtype=np.uint64 + ).reshape(N, 1) weights = np.ones(N) with pytest.raises(ValueError, match="non-contiguous group IDs"): From 1eecfd7ce968b80b2d3dd61df125e655c9701ce6 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 13:40:20 +0100 Subject: [PATCH 05/16] feat: compiled torch lsmr to avoid GPU syncs. - Added `lsmr_torch_fused.py` for a fused version of the LSMR algorithm, utilizing branchless Givens rotations and 0-d tensors to reduce CPU-GPU sync overhead. - Introduced tests for the new fused LSMR implementation in `test_lsmr_fused.py`, ensuring correctness against the original LSMR and benchmarking performance. - Created `test_lsmr_compiled.py` to validate the compiled version of the original LSMR, including auto-detection and MPS compatibility tests. --- pyfixest/estimation/torch/demean_torch_.py | 2 +- .../estimation/torch/lsmr_torch_compiled.py | 541 ++++++++++++++++++ pyfixest/estimation/torch/lsmr_torch_fused.py | 303 ++++++++++ tests/test_lsmr_compiled.py | 212 +++++++ tests/test_lsmr_fused.py | 159 +++++ 5 files changed, 1216 insertions(+), 1 deletion(-) create mode 100644 pyfixest/estimation/torch/lsmr_torch_compiled.py create mode 100644 pyfixest/estimation/torch/lsmr_torch_fused.py create mode 100644 tests/test_lsmr_compiled.py create mode 100644 tests/test_lsmr_fused.py diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 19a9e50be..d13b7b9ff 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -17,7 +17,7 @@ import torch from numpy.typing import NDArray -from pyfixest.estimation.torch.lsmr_torch import lsmr_torch +from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: diff --git a/pyfixest/estimation/torch/lsmr_torch_compiled.py b/pyfixest/estimation/torch/lsmr_torch_compiled.py new file mode 100644 index 000000000..c559bf225 --- /dev/null +++ b/pyfixest/estimation/torch/lsmr_torch_compiled.py @@ -0,0 +1,541 @@ +""" +LSMR iterative solver in pure PyTorch with optional torch.compile kernel fusion. + +The solver splits each iteration into three phases: + 1. Sparse matvec (A @ v, A.T @ u) — cannot be compiled (sparse CSR unsupported) + 2. Scalar Givens rotations + norm estimation + convergence — compiled on GPU + 3. Vector updates (h, hbar, x) — use scalar results from phase 2 + +Phase 2 involves ~60 scalar operations that, without compilation, dispatch as +~60 individual GPU kernels (~15μs each on MPS). torch.compile fuses them into +a SINGLE kernel via the Inductor backend (Metal shaders on MPS, CUDA kernels +on NVIDIA GPUs). + +Workarounds for MPS/Metal limitations: + - torch.hypot not in Metal codegen → overflow-safe manual hypot via max/min scaling + - Metal kernel limited to 31 buffer args → pack all scalars into 1-D tensors + +The safe manual hypot and packed layout work uniformly across all backends +(CPU, MPS, CUDA) with negligible overhead (~5%) after fusion. + +On CPU the scalar step runs without compilation (no kernel launch overhead +to eliminate), so the packed layout is the only difference from a traditional +scalar-state LSMR. + +Reference: + D. C.-L. Fong and M. A. Saunders, + "LSMR: An iterative algorithm for sparse least-squares problems", + SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. +""" + +from __future__ import annotations + +import threading + +import torch + +# --------------------------------------------------------------------------- +# Sparse matvec helpers (outside compiled region) +# --------------------------------------------------------------------------- + + +def _matvec(A, v: torch.Tensor) -> torch.Tensor: + if isinstance(A, torch.Tensor): + return A @ v + return A.mv(v) + + +def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: + if isinstance(A, torch.Tensor): + return A.t() @ u + return A.t().mv(u) + + +# --------------------------------------------------------------------------- +# Packed scalar state layout +# --------------------------------------------------------------------------- +# We pack all scalar state into a single 1-D tensor to minimize Metal buffer +# slots (hardware limit: 31 per kernel). +# +# Input state (20 elements): +_I_ALPHABAR = 0 +_I_DAMP = 1 +_I_BETA = 2 +_I_ALPHA = 3 +_I_SBAR = 4 +_I_CBAR = 5 +_I_ZETABAR = 6 +_I_RHO = 7 +_I_RHOBAR = 8 +_I_RHODOLD = 9 +_I_TAUTILDEOLD = 10 +_I_THETATILDE = 11 +_I_BETADD = 12 +_I_BETAD = 13 +_I_D = 14 +_I_NORMA2 = 15 +_I_MAXRBAR = 16 +_I_MINRBAR = 17 +_I_NORMB = 18 +_I_ZETA = 19 # previous iteration's zeta (for normr estimation) + +# Constants (3 elements): atol, btol, ctol + +# Output adds extra slots for vector update coefficients: +_O_THETANEW = 20 +_O_THETABAR = 21 +_O_ZETA = 22 +_O_RHOOLD = 23 +_O_RHOBAROLD = 24 +_O_CONVERGED = 25 +_O_NORMR = 26 +_O_NORMAR = 27 +_O_NORMA = 28 +_O_CONDA = 29 +_O_NORMX_EST = 30 # placeholder, actual normx computed from vector + +_STATE_SIZE = 20 + + +# --------------------------------------------------------------------------- +# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) +# --------------------------------------------------------------------------- + + +def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Overflow-safe hypot: ``sqrt(a² + b²)`` without intermediate overflow. + + Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)²)``. + Since ``min/max ≤ 1``, the argument to sqrt never exceeds 2. + Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + big = torch.maximum(abs_a, abs_b) + small = torch.minimum(abs_a, abs_b) + safe_big = torch.where(big == 0, torch.ones_like(big), big) + ratio = small / safe_big + return torch.where( + big == 0, + torch.zeros_like(big), + big * torch.sqrt(1.0 + ratio * ratio), + ) + + +# --------------------------------------------------------------------------- +# Compiled scalar step (single Metal/CUDA kernel after fusion) +# --------------------------------------------------------------------------- + + +def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: + """ + All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond + estimation, and convergence check. + + Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). + Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). + """ + # Unpack + alphabar = state[_I_ALPHABAR] + damp = state[_I_DAMP] + beta = state[_I_BETA] + alpha = state[_I_ALPHA] + sbar = state[_I_SBAR] + cbar = state[_I_CBAR] + zetabar = state[_I_ZETABAR] + rho = state[_I_RHO] + rhobar = state[_I_RHOBAR] + rhodold = state[_I_RHODOLD] + tautildeold = state[_I_TAUTILDEOLD] + thetatilde = state[_I_THETATILDE] + betadd = state[_I_BETADD] + betad = state[_I_BETAD] + d = state[_I_D] + normA2 = state[_I_NORMA2] + maxrbar = state[_I_MAXRBAR] + minrbar = state[_I_MINRBAR] + normb = state[_I_NORMB] + zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) + + atol_t = consts[0] + ctol = consts[2] + + _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero + _ONE = _ZERO + 1.0 + + # --- Givens 1: (alphabar, damp) --- + r1 = _safe_hypot(alphabar, damp) + safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) + chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) + shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) + + # --- Givens 2: (alphahat=r1, beta) --- + rhoold = rho + r2 = _safe_hypot(r1, beta) + safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) + c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) + s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) + rho_new = r2 + thetanew = s * alpha + alphabar_new = c * alpha + + # --- Givens 3: rhobar --- + rhobarold = rhobar + thetabar = sbar * rho_new + rhotemp = cbar * rho_new + r3 = _safe_hypot(rhotemp, thetanew) + safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) + cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) + sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) + rhobar_new = r3 + zeta = cbar_new * zetabar + zetabar_new = -sbar_new * zetabar + + # --- ||r|| estimation --- + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd_new = -s * betaacute + + # Givens 4: rhotilde + r4 = _safe_hypot(rhodold, thetabar) + safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) + ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) + stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) + + thetatilde_new = stildeold * rhobar_new + rhodold_new = ctildeold * rhobar_new + betad_new = -stildeold * betad + ctildeold * betahat + + tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp(r4, min=1e-30) + taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( + rhodold_new, min=1e-30 + ) + d_new = d + betacheck * betacheck + normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) + + # --- ||A|| estimation --- + normA2_new = normA2 + beta * beta + normA = torch.sqrt(normA2_new) + normA2_final = normA2_new + alpha * alpha + + # --- cond(A) estimation --- + maxrbar_new = torch.maximum(maxrbar, rhobarold) + minrbar_new = torch.minimum(minrbar, rhobarold) + condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( + torch.minimum(minrbar_new, rhotemp), min=1e-30 + ) + + # --- Convergence check --- + normar = torch.abs(zetabar_new) + test2 = normar / torch.clamp(normA * normr, min=1e-30) + test3 = _ONE / condA + + converged_flag = torch.where( + (test2 <= atol_t) + | (test3 <= ctol) + | (_ONE + test2 <= _ONE) + | (_ONE + test3 <= _ONE), + _ONE, + _ZERO, + ) + + # --- Pack output --- + return torch.stack( + [ + alphabar_new, # 0 _I_ALPHABAR + damp, # 1 _I_DAMP (pass through) + beta, # 2 _I_BETA (pass through, updated by caller) + alpha, # 3 _I_ALPHA (pass through, updated by caller) + sbar_new, # 4 _I_SBAR + cbar_new, # 5 _I_CBAR + zetabar_new, # 6 _I_ZETABAR + rho_new, # 7 _I_RHO + rhobar_new, # 8 _I_RHOBAR + rhodold_new, # 9 _I_RHODOLD + tautildeold_new, # 10 _I_TAUTILDEOLD + thetatilde_new, # 11 _I_THETATILDE + betadd_new, # 12 _I_BETADD + betad_new, # 13 _I_BETAD + d_new, # 14 _I_D + normA2_final, # 15 _I_NORMA2 + maxrbar_new, # 16 _I_MAXRBAR + minrbar_new, # 17 _I_MINRBAR + normb, # 18 _I_NORMB (pass through) + zeta, # 19 _I_ZETA (saved for next iteration's zetaold) + thetanew, # 20 _O_THETANEW (for vector update) + thetabar, # 21 _O_THETABAR (for vector update) + zeta, # 22 _O_ZETA (for vector update — same as slot 19) + rhoold, # 23 _O_RHOOLD (for vector update) + rhobarold, # 24 _O_RHOBAROLD (for vector update) + converged_flag, # 25 _O_CONVERGED + normr, # 26 _O_NORMR + normar, # 27 _O_NORMAR + normA, # 28 _O_NORMA + condA, # 29 _O_CONDA + _ZERO, # 30 _O_NORMX_EST (placeholder) + ] + ) + + +# --------------------------------------------------------------------------- +# Module-level compilation cache +# --------------------------------------------------------------------------- +_compiled_step_cache: dict[str, object] = {} +_cache_lock = threading.Lock() + + +def _get_compiled_step(device_type: str): + """Get or create compiled scalar step for the given device type.""" + if device_type in _compiled_step_cache: + return _compiled_step_cache[device_type] + with _cache_lock: + # Double-check after acquiring lock + if device_type not in _compiled_step_cache: + try: + _compiled_step_cache[device_type] = torch.compile( + _scalar_step, backend="inductor", fullgraph=True + ) + except Exception: + # Fallback: no compilation available + _compiled_step_cache[device_type] = _scalar_step + return _compiled_step_cache[device_type] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def lsmr_torch( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool | None = None, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR iterative solver for sparse least-squares, in pure PyTorch. + + Solves ``min ||b - Ax||_2`` (or the damped variant) where *A* is a sparse + CSR (or dense) tensor and *b* is a dense vector. All vector operations + stay on the tensor's device (CPU / CUDA / MPS). + + On GPU (MPS or CUDA) the scalar Givens rotations, norm estimation, and + convergence check are fused into a **single GPU kernel** via + ``torch.compile`` + Inductor, eliminating ~60 per-iteration kernel + launches. On CPU the scalar step runs without compilation (no kernel- + launch overhead to eliminate). + + The while loop itself cannot be compiled because sparse CSR matvec is + not supported in ``torch.compile``. The single remaining CPU-GPU sync + per iteration (reading the convergence flag) is negligible after + compilation fuses the scalar step into one kernel. + + Parameters + ---------- + A : torch.Tensor + Sparse CSR tensor (or dense tensor / LinearOperator with ``.mv`` + and ``.t().mv`` support) of shape ``(m, n)``. + b : torch.Tensor + Dense vector of shape ``(m,)``. + damp : float + Damping factor for regularised least-squares. + atol, btol : float + Stopping tolerances (see SciPy LSMR documentation). + conlim : float + Condition-number limit. + maxiter : int or None + Maximum iterations. Defaults to ``min(m, n)``. + use_compile : bool or None + Whether to ``torch.compile`` the scalar step. ``None`` (default) + auto-selects: **True** on GPU, **False** on CPU. + + Returns + ------- + x : torch.Tensor + Solution vector of shape ``(n,)``. + istop : int + Reason for stopping (0-7, same codes as SciPy LSMR). + itn : int + Number of iterations used. + normr, normar, normA, condA, normx : float + Diagnostic norms (see SciPy LSMR documentation). + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Auto-detect compilation + if use_compile is None: + use_compile = device.type in ("cuda", "mps") + + # Get compiled or uncompiled step function + step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b) + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb.clone() + + # Safe normalize + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), beta * 0.0) + + v = _rmatvec(A, u) + alpha = torch.linalg.norm(v) + v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), alpha * 0.0) + + # --- Pack initial scalar state --- + state = torch.zeros(_STATE_SIZE, device=device, dtype=dtype) + state[_I_ALPHABAR] = alpha + state[_I_DAMP] = damp + state[_I_BETA] = beta + state[_I_ALPHA] = alpha + state[_I_SBAR] = 0.0 + state[_I_CBAR] = 1.0 + state[_I_ZETABAR] = alpha * beta + state[_I_RHO] = 1.0 + state[_I_RHOBAR] = 1.0 + state[_I_RHODOLD] = 1.0 + state[_I_TAUTILDEOLD] = 0.0 + state[_I_THETATILDE] = 0.0 + state[_I_BETADD] = beta + state[_I_BETAD] = 0.0 + state[_I_D] = 0.0 + state[_I_NORMA2] = alpha * alpha + state[_I_MAXRBAR] = 0.0 + state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 + state[_I_NORMB] = normb + state[_I_ZETA] = 0.0 # initial zeta (no previous iteration) + + ctol = 1.0 / conlim if conlim > 0 else 0.0 + consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) + + # Early exit check + normar_init = (alpha * beta).item() + if normar_init == 0.0: + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + if normb.item() == 0.0: + x.zero_() + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # --- Main iteration loop --- + itn = 0 + istop = 0 + + while itn < maxiter: + itn += 1 + + # Phase 1: Sparse matvec (not compilable) + u = _matvec(A, v) - state[_I_ALPHA] * u + beta_new = torch.linalg.norm(u) + u = u * torch.where( + beta_new > 0, + 1.0 / torch.clamp(beta_new, min=1e-30), + beta_new * 0.0, + ) + + v = _rmatvec(A, u) - beta_new * v + alpha_new = torch.linalg.norm(v) + v = v * torch.where( + alpha_new > 0, + 1.0 / torch.clamp(alpha_new, min=1e-30), + alpha_new * 0.0, + ) + + # Update beta/alpha in state for the scalar step + state[_I_BETA] = beta_new + state[_I_ALPHA] = alpha_new + + # Phase 2: Compiled scalar step (single GPU kernel on MPS/CUDA) + out = step_fn(state, consts) + + # Phase 3: Vector updates using scalar results from compiled step + thetanew = out[_O_THETANEW] + thetabar = out[_O_THETABAR] + zeta = out[_O_ZETA] + rho_new = out[_I_RHO] + rhobar_new = out[_I_RHOBAR] + rhoold = out[_O_RHOOLD] + rhobarold = out[_O_RHOBAROLD] + + hbar = h + hbar * (-(thetabar * rho_new) / (rhoold * rhobarold)) + x = x + (zeta / (rho_new * rhobar_new)) * hbar + h = v + h * (-(thetanew / rho_new)) + + # Propagate state for next iteration + state = out[:_STATE_SIZE] + + # Convergence check — single .item() sync per iteration. + # After compilation the scalar step is one kernel, so this sync + # (reading a boolean from GPU memory) is negligible. + # + # The compiled _scalar_step checks test2 (atol) and test3 (ctol). + # The btol-based test1 depends on normx (a vector quantity) and is + # checked here in Python since it can't be part of the compiled kernel. + converged_scalar = out[_O_CONVERGED].item() > 0.5 + + # btol / test1 convergence (requires normx from vector x) + normr_val = out[_O_NORMR].item() + normA_val = out[_O_NORMA].item() + normx_val = torch.linalg.norm(x).item() + normb_val = out[_I_NORMB].item() + + test1 = normr_val / max(normb_val, 1e-30) + t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, 1e-30)) + rtol = btol + atol * normA_val * normx_val / max(normb_val, 1e-30) + converged_btol = test1 <= rtol + converged_btol_machine = 1.0 + t1 <= 1.0 + + if converged_scalar or converged_btol or converged_btol_machine: + normar_val = out[_O_NORMAR].item() + condA_val = out[_O_CONDA].item() + + test2 = normar_val / max(normA_val * normr_val, 1e-30) + test3 = 1.0 / condA_val + + # Priority order matches SciPy LSMR (highest istop wins) + if 1.0 + test3 <= 1.0: + istop = 6 + if 1.0 + test2 <= 1.0: + istop = 5 + if converged_btol_machine: + istop = 4 + if test3 <= ctol: + istop = 3 + if test2 <= atol: + istop = 2 + if converged_btol: + istop = 1 + break + + if itn >= maxiter and istop == 0: + istop = 7 + + # Handle case where loop never ran (maxiter=0 or similar) + if itn == 0: + return x, istop, 0, normb.item(), normar_init, alpha.item(), 1.0, 0.0 + + normx_val = torch.linalg.norm(x).item() + return ( + x, + istop, + itn, + out[_O_NORMR].item(), + out[_O_NORMAR].item(), + out[_O_NORMA].item(), + out[_O_CONDA].item(), + normx_val, + ) diff --git a/pyfixest/estimation/torch/lsmr_torch_fused.py b/pyfixest/estimation/torch/lsmr_torch_fused.py new file mode 100644 index 000000000..ae1e1cb98 --- /dev/null +++ b/pyfixest/estimation/torch/lsmr_torch_fused.py @@ -0,0 +1,303 @@ +""" +On-device LSMR: all scalar state as 0-d tensors to eliminate CPU-GPU sync. + +Compared to lsmr_torch.py: +- Branchless Givens via torch.hypot (no if/elif, no Python math) +- All scalar state as 0-d device tensors (no .item() in hot loop) +- Convergence check via logical indexing +- Single sync point every `check_every` iterations + +Designed for CUDA/MPS where CPU-GPU synchronization is expensive. + +Reference: + D. C.-L. Fong and M. A. Saunders, + "LSMR: An iterative algorithm for sparse least-squares problems", + SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. +""" + +from __future__ import annotations + +import torch + + +# --------------------------------------------------------------------------- +# Branchless Givens rotation +# --------------------------------------------------------------------------- + + +def _sym_ortho_t( + a: torch.Tensor, b: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Branchless Givens rotation for 0-d tensors on device. + + Equivalent to SciPy's ``_sym_ortho`` but implemented with + ``torch.hypot`` (overflow-safe) and ``torch.where`` (branchless). + No CPU-GPU synchronization occurs. + """ + r = torch.hypot(a, b) + # Guard division: when r == 0, return (0, 0, 0) + safe_r = torch.where(r == 0, torch.ones_like(r), r) + c = torch.where(r == 0, torch.zeros_like(a), a / safe_r) + s = torch.where(r == 0, torch.zeros_like(b), b / safe_r) + return c, s, r + + +# --------------------------------------------------------------------------- +# Sparse matvec helpers +# --------------------------------------------------------------------------- + + +def _matvec(A, v: torch.Tensor) -> torch.Tensor: + if isinstance(A, torch.Tensor): + return A @ v + return A.mv(v) + + +def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: + if isinstance(A, torch.Tensor): + return A.t() @ u + return A.t().mv(u) + + +# --------------------------------------------------------------------------- +# On-device LSMR +# --------------------------------------------------------------------------- + + +def lsmr_torch_fused( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + check_every: int = 10, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR iterative solver with minimal CPU-GPU synchronization. + + Same algorithm as :func:`lsmr_torch`, but keeps **all** scalar state + as 0-d tensors on the compute device. CPU-GPU sync only happens once + every *check_every* iterations (for the convergence check), reducing + pipeline stalls from 3/iteration to 1/N. + + Parameters + ---------- + A : torch.Tensor + Sparse CSR tensor of shape (m, n). + b : torch.Tensor + Dense vector of shape (m,). + damp : float + Damping factor for regularized least-squares. + atol, btol : float + Stopping tolerances (see SciPy LSMR docs). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + check_every : int + How often to sync to CPU for convergence check (default: 10). + Higher values reduce syncs but may overshoot convergence by up to + ``check_every - 1`` iterations. + + Returns + ------- + x, istop, itn, normr, normar, normA, condA, normx + Same signature as :func:`lsmr_torch`. + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Scalar constant factory + def _s(val: float) -> torch.Tensor: + return torch.tensor(val, device=device, dtype=dtype) + + _TINY = _s(1e-30) + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b) + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb.clone() + + # Branchless safe-normalize: u /= beta (or zero if beta == 0) + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), _s(0.0)) + + v = _rmatvec(A, u) + alpha = torch.linalg.norm(v) + v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), _s(0.0)) + + # --- Scalar state (all 0-d device tensors) --- + zetabar = alpha * beta + alphabar = alpha.clone() + rho = _s(1.0) + rhobar = _s(1.0) + cbar = _s(1.0) + sbar = _s(0.0) + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # ||r|| estimation + betadd = beta.clone() + betad = _s(0.0) + rhodold = _s(1.0) + tautildeold = _s(0.0) + thetatilde = _s(0.0) + zeta = _s(0.0) + d = _s(0.0) + + # ||A|| and cond(A) estimation + normA2 = alpha * alpha + maxrbar = _s(0.0) + minrbar = _s(1e100) + normA = torch.sqrt(normA2) + condA = _s(1.0) + normx_est = _s(0.0) + + # Stopping + ctol = _s(1.0 / conlim if conlim > 0 else 0.0) + normr = beta.clone() + normar = alpha * beta + + # Pre-create tolerance tensors + atol_t = _s(atol) + btol_t = _s(btol) + damp_t = _s(damp) + + # Early exit (syncs once at init — unavoidable) + if normar.item() == 0.0: + return x, 0, 0, normr.item(), normar.item(), normA.item(), condA.item(), 0.0 + if normb.item() == 0.0: + x.zero_() + return x, 0, 0, normr.item(), normar.item(), normA.item(), condA.item(), 0.0 + + # --- Main iteration loop (zero syncs inside, except periodic check) --- + itn = 0 + istop = 0 + + while itn < maxiter: + itn += 1 + + # Bidiagonalization step + u = _matvec(A, v) - alpha * u + beta = torch.linalg.norm(u) + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), _s(0.0)) + + v = _rmatvec(A, u) - beta * v + alpha = torch.linalg.norm(v) + v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), _s(0.0)) + + # Givens rotations (branchless, on device) + chat, shat, alphahat = _sym_ortho_t(alphabar, damp_t) + + rhoold = rho + c, s, rho = _sym_ortho_t(alphahat, beta) + thetanew = s * alpha + alphabar = c * alpha + + rhobarold = rhobar + zetaold = zeta + thetabar = sbar * rho + rhotemp = cbar * rho + cbar, sbar, rhobar = _sym_ortho_t(cbar * rho, thetanew) + zeta = cbar * zetabar + zetabar = -sbar * zetabar + + # Vector updates (on device) + hbar = h + hbar * (-(thetabar * rho) / (rhoold * rhobarold)) + x = x + (zeta / (rho * rhobar)) * hbar + h = v + h * (-(thetanew / rho)) + + # ||r|| estimation + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd = -s * betaacute + + thetatildeold = thetatilde + ctildeold, stildeold, rhotildeold = _sym_ortho_t(rhodold, thetabar) + thetatilde = stildeold * rhobar + rhodold = ctildeold * rhobar + betad = -stildeold * betad + ctildeold * betahat + + tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold + taud = (zeta - thetatilde * tautildeold) / rhodold + d = d + betacheck * betacheck + normr = torch.sqrt(d + (betad - taud) ** 2 + betadd * betadd) + + # ||A|| estimation + normA2 = normA2 + beta * beta + normA = torch.sqrt(normA2) + normA2 = normA2 + alpha * alpha + + # cond(A) estimation + maxrbar = torch.maximum(maxrbar, rhobarold) + if itn > 1: + minrbar = torch.minimum(minrbar, rhobarold) + condA = torch.maximum(maxrbar, rhotemp) / torch.minimum(minrbar, rhotemp) + + # Convergence check (all logical ops on device, no sync) + normar = torch.abs(zetabar) + normx_est = torch.linalg.norm(x) + + test1 = normr / normb + test2 = normar / torch.clamp(normA * normr, min=1e-30) + test3 = 1.0 / condA + t1 = test1 / (1.0 + normA * normx_est / normb) + rtol = btol_t + atol_t * normA * normx_est / normb + + converged = ( + (test1 <= rtol) + | (test2 <= atol_t) + | (test3 <= ctol) + | (1.0 + t1 <= 1.0) + | (1.0 + test2 <= 1.0) + | (1.0 + test3 <= 1.0) + ) + + # --- Periodic sync: single .item() every check_every iterations --- + if itn % check_every == 0 or itn >= maxiter: + if converged.item(): + # Determine exact istop code (one-time sync at exit) + _test1 = test1.item() + _test2 = test2.item() + _test3 = test3.item() + _t1 = t1.item() + _rtol = rtol.item() + _ctol = ctol.item() + + if 1.0 + _test3 <= 1.0: + istop = 6 + elif 1.0 + _test2 <= 1.0: + istop = 5 + elif 1.0 + _t1 <= 1.0: + istop = 4 + elif _test3 <= _ctol: + istop = 3 + elif _test2 <= atol: + istop = 2 + elif _test1 <= _rtol: + istop = 1 + break + + if itn >= maxiter and istop == 0: + istop = 7 + + return ( + x, + istop, + itn, + normr.item(), + normar.item(), + normA.item(), + condA.item(), + normx_est.item(), + ) diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py new file mode 100644 index 000000000..2ece2bf60 --- /dev/null +++ b/tests/test_lsmr_compiled.py @@ -0,0 +1,212 @@ +""" +Tests for lsmr_torch (compiled version): correctness, auto-detection, MPS torch.compile. + +Usage: + KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev python -m pytest tests/test_lsmr_compiled.py -v -s +""" + +from __future__ import annotations + +import time + +import numpy as np +import pytest +import torch + +# Reference: original scalar-state LSMR (for correctness comparison) +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch as lsmr_torch_original +from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sparse_problem(m: int, n: int, density: float = 0.01, seed: int = 42): + """Create a sparse CSR system A and dense rhs b.""" + rng = np.random.default_rng(seed) + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + A_coo = torch.sparse_coo_tensor( + torch.tensor(np.stack([rows, cols])), + torch.tensor(vals, dtype=torch.float64), + size=(m, n), + ) + A_csr = A_coo.to_sparse_csr() + b = torch.tensor(rng.standard_normal(m), dtype=torch.float64) + return A_csr, b + + +# --------------------------------------------------------------------------- +# Correctness tests (CPU, f64 - auto-detects use_compile=False) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("m,n", [(200, 100), (500, 300), (1000, 500)]) +def test_matches_original_cpu(m, n): + """New lsmr_torch (auto CPU mode) matches original on CPU f64.""" + A, b = _make_sparse_problem(m, n) + + x_orig, _istop_orig, itn_orig, *_ = lsmr_torch_original(A, b) + x_new, _istop_new, itn_new, *_ = lsmr_torch(A, b) + + assert torch.allclose(x_orig, x_new, atol=1e-6, rtol=1e-6), ( + f"Solutions differ: max_diff={torch.max(torch.abs(x_orig - x_new)).item():.2e}" + ) + assert itn_orig == itn_new, f"itn differs: {itn_orig} vs {itn_new}" + + +def test_zero_rhs(): + """B = 0 should return x = 0.""" + A, _ = _make_sparse_problem(100, 50) + b = torch.zeros(100, dtype=torch.float64) + x, _istop, itn, *_ = lsmr_torch(A, b) + assert torch.all(x == 0) + assert itn == 0 + + +def test_damping(): + """Damped solve should differ from undamped.""" + A, b = _make_sparse_problem(200, 100) + x_undamped, *_ = lsmr_torch(A, b, damp=0.0) + x_damped, *_ = lsmr_torch(A, b, damp=1.0) + assert not torch.allclose(x_undamped, x_damped, atol=1e-3) + + +def test_diagnostics_match_original(): + """normr, normar, normA, condA, normx diagnostics match reference.""" + A, b = _make_sparse_problem(500, 300) + + _x_o, istop_orig, _itn_o, normr_o, normar_o, normA_o, _condA_o, _normx_o = ( + lsmr_torch_original(A, b) + ) + _x_n, istop_new, _itn_n, normr_n, normar_n, normA_n, _condA_n, _normx_n = ( + lsmr_torch(A, b) + ) + + assert istop_orig == istop_new, f"istop: {istop_orig} vs {istop_new}" + assert abs(normr_o - normr_n) / max(normr_o, 1e-30) < 1e-6, ( + f"normr: {normr_o:.6e} vs {normr_n:.6e}" + ) + assert abs(normar_o - normar_n) / max(normar_o, 1e-30) < 1e-6 + assert abs(normA_o - normA_n) / max(normA_o, 1e-30) < 1e-6 + + +def test_btol_convergence(): + """btol-based stopping (istop=1) should work and match reference.""" + A, b = _make_sparse_problem(200, 100) + + # Use tight btol but very loose atol - should converge via test1, not test2 + _, istop_orig, itn_orig, *_ = lsmr_torch_original(A, b, atol=1e-2, btol=1e-10) + _, istop_new, itn_new, *_ = lsmr_torch(A, b, atol=1e-2, btol=1e-10) + + # Both should converge; compiled should match reference istop + assert istop_new == istop_orig, f"istop: {istop_orig} vs {istop_new}" + assert itn_new == itn_orig, f"itn: {itn_orig} vs {itn_new}" + + +# --------------------------------------------------------------------------- +# Auto-detection tests +# --------------------------------------------------------------------------- + + +def test_auto_cpu_defaults(): + """On CPU tensors, auto-selects use_compile=False.""" + A, b = _make_sparse_problem(200, 100) + x, istop, _itn, *_ = lsmr_torch(A, b) + assert x.device.type == "cpu" + assert istop in range(8) + + +# --------------------------------------------------------------------------- +# MPS + torch.compile tests +# --------------------------------------------------------------------------- + +HAS_MPS = torch.backends.mps.is_available() + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +def test_correctness_mps(): + """lsmr_torch on MPS f32 produces reasonable results vs CPU f64.""" + m, n = 500, 300 + A_cpu, b_cpu = _make_sparse_problem(m, n) + + x_ref, *_ = lsmr_torch_original(A_cpu, b_cpu) + + # MPS f32 (auto: use_compile=True) + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + b_mps = b_cpu.to(torch.float32).to("mps") + + x_mps, *_ = lsmr_torch(A_mps, b_mps) + + max_diff = torch.max(torch.abs(x_ref.float() - x_mps.cpu())).item() + assert max_diff < 0.1, f"MPS f32 vs CPU f64 too different: {max_diff:.2e}" + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +def test_compiled_vs_uncompiled_mps(): + """Compiled and uncompiled give same results on MPS.""" + m, n = 500, 300 + A_cpu, b_cpu = _make_sparse_problem(m, n) + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + b_mps = b_cpu.to(torch.float32).to("mps") + + x_comp, *_ = lsmr_torch(A_mps, b_mps, use_compile=True) + x_nocomp, *_ = lsmr_torch(A_mps, b_mps, use_compile=False) + + max_diff = torch.max(torch.abs(x_comp - x_nocomp)).item() + assert max_diff < 1e-5, f"Compiled vs uncompiled differ: {max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# Timing benchmark (run with -s to see output) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +@pytest.mark.parametrize("m,n", [(5000, 2000), (10000, 5000)]) +def test_timing_mps(m, n): + """Compare compiled vs uncompiled LSMR timing on MPS.""" + rng = np.random.default_rng(42) + nnz = int(m * n * 0.005) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz).astype(np.float32) + + A = torch.zeros(m, n, dtype=torch.float32) + for r, c, v in zip(rows, cols, vals): + A[r, c] += v + A = A.to("mps") + b = torch.tensor(rng.standard_normal(m).astype(np.float32), device="mps") + + # Warmup + lsmr_torch(A, b, use_compile=True) + lsmr_torch(A, b, use_compile=False) + torch.mps.synchronize() + + # Compiled + torch.mps.synchronize() + t0 = time.perf_counter() + x_comp, _, itn_comp, *_ = lsmr_torch(A, b, use_compile=True) + torch.mps.synchronize() + t_comp = time.perf_counter() - t0 + + # Uncompiled + torch.mps.synchronize() + t0 = time.perf_counter() + x_nocomp, _, itn_nocomp, *_ = lsmr_torch(A, b, use_compile=False) + torch.mps.synchronize() + t_nocomp = time.perf_counter() - t0 + + speedup = t_nocomp / t_comp if t_comp > 0 else float("inf") + print( + f"\n [{m}x{n}] compiled: {t_comp:.3f}s ({itn_comp} iters) | " + f"uncompiled: {t_nocomp:.3f}s ({itn_nocomp} iters) | " + f"speedup: {speedup:.2f}x" + ) + + # Correctness + assert torch.allclose(x_comp, x_nocomp, atol=1e-4, rtol=1e-4) diff --git a/tests/test_lsmr_fused.py b/tests/test_lsmr_fused.py new file mode 100644 index 000000000..be612e834 --- /dev/null +++ b/tests/test_lsmr_fused.py @@ -0,0 +1,159 @@ +""" +Correctness and timing tests for lsmr_torch_fused vs lsmr_torch. + +Usage: + KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev python -m pytest tests/test_lsmr_fused.py -v -s +""" + +from __future__ import annotations + +import time + +import numpy as np +import pytest +import torch + +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch +from pyfixest.estimation.torch.lsmr_torch_fused import lsmr_torch_fused + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sparse_problem(m: int, n: int, density: float = 0.01, seed: int = 42): + """Create a sparse CSR system A and dense rhs b on the given device.""" + rng = np.random.default_rng(seed) + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + A_coo = torch.sparse_coo_tensor( + torch.tensor(np.stack([rows, cols])), + torch.tensor(vals, dtype=torch.float64), + size=(m, n), + ) + A_csr = A_coo.to_sparse_csr() + b = torch.tensor(rng.standard_normal(m), dtype=torch.float64) + return A_csr, b + + +# --------------------------------------------------------------------------- +# Correctness tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("m,n", [(200, 100), (500, 300), (1000, 500)]) +def test_fused_matches_original(m, n): + """The fused solver should produce the same solution as the original.""" + A, b = _make_sparse_problem(m, n) + + x_orig, istop_orig, itn_orig, *_ = lsmr_torch(A, b) + x_fused, istop_fused, itn_fused, *_ = lsmr_torch_fused(A, b, check_every=1) + + # Solutions should match closely + assert torch.allclose(x_orig, x_fused, atol=1e-6, rtol=1e-6), ( + f"Solutions differ: max_diff={torch.max(torch.abs(x_orig - x_fused)).item():.2e}" + ) + # Iteration count may differ slightly due to convergence check frequency + assert abs(itn_orig - itn_fused) <= 1, f"itn differs: {itn_orig} vs {itn_fused}" + + +@pytest.mark.parametrize("check_every", [1, 5, 10, 50]) +def test_check_every_correctness(check_every): + """Different check_every values should all converge to the same solution.""" + A, b = _make_sparse_problem(500, 300) + x_ref, *_ = lsmr_torch(A, b) + x_fused, _, itn, *_ = lsmr_torch_fused(A, b, check_every=check_every) + + assert torch.allclose(x_ref, x_fused, atol=1e-6, rtol=1e-6), ( + f"check_every={check_every}: max_diff=" + f"{torch.max(torch.abs(x_ref - x_fused)).item():.2e}" + ) + + +def test_zero_rhs(): + """b = 0 should return x = 0.""" + A, _ = _make_sparse_problem(100, 50) + b = torch.zeros(100, dtype=torch.float64) + x, istop, itn, *_ = lsmr_torch_fused(A, b) + assert torch.all(x == 0) + assert itn == 0 + + +def test_damping(): + """Damped solve should differ from undamped.""" + A, b = _make_sparse_problem(200, 100) + x_undamped, *_ = lsmr_torch_fused(A, b, damp=0.0) + x_damped, *_ = lsmr_torch_fused(A, b, damp=1.0) + assert not torch.allclose(x_undamped, x_damped, atol=1e-3) + + +# --------------------------------------------------------------------------- +# Branchless _sym_ortho tests +# --------------------------------------------------------------------------- + + +def test_sym_ortho_matches_scipy(): + """Branchless _sym_ortho_t should match SciPy's convention.""" + import math + + from pyfixest.estimation.torch.lsmr_torch import _sym_ortho + from pyfixest.estimation.torch.lsmr_torch_fused import _sym_ortho_t + + cases = [ + (3.0, 4.0), + (-3.0, 4.0), + (3.0, -4.0), + (-3.0, -4.0), + (0.0, 5.0), + (5.0, 0.0), + (0.0, 0.0), + (1e-300, 1e-300), + (1e300, 1e300), + (1.0, 1e-15), + ] + for a_val, b_val in cases: + c_ref, s_ref, r_ref = _sym_ortho(a_val, b_val) + a_t = torch.tensor(a_val, dtype=torch.float64) + b_t = torch.tensor(b_val, dtype=torch.float64) + c_t, s_t, r_t = _sym_ortho_t(a_t, b_t) + assert abs(c_t.item() - c_ref) < 1e-10, f"c mismatch for ({a_val}, {b_val})" + assert abs(s_t.item() - s_ref) < 1e-10, f"s mismatch for ({a_val}, {b_val})" + assert abs(r_t.item() - r_ref) < 1e-10, f"r mismatch for ({a_val}, {b_val})" + + +# --------------------------------------------------------------------------- +# Timing benchmark (not a test — run with -s to see output) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("m,n", [(5000, 2000), (10000, 5000)]) +def test_timing_comparison(m, n): + """Compare wall time of original vs fused LSMR.""" + A, b = _make_sparse_problem(m, n, density=0.005) + + # Warmup + lsmr_torch(A, b, maxiter=5) + lsmr_torch_fused(A, b, maxiter=5) + + # Original + t0 = time.perf_counter() + x_orig, _, itn_orig, *_ = lsmr_torch(A, b) + t_orig = time.perf_counter() - t0 + + # Fused (check every 10) + t0 = time.perf_counter() + x_fused, _, itn_fused, *_ = lsmr_torch_fused(A, b, check_every=10) + t_fused = time.perf_counter() - t0 + + speedup = t_orig / t_fused if t_fused > 0 else float("inf") + print( + f"\n [{m}x{n}] original: {t_orig:.3f}s ({itn_orig} iters) | " + f"fused: {t_fused:.3f}s ({itn_fused} iters) | " + f"speedup: {speedup:.2f}x" + ) + + # Correctness sanity check + assert torch.allclose(x_orig, x_fused, atol=1e-5, rtol=1e-5) From 58cd00d6b5e0c2836177b7068460e4b1c820f441 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 16:32:34 +0100 Subject: [PATCH 06/16] refactor: dispatch between LSMR scalar vs compiled. - also enhance GPU efficiency with pre-computed transpose --- pyfixest/estimation/torch/demean_torch_.py | 31 ++++++++++++--- pyfixest/estimation/torch/lsmr_torch.py | 38 ++++++++++++++++++- .../estimation/torch/lsmr_torch_compiled.py | 15 ++++---- tests/test_lsmr_compiled.py | 2 +- 4 files changed, 72 insertions(+), 14 deletions(-) diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index d13b7b9ff..f55e21944 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -17,7 +17,7 @@ import torch from numpy.typing import NDArray -from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: @@ -372,10 +372,18 @@ class _PreconditionedSparse: The transpose view is cached and returned by `.t()`, so LSMR's repeated `A.t().mv(u)` calls don't allocate a new object each time. + + D^T is pre-computed once in a GPU-friendly sparse layout to avoid + per-iteration reconversion (COO coalesce on MPS, CSR radixSort on CUDA). """ def __init__( - self, D: torch.Tensor, M_inv: torch.Tensor, *, _transposed: bool = False + self, + D: torch.Tensor, + M_inv: torch.Tensor, + *, + _transposed: bool = False, + _D_t: torch.Tensor | None = None, ): m, n = D.shape self.shape = (n, m) if _transposed else (m, n) @@ -383,11 +391,23 @@ def __init__( self._M_inv = M_inv self._transposed = _transposed self._T: _PreconditionedSparse | None = None + self._D_t = _D_t if _D_t is not None else self._materialize_transpose(D) + + @staticmethod + def _materialize_transpose(D: torch.Tensor) -> torch.Tensor: + """Pre-compute D^T in a GPU-friendly sparse layout.""" + D_t = D.t() + layout = D_t.layout + if layout == torch.sparse_coo: + return D_t.coalesce() + if layout in (torch.sparse_csr, torch.sparse_csc): + return D_t.to_sparse_csr() + return D_t def mv(self, v: torch.Tensor) -> torch.Tensor: if self._transposed: - # Compute M_inv * (D^T @ u) - return self._M_inv * (self._D.t() @ v) + # Compute M_inv * (D^T @ u) — uses pre-computed transpose + return self._M_inv * (self._D_t @ v) # Compute D @ (M_inv * v) return self._D @ (self._M_inv * v) @@ -395,7 +415,8 @@ def t(self) -> _PreconditionedSparse: """Return cached transpose view.""" if self._T is None: self._T = _PreconditionedSparse( - self._D, self._M_inv, _transposed=not self._transposed + self._D, self._M_inv, + _transposed=not self._transposed, _D_t=self._D_t, ) self._T._T = self # cross-link so .t().t() returns self return self._T diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index 1771f8be4..784908661 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -17,6 +17,8 @@ import torch +from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch as _lsmr_compiled + def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: """ @@ -62,7 +64,7 @@ def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: return A.t().mv(u) -def lsmr_torch( +def _lsmr_scalar( A, b: torch.Tensor, damp: float = 0.0, @@ -275,3 +277,37 @@ def lsmr_torch( break return x, istop, itn, normr, normar, normA, condA, normx + + +def lsmr_torch( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool | None = None, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR solver — unified entry point. + + Auto-selects implementation based on device: + - CUDA: compiled (torch.compile fuses scalar step into 1 kernel) + - CPU/MPS: scalar (Python-float math, no compilation overhead) + + Pass use_compile=True to force compilation on any device. + """ + device = b.device + if use_compile is None: + use_compile = device.type == "cuda" + + if use_compile: + return _lsmr_compiled( + A, b, damp=damp, atol=atol, btol=btol, + conlim=conlim, maxiter=maxiter, use_compile=True, + ) + return _lsmr_scalar( + A, b, damp=damp, atol=atol, btol=btol, + conlim=conlim, maxiter=maxiter, + ) diff --git a/pyfixest/estimation/torch/lsmr_torch_compiled.py b/pyfixest/estimation/torch/lsmr_torch_compiled.py index c559bf225..6ebb576ef 100644 --- a/pyfixest/estimation/torch/lsmr_torch_compiled.py +++ b/pyfixest/estimation/torch/lsmr_torch_compiled.py @@ -325,11 +325,12 @@ def lsmr_torch( CSR (or dense) tensor and *b* is a dense vector. All vector operations stay on the tensor's device (CPU / CUDA / MPS). - On GPU (MPS or CUDA) the scalar Givens rotations, norm estimation, and - convergence check are fused into a **single GPU kernel** via - ``torch.compile`` + Inductor, eliminating ~60 per-iteration kernel - launches. On CPU the scalar step runs without compilation (no kernel- - launch overhead to eliminate). + On CUDA the scalar Givens rotations, norm estimation, and convergence + check are fused into a **single GPU kernel** via ``torch.compile`` + + Inductor, eliminating ~60 per-iteration kernel launches. On CPU and + MPS the scalar step runs without compilation (MPS's Metal command + buffer batching already amortizes kernel launch overhead, so + compilation adds dispatch cost without measurable benefit). The while loop itself cannot be compiled because sparse CSR matvec is not supported in ``torch.compile``. The single remaining CPU-GPU sync @@ -353,7 +354,7 @@ def lsmr_torch( Maximum iterations. Defaults to ``min(m, n)``. use_compile : bool or None Whether to ``torch.compile`` the scalar step. ``None`` (default) - auto-selects: **True** on GPU, **False** on CPU. + auto-selects: **True** on CUDA, **False** on CPU and MPS. Returns ------- @@ -375,7 +376,7 @@ def lsmr_torch( # Auto-detect compilation if use_compile is None: - use_compile = device.type in ("cuda", "mps") + use_compile = device.type == "cuda" # Get compiled or uncompiled step function step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py index 2ece2bf60..078f5f907 100644 --- a/tests/test_lsmr_compiled.py +++ b/tests/test_lsmr_compiled.py @@ -14,7 +14,7 @@ import torch # Reference: original scalar-state LSMR (for correctness comparison) -from pyfixest.estimation.torch.lsmr_torch import lsmr_torch as lsmr_torch_original +from pyfixest.estimation.torch.lsmr_torch import _lsmr_scalar as lsmr_torch_original from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch # --------------------------------------------------------------------------- From dc8fb66c7db1f4fe3d2bdc9b3918cda9300016c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:51:02 +0000 Subject: [PATCH 07/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyfixest/estimation/torch/demean_torch_.py | 6 ++++-- pyfixest/estimation/torch/lsmr_torch.py | 19 +++++++++++++++---- pyfixest/estimation/torch/lsmr_torch_fused.py | 1 - tests/test_lsmr_fused.py | 4 +--- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index f55e21944..8a9e51d4f 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -415,8 +415,10 @@ def t(self) -> _PreconditionedSparse: """Return cached transpose view.""" if self._T is None: self._T = _PreconditionedSparse( - self._D, self._M_inv, - _transposed=not self._transposed, _D_t=self._D_t, + self._D, + self._M_inv, + _transposed=not self._transposed, + _D_t=self._D_t, ) self._T._T = self # cross-link so .t().t() returns self return self._T diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index 784908661..e4d8bd0be 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -304,10 +304,21 @@ def lsmr_torch( if use_compile: return _lsmr_compiled( - A, b, damp=damp, atol=atol, btol=btol, - conlim=conlim, maxiter=maxiter, use_compile=True, + A, + b, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + use_compile=True, ) return _lsmr_scalar( - A, b, damp=damp, atol=atol, btol=btol, - conlim=conlim, maxiter=maxiter, + A, + b, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, ) diff --git a/pyfixest/estimation/torch/lsmr_torch_fused.py b/pyfixest/estimation/torch/lsmr_torch_fused.py index ae1e1cb98..1017fb26f 100644 --- a/pyfixest/estimation/torch/lsmr_torch_fused.py +++ b/pyfixest/estimation/torch/lsmr_torch_fused.py @@ -19,7 +19,6 @@ import torch - # --------------------------------------------------------------------------- # Branchless Givens rotation # --------------------------------------------------------------------------- diff --git a/tests/test_lsmr_fused.py b/tests/test_lsmr_fused.py index be612e834..537996f5a 100644 --- a/tests/test_lsmr_fused.py +++ b/tests/test_lsmr_fused.py @@ -74,7 +74,7 @@ def test_check_every_correctness(check_every): def test_zero_rhs(): - """b = 0 should return x = 0.""" + """B = 0 should return x = 0.""" A, _ = _make_sparse_problem(100, 50) b = torch.zeros(100, dtype=torch.float64) x, istop, itn, *_ = lsmr_torch_fused(A, b) @@ -97,8 +97,6 @@ def test_damping(): def test_sym_ortho_matches_scipy(): """Branchless _sym_ortho_t should match SciPy's convention.""" - import math - from pyfixest.estimation.torch.lsmr_torch import _sym_ortho from pyfixest.estimation.torch.lsmr_torch_fused import _sym_ortho_t From f215689d1b8c5bce293f02c647eda38289bc3f87 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Fri, 6 Mar 2026 17:09:28 +0000 Subject: [PATCH 08/16] Refactor LSMR implementation: remove old torch-based solvers and introduce fused version - Deleted `lsmr_torch_compiled.py` and `lsmr_torch_fused.py` files, consolidating functionality into `lsmr_torch.py`. - Updated tests to reflect changes in the LSMR implementation, ensuring correctness and performance benchmarks. - Adjusted convergence checks and state management to optimize CPU-GPU synchronization. - Enhanced the branchless Givens rotation implementation for improved efficiency on CUDA/MPS. --- benchmarks/benchmarks.py | 38 ++ benchmarks/config.json | 32 +- pyfixest/estimation/torch/lsmr_torch.py | 540 ++++++++++++++++- .../estimation/torch/lsmr_torch_compiled.py | 542 ------------------ pyfixest/estimation/torch/lsmr_torch_fused.py | 302 ---------- tests/test_lsmr_compiled.py | 6 +- tests/test_lsmr_fused.py | 157 ----- 7 files changed, 586 insertions(+), 1031 deletions(-) delete mode 100644 pyfixest/estimation/torch/lsmr_torch_compiled.py delete mode 100644 pyfixest/estimation/torch/lsmr_torch_fused.py delete mode 100644 tests/test_lsmr_fused.py diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index ae5091f7b..e6a4289dc 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -36,6 +36,14 @@ HAS_MPS = False HAS_CUDA = False +# Optional CuPy availability detection +try: + import cupy + + HAS_CUPY = cupy.cuda.runtime.getDeviceCount() > 0 +except (ImportError, Exception): + HAS_CUPY = False + # Backends that accept a backend= argument when called through pyfixest runners _PYFIXEST_BACKENDS = { "scipy", @@ -46,6 +54,9 @@ "torch_mps", "torch_cuda", "torch_cuda32", + "cupy", + "cupy32", + "cupy64", } # ============================================================================= @@ -65,6 +76,9 @@ def _append_optional_backends(estimators, label_prefix, runner_func, func_name): if HAS_CUDA: optional.append(("torch_cuda", "torch_cuda")) optional.append(("torch_cuda32", "torch_cuda32")) + if HAS_CUPY: + optional.append(("cupy64", "cupy64")) + optional.append(("cupy32", "cupy32")) for suffix, backend in optional: estimators.append( (f"{label_prefix} ({suffix})", backend, runner_func, False, func_name) @@ -379,6 +393,7 @@ def parse_dataset_name(name: str) -> tuple[str, int]: "500k": 500_000, "1m": 1_000_000, "2m": 2_000_000, + "3m": 3_000_000, "5m": 5_000_000, } parts = name.rsplit("_", 1) @@ -402,11 +417,17 @@ def run_benchmark( timeout_estimators: set[str] | None = None, formulas_override: dict[int, str] | None = None, allowed_datasets: set[str] | None = None, + pyfixest_only: bool = False, + backend_filter: set[str] | None = None, ) -> None: """Run benchmarks on all datasets in data_dir.""" if timeout_estimators is None: timeout_estimators = set() estimators, formulas = get_estimators(benchmark_type, timeout_estimators) + if pyfixest_only: + estimators = [e for e in estimators if e[1] in _PYFIXEST_BACKENDS] + if backend_filter: + estimators = [e for e in estimators if e[1] in backend_filter] if formulas_override: formulas = formulas_override @@ -596,6 +617,17 @@ def main(): default=None, help="Filter datasets by name (e.g., 'simple' to exclude 'difficult')", ) + parser.add_argument( + "--pyfixest-only", + action="store_true", + help="Skip non-pyfixest estimators (linearmodels, statsmodels)", + ) + parser.add_argument( + "--backends", + type=str, + default=None, + help="Comma-separated list of backends to run (e.g., 'torch_cuda,torch_cuda32,cupy64,cupy32')", + ) args = parser.parse_args() config = load_config("bench.json") @@ -606,6 +638,10 @@ def main(): formulas_override = get_formulas_from_config(config, args.type) allowed_datasets = get_allowed_datasets(config, args.type) + backend_filter = None + if args.backends: + backend_filter = set(b.strip() for b in args.backends.split(",")) + run_benchmark( args.data_dir, args.output, @@ -615,6 +651,8 @@ def main(): timeout_estimators, formulas_override, allowed_datasets, + pyfixest_only=args.pyfixest_only, + backend_filter=backend_filter, ) diff --git a/benchmarks/config.json b/benchmarks/config.json index 6f974cea1..24251fdd1 100644 --- a/benchmarks/config.json +++ b/benchmarks/config.json @@ -43,7 +43,13 @@ { "name": "simple_500k", "n": 500000, "type": "simple" }, { "name": "difficult_500k", "n": 500000, "type": "difficult" }, { "name": "simple_1m", "n": 1000000, "type": "simple" }, - { "name": "difficult_1m", "n": 1000000, "type": "difficult" } + { "name": "difficult_1m", "n": 1000000, "type": "difficult" }, + { "name": "simple_2m", "n": 2000000, "type": "simple" }, + { "name": "difficult_2m", "n": 2000000, "type": "difficult" }, + { "name": "simple_3m", "n": 3000000, "type": "simple" }, + { "name": "difficult_3m", "n": 3000000, "type": "difficult" }, + { "name": "simple_5m", "n": 5000000, "type": "simple" }, + { "name": "difficult_5m", "n": 5000000, "type": "difficult" } ], "datasets_by_type": { "ols": [ @@ -56,7 +62,13 @@ "simple_500k", "difficult_500k", "simple_1m", - "difficult_1m" + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" ], "poisson": [ "simple_1k", @@ -68,7 +80,13 @@ "simple_500k", "difficult_500k", "simple_1m", - "difficult_1m" + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" ], "logit": [ "simple_1k", @@ -80,7 +98,13 @@ "simple_500k", "difficult_500k", "simple_1m", - "difficult_1m" + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" ] } } diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index e4d8bd0be..b4ef26da3 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -1,9 +1,18 @@ """ -Pure PyTorch implementation of the LSMR algorithm. +Pure PyTorch LSMR iterative solver with optional torch.compile kernel fusion. -Ported from SciPy's `scipy.sparse.linalg.lsmr` (Fong & Saunders, 2011). -All vector operations use torch tensors (staying on-device for GPU), -while scalar Givens rotations use Python `math` to avoid autograd overhead. +Two implementations live in this file: + +1. ``_lsmr_eager`` — eager-mode PyTorch, Python-float Givens rotations. + Best for CPU and MPS (Metal command-buffer batching already amortizes + kernel-launch overhead). + +2. ``_lsmr_compiled`` — packs all scalar state into a 1-D tensor and runs + the Givens / norm / convergence work through a ``torch.compile``-d + kernel. On CUDA this fuses ~60 per-iteration kernel launches into one. + +The public entry point ``lsmr_torch()`` dispatches automatically: +CUDA → compiled, CPU/MPS → scalar. Pass ``use_compile=True`` to override. Reference: D. C.-L. Fong and M. A. Saunders, @@ -14,10 +23,41 @@ from __future__ import annotations import math +import threading import torch -from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch as _lsmr_compiled +# --------------------------------------------------------------------------- +# Sparse matvec helpers +# --------------------------------------------------------------------------- + + +def _matvec(A, v: torch.Tensor) -> torch.Tensor: + """Compute A @ v for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" + if isinstance(A, torch.Tensor): + return A @ v + return A.mv(v) + + +def _rmatvec(At, u: torch.Tensor) -> torch.Tensor: + """Multiply A^T @ u using a pre-computed transpose.""" + if isinstance(At, torch.Tensor): + return At @ u + return At.mv(u) + + +def _precompute_transpose(A): + """Pre-compute A^T in a GPU-friendly layout to avoid per-iteration reconversion.""" + if isinstance(A, torch.Tensor) and A.is_sparse_csr: + return A.t().to_sparse_csr() + elif isinstance(A, torch.Tensor): + return A.t().contiguous() + return A.t() # LinearOperator / wrapper — assume .mv() is efficient + + +# --------------------------------------------------------------------------- +# Scalar Givens rotation (Python math — no autograd overhead) +# --------------------------------------------------------------------------- def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: @@ -28,11 +68,10 @@ def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: [ c s ] [ a ] = [ r ] [-s c ] [ b ] [ 0 ] - This is the same algorithm as SciPy's `_sym_ortho` from LSQR, + This is the same algorithm as SciPy's ``_sym_ortho`` from LSQR, using pure Python math for speed on scalar values. """ if b == 0.0: - # math.copysign(1, 0) = 1.0 but np.sign(0) = 0.0; match SciPy's behavior c = 0.0 if a == 0.0 else math.copysign(1.0, a) return c, 0.0, abs(a) elif a == 0.0: @@ -50,21 +89,12 @@ def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: return c, s, r -def _matvec(A, v: torch.Tensor) -> torch.Tensor: - """Compute A @ v for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" - if isinstance(A, torch.Tensor): - return A @ v - return A.mv(v) - - -def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: - """A^T @ u — works for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" - if isinstance(A, torch.Tensor): - return A.t() @ u - return A.t().mv(u) +# =========================================================================== +# Implementation 1: scalar-state LSMR (CPU / MPS) +# =========================================================================== -def _lsmr_scalar( +def _lsmr_eager( A, b: torch.Tensor, damp: float = 0.0, @@ -121,6 +151,8 @@ def _lsmr_scalar( if maxiter is None: maxiter = min(m, n) + At = _precompute_transpose(A) + # --- Initialize Golub-Kahan bidiagonalization --- u = b.clone() normb = torch.linalg.norm(b).item() @@ -130,7 +162,7 @@ def _lsmr_scalar( if beta > 0: u = u * (1.0 / beta) - v = _rmatvec(A, u) + v = _rmatvec(At, u) alpha = torch.linalg.norm(v).item() else: v = torch.zeros(n, device=device, dtype=dtype) @@ -191,7 +223,7 @@ def _lsmr_scalar( if beta > 0: u *= 1.0 / beta - v = _rmatvec(A, u) - beta * v + v = _rmatvec(At, u) - beta * v alpha = torch.linalg.norm(v).item() if alpha > 0: v *= 1.0 / alpha @@ -279,6 +311,468 @@ def _lsmr_scalar( return x, istop, itn, normr, normar, normA, condA, normx +# =========================================================================== +# Implementation 2: compiled-state LSMR (CUDA) +# =========================================================================== + +# --------------------------------------------------------------------------- +# Packed scalar state layout +# --------------------------------------------------------------------------- +# All scalar state is packed into a single 1-D tensor to minimize Metal buffer +# slots (hardware limit: 31 per kernel). +# +# Input state (20 elements): +_I_ALPHABAR = 0 +_I_DAMP = 1 +_I_BETA = 2 +_I_ALPHA = 3 +_I_SBAR = 4 +_I_CBAR = 5 +_I_ZETABAR = 6 +_I_RHO = 7 +_I_RHOBAR = 8 +_I_RHODOLD = 9 +_I_TAUTILDEOLD = 10 +_I_THETATILDE = 11 +_I_BETADD = 12 +_I_BETAD = 13 +_I_D = 14 +_I_NORMA2 = 15 +_I_MAXRBAR = 16 +_I_MINRBAR = 17 +_I_NORMB = 18 +_I_ZETA = 19 # previous iteration's zeta (for normr estimation) + +# Constants (3 elements): atol, btol, ctol + +# Output adds extra slots for vector update coefficients: +_O_THETANEW = 20 +_O_THETABAR = 21 +_O_ZETA = 22 +_O_RHOOLD = 23 +_O_RHOBAROLD = 24 +_O_CONVERGED = 25 +_O_NORMR = 26 +_O_NORMAR = 27 +_O_NORMA = 28 +_O_CONDA = 29 +_O_NORMX_EST = 30 # placeholder, actual normx computed from vector + +_STATE_SIZE = 20 + + +# --------------------------------------------------------------------------- +# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) +# --------------------------------------------------------------------------- + + +def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Overflow-safe hypot: ``sqrt(a** + b**)`` without intermediate overflow. + + Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)**)``. + Since ``min/max <= 1``, the argument to sqrt never exceeds 2. + Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + big = torch.maximum(abs_a, abs_b) + small = torch.minimum(abs_a, abs_b) + safe_big = torch.where(big == 0, torch.ones_like(big), big) + ratio = small / safe_big + return torch.where( + big == 0, + torch.zeros_like(big), + big * torch.sqrt(1.0 + ratio * ratio), + ) + + +# --------------------------------------------------------------------------- +# Compiled scalar step (single Metal/CUDA kernel after fusion) +# --------------------------------------------------------------------------- + + +def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: + """ + All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond + estimation, and convergence check. + + Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). + Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). + """ + # Unpack + alphabar = state[_I_ALPHABAR] + damp = state[_I_DAMP] + beta = state[_I_BETA] + alpha = state[_I_ALPHA] + sbar = state[_I_SBAR] + cbar = state[_I_CBAR] + zetabar = state[_I_ZETABAR] + rho = state[_I_RHO] + rhobar = state[_I_RHOBAR] + rhodold = state[_I_RHODOLD] + tautildeold = state[_I_TAUTILDEOLD] + thetatilde = state[_I_THETATILDE] + betadd = state[_I_BETADD] + betad = state[_I_BETAD] + d = state[_I_D] + normA2 = state[_I_NORMA2] + maxrbar = state[_I_MAXRBAR] + minrbar = state[_I_MINRBAR] + normb = state[_I_NORMB] + zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) + + atol_t = consts[0] + ctol = consts[2] + + _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero + _ONE = _ZERO + 1.0 + + # --- Givens 1: (alphabar, damp) --- + r1 = _safe_hypot(alphabar, damp) + safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) + chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) + shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) + + # --- Givens 2: (alphahat=r1, beta) --- + rhoold = rho + r2 = _safe_hypot(r1, beta) + safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) + c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) + s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) + rho_new = r2 + thetanew = s * alpha + alphabar_new = c * alpha + + # --- Givens 3: rhobar --- + rhobarold = rhobar + thetabar = sbar * rho_new + rhotemp = cbar * rho_new + r3 = _safe_hypot(rhotemp, thetanew) + safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) + cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) + sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) + rhobar_new = r3 + zeta = cbar_new * zetabar + zetabar_new = -sbar_new * zetabar + + # --- ||r|| estimation --- + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd_new = -s * betaacute + + # Givens 4: rhotilde + r4 = _safe_hypot(rhodold, thetabar) + safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) + ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) + stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) + + thetatilde_new = stildeold * rhobar_new + rhodold_new = ctildeold * rhobar_new + betad_new = -stildeold * betad + ctildeold * betahat + + tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp(r4, min=1e-30) + taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( + rhodold_new, min=1e-30 + ) + d_new = d + betacheck * betacheck + normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) + + # --- ||A|| estimation --- + normA2_new = normA2 + beta * beta + normA = torch.sqrt(normA2_new) + normA2_final = normA2_new + alpha * alpha + + # --- cond(A) estimation --- + maxrbar_new = torch.maximum(maxrbar, rhobarold) + # Match SciPy: only update minrbar from iteration 2 onward. + # maxrbar == 0 on the first call (initial state), so use it as guard. + minrbar_new = torch.where(maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar) + condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( + torch.minimum(minrbar_new, rhotemp), min=1e-30 + ) + + # --- Convergence check --- + normar = torch.abs(zetabar_new) + test2 = normar / torch.clamp(normA * normr, min=1e-30) + test3 = _ONE / condA + + converged_flag = torch.where( + (test2 <= atol_t) + | (test3 <= ctol) + | (_ONE + test2 <= _ONE) + | (_ONE + test3 <= _ONE), + _ONE, + _ZERO, + ) + + # --- Pack output --- + return torch.stack( + [ + alphabar_new, # 0 _I_ALPHABAR + damp, # 1 _I_DAMP (pass through) + beta, # 2 _I_BETA (pass through, updated by caller) + alpha, # 3 _I_ALPHA (pass through, updated by caller) + sbar_new, # 4 _I_SBAR + cbar_new, # 5 _I_CBAR + zetabar_new, # 6 _I_ZETABAR + rho_new, # 7 _I_RHO + rhobar_new, # 8 _I_RHOBAR + rhodold_new, # 9 _I_RHODOLD + tautildeold_new, # 10 _I_TAUTILDEOLD + thetatilde_new, # 11 _I_THETATILDE + betadd_new, # 12 _I_BETADD + betad_new, # 13 _I_BETAD + d_new, # 14 _I_D + normA2_final, # 15 _I_NORMA2 + maxrbar_new, # 16 _I_MAXRBAR + minrbar_new, # 17 _I_MINRBAR + normb, # 18 _I_NORMB (pass through) + zeta, # 19 _I_ZETA (saved for next iteration's zetaold) + thetanew, # 20 _O_THETANEW (for vector update) + thetabar, # 21 _O_THETABAR (for vector update) + zeta, # 22 _O_ZETA (for vector update — same as slot 19) + rhoold, # 23 _O_RHOOLD (for vector update) + rhobarold, # 24 _O_RHOBAROLD (for vector update) + converged_flag, # 25 _O_CONVERGED + normr, # 26 _O_NORMR + normar, # 27 _O_NORMAR + normA, # 28 _O_NORMA + condA, # 29 _O_CONDA + _ZERO, # 30 _O_NORMX_EST (placeholder) + ] + ) + + +# --------------------------------------------------------------------------- +# Module-level compilation cache +# --------------------------------------------------------------------------- +_compiled_step_cache: dict[str, object] = {} +_cache_lock = threading.Lock() + + +def _get_compiled_step(device_type: str): + """Get or create compiled scalar step for the given device type.""" + if device_type in _compiled_step_cache: + return _compiled_step_cache[device_type] + with _cache_lock: + # Double-check after acquiring lock + if device_type not in _compiled_step_cache: + try: + _compiled_step_cache[device_type] = torch.compile( + _scalar_step, backend="inductor", fullgraph=True + ) + except Exception: + # Fallback: no compilation available + _compiled_step_cache[device_type] = _scalar_step + return _compiled_step_cache[device_type] + + +def _lsmr_compiled( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool = True, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR with packed-tensor scalar state and optional torch.compile fusion. + + On CUDA the scalar Givens rotations, norm estimation, and convergence + check are fused into a **single GPU kernel** via ``torch.compile`` + + Inductor, eliminating ~60 per-iteration kernel launches. + + Called by the ``lsmr_torch`` dispatcher; ``use_compile`` is already + resolved by the caller (no auto-detection here). + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Get compiled or uncompiled step function + step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b) + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb.clone() + + # Safe normalize + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), beta * 0.0) + + v = _rmatvec(At, u) + alpha = torch.linalg.norm(v) + v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), alpha * 0.0) + + # --- Pack initial scalar state --- + state = torch.zeros(_STATE_SIZE, device=device, dtype=dtype) + state[_I_ALPHABAR] = alpha + state[_I_DAMP] = damp + state[_I_BETA] = beta + state[_I_ALPHA] = alpha + state[_I_SBAR] = 0.0 + state[_I_CBAR] = 1.0 + state[_I_ZETABAR] = alpha * beta + state[_I_RHO] = 1.0 + state[_I_RHOBAR] = 1.0 + state[_I_RHODOLD] = 1.0 + state[_I_TAUTILDEOLD] = 0.0 + state[_I_THETATILDE] = 0.0 + state[_I_BETADD] = beta + state[_I_BETAD] = 0.0 + state[_I_D] = 0.0 + state[_I_NORMA2] = alpha * alpha + state[_I_MAXRBAR] = 0.0 + state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 + state[_I_NORMB] = normb + state[_I_ZETA] = 0.0 # initial zeta (no previous iteration) + + ctol = 1.0 / conlim if conlim > 0 else 0.0 + consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) + + # Early exit check + normar_init = (alpha * beta).item() + if normar_init == 0.0: + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + if normb.item() == 0.0: + x.zero_() + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # --- Main iteration loop --- + itn = 0 + istop = 0 + + while itn < maxiter: + itn += 1 + + # Phase 1: Sparse matvec (not compilable) + # state[_I_ALPHA] holds the current alpha (passed through by _scalar_step) + u = _matvec(A, v) - state[_I_ALPHA] * u + beta_new = torch.linalg.norm(u) + u = u * torch.where( + beta_new > 0, + 1.0 / torch.clamp(beta_new, min=1e-30), + beta_new * 0.0, + ) + + v = _rmatvec(At, u) - beta_new * v + alpha_new = torch.linalg.norm(v) + v = v * torch.where( + alpha_new > 0, + 1.0 / torch.clamp(alpha_new, min=1e-30), + alpha_new * 0.0, + ) + + # Update beta/alpha in state for the scalar step + state[_I_BETA] = beta_new + state[_I_ALPHA] = alpha_new + + # Phase 2: Compiled scalar step (single GPU kernel on CUDA) + out = step_fn(state, consts) + + # Phase 3: Vector updates using scalar results from compiled step + thetanew = out[_O_THETANEW] + thetabar = out[_O_THETABAR] + zeta = out[_O_ZETA] + rho_new = out[_I_RHO] + rhobar_new = out[_I_RHOBAR] + rhoold = out[_O_RHOOLD] + rhobarold = out[_O_RHOBAROLD] + + hbar = h + hbar * (-(thetabar * rho_new) / (rhoold * rhobarold)) + x = x + (zeta / (rho_new * rhobar_new)) * hbar + h = v + h * (-(thetanew / rho_new)) + + # Propagate state for next iteration + state = out[:_STATE_SIZE] + + # Convergence check — single .item() sync per iteration. + # The compiled _scalar_step checks test2 (atol) and test3 (ctol). + # The btol-based test1 depends on normx (a vector quantity) and is + # computed here on-device, then combined with the scalar step's flag. + # Only one .item() call reads the combined boolean; all other tensor + # ops queue on the GPU stream without forcing a pipeline stall. + normx_t = torch.linalg.norm(x) + normr_t = out[_O_NORMR] + normA_t = out[_O_NORMA] + normb_t = out[_I_NORMB] + + test1_t = normr_t / torch.clamp(normb_t, min=1e-30) + t1_t = test1_t / (1.0 + normA_t * normx_t / torch.clamp(normb_t, min=1e-30)) + rtol_t = btol + atol * normA_t * normx_t / torch.clamp(normb_t, min=1e-30) + + converged_btol = (test1_t <= rtol_t) | (1.0 + t1_t <= 1.0) + converged_any = (out[_O_CONVERGED] > 0.5) | converged_btol + + if converged_any.item(): + # Pull scalars to CPU only at exit (one-time cost) + normr_val = normr_t.item() + normA_val = normA_t.item() + normx_val = normx_t.item() + normb_val = normb_t.item() + normar_val = out[_O_NORMAR].item() + condA_val = out[_O_CONDA].item() + + test1 = normr_val / max(normb_val, 1e-30) + test2 = normar_val / max(normA_val * normr_val, 1e-30) + test3 = 1.0 / condA_val + t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, 1e-30)) + _rtol = btol + atol * normA_val * normx_val / max(normb_val, 1e-30) + + # Priority order matches SciPy LSMR (lowest istop wins) + if 1.0 + test3 <= 1.0: + istop = 6 + if 1.0 + test2 <= 1.0: + istop = 5 + if 1.0 + t1 <= 1.0: + istop = 4 + if test3 <= ctol: + istop = 3 + if test2 <= atol: + istop = 2 + if test1 <= _rtol: + istop = 1 + break + + if itn >= maxiter and istop == 0: + istop = 7 + + # Handle case where loop never ran (maxiter=0 or similar) + if itn == 0: + return x, istop, 0, normb.item(), normar_init, alpha.item(), 1.0, 0.0 + + normx_val = torch.linalg.norm(x).item() + return ( + x, + istop, + itn, + out[_O_NORMR].item(), + out[_O_NORMAR].item(), + out[_O_NORMA].item(), + out[_O_CONDA].item(), + normx_val, + ) + + +# =========================================================================== +# Public API — dispatcher +# =========================================================================== + + def lsmr_torch( A, b: torch.Tensor, @@ -313,7 +807,7 @@ def lsmr_torch( maxiter=maxiter, use_compile=True, ) - return _lsmr_scalar( + return _lsmr_eager( A, b, damp=damp, diff --git a/pyfixest/estimation/torch/lsmr_torch_compiled.py b/pyfixest/estimation/torch/lsmr_torch_compiled.py deleted file mode 100644 index 6ebb576ef..000000000 --- a/pyfixest/estimation/torch/lsmr_torch_compiled.py +++ /dev/null @@ -1,542 +0,0 @@ -""" -LSMR iterative solver in pure PyTorch with optional torch.compile kernel fusion. - -The solver splits each iteration into three phases: - 1. Sparse matvec (A @ v, A.T @ u) — cannot be compiled (sparse CSR unsupported) - 2. Scalar Givens rotations + norm estimation + convergence — compiled on GPU - 3. Vector updates (h, hbar, x) — use scalar results from phase 2 - -Phase 2 involves ~60 scalar operations that, without compilation, dispatch as -~60 individual GPU kernels (~15μs each on MPS). torch.compile fuses them into -a SINGLE kernel via the Inductor backend (Metal shaders on MPS, CUDA kernels -on NVIDIA GPUs). - -Workarounds for MPS/Metal limitations: - - torch.hypot not in Metal codegen → overflow-safe manual hypot via max/min scaling - - Metal kernel limited to 31 buffer args → pack all scalars into 1-D tensors - -The safe manual hypot and packed layout work uniformly across all backends -(CPU, MPS, CUDA) with negligible overhead (~5%) after fusion. - -On CPU the scalar step runs without compilation (no kernel launch overhead -to eliminate), so the packed layout is the only difference from a traditional -scalar-state LSMR. - -Reference: - D. C.-L. Fong and M. A. Saunders, - "LSMR: An iterative algorithm for sparse least-squares problems", - SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. -""" - -from __future__ import annotations - -import threading - -import torch - -# --------------------------------------------------------------------------- -# Sparse matvec helpers (outside compiled region) -# --------------------------------------------------------------------------- - - -def _matvec(A, v: torch.Tensor) -> torch.Tensor: - if isinstance(A, torch.Tensor): - return A @ v - return A.mv(v) - - -def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: - if isinstance(A, torch.Tensor): - return A.t() @ u - return A.t().mv(u) - - -# --------------------------------------------------------------------------- -# Packed scalar state layout -# --------------------------------------------------------------------------- -# We pack all scalar state into a single 1-D tensor to minimize Metal buffer -# slots (hardware limit: 31 per kernel). -# -# Input state (20 elements): -_I_ALPHABAR = 0 -_I_DAMP = 1 -_I_BETA = 2 -_I_ALPHA = 3 -_I_SBAR = 4 -_I_CBAR = 5 -_I_ZETABAR = 6 -_I_RHO = 7 -_I_RHOBAR = 8 -_I_RHODOLD = 9 -_I_TAUTILDEOLD = 10 -_I_THETATILDE = 11 -_I_BETADD = 12 -_I_BETAD = 13 -_I_D = 14 -_I_NORMA2 = 15 -_I_MAXRBAR = 16 -_I_MINRBAR = 17 -_I_NORMB = 18 -_I_ZETA = 19 # previous iteration's zeta (for normr estimation) - -# Constants (3 elements): atol, btol, ctol - -# Output adds extra slots for vector update coefficients: -_O_THETANEW = 20 -_O_THETABAR = 21 -_O_ZETA = 22 -_O_RHOOLD = 23 -_O_RHOBAROLD = 24 -_O_CONVERGED = 25 -_O_NORMR = 26 -_O_NORMAR = 27 -_O_NORMA = 28 -_O_CONDA = 29 -_O_NORMX_EST = 30 # placeholder, actual normx computed from vector - -_STATE_SIZE = 20 - - -# --------------------------------------------------------------------------- -# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) -# --------------------------------------------------------------------------- - - -def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """ - Overflow-safe hypot: ``sqrt(a² + b²)`` without intermediate overflow. - - Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)²)``. - Since ``min/max ≤ 1``, the argument to sqrt never exceeds 2. - Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. - """ - abs_a = torch.abs(a) - abs_b = torch.abs(b) - big = torch.maximum(abs_a, abs_b) - small = torch.minimum(abs_a, abs_b) - safe_big = torch.where(big == 0, torch.ones_like(big), big) - ratio = small / safe_big - return torch.where( - big == 0, - torch.zeros_like(big), - big * torch.sqrt(1.0 + ratio * ratio), - ) - - -# --------------------------------------------------------------------------- -# Compiled scalar step (single Metal/CUDA kernel after fusion) -# --------------------------------------------------------------------------- - - -def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: - """ - All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond - estimation, and convergence check. - - Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). - Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). - """ - # Unpack - alphabar = state[_I_ALPHABAR] - damp = state[_I_DAMP] - beta = state[_I_BETA] - alpha = state[_I_ALPHA] - sbar = state[_I_SBAR] - cbar = state[_I_CBAR] - zetabar = state[_I_ZETABAR] - rho = state[_I_RHO] - rhobar = state[_I_RHOBAR] - rhodold = state[_I_RHODOLD] - tautildeold = state[_I_TAUTILDEOLD] - thetatilde = state[_I_THETATILDE] - betadd = state[_I_BETADD] - betad = state[_I_BETAD] - d = state[_I_D] - normA2 = state[_I_NORMA2] - maxrbar = state[_I_MAXRBAR] - minrbar = state[_I_MINRBAR] - normb = state[_I_NORMB] - zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) - - atol_t = consts[0] - ctol = consts[2] - - _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero - _ONE = _ZERO + 1.0 - - # --- Givens 1: (alphabar, damp) --- - r1 = _safe_hypot(alphabar, damp) - safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) - chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) - shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) - - # --- Givens 2: (alphahat=r1, beta) --- - rhoold = rho - r2 = _safe_hypot(r1, beta) - safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) - c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) - s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) - rho_new = r2 - thetanew = s * alpha - alphabar_new = c * alpha - - # --- Givens 3: rhobar --- - rhobarold = rhobar - thetabar = sbar * rho_new - rhotemp = cbar * rho_new - r3 = _safe_hypot(rhotemp, thetanew) - safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) - cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) - sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) - rhobar_new = r3 - zeta = cbar_new * zetabar - zetabar_new = -sbar_new * zetabar - - # --- ||r|| estimation --- - betaacute = chat * betadd - betacheck = -shat * betadd - betahat = c * betaacute - betadd_new = -s * betaacute - - # Givens 4: rhotilde - r4 = _safe_hypot(rhodold, thetabar) - safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) - ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) - stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) - - thetatilde_new = stildeold * rhobar_new - rhodold_new = ctildeold * rhobar_new - betad_new = -stildeold * betad + ctildeold * betahat - - tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp(r4, min=1e-30) - taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( - rhodold_new, min=1e-30 - ) - d_new = d + betacheck * betacheck - normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) - - # --- ||A|| estimation --- - normA2_new = normA2 + beta * beta - normA = torch.sqrt(normA2_new) - normA2_final = normA2_new + alpha * alpha - - # --- cond(A) estimation --- - maxrbar_new = torch.maximum(maxrbar, rhobarold) - minrbar_new = torch.minimum(minrbar, rhobarold) - condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( - torch.minimum(minrbar_new, rhotemp), min=1e-30 - ) - - # --- Convergence check --- - normar = torch.abs(zetabar_new) - test2 = normar / torch.clamp(normA * normr, min=1e-30) - test3 = _ONE / condA - - converged_flag = torch.where( - (test2 <= atol_t) - | (test3 <= ctol) - | (_ONE + test2 <= _ONE) - | (_ONE + test3 <= _ONE), - _ONE, - _ZERO, - ) - - # --- Pack output --- - return torch.stack( - [ - alphabar_new, # 0 _I_ALPHABAR - damp, # 1 _I_DAMP (pass through) - beta, # 2 _I_BETA (pass through, updated by caller) - alpha, # 3 _I_ALPHA (pass through, updated by caller) - sbar_new, # 4 _I_SBAR - cbar_new, # 5 _I_CBAR - zetabar_new, # 6 _I_ZETABAR - rho_new, # 7 _I_RHO - rhobar_new, # 8 _I_RHOBAR - rhodold_new, # 9 _I_RHODOLD - tautildeold_new, # 10 _I_TAUTILDEOLD - thetatilde_new, # 11 _I_THETATILDE - betadd_new, # 12 _I_BETADD - betad_new, # 13 _I_BETAD - d_new, # 14 _I_D - normA2_final, # 15 _I_NORMA2 - maxrbar_new, # 16 _I_MAXRBAR - minrbar_new, # 17 _I_MINRBAR - normb, # 18 _I_NORMB (pass through) - zeta, # 19 _I_ZETA (saved for next iteration's zetaold) - thetanew, # 20 _O_THETANEW (for vector update) - thetabar, # 21 _O_THETABAR (for vector update) - zeta, # 22 _O_ZETA (for vector update — same as slot 19) - rhoold, # 23 _O_RHOOLD (for vector update) - rhobarold, # 24 _O_RHOBAROLD (for vector update) - converged_flag, # 25 _O_CONVERGED - normr, # 26 _O_NORMR - normar, # 27 _O_NORMAR - normA, # 28 _O_NORMA - condA, # 29 _O_CONDA - _ZERO, # 30 _O_NORMX_EST (placeholder) - ] - ) - - -# --------------------------------------------------------------------------- -# Module-level compilation cache -# --------------------------------------------------------------------------- -_compiled_step_cache: dict[str, object] = {} -_cache_lock = threading.Lock() - - -def _get_compiled_step(device_type: str): - """Get or create compiled scalar step for the given device type.""" - if device_type in _compiled_step_cache: - return _compiled_step_cache[device_type] - with _cache_lock: - # Double-check after acquiring lock - if device_type not in _compiled_step_cache: - try: - _compiled_step_cache[device_type] = torch.compile( - _scalar_step, backend="inductor", fullgraph=True - ) - except Exception: - # Fallback: no compilation available - _compiled_step_cache[device_type] = _scalar_step - return _compiled_step_cache[device_type] - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def lsmr_torch( - A, - b: torch.Tensor, - damp: float = 0.0, - atol: float = 1e-8, - btol: float = 1e-8, - conlim: float = 1e8, - maxiter: int | None = None, - use_compile: bool | None = None, -) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: - """ - LSMR iterative solver for sparse least-squares, in pure PyTorch. - - Solves ``min ||b - Ax||_2`` (or the damped variant) where *A* is a sparse - CSR (or dense) tensor and *b* is a dense vector. All vector operations - stay on the tensor's device (CPU / CUDA / MPS). - - On CUDA the scalar Givens rotations, norm estimation, and convergence - check are fused into a **single GPU kernel** via ``torch.compile`` + - Inductor, eliminating ~60 per-iteration kernel launches. On CPU and - MPS the scalar step runs without compilation (MPS's Metal command - buffer batching already amortizes kernel launch overhead, so - compilation adds dispatch cost without measurable benefit). - - The while loop itself cannot be compiled because sparse CSR matvec is - not supported in ``torch.compile``. The single remaining CPU-GPU sync - per iteration (reading the convergence flag) is negligible after - compilation fuses the scalar step into one kernel. - - Parameters - ---------- - A : torch.Tensor - Sparse CSR tensor (or dense tensor / LinearOperator with ``.mv`` - and ``.t().mv`` support) of shape ``(m, n)``. - b : torch.Tensor - Dense vector of shape ``(m,)``. - damp : float - Damping factor for regularised least-squares. - atol, btol : float - Stopping tolerances (see SciPy LSMR documentation). - conlim : float - Condition-number limit. - maxiter : int or None - Maximum iterations. Defaults to ``min(m, n)``. - use_compile : bool or None - Whether to ``torch.compile`` the scalar step. ``None`` (default) - auto-selects: **True** on CUDA, **False** on CPU and MPS. - - Returns - ------- - x : torch.Tensor - Solution vector of shape ``(n,)``. - istop : int - Reason for stopping (0-7, same codes as SciPy LSMR). - itn : int - Number of iterations used. - normr, normar, normA, condA, normx : float - Diagnostic norms (see SciPy LSMR documentation). - """ - m, n = A.shape - device = b.device - dtype = b.dtype - - if maxiter is None: - maxiter = min(m, n) - - # Auto-detect compilation - if use_compile is None: - use_compile = device.type == "cuda" - - # Get compiled or uncompiled step function - step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step - - # --- Initialize Golub-Kahan bidiagonalization --- - u = b.clone() - normb = torch.linalg.norm(b) - - x = torch.zeros(n, device=device, dtype=dtype) - beta = normb.clone() - - # Safe normalize - u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), beta * 0.0) - - v = _rmatvec(A, u) - alpha = torch.linalg.norm(v) - v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), alpha * 0.0) - - # --- Pack initial scalar state --- - state = torch.zeros(_STATE_SIZE, device=device, dtype=dtype) - state[_I_ALPHABAR] = alpha - state[_I_DAMP] = damp - state[_I_BETA] = beta - state[_I_ALPHA] = alpha - state[_I_SBAR] = 0.0 - state[_I_CBAR] = 1.0 - state[_I_ZETABAR] = alpha * beta - state[_I_RHO] = 1.0 - state[_I_RHOBAR] = 1.0 - state[_I_RHODOLD] = 1.0 - state[_I_TAUTILDEOLD] = 0.0 - state[_I_THETATILDE] = 0.0 - state[_I_BETADD] = beta - state[_I_BETAD] = 0.0 - state[_I_D] = 0.0 - state[_I_NORMA2] = alpha * alpha - state[_I_MAXRBAR] = 0.0 - state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 - state[_I_NORMB] = normb - state[_I_ZETA] = 0.0 # initial zeta (no previous iteration) - - ctol = 1.0 / conlim if conlim > 0 else 0.0 - consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) - - # Early exit check - normar_init = (alpha * beta).item() - if normar_init == 0.0: - return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 - if normb.item() == 0.0: - x.zero_() - return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 - - h = v.clone() - hbar = torch.zeros(n, device=device, dtype=dtype) - - # --- Main iteration loop --- - itn = 0 - istop = 0 - - while itn < maxiter: - itn += 1 - - # Phase 1: Sparse matvec (not compilable) - u = _matvec(A, v) - state[_I_ALPHA] * u - beta_new = torch.linalg.norm(u) - u = u * torch.where( - beta_new > 0, - 1.0 / torch.clamp(beta_new, min=1e-30), - beta_new * 0.0, - ) - - v = _rmatvec(A, u) - beta_new * v - alpha_new = torch.linalg.norm(v) - v = v * torch.where( - alpha_new > 0, - 1.0 / torch.clamp(alpha_new, min=1e-30), - alpha_new * 0.0, - ) - - # Update beta/alpha in state for the scalar step - state[_I_BETA] = beta_new - state[_I_ALPHA] = alpha_new - - # Phase 2: Compiled scalar step (single GPU kernel on MPS/CUDA) - out = step_fn(state, consts) - - # Phase 3: Vector updates using scalar results from compiled step - thetanew = out[_O_THETANEW] - thetabar = out[_O_THETABAR] - zeta = out[_O_ZETA] - rho_new = out[_I_RHO] - rhobar_new = out[_I_RHOBAR] - rhoold = out[_O_RHOOLD] - rhobarold = out[_O_RHOBAROLD] - - hbar = h + hbar * (-(thetabar * rho_new) / (rhoold * rhobarold)) - x = x + (zeta / (rho_new * rhobar_new)) * hbar - h = v + h * (-(thetanew / rho_new)) - - # Propagate state for next iteration - state = out[:_STATE_SIZE] - - # Convergence check — single .item() sync per iteration. - # After compilation the scalar step is one kernel, so this sync - # (reading a boolean from GPU memory) is negligible. - # - # The compiled _scalar_step checks test2 (atol) and test3 (ctol). - # The btol-based test1 depends on normx (a vector quantity) and is - # checked here in Python since it can't be part of the compiled kernel. - converged_scalar = out[_O_CONVERGED].item() > 0.5 - - # btol / test1 convergence (requires normx from vector x) - normr_val = out[_O_NORMR].item() - normA_val = out[_O_NORMA].item() - normx_val = torch.linalg.norm(x).item() - normb_val = out[_I_NORMB].item() - - test1 = normr_val / max(normb_val, 1e-30) - t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, 1e-30)) - rtol = btol + atol * normA_val * normx_val / max(normb_val, 1e-30) - converged_btol = test1 <= rtol - converged_btol_machine = 1.0 + t1 <= 1.0 - - if converged_scalar or converged_btol or converged_btol_machine: - normar_val = out[_O_NORMAR].item() - condA_val = out[_O_CONDA].item() - - test2 = normar_val / max(normA_val * normr_val, 1e-30) - test3 = 1.0 / condA_val - - # Priority order matches SciPy LSMR (highest istop wins) - if 1.0 + test3 <= 1.0: - istop = 6 - if 1.0 + test2 <= 1.0: - istop = 5 - if converged_btol_machine: - istop = 4 - if test3 <= ctol: - istop = 3 - if test2 <= atol: - istop = 2 - if converged_btol: - istop = 1 - break - - if itn >= maxiter and istop == 0: - istop = 7 - - # Handle case where loop never ran (maxiter=0 or similar) - if itn == 0: - return x, istop, 0, normb.item(), normar_init, alpha.item(), 1.0, 0.0 - - normx_val = torch.linalg.norm(x).item() - return ( - x, - istop, - itn, - out[_O_NORMR].item(), - out[_O_NORMAR].item(), - out[_O_NORMA].item(), - out[_O_CONDA].item(), - normx_val, - ) diff --git a/pyfixest/estimation/torch/lsmr_torch_fused.py b/pyfixest/estimation/torch/lsmr_torch_fused.py deleted file mode 100644 index 1017fb26f..000000000 --- a/pyfixest/estimation/torch/lsmr_torch_fused.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -On-device LSMR: all scalar state as 0-d tensors to eliminate CPU-GPU sync. - -Compared to lsmr_torch.py: -- Branchless Givens via torch.hypot (no if/elif, no Python math) -- All scalar state as 0-d device tensors (no .item() in hot loop) -- Convergence check via logical indexing -- Single sync point every `check_every` iterations - -Designed for CUDA/MPS where CPU-GPU synchronization is expensive. - -Reference: - D. C.-L. Fong and M. A. Saunders, - "LSMR: An iterative algorithm for sparse least-squares problems", - SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. -""" - -from __future__ import annotations - -import torch - -# --------------------------------------------------------------------------- -# Branchless Givens rotation -# --------------------------------------------------------------------------- - - -def _sym_ortho_t( - a: torch.Tensor, b: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Branchless Givens rotation for 0-d tensors on device. - - Equivalent to SciPy's ``_sym_ortho`` but implemented with - ``torch.hypot`` (overflow-safe) and ``torch.where`` (branchless). - No CPU-GPU synchronization occurs. - """ - r = torch.hypot(a, b) - # Guard division: when r == 0, return (0, 0, 0) - safe_r = torch.where(r == 0, torch.ones_like(r), r) - c = torch.where(r == 0, torch.zeros_like(a), a / safe_r) - s = torch.where(r == 0, torch.zeros_like(b), b / safe_r) - return c, s, r - - -# --------------------------------------------------------------------------- -# Sparse matvec helpers -# --------------------------------------------------------------------------- - - -def _matvec(A, v: torch.Tensor) -> torch.Tensor: - if isinstance(A, torch.Tensor): - return A @ v - return A.mv(v) - - -def _rmatvec(A, u: torch.Tensor) -> torch.Tensor: - if isinstance(A, torch.Tensor): - return A.t() @ u - return A.t().mv(u) - - -# --------------------------------------------------------------------------- -# On-device LSMR -# --------------------------------------------------------------------------- - - -def lsmr_torch_fused( - A, - b: torch.Tensor, - damp: float = 0.0, - atol: float = 1e-8, - btol: float = 1e-8, - conlim: float = 1e8, - maxiter: int | None = None, - check_every: int = 10, -) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: - """ - LSMR iterative solver with minimal CPU-GPU synchronization. - - Same algorithm as :func:`lsmr_torch`, but keeps **all** scalar state - as 0-d tensors on the compute device. CPU-GPU sync only happens once - every *check_every* iterations (for the convergence check), reducing - pipeline stalls from 3/iteration to 1/N. - - Parameters - ---------- - A : torch.Tensor - Sparse CSR tensor of shape (m, n). - b : torch.Tensor - Dense vector of shape (m,). - damp : float - Damping factor for regularized least-squares. - atol, btol : float - Stopping tolerances (see SciPy LSMR docs). - conlim : float - Condition number limit. - maxiter : int or None - Maximum iterations. Defaults to min(m, n). - check_every : int - How often to sync to CPU for convergence check (default: 10). - Higher values reduce syncs but may overshoot convergence by up to - ``check_every - 1`` iterations. - - Returns - ------- - x, istop, itn, normr, normar, normA, condA, normx - Same signature as :func:`lsmr_torch`. - """ - m, n = A.shape - device = b.device - dtype = b.dtype - - if maxiter is None: - maxiter = min(m, n) - - # Scalar constant factory - def _s(val: float) -> torch.Tensor: - return torch.tensor(val, device=device, dtype=dtype) - - _TINY = _s(1e-30) - - # --- Initialize Golub-Kahan bidiagonalization --- - u = b.clone() - normb = torch.linalg.norm(b) - - x = torch.zeros(n, device=device, dtype=dtype) - beta = normb.clone() - - # Branchless safe-normalize: u /= beta (or zero if beta == 0) - u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), _s(0.0)) - - v = _rmatvec(A, u) - alpha = torch.linalg.norm(v) - v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), _s(0.0)) - - # --- Scalar state (all 0-d device tensors) --- - zetabar = alpha * beta - alphabar = alpha.clone() - rho = _s(1.0) - rhobar = _s(1.0) - cbar = _s(1.0) - sbar = _s(0.0) - - h = v.clone() - hbar = torch.zeros(n, device=device, dtype=dtype) - - # ||r|| estimation - betadd = beta.clone() - betad = _s(0.0) - rhodold = _s(1.0) - tautildeold = _s(0.0) - thetatilde = _s(0.0) - zeta = _s(0.0) - d = _s(0.0) - - # ||A|| and cond(A) estimation - normA2 = alpha * alpha - maxrbar = _s(0.0) - minrbar = _s(1e100) - normA = torch.sqrt(normA2) - condA = _s(1.0) - normx_est = _s(0.0) - - # Stopping - ctol = _s(1.0 / conlim if conlim > 0 else 0.0) - normr = beta.clone() - normar = alpha * beta - - # Pre-create tolerance tensors - atol_t = _s(atol) - btol_t = _s(btol) - damp_t = _s(damp) - - # Early exit (syncs once at init — unavoidable) - if normar.item() == 0.0: - return x, 0, 0, normr.item(), normar.item(), normA.item(), condA.item(), 0.0 - if normb.item() == 0.0: - x.zero_() - return x, 0, 0, normr.item(), normar.item(), normA.item(), condA.item(), 0.0 - - # --- Main iteration loop (zero syncs inside, except periodic check) --- - itn = 0 - istop = 0 - - while itn < maxiter: - itn += 1 - - # Bidiagonalization step - u = _matvec(A, v) - alpha * u - beta = torch.linalg.norm(u) - u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), _s(0.0)) - - v = _rmatvec(A, u) - beta * v - alpha = torch.linalg.norm(v) - v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), _s(0.0)) - - # Givens rotations (branchless, on device) - chat, shat, alphahat = _sym_ortho_t(alphabar, damp_t) - - rhoold = rho - c, s, rho = _sym_ortho_t(alphahat, beta) - thetanew = s * alpha - alphabar = c * alpha - - rhobarold = rhobar - zetaold = zeta - thetabar = sbar * rho - rhotemp = cbar * rho - cbar, sbar, rhobar = _sym_ortho_t(cbar * rho, thetanew) - zeta = cbar * zetabar - zetabar = -sbar * zetabar - - # Vector updates (on device) - hbar = h + hbar * (-(thetabar * rho) / (rhoold * rhobarold)) - x = x + (zeta / (rho * rhobar)) * hbar - h = v + h * (-(thetanew / rho)) - - # ||r|| estimation - betaacute = chat * betadd - betacheck = -shat * betadd - betahat = c * betaacute - betadd = -s * betaacute - - thetatildeold = thetatilde - ctildeold, stildeold, rhotildeold = _sym_ortho_t(rhodold, thetabar) - thetatilde = stildeold * rhobar - rhodold = ctildeold * rhobar - betad = -stildeold * betad + ctildeold * betahat - - tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold - taud = (zeta - thetatilde * tautildeold) / rhodold - d = d + betacheck * betacheck - normr = torch.sqrt(d + (betad - taud) ** 2 + betadd * betadd) - - # ||A|| estimation - normA2 = normA2 + beta * beta - normA = torch.sqrt(normA2) - normA2 = normA2 + alpha * alpha - - # cond(A) estimation - maxrbar = torch.maximum(maxrbar, rhobarold) - if itn > 1: - minrbar = torch.minimum(minrbar, rhobarold) - condA = torch.maximum(maxrbar, rhotemp) / torch.minimum(minrbar, rhotemp) - - # Convergence check (all logical ops on device, no sync) - normar = torch.abs(zetabar) - normx_est = torch.linalg.norm(x) - - test1 = normr / normb - test2 = normar / torch.clamp(normA * normr, min=1e-30) - test3 = 1.0 / condA - t1 = test1 / (1.0 + normA * normx_est / normb) - rtol = btol_t + atol_t * normA * normx_est / normb - - converged = ( - (test1 <= rtol) - | (test2 <= atol_t) - | (test3 <= ctol) - | (1.0 + t1 <= 1.0) - | (1.0 + test2 <= 1.0) - | (1.0 + test3 <= 1.0) - ) - - # --- Periodic sync: single .item() every check_every iterations --- - if itn % check_every == 0 or itn >= maxiter: - if converged.item(): - # Determine exact istop code (one-time sync at exit) - _test1 = test1.item() - _test2 = test2.item() - _test3 = test3.item() - _t1 = t1.item() - _rtol = rtol.item() - _ctol = ctol.item() - - if 1.0 + _test3 <= 1.0: - istop = 6 - elif 1.0 + _test2 <= 1.0: - istop = 5 - elif 1.0 + _t1 <= 1.0: - istop = 4 - elif _test3 <= _ctol: - istop = 3 - elif _test2 <= atol: - istop = 2 - elif _test1 <= _rtol: - istop = 1 - break - - if itn >= maxiter and istop == 0: - istop = 7 - - return ( - x, - istop, - itn, - normr.item(), - normar.item(), - normA.item(), - condA.item(), - normx_est.item(), - ) diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py index 078f5f907..7651e4c19 100644 --- a/tests/test_lsmr_compiled.py +++ b/tests/test_lsmr_compiled.py @@ -14,8 +14,8 @@ import torch # Reference: original scalar-state LSMR (for correctness comparison) -from pyfixest.estimation.torch.lsmr_torch import _lsmr_scalar as lsmr_torch_original -from pyfixest.estimation.torch.lsmr_torch_compiled import lsmr_torch +from pyfixest.estimation.torch.lsmr_torch import _lsmr_eager as lsmr_torch_original +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch # --------------------------------------------------------------------------- # Helpers @@ -136,7 +136,7 @@ def test_correctness_mps(): x_ref, *_ = lsmr_torch_original(A_cpu, b_cpu) - # MPS f32 (auto: use_compile=True) + # MPS f32 (auto: use_compile=False; pass use_compile=True to force) A_mps = A_cpu.to(torch.float32).to_dense().to("mps") b_mps = b_cpu.to(torch.float32).to("mps") diff --git a/tests/test_lsmr_fused.py b/tests/test_lsmr_fused.py deleted file mode 100644 index 537996f5a..000000000 --- a/tests/test_lsmr_fused.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Correctness and timing tests for lsmr_torch_fused vs lsmr_torch. - -Usage: - KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev python -m pytest tests/test_lsmr_fused.py -v -s -""" - -from __future__ import annotations - -import time - -import numpy as np -import pytest -import torch - -from pyfixest.estimation.torch.lsmr_torch import lsmr_torch -from pyfixest.estimation.torch.lsmr_torch_fused import lsmr_torch_fused - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_sparse_problem(m: int, n: int, density: float = 0.01, seed: int = 42): - """Create a sparse CSR system A and dense rhs b on the given device.""" - rng = np.random.default_rng(seed) - nnz = int(m * n * density) - rows = rng.integers(0, m, nnz) - cols = rng.integers(0, n, nnz) - vals = rng.standard_normal(nnz) - - A_coo = torch.sparse_coo_tensor( - torch.tensor(np.stack([rows, cols])), - torch.tensor(vals, dtype=torch.float64), - size=(m, n), - ) - A_csr = A_coo.to_sparse_csr() - b = torch.tensor(rng.standard_normal(m), dtype=torch.float64) - return A_csr, b - - -# --------------------------------------------------------------------------- -# Correctness tests -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize("m,n", [(200, 100), (500, 300), (1000, 500)]) -def test_fused_matches_original(m, n): - """The fused solver should produce the same solution as the original.""" - A, b = _make_sparse_problem(m, n) - - x_orig, istop_orig, itn_orig, *_ = lsmr_torch(A, b) - x_fused, istop_fused, itn_fused, *_ = lsmr_torch_fused(A, b, check_every=1) - - # Solutions should match closely - assert torch.allclose(x_orig, x_fused, atol=1e-6, rtol=1e-6), ( - f"Solutions differ: max_diff={torch.max(torch.abs(x_orig - x_fused)).item():.2e}" - ) - # Iteration count may differ slightly due to convergence check frequency - assert abs(itn_orig - itn_fused) <= 1, f"itn differs: {itn_orig} vs {itn_fused}" - - -@pytest.mark.parametrize("check_every", [1, 5, 10, 50]) -def test_check_every_correctness(check_every): - """Different check_every values should all converge to the same solution.""" - A, b = _make_sparse_problem(500, 300) - x_ref, *_ = lsmr_torch(A, b) - x_fused, _, itn, *_ = lsmr_torch_fused(A, b, check_every=check_every) - - assert torch.allclose(x_ref, x_fused, atol=1e-6, rtol=1e-6), ( - f"check_every={check_every}: max_diff=" - f"{torch.max(torch.abs(x_ref - x_fused)).item():.2e}" - ) - - -def test_zero_rhs(): - """B = 0 should return x = 0.""" - A, _ = _make_sparse_problem(100, 50) - b = torch.zeros(100, dtype=torch.float64) - x, istop, itn, *_ = lsmr_torch_fused(A, b) - assert torch.all(x == 0) - assert itn == 0 - - -def test_damping(): - """Damped solve should differ from undamped.""" - A, b = _make_sparse_problem(200, 100) - x_undamped, *_ = lsmr_torch_fused(A, b, damp=0.0) - x_damped, *_ = lsmr_torch_fused(A, b, damp=1.0) - assert not torch.allclose(x_undamped, x_damped, atol=1e-3) - - -# --------------------------------------------------------------------------- -# Branchless _sym_ortho tests -# --------------------------------------------------------------------------- - - -def test_sym_ortho_matches_scipy(): - """Branchless _sym_ortho_t should match SciPy's convention.""" - from pyfixest.estimation.torch.lsmr_torch import _sym_ortho - from pyfixest.estimation.torch.lsmr_torch_fused import _sym_ortho_t - - cases = [ - (3.0, 4.0), - (-3.0, 4.0), - (3.0, -4.0), - (-3.0, -4.0), - (0.0, 5.0), - (5.0, 0.0), - (0.0, 0.0), - (1e-300, 1e-300), - (1e300, 1e300), - (1.0, 1e-15), - ] - for a_val, b_val in cases: - c_ref, s_ref, r_ref = _sym_ortho(a_val, b_val) - a_t = torch.tensor(a_val, dtype=torch.float64) - b_t = torch.tensor(b_val, dtype=torch.float64) - c_t, s_t, r_t = _sym_ortho_t(a_t, b_t) - assert abs(c_t.item() - c_ref) < 1e-10, f"c mismatch for ({a_val}, {b_val})" - assert abs(s_t.item() - s_ref) < 1e-10, f"s mismatch for ({a_val}, {b_val})" - assert abs(r_t.item() - r_ref) < 1e-10, f"r mismatch for ({a_val}, {b_val})" - - -# --------------------------------------------------------------------------- -# Timing benchmark (not a test — run with -s to see output) -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize("m,n", [(5000, 2000), (10000, 5000)]) -def test_timing_comparison(m, n): - """Compare wall time of original vs fused LSMR.""" - A, b = _make_sparse_problem(m, n, density=0.005) - - # Warmup - lsmr_torch(A, b, maxiter=5) - lsmr_torch_fused(A, b, maxiter=5) - - # Original - t0 = time.perf_counter() - x_orig, _, itn_orig, *_ = lsmr_torch(A, b) - t_orig = time.perf_counter() - t0 - - # Fused (check every 10) - t0 = time.perf_counter() - x_fused, _, itn_fused, *_ = lsmr_torch_fused(A, b, check_every=10) - t_fused = time.perf_counter() - t0 - - speedup = t_orig / t_fused if t_fused > 0 else float("inf") - print( - f"\n [{m}x{n}] original: {t_orig:.3f}s ({itn_orig} iters) | " - f"fused: {t_fused:.3f}s ({itn_fused} iters) | " - f"speedup: {speedup:.2f}x" - ) - - # Correctness sanity check - assert torch.allclose(x_orig, x_fused, atol=1e-5, rtol=1e-5) From d9be2ca21147f9d2a3951f352a3b0c3685f9cec0 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 19:35:00 +0100 Subject: [PATCH 09/16] ruff --- tests/test_lsmr_compiled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py index 7651e4c19..da348cd2f 100644 --- a/tests/test_lsmr_compiled.py +++ b/tests/test_lsmr_compiled.py @@ -177,7 +177,7 @@ def test_timing_mps(m, n): vals = rng.standard_normal(nnz).astype(np.float32) A = torch.zeros(m, n, dtype=torch.float32) - for r, c, v in zip(rows, cols, vals): + for r, c, v in zip(rows, cols, vals, strict=True): A[r, c] += v A = A.to("mps") b = torch.tensor(rng.standard_normal(m).astype(np.float32), device="mps") From 07ad5b9724436057078eeab105056f578ddb9956 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 20:32:41 +0100 Subject: [PATCH 10/16] minor fixes. --- pyfixest/estimation/torch/lsmr_torch.py | 5 ++++- tests/test_lsmr_compiled.py | 19 +------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index b4ef26da3..66d4d26b3 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -755,7 +755,10 @@ def _lsmr_compiled( if itn == 0: return x, istop, 0, normb.item(), normar_init, alpha.item(), 1.0, 0.0 - normx_val = torch.linalg.norm(x).item() + # normx_val was already computed inside the convergence block (line 725); + # only recompute if the loop exhausted maxiter without converging. + if istop == 0 or istop == 7: + normx_val = torch.linalg.norm(x).item() return ( x, istop, diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py index da348cd2f..e785bf114 100644 --- a/tests/test_lsmr_compiled.py +++ b/tests/test_lsmr_compiled.py @@ -59,23 +59,6 @@ def test_matches_original_cpu(m, n): assert itn_orig == itn_new, f"itn differs: {itn_orig} vs {itn_new}" -def test_zero_rhs(): - """B = 0 should return x = 0.""" - A, _ = _make_sparse_problem(100, 50) - b = torch.zeros(100, dtype=torch.float64) - x, _istop, itn, *_ = lsmr_torch(A, b) - assert torch.all(x == 0) - assert itn == 0 - - -def test_damping(): - """Damped solve should differ from undamped.""" - A, b = _make_sparse_problem(200, 100) - x_undamped, *_ = lsmr_torch(A, b, damp=0.0) - x_damped, *_ = lsmr_torch(A, b, damp=1.0) - assert not torch.allclose(x_undamped, x_damped, atol=1e-3) - - def test_diagnostics_match_original(): """normr, normar, normA, condA, normx diagnostics match reference.""" A, b = _make_sparse_problem(500, 300) @@ -125,7 +108,7 @@ def test_auto_cpu_defaults(): # MPS + torch.compile tests # --------------------------------------------------------------------------- -HAS_MPS = torch.backends.mps.is_available() +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") From d9ed71c9f5a8ce8909c7a8923fa499237c6e66c5 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 22:25:20 +0100 Subject: [PATCH 11/16] wip: batched LSMR --- pyfixest/estimation/torch/demean_torch_.py | 44 ++- pyfixest/estimation/torch/lsmr_torch.py | 418 +++++++++++++++++++++ tests/test_batched_lsmr.py | 374 ++++++++++++++++++ 3 files changed, 826 insertions(+), 10 deletions(-) create mode 100644 tests/test_batched_lsmr.py diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 8a9e51d4f..908d22436 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -17,7 +17,7 @@ import torch from numpy.typing import NDArray -from pyfixest.estimation.torch.lsmr_torch import lsmr_torch +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch, lsmr_torch_batched def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: @@ -248,23 +248,38 @@ def _demean_torch_on_device_impl( # Build preconditioned operator once (not per column) A_precond = _PreconditionedSparse(D_weighted, M_inv) - # Solve for each column + # Solve for each column — batched SpMM for K >= threshold, sequential otherwise. + # Batched SpMM amortizes sparse-matrix load across K columns. Benchmarked + # breakeven: ~K=5 on MPS (Metal), may differ on CUDA (cuSPARSE is faster). + _BATCHED_K_THRESHOLD = 5 theta = torch.zeros(D_cols, K, dtype=dtype, device=device) - success = True - for k in range(K): - z, istop, _itn, _normr, _normar, _normA, _condA, _normx = lsmr_torch( + if K >= _BATCHED_K_THRESHOLD: + # Batched: single call, K columns solved simultaneously via SpMM + Z, istop_vec, _itn, *_ = lsmr_torch_batched( A_precond, - x_w[:, k], + x_w, damp=0.0, atol=tol, btol=tol, maxiter=maxiter, ) - - # Recover theta from preconditioned solution: theta = M_inv * z - theta[:, k] = M_inv * z - success = success and (istop in (1, 2, 3)) + theta = M_inv.unsqueeze(1) * Z + success = ((istop_vec >= 1) & (istop_vec <= 3)).all().item() + else: + # Sequential: K < threshold, per-column single-RHS path is faster + success = True + for k in range(K): + z, istop, _itn, _normr, _normar, _normA, _condA, _normx = lsmr_torch( + A_precond, + x_w[:, k], + damp=0.0, + atol=tol, + btol=tol, + maxiter=maxiter, + ) + theta[:, k] = M_inv * z + success = success and (istop in (1, 2, 3)) # Compute residuals: x_demeaned = x - D_unweighted @ theta x_demeaned = x_t - D_unweighted @ theta @@ -411,6 +426,15 @@ def mv(self, v: torch.Tensor) -> torch.Tensor: # Compute D @ (M_inv * v) return self._D @ (self._M_inv * v) + def mm(self, V: torch.Tensor) -> torch.Tensor: + """Batched matvec: A_precond @ V where V is (n, K) or (m, K). + + Same logic as mv() but broadcasts M_inv over K columns via unsqueeze. + """ + if self._transposed: + return self._M_inv.unsqueeze(1) * (self._D_t @ V) + return self._D @ (self._M_inv.unsqueeze(1) * V) + def t(self) -> _PreconditionedSparse: """Return cached transpose view.""" if self._T is None: diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index 66d4d26b3..9628476b6 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -89,6 +89,359 @@ def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: return c, s, r +# --------------------------------------------------------------------------- +# Batched matvec helpers (SpMM: sparse @ dense matrix) +# --------------------------------------------------------------------------- + + +def _matvec_batched(A, V: torch.Tensor) -> torch.Tensor: + """A @ V where V is (n, K). SpMM for sparse A, mm() for wrappers.""" + if isinstance(A, torch.Tensor): + return A @ V + return A.mm(V) + + +def _rmatvec_batched(At, U: torch.Tensor) -> torch.Tensor: + """A^T @ U where U is (m, K). SpMM.""" + if isinstance(At, torch.Tensor): + return At @ U + return At.mm(U) + + +# --------------------------------------------------------------------------- +# Vectorized Givens rotation for (K,) tensors +# --------------------------------------------------------------------------- + + +def _sym_ortho_vec( + a: torch.Tensor, b: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Stable Givens rotation (SymOrtho) vectorized over K columns. + + Given (K,) tensors a and b, compute c, s, r such that for each k: + [ c_k s_k ] [ a_k ] = [ r_k ] + [-s_k c_k ] [ b_k ] [ 0 ] + + Uses torch.where for branchless execution on GPU. Division guards + use ones_like (not clamp) to preserve sign in dead lanes. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + zero = torch.zeros_like(a) + one = torch.ones_like(a) + + # Safe divisors: replace zeros with ones to prevent NaN in dead lanes. + # The result of the dead-lane computation is discarded by torch.where. + safe_a = torch.where(a != 0, a, one) + safe_b = torch.where(b != 0, b, one) + + # Case 1: b == 0 + c_b0 = torch.where(a == 0, zero, torch.sign(a)) + s_b0 = zero + r_b0 = abs_a + + # Case 2: a == 0 + c_a0 = zero + s_a0 = torch.sign(b) + r_a0 = abs_b + + # Case 3: |b| > |a| (neither zero) + tau_3 = a / safe_b + s_3 = torch.sign(b) / torch.sqrt(one + tau_3 * tau_3) + safe_s_3 = torch.where(s_3 != 0, s_3, one) + c_3 = s_3 * tau_3 + r_3 = b / safe_s_3 + + # Case 4: |a| >= |b| (neither zero) + tau_4 = b / safe_a + c_4 = torch.sign(a) / torch.sqrt(one + tau_4 * tau_4) + safe_c_4 = torch.where(c_4 != 0, c_4, one) + s_4 = c_4 * tau_4 + r_4 = a / safe_c_4 + + # Select: b==0 → case1, a==0 → case2, |b|>|a| → case3, else → case4 + is_b0 = b == 0 + is_a0 = a == 0 + is_b_gt_a = abs_b > abs_a + + # Build from innermost to outermost condition + c = torch.where(is_b_gt_a, c_3, c_4) + s = torch.where(is_b_gt_a, s_3, s_4) + r = torch.where(is_b_gt_a, r_3, r_4) + + c = torch.where(is_a0, c_a0, c) + s = torch.where(is_a0, s_a0, s) + r = torch.where(is_a0, r_a0, r) + + c = torch.where(is_b0, c_b0, c) + s = torch.where(is_b0, s_b0, s) + r = torch.where(is_b0, r_b0, r) + + return c, s, r + + +# =========================================================================== +# Implementation 0: batched LSMR — K right-hand sides via SpMM +# =========================================================================== + + +def _lsmr_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, +) -> tuple[ + torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, +]: + """ + Batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. + + Replaces K sequential SpMV with SpMM for GPU throughput. + All K columns run in lock-step; converged columns do harmless work. + + Parameters + ---------- + A : sparse tensor or LinearOperator-like + Matrix of shape (m, n). + B : torch.Tensor + Dense matrix of shape (m, K). + damp : float + Damping factor. + atol, btol : float + Stopping tolerances (same for all columns). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + + Returns + ------- + X : (n, K) solution matrix + istop : (K,) int tensor — per-column stopping reason + itn : int — iterations used + normr : (K,) — per-column ||b - Ax|| + normar : (K,) — per-column ||A^T(b - Ax)|| + normA : (K,) — per-column estimate of ||A||_F + condA : (K,) — per-column estimate of cond(A) + normx : (K,) — per-column ||x|| + """ + m, n = A.shape + K = B.shape[1] + device = B.device + dtype = B.dtype + + if maxiter is None: + maxiter = min(m, n) + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) + + X = torch.zeros(n, K, device=device, dtype=dtype) + beta = normb.clone() # (K,) + + # Normalize columns where beta > 0 + nonzero = beta > 0 + safe_beta = torch.where(nonzero, beta, torch.ones_like(beta)) + U = U / safe_beta.unsqueeze(0) + U[:, ~nonzero] = 0.0 + + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) + + nonzero_a = alpha > 0 + safe_alpha = torch.where(nonzero_a, alpha, torch.ones_like(alpha)) + V = V / safe_alpha.unsqueeze(0) + V[:, ~nonzero_a] = 0.0 + + # --- Scalar state as (K,) tensors --- + itn = 0 + zetabar = alpha * beta + alphabar = alpha.clone() + rho = torch.ones(K, device=device, dtype=dtype) + rhobar = torch.ones(K, device=device, dtype=dtype) + cbar = torch.ones(K, device=device, dtype=dtype) + sbar = torch.zeros(K, device=device, dtype=dtype) + + H = V.clone() # (n, K) + Hbar = torch.zeros(n, K, device=device, dtype=dtype) + + # ||r|| estimation state + betadd = beta.clone() + betad = torch.zeros(K, device=device, dtype=dtype) + rhodold = torch.ones(K, device=device, dtype=dtype) + tautildeold = torch.zeros(K, device=device, dtype=dtype) + thetatilde = torch.zeros(K, device=device, dtype=dtype) + zeta = torch.zeros(K, device=device, dtype=dtype) + d = torch.zeros(K, device=device, dtype=dtype) + + # ||A|| and cond(A) estimation + normA2 = alpha * alpha + maxrbar = torch.zeros(K, device=device, dtype=dtype) + minrbar_init = 1e100 if dtype == torch.float64 else 1e10 + minrbar = torch.full((K,), minrbar_init, device=device, dtype=dtype) + + # Convergence tracking + istop = torch.zeros(K, device=device, dtype=torch.long) + ctol = 1.0 / conlim if conlim > 0 else 0.0 + normr = beta.clone() + normar = alpha * beta + + damp_vec = torch.full((K,), damp, device=device, dtype=dtype) + + # Early exit: if all normar == 0 or all normb == 0 + if (normar == 0).all(): + return (X, istop, itn, normr, normar, + torch.sqrt(normA2), torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype)) + + if (normb == 0).all(): + X.zero_() + return (X, istop, itn, normr, normar, + torch.sqrt(normA2), torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype)) + + # --- Main iteration loop --- + while itn < maxiter: + itn += 1 + + # Bidiagonalization step: SpMM replaces SpMV + U = _matvec_batched(A, V) - alpha.unsqueeze(0) * U # (m, K) + beta = torch.linalg.norm(U, dim=0) # (K,) + + nonzero_b = beta > 0 + safe_b = torch.where(nonzero_b, beta, torch.ones_like(beta)) + U = U / safe_b.unsqueeze(0) + U[:, ~nonzero_b] = 0.0 + + V = _rmatvec_batched(At, U) - beta.unsqueeze(0) * V # (n, K) + alpha = torch.linalg.norm(V, dim=0) # (K,) + + nonzero_a = alpha > 0 + safe_a = torch.where(nonzero_a, alpha, torch.ones_like(alpha)) + V = V / safe_a.unsqueeze(0) + V[:, ~nonzero_a] = 0.0 + + # Givens rotation 1: (alphabar, damp) + chat, shat, alphahat = _sym_ortho_vec(alphabar, damp_vec) + + # Givens rotation 2: (alphahat, beta) + rhoold = rho + c, s, rho = _sym_ortho_vec(alphahat, beta) + thetanew = s * alpha + alphabar = c * alpha + + # Givens rotation 3: rhobar update + rhobarold = rhobar + zetaold = zeta + thetabar = sbar * rho + rhotemp = cbar * rho + cbar, sbar, rhobar = _sym_ortho_vec(rhotemp, thetanew) + zeta = cbar * zetabar + zetabar = -sbar * zetabar + + # Vector updates: broadcast (K,) scalars over (n, K) matrices + # Guard divisions: when a column has zero RHS, rho/rhobar can be 0. + # Using clamp ensures 0/0 → 0 instead of NaN (numerator is also 0). + _eps = 1e-30 + hbar_coeff = -(thetabar * rho) / torch.clamp(rhoold * rhobarold, min=_eps) + Hbar = H + Hbar * hbar_coeff.unsqueeze(0) + x_coeff = zeta / torch.clamp(rho * rhobar, min=_eps) + X = X + x_coeff.unsqueeze(0) * Hbar + h_coeff = -(thetanew / torch.clamp(rho, min=_eps)) + H = V + H * h_coeff.unsqueeze(0) + + # ||r|| estimation + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd = -s * betaacute + + thetatildeold = thetatilde + ctildeold, stildeold, rhotildeold = _sym_ortho_vec(rhodold, thetabar) + thetatilde = stildeold * rhobar + rhodold = ctildeold * rhobar + betad = -stildeold * betad + ctildeold * betahat + + safe_rhotildeold = torch.clamp(rhotildeold, min=1e-30) + tautildeold = (zetaold - thetatildeold * tautildeold) / safe_rhotildeold + safe_rhodold = torch.clamp(rhodold, min=1e-30) + taud = (zeta - thetatilde * tautildeold) / safe_rhodold + d = d + betacheck * betacheck + normr = torch.sqrt(d + (betad - taud) ** 2 + betadd * betadd) + + # ||A|| estimation + normA2 = normA2 + beta * beta + normA = torch.sqrt(normA2) + normA2 = normA2 + alpha * alpha + + # cond(A) estimation + maxrbar = torch.maximum(maxrbar, rhobarold) + if itn > 1: + minrbar = torch.minimum(minrbar, rhobarold) + condA = torch.maximum(maxrbar, rhotemp) / torch.clamp( + torch.minimum(minrbar, rhotemp), min=1e-30 + ) + + # Per-column convergence check + normar = torch.abs(zetabar) + normx = torch.linalg.norm(X, dim=0) # (K,) + + safe_normb = torch.clamp(normb, min=1e-30) + test1 = normr / safe_normb + safe_normA_normr = torch.clamp(normA * normr, min=1e-30) + test2 = normar / safe_normA_normr + test3 = 1.0 / condA + t1 = test1 / (1.0 + normA * normx / safe_normb) + rtol = btol + atol * normA * normx / safe_normb + + # Determine stopping reason per column (only set if not already set) + not_yet = istop == 0 + if not_yet.any(): + new_stop = torch.zeros(K, device=device, dtype=torch.long) + new_stop = torch.where(test1 <= rtol, torch.ones_like(new_stop), new_stop) + new_stop = torch.where( + (test2 <= atol) & (new_stop == 0), + 2 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (test3 <= ctol) & (new_stop == 0), + 3 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + t1 <= 1.0) & (new_stop == 0), + 4 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + test2 <= 1.0) & (new_stop == 0), + 5 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + test3 <= 1.0) & (new_stop == 0), + 6 * torch.ones_like(new_stop), new_stop, + ) + istop = torch.where(not_yet, new_stop, istop) + + if (istop > 0).all(): + break + + # Mark columns that hit maxiter + istop = torch.where( + (istop == 0) & (itn >= maxiter), + 7 * torch.ones_like(istop), + istop, + ) + + return X, istop, itn, normr, normar, normA, condA, normx + + # =========================================================================== # Implementation 1: scalar-state LSMR (CPU / MPS) # =========================================================================== @@ -819,3 +1172,68 @@ def lsmr_torch( conlim=conlim, maxiter=maxiter, ) + + +def lsmr_torch_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, +) -> tuple[ + torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, +]: + """ + Batched LSMR solver — solve K right-hand sides simultaneously via SpMM. + + Solves ``min ||B - A X||_F`` where B has K columns. Instead of K + sequential sparse matrix-vector products (SpMV), each iteration uses + a single sparse matrix-matrix product (SpMM), which loads the sparse + matrix once and streams through K dense columns. For K >= 2 on GPU, + this is significantly faster than K sequential ``lsmr_torch`` calls. + + All K columns run in lock-step — converged columns continue doing + harmless arithmetic until all columns converge or maxiter is reached. + + Parameters + ---------- + A : sparse tensor or LinearOperator-like + Matrix of shape (m, n). Must support ``A @ V`` for dense V of + shape (n, K). For ``_PreconditionedSparse``, this requires an + ``mm()`` method. + B : torch.Tensor + Dense RHS matrix of shape (m, K). + damp : float + Damping factor for regularized least-squares. + atol, btol : float + Stopping tolerances (applied identically to all columns). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + + Returns + ------- + X : torch.Tensor, shape (n, K) + Solution matrix. + istop : torch.Tensor, shape (K,), dtype long + Per-column stopping reason (0-7, same codes as ``lsmr_torch``). + itn : int + Number of iterations used (max across all columns). + normr : torch.Tensor, shape (K,) + Per-column ``||b_k - A x_k||``. + normar : torch.Tensor, shape (K,) + Per-column ``||A^T(b_k - A x_k)||``. + normA : torch.Tensor, shape (K,) + Per-column estimate of Frobenius norm of A. + condA : torch.Tensor, shape (K,) + Per-column estimate of condition number of A. + normx : torch.Tensor, shape (K,) + Per-column ``||x_k||``. + """ + return _lsmr_batched( + A, B, damp=damp, atol=atol, btol=btol, conlim=conlim, maxiter=maxiter, + ) diff --git a/tests/test_batched_lsmr.py b/tests/test_batched_lsmr.py new file mode 100644 index 000000000..6cd062c29 --- /dev/null +++ b/tests/test_batched_lsmr.py @@ -0,0 +1,374 @@ +""" +Tests for batched LSMR solver (K right-hand sides via SpMM). + +Verifies that lsmr_torch_batched produces identical results to K sequential +lsmr_torch calls, and that the batched demeaning path matches the sequential one. + +Usage: + KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev pytest tests/test_batched_lsmr.py -v -s +""" + +from __future__ import annotations + +import numpy as np +import pyhdfe +import pytest + +torch = pytest.importorskip("torch") + +from pyfixest.estimation.torch.demean_torch_ import ( # noqa: E402 + _PreconditionedSparse, + _build_sparse_dummy, + _scale_sparse_rows, + demean_torch, +) +from pyfixest.estimation.torch.lsmr_torch import ( # noqa: E402 + lsmr_torch, + lsmr_torch_batched, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sparse_problem(m: int, n: int, density: float = 0.1, seed: int = 42): + """Create a sparse CSR system A and dense rhs b. + + Default density=0.1 ensures the system is well-conditioned enough + for LSMR to converge within min(m, n) iterations. + """ + rng = np.random.default_rng(seed) + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + A_coo = torch.sparse_coo_tensor( + torch.tensor(np.stack([rows, cols])), + torch.tensor(vals, dtype=torch.float64), + size=(m, n), + ) + A_csr = A_coo.to_sparse_csr() + return A_csr + + +def _make_rhs(m: int, K: int, seed: int = 42) -> torch.Tensor: + """Create a dense RHS matrix B of shape (m, K).""" + rng = np.random.default_rng(seed) + return torch.tensor(rng.standard_normal((m, K)), dtype=torch.float64) + + +# --------------------------------------------------------------------------- +# Core batched LSMR tests +# --------------------------------------------------------------------------- + + +class TestBatchedLSMR: + """Unit tests for lsmr_torch_batched.""" + + def test_matches_sequential(self): + """K=5: batched result should match K sequential lsmr_torch calls. + + Tolerance is 1e-6 because the batched path uses vectorized tensor + Givens rotations while sequential uses Python-float math. Accumulated + rounding differences over many iterations produce ~1e-7 divergence. + """ + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=123) + + # Batched solve + X_batch, istop_batch, itn_batch, *_ = lsmr_torch_batched(A, B) + + # Sequential solve + for k in range(K): + x_seq, istop_seq, itn_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Column {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + @pytest.mark.parametrize("K", [2, 10, 20]) + def test_matches_sequential_various_K(self, K): + """Batched matches sequential for various K values.""" + m, n = 300, 150 + A = _make_sparse_problem(m, n, seed=77) + B = _make_rhs(m, K, seed=88) + + X_batch, *_ = lsmr_torch_batched(A, B) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"K={K}, col {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + def test_single_column(self): + """K=1 batched should match single-RHS lsmr_torch.""" + m, n = 200, 100 + A = _make_sparse_problem(m, n) + b = _make_rhs(m, 1, seed=42) + + X_batch, istop_batch, itn_batch, *_ = lsmr_torch_batched(A, b) + x_single, istop_single, itn_single, *_ = lsmr_torch(A, b[:, 0]) + + assert torch.allclose(X_batch[:, 0], x_single, atol=1e-6) + + def test_convergence_per_column(self): + """Columns with different difficulty should converge independently.""" + m, n = 200, 100 + A = _make_sparse_problem(m, n) + + rng = np.random.default_rng(999) + # Column 0: easy (small values), Column 1: harder (large values) + B = torch.zeros(m, 2, dtype=torch.float64) + B[:, 0] = torch.tensor(rng.standard_normal(m) * 0.01, dtype=torch.float64) + B[:, 1] = torch.tensor(rng.standard_normal(m) * 100.0, dtype=torch.float64) + + X, istop, itn, *_ = lsmr_torch_batched(A, B) + + # Both should converge (istop in 1-3) + assert (istop >= 1).all() and (istop <= 3).all(), ( + f"Not all columns converged: istop = {istop}" + ) + # Solutions should match sequential (convergence correctness) + for k in range(2): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Column {k}: batched vs sequential max diff = " + f"{torch.max(torch.abs(X[:, k] - x_seq)).item():.2e}" + ) + + def test_zero_rhs_column(self): + """B with a zero column should produce zero solution for that column.""" + m, n = 100, 50 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, 3, seed=42) + B[:, 1] = 0.0 # Zero out middle column + + X, istop, *_ = lsmr_torch_batched(A, B) + + assert torch.allclose( + X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 + ), f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + + def test_all_zero_rhs(self): + """All-zero B should return all-zero X with istop=0.""" + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = torch.zeros(m, K, dtype=torch.float64) + + X, istop, itn, *_ = lsmr_torch_batched(A, B) + + assert torch.allclose(X, torch.zeros(n, K, dtype=torch.float64), atol=1e-12) + assert itn == 0 + + def test_damp(self): + """Damped batched should match damped sequential.""" + m, n, K = 200, 100, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=55) + damp = 5.0 + + X_batch, *_ = lsmr_torch_batched(A, B, damp=damp) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k], damp=damp) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Damped col {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + def test_overdetermined_known_solution(self): + """Overdetermined system with exact solution, K=3 columns.""" + A_dense = torch.tensor( + [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], dtype=torch.float64 + ) + A_sparse = A_dense.to_sparse_csr() + + X_true = torch.tensor( + [[1.0, 2.0, -1.0], [-1.0, 0.5, 3.0]], dtype=torch.float64 + ) + B = A_dense @ X_true # (3, 3) + + X_sol, istop, *_ = lsmr_torch_batched(A_sparse, B) + + assert torch.allclose(X_sol, X_true, atol=1e-10), ( + f"Solution mismatch: max diff = " + f"{torch.max(torch.abs(X_sol - X_true)).item():.2e}" + ) + assert ((istop >= 1) & (istop <= 3)).all() + + def test_maxiter_exhaustion(self): + """Forcing maxiter=2 should return istop=7 for all columns.""" + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=42) + + _, istop, itn, *_ = lsmr_torch_batched( + A, B, maxiter=2, atol=1e-15, btol=1e-15 + ) + + assert (istop == 7).all(), f"Expected istop=7, got {istop}" + assert itn == 2 + + def test_return_shapes(self): + """Verify all return values have correct shapes.""" + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K) + + X, istop, itn, normr, normar, normA, condA, normx = lsmr_torch_batched(A, B) + + assert X.shape == (n, K) + assert istop.shape == (K,) + assert isinstance(itn, int) + assert normr.shape == (K,) + assert normar.shape == (K,) + assert normA.shape == (K,) + assert condA.shape == (K,) + assert normx.shape == (K,) + + +# --------------------------------------------------------------------------- +# PreconditionedSparse.mm() tests +# --------------------------------------------------------------------------- + + +class TestPreconditionedSparseMM: + """Tests for _PreconditionedSparse.mm() batched matvec.""" + + def test_mm_matches_mv_loop(self): + """mm(V) should equal column-wise mv(v_k) for each k.""" + rng = np.random.default_rng(42) + N, G, K = 200, 20, 5 + + flist = rng.choice(G, (N, 1)).astype(np.uint64) + D = _build_sparse_dummy(flist, torch.device("cpu"), torch.float64) + M_inv = 1.0 / torch.sqrt( + torch.tensor(np.bincount(flist[:, 0], minlength=G), dtype=torch.float64) + ) + + A = _PreconditionedSparse(D, M_inv) + V = torch.randn(G, K, dtype=torch.float64) + + # Batched + result_mm = A.mm(V) + + # Sequential + for k in range(K): + result_mv = A.mv(V[:, k]) + assert torch.allclose(result_mm[:, k], result_mv, atol=1e-12), ( + f"mm vs mv mismatch at column {k}" + ) + + def test_mm_transpose_matches_mv_transpose(self): + """A^T.mm(U) should match column-wise A^T.mv(u_k).""" + rng = np.random.default_rng(42) + N, G, K = 200, 20, 5 + + flist = rng.choice(G, (N, 1)).astype(np.uint64) + D = _build_sparse_dummy(flist, torch.device("cpu"), torch.float64) + M_inv = 1.0 / torch.sqrt( + torch.tensor(np.bincount(flist[:, 0], minlength=G), dtype=torch.float64) + ) + + At = _PreconditionedSparse(D, M_inv).t() + U = torch.randn(N, K, dtype=torch.float64) + + result_mm = At.mm(U) + + for k in range(K): + result_mv = At.mv(U[:, k]) + assert torch.allclose(result_mm[:, k], result_mv, atol=1e-12), ( + f"transpose mm vs mv mismatch at column {k}" + ) + + +# --------------------------------------------------------------------------- +# Integration: batched demeaning +# --------------------------------------------------------------------------- + + +class TestDemeanBatched: + """Verify batched demeaning path produces correct results.""" + + def test_batched_demean_matches_pyhdfe(self): + """Batched demean (K>1) should match pyhdfe reference.""" + rng = np.random.default_rng(929291) + N, K = 1000, 10 + x = rng.normal(0, 1, (N, K)) + flist = np.column_stack([ + rng.choice(10, N), + rng.choice(10, N), + ]).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "Batched demean did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="Batched demean vs pyhdfe mismatch", + ) + + def test_batched_demean_weighted_matches_pyhdfe(self): + """Weighted batched demean should match pyhdfe.""" + rng = np.random.default_rng(929291) + N, K = 1000, 5 + x = rng.normal(0, 1, (N, K)) + flist = np.column_stack([ + rng.choice(10, N), + rng.choice(10, N), + ]).astype(np.uint64) + weights = rng.uniform(0.1, 2.0, N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "Weighted batched demean did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="Weighted batched demean vs pyhdfe mismatch", + ) + + def test_single_column_still_works(self): + """K=1 demean should still work (uses sequential path).""" + rng = np.random.default_rng(42) + N = 500 + x = rng.normal(0, 1, (N, 1)) + flist = rng.choice(10, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_many_columns(self): + """K=50 batched demean should match pyhdfe.""" + rng = np.random.default_rng(9999) + N, K = 2000, 50 + x = rng.normal(0, 1, (N, K)) + flist = rng.choice(20, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) From 80dae7365693df00e0ca467c293b8facae650e8f Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 6 Mar 2026 22:54:12 +0100 Subject: [PATCH 12/16] add compiled batched, add helpers --- pyfixest/estimation/torch/lsmr_torch.py | 410 +++++++++++++++++++----- tests/test_batched_lsmr.py | 165 ++++++++++ 2 files changed, 494 insertions(+), 81 deletions(-) diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index 9628476b6..e86d04e32 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -1,18 +1,27 @@ """ Pure PyTorch LSMR iterative solver with optional torch.compile kernel fusion. -Two implementations live in this file: +Four implementations live in this file: -1. ``_lsmr_eager`` — eager-mode PyTorch, Python-float Givens rotations. - Best for CPU and MPS (Metal command-buffer batching already amortizes - kernel-launch overhead). +0. ``_lsmr_batched`` — eager batched LSMR for K right-hand sides via SpMM. + Uses vectorized (K,) Givens rotations with ``_sym_ortho_vec``. -2. ``_lsmr_compiled`` — packs all scalar state into a 1-D tensor and runs +1. ``_lsmr_eager`` — eager single-RHS LSMR, Python-float Givens rotations. + Best for CPU. + +2. ``_lsmr_compiled`` — packs scalar state into a 1-D tensor and runs the Givens / norm / convergence work through a ``torch.compile``-d kernel. On CUDA this fuses ~60 per-iteration kernel launches into one. -The public entry point ``lsmr_torch()`` dispatches automatically: -CUDA → compiled, CPU/MPS → scalar. Pass ``use_compile=True`` to override. +3. ``_lsmr_compiled_batched`` — compiled batched LSMR for K RHS via SpMM. + Packs scalar state into a (_STATE_SIZE, K) tensor — the same compiled + ``_scalar_step`` serves both single-RHS and batched paths since all ops + are shape-agnostic. Fuses scalar work into one kernel per iteration. + +Public entry points: +- ``lsmr_torch()`` dispatches: CUDA → compiled, CPU/MPS → eager. +- ``lsmr_torch_batched()`` dispatches: CUDA → compiled batched, + CPU/MPS → eager batched. Pass ``use_compile=True/False`` to override. Reference: D. C.-L. Fong and M. A. Saunders, @@ -181,6 +190,121 @@ def _sym_ortho_vec( return c, s, r +# --------------------------------------------------------------------------- +# Shared batched helpers +# --------------------------------------------------------------------------- + + +def _safe_normalize_cols( + M: torch.Tensor, norms: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Divide each column of (m, K) matrix ``M`` by its (K,) ``norms``, + zeroing columns where ``norms == 0``. + + Returns ``(M_normalized, norms)`` so the caller has the norms for later use. + """ + nonzero = norms > 0 + safe = torch.where(nonzero, norms, torch.ones_like(norms)) + M = M / safe.unsqueeze(0) + M[:, ~nonzero] = 0.0 + return M, norms + + +def _make_initial_state( + alpha: torch.Tensor, + beta: torch.Tensor, + normb: torch.Tensor, + damp: float, + dtype: torch.dtype, + device: torch.device, + *, + K: int | None = None, +) -> torch.Tensor: + """ + Pack the 20-element LSMR scalar state into a tensor. + + For single-RHS: ``K=None`` → shape ``(_STATE_SIZE,)``. + For batched: ``K=int`` → shape ``(_STATE_SIZE, K)``. + """ + shape = (_STATE_SIZE,) if K is None else (_STATE_SIZE, K) + state = torch.zeros(shape, device=device, dtype=dtype) + state[_I_ALPHABAR] = alpha + state[_I_DAMP] = damp + state[_I_BETA] = beta + state[_I_ALPHA] = alpha + state[_I_SBAR] = 0.0 + state[_I_CBAR] = 1.0 + state[_I_ZETABAR] = alpha * beta + state[_I_RHO] = 1.0 + state[_I_RHOBAR] = 1.0 + state[_I_RHODOLD] = 1.0 + state[_I_TAUTILDEOLD] = 0.0 + state[_I_THETATILDE] = 0.0 + state[_I_BETADD] = beta + state[_I_BETAD] = 0.0 + state[_I_D] = 0.0 + state[_I_NORMA2] = alpha * alpha + state[_I_MAXRBAR] = 0.0 + state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 + state[_I_NORMB] = normb + state[_I_ZETA] = 0.0 + return state + + +def _check_convergence_batched( + istop: torch.Tensor, + test1: torch.Tensor, + rtol: torch.Tensor, + test2: torch.Tensor, + test3: torch.Tensor, + t1: torch.Tensor, + atol: float, + ctol: float, + K: int, + device: torch.device, +) -> torch.Tensor: + """ + Per-column convergence check for batched LSMR. + + Sets istop per column using an only-set-once latch: once a column's + istop becomes non-zero, it is never overwritten. Returns updated istop. + """ + not_yet = istop == 0 + new_stop = torch.zeros(K, device=device, dtype=torch.long) + new_stop = torch.where(test1 <= rtol, torch.ones_like(new_stop), new_stop) + new_stop = torch.where( + (test2 <= atol) & (new_stop == 0), + 2 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (test3 <= ctol) & (new_stop == 0), + 3 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + t1 <= 1.0) & (new_stop == 0), + 4 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + test2 <= 1.0) & (new_stop == 0), + 5 * torch.ones_like(new_stop), new_stop, + ) + new_stop = torch.where( + (1.0 + test3 <= 1.0) & (new_stop == 0), + 6 * torch.ones_like(new_stop), new_stop, + ) + return torch.where(not_yet, new_stop, istop) + + +def _mark_maxiter_batched(istop: torch.Tensor, itn: int, maxiter: int) -> torch.Tensor: + """Set istop=7 for columns that did not converge before maxiter.""" + return torch.where( + (istop == 0) & (itn >= maxiter), + 7 * torch.ones_like(istop), + istop, + ) + + # =========================================================================== # Implementation 0: batched LSMR — K right-hand sides via SpMM # =========================================================================== @@ -247,19 +371,12 @@ def _lsmr_batched( X = torch.zeros(n, K, device=device, dtype=dtype) beta = normb.clone() # (K,) - # Normalize columns where beta > 0 - nonzero = beta > 0 - safe_beta = torch.where(nonzero, beta, torch.ones_like(beta)) - U = U / safe_beta.unsqueeze(0) - U[:, ~nonzero] = 0.0 + U, beta = _safe_normalize_cols(U, beta) V = _rmatvec_batched(At, U) # (n, K) — SpMM alpha = torch.linalg.norm(V, dim=0) # (K,) - nonzero_a = alpha > 0 - safe_alpha = torch.where(nonzero_a, alpha, torch.ones_like(alpha)) - V = V / safe_alpha.unsqueeze(0) - V[:, ~nonzero_a] = 0.0 + V, alpha = _safe_normalize_cols(V, alpha) # --- Scalar state as (K,) tensors --- itn = 0 @@ -315,19 +432,11 @@ def _lsmr_batched( # Bidiagonalization step: SpMM replaces SpMV U = _matvec_batched(A, V) - alpha.unsqueeze(0) * U # (m, K) beta = torch.linalg.norm(U, dim=0) # (K,) - - nonzero_b = beta > 0 - safe_b = torch.where(nonzero_b, beta, torch.ones_like(beta)) - U = U / safe_b.unsqueeze(0) - U[:, ~nonzero_b] = 0.0 + U, beta = _safe_normalize_cols(U, beta) V = _rmatvec_batched(At, U) - beta.unsqueeze(0) * V # (n, K) alpha = torch.linalg.norm(V, dim=0) # (K,) - - nonzero_a = alpha > 0 - safe_a = torch.where(nonzero_a, alpha, torch.ones_like(alpha)) - V = V / safe_a.unsqueeze(0) - V[:, ~nonzero_a] = 0.0 + V, alpha = _safe_normalize_cols(V, alpha) # Givens rotation 1: (alphabar, damp) chat, shat, alphahat = _sym_ortho_vec(alphabar, damp_vec) @@ -402,42 +511,13 @@ def _lsmr_batched( t1 = test1 / (1.0 + normA * normx / safe_normb) rtol = btol + atol * normA * normx / safe_normb - # Determine stopping reason per column (only set if not already set) - not_yet = istop == 0 - if not_yet.any(): - new_stop = torch.zeros(K, device=device, dtype=torch.long) - new_stop = torch.where(test1 <= rtol, torch.ones_like(new_stop), new_stop) - new_stop = torch.where( - (test2 <= atol) & (new_stop == 0), - 2 * torch.ones_like(new_stop), new_stop, - ) - new_stop = torch.where( - (test3 <= ctol) & (new_stop == 0), - 3 * torch.ones_like(new_stop), new_stop, - ) - new_stop = torch.where( - (1.0 + t1 <= 1.0) & (new_stop == 0), - 4 * torch.ones_like(new_stop), new_stop, - ) - new_stop = torch.where( - (1.0 + test2 <= 1.0) & (new_stop == 0), - 5 * torch.ones_like(new_stop), new_stop, - ) - new_stop = torch.where( - (1.0 + test3 <= 1.0) & (new_stop == 0), - 6 * torch.ones_like(new_stop), new_stop, - ) - istop = torch.where(not_yet, new_stop, istop) - + istop = _check_convergence_batched( + istop, test1, rtol, test2, test3, t1, atol, ctol, K, device, + ) if (istop > 0).all(): break - # Mark columns that hit maxiter - istop = torch.where( - (istop == 0) & (itn >= maxiter), - 7 * torch.ones_like(istop), - istop, - ) + istop = _mark_maxiter_batched(istop, itn, maxiter) return X, istop, itn, normr, normar, normA, condA, normx @@ -968,28 +1048,7 @@ def _lsmr_compiled( alpha = torch.linalg.norm(v) v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), alpha * 0.0) - # --- Pack initial scalar state --- - state = torch.zeros(_STATE_SIZE, device=device, dtype=dtype) - state[_I_ALPHABAR] = alpha - state[_I_DAMP] = damp - state[_I_BETA] = beta - state[_I_ALPHA] = alpha - state[_I_SBAR] = 0.0 - state[_I_CBAR] = 1.0 - state[_I_ZETABAR] = alpha * beta - state[_I_RHO] = 1.0 - state[_I_RHOBAR] = 1.0 - state[_I_RHODOLD] = 1.0 - state[_I_TAUTILDEOLD] = 0.0 - state[_I_THETATILDE] = 0.0 - state[_I_BETADD] = beta - state[_I_BETAD] = 0.0 - state[_I_D] = 0.0 - state[_I_NORMA2] = alpha * alpha - state[_I_MAXRBAR] = 0.0 - state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 - state[_I_NORMB] = normb - state[_I_ZETA] = 0.0 # initial zeta (no previous iteration) + state = _make_initial_state(alpha, beta, normb, damp, dtype, device) ctol = 1.0 / conlim if conlim > 0 else 0.0 consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) @@ -1124,6 +1183,177 @@ def _lsmr_compiled( ) +# =========================================================================== +# Implementation 3: compiled batched LSMR — K RHS via SpMM + torch.compile +# =========================================================================== + + +def _lsmr_compiled_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool = True, +) -> tuple[ + torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, +]: + """ + Compiled batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. + + Mirrors ``_lsmr_compiled`` but with (m, K) vectors and SpMM. The scalar + state is packed into a (_STATE_SIZE, K) tensor — each of the 20 scalar + quantities becomes a (K,) vector. ``_scalar_step`` is shape-agnostic: + its indexing and element-wise ops broadcast over the K dimension without + any code changes. + + Called by ``lsmr_torch_batched``; ``use_compile`` is already resolved. + """ + m, n = A.shape + K = B.shape[1] + device = B.device + dtype = B.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Get compiled or uncompiled step function + step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) + + X = torch.zeros(n, K, device=device, dtype=dtype) + beta = normb.clone() # (K,) + + U, beta = _safe_normalize_cols(U, beta) + + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) + V, alpha = _safe_normalize_cols(V, alpha) + + state = _make_initial_state(alpha, beta, normb, damp, dtype, device, K=K) + + ctol = 1.0 / conlim if conlim > 0 else 0.0 + # consts (3,) — _scalar_step indexes as consts[0]/consts[2] which broadcast + # against (K,) state elements automatically. + consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) + + # Early exit check + normar_init = alpha * beta # (K,) + if (normar_init == 0).all(): + return (X, torch.zeros(K, device=device, dtype=torch.long), 0, + beta, torch.zeros(K, device=device, dtype=dtype), + alpha, torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype)) + if (normb == 0).all(): + X.zero_() + return (X, torch.zeros(K, device=device, dtype=torch.long), 0, + beta, torch.zeros(K, device=device, dtype=dtype), + alpha, torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype)) + + H = V.clone() # (n, K) + Hbar = torch.zeros(n, K, device=device, dtype=dtype) + + # Convergence tracking: per-column istop, only-set-once latch + istop = torch.zeros(K, device=device, dtype=torch.long) + + # --- Main iteration loop --- + itn = 0 + _eps = 1e-30 + normx_t = torch.zeros(K, device=device, dtype=dtype) + + while itn < maxiter: + itn += 1 + + # Phase 1: SpMM bidiagonalization (not compilable) + U = _matvec_batched(A, V) - state[_I_ALPHA].unsqueeze(0) * U # (m, K) + beta_new = torch.linalg.norm(U, dim=0) # (K,) + U, beta_new = _safe_normalize_cols(U, beta_new) + + V = _rmatvec_batched(At, U) - beta_new.unsqueeze(0) * V # (n, K) + alpha_new = torch.linalg.norm(V, dim=0) # (K,) + V, alpha_new = _safe_normalize_cols(V, alpha_new) + + # Update beta/alpha in state for the scalar step + state[_I_BETA] = beta_new + state[_I_ALPHA] = alpha_new + + # Phase 2: Compiled scalar step — (_STATE_SIZE, K) → (_OUTPUT_SIZE, K) + out = step_fn(state, consts) + + # Phase 3: Vector updates using scalar results from compiled step + thetanew = out[_O_THETANEW] # (K,) + thetabar = out[_O_THETABAR] # (K,) + zeta = out[_O_ZETA] # (K,) + rho_new = out[_I_RHO] # (K,) + rhobar_new = out[_I_RHOBAR] # (K,) + rhoold = out[_O_RHOOLD] # (K,) + rhobarold = out[_O_RHOBAROLD] # (K,) + + # Safe divisions: some columns may have zero RHS → zero denominators + hbar_coeff = -(thetabar * rho_new) / torch.clamp(rhoold * rhobarold, min=_eps) + Hbar = H + Hbar * hbar_coeff.unsqueeze(0) + x_coeff = zeta / torch.clamp(rho_new * rhobar_new, min=_eps) + X = X + x_coeff.unsqueeze(0) * Hbar + h_coeff = -(thetanew / torch.clamp(rho_new, min=_eps)) + H = V + H * h_coeff.unsqueeze(0) + + # Propagate state for next iteration + state = out[:_STATE_SIZE] + + # Phase 4: Per-column convergence — single .item() sync per iteration. + # No not_yet.any() guard: the torch.where inside _check_convergence_batched + # already protects converged columns, and the guard would add a second + # host-device sync that costs more than the fused tensor ops it skips. + normx_t = torch.linalg.norm(X, dim=0) # (K,) + normr_t = out[_O_NORMR] # (K,) + normA_t = out[_O_NORMA] # (K,) + normb_t = out[_I_NORMB] # (K,) + + safe_normb = torch.clamp(normb_t, min=_eps) + test1_t = normr_t / safe_normb + safe_normA_normr = torch.clamp(normA_t * normr_t, min=_eps) + test2_t = out[_O_NORMAR] / safe_normA_normr + test3_t = 1.0 / out[_O_CONDA] + t1_t = test1_t / (1.0 + normA_t * normx_t / safe_normb) + rtol_t = btol + atol * normA_t * normx_t / safe_normb + + istop = _check_convergence_batched( + istop, test1_t, rtol_t, test2_t, test3_t, t1_t, atol, ctol, K, device, + ) + + # Single .item() sync: check if all columns have converged + if (istop > 0).all().item(): + break + + istop = _mark_maxiter_batched(istop, itn, maxiter) + + # Handle case where loop never ran + if itn == 0: + return (X, istop, 0, normb, normar_init, alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype)) + + return ( + X, + istop, + itn, + out[_O_NORMR], + out[_O_NORMAR], + out[_O_NORMA], + out[_O_CONDA], + normx_t, # reuse from last iteration, avoids redundant norm + ) + + # =========================================================================== # Public API — dispatcher # =========================================================================== @@ -1182,6 +1412,7 @@ def lsmr_torch_batched( btol: float = 1e-8, conlim: float = 1e8, maxiter: int | None = None, + use_compile: bool | None = None, ) -> tuple[ torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, @@ -1198,6 +1429,10 @@ def lsmr_torch_batched( All K columns run in lock-step — converged columns continue doing harmless arithmetic until all columns converge or maxiter is reached. + When ``use_compile=True`` (auto-detected for CUDA/MPS), scalar Givens + rotations are fused into a single compiled kernel via ``torch.compile``, + further reducing per-iteration kernel launches. + Parameters ---------- A : sparse tensor or LinearOperator-like @@ -1214,6 +1449,10 @@ def lsmr_torch_batched( Condition number limit. maxiter : int or None Maximum iterations. Defaults to min(m, n). + use_compile : bool or None + Whether to use ``torch.compile`` for scalar step fusion. + ``None`` (default) auto-detects: compiled for CUDA, eager for + CPU/MPS. Pass ``True`` to force compilation on MPS. Returns ------- @@ -1234,6 +1473,15 @@ def lsmr_torch_batched( normx : torch.Tensor, shape (K,) Per-column ``||x_k||``. """ + device = B.device + if use_compile is None: + use_compile = device.type == "cuda" + + if use_compile: + return _lsmr_compiled_batched( + A, B, damp=damp, atol=atol, btol=btol, conlim=conlim, + maxiter=maxiter, use_compile=True, + ) return _lsmr_batched( A, B, damp=damp, atol=atol, btol=btol, conlim=conlim, maxiter=maxiter, ) diff --git a/tests/test_batched_lsmr.py b/tests/test_batched_lsmr.py index 6cd062c29..04dc3f05c 100644 --- a/tests/test_batched_lsmr.py +++ b/tests/test_batched_lsmr.py @@ -372,3 +372,168 @@ def test_many_columns(self): res_torch, success = demean_torch(x, flist, weights, tol=1e-10) assert success np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + +# --------------------------------------------------------------------------- +# Compiled batched LSMR tests +# --------------------------------------------------------------------------- + +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +class TestCompiledBatchedLSMR: + """Tests for the compiled batched LSMR path (_lsmr_compiled_batched).""" + + def test_compiled_batched_matches_eager_batched(self): + """Packed-state batched matches vectorized eager batched on CPU f64, K=5. + + Uses use_compile=False to avoid Inductor C++ backend issues on macOS. + This still validates the _lsmr_compiled_batched logic (packed state, + _scalar_step on (STATE_SIZE, K) tensors, vector updates with clamp guards). + Actual torch.compile fusion is tested in the MPS tests below. + """ + from pyfixest.estimation.torch.lsmr_torch import ( + _lsmr_batched, + _lsmr_compiled_batched, + ) + + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=123) + + X_eager, istop_eager, itn_eager, *_ = _lsmr_batched(A, B) + X_comp, istop_comp, itn_comp, *_ = _lsmr_compiled_batched( + A, B, use_compile=False + ) + + assert torch.allclose(X_eager, X_comp, atol=1e-6, rtol=1e-6), ( + f"max diff = {torch.max(torch.abs(X_eager - X_comp)).item():.2e}" + ) + assert itn_eager == itn_comp, f"itn: {itn_eager} vs {itn_comp}" + + @pytest.mark.parametrize("K", [2, 10, 20]) + def test_compiled_matches_sequential_various_K(self, K): + """Packed-state batched matches K sequential lsmr_torch calls.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 300, 150 + A = _make_sparse_problem(m, n, seed=77) + B = _make_rhs(m, K, seed=88) + + X_comp, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_comp[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"K={K}, col {k}: max diff = " + f"{torch.max(torch.abs(X_comp[:, k] - x_seq)).item():.2e}" + ) + + def test_compiled_single_column(self): + """Packed-state K=1 matches single-RHS lsmr_torch.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 200, 100 + A = _make_sparse_problem(m, n) + b = _make_rhs(m, 1, seed=42) + + X_comp, *_ = _lsmr_compiled_batched(A, b, use_compile=False) + x_single, *_ = lsmr_torch(A, b[:, 0]) + + assert torch.allclose(X_comp[:, 0], x_single, atol=1e-6), ( + f"max diff = {torch.max(torch.abs(X_comp[:, 0] - x_single)).item():.2e}" + ) + + def test_compiled_zero_rhs_column(self): + """Zero RHS column handled correctly in packed-state path.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 100, 50 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, 3, seed=42) + B[:, 1] = 0.0 # Zero out middle column + + X, istop, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + assert torch.allclose( + X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 + ), f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + + # Non-zero columns should still solve correctly + for k in [0, 2]: + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X[:, k], x_seq, atol=1e-6, rtol=1e-6) + + def test_compiled_damp(self): + """Damped packed-state batched matches damped sequential.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n, K = 200, 100, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=55) + damp = 5.0 + + X_comp, *_ = _lsmr_compiled_batched(A, B, damp=damp, use_compile=False) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k], damp=damp) + assert torch.allclose(X_comp[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Damped col {k}: max diff = " + f"{torch.max(torch.abs(X_comp[:, k] - x_seq)).item():.2e}" + ) + + def test_compiled_all_zero_rhs(self): + """All-zero B returns all-zero X in packed-state path.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = torch.zeros(m, K, dtype=torch.float64) + + X, istop, itn, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + assert torch.allclose(X, torch.zeros(n, K, dtype=torch.float64), atol=1e-12) + assert itn == 0 + + # --- MPS-specific tests --- + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_compiled_mps_correctness(self): + """Compiled batched on MPS f32 vs CPU f64 reference.""" + m, n, K = 300, 150, 5 + A_cpu = _make_sparse_problem(m, n, density=0.1, seed=42) + B_cpu = _make_rhs(m, K, seed=123) + + # CPU f64 reference (eager) + X_ref, *_ = lsmr_torch_batched(A_cpu, B_cpu, use_compile=False) + + # MPS f32 compiled + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + B_mps = B_cpu.to(torch.float32).to("mps") + + X_mps, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=True) + + max_diff = torch.max( + torch.abs(X_ref.float() - X_mps.cpu()) + ).item() + assert max_diff < 0.1, ( + f"MPS f32 compiled vs CPU f64 too different: {max_diff:.2e}" + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_compiled_vs_uncompiled_mps(self): + """Compiled and uncompiled give same results on MPS.""" + m, n, K = 300, 150, 5 + A_cpu = _make_sparse_problem(m, n, density=0.1, seed=42) + B_cpu = _make_rhs(m, K, seed=123) + + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + B_mps = B_cpu.to(torch.float32).to("mps") + + X_comp, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=True) + X_nocomp, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=False) + + max_diff = torch.max(torch.abs(X_comp - X_nocomp)).item() + assert max_diff < 1e-4, ( + f"Compiled vs uncompiled differ on MPS: {max_diff:.2e}" + ) From 90cae22e52d027a07fc53a28f51a412c8295d9b2 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Sat, 7 Mar 2026 20:42:23 +0000 Subject: [PATCH 13/16] refactor: compiled lsmr into seperate file --- .../estimation/torch/_lsmr_compiled_core.py | 330 +++++++++++++++ pyfixest/estimation/torch/demean_torch_.py | 7 +- pyfixest/estimation/torch/lsmr_torch.py | 383 +++--------------- 3 files changed, 392 insertions(+), 328 deletions(-) create mode 100644 pyfixest/estimation/torch/_lsmr_compiled_core.py diff --git a/pyfixest/estimation/torch/_lsmr_compiled_core.py b/pyfixest/estimation/torch/_lsmr_compiled_core.py new file mode 100644 index 000000000..fa1ab855e --- /dev/null +++ b/pyfixest/estimation/torch/_lsmr_compiled_core.py @@ -0,0 +1,330 @@ +""" +Compiled LSMR scalar-step kernel and packed-state protocol. + +This module contains the densest part of the LSMR implementation: +the packed tensor state layout (31 index constants), the fused +``_scalar_step`` kernel that ``torch.compile`` turns into a single +Metal/CUDA kernel, and the helpers that create and cache it. + +Separated from ``lsmr_torch.py`` so the protocol is co-located with +the only code that reads/writes it, and can be unit-tested in isolation. +""" + +from __future__ import annotations + +import threading +import warnings + +import torch + +# Guard value for divisions that can be zero (e.g. normb, rho, rhodold). +# Using clamp(..., min=_DIV_GUARD) or max(..., _DIV_GUARD) prevents 0/0 → NaN. +_DIV_GUARD = 1e-30 + +# --------------------------------------------------------------------------- +# Packed state layout +# --------------------------------------------------------------------------- +# All scalar state is packed into a single 1-D tensor to minimize Metal buffer +# slots (hardware limit: 31 per kernel). +# +# Input state (20 elements): +_I_ALPHABAR = 0 +_I_DAMP = 1 +_I_BETA = 2 +_I_ALPHA = 3 +_I_SBAR = 4 +_I_CBAR = 5 +_I_ZETABAR = 6 +_I_RHO = 7 +_I_RHOBAR = 8 +_I_RHODOLD = 9 +_I_TAUTILDEOLD = 10 +_I_THETATILDE = 11 +_I_BETADD = 12 +_I_BETAD = 13 +_I_D = 14 +_I_NORMA2 = 15 +_I_MAXRBAR = 16 +_I_MINRBAR = 17 +_I_NORMB = 18 +_I_ZETA = 19 # previous iteration's zeta (for normr estimation) + +# Constants (3 elements): atol, btol, ctol + +# Output adds extra slots for vector update coefficients: +_O_THETANEW = 20 +_O_THETABAR = 21 +_O_ZETA = 22 +_O_RHOOLD = 23 +_O_RHOBAROLD = 24 +_O_CONVERGED = 25 +_O_NORMR = 26 +_O_NORMAR = 27 +_O_NORMA = 28 +_O_CONDA = 29 +_O_NORMX_EST = 30 # placeholder, actual normx computed from vector + +_STATE_SIZE = 20 + + +# --------------------------------------------------------------------------- +# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) +# --------------------------------------------------------------------------- + + +def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Overflow-safe hypot: ``sqrt(a^2 + b^2)`` without intermediate overflow. + + Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)^2)``. + Since ``min/max <= 1``, the argument to sqrt never exceeds 2. + Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + big = torch.maximum(abs_a, abs_b) + small = torch.minimum(abs_a, abs_b) + safe_big = torch.where(big == 0, torch.ones_like(big), big) + ratio = small / safe_big + return torch.where( + big == 0, + torch.zeros_like(big), + big * torch.sqrt(1.0 + ratio * ratio), + ) + + +# --------------------------------------------------------------------------- +# Initial state packing +# --------------------------------------------------------------------------- + + +def _make_initial_state( + alpha: torch.Tensor, + beta: torch.Tensor, + normb: torch.Tensor, + damp: float, + dtype: torch.dtype, + device: torch.device, + *, + K: int | None = None, +) -> torch.Tensor: + """ + Pack the 20-element LSMR scalar state into a tensor. + + For single-RHS: ``K=None`` → shape ``(_STATE_SIZE,)``. + For batched: ``K=int`` → shape ``(_STATE_SIZE, K)``. + """ + shape = (_STATE_SIZE,) if K is None else (_STATE_SIZE, K) + state = torch.zeros(shape, device=device, dtype=dtype) + state[_I_ALPHABAR] = alpha + state[_I_DAMP] = damp + state[_I_BETA] = beta + state[_I_ALPHA] = alpha + state[_I_SBAR] = 0.0 + state[_I_CBAR] = 1.0 + state[_I_ZETABAR] = alpha * beta + state[_I_RHO] = 1.0 + state[_I_RHOBAR] = 1.0 + state[_I_RHODOLD] = 1.0 + state[_I_TAUTILDEOLD] = 0.0 + state[_I_THETATILDE] = 0.0 + state[_I_BETADD] = beta + state[_I_BETAD] = 0.0 + state[_I_D] = 0.0 + state[_I_NORMA2] = alpha * alpha + state[_I_MAXRBAR] = 0.0 + state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 + state[_I_NORMB] = normb + state[_I_ZETA] = 0.0 + return state + + +# --------------------------------------------------------------------------- +# Compiled scalar step (single Metal/CUDA kernel after fusion) +# --------------------------------------------------------------------------- + + +def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: + """ + All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond + estimation, and convergence check. + + Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). + Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). + """ + # Unpack + alphabar = state[_I_ALPHABAR] + damp = state[_I_DAMP] + beta = state[_I_BETA] + alpha = state[_I_ALPHA] + sbar = state[_I_SBAR] + cbar = state[_I_CBAR] + zetabar = state[_I_ZETABAR] + rho = state[_I_RHO] + rhobar = state[_I_RHOBAR] + rhodold = state[_I_RHODOLD] + tautildeold = state[_I_TAUTILDEOLD] + thetatilde = state[_I_THETATILDE] + betadd = state[_I_BETADD] + betad = state[_I_BETAD] + d = state[_I_D] + normA2 = state[_I_NORMA2] + maxrbar = state[_I_MAXRBAR] + minrbar = state[_I_MINRBAR] + normb = state[_I_NORMB] + zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) + + atol_t = consts[0] + ctol = consts[2] + + _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero + _ONE = _ZERO + 1.0 + + # --- Givens 1: (alphabar, damp) --- + r1 = _safe_hypot(alphabar, damp) + safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) + chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) + shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) + + # --- Givens 2: (alphahat=r1, beta) --- + rhoold = rho + r2 = _safe_hypot(r1, beta) + safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) + c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) + s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) + rho_new = r2 + thetanew = s * alpha + alphabar_new = c * alpha + + # --- Givens 3: rhobar --- + rhobarold = rhobar + thetabar = sbar * rho_new + rhotemp = cbar * rho_new + r3 = _safe_hypot(rhotemp, thetanew) + safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) + cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) + sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) + rhobar_new = r3 + zeta = cbar_new * zetabar + zetabar_new = -sbar_new * zetabar + + # --- ||r|| estimation --- + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd_new = -s * betaacute + + # Givens 4: rhotilde + r4 = _safe_hypot(rhodold, thetabar) + safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) + ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) + stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) + + thetatilde_new = stildeold * rhobar_new + rhodold_new = ctildeold * rhobar_new + betad_new = -stildeold * betad + ctildeold * betahat + + tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp( + r4, min=_DIV_GUARD + ) + taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( + rhodold_new, min=_DIV_GUARD + ) + d_new = d + betacheck * betacheck + normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) + + # --- ||A|| estimation --- + normA2_new = normA2 + beta * beta + normA = torch.sqrt(normA2_new) + normA2_final = normA2_new + alpha * alpha + + # --- cond(A) estimation --- + maxrbar_new = torch.maximum(maxrbar, rhobarold) + # Match SciPy: only update minrbar from iteration 2 onward. + # maxrbar == 0 on the first call (initial state), so use it as guard. + minrbar_new = torch.where( + maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar + ) + condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( + torch.minimum(minrbar_new, rhotemp), min=_DIV_GUARD + ) + + # --- Convergence check --- + normar = torch.abs(zetabar_new) + test2 = normar / torch.clamp(normA * normr, min=_DIV_GUARD) + test3 = _ONE / condA + + converged_flag = torch.where( + (test2 <= atol_t) + | (test3 <= ctol) + | (_ONE + test2 <= _ONE) + | (_ONE + test3 <= _ONE), + _ONE, + _ZERO, + ) + + # --- Pack output --- + return torch.stack( + [ + alphabar_new, # 0 _I_ALPHABAR + damp, # 1 _I_DAMP (pass through) + beta, # 2 _I_BETA (pass through, updated by caller) + alpha, # 3 _I_ALPHA (pass through, updated by caller) + sbar_new, # 4 _I_SBAR + cbar_new, # 5 _I_CBAR + zetabar_new, # 6 _I_ZETABAR + rho_new, # 7 _I_RHO + rhobar_new, # 8 _I_RHOBAR + rhodold_new, # 9 _I_RHODOLD + tautildeold_new, # 10 _I_TAUTILDEOLD + thetatilde_new, # 11 _I_THETATILDE + betadd_new, # 12 _I_BETADD + betad_new, # 13 _I_BETAD + d_new, # 14 _I_D + normA2_final, # 15 _I_NORMA2 + maxrbar_new, # 16 _I_MAXRBAR + minrbar_new, # 17 _I_MINRBAR + normb, # 18 _I_NORMB (pass through) + zeta, # 19 _I_ZETA (saved for next iteration's zetaold) + thetanew, # 20 _O_THETANEW (for vector update) + thetabar, # 21 _O_THETABAR (for vector update) + zeta, # 22 _O_ZETA (for vector update — same as slot 19) + rhoold, # 23 _O_RHOOLD (for vector update) + rhobarold, # 24 _O_RHOBAROLD (for vector update) + converged_flag, # 25 _O_CONVERGED + normr, # 26 _O_NORMR + normar, # 27 _O_NORMAR + normA, # 28 _O_NORMA + condA, # 29 _O_CONDA + _ZERO, # 30 _O_NORMX_EST (placeholder) + ] + ) + + +# --------------------------------------------------------------------------- +# Module-level compilation cache +# --------------------------------------------------------------------------- +_compiled_step_cache: dict[str, object] = {} +_cache_lock = threading.Lock() + + +def _get_compiled_step(device_type: str): + """Get or create compiled scalar step for the given device type.""" + if device_type in _compiled_step_cache: + return _compiled_step_cache[device_type] + with _cache_lock: + # Double-check after acquiring lock + if device_type not in _compiled_step_cache: + try: + _compiled_step_cache[device_type] = torch.compile( + _scalar_step, backend="inductor", fullgraph=True + ) + except Exception: + warnings.warn( + "torch.compile failed for LSMR scalar step; " + "falling back to eager mode.", + RuntimeWarning, + stacklevel=3, + ) + _compiled_step_cache[device_type] = _scalar_step + return _compiled_step_cache[device_type] diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 908d22436..a00104dec 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -19,6 +19,10 @@ from pyfixest.estimation.torch.lsmr_torch import lsmr_torch, lsmr_torch_batched +# Minimum K (number of RHS columns) for batched SpMM to beat sequential SpMV. +# Benchmarked breakeven: ~K=5 on MPS (Metal) and CUDA (cuSPARSE). +_BATCHED_K_THRESHOLD = 5 + def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: """Auto-detect best available device: CUDA > MPS > CPU. @@ -249,9 +253,6 @@ def _demean_torch_on_device_impl( A_precond = _PreconditionedSparse(D_weighted, M_inv) # Solve for each column — batched SpMM for K >= threshold, sequential otherwise. - # Batched SpMM amortizes sparse-matrix load across K columns. Benchmarked - # breakeven: ~K=5 on MPS (Metal), may differ on CUDA (cuSPARSE is faster). - _BATCHED_K_THRESHOLD = 5 theta = torch.zeros(D_cols, K, dtype=dtype, device=device) if K >= _BATCHED_K_THRESHOLD: diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index e86d04e32..7680c72f5 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -1,7 +1,7 @@ """ Pure PyTorch LSMR iterative solver with optional torch.compile kernel fusion. -Four implementations live in this file: +Four implementations: 0. ``_lsmr_batched`` — eager batched LSMR for K right-hand sides via SpMM. Uses vectorized (K,) Givens rotations with ``_sym_ortho_vec``. @@ -18,6 +18,9 @@ ``_scalar_step`` serves both single-RHS and batched paths since all ops are shape-agnostic. Fuses scalar work into one kernel per iteration. +The compiled scalar step kernel, packed-state layout (31 index constants), +and compilation cache live in ``_lsmr_compiled_core.py``. + Public entry points: - ``lsmr_torch()`` dispatches: CUDA → compiled, CPU/MPS → eager. - ``lsmr_torch_batched()`` dispatches: CUDA → compiled batched, @@ -32,10 +35,33 @@ from __future__ import annotations import math -import threading import torch +from pyfixest.estimation.torch._lsmr_compiled_core import ( + _DIV_GUARD, + _I_ALPHA, + _I_BETA, + _I_NORMB, + _I_RHO, + _I_RHOBAR, + _O_CONDA, + _O_CONVERGED, + _O_NORMA, + _O_NORMAR, + _O_NORMR, + _O_NORMX_EST, + _O_RHOBAROLD, + _O_RHOOLD, + _O_THETABAR, + _O_THETANEW, + _O_ZETA, + _STATE_SIZE, + _get_compiled_step, + _make_initial_state, + _scalar_step, +) + # --------------------------------------------------------------------------- # Sparse matvec helpers # --------------------------------------------------------------------------- @@ -206,52 +232,13 @@ def _safe_normalize_cols( """ nonzero = norms > 0 safe = torch.where(nonzero, norms, torch.ones_like(norms)) - M = M / safe.unsqueeze(0) - M[:, ~nonzero] = 0.0 + # Multiply by mask instead of boolean indexing (M[:, ~nonzero] = 0.0) + # to avoid a GPU→CPU sync from aten::_local_scalar_dense in the indexing path. + mask = nonzero.unsqueeze(0).to(M.dtype) + M = (M / safe.unsqueeze(0)) * mask return M, norms -def _make_initial_state( - alpha: torch.Tensor, - beta: torch.Tensor, - normb: torch.Tensor, - damp: float, - dtype: torch.dtype, - device: torch.device, - *, - K: int | None = None, -) -> torch.Tensor: - """ - Pack the 20-element LSMR scalar state into a tensor. - - For single-RHS: ``K=None`` → shape ``(_STATE_SIZE,)``. - For batched: ``K=int`` → shape ``(_STATE_SIZE, K)``. - """ - shape = (_STATE_SIZE,) if K is None else (_STATE_SIZE, K) - state = torch.zeros(shape, device=device, dtype=dtype) - state[_I_ALPHABAR] = alpha - state[_I_DAMP] = damp - state[_I_BETA] = beta - state[_I_ALPHA] = alpha - state[_I_SBAR] = 0.0 - state[_I_CBAR] = 1.0 - state[_I_ZETABAR] = alpha * beta - state[_I_RHO] = 1.0 - state[_I_RHOBAR] = 1.0 - state[_I_RHODOLD] = 1.0 - state[_I_TAUTILDEOLD] = 0.0 - state[_I_THETATILDE] = 0.0 - state[_I_BETADD] = beta - state[_I_BETAD] = 0.0 - state[_I_D] = 0.0 - state[_I_NORMA2] = alpha * alpha - state[_I_MAXRBAR] = 0.0 - state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 - state[_I_NORMB] = normb - state[_I_ZETA] = 0.0 - return state - - def _check_convergence_batched( istop: torch.Tensor, test1: torch.Tensor, @@ -459,12 +446,11 @@ def _lsmr_batched( # Vector updates: broadcast (K,) scalars over (n, K) matrices # Guard divisions: when a column has zero RHS, rho/rhobar can be 0. # Using clamp ensures 0/0 → 0 instead of NaN (numerator is also 0). - _eps = 1e-30 - hbar_coeff = -(thetabar * rho) / torch.clamp(rhoold * rhobarold, min=_eps) + hbar_coeff = -(thetabar * rho) / torch.clamp(rhoold * rhobarold, min=_DIV_GUARD) Hbar = H + Hbar * hbar_coeff.unsqueeze(0) - x_coeff = zeta / torch.clamp(rho * rhobar, min=_eps) + x_coeff = zeta / torch.clamp(rho * rhobar, min=_DIV_GUARD) X = X + x_coeff.unsqueeze(0) * Hbar - h_coeff = -(thetanew / torch.clamp(rho, min=_eps)) + h_coeff = -(thetanew / torch.clamp(rho, min=_DIV_GUARD)) H = V + H * h_coeff.unsqueeze(0) # ||r|| estimation @@ -479,9 +465,9 @@ def _lsmr_batched( rhodold = ctildeold * rhobar betad = -stildeold * betad + ctildeold * betahat - safe_rhotildeold = torch.clamp(rhotildeold, min=1e-30) + safe_rhotildeold = torch.clamp(rhotildeold, min=_DIV_GUARD) tautildeold = (zetaold - thetatildeold * tautildeold) / safe_rhotildeold - safe_rhodold = torch.clamp(rhodold, min=1e-30) + safe_rhodold = torch.clamp(rhodold, min=_DIV_GUARD) taud = (zeta - thetatilde * tautildeold) / safe_rhodold d = d + betacheck * betacheck normr = torch.sqrt(d + (betad - taud) ** 2 + betadd * betadd) @@ -496,16 +482,16 @@ def _lsmr_batched( if itn > 1: minrbar = torch.minimum(minrbar, rhobarold) condA = torch.maximum(maxrbar, rhotemp) / torch.clamp( - torch.minimum(minrbar, rhotemp), min=1e-30 + torch.minimum(minrbar, rhotemp), min=_DIV_GUARD ) # Per-column convergence check normar = torch.abs(zetabar) normx = torch.linalg.norm(X, dim=0) # (K,) - safe_normb = torch.clamp(normb, min=1e-30) + safe_normb = torch.clamp(normb, min=_DIV_GUARD) test1 = normr / safe_normb - safe_normA_normr = torch.clamp(normA * normr, min=1e-30) + safe_normA_normr = torch.clamp(normA * normr, min=_DIV_GUARD) test2 = normar / safe_normA_normr test3 = 1.0 / condA t1 = test1 / (1.0 + normA * normx / safe_normb) @@ -748,260 +734,6 @@ def _lsmr_eager( # Implementation 2: compiled-state LSMR (CUDA) # =========================================================================== -# --------------------------------------------------------------------------- -# Packed scalar state layout -# --------------------------------------------------------------------------- -# All scalar state is packed into a single 1-D tensor to minimize Metal buffer -# slots (hardware limit: 31 per kernel). -# -# Input state (20 elements): -_I_ALPHABAR = 0 -_I_DAMP = 1 -_I_BETA = 2 -_I_ALPHA = 3 -_I_SBAR = 4 -_I_CBAR = 5 -_I_ZETABAR = 6 -_I_RHO = 7 -_I_RHOBAR = 8 -_I_RHODOLD = 9 -_I_TAUTILDEOLD = 10 -_I_THETATILDE = 11 -_I_BETADD = 12 -_I_BETAD = 13 -_I_D = 14 -_I_NORMA2 = 15 -_I_MAXRBAR = 16 -_I_MINRBAR = 17 -_I_NORMB = 18 -_I_ZETA = 19 # previous iteration's zeta (for normr estimation) - -# Constants (3 elements): atol, btol, ctol - -# Output adds extra slots for vector update coefficients: -_O_THETANEW = 20 -_O_THETABAR = 21 -_O_ZETA = 22 -_O_RHOOLD = 23 -_O_RHOBAROLD = 24 -_O_CONVERGED = 25 -_O_NORMR = 26 -_O_NORMAR = 27 -_O_NORMA = 28 -_O_CONDA = 29 -_O_NORMX_EST = 30 # placeholder, actual normx computed from vector - -_STATE_SIZE = 20 - - -# --------------------------------------------------------------------------- -# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) -# --------------------------------------------------------------------------- - - -def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """ - Overflow-safe hypot: ``sqrt(a** + b**)`` without intermediate overflow. - - Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)**)``. - Since ``min/max <= 1``, the argument to sqrt never exceeds 2. - Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. - """ - abs_a = torch.abs(a) - abs_b = torch.abs(b) - big = torch.maximum(abs_a, abs_b) - small = torch.minimum(abs_a, abs_b) - safe_big = torch.where(big == 0, torch.ones_like(big), big) - ratio = small / safe_big - return torch.where( - big == 0, - torch.zeros_like(big), - big * torch.sqrt(1.0 + ratio * ratio), - ) - - -# --------------------------------------------------------------------------- -# Compiled scalar step (single Metal/CUDA kernel after fusion) -# --------------------------------------------------------------------------- - - -def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: - """ - All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond - estimation, and convergence check. - - Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). - Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). - """ - # Unpack - alphabar = state[_I_ALPHABAR] - damp = state[_I_DAMP] - beta = state[_I_BETA] - alpha = state[_I_ALPHA] - sbar = state[_I_SBAR] - cbar = state[_I_CBAR] - zetabar = state[_I_ZETABAR] - rho = state[_I_RHO] - rhobar = state[_I_RHOBAR] - rhodold = state[_I_RHODOLD] - tautildeold = state[_I_TAUTILDEOLD] - thetatilde = state[_I_THETATILDE] - betadd = state[_I_BETADD] - betad = state[_I_BETAD] - d = state[_I_D] - normA2 = state[_I_NORMA2] - maxrbar = state[_I_MAXRBAR] - minrbar = state[_I_MINRBAR] - normb = state[_I_NORMB] - zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) - - atol_t = consts[0] - ctol = consts[2] - - _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero - _ONE = _ZERO + 1.0 - - # --- Givens 1: (alphabar, damp) --- - r1 = _safe_hypot(alphabar, damp) - safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) - chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) - shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) - - # --- Givens 2: (alphahat=r1, beta) --- - rhoold = rho - r2 = _safe_hypot(r1, beta) - safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) - c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) - s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) - rho_new = r2 - thetanew = s * alpha - alphabar_new = c * alpha - - # --- Givens 3: rhobar --- - rhobarold = rhobar - thetabar = sbar * rho_new - rhotemp = cbar * rho_new - r3 = _safe_hypot(rhotemp, thetanew) - safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) - cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) - sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) - rhobar_new = r3 - zeta = cbar_new * zetabar - zetabar_new = -sbar_new * zetabar - - # --- ||r|| estimation --- - betaacute = chat * betadd - betacheck = -shat * betadd - betahat = c * betaacute - betadd_new = -s * betaacute - - # Givens 4: rhotilde - r4 = _safe_hypot(rhodold, thetabar) - safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) - ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) - stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) - - thetatilde_new = stildeold * rhobar_new - rhodold_new = ctildeold * rhobar_new - betad_new = -stildeold * betad + ctildeold * betahat - - tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp(r4, min=1e-30) - taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( - rhodold_new, min=1e-30 - ) - d_new = d + betacheck * betacheck - normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) - - # --- ||A|| estimation --- - normA2_new = normA2 + beta * beta - normA = torch.sqrt(normA2_new) - normA2_final = normA2_new + alpha * alpha - - # --- cond(A) estimation --- - maxrbar_new = torch.maximum(maxrbar, rhobarold) - # Match SciPy: only update minrbar from iteration 2 onward. - # maxrbar == 0 on the first call (initial state), so use it as guard. - minrbar_new = torch.where(maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar) - condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( - torch.minimum(minrbar_new, rhotemp), min=1e-30 - ) - - # --- Convergence check --- - normar = torch.abs(zetabar_new) - test2 = normar / torch.clamp(normA * normr, min=1e-30) - test3 = _ONE / condA - - converged_flag = torch.where( - (test2 <= atol_t) - | (test3 <= ctol) - | (_ONE + test2 <= _ONE) - | (_ONE + test3 <= _ONE), - _ONE, - _ZERO, - ) - - # --- Pack output --- - return torch.stack( - [ - alphabar_new, # 0 _I_ALPHABAR - damp, # 1 _I_DAMP (pass through) - beta, # 2 _I_BETA (pass through, updated by caller) - alpha, # 3 _I_ALPHA (pass through, updated by caller) - sbar_new, # 4 _I_SBAR - cbar_new, # 5 _I_CBAR - zetabar_new, # 6 _I_ZETABAR - rho_new, # 7 _I_RHO - rhobar_new, # 8 _I_RHOBAR - rhodold_new, # 9 _I_RHODOLD - tautildeold_new, # 10 _I_TAUTILDEOLD - thetatilde_new, # 11 _I_THETATILDE - betadd_new, # 12 _I_BETADD - betad_new, # 13 _I_BETAD - d_new, # 14 _I_D - normA2_final, # 15 _I_NORMA2 - maxrbar_new, # 16 _I_MAXRBAR - minrbar_new, # 17 _I_MINRBAR - normb, # 18 _I_NORMB (pass through) - zeta, # 19 _I_ZETA (saved for next iteration's zetaold) - thetanew, # 20 _O_THETANEW (for vector update) - thetabar, # 21 _O_THETABAR (for vector update) - zeta, # 22 _O_ZETA (for vector update — same as slot 19) - rhoold, # 23 _O_RHOOLD (for vector update) - rhobarold, # 24 _O_RHOBAROLD (for vector update) - converged_flag, # 25 _O_CONVERGED - normr, # 26 _O_NORMR - normar, # 27 _O_NORMAR - normA, # 28 _O_NORMA - condA, # 29 _O_CONDA - _ZERO, # 30 _O_NORMX_EST (placeholder) - ] - ) - - -# --------------------------------------------------------------------------- -# Module-level compilation cache -# --------------------------------------------------------------------------- -_compiled_step_cache: dict[str, object] = {} -_cache_lock = threading.Lock() - - -def _get_compiled_step(device_type: str): - """Get or create compiled scalar step for the given device type.""" - if device_type in _compiled_step_cache: - return _compiled_step_cache[device_type] - with _cache_lock: - # Double-check after acquiring lock - if device_type not in _compiled_step_cache: - try: - _compiled_step_cache[device_type] = torch.compile( - _scalar_step, backend="inductor", fullgraph=True - ) - except Exception: - # Fallback: no compilation available - _compiled_step_cache[device_type] = _scalar_step - return _compiled_step_cache[device_type] - - def _lsmr_compiled( A, b: torch.Tensor, @@ -1042,11 +774,11 @@ def _lsmr_compiled( beta = normb.clone() # Safe normalize - u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=1e-30), beta * 0.0) + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=_DIV_GUARD), beta * 0.0) v = _rmatvec(At, u) alpha = torch.linalg.norm(v) - v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=1e-30), alpha * 0.0) + v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=_DIV_GUARD), alpha * 0.0) state = _make_initial_state(alpha, beta, normb, damp, dtype, device) @@ -1077,7 +809,7 @@ def _lsmr_compiled( beta_new = torch.linalg.norm(u) u = u * torch.where( beta_new > 0, - 1.0 / torch.clamp(beta_new, min=1e-30), + 1.0 / torch.clamp(beta_new, min=_DIV_GUARD), beta_new * 0.0, ) @@ -1085,7 +817,7 @@ def _lsmr_compiled( alpha_new = torch.linalg.norm(v) v = v * torch.where( alpha_new > 0, - 1.0 / torch.clamp(alpha_new, min=1e-30), + 1.0 / torch.clamp(alpha_new, min=_DIV_GUARD), alpha_new * 0.0, ) @@ -1123,9 +855,9 @@ def _lsmr_compiled( normA_t = out[_O_NORMA] normb_t = out[_I_NORMB] - test1_t = normr_t / torch.clamp(normb_t, min=1e-30) - t1_t = test1_t / (1.0 + normA_t * normx_t / torch.clamp(normb_t, min=1e-30)) - rtol_t = btol + atol * normA_t * normx_t / torch.clamp(normb_t, min=1e-30) + test1_t = normr_t / torch.clamp(normb_t, min=_DIV_GUARD) + t1_t = test1_t / (1.0 + normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD)) + rtol_t = btol + atol * normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD) converged_btol = (test1_t <= rtol_t) | (1.0 + t1_t <= 1.0) converged_any = (out[_O_CONVERGED] > 0.5) | converged_btol @@ -1139,11 +871,11 @@ def _lsmr_compiled( normar_val = out[_O_NORMAR].item() condA_val = out[_O_CONDA].item() - test1 = normr_val / max(normb_val, 1e-30) - test2 = normar_val / max(normA_val * normr_val, 1e-30) + test1 = normr_val / max(normb_val, _DIV_GUARD) + test2 = normar_val / max(normA_val * normr_val, _DIV_GUARD) test3 = 1.0 / condA_val - t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, 1e-30)) - _rtol = btol + atol * normA_val * normx_val / max(normb_val, 1e-30) + t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, _DIV_GUARD)) + _rtol = btol + atol * normA_val * normx_val / max(normb_val, _DIV_GUARD) # Priority order matches SciPy LSMR (lowest istop wins) if 1.0 + test3 <= 1.0: @@ -1267,7 +999,6 @@ def _lsmr_compiled_batched( # --- Main iteration loop --- itn = 0 - _eps = 1e-30 normx_t = torch.zeros(K, device=device, dtype=dtype) while itn < maxiter: @@ -1299,11 +1030,11 @@ def _lsmr_compiled_batched( rhobarold = out[_O_RHOBAROLD] # (K,) # Safe divisions: some columns may have zero RHS → zero denominators - hbar_coeff = -(thetabar * rho_new) / torch.clamp(rhoold * rhobarold, min=_eps) + hbar_coeff = -(thetabar * rho_new) / torch.clamp(rhoold * rhobarold, min=_DIV_GUARD) Hbar = H + Hbar * hbar_coeff.unsqueeze(0) - x_coeff = zeta / torch.clamp(rho_new * rhobar_new, min=_eps) + x_coeff = zeta / torch.clamp(rho_new * rhobar_new, min=_DIV_GUARD) X = X + x_coeff.unsqueeze(0) * Hbar - h_coeff = -(thetanew / torch.clamp(rho_new, min=_eps)) + h_coeff = -(thetanew / torch.clamp(rho_new, min=_DIV_GUARD)) H = V + H * h_coeff.unsqueeze(0) # Propagate state for next iteration @@ -1318,9 +1049,9 @@ def _lsmr_compiled_batched( normA_t = out[_O_NORMA] # (K,) normb_t = out[_I_NORMB] # (K,) - safe_normb = torch.clamp(normb_t, min=_eps) + safe_normb = torch.clamp(normb_t, min=_DIV_GUARD) test1_t = normr_t / safe_normb - safe_normA_normr = torch.clamp(normA_t * normr_t, min=_eps) + safe_normA_normr = torch.clamp(normA_t * normr_t, min=_DIV_GUARD) test2_t = out[_O_NORMAR] / safe_normA_normr test3_t = 1.0 / out[_O_CONDA] t1_t = test1_t / (1.0 + normA_t * normx_t / safe_normb) @@ -1359,6 +1090,7 @@ def _lsmr_compiled_batched( # =========================================================================== +@torch.no_grad() def lsmr_torch( A, b: torch.Tensor, @@ -1404,6 +1136,7 @@ def lsmr_torch( ) +@torch.no_grad() def lsmr_torch_batched( A, B: torch.Tensor, From 824c31518e54b0fdcba88a103b5763b38f0a8acb Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 31 Mar 2026 10:15:24 +0200 Subject: [PATCH 14/16] refactor demeaning dispatching. --- pyfixest/estimation/internals/demean_.py | 28 ++------- pyfixest/estimation/torch/demean_torch_.py | 7 ++- tests/test_demean.py | 69 +++++++++++++++++----- 3 files changed, 64 insertions(+), 40 deletions(-) diff --git a/pyfixest/estimation/internals/demean_.py b/pyfixest/estimation/internals/demean_.py index c141ce595..cff219a1c 100644 --- a/pyfixest/estimation/internals/demean_.py +++ b/pyfixest/estimation/internals/demean_.py @@ -346,27 +346,9 @@ def _set_demeaner_backend( ValueError If the demeaning backend is not supported. """ - if demeaner_backend == "rust": - from pyfixest.core.demean import demean as demean_rs + from pyfixest.estimation.internals.backends import BACKENDS - return demean_rs - elif demeaner_backend == "rust-cg": - from pyfixest.core.demean import demean_within - - return demean_within - elif demeaner_backend == "numba": - return demean - elif demeaner_backend == "jax": - from pyfixest.estimation.jax.demean_jax_ import demean_jax - - return demean_jax - elif demeaner_backend in ["cupy", "cupy64"]: - from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy64 - - return demean_cupy64 - elif demeaner_backend == "cupy32": - from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy32 - - return demean_cupy32 - else: - raise ValueError(f"Invalid demeaner backend: {demeaner_backend}") + try: + return BACKENDS[demeaner_backend]["demean"] + except KeyError as exc: + raise ValueError(f"Invalid demeaner backend: {demeaner_backend}") from exc diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index a00104dec..35cd878bc 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -24,6 +24,11 @@ _BATCHED_K_THRESHOLD = 5 +def _should_use_batched_lsmr(device: torch.device, K: int) -> bool: + """Use batched LSMR only on devices where it has been benchmarked to help.""" + return device.type != "cpu" and K >= _BATCHED_K_THRESHOLD + + def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: """Auto-detect best available device: CUDA > MPS > CPU. @@ -255,7 +260,7 @@ def _demean_torch_on_device_impl( # Solve for each column — batched SpMM for K >= threshold, sequential otherwise. theta = torch.zeros(D_cols, K, dtype=dtype, device=device) - if K >= _BATCHED_K_THRESHOLD: + if _should_use_batched_lsmr(device, K): # Batched: single call, K columns solved simultaneously via SpMM Z, istop_vec, _itn, *_ = lsmr_torch_batched( A_precond, diff --git a/tests/test_demean.py b/tests/test_demean.py index 0d161a5c6..0e47d76bd 100644 --- a/tests/test_demean.py +++ b/tests/test_demean.py @@ -3,9 +3,11 @@ import pyhdfe import pytest +import pyfixest as pf from pyfixest.core import demean as demean_rs from pyfixest.core.demean import demean_within from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy32, demean_cupy64 +from pyfixest.estimation.internals.backends import BACKENDS from pyfixest.estimation.internals.demean_ import ( _set_demeaner_backend, demean, @@ -46,22 +48,21 @@ def test_demean(benchmark, demean_func): def test_set_demeaner_backend(): - # Test numba backend - demean_func = _set_demeaner_backend("numba") - assert demean_func == demean - - # Test jax backend - demean_func = _set_demeaner_backend("jax") - assert demean_func == demean_jax - - demean_func = _set_demeaner_backend("rust") - assert demean_func == demean_rs - - demean_func = _set_demeaner_backend("cupy32") - assert demean_func == demean_cupy32 - - demean_func = _set_demeaner_backend("cupy64") - assert demean_func == demean_cupy64 + for backend in [ + "numba", + "jax", + "rust", + "cupy32", + "cupy64", + "scipy", + "torch", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", + ]: + demean_func = _set_demeaner_backend(backend) + assert demean_func == BACKENDS[backend]["demean"] demean_func = _set_demeaner_backend("rust-cg") assert demean_func == demean_within @@ -71,6 +72,42 @@ def test_set_demeaner_backend(): _set_demeaner_backend("invalid") +def test_feols_torch_backend_matches_numba(): + pytest.importorskip("torch") + + rng = np.random.default_rng(12345) + N = 400 + df = pd.DataFrame( + { + "y": rng.normal(size=N), + "x1": rng.normal(size=N), + "x2": rng.normal(size=N), + "f1": rng.integers(0, 20, size=N), + "f2": rng.integers(0, 15, size=N), + } + ) + + fit_numba = pf.feols( + "y ~ x1 + x2 | f1 + f2", + data=df, + fixef_tol=1e-8, + demeaner_backend="numba", + ) + fit_torch = pf.feols( + "y ~ x1 + x2 | f1 + f2", + data=df, + fixef_tol=1e-8, + demeaner_backend="torch", + ) + + np.testing.assert_allclose( + fit_torch.coef().sort_index().values, + fit_numba.coef().sort_index().values, + rtol=1e-7, + atol=1e-9, + ) + + @pytest.mark.parametrize( argnames="demean_func", argvalues=[demean, demean_jax, demean_rs, demean_cupy32, demean_cupy64], From b40b908887624c946d8b9769bf672829865a8f77 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:07:23 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../estimation/torch/_lsmr_compiled_core.py | 4 +- pyfixest/estimation/torch/lsmr_torch.py | 230 ++++++++++++------ tests/test_batched_lsmr.py | 47 ++-- 3 files changed, 185 insertions(+), 96 deletions(-) diff --git a/pyfixest/estimation/torch/_lsmr_compiled_core.py b/pyfixest/estimation/torch/_lsmr_compiled_core.py index fa1ab855e..752df91da 100644 --- a/pyfixest/estimation/torch/_lsmr_compiled_core.py +++ b/pyfixest/estimation/torch/_lsmr_compiled_core.py @@ -242,9 +242,7 @@ def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: maxrbar_new = torch.maximum(maxrbar, rhobarold) # Match SciPy: only update minrbar from iteration 2 onward. # maxrbar == 0 on the first call (initial state), so use it as guard. - minrbar_new = torch.where( - maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar - ) + minrbar_new = torch.where(maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar) condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( torch.minimum(minrbar_new, rhotemp), min=_DIV_GUARD ) diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py index 7680c72f5..2e8dc87b7 100644 --- a/pyfixest/estimation/torch/lsmr_torch.py +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -50,7 +50,6 @@ _O_NORMA, _O_NORMAR, _O_NORMR, - _O_NORMX_EST, _O_RHOBAROLD, _O_RHOOLD, _O_THETABAR, @@ -262,23 +261,28 @@ def _check_convergence_batched( new_stop = torch.where(test1 <= rtol, torch.ones_like(new_stop), new_stop) new_stop = torch.where( (test2 <= atol) & (new_stop == 0), - 2 * torch.ones_like(new_stop), new_stop, + 2 * torch.ones_like(new_stop), + new_stop, ) new_stop = torch.where( (test3 <= ctol) & (new_stop == 0), - 3 * torch.ones_like(new_stop), new_stop, + 3 * torch.ones_like(new_stop), + new_stop, ) new_stop = torch.where( (1.0 + t1 <= 1.0) & (new_stop == 0), - 4 * torch.ones_like(new_stop), new_stop, + 4 * torch.ones_like(new_stop), + new_stop, ) new_stop = torch.where( (1.0 + test2 <= 1.0) & (new_stop == 0), - 5 * torch.ones_like(new_stop), new_stop, + 5 * torch.ones_like(new_stop), + new_stop, ) new_stop = torch.where( (1.0 + test3 <= 1.0) & (new_stop == 0), - 6 * torch.ones_like(new_stop), new_stop, + 6 * torch.ones_like(new_stop), + new_stop, ) return torch.where(not_yet, new_stop, istop) @@ -306,8 +310,14 @@ def _lsmr_batched( conlim: float = 1e8, maxiter: int | None = None, ) -> tuple[ - torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, ]: """ Batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. @@ -352,16 +362,16 @@ def _lsmr_batched( At = _precompute_transpose(A) # --- Initialize Golub-Kahan bidiagonalization --- - U = B.clone() # (m, K) - normb = torch.linalg.norm(U, dim=0) # (K,) + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) X = torch.zeros(n, K, device=device, dtype=dtype) - beta = normb.clone() # (K,) + beta = normb.clone() # (K,) U, beta = _safe_normalize_cols(U, beta) - V = _rmatvec_batched(At, U) # (n, K) — SpMM - alpha = torch.linalg.norm(V, dim=0) # (K,) + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) V, alpha = _safe_normalize_cols(V, alpha) @@ -374,7 +384,7 @@ def _lsmr_batched( cbar = torch.ones(K, device=device, dtype=dtype) sbar = torch.zeros(K, device=device, dtype=dtype) - H = V.clone() # (n, K) + H = V.clone() # (n, K) Hbar = torch.zeros(n, K, device=device, dtype=dtype) # ||r|| estimation state @@ -402,27 +412,41 @@ def _lsmr_batched( # Early exit: if all normar == 0 or all normb == 0 if (normar == 0).all(): - return (X, istop, itn, normr, normar, - torch.sqrt(normA2), torch.ones(K, device=device, dtype=dtype), - torch.zeros(K, device=device, dtype=dtype)) + return ( + X, + istop, + itn, + normr, + normar, + torch.sqrt(normA2), + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) if (normb == 0).all(): X.zero_() - return (X, istop, itn, normr, normar, - torch.sqrt(normA2), torch.ones(K, device=device, dtype=dtype), - torch.zeros(K, device=device, dtype=dtype)) + return ( + X, + istop, + itn, + normr, + normar, + torch.sqrt(normA2), + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) # --- Main iteration loop --- while itn < maxiter: itn += 1 # Bidiagonalization step: SpMM replaces SpMV - U = _matvec_batched(A, V) - alpha.unsqueeze(0) * U # (m, K) - beta = torch.linalg.norm(U, dim=0) # (K,) + U = _matvec_batched(A, V) - alpha.unsqueeze(0) * U # (m, K) + beta = torch.linalg.norm(U, dim=0) # (K,) U, beta = _safe_normalize_cols(U, beta) - V = _rmatvec_batched(At, U) - beta.unsqueeze(0) * V # (n, K) - alpha = torch.linalg.norm(V, dim=0) # (K,) + V = _rmatvec_batched(At, U) - beta.unsqueeze(0) * V # (n, K) + alpha = torch.linalg.norm(V, dim=0) # (K,) V, alpha = _safe_normalize_cols(V, alpha) # Givens rotation 1: (alphabar, damp) @@ -487,7 +511,7 @@ def _lsmr_batched( # Per-column convergence check normar = torch.abs(zetabar) - normx = torch.linalg.norm(X, dim=0) # (K,) + normx = torch.linalg.norm(X, dim=0) # (K,) safe_normb = torch.clamp(normb, min=_DIV_GUARD) test1 = normr / safe_normb @@ -498,7 +522,16 @@ def _lsmr_batched( rtol = btol + atol * normA * normx / safe_normb istop = _check_convergence_batched( - istop, test1, rtol, test2, test3, t1, atol, ctol, K, device, + istop, + test1, + rtol, + test2, + test3, + t1, + atol, + ctol, + K, + device, ) if (istop > 0).all(): break @@ -734,6 +767,7 @@ def _lsmr_eager( # Implementation 2: compiled-state LSMR (CUDA) # =========================================================================== + def _lsmr_compiled( A, b: torch.Tensor, @@ -778,7 +812,9 @@ def _lsmr_compiled( v = _rmatvec(At, u) alpha = torch.linalg.norm(v) - v = v * torch.where(alpha > 0, 1.0 / torch.clamp(alpha, min=_DIV_GUARD), alpha * 0.0) + v = v * torch.where( + alpha > 0, 1.0 / torch.clamp(alpha, min=_DIV_GUARD), alpha * 0.0 + ) state = _make_initial_state(alpha, beta, normb, damp, dtype, device) @@ -856,7 +892,9 @@ def _lsmr_compiled( normb_t = out[_I_NORMB] test1_t = normr_t / torch.clamp(normb_t, min=_DIV_GUARD) - t1_t = test1_t / (1.0 + normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD)) + t1_t = test1_t / ( + 1.0 + normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD) + ) rtol_t = btol + atol * normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD) converged_btol = (test1_t <= rtol_t) | (1.0 + t1_t <= 1.0) @@ -930,8 +968,14 @@ def _lsmr_compiled_batched( maxiter: int | None = None, use_compile: bool = True, ) -> tuple[ - torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, ]: """ Compiled batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. @@ -958,16 +1002,16 @@ def _lsmr_compiled_batched( At = _precompute_transpose(A) # --- Initialize Golub-Kahan bidiagonalization --- - U = B.clone() # (m, K) - normb = torch.linalg.norm(U, dim=0) # (K,) + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) X = torch.zeros(n, K, device=device, dtype=dtype) - beta = normb.clone() # (K,) + beta = normb.clone() # (K,) U, beta = _safe_normalize_cols(U, beta) - V = _rmatvec_batched(At, U) # (n, K) — SpMM - alpha = torch.linalg.norm(V, dim=0) # (K,) + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) V, alpha = _safe_normalize_cols(V, alpha) state = _make_initial_state(alpha, beta, normb, damp, dtype, device, K=K) @@ -980,18 +1024,30 @@ def _lsmr_compiled_batched( # Early exit check normar_init = alpha * beta # (K,) if (normar_init == 0).all(): - return (X, torch.zeros(K, device=device, dtype=torch.long), 0, - beta, torch.zeros(K, device=device, dtype=dtype), - alpha, torch.ones(K, device=device, dtype=dtype), - torch.zeros(K, device=device, dtype=dtype)) + return ( + X, + torch.zeros(K, device=device, dtype=torch.long), + 0, + beta, + torch.zeros(K, device=device, dtype=dtype), + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) if (normb == 0).all(): X.zero_() - return (X, torch.zeros(K, device=device, dtype=torch.long), 0, - beta, torch.zeros(K, device=device, dtype=dtype), - alpha, torch.ones(K, device=device, dtype=dtype), - torch.zeros(K, device=device, dtype=dtype)) + return ( + X, + torch.zeros(K, device=device, dtype=torch.long), + 0, + beta, + torch.zeros(K, device=device, dtype=dtype), + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) - H = V.clone() # (n, K) + H = V.clone() # (n, K) Hbar = torch.zeros(n, K, device=device, dtype=dtype) # Convergence tracking: per-column istop, only-set-once latch @@ -1005,12 +1061,12 @@ def _lsmr_compiled_batched( itn += 1 # Phase 1: SpMM bidiagonalization (not compilable) - U = _matvec_batched(A, V) - state[_I_ALPHA].unsqueeze(0) * U # (m, K) - beta_new = torch.linalg.norm(U, dim=0) # (K,) + U = _matvec_batched(A, V) - state[_I_ALPHA].unsqueeze(0) * U # (m, K) + beta_new = torch.linalg.norm(U, dim=0) # (K,) U, beta_new = _safe_normalize_cols(U, beta_new) - V = _rmatvec_batched(At, U) - beta_new.unsqueeze(0) * V # (n, K) - alpha_new = torch.linalg.norm(V, dim=0) # (K,) + V = _rmatvec_batched(At, U) - beta_new.unsqueeze(0) * V # (n, K) + alpha_new = torch.linalg.norm(V, dim=0) # (K,) V, alpha_new = _safe_normalize_cols(V, alpha_new) # Update beta/alpha in state for the scalar step @@ -1021,16 +1077,18 @@ def _lsmr_compiled_batched( out = step_fn(state, consts) # Phase 3: Vector updates using scalar results from compiled step - thetanew = out[_O_THETANEW] # (K,) - thetabar = out[_O_THETABAR] # (K,) - zeta = out[_O_ZETA] # (K,) - rho_new = out[_I_RHO] # (K,) - rhobar_new = out[_I_RHOBAR] # (K,) - rhoold = out[_O_RHOOLD] # (K,) - rhobarold = out[_O_RHOBAROLD] # (K,) + thetanew = out[_O_THETANEW] # (K,) + thetabar = out[_O_THETABAR] # (K,) + zeta = out[_O_ZETA] # (K,) + rho_new = out[_I_RHO] # (K,) + rhobar_new = out[_I_RHOBAR] # (K,) + rhoold = out[_O_RHOOLD] # (K,) + rhobarold = out[_O_RHOBAROLD] # (K,) # Safe divisions: some columns may have zero RHS → zero denominators - hbar_coeff = -(thetabar * rho_new) / torch.clamp(rhoold * rhobarold, min=_DIV_GUARD) + hbar_coeff = -(thetabar * rho_new) / torch.clamp( + rhoold * rhobarold, min=_DIV_GUARD + ) Hbar = H + Hbar * hbar_coeff.unsqueeze(0) x_coeff = zeta / torch.clamp(rho_new * rhobar_new, min=_DIV_GUARD) X = X + x_coeff.unsqueeze(0) * Hbar @@ -1044,10 +1102,10 @@ def _lsmr_compiled_batched( # No not_yet.any() guard: the torch.where inside _check_convergence_batched # already protects converged columns, and the guard would add a second # host-device sync that costs more than the fused tensor ops it skips. - normx_t = torch.linalg.norm(X, dim=0) # (K,) - normr_t = out[_O_NORMR] # (K,) - normA_t = out[_O_NORMA] # (K,) - normb_t = out[_I_NORMB] # (K,) + normx_t = torch.linalg.norm(X, dim=0) # (K,) + normr_t = out[_O_NORMR] # (K,) + normA_t = out[_O_NORMA] # (K,) + normb_t = out[_I_NORMB] # (K,) safe_normb = torch.clamp(normb_t, min=_DIV_GUARD) test1_t = normr_t / safe_normb @@ -1058,7 +1116,16 @@ def _lsmr_compiled_batched( rtol_t = btol + atol * normA_t * normx_t / safe_normb istop = _check_convergence_batched( - istop, test1_t, rtol_t, test2_t, test3_t, t1_t, atol, ctol, K, device, + istop, + test1_t, + rtol_t, + test2_t, + test3_t, + t1_t, + atol, + ctol, + K, + device, ) # Single .item() sync: check if all columns have converged @@ -1069,9 +1136,16 @@ def _lsmr_compiled_batched( # Handle case where loop never ran if itn == 0: - return (X, istop, 0, normb, normar_init, alpha, - torch.ones(K, device=device, dtype=dtype), - torch.zeros(K, device=device, dtype=dtype)) + return ( + X, + istop, + 0, + normb, + normar_init, + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) return ( X, @@ -1147,8 +1221,14 @@ def lsmr_torch_batched( maxiter: int | None = None, use_compile: bool | None = None, ) -> tuple[ - torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, ]: """ Batched LSMR solver — solve K right-hand sides simultaneously via SpMM. @@ -1212,9 +1292,21 @@ def lsmr_torch_batched( if use_compile: return _lsmr_compiled_batched( - A, B, damp=damp, atol=atol, btol=btol, conlim=conlim, - maxiter=maxiter, use_compile=True, + A, + B, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + use_compile=True, ) return _lsmr_batched( - A, B, damp=damp, atol=atol, btol=btol, conlim=conlim, maxiter=maxiter, + A, + B, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, ) diff --git a/tests/test_batched_lsmr.py b/tests/test_batched_lsmr.py index 04dc3f05c..300f9c18e 100644 --- a/tests/test_batched_lsmr.py +++ b/tests/test_batched_lsmr.py @@ -17,9 +17,8 @@ torch = pytest.importorskip("torch") from pyfixest.estimation.torch.demean_torch_ import ( # noqa: E402 - _PreconditionedSparse, _build_sparse_dummy, - _scale_sparse_rows, + _PreconditionedSparse, demean_torch, ) from pyfixest.estimation.torch.lsmr_torch import ( # noqa: E402 @@ -152,7 +151,9 @@ def test_zero_rhs_column(self): assert torch.allclose( X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 - ), f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ), ( + f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ) def test_all_zero_rhs(self): """All-zero B should return all-zero X with istop=0.""" @@ -188,9 +189,7 @@ def test_overdetermined_known_solution(self): ) A_sparse = A_dense.to_sparse_csr() - X_true = torch.tensor( - [[1.0, 2.0, -1.0], [-1.0, 0.5, 3.0]], dtype=torch.float64 - ) + X_true = torch.tensor([[1.0, 2.0, -1.0], [-1.0, 0.5, 3.0]], dtype=torch.float64) B = A_dense @ X_true # (3, 3) X_sol, istop, *_ = lsmr_torch_batched(A_sparse, B) @@ -207,9 +206,7 @@ def test_maxiter_exhaustion(self): A = _make_sparse_problem(m, n) B = _make_rhs(m, K, seed=42) - _, istop, itn, *_ = lsmr_torch_batched( - A, B, maxiter=2, atol=1e-15, btol=1e-15 - ) + _, istop, itn, *_ = lsmr_torch_batched(A, B, maxiter=2, atol=1e-15, btol=1e-15) assert (istop == 7).all(), f"Expected istop=7, got {istop}" assert itn == 2 @@ -300,10 +297,12 @@ def test_batched_demean_matches_pyhdfe(self): rng = np.random.default_rng(929291) N, K = 1000, 10 x = rng.normal(0, 1, (N, K)) - flist = np.column_stack([ - rng.choice(10, N), - rng.choice(10, N), - ]).astype(np.uint64) + flist = np.column_stack( + [ + rng.choice(10, N), + rng.choice(10, N), + ] + ).astype(np.uint64) weights = np.ones(N) algorithm = pyhdfe.create(flist) @@ -324,10 +323,12 @@ def test_batched_demean_weighted_matches_pyhdfe(self): rng = np.random.default_rng(929291) N, K = 1000, 5 x = rng.normal(0, 1, (N, K)) - flist = np.column_stack([ - rng.choice(10, N), - rng.choice(10, N), - ]).astype(np.uint64) + flist = np.column_stack( + [ + rng.choice(10, N), + rng.choice(10, N), + ] + ).astype(np.uint64) weights = rng.uniform(0.1, 2.0, N) algorithm = pyhdfe.create(flist) @@ -457,7 +458,9 @@ def test_compiled_zero_rhs_column(self): assert torch.allclose( X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 - ), f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ), ( + f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ) # Non-zero columns should still solve correctly for k in [0, 2]: @@ -513,9 +516,7 @@ def test_compiled_mps_correctness(self): X_mps, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=True) - max_diff = torch.max( - torch.abs(X_ref.float() - X_mps.cpu()) - ).item() + max_diff = torch.max(torch.abs(X_ref.float() - X_mps.cpu())).item() assert max_diff < 0.1, ( f"MPS f32 compiled vs CPU f64 too different: {max_diff:.2e}" ) @@ -534,6 +535,4 @@ def test_compiled_vs_uncompiled_mps(self): X_nocomp, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=False) max_diff = torch.max(torch.abs(X_comp - X_nocomp)).item() - assert max_diff < 1e-4, ( - f"Compiled vs uncompiled differ on MPS: {max_diff:.2e}" - ) + assert max_diff < 1e-4, f"Compiled vs uncompiled differ on MPS: {max_diff:.2e}" From 3debe407a1492a45a9374370883f9a0fe916ace3 Mon Sep 17 00:00:00 2001 From: Jan Date: Wed, 1 Apr 2026 11:28:30 +0200 Subject: [PATCH 16/16] different K thresholds for CUDA and MPS --- pyfixest/estimation/torch/demean_torch_.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py index 35cd878bc..949d6f0b3 100644 --- a/pyfixest/estimation/torch/demean_torch_.py +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -20,13 +20,18 @@ from pyfixest.estimation.torch.lsmr_torch import lsmr_torch, lsmr_torch_batched # Minimum K (number of RHS columns) for batched SpMM to beat sequential SpMV. -# Benchmarked breakeven: ~K=5 on MPS (Metal) and CUDA (cuSPARSE). -_BATCHED_K_THRESHOLD = 5 +# Benchmarked breakeven is device-specific. +_BATCHED_K_THRESHOLD_CUDA = 2 +_BATCHED_K_THRESHOLD_MPS = 5 def _should_use_batched_lsmr(device: torch.device, K: int) -> bool: - """Use batched LSMR only on devices where it has been benchmarked to help.""" - return device.type != "cpu" and K >= _BATCHED_K_THRESHOLD + """Use batched LSMR only when device-specific benchmarks show a benefit.""" + if device.type == "cuda": + return K >= _BATCHED_K_THRESHOLD_CUDA + if device.type == "mps": + return K >= _BATCHED_K_THRESHOLD_MPS + return False def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: