From ef6a2de92fed732031c6d8395dbc198552f833d0 Mon Sep 17 00:00:00 2001 From: Sankalp Sharma <41304604+sankalpsharmaa@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:37:29 +0530 Subject: [PATCH 1/2] Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup) The nearest-neighbor variance estimator in rdrobust_res contains a pure Python for-loop over all n observations, each with an inner while-loop for neighbor matching. This is the dominant bottleneck, called 7-19 times per rdrobust() invocation. This commit adds optional Numba JIT compilation for this hot path: - New _numba_core.py with serial and parallel (prange) JIT functions - funs.py dispatches to Numba when available, falls back to original Python code when numba is not installed - Stacks [y, T, Z] into a contiguous 2D array for the JIT kernel Performance (Apple M3, Python 3.11, numba 0.64): - rdrobust_res NN path: 260x faster (0.24s -> 0.0009s at n=50K) - Full rdrobust() end-to-end: 14-17x faster - Numerical difference vs pure Python: exactly zero Numba is an optional dependency. Without it, behavior is unchanged. --- Python/rdrobust/src/rdrobust/_numba_core.py | 126 ++++++++++++++ Python/rdrobust/src/rdrobust/funs.py | 28 +++- benchmark.py | 175 ++++++++++++++++++++ 3 files changed, 327 insertions(+), 2 deletions(-) create mode 100644 Python/rdrobust/src/rdrobust/_numba_core.py create mode 100644 benchmark.py diff --git a/Python/rdrobust/src/rdrobust/_numba_core.py b/Python/rdrobust/src/rdrobust/_numba_core.py new file mode 100644 index 0000000..65cfadc --- /dev/null +++ b/Python/rdrobust/src/rdrobust/_numba_core.py @@ -0,0 +1,126 @@ +""" +Numba-accelerated nearest-neighbor residual computation for rdrobust. + +Replaces the O(n) Python loop in the NN VCE path with JIT-compiled code. +The NN residual computation is the dominant bottleneck in rdrobust, called +7-19 times per estimation. Each call loops over all n observations in +pure Python. This module provides ~100x speedup on that hot path. +""" + +import numpy as np +from numba import njit, prange + + +@njit(cache=True) +def nn_residuals(X, D, matches, dups, dupsid, n, ncols): + """ + Compute nearest-neighbor residuals for all data columns. + + Parameters + ---------- + X : float64 array (n,) + Running variable, sorted. + D : float64 array (n, ncols) + Stacked data columns: [y] or [y, T] or [y, T, Z1, Z2, ...]. + matches : int + Minimum number of neighbors (nnmatch, default 3). + dups : int64 array (n,) + Number of duplicate X values at each position. + dupsid : int64 array (n,) + Cumulative duplicate ID at each position. + n : int + Number of observations. + ncols : int + Number of columns in D. + + Returns + ------- + res : float64 array (n, ncols) + """ + res = np.empty((n, ncols)) + target = min(matches, n - 1) + + for pos in range(n): + rpos = dups[pos] - dupsid[pos] + lpos = dupsid[pos] - 1 + + while lpos + rpos < target: + left_ok = pos - lpos - 1 >= 0 + right_ok = pos + rpos + 1 < n + + if not left_ok: + rpos += dups[pos + rpos + 1] + elif not right_ok: + lpos += dups[pos - lpos - 1] + else: + d_left = X[pos] - X[pos - lpos - 1] + d_right = X[pos + rpos + 1] - X[pos] + if d_left > d_right: + rpos += dups[pos + rpos + 1] + elif d_left < d_right: + lpos += dups[pos - lpos - 1] + else: + rpos += dups[pos + rpos + 1] + lpos += dups[pos - lpos - 1] + + lo = max(0, pos - lpos) + hi = min(n, pos + rpos) + 1 # exclusive + Ji = (hi - lo) - 1 + scale = np.sqrt(Ji / (Ji + 1.0)) + + for col in range(ncols): + col_sum = 0.0 + for k in range(lo, hi): + col_sum += D[k, col] + col_sum -= D[pos, col] + res[pos, col] = scale * (D[pos, col] - col_sum / Ji) + + return res + + +@njit(cache=True, parallel=True) +def nn_residuals_parallel(X, D, matches, dups, dupsid, n, ncols): + """Parallel version of nn_residuals using prange.""" + res = np.empty((n, ncols)) + target = min(matches, n - 1) + + for pos in prange(n): + rpos = dups[pos] - dupsid[pos] + lpos = dupsid[pos] - 1 + + while lpos + rpos < target: + left_ok = pos - lpos - 1 >= 0 + right_ok = pos + rpos + 1 < n + + if not left_ok: + rpos += dups[pos + rpos + 1] + elif not right_ok: + lpos += dups[pos - lpos - 1] + else: + d_left = X[pos] - X[pos - lpos - 1] + d_right = X[pos + rpos + 1] - X[pos] + if d_left > d_right: + rpos += dups[pos + rpos + 1] + elif d_left < d_right: + lpos += dups[pos - lpos - 1] + else: + rpos += dups[pos + rpos + 1] + lpos += dups[pos - lpos - 1] + + lo = max(0, pos - lpos) + hi = min(n, pos + rpos) + 1 + Ji = (hi - lo) - 1 + scale = np.sqrt(Ji / (Ji + 1.0)) + + for col in range(ncols): + col_sum = 0.0 + for k in range(lo, hi): + col_sum += D[k, col] + col_sum -= D[pos, col] + res[pos, col] = scale * (D[pos, col] - col_sum / Ji) + + return res + + +# Threshold for switching to parallel version +PARALLEL_THRESHOLD = 50_000 diff --git a/Python/rdrobust/src/rdrobust/funs.py b/Python/rdrobust/src/rdrobust/funs.py index 046beb4..4a03e1a 100644 --- a/Python/rdrobust/src/rdrobust/funs.py +++ b/Python/rdrobust/src/rdrobust/funs.py @@ -12,6 +12,16 @@ import math from scipy.linalg import qr +try: + from rdrobust._numba_core import ( + nn_residuals as _nn_res_jit, + nn_residuals_parallel as _nn_res_parallel, + PARALLEL_THRESHOLD, + ) + _HAS_NUMBA = True +except ImportError: + _HAS_NUMBA = False + class rdrobust_output: def __init__(self, Estimate, bws, coef, se, t, pv, ci, beta_p_l, beta_p_r, V_cl_l, V_cl_r, V_rb_l, V_rb_r, N, N_h, N_b, M, @@ -282,15 +292,29 @@ def rdrobust_kweight(X, c, h, kernel): return w def rdrobust_res(X, y, T, Z, m, hii, vce, matches, dups, dupsid, d): - + n = len(y) dT = dZ = 0 if T is not None: dT = 1 if Z is not None: - dZ = ncol(Z) + dZ = ncol(Z) res = nanmat(n,1+dT+dZ) if vce=="nn": + if _HAS_NUMBA: + X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64) + D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64) + if T is not None: + D = np.column_stack((D, np.asarray(T).ravel())) + if Z is not None: + D = np.column_stack((D, np.asarray(Z))) + D = np.ascontiguousarray(D, dtype=np.float64) + dups_i = np.ascontiguousarray(np.asarray(dups), dtype=np.int64) + dupsid_i = np.ascontiguousarray(np.asarray(dupsid), dtype=np.int64) + ncols = D.shape[1] + if n > PARALLEL_THRESHOLD: + return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i, n, ncols) + return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i, n, ncols) for pos in range(n): rpos = dups[pos] - dupsid[pos] lpos = dupsid[pos] - 1 diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..c61f960 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,175 @@ +""" +Benchmark: rdrobust with Numba acceleration vs pure Python. + +Tests correctness (numerical equivalence) and measures speedup +at multiple dataset sizes. +""" + +import numpy as np +import time +import sys +import os + +# Add the source to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Python", "rdrobust", "src")) + +from rdrobust.funs import ( + _HAS_NUMBA, rdrobust_res, rdrobust_kweight, nanmat, ncol, qrXXinv, crossprod, +) + + +def generate_rd_data(n, seed=42): + """Generate synthetic sharp RD data.""" + rng = np.random.default_rng(seed) + x = rng.uniform(-1, 1, n) + y = 3 + 2 * x + 5 * (x >= 0) + rng.normal(0, 1, n) + return x, y + + +def prepare_nn_inputs(x, y, c=0): + """Prepare inputs for rdrobust_res NN path (mimics rdrobust internals).""" + import pandas as pd + + order = np.argsort(x) + x = x[order] + y = y[order] + + mask = x < c + X_l, Y_l = x[mask], y[mask] + X_r, Y_r = x[~mask], y[~mask] + + sides = [] + for X_s, Y_s in [(X_l, Y_l), (X_r, Y_r)]: + n_s = len(X_s) + aux = pd.DataFrame({"nn": np.ones(n_s), "X": X_s}) + dups = aux.groupby("X")["nn"].transform("sum").values.astype(int) + dupsid = aux.groupby("X")["nn"].transform("cumsum").values.astype(int) + sides.append((X_s, Y_s, dups, dupsid, n_s)) + + return sides + + +def time_rdrobust_res(X, y, dups, dupsid, n, matches=3, repeats=3): + """Time rdrobust_res on the NN path.""" + times = [] + for _ in range(repeats): + t0 = time.perf_counter() + res = rdrobust_res(X, y, None, None, 0, 0, "nn", matches, dups, dupsid, 2) + t1 = time.perf_counter() + times.append(t1 - t0) + return min(times), res + + +def time_pure_python_res(X, y, dups, dupsid, n, matches=3, repeats=1): + """Time the pure Python NN residuals (bypass Numba).""" + res = nanmat(n, 1) + t0 = time.perf_counter() + for pos in range(n): + rpos = dups[pos] - dupsid[pos] + lpos = dupsid[pos] - 1 + while lpos + rpos < min(matches, n - 1): + if pos - lpos - 1 < 0: + rpos += dups[pos + rpos + 1] + elif pos + rpos + 1 >= n: + lpos += dups[pos - lpos - 1] + elif (X[pos] - X[pos - lpos - 1]) > (X[pos + rpos + 1] - X[pos]): + rpos += dups[pos + rpos + 1] + elif (X[pos] - X[pos - lpos - 1]) < (X[pos + rpos + 1] - X[pos]): + lpos += dups[pos - lpos - 1] + else: + rpos += dups[pos + rpos + 1] + lpos += dups[pos - lpos - 1] + ind_J = np.arange(max(0, pos - lpos), min(n, pos + rpos) + 1) + y_J = sum(y[ind_J]) - y[pos] + Ji = len(ind_J) - 1 + res[pos, 0] = np.sqrt(Ji / (Ji + 1)) * (y[pos] - y_J / Ji) + t1 = time.perf_counter() + return t1 - t0, res + + +def run_full_rdrobust(x, y, repeats=3): + """Time a full rdrobust() call end-to-end.""" + from rdrobust.rdrobust import rdrobust as rd + + times = [] + for _ in range(repeats): + t0 = time.perf_counter() + result = rd(y, x) + t1 = time.perf_counter() + times.append(t1 - t0) + return min(times), result + + +def main(): + print(f"Numba available: {_HAS_NUMBA}") + if _HAS_NUMBA: + from rdrobust._numba_core import nn_residuals + print("Numba detected - JIT compilation on first call\n") + + # ── Warm up Numba JIT ── + print("Warming up Numba JIT...") + x_warm, y_warm = generate_rd_data(1000, seed=0) + sides = prepare_nn_inputs(x_warm, y_warm) + X_s, Y_s, dups, dupsid, n_s = sides[0] + _ = rdrobust_res(X_s, Y_s, None, None, 0, 0, "nn", 3, dups, dupsid, 2) + print("JIT warm-up done.\n") + + # ── Correctness check ── + print("=" * 60) + print("CORRECTNESS CHECK (n=5,000)") + print("=" * 60) + x_check, y_check = generate_rd_data(5000, seed=1) + sides = prepare_nn_inputs(x_check, y_check) + X_s, Y_s, dups, dupsid, n_s = sides[0] + + _, res_python = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s) + _, res_numba = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s) + + max_diff = np.max(np.abs(res_python - res_numba)) + print(f"Max absolute difference: {max_diff:.2e}") + if max_diff < 1e-10: + print("PASS: Results are numerically identical.\n") + else: + print("FAIL: Results differ!\n") + return + + # ── Benchmark rdrobust_res NN path ── + print("=" * 60) + print("BENCHMARK: rdrobust_res (NN path, left side only)") + print("=" * 60) + print(f"{'n':>10} {'Python (s)':>12} {'Numba (s)':>12} {'Speedup':>10}") + print("-" * 50) + + for n in [5_000, 20_000, 50_000, 100_000, 200_000, 500_000]: + x_bench, y_bench = generate_rd_data(n * 2, seed=42) + sides = prepare_nn_inputs(x_bench, y_bench) + X_s, Y_s, dups, dupsid, n_s = sides[0] + + # Numba timing + t_numba, _ = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s, repeats=3) + + # Python timing (skip for large n - too slow) + if n <= 50_000: + t_python, _ = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s, repeats=1) + speedup = t_python / t_numba + print(f"{n_s:>10,} {t_python:>12.4f} {t_numba:>12.4f} {speedup:>9.0f}x") + else: + # Extrapolate Python time (linear in n) + print(f"{n_s:>10,} {'(skipped)':>12} {t_numba:>12.4f} {'~':>9}") + + # ── End-to-end benchmark ── + print("\n" + "=" * 60) + print("BENCHMARK: Full rdrobust() end-to-end") + print("=" * 60) + print(f"{'n':>10} {'Time (s)':>12}") + print("-" * 25) + + for n in [10_000, 50_000, 100_000, 500_000]: + x_e2e, y_e2e = generate_rd_data(n, seed=42) + t_e2e, result = run_full_rdrobust(x_e2e, y_e2e, repeats=2) + coef = result.coef.iloc[0, 0] + print(f"{n:>10,} {t_e2e:>12.3f} tau={coef:.4f}") + + +if __name__ == "__main__": + main() From b22b55ef00257a93a913eaf97abfbed51ea24b39 Mon Sep 17 00:00:00 2001 From: Sankalp Sharma <41304604+sankalpsharmaa@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:36:31 +0530 Subject: [PATCH 2/2] address PR review: lazy import, derive dims from arrays, extend tests - Remove redundant n/ncols params from JIT signatures; derive from X.shape[0] and D.shape[1] inside the kernels - Make Numba import lazy (deferred to first vce=="nn" call) and catch broad Exception instead of just ImportError (handles LLVM/runtime) - Add comment explaining float64 cast rationale in Numba dispatch path - Guard JIT warm-up behind _HAS_NUMBA check in benchmark - Add mass-point correctness test (rounded X with duplicates) - Add fuzzy RD (y + T) and covariate (y + T + Z) correctness checks - Extend prepare_nn_inputs to pass T and Z through to each side --- Python/rdrobust/src/rdrobust/_numba_core.py | 12 +- Python/rdrobust/src/rdrobust/funs.py | 39 ++++--- benchmark.py | 118 +++++++++++++++----- 3 files changed, 122 insertions(+), 47 deletions(-) diff --git a/Python/rdrobust/src/rdrobust/_numba_core.py b/Python/rdrobust/src/rdrobust/_numba_core.py index 65cfadc..673db71 100644 --- a/Python/rdrobust/src/rdrobust/_numba_core.py +++ b/Python/rdrobust/src/rdrobust/_numba_core.py @@ -12,7 +12,7 @@ @njit(cache=True) -def nn_residuals(X, D, matches, dups, dupsid, n, ncols): +def nn_residuals(X, D, matches, dups, dupsid): """ Compute nearest-neighbor residuals for all data columns. @@ -28,15 +28,13 @@ def nn_residuals(X, D, matches, dups, dupsid, n, ncols): Number of duplicate X values at each position. dupsid : int64 array (n,) Cumulative duplicate ID at each position. - n : int - Number of observations. - ncols : int - Number of columns in D. Returns ------- res : float64 array (n, ncols) """ + n = X.shape[0] + ncols = D.shape[1] res = np.empty((n, ncols)) target = min(matches, n - 1) @@ -79,8 +77,10 @@ def nn_residuals(X, D, matches, dups, dupsid, n, ncols): @njit(cache=True, parallel=True) -def nn_residuals_parallel(X, D, matches, dups, dupsid, n, ncols): +def nn_residuals_parallel(X, D, matches, dups, dupsid): """Parallel version of nn_residuals using prange.""" + n = X.shape[0] + ncols = D.shape[1] res = np.empty((n, ncols)) target = min(matches, n - 1) diff --git a/Python/rdrobust/src/rdrobust/funs.py b/Python/rdrobust/src/rdrobust/funs.py index 4a03e1a..28cd9cc 100644 --- a/Python/rdrobust/src/rdrobust/funs.py +++ b/Python/rdrobust/src/rdrobust/funs.py @@ -12,15 +12,27 @@ import math from scipy.linalg import qr -try: - from rdrobust._numba_core import ( - nn_residuals as _nn_res_jit, - nn_residuals_parallel as _nn_res_parallel, - PARALLEL_THRESHOLD, - ) - _HAS_NUMBA = True -except ImportError: - _HAS_NUMBA = False +_HAS_NUMBA = None # lazy: None = not yet checked +_nn_res_jit = _nn_res_parallel = PARALLEL_THRESHOLD = None + +def _try_numba(): + """Lazy Numba import. Only runs once, on first vce=='nn' call.""" + global _HAS_NUMBA, _nn_res_jit, _nn_res_parallel, PARALLEL_THRESHOLD + if _HAS_NUMBA is not None: + return _HAS_NUMBA + try: + from rdrobust._numba_core import ( + nn_residuals, + nn_residuals_parallel, + PARALLEL_THRESHOLD as pt, + ) + _nn_res_jit = nn_residuals + _nn_res_parallel = nn_residuals_parallel + PARALLEL_THRESHOLD = pt + _HAS_NUMBA = True + except Exception: + _HAS_NUMBA = False + return _HAS_NUMBA class rdrobust_output: def __init__(self, Estimate, bws, coef, se, t, pv, ci, beta_p_l, beta_p_r, @@ -301,7 +313,9 @@ def rdrobust_res(X, y, T, Z, m, hii, vce, matches, dups, dupsid, d): dZ = ncol(Z) res = nanmat(n,1+dT+dZ) if vce=="nn": - if _HAS_NUMBA: + if _try_numba(): + # float64 cast: rdrobust's pipeline always produces float64; + # explicit here for Numba JIT type stability and cache reuse X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64) D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64) if T is not None: @@ -311,10 +325,9 @@ def rdrobust_res(X, y, T, Z, m, hii, vce, matches, dups, dupsid, d): D = np.ascontiguousarray(D, dtype=np.float64) dups_i = np.ascontiguousarray(np.asarray(dups), dtype=np.int64) dupsid_i = np.ascontiguousarray(np.asarray(dupsid), dtype=np.int64) - ncols = D.shape[1] if n > PARALLEL_THRESHOLD: - return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i, n, ncols) - return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i, n, ncols) + return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i) + return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i) for pos in range(n): rpos = dups[pos] - dupsid[pos] lpos = dupsid[pos] - 1 diff --git a/benchmark.py b/benchmark.py index c61f960..e2ef3f7 100644 --- a/benchmark.py +++ b/benchmark.py @@ -14,7 +14,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Python", "rdrobust", "src")) from rdrobust.funs import ( - _HAS_NUMBA, rdrobust_res, rdrobust_kweight, nanmat, ncol, qrXXinv, crossprod, + _try_numba, rdrobust_res, rdrobust_kweight, nanmat, ncol, qrXXinv, crossprod, ) @@ -26,29 +26,50 @@ def generate_rd_data(n, seed=42): return x, y -def prepare_nn_inputs(x, y, c=0): +def generate_rd_data_with_masspoints(n, seed=42): + """Generate RD data with duplicate running-variable values (mass points).""" + rng = np.random.default_rng(seed) + x = np.round(rng.uniform(-1, 1, n), decimals=1) + y = 3 + 2 * x + 5 * (x >= 0) + rng.normal(0, 1, n) + return x, y + + +def prepare_nn_inputs(x, y, T=None, Z=None, c=0): """Prepare inputs for rdrobust_res NN path (mimics rdrobust internals).""" import pandas as pd order = np.argsort(x) x = x[order] y = y[order] + T_sorted = T[order] if T is not None else None + Z_sorted = Z[order] if Z is not None else None mask = x < c - X_l, Y_l = x[mask], y[mask] - X_r, Y_r = x[~mask], y[~mask] - sides = [] - for X_s, Y_s in [(X_l, Y_l), (X_r, Y_r)]: + for side_mask in [mask, ~mask]: + X_s = x[side_mask] + Y_s = y[side_mask] + T_s = T_sorted[side_mask] if T_sorted is not None else None + Z_s = Z_sorted[side_mask] if Z_sorted is not None else None n_s = len(X_s) aux = pd.DataFrame({"nn": np.ones(n_s), "X": X_s}) dups = aux.groupby("X")["nn"].transform("sum").values.astype(int) dupsid = aux.groupby("X")["nn"].transform("cumsum").values.astype(int) - sides.append((X_s, Y_s, dups, dupsid, n_s)) + sides.append((X_s, Y_s, T_s, Z_s, dups, dupsid, n_s)) return sides +def rdrobust_res_no_numba(X, y, T, Z, matches, dups, dupsid, d): + """Force pure Python path by temporarily disabling Numba.""" + import rdrobust.funs as _funs + saved = _funs._HAS_NUMBA + _funs._HAS_NUMBA = False + res = rdrobust_res(X, y, T, Z, 0, 0, "nn", matches, dups, dupsid, d) + _funs._HAS_NUMBA = saved + return res + + def time_rdrobust_res(X, y, dups, dupsid, n, matches=3, repeats=3): """Time rdrobust_res on the NN path.""" times = [] @@ -100,37 +121,78 @@ def run_full_rdrobust(x, y, repeats=3): return min(times), result +def check_correctness(label, res_python, res_numba): + """Compare Python and Numba results, print pass/fail.""" + max_diff = np.max(np.abs(res_python - res_numba)) + print(f" Max absolute difference: {max_diff:.2e}") + if max_diff < 1e-10: + print(" PASS\n") + return True + print(" FAIL\n") + return False + + def main(): - print(f"Numba available: {_HAS_NUMBA}") - if _HAS_NUMBA: - from rdrobust._numba_core import nn_residuals + has_numba = _try_numba() + print(f"Numba available: {has_numba}") + + if has_numba: print("Numba detected - JIT compilation on first call\n") - # ── Warm up Numba JIT ── - print("Warming up Numba JIT...") - x_warm, y_warm = generate_rd_data(1000, seed=0) - sides = prepare_nn_inputs(x_warm, y_warm) - X_s, Y_s, dups, dupsid, n_s = sides[0] - _ = rdrobust_res(X_s, Y_s, None, None, 0, 0, "nn", 3, dups, dupsid, 2) - print("JIT warm-up done.\n") + # ── Warm up Numba JIT ── + print("Warming up Numba JIT...") + x_warm, y_warm = generate_rd_data(1000, seed=0) + sides = prepare_nn_inputs(x_warm, y_warm) + X_s, Y_s, _, _, dups, dupsid, n_s = sides[0] + _ = rdrobust_res(X_s, Y_s, None, None, 0, 0, "nn", 3, dups, dupsid, 2) + print("JIT warm-up done.\n") - # ── Correctness check ── + # ── Correctness checks ── print("=" * 60) - print("CORRECTNESS CHECK (n=5,000)") + print("CORRECTNESS CHECKS") print("=" * 60) + + # 1. Sharp RD, continuous X + print("\n1. Sharp RD, continuous X (n=5,000)") x_check, y_check = generate_rd_data(5000, seed=1) sides = prepare_nn_inputs(x_check, y_check) - X_s, Y_s, dups, dupsid, n_s = sides[0] - + X_s, Y_s, _, _, dups, dupsid, n_s = sides[0] _, res_python = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s) _, res_numba = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s) + if not check_correctness("sharp", res_python, res_numba): + return - max_diff = np.max(np.abs(res_python - res_numba)) - print(f"Max absolute difference: {max_diff:.2e}") - if max_diff < 1e-10: - print("PASS: Results are numerically identical.\n") - else: - print("FAIL: Results differ!\n") + # 2. Sharp RD with mass points (duplicate X values) + print("2. Sharp RD, mass points (n=5,000)") + x_mp, y_mp = generate_rd_data_with_masspoints(5000, seed=1) + sides_mp = prepare_nn_inputs(x_mp, y_mp) + X_s, Y_s, _, _, dups, dupsid, n_s = sides_mp[0] + _, res_py_mp = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s) + _, res_nb_mp = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s) + if not check_correctness("mass points", res_py_mp, res_nb_mp): + return + + # 3. Fuzzy RD (y + T) + print("3. Fuzzy RD (n=5,000)") + rng = np.random.default_rng(1) + x_fz = rng.uniform(-1, 1, 5000) + T_fz = ((x_fz >= 0).astype(float) + rng.binomial(1, 0.1, 5000)).clip(0, 1) + y_fz = 3 + 2 * x_fz + 1.5 * T_fz + rng.normal(0, 1, 5000) + sides_fz = prepare_nn_inputs(x_fz, y_fz, T=T_fz) + X_s, Y_s, T_s, _, dups, dupsid, n_s = sides_fz[0] + res_py_fz = rdrobust_res_no_numba(X_s, Y_s, T_s, None, 3, dups, dupsid, 2) + res_nb_fz = rdrobust_res(X_s, Y_s, T_s, None, 0, 0, "nn", 3, dups, dupsid, 2) + if not check_correctness("fuzzy", res_py_fz, res_nb_fz): + return + + # 4. With covariates (y + T + Z) + print("4. With covariates (n=5,000)") + Z_fz = rng.normal(0, 1, (5000, 2)) + sides_cov = prepare_nn_inputs(x_fz, y_fz, T=T_fz, Z=Z_fz) + X_s, Y_s, T_s, Z_s, dups, dupsid, n_s = sides_cov[0] + res_py_cov = rdrobust_res_no_numba(X_s, Y_s, T_s, Z_s, 3, dups, dupsid, 2) + res_nb_cov = rdrobust_res(X_s, Y_s, T_s, Z_s, 0, 0, "nn", 3, dups, dupsid, 2) + if not check_correctness("covariates", res_py_cov, res_nb_cov): return # ── Benchmark rdrobust_res NN path ── @@ -143,7 +205,7 @@ def main(): for n in [5_000, 20_000, 50_000, 100_000, 200_000, 500_000]: x_bench, y_bench = generate_rd_data(n * 2, seed=42) sides = prepare_nn_inputs(x_bench, y_bench) - X_s, Y_s, dups, dupsid, n_s = sides[0] + X_s, Y_s, _, _, dups, dupsid, n_s = sides[0] # Numba timing t_numba, _ = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s, repeats=3)