Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ ci:
autoupdate_commit_msg: "chore: Update pre-commit hooks"
autofix_commit_msg: "style: Pre-commit fixes"

default_language_version:
python: python3.10

repos:
- repo: meta
hooks:
Expand Down
113 changes: 59 additions & 54 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,8 @@ def __init__(
surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
variables, adjust the normalisation to the given tuple consisting of a new location
and scale.
bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run
the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a
gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT
OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING.
autocast (bool, optional): To reduce memory usage, `torch.autocast` only the backbone
to BF16. This is critical to enable fine-tuning.
level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent
on pressure level. If you want to enable this feature, provide a tuple of all
possible pressure levels.
Expand Down Expand Up @@ -228,6 +226,7 @@ def __init__(
embed_dim=embed_dim,
mlp_ratio=mlp_ratio,
drop_path_rate=drop_path,
attn_drop_rate=drop_rate,
drop_rate=drop_rate,
use_lora=use_lora,
lora_steps=lora_steps,
Expand All @@ -252,18 +251,16 @@ def __init__(
modulation_heads=modulation_heads,
)

if autocast and not bf16_mode:
if bf16_mode and not autocast:
warnings.warn(
"The argument `autocast` no longer does anything due to limited utility. "
"Consider instead using `bf16_mode`.",
"`bf16_mode` was removed, because it caused serious issues for gradient "
"computation. `bf16_mode` now automatically activates `autocast`, which will not "
"save as much memory, but should be much more stable.",
stacklevel=2,
)
autocast = True

self.bf16_mode = bf16_mode

if self.bf16_mode:
# We run the backbone in pure BF16.
self.backbone.to(torch.bfloat16)
self.autocast = autocast

def forward(self, batch: Batch) -> Batch:
"""Forward pass.
Expand Down Expand Up @@ -327,44 +324,30 @@ def forward(self, batch: Batch) -> Batch:
lead_time=self.timestep,
)

# In BF16 mode, the backbone is run in pure BF16.
if self.bf16_mode:
x = x.to(torch.bfloat16)
x = self.backbone(
x,
lead_time=self.timestep,
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)

# In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32.
# We run in PF16 as opposed to BF16 for improved relative precision.
if self.bf16_mode:
device_type = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)
context = torch.autocast(device_type=device_type, dtype=torch.float16)
x = x.to(torch.float16)
if self.autocast:
if torch.cuda.is_available():
device_type = "cuda"
elif torch.xpu.is_available():
device_type = "xpu"
else:
device_type = "cpu"
context = torch.autocast(device_type=device_type, dtype=torch.bfloat16)
else:
context = contextlib.nullcontext()
with context:
pred = self.decoder(
x = self.backbone(
x,
batch,
lead_time=self.timestep,
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
if self.bf16_mode:
pred = dataclasses.replace(
pred,
surf_vars={k: v.float() for k, v in pred.surf_vars.items()},
static_vars={k: v.float() for k, v in pred.static_vars.items()},
atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()},
)

pred = self.decoder(
x,
batch,
lead_time=self.timestep,
patch_res=patch_res,
)

# Remove batch and history dimension from static variables.
pred = dataclasses.replace(
Expand Down Expand Up @@ -520,27 +503,49 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor])

checkpoint[name] = new_weight

def configure_activation_checkpointing(self):
def configure_activation_checkpointing(
self,
module_names: tuple[str, ...] = (
"Basic3DDecoderLayer",
"Basic3DEncoderLayer",
"LinearPatchReconstruction",
"Perceiver3DDecoder",
"Perceiver3DEncoder",
"Swin3DTransformerBackbone",
"Swin3DTransformerBlock",
),
) -> None:
"""Configure activation checkpointing.

This is required in order to compute gradients without running out of memory.

Args:
module_names (tuple[str, ...], optional): Names of the modules to checkpoint
on.

Raises:
RuntimeError: If any module specifies in `module_names` was not found and
thus could not be checkpointed.
"""
# Checkpoint these modules:
module_names = (
"Perceiver3DEncoder",
"Swin3DTransformerBackbone",
"Basic3DEncoderLayer",
"Basic3DDecoderLayer",
"Perceiver3DDecoder",
"LinearPatchReconstruction",
)

found: set[str] = set()

def check(x: torch.nn.Module) -> bool:
name = x.__class__.__name__
return name in module_names
if name in module_names:
found.add(name)
return True
else:
return False

apply_activation_checkpointing(self, check_fn=check)

if found != set(module_names):
raise RuntimeError(
f'Could not checkpoint on the following modules: '
f'{", ".join(sorted(set(module_names) - found))}.'
)


class AuroraPretrained(Aurora):
"""Pretrained version of Aurora."""
Expand Down
4 changes: 2 additions & 2 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
embed_dim: int = 1024,
num_heads: int = 16,
head_dim: int = 64,
drop_rate: float = 0.1,
drop_rate: float = 0.0,
depth: int = 2,
mlp_ratio: float = 4.0,
max_history_size: int = 2,
Expand All @@ -66,7 +66,7 @@ def __init__(
Defaults to `16`.
head_dim (int, optional): Dimension of attention heads used in aggregation blocks.
Defaults to `64`.
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.1`.
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.0`.
depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks.
Defaults to `2`.
mlp_ratio (float, optional): Ratio of hidden dimensionality to embedding dimensionality
Expand Down
8 changes: 4 additions & 4 deletions aurora/model/swin3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,8 @@ def __init__(
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.1,
drop_path_rate: float = 0.1,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
use_lora: bool = False,
Expand All @@ -785,8 +785,8 @@ def __init__(
qkv_bias (bool): If `True`, add a learnable bias to the query, key, and value. Defaults
to `True`.
drop_rate (float): Drop-out rate. Defaults to `0.0`.
attn_drop_rate (float): Attention drop-out rate. Defaults to `0.1`.
drop_path_rate (float): Stochastic depth rate. Defaults to `0.1`.
attn_drop_rate (float): Attention drop-out rate. Defaults to `0.0`.
drop_path_rate (float): Stochastic depth rate. Defaults to `0.0`.
lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`.
lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out
steps, `"from_second"` uses the same LoRA from the second roll-out step on,
Expand Down
52 changes: 45 additions & 7 deletions docs/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,47 @@ model = AuroraPretrained()
model.load_checkpoint()
```

## Basic Fine-Tuning Environment

We provide a very basic Docker image and fine-tuning loop to get you started.
This Docker image is built from a NVIDIA PyTorch base image,
so is tailored to work for NVIDIA GPUs, and has been tested on an 80 GB A100.
The image can be found at `finetuning/Dockerfile` and the fine-tuning
loop at `finetuning/finetune.py`.
Assuming that you have cloned the Aurora repository, you can build and run
the image by running the following from the root of the repository:

```bash
docker build . -t aurora:latest -f finetuning/Dockerfile
docker run --rm -it -v .:/app/aurora \
--gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
aurora:latest
```

Then, within the image, execute

```bash
python finetuning/finetune.py
```

to run the sample fine-tuning loop.

For example, on Azure, launch a VM with size `Standard_NC24ads_A100_v4`, image
Ubuntu 24.04 LTS (x64), and 256 GB of disk space.
Then [install CUDA](https://learn.microsoft.com/en-us/azure/virtual-machines/linux/n-series-driver-setup).
Be sure to install the latest supported version of the CUDA Toolkit by
checking `nvidia-smi` after installing the drivers with
`sudo ubuntu-drivers autoinstall` and rebooting.
Best performance is achieved with CUDA Toolkit 13.0 or higher, which
requires drivers that support CUDA 13.0 or higher.
Then install Docker with `sudo apt install docker.io`,
set the right permissions for the current user with
`sudo usermod -a -G docker $USER`,
[install the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html),
and reboot.
You should now be able to clone the repo and build and run the image using
the instructions above.

## Computing Gradients

To compute gradients, you will need an A100 with 80 GB of memory.
Expand All @@ -19,13 +60,7 @@ You can do this as follows:
```python
from aurora import AuroraPretrained

model = AuroraPretrained(
# BF16 mode is an EXPERIMENTAL mode that saves memory by running the backbone in pure BF16
# and the decoder in FP16 AMP. This should enable gradient computation. USE AT YOUR OWN RISK.
# THIS WAS NOT USED IN THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT
# FOR FINE-TUNING.
bf16_mode=True,
)
model = AuroraPretrained(autocast=True)
model.load_checkpoint()

batch = ... # Load some data.
Expand All @@ -39,6 +74,9 @@ loss = ...
loss.backward()
```

Here `autocast` enables AMP with `bfloat16` for only the backbone.
This is necessary to be able to fit gradients in memory.

## Exploding Gradients

When fine-tuning, you may run into very large gradient values.
Expand Down
20 changes: 20 additions & 0 deletions finetuning/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM nvcr.io/nvidia/pytorch:25.08-py3
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

WORKDIR /app
SHELL ["/bin/bash", "-c"]

# Create the environment and install the repo in editable mode.
RUN mkdir -p /app/aurora/aurora
COPY pyproject.toml LICENSE.txt /app/aurora/
RUN touch /app/aurora/__init__.py \
&& touch /app/aurora/README.md \
&& uv venv --python 3.13 \
&& SETUPTOOLS_SCM_PRETEND_VERSION=0.0.0 uv pip install -e /app/aurora

# Use the environment automatically.
ENV VIRTUAL_ENV="/app/.venv/"
ENV PATH="/app/.venv/bin:$PATH"

# Let the user enter at `/app/aurora`.
WORKDIR /app/aurora
45 changes: 45 additions & 0 deletions finetuning/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

from datetime import datetime

import torch

from aurora import AuroraPretrained, Batch, Metadata


def loss(pred: Batch) -> torch.Tensor:
"""A sample loss function. You should replace this with your own loss function."""
surf_values = prediction.surf_vars.values()
atmos_values = prediction.atmos_vars.values()
return sum((x * x).sum() for x in tuple(surf_values) + tuple(atmos_values))


model = AuroraPretrained(autocast=True)
model.load_checkpoint()
model.configure_activation_checkpointing()
model.train()
model = model.to("cuda")

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i in range(10):
print(f"Step {i}")

# Train on random data. You should replace this with your own data.
batch = Batch(
surf_vars={k: torch.randn(1, 2, 721, 1440) for k in ("2t", "10u", "10v", "msl")},
static_vars={k: torch.randn(721, 1440) for k in ("lsm", "z", "slt")},
atmos_vars={k: torch.randn(1, 2, 13, 721, 1440) for k in ("z", "u", "v", "t", "q")},
metadata=Metadata(
lat=torch.linspace(90, -90, 721),
lon=torch.linspace(0, 360, 1440 + 1)[:-1],
time=(datetime(2020, 6, 1, 12, 0),),
atmos_levels=(50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000),
),
)

opt.zero_grad()
prediction = model(batch.to("cuda"))
loss_value = loss(prediction)
loss_value.backward()
opt.step()