From 97b5cb684f956e034ed9697da871d38615ab34b7 Mon Sep 17 00:00:00 2001 From: ram-from-tvl Date: Fri, 5 Sep 2025 11:26:50 +0530 Subject: [PATCH] Convert data storage from float64 to float32 1. Updating as_zeroed_dataarray() in coordinates.py to create float32 dask arrays 2. Updating dummy_adaptors.py to generate float32 test data 3. Converting all repository implementations to save data as float32: - ceda.py: Both global and UKV conversion methods - mo_datahub.py: Both global and UKV conversion methods - noaa_s3.py: Main conversion method - ecmwf_realtime.py: Main conversion method - ecmwf_mars.py: Main conversion method 4. Adding test cases to verify float32 dtype is maintained 5. Ensuring all code passes linting and formatting checks --- .../internal/entities/coordinates.py | 4 ++- .../internal/entities/tensorstore.py | 7 +++-- .../internal/entities/test_coordinates.py | 26 ++++++++++++++++++- .../internal/ports/repositories.py | 3 ++- .../repositories/raw_repositories/ceda.py | 7 ++--- .../raw_repositories/ecmwf_mars.py | 3 ++- .../raw_repositories/ecmwf_realtime.py | 9 +++++-- .../raw_repositories/mo_datahub.py | 15 ++++++----- .../repositories/raw_repositories/noaa_s3.py | 3 ++- .../raw_repositories/test_ceda.py | 16 ++++++++++++ .../internal/services/_dummy_adaptors.py | 5 ++-- .../internal/services/consumer_service.py | 2 +- 12 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index f782fa66..6321a843 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -346,7 +346,8 @@ def from_pandas( @classmethod def from_xarray( - cls, xarray_obj: xr.DataArray | xr.Dataset, + cls, + xarray_obj: xr.DataArray | xr.Dataset, ) -> ResultE["NWPDimensionCoordinateMap"]: """Create a new NWPDimensionCoordinateMap from an XArray DataArray or Dataset object.""" return cls.from_pandas(xarray_obj.coords.indexes) # type: ignore @@ -560,6 +561,7 @@ def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray dummy_values = dask.array.zeros( # type: ignore shape=list(self.shapemap.values()), chunks=tuple([chunks[k] for k in self.shapemap]), + dtype=np.float32, ) attrs: dict[str, str] = { "produced_by": "".join( diff --git a/src/nwp_consumer/internal/entities/tensorstore.py b/src/nwp_consumer/internal/entities/tensorstore.py index a82e53fc..ff959002 100644 --- a/src/nwp_consumer/internal/entities/tensorstore.py +++ b/src/nwp_consumer/internal/entities/tensorstore.py @@ -468,7 +468,7 @@ def delete_store(self) -> ResultE[None]: except Exception as e: return Failure( OSError( - f"Unable to delete store at path '{self.path}'. " f"Error context: {e}", + f"Unable to delete store at path '{self.path}'. Error context: {e}", ), ) log.info("Deleted zarr store at '%s'", self.path) @@ -571,8 +571,7 @@ def missing_times(self) -> ResultE[list[dt.datetime]]: missing_times.append(pd.Timestamp(it).to_pydatetime().replace(tzinfo=dt.UTC)) if len(missing_times) > 0: log.debug( - f"NaNs in init times '{missing_times}' suggest they are missing, " - f"will redownload", + f"NaNs in init times '{missing_times}' suggest they are missing, will redownload", ) return Success(missing_times) @@ -594,7 +593,7 @@ def _create_zarrstore_s3(s3_folder: str, filename: str) -> ResultE[tuple[Mutable if not s3_folder.startswith("s3://"): return Failure( ValueError( - "S3 folder path must start with 's3://'. " f"Got: {s3_folder}", + f"S3 folder path must start with 's3://'. Got: {s3_folder}", ), ) log.debug("Attempting AWS connection using credential discovery") diff --git a/src/nwp_consumer/internal/entities/test_coordinates.py b/src/nwp_consumer/internal/entities/test_coordinates.py index 15c99ac4..47a35abc 100644 --- a/src/nwp_consumer/internal/entities/test_coordinates.py +++ b/src/nwp_consumer/internal/entities/test_coordinates.py @@ -166,7 +166,8 @@ class TestCase: self.assertListEqual(list(result.keys()), list(t.expected_indexes.keys())) for key in result: self.assertListEqual( - result[key].values.tolist(), t.expected_indexes[key].values.tolist(), + result[key].values.tolist(), + t.expected_indexes[key].values.tolist(), ) def test_from_pandas(self) -> None: @@ -322,6 +323,29 @@ class TestCase: self.assertListEqual(coords.latitude, t.expected_latitude) # type: ignore self.assertListEqual(coords.longitude, t.expected_longitude) # type: ignore + def test_as_zeroed_dataarray_float32_dtype(self) -> None: + """Test that as_zeroed_dataarray creates data with float32 dtype.""" + coords = NWPDimensionCoordinateMap( + init_time=[dt.datetime(2021, 1, 1, 0, tzinfo=dt.UTC)], + step=[0, 3, 6], + variable=[Parameter.TEMPERATURE_SL], + latitude=[50.0, 51.0], + longitude=[0.0, 1.0], + ) + + chunks = {"init_time": 1, "step": 1, "variable": 1, "latitude": 2, "longitude": 2} + + da = coords.as_zeroed_dataarray(name="test_model", chunks=chunks) + + # Check that the data type is float32 + self.assertEqual(da.dtype, "float32") + + # Also check the underlying dask array dtype + import dask.array as dask_array + + if hasattr(da.data, "dtype"): + self.assertEqual(da.data.dtype, dask_array.float32) + if __name__ == "__main__": unittest.main() diff --git a/src/nwp_consumer/internal/ports/repositories.py b/src/nwp_consumer/internal/ports/repositories.py index ae53821b..777751dc 100644 --- a/src/nwp_consumer/internal/ports/repositories.py +++ b/src/nwp_consumer/internal/ports/repositories.py @@ -50,7 +50,8 @@ def authenticate(cls) -> ResultE["RawRepository"]: @abc.abstractmethod def fetch_init_data( - self, it: dt.datetime, + self, + it: dt.datetime, ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: """Fetch raw data files for an init time as xarray datasets. diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/ceda.py b/src/nwp_consumer/internal/repositories/raw_repositories/ceda.py index ab519814..13ff27d3 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/ceda.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/ceda.py @@ -155,7 +155,8 @@ def model() -> entities.ModelMetadata: @override def fetch_init_data( - self, it: dt.datetime, + self, + it: dt.datetime, ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: parameter_stubs: list[str] = [ "Total_Downward_Surface_SW_Flux", @@ -365,7 +366,7 @@ def _convert_global(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: # * and so may not produce a contiguous subset of the expected coordinates. processed_das.extend( [ - da.where(cond=da.coords["variable"] == v, drop=True) + da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32) for v in da.coords["variable"].values ], ) @@ -474,7 +475,7 @@ def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: # * and so may not produce a contiguous subset of the expected coordinates. processed_das.extend( [ - da.where(cond=da.coords["variable"] == v, drop=True) + da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32) for v in da.coords["variable"].values ], ) diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_mars.py b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_mars.py index e0a7e505..79f0b3e7 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_mars.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_mars.py @@ -18,6 +18,7 @@ from typing import override import cfgrib +import numpy as np import xarray as xr from ecmwfapi import ECMWFService from joblib import delayed @@ -386,7 +387,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: # * and so may not produce a contiguous subset of the expected coordinates. processed_das.extend( [ - da.where(cond=da.coords["variable"] == v, drop=True) + da.where(cond=da.coords["variable"] == v, drop=True).astype(np.float32) for v in da.coords["variable"].values ], ) diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py index 8161ef37..587e2230 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py @@ -39,6 +39,7 @@ from typing import override import cfgrib +import numpy as np import s3fs import xarray as xr from joblib import delayed @@ -108,7 +109,8 @@ def model() -> entities.ModelMetadata: @override def fetch_init_data( - self, it: dt.datetime, + self, + it: dt.datetime, ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: # List relevant files in the S3 bucket try: @@ -317,7 +319,10 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: # * Each raw file does not contain a full set of parameters # * and so may not produce a contiguous subset of the expected coordinates. processed_das.extend( - [da.where(cond=da["variable"] == v, drop=True) for v in da["variable"].values], + [ + da.where(cond=da["variable"] == v, drop=True).astype(np.float32) + for v in da["variable"].values + ], ) if len(processed_das) == 0: diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/mo_datahub.py b/src/nwp_consumer/internal/repositories/raw_repositories/mo_datahub.py index 16f0e7c8..c576a209 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/mo_datahub.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/mo_datahub.py @@ -112,6 +112,7 @@ from collections.abc import Callable, Iterator from typing import TYPE_CHECKING, ClassVar, override +import numpy as np import xarray as xr from joblib import delayed from returns.result import Failure, ResultE, Success @@ -205,8 +206,8 @@ def fetch_init_data( it: dt.datetime, ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: req: urllib.request.Request = urllib.request.Request( # noqa: S310 - url=self.request_url + \ - f"?detail=MINIMAL&runfilter={it:%Y%m%d%H}&dataSpec={self.dataspec}", + url=self.request_url + + f"?detail=MINIMAL&runfilter={it:%Y%m%d%H}&dataSpec={self.dataspec}", headers=self._headers, method="GET", ) @@ -244,7 +245,7 @@ def fetch_init_data( for filedata in data["orderDetails"]["files"]: if "fileId" in filedata and "+" not in filedata["fileId"]: urls.append( - f"{self.request_url}/{filedata["fileId"]}/data?dataSpec={self.dataspec}", + f"{self.request_url}/{filedata['fileId']}/data?dataSpec={self.dataspec}", ) log.debug( @@ -286,7 +287,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", ), ) - / f"{url.split("/")[-2]}.grib" + / f"{url.split('/')[-2]}.grib" ).expanduser() # Only download the file if not already present @@ -309,7 +310,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: except Exception as e: return Failure( OSError( - "Unable to request file data from MetOffice DataHub at " f"'{url}': {e}", + f"Unable to request file data from MetOffice DataHub at '{url}': {e}", ), ) @@ -405,7 +406,7 @@ def _convert_global(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ValueError(f"Error processing DataArray for path '{path}'. Error context: {e}"), ) - return Success([da]) + return Success([da.astype(np.float32)]) @staticmethod def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: @@ -516,4 +517,4 @@ def _convert_ukv(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ), ) - return Success([da]) + return Success([da.astype(np.float32)]) diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/noaa_s3.py b/src/nwp_consumer/internal/repositories/raw_repositories/noaa_s3.py index e6556287..68aad863 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/noaa_s3.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/noaa_s3.py @@ -35,6 +35,7 @@ from typing import override import cfgrib +import numpy as np import s3fs import xarray as xr from joblib import delayed @@ -300,7 +301,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ), ) - return Success([da]) + return Success([da.astype(np.float32)]) @staticmethod def _wanted_file(filename: str, it: dt.datetime, steps: list[int]) -> bool: diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/test_ceda.py b/src/nwp_consumer/internal/repositories/raw_repositories/test_ceda.py index 9b2e74b2..79179f18 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/test_ceda.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/test_ceda.py @@ -72,6 +72,14 @@ def test__convert_global(self) -> None: self.assertIsInstance(region_result, Failure, msg=f"{region_result}") else: self.assertIsInstance(region_result, Success, msg=f"{region_result}") + # Test that the data is stored as float32 + das = result.unwrap() + for da in das: + self.assertEqual( + da.dtype, + "float32", + msg=f"DataArray should have float32 dtype, got {da.dtype}", + ) @patch.dict(os.environ, {"MODEL": "mo-um-ukv"}, clear=True) def test__convert_ukv(self) -> None: @@ -117,6 +125,14 @@ def test__convert_ukv(self) -> None: self.assertIsInstance(region_result, Failure, msg=f"{region_result}") else: self.assertIsInstance(region_result, Success, msg=f"{region_result}") + # Test that the data is stored as float32 + das = result.unwrap() + for da in das: + self.assertEqual( + da.dtype, + "float32", + msg=f"DataArray should have float32 dtype, got {da.dtype}", + ) if __name__ == "__main__": diff --git a/src/nwp_consumer/internal/services/_dummy_adaptors.py b/src/nwp_consumer/internal/services/_dummy_adaptors.py index 13ca9745..dd66d477 100644 --- a/src/nwp_consumer/internal/services/_dummy_adaptors.py +++ b/src/nwp_consumer/internal/services/_dummy_adaptors.py @@ -53,14 +53,15 @@ def model() -> entities.ModelMetadata: @override def fetch_init_data( - self, it: dt.datetime, + self, + it: dt.datetime, ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: def gen_dataset(step: int, variable: str) -> ResultE[list[xr.DataArray]]: """Define a generator that provides one variable at one step.""" da = xr.DataArray( name=self.model().name, dims=["init_time", "step", "variable", "latitude", "longitude"], - data=np.random.rand(1, 1, 1, 721, 1440), + data=np.random.rand(1, 1, 1, 721, 1440).astype(np.float32), coords=self.model().expected_coordinates.to_pandas() | { "init_time": [np.datetime64(it.replace(tzinfo=None), "ns")], diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index 591b4929..2940ae86 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -92,7 +92,7 @@ def _fold_dataarrays_generator( return Failure( OSError( "Error threshold exceeded: " - f"{len(failures)/len(results)} errors (>6%) occurred during processing.", + f"{len(failures) / len(results)} errors (>6%) occurred during processing.", ), ) else: