From 4737a943cc308c427daea99ea2ea30831432fe1f Mon Sep 17 00:00:00 2001 From: Shruti Nath Date: Tue, 17 Jun 2025 19:28:46 +0100 Subject: [PATCH 1/4] fix comma --- xarray_batcher/torch_batcher.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/xarray_batcher/torch_batcher.py b/xarray_batcher/torch_batcher.py index 1e13ea6..fbc3628 100644 --- a/xarray_batcher/torch_batcher.py +++ b/xarray_batcher/torch_batcher.py @@ -20,7 +20,7 @@ def __init__( X, y, constants, - batch_size: List[int] = [4, 128, 128], + batch_size: list[int] = [4, 128, 128], weighted_sampler: bool = True, for_NJ: bool = False, for_val: bool = False, @@ -90,7 +90,8 @@ def __getitem__(self, idx): np.concatenate( X_batch, axis=-1, - )).float() + ) + ).float() constant_batch = torch.from_numpy( np.stack( @@ -101,7 +102,8 @@ def __getitem__(self, idx): for constant in self.constants ], axis=-1, - )).float() + ) + ).float() if self.for_NJ: @@ -132,7 +134,8 @@ def __getitem__(self, idx): torch.from_numpy( y_batch.precipitation.fillna(0).values.reshape( self.batch_size[0], -1, 1 - )).float(), + ) + ).float(), X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), ), dim=-1, @@ -164,7 +167,8 @@ def __getitem__(self, idx): else: y_batch = torch.from_numpy( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None]).float() + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) @@ -200,11 +204,14 @@ def __init__( else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} ) self.y_generator = xbatcher.BatchGenerator( - y, - {"time": batch_size[0], - "latitude" if for_NJ else "lat": batch_size[1], "longitude" if for_NJ else "lon": batch_size[2]}, - input_overlap=overlap, - ) + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) if weighted_sampler: if self.for_NJ: From 33ad55a04049ef7bd79085e1ae61d858d29bdeb5 Mon Sep 17 00:00:00 2001 From: Shruti Nath Date: Thu, 26 Jun 2025 21:26:07 +0000 Subject: [PATCH 2/4] add streaming option and test --- test_xbatcher.ipynb | 109 +++++++- xarray_batcher/__init__.py | 5 +- xarray_batcher/create_npz.py | 4 + xarray_batcher/get_fcst_and_truth.py | 11 +- xarray_batcher/loading.py | 37 ++- xarray_batcher/torch_streamer.py | 374 +++++++++++++++++++++++++++ xarray_batcher/utils.py | 12 +- 7 files changed, 520 insertions(+), 32 deletions(-) diff --git a/test_xbatcher.ipynb b/test_xbatcher.ipynb index 1525653..8f37016 100644 --- a/test_xbatcher.ipynb +++ b/test_xbatcher.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "6fcec189-ecb8-4bc3-81c2-a858cb4f7a5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/n/nath/nobackups/miniforge3/envs/torch_24/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import xarray_batcher as xb" ] @@ -49,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "e3824ef0-1d4b-4932-b568-f8905f93e360", "metadata": {}, "outputs": [], @@ -60,28 +69,104 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "cdc49def-ea1d-4ab9-94db-a8061093f126", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Only getting truth values over, 376 dates\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████| 376/376 [00:19<00:00, 19.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished retrieving truth values in ---- 20.60894227027893 s---- for years [2018]\n" + ] + } + ], "source": [ "truth_ds = get_all([2018],model='truth')" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "d4e94238-e2ae-4b54-96dc-7c8e84423efb", + "metadata": {}, + "outputs": [], + "source": [ + "from xarray_batcher.loading import load_hires_constants\n", + "\n", + "constants = load_hires_constants(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "id": "7194092b-1ced-48b6-881d-cbde68d6a5bf", "metadata": {}, "outputs": [], "source": [ - "import xbatcher\n", + "import importlib\n", + "importlib.reload(xb)\n", "\n", "batch_size=[1,128,128]\n", "\n", - "y_generator = xbatcher.BatchGenerator(truth_ds,\n", - " {\"time\": batch_size[0], \"lat\": batch_size[1], \"lon\": batch_size[2]},\n", - " input_overlap={\"lat\": int(batch_size[1]/32), \"lon\": int(batch_size[2]/32)})\n" + "data_generator = xb.StreamDataset(truth_ds,['cp', 'mcc', 'sp', 'ssr', 't2m', 'tciw', 'tclw', 'tcrw', 'tcw', 'tcwv', 'tp','u700','v700'],\n", + " constants,batch_size=[4,128,128])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cd8f5d87-5655-44ef-bc58-c190948ae855", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.97 s, sys: 813 ms, total: 4.78 s\n", + "Wall time: 4.76 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "X,y = next(data_generator.__iter__())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b2f77a4b-cb35-4e83-ae27-30ce0f866d22", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 128, 128, 1])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.shape" ] }, { @@ -93,7 +178,9 @@ "source": [ "%%time\n", "\n", - "get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,variables=['tp','cp','u700'])" + "get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,\n", + " variables=['cp', 'u700', 'tp'])\n", + "\n" ] }, { diff --git a/xarray_batcher/__init__.py b/xarray_batcher/__init__.py index 86b18df..772e0e6 100644 --- a/xarray_batcher/__init__.py +++ b/xarray_batcher/__init__.py @@ -1,7 +1,8 @@ ## Initialisation for xarray batcher, import all helper functions import sys - from .setup_data import DataModule +from .torch_batcher import BatchDataset, BatchTruth +from .torch_streamer import StreamDataset, StreamTruth -__all__ = ["DataModule"] \ No newline at end of file +__all__ = ["DataModule", "BatchDataset", "BatchTruth", "StreamDataset", "StreamTruth"] diff --git a/xarray_batcher/create_npz.py b/xarray_batcher/create_npz.py index 309f531..586c336 100644 --- a/xarray_batcher/create_npz.py +++ b/xarray_batcher/create_npz.py @@ -38,8 +38,10 @@ def collate_fn(batch, elev=elev, reg_dict={}): "elevation": elev_values, "spherical_coords": spherical_coords, "precipitation": [], + "time": [], } reg_dict[reg_sel]["precipitation"].append(batch.precipitation.values) + reg_dict[reg_sel]["time"].append(batch.time.values) return reg_dict @@ -81,10 +83,12 @@ def TruthDataloader_to_Npz( precipitation = np.stack( [np.vstack(reg_dict[key]["precipitation"]) for key in reg_dict.keys()] ) + time = np.stack([np.hstack(reg_dict[key]["time"]) for key in reg_dict.keys()]) np.savez( out_path + f"{year}_30min_IMERG_Nairobi_windowsize={window_size}.npz", spherical_coords=spherical_coords, elevation=elevation, precipitation=precipitation, + time=time, ) diff --git a/xarray_batcher/get_fcst_and_truth.py b/xarray_batcher/get_fcst_and_truth.py index 68ad71d..2ccc62a 100644 --- a/xarray_batcher/get_fcst_and_truth.py +++ b/xarray_batcher/get_fcst_and_truth.py @@ -70,7 +70,10 @@ def open_mfzarr( combined = datasets[0] # or some combination of concat, merge dates = [] for dataset in combined: - dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) + if time_idx is None: + dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) + else: + dates += list(np.unique(dataset.time.values)) return combined, dates @@ -189,11 +192,11 @@ def stream_ifs(truth_batch, offset=24, variables=None): # Get hours in the truth_batch times object # First need to convert to format with base units of hours to extract hour offset - hour = batch_time.astype("datetime64[h]").astype(object)[0].hour + hour = [time.hour for time in batch_time.astype("datetime64[h]").astype(object)] # Note that if hour is 0 then we add 24 as this # is the offset+24 - hour = hour + 24 * (hour == 0) + offset + hour = [h + 24 * (h == 0) + offset for h in hour] fcst_date, time_idx = match_fcst_to_valid_time(batch_time, hour) @@ -215,7 +218,7 @@ def stream_ifs(truth_batch, offset=24, variables=None): time_idx=time_idx, clip_to_window=False, ) - assert batch_time.astype("datetime64[D]") == np.unique(dates_modified) + assert np.all(np.isin(dates_modified, batch_time)) return ds diff --git a/xarray_batcher/loading.py b/xarray_batcher/loading.py index 3fb746c..5a51311 100644 --- a/xarray_batcher/loading.py +++ b/xarray_batcher/loading.py @@ -213,7 +213,8 @@ def streamline_and_normalise_ifs( xr.DataArray or xr.Dataset of streamline and normalised values - NOTE: We replace the time with the valid time NOT initia + NOTE: We replace the time with the valid time NOT initial fcst + time. """ @@ -230,6 +231,8 @@ def streamline_and_normalise_ifs( ) ) else: + if time_idx.shape[0] % 4 == 0: + time_idx = time_idx.reshape(-1, 4) assert da.fcst_valid_time.values.shape[0] == time_idx.shape[0] times = np.hstack( ( @@ -253,26 +256,36 @@ def streamline_and_normalise_ifs( else: for i_row, start in enumerate(time_idx): - - data.append( - retrieve_vars_ifs( - field, - all_data_mean[[i_row]], - all_data_sd[[i_row]], - start=start, - end=start + 1, + if isinstance(start, np.ndarray): + for s in start: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=s, + end=s + 1, + ) + ) + else: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=start, + end=start + 1, + ) ) - ) data = np.hstack((data)).reshape(-1, da.latitude.shape[0], da.longitude.shape[0], 4) - da = xr.DataArray( data=data, dims=["time", "lat", "lon", "i_x"], coords=dict( lon=da.longitude.values, lat=da.latitude.values, - time=times, + time=times.flatten(), i_x=np.arange(4), ), ) diff --git a/xarray_batcher/torch_streamer.py b/xarray_batcher/torch_streamer.py index e69de29..91d864c 100644 --- a/xarray_batcher/torch_streamer.py +++ b/xarray_batcher/torch_streamer.py @@ -0,0 +1,374 @@ +import dask +import numpy as np +import torch +import xbatcher +from scipy.spatial import KDTree + +from xarray_batcher.get_fcst_and_truth import get_all + +from .batch_helper_functions import Antialiasing, get_spherical + + +class StreamDataset(torch.utils.data.IterableDataset): + + """ + Similar as BatchDataset, see torch_batcher.py apart + from the new workflow to assist in streaming: + + 1) Start using only truth data + 2) Calculate sampler + 3) When iterating through truth, load in the fcst. + data on-the-fly + + """ + + def __init__( + self, + y, + variables, + constants, + batch_size: list[int] = [4, 128, 128], + weighted_sampler: bool = True, + for_NJ: bool = False, + for_val: bool = False, + antialiasing: bool = False, + ): + self.batch_size = batch_size + self.variables = variables + self.y_generator = xbatcher.BatchGenerator( + y, + {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, + input_overlap={ + "lat": int(batch_size[1] / 32), + "lon": int(batch_size[2] / 32), + }, + ) + constants["lat"] = np.round(y.lat.values, decimals=2) + constants["lon"] = np.round(y.lon.values, decimals=2) + + self.constants_generator = constants + + self.constants = list(constants.data_vars) + self.for_NJ = for_NJ + self.for_val = for_val + self.antialiasing = antialiasing + + if weighted_sampler: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + samples_weight = samples_weight / np.sum(samples_weight) + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + else: + self.sampler = None + + self.len = len(self.y_generator) + + def __iter__(self): + + if self.sampler is None: + idx = np.random.randint(0, self.len) + else: + idx = int(np.random.choice(self.len, p=self.samples_weight)) + + while True: + + yield self.__sample__(idx) + + def __sample__(self, idx): + y_batch = self.y_generator[idx] + time_batch = y_batch.time.values + lat_batch = np.round(y_batch.lat.values, decimals=2) + lon_batch = np.round(y_batch.lon.values, decimals=2) + + X_generator = get_all( + None, + model="ifs", + truth_batch=y_batch, + stream=True, + offset=24, + variables=self.variables, + ) + + X_batch = [] + for x, variable in zip(X_generator, self.variables): + X_batch.append(x[variable].values) + + X_batch = torch.from_numpy( + np.concatenate( + X_batch, + axis=-1, + ) + ).float() + + constant_batch = torch.from_numpy( + np.stack( + [ + self.constants_generator[constant] + .sel({"lat": lat_batch, "lon": lon_batch}) + .values + for constant in self.constants + ], + axis=-1, + ) + ).float() + + if self.for_NJ: + + elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) + lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) + spherical_coords = get_spherical( + lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values + ) + + kdtree = KDTree(spherical_coords) + + pairs = [] + + for i_coord, coord in enumerate(spherical_coords): + pairs.append( + np.vstack( + ( + np.full(3, fill_value=i_coord).reshape(1, -1), + kdtree.query(coord, k=3)[1], + ) + ) + ) + + pairs = np.hstack((pairs)) + + rainfall_path = torch.cat( + ( + torch.from_numpy( + y_batch.precipitation.fillna(0).values.reshape( + self.batch_size[0], -1, 1 + ) + ).float(), + X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), + ), + dim=-1, + ) + obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) + n_obs = np.array([self.batch_size[0]]) + if self.for_val: + obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) + n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) + obs_dates[: n_obs[0]] = 1 + + return { + "idx": idx, + "rainfall_path": rainfall_path[None, :, :, :], + "observed_dates": obs_dates, + "nb_obs": n_obs, + "dt": 1, + "edge_indices": pairs, + "obs_noise": None, + } + + else: + + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() + + else: + y_batch = torch.from_numpy( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() + return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) + + +class StreamTruth(torch.utils.data.Dataset): + + """ + class for iterating over a dataset + """ + + def __init__( + self, + y, + batch_size=[4, 128, 128], + weighted_sampler=True, + for_NJ=False, + for_val=False, + length=None, + antialiasing=False, + transform=None, + return_dataset=False, + ): + + self.batch_size = batch_size + self.for_NJ = for_NJ + self.for_val = for_val + self.length = length + self.antialiasing = antialiasing + self.transform = transform + self.return_dataset = return_dataset + overlap = ( + {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} + if for_NJ + else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} + ) + self.y_generator = xbatcher.BatchGenerator( + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) + + if weighted_sampler: + if self.for_NJ: + y_train = [ + self.y_generator[i].mean( + ["time", "latitude", "longitude"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + else: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + + def __len__(self) -> int: + return len(self.y_generator) + + def __getitem__(self, idx): + + y_batch = self.y_generator[idx] + + if self.return_dataset: + return y_batch + + if self.for_NJ: + + def generate(y_batch, length, stop=None): + + rng = np.random.default_rng() + random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] + ds_sel = y_batch.sel( + { + "time": slice( + "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) + ) + } + ) + + time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] + + time_to_event = rng.choice(np.arange(length), 1)[0] + time_after_event = length - time_to_event - 1 + + rainfall_path = ds_sel.sel( + { + "time": slice( + time_of_event - np.timedelta64(time_to_event * 30, "m"), + time_of_event + np.timedelta64(time_after_event * 30, "m"), + ), + } + ) + times_rainfall = rainfall_path.time.values + rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] + + if stop is not None: + # limit observations to once a day + nb_obs_single = stop + obs_ptr = np.arange(1, nb_obs_single) + + else: + nb_obs_single = length + obs_ptr = np.arange(1, length) + + observed_date = np.zeros(rainfall_path.shape[1]) + observed_date[0] = 1 + + for i_obs in obs_ptr: + + observed_date[i_obs] = 1 + + return rainfall_path, observed_date, nb_obs_single + + rainfall_paths = [] + observed_dates = [] + n_obs = [] + + stop = None + batch_size = 50 + if self.for_val: + rng = np.random.default_rng() + stop = rng.choice(np.arange(2, self.length - 100), 1)[0] + batch_size = 2 + + for i in range(batch_size): + rainfall_path, observed_date, nb_obs = generate( + y_batch, self.length, stop=stop + ) + rainfall_paths.append(rainfall_path) + observed_dates.append(observed_date) + n_obs.append(nb_obs) + + rainfall_paths = np.vstack(rainfall_paths) + observed_dates = np.stack(observed_dates) + n_obs = np.asarray(n_obs) + + return { + "idx": idx, + "rainfall_path": torch.tensor( + rainfall_paths[:, :, :, :, None], dtype=torch.float32 + ), + "observed_dates": observed_dates, + "nb_obs": n_obs, + "dt": 1, + "obs_noise": None, + } + + else: + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) + + else: + y_batch = torch.tensor( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], + dtype=torch.float32, + ) + if self.transform: + y_batch = self.transform(y_batch) + + return y_batch diff --git a/xarray_batcher/utils.py b/xarray_batcher/utils.py index 978319b..50ef143 100644 --- a/xarray_batcher/utils.py +++ b/xarray_batcher/utils.py @@ -154,9 +154,15 @@ def match_fcst_to_valid_time(valid_times, time_idx, step_type="h"): to select """ - time_offset = np.timedelta64(time_idx, step_type) - fcst_times = valid_times - time_offset + if not isinstance(time_idx, list): + time_offset = np.timedelta64(time_idx, step_type) + valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) + else: + time_offset = [np.timedelta64(t_idx, step_type) for t_idx in time_idx] + valid_date_idx = np.asarray( + [int(t_offset.astype(int) / TIME_RES) for t_offset in time_offset] + ) - valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) + fcst_times = valid_times - time_offset return fcst_times, valid_date_idx From b11093163c7a66f8a38ff293b8631b159d61f541 Mon Sep 17 00:00:00 2001 From: Shruti Nath Date: Wed, 2 Jul 2025 10:14:15 +0100 Subject: [PATCH 3/4] fix sample weight handling --- test_xbatcher.ipynb | 2 +- .../.ipynb_checkpoints/__init__-checkpoint.py | 8 + .../batch_helper_functions-checkpoint.py | 96 ++++ .../create_npz-checkpoint.py | 95 ++++ .../custom_collate_fn-checkpoint.py | 141 +++++ .../get_fcst_and_truth-checkpoint.py | 488 +++++++++++++++++ .../.ipynb_checkpoints/loading-checkpoint.py | 492 ++++++++++++++++++ .../normalise-checkpoint.py | 128 +++++ .../setup_data-checkpoint.py | 100 ++++ .../torch_batcher-checkpoint.py | 351 +++++++++++++ .../torch_streamer-checkpoint.py | 380 ++++++++++++++ .../.ipynb_checkpoints/utils-checkpoint.py | 168 ++++++ xarray_batcher/get_fcst_and_truth.py | 1 + xarray_batcher/loading.py | 2 + xarray_batcher/torch_streamer.py | 42 +- 15 files changed, 2475 insertions(+), 19 deletions(-) create mode 100644 xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py create mode 100644 xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py diff --git a/test_xbatcher.ipynb b/test_xbatcher.ipynb index 8f37016..ef7bc69 100644 --- a/test_xbatcher.ipynb +++ b/test_xbatcher.ipynb @@ -208,7 +208,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75d4787a-82b7-4374-b61a-fd59a8fdf880", + "id": "df9e88f2-5647-43f0-a99a-f20e21a9fbc5", "metadata": {}, "outputs": [], "source": [] diff --git a/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..772e0e6 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,8 @@ +## Initialisation for xarray batcher, import all helper functions +import sys + +from .setup_data import DataModule +from .torch_batcher import BatchDataset, BatchTruth +from .torch_streamer import StreamDataset, StreamTruth + +__all__ = ["DataModule", "BatchDataset", "BatchTruth", "StreamDataset", "StreamTruth"] diff --git a/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py new file mode 100644 index 0000000..efd7317 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py @@ -0,0 +1,96 @@ +import concurrent.futures +import multiprocessing + +import numpy as np +from scipy.ndimage import convolve + + +class Antialiasing: + def __init__(self): + (x, y) = np.mgrid[-2:3, -2:3] + self.kernel = np.exp(-0.5 * (x**2 + y**2) / (0.5**2)) + self.kernel /= self.kernel.sum() + self.edge_factors = {} + self.img_smooth = {} + num_threads = multiprocessing.cpu_count() + self.executor = concurrent.futures.ThreadPoolExecutor(num_threads) + + def __call__(self, img): + if img.ndim < 3: + img = img[None, None, :, :] + elif img.ndim < 4: + img = img[None, :, :, :] + img_shape = img.shape[-2:] + if img_shape not in self.edge_factors: + s = convolve( + np.ones(img_shape, dtype=np.float32), self.kernel, mode="constant" + ) + s = 1.0 / s + self.edge_factors[img_shape] = s + else: + s = self.edge_factors[img_shape] + + if img.shape not in self.img_smooth: + img_smooth = np.empty_like(img) + self.img_smooth[img_shape] = img_smooth + else: + img_smooth = self.img_smooth[img_shape] + + def _convolve_frame(i, j): + convolve( + img[i, j, :, :], + self.kernel, + mode="constant", + output=img_smooth[i, j, :, :], + ) + img_smooth[i, j, :, :] *= s + + futures = [] + for i in range(img.shape[0]): + for j in range(img.shape[1]): + args = (_convolve_frame, i, j) + futures.append(self.executor.submit(*args)) + concurrent.futures.wait(futures) + + return img_smooth + + +def get_spherical(lat, lon, elev, return_hstacked=True): + + """ + Get spherical coordinates of lat and lon, not assuming unit ball for radius + So we also take elev into account + + Inputs + ------ + + lat: np.array or xr.DataArray (n_lats,n_lons) + meshgrid of latitude points + + lon: np.array or xr.DataArray (n_lats,n_lons) + meshgrid of longitude points + + elev: np.array or xr.DataArray (n_lats,n_lons) + altitude values in m + + return_hstacked: boolean + typically for graph networks we collapse + lat lon into one dimension + Output + ------ + + r, sigma and phi + See: https://en.wikipedia.org/wiki/Spherical_coordinate_system + for more details + """ + + lat, lon = np.deg2rad(lat), np.deg2rad(lon) + + x = elev * np.cos(lat) * np.cos(lon) + y = elev * np.cos(lat) * np.sin(lon) + z = elev * np.sin(lat) + + if return_hstacked: + return np.hstack((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1))) + else: + return np.dstack((x[:, :, None], y[:, :, None], z[:, :, None])) diff --git a/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py new file mode 100644 index 0000000..ed35e38 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py @@ -0,0 +1,95 @@ +import os + +import numpy as np +import xarray as xr +from torch.utils.data import DataLoader + +from .batch_helper_functions import get_spherical +from .loading import get_IMERG_year +from .torch_batcher import BatchDataset, BatchTruth +from .utils import get_paths + +_, _, CONSTANTS_PATH = get_paths() +elev = xr.open_dataset(CONSTANTS_PATH + "elev.nc") + + +def collate_fn(batch, elev=elev, reg_dict={}): + + lat_batch = np.round(batch.lat.values, decimals=2) + lon_batch = np.round(batch.lon.values, decimals=2) + lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) + elev_values = elev.sel({"lat": lat_batch, "lon": lon_batch}, method="nearest") + elev_values = np.squeeze(elev_values.elevation.values) / 10000.0 + spherical_coords = get_spherical( + lat_values, lon_values, elev_values, return_hstacked=False + ) + + i = 0 + reg_sel = None + for reg in reg_dict.keys(): + if np.array_equal(reg_dict[reg]["spherical_coords"], spherical_coords): + reg_sel = reg + break + i += 1 + + if reg_sel is None: + reg_sel = i + reg_dict[reg_sel] = { + "elevation": elev_values, + "spherical_coords": spherical_coords, + "precipitation": [], + "time": [], + } + reg_dict[reg_sel]["precipitation"].append(batch.precipitation.values) + reg_dict[reg_sel]["time"].append(batch.time.values) + + return reg_dict + + +def TruthDataloader_to_Npz( + out_path, + years=[2018, 2019, 2020, 2021, 2023, 2024], + centre=[-1.25, 36.80], + months=None, + window_size=3, + collate_fn=collate_fn, +): + + if not os.path.exists(out_path): + os.makedirs(out_path) + + if months is None: + months = np.arange(1, 13).tolist() + elev = xr.open_dataset(CONSTANTS_PATH + "elev.nc") + for year in years: + + ds = get_IMERG_year(year, months=months, centre=centre, window_size=window_size) + if not isinstance(ds, xr.Dataset): + ds = ds.to_dataset() + + # load in truth to batcher without any weighting in sampler + ds = BatchTruth( + ds, batch_size=[1, 128, 128], weighted_sampler=False, return_dataset=True + ) + reg_dict = {} + + for batch in ds: + reg_dict = collate_fn(batch, elev=elev, reg_dict=reg_dict) + + spherical_coords = np.stack( + [reg_dict[key]["spherical_coords"] for key in reg_dict.keys()] + ) + elevation = np.stack([reg_dict[key]["elevation"] for key in reg_dict.keys()]) + precipitation = np.stack( + [np.vstack(reg_dict[key]["precipitation"]) for key in reg_dict.keys()] + ) + time = np.stack([np.hstack(reg_dict[key]["time"]) for key in reg_dict.keys()]) + print(time.shape) + + np.savez( + out_path + f"{year}_30min_IMERG_Nairobi_windowsize={window_size}.npz", + spherical_coords=spherical_coords, + elevation=elevation, + precipitation=precipitation, + time=time, + ) diff --git a/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py new file mode 100644 index 0000000..934df67 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py @@ -0,0 +1,141 @@ +import numpy as np +import torch + + +def _get_func(name): + """ + transform a function given as str to a python function + :param name: str, correspond to a function, + supported: 'exp', 'power-x' (x the wanted power) + :return: numpy fuction + """ + if name in ["exp", "exponential"]: + return np.exp + if "power-" in name: + x = float(name.split("-")[1]) + + def pow(input): + return np.power(input, x) + + return pow + else: + try: + return eval(name) + except Exception: + return None + + +def _get_X_with_func_appl(X, functions, axis): + """ + apply a list of functions to the paths in X and append X by the outputs + along the given axis + :param X: np.array, with the data, + :param functions: list of functions to be applied + :param axis: int, the data_dimension (not batch and not time dim) along + which the new paths are appended + :return: np.array + """ + Y = X + for f in functions: + Y = np.concatenate([Y, f(X)], axis=axis) + return Y + + +def CustomCollateFnGen(func_names=None): + """ + a function to get the costume collate function that can be used in + torch.DataLoader with the wanted functions applied to the data as new + dimensions + -> the functions are applied on the fly to the dataset, and this additional + data doesn't have to be saved + + :param func_names: list of str, with all function names, see _get_func + :return: collate function, int (multiplication factor of dimension before + and after applying the functions) + """ + # get functions that should be applied to X, additionally to identity + functions = [] + if func_names is not None: + for func_name in func_names: + f = _get_func(func_name) + if f is not None: + functions.append(f) + mult = len(functions) + 1 + + def custom_collate_fn(batch): + dt = batch[0]["dt"] + stock_paths = np.concatenate([b["rainfall_path"] for b in batch], axis=0) + observed_dates = np.concatenate([b["observed_dates"] for b in batch], axis=0) + # edge_indices = np.concatenate([b['edge_indices'] for b in batch], axis=0) + obs_noise = None + if batch[0]["obs_noise"] is not None: + obs_noise = np.concatenate([b["obs_noise"] for b in batch], axis=0) + masked = False + mask = None + if len(observed_dates.shape) == 3: + masked = True + mask = observed_dates + observed_dates = observed_dates.max(axis=1) + nb_obs = torch.tensor(np.concatenate([b["nb_obs"] for b in batch], axis=0)) + + # here axis=1, since we have elements of dim + # [batch_size, data_dimension] => add as new data_dimensions + sp = stock_paths[:, 0] + if obs_noise is not None: + sp = stock_paths[:, :, 0] + obs_noise[:, :, 0] + start_X = torch.tensor( + _get_X_with_func_appl(sp, functions, axis=1), dtype=torch.float32 + ) + X = [] + if masked: + M = [] + start_M = torch.tensor(mask[:, :, 0], dtype=torch.float32).repeat((1, mult)) + else: + M = None + start_M = None + times = [] + time_ptr = [0] + obs_idx = [] + current_time = 0.0 + counter = 0 + for t in range(1, observed_dates.shape[-1]): + current_time += dt + if observed_dates[:, t].sum() > 0: + times.append(current_time) + for i in range(observed_dates.shape[0]): + if observed_dates[i, t] == 1: + counter += 1 + # here axis=0, since only 1 dim (the data_dimension), + # i.e. the batch-dim is cummulated outside together + # with the time dimension + sp = stock_paths[i, t] + if obs_noise is not None: + sp = stock_paths[i, :, t] + obs_noise[i, :, t] + X.append(_get_X_with_func_appl(sp, functions, axis=0)) + if masked: + M.append(np.tile(mask[i, :, t], reps=mult)) + obs_idx.append(i) + time_ptr.append(counter) + # if obs_noise is not None: + # print("noisy observations used") + + assert len(obs_idx) == observed_dates[:, 1:].sum() + if masked: + M = torch.tensor(np.array(M), dtype=torch.float32) + res = { + "times": np.array(times), + "time_ptr": np.array(time_ptr), + "obs_idx": torch.tensor(obs_idx, dtype=torch.long), + "start_X": start_X, + "n_obs_ot": nb_obs, + "X": torch.tensor(np.array(X), dtype=torch.float32).permute(3, 1, 2, 0), + "true_paths": stock_paths, + "observed_dates": observed_dates, + "true_mask": mask, + "obs_noise": obs_noise, #'edge_indices': torch.from_numpy(edge_indices).long().contiguous(), + "M": M, + "start_M": start_M, + } + return res + + return custom_collate_fn, mult diff --git a/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py new file mode 100644 index 0000000..633d2f5 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py @@ -0,0 +1,488 @@ +import glob +import time + +import dask +import numpy as np +import xarray as xr + +from .loading import ( + load_hires_constants, + load_truth_and_mask, + streamline_and_normalise_ifs, +) +from .utils import get_paths, get_valid_dates, match_fcst_to_valid_time + +FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH = get_paths() + + +def open_mfzarr( + file_names, + use_modify=False, + centre=[-1.25, 36.80], + window_size=2, + lats=None, + lons=None, + months=[3, 4, 5, 6], + dates=None, + time_idx=None, + split_steps=[5, 6, 7, 8, 9], + clip_to_window=True, +): + """ + Open multiple files using dask delayed + + Inputs + ------ + file_names: list + list of file names to open typically one + file for each variable + kwargs: to be passed on to modify + + Outputs + ------- + List of xr.DataArray or xr.Dataset we avoid concatenation + as this takes to long, but through modify, we are confident + that the times align. + """ + + # this is basically what open_mfdataset does + open_kwargs = dict(decode_cf=True, decode_times=True) + open_tasks = [dask.delayed(xr.open_dataset)(f, **open_kwargs) for f in file_names] + + tasks = [ + dask.delayed(modify)( + task, + use_modify=use_modify, + centre=centre, + window_size=window_size, + lats=lats, + lons=lons, + months=months, + dates=dates, + time_idx=time_idx, + split_steps=split_steps, + clip_to_window=clip_to_window, + ) + for task in open_tasks + ] + + datasets = dask.compute(tasks) # get a list of xarray.Datasets + combined = datasets[0] # or some combination of concat, merge + dates = [] + for dataset in combined: + if time_idx is None: + dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) + else: + dates += list(np.unique(dataset.time.values)) + + return combined, dates + + +def modify( + ds, + use_modify=False, + centre=[-1.25, 36.80], + window_size=2, + lats=None, + lons=None, + months=[3, 4, 5, 6], + dates=None, + time_idx=None, + split_steps=[5, 6, 7, 8, 9], + clip_to_window=True, +): + + """ + Modification function to wrap around dask delayed compute + + Inputs + ------ + + use_modify: boolean + whether to apply modification function at all. + centre: list or tuple + centre around which to select a region when looking at sub-domains + window_size: integer + window size around centre to use in sub-domain selection + lats: ndarray or list + alternatively, latitudes can be given to sub-select + lons: ndarray or list + Ditto as lats + months: list + months to select if we are doing seasonal/monthly training + dates: list or 1-D array + dates to sub-select, typically is all months are used but too + expensive to downlaod. If time_idx is None, we assume these are + the dates of forecast issue + time_idx: list or 1-D array + valid time indices to select in index form, not absolute + value + Outputs + ------- + + xr.Dataset or xr.DataArray with modifications if use_midify=True or simply + without modifications + + **Note: when time_idx is provided, split_steps is ignored** + """ + + if use_modify: + name = [var for var in ds.data_vars] + if lats is not None and lons is not None: + lat_batch = np.round(lats, decimals=2) + lon_batch = np.round(lons, decimals=2) + + ds = ds.sel(time=ds.time.dt.month.isin(months)).sel( + { + "latitude": lat_batch, + "longitude": lon_batch, + } + ) + elif clip_to_window: + + ds = ds.sel(time=ds.time.dt.month.isin(months)).sel( + { + "latitude": slice(centre[0] - window_size, centre[0] + window_size), + "longitude": slice( + centre[1] - window_size, centre[1] + window_size + ), + } + ) + if dates is not None: + _, dates_intersect, _ = np.intersect1d( + ds.time.values.astype("datetime64[D]"), dates, return_indices=True + ) + ds = ds.isel(time=dates_intersect) + + ds = streamline_and_normalise_ifs( + name[1].split("_")[0], ds, time_idx=time_idx, split_steps=split_steps + ).to_dataset(name=name[1].split("_")[0]) + + return ds + + else: + return ds + + +def stream_ifs(truth_batch, offset=24, variables=None): + """ + Input + ----- + truth_batch: xr.DataArray or xr.Dataset + truth values of a single batch item + to match and load + fcst data for + offset: int + day offset to factor in, should be in hours + + variables: list or None + variables to load, if None then all are + loaded. + Output + ------ + + forecast batch as ndarray + """ + + batch_time = truth_batch.time.values + batch_lats = truth_batch.lat.values + batch_lons = truth_batch.lon.values + + if not isinstance(batch_time, np.ndarray): + batch_time = np.asarray(batch_time) + + # Get hours in the truth_batch times object + # First need to convert to format with base units of hours to extract hour offset + hour = [time.hour for time in batch_time.astype("datetime64[h]").astype(object)] + + # Note that if hour is 0 then we add 24 as this + # is the offset+24 + hour = [h + 24 * (h == 0) + offset for h in hour] + + fcst_date, time_idx = match_fcst_to_valid_time(batch_time, hour) + + year = fcst_date.astype("datetime64[D]").astype(object)[0].year + month = fcst_date.astype("datetime64[D]").astype(object)[0].month + + if variables is None: + files = sorted(glob.glob(FCST_PATH_IFS + str(year) + "/" + "*.nc")) + else: + files = [FCST_PATH_IFS + str(year) + "/" + "%s.nc" % var for var in variables] + + ds, dates_modified = open_mfzarr( + files, + use_modify=True, + lats=batch_lats, + lons=batch_lons, + months=[month], + dates=fcst_date, + time_idx=time_idx, + clip_to_window=False, + ) + assert np.all(np.isin(dates_modified, batch_time)) + return ds + + +def get_whole_year_ifs( + years, + centre=[-1.25, 36.80], + window_size=30, + months=[3, 4, 5, 6], + n_days=None, + split_steps=[5, 6, 7, 8, 9], + ignore_truth=False, + variables=None, + clip_to_window=True, +): + dates_all = [] + + if ignore_truth: + dates_year = [ + list( + np.arange( + "%i-01-01" % year, + "%i-01-01" % (year + 1), + np.timedelta64(1, "D"), + dtype="datetime64[D]", + ).astype("str") + ) + for year in years + ] + else: + dates_year = [get_valid_dates(year, raw_list=True) for year in years] + + for dates in dates_year: + start_time = time.time() + + if len(months) == 12 and n_days is not None: + dates_sel = np.random.choice( + np.array(dates, dtype="datetime64[D]"), n_days, replace=False + ) + else: + dates_sel = None + + dates_final = np.array(dates.copy(), dtype="datetime64[D]") + year = dates_final[0].astype(object).year + + if variables is None: + files = sorted(glob.glob(FCST_PATH_IFS + str(year) + "/" + "*.nc")) + else: + files = [ + FCST_PATH_IFS + str(year) + "/" + "%s.nc" % var for var in variables + ] + + ds, dates_modified = open_mfzarr( + files, + use_modify=True, + centre=centre, + window_size=window_size, + months=months, + dates=dates_sel, + split_steps=split_steps, + clip_to_window=clip_to_window, + ) + + if year == years[0]: + ds_vars = ds + else: + ds_vars = [ + xr.concat([ds_1, ds_2], "time") for ds_1, ds_2 in zip(ds_vars, ds) + ] + + dates_final = np.append(dates_final, dates_modified, axis=0) + + del ds + + print( + "Extracted all %i variables in ----" % len(files), + time.time() - start_time, + "s---- for year", + year, + ) + + dates_final, dates_count = np.unique(dates_final, return_counts=True) + dates_idx = np.squeeze(np.argwhere(dates_count == (len(files) + 1))) + dates_final = dates_final[dates_idx] + dates_all += [str(date) for date in dates_final] + # print(len(dates)-len(dates_final)," missing dates in year", year) + + if ignore_truth: + return ds_vars + + print("Now doing truth values") + start_time = time.time() + # time_idx is hard-coded in here as forecast is made to have time as valid_time + ds_truth_and_mask = load_truth_and_mask( + np.array(dates_all, dtype="datetime64[ns]").flatten(), + time_idx=[1, 2, 3, 4], + ) + if dates_sel is not None: + # Because of 6 hour offset when streamlining select dates 6AM to midnight is used + # Meaning that the next day midnight is in truth but no other time step in that date. + # Therefore, need to guarantee alignment in times. + ds_truth_and_mask = ds_truth_and_mask.drop_duplicates("time") + times_sel = np.intersect1d( + ds_vars[0].time.values, ds_truth_and_mask.time.values + ) + ds_truth_and_mask = ds_truth_and_mask.sel({"time": times_sel}) + ds_constants = load_hires_constants(batch_size=1) + + print( + "Finished retrieving truth values in ----", + time.time() - start_time, + "s---- for year", + year, + ) + + return ( + ds_vars, + ds_truth_and_mask.rename({"latitude": "lat", "longitude": "lon"}) + .sel(time=ds_truth_and_mask.time.dt.month.isin(months)) + .sel( + { + "lat": slice(centre[0] - window_size, centre[0] + window_size), + "lon": slice(centre[1] - window_size, centre[1] + window_size), + } + ), + ds_constants.sel( + { + "lat": slice(centre[0] - window_size, centre[0] + window_size), + "lon": slice(centre[1] - window_size, centre[1] + window_size), + } + ), + ) + + +def get_all( + years, + model="ifs", + offset=None, + stream=False, + truth_batch=None, + time_idx=[5, 6, 7, 8], + split_steps=[5, 6, 7, 8, 9], + ignore_truth=False, + variables=None, + months=None, + n_days=10, + centre=[-1.25, 36.80], + window_size=30, + clip_to_window=False, + log_precip=True, +): + + """ + Wrapper function to return either: + + * IFS data alongside truth data fully loaded into memory. + This is recommended if an npz file is being created for + example. + * stream IFS data in which case truth_batch should be given + and offset is the no. days (in hours) lead time used to + obtain the initialisation time from truth batch valid time + (See stream_ifs for further details) + * Obtain only truth data (when model="truth") + + Inputs + ------ + years: list + list of years to calculate over + model: str + either truth or ifs (later additions possible) + default='ifs' + offset: int or None + Hour offset that accounts for no. days lead time + needed to match back to intiialisation time + during streaming. + stream: boolean + whether or not to stream data in which case + truth batch should be provided + truth_batch: xr.DataArray or xr.Dataset + truth batch to obtain forecast variables from + time_idx: list + time_idx to obtain truth + split_steps: list + fcst lead times to take when loading in IFS + ignore_truth: boolean + passed on to get_whole_year_ifs, whether to + only load in the fcst and ignore turht + variables: list or None + variables to load, if None then all are + loaded. + months: ndarray or list or None + months to load in in case of seasonal/sub-seasonal + training, if None all months are used + centre: list or tuple + centre around which to select a region when looking + at sub-domains, only used when clip_to_window is true + window_size: integer + window size around centre to use in sub-domain + selection, only used when clip_to_window is true + clip_to_window: boolean + passed on to get_whole_year_ifs or simply + getting truth, to signal when to clip to a + window (size window_size) around centre + log_precip: boolean + passed on to load_truth_and_mask when model is "truth + default=True + + Outputs: + ------- + list of xr.DataArray or streamed or whole IFS+truth or simply truth data + """ + + if months is None: + months = np.arange(1, 13).tolist() + + if model == "ifs": + if not stream: + return get_whole_year_ifs( + years, + split_steps=split_steps, + ignore_truth=ignore_truth, + variables=variables, + months=months, + n_days=n_days, + centre=centre, + window_size=window_size, + clip_to_window=clip_to_window, + ) + + else: + assert truth_batch is not None + assert offset is not None + return stream_ifs(truth_batch, offset=offset, variables=variables) + + elif model == "truth": + dates_all = [] + for year in years: + dates = get_valid_dates(year) + dates_all += [date.strftime("%Y-%m-%d") for date in dates] + + print("Only getting truth values over,", len(dates_all), "dates") + start_time = time.time() + ds_truth_and_mask = load_truth_and_mask( + np.array(dates_all, dtype="datetime64[ns]").flatten(), + time_idx=time_idx, + log_precip=log_precip, + ) + + print( + "Finished retrieving truth values in ----", + time.time() - start_time, + "s---- for years", + years, + ) + if clip_to_window: + ds_truth_and_mask = ds_truth_and_mask.sel( + { + "latitude": slice(centre[0] - window_size, centre[0] + window_size), + "longitude": slice( + centre[1] - window_size, centre[1] + window_size + ), + } + ) + + return ds_truth_and_mask.sel( + time=ds_truth_and_mask.time.dt.month.isin(months) + ).rename({"latitude": "lat", "longitude": "lon"}) diff --git a/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py new file mode 100644 index 0000000..704fbef --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py @@ -0,0 +1,492 @@ +import datetime +import glob +import time + +import h5py +import numpy as np +import xarray as xr +from tqdm import tqdm + +from .normalise import convert_units, get_norm, logprec, nonnegative +from .utils import get_metadata, get_paths + +(FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH) = get_paths() + +(fcst_time_res, time_res, lonlatbox, fcst_spat_res) = get_metadata() + + +def daterange(start_date, end_date): + """ + Generator to get date range for a given time period + from start_date to end_date + """ + + for n in range(int((end_date - start_date).days)): + yield start_date + datetime.timedelta(days=n) + + +def get_lonlat(): + """ + Function to get longitudes and latitudes of forecast and truth data + + Input + ------ + + lonlatbox: list of int with length 4 + bottom, left, top, right corners of lon-lat box + fcst_spat_res: float + spatial resolution of forecasts + + Output + ------ + + centres of forecast (lon_reg, lat_reg), + and truth (lon__reg_TRUTH, lat_reg_TRUTH), and their box + edges: (lon_reg_b, lat_reg_b) and (lon__reg_TRUTH, + lat_reg_TRUTH) for forecasts and truth resp. + + """ + assert len(lonlatbox) == 4 + + lat_reg_b = np.arange(lonlatbox[0], lonlatbox[2], fcst_spat_res) - fcst_spat_res / 2 + lat_reg = 0.5 * (lat_reg_b[1:] + lat_reg_b[:-1]) + + lon_reg_b = np.arange(lonlatbox[1], lonlatbox[3], fcst_spat_res) - fcst_spat_res / 2 + lon_reg = 0.5 * (lon_reg_b[1:] + lon_reg_b[:-1]) + + data_path = glob.glob(TRUTH_PATH + "*.nc") + + ds = xr.open_mfdataset(data_path[0]) + # print(ds) + ##infer spatial resolution of truth, we assume a standard lon lat grid! + + lat_reg_TRUTH = ds.latitude.values + lon_reg_TRUTH = ds.longitude.values + + TRUTH_RES = np.abs(lat_reg_TRUTH[1] - lat_reg_TRUTH[0]) + + lat_reg_TRUTH_b = np.append( + (lat_reg_TRUTH - TRUTH_RES / 2), lat_reg_TRUTH[-1] + TRUTH_RES / 2 + ) + lon_reg_TRUTH_b = np.append( + (lon_reg_TRUTH - TRUTH_RES / 2), lon_reg_TRUTH[-1] + TRUTH_RES / 2 + ) + + return ( + lon_reg, + lat_reg, + lon_reg_b, + lat_reg_b, + lon_reg_TRUTH, + lat_reg_TRUTH, + lon_reg_TRUTH_b, + lat_reg_TRUTH_b, + ) + + +def get_IMERG_lonlat(): + + # A single IMERG data file to get latitude and longitude + IMERG_file_name = "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07/2018/Jan/3B-HHR.MS.MRG.3IMERG.20180116-S120000-E122959.0720.V07B.HDF5" + + # HDF5 in the ICPAC region + h5_file = h5py.File(IMERG_file_name) + latitude = h5_file["Grid"]["lat"][763:1147] + longitude = h5_file["Grid"]["lon"][1991:2343] + h5_file.close() + + return latitude, longitude + + +def prepare_year_and_month_input(years, months): + + if not isinstance(years, list): + assert isinstance(years, int) + years = [years, years] + if not isinstance(months, list): + assert isinstance(months, int) + months = [months, months] + + assert len(years) > 1 + assert len(months) > 1 + + years = np.sort(years) + months = np.sort(months) + + assert years[-1] >= years[0] + assert months[-1] >= months[0] + + year_beg = years[0] + year_end = years[-1] + + month_beg = months[0] + month_end = months[-1] + + if month_end == 12: + year_end += 1 + else: + month_end += 1 + + return year_beg, year_end, month_beg, month_end + + +def retrieve_vars_ifs(field, all_data_mean, all_data_sd, start=1, end=2): + + if field in ["tp", "cp", "ssr"]: + # return mean, sd, 0, 0. zero fields are so + # that each field returns a 4 x ny x nx array. + # accumulated fields have been pre-processed + # s.t. data[:, j, :, :] has accumulation between times j and j+1 + data1 = all_data_mean[:, start:end, :, :].reshape( + -1, all_data_mean.shape[2], all_data_mean.shape[3] + ) + data2 = all_data_sd[:, start:end, :, :].reshape( + -1, all_data_sd.shape[2], all_data_sd.shape[3] + ) + data3 = np.zeros(data1.shape) + data = np.stack([data1, data2, data3, data3], axis=-1)[:, None, :, :, :] + + else: + temp_data_mean_start = all_data_mean[:, start:end, :, :].reshape( + -1, all_data_mean.shape[2], all_data_mean.shape[3] + ) + temp_data_mean_end = all_data_mean[:, end : end + 1, :, :].reshape( + -1, all_data_mean.shape[2], all_data_mean.shape[3] + ) + temp_data_sd_start = all_data_sd[:, start:end, :, :].reshape( + -1, all_data_sd.shape[2], all_data_sd.shape[3] + ) + temp_data_sd_end = all_data_sd[:, end : end + 1, :, :].reshape( + -1, all_data_sd.shape[2], all_data_sd.shape[3] + ) + + data = np.stack( + [ + temp_data_mean_start, + temp_data_sd_start, + temp_data_mean_end, + temp_data_sd_end, + ], + axis=-1, + )[:, None, :, :, :] + + return data + + +def streamline_and_normalise_ifs( + field, + da, + log_prec=True, + norm=True, + time_idx=None, + split_steps=[5, 6, 7, 8, 9], +): + """ + Streamline IFS date to: + * Have appropriate valid time from time of forecast initialization + * If time_idx are provided then we directly select based on that + * Otherwise we select based on split_steps (default 30 - 54 hour lead time) + + Inputs + ------ + + field: str + field to select, needed to check accumulated or non- + negative field + da: xr.DataArray or xr.Dataset + data over which to streamline and normalise + log_prec: boolean + whether to calculate the log of precipitation, + default=True. + norm: boolean + whether to normalise or not, default = True + split_steps: list or 1-D array + valid_time steps to iterate over + default=[5,6,7,8,9] + time_idx: 1-D array or None + instead of split-steps if we have a more randomised + selection of valid time to operate on for each + initialisation time available + + Outputs + ------- + + xr.DataArray or xr.Dataset of streamline and normalised values + + NOTE: We replace the time with the valid time NOT initial fcst + time. + + """ + + all_data_mean = da[f"{field}_mean"].values + all_data_sd = da[f"{field}_sd"].values + + da.close() + + if time_idx is None: + times = np.hstack( + ( + [ + time[split_steps[0] : split_steps[-1]] + for time in da.fcst_valid_time.values + ] + ) + ) + else: + if time_idx.shape[0] % 4 == 0: + time_idx = time_idx.reshape(-1, 4) + assert da.fcst_valid_time.values.shape[0] == time_idx.shape[0] + times = np.hstack( + ( + [ + time[[idx_t]] + for time, idx_t in zip(da.fcst_valid_time.values, time_idx) + ] + ) + ) + + data = [] + + if time_idx is None: + for start, end in zip(split_steps[:4], split_steps[1:5]): + + data.append( + retrieve_vars_ifs( + field, all_data_mean, all_data_sd, start=start, end=end + ) + ) + else: + + for i_row, start in enumerate(time_idx): + if isinstance(start, np.ndarray): + for s in start: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=s, + end=s + 1, + ) + ) + else: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=start, + end=start + 1, + ) + ) + + data = np.hstack((data)).reshape(-1, da.latitude.shape[0], da.longitude.shape[0], 4) + da = xr.DataArray( + data=data, + dims=["time", "lat", "lon", "i_x"], + coords=dict( + lon=da.longitude.values, + lat=da.latitude.values, + time=times.flatten(), + i_x=np.arange(4), + ), + ) + + if field in [ + "cape", + "cp", + "mcc", + "sp", + "ssr", + "t2m", + "tciw", + "tclw", + "tcrw", + "tcw", + "tcwv", + "tp", + ]: + da = nonnegative(da) + + da = convert_units(da, field, log_prec, m_to_mm=True) + + if norm: + da = get_norm(da, field) + + return da.where(np.isfinite(da), 0).sortby("time") + + +def get_IMERG_year( + years, + months=[3, 4, 5, 6], + centre=[-1.25, 36.80], + window_size=30, + clip_to_window=False, +): + + year_beg, year_end, month_beg, month_end = prepare_year_and_month_input( + years, months + ) + + latitude, longitude = get_IMERG_lonlat() + + # Load the IMERG data averaged over 6h periods + d = datetime.datetime(year_beg, month_beg, 1, 6) + + if year_end != year_beg: + d_end = datetime.datetime(year_end, 1, 1, 6) + times_xr = np.arange( + "%s-%s-01" % (str(year_beg), str(month_beg).zfill(2)), + "%s-%s-01" % (str(year_end), str(1).zfill(2)), + np.timedelta64(30, "m"), + dtype="datetime64[ns]", + ) + else: + d_end = datetime.datetime(year_end, month_end, 1, 6) + times_xr = np.arange( + "%s-%s-01" % (str(year_beg), str(month_beg).zfill(2)), + "%s-%s-01" % (str(year_end), str(month_end).zfill(2)), + np.timedelta64(30, "m"), + dtype="datetime64[ns]", + ) + # Number of 30 minutes rainfall periods + num_time_pts = (d_end - d).days * 48 + + # The 6h average rainfall + rain_IMERG = np.full([num_time_pts, len(longitude), len(latitude)], np.nan) + + start_time = time.time() + + time_idx = 0 + progbar = tqdm(total=int((d_end - d).days) * 2 * 24) + while d < d_end: + + if d.month not in np.arange(month_beg, month_end): + progbar.update(1) + # Move to the next timesetp + d += datetime.timedelta(minutes=30) + time_idx += 1 + continue + + # Load an IMERG file with the current date + d2 = d + datetime.timedelta(seconds=30 * 60 - 1) + # Number of minutes since 00:00 + count = int((d - datetime.datetime(d.year, d.month, d.day)).seconds / 60) + IMERG_file_name = ( + "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07" + + "/%s/%s/" % (str(d.year), str(d.strftime("%b"))) + + f"3B-HHR.MS.MRG.3IMERG.{d.year}{d.month:02d}{d.day:02d}-S{d.hour:02d}{d.minute:02d}00-" + + f"E{d2.hour:02d}{d2.minute:02d}{d2.second:02d}.{count:04d}.V07B.HDF5" + ) + + h5_file = h5py.File(IMERG_file_name) + times = h5_file["Grid"]["time"][:] + + # Check the time is correct + # if (d != datetime(1970,1,1) + timedelta(seconds=int(times[0]))): + # print(f"Incorrect time for {d}", datetime(1970,1,1) + timedelta(seconds=int(times[0]))) + + # Accumulate the rainfall + rain_IMERG[time_idx, :, :] = h5_file["Grid"]["precipitation"][ + 0, 1991:2343, 763:1147 + ] + h5_file.close() + + # Move to the next timesetp + d += datetime.timedelta(minutes=30) + + # Move to the next time index + time_idx += 1 + progbar.update(1) + + progbar.close() + + # Put into the same order as the IFS and cGAN data + rain_IMERG = np.moveaxis(rain_IMERG, [0, 1, 2], [0, 2, 1]) + + obs = xr.DataArray( + data=rain_IMERG.reshape(-1, len(latitude), len(longitude)), + dims=["time", "lat", "lon"], + coords={ + "time": times_xr, + "lat": latitude, + "lon": longitude, + }, + attrs=dict(description="IMERG 30 min precipitation", units="mm"), + ).rename("precipitation") + if clip_to_window: + obs = obs.sel( + { + "lat": slice(centre[0] - window_size, centre[0] + window_size), + "lon": slice(centre[1] - window_size, centre[1] + window_size), + } + ) + print( + "Finished loading in IMERG data in ----%.2f s-----" % (time.time() - start_time) + ) + + return obs.dropna("time", how="all") + + +def load_truth_and_mask(dates, time_idx=[5, 6, 7, 8], log_precip=True, normalise=True): + """ + Returns a single (truth, mask) item of data. + Parameters: + date: forecast start date + time_idx: forecast 'valid time' array index + log_precip: whether to apply log10(1+x) transformation + """ + ds_to_concat = [] + + for date in tqdm(dates): + date = str(date).split("T")[0].replace("-", "") + + # convert date and time_idx to get the correct truth file + fcst_date = datetime.datetime.strptime(date, "%Y%m%d") + + for idx_t in time_idx: + valid_dt = fcst_date + datetime.timedelta( + hours=int(idx_t) * time_res + ) # needs to change for 12Z forecasts + + fname = valid_dt.strftime("%Y%m%d_%H") + # print(fname) + + data_path = glob.glob(TRUTH_PATH + f"{date[:4]}/{fname}.nc") + if len(data_path) < 1: + break + # ds = xr.concat([xr.open_dataset(dataset).expand_dims(dim={'time':i}, + # axis=0) + # for i,dataset in enumerate(data_path)],dim='time').mean('time') + ds = xr.open_dataset(data_path[0]) + if log_precip: + ds["precipitation"] = logprec(ds["precipitation"]) + + # mask: False for valid truth data, True for invalid truth data + # (compatible with the NumPy masked array functionality) + # if all data is valid: + mask = ~np.isfinite(ds["precipitation"]) + ds["mask"] = mask + + ds_to_concat.append(ds) + + return xr.concat(ds_to_concat, dim="time") + + +def load_hires_constants(batch_size=1): + """ + + Get elevation and land sea mask on IMERG resolution + + """ + + oro_path = CONSTANTS_PATH + "elev.nc" + + lsm_path = CONSTANTS_PATH + "lsm.nc" + + ds = xr.open_mfdataset([lsm_path, oro_path]) + + # LSM is already 0:1 + ds["elevation"] = ds["elevation"] / 10000.0 + + return ds.expand_dims(dim={"time": batch_size}).compute() diff --git a/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py new file mode 100644 index 0000000..84a594e --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py @@ -0,0 +1,128 @@ +## Normalisation functions, note to self you can apply universal functions from numpy as scipy that operate element wise on xarray +## not too many function comments as I feel like they are self-explanatory + +import numpy as np + +from .utils import get_metadata, load_fcst_norm + +## Unfortunately need to have this look up table, not sure what a work around is +precip_fields = ["Convective precipitation (water)", "Total Precipitation", "cp", "tp"] + +accumulated_fields = ["ssr"] + +## Normalisation to apply !!! make sure a field doesn't appear twice!!! +standard_scaling = ["Surface pressure", "2 metre temperature", "sp", "t2m"] +maximum_scaling = [ + "Convective available potential energy", + "Upward short-wave radiation flux", + "Downward short-wave radiation flux", + "Cloud water", + "Precipitable water", + "Ice water mixing ratio", + "Cloud mixing ratio", + "Rain mixing ratio", + "cape", + "ssr", + "tciw", + "tclw", + "tcrw", + "tcw", + "tcwv", +] +absminimum_maximum_scaling = [ + "U component of wind", + "V component of wind", + "u700", + "v700", +] + +fcst_norm = load_fcst_norm() +## get some standard stuff from utils +fcst_time_res, time_res, lonlatbox, fcst_spat_res = get_metadata() + + +def logprec(data, threshold=0.1, fill_value=0.02, mean=0.051, std=0.266): + log_scale = np.log10(1e-1 + data).astype(np.float32) + if threshold is not None: + log_scale.where(log_scale > np.log10(threshold), np.log10(fill_value)) + + log_scale.fillna(np.log10(fill_value)) + log_scale -= mean + log_scale /= std + return log_scale + + +def nonnegative(data): + return np.maximum(data, 0.0) # eliminate any data weirdness/regridding issues + + +def m_to_mm_per_hour(data, time_res): + data *= 1000 + return data / time_res # convert to mm/hr + + +def to_per_second(data, time_res): + # for all other accumulated fields [just ssr for us] + return data / ( + time_res * 3600 + ) # convert from a 6-hr difference to a per-second rate + + +def centre_at_mean(data, field): + # these are bounded well away from zero, so subtract mean from ens mean (but NOT from ens sd!) + return data - fcst_norm[field]["mean"] + + +def change_to_unit_std(data, field): + return data / fcst_norm[field]["std"] + + +def max_scaling(data, field): + return (data) / (fcst_norm[field]["max"]) + + +def absmin_max_scaling(data, field): + return data / max(-fcst_norm[field]["min"], fcst_norm[field]["max"]) + + +def convert_units(data, field, log_prec, m_to_mm=True): + if field in precip_fields: + + if m_to_mm: + data = m_to_mm_per_hour(data, time_res) + + if log_prec: + return logprec(data) + + else: + return data + + elif field in accumulated_fields: + data = to_per_second(data, time_res) + + return data + + else: + return data + + +def get_norm(data, field, location_of_vals=[0, 2]): + + if field in precip_fields: + return data + + if field in standard_scaling: + data.loc[{"i_x": location_of_vals}] = centre_at_mean( + data.sel({"i_x": location_of_vals}), field + ) + + return change_to_unit_std(data, field) + + if field in maximum_scaling: + return max_scaling(data, field) + + if field in absminimum_maximum_scaling: + return absmin_max_scaling(data, field) + + else: + return data diff --git a/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py new file mode 100644 index 0000000..4009919 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py @@ -0,0 +1,100 @@ +import copy +import sys + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms + +from .custom_collate_fn import CustomCollateFnGen +from .loading import get_IMERG_year +from .torch_batcher import BatchTruth + + +class DataModule(pl.LightningDataModule): + def __init__( + self, + train_years=2018, + val_years=2018, + test_years=None, + xbatch_size=[None, 64, 64], + batch_size=8, + train_epoch_size=1000, + valid_epoch_size=200, + test_epoch_size=1000, + ): + super().__init__() + + self.datasets = {} + self.batch_size = batch_size + + mult = 1 + if train_years is not None: + + df_truth = get_IMERG_year(train_years, months=2) + self.datasets["train"] = BatchTruth( + df_truth, + batch_size=xbatch_size, + antialiasing=True, + transform=transforms.RandomVerticalFlip(p=0.5), + for_NJ=True, + length=240, + ) + + if val_years is not None: + + df_truth = get_IMERG_year(val_years, months=2) + self.datasets["valid"] = BatchTruth( + df_truth, + batch_size=[xbatch_size[0], 300, 300], + weighted_sampler=False, + antialiasing=True, + transform=transforms.RandomVerticalFlip(p=0.5), + for_NJ=True, + length=240, + ) + + if test_years is not None: + df_truth = get_IMERG_year(val_years, months=2) + self.datasets["test"] = BatchTruth( + df_truth, + batch_size=[xbatch_size[0], 300, 300], + weighted_sampler=False, + antialiasing=True, + for_NJ=True, + length=240, + ) + + else: + self.datasets["test"] = self.datasets["valid"] + + def dataloader(self, split): + collate_fn, mult = CustomCollateFnGen(None) + if split == "train": + + return DataLoader( + self.datasets[split], + batch_size=self.batch_size, + collate_fn=collate_fn, + pin_memory=True, + num_workers=0, + sampler=self.datasets[split].sampler, + drop_last=True, + ) + else: + return DataLoader( + self.datasets[split], + collate_fn=collate_fn, + pin_memory=True, + num_workers=0, + drop_last=True, + ) + + def train_dataloader(self): + return self.dataloader("train") + + def val_dataloader(self): + return self.dataloader("valid") + + def test_dataloader(self): + return self.dataloader("test") diff --git a/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py new file mode 100644 index 0000000..fbc3628 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py @@ -0,0 +1,351 @@ +import dask +import numpy as np +import torch +import xbatcher +from scipy.spatial import KDTree +from tqdm import tqdm +from tqdm.dask import TqdmCallback + +from .batch_helper_functions import Antialiasing, get_spherical + + +class BatchDataset(torch.utils.data.Dataset): + + """ + class for iterating over a dataset + """ + + def __init__( + self, + X, + y, + constants, + batch_size: list[int] = [4, 128, 128], + weighted_sampler: bool = True, + for_NJ: bool = False, + for_val: bool = False, + antialiasing: bool = False, + ): + self.batch_size = batch_size + self.X_generator = X + self.y_generator = xbatcher.BatchGenerator( + y, + {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, + input_overlap={ + "lat": int(batch_size[1] / 32), + "lon": int(batch_size[2] / 32), + }, + ) + constants["lat"] = np.round(y.lat.values, decimals=2) + constants["lon"] = np.round(y.lon.values, decimals=2) + + self.constants_generator = constants + + self.variables = [list(x.data_vars)[0] for x in X] + self.constants = list(constants.data_vars) + self.for_NJ = for_NJ + self.for_val = for_val + self.antialiasing = antialiasing + + if weighted_sampler: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + + def __len__(self) -> int: + return len(self.y_generator) + + def __getitem__(self, idx): + + y_batch = self.y_generator[idx] + time_batch = y_batch.time.values + lat_batch = np.round(y_batch.lat.values, decimals=2) + lon_batch = np.round(y_batch.lon.values, decimals=2) + + X_batch = [] + for x, variable in zip(self.X_generator, self.variables): + X_batch.append( + x[variable] + .sel({"time": time_batch, "lat": lat_batch, "lon": lon_batch}) + .values + ) + + X_batch = torch.from_numpy( + np.concatenate( + X_batch, + axis=-1, + ) + ).float() + + constant_batch = torch.from_numpy( + np.stack( + [ + self.constants_generator[constant] + .sel({"lat": lat_batch, "lon": lon_batch}) + .values + for constant in self.constants + ], + axis=-1, + ) + ).float() + + if self.for_NJ: + + elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) + lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) + spherical_coords = get_spherical( + lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values + ) + + kdtree = KDTree(spherical_coords) + + pairs = [] + + for i_coord, coord in enumerate(spherical_coords): + pairs.append( + np.vstack( + ( + np.full(3, fill_value=i_coord).reshape(1, -1), + kdtree.query(coord, k=3)[1], + ) + ) + ) + + pairs = np.hstack((pairs)) + + rainfall_path = torch.cat( + ( + torch.from_numpy( + y_batch.precipitation.fillna(0).values.reshape( + self.batch_size[0], -1, 1 + ) + ).float(), + X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), + ), + dim=-1, + ) + obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) + n_obs = np.array([self.batch_size[0]]) + if self.for_val: + obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) + n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) + obs_dates[: n_obs[0]] = 1 + + return { + "idx": idx, + "rainfall_path": rainfall_path[None, :, :, :], + "observed_dates": obs_dates, + "nb_obs": n_obs, + "dt": 1, + "edge_indices": pairs, + "obs_noise": None, + } + + else: + + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() + + else: + y_batch = torch.from_numpy( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() + return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) + + +class BatchTruth(torch.utils.data.Dataset): + + """ + class for iterating over a dataset + """ + + def __init__( + self, + y, + batch_size=[4, 128, 128], + weighted_sampler=True, + for_NJ=False, + for_val=False, + length=None, + antialiasing=False, + transform=None, + return_dataset=False, + ): + + self.batch_size = batch_size + self.for_NJ = for_NJ + self.for_val = for_val + self.length = length + self.antialiasing = antialiasing + self.transform = transform + self.return_dataset = return_dataset + overlap = ( + {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} + if for_NJ + else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} + ) + self.y_generator = xbatcher.BatchGenerator( + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) + + if weighted_sampler: + if self.for_NJ: + y_train = [ + self.y_generator[i].mean( + ["time", "latitude", "longitude"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + else: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + + def __len__(self) -> int: + return len(self.y_generator) + + def __getitem__(self, idx): + + y_batch = self.y_generator[idx] + + if self.return_dataset: + return y_batch + + if self.for_NJ: + + def generate(y_batch, length, stop=None): + + rng = np.random.default_rng() + random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] + ds_sel = y_batch.sel( + { + "time": slice( + "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) + ) + } + ) + + time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] + + time_to_event = rng.choice(np.arange(length), 1)[0] + time_after_event = length - time_to_event - 1 + + rainfall_path = ds_sel.sel( + { + "time": slice( + time_of_event - np.timedelta64(time_to_event * 30, "m"), + time_of_event + np.timedelta64(time_after_event * 30, "m"), + ), + } + ) + times_rainfall = rainfall_path.time.values + rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] + + if stop is not None: + # limit observations to once a day + nb_obs_single = stop + obs_ptr = np.arange(1, nb_obs_single) + + else: + nb_obs_single = length + obs_ptr = np.arange(1, length) + + observed_date = np.zeros(rainfall_path.shape[1]) + observed_date[0] = 1 + + for i_obs in obs_ptr: + + observed_date[i_obs] = 1 + + return rainfall_path, observed_date, nb_obs_single + + rainfall_paths = [] + observed_dates = [] + n_obs = [] + + stop = None + batch_size = 50 + if self.for_val: + rng = np.random.default_rng() + stop = rng.choice(np.arange(2, self.length - 100), 1)[0] + batch_size = 2 + + for i in range(batch_size): + rainfall_path, observed_date, nb_obs = generate( + y_batch, self.length, stop=stop + ) + rainfall_paths.append(rainfall_path) + observed_dates.append(observed_date) + n_obs.append(nb_obs) + + rainfall_paths = np.vstack(rainfall_paths) + observed_dates = np.stack(observed_dates) + n_obs = np.asarray(n_obs) + + return { + "idx": idx, + "rainfall_path": torch.tensor( + rainfall_paths[:, :, :, :, None], dtype=torch.float32 + ), + "observed_dates": observed_dates, + "nb_obs": n_obs, + "dt": 1, + "obs_noise": None, + } + + else: + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) + + else: + y_batch = torch.tensor( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], + dtype=torch.float32, + ) + if self.transform: + y_batch = self.transform(y_batch) + + return y_batch diff --git a/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py new file mode 100644 index 0000000..618757e --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py @@ -0,0 +1,380 @@ +import dask +import numpy as np +import torch +import xbatcher +from scipy.spatial import KDTree + +from xarray_batcher.get_fcst_and_truth import get_all + +from .batch_helper_functions import Antialiasing, get_spherical + + +class StreamDataset(torch.utils.data.IterableDataset): + + """ + Similar as BatchDataset, see torch_batcher.py apart + from the new workflow to assist in streaming: + + 1) Start using only truth data + 2) Calculate sampler + 3) When iterating through truth, load in the fcst. + data on-the-fly + + """ + + def __init__( + self, + y, + variables, + constants, + batch_size: list[int] = [4, 128, 128], + batches_per_epoch=1200, + weighted_sampler: bool = True, + for_NJ: bool = False, + for_val: bool = False, + antialiasing: bool = False, + ): + self.batch_size = batch_size + self.batches_per_epoch = batches_per_epoch + self.variables = variables + self.y_generator = xbatcher.BatchGenerator( + y, + {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, + input_overlap={ + "lat": int(batch_size[1] / 32), + "lon": int(batch_size[2] / 32), + }, + ) + constants["lat"] = np.round(y.lat.values, decimals=2) + constants["lon"] = np.round(y.lon.values, decimals=2) + + self.constants_generator = constants + + self.constants = list(constants.data_vars) + self.for_NJ = for_NJ + self.for_val = for_val + self.antialiasing = antialiasing + + if weighted_sampler: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + + rounded_y_train = np.round(y_train, decimals=0) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + sample_weights = weight[np.digitize(rounded_y_train, unique_classes) - 1] + sample_weights = sample_weights / np.sum(sample_weights) + + self.sample_weights = torch.from_numpy(np.asarray(sample_weights)) + else: + self.sample_weights = None + + self.len = len(self.y_generator) + + def __len__(self): + return self.batches_per_epoch + + def __iter__(self): + self.idx = 0 + while self.idx <= self.__len__(): + try: + yield self.__sample__() + self.idx += 1 + except: + continue + + def __sample__(self): + + if self.sample_weights is None: + idx_samp = np.random.randint(0, self.len) + else: + idx_samp = int(np.random.choice(self.len, p=self.sample_weights)) + + y_batch = self.y_generator[idx_samp] + time_batch = y_batch.time.values + lat_batch = np.round(y_batch.lat.values, decimals=2) + lon_batch = np.round(y_batch.lon.values, decimals=2) + + X_generator = get_all( + None, + model="ifs", + truth_batch=y_batch, + stream=True, + offset=24, + variables=self.variables, + ) + + X_batch = [] + for x, variable in zip(X_generator, self.variables): + X_batch.append(x[variable].values) + + X_batch = torch.from_numpy( + np.concatenate( + X_batch, + axis=-1, + ) + ).float() + + constant_batch = torch.from_numpy( + np.stack( + [ + self.constants_generator[constant] + .sel({"lat": lat_batch, "lon": lon_batch}) + .values + for constant in self.constants + ], + axis=-1, + ) + ).float() + + if self.for_NJ: + + elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) + lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) + spherical_coords = get_spherical( + lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values + ) + + kdtree = KDTree(spherical_coords) + + pairs = [] + + for i_coord, coord in enumerate(spherical_coords): + pairs.append( + np.vstack( + ( + np.full(3, fill_value=i_coord).reshape(1, -1), + kdtree.query(coord, k=3)[1], + ) + ) + ) + + pairs = np.hstack((pairs)) + + rainfall_path = torch.cat( + ( + torch.from_numpy( + y_batch.precipitation.fillna(0).values.reshape( + self.batch_size[0], -1, 1 + ) + ).float(), + X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), + ), + dim=-1, + ) + obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) + n_obs = np.array([self.batch_size[0]]) + if self.for_val: + obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) + n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) + obs_dates[: n_obs[0]] = 1 + + return { + "idx": idx, + "rainfall_path": rainfall_path[None, :, :, :], + "observed_dates": obs_dates, + "nb_obs": n_obs, + "dt": 1, + "edge_indices": pairs, + "obs_noise": None, + } + + else: + + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() + + else: + y_batch = torch.from_numpy( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() + return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) + + +class StreamTruth(torch.utils.data.Dataset): + + """ + class for iterating over a dataset + """ + + def __init__( + self, + y, + batch_size=[4, 128, 128], + weighted_sampler=True, + for_NJ=False, + for_val=False, + length=None, + antialiasing=False, + transform=None, + return_dataset=False, + ): + + self.batch_size = batch_size + self.for_NJ = for_NJ + self.for_val = for_val + self.length = length + self.antialiasing = antialiasing + self.transform = transform + self.return_dataset = return_dataset + overlap = ( + {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} + if for_NJ + else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} + ) + self.y_generator = xbatcher.BatchGenerator( + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) + + if weighted_sampler: + if self.for_NJ: + y_train = [ + self.y_generator[i].mean( + ["time", "latitude", "longitude"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + else: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + + def __len__(self) -> int: + return len(self.y_generator) + + def __getitem__(self, idx): + + y_batch = self.y_generator[idx] + + if self.return_dataset: + return y_batch + + if self.for_NJ: + + def generate(y_batch, length, stop=None): + + rng = np.random.default_rng() + random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] + ds_sel = y_batch.sel( + { + "time": slice( + "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) + ) + } + ) + + time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] + + time_to_event = rng.choice(np.arange(length), 1)[0] + time_after_event = length - time_to_event - 1 + + rainfall_path = ds_sel.sel( + { + "time": slice( + time_of_event - np.timedelta64(time_to_event * 30, "m"), + time_of_event + np.timedelta64(time_after_event * 30, "m"), + ), + } + ) + times_rainfall = rainfall_path.time.values + rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] + + if stop is not None: + # limit observations to once a day + nb_obs_single = stop + obs_ptr = np.arange(1, nb_obs_single) + + else: + nb_obs_single = length + obs_ptr = np.arange(1, length) + + observed_date = np.zeros(rainfall_path.shape[1]) + observed_date[0] = 1 + + for i_obs in obs_ptr: + + observed_date[i_obs] = 1 + + return rainfall_path, observed_date, nb_obs_single + + rainfall_paths = [] + observed_dates = [] + n_obs = [] + + stop = None + batch_size = 50 + if self.for_val: + rng = np.random.default_rng() + stop = rng.choice(np.arange(2, self.length - 100), 1)[0] + batch_size = 2 + + for i in range(batch_size): + rainfall_path, observed_date, nb_obs = generate( + y_batch, self.length, stop=stop + ) + rainfall_paths.append(rainfall_path) + observed_dates.append(observed_date) + n_obs.append(nb_obs) + + rainfall_paths = np.vstack(rainfall_paths) + observed_dates = np.stack(observed_dates) + n_obs = np.asarray(n_obs) + + return { + "idx": idx, + "rainfall_path": torch.tensor( + rainfall_paths[:, :, :, :, None], dtype=torch.float32 + ), + "observed_dates": observed_dates, + "nb_obs": n_obs, + "dt": 1, + "obs_noise": None, + } + + else: + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) + + else: + y_batch = torch.tensor( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], + dtype=torch.float32, + ) + if self.transform: + y_batch = self.transform(y_batch) + + return y_batch diff --git a/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000..50ef143 --- /dev/null +++ b/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,168 @@ +## Utils needed in loading zarr and batching +import datetime +import os +import pickle + +import numpy as np + +## Put all forecast fields, their levels (can be None also), and specify categories of accumulated fields +accumulated_fields = ["ssr", "cp", "tp"] + + +## Put other user-specification i.e., lon-lat box, spatial and temporal resolution (in hours) +TIME_RES = 6 + +LONLATBOX = [-14, 19, 25.25, 54.75] +FCST_SPAT_RES = 0.1 +FCST_TIME_RES = 3 + +## Put all directories here + +TRUTH_PATH = ( + "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07/ICPAC_region/6h/" +) +FCST_PATH_IFS = ( + "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IFS-regICPAC-meansd/" +) + +CONSTANTS_PATH = ( + "/network/group/aopp/predict/TIP022_NATH_GFSAIMOD/cGAN/constants-regICPAC/" +) + + +def get_metadata(): + + """ + Returns time resolution (in hours), lonlat box (bottom, left, top, right) and the forecast's spatial resolution + """ + + return FCST_TIME_RES, TIME_RES, LONLATBOX, FCST_SPAT_RES + + +def get_paths(): + + return FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH + + +import pickle + + +def load_fcst_norm(year=2018): + + fcstnorm_path = os.path.join( + CONSTANTS_PATH.replace("-regICPAC", "_IFS"), f"FCSTNorm{year}.pkl" + ) + + with open(fcstnorm_path, "rb") as f: + return pickle.load(f) + + +def daterange(start_date, end_date): + + """ + Generator to get date range for a given time period from start_date to end_date + """ + + for n in range(int((end_date - start_date).days)): + yield start_date + datetime.timedelta(days=n) + + +def get_valid_dates( + year, + TIME_RES=TIME_RES, + start_hour=30, + end_hour=60, + raw_list=False, +): + + """ + Returns list of valid forecast start dates for which 'truth' data + exists, given the other input parameters. If truth data is not available + for certain days/hours, this will not be the full year. Dates are returned + as a list of YYYYMMDD strings. + + Parameters: + year (list): forecasts starting in this year + start_hour (int): Lead time of first forecast desired + end_hour (int): Lead time of last forecast desired + """ + + # sanity checks for our dataset + assert year in (2018, 2019, 2020, 2021, 2022, 2023, 2024) + assert start_hour >= 0 + assert start_hour % TIME_RES == 0 + assert end_hour % TIME_RES == 0 + assert end_hour > start_hour + + # Build "cache" of truth data dates/times that exist as well as forecasts + valid_dates = [] + + start_date = datetime.date(year, 1, 1) + end_date = datetime.date( + year + 1, 1, end_hour // TIME_RES + 2 + ) # go a bit into following year + + for curdate in daterange(start_date, end_date): + datestr = curdate.strftime("%Y%m%d") + valid = True + + ## then check for truth data at the desired lead time + for hr in np.arange(start_hour, end_hour, TIME_RES): + datestr_true = curdate + datetime.timedelta(hours=6) + datestr_true = datestr_true.strftime("%Y%m%d_%H") + fname = f"{datestr_true}" # {hr:02} + + if not os.path.exists( + os.path.join(TRUTH_PATH, f"{datestr_true[:4]}/{fname}.nc") + ): + valid = False + break + + if valid: + valid_dates.append(curdate) + + if raw_list: + # Need to get it from datetime to numpy readable format + valid_dates = [date.strftime("%Y-%m-%d") for date in valid_dates] + + return valid_dates + + +def match_fcst_to_valid_time(valid_times, time_idx, step_type="h"): + + """ + Inputs + ------ + valid_times: ndarray or datetime64 object + array of dates as data type datetime64[ns] + TIME_RES: integer + hourly timesteps which the fcst. is in + default = 6 + time_idx: int + array of prediction timedelta of same shape + as valid_dates, should be in hours, + data type = int. + step_type: str + type of fcst step e.g., D for day + default: 'h' for hour + + Outputs + ------- + fcst_dates: ndarray + i.e., valid_dates-time_idx + valid_date_idx: ndarray + to select + """ + + if not isinstance(time_idx, list): + time_offset = np.timedelta64(time_idx, step_type) + valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) + else: + time_offset = [np.timedelta64(t_idx, step_type) for t_idx in time_idx] + valid_date_idx = np.asarray( + [int(t_offset.astype(int) / TIME_RES) for t_offset in time_offset] + ) + + fcst_times = valid_times - time_offset + + return fcst_times, valid_date_idx diff --git a/xarray_batcher/get_fcst_and_truth.py b/xarray_batcher/get_fcst_and_truth.py index 2ccc62a..633d2f5 100644 --- a/xarray_batcher/get_fcst_and_truth.py +++ b/xarray_batcher/get_fcst_and_truth.py @@ -157,6 +157,7 @@ def modify( ds = streamline_and_normalise_ifs( name[1].split("_")[0], ds, time_idx=time_idx, split_steps=split_steps ).to_dataset(name=name[1].split("_")[0]) + return ds else: diff --git a/xarray_batcher/loading.py b/xarray_batcher/loading.py index 5a51311..704fbef 100644 --- a/xarray_batcher/loading.py +++ b/xarray_batcher/loading.py @@ -221,6 +221,8 @@ def streamline_and_normalise_ifs( all_data_mean = da[f"{field}_mean"].values all_data_sd = da[f"{field}_sd"].values + da.close() + if time_idx is None: times = np.hstack( ( diff --git a/xarray_batcher/torch_streamer.py b/xarray_batcher/torch_streamer.py index 91d864c..618757e 100644 --- a/xarray_batcher/torch_streamer.py +++ b/xarray_batcher/torch_streamer.py @@ -28,12 +28,14 @@ def __init__( variables, constants, batch_size: list[int] = [4, 128, 128], + batches_per_epoch=1200, weighted_sampler: bool = True, for_NJ: bool = False, for_val: bool = False, antialiasing: bool = False, ): self.batch_size = batch_size + self.batches_per_epoch = batches_per_epoch self.variables = variables self.y_generator = xbatcher.BatchGenerator( y, @@ -61,37 +63,41 @@ def __init__( for i in range(len(self.y_generator)) ] - rounded_y_train = np.round(y_train, decimals=1) + rounded_y_train = np.round(y_train, decimals=0) unique_classes = np.unique(rounded_y_train) class_sample_count = np.bincount( np.digitize(rounded_y_train, unique_classes) - 1 ) weight = 1.0 / class_sample_count - samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] - samples_weight = samples_weight / np.sum(samples_weight) + sample_weights = weight[np.digitize(rounded_y_train, unique_classes) - 1] + sample_weights = sample_weights / np.sum(sample_weights) - self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) - self.sampler = torch.utils.data.WeightedRandomSampler( - self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) - ) + self.sample_weights = torch.from_numpy(np.asarray(sample_weights)) else: - self.sampler = None + self.sample_weights = None self.len = len(self.y_generator) - def __iter__(self): + def __len__(self): + return self.batches_per_epoch - if self.sampler is None: - idx = np.random.randint(0, self.len) + def __iter__(self): + self.idx = 0 + while self.idx <= self.__len__(): + try: + yield self.__sample__() + self.idx += 1 + except: + continue + + def __sample__(self): + + if self.sample_weights is None: + idx_samp = np.random.randint(0, self.len) else: - idx = int(np.random.choice(self.len, p=self.samples_weight)) - - while True: + idx_samp = int(np.random.choice(self.len, p=self.sample_weights)) - yield self.__sample__(idx) - - def __sample__(self, idx): - y_batch = self.y_generator[idx] + y_batch = self.y_generator[idx_samp] time_batch = y_batch.time.values lat_batch = np.round(y_batch.lat.values, decimals=2) lon_batch = np.round(y_batch.lon.values, decimals=2) From 32bf1851457c21a5f5f0b032addc4d333977c79e Mon Sep 17 00:00:00 2001 From: Shruti Nath Date: Wed, 2 Jul 2025 10:14:42 +0100 Subject: [PATCH 4/4] remove cahced --- .../.ipynb_checkpoints/__init__-checkpoint.py | 8 - .../batch_helper_functions-checkpoint.py | 96 ---- .../create_npz-checkpoint.py | 95 ---- .../custom_collate_fn-checkpoint.py | 141 ----- .../get_fcst_and_truth-checkpoint.py | 488 ----------------- .../.ipynb_checkpoints/loading-checkpoint.py | 492 ------------------ .../normalise-checkpoint.py | 128 ----- .../setup_data-checkpoint.py | 100 ---- .../torch_batcher-checkpoint.py | 351 ------------- .../torch_streamer-checkpoint.py | 380 -------------- .../.ipynb_checkpoints/utils-checkpoint.py | 168 ------ 11 files changed, 2447 deletions(-) delete mode 100644 xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py delete mode 100644 xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py diff --git a/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index 772e0e6..0000000 --- a/xarray_batcher/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,8 +0,0 @@ -## Initialisation for xarray batcher, import all helper functions -import sys - -from .setup_data import DataModule -from .torch_batcher import BatchDataset, BatchTruth -from .torch_streamer import StreamDataset, StreamTruth - -__all__ = ["DataModule", "BatchDataset", "BatchTruth", "StreamDataset", "StreamTruth"] diff --git a/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py deleted file mode 100644 index efd7317..0000000 --- a/xarray_batcher/.ipynb_checkpoints/batch_helper_functions-checkpoint.py +++ /dev/null @@ -1,96 +0,0 @@ -import concurrent.futures -import multiprocessing - -import numpy as np -from scipy.ndimage import convolve - - -class Antialiasing: - def __init__(self): - (x, y) = np.mgrid[-2:3, -2:3] - self.kernel = np.exp(-0.5 * (x**2 + y**2) / (0.5**2)) - self.kernel /= self.kernel.sum() - self.edge_factors = {} - self.img_smooth = {} - num_threads = multiprocessing.cpu_count() - self.executor = concurrent.futures.ThreadPoolExecutor(num_threads) - - def __call__(self, img): - if img.ndim < 3: - img = img[None, None, :, :] - elif img.ndim < 4: - img = img[None, :, :, :] - img_shape = img.shape[-2:] - if img_shape not in self.edge_factors: - s = convolve( - np.ones(img_shape, dtype=np.float32), self.kernel, mode="constant" - ) - s = 1.0 / s - self.edge_factors[img_shape] = s - else: - s = self.edge_factors[img_shape] - - if img.shape not in self.img_smooth: - img_smooth = np.empty_like(img) - self.img_smooth[img_shape] = img_smooth - else: - img_smooth = self.img_smooth[img_shape] - - def _convolve_frame(i, j): - convolve( - img[i, j, :, :], - self.kernel, - mode="constant", - output=img_smooth[i, j, :, :], - ) - img_smooth[i, j, :, :] *= s - - futures = [] - for i in range(img.shape[0]): - for j in range(img.shape[1]): - args = (_convolve_frame, i, j) - futures.append(self.executor.submit(*args)) - concurrent.futures.wait(futures) - - return img_smooth - - -def get_spherical(lat, lon, elev, return_hstacked=True): - - """ - Get spherical coordinates of lat and lon, not assuming unit ball for radius - So we also take elev into account - - Inputs - ------ - - lat: np.array or xr.DataArray (n_lats,n_lons) - meshgrid of latitude points - - lon: np.array or xr.DataArray (n_lats,n_lons) - meshgrid of longitude points - - elev: np.array or xr.DataArray (n_lats,n_lons) - altitude values in m - - return_hstacked: boolean - typically for graph networks we collapse - lat lon into one dimension - Output - ------ - - r, sigma and phi - See: https://en.wikipedia.org/wiki/Spherical_coordinate_system - for more details - """ - - lat, lon = np.deg2rad(lat), np.deg2rad(lon) - - x = elev * np.cos(lat) * np.cos(lon) - y = elev * np.cos(lat) * np.sin(lon) - z = elev * np.sin(lat) - - if return_hstacked: - return np.hstack((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1))) - else: - return np.dstack((x[:, :, None], y[:, :, None], z[:, :, None])) diff --git a/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py deleted file mode 100644 index ed35e38..0000000 --- a/xarray_batcher/.ipynb_checkpoints/create_npz-checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -import numpy as np -import xarray as xr -from torch.utils.data import DataLoader - -from .batch_helper_functions import get_spherical -from .loading import get_IMERG_year -from .torch_batcher import BatchDataset, BatchTruth -from .utils import get_paths - -_, _, CONSTANTS_PATH = get_paths() -elev = xr.open_dataset(CONSTANTS_PATH + "elev.nc") - - -def collate_fn(batch, elev=elev, reg_dict={}): - - lat_batch = np.round(batch.lat.values, decimals=2) - lon_batch = np.round(batch.lon.values, decimals=2) - lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) - elev_values = elev.sel({"lat": lat_batch, "lon": lon_batch}, method="nearest") - elev_values = np.squeeze(elev_values.elevation.values) / 10000.0 - spherical_coords = get_spherical( - lat_values, lon_values, elev_values, return_hstacked=False - ) - - i = 0 - reg_sel = None - for reg in reg_dict.keys(): - if np.array_equal(reg_dict[reg]["spherical_coords"], spherical_coords): - reg_sel = reg - break - i += 1 - - if reg_sel is None: - reg_sel = i - reg_dict[reg_sel] = { - "elevation": elev_values, - "spherical_coords": spherical_coords, - "precipitation": [], - "time": [], - } - reg_dict[reg_sel]["precipitation"].append(batch.precipitation.values) - reg_dict[reg_sel]["time"].append(batch.time.values) - - return reg_dict - - -def TruthDataloader_to_Npz( - out_path, - years=[2018, 2019, 2020, 2021, 2023, 2024], - centre=[-1.25, 36.80], - months=None, - window_size=3, - collate_fn=collate_fn, -): - - if not os.path.exists(out_path): - os.makedirs(out_path) - - if months is None: - months = np.arange(1, 13).tolist() - elev = xr.open_dataset(CONSTANTS_PATH + "elev.nc") - for year in years: - - ds = get_IMERG_year(year, months=months, centre=centre, window_size=window_size) - if not isinstance(ds, xr.Dataset): - ds = ds.to_dataset() - - # load in truth to batcher without any weighting in sampler - ds = BatchTruth( - ds, batch_size=[1, 128, 128], weighted_sampler=False, return_dataset=True - ) - reg_dict = {} - - for batch in ds: - reg_dict = collate_fn(batch, elev=elev, reg_dict=reg_dict) - - spherical_coords = np.stack( - [reg_dict[key]["spherical_coords"] for key in reg_dict.keys()] - ) - elevation = np.stack([reg_dict[key]["elevation"] for key in reg_dict.keys()]) - precipitation = np.stack( - [np.vstack(reg_dict[key]["precipitation"]) for key in reg_dict.keys()] - ) - time = np.stack([np.hstack(reg_dict[key]["time"]) for key in reg_dict.keys()]) - print(time.shape) - - np.savez( - out_path + f"{year}_30min_IMERG_Nairobi_windowsize={window_size}.npz", - spherical_coords=spherical_coords, - elevation=elevation, - precipitation=precipitation, - time=time, - ) diff --git a/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py deleted file mode 100644 index 934df67..0000000 --- a/xarray_batcher/.ipynb_checkpoints/custom_collate_fn-checkpoint.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -import torch - - -def _get_func(name): - """ - transform a function given as str to a python function - :param name: str, correspond to a function, - supported: 'exp', 'power-x' (x the wanted power) - :return: numpy fuction - """ - if name in ["exp", "exponential"]: - return np.exp - if "power-" in name: - x = float(name.split("-")[1]) - - def pow(input): - return np.power(input, x) - - return pow - else: - try: - return eval(name) - except Exception: - return None - - -def _get_X_with_func_appl(X, functions, axis): - """ - apply a list of functions to the paths in X and append X by the outputs - along the given axis - :param X: np.array, with the data, - :param functions: list of functions to be applied - :param axis: int, the data_dimension (not batch and not time dim) along - which the new paths are appended - :return: np.array - """ - Y = X - for f in functions: - Y = np.concatenate([Y, f(X)], axis=axis) - return Y - - -def CustomCollateFnGen(func_names=None): - """ - a function to get the costume collate function that can be used in - torch.DataLoader with the wanted functions applied to the data as new - dimensions - -> the functions are applied on the fly to the dataset, and this additional - data doesn't have to be saved - - :param func_names: list of str, with all function names, see _get_func - :return: collate function, int (multiplication factor of dimension before - and after applying the functions) - """ - # get functions that should be applied to X, additionally to identity - functions = [] - if func_names is not None: - for func_name in func_names: - f = _get_func(func_name) - if f is not None: - functions.append(f) - mult = len(functions) + 1 - - def custom_collate_fn(batch): - dt = batch[0]["dt"] - stock_paths = np.concatenate([b["rainfall_path"] for b in batch], axis=0) - observed_dates = np.concatenate([b["observed_dates"] for b in batch], axis=0) - # edge_indices = np.concatenate([b['edge_indices'] for b in batch], axis=0) - obs_noise = None - if batch[0]["obs_noise"] is not None: - obs_noise = np.concatenate([b["obs_noise"] for b in batch], axis=0) - masked = False - mask = None - if len(observed_dates.shape) == 3: - masked = True - mask = observed_dates - observed_dates = observed_dates.max(axis=1) - nb_obs = torch.tensor(np.concatenate([b["nb_obs"] for b in batch], axis=0)) - - # here axis=1, since we have elements of dim - # [batch_size, data_dimension] => add as new data_dimensions - sp = stock_paths[:, 0] - if obs_noise is not None: - sp = stock_paths[:, :, 0] + obs_noise[:, :, 0] - start_X = torch.tensor( - _get_X_with_func_appl(sp, functions, axis=1), dtype=torch.float32 - ) - X = [] - if masked: - M = [] - start_M = torch.tensor(mask[:, :, 0], dtype=torch.float32).repeat((1, mult)) - else: - M = None - start_M = None - times = [] - time_ptr = [0] - obs_idx = [] - current_time = 0.0 - counter = 0 - for t in range(1, observed_dates.shape[-1]): - current_time += dt - if observed_dates[:, t].sum() > 0: - times.append(current_time) - for i in range(observed_dates.shape[0]): - if observed_dates[i, t] == 1: - counter += 1 - # here axis=0, since only 1 dim (the data_dimension), - # i.e. the batch-dim is cummulated outside together - # with the time dimension - sp = stock_paths[i, t] - if obs_noise is not None: - sp = stock_paths[i, :, t] + obs_noise[i, :, t] - X.append(_get_X_with_func_appl(sp, functions, axis=0)) - if masked: - M.append(np.tile(mask[i, :, t], reps=mult)) - obs_idx.append(i) - time_ptr.append(counter) - # if obs_noise is not None: - # print("noisy observations used") - - assert len(obs_idx) == observed_dates[:, 1:].sum() - if masked: - M = torch.tensor(np.array(M), dtype=torch.float32) - res = { - "times": np.array(times), - "time_ptr": np.array(time_ptr), - "obs_idx": torch.tensor(obs_idx, dtype=torch.long), - "start_X": start_X, - "n_obs_ot": nb_obs, - "X": torch.tensor(np.array(X), dtype=torch.float32).permute(3, 1, 2, 0), - "true_paths": stock_paths, - "observed_dates": observed_dates, - "true_mask": mask, - "obs_noise": obs_noise, #'edge_indices': torch.from_numpy(edge_indices).long().contiguous(), - "M": M, - "start_M": start_M, - } - return res - - return custom_collate_fn, mult diff --git a/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py deleted file mode 100644 index 633d2f5..0000000 --- a/xarray_batcher/.ipynb_checkpoints/get_fcst_and_truth-checkpoint.py +++ /dev/null @@ -1,488 +0,0 @@ -import glob -import time - -import dask -import numpy as np -import xarray as xr - -from .loading import ( - load_hires_constants, - load_truth_and_mask, - streamline_and_normalise_ifs, -) -from .utils import get_paths, get_valid_dates, match_fcst_to_valid_time - -FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH = get_paths() - - -def open_mfzarr( - file_names, - use_modify=False, - centre=[-1.25, 36.80], - window_size=2, - lats=None, - lons=None, - months=[3, 4, 5, 6], - dates=None, - time_idx=None, - split_steps=[5, 6, 7, 8, 9], - clip_to_window=True, -): - """ - Open multiple files using dask delayed - - Inputs - ------ - file_names: list - list of file names to open typically one - file for each variable - kwargs: to be passed on to modify - - Outputs - ------- - List of xr.DataArray or xr.Dataset we avoid concatenation - as this takes to long, but through modify, we are confident - that the times align. - """ - - # this is basically what open_mfdataset does - open_kwargs = dict(decode_cf=True, decode_times=True) - open_tasks = [dask.delayed(xr.open_dataset)(f, **open_kwargs) for f in file_names] - - tasks = [ - dask.delayed(modify)( - task, - use_modify=use_modify, - centre=centre, - window_size=window_size, - lats=lats, - lons=lons, - months=months, - dates=dates, - time_idx=time_idx, - split_steps=split_steps, - clip_to_window=clip_to_window, - ) - for task in open_tasks - ] - - datasets = dask.compute(tasks) # get a list of xarray.Datasets - combined = datasets[0] # or some combination of concat, merge - dates = [] - for dataset in combined: - if time_idx is None: - dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) - else: - dates += list(np.unique(dataset.time.values)) - - return combined, dates - - -def modify( - ds, - use_modify=False, - centre=[-1.25, 36.80], - window_size=2, - lats=None, - lons=None, - months=[3, 4, 5, 6], - dates=None, - time_idx=None, - split_steps=[5, 6, 7, 8, 9], - clip_to_window=True, -): - - """ - Modification function to wrap around dask delayed compute - - Inputs - ------ - - use_modify: boolean - whether to apply modification function at all. - centre: list or tuple - centre around which to select a region when looking at sub-domains - window_size: integer - window size around centre to use in sub-domain selection - lats: ndarray or list - alternatively, latitudes can be given to sub-select - lons: ndarray or list - Ditto as lats - months: list - months to select if we are doing seasonal/monthly training - dates: list or 1-D array - dates to sub-select, typically is all months are used but too - expensive to downlaod. If time_idx is None, we assume these are - the dates of forecast issue - time_idx: list or 1-D array - valid time indices to select in index form, not absolute - value - Outputs - ------- - - xr.Dataset or xr.DataArray with modifications if use_midify=True or simply - without modifications - - **Note: when time_idx is provided, split_steps is ignored** - """ - - if use_modify: - name = [var for var in ds.data_vars] - if lats is not None and lons is not None: - lat_batch = np.round(lats, decimals=2) - lon_batch = np.round(lons, decimals=2) - - ds = ds.sel(time=ds.time.dt.month.isin(months)).sel( - { - "latitude": lat_batch, - "longitude": lon_batch, - } - ) - elif clip_to_window: - - ds = ds.sel(time=ds.time.dt.month.isin(months)).sel( - { - "latitude": slice(centre[0] - window_size, centre[0] + window_size), - "longitude": slice( - centre[1] - window_size, centre[1] + window_size - ), - } - ) - if dates is not None: - _, dates_intersect, _ = np.intersect1d( - ds.time.values.astype("datetime64[D]"), dates, return_indices=True - ) - ds = ds.isel(time=dates_intersect) - - ds = streamline_and_normalise_ifs( - name[1].split("_")[0], ds, time_idx=time_idx, split_steps=split_steps - ).to_dataset(name=name[1].split("_")[0]) - - return ds - - else: - return ds - - -def stream_ifs(truth_batch, offset=24, variables=None): - """ - Input - ----- - truth_batch: xr.DataArray or xr.Dataset - truth values of a single batch item - to match and load - fcst data for - offset: int - day offset to factor in, should be in hours - - variables: list or None - variables to load, if None then all are - loaded. - Output - ------ - - forecast batch as ndarray - """ - - batch_time = truth_batch.time.values - batch_lats = truth_batch.lat.values - batch_lons = truth_batch.lon.values - - if not isinstance(batch_time, np.ndarray): - batch_time = np.asarray(batch_time) - - # Get hours in the truth_batch times object - # First need to convert to format with base units of hours to extract hour offset - hour = [time.hour for time in batch_time.astype("datetime64[h]").astype(object)] - - # Note that if hour is 0 then we add 24 as this - # is the offset+24 - hour = [h + 24 * (h == 0) + offset for h in hour] - - fcst_date, time_idx = match_fcst_to_valid_time(batch_time, hour) - - year = fcst_date.astype("datetime64[D]").astype(object)[0].year - month = fcst_date.astype("datetime64[D]").astype(object)[0].month - - if variables is None: - files = sorted(glob.glob(FCST_PATH_IFS + str(year) + "/" + "*.nc")) - else: - files = [FCST_PATH_IFS + str(year) + "/" + "%s.nc" % var for var in variables] - - ds, dates_modified = open_mfzarr( - files, - use_modify=True, - lats=batch_lats, - lons=batch_lons, - months=[month], - dates=fcst_date, - time_idx=time_idx, - clip_to_window=False, - ) - assert np.all(np.isin(dates_modified, batch_time)) - return ds - - -def get_whole_year_ifs( - years, - centre=[-1.25, 36.80], - window_size=30, - months=[3, 4, 5, 6], - n_days=None, - split_steps=[5, 6, 7, 8, 9], - ignore_truth=False, - variables=None, - clip_to_window=True, -): - dates_all = [] - - if ignore_truth: - dates_year = [ - list( - np.arange( - "%i-01-01" % year, - "%i-01-01" % (year + 1), - np.timedelta64(1, "D"), - dtype="datetime64[D]", - ).astype("str") - ) - for year in years - ] - else: - dates_year = [get_valid_dates(year, raw_list=True) for year in years] - - for dates in dates_year: - start_time = time.time() - - if len(months) == 12 and n_days is not None: - dates_sel = np.random.choice( - np.array(dates, dtype="datetime64[D]"), n_days, replace=False - ) - else: - dates_sel = None - - dates_final = np.array(dates.copy(), dtype="datetime64[D]") - year = dates_final[0].astype(object).year - - if variables is None: - files = sorted(glob.glob(FCST_PATH_IFS + str(year) + "/" + "*.nc")) - else: - files = [ - FCST_PATH_IFS + str(year) + "/" + "%s.nc" % var for var in variables - ] - - ds, dates_modified = open_mfzarr( - files, - use_modify=True, - centre=centre, - window_size=window_size, - months=months, - dates=dates_sel, - split_steps=split_steps, - clip_to_window=clip_to_window, - ) - - if year == years[0]: - ds_vars = ds - else: - ds_vars = [ - xr.concat([ds_1, ds_2], "time") for ds_1, ds_2 in zip(ds_vars, ds) - ] - - dates_final = np.append(dates_final, dates_modified, axis=0) - - del ds - - print( - "Extracted all %i variables in ----" % len(files), - time.time() - start_time, - "s---- for year", - year, - ) - - dates_final, dates_count = np.unique(dates_final, return_counts=True) - dates_idx = np.squeeze(np.argwhere(dates_count == (len(files) + 1))) - dates_final = dates_final[dates_idx] - dates_all += [str(date) for date in dates_final] - # print(len(dates)-len(dates_final)," missing dates in year", year) - - if ignore_truth: - return ds_vars - - print("Now doing truth values") - start_time = time.time() - # time_idx is hard-coded in here as forecast is made to have time as valid_time - ds_truth_and_mask = load_truth_and_mask( - np.array(dates_all, dtype="datetime64[ns]").flatten(), - time_idx=[1, 2, 3, 4], - ) - if dates_sel is not None: - # Because of 6 hour offset when streamlining select dates 6AM to midnight is used - # Meaning that the next day midnight is in truth but no other time step in that date. - # Therefore, need to guarantee alignment in times. - ds_truth_and_mask = ds_truth_and_mask.drop_duplicates("time") - times_sel = np.intersect1d( - ds_vars[0].time.values, ds_truth_and_mask.time.values - ) - ds_truth_and_mask = ds_truth_and_mask.sel({"time": times_sel}) - ds_constants = load_hires_constants(batch_size=1) - - print( - "Finished retrieving truth values in ----", - time.time() - start_time, - "s---- for year", - year, - ) - - return ( - ds_vars, - ds_truth_and_mask.rename({"latitude": "lat", "longitude": "lon"}) - .sel(time=ds_truth_and_mask.time.dt.month.isin(months)) - .sel( - { - "lat": slice(centre[0] - window_size, centre[0] + window_size), - "lon": slice(centre[1] - window_size, centre[1] + window_size), - } - ), - ds_constants.sel( - { - "lat": slice(centre[0] - window_size, centre[0] + window_size), - "lon": slice(centre[1] - window_size, centre[1] + window_size), - } - ), - ) - - -def get_all( - years, - model="ifs", - offset=None, - stream=False, - truth_batch=None, - time_idx=[5, 6, 7, 8], - split_steps=[5, 6, 7, 8, 9], - ignore_truth=False, - variables=None, - months=None, - n_days=10, - centre=[-1.25, 36.80], - window_size=30, - clip_to_window=False, - log_precip=True, -): - - """ - Wrapper function to return either: - - * IFS data alongside truth data fully loaded into memory. - This is recommended if an npz file is being created for - example. - * stream IFS data in which case truth_batch should be given - and offset is the no. days (in hours) lead time used to - obtain the initialisation time from truth batch valid time - (See stream_ifs for further details) - * Obtain only truth data (when model="truth") - - Inputs - ------ - years: list - list of years to calculate over - model: str - either truth or ifs (later additions possible) - default='ifs' - offset: int or None - Hour offset that accounts for no. days lead time - needed to match back to intiialisation time - during streaming. - stream: boolean - whether or not to stream data in which case - truth batch should be provided - truth_batch: xr.DataArray or xr.Dataset - truth batch to obtain forecast variables from - time_idx: list - time_idx to obtain truth - split_steps: list - fcst lead times to take when loading in IFS - ignore_truth: boolean - passed on to get_whole_year_ifs, whether to - only load in the fcst and ignore turht - variables: list or None - variables to load, if None then all are - loaded. - months: ndarray or list or None - months to load in in case of seasonal/sub-seasonal - training, if None all months are used - centre: list or tuple - centre around which to select a region when looking - at sub-domains, only used when clip_to_window is true - window_size: integer - window size around centre to use in sub-domain - selection, only used when clip_to_window is true - clip_to_window: boolean - passed on to get_whole_year_ifs or simply - getting truth, to signal when to clip to a - window (size window_size) around centre - log_precip: boolean - passed on to load_truth_and_mask when model is "truth - default=True - - Outputs: - ------- - list of xr.DataArray or streamed or whole IFS+truth or simply truth data - """ - - if months is None: - months = np.arange(1, 13).tolist() - - if model == "ifs": - if not stream: - return get_whole_year_ifs( - years, - split_steps=split_steps, - ignore_truth=ignore_truth, - variables=variables, - months=months, - n_days=n_days, - centre=centre, - window_size=window_size, - clip_to_window=clip_to_window, - ) - - else: - assert truth_batch is not None - assert offset is not None - return stream_ifs(truth_batch, offset=offset, variables=variables) - - elif model == "truth": - dates_all = [] - for year in years: - dates = get_valid_dates(year) - dates_all += [date.strftime("%Y-%m-%d") for date in dates] - - print("Only getting truth values over,", len(dates_all), "dates") - start_time = time.time() - ds_truth_and_mask = load_truth_and_mask( - np.array(dates_all, dtype="datetime64[ns]").flatten(), - time_idx=time_idx, - log_precip=log_precip, - ) - - print( - "Finished retrieving truth values in ----", - time.time() - start_time, - "s---- for years", - years, - ) - if clip_to_window: - ds_truth_and_mask = ds_truth_and_mask.sel( - { - "latitude": slice(centre[0] - window_size, centre[0] + window_size), - "longitude": slice( - centre[1] - window_size, centre[1] + window_size - ), - } - ) - - return ds_truth_and_mask.sel( - time=ds_truth_and_mask.time.dt.month.isin(months) - ).rename({"latitude": "lat", "longitude": "lon"}) diff --git a/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py deleted file mode 100644 index 704fbef..0000000 --- a/xarray_batcher/.ipynb_checkpoints/loading-checkpoint.py +++ /dev/null @@ -1,492 +0,0 @@ -import datetime -import glob -import time - -import h5py -import numpy as np -import xarray as xr -from tqdm import tqdm - -from .normalise import convert_units, get_norm, logprec, nonnegative -from .utils import get_metadata, get_paths - -(FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH) = get_paths() - -(fcst_time_res, time_res, lonlatbox, fcst_spat_res) = get_metadata() - - -def daterange(start_date, end_date): - """ - Generator to get date range for a given time period - from start_date to end_date - """ - - for n in range(int((end_date - start_date).days)): - yield start_date + datetime.timedelta(days=n) - - -def get_lonlat(): - """ - Function to get longitudes and latitudes of forecast and truth data - - Input - ------ - - lonlatbox: list of int with length 4 - bottom, left, top, right corners of lon-lat box - fcst_spat_res: float - spatial resolution of forecasts - - Output - ------ - - centres of forecast (lon_reg, lat_reg), - and truth (lon__reg_TRUTH, lat_reg_TRUTH), and their box - edges: (lon_reg_b, lat_reg_b) and (lon__reg_TRUTH, - lat_reg_TRUTH) for forecasts and truth resp. - - """ - assert len(lonlatbox) == 4 - - lat_reg_b = np.arange(lonlatbox[0], lonlatbox[2], fcst_spat_res) - fcst_spat_res / 2 - lat_reg = 0.5 * (lat_reg_b[1:] + lat_reg_b[:-1]) - - lon_reg_b = np.arange(lonlatbox[1], lonlatbox[3], fcst_spat_res) - fcst_spat_res / 2 - lon_reg = 0.5 * (lon_reg_b[1:] + lon_reg_b[:-1]) - - data_path = glob.glob(TRUTH_PATH + "*.nc") - - ds = xr.open_mfdataset(data_path[0]) - # print(ds) - ##infer spatial resolution of truth, we assume a standard lon lat grid! - - lat_reg_TRUTH = ds.latitude.values - lon_reg_TRUTH = ds.longitude.values - - TRUTH_RES = np.abs(lat_reg_TRUTH[1] - lat_reg_TRUTH[0]) - - lat_reg_TRUTH_b = np.append( - (lat_reg_TRUTH - TRUTH_RES / 2), lat_reg_TRUTH[-1] + TRUTH_RES / 2 - ) - lon_reg_TRUTH_b = np.append( - (lon_reg_TRUTH - TRUTH_RES / 2), lon_reg_TRUTH[-1] + TRUTH_RES / 2 - ) - - return ( - lon_reg, - lat_reg, - lon_reg_b, - lat_reg_b, - lon_reg_TRUTH, - lat_reg_TRUTH, - lon_reg_TRUTH_b, - lat_reg_TRUTH_b, - ) - - -def get_IMERG_lonlat(): - - # A single IMERG data file to get latitude and longitude - IMERG_file_name = "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07/2018/Jan/3B-HHR.MS.MRG.3IMERG.20180116-S120000-E122959.0720.V07B.HDF5" - - # HDF5 in the ICPAC region - h5_file = h5py.File(IMERG_file_name) - latitude = h5_file["Grid"]["lat"][763:1147] - longitude = h5_file["Grid"]["lon"][1991:2343] - h5_file.close() - - return latitude, longitude - - -def prepare_year_and_month_input(years, months): - - if not isinstance(years, list): - assert isinstance(years, int) - years = [years, years] - if not isinstance(months, list): - assert isinstance(months, int) - months = [months, months] - - assert len(years) > 1 - assert len(months) > 1 - - years = np.sort(years) - months = np.sort(months) - - assert years[-1] >= years[0] - assert months[-1] >= months[0] - - year_beg = years[0] - year_end = years[-1] - - month_beg = months[0] - month_end = months[-1] - - if month_end == 12: - year_end += 1 - else: - month_end += 1 - - return year_beg, year_end, month_beg, month_end - - -def retrieve_vars_ifs(field, all_data_mean, all_data_sd, start=1, end=2): - - if field in ["tp", "cp", "ssr"]: - # return mean, sd, 0, 0. zero fields are so - # that each field returns a 4 x ny x nx array. - # accumulated fields have been pre-processed - # s.t. data[:, j, :, :] has accumulation between times j and j+1 - data1 = all_data_mean[:, start:end, :, :].reshape( - -1, all_data_mean.shape[2], all_data_mean.shape[3] - ) - data2 = all_data_sd[:, start:end, :, :].reshape( - -1, all_data_sd.shape[2], all_data_sd.shape[3] - ) - data3 = np.zeros(data1.shape) - data = np.stack([data1, data2, data3, data3], axis=-1)[:, None, :, :, :] - - else: - temp_data_mean_start = all_data_mean[:, start:end, :, :].reshape( - -1, all_data_mean.shape[2], all_data_mean.shape[3] - ) - temp_data_mean_end = all_data_mean[:, end : end + 1, :, :].reshape( - -1, all_data_mean.shape[2], all_data_mean.shape[3] - ) - temp_data_sd_start = all_data_sd[:, start:end, :, :].reshape( - -1, all_data_sd.shape[2], all_data_sd.shape[3] - ) - temp_data_sd_end = all_data_sd[:, end : end + 1, :, :].reshape( - -1, all_data_sd.shape[2], all_data_sd.shape[3] - ) - - data = np.stack( - [ - temp_data_mean_start, - temp_data_sd_start, - temp_data_mean_end, - temp_data_sd_end, - ], - axis=-1, - )[:, None, :, :, :] - - return data - - -def streamline_and_normalise_ifs( - field, - da, - log_prec=True, - norm=True, - time_idx=None, - split_steps=[5, 6, 7, 8, 9], -): - """ - Streamline IFS date to: - * Have appropriate valid time from time of forecast initialization - * If time_idx are provided then we directly select based on that - * Otherwise we select based on split_steps (default 30 - 54 hour lead time) - - Inputs - ------ - - field: str - field to select, needed to check accumulated or non- - negative field - da: xr.DataArray or xr.Dataset - data over which to streamline and normalise - log_prec: boolean - whether to calculate the log of precipitation, - default=True. - norm: boolean - whether to normalise or not, default = True - split_steps: list or 1-D array - valid_time steps to iterate over - default=[5,6,7,8,9] - time_idx: 1-D array or None - instead of split-steps if we have a more randomised - selection of valid time to operate on for each - initialisation time available - - Outputs - ------- - - xr.DataArray or xr.Dataset of streamline and normalised values - - NOTE: We replace the time with the valid time NOT initial fcst - time. - - """ - - all_data_mean = da[f"{field}_mean"].values - all_data_sd = da[f"{field}_sd"].values - - da.close() - - if time_idx is None: - times = np.hstack( - ( - [ - time[split_steps[0] : split_steps[-1]] - for time in da.fcst_valid_time.values - ] - ) - ) - else: - if time_idx.shape[0] % 4 == 0: - time_idx = time_idx.reshape(-1, 4) - assert da.fcst_valid_time.values.shape[0] == time_idx.shape[0] - times = np.hstack( - ( - [ - time[[idx_t]] - for time, idx_t in zip(da.fcst_valid_time.values, time_idx) - ] - ) - ) - - data = [] - - if time_idx is None: - for start, end in zip(split_steps[:4], split_steps[1:5]): - - data.append( - retrieve_vars_ifs( - field, all_data_mean, all_data_sd, start=start, end=end - ) - ) - else: - - for i_row, start in enumerate(time_idx): - if isinstance(start, np.ndarray): - for s in start: - data.append( - retrieve_vars_ifs( - field, - all_data_mean[[i_row]], - all_data_sd[[i_row]], - start=s, - end=s + 1, - ) - ) - else: - data.append( - retrieve_vars_ifs( - field, - all_data_mean[[i_row]], - all_data_sd[[i_row]], - start=start, - end=start + 1, - ) - ) - - data = np.hstack((data)).reshape(-1, da.latitude.shape[0], da.longitude.shape[0], 4) - da = xr.DataArray( - data=data, - dims=["time", "lat", "lon", "i_x"], - coords=dict( - lon=da.longitude.values, - lat=da.latitude.values, - time=times.flatten(), - i_x=np.arange(4), - ), - ) - - if field in [ - "cape", - "cp", - "mcc", - "sp", - "ssr", - "t2m", - "tciw", - "tclw", - "tcrw", - "tcw", - "tcwv", - "tp", - ]: - da = nonnegative(da) - - da = convert_units(da, field, log_prec, m_to_mm=True) - - if norm: - da = get_norm(da, field) - - return da.where(np.isfinite(da), 0).sortby("time") - - -def get_IMERG_year( - years, - months=[3, 4, 5, 6], - centre=[-1.25, 36.80], - window_size=30, - clip_to_window=False, -): - - year_beg, year_end, month_beg, month_end = prepare_year_and_month_input( - years, months - ) - - latitude, longitude = get_IMERG_lonlat() - - # Load the IMERG data averaged over 6h periods - d = datetime.datetime(year_beg, month_beg, 1, 6) - - if year_end != year_beg: - d_end = datetime.datetime(year_end, 1, 1, 6) - times_xr = np.arange( - "%s-%s-01" % (str(year_beg), str(month_beg).zfill(2)), - "%s-%s-01" % (str(year_end), str(1).zfill(2)), - np.timedelta64(30, "m"), - dtype="datetime64[ns]", - ) - else: - d_end = datetime.datetime(year_end, month_end, 1, 6) - times_xr = np.arange( - "%s-%s-01" % (str(year_beg), str(month_beg).zfill(2)), - "%s-%s-01" % (str(year_end), str(month_end).zfill(2)), - np.timedelta64(30, "m"), - dtype="datetime64[ns]", - ) - # Number of 30 minutes rainfall periods - num_time_pts = (d_end - d).days * 48 - - # The 6h average rainfall - rain_IMERG = np.full([num_time_pts, len(longitude), len(latitude)], np.nan) - - start_time = time.time() - - time_idx = 0 - progbar = tqdm(total=int((d_end - d).days) * 2 * 24) - while d < d_end: - - if d.month not in np.arange(month_beg, month_end): - progbar.update(1) - # Move to the next timesetp - d += datetime.timedelta(minutes=30) - time_idx += 1 - continue - - # Load an IMERG file with the current date - d2 = d + datetime.timedelta(seconds=30 * 60 - 1) - # Number of minutes since 00:00 - count = int((d - datetime.datetime(d.year, d.month, d.day)).seconds / 60) - IMERG_file_name = ( - "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07" - + "/%s/%s/" % (str(d.year), str(d.strftime("%b"))) - + f"3B-HHR.MS.MRG.3IMERG.{d.year}{d.month:02d}{d.day:02d}-S{d.hour:02d}{d.minute:02d}00-" - + f"E{d2.hour:02d}{d2.minute:02d}{d2.second:02d}.{count:04d}.V07B.HDF5" - ) - - h5_file = h5py.File(IMERG_file_name) - times = h5_file["Grid"]["time"][:] - - # Check the time is correct - # if (d != datetime(1970,1,1) + timedelta(seconds=int(times[0]))): - # print(f"Incorrect time for {d}", datetime(1970,1,1) + timedelta(seconds=int(times[0]))) - - # Accumulate the rainfall - rain_IMERG[time_idx, :, :] = h5_file["Grid"]["precipitation"][ - 0, 1991:2343, 763:1147 - ] - h5_file.close() - - # Move to the next timesetp - d += datetime.timedelta(minutes=30) - - # Move to the next time index - time_idx += 1 - progbar.update(1) - - progbar.close() - - # Put into the same order as the IFS and cGAN data - rain_IMERG = np.moveaxis(rain_IMERG, [0, 1, 2], [0, 2, 1]) - - obs = xr.DataArray( - data=rain_IMERG.reshape(-1, len(latitude), len(longitude)), - dims=["time", "lat", "lon"], - coords={ - "time": times_xr, - "lat": latitude, - "lon": longitude, - }, - attrs=dict(description="IMERG 30 min precipitation", units="mm"), - ).rename("precipitation") - if clip_to_window: - obs = obs.sel( - { - "lat": slice(centre[0] - window_size, centre[0] + window_size), - "lon": slice(centre[1] - window_size, centre[1] + window_size), - } - ) - print( - "Finished loading in IMERG data in ----%.2f s-----" % (time.time() - start_time) - ) - - return obs.dropna("time", how="all") - - -def load_truth_and_mask(dates, time_idx=[5, 6, 7, 8], log_precip=True, normalise=True): - """ - Returns a single (truth, mask) item of data. - Parameters: - date: forecast start date - time_idx: forecast 'valid time' array index - log_precip: whether to apply log10(1+x) transformation - """ - ds_to_concat = [] - - for date in tqdm(dates): - date = str(date).split("T")[0].replace("-", "") - - # convert date and time_idx to get the correct truth file - fcst_date = datetime.datetime.strptime(date, "%Y%m%d") - - for idx_t in time_idx: - valid_dt = fcst_date + datetime.timedelta( - hours=int(idx_t) * time_res - ) # needs to change for 12Z forecasts - - fname = valid_dt.strftime("%Y%m%d_%H") - # print(fname) - - data_path = glob.glob(TRUTH_PATH + f"{date[:4]}/{fname}.nc") - if len(data_path) < 1: - break - # ds = xr.concat([xr.open_dataset(dataset).expand_dims(dim={'time':i}, - # axis=0) - # for i,dataset in enumerate(data_path)],dim='time').mean('time') - ds = xr.open_dataset(data_path[0]) - if log_precip: - ds["precipitation"] = logprec(ds["precipitation"]) - - # mask: False for valid truth data, True for invalid truth data - # (compatible with the NumPy masked array functionality) - # if all data is valid: - mask = ~np.isfinite(ds["precipitation"]) - ds["mask"] = mask - - ds_to_concat.append(ds) - - return xr.concat(ds_to_concat, dim="time") - - -def load_hires_constants(batch_size=1): - """ - - Get elevation and land sea mask on IMERG resolution - - """ - - oro_path = CONSTANTS_PATH + "elev.nc" - - lsm_path = CONSTANTS_PATH + "lsm.nc" - - ds = xr.open_mfdataset([lsm_path, oro_path]) - - # LSM is already 0:1 - ds["elevation"] = ds["elevation"] / 10000.0 - - return ds.expand_dims(dim={"time": batch_size}).compute() diff --git a/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py deleted file mode 100644 index 84a594e..0000000 --- a/xarray_batcher/.ipynb_checkpoints/normalise-checkpoint.py +++ /dev/null @@ -1,128 +0,0 @@ -## Normalisation functions, note to self you can apply universal functions from numpy as scipy that operate element wise on xarray -## not too many function comments as I feel like they are self-explanatory - -import numpy as np - -from .utils import get_metadata, load_fcst_norm - -## Unfortunately need to have this look up table, not sure what a work around is -precip_fields = ["Convective precipitation (water)", "Total Precipitation", "cp", "tp"] - -accumulated_fields = ["ssr"] - -## Normalisation to apply !!! make sure a field doesn't appear twice!!! -standard_scaling = ["Surface pressure", "2 metre temperature", "sp", "t2m"] -maximum_scaling = [ - "Convective available potential energy", - "Upward short-wave radiation flux", - "Downward short-wave radiation flux", - "Cloud water", - "Precipitable water", - "Ice water mixing ratio", - "Cloud mixing ratio", - "Rain mixing ratio", - "cape", - "ssr", - "tciw", - "tclw", - "tcrw", - "tcw", - "tcwv", -] -absminimum_maximum_scaling = [ - "U component of wind", - "V component of wind", - "u700", - "v700", -] - -fcst_norm = load_fcst_norm() -## get some standard stuff from utils -fcst_time_res, time_res, lonlatbox, fcst_spat_res = get_metadata() - - -def logprec(data, threshold=0.1, fill_value=0.02, mean=0.051, std=0.266): - log_scale = np.log10(1e-1 + data).astype(np.float32) - if threshold is not None: - log_scale.where(log_scale > np.log10(threshold), np.log10(fill_value)) - - log_scale.fillna(np.log10(fill_value)) - log_scale -= mean - log_scale /= std - return log_scale - - -def nonnegative(data): - return np.maximum(data, 0.0) # eliminate any data weirdness/regridding issues - - -def m_to_mm_per_hour(data, time_res): - data *= 1000 - return data / time_res # convert to mm/hr - - -def to_per_second(data, time_res): - # for all other accumulated fields [just ssr for us] - return data / ( - time_res * 3600 - ) # convert from a 6-hr difference to a per-second rate - - -def centre_at_mean(data, field): - # these are bounded well away from zero, so subtract mean from ens mean (but NOT from ens sd!) - return data - fcst_norm[field]["mean"] - - -def change_to_unit_std(data, field): - return data / fcst_norm[field]["std"] - - -def max_scaling(data, field): - return (data) / (fcst_norm[field]["max"]) - - -def absmin_max_scaling(data, field): - return data / max(-fcst_norm[field]["min"], fcst_norm[field]["max"]) - - -def convert_units(data, field, log_prec, m_to_mm=True): - if field in precip_fields: - - if m_to_mm: - data = m_to_mm_per_hour(data, time_res) - - if log_prec: - return logprec(data) - - else: - return data - - elif field in accumulated_fields: - data = to_per_second(data, time_res) - - return data - - else: - return data - - -def get_norm(data, field, location_of_vals=[0, 2]): - - if field in precip_fields: - return data - - if field in standard_scaling: - data.loc[{"i_x": location_of_vals}] = centre_at_mean( - data.sel({"i_x": location_of_vals}), field - ) - - return change_to_unit_std(data, field) - - if field in maximum_scaling: - return max_scaling(data, field) - - if field in absminimum_maximum_scaling: - return absmin_max_scaling(data, field) - - else: - return data diff --git a/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py deleted file mode 100644 index 4009919..0000000 --- a/xarray_batcher/.ipynb_checkpoints/setup_data-checkpoint.py +++ /dev/null @@ -1,100 +0,0 @@ -import copy -import sys - -import numpy as np -import pytorch_lightning as pl -from torch.utils.data import DataLoader -from torchvision import transforms - -from .custom_collate_fn import CustomCollateFnGen -from .loading import get_IMERG_year -from .torch_batcher import BatchTruth - - -class DataModule(pl.LightningDataModule): - def __init__( - self, - train_years=2018, - val_years=2018, - test_years=None, - xbatch_size=[None, 64, 64], - batch_size=8, - train_epoch_size=1000, - valid_epoch_size=200, - test_epoch_size=1000, - ): - super().__init__() - - self.datasets = {} - self.batch_size = batch_size - - mult = 1 - if train_years is not None: - - df_truth = get_IMERG_year(train_years, months=2) - self.datasets["train"] = BatchTruth( - df_truth, - batch_size=xbatch_size, - antialiasing=True, - transform=transforms.RandomVerticalFlip(p=0.5), - for_NJ=True, - length=240, - ) - - if val_years is not None: - - df_truth = get_IMERG_year(val_years, months=2) - self.datasets["valid"] = BatchTruth( - df_truth, - batch_size=[xbatch_size[0], 300, 300], - weighted_sampler=False, - antialiasing=True, - transform=transforms.RandomVerticalFlip(p=0.5), - for_NJ=True, - length=240, - ) - - if test_years is not None: - df_truth = get_IMERG_year(val_years, months=2) - self.datasets["test"] = BatchTruth( - df_truth, - batch_size=[xbatch_size[0], 300, 300], - weighted_sampler=False, - antialiasing=True, - for_NJ=True, - length=240, - ) - - else: - self.datasets["test"] = self.datasets["valid"] - - def dataloader(self, split): - collate_fn, mult = CustomCollateFnGen(None) - if split == "train": - - return DataLoader( - self.datasets[split], - batch_size=self.batch_size, - collate_fn=collate_fn, - pin_memory=True, - num_workers=0, - sampler=self.datasets[split].sampler, - drop_last=True, - ) - else: - return DataLoader( - self.datasets[split], - collate_fn=collate_fn, - pin_memory=True, - num_workers=0, - drop_last=True, - ) - - def train_dataloader(self): - return self.dataloader("train") - - def val_dataloader(self): - return self.dataloader("valid") - - def test_dataloader(self): - return self.dataloader("test") diff --git a/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py deleted file mode 100644 index fbc3628..0000000 --- a/xarray_batcher/.ipynb_checkpoints/torch_batcher-checkpoint.py +++ /dev/null @@ -1,351 +0,0 @@ -import dask -import numpy as np -import torch -import xbatcher -from scipy.spatial import KDTree -from tqdm import tqdm -from tqdm.dask import TqdmCallback - -from .batch_helper_functions import Antialiasing, get_spherical - - -class BatchDataset(torch.utils.data.Dataset): - - """ - class for iterating over a dataset - """ - - def __init__( - self, - X, - y, - constants, - batch_size: list[int] = [4, 128, 128], - weighted_sampler: bool = True, - for_NJ: bool = False, - for_val: bool = False, - antialiasing: bool = False, - ): - self.batch_size = batch_size - self.X_generator = X - self.y_generator = xbatcher.BatchGenerator( - y, - {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, - input_overlap={ - "lat": int(batch_size[1] / 32), - "lon": int(batch_size[2] / 32), - }, - ) - constants["lat"] = np.round(y.lat.values, decimals=2) - constants["lon"] = np.round(y.lon.values, decimals=2) - - self.constants_generator = constants - - self.variables = [list(x.data_vars)[0] for x in X] - self.constants = list(constants.data_vars) - self.for_NJ = for_NJ - self.for_val = for_val - self.antialiasing = antialiasing - - if weighted_sampler: - y_train = [ - self.y_generator[i].precipitation.mean( - ["time", "lat", "lon"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - - rounded_y_train = np.round(y_train, decimals=1) - unique_classes = np.unique(rounded_y_train) - class_sample_count = np.bincount( - np.digitize(rounded_y_train, unique_classes) - 1 - ) - weight = 1.0 / class_sample_count - samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] - - self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) - self.sampler = torch.utils.data.WeightedRandomSampler( - self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) - ) - - def __len__(self) -> int: - return len(self.y_generator) - - def __getitem__(self, idx): - - y_batch = self.y_generator[idx] - time_batch = y_batch.time.values - lat_batch = np.round(y_batch.lat.values, decimals=2) - lon_batch = np.round(y_batch.lon.values, decimals=2) - - X_batch = [] - for x, variable in zip(self.X_generator, self.variables): - X_batch.append( - x[variable] - .sel({"time": time_batch, "lat": lat_batch, "lon": lon_batch}) - .values - ) - - X_batch = torch.from_numpy( - np.concatenate( - X_batch, - axis=-1, - ) - ).float() - - constant_batch = torch.from_numpy( - np.stack( - [ - self.constants_generator[constant] - .sel({"lat": lat_batch, "lon": lon_batch}) - .values - for constant in self.constants - ], - axis=-1, - ) - ).float() - - if self.for_NJ: - - elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) - lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) - spherical_coords = get_spherical( - lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values - ) - - kdtree = KDTree(spherical_coords) - - pairs = [] - - for i_coord, coord in enumerate(spherical_coords): - pairs.append( - np.vstack( - ( - np.full(3, fill_value=i_coord).reshape(1, -1), - kdtree.query(coord, k=3)[1], - ) - ) - ) - - pairs = np.hstack((pairs)) - - rainfall_path = torch.cat( - ( - torch.from_numpy( - y_batch.precipitation.fillna(0).values.reshape( - self.batch_size[0], -1, 1 - ) - ).float(), - X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), - ), - dim=-1, - ) - obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) - n_obs = np.array([self.batch_size[0]]) - if self.for_val: - obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) - n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) - obs_dates[: n_obs[0]] = 1 - - return { - "idx": idx, - "rainfall_path": rainfall_path[None, :, :, :], - "observed_dates": obs_dates, - "nb_obs": n_obs, - "dt": 1, - "edge_indices": pairs, - "obs_noise": None, - } - - else: - - if self.antialiasing: - antialiaser = Antialiasing() - y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values - y_batch = antialiaser(y_batch) - y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() - - else: - y_batch = torch.from_numpy( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] - ).float() - return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) - - -class BatchTruth(torch.utils.data.Dataset): - - """ - class for iterating over a dataset - """ - - def __init__( - self, - y, - batch_size=[4, 128, 128], - weighted_sampler=True, - for_NJ=False, - for_val=False, - length=None, - antialiasing=False, - transform=None, - return_dataset=False, - ): - - self.batch_size = batch_size - self.for_NJ = for_NJ - self.for_val = for_val - self.length = length - self.antialiasing = antialiasing - self.transform = transform - self.return_dataset = return_dataset - overlap = ( - {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} - if for_NJ - else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} - ) - self.y_generator = xbatcher.BatchGenerator( - y, - { - "time": batch_size[0], - "latitude" if for_NJ else "lat": batch_size[1], - "longitude" if for_NJ else "lon": batch_size[2], - }, - input_overlap=overlap, - ) - - if weighted_sampler: - if self.for_NJ: - y_train = [ - self.y_generator[i].mean( - ["time", "latitude", "longitude"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - else: - y_train = [ - self.y_generator[i].precipitation.mean( - ["time", "lat", "lon"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - rounded_y_train = np.round(y_train, decimals=1) - unique_classes = np.unique(rounded_y_train) - class_sample_count = np.bincount( - np.digitize(rounded_y_train, unique_classes) - 1 - ) - weight = 1.0 / class_sample_count - samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] - - self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) - self.sampler = torch.utils.data.WeightedRandomSampler( - self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) - ) - - def __len__(self) -> int: - return len(self.y_generator) - - def __getitem__(self, idx): - - y_batch = self.y_generator[idx] - - if self.return_dataset: - return y_batch - - if self.for_NJ: - - def generate(y_batch, length, stop=None): - - rng = np.random.default_rng() - random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] - ds_sel = y_batch.sel( - { - "time": slice( - "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) - ) - } - ) - - time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] - - time_to_event = rng.choice(np.arange(length), 1)[0] - time_after_event = length - time_to_event - 1 - - rainfall_path = ds_sel.sel( - { - "time": slice( - time_of_event - np.timedelta64(time_to_event * 30, "m"), - time_of_event + np.timedelta64(time_after_event * 30, "m"), - ), - } - ) - times_rainfall = rainfall_path.time.values - rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] - - if stop is not None: - # limit observations to once a day - nb_obs_single = stop - obs_ptr = np.arange(1, nb_obs_single) - - else: - nb_obs_single = length - obs_ptr = np.arange(1, length) - - observed_date = np.zeros(rainfall_path.shape[1]) - observed_date[0] = 1 - - for i_obs in obs_ptr: - - observed_date[i_obs] = 1 - - return rainfall_path, observed_date, nb_obs_single - - rainfall_paths = [] - observed_dates = [] - n_obs = [] - - stop = None - batch_size = 50 - if self.for_val: - rng = np.random.default_rng() - stop = rng.choice(np.arange(2, self.length - 100), 1)[0] - batch_size = 2 - - for i in range(batch_size): - rainfall_path, observed_date, nb_obs = generate( - y_batch, self.length, stop=stop - ) - rainfall_paths.append(rainfall_path) - observed_dates.append(observed_date) - n_obs.append(nb_obs) - - rainfall_paths = np.vstack(rainfall_paths) - observed_dates = np.stack(observed_dates) - n_obs = np.asarray(n_obs) - - return { - "idx": idx, - "rainfall_path": torch.tensor( - rainfall_paths[:, :, :, :, None], dtype=torch.float32 - ), - "observed_dates": observed_dates, - "nb_obs": n_obs, - "dt": 1, - "obs_noise": None, - } - - else: - if self.antialiasing: - antialiaser = Antialiasing() - y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values - y_batch = antialiaser(y_batch) - y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) - - else: - y_batch = torch.tensor( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], - dtype=torch.float32, - ) - if self.transform: - y_batch = self.transform(y_batch) - - return y_batch diff --git a/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py deleted file mode 100644 index 618757e..0000000 --- a/xarray_batcher/.ipynb_checkpoints/torch_streamer-checkpoint.py +++ /dev/null @@ -1,380 +0,0 @@ -import dask -import numpy as np -import torch -import xbatcher -from scipy.spatial import KDTree - -from xarray_batcher.get_fcst_and_truth import get_all - -from .batch_helper_functions import Antialiasing, get_spherical - - -class StreamDataset(torch.utils.data.IterableDataset): - - """ - Similar as BatchDataset, see torch_batcher.py apart - from the new workflow to assist in streaming: - - 1) Start using only truth data - 2) Calculate sampler - 3) When iterating through truth, load in the fcst. - data on-the-fly - - """ - - def __init__( - self, - y, - variables, - constants, - batch_size: list[int] = [4, 128, 128], - batches_per_epoch=1200, - weighted_sampler: bool = True, - for_NJ: bool = False, - for_val: bool = False, - antialiasing: bool = False, - ): - self.batch_size = batch_size - self.batches_per_epoch = batches_per_epoch - self.variables = variables - self.y_generator = xbatcher.BatchGenerator( - y, - {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, - input_overlap={ - "lat": int(batch_size[1] / 32), - "lon": int(batch_size[2] / 32), - }, - ) - constants["lat"] = np.round(y.lat.values, decimals=2) - constants["lon"] = np.round(y.lon.values, decimals=2) - - self.constants_generator = constants - - self.constants = list(constants.data_vars) - self.for_NJ = for_NJ - self.for_val = for_val - self.antialiasing = antialiasing - - if weighted_sampler: - y_train = [ - self.y_generator[i].precipitation.mean( - ["time", "lat", "lon"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - - rounded_y_train = np.round(y_train, decimals=0) - unique_classes = np.unique(rounded_y_train) - class_sample_count = np.bincount( - np.digitize(rounded_y_train, unique_classes) - 1 - ) - weight = 1.0 / class_sample_count - sample_weights = weight[np.digitize(rounded_y_train, unique_classes) - 1] - sample_weights = sample_weights / np.sum(sample_weights) - - self.sample_weights = torch.from_numpy(np.asarray(sample_weights)) - else: - self.sample_weights = None - - self.len = len(self.y_generator) - - def __len__(self): - return self.batches_per_epoch - - def __iter__(self): - self.idx = 0 - while self.idx <= self.__len__(): - try: - yield self.__sample__() - self.idx += 1 - except: - continue - - def __sample__(self): - - if self.sample_weights is None: - idx_samp = np.random.randint(0, self.len) - else: - idx_samp = int(np.random.choice(self.len, p=self.sample_weights)) - - y_batch = self.y_generator[idx_samp] - time_batch = y_batch.time.values - lat_batch = np.round(y_batch.lat.values, decimals=2) - lon_batch = np.round(y_batch.lon.values, decimals=2) - - X_generator = get_all( - None, - model="ifs", - truth_batch=y_batch, - stream=True, - offset=24, - variables=self.variables, - ) - - X_batch = [] - for x, variable in zip(X_generator, self.variables): - X_batch.append(x[variable].values) - - X_batch = torch.from_numpy( - np.concatenate( - X_batch, - axis=-1, - ) - ).float() - - constant_batch = torch.from_numpy( - np.stack( - [ - self.constants_generator[constant] - .sel({"lat": lat_batch, "lon": lon_batch}) - .values - for constant in self.constants - ], - axis=-1, - ) - ).float() - - if self.for_NJ: - - elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) - lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) - spherical_coords = get_spherical( - lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values - ) - - kdtree = KDTree(spherical_coords) - - pairs = [] - - for i_coord, coord in enumerate(spherical_coords): - pairs.append( - np.vstack( - ( - np.full(3, fill_value=i_coord).reshape(1, -1), - kdtree.query(coord, k=3)[1], - ) - ) - ) - - pairs = np.hstack((pairs)) - - rainfall_path = torch.cat( - ( - torch.from_numpy( - y_batch.precipitation.fillna(0).values.reshape( - self.batch_size[0], -1, 1 - ) - ).float(), - X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), - ), - dim=-1, - ) - obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) - n_obs = np.array([self.batch_size[0]]) - if self.for_val: - obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) - n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) - obs_dates[: n_obs[0]] = 1 - - return { - "idx": idx, - "rainfall_path": rainfall_path[None, :, :, :], - "observed_dates": obs_dates, - "nb_obs": n_obs, - "dt": 1, - "edge_indices": pairs, - "obs_noise": None, - } - - else: - - if self.antialiasing: - antialiaser = Antialiasing() - y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values - y_batch = antialiaser(y_batch) - y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() - - else: - y_batch = torch.from_numpy( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] - ).float() - return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) - - -class StreamTruth(torch.utils.data.Dataset): - - """ - class for iterating over a dataset - """ - - def __init__( - self, - y, - batch_size=[4, 128, 128], - weighted_sampler=True, - for_NJ=False, - for_val=False, - length=None, - antialiasing=False, - transform=None, - return_dataset=False, - ): - - self.batch_size = batch_size - self.for_NJ = for_NJ - self.for_val = for_val - self.length = length - self.antialiasing = antialiasing - self.transform = transform - self.return_dataset = return_dataset - overlap = ( - {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} - if for_NJ - else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} - ) - self.y_generator = xbatcher.BatchGenerator( - y, - { - "time": batch_size[0], - "latitude" if for_NJ else "lat": batch_size[1], - "longitude" if for_NJ else "lon": batch_size[2], - }, - input_overlap=overlap, - ) - - if weighted_sampler: - if self.for_NJ: - y_train = [ - self.y_generator[i].mean( - ["time", "latitude", "longitude"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - else: - y_train = [ - self.y_generator[i].precipitation.mean( - ["time", "lat", "lon"], skipna=False - ) - for i in range(len(self.y_generator)) - ] - rounded_y_train = np.round(y_train, decimals=1) - unique_classes = np.unique(rounded_y_train) - class_sample_count = np.bincount( - np.digitize(rounded_y_train, unique_classes) - 1 - ) - weight = 1.0 / class_sample_count - samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] - - self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) - self.sampler = torch.utils.data.WeightedRandomSampler( - self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) - ) - - def __len__(self) -> int: - return len(self.y_generator) - - def __getitem__(self, idx): - - y_batch = self.y_generator[idx] - - if self.return_dataset: - return y_batch - - if self.for_NJ: - - def generate(y_batch, length, stop=None): - - rng = np.random.default_rng() - random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] - ds_sel = y_batch.sel( - { - "time": slice( - "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) - ) - } - ) - - time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] - - time_to_event = rng.choice(np.arange(length), 1)[0] - time_after_event = length - time_to_event - 1 - - rainfall_path = ds_sel.sel( - { - "time": slice( - time_of_event - np.timedelta64(time_to_event * 30, "m"), - time_of_event + np.timedelta64(time_after_event * 30, "m"), - ), - } - ) - times_rainfall = rainfall_path.time.values - rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] - - if stop is not None: - # limit observations to once a day - nb_obs_single = stop - obs_ptr = np.arange(1, nb_obs_single) - - else: - nb_obs_single = length - obs_ptr = np.arange(1, length) - - observed_date = np.zeros(rainfall_path.shape[1]) - observed_date[0] = 1 - - for i_obs in obs_ptr: - - observed_date[i_obs] = 1 - - return rainfall_path, observed_date, nb_obs_single - - rainfall_paths = [] - observed_dates = [] - n_obs = [] - - stop = None - batch_size = 50 - if self.for_val: - rng = np.random.default_rng() - stop = rng.choice(np.arange(2, self.length - 100), 1)[0] - batch_size = 2 - - for i in range(batch_size): - rainfall_path, observed_date, nb_obs = generate( - y_batch, self.length, stop=stop - ) - rainfall_paths.append(rainfall_path) - observed_dates.append(observed_date) - n_obs.append(nb_obs) - - rainfall_paths = np.vstack(rainfall_paths) - observed_dates = np.stack(observed_dates) - n_obs = np.asarray(n_obs) - - return { - "idx": idx, - "rainfall_path": torch.tensor( - rainfall_paths[:, :, :, :, None], dtype=torch.float32 - ), - "observed_dates": observed_dates, - "nb_obs": n_obs, - "dt": 1, - "obs_noise": None, - } - - else: - if self.antialiasing: - antialiaser = Antialiasing() - y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values - y_batch = antialiaser(y_batch) - y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) - - else: - y_batch = torch.tensor( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], - dtype=torch.float32, - ) - if self.transform: - y_batch = self.transform(y_batch) - - return y_batch diff --git a/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py b/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index 50ef143..0000000 --- a/xarray_batcher/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,168 +0,0 @@ -## Utils needed in loading zarr and batching -import datetime -import os -import pickle - -import numpy as np - -## Put all forecast fields, their levels (can be None also), and specify categories of accumulated fields -accumulated_fields = ["ssr", "cp", "tp"] - - -## Put other user-specification i.e., lon-lat box, spatial and temporal resolution (in hours) -TIME_RES = 6 - -LONLATBOX = [-14, 19, 25.25, 54.75] -FCST_SPAT_RES = 0.1 -FCST_TIME_RES = 3 - -## Put all directories here - -TRUTH_PATH = ( - "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IMERG_V07/ICPAC_region/6h/" -) -FCST_PATH_IFS = ( - "/network/group/aopp/predict/TIP021_MCRAECOOPER_IFS/IFS-regICPAC-meansd/" -) - -CONSTANTS_PATH = ( - "/network/group/aopp/predict/TIP022_NATH_GFSAIMOD/cGAN/constants-regICPAC/" -) - - -def get_metadata(): - - """ - Returns time resolution (in hours), lonlat box (bottom, left, top, right) and the forecast's spatial resolution - """ - - return FCST_TIME_RES, TIME_RES, LONLATBOX, FCST_SPAT_RES - - -def get_paths(): - - return FCST_PATH_IFS, TRUTH_PATH, CONSTANTS_PATH - - -import pickle - - -def load_fcst_norm(year=2018): - - fcstnorm_path = os.path.join( - CONSTANTS_PATH.replace("-regICPAC", "_IFS"), f"FCSTNorm{year}.pkl" - ) - - with open(fcstnorm_path, "rb") as f: - return pickle.load(f) - - -def daterange(start_date, end_date): - - """ - Generator to get date range for a given time period from start_date to end_date - """ - - for n in range(int((end_date - start_date).days)): - yield start_date + datetime.timedelta(days=n) - - -def get_valid_dates( - year, - TIME_RES=TIME_RES, - start_hour=30, - end_hour=60, - raw_list=False, -): - - """ - Returns list of valid forecast start dates for which 'truth' data - exists, given the other input parameters. If truth data is not available - for certain days/hours, this will not be the full year. Dates are returned - as a list of YYYYMMDD strings. - - Parameters: - year (list): forecasts starting in this year - start_hour (int): Lead time of first forecast desired - end_hour (int): Lead time of last forecast desired - """ - - # sanity checks for our dataset - assert year in (2018, 2019, 2020, 2021, 2022, 2023, 2024) - assert start_hour >= 0 - assert start_hour % TIME_RES == 0 - assert end_hour % TIME_RES == 0 - assert end_hour > start_hour - - # Build "cache" of truth data dates/times that exist as well as forecasts - valid_dates = [] - - start_date = datetime.date(year, 1, 1) - end_date = datetime.date( - year + 1, 1, end_hour // TIME_RES + 2 - ) # go a bit into following year - - for curdate in daterange(start_date, end_date): - datestr = curdate.strftime("%Y%m%d") - valid = True - - ## then check for truth data at the desired lead time - for hr in np.arange(start_hour, end_hour, TIME_RES): - datestr_true = curdate + datetime.timedelta(hours=6) - datestr_true = datestr_true.strftime("%Y%m%d_%H") - fname = f"{datestr_true}" # {hr:02} - - if not os.path.exists( - os.path.join(TRUTH_PATH, f"{datestr_true[:4]}/{fname}.nc") - ): - valid = False - break - - if valid: - valid_dates.append(curdate) - - if raw_list: - # Need to get it from datetime to numpy readable format - valid_dates = [date.strftime("%Y-%m-%d") for date in valid_dates] - - return valid_dates - - -def match_fcst_to_valid_time(valid_times, time_idx, step_type="h"): - - """ - Inputs - ------ - valid_times: ndarray or datetime64 object - array of dates as data type datetime64[ns] - TIME_RES: integer - hourly timesteps which the fcst. is in - default = 6 - time_idx: int - array of prediction timedelta of same shape - as valid_dates, should be in hours, - data type = int. - step_type: str - type of fcst step e.g., D for day - default: 'h' for hour - - Outputs - ------- - fcst_dates: ndarray - i.e., valid_dates-time_idx - valid_date_idx: ndarray - to select - """ - - if not isinstance(time_idx, list): - time_offset = np.timedelta64(time_idx, step_type) - valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) - else: - time_offset = [np.timedelta64(t_idx, step_type) for t_idx in time_idx] - valid_date_idx = np.asarray( - [int(t_offset.astype(int) / TIME_RES) for t_offset in time_offset] - ) - - fcst_times = valid_times - time_offset - - return fcst_times, valid_date_idx