Skip to content

feat: implement LSMR demeaning in torch for CUDA MPS support#1220

Open
janfb wants to merge 16 commits intopy-econometrics:masterfrom
janfb:feature/torch-lsmr
Open

feat: implement LSMR demeaning in torch for CUDA MPS support#1220
janfb wants to merge 16 commits intopy-econometrics:masterfrom
janfb:feature/torch-lsmr

Conversation

@janfb
Copy link
Copy Markdown

@janfb janfb commented Mar 5, 2026

This PR implements LSMR in torch and adds a PyTorch-based fixed-effect demeaning backend to pyfixest.

Users can access it through the existing demeaner_backend argument, for example:

pf.feols("Y ~ X1 | f1 + f2", data, demeaner_backend="torch_mps")

The main motivation is enabling GPU support on consumer Apple laptops via MPS, which is relevant for a large part of the user base. The same backend family also supports CUDA and CPU, and "batched" LSMR for systems with more features.

User-facing API

This PR extends the existing demeaner_backend option with:

  • torch
  • torch_cpu
  • torch_mps
  • torch_cuda
  • torch_cuda32

No new top-level estimation API is introduced. Internally, we additionally choose between compiled and batched versions of LSMR.

Main caveat

MPS only supports float32 for this implementation, not the current pyfixest default float64. So the torch MPS backend is effectively a float32 path.

Core implementation choices

  • Sparse matrix layout is device-specific (due to MPS limitations): MPS uses COO, CPU and CUDA use CSR.

  • Torch LSMR is integrated into the existing demeaning pipeline:
    the backend builds the sparse FE dummy matrix directly from encoded fixed effects, applies optional weighting and diagonal preconditioning, and solves the FWL system with LSMR.

  • Dispatch is device-specific:

    • CPU defaults to the eager (not compiled) single-RHS (no batching) path
    • CUDA defaults to the compiled torch path, then switches to batched LSMR for K>=2
    • MPS defaults to the eager path (compiled is slower on MPS), then switches to batched LSMR for K>=5
  • Batched LSMR is included for multi-column demeaning:
    lsmr_torch_batched() is used internally when solving several RHS columns jointly on devices where batched sparse matmul is beneficial.

  • Scalar-step handling changed during development: the early version relied heavily on .item() / CPU scalar math for Givens rotations. The current implementation is more nuanced:

    • eager single-RHS still uses Python scalar math where that helps
    • CUDA additionally supports a compiled path
    • compiled-state logic is factored into _lsmr_compiled_core.py

Benchmarking

  • I ran a custom benchmark on the simpler problems already present in the tests.
  • I also integrated the torch backend into the recently merged difficult benchmarking suite.
  • benchmark results are posted below.

@janfb
Copy link
Copy Markdown
Author

janfb commented Mar 5, 2026

Results on simple benchmarks

These simple benchmarks are taken from tests/. Note that pyhdfe dominates all of them, likely because these problems are trivial to use and LSMR does not amortize.

bench_demean_torch bench_demean_torch_profile

@janfb
Copy link
Copy Markdown
Author

janfb commented Mar 5, 2026

Results on new benchmarking suite

After #1211 being merged, I integrated the torch backend options into this framework and ran all the tasks and algos on my mac (excluding all CUDA related algos).

bench_scaling bench_bars_difficult bench_speedup_heatmap

Summary (by Claude Code)

  • Plot 1 — Scaling (log-log line chart): The bottom-right panel (Difficult 3FE) is the most dramatic — you can see the red torch MPS LSMR (f32) line crossing below
    the green numba line around 100K observations, then staying the fastest all the way to 1M. Meanwhile rust (brown) shoots up catastrophically. On the easy problems
    (top row), numba and JAX dominate with near-identical scaling.

  • Plot 2 — Bar chart (Difficult DGP focus): The right panel at 1M/3FE tells the key story visually — the red torch MPS bar (22.2s) is shorter than everything else,
    with rust towering at 303s. The value labels make exact comparisons easy.

  • Plot 3 — Heatmap (speedup vs numba): The rightmost column group (Difficult 3FE) has the only green cells for torch MPS — 1.0x at 500K and 1.1x at 1M, meaning it's
    faster than numba. Everything else is yellow-to-red (slower than numba) in

