-
Notifications
You must be signed in to change notification settings - Fork 41
Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup) #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """ | ||
| Numba-accelerated nearest-neighbor residual computation for rdrobust. | ||
|
|
||
| Replaces the O(n) Python loop in the NN VCE path with JIT-compiled code. | ||
| The NN residual computation is the dominant bottleneck in rdrobust, called | ||
| 7-19 times per estimation. Each call loops over all n observations in | ||
| pure Python. This module provides ~100x speedup on that hot path. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| from numba import njit, prange | ||
|
|
||
|
|
||
| @njit(cache=True) | ||
| def nn_residuals(X, D, matches, dups, dupsid): | ||
| """ | ||
| Compute nearest-neighbor residuals for all data columns. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : float64 array (n,) | ||
| Running variable, sorted. | ||
| D : float64 array (n, ncols) | ||
| Stacked data columns: [y] or [y, T] or [y, T, Z1, Z2, ...]. | ||
| matches : int | ||
| Minimum number of neighbors (nnmatch, default 3). | ||
| dups : int64 array (n,) | ||
| Number of duplicate X values at each position. | ||
| dupsid : int64 array (n,) | ||
| Cumulative duplicate ID at each position. | ||
|
|
||
| Returns | ||
| ------- | ||
| res : float64 array (n, ncols) | ||
| """ | ||
| n = X.shape[0] | ||
| ncols = D.shape[1] | ||
| res = np.empty((n, ncols)) | ||
| target = min(matches, n - 1) | ||
|
|
||
| for pos in range(n): | ||
| rpos = dups[pos] - dupsid[pos] | ||
| lpos = dupsid[pos] - 1 | ||
|
|
||
| while lpos + rpos < target: | ||
| left_ok = pos - lpos - 1 >= 0 | ||
| right_ok = pos + rpos + 1 < n | ||
|
|
||
| if not left_ok: | ||
| rpos += dups[pos + rpos + 1] | ||
| elif not right_ok: | ||
| lpos += dups[pos - lpos - 1] | ||
| else: | ||
| d_left = X[pos] - X[pos - lpos - 1] | ||
| d_right = X[pos + rpos + 1] - X[pos] | ||
| if d_left > d_right: | ||
| rpos += dups[pos + rpos + 1] | ||
| elif d_left < d_right: | ||
| lpos += dups[pos - lpos - 1] | ||
| else: | ||
| rpos += dups[pos + rpos + 1] | ||
| lpos += dups[pos - lpos - 1] | ||
|
|
||
| lo = max(0, pos - lpos) | ||
| hi = min(n, pos + rpos) + 1 # exclusive | ||
| Ji = (hi - lo) - 1 | ||
| scale = np.sqrt(Ji / (Ji + 1.0)) | ||
|
|
||
| for col in range(ncols): | ||
| col_sum = 0.0 | ||
| for k in range(lo, hi): | ||
| col_sum += D[k, col] | ||
| col_sum -= D[pos, col] | ||
| res[pos, col] = scale * (D[pos, col] - col_sum / Ji) | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| @njit(cache=True, parallel=True) | ||
| def nn_residuals_parallel(X, D, matches, dups, dupsid): | ||
| """Parallel version of nn_residuals using prange.""" | ||
| n = X.shape[0] | ||
| ncols = D.shape[1] | ||
| res = np.empty((n, ncols)) | ||
| target = min(matches, n - 1) | ||
|
|
||
| for pos in prange(n): | ||
| rpos = dups[pos] - dupsid[pos] | ||
| lpos = dupsid[pos] - 1 | ||
|
|
||
| while lpos + rpos < target: | ||
| left_ok = pos - lpos - 1 >= 0 | ||
| right_ok = pos + rpos + 1 < n | ||
|
|
||
| if not left_ok: | ||
| rpos += dups[pos + rpos + 1] | ||
| elif not right_ok: | ||
| lpos += dups[pos - lpos - 1] | ||
| else: | ||
| d_left = X[pos] - X[pos - lpos - 1] | ||
| d_right = X[pos + rpos + 1] - X[pos] | ||
| if d_left > d_right: | ||
| rpos += dups[pos + rpos + 1] | ||
| elif d_left < d_right: | ||
| lpos += dups[pos - lpos - 1] | ||
| else: | ||
| rpos += dups[pos + rpos + 1] | ||
| lpos += dups[pos - lpos - 1] | ||
|
|
||
| lo = max(0, pos - lpos) | ||
| hi = min(n, pos + rpos) + 1 | ||
| Ji = (hi - lo) - 1 | ||
| scale = np.sqrt(Ji / (Ji + 1.0)) | ||
|
|
||
| for col in range(ncols): | ||
| col_sum = 0.0 | ||
| for k in range(lo, hi): | ||
| col_sum += D[k, col] | ||
| col_sum -= D[pos, col] | ||
| res[pos, col] = scale * (D[pos, col] - col_sum / Ji) | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| # Threshold for switching to parallel version | ||
| PARALLEL_THRESHOLD = 50_000 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,28 @@ | |
| import math | ||
| from scipy.linalg import qr | ||
|
|
||
| _HAS_NUMBA = None # lazy: None = not yet checked | ||
| _nn_res_jit = _nn_res_parallel = PARALLEL_THRESHOLD = None | ||
|
|
||
| def _try_numba(): | ||
| """Lazy Numba import. Only runs once, on first vce=='nn' call.""" | ||
| global _HAS_NUMBA, _nn_res_jit, _nn_res_parallel, PARALLEL_THRESHOLD | ||
| if _HAS_NUMBA is not None: | ||
| return _HAS_NUMBA | ||
| try: | ||
| from rdrobust._numba_core import ( | ||
| nn_residuals, | ||
| nn_residuals_parallel, | ||
| PARALLEL_THRESHOLD as pt, | ||
| ) | ||
| _nn_res_jit = nn_residuals | ||
| _nn_res_parallel = nn_residuals_parallel | ||
| PARALLEL_THRESHOLD = pt | ||
| _HAS_NUMBA = True | ||
| except Exception: | ||
| _HAS_NUMBA = False | ||
| return _HAS_NUMBA | ||
|
|
||
| class rdrobust_output: | ||
| def __init__(self, Estimate, bws, coef, se, t, pv, ci, beta_p_l, beta_p_r, | ||
| V_cl_l, V_cl_r, V_rb_l, V_rb_r, N, N_h, N_b, M, | ||
|
|
@@ -282,15 +304,30 @@ def rdrobust_kweight(X, c, h, kernel): | |
| return w | ||
|
|
||
| def rdrobust_res(X, y, T, Z, m, hii, vce, matches, dups, dupsid, d): | ||
|
|
||
| n = len(y) | ||
| dT = dZ = 0 | ||
| if T is not None: | ||
| dT = 1 | ||
| if Z is not None: | ||
| dZ = ncol(Z) | ||
| dZ = ncol(Z) | ||
| res = nanmat(n,1+dT+dZ) | ||
| if vce=="nn": | ||
| if _try_numba(): | ||
| # float64 cast: rdrobust's pipeline always produces float64; | ||
| # explicit here for Numba JIT type stability and cache reuse | ||
| X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64) | ||
| D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64) | ||
| if T is not None: | ||
| D = np.column_stack((D, np.asarray(T).ravel())) | ||
| if Z is not None: | ||
| D = np.column_stack((D, np.asarray(Z))) | ||
| D = np.ascontiguousarray(D, dtype=np.float64) | ||
|
Comment on lines
+319
to
+325
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. The explicit |
||
| dups_i = np.ascontiguousarray(np.asarray(dups), dtype=np.int64) | ||
| dupsid_i = np.ascontiguousarray(np.asarray(dupsid), dtype=np.int64) | ||
| if n > PARALLEL_THRESHOLD: | ||
| return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i) | ||
| return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i) | ||
| for pos in range(n): | ||
| rpos = dups[pos] - dupsid[pos] | ||
| lpos = dupsid[pos] - 1 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in b22b55e. Removed
nandncolsfrom both JIT signatures. They're now derived fromX.shape[0]andD.shape[1]at the top of each kernel.