diff --git a/causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py b/causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py index f907d8e76d..669853ffc0 100644 --- a/causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py +++ b/causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py @@ -5,8 +5,10 @@ """ import logging -from typing import Any, Dict, List, Tuple +import re +from typing import Any, Dict, List, Optional, Tuple, cast +import helpers.hdbg as hdbg import matplotlib.pyplot as plt import pandas as pd import requests @@ -118,12 +120,18 @@ def _get_api_request(self, route: str) -> Dict[str, Any]: # Build the full API request URL. url = f"{self._base_url}/{route}?api_key={self._api_key}" # Send HTTP GET request to the EIA API. + # TODO(alvino): Add error handling for the HTTP request to handle + # potential exceptions such as connection errors or timeouts. response = requests.get(url, timeout=20) # Parse JSON content. + # TODO(alvino): Check if the response is successful (e.g., + # `response.status_code == 200`) before attempting to parse the JSON + # content. json_data = response.json() # Get response from parsed payload. data: Dict[str, Any] = {} - data = json_data.get("response", {}) + # TODO(alvino): Add error handling for JSON parsing to manage potential parsing errors. + data = json_data["response"] return data def _get_leaf_route_data(self) -> Dict[str, Dict[str, Any]]: @@ -242,19 +250,19 @@ def _extract_metadata( "url": url, "id": f"{route_clean}.{frequency_id}.{metric_id_clean}", "dataset_id": dataset_id_clean, - "name": data.get("name"), - "description": data.get("description"), - "frequency_id": frequency.get("id"), + "name": data["name"], + "description": data["description"], + "frequency_id": frequency["id"], "frequency_alias": frequency.get("alias"), - "frequency_description": frequency.get("description"), - "frequency_query": frequency.get("query"), - "frequency_format": frequency.get("format"), - "facets": data.get("facets"), + "frequency_description": frequency["description"], + "frequency_query": frequency["query"], + "frequency_format": frequency["format"], + "facets": data["facets"], "data": metric_id, "data_alias": metric_info.get("alias"), "data_units": metric_info.get("units"), - "start_period": data.get("startPeriod"), - "end_period": data.get("endPeriod"), + "start_period": data["startPeriod"], + "end_period": data["endPeriod"], "parameter_values_file": param_file_path, } flattened_metadata.append(metadata) @@ -270,6 +278,11 @@ def _get_facet_values( :param route: dataset route under the EIA v2 API :return: data containing all facet values """ + hdbg.dassert_in( + "facets", + metadata, + msg="Column 'facets' not found in metadata index." + ) facets = metadata["facets"] rows = [] for facet in facets: @@ -295,31 +308,84 @@ def _get_facet_values( def build_full_url( base_url: str, api_key: str, - facet_input: Dict[str, str], + *, + facet_input: Optional[Dict[str, str]] = None, + start_timestamp: Optional[pd.Timestamp] = None, + end_timestamp: Optional[pd.Timestamp] = None, ) -> str: """ - Build a full EIA v2 API URL by appending one facet value per facet type. + Build an EIA v2 API URL to data endpoint. - This modifies the base metadata URL to point to the actual time series - data endpoint. + This function modifies the base metadata URL by: + - Replacing the metadata endpoint with the actual data endpoint + - Injecting the provided API key + - Appending optional facet filters + - Appending start and end timestamps formatted to match the series frequency :param base_url: base API URL with frequency and metric, excluding facet values, e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue" :param api_key: EIA API key, e.g., "abcd1234xyz" :param facet_input: specified facet values, e.g., {"stateid": "KS", "sectorid": "COM"} - :return: full EIA API URL with all required facet parameters, + :param start_timestamp: first observation date + :param end_timestamp: last observation date + :return: full EIA API URL to data endpoint, e.g, "https://api.eia.gov/v2/electricity/retail-sales/data?api_key=abcd1234xyz&frequency=monthly&data[0]=price&facets[stateid][]=KS&facets[sectorid][]=OTH" """ + match = cast(re.Match[str], re.search(r"frequency=([a-zA-Z\-]+)", base_url)) + frequency = match.group(1) base_url = base_url.replace("?", "/data?") url = base_url.replace("{API_KEY}", api_key) query_parts = [] - for facet_id, value in facet_input.items(): - query_parts.append(f"&facets[{facet_id}][]={value}") + if start_timestamp: + formatted_start = _format_timestamp(start_timestamp, frequency) + query_parts.append(f"&start={formatted_start}") + if end_timestamp: + formatted_end = _format_timestamp(end_timestamp, frequency) + query_parts.append(f"&end={formatted_end}") + if facet_input: + # Add facet values when specified. + for facet_id, value in facet_input.items(): + query_parts.append(f"&facets[{facet_id}][]={value}") full_url = url + "".join(query_parts) return full_url +def _format_timestamp(timestamp: pd.Timestamp, frequency: str) -> pd.Timestamp: + """ + Format a timestamp based on the EIA time series frequency. + + Supported formats: + - "annual": "YYYY" + - "quarterly": "YYYY-QN" + - "monthly": "YYYY-MM" + - "daily": "YYYY-MM-DD" + - "hourly": "YYYY-MM-DDTHH" + - "local-hourly": "YYYY-MM-DDTHH-ZZ" (fixed timezone offset, e.g., "-00") + + :param timestamp: the timestamp to format + :param frequency: the frequency type (e.g., "monthly", "local-hourly") + :return: formatted timestamp + """ + result = "" + if frequency == "annual": + result = timestamp.strftime("%Y") + elif frequency == "monthly": + result = timestamp.strftime("%Y-%m") + elif frequency == "quarterly": + q = (timestamp.month - 1) // 3 + 1 + result = f"{timestamp.year}-Q{q}" + elif frequency == "daily": + result = timestamp.strftime("%Y-%m-%d") + elif frequency == "hourly": + result = timestamp.strftime("%Y-%m-%dT%H") + elif frequency == "local-hourly": + result = timestamp.strftime("%Y-%m-%dT%H") + "-00" + else: + raise ValueError(f"Unsupported frequency: {frequency}") + return result + + def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> None: """ Plot a distribution count for a specified metadata column. @@ -329,8 +395,11 @@ def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> Non 'frequency_id', 'data_units') :param title: title for the plot """ - if column not in df_metadata.columns: - raise ValueError(f"Column '{column}' not found in metadata index.") + hdbg.dassert_in( + column, + df_metadata.columns, + msg=f"Column '{column}' not found in metadata index." + ) counts = df_metadata[column].value_counts() ax = counts.plot(kind="bar", figsize=(8, 4), title=title) ax.set_xlabel(column.replace("_", " ").title()) diff --git a/causal_automl/download_eia_data.py b/causal_automl/download_eia_data.py new file mode 100644 index 0000000000..6c32958d12 --- /dev/null +++ b/causal_automl/download_eia_data.py @@ -0,0 +1,249 @@ +""" +Import as: + +import causal_automl.download_eia_data as cadoeida +""" + +import io +import logging +import os +from typing import Dict, Optional, Tuple + +import helpers.hdbg as hdbg +import helpers.hs3 as hs3 +import myeia +import pandas as pd + +import causal_automl.TutorTask401_EIA_metadata_downloader_pipeline.eia_utils as catemdpeu + +_LOG = logging.getLogger(__name__) + + +# ############################################################################# +# EiaDataDownloader +# ############################################################################# + + +class EiaDataDownloader: + """ + Download historical data from EIA. + """ + + def __init__(self, *, aws_profile: str = "ck") -> None: + """ + Initialize the EIA data downloader with the API key and AWS profile. + + EIA API key is read from the environment variable. + + :param aws_profile: AWS CLI profile name used for authentication + """ + hdbg.dassert_in( + "EIA_API_KEY", + os.environ, + msg="EIA_API_KEY is not found in environment variables", + ) + self._api_key = os.getenv("EIA_API_KEY") + self._client = myeia.API(token=self._api_key) + self._aws_profile = aws_profile + self._metadata_index_by_category: Dict[str, pd.DataFrame] = {} + + def filter_series( + self, + df: pd.DataFrame, + id_: str, + facets: Dict[str, str], + ) -> pd.DataFrame: + """ + Filter and clean a single time series from an EIA dataset. + + This function performs data post-processing: + - Filter by facet values (e.g., "stateid", "sectorid") + - Retain only the period and metric column + - Convert the period column to UTC datetime + - Set the period as the index and sort chronologically + + :param df: EIA series data + :param id_: EIA series ID, e.g., + "electricity.retail_sales.monthly.price" + :param facets: facet filters, + e.g., {"stateid": "WI", "sectorid": "ALL"} + :return: data of single time series with one facet value per + facet type + + Example output: + ``` + period price + 2001-01-01T00:00:00+00:00 5.9 + 2001-02-01T00:00:00+00:00 5.98 + 2001-03-01T00:00:00+00:00 5.93 + ``` + """ + # Filter data with given facet values. + for key, val in facets.items(): + hdbg.dassert_in( + key, + df.columns, + msg=( + f"Facet '{key}' not found in data columns={list(df.columns)}" + ), + ) + df = df[df[key] == val] + if df.empty: + _LOG.warning("No data remaining after applying facets.") + # Detect the metric column. + _, _, _, data_identifier = self._parse_id(id_) + df = df[["period", data_identifier]] + # Drop rows with missing value. + df = df.dropna(subset=[data_identifier]) + if df.empty: + _LOG.warning("No data remaining after dropping NaN values.") + # Convert to datetime and index. + df["period"] = pd.to_datetime(df["period"]).dt.tz_localize("UTC") + df = df.set_index("period") + df = df.sort_index() + return df + + def download_series( + self, + id_: str, + *, + start_timestamp: Optional[pd.Timestamp] = None, + end_timestamp: Optional[pd.Timestamp] = None, + max_rows_per_call: int = 5000, + ) -> pd.DataFrame: + """ + Download EIA historical series data. + + This method retrieves the full set of time series linked to an + EIA identifier, including all combinations of facet values + (e.g., `stateid`, `sectorid`). When no start and end timestamps are + passed, the entire time series is downloaded. + + Pagination is handled internally. The `max_rows_per_call` parameter + controls the page size for each API request, but the method will + continue fetching until all available data is retrieved. + + :param id_: EIA series ID, e.g., + "electricity.retail_sales.monthly.price" + :param start_timestamp: first observation date + :param end_timestamp: last observation date + :param max_rows_per_call: max data rows per API call + :return: full time series data with all facets + + Example output: + ``` + period stateid stateDescription sectorid sectorName + 2020-09 WI Wisconsin IND industrial + 2020-09 WY Wyoming ALL all sectors + 2020-09 IA Iowa RES Residential + + price price-units + 7.45 cents per kilowatt-hour + 8.55 cents per kilowatt-hour + 12.65 cents per kilowatt-hour + ``` + """ + # Get base url from metadata index. + base_url = self._get_metadata_url(id_) + # Build URL query with API key and timestamps. + url = catemdpeu.build_full_url( + base_url, + self._api_key, + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + ) + data_chunks = [] + offset = 0 + while True: + # Construct the paginated URL for the current offset. + paginated_url = f"{url}&offset={offset}&length={max_rows_per_call}" + data = self._client.get_response(paginated_url, self._client.header) + data_chunks.append(data) + if len(data) < max_rows_per_call: + # Exit loop when it's the final page of data. + break + offset += max_rows_per_call + if not data_chunks: + _LOG.warning("No data returned under given id.") + df = pd.concat(data_chunks, ignore_index=True) + _LOG.debug("Downloaded %d rows for id=%s", len(df), id_) + return df + + def _parse_id(self, id_: str) -> Tuple[str, str, str, str]: + """ + Parse an EIA time series ID into its components. + + EIA time series IDs follow the format: + ... + + Underscores are converted to dashes to match the EIA API format. + + :param id_: EIA time series ID, + e.g., "electricity.retail_sales.monthly.price" + :return: + - top-level EIA category, e.g., "electricity" + - subroute in the category, e.g., "retail-sales" + - reporting frequency, e.g., "monthly" + - data identifier, e.g., "price" + """ + id_ = id_.replace("_", "-") + parts = id_.split(".") + category = parts[0] + frequency = parts[-2] + data_identifier = parts[-1] + route_parts = parts[1:-2] + subroute = "/".join(route_parts) + return category, subroute, frequency, data_identifier + + def _get_latest_metadata_from_s3(self, category: str) -> pd.DataFrame: + """ + Get the latest versioned metadata index file from S3 for a category. + + :param category: top-level EIA category, e.g., "electricity" + :return: latest versioned metadata index + """ + # Get file names from S3 bucket. + base_dir = "s3://causify-data-collaborators/causal_automl/metadata" + pattern = f"eia_{category}_metadata_original_v*" + files = hs3.listdir( + dir_name=base_dir, + pattern=pattern, + only_files=True, + use_relative_paths=False, + aws_profile=self._aws_profile, + maxdepth=1, + ) + if not files: + raise FileNotFoundError( + f"No metadata index file found for category: '{category}' in S3." + ) + # Get latest file version. + files.sort(reverse=True) + s3_path = f"s3://{files[0]}" + # Load latest metadata index file from S3. + csv_str = hs3.from_file(s3_path, aws_profile=self._aws_profile) + df = pd.read_csv(io.StringIO(csv_str)) + return df + + def _get_metadata_url(self, id_: str) -> str: + """ + Get base URL for given series ID from the metadata index. + + :param id_: EIA time series ID, + e.g., "electricity.retail_sales.monthly.price" + :return: base API URL with frequency and metric, excluding facet values, + e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue" + """ + category, _, _, _ = self._parse_id(id_) + # Load latest metadata index file from S3. + if category not in self._metadata_index_by_category: + self._metadata_index_by_category[category] = ( + self._get_latest_metadata_from_s3(category) + ) + df = self._metadata_index_by_category[category] + # Filter for exact ID match. + match = df[df["id"] == id_] + if match.empty: + raise ValueError(f"Invalid ID: '{id_}'") + base_url: str = match.iloc[0]["url"] + return base_url