@janfb
Copy link
Copy Markdown
Author

janfb commented Mar 6, 2026

Benchmarks on CUDA

I ran the same benchmark tasks and settings on a CUDA machine using a Tesla V100. This shows only CUDA-related backends: JAX (MAP), CuPy (LSMR), and the new torch-lsmr.

Summary: On the simpler problems where alternating projects (MAP) work well, JAX dominates. On high-D difficult FE problems the new torch LSMR dominates 🚀 . It's unclear why it is so much faster then CuPy which is based on LSMR as well.

bench_cuda_scaling bench_cuda_speedup_heatmap

@janfb
Copy link
Copy Markdown
Author

janfb commented Mar 6, 2026

LSMR Unification + better CUDA Results with torch.compile

TL;DR: torch.compile works well for CUDA and pays off for large N and complex problems (another 1.67x speedup on top!). For MPS it doesn't help because syncs are fast (shared MPS memory) and torch.compile can be applied only partially.

After discussing with @schroedk , I tried different options of optimizing for few CPU-GPU syncs vs few GPU kernel launches in the LSMR loop. For example, previously, I was performing Given rotations on CPU to avoid scalar operations on GPU. The needed syncs between GPU and CPU induce overhead. But when we do it all on GPU, we have more kernel launches in the loop, which are costly as well. Compiling the entire loop would help because this results in a single kernel launch. However, torch.compile doesn't support sparse CSR matvec operations, so we can compile only the Given rotation part of the loop. For CUDA this helps. For MPS, it can be done as well (only for float32 and only for COO sparse matrices), but it actually doesn't help because syncing is so fast.

Design choices

So, overall, we now have torch LSMR in two different solver implementations (lsmr_torch.py and lsmr_torch_compiled.py). This is unified into a single entry pointlsmr_torch() in lsmr_torch.py. It auto-dispatches based on device:

Device Default path Why
CUDA compiled (torch.compile + precomputed D^T) Fuses ~60 per-iteration scalar kernels into 1; avoids sparse transpose reconversion
CPU Given rotations on CPU (Python math) No kernel launch overhead to optimize away
MPS Given rotations on CPU (Python math) Metal command buffer batching already amortizes launches (see below)

Callers can override with use_compile=True/False. All existing imports (from pyfixest.estimation.torch.lsmr_torch import lsmr_torch) continue to work — tests pass unchanged.

CUDA benchmark results

bench_cuda_old_vs_opt

Claude's summary:

Benchmarked on DGX (NVIDIA A100), OLS with fixed effects, comparing old (uncompiled) vs optimized (compiled + precomputed D^T):

  • Small N (<100K): No measurable difference — few LSMR iterations, kernel launch overhead is negligible
  • 500K–1M: Optimized version pulls ahead, ~1.3–1.5x
  • 2M–5M: Consistent 1.5–1.8x speedup across both f32 and f64
  • Difficult 3FE (most LSMR iterations) shows the largest gains — the per-iteration savings compound over hundreds of iterations

The two optimizations are complementary:

  1. torch.compile fuses the scalar Givens rotation + norm estimation + convergence check into a single GPU kernel, eliminating ~60 tiny kernel launches per iteration
  2. Precomputed D^T materializes the sparse transpose once upfront instead of reconverting it every LSMR iteration (COO coalesce on MPS, CSR radixSort on CUDA)

Why torch.compile doesn't help on MPS

On CUDA, each scalar operation (e.g., a Givens rotation on two floats) launches a separate kernel. At ~60 scalar ops per LSMR iteration, the CPU→GPU dispatch overhead dominates. torch.compile fixes this by fusing them into one kernel.

