Skip to content
Draft
49 changes: 42 additions & 7 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,39 @@ def calculate_layer_counts():
num_moe_layers = 0
return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers

def get_effective_seq_length(seq_len):
"""
Calculate effective sequence length for attention FLOPs based on attention pattern.

For causal attention, only half the attention matrix is computed (lower triangular),
so we use seq_len / 2. For specialized attention patterns:
- Sliding Window Attention: uses min(seq_len, window_size)
- Chunk Attention: uses chunk_size
"""
# Check for chunk attention (e.g., Llama 4)
if hasattr(args, 'chunk_attention_size') and args.chunk_attention_size is not None:
effective_len = args.chunk_attention_size
# Check for sliding window attention (e.g., Gemma 3)
elif hasattr(args, 'window_size') and args.window_size is not None:
# window_size is a tuple (local_window, global_window)
# For FLOPs calculation, use the maximum window size
if isinstance(args.window_size, tuple):
# Filter out -1 (infinite window) and take the max of finite windows
finite_windows = [w for w in args.window_size if w > 0]
if finite_windows:
effective_len = min(seq_len, max(finite_windows))
else:
# All windows are infinite (-1), so use full seq_len
effective_len = seq_len
else:
effective_len = min(seq_len, args.window_size)
else:
# Full causal attention - only half the matrix is computed
effective_len = seq_len

# For causal attention, divide by 2 (lower triangular matrix)
return effective_len / 2

