Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 148 additions & 81 deletions duckreg/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,43 @@ def prepare_data(self):
def compress_data(self):
# Pre-compute expressions once to avoid repeated string operations
group_by_cols = ", ".join(self.strata_cols)

# Build aggregation expressions more efficiently
agg_parts = ["COUNT(*) as count"]
sum_expressions = []
sum_sq_expressions = []

for var in self.outcome_vars:
sum_expr = f"SUM({var}) as sum_{var}"
sum_sq_expr = f"SUM(POW({var}, 2)) as sum_{var}_sq"
sum_expressions.append(sum_expr)
sum_sq_expressions.append(sum_sq_expr)

# Single join operation instead of multiple concatenations
all_agg_expressions = ", ".join(agg_parts + sum_expressions + sum_sq_expressions)

all_agg_expressions = ", ".join(
agg_parts + sum_expressions + sum_sq_expressions
)

self.agg_query = f"""
SELECT {group_by_cols}, {all_agg_expressions}
FROM {self.table_name}
GROUP BY {group_by_cols}
"""

self.df_compressed = self.conn.execute(self.agg_query).fetchdf()

# Pre-compute column lists
sum_cols = [f"sum_{var}" for var in self.outcome_vars]
sum_sq_cols = [f"sum_{var}_sq" for var in self.outcome_vars]

self.df_compressed.columns = self.strata_cols + ["count"] + sum_cols + sum_sq_cols


self.df_compressed.columns = (
self.strata_cols + ["count"] + sum_cols + sum_sq_cols
)

# Single eval operation for all means
mean_expressions = [f"mean_{var} = sum_{var}/count" for var in self.outcome_vars]
mean_expressions = [
f"mean_{var} = sum_{var}/count" for var in self.outcome_vars
]
if mean_expressions:
self.df_compressed.eval("\n".join(mean_expressions), inplace=True)

Expand Down Expand Up @@ -153,29 +159,43 @@ def bootstrap(self):
total_rows = self.conn.execute(
f"SELECT COUNT(DISTINCT {self.rowid_col}) FROM {self.table_name}"
).fetchone()[0]
unique_rows = total_rows
# unique_rows = total_rows
unique_groups = np.arange(total_rows) # Add this line
self.bootstrap_query = f"""
SELECT {", ".join(self.strata_cols)}, {", ".join(["COUNT(*) as count"] + [f"SUM({var}) as sum_{var}" for var in self.outcome_vars])}
FROM {self.table_name}
GROUP BY {", ".join(self.strata_cols)}
"""
else:
# Cluster bootstrap
# Cluster bootstrap - FIX
unique_groups = self.conn.execute(
f"SELECT DISTINCT {self.cluster_col} FROM {self.table_name}"
).fetchall()
unique_groups = [group[0] for group in unique_groups]
unique_rows = len(unique_groups)
self.bootstrap_query = f"""
SELECT {", ".join(self.strata_cols)}, {", ".join(["COUNT(*) as count"] + [f"SUM({var}) as sum_{var}" for var in self.outcome_vars])}
FROM {self.table_name}
WHERE {self.cluster_col} IN (SELECT unnest((?)))
WITH resampled AS (
SELECT cluster_id, COUNT(*) as mult
FROM (SELECT unnest(?) as cluster_id)
GROUP BY cluster_id
),
grouped_data AS (
SELECT {", ".join(self.strata_cols)}, {self.cluster_col},
COUNT(*) as count,
{", ".join([f"SUM({var}) as sum_{var}" for var in self.outcome_vars])}
FROM {self.table_name}
GROUP BY {", ".join(self.strata_cols)}, {self.cluster_col}
)
SELECT {", ".join(self.strata_cols)},
SUM(gd.count * r.mult) as count,
{", ".join([f"SUM(gd.sum_{var} * r.mult) as sum_{var}" for var in self.outcome_vars])}
FROM grouped_data gd
JOIN resampled r ON gd.{self.cluster_col} = r.cluster_id
GROUP BY {", ".join(self.strata_cols)}
"""

