Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
958 changes: 420 additions & 538 deletions notebooks/tutorial/CNN-Model-Training.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions packages/data/src/pyearthtools/data/download/weatherbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import textwrap
import hashlib
import shutil
from pathlib import Path
from typing import Literal

Expand Down Expand Up @@ -97,10 +98,14 @@ def _save_variable(darr: xr.DataArray, path: Path):
zarrpath = path / f"{darr.name}.zarr"
varname = f"{darr.name} variable"

if zarrpath.is_dir():
if (canary_file := zarrpath / ".completed").is_file():
logger.info(f"Skip saving {varname}, folder {zarrpath} already exists.")
return

if zarrpath.is_dir():
logger.info(f"Incomplete download of {varname} found, removing folder {zarrpath}.")
shutil.rmtree(zarrpath)

compressor = {"compressor": Blosc(cname="zstd", clevel=6)}
zarr_kwargs = {"encoding": {darr.name: compressor}, "consolidated": False}

Expand All @@ -111,6 +116,7 @@ def _save_variable(darr: xr.DataArray, path: Path):
with TqdmCallback(desc="Writing", disable=disable_bar):
darr.to_zarr(zarrpath, **zarr_kwargs)

canary_file.touch()
logger.info(f"Saving {varname} finished.")


Expand Down Expand Up @@ -152,12 +158,12 @@ def open_local_dataset(path: Path, variables: list[str], level: list[int]) -> xr
dsets = []
for varname in variables:
filepath = path / f"{varname}.zarr"
if filepath.is_dir():
if (filepath / ".completed").is_file():
logger.debug(f"Loading {varname} variable from folder {filepath}.")
dset = xr.open_zarr(filepath, consolidated=False)
else:
filelist = [path / f"{varname}_level-{lvl}.zarr" for lvl in level]
if any(not fpath.is_dir() for fpath in filelist):
if any(not (fpath / ".completed").is_file() for fpath in filelist):
raise MissingVariableFile("Missing .zarr folder for some variables")
logger.debug(f"Loading {varname} variable from folders {[str(p) for p in filelist]}.")
dset = xr.open_mfdataset(filelist, concat_dim="level", combine="nested", consolidated=False)
Expand Down
2 changes: 1 addition & 1 deletion packages/data/src/pyearthtools/data/indexes/_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

LOG = logging.getLogger("pyearthtools.data")


class Index(CallRedirectMixin, CatalogMixin, metaclass=ABCMeta):
"""
Base Level Index to define the structure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_config(mf: bool = False):
return xr.open_mfdataset(
filter_files(location),
decode_timedelta=True, # TODO: should we raise a warning? It seems to be required for almost all our data.
compat='override',
compat="override",
**get_config(True),
)

Expand Down
1 change: 1 addition & 0 deletions packages/data/tests/indexes/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pyearthtools.data.exceptions import DataNotFoundError


def test_Index(monkeypatch):

monkeypatch.setattr("pyearthtools.data.indexes.Index.__abstractmethods__", set())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def unjoin(self, sample: Any) -> tuple:


class LatLonInterpolate(Joiner):
'''
"""
Makes additional assumptions about how interpolation should work and
how the data is structured. In this case, interpolation is primarily
expected to occur according to latitude and longitude, performing
expected to occur according to latitude and longitude, performing
no broadcasting, and iterating over other dimensions instead.

It assumed the dimensions 'latitude', 'longitude', 'time', and 'level' will
be present. 'lat' or 'lon' may also be used for convenience.
'''
"""

_override_interface = "Serial"

Expand All @@ -78,15 +78,15 @@ def __init__(
self._merge_kwargs = merge_kwargs

def raise_if_dimensions_wrong(self, dataset):
'''
"""
Raise exceptions if the supplied dataset does not meet requirements
'''
"""

if not hasattr(self, 'required_dims'):
if 'lat' in dataset.coords:
self.required_dims = ['lat', 'lon']
if not hasattr(self, "required_dims"):
if "lat" in dataset.coords:
self.required_dims = ["lat", "lon"]
else:
self.required_dims = ['latitude', 'longitude']
self.required_dims = ["latitude", "longitude"]

present_in_coords = [d in dataset.coords for d in self.required_dims]
if not all(present_in_coords):
Expand All @@ -100,12 +100,12 @@ def raise_if_dimensions_wrong(self, dataset):
# raise ValueError(f"Cannot perform a GeoMergePancake on the data variables {data_var} without {self.required_dims} as a dimension")

def maybe_interp(self, ds):
'''
"""
This method will only interpolate the datasets if the latitudes and longitudes don't already
match. This means, for example, you can't use it to interpolate between time steps
or vertical levels alone. The primary purpose here is lat/lon interpolation, not general
model interpolation or arbitrarily-dimensioned data interpolation.
'''
"""

ds_coords_ok = [ds[coord].equals(self.reference_dataset[coord]) for coord in self.required_dims]

Expand All @@ -115,7 +115,6 @@ def maybe_interp(self, ds):

return ds


def _join_two_datasets(self, sample_a: xr.Dataset, sample_b: xr.Dataset) -> xr.Dataset:
"""
Used to reduce a sequence of joinable items. Only called by the public interface join method.
Expand Down Expand Up @@ -144,7 +143,7 @@ def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Datase
return merged

def unjoin(self, sample: Any) -> tuple:
raise NotImplementedError("Not Implemented")
raise NotImplementedError("Not Implemented")


class GeospatialTimeSeriesMerge(Joiner):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ class Deviation(xarrayNormalisation):
"""Deviation Normalisation"""

def __init__(
self,
mean: FILE | xr.Dataset | xr.DataArray | float,
self,
mean: FILE | xr.Dataset | xr.DataArray | float,
deviation: FILE | xr.Dataset | xr.DataArray | float,
debug=False
debug=False,
):
"""
Each argument take take a Dataset, DataArray, float or file object.
Expand All @@ -173,7 +173,9 @@ def __init__(
self.record_initialisation()

if debug:
import pdb; pdb.set_trace()
import pdb

pdb.set_trace()

if isinstance(mean, xr.Dataset):
self.mean = mean
Expand All @@ -187,7 +189,7 @@ def __init__(
if isinstance(deviation, xr.Dataset):
self.deviation = deviation
elif isinstance(deviation, xr.DataArray):
self.deviation = deviation
self.deviation = deviation
elif isinstance(deviation, float):
self.deviation = deviation
else:
Expand Down