diff --git a/test_xbatcher.ipynb b/test_xbatcher.ipynb index 1525653..ef7bc69 100644 --- a/test_xbatcher.ipynb +++ b/test_xbatcher.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "6fcec189-ecb8-4bc3-81c2-a858cb4f7a5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/n/nath/nobackups/miniforge3/envs/torch_24/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import xarray_batcher as xb" ] @@ -49,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "e3824ef0-1d4b-4932-b568-f8905f93e360", "metadata": {}, "outputs": [], @@ -60,28 +69,104 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "cdc49def-ea1d-4ab9-94db-a8061093f126", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Only getting truth values over, 376 dates\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████| 376/376 [00:19<00:00, 19.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished retrieving truth values in ---- 20.60894227027893 s---- for years [2018]\n" + ] + } + ], "source": [ "truth_ds = get_all([2018],model='truth')" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "d4e94238-e2ae-4b54-96dc-7c8e84423efb", + "metadata": {}, + "outputs": [], + "source": [ + "from xarray_batcher.loading import load_hires_constants\n", + "\n", + "constants = load_hires_constants(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "id": "7194092b-1ced-48b6-881d-cbde68d6a5bf", "metadata": {}, "outputs": [], "source": [ - "import xbatcher\n", + "import importlib\n", + "importlib.reload(xb)\n", "\n", "batch_size=[1,128,128]\n", "\n", - "y_generator = xbatcher.BatchGenerator(truth_ds,\n", - " {\"time\": batch_size[0], \"lat\": batch_size[1], \"lon\": batch_size[2]},\n", - " input_overlap={\"lat\": int(batch_size[1]/32), \"lon\": int(batch_size[2]/32)})\n" + "data_generator = xb.StreamDataset(truth_ds,['cp', 'mcc', 'sp', 'ssr', 't2m', 'tciw', 'tclw', 'tcrw', 'tcw', 'tcwv', 'tp','u700','v700'],\n", + " constants,batch_size=[4,128,128])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cd8f5d87-5655-44ef-bc58-c190948ae855", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.97 s, sys: 813 ms, total: 4.78 s\n", + "Wall time: 4.76 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "X,y = next(data_generator.__iter__())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b2f77a4b-cb35-4e83-ae27-30ce0f866d22", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 128, 128, 1])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.shape" ] }, { @@ -93,7 +178,9 @@ "source": [ "%%time\n", "\n", - "get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,variables=['tp','cp','u700'])" + "get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,\n", + " variables=['cp', 'u700', 'tp'])\n", + "\n" ] }, { @@ -121,7 +208,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75d4787a-82b7-4374-b61a-fd59a8fdf880", + "id": "df9e88f2-5647-43f0-a99a-f20e21a9fbc5", "metadata": {}, "outputs": [], "source": [] diff --git a/xarray_batcher/__init__.py b/xarray_batcher/__init__.py index 86b18df..772e0e6 100644 --- a/xarray_batcher/__init__.py +++ b/xarray_batcher/__init__.py @@ -1,7 +1,8 @@ ## Initialisation for xarray batcher, import all helper functions import sys - from .setup_data import DataModule +from .torch_batcher import BatchDataset, BatchTruth +from .torch_streamer import StreamDataset, StreamTruth -__all__ = ["DataModule"] \ No newline at end of file +__all__ = ["DataModule", "BatchDataset", "BatchTruth", "StreamDataset", "StreamTruth"] diff --git a/xarray_batcher/create_npz.py b/xarray_batcher/create_npz.py index 309f531..586c336 100644 --- a/xarray_batcher/create_npz.py +++ b/xarray_batcher/create_npz.py @@ -38,8 +38,10 @@ def collate_fn(batch, elev=elev, reg_dict={}): "elevation": elev_values, "spherical_coords": spherical_coords, "precipitation": [], + "time": [], } reg_dict[reg_sel]["precipitation"].append(batch.precipitation.values) + reg_dict[reg_sel]["time"].append(batch.time.values) return reg_dict @@ -81,10 +83,12 @@ def TruthDataloader_to_Npz( precipitation = np.stack( [np.vstack(reg_dict[key]["precipitation"]) for key in reg_dict.keys()] ) + time = np.stack([np.hstack(reg_dict[key]["time"]) for key in reg_dict.keys()]) np.savez( out_path + f"{year}_30min_IMERG_Nairobi_windowsize={window_size}.npz", spherical_coords=spherical_coords, elevation=elevation, precipitation=precipitation, + time=time, ) diff --git a/xarray_batcher/get_fcst_and_truth.py b/xarray_batcher/get_fcst_and_truth.py index 68ad71d..633d2f5 100644 --- a/xarray_batcher/get_fcst_and_truth.py +++ b/xarray_batcher/get_fcst_and_truth.py @@ -70,7 +70,10 @@ def open_mfzarr( combined = datasets[0] # or some combination of concat, merge dates = [] for dataset in combined: - dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) + if time_idx is None: + dates += list(np.unique(dataset.time.values.astype("datetime64[D]"))) + else: + dates += list(np.unique(dataset.time.values)) return combined, dates @@ -154,6 +157,7 @@ def modify( ds = streamline_and_normalise_ifs( name[1].split("_")[0], ds, time_idx=time_idx, split_steps=split_steps ).to_dataset(name=name[1].split("_")[0]) + return ds else: @@ -189,11 +193,11 @@ def stream_ifs(truth_batch, offset=24, variables=None): # Get hours in the truth_batch times object # First need to convert to format with base units of hours to extract hour offset - hour = batch_time.astype("datetime64[h]").astype(object)[0].hour + hour = [time.hour for time in batch_time.astype("datetime64[h]").astype(object)] # Note that if hour is 0 then we add 24 as this # is the offset+24 - hour = hour + 24 * (hour == 0) + offset + hour = [h + 24 * (h == 0) + offset for h in hour] fcst_date, time_idx = match_fcst_to_valid_time(batch_time, hour) @@ -215,7 +219,7 @@ def stream_ifs(truth_batch, offset=24, variables=None): time_idx=time_idx, clip_to_window=False, ) - assert batch_time.astype("datetime64[D]") == np.unique(dates_modified) + assert np.all(np.isin(dates_modified, batch_time)) return ds diff --git a/xarray_batcher/loading.py b/xarray_batcher/loading.py index 3fb746c..704fbef 100644 --- a/xarray_batcher/loading.py +++ b/xarray_batcher/loading.py @@ -213,13 +213,16 @@ def streamline_and_normalise_ifs( xr.DataArray or xr.Dataset of streamline and normalised values - NOTE: We replace the time with the valid time NOT initia + NOTE: We replace the time with the valid time NOT initial fcst + time. """ all_data_mean = da[f"{field}_mean"].values all_data_sd = da[f"{field}_sd"].values + da.close() + if time_idx is None: times = np.hstack( ( @@ -230,6 +233,8 @@ def streamline_and_normalise_ifs( ) ) else: + if time_idx.shape[0] % 4 == 0: + time_idx = time_idx.reshape(-1, 4) assert da.fcst_valid_time.values.shape[0] == time_idx.shape[0] times = np.hstack( ( @@ -253,26 +258,36 @@ def streamline_and_normalise_ifs( else: for i_row, start in enumerate(time_idx): - - data.append( - retrieve_vars_ifs( - field, - all_data_mean[[i_row]], - all_data_sd[[i_row]], - start=start, - end=start + 1, + if isinstance(start, np.ndarray): + for s in start: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=s, + end=s + 1, + ) + ) + else: + data.append( + retrieve_vars_ifs( + field, + all_data_mean[[i_row]], + all_data_sd[[i_row]], + start=start, + end=start + 1, + ) ) - ) data = np.hstack((data)).reshape(-1, da.latitude.shape[0], da.longitude.shape[0], 4) - da = xr.DataArray( data=data, dims=["time", "lat", "lon", "i_x"], coords=dict( lon=da.longitude.values, lat=da.latitude.values, - time=times, + time=times.flatten(), i_x=np.arange(4), ), ) diff --git a/xarray_batcher/torch_batcher.py b/xarray_batcher/torch_batcher.py index 1e13ea6..fbc3628 100644 --- a/xarray_batcher/torch_batcher.py +++ b/xarray_batcher/torch_batcher.py @@ -20,7 +20,7 @@ def __init__( X, y, constants, - batch_size: List[int] = [4, 128, 128], + batch_size: list[int] = [4, 128, 128], weighted_sampler: bool = True, for_NJ: bool = False, for_val: bool = False, @@ -90,7 +90,8 @@ def __getitem__(self, idx): np.concatenate( X_batch, axis=-1, - )).float() + ) + ).float() constant_batch = torch.from_numpy( np.stack( @@ -101,7 +102,8 @@ def __getitem__(self, idx): for constant in self.constants ], axis=-1, - )).float() + ) + ).float() if self.for_NJ: @@ -132,7 +134,8 @@ def __getitem__(self, idx): torch.from_numpy( y_batch.precipitation.fillna(0).values.reshape( self.batch_size[0], -1, 1 - )).float(), + ) + ).float(), X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), ), dim=-1, @@ -164,7 +167,8 @@ def __getitem__(self, idx): else: y_batch = torch.from_numpy( - y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None]).float() + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) @@ -200,11 +204,14 @@ def __init__( else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} ) self.y_generator = xbatcher.BatchGenerator( - y, - {"time": batch_size[0], - "latitude" if for_NJ else "lat": batch_size[1], "longitude" if for_NJ else "lon": batch_size[2]}, - input_overlap=overlap, - ) + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) if weighted_sampler: if self.for_NJ: diff --git a/xarray_batcher/torch_streamer.py b/xarray_batcher/torch_streamer.py index e69de29..618757e 100644 --- a/xarray_batcher/torch_streamer.py +++ b/xarray_batcher/torch_streamer.py @@ -0,0 +1,380 @@ +import dask +import numpy as np +import torch +import xbatcher +from scipy.spatial import KDTree + +from xarray_batcher.get_fcst_and_truth import get_all + +from .batch_helper_functions import Antialiasing, get_spherical + + +class StreamDataset(torch.utils.data.IterableDataset): + + """ + Similar as BatchDataset, see torch_batcher.py apart + from the new workflow to assist in streaming: + + 1) Start using only truth data + 2) Calculate sampler + 3) When iterating through truth, load in the fcst. + data on-the-fly + + """ + + def __init__( + self, + y, + variables, + constants, + batch_size: list[int] = [4, 128, 128], + batches_per_epoch=1200, + weighted_sampler: bool = True, + for_NJ: bool = False, + for_val: bool = False, + antialiasing: bool = False, + ): + self.batch_size = batch_size + self.batches_per_epoch = batches_per_epoch + self.variables = variables + self.y_generator = xbatcher.BatchGenerator( + y, + {"time": batch_size[0], "lat": batch_size[1], "lon": batch_size[2]}, + input_overlap={ + "lat": int(batch_size[1] / 32), + "lon": int(batch_size[2] / 32), + }, + ) + constants["lat"] = np.round(y.lat.values, decimals=2) + constants["lon"] = np.round(y.lon.values, decimals=2) + + self.constants_generator = constants + + self.constants = list(constants.data_vars) + self.for_NJ = for_NJ + self.for_val = for_val + self.antialiasing = antialiasing + + if weighted_sampler: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + + rounded_y_train = np.round(y_train, decimals=0) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + sample_weights = weight[np.digitize(rounded_y_train, unique_classes) - 1] + sample_weights = sample_weights / np.sum(sample_weights) + + self.sample_weights = torch.from_numpy(np.asarray(sample_weights)) + else: + self.sample_weights = None + + self.len = len(self.y_generator) + + def __len__(self): + return self.batches_per_epoch + + def __iter__(self): + self.idx = 0 + while self.idx <= self.__len__(): + try: + yield self.__sample__() + self.idx += 1 + except: + continue + + def __sample__(self): + + if self.sample_weights is None: + idx_samp = np.random.randint(0, self.len) + else: + idx_samp = int(np.random.choice(self.len, p=self.sample_weights)) + + y_batch = self.y_generator[idx_samp] + time_batch = y_batch.time.values + lat_batch = np.round(y_batch.lat.values, decimals=2) + lon_batch = np.round(y_batch.lon.values, decimals=2) + + X_generator = get_all( + None, + model="ifs", + truth_batch=y_batch, + stream=True, + offset=24, + variables=self.variables, + ) + + X_batch = [] + for x, variable in zip(X_generator, self.variables): + X_batch.append(x[variable].values) + + X_batch = torch.from_numpy( + np.concatenate( + X_batch, + axis=-1, + ) + ).float() + + constant_batch = torch.from_numpy( + np.stack( + [ + self.constants_generator[constant] + .sel({"lat": lat_batch, "lon": lon_batch}) + .values + for constant in self.constants + ], + axis=-1, + ) + ).float() + + if self.for_NJ: + + elev_values = np.squeeze(constant_batch[:, :, 0]).reshape(-1, 1) + lat_values, lon_values = np.meshgrid(lat_batch, lon_batch) + spherical_coords = get_spherical( + lat_values.reshape(-1, 1), lon_values.reshape(-1, 1), elev_values + ) + + kdtree = KDTree(spherical_coords) + + pairs = [] + + for i_coord, coord in enumerate(spherical_coords): + pairs.append( + np.vstack( + ( + np.full(3, fill_value=i_coord).reshape(1, -1), + kdtree.query(coord, k=3)[1], + ) + ) + ) + + pairs = np.hstack((pairs)) + + rainfall_path = torch.cat( + ( + torch.from_numpy( + y_batch.precipitation.fillna(0).values.reshape( + self.batch_size[0], -1, 1 + ) + ).float(), + X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4), + ), + dim=-1, + ) + obs_dates = np.ones(self.batch_size[0]).reshape(1, -1) + n_obs = np.array([self.batch_size[0]]) + if self.for_val: + obs_dates = np.zeros(self.batch_size[0]).reshape(1, -1) + n_obs = np.random.randint(1, self.batch_size[0] - 8, 1) + obs_dates[: n_obs[0]] = 1 + + return { + "idx": idx, + "rainfall_path": rainfall_path[None, :, :, :], + "observed_dates": obs_dates, + "nb_obs": n_obs, + "dt": 1, + "edge_indices": pairs, + "obs_noise": None, + } + + else: + + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.from_numpy(np.moveaxis(y_batch, 0, -1)).float() + + else: + y_batch = torch.from_numpy( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None] + ).float() + return (torch.cat((X_batch, constant_batch), dim=-1), y_batch) + + +class StreamTruth(torch.utils.data.Dataset): + + """ + class for iterating over a dataset + """ + + def __init__( + self, + y, + batch_size=[4, 128, 128], + weighted_sampler=True, + for_NJ=False, + for_val=False, + length=None, + antialiasing=False, + transform=None, + return_dataset=False, + ): + + self.batch_size = batch_size + self.for_NJ = for_NJ + self.for_val = for_val + self.length = length + self.antialiasing = antialiasing + self.transform = transform + self.return_dataset = return_dataset + overlap = ( + {"latitude": int(batch_size[1] - 8), "longitude": int(batch_size[2] - 8)} + if for_NJ + else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)} + ) + self.y_generator = xbatcher.BatchGenerator( + y, + { + "time": batch_size[0], + "latitude" if for_NJ else "lat": batch_size[1], + "longitude" if for_NJ else "lon": batch_size[2], + }, + input_overlap=overlap, + ) + + if weighted_sampler: + if self.for_NJ: + y_train = [ + self.y_generator[i].mean( + ["time", "latitude", "longitude"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + else: + y_train = [ + self.y_generator[i].precipitation.mean( + ["time", "lat", "lon"], skipna=False + ) + for i in range(len(self.y_generator)) + ] + rounded_y_train = np.round(y_train, decimals=1) + unique_classes = np.unique(rounded_y_train) + class_sample_count = np.bincount( + np.digitize(rounded_y_train, unique_classes) - 1 + ) + weight = 1.0 / class_sample_count + samples_weight = weight[np.digitize(rounded_y_train, unique_classes) - 1] + + self.samples_weight = torch.from_numpy(np.asarray(samples_weight)) + self.sampler = torch.utils.data.WeightedRandomSampler( + self.samples_weight.type("torch.DoubleTensor"), len(samples_weight) + ) + + def __len__(self) -> int: + return len(self.y_generator) + + def __getitem__(self, idx): + + y_batch = self.y_generator[idx] + + if self.return_dataset: + return y_batch + + if self.for_NJ: + + def generate(y_batch, length, stop=None): + + rng = np.random.default_rng() + random_year = rng.choice(np.unique(y_batch["time.year"].values), 1)[0] + ds_sel = y_batch.sel( + { + "time": slice( + "%i-01-01" % random_year, "%i-01-01" % (random_year + 1) + ) + } + ) + + time_of_event = rng.choice(ds_sel.time.values[length:-length], 1)[0] + + time_to_event = rng.choice(np.arange(length), 1)[0] + time_after_event = length - time_to_event - 1 + + rainfall_path = ds_sel.sel( + { + "time": slice( + time_of_event - np.timedelta64(time_to_event * 30, "m"), + time_of_event + np.timedelta64(time_after_event * 30, "m"), + ), + } + ) + times_rainfall = rainfall_path.time.values + rainfall_path = rainfall_path.fillna(0).values[None, :, :, :] + + if stop is not None: + # limit observations to once a day + nb_obs_single = stop + obs_ptr = np.arange(1, nb_obs_single) + + else: + nb_obs_single = length + obs_ptr = np.arange(1, length) + + observed_date = np.zeros(rainfall_path.shape[1]) + observed_date[0] = 1 + + for i_obs in obs_ptr: + + observed_date[i_obs] = 1 + + return rainfall_path, observed_date, nb_obs_single + + rainfall_paths = [] + observed_dates = [] + n_obs = [] + + stop = None + batch_size = 50 + if self.for_val: + rng = np.random.default_rng() + stop = rng.choice(np.arange(2, self.length - 100), 1)[0] + batch_size = 2 + + for i in range(batch_size): + rainfall_path, observed_date, nb_obs = generate( + y_batch, self.length, stop=stop + ) + rainfall_paths.append(rainfall_path) + observed_dates.append(observed_date) + n_obs.append(nb_obs) + + rainfall_paths = np.vstack(rainfall_paths) + observed_dates = np.stack(observed_dates) + n_obs = np.asarray(n_obs) + + return { + "idx": idx, + "rainfall_path": torch.tensor( + rainfall_paths[:, :, :, :, None], dtype=torch.float32 + ), + "observed_dates": observed_dates, + "nb_obs": n_obs, + "dt": 1, + "obs_noise": None, + } + + else: + if self.antialiasing: + antialiaser = Antialiasing() + y_batch = y_batch.precipitation.fillna(np.log10(0.02)).values + y_batch = antialiaser(y_batch) + y_batch = torch.tensor(np.moveaxis(y_batch, 0, -1), dtype=torch.float32) + + else: + y_batch = torch.tensor( + y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None], + dtype=torch.float32, + ) + if self.transform: + y_batch = self.transform(y_batch) + + return y_batch diff --git a/xarray_batcher/utils.py b/xarray_batcher/utils.py index 978319b..50ef143 100644 --- a/xarray_batcher/utils.py +++ b/xarray_batcher/utils.py @@ -154,9 +154,15 @@ def match_fcst_to_valid_time(valid_times, time_idx, step_type="h"): to select """ - time_offset = np.timedelta64(time_idx, step_type) - fcst_times = valid_times - time_offset + if not isinstance(time_idx, list): + time_offset = np.timedelta64(time_idx, step_type) + valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) + else: + time_offset = [np.timedelta64(t_idx, step_type) for t_idx in time_idx] + valid_date_idx = np.asarray( + [int(t_offset.astype(int) / TIME_RES) for t_offset in time_offset] + ) - valid_date_idx = np.asarray([int(time_offset.astype(int) / TIME_RES)]) + fcst_times = valid_times - time_offset return fcst_times, valid_date_idx