Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def from_pandas(

@classmethod
def from_xarray(
cls, xarray_obj: xr.DataArray | xr.Dataset,
cls,
xarray_obj: xr.DataArray | xr.Dataset,
) -> ResultE["NWPDimensionCoordinateMap"]:
"""Create a new NWPDimensionCoordinateMap from an XArray DataArray or Dataset object."""
return cls.from_pandas(xarray_obj.coords.indexes) # type: ignore
Expand Down Expand Up @@ -560,6 +561,7 @@ def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray
dummy_values = dask.array.zeros( # type: ignore
shape=list(self.shapemap.values()),
chunks=tuple([chunks[k] for k in self.shapemap]),
dtype=np.float32,
)
attrs: dict[str, str] = {
"produced_by": "".join(
Expand Down
7 changes: 3 additions & 4 deletions src/nwp_consumer/internal/entities/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def delete_store(self) -> ResultE[None]:
except Exception as e:
return Failure(
OSError(
f"Unable to delete store at path '{self.path}'. " f"Error context: {e}",
f"Unable to delete store at path '{self.path}'. Error context: {e}",
),
)
log.info("Deleted zarr store at '%s'", self.path)
Expand Down Expand Up @@ -571,8 +571,7 @@ def missing_times(self) -> ResultE[list[dt.datetime]]:
missing_times.append(pd.Timestamp(it).to_pydatetime().replace(tzinfo=dt.UTC))
if len(missing_times) > 0:
log.debug(
f"NaNs in init times '{missing_times}' suggest they are missing, "
f"will redownload",
f"NaNs in init times '{missing_times}' suggest they are missing, will redownload",
)
return Success(missing_times)

Expand All @@ -594,7 +593,7 @@ def _create_zarrstore_s3(s3_folder: str, filename: str) -> ResultE[tuple[Mutable
if not s3_folder.startswith("s3://"):
return Failure(
ValueError(
"S3 folder path must start with 's3://'. " f"Got: {s3_folder}",
f"S3 folder path must start with 's3://'. Got: {s3_folder}",
),
)
log.debug("Attempting AWS connection using credential discovery")
Expand Down
26 changes: 25 additions & 1 deletion src/nwp_consumer/internal/entities/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ class TestCase:
self.assertListEqual(list(result.keys()), list(t.expected_indexes.keys()))
for key in result:
self.assertListEqual(
result[key].values.tolist(), t.expected_indexes[key].values.tolist(),
result[key].values.tolist(),
t.expected_indexes[key].values.tolist(),
)

def test_from_pandas(self) -> None:
Expand Down Expand Up @@ -322,6 +323,29 @@ class TestCase:
self.assertListEqual(coords.latitude, t.expected_latitude) # type: ignore
self.assertListEqual(coords.longitude, t.expected_longitude) # type: ignore

def test_as_zeroed_dataarray_float32_dtype(self) -> None:
"""Test that as_zeroed_dataarray creates data with float32 dtype."""
coords = NWPDimensionCoordinateMap(
init_time=[dt.datetime(2021, 1, 1, 0, tzinfo=dt.UTC)],
step=[0, 3, 6],
variable=[Parameter.TEMPERATURE_SL],
latitude=[50.0, 51.0],
longitude=[0.0, 1.0],
)

chunks = {"init_time": 1, "step": 1, "variable": 1, "latitude": 2, "longitude": 2}

da = coords.as_zeroed_dataarray(name="test_model", chunks=chunks)

# Check that the data type is float32
self.assertEqual(da.dtype, "float32")

# Also check the underlying dask array dtype
import dask.array as dask_array

if hasattr(da.data, "dtype"):
self.assertEqual(da.data.dtype, dask_array.float32)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion src/nwp_consumer/internal/ports/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def authenticate(cls) -> ResultE["RawRepository"]:

@abc.abstractmethod
def fetch_init_data(
self, it: dt.datetime,
self,
it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
"""Fetch raw data files for an init time as xarray datasets.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def model() -> entities.ModelMetadata:

@override
def fetch_init_data(
self, it: dt.datetime,
self,
it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
parameter_stubs: list[str] = [
"Total_Downward_Surface_SW_Flux",
Expand Down Expand Up @@ -364,7 +365,7 @@ def _convert_global(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
# * and so may not produce a contiguous subset of the expected coordinates.
processed_das.extend(
[
da.where(cond=da.coords["variable"] == v, drop=True)
da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32)
for v in da.coords["variable"].values
],
)
Expand Down Expand Up @@ -473,7 +474,7 @@ def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
# * and so may not produce a contiguous subset of the expected coordinates.
processed_das.extend(
[
da.where(cond=da.coords["variable"] == v, drop=True)
da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32)
for v in da.coords["variable"].values
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import override

import cfgrib
import numpy as np
import xarray as xr
from ecmwfapi import ECMWFService
from joblib import delayed
Expand Down Expand Up @@ -385,7 +386,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
# * and so may not produce a contiguous subset of the expected coordinates.
processed_das.extend(
[
da.where(cond=da.coords["variable"] == v, drop=True)
da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32)
for v in da.coords["variable"].values
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import override

import cfgrib
import numpy as np
import s3fs
import xarray as xr
from joblib import delayed
Expand Down Expand Up @@ -109,7 +110,8 @@ def model() -> entities.ModelMetadata:

@override
def fetch_init_data(
self, it: dt.datetime,
self,
it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
# List relevant files in the S3 bucket
try:
Expand Down Expand Up @@ -318,7 +320,10 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
# * Each raw file does not contain a full set of parameters
# * and so may not produce a contiguous subset of the expected coordinates.
processed_das.extend(
[da.where(cond=da["variable"] == v, drop=True) for v in da["variable"].values],
[
da.where(cond=da["variable"] == v, drop=True).astype(np.float32)
for v in da["variable"].values
],
)

if len(processed_das) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING, ClassVar, override

import numpy as np
import xarray as xr
from joblib import delayed
from returns.result import Failure, ResultE, Success
Expand Down Expand Up @@ -205,8 +206,8 @@ def fetch_init_data(
it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
req: urllib.request.Request = urllib.request.Request( # noqa: S310
url=self.request_url + \
f"?detail=MINIMAL&runfilter={it:%Y%m%d%H}&dataSpec={self.dataspec}",
url=self.request_url
+ f"?detail=MINIMAL&runfilter={it:%Y%m%d%H}&dataSpec={self.dataspec}",
headers=self._headers,
method="GET",
)
Expand Down Expand Up @@ -244,7 +245,7 @@ def fetch_init_data(
for filedata in data["orderDetails"]["files"]:
if "fileId" in filedata and "+" not in filedata["fileId"]:
urls.append(
f"{self.request_url}/{filedata["fileId"]}/data?dataSpec={self.dataspec}",
f"{self.request_url}/{filedata['fileId']}/data?dataSpec={self.dataspec}",
)

log.debug(
Expand Down Expand Up @@ -286,7 +287,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw",
),
)
/ f"{url.split("/")[-2]}.grib"
/ f"{url.split('/')[-2]}.grib"
).expanduser()

# Only download the file if not already present
Expand All @@ -309,7 +310,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
except Exception as e:
return Failure(
OSError(
"Unable to request file data from MetOffice DataHub at " f"'{url}': {e}",
f"Unable to request file data from MetOffice DataHub at '{url}': {e}",
),
)

Expand Down Expand Up @@ -405,7 +406,7 @@ def _convert_global(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
ValueError(f"Error processing DataArray for path '{path}'. Error context: {e}"),
)

return Success([da])
return Success([da.astype(np.float32)])

@staticmethod
def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
Expand Down Expand Up @@ -516,4 +517,4 @@ def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
),
)

return Success([da])
return Success([da.astype(np.float32)])
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing import override

import cfgrib
import numpy as np
import s3fs
import xarray as xr
from joblib import delayed
Expand Down Expand Up @@ -299,7 +300,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
),
)

return Success([da])
return Success([da.astype(np.float32)])

@staticmethod
def _wanted_file(filename: str, it: dt.datetime, steps: list[int]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def test__convert_global(self) -> None:
self.assertIsInstance(region_result, Failure, msg=f"{region_result}")
else:
self.assertIsInstance(region_result, Success, msg=f"{region_result}")
# Test that the data is stored as float32
das = result.unwrap()
for da in das:
self.assertEqual(
da.dtype,
"float32",
msg=f"DataArray should have float32 dtype, got {da.dtype}",
)

@patch.dict(os.environ, {"MODEL": "mo-um-ukv"}, clear=True)
def test__convert_ukv(self) -> None:
Expand Down Expand Up @@ -117,6 +125,14 @@ def test__convert_ukv(self) -> None:
self.assertIsInstance(region_result, Failure, msg=f"{region_result}")
else:
self.assertIsInstance(region_result, Success, msg=f"{region_result}")
# Test that the data is stored as float32
das = result.unwrap()
for da in das:
self.assertEqual(
da.dtype,
"float32",
msg=f"DataArray should have float32 dtype, got {da.dtype}",
)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions src/nwp_consumer/internal/services/_dummy_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ def model() -> entities.ModelMetadata:

@override
def fetch_init_data(
self, it: dt.datetime,
self,
it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
def gen_dataset(step: int, variable: str) -> ResultE[list[xr.DataArray]]:
"""Define a generator that provides one variable at one step."""
da = xr.DataArray(
name=self.model().name,
dims=["init_time", "step", "variable", "latitude", "longitude"],
data=np.random.rand(1, 1, 1, 721, 1440),
data=np.random.rand(1, 1, 1, 721, 1440).astype(np.float32),
coords=self.model().expected_coordinates.to_pandas()
| {
"init_time": [np.datetime64(it.replace(tzinfo=None), "ns")],
Expand Down
2 changes: 1 addition & 1 deletion src/nwp_consumer/internal/services/consumer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _fold_dataarrays_generator(
return Failure(
OSError(
"Error threshold exceeded: "
f"{len(failures)/len(results)} errors (>6%) occurred during processing.",
f"{len(failures) / len(results)} errors (>6%) occurred during processing.",
),
)
else:
Expand Down