Skip to content
Merged
201 changes: 0 additions & 201 deletions causalpy/data/ancova_generated.csv

This file was deleted.

92 changes: 57 additions & 35 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,62 @@
Functions to load example datasets
"""

import pathlib
from collections.abc import Callable
from pathlib import Path

import pandas as pd

import causalpy as cp

DATASETS = {
"banks": {"filename": "banks.csv"},
"brexit": {"filename": "GDP_in_dollars_billions.csv"},
"covid": {"filename": "deaths_and_temps_england_wales.csv"},
"did": {"filename": "did.csv"},
"drinking": {"filename": "drinking.csv"},
"its": {"filename": "its.csv"},
"its simple": {"filename": "its_simple.csv"},
"rd": {"filename": "regression_discontinuity.csv"},
"sc": {"filename": "synthetic_control.csv"},
"anova1": {"filename": "ancova_generated.csv"},
"geolift1": {"filename": "geolift1.csv"},
"geolift_multi_cell": {"filename": "geolift_multi_cell.csv"},
"risk": {"filename": "AJR2001.csv"},
"nhefs": {"filename": "nhefs.csv"},
"schoolReturns": {"filename": "schoolingReturns.csv"},
"pisa18": {"filename": "PISA18sampleScale.csv"},
"nets": {"filename": "nets_df.csv"},
"lalonde": {"filename": "lalonde.csv"},
"zipcodes": {"filename": "zipcodes_data.csv"},
"nevo": {"filename": "data_nevo.csv"},
from .simulate_data import (
RANDOM_SEED,
generate_ancova_data,
generate_did,
generate_geolift_data,
generate_multicell_geolift_data,
generate_regression_discontinuity_data,
generate_synthetic_control_data,
generate_time_series_data_seasonal,
generate_time_series_data_simple,
)

_DATA_DIR = Path(__file__).parent

# Synthetic datasets are generated programmatically for reproducibility.
# .reset_index() on ITS functions because generators set date as the index,
# but the old CSV-based load_data returned date as a column.
SYNTHETIC_DATASETS: dict[str, Callable[[], pd.DataFrame]] = {
"did": lambda: generate_did(seed=RANDOM_SEED),
"rd": lambda: generate_regression_discontinuity_data(
true_treatment_threshold=0.5, seed=RANDOM_SEED
),
"sc": lambda: generate_synthetic_control_data(seed=RANDOM_SEED)[0],
"its": lambda: generate_time_series_data_seasonal(
treatment_time=pd.to_datetime("2017-01-01"), seed=RANDOM_SEED
).reset_index(),
"its simple": lambda: generate_time_series_data_simple(
treatment_time=pd.to_datetime("2015-01-01"), seed=RANDOM_SEED
).reset_index(),
"anova1": lambda: generate_ancova_data(seed=RANDOM_SEED),
"geolift1": lambda: generate_geolift_data(seed=RANDOM_SEED).reset_index(),
"geolift_multi_cell": lambda: generate_multicell_geolift_data(
seed=RANDOM_SEED
).reset_index(),
}


def _get_data_home() -> pathlib.Path:
"""Return the path of the data directory"""
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"
# Real-world datasets remain as CSV files shipped with the package.
REAL_WORLD_DATASETS: dict[str, str] = {
"banks": "banks.csv",
"brexit": "GDP_in_dollars_billions.csv",
"covid": "deaths_and_temps_england_wales.csv",
"drinking": "drinking.csv",
"risk": "AJR2001.csv",
"nhefs": "nhefs.csv",
"schoolReturns": "schoolingReturns.csv",
"pisa18": "PISA18sampleScale.csv",
"nets": "nets_df.csv",
"lalonde": "lalonde.csv",
"zipcodes": "zipcodes_data.csv",
"nevo": "data_nevo.csv",
}


def load_data(dataset: str) -> pd.DataFrame:
Expand Down Expand Up @@ -84,6 +107,7 @@ def load_data(dataset: str) -> pd.DataFrame:
- ``"zipcodes"`` - Geo-experimentation zipcode data for comparative interrupted time
series analysis. Based on synthetic data from Juan Orduz's blog post on
`time-based regression for geo-experiments <https://juanitorduz.github.io/time_based_regression_pymc/>`_.
- ``"nevo"`` - Berry, Levinsohn, and Pakes (1995) cereal data for BLP estimation

Returns
-------
Expand All @@ -106,11 +130,9 @@ def load_data(dataset: str) -> pd.DataFrame:

>>> df = cp.load_data("rd")
"""

if dataset in DATASETS:
data_dir = _get_data_home()
datafile = DATASETS[dataset]
file_path = data_dir / datafile["filename"]
return pd.read_csv(file_path)
if dataset in SYNTHETIC_DATASETS:
return SYNTHETIC_DATASETS[dataset]()
elif dataset in REAL_WORLD_DATASETS:
return pd.read_csv(_DATA_DIR / REAL_WORLD_DATASETS[dataset])
else:
raise ValueError(f"Dataset {dataset} not found!")
raise ValueError(f"Dataset {dataset!r} not found!")
41 changes: 0 additions & 41 deletions causalpy/data/did.csv

This file was deleted.

Loading