From 845a0e3151010c7468faad92548f3591c9717e80 Mon Sep 17 00:00:00 2001 From: yanliang Date: Tue, 6 Jan 2026 14:39:37 +0800 Subject: [PATCH 1/3] Skip JIT warmup when fusion is disabled via arguments --- megatron/training/initialize.py | 96 +++++++++++++++++---------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index e88222fe7fe..76363089e04 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -472,55 +472,57 @@ def _warmup_jit_function(): dtype = torch.float32 # Warmup fused bias+gelu - bias = torch.rand( - args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" - ) - input = torch.rand( - ( - args.seq_length // args.context_parallel_size, - args.micro_batch_size, - args.ffn_hidden_size // args.tensor_model_parallel_size, - ), - dtype=dtype, - device="cuda", - ) - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for bias_grad, input_grad in zip([True, True], [False, True]): - bias.requires_grad, input.requires_grad = bias_grad, input_grad - for _ in range(5): - if args.swiglu: - output = bias_swiglu(input, bias) - else: - output = bias_gelu(bias, input) - del bias, input, output + if (args.swiglu and args.bias_swiglu_fusion) or (not args.swiglu and args.bias_gelu_fusion): + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + if args.swiglu: + output = bias_swiglu(input, bias) + else: + output = bias_gelu(bias, input) + del bias, input, output # Warmup fused bias+dropout+add - if args.sequence_parallel: - seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() - else: - seq_length = args.seq_length - input = torch.rand( - (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - residual = torch.rand( - (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual) - dropout_rate = 0.1 - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): - input.requires_grad = input_grad - bias.requires_grad = bias_grad - residual.requires_grad = residual_grad - for _ in range(5): - output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate) - del bias, input, residual, output + if args.bias_dropout_fusion: + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + input = torch.rand( + (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + residual = torch.rand( + (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual) + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + input.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + for _ in range(5): + output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate) + del bias, input, residual, output torch.cuda.empty_cache() From 7396dc414b5e30f7d4ae06687e10c99e86a37b92 Mon Sep 17 00:00:00 2001 From: yanliang Date: Thu, 5 Feb 2026 11:23:39 +0800 Subject: [PATCH 2/3] Fix: Add use_te_activation_func and geglu checks for JIT warmup --- megatron/training/initialize.py | 100 ++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 6 deletions(-) diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 76363089e04..90acbf73cb2 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -13,6 +13,7 @@ from megatron.core import mpu, tensor_parallel from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train +from megatron.core.fusions.fused_bias_geglu import bias_geglu from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu from megatron.core.parallel_state import create_group @@ -471,8 +472,75 @@ def _warmup_jit_function(): else: dtype = torch.float32 - # Warmup fused bias+gelu - if (args.swiglu and args.bias_swiglu_fusion) or (not args.swiglu and args.bias_gelu_fusion): + # Check if TE activation function is used (in which case, no need to warmup custom fusions) + use_te_activation_func = getattr(args, 'use_te_activation_func', False) + + # Determine which activation fusion to warmup based on args + # Reference: megatron/core/transformer/mlp.py and megatron/core/transformer/moe/shared_experts.py + # + # In MLP forward (when bias_activation_fusion=True and use_te_activation_func=False): + # - activation_func=F.gelu + gated_linear_unit=True -> bias_geglu_impl + # - activation_func=F.gelu + gated_linear_unit=False -> bias_gelu_impl + # - activation_func=F.silu + gated_linear_unit=True -> bias_swiglu_impl + # + # Args mapping: + # - args.swiglu=True -> gated_linear_unit=True, activation_func=F.silu + # - args.quick_geglu=True -> gated_linear_unit=True, activation_func=quick_gelu (no fusion warmup needed) + # - default (neither set) -> gated_linear_unit=False, activation_func=F.gelu + + # gated_linear_unit can be set via YAML config or other means + gated_linear_unit = getattr(args, 'gated_linear_unit', False) + + # Warmup bias_swiglu: swiglu activation (F.silu + GLU) + warmup_bias_swiglu = ( + not use_te_activation_func + and args.swiglu + and args.bias_swiglu_fusion + ) + + # Warmup bias_gelu: non-gated gelu activation (F.gelu without GLU) + warmup_bias_gelu = ( + not use_te_activation_func + and not args.swiglu + and not getattr(args, 'quick_geglu', False) + and not gated_linear_unit + and args.bias_gelu_fusion + ) + + # Warmup bias_geglu: gated gelu activation (F.gelu + GLU) + # This is triggered when gated_linear_unit=True with gelu activation + warmup_bias_geglu = ( + not use_te_activation_func + and not args.swiglu + and not getattr(args, 'quick_geglu', False) + and gated_linear_unit + and args.bias_gelu_fusion + ) + + # Warmup fused bias+swiglu + if warmup_bias_swiglu: + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_swiglu(input, bias) + del bias, input, output + + # Warmup fused bias+gelu (non-gated) + if warmup_bias_gelu: bias = torch.rand( args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" ) @@ -490,10 +558,30 @@ def _warmup_jit_function(): for bias_grad, input_grad in zip([True, True], [False, True]): bias.requires_grad, input.requires_grad = bias_grad, input_grad for _ in range(5): - if args.swiglu: - output = bias_swiglu(input, bias) - else: - output = bias_gelu(bias, input) + output = bias_gelu(bias, input) + del bias, input, output + + # Warmup fused bias+geglu (gated gelu) + if warmup_bias_geglu: + # For geglu, input size is 2x ffn_hidden_size (will be split into two halves) + bias = torch.rand( + (args.ffn_hidden_size // args.tensor_model_parallel_size) * 2, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + (args.ffn_hidden_size // args.tensor_model_parallel_size) * 2, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_geglu(bias, input) del bias, input, output # Warmup fused bias+dropout+add From 0fd720f0283a2f9330861b3057b2a502d804111b Mon Sep 17 00:00:00 2001 From: yanliang Date: Tue, 10 Feb 2026 09:45:06 +0800 Subject: [PATCH 3/3] Remove some comments --- megatron/training/initialize.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index b0739d808d8..50e337b4b5d 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -480,21 +480,6 @@ def _warmup_jit_function(): # Check if TE activation function is used (in which case, no need to warmup custom fusions) use_te_activation_func = getattr(args, 'use_te_activation_func', False) - - # Determine which activation fusion to warmup based on args - # Reference: megatron/core/transformer/mlp.py and megatron/core/transformer/moe/shared_experts.py - # - # In MLP forward (when bias_activation_fusion=True and use_te_activation_func=False): - # - activation_func=F.gelu + gated_linear_unit=True -> bias_geglu_impl - # - activation_func=F.gelu + gated_linear_unit=False -> bias_gelu_impl - # - activation_func=F.silu + gated_linear_unit=True -> bias_swiglu_impl - # - # Args mapping: - # - args.swiglu=True -> gated_linear_unit=True, activation_func=F.silu - # - args.quick_geglu=True -> gated_linear_unit=True, activation_func=quick_gelu (no fusion warmup needed) - # - default (neither set) -> gated_linear_unit=False, activation_func=F.gelu - - # gated_linear_unit can be set via YAML config or other means gated_linear_unit = getattr(args, 'gated_linear_unit', False) # Warmup bias_swiglu: swiglu activation (F.silu + GLU) @@ -504,7 +489,6 @@ def _warmup_jit_function(): and args.bias_swiglu_fusion ) - # Warmup bias_gelu: non-gated gelu activation (F.gelu without GLU) warmup_bias_gelu = ( not use_te_activation_func and not args.swiglu @@ -513,8 +497,6 @@ def _warmup_jit_function(): and args.bias_gelu_fusion ) - # Warmup bias_geglu: gated gelu activation (F.gelu + GLU) - # This is triggered when gated_linear_unit=True with gelu activation warmup_bias_geglu = ( not use_te_activation_func and not args.swiglu @@ -523,7 +505,6 @@ def _warmup_jit_function(): and args.bias_gelu_fusion ) - # Warmup fused bias+swiglu if warmup_bias_swiglu: bias = torch.rand( args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" @@ -545,7 +526,6 @@ def _warmup_jit_function(): output = bias_swiglu(input, bias) del bias, input, output - # Warmup fused bias+gelu (non-gated) if warmup_bias_gelu: bias = torch.rand( args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"