Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup)#11
Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup)#11sankalpsharmaa wants to merge 2 commits intordpackages:masterfrom
Conversation
The nearest-neighbor variance estimator in rdrobust_res contains a pure Python for-loop over all n observations, each with an inner while-loop for neighbor matching. This is the dominant bottleneck, called 7-19 times per rdrobust() invocation. This commit adds optional Numba JIT compilation for this hot path: - New _numba_core.py with serial and parallel (prange) JIT functions - funs.py dispatches to Numba when available, falls back to original Python code when numba is not installed - Stacks [y, T, Z] into a contiguous 2D array for the JIT kernel Performance (Apple M3, Python 3.11, numba 0.64): - rdrobust_res NN path: 260x faster (0.24s -> 0.0009s at n=50K) - Full rdrobust() end-to-end: 14-17x faster - Numerical difference vs pure Python: exactly zero Numba is an optional dependency. Without it, behavior is unchanged.
There was a problem hiding this comment.
Pull request overview
This PR adds an optional Numba-accelerated implementation of the nearest-neighbor (NN) VCE residual computation in rdrobust_res, aiming to remove a Python-level O(n) hot loop and optionally enable parallel execution for large n.
Changes:
- Introduces Numba JIT kernels (serial +
prangeparallel) for NN residual computation. - Adds an import guard and runtime dispatch in
rdrobust_resto use the JIT kernels when available. - Adds a standalone benchmarking script to compare correctness/performance vs the pure-Python path.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
Python/rdrobust/src/rdrobust/_numba_core.py |
New Numba-compiled NN residual kernels (serial + parallel) and a parallelization threshold constant. |
Python/rdrobust/src/rdrobust/funs.py |
Adds optional Numba import and dispatch in the vce=="nn" branch of rdrobust_res. |
benchmark.py |
New benchmark/correctness script to compare Numba vs pure-Python behavior and speed. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @njit(cache=True) | ||
| def nn_residuals(X, D, matches, dups, dupsid, n, ncols): | ||
| """ |
There was a problem hiding this comment.
Fixed in b22b55e. Removed n and ncols from both JIT signatures. They're now derived from X.shape[0] and D.shape[1] at the top of each kernel.
| X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64) | ||
| D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64) | ||
| if T is not None: | ||
| D = np.column_stack((D, np.asarray(T).ravel())) | ||
| if Z is not None: | ||
| D = np.column_stack((D, np.asarray(Z))) | ||
| D = np.ascontiguousarray(D, dtype=np.float64) |
There was a problem hiding this comment.
Good catch. The explicit float64 cast is intentional for Numba JIT type stability and cache reuse (avoids recompilation per dtype). In practice rdrobust's internal pipeline always produces float64 arrays, so the two paths are equivalent. Added a comment at the cast site explaining the rationale.
benchmark.py
Outdated
| print("Warming up Numba JIT...") | ||
| x_warm, y_warm = generate_rd_data(1000, seed=0) | ||
| sides = prepare_nn_inputs(x_warm, y_warm) | ||
| X_s, Y_s, dups, dupsid, n_s = sides[0] | ||
| _ = rdrobust_res(X_s, Y_s, None, None, 0, 0, "nn", 3, dups, dupsid, 2) | ||
| print("JIT warm-up done.\n") |
There was a problem hiding this comment.
Fixed in b22b55e. Warm-up block is now guarded behind if has_numba:, so it only runs (and prints) when Numba is actually available.
Python/rdrobust/src/rdrobust/funs.py
Outdated
| PARALLEL_THRESHOLD, | ||
| ) | ||
| _HAS_NUMBA = True | ||
| except ImportError: |
There was a problem hiding this comment.
Fixed in b22b55e. Import is now lazy via _try_numba(), which runs once on the first vce=="nn" call. Also catches broad Exception instead of just ImportError to handle LLVM/runtime failures gracefully.
Python/rdrobust/src/rdrobust/funs.py
Outdated
| ncols = D.shape[1] | ||
| if n > PARALLEL_THRESHOLD: | ||
| return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i, n, ncols) | ||
| return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i, n, ncols) |
There was a problem hiding this comment.
Fixed in b22b55e. n and ncols are no longer passed to the kernels. Both are derived from the array shapes inside the JIT functions, so mismatches are impossible.
| def generate_rd_data(n, seed=42): | ||
| """Generate synthetic sharp RD data.""" | ||
| rng = np.random.default_rng(seed) | ||
| x = rng.uniform(-1, 1, n) |
There was a problem hiding this comment.
Fixed in b22b55e. Added generate_rd_data_with_masspoints() which rounds X to 1 decimal place, creating duplicate running-variable values. This is now correctness check #2 in the benchmark.
benchmark.py
Outdated
| # ── Correctness check ── | ||
| print("=" * 60) | ||
| print("CORRECTNESS CHECK (n=5,000)") | ||
| print("=" * 60) | ||
| x_check, y_check = generate_rd_data(5000, seed=1) | ||
| sides = prepare_nn_inputs(x_check, y_check) | ||
| X_s, Y_s, dups, dupsid, n_s = sides[0] | ||
|
|
||
| _, res_python = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s) | ||
| _, res_numba = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s) | ||
|
|
||
| max_diff = np.max(np.abs(res_python - res_numba)) | ||
| print(f"Max absolute difference: {max_diff:.2e}") | ||
| if max_diff < 1e-10: | ||
| print("PASS: Results are numerically identical.\n") | ||
| else: | ||
| print("FAIL: Results differ!\n") | ||
| return |
There was a problem hiding this comment.
Fixed in b22b55e. Correctness checks now cover all four cases:
- Sharp RD, continuous X
- Sharp RD, mass points (duplicates)
- Fuzzy RD (y + T)
- With covariates (y + T + Z)
All pass with zero numerical difference.
- Remove redundant n/ncols params from JIT signatures; derive from X.shape[0] and D.shape[1] inside the kernels - Make Numba import lazy (deferred to first vce=="nn" call) and catch broad Exception instead of just ImportError (handles LLVM/runtime) - Add comment explaining float64 cast rationale in Numba dispatch path - Guard JIT warm-up behind _HAS_NUMBA check in benchmark - Add mass-point correctness test (rounded X with duplicates) - Add fuzzy RD (y + T) and covariate (y + T + Z) correctness checks - Extend prepare_nn_inputs to pass T and Z through to each side
Summary
rdrobust_resusing Numba, eliminating the pure Pythonfor pos in range(n)loop that is the dominant bottleneckprangefor datasets with n > 50,000Motivation
The NN VCE path in
rdrobust_res(funs.py:294-319) contains a Python-level loop over all observations, each with an innerwhileloop for neighbor matching. This function is called 7-19 times perrdrobust()invocation (viardrobust_bwin bandwidth selection). For n > 50K, this dominates total runtime.Changes
_numba_core.py(new)@njitfunctions that replace the NN loopfuns.pyvce=="nn"branchbenchmark.py(new)Performance (Apple M3, Python 3.11, Numba 0.64)
rdrobust_resNN path (left side):Full
rdrobust()end-to-end:Numerical difference: exactly zero (verified at n=5K and n=50K, all coefficients, SEs, and CIs match to machine precision).
Design decisions
[y, T, Z]into a single 2D array and loops over columns, avoiding separate functions for sharp/fuzzy/covariate cases.try/except ImportErroraround the Numba import means zero impact on users without Numba installed.Test plan