def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
"""Calculate FLOPs for an MLP layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
Expand Down Expand Up @@ -279,13 +312,14 @@ def attn_layer_flops(
"""Calculate FLOPs for an attention layer."""
p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
g = gqa_groups if gqa else num_heads
effective_seq_len = get_effective_seq_length(seq_len)
return (
4
* batch_size
* seq_len
* hidden_size
* p
* (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
* (hidden_size + (hidden_size * (g / num_heads)) + effective_seq_len)
)

def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16,
Expand Down Expand Up @@ -427,6 +461,7 @@ def transformer_flops():
+ args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)
+ 1
)
effective_seq_length = get_effective_seq_length(args.seq_length)
standard_self_attn_term = (
forward_backward_expansion_factor
* fma_expansion_factor
Expand All @@ -443,11 +478,11 @@ def transformer_flops():
+ args.hidden_size * args.qk_pos_emb_head_dim
## o proj
+ (args.num_attention_heads * args.v_head_dim) * args.hidden_size
## core attn
+ args.seq_length
## core attn - QK^T
+ effective_seq_length
* (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim))
/ 2 # causal mask (only half of the mask is non-zero)
+ args.seq_length * args.num_attention_heads * args.v_head_dim / 2
## core attn - (QK^T)V
+ effective_seq_length * args.num_attention_heads * args.v_head_dim
)
)

Expand All @@ -457,6 +492,7 @@ def transformer_flops():
key_projection_size = args.kv_channels * args.num_query_groups
value_projection_size = args.kv_channels * args.num_query_groups
gate_projection_size = query_projection_size if args.attention_output_gate else 0
effective_seq_length = get_effective_seq_length(args.seq_length)
standard_self_attn_term = (
forward_backward_expansion_factor
* fma_expansion_factor
Expand All @@ -471,8 +507,7 @@ def transformer_flops():
)
## core attention
+ query_projection_size
* args.seq_length
/ 2 # causal mask (only half of the mask is non-zero)
* effective_seq_length
* 2 # QK^T and (QK^T)V
## out proj
+ query_projection_size
Expand Down
247 changes: 246 additions & 1 deletion tests/unit_tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from pathlib import Path
from types import SimpleNamespace

import pytest
import torch

from megatron.core.tokenizers.utils.build_tokenizer import vocab_size_with_padding
from megatron.training.checkpointing import save_grads
from megatron.training.global_vars import set_args
from megatron.training.training import build_train_valid_test_data_iterators
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
from megatron.training.training import build_train_valid_test_data_iterators, num_floating_point_operations
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils

Expand Down Expand Up @@ -135,3 +137,246 @@ def test_save_grads(self, tmp_path_dist_ckpt):
assert torch.equal(
loaded["model_chunk0"]["layer.bias"], state_dict["model_chunk0"]["layer.bias"]
)


class TestFLOPsCalculation:
"""Tests for FLOPs calculation with different attention patterns."""

def create_base_args(self):
"""Create base args for FLOPs testing."""
args = SimpleNamespace()
args.num_layers = 12
args.hidden_size = 768
args.num_attention_heads = 12
args.kv_channels = 64
args.seq_length = 2048
args.ffn_hidden_size = 3072
args.swiglu = False
args.group_query_attention = False
args.num_query_groups = 12
args.attention_output_gate = False
args.multi_latent_attention = False
args.num_experts = None
args.moe_layer_freq = None
args.mtp_num_layers = None
args.experimental_attention_variant = None
args.linear_attention_freq = None
args.hybrid_override_pattern = None
args.hybrid_attention_ratio = 0.0
args.hybrid_mlp_ratio = 0.0
return args

def test_full_causal_attention_baseline(self):
"""Test FLOPs calculation for standard full causal attention."""
args = self.create_base_args()
# No window_size or chunk_attention_size
args.window_size = None
args.chunk_attention_size = None

batch_size = 8
flops = num_floating_point_operations(args, batch_size)

# FLOPs should be positive
assert flops > 0, "FLOPs should be positive for baseline case"

# Store baseline for comparison
baseline_flops = flops
return baseline_flops

def test_sliding_window_attention_reduces_flops(self):
"""Test that sliding window attention reduces FLOPs compared to full attention."""
args = self.create_base_args()
batch_size = 8

# Calculate baseline (full causal attention)
args.window_size = None
args.chunk_attention_size = None
baseline_flops = num_floating_point_operations(args, batch_size)

# Calculate with sliding window (window much smaller than seq_length)
args.window_size = (512, 512) # Much smaller than seq_length=2048
sliding_window_flops = num_floating_point_operations(args, batch_size)

# Sliding window should result in fewer FLOPs
assert sliding_window_flops < baseline_flops, (
f"Sliding window FLOPs ({sliding_window_flops}) should be less than "
f"baseline FLOPs ({baseline_flops})"
)

# Calculate expected reduction ratio
# For attention, effective_seq_len changes from 2048 to 512
# The reduction should be approximately proportional to the window size
reduction_ratio = sliding_window_flops / baseline_flops
# Should see significant reduction (at least 20% savings)
assert reduction_ratio < 0.95, (
f"Expected significant FLOPs reduction with sliding window, "
f"but got ratio {reduction_ratio}"
)

def test_sliding_window_with_infinite_window(self):
"""Test sliding window with -1 (infinite window) equals full attention."""
args = self.create_base_args()
batch_size = 8

# Full attention baseline
args.window_size = None
args.chunk_attention_size = None
baseline_flops = num_floating_point_operations(args, batch_size)

# Sliding window with infinite window (-1)
args.window_size = (-1, -1)
infinite_window_flops = num_floating_point_operations(args, batch_size)

# Should be the same as baseline
assert abs(infinite_window_flops - baseline_flops) < 1e-6, (
f"Infinite window FLOPs ({infinite_window_flops}) should equal "
f"baseline FLOPs ({baseline_flops})"
)

def test_chunked_attention_reduces_flops(self):
"""Test that chunked attention reduces FLOPs compared to full attention."""
args = self.create_base_args()
batch_size = 8

# Calculate baseline (full causal attention)
args.window_size = None
args.chunk_attention_size = None
baseline_flops = num_floating_point_operations(args, batch_size)

# Calculate with chunked attention (chunk_size much smaller than seq_length)
args.chunk_attention_size = 256 # Much smaller than seq_length=2048
chunked_flops = num_floating_point_operations(args, batch_size)

# Chunked attention should result in fewer FLOPs
assert chunked_flops < baseline_flops, (
f"Chunked attention FLOPs ({chunked_flops}) should be less than "
f"baseline FLOPs ({baseline_flops})"
)

# Calculate expected reduction ratio
reduction_ratio = chunked_flops / baseline_flops
# Should see significant reduction (at least 30% savings)
assert reduction_ratio < 0.9, (
f"Expected significant FLOPs reduction with chunked attention, "
f"but got ratio {reduction_ratio}"
)

def test_gqa_with_sliding_window(self):
"""Test FLOPs calculation for GQA with sliding window attention."""
args = self.create_base_args()
args.group_query_attention = True
args.num_query_groups = 4 # GQA with 4 groups
batch_size = 8

# GQA baseline
args.window_size = None
args.chunk_attention_size = None
gqa_baseline_flops = num_floating_point_operations(args, batch_size)

# GQA with sliding window
args.window_size = (512, 512)
gqa_sliding_flops = num_floating_point_operations(args, batch_size)

# Sliding window should still reduce FLOPs for GQA
assert gqa_sliding_flops < gqa_baseline_flops, (
f"GQA with sliding window FLOPs ({gqa_sliding_flops}) should be less than "
f"GQA baseline FLOPs ({gqa_baseline_flops})"
)

def test_mla_with_sliding_window(self):
"""Test FLOPs calculation for MLA with sliding window attention."""
args = self.create_base_args()
# Enable MLA
args.multi_latent_attention = True
args.q_lora_rank = None # Use standard q projection
args.kv_lora_rank = 512
args.qk_head_dim = 64
args.v_head_dim = 64
args.qk_pos_emb_head_dim = 64
batch_size = 8

# MLA baseline
args.window_size = None
args.chunk_attention_size = None
mla_baseline_flops = num_floating_point_operations(args, batch_size)

# MLA with sliding window
args.window_size = (512, 512)
mla_sliding_flops = num_floating_point_operations(args, batch_size)

# Sliding window should reduce FLOPs for MLA
assert mla_sliding_flops < mla_baseline_flops, (
f"MLA with sliding window FLOPs ({mla_sliding_flops}) should be less than "
f"MLA baseline FLOPs ({mla_baseline_flops})"
)

def test_chunk_attention_takes_precedence_over_sliding_window(self):
"""Test that chunk_attention_size takes precedence over window_size."""
args = self.create_base_args()
batch_size = 8

# Only chunk attention
args.window_size = None
args.chunk_attention_size = 256
chunk_only_flops = num_floating_point_operations(args, batch_size)

# Both chunk and sliding window (chunk should take precedence)
args.window_size = (1024, 1024)
args.chunk_attention_size = 256
both_flops = num_floating_point_operations(args, batch_size)

# Should be identical since chunk takes precedence
assert abs(both_flops - chunk_only_flops) < 1e-6, (
f"Chunk attention should take precedence. "
f"chunk_only: {chunk_only_flops}, both: {both_flops}"
)

@pytest.mark.parametrize("window_size", [
(128, 128),
(256, 512),
(1024, 2048),
(2048, -1), # One finite, one infinite
])
def test_various_window_sizes(self, window_size):
"""Test FLOPs calculation with various window sizes."""
args = self.create_base_args()
args.window_size = window_size
args.chunk_attention_size = None
batch_size = 8

flops = num_floating_point_operations(args, batch_size)

# FLOPs should always be positive
assert flops > 0, f"FLOPs should be positive for window_size={window_size}"

@pytest.mark.parametrize("chunk_size", [64, 128, 256, 512, 1024])
def test_various_chunk_sizes(self, chunk_size):
"""Test FLOPs calculation with various chunk sizes."""
args = self.create_base_args()
args.window_size = None
args.chunk_attention_size = chunk_size
batch_size = 8

flops = num_floating_point_operations(args, batch_size)

# FLOPs should always be positive
assert flops > 0, f"FLOPs should be positive for chunk_size={chunk_size}"

def test_flops_scale_with_batch_size(self):
"""Test that FLOPs scale linearly with batch size."""
args = self.create_base_args()
args.window_size = (512, 512)
args.chunk_attention_size = None

batch_size_1 = 1
batch_size_8 = 8

flops_1 = num_floating_point_operations(args, batch_size_1)
flops_8 = num_floating_point_operations(args, batch_size_8)

# Should scale linearly
ratio = flops_8 / flops_1
assert abs(ratio - 8.0) < 0.01, (
f"FLOPs should scale linearly with batch size, "
f"expected ratio ~8.0, got {ratio}"
)