diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 365fe32..d3f71f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,9 +2,6 @@ ci: autoupdate_commit_msg: "chore: Update pre-commit hooks" autofix_commit_msg: "style: Pre-commit fixes" -default_language_version: - python: python3.10 - repos: - repo: meta hooks: diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index f3e3401..db3a5d2 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -145,10 +145,8 @@ def __init__( surf_stats (dict[str, tuple[float, float]], optional): For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale. - bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run - the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a - gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT - OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING. + autocast (bool, optional): To reduce memory usage, `torch.autocast` only the backbone + to BF16. This is critical to enable fine-tuning. level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent on pressure level. If you want to enable this feature, provide a tuple of all possible pressure levels. @@ -228,6 +226,7 @@ def __init__( embed_dim=embed_dim, mlp_ratio=mlp_ratio, drop_path_rate=drop_path, + attn_drop_rate=drop_rate, drop_rate=drop_rate, use_lora=use_lora, lora_steps=lora_steps, @@ -252,18 +251,16 @@ def __init__( modulation_heads=modulation_heads, ) - if autocast and not bf16_mode: + if bf16_mode and not autocast: warnings.warn( - "The argument `autocast` no longer does anything due to limited utility. " - "Consider instead using `bf16_mode`.", + "`bf16_mode` was removed, because it caused serious issues for gradient " + "computation. `bf16_mode` now automatically activates `autocast`, which will not " + "save as much memory, but should be much more stable.", stacklevel=2, ) + autocast = True - self.bf16_mode = bf16_mode - - if self.bf16_mode: - # We run the backbone in pure BF16. - self.backbone.to(torch.bfloat16) + self.autocast = autocast def forward(self, batch: Batch) -> Batch: """Forward pass. @@ -327,44 +324,30 @@ def forward(self, batch: Batch) -> Batch: lead_time=self.timestep, ) - # In BF16 mode, the backbone is run in pure BF16. - if self.bf16_mode: - x = x.to(torch.bfloat16) - x = self.backbone( - x, - lead_time=self.timestep, - patch_res=patch_res, - rollout_step=batch.metadata.rollout_step, - ) - - # In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32. - # We run in PF16 as opposed to BF16 for improved relative precision. - if self.bf16_mode: - device_type = ( - "cuda" - if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "cpu" - ) - context = torch.autocast(device_type=device_type, dtype=torch.float16) - x = x.to(torch.float16) + if self.autocast: + if torch.cuda.is_available(): + device_type = "cuda" + elif torch.xpu.is_available(): + device_type = "xpu" + else: + device_type = "cpu" + context = torch.autocast(device_type=device_type, dtype=torch.bfloat16) else: context = contextlib.nullcontext() with context: - pred = self.decoder( + x = self.backbone( x, - batch, lead_time=self.timestep, patch_res=patch_res, + rollout_step=batch.metadata.rollout_step, ) - if self.bf16_mode: - pred = dataclasses.replace( - pred, - surf_vars={k: v.float() for k, v in pred.surf_vars.items()}, - static_vars={k: v.float() for k, v in pred.static_vars.items()}, - atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()}, - ) + + pred = self.decoder( + x, + batch, + lead_time=self.timestep, + patch_res=patch_res, + ) # Remove batch and history dimension from static variables. pred = dataclasses.replace( @@ -520,27 +503,49 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) checkpoint[name] = new_weight - def configure_activation_checkpointing(self): + def configure_activation_checkpointing( + self, + module_names: tuple[str, ...] = ( + "Basic3DDecoderLayer", + "Basic3DEncoderLayer", + "LinearPatchReconstruction", + "Perceiver3DDecoder", + "Perceiver3DEncoder", + "Swin3DTransformerBackbone", + "Swin3DTransformerBlock", + ), + ) -> None: """Configure activation checkpointing. This is required in order to compute gradients without running out of memory. + + Args: + module_names (tuple[str, ...], optional): Names of the modules to checkpoint + on. + + Raises: + RuntimeError: If any module specifies in `module_names` was not found and + thus could not be checkpointed. """ - # Checkpoint these modules: - module_names = ( - "Perceiver3DEncoder", - "Swin3DTransformerBackbone", - "Basic3DEncoderLayer", - "Basic3DDecoderLayer", - "Perceiver3DDecoder", - "LinearPatchReconstruction", - ) + + found: set[str] = set() def check(x: torch.nn.Module) -> bool: name = x.__class__.__name__ - return name in module_names + if name in module_names: + found.add(name) + return True + else: + return False apply_activation_checkpointing(self, check_fn=check) + if found != set(module_names): + raise RuntimeError( + f'Could not checkpoint on the following modules: ' + f'{", ".join(sorted(set(module_names) - found))}.' + ) + class AuroraPretrained(Aurora): """Pretrained version of Aurora.""" diff --git a/aurora/model/encoder.py b/aurora/model/encoder.py index 84aa1d4..c4d1988 100644 --- a/aurora/model/encoder.py +++ b/aurora/model/encoder.py @@ -41,7 +41,7 @@ def __init__( embed_dim: int = 1024, num_heads: int = 16, head_dim: int = 64, - drop_rate: float = 0.1, + drop_rate: float = 0.0, depth: int = 2, mlp_ratio: float = 4.0, max_history_size: int = 2, @@ -66,7 +66,7 @@ def __init__( Defaults to `16`. head_dim (int, optional): Dimension of attention heads used in aggregation blocks. Defaults to `64`. - drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.1`. + drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.0`. depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks. Defaults to `2`. mlp_ratio (float, optional): Ratio of hidden dimensionality to embedding dimensionality diff --git a/aurora/model/swin3d.py b/aurora/model/swin3d.py index 4d4c084..cc83e72 100644 --- a/aurora/model/swin3d.py +++ b/aurora/model/swin3d.py @@ -762,8 +762,8 @@ def __init__( mlp_ratio: float = 4.0, qkv_bias: bool = True, drop_rate: float = 0.0, - attn_drop_rate: float = 0.1, - drop_path_rate: float = 0.1, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, lora_steps: int = 40, lora_mode: LoRAMode = "single", use_lora: bool = False, @@ -785,8 +785,8 @@ def __init__( qkv_bias (bool): If `True`, add a learnable bias to the query, key, and value. Defaults to `True`. drop_rate (float): Drop-out rate. Defaults to `0.0`. - attn_drop_rate (float): Attention drop-out rate. Defaults to `0.1`. - drop_path_rate (float): Stochastic depth rate. Defaults to `0.1`. + attn_drop_rate (float): Attention drop-out rate. Defaults to `0.0`. + drop_path_rate (float): Stochastic depth rate. Defaults to `0.0`. lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`. lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out steps, `"from_second"` uses the same LoRA from the second roll-out step on, diff --git a/docs/finetuning.md b/docs/finetuning.md index 3a1353a..0ce9ab1 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -10,6 +10,47 @@ model = AuroraPretrained() model.load_checkpoint() ``` +## Basic Fine-Tuning Environment + +We provide a very basic Docker image and fine-tuning loop to get you started. +This Docker image is built from a NVIDIA PyTorch base image, +so is tailored to work for NVIDIA GPUs, and has been tested on an 80 GB A100. +The image can be found at `finetuning/Dockerfile` and the fine-tuning +loop at `finetuning/finetune.py`. +Assuming that you have cloned the Aurora repository, you can build and run +the image by running the following from the root of the repository: + +```bash +docker build . -t aurora:latest -f finetuning/Dockerfile +docker run --rm -it -v .:/app/aurora \ + --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ + aurora:latest +``` + +Then, within the image, execute + +```bash +python finetuning/finetune.py +``` + +to run the sample fine-tuning loop. + +For example, on Azure, launch a VM with size `Standard_NC24ads_A100_v4`, image +Ubuntu 24.04 LTS (x64), and 256 GB of disk space. +Then [install CUDA](https://learn.microsoft.com/en-us/azure/virtual-machines/linux/n-series-driver-setup). +Be sure to install the latest supported version of the CUDA Toolkit by +checking `nvidia-smi` after installing the drivers with +`sudo ubuntu-drivers autoinstall` and rebooting. +Best performance is achieved with CUDA Toolkit 13.0 or higher, which +requires drivers that support CUDA 13.0 or higher. +Then install Docker with `sudo apt install docker.io`, +set the right permissions for the current user with +`sudo usermod -a -G docker $USER`, +[install the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html), +and reboot. +You should now be able to clone the repo and build and run the image using +the instructions above. + ## Computing Gradients To compute gradients, you will need an A100 with 80 GB of memory. @@ -19,13 +60,7 @@ You can do this as follows: ```python from aurora import AuroraPretrained -model = AuroraPretrained( - # BF16 mode is an EXPERIMENTAL mode that saves memory by running the backbone in pure BF16 - # and the decoder in FP16 AMP. This should enable gradient computation. USE AT YOUR OWN RISK. - # THIS WAS NOT USED IN THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT - # FOR FINE-TUNING. - bf16_mode=True, -) +model = AuroraPretrained(autocast=True) model.load_checkpoint() batch = ... # Load some data. @@ -39,6 +74,9 @@ loss = ... loss.backward() ``` +Here `autocast` enables AMP with `bfloat16` for only the backbone. +This is necessary to be able to fit gradients in memory. + ## Exploding Gradients When fine-tuning, you may run into very large gradient values. diff --git a/finetuning/Dockerfile b/finetuning/Dockerfile new file mode 100644 index 0000000..575138b --- /dev/null +++ b/finetuning/Dockerfile @@ -0,0 +1,20 @@ +FROM nvcr.io/nvidia/pytorch:25.08-py3 +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +WORKDIR /app +SHELL ["/bin/bash", "-c"] + +# Create the environment and install the repo in editable mode. +RUN mkdir -p /app/aurora/aurora +COPY pyproject.toml LICENSE.txt /app/aurora/ +RUN touch /app/aurora/__init__.py \ + && touch /app/aurora/README.md \ + && uv venv --python 3.13 \ + && SETUPTOOLS_SCM_PRETEND_VERSION=0.0.0 uv pip install -e /app/aurora + +# Use the environment automatically. +ENV VIRTUAL_ENV="/app/.venv/" +ENV PATH="/app/.venv/bin:$PATH" + +# Let the user enter at `/app/aurora`. +WORKDIR /app/aurora diff --git a/finetuning/finetune.py b/finetuning/finetune.py new file mode 100644 index 0000000..c862c42 --- /dev/null +++ b/finetuning/finetune.py @@ -0,0 +1,45 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +from datetime import datetime + +import torch + +from aurora import AuroraPretrained, Batch, Metadata + + +def loss(pred: Batch) -> torch.Tensor: + """A sample loss function. You should replace this with your own loss function.""" + surf_values = prediction.surf_vars.values() + atmos_values = prediction.atmos_vars.values() + return sum((x * x).sum() for x in tuple(surf_values) + tuple(atmos_values)) + + +model = AuroraPretrained(autocast=True) +model.load_checkpoint() +model.configure_activation_checkpointing() +model.train() +model = model.to("cuda") + +opt = torch.optim.AdamW(model.parameters(), lr=3e-4) + +for i in range(10): + print(f"Step {i}") + + # Train on random data. You should replace this with your own data. + batch = Batch( + surf_vars={k: torch.randn(1, 2, 721, 1440) for k in ("2t", "10u", "10v", "msl")}, + static_vars={k: torch.randn(721, 1440) for k in ("lsm", "z", "slt")}, + atmos_vars={k: torch.randn(1, 2, 13, 721, 1440) for k in ("z", "u", "v", "t", "q")}, + metadata=Metadata( + lat=torch.linspace(90, -90, 721), + lon=torch.linspace(0, 360, 1440 + 1)[:-1], + time=(datetime(2020, 6, 1, 12, 0),), + atmos_levels=(50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000), + ), + ) + + opt.zero_grad() + prediction = model(batch.to("cuda")) + loss_value = loss(prediction) + loss_value.backward() + opt.step()