-
Notifications
You must be signed in to change notification settings - Fork 135
Description
Dear Microsoft,
I am working to try finetue Aurora model, but it went out of memory just using pseudo data and forward once,
The following is my training code:
"""
import os
from datetime import datetime,timedelta
import xarray as xr
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from aurora import AuroraPretrained, Batch, Metadata, rollout,AuroraSmallPretrained
from aurora.normalisation import locations, scales
device = torch.device(torch.device("cuda:0"))
print(f"device is {device}")
model = AuroraPretrained(
surf_vars=("2t","10u","10v","msl"),
static_vars=("lsm","z","slt"),
atmos_vars=("z","u","v","t","q"),
bf16_mode=False
).to(device)
model.load_checkpoint_local(path='ckpt/aurora-0.25-pretrained.ckpt',strict = True)
set models
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = torch.nn.HuberLoss(delta=1.0) # robust vs spikes
scaler = torch.amp.GradScaler('cuda')
pseudo data
batch = Batch(
surf_vars={k: torch.randn(1, 2, 721, 1440).to(device) for k in ("2t", "10u", "10v", "msl")},
static_vars={k: torch.randn(721, 1440).to(device) for k in ("lsm", "z", "slt")},
atmos_vars={k: torch.randn(1, 2, 4, 721, 1440).to(device) for k in ("z", "u", "v", "t", "q")},
metadata=Metadata(
lat=torch.linspace(90, -90, 721).to(device),
lon=torch.linspace(0, 360, 1440 + 1)[:-1].to(device),
time=(datetime(2020, 6, 1, 12, 0),),
atmos_levels=(100, 250, 500, 850),
),
)
model forward
model.train()
pred = model.forward(batch)
XXX
“”“
You could see that I only put the pseudo data with 0.25 degree resolution and 13 level. When running model.forward(batch), It went out of memory (over 80G).
Therefore, I wanna ask:
- First, am I train or comput the gradiant in a right way?
- How to train the model when dealing with GPU memory ?
- How to train use multiple GPUs
This seems a big step when finetuning aurora.
Best