diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c68c70735d..bf1e7752c92 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -251,6 +251,39 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers + def get_effective_seq_length(seq_len): + """ + Calculate effective sequence length for attention FLOPs based on attention pattern. + + For causal attention, only half the attention matrix is computed (lower triangular), + so we use seq_len / 2. For specialized attention patterns: + - Sliding Window Attention: uses min(seq_len, window_size) + - Chunk Attention: uses chunk_size + """ + # Check for chunk attention (e.g., Llama 4) + if hasattr(args, 'chunk_attention_size') and args.chunk_attention_size is not None: + effective_len = args.chunk_attention_size + # Check for sliding window attention (e.g., Gemma 3) + elif hasattr(args, 'window_size') and args.window_size is not None: + # window_size is a tuple (local_window, global_window) + # For FLOPs calculation, use the maximum window size + if isinstance(args.window_size, tuple): + # Filter out -1 (infinite window) and take the max of finite windows + finite_windows = [w for w in args.window_size if w > 0] + if finite_windows: + effective_len = min(seq_len, max(finite_windows)) + else: + # All windows are infinite (-1), so use full seq_len + effective_len = seq_len + else: + effective_len = min(seq_len, args.window_size) + else: + # Full causal attention - only half the matrix is computed + effective_len = seq_len + + # For causal attention, divide by 2 (lower triangular matrix) + return effective_len / 2 + def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 @@ -279,13 +312,14 @@ def attn_layer_flops( """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads + effective_seq_len = get_effective_seq_length(seq_len) return ( 4 * batch_size * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size + (hidden_size * (g / num_heads)) + effective_seq_len) ) def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, @@ -427,6 +461,7 @@ def transformer_flops(): + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim) + 1 ) + effective_seq_length = get_effective_seq_length(args.seq_length) standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor @@ -443,11 +478,11 @@ def transformer_flops(): + args.hidden_size * args.qk_pos_emb_head_dim ## o proj + (args.num_attention_heads * args.v_head_dim) * args.hidden_size - ## core attn - + args.seq_length + ## core attn - QK^T + + effective_seq_length * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + ## core attn - (QK^T)V + + effective_seq_length * args.num_attention_heads * args.v_head_dim ) ) @@ -457,6 +492,7 @@ def transformer_flops(): key_projection_size = args.kv_channels * args.num_query_groups value_projection_size = args.kv_channels * args.num_query_groups gate_projection_size = query_projection_size if args.attention_output_gate else 0 + effective_seq_length = get_effective_seq_length(args.seq_length) standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor @@ -471,8 +507,7 @@ def transformer_flops(): ) ## core attention + query_projection_size - * args.seq_length - / 2 # causal mask (only half of the mask is non-zero) + * effective_seq_length * 2 # QK^T and (QK^T)V ## out proj + query_projection_size diff --git a/tests/unit_tests/test_training.py b/tests/unit_tests/test_training.py index 6f94296ad89..bc45f543c86 100644 --- a/tests/unit_tests/test_training.py +++ b/tests/unit_tests/test_training.py @@ -4,12 +4,14 @@ from pathlib import Path from types import SimpleNamespace +import pytest import torch from megatron.core.tokenizers.utils.build_tokenizer import vocab_size_with_padding from megatron.training.checkpointing import save_grads from megatron.training.global_vars import set_args -from megatron.training.training import build_train_valid_test_data_iterators +from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding +from megatron.training.training import build_train_valid_test_data_iterators, num_floating_point_operations from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -135,3 +137,246 @@ def test_save_grads(self, tmp_path_dist_ckpt): assert torch.equal( loaded["model_chunk0"]["layer.bias"], state_dict["model_chunk0"]["layer.bias"] ) + + +class TestFLOPsCalculation: + """Tests for FLOPs calculation with different attention patterns.""" + + def create_base_args(self): + """Create base args for FLOPs testing.""" + args = SimpleNamespace() + args.num_layers = 12 + args.hidden_size = 768 + args.num_attention_heads = 12 + args.kv_channels = 64 + args.seq_length = 2048 + args.ffn_hidden_size = 3072 + args.swiglu = False + args.group_query_attention = False + args.num_query_groups = 12 + args.attention_output_gate = False + args.multi_latent_attention = False + args.num_experts = None + args.moe_layer_freq = None + args.mtp_num_layers = None + args.experimental_attention_variant = None + args.linear_attention_freq = None + args.hybrid_override_pattern = None + args.hybrid_attention_ratio = 0.0 + args.hybrid_mlp_ratio = 0.0 + return args + + def test_full_causal_attention_baseline(self): + """Test FLOPs calculation for standard full causal attention.""" + args = self.create_base_args() + # No window_size or chunk_attention_size + args.window_size = None + args.chunk_attention_size = None + + batch_size = 8 + flops = num_floating_point_operations(args, batch_size) + + # FLOPs should be positive + assert flops > 0, "FLOPs should be positive for baseline case" + + # Store baseline for comparison + baseline_flops = flops + return baseline_flops + + def test_sliding_window_attention_reduces_flops(self): + """Test that sliding window attention reduces FLOPs compared to full attention.""" + args = self.create_base_args() + batch_size = 8 + + # Calculate baseline (full causal attention) + args.window_size = None + args.chunk_attention_size = None + baseline_flops = num_floating_point_operations(args, batch_size) + + # Calculate with sliding window (window much smaller than seq_length) + args.window_size = (512, 512) # Much smaller than seq_length=2048 + sliding_window_flops = num_floating_point_operations(args, batch_size) + + # Sliding window should result in fewer FLOPs + assert sliding_window_flops < baseline_flops, ( + f"Sliding window FLOPs ({sliding_window_flops}) should be less than " + f"baseline FLOPs ({baseline_flops})" + ) + + # Calculate expected reduction ratio + # For attention, effective_seq_len changes from 2048 to 512 + # The reduction should be approximately proportional to the window size + reduction_ratio = sliding_window_flops / baseline_flops + # Should see significant reduction (at least 20% savings) + assert reduction_ratio < 0.95, ( + f"Expected significant FLOPs reduction with sliding window, " + f"but got ratio {reduction_ratio}" + ) + + def test_sliding_window_with_infinite_window(self): + """Test sliding window with -1 (infinite window) equals full attention.""" + args = self.create_base_args() + batch_size = 8 + + # Full attention baseline + args.window_size = None + args.chunk_attention_size = None + baseline_flops = num_floating_point_operations(args, batch_size) + + # Sliding window with infinite window (-1) + args.window_size = (-1, -1) + infinite_window_flops = num_floating_point_operations(args, batch_size) + + # Should be the same as baseline + assert abs(infinite_window_flops - baseline_flops) < 1e-6, ( + f"Infinite window FLOPs ({infinite_window_flops}) should equal " + f"baseline FLOPs ({baseline_flops})" + ) + + def test_chunked_attention_reduces_flops(self): + """Test that chunked attention reduces FLOPs compared to full attention.""" + args = self.create_base_args() + batch_size = 8 + + # Calculate baseline (full causal attention) + args.window_size = None + args.chunk_attention_size = None + baseline_flops = num_floating_point_operations(args, batch_size) + + # Calculate with chunked attention (chunk_size much smaller than seq_length) + args.chunk_attention_size = 256 # Much smaller than seq_length=2048 + chunked_flops = num_floating_point_operations(args, batch_size) + + # Chunked attention should result in fewer FLOPs + assert chunked_flops < baseline_flops, ( + f"Chunked attention FLOPs ({chunked_flops}) should be less than " + f"baseline FLOPs ({baseline_flops})" + ) + + # Calculate expected reduction ratio + reduction_ratio = chunked_flops / baseline_flops + # Should see significant reduction (at least 30% savings) + assert reduction_ratio < 0.9, ( + f"Expected significant FLOPs reduction with chunked attention, " + f"but got ratio {reduction_ratio}" + ) + + def test_gqa_with_sliding_window(self): + """Test FLOPs calculation for GQA with sliding window attention.""" + args = self.create_base_args() + args.group_query_attention = True + args.num_query_groups = 4 # GQA with 4 groups + batch_size = 8 + + # GQA baseline + args.window_size = None + args.chunk_attention_size = None + gqa_baseline_flops = num_floating_point_operations(args, batch_size) + + # GQA with sliding window + args.window_size = (512, 512) + gqa_sliding_flops = num_floating_point_operations(args, batch_size) + + # Sliding window should still reduce FLOPs for GQA + assert gqa_sliding_flops < gqa_baseline_flops, ( + f"GQA with sliding window FLOPs ({gqa_sliding_flops}) should be less than " + f"GQA baseline FLOPs ({gqa_baseline_flops})" + ) + + def test_mla_with_sliding_window(self): + """Test FLOPs calculation for MLA with sliding window attention.""" + args = self.create_base_args() + # Enable MLA + args.multi_latent_attention = True + args.q_lora_rank = None # Use standard q projection + args.kv_lora_rank = 512 + args.qk_head_dim = 64 + args.v_head_dim = 64 + args.qk_pos_emb_head_dim = 64 + batch_size = 8 + + # MLA baseline + args.window_size = None + args.chunk_attention_size = None + mla_baseline_flops = num_floating_point_operations(args, batch_size) + + # MLA with sliding window + args.window_size = (512, 512) + mla_sliding_flops = num_floating_point_operations(args, batch_size) + + # Sliding window should reduce FLOPs for MLA + assert mla_sliding_flops < mla_baseline_flops, ( + f"MLA with sliding window FLOPs ({mla_sliding_flops}) should be less than " + f"MLA baseline FLOPs ({mla_baseline_flops})" + ) + + def test_chunk_attention_takes_precedence_over_sliding_window(self): + """Test that chunk_attention_size takes precedence over window_size.""" + args = self.create_base_args() + batch_size = 8 + + # Only chunk attention + args.window_size = None + args.chunk_attention_size = 256 + chunk_only_flops = num_floating_point_operations(args, batch_size) + + # Both chunk and sliding window (chunk should take precedence) + args.window_size = (1024, 1024) + args.chunk_attention_size = 256 + both_flops = num_floating_point_operations(args, batch_size) + + # Should be identical since chunk takes precedence + assert abs(both_flops - chunk_only_flops) < 1e-6, ( + f"Chunk attention should take precedence. " + f"chunk_only: {chunk_only_flops}, both: {both_flops}" + ) + + @pytest.mark.parametrize("window_size", [ + (128, 128), + (256, 512), + (1024, 2048), + (2048, -1), # One finite, one infinite + ]) + def test_various_window_sizes(self, window_size): + """Test FLOPs calculation with various window sizes.""" + args = self.create_base_args() + args.window_size = window_size + args.chunk_attention_size = None + batch_size = 8 + + flops = num_floating_point_operations(args, batch_size) + + # FLOPs should always be positive + assert flops > 0, f"FLOPs should be positive for window_size={window_size}" + + @pytest.mark.parametrize("chunk_size", [64, 128, 256, 512, 1024]) + def test_various_chunk_sizes(self, chunk_size): + """Test FLOPs calculation with various chunk sizes.""" + args = self.create_base_args() + args.window_size = None + args.chunk_attention_size = chunk_size + batch_size = 8 + + flops = num_floating_point_operations(args, batch_size) + + # FLOPs should always be positive + assert flops > 0, f"FLOPs should be positive for chunk_size={chunk_size}" + + def test_flops_scale_with_batch_size(self): + """Test that FLOPs scale linearly with batch size.""" + args = self.create_base_args() + args.window_size = (512, 512) + args.chunk_attention_size = None + + batch_size_1 = 1 + batch_size_8 = 8 + + flops_1 = num_floating_point_operations(args, batch_size_1) + flops_8 = num_floating_point_operations(args, batch_size_8) + + # Should scale linearly + ratio = flops_8 / flops_1 + assert abs(ratio - 8.0) < 0.01, ( + f"FLOPs should scale linearly with batch size, " + f"expected ratio ~8.0, got {ratio}" + )