for b in tqdm(range(self.n_bootstraps)):
resampled_rows = self.rng.choice(
unique_rows, size=unique_rows, replace=True
unique_groups, size=len(unique_groups), replace=True
)
df_boot = pd.DataFrame(
self.conn.execute(
Expand Down Expand Up @@ -271,6 +291,7 @@ def prepare_data(self):
SELECT
t.{self.unit_col},
{f"t.{self.time_col}," if self.time_col is not None else ""}
{f"t.{self.cluster_col}," if self.cluster_col and self.cluster_col != self.unit_col else ""}
t.{self.outcome_var},
{", ".join([f"t.{cov}" for cov in self.covariates])},
{", ".join([f"u.avg_{cov}_unit" for cov in self.covariates])}
Expand All @@ -285,13 +306,17 @@ def compress_data(self):
# Pre-compute column lists to avoid repeated operations
cov_cols = [f"{cov}" for cov in self.covariates]
unit_avg_cols = [f"avg_{cov}_unit" for cov in self.covariates]
time_avg_cols = [f"avg_{cov}_time" for cov in self.covariates] if self.time_col is not None else []

time_avg_cols = (
[f"avg_{cov}_time" for cov in self.covariates]
if self.time_col is not None
else []
)

# Build SELECT and GROUP BY columns once
select_cols = cov_cols + unit_avg_cols + time_avg_cols
select_clause = ", ".join(select_cols)
group_by_clause = ", ".join(select_cols)

self.compress_query = f"""
SELECT
{select_clause},
Expand Down Expand Up @@ -363,27 +388,48 @@ def bootstrap(self):
total_samples = total_units
else:
# Cluster bootstrap
total_clusters = self.conn.execute(
f"SELECT COUNT(DISTINCT {self.cluster_col}) FROM {self.table_name}"
).fetchone()[0]
unique_clusters = self.conn.execute(
f"SELECT DISTINCT {self.cluster_col} FROM {self.table_name}"
).fetchall()
unique_clusters = [c[0] for c in unique_clusters]

self.bootstrap_query = f"""
WITH resampled AS (
SELECT cluster_id, COUNT(*) as mult
FROM (SELECT unnest(?) as cluster_id)
GROUP BY cluster_id
),
grouped_data AS (
SELECT
{", ".join([f"{cov}" for cov in self.covariates])},
{", ".join([f"avg_{cov}_unit" for cov in self.covariates])}
{", " + ", ".join([f"avg_{cov}_time" for cov in self.covariates]) if self.time_col is not None else ""},
{self.cluster_col},
COUNT(*) as count,
SUM({self.outcome_var}) as sum_{self.outcome_var}
FROM design_matrix
GROUP BY {", ".join([f"{cov}" for cov in self.covariates])},
{", ".join([f"avg_{cov}_unit" for cov in self.covariates])}
{", " + ", ".join([f"avg_{cov}_time" for cov in self.covariates]) if self.time_col is not None else ""},
{self.cluster_col}
)
SELECT
{", ".join([f"{cov}" for cov in self.covariates])},
{", ".join([f"avg_{cov}_unit" for cov in self.covariates])}
{", " + ", ".join([f"avg_{cov}_time" for cov in self.covariates]) if self.time_col is not None else ""},
COUNT(*) as count,
SUM({self.outcome_var}) as sum_{self.outcome_var}
FROM design_matrix
WHERE {self.cluster_col} IN (SELECT unnest((?)))
SUM(gd.count * r.mult) as count,
SUM(gd.sum_{self.outcome_var} * r.mult) as sum_{self.outcome_var}
FROM grouped_data gd
JOIN resampled r ON gd.{self.cluster_col} = r.cluster_id
GROUP BY {", ".join([f"{cov}" for cov in self.covariates])},
{", ".join([f"avg_{cov}_unit" for cov in self.covariates])}
{", " + ", ".join([f"avg_{cov}_time" for cov in self.covariates]) if self.time_col is not None else ""}
{", ".join([f"avg_{cov}_unit" for cov in self.covariates])}
{", " + ", ".join([f"avg_{cov}_time" for cov in self.covariates]) if self.time_col is not None else ""}
"""
total_samples = total_clusters
total_samples = unique_clusters

for b in tqdm(range(self.n_bootstraps)):
resampled_samples = self.rng.choice(
total_samples, size=total_samples, replace=True
total_samples, size=len(total_samples), replace=True
)
df_boot = self.conn.execute(
self.bootstrap_query, [resampled_samples.tolist()]
Expand Down Expand Up @@ -495,6 +541,7 @@ def prepare_data(self):
p.{self.time_col},
p.{self.treatment_col},
p.{self.outcome_var},
{f"p.{self.cluster_col}," if self.cluster_col != self.unit_col else ""}
-- Intercept (constant term)
1 AS intercept,
-- cohort intercepts
Expand All @@ -511,11 +558,15 @@ def compress_data(self):
# Pre-compute RHS columns to avoid repeated string operations
cohort_cols = [f"cohort_{cohort}" for cohort in self.cohorts]
time_cols = [f"time_{i}" for i in range(self.num_periods + 1)]
treatment_cols = [f"treatment_time_{cohort}_{i}" for cohort in self.cohorts for i in range(self.num_periods + 1)]

treatment_cols = [
f"treatment_time_{cohort}_{i}"
for cohort in self.cohorts
for i in range(self.num_periods + 1)
]

rhs_cols = ["intercept"] + cohort_cols + time_cols + treatment_cols
rhs_clause = ", ".join(rhs_cols)

# Use single query with CTE instead of temp table
self.compression_query = f"""
{self.design_matrix_cte}
Expand All @@ -526,12 +577,12 @@ def compress_data(self):
FROM transformed_panel_data
GROUP BY {rhs_clause}
"""

self.df_compressed = self.conn.execute(self.compression_query).fetchdf()
self.df_compressed[f"mean_{self.outcome_var}"] = (
self.df_compressed[f"sum_{self.outcome_var}"] / self.df_compressed["count"]
)

# Store for later use
self.rhs_cols = rhs_cols

Expand Down Expand Up @@ -566,46 +617,49 @@ def estimate(self):

def bootstrap(self):
# list all clusters
total_clusters = self.conn.execute(
f"SELECT COUNT(DISTINCT {self.cluster_col}) FROM transformed_panel_data"
).fetchone()[0]
unique_clusters = self.conn.execute(
f"{self.design_matrix_cte} SELECT DISTINCT {self.cluster_col} FROM transformed_panel_data"
).fetchall()
unique_clusters = [c[0] for c in unique_clusters]

boot_coefs = {str(cohort): [] for cohort in self.cohorts}

rhs_clause = ", ".join(self.rhs_cols)

self.bootstrap_query = f"""
{self.design_matrix_cte},
resampled AS (
SELECT cluster_id, COUNT(*) as mult
FROM (SELECT unnest(?) as cluster_id)
GROUP BY cluster_id
),
grouped_data AS (
SELECT
{rhs_clause}, {self.cluster_col},
COUNT(*) as count,
SUM({self.outcome_var}) as sum_{self.outcome_var}
FROM transformed_panel_data
GROUP BY {rhs_clause}, {self.cluster_col}
)
SELECT
{rhs_clause},
SUM(gd.count * r.mult) as count,
SUM(gd.sum_{self.outcome_var} * r.mult) as sum_{self.outcome_var}
FROM grouped_data gd
JOIN resampled r ON gd.{self.cluster_col} = r.cluster_id
GROUP BY {rhs_clause}
"""

# bootstrap loop
for _ in tqdm(range(self.n_bootstraps)):
resampled_clusters = (
self.conn.execute(
f"SELECT UNNEST(ARRAY(SELECT {self.cluster_col} FROM transformed_panel_data ORDER BY RANDOM() LIMIT {total_clusters}))"
)
.fetchdf()
.values.flatten()
.tolist()
)

self.conn.execute(
f"""
CREATE TEMP TABLE resampled_transformed_panel_data AS
SELECT * FROM transformed_panel_data
WHERE {self.cluster_col} IN ({", ".join(map(str, resampled_clusters))})
"""
)

self.conn.execute(
f"""
CREATE TEMP TABLE resampled_compressed_panel_data AS
SELECT
{self.rhs.replace(";", "")},
COUNT(*) AS count,
SUM({self.outcome_var}) AS sum_{self.outcome_var}
FROM
resampled_transformed_panel_data
GROUP BY
{self.rhs.replace(";", "")}
"""
resampled_clusters = self.rng.choice(
unique_clusters, size=len(unique_clusters), replace=True
)

df_boot = self.conn.execute(
"SELECT * FROM resampled_compressed_panel_data"
self.bootstrap_query, [resampled_clusters.tolist()]
).fetchdf()

df_boot[f"mean_{self.outcome_var}"] = (
df_boot[f"sum_{self.outcome_var}"] / df_boot["count"]
)
Expand All @@ -626,8 +680,6 @@ def bootstrap(self):
)
boot_coefs[c].append(event_study_coefs.values.flatten())

self.conn.execute("DROP TABLE resampled_transformed_panel_data")
self.conn.execute("DROP TABLE resampled_compressed_panel_data")
# Calculate the covariance matrix for each cohort
bootstrap_cov_matrix = {
cohort: np.cov(np.array(coefs).T) for cohort, coefs in boot_coefs.items()
Expand Down Expand Up @@ -715,6 +767,7 @@ def prepare_data(self):
SELECT
t.{self.unit_col},
t.{self.time_col},
{f"t.{self.cluster_col}," if self.cluster_col and self.cluster_col != self.unit_col else ""}
t.{self.outcome_var},
t.{self.treatment_var} - um.mean_{self.treatment_var}_unit - tm.mean_{self.treatment_var}_time + om.mean_{self.treatment_var} AS ddot_{self.treatment_var}
FROM {self.table_name} t
Expand Down Expand Up @@ -769,16 +822,30 @@ def bootstrap(self):
GROUP BY ddot_{self.treatment_var}
"""
else:
total_clusters = self.conn.execute(
f"SELECT COUNT(DISTINCT {self.cluster_col}) FROM {self.table_name}"
).fetchone()[0]
unique_clusters = self.conn.execute(
f"SELECT DISTINCT {self.cluster_col} FROM {self.table_name}"
).fetchall()
unique_clusters = [c[0] for c in unique_clusters]
self.bootstrap_query = f"""
WITH resampled AS (
SELECT cluster_id, COUNT(*) as mult
FROM (SELECT unnest(?) as cluster_id)
GROUP BY cluster_id
),
grouped_data AS (
SELECT
ddot_{self.treatment_var}, {self.cluster_col},
COUNT(*) as count,
SUM({self.outcome_var}) as sum_{self.outcome_var}
FROM double_demeaned
GROUP BY ddot_{self.treatment_var}, {self.cluster_col}
)
SELECT
ddot_{self.treatment_var},
COUNT(*) as count,
SUM({self.outcome_var}) as sum_{self.outcome_var}
FROM double_demeaned
WHERE {self.cluster_col} IN (SELECT unnest((?)))
SUM(gd.count * r.mult) as count,
SUM(gd.sum_{self.outcome_var} * r.mult) as sum_{self.outcome_var}
FROM grouped_data gd
JOIN resampled r ON gd.{self.cluster_col} = r.cluster_id
GROUP BY ddot_{self.treatment_var}
"""

Expand All @@ -789,7 +856,7 @@ def bootstrap(self):
)
else:
resampled_clusters = self.rng.choice(
total_clusters, size=total_clusters, replace=True
unique_clusters, size=len(unique_clusters), replace=True
)
resampled_units = resampled_clusters

Expand Down
Loading