Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions Python/rdrobust/src/rdrobust/_numba_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Numba-accelerated nearest-neighbor residual computation for rdrobust.

Replaces the O(n) Python loop in the NN VCE path with JIT-compiled code.
The NN residual computation is the dominant bottleneck in rdrobust, called
7-19 times per estimation. Each call loops over all n observations in
pure Python. This module provides ~100x speedup on that hot path.
"""

import numpy as np
from numba import njit, prange


@njit(cache=True)
def nn_residuals(X, D, matches, dups, dupsid):
"""
Comment on lines +14 to +16
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Removed n and ncols from both JIT signatures. They're now derived from X.shape[0] and D.shape[1] at the top of each kernel.

Compute nearest-neighbor residuals for all data columns.

Parameters
----------
X : float64 array (n,)
Running variable, sorted.
D : float64 array (n, ncols)
Stacked data columns: [y] or [y, T] or [y, T, Z1, Z2, ...].
matches : int
Minimum number of neighbors (nnmatch, default 3).
dups : int64 array (n,)
Number of duplicate X values at each position.
dupsid : int64 array (n,)
Cumulative duplicate ID at each position.

Returns
-------
res : float64 array (n, ncols)
"""
n = X.shape[0]
ncols = D.shape[1]
res = np.empty((n, ncols))
target = min(matches, n - 1)

for pos in range(n):
rpos = dups[pos] - dupsid[pos]
lpos = dupsid[pos] - 1

while lpos + rpos < target:
left_ok = pos - lpos - 1 >= 0
right_ok = pos + rpos + 1 < n

if not left_ok:
rpos += dups[pos + rpos + 1]
elif not right_ok:
lpos += dups[pos - lpos - 1]
else:
d_left = X[pos] - X[pos - lpos - 1]
d_right = X[pos + rpos + 1] - X[pos]
if d_left > d_right:
rpos += dups[pos + rpos + 1]
elif d_left < d_right:
lpos += dups[pos - lpos - 1]
else:
rpos += dups[pos + rpos + 1]
lpos += dups[pos - lpos - 1]

lo = max(0, pos - lpos)
hi = min(n, pos + rpos) + 1 # exclusive
Ji = (hi - lo) - 1
scale = np.sqrt(Ji / (Ji + 1.0))

for col in range(ncols):
col_sum = 0.0
for k in range(lo, hi):
col_sum += D[k, col]
col_sum -= D[pos, col]
res[pos, col] = scale * (D[pos, col] - col_sum / Ji)

return res


@njit(cache=True, parallel=True)
def nn_residuals_parallel(X, D, matches, dups, dupsid):
"""Parallel version of nn_residuals using prange."""
n = X.shape[0]
ncols = D.shape[1]
res = np.empty((n, ncols))
target = min(matches, n - 1)

for pos in prange(n):
rpos = dups[pos] - dupsid[pos]
lpos = dupsid[pos] - 1

while lpos + rpos < target:
left_ok = pos - lpos - 1 >= 0
right_ok = pos + rpos + 1 < n

if not left_ok:
rpos += dups[pos + rpos + 1]
elif not right_ok:
lpos += dups[pos - lpos - 1]
else:
d_left = X[pos] - X[pos - lpos - 1]
d_right = X[pos + rpos + 1] - X[pos]
if d_left > d_right:
rpos += dups[pos + rpos + 1]
elif d_left < d_right:
lpos += dups[pos - lpos - 1]
else:
rpos += dups[pos + rpos + 1]
lpos += dups[pos - lpos - 1]

lo = max(0, pos - lpos)
hi = min(n, pos + rpos) + 1
Ji = (hi - lo) - 1
scale = np.sqrt(Ji / (Ji + 1.0))

for col in range(ncols):
col_sum = 0.0
for k in range(lo, hi):
col_sum += D[k, col]
col_sum -= D[pos, col]
res[pos, col] = scale * (D[pos, col] - col_sum / Ji)

return res


# Threshold for switching to parallel version
PARALLEL_THRESHOLD = 50_000
41 changes: 39 additions & 2 deletions Python/rdrobust/src/rdrobust/funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@
import math
from scipy.linalg import qr

_HAS_NUMBA = None # lazy: None = not yet checked
_nn_res_jit = _nn_res_parallel = PARALLEL_THRESHOLD = None

def _try_numba():
"""Lazy Numba import. Only runs once, on first vce=='nn' call."""
global _HAS_NUMBA, _nn_res_jit, _nn_res_parallel, PARALLEL_THRESHOLD
if _HAS_NUMBA is not None:
return _HAS_NUMBA
try:
from rdrobust._numba_core import (
nn_residuals,
nn_residuals_parallel,
PARALLEL_THRESHOLD as pt,
)
_nn_res_jit = nn_residuals
_nn_res_parallel = nn_residuals_parallel
PARALLEL_THRESHOLD = pt
_HAS_NUMBA = True
except Exception:
_HAS_NUMBA = False
return _HAS_NUMBA

class rdrobust_output:
def __init__(self, Estimate, bws, coef, se, t, pv, ci, beta_p_l, beta_p_r,
V_cl_l, V_cl_r, V_rb_l, V_rb_r, N, N_h, N_b, M,
Expand Down Expand Up @@ -282,15 +304,30 @@ def rdrobust_kweight(X, c, h, kernel):
return w

def rdrobust_res(X, y, T, Z, m, hii, vce, matches, dups, dupsid, d):

n = len(y)
dT = dZ = 0
if T is not None:
dT = 1
if Z is not None:
dZ = ncol(Z)
dZ = ncol(Z)
res = nanmat(n,1+dT+dZ)
if vce=="nn":
if _try_numba():
# float64 cast: rdrobust's pipeline always produces float64;
# explicit here for Numba JIT type stability and cache reuse
X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64)
D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64)
if T is not None:
D = np.column_stack((D, np.asarray(T).ravel()))
if Z is not None:
D = np.column_stack((D, np.asarray(Z)))
D = np.ascontiguousarray(D, dtype=np.float64)
Comment on lines +319 to +325
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The explicit float64 cast is intentional for Numba JIT type stability and cache reuse (avoids recompilation per dtype). In practice rdrobust's internal pipeline always produces float64 arrays, so the two paths are equivalent. Added a comment at the cast site explaining the rationale.

dups_i = np.ascontiguousarray(np.asarray(dups), dtype=np.int64)
dupsid_i = np.ascontiguousarray(np.asarray(dupsid), dtype=np.int64)
if n > PARALLEL_THRESHOLD:
return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i)
return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i)
for pos in range(n):
rpos = dups[pos] - dupsid[pos]
lpos = dupsid[pos] - 1
Expand Down
Loading