From ca0eb805f6c967e1d765dd0a4046371f359416e8 Mon Sep 17 00:00:00 2001 From: Ruben Cartuyvels Date: Wed, 29 Oct 2025 16:17:07 +0000 Subject: [PATCH 1/3] Use flash_attn_varlen_qkvpacked_func and benchmark it --- presto/presto.py | 140 ++++++++++++++++++++++++++++-------- train_time.py | 181 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 31 deletions(-) create mode 100644 train_time.py diff --git a/presto/presto.py b/presto/presto.py index 4f6e8e5..2d3e41f 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -1,10 +1,11 @@ import math from copy import deepcopy -from typing import Optional, Tuple, Union, cast +from typing import Literal, Optional, Tuple, Union, cast import numpy as np import torch from einops import rearrange, repeat +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from torch import nn from torch.jit import Final from torch.nn import functional as F @@ -17,7 +18,8 @@ class Attention(nn.Module): # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py - fast_attn: Final[bool] + # fast_attn: Final[bool] + attn: Final[Literal["sdpa", "flash", "naive"]] def __init__( self, @@ -28,13 +30,14 @@ def __init__( attn_drop=0.0, proj_drop=0.0, norm_layer=nn.LayerNorm, + attn: Literal["sdpa", "flash", "naive"] = "sdpa", ): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 - self.fast_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention") # FIXME + self.attn = attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -45,31 +48,103 @@ def __init__( def forward(self, x, attn_mask=None): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) - if self.fast_attn: + if self.attn == "flash": + # Compute sequence lengths from attention mask + # attn_mask is [B, N] where True/1 means valid token, False/0 means padding if attn_mask is not None: - attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1)) - x = F.scaled_dot_product_attention( - q, - k, - v, - # a value of True indicates that the element should take part in attention - attn_mask=attn_mask, - dropout_p=self.attn_drop.p, - ) - else: + # Count valid (non-padding) tokens per sequence + seqlens = attn_mask.sum(dim=1).int() # [B] + else: + # No padding, all tokens are valid + seqlens = torch.full((B,), N, dtype=torch.int32, device=x.device) + + # Compute cumulative sequence lengths for flash attention + # cu_seqlens[i] = sum of sequence lengths up to batch i + cu_seqlens = torch.cat( + [torch.zeros(1, dtype=torch.int32, device=x.device), seqlens.cumsum(dim=0)] + ).to(torch.int32) # [B+1] + + max_seqlen = seqlens.max().item() + + # Remove padding tokens to create packed tensor + # We need to flatten and select only valid tokens if attn_mask is not None: - raise NotImplementedError - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - x = x.transpose(1, 2).reshape(B, N, C) + # Create indices for valid tokens + valid_mask = attn_mask.bool() # [B, N] + # Flatten qkv_packed and select valid tokens + qkv_packed_flat = qkv.reshape(B * N, 3, self.num_heads, self.head_dim) + valid_mask_flat = valid_mask.reshape(B * N) + qkv_packed_valid = qkv_packed_flat[ + valid_mask_flat + ] # [total_valid, 3, num_heads, head_dim] + else: + # No padding, just reshape + qkv_packed_valid = qkv.reshape(B * N, 3, self.num_heads, self.head_dim) + + # Call flash attention + x_packed = flash_attn_varlen_qkvpacked_func( + qkv_packed_valid, + cu_seqlens, + max_seqlen, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + ) # [total_valid, num_heads, head_dim] + + # Unpack result back to [B, N, num_heads, head_dim] with padding + if attn_mask is not None: + # Create output tensor with zeros for padding + x = torch.zeros( + B * N, + self.num_heads, + self.head_dim, + dtype=x_packed.dtype, + device=x_packed.device, + ) + # Fill in valid positions + x[valid_mask_flat] = x_packed + x = x.reshape(B, N, -1) + else: + x = x_packed.reshape(B, N, -1) + + else: + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.attn == "sdpa": + if attn_mask is not None: + if torch.all(attn_mask): + attn_mask = None + else: + attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1)) + + x = F.scaled_dot_product_attention( + q, + k, + v, + # a value of True indicates that the element should take part in attention + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) + else: + if attn_mask is not None: + raise NotImplementedError + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2) + + x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -129,6 +204,7 @@ def __init__( init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + attn="sdpa", ): super().__init__() self.norm1 = norm_layer(dim) @@ -140,6 +216,7 @@ def __init__( attn_drop=attn_drop, proj_drop=drop, norm_layer=norm_layer, + attn=attn, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() @@ -225,6 +302,7 @@ def __init__( mlp_ratio=2, num_heads=8, max_sequence_length=24, + attn="sdpa", ): super().__init__() @@ -256,6 +334,7 @@ def __init__( mlp_ratio, qkv_bias=True, norm_layer=nn.LayerNorm, + attn=attn, ) for _ in range(depth) ] @@ -279,7 +358,6 @@ def __init__( self.initialize_weights() def initialize_weights(self): - pos_embed = get_sinusoid_encoding_table(self.pos_embed.shape[1], self.pos_embed.shape[-1]) self.pos_embed.data.copy_(pos_embed) @@ -334,7 +412,6 @@ def forward( month: Union[torch.Tensor, int] = 0, eval_task: bool = True, ): - if mask is None: mask = torch.zeros_like(x, device=x.device).float() @@ -432,6 +509,7 @@ def __init__( decoder_num_heads=8, mlp_ratio=2, max_sequence_length=24, + attn="sdpa", ): super().__init__() @@ -455,6 +533,7 @@ def __init__( mlp_ratio, qkv_bias=True, norm_layer=nn.LayerNorm, + attn=attn, ) for _ in range(decoder_depth) ] @@ -485,7 +564,6 @@ def __init__( self.initialize_weights() def initialize_weights(self): - pos_embed = get_sinusoid_encoding_table(self.pos_embed.shape[1], self.pos_embed.shape[-1]) self.pos_embed.data.copy_(pos_embed) @@ -606,7 +684,6 @@ def reconstruct_inputs(self, x) -> Tuple[torch.Tensor, torch.Tensor]: return torch.cat(eo_output, dim=-1), cast(torch.Tensor, dw_output) def forward(self, x, orig_indices, x_mask, month): - x = self.decoder_embed(x) x = self.add_masked_tokens(x, orig_indices, x_mask) x = self.add_embeddings(x, month) @@ -639,7 +716,6 @@ def forward( mask: Optional[torch.Tensor] = None, month: Union[torch.Tensor, int] = 0, ) -> torch.Tensor: - return self.head( self.encoder( x=x, @@ -716,7 +792,6 @@ def forward( mask: Optional[torch.Tensor] = None, month: Union[torch.Tensor, int] = 0, ) -> torch.Tensor: - # inputs are expected to be with 2 batch dimensions # (batches of images) (patches within an image) ... # vmap doesn't work with data dependent flows (yet) @@ -773,6 +848,7 @@ def construct( decoder_depth=2, decoder_num_heads=8, max_sequence_length=24, + attn="sdpa", ): encoder = Encoder( embedding_size=encoder_embedding_size, @@ -782,6 +858,7 @@ def construct( mlp_ratio=mlp_ratio, num_heads=encoder_num_heads, max_sequence_length=max_sequence_length, + attn=attn, ) decoder = Decoder( channel_embeddings=encoder.channel_embed, @@ -791,6 +868,7 @@ def construct( decoder_num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, max_sequence_length=max_sequence_length, + attn=attn, ) return cls(encoder, decoder) diff --git a/train_time.py b/train_time.py new file mode 100644 index 0000000..4ef1a1f --- /dev/null +++ b/train_time.py @@ -0,0 +1,181 @@ +from torch.utils.benchmark import Timer + +NUM_ITERATIONS = 100 +# For seq length (in Attention forward, so after channel grouping) in [30, 55] +VARIABLE_LENGTH_RATIO = 0.5 + +setup = """ +import json +from pathlib import Path + +import torch +import torch.nn as nn +import webdataset as wds +from torch import optim + +from presto import Presto +from presto.dataops import MASK_STRATEGIES, MaskParams +from presto.dataops.dataset import ( + S1_S2_ERA5_SRTM_DynamicWorldMonthly_2020_2021, +) +from presto.model import LossWrapper + +train_url: str = "data/dw_144_mini_shard_44.tar" +device = torch.device("cuda:0") +path_to_config = Path("config") / "default.json" +model_kwargs = json.load(Path(path_to_config).open("r")) + +# ------------ Dataloaders ------------------------------------- +# Set mask_ratio to 0.0 here because it will generate masks that result in +# sequences with equal length +mask_params = MaskParams(MASK_STRATEGIES, 0.01) + + +def load_dataset(url, shuffle_on_load): + dataset = S1_S2_ERA5_SRTM_DynamicWorldMonthly_2020_2021(mask_params=mask_params) + return dataset.as_webdataset(url, shuffle_on_load) + + +train_dataset = load_dataset("data/dw_144_mini_shard_44.tar", shuffle_on_load=True) +train_dataloader = wds.WebLoader(train_dataset, batch_size=4096) + +mse = LossWrapper(nn.MSELoss()) + +b = next(iter(train_dataloader)) +mask, x, y, start_month = b[0].to(device), b[2].to(device), b[3].to(device), b[6] +dw_mask, x_dw, y_dw = b[1].to(device), b[4].to(device).long(), b[5].to(device).long() +latlons = b[7].to(device) +""" + +model_sdpa_setup = """ +model = Presto.construct(**model_kwargs, attn="sdpa") +model.to(device) +optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) +model.train() +""" + +model_flash_setup = """ +model = Presto.construct(**model_kwargs, attn="flash") +model.to(device) +optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) +model.train() +""" + +bfloat16_setup = """ +latlons = latlons.bfloat16() +x = x.bfloat16() +y = y.bfloat16() +model = model.bfloat16() +""" + +variable_length_mask_setup = f""" +B, T, C = mask.shape +total_tokens = B * T * C +num_tokens_to_mask = int(total_tokens * {VARIABLE_LENGTH_RATIO}) + +# Create flat mask and randomly select positions to mask +flat_mask = torch.zeros(total_tokens, dtype=torch.bool, device=device) +mask_indices = torch.randperm(total_tokens, device=device)[:num_tokens_to_mask] +flat_mask[mask_indices] = True + +# Reshape back to (B, T, C) +mask = ~flat_mask.reshape(B, T, C) + +# Apply mask: x should be zero where mask is False, y should be zero where mask is True +x = torch.where(mask, torch.zeros_like(x), x) +y = torch.where(mask, y, torch.zeros_like(y)) +""" + +forward_backward = """ +optimizer.zero_grad() +y_pred, dw_pred = model(x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month) +loss = mse(y_pred[mask], y[mask]) +loss.backward() +optimizer.step() +""" + +# Run these with pytorch 2.0, and without the code changes to presto.py + +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100", +# ) +# print(timer.timeit(NUM_ITERATIONS)) + +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + variable_length_mask_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, varlen mask", +# ) +# print(timer.timeit(NUM_ITERATIONS)) + + +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + bfloat16_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16", +# ) +# print(timer.timeit(NUM_ITERATIONS)) + +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16, varlen mask", +# ) +# print(timer.timeit(NUM_ITERATIONS)) + + +# run +# pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 + + +timer = Timer( + stmt=forward_backward, + setup=setup + model_sdpa_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100", +) +print(timer.timeit(NUM_ITERATIONS)) + +timer = Timer( + stmt=forward_backward, + setup=setup + model_sdpa_setup + variable_length_mask_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) + + +timer = Timer( + stmt=forward_backward, + setup=setup + model_sdpa_setup + bfloat16_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16", +) +print(timer.timeit(NUM_ITERATIONS)) + +timer = Timer( + stmt=forward_backward, + setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) + +# 2.7.4.post1 because latest version gave error `lib/libstdc++.so.6: version `GLIBCXX_3.4.32' not found` +# may be a LINUX version issue, may be a python version issue? Tested for python 3.9 +# https://github.com/Dao-AILab/flash-attention/issues/1708 +# https://stackoverflow.com/questions/76974555/glibcxx-3-4-32-not-found-error-at-runtime-gcc-13-2-0 +# pip install packaging ninja +# pip install flash-attn==2.7.4.post1 --no-build-isolation + +timer = Timer( + stmt=forward_backward, + setup=setup + model_flash_setup + bfloat16_setup, + label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16", +) +print(timer.timeit(NUM_ITERATIONS)) + +timer = Timer( + stmt=forward_backward, + setup=setup + model_flash_setup + bfloat16_setup + variable_length_mask_setup, + label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) From 5e9545742ea13347cda6b69b25ae38bb74a37fde Mon Sep 17 00:00:00 2001 From: Ruben Cartuyvels Date: Thu, 30 Oct 2025 16:28:04 +0000 Subject: [PATCH 2/3] Add memory usage and longer sequences --- presto/presto.py | 6 +++-- train_time.py | 67 +++++++++++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/presto/presto.py b/presto/presto.py index 2d3e41f..fb930a2 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -124,7 +124,6 @@ def forward(self, x, attn_mask=None): attn_mask = None else: attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1)) - x = F.scaled_dot_product_attention( q, k, @@ -134,6 +133,8 @@ def forward(self, x, attn_mask=None): dropout_p=self.attn_drop.p, ) else: + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + q, k = self.q_norm(q), self.k_norm(k) if attn_mask is not None: raise NotImplementedError q = q * self.scale @@ -492,7 +493,8 @@ def forward( if eval_task: # set masked tokens to 0 x_for_mean = x * (1 - upd_mask.unsqueeze(-1)) - x_mean = x_for_mean.sum(dim=1) / torch.sum(1 - upd_mask, -1, keepdim=True) + # type_as for half precision + x_mean = (x_for_mean.sum(dim=1) / torch.sum(1 - upd_mask, -1, keepdim=True)).type_as(x) # note: page 6 of https://arxiv.org/pdf/2104.02057.pdf # suggests removing the norm layer return self.norm(x_mean) diff --git a/train_time.py b/train_time.py index 4ef1a1f..fb0a6e4 100644 --- a/train_time.py +++ b/train_time.py @@ -1,10 +1,13 @@ +import torch +from torch.cuda import max_memory_allocated, reset_peak_memory_stats from torch.utils.benchmark import Timer NUM_ITERATIONS = 100 -# For seq length (in Attention forward, so after channel grouping) in [30, 55] VARIABLE_LENGTH_RATIO = 0.5 +BATCH_SIZE = 4096 +REPEATS = 1 -setup = """ +setup = f""" import json from pathlib import Path @@ -28,6 +31,7 @@ # ------------ Dataloaders ------------------------------------- # Set mask_ratio to 0.0 here because it will generate masks that result in # sequences with equal length +# Set to 0.01 because 0.0 gives gradient errors mask_params = MaskParams(MASK_STRATEGIES, 0.01) @@ -37,7 +41,7 @@ def load_dataset(url, shuffle_on_load): train_dataset = load_dataset("data/dw_144_mini_shard_44.tar", shuffle_on_load=True) -train_dataloader = wds.WebLoader(train_dataset, batch_size=4096) +train_dataloader = wds.WebLoader(train_dataset, batch_size={BATCH_SIZE}) mse = LossWrapper(nn.MSELoss()) @@ -45,17 +49,23 @@ def load_dataset(url, shuffle_on_load): mask, x, y, start_month = b[0].to(device), b[2].to(device), b[3].to(device), b[6] dw_mask, x_dw, y_dw = b[1].to(device), b[4].to(device).long(), b[5].to(device).long() latlons = b[7].to(device) + +x = torch.repeat_interleave(x, {REPEATS}, dim=1) +y = torch.repeat_interleave(y, {REPEATS}, dim=1) +x_dw = torch.repeat_interleave(x_dw, {REPEATS}, dim=1) +y_dw = torch.repeat_interleave(y_dw, {REPEATS}, dim=1) +mask = torch.repeat_interleave(mask, {REPEATS}, dim=1) """ model_sdpa_setup = """ -model = Presto.construct(**model_kwargs, attn="sdpa") +model = Presto.construct(**model_kwargs, attn="sdpa", max_sequence_length=480) model.to(device) optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) model.train() """ model_flash_setup = """ -model = Presto.construct(**model_kwargs, attn="flash") +model = Presto.construct(**model_kwargs, attn="flash", max_sequence_length=480) model.to(device) optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) model.train() @@ -94,88 +104,105 @@ def load_dataset(url, shuffle_on_load): optimizer.step() """ +device = torch.device("cuda:0") + # Run these with pytorch 2.0, and without the code changes to presto.py +# (or just comment the line that imports flash_attn_varlen_qkvpacked_func) # timer = Timer( # stmt=forward_backward, # setup=setup + model_sdpa_setup, -# label="Pytorch 2.0, scaled_dot_product_attention, on A100", +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, set to None", # ) # print(timer.timeit(NUM_ITERATIONS)) - +# print(max_memory_allocated(device) / 10**6) +# reset_peak_memory_stats(device) # timer = Timer( # stmt=forward_backward, # setup=setup + model_sdpa_setup + variable_length_mask_setup, # label="Pytorch 2.0, scaled_dot_product_attention, on A100, varlen mask", # ) # print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) - +# reset_peak_memory_stats(device) # timer = Timer( # stmt=forward_backward, # setup=setup + model_sdpa_setup + bfloat16_setup, # label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16", # ) # print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) +# reset_peak_memory_stats(device) # timer = Timer( # stmt=forward_backward, # setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, # label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16, varlen mask", # ) # print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) - +# exit() # run # pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 - +# reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_sdpa_setup, - label="Pytorch 2.2, scaled_dot_product_attention, on A100", + label="Pytorch 2.2, scaled_dot_product_attention, on A100, set to None", ) print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) +reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_sdpa_setup + variable_length_mask_setup, label="Pytorch 2.2, scaled_dot_product_attention, on A100, varlen mask", ) print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) - +reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_sdpa_setup + bfloat16_setup, - label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16", + label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16, set to None", ) print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) +reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16, varlen mask", ) print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) -# 2.7.4.post1 because latest version gave error `lib/libstdc++.so.6: version `GLIBCXX_3.4.32' not found` -# may be a LINUX version issue, may be a python version issue? Tested for python 3.9 -# https://github.com/Dao-AILab/flash-attention/issues/1708 -# https://stackoverflow.com/questions/76974555/glibcxx-3-4-32-not-found-error-at-runtime-gcc-13-2-0 -# pip install packaging ninja -# pip install flash-attn==2.7.4.post1 --no-build-isolation +# # 2.7.4.post1 because latest version gave error `lib/libstdc++.so.6: version `GLIBCXX_3.4.32' not found` +# # may be a LINUX version issue, may be a python version issue? Tested for python 3.9 +# # https://github.com/Dao-AILab/flash-attention/issues/1708 +# # https://stackoverflow.com/questions/76974555/glibcxx-3-4-32-not-found-error-at-runtime-gcc-13-2-0 +# # pip install packaging ninja +# # pip install flash-attn==2.7.4.post1 --no-build-isolation +reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_flash_setup + bfloat16_setup, label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16", ) print(timer.timeit(NUM_ITERATIONS)) - +print(max_memory_allocated(device) / 10**6) +reset_peak_memory_stats(device) timer = Timer( stmt=forward_backward, setup=setup + model_flash_setup + bfloat16_setup + variable_length_mask_setup, label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16, varlen mask", ) print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) From 5a7b5c38e925ced3b8d6a4080792ab28b5d52aff Mon Sep 17 00:00:00 2001 From: Ruben Cartuyvels Date: Thu, 30 Oct 2025 16:28:26 +0000 Subject: [PATCH 3/3] Add script to benchmark inference --- inference_time.py | 206 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 inference_time.py diff --git a/inference_time.py b/inference_time.py new file mode 100644 index 0000000..408de8e --- /dev/null +++ b/inference_time.py @@ -0,0 +1,206 @@ +import torch +from torch.cuda import max_memory_allocated, reset_peak_memory_stats +from torch.utils.benchmark import Timer + +NUM_ITERATIONS = 100 +VARIABLE_LENGTH_RATIO = 0.5 +BATCH_SIZE = 4096 +REPEATS = 1 + +setup = f""" +import json +from pathlib import Path + +import torch +import torch.nn as nn +import webdataset as wds +from torch import optim + +from presto import Presto +from presto.dataops import MASK_STRATEGIES, MaskParams +from presto.dataops.dataset import ( + S1_S2_ERA5_SRTM_DynamicWorldMonthly_2020_2021, +) +from presto.model import LossWrapper + +train_url: str = "data/dw_144_mini_shard_44.tar" +device = torch.device("cuda:0") +path_to_config = Path("config") / "default.json" +model_kwargs = json.load(Path(path_to_config).open("r")) + +# ------------ Dataloaders ------------------------------------- +# Set mask_ratio to 0.0 here because it will generate masks that result in +# sequences with equal length +mask_params = MaskParams(MASK_STRATEGIES, 0.0) + + +def load_dataset(url, shuffle_on_load): + dataset = S1_S2_ERA5_SRTM_DynamicWorldMonthly_2020_2021(mask_params=mask_params) + return dataset.as_webdataset(url, shuffle_on_load) + + +train_dataset = load_dataset("data/dw_144_mini_shard_44.tar", shuffle_on_load=True) +train_dataloader = wds.WebLoader(train_dataset, batch_size={BATCH_SIZE}) + +mse = LossWrapper(nn.MSELoss()) + +b = next(iter(train_dataloader)) +mask, x, y, start_month = b[0].to(device), b[2].to(device), b[3].to(device), b[6] +dw_mask, x_dw, y_dw = b[1].to(device), b[4].to(device).long(), b[5].to(device).long() +latlons = b[7].to(device) + +x = torch.repeat_interleave(x, {REPEATS}, dim=1) +y = torch.repeat_interleave(y, {REPEATS}, dim=1) +x_dw = torch.repeat_interleave(x_dw, {REPEATS}, dim=1) +y_dw = torch.repeat_interleave(y_dw, {REPEATS}, dim=1) +mask = torch.repeat_interleave(mask, {REPEATS}, dim=1) +""" + +model_sdpa_setup = """ +_model = Presto.construct(**model_kwargs, attn="sdpa", max_sequence_length=480) +model = _model.construct_finetuning_model(num_outputs=1, regression=True) +model.to(device) +optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) +model.eval() +""" + +model_flash_setup = """ +_model = Presto.construct(**model_kwargs, attn="flash", max_sequence_length=480) +model = _model.construct_finetuning_model(num_outputs=1, regression=True) +model.to(device) +model.eval() +""" + +bfloat16_setup = """ +latlons = latlons.bfloat16() +x = x.bfloat16() +y = y.bfloat16() +model = model.bfloat16() +""" + +variable_length_mask_setup = f""" +B, T, C = mask.shape +total_tokens = B * T * C +num_tokens_to_mask = int(total_tokens * {VARIABLE_LENGTH_RATIO}) + +# Create flat mask and randomly select positions to mask +flat_mask = torch.zeros(total_tokens, dtype=torch.bool, device=device) +mask_indices = torch.randperm(total_tokens, device=device)[:num_tokens_to_mask] +flat_mask[mask_indices] = True + +# Reshape back to (B, T, C) +mask = ~flat_mask.reshape(B, T, C) + +# Apply mask: x should be zero where mask is False, y should be zero where mask is True +x = torch.where(mask, torch.zeros_like(x), x) +y = torch.where(mask, y, torch.zeros_like(y)) +""" + +forward_only = """ +with torch.no_grad(): + y_pred = model(x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month) + loss = mse(y_pred, torch.ones_like(y_pred)) +""" + +device = torch.device("cuda:0") + +# Run these with pytorch 2.0, and without the code changes to presto.py +# (or just comment the line that imports flash_attn_varlen_qkvpacked_func) + +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, set to None", +# ) +# print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) +# reset_peak_memory_stats(device) +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + variable_length_mask_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, varlen mask", +# ) +# print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) + +# reset_peak_memory_stats(device) +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + bfloat16_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16", +# ) +# print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) + +# reset_peak_memory_stats(device) +# timer = Timer( +# stmt=forward_backward, +# setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, +# label="Pytorch 2.0, scaled_dot_product_attention, on A100, bfloat16, varlen mask", +# ) +# print(timer.timeit(NUM_ITERATIONS)) +# print(max_memory_allocated(device) / 10**6) + +# exit() +# run +# pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 + +# reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_sdpa_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, NOT set to None", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) + +reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_sdpa_setup + variable_length_mask_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) + +reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_sdpa_setup + bfloat16_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16, set to None", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) + +reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_sdpa_setup + bfloat16_setup + variable_length_mask_setup, + label="Pytorch 2.2, scaled_dot_product_attention, on A100, bfloat16, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) + +# # 2.7.4.post1 because latest version gave error `lib/libstdc++.so.6: version `GLIBCXX_3.4.32' not found` +# # may be a LINUX version issue, may be a python version issue? Tested for python 3.9 +# # https://github.com/Dao-AILab/flash-attention/issues/1708 +# # https://stackoverflow.com/questions/76974555/glibcxx-3-4-32-not-found-error-at-runtime-gcc-13-2-0 +# # pip install packaging ninja +# # pip install flash-attn==2.7.4.post1 --no-build-isolation + +reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_flash_setup + bfloat16_setup, + label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6) +reset_peak_memory_stats(device) +timer = Timer( + stmt=forward_only, + setup=setup + model_flash_setup + bfloat16_setup + variable_length_mask_setup, + label="Pytorch 2.2, flashattention `flash_attn_varlen_qkvpacked_func`, on A100, bfloat16, varlen mask", +) +print(timer.timeit(NUM_ITERATIONS)) +print(max_memory_allocated(device) / 10**6)