Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions pyfixest/estimation/ritest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from lets_plot import (
LetsPlot,
aes,
Expand All @@ -24,7 +25,6 @@

LetsPlot.setup_html()


def _get_ritest_stats_slow(
data: pd.DataFrame,
resampvar: str,
Expand Down Expand Up @@ -77,28 +77,36 @@ def _get_ritest_stats_slow(
fit_ = getattr(fixest_module, model)

resampvar_arr = data_resampled[resampvar].to_numpy()

ri_stats = np.zeros(reps)

for i in tqdm(range(reps)):
D_treat = _resample(
resampvar_arr=resampvar_arr,
clustervar_arr=clustervar_arr,
rng=rng,
iterations=1,
).flatten()

data_resampled[f"{resampvar}_resampled"] = D_treat
results = Parallel(n_jobs=-1)(
delayed(lambda: (
# Create resampled treatment values
D_treat := _resample(
resampvar_arr=resampvar_arr,
clustervar_arr=clustervar_arr,
rng=rng,
iterations=1,
).flatten(),

# Add values to data
data_resampled.__setitem__(f"{resampvar}_resampled", D_treat),
fixest_fit := fit_(fml_update, data=data_resampled, vcov=vcov),

# Return appropriate statistic
fixest_fit.coef().xs(f"{resampvar}_resampled")
if type == "randomization-c"
else fixest_fit.tstat().xs(f"{resampvar}_resampled")
)[3])()
for _ in tqdm(range(reps)) # We use _ since we don't actually need the index
)

fixest_fit = fit_(fml_update, data=data_resampled, vcov=vcov)
if type == "randomization-c":
ri_stats[i] = fixest_fit.coef().xs(f"{resampvar}_resampled")
else:
ri_stats[i] = fixest_fit.tstat().xs(f"{resampvar}_resampled")
# Fill out the results array
for i, result in enumerate(results):
ri_stats[i] = result

return ri_stats


def _get_ritest_stats_fast(
Y: np.ndarray,
X: np.ndarray,
Expand Down