diff --git a/notebooks/Gallery.ipynb b/notebooks/Gallery.ipynb index 47a615ce..644e323c 100644 --- a/notebooks/Gallery.ipynb +++ b/notebooks/Gallery.ipynb @@ -35,12 +35,14 @@ "\n", "These tutorials can be run on a 4GB GPU using relatively low volumes of data (3-10GB). They will also work in HPC environments.\n", "\n", - "| Title | Description | Image | Notebooks | Last Tested |\n", + "| Topic | Description | Image | Notebooks | Last Tested |\n", "|-------|--------------|-------|-------------|-------------|\n", "| **Simplified weather model** | Train a reduced-size weather model on a standard GPU with fetchable dataset | ![Image showing FourCastMini prediction outputs](https://pyearthtools.readthedocs.io/en/latest/_images/notebooks_tutorial_FourCastMini_Demo_18_1.png) | [Train and run a simplified global weather model (low hardware and data requirements)](./tutorial/FourCastMini_Demo.ipynb) | 18 Aug 2025 |\n", "| **MLX Demo** | Shows how to integrate PyEarthTools with a non-PyTorch framework (Apple MLX) optimised for M-series chips | ![Image showing weather model outputs from MLX demo](https://pyearthtools.readthedocs.io/en/latest/_images/notebooks_tutorial_MLX-Demo-Custom-Arch_13_1.png) | [MLX Framework Example](./tutorial/MLX-Demo-Custom-Arch.ipynb) | 8 Jun 2025 | \n", "| **Convolutional Neural Net on ERA5** | Shows all steps to train a CNN on ERA5, running on CPU or a standard GPU | ![Image showing weather model outputs](https://pyearthtools.readthedocs.io/en/latest/_images/notebooks_tutorial_CNN-Model-Training_44_1.png) | [End-to-end CNN Training Example](./tutorial/CNN-Model-Training.ipynb) | 25 Aug 2025 |\n", - "| **Radar Visualisation** | Shows how to visualise radar data as a time-series, in 2D and in 3D | ![Image showing a top down view of radar data](https://pyearthtools.readthedocs.io/en/latest/_images/notebooks_RadarVisualisation_10_1.png) | [Radar Visualisation](./RadarVisualisation.ipynb) | 23 Aug 2025 |\n" + "| **Radar Visualisation** | Shows how to visualise radar data as a time-series, in 2D and in 3D | ![Image showing a top down view of radar data](https://pyearthtools.readthedocs.io/en/latest/_images/notebooks_RadarVisualisation_10_1.png) | [Radar Visualisation](./RadarVisualisation.ipynb) | 23 Aug 2025 |\n", + "| **LUCIE Climate Model** | Train a climate model | (no image) | [LUCIE-Training](./tutorial/LUCIE/LUCIE-Training.ipynb) | 13 Nov 2025 |\n", + "| **LUCIE Climate Model** | Make predictions from a climate model | (no image) | [LUCIE-Inference](./tutorial/LUCIE/LUCIE-Inference.ipynb) | 13 Nov 2025 |\n" ] }, { diff --git a/notebooks/tutorial/LUCIE/LUCIE-Inference.ipynb b/notebooks/tutorial/LUCIE/LUCIE-Inference.ipynb new file mode 100644 index 00000000..b8be428a --- /dev/null +++ b/notebooks/tutorial/LUCIE/LUCIE-Inference.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "175e2165-f568-48e2-bedb-1245603b1ab5", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import lucie\n", + "import lucie.inference\n", + "from pathlib import Path\n", + "import numpy as np\n", + "import xarray as xr" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "bc8460bd-8691-403f-8cbb-dbeb4e39875a", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e5000929-4fec-4b6c-82c1-a5769287aa38", + "metadata": {}, + "outputs": [], + "source": [ + "regridded_path = Path.home() / 'dev/data/lucie' / 'era5_T30_regridded.npz'\n", + "regridded_data = lucie.train.load_data(regridded_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3facc4a4-ec3f-4fc8-be29-c244ec4268e2", + "metadata": {}, + "outputs": [], + "source": [ + "preprocessed_path = Path.home() / 'dev/data/lucie' / 'era5_T30_preprocessed.npz'\n", + "preprocessed_data = np.load(preprocessed_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fb5d5b70-681c-4cc4-b973-7b259bb6e81d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1min 49s, sys: 30.3 s, total: 2min 19s\n", + "Wall time: 1min 54s\n" + ] + } + ], + "source": [ + "%%time\n", + "%%capture\n", + "\n", + "# Note - these timings were obtained on a laptop, not on a high-performance GPU.\n", + "\n", + "predictions = lucie.inference.load_data_and_predict(device, regridded_data, preprocessed_data,model_weights_pth='model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9f9e517e-a6e6-4cff-bc82-fe0cff6c89ec", + "metadata": {}, + "outputs": [], + "source": [ + "da = xr.DataArray(predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e0f847a4-edd0-4dc4-9a2c-278c5df6127e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Please note - this image was generated from only a few samples of training and does not represent the final model\n", + "da[5][0].plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25598079-8d9b-4893-a807-cfe1c50d35b8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/LUCIE/LUCIE-Training.ipynb b/notebooks/tutorial/LUCIE/LUCIE-Training.ipynb new file mode 100644 index 00000000..8aaf970a --- /dev/null +++ b/notebooks/tutorial/LUCIE/LUCIE-Training.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a3575c36-ed8d-4bae-ab90-aefe441949f9", + "metadata": {}, + "source": [ + "# Training the LUCIE model\n", + "\n", + "LUCIE is a climate model developed by Haiwen Guan, Troy Arcomano, Ashesh Chattopadhyay and Romit Maulik (2024). See their preprint at https://doi.org/10.48550/arXiv.2405.16297 and the archive of their training data, code and results here https://doi.org/10.5281/zenodo.15164648.\n", + "\n", + "The code in PyEarthTools was based on their code repository at https://github.com/ISCLPennState/LUCIE, which is made available under the MIT license (see the PyEarthTools NOTICE file for full information on this point)\n", + "\n", + "LUCIE is a model which of interest to climate researchers due to its long-term stability for rollouts for many decades. This model is licensed in a compatible fashion, so we are able to provide a bundled, customised version of LUCIE which can be used within the PyEarthTools framework, integrated with its data pipelines and configurable to work flexibly.\n", + "\n", + "We have only just begun the process of this integration, and so for now the model does not make extensive use of the PyEarthTools classes. This is expected to change fairly quickly, and as this happens, this notebook will be updated. However, in the interests of providing the bundled version to the community as soon as possible for those already seeking to work with the model, we present it in a \"work in progress\" fashion.\n", + "\n", + "You need to manually download the original published dataset from Zenodo, and update the paths in this notebook to point to them. The initial focus will be on reproducing the paper fairly closely using the same data and only slightly modified code (changes to support more devices and updates for compatibility), true enough to the original. Subsequently, we will develop the code further to be adaptable to new data sources.\n", + "\n", + "The intention is to:\n", + " - [done] Supply the source code to train and run the model in PyEarthTools\n", + " - [done] Validate that the model can train without obvious code-level errors\n", + " - Validate inference and reproduce the training results to ensure the trained model is valid\n", + " - Support library updates and other changes\n", + " - Support multiple ML backends beyond CUDA\n", + " - Support connection to multiple data sources through PET data accessors\n", + " - Move the normalisation into a PET pipeline so it can be easily modified and experimented with\n", + "\n", + "If you would like to know more, or get involved with this work, please [let us know on the issue tracker](https://github.com/ACCESS-Community-Hub/PyEarthTools/issues/211)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e5068eca-cfcc-4dec-bf88-8b1fb870dc3b", + "metadata": {}, + "outputs": [], + "source": [ + "import lucie\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f69a338a-ff4e-465f-a664-cd76630baa52", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b81e2c08-bf62-49fc-9090-0595cbfd24ab", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4180dd8c-ff64-466b-b3bc-9771b2053a57", + "metadata": {}, + "outputs": [], + "source": [ + "regridded_path = Path.home() / 'dev/data/lucie' / 'era5_T30_regridded.npz'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7f5ca64a-87c8-4cae-a2e3-3a4788066a73", + "metadata": {}, + "outputs": [], + "source": [ + "regridded_data = lucie.train.load_data(regridded_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b53d4754-4303-4325-801a-afa626aac582", + "metadata": {}, + "outputs": [], + "source": [ + "preprocessed_path = Path.home() / 'dev/data/lucie' / 'era5_T30_preprocessed.npz'\n", + "preprocessed_data = np.load(preprocessed_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "22148013-b8d6-40c7-8c11-9f8545295b85", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2 [00:00 0: - out = torch.sum( - ugrid[..., polar_opt:-polar_opt, :] * quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1) - ) - else: - out = torch.sum(ugrid * quad_weights * dlon * radius**2, dim=(-2, -1)) - return out - - -def l2loss_sphere(prd, tar, relative=False, squared=True): - loss = integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1) - if relative: - loss = loss / integrate_grid(tar**2, dimensionless=True).sum(dim=-1) - - if not squared: - loss = torch.sqrt(loss) - loss = loss.mean() - - return loss - - -def train_model( - model, - train_loader, - val_loader, - optimizer, - scheduler=None, - nepochs=20, - nfuture=0, - num_examples=256, - num_valid=8, - reg_rate=0, -): - - infer_bias = 1e80 - recall_count = 0 - for epoch in tqdm(range(nepochs)): - if epoch < 149: - if scheduler is not None: - scheduler.step() - else: - for param_group in optimizer.param_groups: - param_group["lr"] = 1e-6 - - optimizer.zero_grad() - - model.train() - batch_num = 0 - for inp, tar in train_loader: - batch_num += 1 - loss = 0 - - inp = inp.to(device) - tar = tar.to(device) - prd = model(inp) - - loss_delta = l2loss_sphere(prd[:, :5, :, :], tar[:, :5, :, :], relative=True) - loss_tp = torch.mean((prd[:, 5:, :, :] - tar[:, 5:, :, :]) ** 2) - loss = loss_delta + loss_tp / tar.shape[1] - - lat_index = np.r_[7:15, 32:40] - # lat_index = np.r_[0:48] - # quad_weight_reg = quad_weights.reshape(1,1,48,1)[:,:,lat_index,:] - out_fft = torch.mean(torch.abs(torch.fft.rfft(prd[:, :, lat_index, :], dim=3)), dim=2) - target_fft = torch.mean(torch.abs(torch.fft.rfft(tar[:, :, lat_index, :], dim=3)), dim=2) - loss_reg = 0.05 * torch.mean(torch.abs(out_fft - target_fft)) - - if epoch > 150: - loss = loss + loss_reg - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if epoch % 10 == 0: - rollout_steps = 2920 - rollout = torch.tensor( - inference( - model, - rollout_steps, - data_inp[0:1].to(device), - data_inp[:1460, -2:].to(device), - 1, - prog_means, - prog_stds, - diag_means, - diag_stds, - diff_stds, - ) - ).to(device) - rollout_clim = torch.mean(rollout[1460:], dim=0) - clim_bias = torch.mean(torch.abs(rollout_clim - true_clim)) - print("2 year rollout bias", clim_bias) - if epoch > 60: - if clim_bias <= infer_bias: - infer_bias = clim_bias - torch.save(model.state_dict(), "regular_training_checkpoint.pth") - recall_count = 0 - else: - state_pth = torch.load("regular_training_checkpoint.pth") - model.load_state_dict(state_pth) - recall_count += 1 - if recall_count > 3: - break - - -data = load_data("era5_T30_regridded.npz")[..., :6] -true_clim = torch.tensor(np.mean(data, axis=0)).to(device).permute(2, 0, 1) - -data = np.load("era5_T30_preprocessed.npz") # standardized data with mean and stds generated from dataset_generator.py -data_inp = torch.tensor(data["data_inp"], dtype=torch.float32) # input data -data_tar = torch.tensor(data["data_tar"], dtype=torch.float32) -raw_means = torch.tensor(data["raw_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) -raw_stds = torch.tensor(data["raw_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) -prog_means = raw_means[:, :5] -prog_stds = raw_stds[:, :5] -diag_means = torch.tensor(data["diag_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) -diag_stds = torch.tensor(data["diag_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) -diff_means = torch.tensor(data["diff_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) -diff_stds = torch.tensor(data["diff_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) - -ntrain = 16000 -nval = 100 - -train_set = TensorDataset(data_inp[:ntrain], data_tar[:ntrain]) -val_set = TensorDataset(data_inp[ntrain : ntrain + nval], data_tar[ntrain : ntrain + nval]) - -train_loader = DataLoader(train_set, batch_size=16, shuffle=True) -val_loader = DataLoader(val_set, batch_size=4, shuffle=False) - - -grid = "legendre-gauss" -nlat = 48 -nlon = 96 -hard_thresholding_fraction = 0.9 -lmax = ceil(nlat / 1) -mmax = lmax -modes_lat = int(nlat * hard_thresholding_fraction) -modes_lon = int(nlon // 2 * hard_thresholding_fraction) -modes_lat = modes_lon = min(modes_lat, modes_lon) -sht = RealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=grid, csphase=False) -radius = 6.37122e6 -cost, quad_weights = legendre_gauss_weights(nlat, -1, 1) -quad_weights = (torch.as_tensor(quad_weights).reshape(-1, 1)).to(device) - -model = SphericalFourierNeuralOperatorNet( - params={}, - spectral_transform="sht", - filter_type="linear", - operator_type="dhconv", - img_shape=(48, 96), - num_layers=8, - in_chans=7, - out_chans=6, - scale_factor=1, - embed_dim=72, - activation_function="silu", - big_skip=True, - pos_embed="latlon", - use_mlp=True, - normalization_layer="instance_norm", - hard_thresholding_fraction=hard_thresholding_fraction, - mlp_ratio=2.0, -).to(device) - -optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0) -scheduler = CosineAnnealingLR(optimizer, T_max=150, eta_min=1e-5) -train_model(model, train_loader, val_loader, optimizer, scheduler=scheduler, nepochs=500) -torch.save(model.state_dict(), "model.pth") diff --git a/packages/bundled_models/lucie/README.md b/packages/bundled_models/lucie/README.md index 3eebfac5..dca21111 100644 --- a/packages/bundled_models/lucie/README.md +++ b/packages/bundled_models/lucie/README.md @@ -1,13 +1,17 @@ # LUCIE: Lightweight Uncoupled ClImate Emulator -Please note - this is a fork of https://github.com/ISCLPennState/LUCIE which has been adapted included in PyEarthTools for the purposes of maintenance, compatbility and to supply an integrated approach to using the LUCIE model within the PyEarthTools framework. +Please note - this is a adaptation of https://github.com/ISCLPennState/LUCIE which has been modified for inclusion in PyEarthTools for the purposes of maintenance, compatbility and to supply an integrated approach to using the LUCIE model within the PyEarthTools framework. + +This code was copied from the LUCIE repository from commit hash 19a1d6ebe844f49893f92e8b377ebdca8f6aa0e6 (Jul 9th, 2025). --- ## Paper & Data -- [arXiv Preprint: arxiv.org/abs/2405.16297](https://arxiv.org/abs/2405.16297) -- [Zenodo Archive: zenodo.org/records/15164648](https://zenodo.org/records/15164648) +These are the links for the original paper, code and data published by the LUCIE authors. The code was published to Zenodo under a Creative Commons license but the license in their github repository was MIT to allow improved code re-use. + +- [arXiv Preprint: https://doi.org/10.48550/arXiv.2405.16297](https://doi.org/10.48550/arXiv.2405.16297) +- [Zenodo Archive: [https://doi.org/10.5281/zenodo.15164648](https://doi.org/10.5281/zenodo.15164648) --- @@ -22,4 +26,4 @@ This repository prvides the following: 5. The data generator file that precprocesses the regridded ERA5 data. ## Note -Please refer to the zenodo link for the regridded ERA5 data. The link also includes the preprocessed data from the data generator file. +Please refer to the LUCIE zenodo link for the regridded ERA5 data. The link also includes the preprocessed data from the data generator file. diff --git a/packages/bundled_models/lucie/pyproject.toml b/packages/bundled_models/lucie/pyproject.toml new file mode 100644 index 00000000..34fbb687 --- /dev/null +++ b/packages/bundled_models/lucie/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + + +[project] +name = "pyearthtools-bundled-lucie" +version = "0.6.0" +description = "LUCIE Bundled Model" +readme = "README.md" +requires-python = ">=3.11, <3.14" +keywords = ["lucie"] +maintainers = [ + {name = "Tennessee Leeuwenburg", email = "tennessee.leeuwenburg@bom.gov.au"} +] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + 'pyearthtools.training[lightning]>=0.5.1', + 'pyearthtools.zoo>=0.5.1', + 'pyearthtools.data>=0.5.1', + 'pyearthtools.pipeline>=0.5.1', + 'torch_optimizer', + 'timm', +] + + +[project.urls] +homepage = "https://pyearthtools.readthedocs.io/" +documentation = "https://pyearthtools.readthedocs.io/" +repository = "https://github.com/ACCESS-Community-Hub/PyEarthTools" + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 120 + +[tool.mypy] +warn_return_any = true +warn_unused_configs = true + +[[tool.mypy.overrides]] +ignore_missing_imports = true + +[tool.hatch.version] +path = "src/lucie/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["src/lucie/"] diff --git a/packages/bundled_models/lucie/src/lucie/__init__.py b/packages/bundled_models/lucie/src/lucie/__init__.py new file mode 100644 index 00000000..95a13f6c --- /dev/null +++ b/packages/bundled_models/lucie/src/lucie/__init__.py @@ -0,0 +1,2 @@ +from lucie import train +from lucie import torch_harmonics_local diff --git a/packages/bundled_models/lucie/dataset_generator.py b/packages/bundled_models/lucie/src/lucie/dataset_generator.py similarity index 100% rename from packages/bundled_models/lucie/dataset_generator.py rename to packages/bundled_models/lucie/src/lucie/dataset_generator.py diff --git a/packages/bundled_models/lucie/LUCIE_inference.py b/packages/bundled_models/lucie/src/lucie/inference.py similarity index 86% rename from packages/bundled_models/lucie/LUCIE_inference.py rename to packages/bundled_models/lucie/src/lucie/inference.py index a36c102f..93dcd756 100644 --- a/packages/bundled_models/lucie/LUCIE_inference.py +++ b/packages/bundled_models/lucie/src/lucie/inference.py @@ -28,13 +28,14 @@ # import torch_harmonics.distributed as thd # from torch_harmonics import * +# from torch._C import float32 import torch.fft from tqdm import tqdm import torch -from torch_harmonics_local import * +from lucie.torch_harmonics_local import * device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -42,8 +43,9 @@ torch.cuda.set_device(0) -def inference( - model, steps, initial_frame, forcing, initial_forcing_idx, prog_means, prog_stds, diag_means, diag_stds, diff_stds +def infer(device, + model, steps, initial_frame, forcing, initial_forcing_idx, + prog_means, prog_stds, diag_means, diag_stds, diff_stds ): inf_data = [] model.eval() @@ -76,14 +78,17 @@ def inference( return inf_data -if __name__ == "__main__": +def load_data_and_predict( + device, + regridded_data, + preprocessed_data, # standardised data generated by dataset_generator.py + model_weights_pth='model.pth', + ): - data = load_data("era5_T30_regridded.npz")[..., :6] - true_clim = torch.tensor(np.mean(data, axis=0)).to(device).permute(2, 0, 1) + regridded_data = regridded_data[..., :6] + true_clim = torch.tensor(np.mean(regridded_data, axis=0)).to(device).permute(2, 0, 1) - data = np.load( - "era5_T30_preprocessed.npz" - ) # standardized data with mean and stds generated from dataset_generator.py + data = preprocessed_data # dictionary-like numpy array data_inp = torch.tensor(data["data_inp"], dtype=torch.float32) # input data data_tar = torch.tensor(data["data_tar"], dtype=torch.float32) raw_means = torch.tensor(data["raw_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) @@ -107,7 +112,8 @@ def inference( sht = RealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=grid, csphase=False) radius = 6.37122e6 cost, quad_weights = legendre_gauss_weights(nlat, -1, 1) - quad_weights = (torch.as_tensor(quad_weights).reshape(-1, 1)).to(device) + quad_weights = (torch.as_tensor(quad_weights).reshape(-1, 1)).to(torch.float32).to(device) + # quad_weights = (torch.as_tensor(quad_weights).reshape(-1, 1)).to(device) model = SphericalFourierNeuralOperatorNet( params={}, @@ -129,7 +135,7 @@ def inference( mlp_ratio=2.0, ).to(device) - path = torch.load("regular_8x72_fftreg_baseline.pth") + path = torch.load(model_weights_pth) model.load_state_dict(path) forcing = data_inp[:1460, -2:] # repeating tisr and constant oro @@ -137,7 +143,8 @@ def inference( rollout_step = 14600 initial_frame_idx = 16000 + 100 forcing_initial_idx = (16000 + 100) % 1460 + 1 - rollout = inference( + rollout = infer( + device, model, rollout_step, data_inp[initial_frame_idx].unsqueeze(0).to(device), @@ -149,3 +156,5 @@ def inference( diag_stds, diff_stds, ) + + return rollout diff --git a/packages/bundled_models/lucie/torch_harmonics_local.py b/packages/bundled_models/lucie/src/lucie/torch_harmonics_local.py similarity index 99% rename from packages/bundled_models/lucie/torch_harmonics_local.py rename to packages/bundled_models/lucie/src/lucie/torch_harmonics_local.py index 18024be2..e7b4e08e 100644 --- a/packages/bundled_models/lucie/torch_harmonics_local.py +++ b/packages/bundled_models/lucie/src/lucie/torch_harmonics_local.py @@ -11,7 +11,7 @@ # from torch_harmonics import * import torch.nn.functional as F import torch.fft -from torch.cuda import amp +from torch import amp # was from torch.cuda import amp import math import logging @@ -1158,7 +1158,7 @@ def forward(self, x): # pragma: no cover x = x.float() B, C, H, W = x.shape - with amp.autocast(enabled=False): + with amp.autocast(str(device), enabled=False): x = self.forward_transform(x) if self.scale_residual: x = x.contiguous() @@ -1179,7 +1179,7 @@ def forward(self, x): # pragma: no cover # x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) # x = x.contiguous() - with amp.autocast(enabled=False): + with amp.autocast(str(device), enabled=False): x = self.inverse_transform(x) if hasattr(self, "bias"): diff --git a/packages/bundled_models/lucie/src/lucie/train.py b/packages/bundled_models/lucie/src/lucie/train.py new file mode 100644 index 00000000..623ae39f --- /dev/null +++ b/packages/bundled_models/lucie/src/lucie/train.py @@ -0,0 +1,282 @@ +# MIT License + +# Copyright (c) 2025 ISCLPennState + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from math import ceil +import torch + +# import torch_harmonics as th +# import torch_harmonics.distributed as thd + +# from torch_harmonics import * +import torch.fft +from tqdm import tqdm + +import torch + +from torch.utils.data import TensorDataset, DataLoader + +from lucie.torch_harmonics_local import * + +from torch.optim.lr_scheduler import CosineAnnealingLR + +from lucie import inference + + +def integrate_grid(ugrid, nlon, quad_weights, dimensionless=False, polar_opt=0): + + dlon = 2 * torch.pi / nlon + radius = 1 if dimensionless else radius + if polar_opt > 0: + out = torch.sum( + ugrid[..., polar_opt:-polar_opt, :] * quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1) + ) + else: + out = torch.sum(ugrid * quad_weights * dlon * radius**2, dim=(-2, -1)) + return out + + +def l2loss_sphere(prd, tar, nlon, quad_weights, relative=False, squared=True): + loss = integrate_grid((prd - tar) ** 2, nlon, quad_weights, dimensionless=True).sum(dim=-1) + if relative: + loss = loss / integrate_grid(tar**2, nlon, quad_weights, dimensionless=True).sum(dim=-1) + + if not squared: + loss = torch.sqrt(loss) + loss = loss.mean() + + return loss + + +def train_model( + device, + model, + train_loader, + val_loader, + optimizer, + data_inp=None, + prog_means=None, + prog_stds=None, + diag_means=None, + diag_stds=None, + diff_stds=None, + nlon=96, + scheduler=None, + nepochs=20, + debug_sample_limit=5, + quad_weights=None, + true_clim=None, + nfuture=0, + num_examples=256, + num_valid=8, + reg_rate=0, +): + """ + Train your own weights for the LUCIE model + """ + + infer_bias = 1e80 + recall_count = 0 + + print("Starting Training") + for epoch in tqdm(range(nepochs)): + + if epoch < 149: + if scheduler is not None: + scheduler.step() + else: + for param_group in optimizer.param_groups: + param_group["lr"] = 1e-6 + + optimizer.zero_grad() + + model.train() + + batch_num = 0 + + zz = 0 + + for inp, tar in train_loader: + batch_num += 1 + loss = 0 + + zz += 1 + if zz > debug_sample_limit: + break + + inp = inp.to(device) + tar = tar.to(device) + prd = model(inp) + + loss_delta = l2loss_sphere(prd[:, :5, :, :], tar[:, :5, :, :], nlon, quad_weights, relative=True) + loss_tp = torch.mean((prd[:, 5:, :, :] - tar[:, 5:, :, :]) ** 2) + loss = loss_delta + loss_tp / tar.shape[1] + + lat_index = np.r_[7:15, 32:40] + # lat_index = np.r_[0:48] + # quad_weight_reg = quad_weights.reshape(1,1,48,1)[:,:,lat_index,:] + out_fft = torch.mean(torch.abs(torch.fft.rfft(prd[:, :, lat_index, :], dim=3)), dim=2) + target_fft = torch.mean(torch.abs(torch.fft.rfft(tar[:, :, lat_index, :], dim=3)), dim=2) + loss_reg = 0.05 * torch.mean(torch.abs(out_fft - target_fft)) + + if epoch > 150: + loss = loss + loss_reg + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch % 10 == 0: + rollout_steps = 2920 # Per paper + # rollout_steps = 50 # Testing + rollout = torch.tensor( + inference.infer( + device, + model, + rollout_steps, + data_inp[0:1].to(device), + data_inp[:1460, -2:].to(device), + 1, + prog_means, + prog_stds, + diag_means, + diag_stds, + diff_stds, + ) + ).to(device) + rollout_clim = torch.mean(rollout[1460:], dim=0) + clim_bias = torch.mean(torch.abs(rollout_clim - true_clim)) + print("2 year rollout bias", clim_bias) + if epoch > 60: + if clim_bias <= infer_bias: + infer_bias = clim_bias + torch.save(model.state_dict(), "regular_training_checkpoint.pth") + recall_count = 0 + else: + state_pth = torch.load("regular_training_checkpoint.pth") + model.load_state_dict(state_pth) + recall_count += 1 + if recall_count > 3: + break + + +def load_data_and_train( + device, + regridded_data, + preprocessed_data, + *, + debug_sample_limit: int | None = None, + n_epochs: int | None = 500, + ntrain: int | None = 16000, + nval: int | None = 100, +): + """ + + args: + unprocessed_data + reprocessed_data: dictionary or numpy collection containing 'diagn_means', 'diag_stds', 'diff_means' and 'diff_stds' + + """ + + regridded_data = regridded_data[..., :6] + true_clim = torch.tensor(np.mean(regridded_data, axis=0)).to(device).permute(2, 0, 1) + + data = preprocessed_data # dictionary-like numpy array + data_inp = torch.tensor(data["data_inp"], dtype=torch.float32) # input data + data_tar = torch.tensor(data["data_tar"], dtype=torch.float32) + raw_means = torch.tensor(data["raw_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + raw_stds = torch.tensor(data["raw_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + prog_means = raw_means[:, :5] + prog_stds = raw_stds[:, :5] + diag_means = torch.tensor(data["diag_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + diag_stds = torch.tensor(data["diag_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + diff_means = torch.tensor(data["diff_means"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + diff_stds = torch.tensor(data["diff_stds"], dtype=torch.float32).reshape(1, -1, 1, 1).to(device) + + train_set = TensorDataset(data_inp[:ntrain], data_tar[:ntrain]) + val_set = TensorDataset(data_inp[ntrain : ntrain + nval], data_tar[ntrain : ntrain + nval]) + + train_loader = DataLoader(train_set, batch_size=16, shuffle=True) + val_loader = DataLoader(val_set, batch_size=4, shuffle=False) + + grid = "legendre-gauss" + nlat = 48 + nlon = 96 + hard_thresholding_fraction = 0.9 + lmax = ceil(nlat / 1) + mmax = lmax + modes_lat = int(nlat * hard_thresholding_fraction) + modes_lon = int(nlon // 2 * hard_thresholding_fraction) + modes_lat = modes_lon = min(modes_lat, modes_lon) + # sht = RealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=grid, csphase=False) + radius = 6.37122e6 + _cost, quad_weights = legendre_gauss_weights(nlat, -1, 1) + + # mps only supports float32, todo only do this if mps + # That said, most of the data seems to actually be in float32, so it is unclear if the weights + # benefit from being in float64 even on supported devices + # TODO: Experimentally verify the impact of using float32 here by default vs float64 performance + quad_weights = (torch.as_tensor(quad_weights).reshape(-1, 1)).to(torch.float32).to(device) + + model = SphericalFourierNeuralOperatorNet( + params={}, + spectral_transform="sht", + filter_type="linear", + operator_type="dhconv", + img_shape=(48, 96), + num_layers=8, + in_chans=7, + out_chans=6, + scale_factor=1, + embed_dim=72, + activation_function="silu", + big_skip=True, + pos_embed="latlon", + use_mlp=True, + normalization_layer="instance_norm", + hard_thresholding_fraction=hard_thresholding_fraction, + mlp_ratio=2.0, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0) + scheduler = CosineAnnealingLR(optimizer, T_max=150, eta_min=1e-5) + train_model( + device, + model, + train_loader, + val_loader, + optimizer, + prog_means=prog_means, + prog_stds=prog_stds, + diag_means=diag_means, + diag_stds=diag_stds, + diff_stds=diff_stds, + true_clim=true_clim, + data_inp=data_inp, + nlon=nlon, + quad_weights=quad_weights, + scheduler=scheduler, + nepochs=n_epochs, + debug_sample_limit=debug_sample_limit, + ) + + return model