MPS (Metal) doesn't have this problem. Metal uses a command buffer model — scalar operations are batched into a command buffer on the CPU side and submitted to the GPU in bulk. The dispatch overhead is already amortized without compilation. What torch.compile adds on MPS is Python-side tracing and graph capture overhead, which is pure cost with no kernel-side benefit. Our A/B benchmarks on Apple Silicon confirmed this: the compiled path was slower than the scalar path at every dataset size on MPS.

This is why the dispatcher defaults to use_compile=False on MPS — it's not a missing optimization, it's the correct choice for the hardware.

@janfb
Copy link
Copy Markdown
Author

janfb commented Mar 6, 2026

More tweaking?

Of course there are even more options for optimizing this on CUDA, e.g., for smaller problems (<100K) one could just represent the dummy matrix not as sparse matrix but as a plain dense one and the torch.compile the entire loop or function. But I think the payoff is marginal given that the big speed up happen at larger N.

Moving on

@s3alfisc , I suggest to get your review on the LSMR part of this code and the tests. The changes to any benchmarking files are less relevant given that the benchmarking will be refactored anyway. I can push the current timings as csv for later reference, shall I?

janfb and others added 14 commits March 31, 2026 10:23
- Added `lsmr_torch_fused.py` for a fused version of the LSMR algorithm, utilizing branchless Givens rotations and 0-d tensors to reduce CPU-GPU sync overhead.
- Introduced tests for the new fused LSMR implementation in `test_lsmr_fused.py`, ensuring correctness against the original LSMR and benchmarking performance.
- Created `test_lsmr_compiled.py` to validate the compiled version of the original LSMR, including auto-detection and MPS compatibility tests.
- also enhance GPU efficiency with pre-computed transpose
…oduce fused version

- Deleted `lsmr_torch_compiled.py` and `lsmr_torch_fused.py` files, consolidating functionality into `lsmr_torch.py`.
- Updated tests to reflect changes in the LSMR implementation, ensuring correctness and performance benchmarks.
- Adjusted convergence checks and state management to optimize CPU-GPU synchronization.
- Enhanced the branchless Givens rotation implementation for improved efficiency on CUDA/MPS.
@janfb janfb force-pushed the feature/torch-lsmr branch from f009def to 824c315 Compare March 31, 2026 12:07
@janfb
Copy link
Copy Markdown
Author

janfb commented Apr 1, 2026

Batched LSMR benchmarks

I also benchmarked the new batched LSMR path, which is used internally when demeaning multiple RHS columns jointly via matrix-multiplication of matrix-vector operations.

What is being compared?

For K right-hand sides, we compare:

  • sequential: K × lsmr_torch(A, b_k)
  • batched: one lsmr_torch_batched(A, B) call

So the question is: when does it pay off to solve multiple columns together instead of looping over the scalar solver?

MPS results

On Apple MPS (float32 only), the behavior is the expected one:

  • for K=1, batched is worse than sequential due to overhead
  • around K=5, batched is around break-even
  • by K=10, batched is clearly faster
  • overall, the speedup dampens for larger N where the cost across rows dominates
bench_batched_lsmr_mps_speedup bench_batched_lsmr_mps_abs_timing bench_batched_lsmr_mps_heatmap

CUDA results

On CUDA, the crossover happens earlier and the gains are much larger.

Across the DGX runs:

  • K=1: batched is slower, about 0.66x
  • K=2: already slightly faster, about 1.3x
  • K=5: clearly faster, about 3.3x
  • K=50: up to 32x

This speedups decrease as we move from N=10K to N=1M but are still 2-12 fold even for 1M (for float32, while forfloat64 it's about half as strong at N=1M).

bench_batched_lsmr_cuda_speedup bench_batched_lsmr_cuda_heatmap bench_batched_lsmr_cuda_abs_timing

Summary

The batched LSMR story is:

  • not worthwhile for a single RHS
  • noticable speedups for K>=2 for CUDA and K>=5 for MPS
  • speedup is strongest for small N and dampens for larger N where compute across N dominates.
  • especially strong on CUDA, where the throughput gains become very large for wider RHS batches

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant