diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 3700ef434..e6a4289dc 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -16,6 +16,75 @@ import pandas as pd +# Optional JAX availability detection +try: + import jax # noqa: F401 + + HAS_JAX = True +except ImportError: + HAS_JAX = False + +# Optional torch availability detection +try: + import torch + + HAS_TORCH = True + HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + HAS_CUDA = torch.cuda.is_available() +except ImportError: + HAS_TORCH = False + HAS_MPS = False + HAS_CUDA = False + +# Optional CuPy availability detection +try: + import cupy + + HAS_CUPY = cupy.cuda.runtime.getDeviceCount() > 0 +except (ImportError, Exception): + HAS_CUPY = False + +# Backends that accept a backend= argument when called through pyfixest runners +_PYFIXEST_BACKENDS = { + "scipy", + "numba", + "rust", + "jax", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", + "cupy", + "cupy32", + "cupy64", +} + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _append_optional_backends(estimators, label_prefix, runner_func, func_name): + """Append JAX + torch backend estimators based on runtime availability.""" + optional = [] + if HAS_JAX: + optional.append(("jax", "jax")) + if HAS_TORCH: + optional.append(("torch_cpu", "torch_cpu")) + if HAS_MPS: + optional.append(("torch_mps", "torch_mps")) + if HAS_CUDA: + optional.append(("torch_cuda", "torch_cuda")) + optional.append(("torch_cuda32", "torch_cuda32")) + if HAS_CUPY: + optional.append(("cupy64", "cupy64")) + optional.append(("cupy32", "cupy32")) + for suffix, backend in optional: + estimators.append( + (f"{label_prefix} ({suffix})", backend, runner_func, False, func_name) + ) + + # ============================================================================= # Estimator functions (run in main process for JIT caching) # ============================================================================= @@ -220,6 +289,11 @@ def get_estimators( False, "pyfixest_feols", ), + ] + _append_optional_backends( + estimators, "pyfixest.feols", run_pyfixest_feols, "pyfixest_feols" + ) + estimators += [ ( "linearmodels.AbsorbingLS", "absorbingls", @@ -263,6 +337,9 @@ def get_estimators( "pyfixest_fepois", ), ] + _append_optional_backends( + estimators, "pyfixest.fepois", run_pyfixest_fepois, "pyfixest_fepois" + ) formulas = { 2: "negbin_y ~ x1 | indiv_id + year", 3: "negbin_y ~ x1 | indiv_id + year + firm_id", @@ -291,6 +368,12 @@ def get_estimators( "pyfixest_feglm_logit", ), ] + _append_optional_backends( + estimators, + "pyfixest.feglm_logit", + run_pyfixest_feglm_logit, + "pyfixest_feglm_logit", + ) formulas = { 2: "binary_y ~ x1 | indiv_id + year", 3: "binary_y ~ x1 | indiv_id + year + firm_id", @@ -310,6 +393,7 @@ def parse_dataset_name(name: str) -> tuple[str, int]: "500k": 500_000, "1m": 1_000_000, "2m": 2_000_000, + "3m": 3_000_000, "5m": 5_000_000, } parts = name.rsplit("_", 1) @@ -333,11 +417,17 @@ def run_benchmark( timeout_estimators: set[str] | None = None, formulas_override: dict[int, str] | None = None, allowed_datasets: set[str] | None = None, + pyfixest_only: bool = False, + backend_filter: set[str] | None = None, ) -> None: """Run benchmarks on all datasets in data_dir.""" if timeout_estimators is None: timeout_estimators = set() estimators, formulas = get_estimators(benchmark_type, timeout_estimators) + if pyfixest_only: + estimators = [e for e in estimators if e[1] in _PYFIXEST_BACKENDS] + if backend_filter: + estimators = [e for e in estimators if e[1] in backend_filter] if formulas_override: formulas = formulas_override @@ -420,7 +510,7 @@ def run_benchmark( print(f"{elapsed:.3f}s") else: # Run in main process - if backend_or_func in ("scipy", "numba", "rust"): + if backend_or_func in _PYFIXEST_BACKENDS: elapsed = func(data, formula, backend_or_func) else: elapsed = func(data, formula) @@ -527,6 +617,17 @@ def main(): default=None, help="Filter datasets by name (e.g., 'simple' to exclude 'difficult')", ) + parser.add_argument( + "--pyfixest-only", + action="store_true", + help="Skip non-pyfixest estimators (linearmodels, statsmodels)", + ) + parser.add_argument( + "--backends", + type=str, + default=None, + help="Comma-separated list of backends to run (e.g., 'torch_cuda,torch_cuda32,cupy64,cupy32')", + ) args = parser.parse_args() config = load_config("bench.json") @@ -537,6 +638,10 @@ def main(): formulas_override = get_formulas_from_config(config, args.type) allowed_datasets = get_allowed_datasets(config, args.type) + backend_filter = None + if args.backends: + backend_filter = set(b.strip() for b in args.backends.split(",")) + run_benchmark( args.data_dir, args.output, @@ -546,6 +651,8 @@ def main(): timeout_estimators, formulas_override, allowed_datasets, + pyfixest_only=args.pyfixest_only, + backend_filter=backend_filter, ) diff --git a/benchmarks/config.json b/benchmarks/config.json new file mode 100644 index 000000000..24251fdd1 --- /dev/null +++ b/benchmarks/config.json @@ -0,0 +1,110 @@ +{ + "timeout_secs": { + "python": 60 + }, + "python_timeout_estimators": ["linearmodels.AbsorbingLS", "statsmodels.OLS"], + "iterations": { + "n_iters": 3, + "burn_in": 1 + }, + "formulas": { + "ols": { + "2": { + "python": "y ~ x1 | indiv_id + year" + }, + "3": { + "python": "y ~ x1 | indiv_id + year + firm_id" + } + }, + "poisson": { + "2": { + "python": "negbin_y ~ x1 | indiv_id + year" + }, + "3": { + "python": "negbin_y ~ x1 | indiv_id + year + firm_id" + } + }, + "logit": { + "2": { + "python": "binary_y ~ x1 | indiv_id + year" + }, + "3": { + "python": "binary_y ~ x1 | indiv_id + year + firm_id" + } + } + }, + "datasets": [ + { "name": "simple_1k", "n": 1000, "type": "simple" }, + { "name": "difficult_1k", "n": 1000, "type": "difficult" }, + { "name": "simple_10k", "n": 10000, "type": "simple" }, + { "name": "difficult_10k", "n": 10000, "type": "difficult" }, + { "name": "simple_100k", "n": 100000, "type": "simple" }, + { "name": "difficult_100k", "n": 100000, "type": "difficult" }, + { "name": "simple_500k", "n": 500000, "type": "simple" }, + { "name": "difficult_500k", "n": 500000, "type": "difficult" }, + { "name": "simple_1m", "n": 1000000, "type": "simple" }, + { "name": "difficult_1m", "n": 1000000, "type": "difficult" }, + { "name": "simple_2m", "n": 2000000, "type": "simple" }, + { "name": "difficult_2m", "n": 2000000, "type": "difficult" }, + { "name": "simple_3m", "n": 3000000, "type": "simple" }, + { "name": "difficult_3m", "n": 3000000, "type": "difficult" }, + { "name": "simple_5m", "n": 5000000, "type": "simple" }, + { "name": "difficult_5m", "n": 5000000, "type": "difficult" } + ], + "datasets_by_type": { + "ols": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" + ], + "poisson": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" + ], + "logit": [ + "simple_1k", + "difficult_1k", + "simple_10k", + "difficult_10k", + "simple_100k", + "difficult_100k", + "simple_500k", + "difficult_500k", + "simple_1m", + "difficult_1m", + "simple_2m", + "difficult_2m", + "simple_3m", + "difficult_3m", + "simple_5m", + "difficult_5m" + ] + } +} diff --git a/pyfixest/estimation/internals/backends.py b/pyfixest/estimation/internals/backends.py index 2f6b35bde..40f698eed 100644 --- a/pyfixest/estimation/internals/backends.py +++ b/pyfixest/estimation/internals/backends.py @@ -47,6 +47,29 @@ crv1_meat_loop_cupy = crv1_meat_loop_nb count_fixef_fully_nested_all_cupy = count_fixef_fully_nested_all_nb +# Try to import Torch functions, fall back to numba if not available +try: + from pyfixest.estimation.torch.demean_torch_ import ( + demean_torch, + demean_torch_cpu, + demean_torch_cuda, + demean_torch_cuda32, + demean_torch_mps, + ) + + TORCH_AVAILABLE = True +except ImportError: + demean_torch = demean_nb + demean_torch_cpu = demean_nb + demean_torch_mps = demean_nb + demean_torch_cuda = demean_nb + demean_torch_cuda32 = demean_nb + TORCH_AVAILABLE = False + +find_collinear_variables_torch = find_collinear_variables_nb +crv1_meat_loop_torch = crv1_meat_loop_nb +count_fixef_fully_nested_all_torch = count_fixef_fully_nested_all_nb + BACKENDS = { "numba": { "demean": demean_nb, @@ -96,4 +119,19 @@ "crv1_meat": crv1_meat_loop_cupy, "nonnested": count_fixef_fully_nested_all_cupy, }, + **{ + name: { + "demean": demean_fn, + "collinear": find_collinear_variables_torch, + "crv1_meat": crv1_meat_loop_torch, + "nonnested": count_fixef_fully_nested_all_torch, + } + for name, demean_fn in [ + ("torch", demean_torch), + ("torch_cpu", demean_torch_cpu), + ("torch_mps", demean_torch_mps), + ("torch_cuda", demean_torch_cuda), + ("torch_cuda32", demean_torch_cuda32), + ] + }, } diff --git a/pyfixest/estimation/internals/demean_.py b/pyfixest/estimation/internals/demean_.py index c141ce595..cff219a1c 100644 --- a/pyfixest/estimation/internals/demean_.py +++ b/pyfixest/estimation/internals/demean_.py @@ -346,27 +346,9 @@ def _set_demeaner_backend( ValueError If the demeaning backend is not supported. """ - if demeaner_backend == "rust": - from pyfixest.core.demean import demean as demean_rs + from pyfixest.estimation.internals.backends import BACKENDS - return demean_rs - elif demeaner_backend == "rust-cg": - from pyfixest.core.demean import demean_within - - return demean_within - elif demeaner_backend == "numba": - return demean - elif demeaner_backend == "jax": - from pyfixest.estimation.jax.demean_jax_ import demean_jax - - return demean_jax - elif demeaner_backend in ["cupy", "cupy64"]: - from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy64 - - return demean_cupy64 - elif demeaner_backend == "cupy32": - from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy32 - - return demean_cupy32 - else: - raise ValueError(f"Invalid demeaner backend: {demeaner_backend}") + try: + return BACKENDS[demeaner_backend]["demean"] + except KeyError as exc: + raise ValueError(f"Invalid demeaner backend: {demeaner_backend}") from exc diff --git a/pyfixest/estimation/internals/literals.py b/pyfixest/estimation/internals/literals.py index ab8e7e067..84be563f6 100644 --- a/pyfixest/estimation/internals/literals.py +++ b/pyfixest/estimation/internals/literals.py @@ -12,7 +12,19 @@ "jax", ] DemeanerBackendOptions = Literal[ - "numba", "jax", "rust", "rust-cg", "cupy", "cupy32", "cupy64", "scipy" + "numba", + "jax", + "rust", + "rust-cg", + "cupy", + "cupy32", + "cupy64", + "scipy", + "torch", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", ] PredictionErrorOptions = Literal["prediction"] QuantregMethodOptions = Literal["fn", "pfn"] diff --git a/pyfixest/estimation/torch/__init__.py b/pyfixest/estimation/torch/__init__.py new file mode 100644 index 000000000..c89341727 --- /dev/null +++ b/pyfixest/estimation/torch/__init__.py @@ -0,0 +1,15 @@ +from pyfixest.estimation.torch.demean_torch_ import ( + demean_torch, + demean_torch_cpu, + demean_torch_cuda, + demean_torch_cuda32, + demean_torch_mps, +) + +__all__ = [ + "demean_torch", + "demean_torch_cpu", + "demean_torch_cuda", + "demean_torch_cuda32", + "demean_torch_mps", +] diff --git a/pyfixest/estimation/torch/_lsmr_compiled_core.py b/pyfixest/estimation/torch/_lsmr_compiled_core.py new file mode 100644 index 000000000..752df91da --- /dev/null +++ b/pyfixest/estimation/torch/_lsmr_compiled_core.py @@ -0,0 +1,328 @@ +""" +Compiled LSMR scalar-step kernel and packed-state protocol. + +This module contains the densest part of the LSMR implementation: +the packed tensor state layout (31 index constants), the fused +``_scalar_step`` kernel that ``torch.compile`` turns into a single +Metal/CUDA kernel, and the helpers that create and cache it. + +Separated from ``lsmr_torch.py`` so the protocol is co-located with +the only code that reads/writes it, and can be unit-tested in isolation. +""" + +from __future__ import annotations + +import threading +import warnings + +import torch + +# Guard value for divisions that can be zero (e.g. normb, rho, rhodold). +# Using clamp(..., min=_DIV_GUARD) or max(..., _DIV_GUARD) prevents 0/0 → NaN. +_DIV_GUARD = 1e-30 + +# --------------------------------------------------------------------------- +# Packed state layout +# --------------------------------------------------------------------------- +# All scalar state is packed into a single 1-D tensor to minimize Metal buffer +# slots (hardware limit: 31 per kernel). +# +# Input state (20 elements): +_I_ALPHABAR = 0 +_I_DAMP = 1 +_I_BETA = 2 +_I_ALPHA = 3 +_I_SBAR = 4 +_I_CBAR = 5 +_I_ZETABAR = 6 +_I_RHO = 7 +_I_RHOBAR = 8 +_I_RHODOLD = 9 +_I_TAUTILDEOLD = 10 +_I_THETATILDE = 11 +_I_BETADD = 12 +_I_BETAD = 13 +_I_D = 14 +_I_NORMA2 = 15 +_I_MAXRBAR = 16 +_I_MINRBAR = 17 +_I_NORMB = 18 +_I_ZETA = 19 # previous iteration's zeta (for normr estimation) + +# Constants (3 elements): atol, btol, ctol + +# Output adds extra slots for vector update coefficients: +_O_THETANEW = 20 +_O_THETABAR = 21 +_O_ZETA = 22 +_O_RHOOLD = 23 +_O_RHOBAROLD = 24 +_O_CONVERGED = 25 +_O_NORMR = 26 +_O_NORMAR = 27 +_O_NORMA = 28 +_O_CONDA = 29 +_O_NORMX_EST = 30 # placeholder, actual normx computed from vector + +_STATE_SIZE = 20 + + +# --------------------------------------------------------------------------- +# Overflow-safe hypot (replaces torch.hypot for Metal compatibility) +# --------------------------------------------------------------------------- + + +def _safe_hypot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Overflow-safe hypot: ``sqrt(a^2 + b^2)`` without intermediate overflow. + + Uses max/min scaling: ``hypot(a,b) = max(|a|,|b|) * sqrt(1 + (min/max)^2)``. + Since ``min/max <= 1``, the argument to sqrt never exceeds 2. + Compiles to ~6 Metal/CUDA ops that fuse into the surrounding kernel. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + big = torch.maximum(abs_a, abs_b) + small = torch.minimum(abs_a, abs_b) + safe_big = torch.where(big == 0, torch.ones_like(big), big) + ratio = small / safe_big + return torch.where( + big == 0, + torch.zeros_like(big), + big * torch.sqrt(1.0 + ratio * ratio), + ) + + +# --------------------------------------------------------------------------- +# Initial state packing +# --------------------------------------------------------------------------- + + +def _make_initial_state( + alpha: torch.Tensor, + beta: torch.Tensor, + normb: torch.Tensor, + damp: float, + dtype: torch.dtype, + device: torch.device, + *, + K: int | None = None, +) -> torch.Tensor: + """ + Pack the 20-element LSMR scalar state into a tensor. + + For single-RHS: ``K=None`` → shape ``(_STATE_SIZE,)``. + For batched: ``K=int`` → shape ``(_STATE_SIZE, K)``. + """ + shape = (_STATE_SIZE,) if K is None else (_STATE_SIZE, K) + state = torch.zeros(shape, device=device, dtype=dtype) + state[_I_ALPHABAR] = alpha + state[_I_DAMP] = damp + state[_I_BETA] = beta + state[_I_ALPHA] = alpha + state[_I_SBAR] = 0.0 + state[_I_CBAR] = 1.0 + state[_I_ZETABAR] = alpha * beta + state[_I_RHO] = 1.0 + state[_I_RHOBAR] = 1.0 + state[_I_RHODOLD] = 1.0 + state[_I_TAUTILDEOLD] = 0.0 + state[_I_THETATILDE] = 0.0 + state[_I_BETADD] = beta + state[_I_BETAD] = 0.0 + state[_I_D] = 0.0 + state[_I_NORMA2] = alpha * alpha + state[_I_MAXRBAR] = 0.0 + state[_I_MINRBAR] = 1e100 if dtype == torch.float64 else 1e10 + state[_I_NORMB] = normb + state[_I_ZETA] = 0.0 + return state + + +# --------------------------------------------------------------------------- +# Compiled scalar step (single Metal/CUDA kernel after fusion) +# --------------------------------------------------------------------------- + + +def _scalar_step(state: torch.Tensor, consts: torch.Tensor) -> torch.Tensor: + """ + All scalar work for one LSMR iteration: 4 Givens rotations, norm/cond + estimation, and convergence check. + + Packed I/O keeps Metal buffer count to 3 (state_in, consts, state_out). + Uses overflow-safe hypot (no torch.hypot — unsupported in Metal codegen). + """ + # Unpack + alphabar = state[_I_ALPHABAR] + damp = state[_I_DAMP] + beta = state[_I_BETA] + alpha = state[_I_ALPHA] + sbar = state[_I_SBAR] + cbar = state[_I_CBAR] + zetabar = state[_I_ZETABAR] + rho = state[_I_RHO] + rhobar = state[_I_RHOBAR] + rhodold = state[_I_RHODOLD] + tautildeold = state[_I_TAUTILDEOLD] + thetatilde = state[_I_THETATILDE] + betadd = state[_I_BETADD] + betad = state[_I_BETAD] + d = state[_I_D] + normA2 = state[_I_NORMA2] + maxrbar = state[_I_MAXRBAR] + minrbar = state[_I_MINRBAR] + normb = state[_I_NORMB] + zetaold = state[_I_ZETA] # zeta from previous iteration (for normr estimation) + + atol_t = consts[0] + ctol = consts[2] + + _ZERO = state[_I_ALPHABAR] * 0.0 # device-local zero + _ONE = _ZERO + 1.0 + + # --- Givens 1: (alphabar, damp) --- + r1 = _safe_hypot(alphabar, damp) + safe_r1 = torch.where(r1 == _ZERO, _ONE, r1) + chat = torch.where(r1 == _ZERO, _ZERO, alphabar / safe_r1) + shat = torch.where(r1 == _ZERO, _ZERO, damp / safe_r1) + + # --- Givens 2: (alphahat=r1, beta) --- + rhoold = rho + r2 = _safe_hypot(r1, beta) + safe_r2 = torch.where(r2 == _ZERO, _ONE, r2) + c = torch.where(r2 == _ZERO, _ZERO, r1 / safe_r2) + s = torch.where(r2 == _ZERO, _ZERO, beta / safe_r2) + rho_new = r2 + thetanew = s * alpha + alphabar_new = c * alpha + + # --- Givens 3: rhobar --- + rhobarold = rhobar + thetabar = sbar * rho_new + rhotemp = cbar * rho_new + r3 = _safe_hypot(rhotemp, thetanew) + safe_r3 = torch.where(r3 == _ZERO, _ONE, r3) + cbar_new = torch.where(r3 == _ZERO, _ZERO, rhotemp / safe_r3) + sbar_new = torch.where(r3 == _ZERO, _ZERO, thetanew / safe_r3) + rhobar_new = r3 + zeta = cbar_new * zetabar + zetabar_new = -sbar_new * zetabar + + # --- ||r|| estimation --- + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd_new = -s * betaacute + + # Givens 4: rhotilde + r4 = _safe_hypot(rhodold, thetabar) + safe_r4 = torch.where(r4 == _ZERO, _ONE, r4) + ctildeold = torch.where(r4 == _ZERO, _ZERO, rhodold / safe_r4) + stildeold = torch.where(r4 == _ZERO, _ZERO, thetabar / safe_r4) + + thetatilde_new = stildeold * rhobar_new + rhodold_new = ctildeold * rhobar_new + betad_new = -stildeold * betad + ctildeold * betahat + + tautildeold_new = (zetaold - thetatilde * tautildeold) / torch.clamp( + r4, min=_DIV_GUARD + ) + taud = (zeta - thetatilde_new * tautildeold_new) / torch.clamp( + rhodold_new, min=_DIV_GUARD + ) + d_new = d + betacheck * betacheck + normr = torch.sqrt(d_new + (betad_new - taud) ** 2 + betadd_new * betadd_new) + + # --- ||A|| estimation --- + normA2_new = normA2 + beta * beta + normA = torch.sqrt(normA2_new) + normA2_final = normA2_new + alpha * alpha + + # --- cond(A) estimation --- + maxrbar_new = torch.maximum(maxrbar, rhobarold) + # Match SciPy: only update minrbar from iteration 2 onward. + # maxrbar == 0 on the first call (initial state), so use it as guard. + minrbar_new = torch.where(maxrbar > 0, torch.minimum(minrbar, rhobarold), minrbar) + condA = torch.maximum(maxrbar_new, rhotemp) / torch.clamp( + torch.minimum(minrbar_new, rhotemp), min=_DIV_GUARD + ) + + # --- Convergence check --- + normar = torch.abs(zetabar_new) + test2 = normar / torch.clamp(normA * normr, min=_DIV_GUARD) + test3 = _ONE / condA + + converged_flag = torch.where( + (test2 <= atol_t) + | (test3 <= ctol) + | (_ONE + test2 <= _ONE) + | (_ONE + test3 <= _ONE), + _ONE, + _ZERO, + ) + + # --- Pack output --- + return torch.stack( + [ + alphabar_new, # 0 _I_ALPHABAR + damp, # 1 _I_DAMP (pass through) + beta, # 2 _I_BETA (pass through, updated by caller) + alpha, # 3 _I_ALPHA (pass through, updated by caller) + sbar_new, # 4 _I_SBAR + cbar_new, # 5 _I_CBAR + zetabar_new, # 6 _I_ZETABAR + rho_new, # 7 _I_RHO + rhobar_new, # 8 _I_RHOBAR + rhodold_new, # 9 _I_RHODOLD + tautildeold_new, # 10 _I_TAUTILDEOLD + thetatilde_new, # 11 _I_THETATILDE + betadd_new, # 12 _I_BETADD + betad_new, # 13 _I_BETAD + d_new, # 14 _I_D + normA2_final, # 15 _I_NORMA2 + maxrbar_new, # 16 _I_MAXRBAR + minrbar_new, # 17 _I_MINRBAR + normb, # 18 _I_NORMB (pass through) + zeta, # 19 _I_ZETA (saved for next iteration's zetaold) + thetanew, # 20 _O_THETANEW (for vector update) + thetabar, # 21 _O_THETABAR (for vector update) + zeta, # 22 _O_ZETA (for vector update — same as slot 19) + rhoold, # 23 _O_RHOOLD (for vector update) + rhobarold, # 24 _O_RHOBAROLD (for vector update) + converged_flag, # 25 _O_CONVERGED + normr, # 26 _O_NORMR + normar, # 27 _O_NORMAR + normA, # 28 _O_NORMA + condA, # 29 _O_CONDA + _ZERO, # 30 _O_NORMX_EST (placeholder) + ] + ) + + +# --------------------------------------------------------------------------- +# Module-level compilation cache +# --------------------------------------------------------------------------- +_compiled_step_cache: dict[str, object] = {} +_cache_lock = threading.Lock() + + +def _get_compiled_step(device_type: str): + """Get or create compiled scalar step for the given device type.""" + if device_type in _compiled_step_cache: + return _compiled_step_cache[device_type] + with _cache_lock: + # Double-check after acquiring lock + if device_type not in _compiled_step_cache: + try: + _compiled_step_cache[device_type] = torch.compile( + _scalar_step, backend="inductor", fullgraph=True + ) + except Exception: + warnings.warn( + "torch.compile failed for LSMR scalar step; " + "falling back to eager mode.", + RuntimeWarning, + stacklevel=3, + ) + _compiled_step_cache[device_type] = _scalar_step + return _compiled_step_cache[device_type] diff --git a/pyfixest/estimation/torch/demean_torch_.py b/pyfixest/estimation/torch/demean_torch_.py new file mode 100644 index 000000000..949d6f0b3 --- /dev/null +++ b/pyfixest/estimation/torch/demean_torch_.py @@ -0,0 +1,459 @@ +""" +FWL demeaning via LSMR in pure PyTorch. + +Builds a sparse dummy matrix D directly from integer-encoded fixed effects +(no formulaic/pandas dependency), then solves D @ theta = x via LSMR per column. +Diagonal preconditioning normalizes column norms when group sizes vary. + +Sparse format is chosen per device: CSR on CUDA/CPU (cuSPARSE / native), +COO on MPS (Metal does not support sparse CSR). +""" + +from __future__ import annotations + +import warnings + +import numpy as np +import torch +from numpy.typing import NDArray + +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch, lsmr_torch_batched + +# Minimum K (number of RHS columns) for batched SpMM to beat sequential SpMV. +# Benchmarked breakeven is device-specific. +_BATCHED_K_THRESHOLD_CUDA = 2 +_BATCHED_K_THRESHOLD_MPS = 5 + + +def _should_use_batched_lsmr(device: torch.device, K: int) -> bool: + """Use batched LSMR only when device-specific benchmarks show a benefit.""" + if device.type == "cuda": + return K >= _BATCHED_K_THRESHOLD_CUDA + if device.type == "mps": + return K >= _BATCHED_K_THRESHOLD_MPS + return False + + +def _get_device(dtype: torch.dtype = torch.float64) -> torch.device: + """Auto-detect best available device: CUDA > MPS > CPU. + + MPS does not support float64, so we fall back to CPU when float64 is needed. + When MPS is available but dtype is float64, a hint is issued to use float32. + """ + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + if dtype != torch.float64: + return torch.device("mps") + warnings.warn( + "MPS GPU is available but requires float32. " + "Pass `dtype=torch.float32` to `demean_torch` for GPU acceleration. " + "Falling back to CPU.", + UserWarning, + stacklevel=3, + ) + return torch.device("cpu") + warnings.warn( + "No GPU available — torch demeaning will run on CPU, which is slower " + "than the scipy backend. Consider using `demean_scipy` instead.", + UserWarning, + stacklevel=3, + ) + return torch.device("cpu") + + +def _build_sparse_dummy( + flist: NDArray[np.uint64], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Build sparse dummy matrix D from integer-encoded FE array. + + For n_factors factors, D has shape (N, total_groups) where + total_groups = sum of unique groups per factor. Columns are stacked: + [factor0_groups | factor1_groups | ...]. + + No reference-level dropping — LSMR finds the min-norm solution, + which gives correct residuals for rank-deficient systems. + + Returns COO on MPS (Metal has no CSR kernels), CSR otherwise. + + Parameters + ---------- + flist : np.ndarray, shape (N, n_factors), dtype uint64 + Integer-encoded fixed effects. Must be contiguous 0-based integers + per factor (i.e., values in [0, n_groups_j) with no gaps). + device : torch.device + Target device. + dtype : torch.dtype + Value dtype (e.g. torch.float64). + + Returns + ------- + D : torch.Tensor + Sparse tensor of shape (N, total_groups). COO on MPS, CSR otherwise. + + Raises + ------ + ValueError + If any factor has non-contiguous group IDs (gaps in the integer encoding). + """ + N, n_factors = flist.shape + + row_indices = [] + col_indices = [] + col_offset = 0 + + for j in range(n_factors): + col_j = flist[:, j] + unique_vals = np.unique(col_j) + n_unique_j = len(unique_vals) + n_groups_j = int(unique_vals[-1]) + 1 + + if n_groups_j != n_unique_j: + raise ValueError( + f"Factor {j} has non-contiguous group IDs: " + f"max ID is {n_groups_j - 1} but only {n_unique_j} unique values found. " + f"Re-encode to contiguous 0-based integers." + ) + + rows = np.arange(N, dtype=np.int64) + cols = col_j.astype(np.int64) + col_offset + + row_indices.append(rows) + col_indices.append(cols) + col_offset += n_groups_j + + all_rows = np.concatenate(row_indices) + all_cols = np.concatenate(col_indices) + + indices = torch.tensor( + np.stack([all_rows, all_cols]), dtype=torch.long, device=device + ) + values = torch.ones(len(all_rows), dtype=dtype, device=device) + + D_coo = torch.sparse_coo_tensor(indices, values, size=(N, col_offset)) + + if device.type == "mps": + return D_coo.coalesce() + return D_coo.to_sparse_csr() + + +def _scale_sparse_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse matrix by a dense vector (row-wise multiply). + + Dispatches to the appropriate implementation based on sparse layout: + CSR on CUDA/CPU, COO on MPS. + """ + if D.layout == torch.sparse_csr: + return _scale_csr_rows(D, scale) + return _scale_coo_rows(D, scale) + + +def _scale_csr_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse CSR matrix. + + Operates directly on CSR internal arrays, avoiding COO format roundtrip. + Uses repeat_interleave to expand per-row scales to per-nonzero scales. + """ + crow = D.crow_indices() + col = D.col_indices() + row_counts = crow[1:] - crow[:-1] + val = D.values() * torch.repeat_interleave(scale, row_counts) + + return torch.sparse_csr_tensor( + crow, col, val, D.shape, dtype=D.dtype, device=D.device + ) + + +def _scale_coo_rows(D: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale rows of a sparse COO matrix. + + Indexes into the scale vector using COO row indices to expand + per-row scales to per-nonzero scales. + """ + d_indices = D.indices() + new_values = D.values() * scale[d_indices[0]] + return torch.sparse_coo_tensor( + d_indices, new_values, D.shape, device=D.device + ).coalesce() + + +def _demean_torch_on_device( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None, + weights: NDArray[np.float64] | None, + tol: float, + maxiter: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple[NDArray[np.float64], bool]: + """ + Core demeaning implementation for a specific device and dtype. + + This is the shared workhorse called by all public wrappers. + See `demean_torch` for full parameter documentation. + """ + if flist is None: + raise ValueError("flist cannot be None") + if weights is None: + weights = np.ones(x.shape[0], dtype=np.float64) + + return _demean_torch_on_device_impl(x, flist, weights, tol, maxiter, device, dtype) + + +@torch.no_grad() +def _demean_torch_on_device_impl( + x: NDArray[np.float64], + flist: NDArray[np.uint64], + weights: NDArray[np.float64], + tol: float, + maxiter: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple[NDArray[np.float64], bool]: + """Inner implementation wrapped in torch.no_grad() to skip autograd overhead.""" + # Track original shape to restore on output + was_1d = x.ndim == 1 + x_2d = x[:, None] if was_1d else x + weights_1d = weights.ravel() + + # Move to torch — use from_numpy for zero-copy when staying on CPU + x_c = x_2d if x_2d.flags["C_CONTIGUOUS"] else np.ascontiguousarray(x_2d) + w_c = ( + weights_1d + if weights_1d.flags["C_CONTIGUOUS"] + else np.ascontiguousarray(weights_1d) + ) + x_t = torch.from_numpy(x_c).to(dtype=dtype, device=device) + w_t = torch.from_numpy(w_c).to(dtype=dtype, device=device) + + # Ensure flist is 2D + flist_2d = flist if flist.ndim == 2 else flist[:, None] + + # Build sparse dummy matrix (unweighted) + D_unweighted = _build_sparse_dummy(flist_2d, device, dtype) + _, D_cols = D_unweighted.shape + K = x_t.shape[1] + + # Apply sqrt-weights + sqrt_w = torch.sqrt(w_t) + x_w = x_t * sqrt_w[:, None] + + # Weight the dummy matrix: D_weighted = diag(sqrt_w) @ D_unweighted + D_weighted = _scale_sparse_rows(D_unweighted, sqrt_w) + + # Diagonal preconditioning: M_inv = 1 / sqrt(D_unweighted^T @ w) + # D_unweighted^T @ w gives the sum of weights per group + group_weights = D_unweighted.t() @ w_t + + # Guard against zero-weight groups (would produce inf in M_inv) + zero_weight_mask = group_weights <= 0.0 + if zero_weight_mask.any(): + bad_groups = zero_weight_mask.nonzero(as_tuple=False).squeeze(-1).tolist() + raise ValueError( + f"Fixed effect groups {bad_groups} have zero total weight. " + f"Check your weights or fixed effect encoding." + ) + + M_inv = 1.0 / torch.sqrt(group_weights) + + # Build preconditioned operator once (not per column) + A_precond = _PreconditionedSparse(D_weighted, M_inv) + + # Solve for each column — batched SpMM for K >= threshold, sequential otherwise. + theta = torch.zeros(D_cols, K, dtype=dtype, device=device) + + if _should_use_batched_lsmr(device, K): + # Batched: single call, K columns solved simultaneously via SpMM + Z, istop_vec, _itn, *_ = lsmr_torch_batched( + A_precond, + x_w, + damp=0.0, + atol=tol, + btol=tol, + maxiter=maxiter, + ) + theta = M_inv.unsqueeze(1) * Z + success = ((istop_vec >= 1) & (istop_vec <= 3)).all().item() + else: + # Sequential: K < threshold, per-column single-RHS path is faster + success = True + for k in range(K): + z, istop, _itn, _normr, _normar, _normA, _condA, _normx = lsmr_torch( + A_precond, + x_w[:, k], + damp=0.0, + atol=tol, + btol=tol, + maxiter=maxiter, + ) + theta[:, k] = M_inv * z + success = success and (istop in (1, 2, 3)) + + # Compute residuals: x_demeaned = x - D_unweighted @ theta + x_demeaned = x_t - D_unweighted @ theta + + result = x_demeaned.cpu().numpy() + if was_1d: + result = result[:, 0] + + return result, success + + +def demean_torch( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, + dtype: torch.dtype = torch.float64, +) -> tuple[NDArray[np.float64], bool]: + """ + Demean x by projecting out fixed effects via FWL + LSMR. + + Auto-detects the best available device (CUDA > MPS > CPU). + For explicit device control, use the device-specific variants: + `demean_torch_cpu`, `demean_torch_mps`, `demean_torch_cuda`, + `demean_torch_cuda32`. + + Parameters + ---------- + x : np.ndarray, shape (N,) or (N, K) + Variables to demean. + flist : np.ndarray, shape (N, n_factors), dtype uint64 + Integer-encoded fixed effects. + weights : np.ndarray, shape (N,) + Observation weights (1.0 for equal weighting). + tol : float + Convergence tolerance for LSMR (used as both atol and btol). + maxiter : int + Maximum LSMR iterations per column. + dtype : torch.dtype + Tensor dtype. Use ``torch.float32`` to enable MPS (Apple GPU) + acceleration. Default ``torch.float64`` for full precision. + + Returns + ------- + x_demeaned : np.ndarray + Residuals after projecting out fixed effects. Same shape as input x. + success : bool + True if LSMR converged for all columns. + """ + device = _get_device(dtype) + return _demean_torch_on_device(x, flist, weights, tol, maxiter, device, dtype) + + +def _make_demean_variant( + device_str: str, + dtype: torch.dtype, + doc: str, +): + """Create a device-specific demean wrapper.""" + + def _demean( + x: NDArray[np.float64], + flist: NDArray[np.uint64] | None = None, + weights: NDArray[np.float64] | None = None, + tol: float = 1e-8, + maxiter: int = 100_000, + ) -> tuple[NDArray[np.float64], bool]: + return _demean_torch_on_device( + x, + flist, + weights, + tol, + maxiter, + device=torch.device(device_str), + dtype=dtype, + ) + + _demean.__doc__ = doc + _demean.__qualname__ = f"demean_torch_{device_str}" + return _demean + + +demean_torch_cpu = _make_demean_variant( + "cpu", torch.float64, "Torch demeaner on CPU, float64." +) +demean_torch_mps = _make_demean_variant( + "mps", torch.float32, "Torch demeaner on MPS (Apple GPU), float32." +) +demean_torch_cuda = _make_demean_variant( + "cuda", torch.float64, "Torch demeaner on CUDA GPU, float64." +) +demean_torch_cuda32 = _make_demean_variant( + "cuda", torch.float32, "Torch demeaner on CUDA GPU, float32." +) + + +class _PreconditionedSparse: + """ + Wraps a sparse matrix D and diagonal preconditioner M_inv + to present A_precond = D @ diag(M_inv) for LSMR. + + This avoids forming the preconditioned matrix explicitly — + just element-wise multiply before/after matvec. + + The transpose view is cached and returned by `.t()`, so LSMR's + repeated `A.t().mv(u)` calls don't allocate a new object each time. + + D^T is pre-computed once in a GPU-friendly sparse layout to avoid + per-iteration reconversion (COO coalesce on MPS, CSR radixSort on CUDA). + """ + + def __init__( + self, + D: torch.Tensor, + M_inv: torch.Tensor, + *, + _transposed: bool = False, + _D_t: torch.Tensor | None = None, + ): + m, n = D.shape + self.shape = (n, m) if _transposed else (m, n) + self._D = D + self._M_inv = M_inv + self._transposed = _transposed + self._T: _PreconditionedSparse | None = None + self._D_t = _D_t if _D_t is not None else self._materialize_transpose(D) + + @staticmethod + def _materialize_transpose(D: torch.Tensor) -> torch.Tensor: + """Pre-compute D^T in a GPU-friendly sparse layout.""" + D_t = D.t() + layout = D_t.layout + if layout == torch.sparse_coo: + return D_t.coalesce() + if layout in (torch.sparse_csr, torch.sparse_csc): + return D_t.to_sparse_csr() + return D_t + + def mv(self, v: torch.Tensor) -> torch.Tensor: + if self._transposed: + # Compute M_inv * (D^T @ u) — uses pre-computed transpose + return self._M_inv * (self._D_t @ v) + # Compute D @ (M_inv * v) + return self._D @ (self._M_inv * v) + + def mm(self, V: torch.Tensor) -> torch.Tensor: + """Batched matvec: A_precond @ V where V is (n, K) or (m, K). + + Same logic as mv() but broadcasts M_inv over K columns via unsqueeze. + """ + if self._transposed: + return self._M_inv.unsqueeze(1) * (self._D_t @ V) + return self._D @ (self._M_inv.unsqueeze(1) * V) + + def t(self) -> _PreconditionedSparse: + """Return cached transpose view.""" + if self._T is None: + self._T = _PreconditionedSparse( + self._D, + self._M_inv, + _transposed=not self._transposed, + _D_t=self._D_t, + ) + self._T._T = self # cross-link so .t().t() returns self + return self._T diff --git a/pyfixest/estimation/torch/lsmr_torch.py b/pyfixest/estimation/torch/lsmr_torch.py new file mode 100644 index 000000000..2e8dc87b7 --- /dev/null +++ b/pyfixest/estimation/torch/lsmr_torch.py @@ -0,0 +1,1312 @@ +""" +Pure PyTorch LSMR iterative solver with optional torch.compile kernel fusion. + +Four implementations: + +0. ``_lsmr_batched`` — eager batched LSMR for K right-hand sides via SpMM. + Uses vectorized (K,) Givens rotations with ``_sym_ortho_vec``. + +1. ``_lsmr_eager`` — eager single-RHS LSMR, Python-float Givens rotations. + Best for CPU. + +2. ``_lsmr_compiled`` — packs scalar state into a 1-D tensor and runs + the Givens / norm / convergence work through a ``torch.compile``-d + kernel. On CUDA this fuses ~60 per-iteration kernel launches into one. + +3. ``_lsmr_compiled_batched`` — compiled batched LSMR for K RHS via SpMM. + Packs scalar state into a (_STATE_SIZE, K) tensor — the same compiled + ``_scalar_step`` serves both single-RHS and batched paths since all ops + are shape-agnostic. Fuses scalar work into one kernel per iteration. + +The compiled scalar step kernel, packed-state layout (31 index constants), +and compilation cache live in ``_lsmr_compiled_core.py``. + +Public entry points: +- ``lsmr_torch()`` dispatches: CUDA → compiled, CPU/MPS → eager. +- ``lsmr_torch_batched()`` dispatches: CUDA → compiled batched, + CPU/MPS → eager batched. Pass ``use_compile=True/False`` to override. + +Reference: + D. C.-L. Fong and M. A. Saunders, + "LSMR: An iterative algorithm for sparse least-squares problems", + SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. +""" + +from __future__ import annotations + +import math + +import torch + +from pyfixest.estimation.torch._lsmr_compiled_core import ( + _DIV_GUARD, + _I_ALPHA, + _I_BETA, + _I_NORMB, + _I_RHO, + _I_RHOBAR, + _O_CONDA, + _O_CONVERGED, + _O_NORMA, + _O_NORMAR, + _O_NORMR, + _O_RHOBAROLD, + _O_RHOOLD, + _O_THETABAR, + _O_THETANEW, + _O_ZETA, + _STATE_SIZE, + _get_compiled_step, + _make_initial_state, + _scalar_step, +) + +# --------------------------------------------------------------------------- +# Sparse matvec helpers +# --------------------------------------------------------------------------- + + +def _matvec(A, v: torch.Tensor) -> torch.Tensor: + """Compute A @ v for both torch.Tensor (dense/sparse) and duck-typed wrappers.""" + if isinstance(A, torch.Tensor): + return A @ v + return A.mv(v) + + +def _rmatvec(At, u: torch.Tensor) -> torch.Tensor: + """Multiply A^T @ u using a pre-computed transpose.""" + if isinstance(At, torch.Tensor): + return At @ u + return At.mv(u) + + +def _precompute_transpose(A): + """Pre-compute A^T in a GPU-friendly layout to avoid per-iteration reconversion.""" + if isinstance(A, torch.Tensor) and A.is_sparse_csr: + return A.t().to_sparse_csr() + elif isinstance(A, torch.Tensor): + return A.t().contiguous() + return A.t() # LinearOperator / wrapper — assume .mv() is efficient + + +# --------------------------------------------------------------------------- +# Scalar Givens rotation (Python math — no autograd overhead) +# --------------------------------------------------------------------------- + + +def _sym_ortho(a: float, b: float) -> tuple[float, float, float]: + """ + Stable Givens rotation (SymOrtho). + + Given scalars a and b, compute c, s, r such that: + [ c s ] [ a ] = [ r ] + [-s c ] [ b ] [ 0 ] + + This is the same algorithm as SciPy's ``_sym_ortho`` from LSQR, + using pure Python math for speed on scalar values. + """ + if b == 0.0: + c = 0.0 if a == 0.0 else math.copysign(1.0, a) + return c, 0.0, abs(a) + elif a == 0.0: + return 0.0, math.copysign(1.0, b), abs(b) + elif abs(b) > abs(a): + tau = a / b + s = math.copysign(1.0, b) / math.sqrt(1.0 + tau * tau) + c = s * tau + r = b / s + else: + tau = b / a + c = math.copysign(1.0, a) / math.sqrt(1.0 + tau * tau) + s = c * tau + r = a / c + return c, s, r + + +# --------------------------------------------------------------------------- +# Batched matvec helpers (SpMM: sparse @ dense matrix) +# --------------------------------------------------------------------------- + + +def _matvec_batched(A, V: torch.Tensor) -> torch.Tensor: + """A @ V where V is (n, K). SpMM for sparse A, mm() for wrappers.""" + if isinstance(A, torch.Tensor): + return A @ V + return A.mm(V) + + +def _rmatvec_batched(At, U: torch.Tensor) -> torch.Tensor: + """A^T @ U where U is (m, K). SpMM.""" + if isinstance(At, torch.Tensor): + return At @ U + return At.mm(U) + + +# --------------------------------------------------------------------------- +# Vectorized Givens rotation for (K,) tensors +# --------------------------------------------------------------------------- + + +def _sym_ortho_vec( + a: torch.Tensor, b: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Stable Givens rotation (SymOrtho) vectorized over K columns. + + Given (K,) tensors a and b, compute c, s, r such that for each k: + [ c_k s_k ] [ a_k ] = [ r_k ] + [-s_k c_k ] [ b_k ] [ 0 ] + + Uses torch.where for branchless execution on GPU. Division guards + use ones_like (not clamp) to preserve sign in dead lanes. + """ + abs_a = torch.abs(a) + abs_b = torch.abs(b) + zero = torch.zeros_like(a) + one = torch.ones_like(a) + + # Safe divisors: replace zeros with ones to prevent NaN in dead lanes. + # The result of the dead-lane computation is discarded by torch.where. + safe_a = torch.where(a != 0, a, one) + safe_b = torch.where(b != 0, b, one) + + # Case 1: b == 0 + c_b0 = torch.where(a == 0, zero, torch.sign(a)) + s_b0 = zero + r_b0 = abs_a + + # Case 2: a == 0 + c_a0 = zero + s_a0 = torch.sign(b) + r_a0 = abs_b + + # Case 3: |b| > |a| (neither zero) + tau_3 = a / safe_b + s_3 = torch.sign(b) / torch.sqrt(one + tau_3 * tau_3) + safe_s_3 = torch.where(s_3 != 0, s_3, one) + c_3 = s_3 * tau_3 + r_3 = b / safe_s_3 + + # Case 4: |a| >= |b| (neither zero) + tau_4 = b / safe_a + c_4 = torch.sign(a) / torch.sqrt(one + tau_4 * tau_4) + safe_c_4 = torch.where(c_4 != 0, c_4, one) + s_4 = c_4 * tau_4 + r_4 = a / safe_c_4 + + # Select: b==0 → case1, a==0 → case2, |b|>|a| → case3, else → case4 + is_b0 = b == 0 + is_a0 = a == 0 + is_b_gt_a = abs_b > abs_a + + # Build from innermost to outermost condition + c = torch.where(is_b_gt_a, c_3, c_4) + s = torch.where(is_b_gt_a, s_3, s_4) + r = torch.where(is_b_gt_a, r_3, r_4) + + c = torch.where(is_a0, c_a0, c) + s = torch.where(is_a0, s_a0, s) + r = torch.where(is_a0, r_a0, r) + + c = torch.where(is_b0, c_b0, c) + s = torch.where(is_b0, s_b0, s) + r = torch.where(is_b0, r_b0, r) + + return c, s, r + + +# --------------------------------------------------------------------------- +# Shared batched helpers +# --------------------------------------------------------------------------- + + +def _safe_normalize_cols( + M: torch.Tensor, norms: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Divide each column of (m, K) matrix ``M`` by its (K,) ``norms``, + zeroing columns where ``norms == 0``. + + Returns ``(M_normalized, norms)`` so the caller has the norms for later use. + """ + nonzero = norms > 0 + safe = torch.where(nonzero, norms, torch.ones_like(norms)) + # Multiply by mask instead of boolean indexing (M[:, ~nonzero] = 0.0) + # to avoid a GPU→CPU sync from aten::_local_scalar_dense in the indexing path. + mask = nonzero.unsqueeze(0).to(M.dtype) + M = (M / safe.unsqueeze(0)) * mask + return M, norms + + +def _check_convergence_batched( + istop: torch.Tensor, + test1: torch.Tensor, + rtol: torch.Tensor, + test2: torch.Tensor, + test3: torch.Tensor, + t1: torch.Tensor, + atol: float, + ctol: float, + K: int, + device: torch.device, +) -> torch.Tensor: + """ + Per-column convergence check for batched LSMR. + + Sets istop per column using an only-set-once latch: once a column's + istop becomes non-zero, it is never overwritten. Returns updated istop. + """ + not_yet = istop == 0 + new_stop = torch.zeros(K, device=device, dtype=torch.long) + new_stop = torch.where(test1 <= rtol, torch.ones_like(new_stop), new_stop) + new_stop = torch.where( + (test2 <= atol) & (new_stop == 0), + 2 * torch.ones_like(new_stop), + new_stop, + ) + new_stop = torch.where( + (test3 <= ctol) & (new_stop == 0), + 3 * torch.ones_like(new_stop), + new_stop, + ) + new_stop = torch.where( + (1.0 + t1 <= 1.0) & (new_stop == 0), + 4 * torch.ones_like(new_stop), + new_stop, + ) + new_stop = torch.where( + (1.0 + test2 <= 1.0) & (new_stop == 0), + 5 * torch.ones_like(new_stop), + new_stop, + ) + new_stop = torch.where( + (1.0 + test3 <= 1.0) & (new_stop == 0), + 6 * torch.ones_like(new_stop), + new_stop, + ) + return torch.where(not_yet, new_stop, istop) + + +def _mark_maxiter_batched(istop: torch.Tensor, itn: int, maxiter: int) -> torch.Tensor: + """Set istop=7 for columns that did not converge before maxiter.""" + return torch.where( + (istop == 0) & (itn >= maxiter), + 7 * torch.ones_like(istop), + istop, + ) + + +# =========================================================================== +# Implementation 0: batched LSMR — K right-hand sides via SpMM +# =========================================================================== + + +def _lsmr_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. + + Replaces K sequential SpMV with SpMM for GPU throughput. + All K columns run in lock-step; converged columns do harmless work. + + Parameters + ---------- + A : sparse tensor or LinearOperator-like + Matrix of shape (m, n). + B : torch.Tensor + Dense matrix of shape (m, K). + damp : float + Damping factor. + atol, btol : float + Stopping tolerances (same for all columns). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + + Returns + ------- + X : (n, K) solution matrix + istop : (K,) int tensor — per-column stopping reason + itn : int — iterations used + normr : (K,) — per-column ||b - Ax|| + normar : (K,) — per-column ||A^T(b - Ax)|| + normA : (K,) — per-column estimate of ||A||_F + condA : (K,) — per-column estimate of cond(A) + normx : (K,) — per-column ||x|| + """ + m, n = A.shape + K = B.shape[1] + device = B.device + dtype = B.dtype + + if maxiter is None: + maxiter = min(m, n) + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) + + X = torch.zeros(n, K, device=device, dtype=dtype) + beta = normb.clone() # (K,) + + U, beta = _safe_normalize_cols(U, beta) + + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) + + V, alpha = _safe_normalize_cols(V, alpha) + + # --- Scalar state as (K,) tensors --- + itn = 0 + zetabar = alpha * beta + alphabar = alpha.clone() + rho = torch.ones(K, device=device, dtype=dtype) + rhobar = torch.ones(K, device=device, dtype=dtype) + cbar = torch.ones(K, device=device, dtype=dtype) + sbar = torch.zeros(K, device=device, dtype=dtype) + + H = V.clone() # (n, K) + Hbar = torch.zeros(n, K, device=device, dtype=dtype) + + # ||r|| estimation state + betadd = beta.clone() + betad = torch.zeros(K, device=device, dtype=dtype) + rhodold = torch.ones(K, device=device, dtype=dtype) + tautildeold = torch.zeros(K, device=device, dtype=dtype) + thetatilde = torch.zeros(K, device=device, dtype=dtype) + zeta = torch.zeros(K, device=device, dtype=dtype) + d = torch.zeros(K, device=device, dtype=dtype) + + # ||A|| and cond(A) estimation + normA2 = alpha * alpha + maxrbar = torch.zeros(K, device=device, dtype=dtype) + minrbar_init = 1e100 if dtype == torch.float64 else 1e10 + minrbar = torch.full((K,), minrbar_init, device=device, dtype=dtype) + + # Convergence tracking + istop = torch.zeros(K, device=device, dtype=torch.long) + ctol = 1.0 / conlim if conlim > 0 else 0.0 + normr = beta.clone() + normar = alpha * beta + + damp_vec = torch.full((K,), damp, device=device, dtype=dtype) + + # Early exit: if all normar == 0 or all normb == 0 + if (normar == 0).all(): + return ( + X, + istop, + itn, + normr, + normar, + torch.sqrt(normA2), + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) + + if (normb == 0).all(): + X.zero_() + return ( + X, + istop, + itn, + normr, + normar, + torch.sqrt(normA2), + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) + + # --- Main iteration loop --- + while itn < maxiter: + itn += 1 + + # Bidiagonalization step: SpMM replaces SpMV + U = _matvec_batched(A, V) - alpha.unsqueeze(0) * U # (m, K) + beta = torch.linalg.norm(U, dim=0) # (K,) + U, beta = _safe_normalize_cols(U, beta) + + V = _rmatvec_batched(At, U) - beta.unsqueeze(0) * V # (n, K) + alpha = torch.linalg.norm(V, dim=0) # (K,) + V, alpha = _safe_normalize_cols(V, alpha) + + # Givens rotation 1: (alphabar, damp) + chat, shat, alphahat = _sym_ortho_vec(alphabar, damp_vec) + + # Givens rotation 2: (alphahat, beta) + rhoold = rho + c, s, rho = _sym_ortho_vec(alphahat, beta) + thetanew = s * alpha + alphabar = c * alpha + + # Givens rotation 3: rhobar update + rhobarold = rhobar + zetaold = zeta + thetabar = sbar * rho + rhotemp = cbar * rho + cbar, sbar, rhobar = _sym_ortho_vec(rhotemp, thetanew) + zeta = cbar * zetabar + zetabar = -sbar * zetabar + + # Vector updates: broadcast (K,) scalars over (n, K) matrices + # Guard divisions: when a column has zero RHS, rho/rhobar can be 0. + # Using clamp ensures 0/0 → 0 instead of NaN (numerator is also 0). + hbar_coeff = -(thetabar * rho) / torch.clamp(rhoold * rhobarold, min=_DIV_GUARD) + Hbar = H + Hbar * hbar_coeff.unsqueeze(0) + x_coeff = zeta / torch.clamp(rho * rhobar, min=_DIV_GUARD) + X = X + x_coeff.unsqueeze(0) * Hbar + h_coeff = -(thetanew / torch.clamp(rho, min=_DIV_GUARD)) + H = V + H * h_coeff.unsqueeze(0) + + # ||r|| estimation + betaacute = chat * betadd + betacheck = -shat * betadd + betahat = c * betaacute + betadd = -s * betaacute + + thetatildeold = thetatilde + ctildeold, stildeold, rhotildeold = _sym_ortho_vec(rhodold, thetabar) + thetatilde = stildeold * rhobar + rhodold = ctildeold * rhobar + betad = -stildeold * betad + ctildeold * betahat + + safe_rhotildeold = torch.clamp(rhotildeold, min=_DIV_GUARD) + tautildeold = (zetaold - thetatildeold * tautildeold) / safe_rhotildeold + safe_rhodold = torch.clamp(rhodold, min=_DIV_GUARD) + taud = (zeta - thetatilde * tautildeold) / safe_rhodold + d = d + betacheck * betacheck + normr = torch.sqrt(d + (betad - taud) ** 2 + betadd * betadd) + + # ||A|| estimation + normA2 = normA2 + beta * beta + normA = torch.sqrt(normA2) + normA2 = normA2 + alpha * alpha + + # cond(A) estimation + maxrbar = torch.maximum(maxrbar, rhobarold) + if itn > 1: + minrbar = torch.minimum(minrbar, rhobarold) + condA = torch.maximum(maxrbar, rhotemp) / torch.clamp( + torch.minimum(minrbar, rhotemp), min=_DIV_GUARD + ) + + # Per-column convergence check + normar = torch.abs(zetabar) + normx = torch.linalg.norm(X, dim=0) # (K,) + + safe_normb = torch.clamp(normb, min=_DIV_GUARD) + test1 = normr / safe_normb + safe_normA_normr = torch.clamp(normA * normr, min=_DIV_GUARD) + test2 = normar / safe_normA_normr + test3 = 1.0 / condA + t1 = test1 / (1.0 + normA * normx / safe_normb) + rtol = btol + atol * normA * normx / safe_normb + + istop = _check_convergence_batched( + istop, + test1, + rtol, + test2, + test3, + t1, + atol, + ctol, + K, + device, + ) + if (istop > 0).all(): + break + + istop = _mark_maxiter_batched(istop, itn, maxiter) + + return X, istop, itn, normr, normar, normA, condA, normx + + +# =========================================================================== +# Implementation 1: scalar-state LSMR (CPU / MPS) +# =========================================================================== + + +def _lsmr_eager( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR iterative solver for sparse least-squares problems, in pure PyTorch. + + Solves ``min ||b - Ax||_2`` (or the damped variant) where A is a sparse + CSR (COO) tensor and b is a dense vector. All vector ops stay on the tensor's + device (CPU/CUDA/MPS). + + Parameters + ---------- + A : torch.Tensor + Sparse CSR tensor of shape (m, n). + b : torch.Tensor + Dense vector of shape (m,). + damp : float + Damping factor for regularized least-squares. + atol, btol : float + Stopping tolerances (see SciPy LSMR docs). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + + Returns + ------- + x : torch.Tensor + Solution vector of shape (n,). + istop : int + Reason for stopping (0-7, same codes as SciPy LSMR). + itn : int + Number of iterations used. + normr : float + ``||b - Ax||`` + normar : float + ``||A^T(b - Ax)||`` + normA : float + Estimate of Frobenius norm of A. + condA : float + Estimate of condition number of A. + normx : float + ``||x||`` + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b).item() + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb + + if beta > 0: + u = u * (1.0 / beta) + v = _rmatvec(At, u) + alpha = torch.linalg.norm(v).item() + else: + v = torch.zeros(n, device=device, dtype=dtype) + alpha = 0.0 + + if alpha > 0: + v = v * (1.0 / alpha) + + # --- Scalar state for iteration --- + itn = 0 + zetabar = alpha * beta + alphabar = alpha + rho = 1.0 + rhobar = 1.0 + cbar = 1.0 + sbar = 0.0 + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # Estimation of ||r|| + betadd = beta + betad = 0.0 + rhodold = 1.0 + tautildeold = 0.0 + thetatilde = 0.0 + zeta = 0.0 + d = 0.0 + + # Estimation of ||A|| and cond(A) + normA2 = alpha * alpha + maxrbar = 0.0 + minrbar = 1e100 + normA = math.sqrt(normA2) + condA = 1.0 + normx = 0.0 + + # Stopping + istop = 0 + ctol = 1.0 / conlim if conlim > 0 else 0.0 + normr = beta + normar = alpha * beta + + if normar == 0.0: + return x, istop, itn, normr, normar, normA, condA, normx + + if normb == 0.0: + x.zero_() + return x, istop, itn, normr, normar, normA, condA, normx + + # --- Main iteration loop --- + while itn < maxiter: + itn += 1 + + # Bidiagonalization step: get next beta, u, alpha, v + u = _matvec(A, v) - alpha * u + beta = torch.linalg.norm(u).item() + + if beta > 0: + u *= 1.0 / beta + v = _rmatvec(At, u) - beta * v + alpha = torch.linalg.norm(v).item() + if alpha > 0: + v *= 1.0 / alpha + + # Construct rotation Qhat_{k,2k+1} + chat, shat, alphahat = _sym_ortho(alphabar, damp) + + # Use plane rotation Q_i to turn B_i to R_i + rhoold = rho + c, s, rho = _sym_ortho(alphahat, beta) + thetanew = s * alpha + alphabar = c * alpha + + # Use plane rotation Qbar_i to turn R_i^T to R_i^bar + rhobarold = rhobar + zetaold = zeta + thetabar = sbar * rho + rhotemp = cbar * rho + cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew) + zeta = cbar * zetabar + zetabar = -sbar * zetabar + + # Update h, hbar, x (vector ops — stay on device) + hbar = h + hbar * (-(thetabar * rho) / (rhoold * rhobarold)) + x = x + (zeta / (rho * rhobar)) * hbar + h = v + h * (-(thetanew / rho)) + + # Estimate ||r|| + betaacute = chat * betadd + betacheck = -shat * betadd + + betahat = c * betaacute + betadd = -s * betaacute + + thetatildeold = thetatilde + ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar) + thetatilde = stildeold * rhobar + rhodold = ctildeold * rhobar + betad = -stildeold * betad + ctildeold * betahat + + tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold + taud = (zeta - thetatilde * tautildeold) / rhodold + d = d + betacheck * betacheck + normr = math.sqrt(d + (betad - taud) ** 2 + betadd * betadd) + + # Estimate ||A|| + normA2 = normA2 + beta * beta + normA = math.sqrt(normA2) + normA2 = normA2 + alpha * alpha + + # Estimate cond(A) + maxrbar = max(maxrbar, rhobarold) + if itn > 1: + minrbar = min(minrbar, rhobarold) + condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp) + + # Convergence tests + normar = abs(zetabar) + normx = torch.linalg.norm(x).item() + + test1 = normr / normb + test2 = normar / (normA * normr) if normA * normr != 0 else float("inf") + test3 = 1.0 / condA + t1 = test1 / (1.0 + normA * normx / normb) + rtol = btol + atol * normA * normx / normb + + if itn >= maxiter: + istop = 7 + if 1.0 + test3 <= 1.0: + istop = 6 + if 1.0 + test2 <= 1.0: + istop = 5 + if 1.0 + t1 <= 1.0: + istop = 4 + if test3 <= ctol: + istop = 3 + if test2 <= atol: + istop = 2 + if test1 <= rtol: + istop = 1 + + if istop > 0: + break + + return x, istop, itn, normr, normar, normA, condA, normx + + +# =========================================================================== +# Implementation 2: compiled-state LSMR (CUDA) +# =========================================================================== + + +def _lsmr_compiled( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool = True, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR with packed-tensor scalar state and optional torch.compile fusion. + + On CUDA the scalar Givens rotations, norm estimation, and convergence + check are fused into a **single GPU kernel** via ``torch.compile`` + + Inductor, eliminating ~60 per-iteration kernel launches. + + Called by the ``lsmr_torch`` dispatcher; ``use_compile`` is already + resolved by the caller (no auto-detection here). + """ + m, n = A.shape + device = b.device + dtype = b.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Get compiled or uncompiled step function + step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + u = b.clone() + normb = torch.linalg.norm(b) + + x = torch.zeros(n, device=device, dtype=dtype) + beta = normb.clone() + + # Safe normalize + u = u * torch.where(beta > 0, 1.0 / torch.clamp(beta, min=_DIV_GUARD), beta * 0.0) + + v = _rmatvec(At, u) + alpha = torch.linalg.norm(v) + v = v * torch.where( + alpha > 0, 1.0 / torch.clamp(alpha, min=_DIV_GUARD), alpha * 0.0 + ) + + state = _make_initial_state(alpha, beta, normb, damp, dtype, device) + + ctol = 1.0 / conlim if conlim > 0 else 0.0 + consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) + + # Early exit check + normar_init = (alpha * beta).item() + if normar_init == 0.0: + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + if normb.item() == 0.0: + x.zero_() + return x, 0, 0, beta.item(), 0.0, alpha.item(), 1.0, 0.0 + + h = v.clone() + hbar = torch.zeros(n, device=device, dtype=dtype) + + # --- Main iteration loop --- + itn = 0 + istop = 0 + + while itn < maxiter: + itn += 1 + + # Phase 1: Sparse matvec (not compilable) + # state[_I_ALPHA] holds the current alpha (passed through by _scalar_step) + u = _matvec(A, v) - state[_I_ALPHA] * u + beta_new = torch.linalg.norm(u) + u = u * torch.where( + beta_new > 0, + 1.0 / torch.clamp(beta_new, min=_DIV_GUARD), + beta_new * 0.0, + ) + + v = _rmatvec(At, u) - beta_new * v + alpha_new = torch.linalg.norm(v) + v = v * torch.where( + alpha_new > 0, + 1.0 / torch.clamp(alpha_new, min=_DIV_GUARD), + alpha_new * 0.0, + ) + + # Update beta/alpha in state for the scalar step + state[_I_BETA] = beta_new + state[_I_ALPHA] = alpha_new + + # Phase 2: Compiled scalar step (single GPU kernel on CUDA) + out = step_fn(state, consts) + + # Phase 3: Vector updates using scalar results from compiled step + thetanew = out[_O_THETANEW] + thetabar = out[_O_THETABAR] + zeta = out[_O_ZETA] + rho_new = out[_I_RHO] + rhobar_new = out[_I_RHOBAR] + rhoold = out[_O_RHOOLD] + rhobarold = out[_O_RHOBAROLD] + + hbar = h + hbar * (-(thetabar * rho_new) / (rhoold * rhobarold)) + x = x + (zeta / (rho_new * rhobar_new)) * hbar + h = v + h * (-(thetanew / rho_new)) + + # Propagate state for next iteration + state = out[:_STATE_SIZE] + + # Convergence check — single .item() sync per iteration. + # The compiled _scalar_step checks test2 (atol) and test3 (ctol). + # The btol-based test1 depends on normx (a vector quantity) and is + # computed here on-device, then combined with the scalar step's flag. + # Only one .item() call reads the combined boolean; all other tensor + # ops queue on the GPU stream without forcing a pipeline stall. + normx_t = torch.linalg.norm(x) + normr_t = out[_O_NORMR] + normA_t = out[_O_NORMA] + normb_t = out[_I_NORMB] + + test1_t = normr_t / torch.clamp(normb_t, min=_DIV_GUARD) + t1_t = test1_t / ( + 1.0 + normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD) + ) + rtol_t = btol + atol * normA_t * normx_t / torch.clamp(normb_t, min=_DIV_GUARD) + + converged_btol = (test1_t <= rtol_t) | (1.0 + t1_t <= 1.0) + converged_any = (out[_O_CONVERGED] > 0.5) | converged_btol + + if converged_any.item(): + # Pull scalars to CPU only at exit (one-time cost) + normr_val = normr_t.item() + normA_val = normA_t.item() + normx_val = normx_t.item() + normb_val = normb_t.item() + normar_val = out[_O_NORMAR].item() + condA_val = out[_O_CONDA].item() + + test1 = normr_val / max(normb_val, _DIV_GUARD) + test2 = normar_val / max(normA_val * normr_val, _DIV_GUARD) + test3 = 1.0 / condA_val + t1 = test1 / (1.0 + normA_val * normx_val / max(normb_val, _DIV_GUARD)) + _rtol = btol + atol * normA_val * normx_val / max(normb_val, _DIV_GUARD) + + # Priority order matches SciPy LSMR (lowest istop wins) + if 1.0 + test3 <= 1.0: + istop = 6 + if 1.0 + test2 <= 1.0: + istop = 5 + if 1.0 + t1 <= 1.0: + istop = 4 + if test3 <= ctol: + istop = 3 + if test2 <= atol: + istop = 2 + if test1 <= _rtol: + istop = 1 + break + + if itn >= maxiter and istop == 0: + istop = 7 + + # Handle case where loop never ran (maxiter=0 or similar) + if itn == 0: + return x, istop, 0, normb.item(), normar_init, alpha.item(), 1.0, 0.0 + + # normx_val was already computed inside the convergence block (line 725); + # only recompute if the loop exhausted maxiter without converging. + if istop == 0 or istop == 7: + normx_val = torch.linalg.norm(x).item() + return ( + x, + istop, + itn, + out[_O_NORMR].item(), + out[_O_NORMAR].item(), + out[_O_NORMA].item(), + out[_O_CONDA].item(), + normx_val, + ) + + +# =========================================================================== +# Implementation 3: compiled batched LSMR — K RHS via SpMM + torch.compile +# =========================================================================== + + +def _lsmr_compiled_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool = True, +) -> tuple[ + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Compiled batched LSMR: solve min ||B - A X||_F for K RHS simultaneously. + + Mirrors ``_lsmr_compiled`` but with (m, K) vectors and SpMM. The scalar + state is packed into a (_STATE_SIZE, K) tensor — each of the 20 scalar + quantities becomes a (K,) vector. ``_scalar_step`` is shape-agnostic: + its indexing and element-wise ops broadcast over the K dimension without + any code changes. + + Called by ``lsmr_torch_batched``; ``use_compile`` is already resolved. + """ + m, n = A.shape + K = B.shape[1] + device = B.device + dtype = B.dtype + + if maxiter is None: + maxiter = min(m, n) + + # Get compiled or uncompiled step function + step_fn = _get_compiled_step(device.type) if use_compile else _scalar_step + + At = _precompute_transpose(A) + + # --- Initialize Golub-Kahan bidiagonalization --- + U = B.clone() # (m, K) + normb = torch.linalg.norm(U, dim=0) # (K,) + + X = torch.zeros(n, K, device=device, dtype=dtype) + beta = normb.clone() # (K,) + + U, beta = _safe_normalize_cols(U, beta) + + V = _rmatvec_batched(At, U) # (n, K) — SpMM + alpha = torch.linalg.norm(V, dim=0) # (K,) + V, alpha = _safe_normalize_cols(V, alpha) + + state = _make_initial_state(alpha, beta, normb, damp, dtype, device, K=K) + + ctol = 1.0 / conlim if conlim > 0 else 0.0 + # consts (3,) — _scalar_step indexes as consts[0]/consts[2] which broadcast + # against (K,) state elements automatically. + consts = torch.tensor([atol, btol, ctol], device=device, dtype=dtype) + + # Early exit check + normar_init = alpha * beta # (K,) + if (normar_init == 0).all(): + return ( + X, + torch.zeros(K, device=device, dtype=torch.long), + 0, + beta, + torch.zeros(K, device=device, dtype=dtype), + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) + if (normb == 0).all(): + X.zero_() + return ( + X, + torch.zeros(K, device=device, dtype=torch.long), + 0, + beta, + torch.zeros(K, device=device, dtype=dtype), + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) + + H = V.clone() # (n, K) + Hbar = torch.zeros(n, K, device=device, dtype=dtype) + + # Convergence tracking: per-column istop, only-set-once latch + istop = torch.zeros(K, device=device, dtype=torch.long) + + # --- Main iteration loop --- + itn = 0 + normx_t = torch.zeros(K, device=device, dtype=dtype) + + while itn < maxiter: + itn += 1 + + # Phase 1: SpMM bidiagonalization (not compilable) + U = _matvec_batched(A, V) - state[_I_ALPHA].unsqueeze(0) * U # (m, K) + beta_new = torch.linalg.norm(U, dim=0) # (K,) + U, beta_new = _safe_normalize_cols(U, beta_new) + + V = _rmatvec_batched(At, U) - beta_new.unsqueeze(0) * V # (n, K) + alpha_new = torch.linalg.norm(V, dim=0) # (K,) + V, alpha_new = _safe_normalize_cols(V, alpha_new) + + # Update beta/alpha in state for the scalar step + state[_I_BETA] = beta_new + state[_I_ALPHA] = alpha_new + + # Phase 2: Compiled scalar step — (_STATE_SIZE, K) → (_OUTPUT_SIZE, K) + out = step_fn(state, consts) + + # Phase 3: Vector updates using scalar results from compiled step + thetanew = out[_O_THETANEW] # (K,) + thetabar = out[_O_THETABAR] # (K,) + zeta = out[_O_ZETA] # (K,) + rho_new = out[_I_RHO] # (K,) + rhobar_new = out[_I_RHOBAR] # (K,) + rhoold = out[_O_RHOOLD] # (K,) + rhobarold = out[_O_RHOBAROLD] # (K,) + + # Safe divisions: some columns may have zero RHS → zero denominators + hbar_coeff = -(thetabar * rho_new) / torch.clamp( + rhoold * rhobarold, min=_DIV_GUARD + ) + Hbar = H + Hbar * hbar_coeff.unsqueeze(0) + x_coeff = zeta / torch.clamp(rho_new * rhobar_new, min=_DIV_GUARD) + X = X + x_coeff.unsqueeze(0) * Hbar + h_coeff = -(thetanew / torch.clamp(rho_new, min=_DIV_GUARD)) + H = V + H * h_coeff.unsqueeze(0) + + # Propagate state for next iteration + state = out[:_STATE_SIZE] + + # Phase 4: Per-column convergence — single .item() sync per iteration. + # No not_yet.any() guard: the torch.where inside _check_convergence_batched + # already protects converged columns, and the guard would add a second + # host-device sync that costs more than the fused tensor ops it skips. + normx_t = torch.linalg.norm(X, dim=0) # (K,) + normr_t = out[_O_NORMR] # (K,) + normA_t = out[_O_NORMA] # (K,) + normb_t = out[_I_NORMB] # (K,) + + safe_normb = torch.clamp(normb_t, min=_DIV_GUARD) + test1_t = normr_t / safe_normb + safe_normA_normr = torch.clamp(normA_t * normr_t, min=_DIV_GUARD) + test2_t = out[_O_NORMAR] / safe_normA_normr + test3_t = 1.0 / out[_O_CONDA] + t1_t = test1_t / (1.0 + normA_t * normx_t / safe_normb) + rtol_t = btol + atol * normA_t * normx_t / safe_normb + + istop = _check_convergence_batched( + istop, + test1_t, + rtol_t, + test2_t, + test3_t, + t1_t, + atol, + ctol, + K, + device, + ) + + # Single .item() sync: check if all columns have converged + if (istop > 0).all().item(): + break + + istop = _mark_maxiter_batched(istop, itn, maxiter) + + # Handle case where loop never ran + if itn == 0: + return ( + X, + istop, + 0, + normb, + normar_init, + alpha, + torch.ones(K, device=device, dtype=dtype), + torch.zeros(K, device=device, dtype=dtype), + ) + + return ( + X, + istop, + itn, + out[_O_NORMR], + out[_O_NORMAR], + out[_O_NORMA], + out[_O_CONDA], + normx_t, # reuse from last iteration, avoids redundant norm + ) + + +# =========================================================================== +# Public API — dispatcher +# =========================================================================== + + +@torch.no_grad() +def lsmr_torch( + A, + b: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool | None = None, +) -> tuple[torch.Tensor, int, int, float, float, float, float, float]: + """ + LSMR solver — unified entry point. + + Auto-selects implementation based on device: + - CUDA: compiled (torch.compile fuses scalar step into 1 kernel) + - CPU/MPS: scalar (Python-float math, no compilation overhead) + + Pass use_compile=True to force compilation on any device. + """ + device = b.device + if use_compile is None: + use_compile = device.type == "cuda" + + if use_compile: + return _lsmr_compiled( + A, + b, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + use_compile=True, + ) + return _lsmr_eager( + A, + b, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + ) + + +@torch.no_grad() +def lsmr_torch_batched( + A, + B: torch.Tensor, + damp: float = 0.0, + atol: float = 1e-8, + btol: float = 1e-8, + conlim: float = 1e8, + maxiter: int | None = None, + use_compile: bool | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Batched LSMR solver — solve K right-hand sides simultaneously via SpMM. + + Solves ``min ||B - A X||_F`` where B has K columns. Instead of K + sequential sparse matrix-vector products (SpMV), each iteration uses + a single sparse matrix-matrix product (SpMM), which loads the sparse + matrix once and streams through K dense columns. For K >= 2 on GPU, + this is significantly faster than K sequential ``lsmr_torch`` calls. + + All K columns run in lock-step — converged columns continue doing + harmless arithmetic until all columns converge or maxiter is reached. + + When ``use_compile=True`` (auto-detected for CUDA/MPS), scalar Givens + rotations are fused into a single compiled kernel via ``torch.compile``, + further reducing per-iteration kernel launches. + + Parameters + ---------- + A : sparse tensor or LinearOperator-like + Matrix of shape (m, n). Must support ``A @ V`` for dense V of + shape (n, K). For ``_PreconditionedSparse``, this requires an + ``mm()`` method. + B : torch.Tensor + Dense RHS matrix of shape (m, K). + damp : float + Damping factor for regularized least-squares. + atol, btol : float + Stopping tolerances (applied identically to all columns). + conlim : float + Condition number limit. + maxiter : int or None + Maximum iterations. Defaults to min(m, n). + use_compile : bool or None + Whether to use ``torch.compile`` for scalar step fusion. + ``None`` (default) auto-detects: compiled for CUDA, eager for + CPU/MPS. Pass ``True`` to force compilation on MPS. + + Returns + ------- + X : torch.Tensor, shape (n, K) + Solution matrix. + istop : torch.Tensor, shape (K,), dtype long + Per-column stopping reason (0-7, same codes as ``lsmr_torch``). + itn : int + Number of iterations used (max across all columns). + normr : torch.Tensor, shape (K,) + Per-column ``||b_k - A x_k||``. + normar : torch.Tensor, shape (K,) + Per-column ``||A^T(b_k - A x_k)||``. + normA : torch.Tensor, shape (K,) + Per-column estimate of Frobenius norm of A. + condA : torch.Tensor, shape (K,) + Per-column estimate of condition number of A. + normx : torch.Tensor, shape (K,) + Per-column ``||x_k||``. + """ + device = B.device + if use_compile is None: + use_compile = device.type == "cuda" + + if use_compile: + return _lsmr_compiled_batched( + A, + B, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + use_compile=True, + ) + return _lsmr_batched( + A, + B, + damp=damp, + atol=atol, + btol=btol, + conlim=conlim, + maxiter=maxiter, + ) diff --git a/tests/test_batched_lsmr.py b/tests/test_batched_lsmr.py new file mode 100644 index 000000000..300f9c18e --- /dev/null +++ b/tests/test_batched_lsmr.py @@ -0,0 +1,538 @@ +""" +Tests for batched LSMR solver (K right-hand sides via SpMM). + +Verifies that lsmr_torch_batched produces identical results to K sequential +lsmr_torch calls, and that the batched demeaning path matches the sequential one. + +Usage: + KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev pytest tests/test_batched_lsmr.py -v -s +""" + +from __future__ import annotations + +import numpy as np +import pyhdfe +import pytest + +torch = pytest.importorskip("torch") + +from pyfixest.estimation.torch.demean_torch_ import ( # noqa: E402 + _build_sparse_dummy, + _PreconditionedSparse, + demean_torch, +) +from pyfixest.estimation.torch.lsmr_torch import ( # noqa: E402 + lsmr_torch, + lsmr_torch_batched, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sparse_problem(m: int, n: int, density: float = 0.1, seed: int = 42): + """Create a sparse CSR system A and dense rhs b. + + Default density=0.1 ensures the system is well-conditioned enough + for LSMR to converge within min(m, n) iterations. + """ + rng = np.random.default_rng(seed) + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + A_coo = torch.sparse_coo_tensor( + torch.tensor(np.stack([rows, cols])), + torch.tensor(vals, dtype=torch.float64), + size=(m, n), + ) + A_csr = A_coo.to_sparse_csr() + return A_csr + + +def _make_rhs(m: int, K: int, seed: int = 42) -> torch.Tensor: + """Create a dense RHS matrix B of shape (m, K).""" + rng = np.random.default_rng(seed) + return torch.tensor(rng.standard_normal((m, K)), dtype=torch.float64) + + +# --------------------------------------------------------------------------- +# Core batched LSMR tests +# --------------------------------------------------------------------------- + + +class TestBatchedLSMR: + """Unit tests for lsmr_torch_batched.""" + + def test_matches_sequential(self): + """K=5: batched result should match K sequential lsmr_torch calls. + + Tolerance is 1e-6 because the batched path uses vectorized tensor + Givens rotations while sequential uses Python-float math. Accumulated + rounding differences over many iterations produce ~1e-7 divergence. + """ + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=123) + + # Batched solve + X_batch, istop_batch, itn_batch, *_ = lsmr_torch_batched(A, B) + + # Sequential solve + for k in range(K): + x_seq, istop_seq, itn_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Column {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + @pytest.mark.parametrize("K", [2, 10, 20]) + def test_matches_sequential_various_K(self, K): + """Batched matches sequential for various K values.""" + m, n = 300, 150 + A = _make_sparse_problem(m, n, seed=77) + B = _make_rhs(m, K, seed=88) + + X_batch, *_ = lsmr_torch_batched(A, B) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"K={K}, col {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + def test_single_column(self): + """K=1 batched should match single-RHS lsmr_torch.""" + m, n = 200, 100 + A = _make_sparse_problem(m, n) + b = _make_rhs(m, 1, seed=42) + + X_batch, istop_batch, itn_batch, *_ = lsmr_torch_batched(A, b) + x_single, istop_single, itn_single, *_ = lsmr_torch(A, b[:, 0]) + + assert torch.allclose(X_batch[:, 0], x_single, atol=1e-6) + + def test_convergence_per_column(self): + """Columns with different difficulty should converge independently.""" + m, n = 200, 100 + A = _make_sparse_problem(m, n) + + rng = np.random.default_rng(999) + # Column 0: easy (small values), Column 1: harder (large values) + B = torch.zeros(m, 2, dtype=torch.float64) + B[:, 0] = torch.tensor(rng.standard_normal(m) * 0.01, dtype=torch.float64) + B[:, 1] = torch.tensor(rng.standard_normal(m) * 100.0, dtype=torch.float64) + + X, istop, itn, *_ = lsmr_torch_batched(A, B) + + # Both should converge (istop in 1-3) + assert (istop >= 1).all() and (istop <= 3).all(), ( + f"Not all columns converged: istop = {istop}" + ) + # Solutions should match sequential (convergence correctness) + for k in range(2): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Column {k}: batched vs sequential max diff = " + f"{torch.max(torch.abs(X[:, k] - x_seq)).item():.2e}" + ) + + def test_zero_rhs_column(self): + """B with a zero column should produce zero solution for that column.""" + m, n = 100, 50 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, 3, seed=42) + B[:, 1] = 0.0 # Zero out middle column + + X, istop, *_ = lsmr_torch_batched(A, B) + + assert torch.allclose( + X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 + ), ( + f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ) + + def test_all_zero_rhs(self): + """All-zero B should return all-zero X with istop=0.""" + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = torch.zeros(m, K, dtype=torch.float64) + + X, istop, itn, *_ = lsmr_torch_batched(A, B) + + assert torch.allclose(X, torch.zeros(n, K, dtype=torch.float64), atol=1e-12) + assert itn == 0 + + def test_damp(self): + """Damped batched should match damped sequential.""" + m, n, K = 200, 100, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=55) + damp = 5.0 + + X_batch, *_ = lsmr_torch_batched(A, B, damp=damp) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k], damp=damp) + assert torch.allclose(X_batch[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Damped col {k}: max diff = " + f"{torch.max(torch.abs(X_batch[:, k] - x_seq)).item():.2e}" + ) + + def test_overdetermined_known_solution(self): + """Overdetermined system with exact solution, K=3 columns.""" + A_dense = torch.tensor( + [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], dtype=torch.float64 + ) + A_sparse = A_dense.to_sparse_csr() + + X_true = torch.tensor([[1.0, 2.0, -1.0], [-1.0, 0.5, 3.0]], dtype=torch.float64) + B = A_dense @ X_true # (3, 3) + + X_sol, istop, *_ = lsmr_torch_batched(A_sparse, B) + + assert torch.allclose(X_sol, X_true, atol=1e-10), ( + f"Solution mismatch: max diff = " + f"{torch.max(torch.abs(X_sol - X_true)).item():.2e}" + ) + assert ((istop >= 1) & (istop <= 3)).all() + + def test_maxiter_exhaustion(self): + """Forcing maxiter=2 should return istop=7 for all columns.""" + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=42) + + _, istop, itn, *_ = lsmr_torch_batched(A, B, maxiter=2, atol=1e-15, btol=1e-15) + + assert (istop == 7).all(), f"Expected istop=7, got {istop}" + assert itn == 2 + + def test_return_shapes(self): + """Verify all return values have correct shapes.""" + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K) + + X, istop, itn, normr, normar, normA, condA, normx = lsmr_torch_batched(A, B) + + assert X.shape == (n, K) + assert istop.shape == (K,) + assert isinstance(itn, int) + assert normr.shape == (K,) + assert normar.shape == (K,) + assert normA.shape == (K,) + assert condA.shape == (K,) + assert normx.shape == (K,) + + +# --------------------------------------------------------------------------- +# PreconditionedSparse.mm() tests +# --------------------------------------------------------------------------- + + +class TestPreconditionedSparseMM: + """Tests for _PreconditionedSparse.mm() batched matvec.""" + + def test_mm_matches_mv_loop(self): + """mm(V) should equal column-wise mv(v_k) for each k.""" + rng = np.random.default_rng(42) + N, G, K = 200, 20, 5 + + flist = rng.choice(G, (N, 1)).astype(np.uint64) + D = _build_sparse_dummy(flist, torch.device("cpu"), torch.float64) + M_inv = 1.0 / torch.sqrt( + torch.tensor(np.bincount(flist[:, 0], minlength=G), dtype=torch.float64) + ) + + A = _PreconditionedSparse(D, M_inv) + V = torch.randn(G, K, dtype=torch.float64) + + # Batched + result_mm = A.mm(V) + + # Sequential + for k in range(K): + result_mv = A.mv(V[:, k]) + assert torch.allclose(result_mm[:, k], result_mv, atol=1e-12), ( + f"mm vs mv mismatch at column {k}" + ) + + def test_mm_transpose_matches_mv_transpose(self): + """A^T.mm(U) should match column-wise A^T.mv(u_k).""" + rng = np.random.default_rng(42) + N, G, K = 200, 20, 5 + + flist = rng.choice(G, (N, 1)).astype(np.uint64) + D = _build_sparse_dummy(flist, torch.device("cpu"), torch.float64) + M_inv = 1.0 / torch.sqrt( + torch.tensor(np.bincount(flist[:, 0], minlength=G), dtype=torch.float64) + ) + + At = _PreconditionedSparse(D, M_inv).t() + U = torch.randn(N, K, dtype=torch.float64) + + result_mm = At.mm(U) + + for k in range(K): + result_mv = At.mv(U[:, k]) + assert torch.allclose(result_mm[:, k], result_mv, atol=1e-12), ( + f"transpose mm vs mv mismatch at column {k}" + ) + + +# --------------------------------------------------------------------------- +# Integration: batched demeaning +# --------------------------------------------------------------------------- + + +class TestDemeanBatched: + """Verify batched demeaning path produces correct results.""" + + def test_batched_demean_matches_pyhdfe(self): + """Batched demean (K>1) should match pyhdfe reference.""" + rng = np.random.default_rng(929291) + N, K = 1000, 10 + x = rng.normal(0, 1, (N, K)) + flist = np.column_stack( + [ + rng.choice(10, N), + rng.choice(10, N), + ] + ).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "Batched demean did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="Batched demean vs pyhdfe mismatch", + ) + + def test_batched_demean_weighted_matches_pyhdfe(self): + """Weighted batched demean should match pyhdfe.""" + rng = np.random.default_rng(929291) + N, K = 1000, 5 + x = rng.normal(0, 1, (N, K)) + flist = np.column_stack( + [ + rng.choice(10, N), + rng.choice(10, N), + ] + ).astype(np.uint64) + weights = rng.uniform(0.1, 2.0, N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "Weighted batched demean did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="Weighted batched demean vs pyhdfe mismatch", + ) + + def test_single_column_still_works(self): + """K=1 demean should still work (uses sequential path).""" + rng = np.random.default_rng(42) + N = 500 + x = rng.normal(0, 1, (N, 1)) + flist = rng.choice(10, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_many_columns(self): + """K=50 batched demean should match pyhdfe.""" + rng = np.random.default_rng(9999) + N, K = 2000, 50 + x = rng.normal(0, 1, (N, K)) + flist = rng.choice(20, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + +# --------------------------------------------------------------------------- +# Compiled batched LSMR tests +# --------------------------------------------------------------------------- + +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +class TestCompiledBatchedLSMR: + """Tests for the compiled batched LSMR path (_lsmr_compiled_batched).""" + + def test_compiled_batched_matches_eager_batched(self): + """Packed-state batched matches vectorized eager batched on CPU f64, K=5. + + Uses use_compile=False to avoid Inductor C++ backend issues on macOS. + This still validates the _lsmr_compiled_batched logic (packed state, + _scalar_step on (STATE_SIZE, K) tensors, vector updates with clamp guards). + Actual torch.compile fusion is tested in the MPS tests below. + """ + from pyfixest.estimation.torch.lsmr_torch import ( + _lsmr_batched, + _lsmr_compiled_batched, + ) + + m, n, K = 200, 100, 5 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=123) + + X_eager, istop_eager, itn_eager, *_ = _lsmr_batched(A, B) + X_comp, istop_comp, itn_comp, *_ = _lsmr_compiled_batched( + A, B, use_compile=False + ) + + assert torch.allclose(X_eager, X_comp, atol=1e-6, rtol=1e-6), ( + f"max diff = {torch.max(torch.abs(X_eager - X_comp)).item():.2e}" + ) + assert itn_eager == itn_comp, f"itn: {itn_eager} vs {itn_comp}" + + @pytest.mark.parametrize("K", [2, 10, 20]) + def test_compiled_matches_sequential_various_K(self, K): + """Packed-state batched matches K sequential lsmr_torch calls.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 300, 150 + A = _make_sparse_problem(m, n, seed=77) + B = _make_rhs(m, K, seed=88) + + X_comp, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X_comp[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"K={K}, col {k}: max diff = " + f"{torch.max(torch.abs(X_comp[:, k] - x_seq)).item():.2e}" + ) + + def test_compiled_single_column(self): + """Packed-state K=1 matches single-RHS lsmr_torch.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 200, 100 + A = _make_sparse_problem(m, n) + b = _make_rhs(m, 1, seed=42) + + X_comp, *_ = _lsmr_compiled_batched(A, b, use_compile=False) + x_single, *_ = lsmr_torch(A, b[:, 0]) + + assert torch.allclose(X_comp[:, 0], x_single, atol=1e-6), ( + f"max diff = {torch.max(torch.abs(X_comp[:, 0] - x_single)).item():.2e}" + ) + + def test_compiled_zero_rhs_column(self): + """Zero RHS column handled correctly in packed-state path.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n = 100, 50 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, 3, seed=42) + B[:, 1] = 0.0 # Zero out middle column + + X, istop, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + assert torch.allclose( + X[:, 1], torch.zeros(n, dtype=torch.float64), atol=1e-12 + ), ( + f"Zero-RHS column has non-zero solution: ||x|| = {torch.norm(X[:, 1]).item()}" + ) + + # Non-zero columns should still solve correctly + for k in [0, 2]: + x_seq, *_ = lsmr_torch(A, B[:, k]) + assert torch.allclose(X[:, k], x_seq, atol=1e-6, rtol=1e-6) + + def test_compiled_damp(self): + """Damped packed-state batched matches damped sequential.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n, K = 200, 100, 3 + A = _make_sparse_problem(m, n) + B = _make_rhs(m, K, seed=55) + damp = 5.0 + + X_comp, *_ = _lsmr_compiled_batched(A, B, damp=damp, use_compile=False) + + for k in range(K): + x_seq, *_ = lsmr_torch(A, B[:, k], damp=damp) + assert torch.allclose(X_comp[:, k], x_seq, atol=1e-6, rtol=1e-6), ( + f"Damped col {k}: max diff = " + f"{torch.max(torch.abs(X_comp[:, k] - x_seq)).item():.2e}" + ) + + def test_compiled_all_zero_rhs(self): + """All-zero B returns all-zero X in packed-state path.""" + from pyfixest.estimation.torch.lsmr_torch import _lsmr_compiled_batched + + m, n, K = 100, 50, 3 + A = _make_sparse_problem(m, n) + B = torch.zeros(m, K, dtype=torch.float64) + + X, istop, itn, *_ = _lsmr_compiled_batched(A, B, use_compile=False) + + assert torch.allclose(X, torch.zeros(n, K, dtype=torch.float64), atol=1e-12) + assert itn == 0 + + # --- MPS-specific tests --- + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_compiled_mps_correctness(self): + """Compiled batched on MPS f32 vs CPU f64 reference.""" + m, n, K = 300, 150, 5 + A_cpu = _make_sparse_problem(m, n, density=0.1, seed=42) + B_cpu = _make_rhs(m, K, seed=123) + + # CPU f64 reference (eager) + X_ref, *_ = lsmr_torch_batched(A_cpu, B_cpu, use_compile=False) + + # MPS f32 compiled + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + B_mps = B_cpu.to(torch.float32).to("mps") + + X_mps, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=True) + + max_diff = torch.max(torch.abs(X_ref.float() - X_mps.cpu())).item() + assert max_diff < 0.1, ( + f"MPS f32 compiled vs CPU f64 too different: {max_diff:.2e}" + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_compiled_vs_uncompiled_mps(self): + """Compiled and uncompiled give same results on MPS.""" + m, n, K = 300, 150, 5 + A_cpu = _make_sparse_problem(m, n, density=0.1, seed=42) + B_cpu = _make_rhs(m, K, seed=123) + + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + B_mps = B_cpu.to(torch.float32).to("mps") + + X_comp, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=True) + X_nocomp, *_ = lsmr_torch_batched(A_mps, B_mps, use_compile=False) + + max_diff = torch.max(torch.abs(X_comp - X_nocomp)).item() + assert max_diff < 1e-4, f"Compiled vs uncompiled differ on MPS: {max_diff:.2e}" diff --git a/tests/test_demean.py b/tests/test_demean.py index 0d161a5c6..0e47d76bd 100644 --- a/tests/test_demean.py +++ b/tests/test_demean.py @@ -3,9 +3,11 @@ import pyhdfe import pytest +import pyfixest as pf from pyfixest.core import demean as demean_rs from pyfixest.core.demean import demean_within from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy32, demean_cupy64 +from pyfixest.estimation.internals.backends import BACKENDS from pyfixest.estimation.internals.demean_ import ( _set_demeaner_backend, demean, @@ -46,22 +48,21 @@ def test_demean(benchmark, demean_func): def test_set_demeaner_backend(): - # Test numba backend - demean_func = _set_demeaner_backend("numba") - assert demean_func == demean - - # Test jax backend - demean_func = _set_demeaner_backend("jax") - assert demean_func == demean_jax - - demean_func = _set_demeaner_backend("rust") - assert demean_func == demean_rs - - demean_func = _set_demeaner_backend("cupy32") - assert demean_func == demean_cupy32 - - demean_func = _set_demeaner_backend("cupy64") - assert demean_func == demean_cupy64 + for backend in [ + "numba", + "jax", + "rust", + "cupy32", + "cupy64", + "scipy", + "torch", + "torch_cpu", + "torch_mps", + "torch_cuda", + "torch_cuda32", + ]: + demean_func = _set_demeaner_backend(backend) + assert demean_func == BACKENDS[backend]["demean"] demean_func = _set_demeaner_backend("rust-cg") assert demean_func == demean_within @@ -71,6 +72,42 @@ def test_set_demeaner_backend(): _set_demeaner_backend("invalid") +def test_feols_torch_backend_matches_numba(): + pytest.importorskip("torch") + + rng = np.random.default_rng(12345) + N = 400 + df = pd.DataFrame( + { + "y": rng.normal(size=N), + "x1": rng.normal(size=N), + "x2": rng.normal(size=N), + "f1": rng.integers(0, 20, size=N), + "f2": rng.integers(0, 15, size=N), + } + ) + + fit_numba = pf.feols( + "y ~ x1 + x2 | f1 + f2", + data=df, + fixef_tol=1e-8, + demeaner_backend="numba", + ) + fit_torch = pf.feols( + "y ~ x1 + x2 | f1 + f2", + data=df, + fixef_tol=1e-8, + demeaner_backend="torch", + ) + + np.testing.assert_allclose( + fit_torch.coef().sort_index().values, + fit_numba.coef().sort_index().values, + rtol=1e-7, + atol=1e-9, + ) + + @pytest.mark.parametrize( argnames="demean_func", argvalues=[demean, demean_jax, demean_rs, demean_cupy32, demean_cupy64], diff --git a/tests/test_lsmr_compiled.py b/tests/test_lsmr_compiled.py new file mode 100644 index 000000000..e785bf114 --- /dev/null +++ b/tests/test_lsmr_compiled.py @@ -0,0 +1,195 @@ +""" +Tests for lsmr_torch (compiled version): correctness, auto-detection, MPS torch.compile. + +Usage: + KMP_DUPLICATE_LIB_OK=TRUE pixi run -e dev python -m pytest tests/test_lsmr_compiled.py -v -s +""" + +from __future__ import annotations + +import time + +import numpy as np +import pytest +import torch + +# Reference: original scalar-state LSMR (for correctness comparison) +from pyfixest.estimation.torch.lsmr_torch import _lsmr_eager as lsmr_torch_original +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sparse_problem(m: int, n: int, density: float = 0.01, seed: int = 42): + """Create a sparse CSR system A and dense rhs b.""" + rng = np.random.default_rng(seed) + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + A_coo = torch.sparse_coo_tensor( + torch.tensor(np.stack([rows, cols])), + torch.tensor(vals, dtype=torch.float64), + size=(m, n), + ) + A_csr = A_coo.to_sparse_csr() + b = torch.tensor(rng.standard_normal(m), dtype=torch.float64) + return A_csr, b + + +# --------------------------------------------------------------------------- +# Correctness tests (CPU, f64 - auto-detects use_compile=False) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("m,n", [(200, 100), (500, 300), (1000, 500)]) +def test_matches_original_cpu(m, n): + """New lsmr_torch (auto CPU mode) matches original on CPU f64.""" + A, b = _make_sparse_problem(m, n) + + x_orig, _istop_orig, itn_orig, *_ = lsmr_torch_original(A, b) + x_new, _istop_new, itn_new, *_ = lsmr_torch(A, b) + + assert torch.allclose(x_orig, x_new, atol=1e-6, rtol=1e-6), ( + f"Solutions differ: max_diff={torch.max(torch.abs(x_orig - x_new)).item():.2e}" + ) + assert itn_orig == itn_new, f"itn differs: {itn_orig} vs {itn_new}" + + +def test_diagnostics_match_original(): + """normr, normar, normA, condA, normx diagnostics match reference.""" + A, b = _make_sparse_problem(500, 300) + + _x_o, istop_orig, _itn_o, normr_o, normar_o, normA_o, _condA_o, _normx_o = ( + lsmr_torch_original(A, b) + ) + _x_n, istop_new, _itn_n, normr_n, normar_n, normA_n, _condA_n, _normx_n = ( + lsmr_torch(A, b) + ) + + assert istop_orig == istop_new, f"istop: {istop_orig} vs {istop_new}" + assert abs(normr_o - normr_n) / max(normr_o, 1e-30) < 1e-6, ( + f"normr: {normr_o:.6e} vs {normr_n:.6e}" + ) + assert abs(normar_o - normar_n) / max(normar_o, 1e-30) < 1e-6 + assert abs(normA_o - normA_n) / max(normA_o, 1e-30) < 1e-6 + + +def test_btol_convergence(): + """btol-based stopping (istop=1) should work and match reference.""" + A, b = _make_sparse_problem(200, 100) + + # Use tight btol but very loose atol - should converge via test1, not test2 + _, istop_orig, itn_orig, *_ = lsmr_torch_original(A, b, atol=1e-2, btol=1e-10) + _, istop_new, itn_new, *_ = lsmr_torch(A, b, atol=1e-2, btol=1e-10) + + # Both should converge; compiled should match reference istop + assert istop_new == istop_orig, f"istop: {istop_orig} vs {istop_new}" + assert itn_new == itn_orig, f"itn: {itn_orig} vs {itn_new}" + + +# --------------------------------------------------------------------------- +# Auto-detection tests +# --------------------------------------------------------------------------- + + +def test_auto_cpu_defaults(): + """On CPU tensors, auto-selects use_compile=False.""" + A, b = _make_sparse_problem(200, 100) + x, istop, _itn, *_ = lsmr_torch(A, b) + assert x.device.type == "cpu" + assert istop in range(8) + + +# --------------------------------------------------------------------------- +# MPS + torch.compile tests +# --------------------------------------------------------------------------- + +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +def test_correctness_mps(): + """lsmr_torch on MPS f32 produces reasonable results vs CPU f64.""" + m, n = 500, 300 + A_cpu, b_cpu = _make_sparse_problem(m, n) + + x_ref, *_ = lsmr_torch_original(A_cpu, b_cpu) + + # MPS f32 (auto: use_compile=False; pass use_compile=True to force) + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + b_mps = b_cpu.to(torch.float32).to("mps") + + x_mps, *_ = lsmr_torch(A_mps, b_mps) + + max_diff = torch.max(torch.abs(x_ref.float() - x_mps.cpu())).item() + assert max_diff < 0.1, f"MPS f32 vs CPU f64 too different: {max_diff:.2e}" + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +def test_compiled_vs_uncompiled_mps(): + """Compiled and uncompiled give same results on MPS.""" + m, n = 500, 300 + A_cpu, b_cpu = _make_sparse_problem(m, n) + A_mps = A_cpu.to(torch.float32).to_dense().to("mps") + b_mps = b_cpu.to(torch.float32).to("mps") + + x_comp, *_ = lsmr_torch(A_mps, b_mps, use_compile=True) + x_nocomp, *_ = lsmr_torch(A_mps, b_mps, use_compile=False) + + max_diff = torch.max(torch.abs(x_comp - x_nocomp)).item() + assert max_diff < 1e-5, f"Compiled vs uncompiled differ: {max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# Timing benchmark (run with -s to see output) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not HAS_MPS, reason="MPS not available") +@pytest.mark.parametrize("m,n", [(5000, 2000), (10000, 5000)]) +def test_timing_mps(m, n): + """Compare compiled vs uncompiled LSMR timing on MPS.""" + rng = np.random.default_rng(42) + nnz = int(m * n * 0.005) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz).astype(np.float32) + + A = torch.zeros(m, n, dtype=torch.float32) + for r, c, v in zip(rows, cols, vals, strict=True): + A[r, c] += v + A = A.to("mps") + b = torch.tensor(rng.standard_normal(m).astype(np.float32), device="mps") + + # Warmup + lsmr_torch(A, b, use_compile=True) + lsmr_torch(A, b, use_compile=False) + torch.mps.synchronize() + + # Compiled + torch.mps.synchronize() + t0 = time.perf_counter() + x_comp, _, itn_comp, *_ = lsmr_torch(A, b, use_compile=True) + torch.mps.synchronize() + t_comp = time.perf_counter() - t0 + + # Uncompiled + torch.mps.synchronize() + t0 = time.perf_counter() + x_nocomp, _, itn_nocomp, *_ = lsmr_torch(A, b, use_compile=False) + torch.mps.synchronize() + t_nocomp = time.perf_counter() - t0 + + speedup = t_nocomp / t_comp if t_comp > 0 else float("inf") + print( + f"\n [{m}x{n}] compiled: {t_comp:.3f}s ({itn_comp} iters) | " + f"uncompiled: {t_nocomp:.3f}s ({itn_nocomp} iters) | " + f"speedup: {speedup:.2f}x" + ) + + # Correctness + assert torch.allclose(x_comp, x_nocomp, atol=1e-4, rtol=1e-4) diff --git a/tests/test_torch_lsmr.py b/tests/test_torch_lsmr.py new file mode 100644 index 000000000..f1aa796c9 --- /dev/null +++ b/tests/test_torch_lsmr.py @@ -0,0 +1,540 @@ +""" +Tests for standalone PyTorch LSMR solver and torch-based FWL demeaning. + +Three test levels: +1. Bare LSMR: verify the solver on known linear systems +2. Demeaning vs pyhdfe: compare demean_torch against pyhdfe reference +3. Demeaning vs demean_scipy: compare against SciPy LSMR-based backend +""" + +import numpy as np +import pyhdfe +import pytest + +torch = pytest.importorskip("torch") + +from pyfixest.estimation.cupy.demean_cupy_ import demean_scipy # noqa: E402 +from pyfixest.estimation.torch.demean_torch_ import demean_torch # noqa: E402 +from pyfixest.estimation.torch.lsmr_torch import lsmr_torch # noqa: E402 + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def demean_data(): + """Shared data fixture for all demeaning tests (matches test_demean.py pattern).""" + rng = np.random.default_rng(929291) + N = 1000 + M = 10 + x = rng.normal(0, 1, M * N).reshape((N, M)) + f1 = rng.choice(list(range(M)), N).reshape((N, 1)) + f2 = rng.choice(list(range(M)), N).reshape((N, 1)) + flist = np.concatenate((f1, f2), axis=1).astype(np.uint64) + # Weights drawn from the *same* advanced RNG (not a fresh seed) + weights = rng.uniform(0, 1, N) + return x, flist, weights + + +# --------------------------------------------------------------------------- +# Level 1: Bare LSMR tests +# --------------------------------------------------------------------------- + + +class TestLSMR: + """Unit tests for the pure-torch LSMR solver.""" + + def test_overdetermined_known_solution(self): + """Overdetermined system with exact solution: LSMR should recover it.""" + A_dense = torch.tensor( + [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], dtype=torch.float64 + ) + x_true = torch.tensor([1.0, -1.0], dtype=torch.float64) + b = A_dense @ x_true + + A_sparse = A_dense.to_sparse_csr() + x_sol, istop, _itn, normr, _normar, *_ = lsmr_torch(A_sparse, b) + + assert istop in (1, 2), f"LSMR did not converge, istop={istop}" + assert torch.allclose(x_sol, x_true, atol=1e-10), ( + f"Solution mismatch: {x_sol} vs {x_true}" + ) + assert normr < 1e-10, f"Residual too large: {normr}" + + def test_underdetermined_min_norm(self): + """Underdetermined system (m < n): LSMR should find min-norm solution.""" + torch.manual_seed(42) + A_dense = torch.randn(2, 4, dtype=torch.float64) + b = torch.randn(2, dtype=torch.float64) + + A_sparse = A_dense.to_sparse_csr() + x_sol, _istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) + + residual = torch.norm(A_dense @ x_sol - b).item() + assert residual < 1e-8, f"Residual too large: {residual}" + + x_lstsq = torch.linalg.lstsq(A_dense, b).solution + assert torch.allclose(x_sol, x_lstsq, atol=1e-6), ( + f"Not min-norm: ||x_lsmr||={torch.norm(x_sol):.6f}, " + f"||x_lstsq||={torch.norm(x_lstsq):.6f}" + ) + + def test_larger_sparse_system(self): + """Larger sparse system to exercise the iteration loop.""" + rng = np.random.default_rng(123) + m, n = 200, 50 + density = 0.1 + + nnz = int(m * n * density) + rows = rng.integers(0, m, nnz) + cols = rng.integers(0, n, nnz) + vals = rng.standard_normal(nnz) + + indices = torch.tensor(np.stack([rows, cols]), dtype=torch.long) + values = torch.tensor(vals, dtype=torch.float64) + A_sparse = torch.sparse_coo_tensor(indices, values, (m, n)).to_sparse_csr() + + x_true = torch.tensor(rng.standard_normal(n), dtype=torch.float64) + A_dense = A_sparse.to_dense() + b = A_dense @ x_true + + x_sol, istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) + + assert istop in (1, 2, 3), f"LSMR did not converge, istop={istop}" + residual = torch.norm(A_dense @ x_sol - b).item() + assert residual < 1e-5, f"Residual too large: {residual}" + + def test_zero_rhs(self): + """B = 0 should give x = 0.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.zeros(3, dtype=torch.float64) + + x_sol, istop, _itn, _normr, _normar, *_ = lsmr_torch(A_sparse, b) + + assert torch.allclose(x_sol, torch.zeros(3, dtype=torch.float64), atol=1e-12) + assert istop == 0 + + def test_damped_regularization(self): + """Damping should shrink the solution toward zero.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.ones(3, dtype=torch.float64) + + x_undamped, *_ = lsmr_torch(A_sparse, b, damp=0.0) + x_damped, *_ = lsmr_torch(A_sparse, b, damp=10.0) + + assert torch.norm(x_damped) < torch.norm(x_undamped), ( + "Damped solution should have smaller norm" + ) + + def test_maxiter_exhaustion_returns_istop_7(self): + """Forcing maxiter=2 on an ill-conditioned system must return istop=7.""" + # Use a system that genuinely needs many iterations + # (identity converges in 1 step, so we use a harder matrix) + torch.manual_seed(99) + A_dense = torch.randn(20, 10, dtype=torch.float64) + b = torch.randn(20, dtype=torch.float64) + A_sparse = A_dense.to_sparse_csr() + + _, istop, itn, *_ = lsmr_torch(A_sparse, b, maxiter=2, atol=1e-15, btol=1e-15) + + assert istop == 7, f"Expected istop=7 (maxiter hit), got {istop}" + assert itn == 2 + + def test_full_return_tuple(self): + """Verify all 8 return values are present and sensible.""" + A_sparse = torch.eye(3, dtype=torch.float64).to_sparse_csr() + b = torch.ones(3, dtype=torch.float64) + + _x, _istop, _itn, _normr, _normar, normA, condA, normx = lsmr_torch(A_sparse, b) + + assert isinstance(normA, float) and normA > 0 + assert isinstance(condA, float) and condA >= 1.0 + assert isinstance(normx, float) and normx > 0 + + +# --------------------------------------------------------------------------- +# Level 2: Demeaning vs pyhdfe +# --------------------------------------------------------------------------- + + +class TestDemeanVsPyhdfe: + """Compare demean_torch against pyhdfe reference.""" + + def test_unweighted(self, demean_data): + """Verify unweighted demeaning matches pyhdfe.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="demean_torch vs pyhdfe mismatch (unweighted)", + ) + + def test_weighted(self, demean_data): + """Verify weighted demeaning matches pyhdfe.""" + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge (weighted)" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-6, + atol=1e-8, + err_msg="demean_torch vs pyhdfe mismatch (weighted)", + ) + + +# --------------------------------------------------------------------------- +# Level 3: Demeaning vs demean_scipy +# --------------------------------------------------------------------------- + + +class TestDemeanVsScipy: + """Compare demean_torch against demean_scipy (both use LSMR).""" + + def test_unweighted_vs_scipy(self, demean_data): + """Verify unweighted demeaning matches scipy LSMR.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + res_scipy, success_scipy = demean_scipy(x, flist, weights, tol=1e-10) + res_torch, success_torch = demean_torch(x, flist, weights, tol=1e-10) + + assert success_scipy, "demean_scipy did not converge" + assert success_torch, "demean_torch did not converge" + + # Different D construction (formulaic drops reference level, we don't) + # so tolerances are slightly looser than torch-vs-pyhdfe. + # Both produce valid demeaned residuals, but the underlying theta + # coefficients differ because the systems have different null spaces. + np.testing.assert_allclose( + res_torch, + res_scipy, + rtol=1e-5, + atol=1e-7, + err_msg="demean_torch vs demean_scipy mismatch (unweighted)", + ) + + def test_weighted_vs_scipy(self, demean_data): + """Verify weighted demeaning matches scipy LSMR.""" + x, flist, weights = demean_data + + res_scipy, success_scipy = demean_scipy(x, flist, weights, tol=1e-10) + res_torch, success_torch = demean_torch(x, flist, weights, tol=1e-10) + + assert success_scipy, "demean_scipy did not converge" + assert success_torch, "demean_torch did not converge" + + np.testing.assert_allclose( + res_torch, + res_scipy, + rtol=1e-5, + atol=1e-7, + err_msg="demean_torch vs demean_scipy mismatch (weighted)", + ) + + +# --------------------------------------------------------------------------- +# Level 4: Edge cases for demean_torch +# --------------------------------------------------------------------------- + + +class TestDemeanEdgeCases: + """Edge case tests for demean_torch.""" + + def test_1d_x_input(self): + """1D x input should return 1D output with same shape.""" + rng = np.random.default_rng(42) + N = 100 + x_1d = rng.normal(0, 1, N) + flist = rng.choice(5, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + res, success = demean_torch(x_1d, flist, weights, tol=1e-10) + assert success + assert res.ndim == 1, f"Expected 1D output, got shape {res.shape}" + assert res.shape == x_1d.shape + + def test_single_fe_factor(self): + """Single fixed effect factor should work and match pyhdfe.""" + rng = np.random.default_rng(42) + N = 200 + x = rng.normal(0, 1, (N, 3)) + flist = rng.choice(10, N).astype(np.uint64).reshape(N, 1) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_three_fe_factors(self): + """Three FE factors should work and match pyhdfe.""" + rng = np.random.default_rng(42) + N = 500 + x = rng.normal(0, 1, (N, 2)) + f1 = rng.choice(8, N).reshape(N, 1) + f2 = rng.choice(6, N).reshape(N, 1) + f3 = rng.choice(4, N).reshape(N, 1) + flist = np.concatenate([f1, f2, f3], axis=1).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-6, atol=1e-8) + + def test_non_contiguous_group_ids_raises(self): + """Non-contiguous group IDs (gaps) should raise ValueError.""" + N = 50 + x = np.ones((N, 1)) + # Groups [0, 5, 10] — non-contiguous, max+1=11 but only 3 unique + flist = np.array( + [0, 5, 10] * (N // 3) + [0] * (N % 3), dtype=np.uint64 + ).reshape(N, 1) + weights = np.ones(N) + + with pytest.raises(ValueError, match="non-contiguous group IDs"): + demean_torch(x, flist, weights) + + def test_zero_weight_group_raises(self): + """If any group has zero total weight, should raise ValueError.""" + N = 50 + x = np.ones((N, 1)) + # All observations in group 0 get zero weight + flist = np.zeros(N, dtype=np.uint64).reshape(N, 1) + flist[N // 2 :, 0] = 1 + weights = np.ones(N) + weights[: N // 2] = 0.0 # group 0 has zero total weight + + with pytest.raises(ValueError, match="zero total weight"): + demean_torch(x, flist, weights) + + def test_flist_none_raises(self): + """flist=None should raise ValueError.""" + with pytest.raises(ValueError, match="flist cannot be None"): + demean_torch(np.ones((10, 2)), flist=None, weights=np.ones(10)) + + +# --------------------------------------------------------------------------- +# Level 5: High-dimensional stress tests +# --------------------------------------------------------------------------- + + +class TestHighDimensional: + """Stress tests with larger N, many groups, and unbalanced designs.""" + + def test_many_groups(self): + """N=10K, 200 groups per factor — exercises LSMR on a larger system.""" + rng = np.random.default_rng(7777) + N = 10_000 + G = 200 + x = rng.normal(0, 1, (N, 3)) + f1 = rng.choice(G, N).reshape(N, 1) + f2 = rng.choice(G, N).reshape(N, 1) + flist = np.concatenate([f1, f2], axis=1).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge on high-D problem" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_unbalanced_groups(self): + """Highly unbalanced groups: one large group, many tiny groups. + + This stresses the diagonal preconditioner because group_weights + vary by orders of magnitude. + """ + rng = np.random.default_rng(8888) + N = 5_000 + # Group 0 gets 80% of observations, groups 1-49 share the rest + groups = np.zeros(N, dtype=np.uint64) + small_start = int(N * 0.8) + groups[small_start:] = rng.choice(49, N - small_start).astype(np.uint64) + 1 + flist = groups.reshape(N, 1) + + x = rng.normal(0, 1, (N, 2)) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge on unbalanced groups" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_many_columns(self): + """Many dependent variables (K=50) — checks per-column loop scales.""" + rng = np.random.default_rng(9999) + N = 2_000 + K = 50 + x = rng.normal(0, 1, (N, K)) + flist = rng.choice(20, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + def test_weighted_high_d(self): + """Weighted demeaning with many groups — weighted preconditioner stress test.""" + rng = np.random.default_rng(1234) + N = 10_000 + G = 100 + x = rng.normal(0, 1, (N, 5)) + flist = rng.choice(G, (N, 2)).astype(np.uint64) + # Highly skewed weights (log-normal) + weights = rng.lognormal(0, 2, N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch(x, flist, weights, tol=1e-10) + assert success, "demean_torch did not converge with skewed weights" + np.testing.assert_allclose(res_torch, res_pyhdfe, rtol=1e-5, atol=1e-7) + + +# --------------------------------------------------------------------------- +# Level 6: float32 / MPS tests +# --------------------------------------------------------------------------- + +HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +class TestFloat32: + """Tests for float32 dtype path (COO on MPS, COO-converted-to-CSR on CPU).""" + + def test_cpu_float32_vs_pyhdfe(self, demean_data): + """float32 on CPU should match pyhdfe within single-precision tolerance.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (f32) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (f32 CPU) vs pyhdfe mismatch", + ) + + def test_cpu_float32_weighted(self, demean_data): + """Weighted float32 on CPU should match pyhdfe.""" + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (f32 weighted) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (f32 CPU weighted) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_vs_pyhdfe(self, demean_data): + """float32 on MPS should match pyhdfe within single-precision tolerance.""" + x, flist, _ = demean_data + N = x.shape[0] + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_weighted(self, demean_data): + """Weighted float32 on MPS should match pyhdfe.""" + x, flist, weights = demean_data + N = x.shape[0] + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x, weights.reshape(N, 1)) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32 weighted) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32 weighted) vs pyhdfe mismatch", + ) + + @pytest.mark.skipif(not HAS_MPS, reason="MPS not available") + def test_mps_float32_high_d(self): + """MPS float32 with many groups — exercises COO path at scale.""" + rng = np.random.default_rng(5555) + N = 10_000 + G = 100 + x = rng.normal(0, 1, (N, 3)) + flist = rng.choice(G, (N, 2)).astype(np.uint64) + weights = np.ones(N) + + algorithm = pyhdfe.create(flist) + res_pyhdfe = algorithm.residualize(x) + + res_torch, success = demean_torch( + x, flist, weights, tol=1e-5, dtype=torch.float32 + ) + assert success, "demean_torch (MPS f32 high-D) did not converge" + np.testing.assert_allclose( + res_torch, + res_pyhdfe, + rtol=1e-3, + atol=1e-3, + err_msg="demean_torch (MPS f32 high-D) vs pyhdfe mismatch", + )