diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 00000000..c4167a05 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,86 @@ +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + release-build: + name: Build release distribution + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + name: Publish release distribution to PyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: + - release-build + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. + # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules + environment: + name: ewb-pypi-release + url: https://pypi.org/p/extremeweatherbench + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ + + publish-to-testpypi: + name: Publish release distribution to TestPyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: + - release-build + + permissions: + id-token: write + + environment: + name: ewb-testpypi-release + url: https://test.pypi.org/p/extremeweatherbench + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to TestPyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + packages-dir: dist/ diff --git a/.github/workflows/run-pre-commit.yaml b/.github/workflows/run-pre-commit.yaml new file mode 100644 index 00000000..42e6056f --- /dev/null +++ b/.github/workflows/run-pre-commit.yaml @@ -0,0 +1,36 @@ +name: Run pre-commit + +on: + pull_request: + branches: [main, develop] + push: + branches: [main] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python env with uv + uses: astral-sh/setup-uv@v4 + with: + version: "0.5.6" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install the project + run: uv sync --all-extras --all-groups + + - name: Run pre-commit hooks + run: uv run pre-commit run --all-files diff --git a/.github/workflows/ci.yaml b/.github/workflows/run-tests.yaml similarity index 50% rename from .github/workflows/ci.yaml rename to .github/workflows/run-tests.yaml index e042a833..99baeb36 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/run-tests.yaml @@ -1,23 +1,20 @@ -name: ci +name: Run tests on: pull_request: + branches: [main, develop] push: branches: [main] -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 +permissions: + contents: read - test: +jobs: + build: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 @@ -30,24 +27,13 @@ jobs: - name: "Set up Python" uses: actions/setup-python@v5 with: - python-version-file: "pyproject.toml" + python-version: ${{ matrix.python-version }} - name: Install the project - run: uv sync --all-extras --dev + run: uv sync --all-extras --all-groups - name: Run tests run: uv run pytest - name: Generate Coverage Report run: uv run coverage report -m - - golden-tests: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: extractions/setup-just@v3 - with: - just-version: 1.43.1 - - - name: Run golden tests with just - run: just golden-tests diff --git a/Justfile b/Justfile index f0c45f50..cddbd2d1 100644 --- a/Justfile +++ b/Justfile @@ -1,7 +1,57 @@ +# NOTE: We automatically load a .env file containing the "GH_TOKEN" environment variable +# for use with semantic-release. If this isn't present, then those commands will likely fail. +set dotenv-load + # List all available recipes default: @just --list -# Placeholder for golden tests -golden-tests: - @just --list \ No newline at end of file +# Run the complete test suite +test: + @echo "Running tests" + uv run pytest + +# Serve a local build of the project documentation at http://localhost:8000 +serve-docs: + @echo "Serving docs at http://localhost:8000" + uv run --extra docs mkdocs serve + +# Build the project documentation +build-docs: + @echo "Building docs" + uv run --extra docs mkdocs build + +# Run the pre-commit hooks on all files in the repo +pre-commit: + @echo "Running pre-commit hooks" + uv run pre-commit run --all-files + +# Run the coverage report +coverage: + @echo "Running coverage report" + uv run coverage run -m pytest + uv run coverage report + +# Determine the next version number +next-version: + @echo "Determining next version" + uv run semantic-release version --print + +# Create a minor release +minor-release: + @echo "Creating minor release" + uv run semantic-release -vvv --noop version --minor --no-changelog + +# Create a patch release +patch-release: + @echo "Creating patch release" + uv run semantic-release -vvv --noop version --patch --no-changelog + +# Upload a release to PyPI +pypi-upload tag: + @echo "Uploading release {{tag}} to PyPI" + git checkout {{tag}} + rm -rf dist + uv run python -m build + uv run twine upload dist/* + git checkout - \ No newline at end of file diff --git a/data_prep/ar_bounds.py b/data_prep/ar_bounds.py index b642d334..1c4cf5cf 100644 --- a/data_prep/ar_bounds.py +++ b/data_prep/ar_bounds.py @@ -17,7 +17,11 @@ from dask.distributed import Client from matplotlib.patches import Rectangle -from extremeweatherbench import cases, derived, inputs, regions, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils from extremeweatherbench.events import atmospheric_river as ar logging.basicConfig() diff --git a/data_prep/ibtracs_bounds.py b/data_prep/ibtracs_bounds.py index 0a2ecc0a..0dc962d9 100644 --- a/data_prep/ibtracs_bounds.py +++ b/data_prep/ibtracs_bounds.py @@ -4,6 +4,7 @@ import logging import re from importlib import resources +from typing import TYPE_CHECKING import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -14,8 +15,11 @@ import yaml from matplotlib.patches import Rectangle +import extremeweatherbench as ewb import extremeweatherbench.data -from extremeweatherbench import cases, inputs, regions, utils + +if TYPE_CHECKING: + from extremeweatherbench.regions import Region logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -67,7 +71,7 @@ def calculate_extent_bounds( bottom_lat: float, top_lat: float, extent_buffer: float = 250, -) -> regions.Region: +) -> Region: """Calculate extent bounds with buffer. Args: @@ -94,9 +98,9 @@ def calculate_extent_bounds( calculate_end_point(bottom_lat, right_lon, 90, extent_buffer), 1 ) - new_left_lon = np.round(utils.convert_longitude_to_360(new_left_lon), 1) - new_right_lon = np.round(utils.convert_longitude_to_360(new_right_lon), 1) - new_box = regions.BoundingBoxRegion( + new_left_lon = np.round(ewb.utils.convert_longitude_to_360(new_left_lon), 1) + new_right_lon = np.round(ewb.utils.convert_longitude_to_360(new_right_lon), 1) + new_box = ewb.regions.BoundingBoxRegion( new_bottom_lat, new_top_lat, new_left_lon, new_right_lon ) return new_box @@ -164,10 +168,10 @@ def load_and_process_ibtracs_data(): """ logger.info("Loading IBTrACS data...") - IBTRACS = inputs.IBTrACS( - source=inputs.IBTRACS_URI, + IBTRACS = ewb.inputs.IBTrACS( + source=ewb.inputs.IBTRACS_URI, variables=["vmax", "slp"], - variable_mapping=inputs.IBTrACS_metadata_variable_mapping, + variable_mapping=ewb.inputs.IBTrACS_metadata_variable_mapping, storage_options={}, ) @@ -177,7 +181,7 @@ def load_and_process_ibtracs_data(): # Get all storms from 2020 - 2025 seasons all_storms_2020_2025_lf = IBTRACS_lf.filter( (pl.col("SEASON").cast(pl.Int32) >= 2020) - ).select(inputs.IBTrACS_metadata_variable_mapping.values()) + ).select(ewb.inputs.IBTrACS_metadata_variable_mapping.values()) schema = all_storms_2020_2025_lf.collect_schema() # Convert pressure and surface wind columns to float, replacing " " with null @@ -464,7 +468,7 @@ def find_storm_bounds_for_case(storm_name, storm_bounds, all_storms_df): # If we found both, merge them by taking the bounding box that # encompasses both if bounds1 is not None and bounds2 is not None: - merged_bbox = regions.BoundingBoxRegion( + merged_bbox = ewb.regions.BoundingBoxRegion( latitude_min=min( bounds1.iloc[0].latitude_min, bounds2.iloc[0].latitude_min ), @@ -537,7 +541,7 @@ def update_cases_with_storm_bounds(storm_bounds, all_storms_df): """ logger.info("Updating cases with storm bounds...") - cases_all = cases.load_ewb_events_yaml_into_case_list() + cases_all = ewb.cases.load_ewb_events_yaml_into_case_list() cases_new = cases_all.copy() # Update the yaml cases with storm bounds from IBTrACS data diff --git a/data_prep/practically_perfect_hindcast_from_lsr.py b/data_prep/practically_perfect_hindcast_from_lsr.py index db114242..1e96cbda 100644 --- a/data_prep/practically_perfect_hindcast_from_lsr.py +++ b/data_prep/practically_perfect_hindcast_from_lsr.py @@ -11,7 +11,8 @@ from scipy.ndimage import gaussian_filter from tqdm.auto import tqdm -from extremeweatherbench import inputs, utils +import extremeweatherbench.inputs as inputs +import extremeweatherbench.utils as utils def sparse_practically_perfect_hindcast( diff --git a/data_prep/severe_convection_bounds.py b/data_prep/severe_convection_bounds.py index 12e1bcf9..11985632 100644 --- a/data_prep/severe_convection_bounds.py +++ b/data_prep/severe_convection_bounds.py @@ -17,7 +17,8 @@ import yaml from scipy.ndimage import label -from extremeweatherbench import calc, cases +import extremeweatherbench.calc as calc +import extremeweatherbench.cases as cases # Radius of Earth in km (mean radius) EARTH_RADIUS_KM = 6371.0 diff --git a/data_prep/subset_heat_cold_events.py b/data_prep/subset_heat_cold_events.py index e109889b..ee6b4add 100644 --- a/data_prep/subset_heat_cold_events.py +++ b/data_prep/subset_heat_cold_events.py @@ -13,7 +13,8 @@ from matplotlib import dates as mdates from mpl_toolkits.axes_grid1 import make_axes_locatable -from extremeweatherbench import cases, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.utils as utils sns.set_theme(style="whitegrid", context="talk") diff --git a/docs/examples/applied_ar.py b/docs/examples/applied_ar.py index 75239f47..a99e5a0e 100644 --- a/docs/examples/applied_ar.py +++ b/docs/examples/applied_ar.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # %% @@ -38,85 +38,86 @@ def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() -case_yaml = [n for n in case_yaml if n.case_id_number == 114] -case_yaml[0].start_date = datetime.datetime(2022, 12, 27, 11, 0, 0) -case_yaml[0].end_date = datetime.datetime(2022, 12, 27, 13, 0, 0) +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() +case_list = [n for n in case_yaml if n.case_id_number == 114] + +case_list[0].start_date = datetime.datetime(2022, 12, 27, 11, 0, 0) +case_list[0].end_date = datetime.datetime(2022, 12, 27, 13, 0, 0) # Define ERA5 target -era5_target = inputs.ERA5( +era5_target = ewb.targets.ERA5( variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) -grap_forecast = inputs.KerchunkForecast( +grap_forecast = ewb.forecasts.KerchunkForecast( name="Graphcast", source="gs://extremeweatherbench/GRAP_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_cira_forecast_dataset, ) -pang_forecast = inputs.KerchunkForecast( +pang_forecast = ewb.forecasts.KerchunkForecast( name="Pangu", source="gs://extremeweatherbench/PANG_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_cira_forecast_dataset, ) # Create a list of evaluation objects for atmospheric river ar_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=grap_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=pang_forecast, @@ -126,7 +127,7 @@ def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: if __name__ == "__main__": # Initialize ExtremeWeatherBench; will only run on cases with event_type # atmospheric_river - ar_ewb = evaluate.ExtremeWeatherBench( + ar_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=ar_evaluation_objects, ) diff --git a/docs/examples/applied_freeze.py b/docs/examples/applied_freeze.py index 8b2325cd..864d76f7 100644 --- a/docs/examples/applied_freeze.py +++ b/docs/examples/applied_freeze.py @@ -1,55 +1,55 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_freeze_target = inputs.ERA5( +era5_freeze_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_freeze_target = inputs.GHCN(variables=["surface_air_temperature"]) +ghcn_freeze_target = ewb.targets.GHCN(variables=["surface_air_temperature"]) # Define forecast (FCNv2 CIRA Virtualizarr) -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=defaults._preprocess_cira_forecast_dataset, + preprocess=ewb.defaults._preprocess_bb_cira_forecast_dataset, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.RootMeanSquaredError(), - metrics.MinimumMeanAbsoluteError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MinimumMeanAbsoluteError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), ] # Create a list of evaluation objects for freeze freeze_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=ghcn_freeze_target, forecast=fcnv2_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=era5_freeze_target, @@ -59,13 +59,13 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench runner instance - ewb = evaluate.ExtremeWeatherBench( + freeze_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=freeze_evaluation_object, ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) + outputs = freeze_ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) # Print the outputs; can be saved if desired outputs.to_csv("freeze_outputs.csv") diff --git a/docs/examples/applied_heatwave.py b/docs/examples/applied_heatwave.py index 7f44b081..22c9b809 100644 --- a/docs/examples/applied_heatwave.py +++ b/docs/examples/applied_heatwave.py @@ -1,56 +1,56 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_heatwave_target = inputs.GHCN( +ghcn_heatwave_target = ewb.targets.GHCN( variables=["surface_air_temperature"], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), - metrics.MaximumLowestMeanAbsoluteError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), + ewb.metrics.MaximumLowestMeanAbsoluteError(), ] # Create a list of evaluation objects for heatwave heatwave_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=ghcn_heatwave_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=era5_heatwave_target, @@ -59,11 +59,11 @@ ] if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( + heatwave_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_object, ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) + outputs = heatwave_ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) outputs.to_csv("applied_heatwave_outputs.csv") diff --git a/docs/examples/applied_severe.py b/docs/examples/applied_severe.py index 6cdba3f9..f6a07003 100644 --- a/docs/examples/applied_severe.py +++ b/docs/examples/applied_severe.py @@ -1,6 +1,6 @@ import logging -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -8,45 +8,45 @@ # Load case data from the default events.yaml -case_yaml = cases.load_ewb_events_yaml_into_case_list() -case_yaml = [n for n in case_yaml if n.case_id_number == 305] +case_yaml = ewb.load_cases() +case_list = [n for n in case_yaml if n.case_id_number == 305] # Define PPH target -pph_target = inputs.PPH( +pph_target = ewb.targets.PPH( variables=["practically_perfect_hindcast"], ) # Define LSR target -lsr_target = inputs.LSR() +lsr_target = ewb.targets.LSR() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", - variables=[derived.CravenBrooksSignificantSevere(layer_depth=100)], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variables=[ewb.derived.CravenBrooksSignificantSevere(layer_depth=100)], + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, ) # Define pph metrics as thresholdmetric to share scores contingency table pph_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.CriticalSuccessIndex, - metrics.FalseAlarmRatio, + ewb.metrics.CriticalSuccessIndex, + ewb.metrics.FalseAlarmRatio, ], forecast_threshold=15000, target_threshold=0.3, ), - metrics.EarlySignal(threshold=15000), + ewb.metrics.EarlySignal(threshold=15000), ] # Define LSR metrics as thresholdmetric to share scores contingency table lsr_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.TruePositives, - metrics.FalseNegatives, + ewb.metrics.TruePositives, + ewb.metrics.FalseNegatives, ], forecast_threshold=15000, target_threshold=0.5, @@ -56,7 +56,7 @@ # Define evaluation objects for severe convection: # One evaluation object for PPH pph_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=pph_metrics, target=pph_target, @@ -66,7 +66,7 @@ # One evaluation object for LSR lsr_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=lsr_metrics, target=lsr_target, @@ -76,14 +76,14 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench with both evaluation objects - ewb = evaluate.ExtremeWeatherBench( - case_metadata=case_yaml, + severe_ewb = ewb.evaluation( + case_metadata=case_list, evaluation_objects=lsr_evaluation_objects + pph_evaluation_objects, ) logger.info("Starting EWB run") # Run the workflow with parllel_config backend set to dask - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = severe_ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the results to a CSV file outputs.to_csv("applied_severe_convection_results.csv") diff --git a/docs/examples/applied_tc.py b/docs/examples/applied_tc.py index 79d4d3b4..e0d17d3c 100644 --- a/docs/examples/applied_tc.py +++ b/docs/examples/applied_tc.py @@ -1,56 +1,102 @@ import logging -from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics +import numpy as np +import xarray as xr + +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) -# Load the case list from the YAML file -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Preprocessing function for CIRA data that includes geopotential thickness calculation +# required for tropical cyclone tracks +def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: + """An example preprocess function that renames the time coordinate to lead_time, + creates a valid_time coordinate, and sets the lead time range and resolution not + present in the original dataset. + + Args: + ds: The forecast dataset to rename. + + Returns: + The renamed forecast dataset. + """ + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") + ds["geopotential_thickness"] = ewb.calc.geopotential_thickness( + ds["z"], top_level_value=300, bottom_level_value=500 + ) + return ds + + +# Preprocessing function for HRES data that includes geopotential thickness calculation +# required for tropical cyclone tracks +def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: + """An example preprocess function that renames the time coordinate to lead_time, + creates a valid_time coordinate, and sets the lead time range and resolution not + present in the original dataset. + + Args: + ds: The forecast dataset to rename. + """ + ds["geopotential_thickness"] = ewb.calc.geopotential_thickness( + ds["geopotential"], + top_level_value=300, + bottom_level_value=500, + geopotential=True, + ) + return ds + + +# Load the case collection from the YAML file +case_yaml = ewb.load_cases() # Select single case (TC Ida) -case_yaml = [n for n in case_yaml if n.case_id_number == 220] +case_list = [n for n in case_yaml if n.case_id_number == 220] # Define IBTrACS target, no arguments needed as defaults are sufficient -ibtracs_target = inputs.IBTrACS() +ibtracs_target = ewb.targets.IBTrACS() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", # Define tropical cyclone track derivedvariable to include in the forecast - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for HRES forecast - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, # Preprocess the HRES forecast to include geopotential thickness calculation - preprocess=defaults._preprocess_hres_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_hres_tc_forecast_dataset, ) -# Define FCNv2 forecast -fcnv2_forecast = inputs.KerchunkForecast( +# Define FCNv2 forecast, this is the old version for reference only +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcn_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for FCNv2 forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the FCNv2 forecast to include geopotential thickness calculation - preprocess=defaults._preprocess_cira_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) # Define Pangu forecast -pangu_forecast = inputs.KerchunkForecast( +pangu_forecast = ewb.forecasts.KerchunkForecast( name="pangu_forecast", source="gs://extremeweatherbench/PANG_v100_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for Pangu forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the Pangu forecast to include geopotential thickness calculation # which uses the same preprocessing function as the FCNv2 forecast - preprocess=defaults._preprocess_cira_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) @@ -60,11 +106,11 @@ # the evaluation to occur, in the case of multiple landfalls, for the next landfall in # time to be evaluated against composite_landfall_metrics = [ - metrics.LandfallMetric( + ewb.metrics.LandfallMetric( metrics=[ - metrics.LandfallIntensityMeanAbsoluteError, - metrics.LandfallTimeMeanError, - metrics.LandfallDisplacement, + ewb.metrics.LandfallIntensityMeanAbsoluteError, + ewb.metrics.LandfallTimeMeanError, + ewb.metrics.LandfallDisplacement, ], approach="next", # Set the intensity variable to use for the metric @@ -77,21 +123,21 @@ # the relevant cases inside the events YAML file tc_evaluation_object = [ # HRES forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=hres_forecast, ), # Pangu forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=pangu_forecast, ), # FCNv2 forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, @@ -101,13 +147,13 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( - case_metadata=case_yaml, + tc_ewb = ewb.evaluation( + case_metadata=case_list, evaluation_objects=tc_evaluation_object, ) logger.info("Starting EWB run") # Run the workflow with parallel_config backend set to dask - outputs = ewb.run( + outputs = tc_ewb.run( parallel_config={"backend": "loky", "n_jobs": 3}, ) outputs.to_csv("tc_metric_test_results.csv") diff --git a/docs/examples/example_config.py b/docs/examples/example_config.py index 3aff5967..f41d5e44 100644 --- a/docs/examples/example_config.py +++ b/docs/examples/example_config.py @@ -7,31 +7,29 @@ ewb --config-file example_config.py """ -from extremeweatherbench import cases, inputs, metrics +import extremeweatherbench as ewb # Define targets (observation data) -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # Define forecasts -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, ) # Define evaluation objects evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.OnsetMeanError(), - metrics.DurationMeanError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), ], target=era5_heatwave_target, forecast=fcnv2_forecast, @@ -39,8 +37,9 @@ ] # Load case data from the default events.yaml -# Users can also define their own cases -cases_list = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +cases_list = ewb.load_cases() + # Alternatively, users could define custom cases like this: # cases_list = [ # { diff --git a/docs/parallelism.md b/docs/parallelism.md index 98ceb4be..fae25fea 100644 --- a/docs/parallelism.md +++ b/docs/parallelism.md @@ -35,7 +35,7 @@ ewb = evaluate.ExtremeWeatherBench( # The larger the machine, the larger n_jobs can be (a bit of an oversimplification) parallel_config = {"backend":"loky","n_jobs":len(evaluation_objects)} -outputs = ewb.run(parallel_config=parallel_config) +outputs = ewb.run_evaluation(parallel_config=parallel_config) ``` The _safest_ approach is to run EWB in serial, with `n_jobs` set to 1. `Dask` will still be invoked during each `CaseOperator` when the case executes and computes the directed acyclic graph, only one at a time. That said, for evaluations with more cases this approach would likely be too time-consuming. \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index d5e37778..ddd372fa 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -10,16 +10,17 @@ tropical cyclones, and atmospheric rivers: ```python -from extremeweatherbench import evaluate, defaults, cases +import extremeweatherbench as ewb -eval_objects = defaults.get_brightband_evaluation_objects() +eval_objects = ewb.get_brightband_evaluation_objects() +cases = ewb.load_cases() -cases = cases.load_ewb_events_yaml_into_case_list() -ewb = ExtremeWeatherBench(cases=cases, -evaluation_objects=eval_objects) - -outputs = ewb.run() +runner = ewb.evaluation( + case_metadata=cases, + evaluation_objects=eval_objects +) +outputs = runner.run() outputs.to_csv('your_outputs.csv') ``` @@ -28,6 +29,30 @@ or: ```bash ewb --default ``` + +## API Overview + +ExtremeWeatherBench provides a hierarchical API for accessing its components: + +```python +import extremeweatherbench as ewb + +# Main evaluation entry point +ewb.evaluation(...) # Alias for ExtremeWeatherBench class + +# Hierarchical access via namespaces +ewb.targets.ERA5(...) # Target classes +ewb.forecasts.ZarrForecast(...) # Forecast classes +ewb.metrics.MeanAbsoluteError() # Metric classes +ewb.derived.AtmosphericRiverVariables() # Derived variables +ewb.regions.BoundingBoxRegion(...) # Region classes +ewb.cases.IndividualCase # Case metadata classes + +# Also available at top level for convenience +ewb.ERA5(...) +ewb.ZarrForecast(...) +ewb.load_cases() +``` ## Running an Evaluation for a Single Event Type ExtremeWeatherBench has default event types and cases for heat waves, freezes, severe convection, tropical cyclones, and atmospheric rivers. @@ -39,20 +64,20 @@ ExtremeWeatherBench requires forecasts to have `init_time`, `lead_time`, `latitu Targets require at least a `valid_time` with at least one spatial dimension. Examples include `location`, `station`, or (`latitude`, `longitude`). Forecasts are aligned to targets during the steps immediately prior to evaluating a metric. ```python -from extremeweatherbench import inputs +import extremeweatherbench as ewb ``` There are three built-in `ForecastBase` classes to set up a forecast: `ZarrForecast`, `XarrayForecast`, and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, # built-in mapping available + variable_mapping=ewb.HRES_metadata_variable_mapping, # built-in mapping available storage_options={"remote_options": {"anon": True}}, - ) ``` + There are required arguments, namely: - `source` @@ -67,8 +92,8 @@ There are required arguments, namely: Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: ```python -era5_heatwave_target = inputs.ERA5( - source=inputs.ARCO_ERA5_FULL_URI, +era5_heatwave_target = ewb.targets.ERA5( + source=ewb.ARCO_ERA5_FULL_URI, variables=["surface_air_temperature"], storage_options={"remote_options": {"anon": True}}, chunks=None, @@ -87,48 +112,53 @@ Or (if defining variables as arguments to the metrics): era5_heatwave_target = inputs.ERA5() ``` -> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `inputs.ERA5` in an evaluation; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). +> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `ewb.inputs.ERA5` in an evaluation; `variable_mapping` defaults to `ewb.inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). We then set up an `EvaluationObject` list: ```python -from extremeweatherbench import metrics - heatwave_evaluation_list = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.MaximumLowestMeanAbsoluteError() + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MaximumLowestMeanAbsoluteError() ], target=era5_heatwave_target, forecast=hres_forecast, ), ] ``` + Which includes the event_type of interest (as defined in the case dictionary or YAML file used), the list of metrics to run, one target, and one forecast. There can be multiple `EvaluationObjects` which are used for an evaluation run. Plugging these all in: ```python -from extremeweatherbench import cases, evaluate -case_list = cases.load_ewb_events_yaml_into_case_list() - +case_yaml = ewb.load_cases() -ewb_instance = evaluate.ExtremeWeatherBench( - cases=case_list, +ewb_instance = ewb.evaluation( + case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_list, ) outputs = ewb_instance.run() - outputs.to_csv('your_file_name.csv') ``` -Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we trigger the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. +Where the EWB default events YAML file is loaded in using `ewb.load_cases()`, then applied to an instance of `ewb.evaluation` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. Running locally is feasible but is typically bottlenecked heavily by IO and network bandwidth. Even on a gigabit connection, the rate of data access is significantly slower compared to within a cloud provider VM. -The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. +The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, etc. + +## Backward Compatibility + +All existing import patterns remain functional: + +```python +from extremeweatherbench import evaluate, inputs, cases, metrics # Still works +from extremeweatherbench.evaluate import ExtremeWeatherBench # Still works +``` diff --git a/pyproject.toml b/pyproject.toml index 5edb482d..887442c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,26 @@ [project] name = "extremeweatherbench" -version = "0.3.0" +version = "1.0.0" description = "Benchmarking weather and weather AI models using extreme events" +keywords = [ + "weather", + "extreme events", + "benchmarking", + "forecasting", + "climate", +] +license = { file = "LICENSE" } readme = "README.md" +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Atmospheric Science", +] requires-python = ">=3.11,<3.14" dependencies = [ "dacite>=1.8.1", @@ -63,6 +81,8 @@ dev = [ "types-pytz>=2025.2.0.20250809", "types-pyyaml>=6.0.12.20241230", "types-tqdm>=4.67.0.20250809", + "python-semantic-release>=10.3.0", + "twine>=5.1.1", ] docs = [ "mkdocs>=1.6.1", @@ -72,17 +92,22 @@ docs = [ "pymdown-extensions>=10.19.1", ] +complete = ["extremeweatherbench[data-prep,multiprocessing]"] + [build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["hatchling >= 1.26"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/extremeweatherbench"] -[tool.setuptools] -packages = ["extremeweatherbench"] -package-dir = { "" = "src" } -include-package-data = true +[tool.hatch.build.targets.sdist] +include = ["src/extremeweatherbench/**/*"] + +[project.urls] +Documentation = "https://extremeweatherbench.readthedocs.io/" +Repository = "https://github.com/brightbandtech/extremeweatherbench" -[tool.setuptools.package-data] -extremeweatherbench = ["data/**/*", "data/**/.*"] [project.scripts] ewb = "extremeweatherbench.evaluate_cli:cli_runner" @@ -125,3 +150,32 @@ docstring-code-line-length = "dynamic" [tool.ruff.lint.isort] case-sensitive = true + +[tool.semantic_release] +version_toml = ["pyproject.toml:project.version"] +branch = "main" +dist_path = "dist/" +upload_to_pypi = false +remote = { type = "github" } +commit_author = "semantic-release " +commit_parser = "conventional" +commit_parser_options = { parse_squash_commits = "false", parse_merge_commits = "true" } +minor_tag = "[minor]" +patch_tag = "[patch]" +major_tag = "[major]" +build_command = """ + uv lock --offline + git add uv.lock + uv build +""" +# Only create GitHub releases for the current version, not historical ones +github_release_mode = "latest" +# Ensure assets are only uploaded for the current release, not past ones +upload_assets_for_all_releases = false + +[tool.pytest] +addopts = ["--ignore=tests/test_golden.py", "--cov=extremeweatherbench"] +markers = [ + "integration: marks tests as integration tests (may be slow)", + "slow: marks tests as slow (may take longer to complete)", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 29aab096..00000000 --- a/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -addopts = --cov=extremeweatherbench -markers = - integration: marks tests as integration tests (may be slow) - slow: marks tests as slow (may take longer to complete) \ No newline at end of file diff --git a/scripts/brightband_evaluation.py b/scripts/brightband_evaluation.py index 4086ef51..1dc8103b 100644 --- a/scripts/brightband_evaluation.py +++ b/scripts/brightband_evaluation.py @@ -39,5 +39,5 @@ def configure_logger(level=logging.INFO): # Set up parallel configuration parallel_config = {"backend": "loky", "n_jobs": n_processes} - results = ewb.run(parallel_config=parallel_config) + results = ewb.run_evaluation(parallel_config=parallel_config) results.to_csv("brightband_evaluation_results.csv", index=False) diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index e69de29b..af479c2f 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -0,0 +1,339 @@ +"""ExtremeWeatherBench: A benchmarking framework for extreme weather forecasts. + +This module provides the public API for ExtremeWeatherBench. Users can import +the package and access all key functionality: + + import extremeweatherbench as ewb + + # Main entry point for evaluation + ewb.evaluation(case_metadata=..., evaluation_objects=...) + + # Hierarchical access via namespace submodules + ewb.targets.ERA5(...) + ewb.forecasts.ZarrForecast(...) + ewb.metrics.MeanAbsoluteError(...) + + # Also available at top level + ewb.ERA5(...) + ewb.load_cases() +""" + +from types import SimpleNamespace + +# Import actual modules for backwards compatibility +from extremeweatherbench import calc, cases, defaults, derived, metrics, regions, utils + +# Import specific items for top-level access +from extremeweatherbench.calc import ( + convert_from_cartesian_to_latlon, + geopotential_thickness, + great_circle_mask, + haversine_distance, + maybe_calculate_wind_speed, + mixing_ratio, + orography, + pressure_at_surface, + saturation_mixing_ratio, + saturation_vapor_pressure, + specific_humidity_from_relative_humidity, +) +from extremeweatherbench.cases import ( + CaseOperator, + IndividualCase, + build_case_operators, + load_ewb_events_yaml_into_case_list, + load_individual_cases, + load_individual_cases_from_yaml, + read_incoming_yaml, +) +from extremeweatherbench.defaults import ( + DEFAULT_COORDINATE_VARIABLES, + DEFAULT_VARIABLE_NAMES, + cira_fcnv2_atmospheric_river_forecast, + cira_fcnv2_freeze_forecast, + cira_fcnv2_heatwave_forecast, + cira_fcnv2_severe_convection_forecast, + cira_fcnv2_tropical_cyclone_forecast, + era5_atmospheric_river_target, + era5_freeze_target, + era5_heatwave_target, + get_brightband_evaluation_objects, + get_climatology, + ghcn_freeze_target, + ghcn_heatwave_target, + ibtracs_target, + lsr_target, + pph_target, +) +from extremeweatherbench.derived import ( + AtmosphericRiverVariables, + CravenBrooksSignificantSevere, + DerivedVariable, + TropicalCycloneTrackVariables, + maybe_derive_variables, + maybe_include_variables_from_derived_input, +) +from extremeweatherbench.evaluate import ExtremeWeatherBench +from extremeweatherbench.inputs import ( + ARCO_ERA5_FULL_URI, + DEFAULT_GHCN_URI, + ERA5, + GHCN, + IBTRACS_URI, + LSR, + LSR_URI, + PPH, + PPH_URI, + CIRA_metadata_variable_mapping, + ERA5_metadata_variable_mapping, + EvaluationObject, + ForecastBase, + HRES_metadata_variable_mapping, + IBTrACS, + IBTrACS_metadata_variable_mapping, + InputBase, + KerchunkForecast, + TargetBase, + XarrayForecast, + ZarrForecast, + align_forecast_to_target, + check_for_missing_data, + maybe_subset_variables, + open_kerchunk_reference, + zarr_target_subsetter, +) +from extremeweatherbench.metrics import ( + Accuracy, + BaseMetric, + CompositeMetric, + CriticalSuccessIndex, + DurationMeanError, + EarlySignal, + FalseAlarmRatio, + FalseNegatives, + FalsePositives, + LandfallDisplacement, + LandfallIntensityMeanAbsoluteError, + LandfallMetric, + LandfallTimeMeanError, + MaximumLowestMeanAbsoluteError, + MaximumMeanAbsoluteError, + MeanAbsoluteError, + MeanError, + MeanSquaredError, + MinimumMeanAbsoluteError, + RootMeanSquaredError, + SpatialDisplacement, + ThresholdMetric, + TrueNegatives, + TruePositives, +) +from extremeweatherbench.regions import ( + REGION_TYPES, + BoundingBoxRegion, + CenteredRegion, + Region, + RegionSubsetter, + ShapefileRegion, + map_to_create_region, + subset_cases_to_region, + subset_results_to_region, +) +from extremeweatherbench.utils import ( + check_for_vars, + convert_day_yearofday_to_time, + convert_init_time_to_valid_time, + convert_longitude_to_180, + convert_longitude_to_360, + convert_valid_time_to_init_time, + derive_indices_from_init_time_and_lead_time, + determine_temporal_resolution, + extract_tc_names, + filter_kwargs_for_callable, + find_common_init_times, + idx_to_coords, + interp_climatology_to_target, + is_valid_landfall, + load_land_geometry, + maybe_cache_and_compute, + maybe_densify_dataarray, + maybe_get_closest_timestamp_to_center_of_valid_times, + maybe_get_operator, + min_if_all_timesteps_present, + min_if_all_timesteps_present_forecast, + read_event_yaml, + remove_ocean_gridpoints, + stack_dataarray_from_dims, +) + +# Aliases +evaluation = ExtremeWeatherBench +load_cases = load_ewb_events_yaml_into_case_list + +# Namespace submodules for convenient grouping (these don't shadow actual modules) +targets = SimpleNamespace( + TargetBase=TargetBase, + ERA5=ERA5, + GHCN=GHCN, + IBTrACS=IBTrACS, + LSR=LSR, + PPH=PPH, +) + +forecasts = SimpleNamespace( + ForecastBase=ForecastBase, + ZarrForecast=ZarrForecast, + KerchunkForecast=KerchunkForecast, + XarrayForecast=XarrayForecast, +) + +__all__ = [ + # Core evaluation + "evaluation", + "ExtremeWeatherBench", + # Modules + "calc", + "cases", + "defaults", + "derived", + "metrics", + "regions", + "utils", + # Namespace submodules + "targets", + "forecasts", + # Aliases + "load_cases", + # calc + "convert_from_cartesian_to_latlon", + "geopotential_thickness", + "great_circle_mask", + "haversine_distance", + "maybe_calculate_wind_speed", + "mixing_ratio", + "orography", + "pressure_at_surface", + "saturation_mixing_ratio", + "saturation_vapor_pressure", + "specific_humidity_from_relative_humidity", + # cases + "CaseOperator", + "IndividualCase", + "build_case_operators", + "load_ewb_events_yaml_into_case_list", + "load_individual_cases", + "load_individual_cases_from_yaml", + "read_incoming_yaml", + # defaults + "DEFAULT_COORDINATE_VARIABLES", + "DEFAULT_VARIABLE_NAMES", + "cira_fcnv2_atmospheric_river_forecast", + "cira_fcnv2_freeze_forecast", + "cira_fcnv2_heatwave_forecast", + "cira_fcnv2_severe_convection_forecast", + "cira_fcnv2_tropical_cyclone_forecast", + "era5_atmospheric_river_target", + "era5_freeze_target", + "era5_heatwave_target", + "get_brightband_evaluation_objects", + "get_climatology", + "ghcn_freeze_target", + "ghcn_heatwave_target", + "ibtracs_target", + "lsr_target", + "pph_target", + # derived + "AtmosphericRiverVariables", + "CravenBrooksSignificantSevere", + "DerivedVariable", + "TropicalCycloneTrackVariables", + "maybe_derive_variables", + "maybe_include_variables_from_derived_input", + # inputs + "ARCO_ERA5_FULL_URI", + "CIRA_metadata_variable_mapping", + "DEFAULT_GHCN_URI", + "ERA5", + "ERA5_metadata_variable_mapping", + "EvaluationObject", + "ForecastBase", + "GHCN", + "HRES_metadata_variable_mapping", + "IBTrACS", + "IBTrACS_metadata_variable_mapping", + "IBTRACS_URI", + "InputBase", + "KerchunkForecast", + "LSR", + "LSR_URI", + "PPH", + "PPH_URI", + "TargetBase", + "XarrayForecast", + "ZarrForecast", + "align_forecast_to_target", + "check_for_missing_data", + "maybe_subset_variables", + "open_kerchunk_reference", + "zarr_target_subsetter", + # metrics + "Accuracy", + "BaseMetric", + "CompositeMetric", + "CriticalSuccessIndex", + "DurationMeanError", + "EarlySignal", + "FalseAlarmRatio", + "FalseNegatives", + "FalsePositives", + "LandfallDisplacement", + "LandfallIntensityMeanAbsoluteError", + "LandfallMetric", + "LandfallTimeMeanError", + "MaximumLowestMeanAbsoluteError", + "MaximumMeanAbsoluteError", + "MeanAbsoluteError", + "MeanError", + "MeanSquaredError", + "MinimumMeanAbsoluteError", + "RootMeanSquaredError", + "SpatialDisplacement", + "ThresholdMetric", + "TrueNegatives", + "TruePositives", + # regions + "BoundingBoxRegion", + "CenteredRegion", + "REGION_TYPES", + "Region", + "RegionSubsetter", + "ShapefileRegion", + "map_to_create_region", + "subset_cases_to_region", + "subset_results_to_region", + # utils + "check_for_vars", + "convert_day_yearofday_to_time", + "convert_init_time_to_valid_time", + "convert_longitude_to_180", + "convert_longitude_to_360", + "convert_valid_time_to_init_time", + "derive_indices_from_init_time_and_lead_time", + "determine_temporal_resolution", + "extract_tc_names", + "filter_kwargs_for_callable", + "find_common_init_times", + "idx_to_coords", + "interp_climatology_to_target", + "is_valid_landfall", + "load_land_geometry", + "maybe_cache_and_compute", + "maybe_densify_dataarray", + "maybe_get_closest_timestamp_to_center_of_valid_times", + "maybe_get_operator", + "min_if_all_timesteps_present", + "min_if_all_timesteps_present_forecast", + "read_event_yaml", + "remove_ocean_gridpoints", + "stack_dataarray_from_dims", +] diff --git a/src/extremeweatherbench/cases.py b/src/extremeweatherbench/cases.py index 3f2e858c..236104c9 100644 --- a/src/extremeweatherbench/cases.py +++ b/src/extremeweatherbench/cases.py @@ -14,10 +14,11 @@ import dacite import yaml # type: ignore[import] -from extremeweatherbench import regions +import extremeweatherbench.regions as regions if TYPE_CHECKING: - from extremeweatherbench import inputs, metrics + import extremeweatherbench.inputs as inputs + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index 7dcc68d6..41d20e16 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -309,7 +309,7 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: routine. """ # Import metrics here to avoid circular import - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics heatwave_metric_list: list[metrics.BaseMetric] = [ metrics.MaximumMeanAbsoluteError(), diff --git a/src/extremeweatherbench/derived.py b/src/extremeweatherbench/derived.py index 0e609fbe..7517d232 100644 --- a/src/extremeweatherbench/derived.py +++ b/src/extremeweatherbench/derived.py @@ -10,7 +10,7 @@ from extremeweatherbench.events import tropical_cyclone if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 64c0ba96..504d7694 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -15,10 +15,15 @@ from tqdm.contrib.logging import logging_redirect_tqdm from tqdm.dask import TqdmCallback -from extremeweatherbench import cases, derived, inputs, metrics, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.metrics as metrics +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions logger = logging.getLogger(__name__) @@ -101,11 +106,11 @@ def run( parallel_config: Optional[dict] = None, **kwargs, ) -> pd.DataFrame: - """Runs the ExtremeWeatherBench workflow. + """Runs the ExtremeWeatherBench evaluation workflow. - This method will run the workflow in the order of the case operators, optionally - caching the mid-flight outputs of the workflow if cache_dir was provided for - serial runs. + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. Args: n_jobs: The number of jobs to run in parallel. If None, defaults to the @@ -113,16 +118,60 @@ def run( Ignored if parallel_config is provided. parallel_config: Optional dictionary of joblib parallel configuration. If provided, this takes precedence over n_jobs. If not provided and - n_jobs is specified, a default config with loky backend is used. + n_jobs is specified, a default config with the loky backend is used. + **kwargs: Additional arguments to pass to compute_case_operator. + Returns: + A concatenated dataframe of the evaluation results. + """ + logger.warning("The run method is deprecated. Use run_evaluation instead.") + logger.info("Running ExtremeWeatherBench evaluations...") + + # Check for serial or parallel configuration + parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) + + run_results = _run_evaluation( + self.case_operators, + cache_dir=self.cache_dir, + parallel_config=parallel_config, + **kwargs, + ) + + # If there are results, concatenate them and return, else return an empty + # DataFrame with the expected columns + if run_results: + return _safe_concat(run_results, ignore_index=True) + else: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + def run_evaluation( + self, + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, + **kwargs, + ) -> pd.DataFrame: + """Runs the ExtremeWeatherBench evaluation workflow. + + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + Ignored if parallel_config is provided. + parallel_config: Optional dictionary of joblib parallel configuration. + If provided, this takes precedence over n_jobs. If not provided and + n_jobs is specified, a default config with the loky backend is used. + **kwargs: Additional arguments to pass to compute_case_operator. Returns: A concatenated dataframe of the evaluation results. """ - logger.info("Running ExtremeWeatherBench workflow...") + logger.info("Running ExtremeWeatherBench evaluations...") # Check for serial or parallel configuration parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) - run_results = _run_case_operators( + + run_results = _run_evaluation( self.case_operators, cache_dir=self.cache_dir, parallel_config=parallel_config, @@ -137,7 +186,48 @@ def run( return pd.DataFrame(columns=OUTPUT_COLUMNS) -def _run_case_operators( +def _parallel_serial_config_check( + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, +) -> Optional[dict]: + """Check if running in serial or parallel mode. + + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + parallel_config: Optional dictionary of joblib parallel configuration. If + provided, this takes precedence over n_jobs. If not provided and n_jobs is + specified, a default config with loky backend is used. + Returns: + None if running in serial mode, otherwise a dictionary of joblib parallel + configuration. + """ + # Determine if running in serial or parallel mode + # Serial: n_jobs=1 or (parallel_config with n_jobs=1) + # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) + is_serial = ( + (n_jobs == 1) + or (parallel_config is not None and parallel_config.get("n_jobs") == 1) + or (n_jobs is None and parallel_config is None) + ) + logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") + + if not is_serial: + # Build parallel_config if not provided + if parallel_config is None and n_jobs is not None: + logger.debug( + "No parallel_config provided, using loky backend and %s jobs.", + n_jobs, + ) + parallel_config = {"backend": "loky", "n_jobs": n_jobs} + # If running in serial mode, set parallel_config to None if not already + else: + parallel_config = None + # Return the maybe updated kwargs + return parallel_config + + +def _run_evaluation( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, parallel_config: Optional[dict] = None, @@ -154,37 +244,29 @@ def _run_case_operators( Returns: List of result DataFrames. """ - with logging_redirect_tqdm(): - # Run in parallel if parallel_config exists and n_jobs != 1 - if parallel_config is not None: + if parallel_config is not None: + with logging_redirect_tqdm(): logger.info("Running case operators in parallel...") - return _run_parallel( + run_results = _run_parallel_evaluation( case_operators, cache_dir=cache_dir, parallel_config=parallel_config, **kwargs, ) - else: - logger.info("Running case operators in serial...") - return _run_serial(case_operators, cache_dir=cache_dir, **kwargs) - - -def _run_serial( - case_operators: list["cases.CaseOperator"], - cache_dir: Optional[pathlib.Path] = None, - **kwargs, -) -> list[pd.DataFrame]: - """Run the case operators in serial.""" - run_results = [] + else: + logger.info("Running case operators in serial...") + run_results = [] + for case_operator in tqdm(case_operators): + run_results.append( + compute_case_operator(case_operator, cache_dir, **kwargs) + ) - # Loop over the case operators - for case_operator in tqdm(case_operators): - run_results.append(compute_case_operator(case_operator, cache_dir, **kwargs)) return run_results -def _run_parallel( +def _run_parallel_evaluation( case_operators: list["cases.CaseOperator"], + parallel_config: dict, cache_dir: Optional[pathlib.Path] = None, **kwargs, ) -> list[pd.DataFrame]: @@ -197,11 +279,6 @@ def _run_parallel( Returns: List of result DataFrames. """ - parallel_config = kwargs.pop("parallel_config", None) - - if parallel_config is None: - raise ValueError("parallel_config must be provided to _run_parallel") - if parallel_config.get("n_jobs") is None: logger.warning("No number of jobs provided, using joblib backend default.") @@ -900,44 +977,3 @@ def _safe_concat( return pd.concat(valid_dfs, ignore_index=ignore_index) else: return pd.DataFrame(columns=OUTPUT_COLUMNS) - - -def _parallel_serial_config_check( - n_jobs: Optional[int] = None, - parallel_config: Optional[dict] = None, -) -> Optional[dict]: - """Check if running in serial or parallel mode. - - Args: - n_jobs: The number of jobs to run in parallel. If None, defaults to the - joblib backend default value. If 1, the workflow will run serially. - parallel_config: Optional dictionary of joblib parallel configuration. If - provided, this takes precedence over n_jobs. If not provided and n_jobs is - specified, a default config with loky backend is used. - Returns: - None if running in serial mode, otherwise a dictionary of joblib parallel - configuration. - """ - # Determine if running in serial or parallel mode - # Serial: n_jobs=1 or (parallel_config with n_jobs=1) - # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) - is_serial = ( - (n_jobs == 1) - or (parallel_config is not None and parallel_config.get("n_jobs") == 1) - or (n_jobs is None and parallel_config is None) - ) - logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") - - if not is_serial: - # Build parallel_config if not provided - if parallel_config is None and n_jobs is not None: - logger.debug( - "No parallel_config provided, using loky backend and %s jobs.", - n_jobs, - ) - parallel_config = {"backend": "loky", "n_jobs": n_jobs} - # If running in serial mode, set parallel_config to None if not already - else: - parallel_config = None - # Return the maybe updated kwargs - return parallel_config diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index 9b978269..1a96dfc1 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -7,7 +7,9 @@ import click import pandas as pd -from extremeweatherbench import cases, defaults, evaluate +import extremeweatherbench.cases as cases +import extremeweatherbench.defaults as defaults +import extremeweatherbench.evaluate as evaluate @click.command() @@ -152,7 +154,7 @@ def cli_runner( # Run evaluation click.echo("Running evaluation...") - results = ewb.run( + results = ewb.run_evaluation( n_jobs=n_jobs, parallel_config=parallel_config, ) diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index 8708dde9..37aba5a5 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -19,10 +19,13 @@ import polars as pl import xarray as xr -from extremeweatherbench import cases, derived, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index 5a36dc0e..7f86189e 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -16,7 +16,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) @@ -361,8 +361,14 @@ def mask(self, dataset: xr.Dataset, drop: bool = False) -> xr.Dataset: # Note: ShapefileRegion.mask uses slice which doesn't support # prime/antimeridian crossing with OR logic, but regionmask handles it + # Check if latitude is ascending or descending to handle slice correctly + lat_ascending = dataset.latitude[0] < dataset.latitude[-1] + if lat_ascending: + lat_slice = slice(latitude_min, latitude_max) + else: + lat_slice = slice(latitude_max, latitude_min) dataset = dataset.sel( - latitude=slice(latitude_max, latitude_min), + latitude=lat_slice, longitude=slice(longitude_min, longitude_max), drop=drop, ) diff --git a/src/extremeweatherbench/sources/base.py b/src/extremeweatherbench/sources/base.py index dd58641c..e7dbda6e 100644 --- a/src/extremeweatherbench/sources/base.py +++ b/src/extremeweatherbench/sources/base.py @@ -1,7 +1,7 @@ import datetime from typing import Any, Protocol, runtime_checkable -from extremeweatherbench import regions +import extremeweatherbench.regions as regions @runtime_checkable diff --git a/src/extremeweatherbench/sources/pandas_dataframe.py b/src/extremeweatherbench/sources/pandas_dataframe.py index 31bc4062..b6eb91a9 100644 --- a/src/extremeweatherbench/sources/pandas_dataframe.py +++ b/src/extremeweatherbench/sources/pandas_dataframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -43,7 +43,7 @@ def safely_pull_variables( >>> list(result.columns) ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from DataFrame available_columns = list(data.columns) diff --git a/src/extremeweatherbench/sources/polars_lazyframe.py b/src/extremeweatherbench/sources/polars_lazyframe.py index e9e56cf4..f0caa41e 100644 --- a/src/extremeweatherbench/sources/polars_lazyframe.py +++ b/src/extremeweatherbench/sources/polars_lazyframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -47,7 +47,7 @@ def safely_pull_variables( >>> result.collect().columns ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from LazyFrame available_columns = data.collect_schema().names() diff --git a/src/extremeweatherbench/sources/xarray_dataarray.py b/src/extremeweatherbench/sources/xarray_dataarray.py index e58d82d6..f3b4e734 100644 --- a/src/extremeweatherbench/sources/xarray_dataarray.py +++ b/src/extremeweatherbench/sources/xarray_dataarray.py @@ -5,7 +5,8 @@ import pandas as pd import xarray as xr -from extremeweatherbench import regions, utils +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils def safely_pull_variables( diff --git a/src/extremeweatherbench/sources/xarray_dataset.py b/src/extremeweatherbench/sources/xarray_dataset.py index 56d52618..ae8e8b89 100644 --- a/src/extremeweatherbench/sources/xarray_dataset.py +++ b/src/extremeweatherbench/sources/xarray_dataset.py @@ -9,7 +9,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( diff --git a/tests/data/golden_tests.yaml b/tests/data/golden_tests.yaml new file mode 100644 index 00000000..35036e6b --- /dev/null +++ b/tests/data/golden_tests.yaml @@ -0,0 +1,57 @@ +- case_id_number: 1 + title: NYC Heat Wave + start_date: 2022-06-19 12:00:00 + end_date: 2022-06-24 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 40.5 + latitude_max: 41.5 + longitude_min: -75 + longitude_max: -73.5 + event_type: heat_wave +- case_id_number: 2 + title: Europe Freeze + start_date: 2022-12-14 06:00:00 + end_date: 2022-12-18 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 50 + latitude_max: 55 + longitude_min: -5 + longitude_max: 5 + event_type: freeze +- case_id_number: 3 + title: April 2022 South Carolina + start_date: 2022-04-05 12:00:00 + end_date: 2022-04-06 12:00:00 + location: + type: shapefile_region + parameters: + shapefile_path: tests/data/south_carolina_110m.shp + event_type: severe_convection +- case_id_number: 4 + title: Atmospheric River Alaska + start_date: 2021-06-24 00:00:00 + end_date: 2021-06-27 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 50 + latitude_max: 55 + longitude_min: 185 + longitude_max: 200 + event_type: atmospheric_river +- case_id_number: 5 + title: Tropical Cyclone Max + start_date: 2023-10-06 00:00:00 + end_date: 2023-10-12 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 11.6 + latitude_max: 20.1 + longitude_min: 256.1 + longitude_max: 262.4 + event_type: tropical_cyclone diff --git a/tests/data/south_carolina_110m.dbf b/tests/data/south_carolina_110m.dbf new file mode 100644 index 00000000..cdc66c52 Binary files /dev/null and b/tests/data/south_carolina_110m.dbf differ diff --git a/tests/data/south_carolina_110m.prj b/tests/data/south_carolina_110m.prj new file mode 100644 index 00000000..f45cbadf --- /dev/null +++ b/tests/data/south_carolina_110m.prj @@ -0,0 +1 @@ +GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] \ No newline at end of file diff --git a/tests/data/south_carolina_110m.shp b/tests/data/south_carolina_110m.shp new file mode 100644 index 00000000..b08f9227 Binary files /dev/null and b/tests/data/south_carolina_110m.shp differ diff --git a/tests/data/south_carolina_110m.shx b/tests/data/south_carolina_110m.shx new file mode 100644 index 00000000..5070f67f Binary files /dev/null and b/tests/data/south_carolina_110m.shx differ diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 9220093f..18569e6b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -332,10 +332,10 @@ def test_case_operators_property( # Check that the result is what the mock returned assert result == [sample_case_operator] - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_serial( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_serial_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -345,7 +345,7 @@ def test_run_serial( with mock.patch.object( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): - # Mock _run_case_operators to return a list of DataFrames + # Mock _run_evaluation to return a list of DataFrames mock_result = [ pd.DataFrame( { @@ -355,17 +355,17 @@ def test_run_serial( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1) + result = ewb.run_evaluation(n_jobs=1) - # Serial mode passes parallel_config=None - mock_run_case_operators.assert_called_once_with( + # Serial mode should pass parallel_config=None + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config=None, @@ -373,10 +373,10 @@ def test_run_serial( assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_parallel( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_parallel_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -394,16 +394,16 @@ def test_run_parallel( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=2) + result = ewb.run_evaluation(n_jobs=2) - mock_run_case_operators.assert_called_once_with( + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "loky", "n_jobs": 2}, @@ -411,10 +411,10 @@ def test_run_parallel( assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_with_kwargs( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -424,37 +424,37 @@ def test_run_with_kwargs( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): mock_result = [pd.DataFrame({"value": [1.0]})] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1, threshold=0.5) + result = ewb.run_evaluation(n_jobs=1, threshold=0.5) # Check that kwargs were passed through - call_args = mock_run_case_operators.call_args + call_args = mock_run_evaluation.call_args assert call_args[1]["threshold"] == 0.5 assert isinstance(result, pd.DataFrame) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_empty_results( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, ): """Test the run method handles empty results.""" with mock.patch.object(evaluate.ExtremeWeatherBench, "case_operators", new=[]): - mock_run_case_operators.return_value = [] + mock_run_evaluation.return_value = [] ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert isinstance(result, pd.DataFrame) assert len(result) == 0 @@ -504,7 +504,7 @@ def mock_compute_with_caching(case_operator, cache_dir_arg, **kwargs): cache_dir=cache_dir, ) - ewb.run(n_jobs=1) + ewb.run_evaluation(n_jobs=1) # Check that cache directory was created assert cache_dir.exists() @@ -538,7 +538,7 @@ def test_run_multiple_cases( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -546,107 +546,111 @@ def test_run_multiple_cases( class TestRunCaseOperators: - """Test the _run_case_operators function.""" + """Test the _run_evaluation function.""" - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial(self, mock_run_serial, sample_case_operator): - """Test _run_case_operators routes to serial execution.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_serial( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator + ): + """Test _run_evaluation executes serially when parallel_config=None.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([sample_case_operator], cache_dir=None) + result = evaluate._run_evaluation([sample_case_operator], cache_dir=None) - mock_run_serial.assert_called_once_with([sample_case_operator], cache_dir=None) - assert result == mock_results + mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) + assert len(result) == 1 + assert result[0].equals(mock_results) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel(self, mock_run_parallel, sample_case_operator): - """Test _run_case_operators routes to parallel execution.""" + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel( + self, mock_run_parallel_evaluation, sample_case_operator + ): + """Test _run_evaluation routes to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) - mock_run_parallel.assert_called_once_with( + mock_run_parallel_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) assert result == mock_results - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_with_kwargs( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_with_kwargs( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_case_operators passes kwargs correctly.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + """Test _run_evaluation passes kwargs correctly in serial mode.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, threshold=0.5, ) - call_args = mock_run_serial.call_args - assert call_args[0][0] == [sample_case_operator] - assert call_args[1]["cache_dir"] is None + call_args = mock_compute_case_operator.call_args + assert call_args[0][0] == sample_case_operator + assert call_args[0][1] is None # cache_dir assert call_args[1]["threshold"] == 0.5 assert isinstance(result, list) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_with_kwargs( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_with_kwargs( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators passes kwargs to parallel execution.""" + """Test _run_evaluation passes kwargs to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="test_value", ) - call_args = mock_run_parallel.call_args + call_args = mock_run_parallel_evaluation.call_args assert call_args[0][0] == [sample_case_operator] assert call_args[1]["parallel_config"] == {"backend": "threading", "n_jobs": 2} assert call_args[1]["custom_param"] == "test_value" assert isinstance(result, list) - def test_run_case_operators_empty_list(self): - """Test _run_case_operators with empty case operator list.""" - with mock.patch("extremeweatherbench.evaluate._run_serial") as mock_serial: - mock_serial.return_value = [] - - # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([], cache_dir=None) - - mock_serial.assert_called_once_with([], cache_dir=None) - assert result == [] + def test_run_evaluation_empty_list(self): + """Test _run_evaluation with empty case operator list.""" + # Serial mode: don't pass parallel_config + result = evaluate._run_evaluation([], cache_dir=None) + assert result == [] class TestRunSerial: - """Test the _run_serial function.""" + """Test the serial execution path of _run_evaluation.""" @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_basic( + def test_run_serial_evaluation_basic( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test basic _run_serial functionality.""" + """Test basic serial execution functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] # tqdm returns iterable mock_result = pd.DataFrame({"value": [1.0], "case_id_number": [1]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial([sample_case_operator]) + result = evaluate._run_evaluation([sample_case_operator], parallel_config=None) mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 @@ -654,8 +658,10 @@ def test_run_serial_basic( @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial with multiple case operators.""" + def test_run_serial_evaluation_multiple_cases( + self, mock_tqdm, mock_compute_case_operator + ): + """Test serial execution with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -666,7 +672,7 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): pd.DataFrame({"value": [2.0], "case_id_number": [2]}), ] - result = evaluate._run_serial(case_operators) + result = evaluate._run_evaluation(case_operators, parallel_config=None) assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -675,16 +681,19 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_with_kwargs( + def test_run_serial_evaluation_with_kwargs( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial passes kwargs to compute_case_operator.""" + """Test serial execution passes kwargs to compute_case_operator.""" mock_tqdm.return_value = [sample_case_operator] mock_result = pd.DataFrame({"value": [1.0]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial( - [sample_case_operator], threshold=0.7, custom_param="test" + result = evaluate._run_evaluation( + [sample_case_operator], + parallel_config=None, + threshold=0.7, + custom_param="test", ) call_args = mock_compute_case_operator.call_args @@ -693,22 +702,22 @@ def test_run_serial_with_kwargs( assert call_args[1]["custom_param"] == "test" assert isinstance(result, list) - def test_run_serial_empty_list(self): - """Test _run_serial with empty case operator list.""" - result = evaluate._run_serial([]) + def test_run_serial_evaluation_empty_list(self): + """Test serial execution with empty case operator list.""" + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] class TestRunParallel: - """Test the _run_parallel function.""" + """Test the _run_parallel_evaluation function.""" @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_basic( + def test_run_parallel_evaluation_basic( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test basic _run_parallel functionality.""" + """Test basic _run_parallel_evaluation functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() @@ -719,7 +728,7 @@ def test_run_parallel_basic( mock_result = [pd.DataFrame({"value": [1.0], "case_id_number": [1]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -735,10 +744,10 @@ def test_run_parallel_basic( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_none_n_jobs( + def test_run_parallel_evaluation_with_none_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with n_jobs=None (should use all CPUs).""" + """Test _run_parallel_evaluation with n_jobs=None (should use all CPUs).""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -749,7 +758,7 @@ def test_run_parallel_with_none_n_jobs( mock_parallel_instance.return_value = mock_result with mock.patch("extremeweatherbench.evaluate.logger.warning") as mock_warning: - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": None}, ) @@ -765,7 +774,7 @@ def test_run_parallel_with_none_n_jobs( @mock.patch("joblib.parallel_config") @mock.patch("extremeweatherbench.utils.ParallelTqdm") - def test_run_parallel_n_jobs_in_config( + def test_run_parallel_evaluation_n_jobs_in_config( self, mock_parallel_class, mock_parallel_config ): """Test that n_jobs is passed through parallel_config, not directly.""" @@ -782,7 +791,7 @@ def test_run_parallel_n_jobs_in_config( ) mock_parallel_config.return_value.__exit__ = mock.Mock(return_value=False) - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 4}, ) @@ -800,10 +809,10 @@ def test_run_parallel_n_jobs_in_config( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_multiple_cases( + def test_run_parallel_evaluation_multiple_cases( self, mock_tqdm, mock_delayed, mock_parallel_class ): - """Test _run_parallel with multiple case operators.""" + """Test _run_parallel_evaluation with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -820,7 +829,7 @@ def test_run_parallel_multiple_cases( ] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) @@ -831,10 +840,10 @@ def test_run_parallel_multiple_cases( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_kwargs( + def test_run_parallel_evaluation_with_kwargs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel passes kwargs correctly.""" + """Test _run_parallel_evaluation passes kwargs correctly.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -844,7 +853,7 @@ def test_run_parallel_with_kwargs( mock_result = [pd.DataFrame({"value": [1.0]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, threshold=0.8, @@ -861,8 +870,8 @@ def test_run_parallel_with_kwargs( assert len(delayed_calls) == 1 assert isinstance(result, list) - def test_run_parallel_empty_list(self): - """Test _run_parallel with empty case operator list.""" + def test_run_parallel_evaluation_empty_list(self): + """Test _run_parallel_evaluation with empty case operator list.""" with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -872,7 +881,7 @@ def test_run_parallel_empty_list(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) @@ -883,10 +892,10 @@ def test_run_parallel_empty_list(self): ) @mock.patch("dask.distributed.Client") @mock.patch("dask.distributed.LocalCluster") - def test_run_parallel_dask_backend_auto_client( + def test_run_parallel_evaluation_dask_backend_auto_client( self, mock_local_cluster, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend automatically creates client.""" + """Test _run_parallel_evaluation with dask backend automatically creates client.""" # Mock Client.current() to raise ValueError (no existing client) mock_client_class.current.side_effect = ValueError("No client found") @@ -905,7 +914,7 @@ def test_run_parallel_dask_backend_auto_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -919,10 +928,10 @@ def test_run_parallel_dask_backend_auto_client( not HAS_DASK_DISTRIBUTED, reason="dask.distributed not installed" ) @mock.patch("dask.distributed.Client") - def test_run_parallel_dask_backend_existing_client( + def test_run_parallel_evaluation_dask_backend_existing_client( self, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend uses existing client.""" + """Test _run_parallel_evaluation with dask backend uses existing client.""" # Mock existing client mock_existing_client = mock.Mock() mock_client_class.current.return_value = mock_existing_client @@ -934,7 +943,7 @@ def test_run_parallel_dask_backend_existing_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -1385,7 +1394,7 @@ def test_run_pipeline_target( def test_run_pipeline_invalid_source(self, sample_case_operator): """Test run_pipeline function with invalid input source.""" with pytest.raises(AttributeError, match="'str' object has no attribute"): - evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") + evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") # type: ignore def test_maybe_cache_and_compute_with_cache_dir( self, sample_forecast_dataset, sample_target_dataset, sample_individual_case @@ -1600,9 +1609,14 @@ def test_extremeweatherbench_empty_cases(self, sample_evaluation_object): evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() - assert isinstance(result, pd.DataFrame) - assert len(result) == 0 + with mock.patch("extremeweatherbench.cases.build_case_operators") as mock_build: + mock_build.return_value = [] + + result = ewb.run_evaluation() + + # Should return empty DataFrame when no cases + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 def test_compute_case_operator_exception_handling(self, sample_case_operator): """Test exception handling in compute_case_operator.""" @@ -1645,49 +1659,53 @@ def test_evaluate_metric_computation_failure( case_operator=sample_case_operator, ) - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial_exception( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_serial_exception( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_case_operators handles exceptions in serial execution.""" - mock_run_serial.side_effect = Exception("Serial execution failed") + """Test _run_evaluation handles exceptions in serial execution.""" + mock_tqdm.return_value = [sample_case_operator] + mock_compute_case_operator.side_effect = Exception("Serial execution failed") with pytest.raises(Exception, match="Serial execution failed"): # Serial mode: don't pass parallel_config - evaluate._run_case_operators([sample_case_operator], None) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_exception( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_exception( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators handles exceptions in parallel execution.""" - mock_run_parallel.side_effect = Exception("Parallel execution failed") + """Test _run_evaluation handles exceptions in parallel execution.""" + mock_run_parallel_evaluation.side_effect = Exception( + "Parallel execution failed" + ) with pytest.raises(Exception, match="Parallel execution failed"): - evaluate._run_case_operators( + evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_case_operator_exception( + def test_run_serial_evaluation_case_operator_exception( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial handles exceptions from individual case operators.""" + """Test serial execution handles exceptions from individual case operators.""" mock_tqdm.return_value = [sample_case_operator] mock_compute_case_operator.side_effect = Exception("Case operator failed") with pytest.raises(Exception, match="Case operator failed"): - evaluate._run_serial([sample_case_operator]) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_joblib_exception( + def test_run_parallel_evaluation_joblib_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles joblib Parallel exceptions.""" + """Test _run_parallel_evaluation handles joblib Parallel exceptions.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1697,7 +1715,7 @@ def test_run_parallel_joblib_exception( mock_parallel_instance.side_effect = Exception("Joblib parallel failed") with pytest.raises(Exception, match="Joblib parallel failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -1705,10 +1723,10 @@ def test_run_parallel_joblib_exception( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_delayed_function_exception( + def test_run_parallel_evaluation_delayed_function_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles exceptions in delayed functions.""" + """Test _run_parallel_evaluation handles exceptions in delayed functions.""" mock_tqdm.return_value = [sample_case_operator] # Mock delayed to raise an exception @@ -1724,12 +1742,12 @@ def consume_generator(generator): mock_parallel_instance.side_effect = consume_generator with pytest.raises(Exception, match="Delayed function creation failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_method_exception_propagation( self, mock_run_case_operators, sample_cases_list, sample_evaluation_object ): @@ -1742,12 +1760,14 @@ def test_run_method_exception_propagation( ) with pytest.raises(Exception, match="Execution failed"): - ewb.run() + ewb.run_evaluation() @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial behavior when some case operators fail.""" + def test_run_serial_evaluation_partial_failure( + self, mock_tqdm, mock_compute_case_operator + ): + """Test serial execution behavior when some case operators fail.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_op_3 = mock.Mock() @@ -1764,7 +1784,7 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) # Should fail on the second case operator with pytest.raises(Exception, match="Case operator 2 failed"): - evaluate._run_serial(case_operators) + evaluate._run_evaluation(case_operators, parallel_config=None) # Should have tried only the first two assert mock_compute_case_operator.call_count == 2 @@ -1772,10 +1792,10 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_invalid_n_jobs( + def test_run_parallel_evaluation_invalid_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with invalid n_jobs parameter.""" + """Test _run_parallel_evaluation with invalid n_jobs parameter.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1784,7 +1804,7 @@ def test_run_parallel_invalid_n_jobs( mock_parallel_class.side_effect = ValueError("Invalid n_jobs parameter") with pytest.raises(ValueError, match="Invalid n_jobs parameter"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": -5}, ) @@ -1869,7 +1889,7 @@ def test_end_to_end_workflow( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify the result structure assert isinstance(result, pd.DataFrame) @@ -1973,7 +1993,7 @@ def test_multiple_variables_and_metrics( evaluation_objects=[eval_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Should have results for each metric combination assert len(result) >= 2 # At least 2 metrics * 1 case @@ -2016,12 +2036,12 @@ def test_serial_vs_parallel_results_consistency( # Test serial execution mock_compute_case_operator.side_effect = [result_1, result_2] - serial_result = ewb.run(n_jobs=1) + serial_result = ewb.run_evaluation(n_jobs=1) # Reset mock and test parallel execution mock_compute_case_operator.reset_mock() mock_compute_case_operator.side_effect = [result_1, result_2] - parallel_result = ewb.run(n_jobs=2) + parallel_result = ewb.run_evaluation(n_jobs=2) # Both should produce valid DataFrames with same structure assert isinstance(serial_result, pd.DataFrame) @@ -2031,12 +2051,16 @@ def test_serial_vs_parallel_results_consistency( assert list(serial_result.columns) == list(parallel_result.columns) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_execution_method_performance_comparison(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_execution_method_performance_comparison( + self, mock_tqdm, mock_compute_case_operator + ): """Test that both execution methods handle the same workload.""" import time # Create many case operators to simulate realistic workload case_operators = [mock.Mock() for _ in range(10)] + mock_tqdm.return_value = case_operators # Mock results mock_results = [ @@ -2051,13 +2075,13 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato for i in range(10) ] - # Test serial execution timing - call _run_serial directly + # Test serial execution timing - call _run_evaluation in serial mode mock_compute_case_operator.side_effect = mock_results start_time = time.time() - serial_result = evaluate._run_serial(case_operators) + serial_result = evaluate._run_evaluation(case_operators, parallel_config=None) serial_time = time.time() - start_time - # Test parallel execution timing - call _run_parallel directly with mocked + # Test parallel execution timing - call _run_parallel_evaluation directly with mocked # Parallel serial_call_count = mock_compute_case_operator.call_count mock_compute_case_operator.side_effect = mock_results @@ -2070,7 +2094,7 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato mock_parallel_instance.return_value = mock_results start_time = time.time() - parallel_result = evaluate._run_parallel( + parallel_result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 2} ) parallel_time = time.time() - start_time @@ -2087,9 +2111,11 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato assert parallel_time >= 0 @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_mixed_execution_parameters(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_mixed_execution_parameters(self, mock_tqdm, mock_compute_case_operator): """Test various parameter combinations for execution methods.""" case_operators = [mock.Mock(), mock.Mock()] + mock_tqdm.return_value = case_operators mock_results = [ pd.DataFrame({"value": [1.0], "case_id_number": [1]}), pd.DataFrame({"value": [2.0], "case_id_number": [2]}), @@ -2112,7 +2138,7 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): mock_compute_case_operator.side_effect = mock_results if config["method"] == "serial": - result = evaluate._run_serial(*config["args"]) + result = evaluate._run_evaluation(*config["args"], parallel_config=None) # All configurations should produce valid results assert isinstance(result, list) assert len(result) == 2 @@ -2135,7 +2161,9 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): "n_jobs": n_jobs, } - result = evaluate._run_parallel(*config["args"], **kwargs) + result = evaluate._run_parallel_evaluation( + *config["args"], **kwargs + ) # All configurations should produce valid results assert isinstance(result, list) @@ -2156,13 +2184,19 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): mock_compute_with_kwargs.captured_kwargs = {} - with mock.patch( - "extremeweatherbench.evaluate.compute_case_operator", - side_effect=mock_compute_with_kwargs, + with ( + mock.patch( + "extremeweatherbench.evaluate.compute_case_operator", + side_effect=mock_compute_with_kwargs, + ), + mock.patch("tqdm.auto.tqdm", return_value=[case_operator]), ): # Test serial kwargs propagation - result = evaluate._run_serial( - [case_operator], custom_param="serial_test", threshold=0.9 + result = evaluate._run_evaluation( + [case_operator], + parallel_config=None, + custom_param="serial_test", + threshold=0.9, ) captured = mock_compute_with_kwargs.captured_kwargs @@ -2185,7 +2219,7 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): # Reset captured kwargs mock_compute_with_kwargs.captured_kwargs = {} - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="parallel_test", @@ -2198,20 +2232,20 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): def test_empty_case_operators_all_methods(self): """Test that all execution methods handle empty case operator lists.""" - # Test _run_case_operators - result = evaluate._run_case_operators([], parallel_config={"n_jobs": 1}) + # Test _run_evaluation with parallel config + result = evaluate._run_evaluation([], parallel_config={"n_jobs": 1}) assert result == [] - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] - # Test _run_serial - result = evaluate._run_serial([]) + # Test _run_evaluation in serial mode + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] - # Test _run_parallel + # Test _run_parallel_evaluation with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -2219,17 +2253,21 @@ def test_empty_case_operators_all_methods(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_large_case_operator_list_handling(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_large_case_operator_list_handling( + self, mock_tqdm, mock_compute_case_operator + ): """Test handling of large numbers of case operators.""" # Create a large list of case operators num_cases = 100 case_operators = [mock.Mock() for _ in range(num_cases)] + mock_tqdm.return_value = case_operators # Create mock results mock_results = [ @@ -2241,7 +2279,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): # Test serial execution mock_compute_case_operator.side_effect = mock_results - serial_results = evaluate._run_serial(case_operators) + serial_results = evaluate._run_evaluation(case_operators, parallel_config=None) assert len(serial_results) == num_cases assert mock_compute_case_operator.call_count == num_cases @@ -2257,7 +2295,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = mock_results - parallel_results = evaluate._run_parallel( + parallel_results = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) diff --git a/tests/test_evaluate_cli.py b/tests/test_evaluate_cli.py index e5797712..ca71d408 100644 --- a/tests/test_evaluate_cli.py +++ b/tests/test_evaluate_cli.py @@ -1,12 +1,9 @@ """Tests for the evaluate_cli interface.""" import pickle -import tempfile import textwrap -from pathlib import Path from unittest import mock -import click.testing import pandas as pd import pytest @@ -23,23 +20,6 @@ def suppress_cli_output(): yield -@pytest.fixture -def runner(): - """Create a Click test runner with output suppression.""" - return click.testing.CliRunner() - - -@pytest.fixture -def temp_config_dir(): - """Create a temporary directory for config files and test outputs. - - This ensures all test files are created in temporary directories and automatically - cleaned up after each test. - """ - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - @pytest.fixture def sample_config_py(temp_config_dir): """Create a sample Python config file.""" @@ -88,7 +68,7 @@ def test_default_mode_basic( # Mock the ExtremeWeatherBench class and its methods mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock()] # Mock 2 case operators - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2]}) mock_ewb_class.return_value = mock_ewb # Mock loading default cases @@ -100,7 +80,7 @@ def test_default_mode_basic( assert result.exit_code == 0 mock_ewb_class.assert_called_once() - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() @mock.patch( "extremeweatherbench.defaults.get_brightband_evaluation_objects", @@ -119,7 +99,7 @@ def test_default_mode_with_cache_dir( """Test default mode with cache directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -145,7 +125,7 @@ def test_config_file_mode_basic( """Test basic config file mode execution.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1]}) mock_ewb_class.return_value = mock_ewb result = runner.invoke( @@ -219,15 +199,15 @@ def test_parallel_execution( """Test parallel execution mode.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock(), mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2, 3]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2, 3]}) mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] result = runner.invoke(evaluate_cli.cli_runner, ["--default", "--n-jobs", "3"]) assert result.exit_code == 0 - # Verify ewb.run was called with parallel config - mock_ewb.run.assert_called_once_with( + # Verify ewb.run_evaluation was called with parallel config + mock_ewb.run_evaluation.assert_called_once_with( n_jobs=3, parallel_config=None, ) @@ -244,7 +224,7 @@ def test_serial_execution_default( """Test that serial execution is default (parallel=1).""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -252,7 +232,7 @@ def test_serial_execution_default( assert result.exit_code == 0 # Output suppressed - only check exit code - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() class TestCaseOperatorSaving: @@ -278,7 +258,7 @@ def test_save_case_operators( mock_case_op2 = {"id": 2, "type": "test_case_op"} mock_ewb = mock.Mock() mock_ewb.case_operators = [mock_case_op1, mock_case_op2] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -317,7 +297,7 @@ def test_save_case_operators_creates_directory( """Test that saving case operators creates parent directories.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -369,7 +349,7 @@ def test_output_directory_creation( """Test that output directory is created if it doesn't exist.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -395,7 +375,7 @@ def test_default_output_directory( """Test that default output directory is current working directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -435,7 +415,7 @@ def test_results_saved_to_csv( mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = mock_results + mock_ewb.run_evaluation.return_value = mock_results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -456,7 +436,7 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): """Test handling when no results are returned.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() # Empty results + mock_ewb.run_evaluation.return_value = pd.DataFrame() # Empty results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -469,7 +449,9 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): class TestHelperFunctions: """Test helper function functionality.""" - @mock.patch("extremeweatherbench.cases.load_ewb_events_yaml_into_case_list") + @mock.patch( + "extremeweatherbench.evaluate_cli.cases.load_ewb_events_yaml_into_case_list" + ) def test_load_default_cases(self, mock_load_yaml): """Test _load_default_cases function.""" mock_cases = [{"id": 1}] diff --git a/tests/test_golden.py b/tests/test_golden.py new file mode 100644 index 00000000..8476f0ba --- /dev/null +++ b/tests/test_golden.py @@ -0,0 +1,282 @@ +"""Tests which use the full end-to-end EWB workflow. + +These tests are likely incompatible with Github Actions and will be used on a VM +or other virtual environment. These are intended to be fairly lightweight marquee +examples of each event type and core metrics. If the values deviate from expected +for a release, it will be flagged as a failure.""" + + +# Load case data from the default events.yaml + +import pathlib + +import pytest + +from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics + + +@pytest.fixture(scope="module") +def reference_data_dir(): + """Path to reference data directory.""" + path = pathlib.Path(__file__).parent / "data" + if not path.exists(): + pytest.skip( + "Reference data not found. Run 'uv run data/generate_cape_reference_data.py' first." + ) + return path + + +@pytest.fixture(scope="module") +def golden_tests_event_data(reference_data_dir): + """Load golden tests event data.""" + ref_file = reference_data_dir / "golden_tests.yaml" + if not ref_file.exists(): + pytest.skip(f"Golden tests event data not found: {ref_file}") + + return cases.load_individual_cases_from_yaml(ref_file) + + +@pytest.mark.integration +class TestGoldenTests: + """Golden tests.""" + + def test_heatwaves(self, golden_tests_event_data): + """Heatwave tests.""" + # Define heatwave objects + era5_heatwave_target = inputs.ERA5() + ghcn_heatwave_target = inputs.GHCN() + + heatwave_metrics = [ + metrics.MaximumMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.RootMeanSquaredError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.DurationMeanError( + threshold_criteria=defaults.get_climatology(quantile=0.85) + ), + metrics.MaximumLowestMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + ] + + hres_heatwave_forecast = inputs.ZarrForecast( + name="hres_heatwave_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=["surface_air_temperature"], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + + heatwave_evaluation_objects = [ + inputs.EvaluationObject( + event_type="heat_wave", + metric_list=heatwave_metrics, + target=era5_heatwave_target, + forecast=hres_heatwave_forecast, + ), + inputs.EvaluationObject( + event_type="heat_wave", + metric_list=heatwave_metrics, + target=ghcn_heatwave_target, + forecast=hres_heatwave_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=heatwave_evaluation_objects, + ) + + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(heatwave_evaluation_objects) * len(heatwave_metrics), + }, + ) + + def test_freezes(self, golden_tests_event_data): + """Freeze tests.""" + era5_freeze_target = inputs.ERA5() + ghcn_freeze_target = inputs.GHCN() + hres_freeze_forecast = inputs.ZarrForecast( + name="hres_freeze_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=["surface_air_temperature"], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + freeze_metrics = [ + metrics.MinimumMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.RootMeanSquaredError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.DurationMeanError( + threshold_criteria=defaults.get_climatology(quantile=0.15) + ), + ] + freeze_evaluation_objects = [ + inputs.EvaluationObject( + event_type="freeze", + metric_list=freeze_metrics, + target=era5_freeze_target, + forecast=hres_freeze_forecast, + ), + inputs.EvaluationObject( + event_type="freeze", + metric_list=freeze_metrics, + target=ghcn_freeze_target, + forecast=hres_freeze_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=freeze_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(freeze_evaluation_objects) * len(freeze_metrics), + }, + ) + + def test_severe_convection(self, golden_tests_event_data): + """Severe convection tests.""" + lsr_severe_convection_target = inputs.LSR() + pph_severe_convection_target = inputs.PPH() + hres_severe_convection_forecast = inputs.ZarrForecast( + name="hres_severe_convection_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[derived.CravenBrooksSignificantSevere()], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + severe_convection_metrics = [ + metrics.ThresholdMetric( + metrics=[metrics.CriticalSuccessIndex, metrics.FalseAlarmRatio], + forecast_threshold=15000, + target_threshold=0.3, + ), + metrics.EarlySignal(threshold=15000), + ] + severe_convection_evaluation_objects = [ + inputs.EvaluationObject( + event_type="severe_convection", + metric_list=severe_convection_metrics, + target=lsr_severe_convection_target, + forecast=hres_severe_convection_forecast, + ), + inputs.EvaluationObject( + event_type="severe_convection", + metric_list=severe_convection_metrics, + target=pph_severe_convection_target, + forecast=hres_severe_convection_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=severe_convection_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(severe_convection_evaluation_objects) + * len(severe_convection_metrics), + }, + ) + + def test_atmospheric_river(self, golden_tests_event_data): + """Atmospheric river tests.""" + era5_atmospheric_river_target = inputs.ERA5() + hres_atmospheric_river_forecast = inputs.ZarrForecast( + name="hres_atmospheric_river_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[ + derived.AtmosphericRiverVariables( + output_variables=[ + "atmospheric_river_mask", + "integrated_vapor_transport", + "atmospheric_river_land_intersection", + ] + ) + ], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + atmospheric_river_metrics = [ + metrics.CriticalSuccessIndex(), + metrics.EarlySignal(), + metrics.SpatialDisplacement(), + ] + atmospheric_river_evaluation_objects = [ + inputs.EvaluationObject( + event_type="atmospheric_river", + metric_list=atmospheric_river_metrics, + target=era5_atmospheric_river_target, + forecast=hres_atmospheric_river_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=atmospheric_river_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(atmospheric_river_evaluation_objects) + * len(atmospheric_river_metrics), + }, + ) + + def test_tropical_cyclone(self, golden_tests_event_data): + """Tropical cyclone tests.""" + ibtracs_tropical_cyclone_target = inputs.IBTrACS() + hres_tropical_cyclone_forecast = inputs.ZarrForecast( + name="hres_tropical_cyclone_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[derived.TropicalCycloneTrackVariables()], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + tropical_cyclone_metrics = [ + metrics.LandfallMetric( + metrics=[ + metrics.LandfallIntensityMeanAbsoluteError, + metrics.LandfallTimeMeanError, + metrics.LandfallDisplacement, + ], + approach="next", + forecast_variable="air_pressure_at_mean_sea_level", + target_variable="air_pressure_at_mean_sea_level", + ), + ] + tropical_cyclone_evaluation_objects = [ + inputs.EvaluationObject( + event_type="tropical_cyclone", + metric_list=tropical_cyclone_metrics, + target=ibtracs_tropical_cyclone_target, + forecast=hres_tropical_cyclone_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=tropical_cyclone_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(tropical_cyclone_evaluation_objects) + * len(tropical_cyclone_metrics), + }, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 00000000..00517ae8 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,259 @@ +"""Tests for the extremeweatherbench package __init__.py API.""" + +import types + + +class TestModuleImports: + """Test that submodules are importable and are actual modules.""" + + def test_calc_is_module(self): + """Test that calc is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import calc + + assert isinstance(calc, types.ModuleType) + + def test_utils_is_module(self): + """Test that utils is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import utils + + assert isinstance(utils, types.ModuleType) + + def test_metrics_is_module(self): + """Test that metrics is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import metrics + + assert isinstance(metrics, types.ModuleType) + + def test_regions_is_module(self): + """Test that regions is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import regions + + assert isinstance(regions, types.ModuleType) + + def test_derived_is_module(self): + """Test that derived is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import derived + + assert isinstance(derived, types.ModuleType) + + def test_defaults_is_module(self): + """Test that defaults is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import defaults + + assert isinstance(defaults, types.ModuleType) + + def test_cases_is_module(self): + """Test that cases is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import cases + + assert isinstance(cases, types.ModuleType) + + +class TestModuleAccessPatterns: + """Test both import patterns work identically.""" + + def test_ewb_dot_notation_equals_direct_import_calc(self): + """Test ewb.calc is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import calc + + assert ewb.calc is calc + + def test_ewb_dot_notation_equals_direct_import_metrics(self): + """Test ewb.metrics is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import metrics + + assert ewb.metrics is metrics + + def test_ewb_dot_notation_equals_direct_import_utils(self): + """Test ewb.utils is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import utils + + assert ewb.utils is utils + + +class TestModuleLevelConstants: + """Test that module-level constants are accessible.""" + + def test_calc_g0_accessible(self): + """Test that calc.g0 constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "g0") + assert calc.g0 == 9.80665 + + def test_calc_epsilon_accessible(self): + """Test that calc.epsilon constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "epsilon") + assert isinstance(calc.epsilon, float) + + +class TestPrivateFunctionAccess: + """Test that private functions are accessible for testing purposes.""" + + def test_calc_private_functions_accessible(self): + """Test that private functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "_is_true_landfall") + assert hasattr(calc, "_detect_landfalls_wrapper") + assert hasattr(calc, "_mask_init_time_boundaries") + assert hasattr(calc, "_interpolate_and_format_landfalls") + + def test_utils_private_functions_accessible(self): + """Test that private functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "_create_nan_dataarray") + assert hasattr(utils, "_cache_maybe_densify_helper") + + def test_derived_private_functions_accessible(self): + """Test that private functions in derived are accessible.""" + from extremeweatherbench import derived + + assert hasattr(derived, "_maybe_convert_variable_to_string") + + def test_defaults_private_functions_accessible(self): + """Test that private functions in defaults are accessible.""" + from extremeweatherbench import defaults + + assert hasattr(defaults, "_preprocess_cira_forecast_dataset") + + def test_regions_private_functions_accessible(self): + """Test that private functions in regions are accessible.""" + from extremeweatherbench import regions + + assert hasattr(regions, "_adjust_bounds_to_dataset_convention") + + +class TestPublicFunctionAccess: + """Test that all public functions are accessible via module.""" + + def test_calc_public_functions(self): + """Test public functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "find_landfalls") + assert hasattr(calc, "nantrapezoid") + assert hasattr(calc, "dewpoint_from_specific_humidity") + assert hasattr(calc, "find_land_intersection") + assert hasattr(calc, "haversine_distance") + + def test_utils_public_functions(self): + """Test public functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "reduce_dataarray") + assert hasattr(utils, "stack_dataarray_from_dims") + assert hasattr(utils, "convert_longitude_to_360") + + +class TestTopLevelImports: + """Test that top-level imports work for commonly used items.""" + + def test_top_level_metric_imports(self): + """Test that metrics can be imported at top level.""" + from extremeweatherbench import ( + MeanAbsoluteError, + MeanError, + MeanSquaredError, + RootMeanSquaredError, + ) + + assert MeanAbsoluteError is not None + assert MeanError is not None + assert MeanSquaredError is not None + assert RootMeanSquaredError is not None + + def test_top_level_input_imports(self): + """Test that input classes can be imported at top level.""" + from extremeweatherbench import ERA5, GHCN, IBTrACS, ZarrForecast + + assert ERA5 is not None + assert GHCN is not None + assert IBTrACS is not None + assert ZarrForecast is not None + + def test_top_level_region_imports(self): + """Test that region classes can be imported at top level.""" + from extremeweatherbench import BoundingBoxRegion, CenteredRegion, Region + + assert Region is not None + assert BoundingBoxRegion is not None + assert CenteredRegion is not None + + def test_top_level_case_imports(self): + """Test that case classes can be imported at top level.""" + from extremeweatherbench import CaseOperator, IndividualCase + + assert IndividualCase is not None + assert CaseOperator is not None + + def test_evaluation_alias(self): + """Test that evaluation alias works.""" + from extremeweatherbench import ExtremeWeatherBench, evaluation + + assert evaluation is ExtremeWeatherBench + + def test_load_cases_alias(self): + """Test that load_cases alias works.""" + from extremeweatherbench import ( + load_cases, + load_ewb_events_yaml_into_case_list, + ) + + assert load_cases is load_ewb_events_yaml_into_case_list + + +class TestNamespaceSubmodules: + """Test the convenience namespace submodules.""" + + def test_targets_namespace(self): + """Test targets SimpleNamespace contains expected items.""" + from extremeweatherbench import targets + + assert isinstance(targets, types.SimpleNamespace) + assert hasattr(targets, "ERA5") + assert hasattr(targets, "GHCN") + assert hasattr(targets, "IBTrACS") + assert hasattr(targets, "TargetBase") + + def test_forecasts_namespace(self): + """Test forecasts SimpleNamespace contains expected items.""" + from extremeweatherbench import forecasts + + assert isinstance(forecasts, types.SimpleNamespace) + assert hasattr(forecasts, "ZarrForecast") + assert hasattr(forecasts, "KerchunkForecast") + assert hasattr(forecasts, "ForecastBase") + + +class TestMockPatching: + """Test that mock.patch.object works with module imports.""" + + def test_mock_patch_object_on_calc(self): + """Test that mock.patch.object works on calc module.""" + from unittest import mock + + from extremeweatherbench import calc + + with mock.patch.object(calc, "haversine_distance") as mock_func: + mock_func.return_value = 42.0 + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 42.0 + mock_func.assert_called_once() + + def test_mock_patch_string_on_calc(self): + """Test that mock.patch with string path works on calc module.""" + from unittest import mock + + with mock.patch("extremeweatherbench.calc.haversine_distance") as mock_func: + mock_func.return_value = 100.0 + from extremeweatherbench import calc + + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 100.0 diff --git a/tests/test_integration.py b/tests/test_integration.py index 07fba348..43ba0d30 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -600,7 +600,7 @@ def test_full_workflow_single_variable( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame) @@ -679,7 +679,7 @@ def test_full_workflow_multiple_variables( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame)