Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 99 additions & 12 deletions test_xbatcher.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "6fcec189-ecb8-4bc3-81c2-a858cb4f7a5a",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/n/nath/nobackups/miniforge3/envs/torch_24/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import xarray_batcher as xb"
]
Expand Down Expand Up @@ -49,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "e3824ef0-1d4b-4932-b568-f8905f93e360",
"metadata": {},
"outputs": [],
Expand All @@ -60,28 +69,104 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "cdc49def-ea1d-4ab9-94db-a8061093f126",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Only getting truth values over, 376 dates\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████| 376/376 [00:19<00:00, 19.46it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished retrieving truth values in ---- 20.60894227027893 s---- for years [2018]\n"
]
}
],
"source": [
"truth_ds = get_all([2018],model='truth')"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "d4e94238-e2ae-4b54-96dc-7c8e84423efb",
"metadata": {},
"outputs": [],
"source": [
"from xarray_batcher.loading import load_hires_constants\n",
"\n",
"constants = load_hires_constants(4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7194092b-1ced-48b6-881d-cbde68d6a5bf",
"metadata": {},
"outputs": [],
"source": [
"import xbatcher\n",
"import importlib\n",
"importlib.reload(xb)\n",
"\n",
"batch_size=[1,128,128]\n",
"\n",
"y_generator = xbatcher.BatchGenerator(truth_ds,\n",
" {\"time\": batch_size[0], \"lat\": batch_size[1], \"lon\": batch_size[2]},\n",
" input_overlap={\"lat\": int(batch_size[1]/32), \"lon\": int(batch_size[2]/32)})\n"
"data_generator = xb.StreamDataset(truth_ds,['cp', 'mcc', 'sp', 'ssr', 't2m', 'tciw', 'tclw', 'tcrw', 'tcw', 'tcwv', 'tp','u700','v700'],\n",
" constants,batch_size=[4,128,128])\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cd8f5d87-5655-44ef-bc58-c190948ae855",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.97 s, sys: 813 ms, total: 4.78 s\n",
"Wall time: 4.76 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"X,y = next(data_generator.__iter__())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b2f77a4b-cb35-4e83-ae27-30ce0f866d22",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 128, 128, 1])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
Expand All @@ -93,7 +178,9 @@
"source": [
"%%time\n",
"\n",
"get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,variables=['tp','cp','u700'])"
"get_all([2018],model='ifs',truth_batch=y_generator[0],stream=True,offset=24,\n",
" variables=['cp', 'u700', 'tp'])\n",
"\n"
]
},
{
Expand Down Expand Up @@ -121,7 +208,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "75d4787a-82b7-4374-b61a-fd59a8fdf880",
"id": "df9e88f2-5647-43f0-a99a-f20e21a9fbc5",
"metadata": {},
"outputs": [],
"source": []
Expand Down
5 changes: 3 additions & 2 deletions xarray_batcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## Initialisation for xarray batcher, import all helper functions
import sys


from .setup_data import DataModule
from .torch_batcher import BatchDataset, BatchTruth
from .torch_streamer import StreamDataset, StreamTruth

__all__ = ["DataModule"]
__all__ = ["DataModule", "BatchDataset", "BatchTruth", "StreamDataset", "StreamTruth"]
4 changes: 4 additions & 0 deletions xarray_batcher/create_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def collate_fn(batch, elev=elev, reg_dict={}):
"elevation": elev_values,
"spherical_coords": spherical_coords,
"precipitation": [],
"time": [],
}
reg_dict[reg_sel]["precipitation"].append(batch.precipitation.values)
reg_dict[reg_sel]["time"].append(batch.time.values)

return reg_dict

Expand Down Expand Up @@ -81,10 +83,12 @@ def TruthDataloader_to_Npz(
precipitation = np.stack(
[np.vstack(reg_dict[key]["precipitation"]) for key in reg_dict.keys()]
)
time = np.stack([np.hstack(reg_dict[key]["time"]) for key in reg_dict.keys()])

np.savez(
out_path + f"{year}_30min_IMERG_Nairobi_windowsize={window_size}.npz",
spherical_coords=spherical_coords,
elevation=elevation,
precipitation=precipitation,
time=time,
)
12 changes: 8 additions & 4 deletions xarray_batcher/get_fcst_and_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def open_mfzarr(
combined = datasets[0] # or some combination of concat, merge
dates = []
for dataset in combined:
dates += list(np.unique(dataset.time.values.astype("datetime64[D]")))
if time_idx is None:
dates += list(np.unique(dataset.time.values.astype("datetime64[D]")))
else:
dates += list(np.unique(dataset.time.values))

return combined, dates

Expand Down Expand Up @@ -154,6 +157,7 @@ def modify(
ds = streamline_and_normalise_ifs(
name[1].split("_")[0], ds, time_idx=time_idx, split_steps=split_steps
).to_dataset(name=name[1].split("_")[0])

return ds

else:
Expand Down Expand Up @@ -189,11 +193,11 @@ def stream_ifs(truth_batch, offset=24, variables=None):

# Get hours in the truth_batch times object
# First need to convert to format with base units of hours to extract hour offset
hour = batch_time.astype("datetime64[h]").astype(object)[0].hour
hour = [time.hour for time in batch_time.astype("datetime64[h]").astype(object)]

# Note that if hour is 0 then we add 24 as this
# is the offset+24
hour = hour + 24 * (hour == 0) + offset
hour = [h + 24 * (h == 0) + offset for h in hour]

fcst_date, time_idx = match_fcst_to_valid_time(batch_time, hour)

Expand All @@ -215,7 +219,7 @@ def stream_ifs(truth_batch, offset=24, variables=None):
time_idx=time_idx,
clip_to_window=False,
)
assert batch_time.astype("datetime64[D]") == np.unique(dates_modified)
assert np.all(np.isin(dates_modified, batch_time))
return ds


Expand Down
39 changes: 27 additions & 12 deletions xarray_batcher/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,16 @@ def streamline_and_normalise_ifs(

xr.DataArray or xr.Dataset of streamline and normalised values

NOTE: We replace the time with the valid time NOT initia
NOTE: We replace the time with the valid time NOT initial fcst
time.

"""

all_data_mean = da[f"{field}_mean"].values
all_data_sd = da[f"{field}_sd"].values

da.close()

if time_idx is None:
times = np.hstack(
(
Expand All @@ -230,6 +233,8 @@ def streamline_and_normalise_ifs(
)
)
else:
if time_idx.shape[0] % 4 == 0:
time_idx = time_idx.reshape(-1, 4)
assert da.fcst_valid_time.values.shape[0] == time_idx.shape[0]
times = np.hstack(
(
Expand All @@ -253,26 +258,36 @@ def streamline_and_normalise_ifs(
else:

for i_row, start in enumerate(time_idx):

data.append(
retrieve_vars_ifs(
field,
all_data_mean[[i_row]],
all_data_sd[[i_row]],
start=start,
end=start + 1,
if isinstance(start, np.ndarray):
for s in start:
data.append(
retrieve_vars_ifs(
field,
all_data_mean[[i_row]],
all_data_sd[[i_row]],
start=s,
end=s + 1,
)
)
else:
data.append(
retrieve_vars_ifs(
field,
all_data_mean[[i_row]],
all_data_sd[[i_row]],
start=start,
end=start + 1,
)
)
)

data = np.hstack((data)).reshape(-1, da.latitude.shape[0], da.longitude.shape[0], 4)

da = xr.DataArray(
data=data,
dims=["time", "lat", "lon", "i_x"],
coords=dict(
lon=da.longitude.values,
lat=da.latitude.values,
time=times,
time=times.flatten(),
i_x=np.arange(4),
),
)
Expand Down
27 changes: 17 additions & 10 deletions xarray_batcher/torch_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
X,
y,
constants,
batch_size: List[int] = [4, 128, 128],
batch_size: list[int] = [4, 128, 128],
weighted_sampler: bool = True,
for_NJ: bool = False,
for_val: bool = False,
Expand Down Expand Up @@ -90,7 +90,8 @@ def __getitem__(self, idx):
np.concatenate(
X_batch,
axis=-1,
)).float()
)
).float()

constant_batch = torch.from_numpy(
np.stack(
Expand All @@ -101,7 +102,8 @@ def __getitem__(self, idx):
for constant in self.constants
],
axis=-1,
)).float()
)
).float()

if self.for_NJ:

Expand Down Expand Up @@ -132,7 +134,8 @@ def __getitem__(self, idx):
torch.from_numpy(
y_batch.precipitation.fillna(0).values.reshape(
self.batch_size[0], -1, 1
)).float(),
)
).float(),
X_batch.reshape(self.batch_size[0], -1, len(self.variables) * 4),
),
dim=-1,
Expand Down Expand Up @@ -164,7 +167,8 @@ def __getitem__(self, idx):

else:
y_batch = torch.from_numpy(
y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None]).float()
y_batch.precipitation.fillna(np.log10(0.02)).values[:, :, :, None]
).float()
return (torch.cat((X_batch, constant_batch), dim=-1), y_batch)


Expand Down Expand Up @@ -200,11 +204,14 @@ def __init__(
else {"lat": int(batch_size[1] // 8), "lon": int(batch_size[2] // 8)}
)
self.y_generator = xbatcher.BatchGenerator(
y,
{"time": batch_size[0],
"latitude" if for_NJ else "lat": batch_size[1], "longitude" if for_NJ else "lon": batch_size[2]},
input_overlap=overlap,
)
y,
{
"time": batch_size[0],
"latitude" if for_NJ else "lat": batch_size[1],
"longitude" if for_NJ else "lon": batch_size[2],
},
input_overlap=overlap,
)

if weighted_sampler:
if self.for_NJ:
Expand Down
Loading