diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index c150ac3d5ca..f671a83ef0c 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -13,6 +13,7 @@ from megatron.core import mpu, tensor_parallel from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train +from megatron.core.fusions.fused_bias_geglu import bias_geglu from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu from megatron.core.parallel_state import create_group @@ -470,56 +471,125 @@ def _warmup_jit_function(): else: dtype = torch.float32 - # Warmup fused bias+gelu - bias = torch.rand( - args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" + # Check if TE activation function is used (in which case, no need to warmup custom fusions) + use_te_activation_func = getattr(args, 'use_te_activation_func', False) + gated_linear_unit = getattr(args, 'gated_linear_unit', False) + + # Warmup bias_swiglu: swiglu activation (F.silu + GLU) + warmup_bias_swiglu = ( + not use_te_activation_func + and args.swiglu + and args.bias_swiglu_fusion ) - input = torch.rand( - ( - args.seq_length // args.context_parallel_size, - args.micro_batch_size, - args.ffn_hidden_size // args.tensor_model_parallel_size, - ), - dtype=dtype, - device="cuda", + + warmup_bias_gelu = ( + not use_te_activation_func + and not args.swiglu + and not getattr(args, 'quick_geglu', False) + and not gated_linear_unit + and args.bias_gelu_fusion + ) + + warmup_bias_geglu = ( + not use_te_activation_func + and not args.swiglu + and not getattr(args, 'quick_geglu', False) + and gated_linear_unit + and args.bias_gelu_fusion ) - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for bias_grad, input_grad in zip([True, True], [False, True]): - bias.requires_grad, input.requires_grad = bias_grad, input_grad - for _ in range(5): - if args.swiglu: + + if warmup_bias_swiglu: + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): output = bias_swiglu(input, bias) - else: + del bias, input, output + + if warmup_bias_gelu: + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): output = bias_gelu(bias, input) - del bias, input, output + del bias, input, output + + # Warmup fused bias+geglu (gated gelu) + if warmup_bias_geglu: + # For geglu, input size is 2x ffn_hidden_size (will be split into two halves) + bias = torch.rand( + (args.ffn_hidden_size // args.tensor_model_parallel_size) * 2, dtype=dtype, device="cuda" + ) + input = torch.rand( + ( + args.seq_length // args.context_parallel_size, + args.micro_batch_size, + (args.ffn_hidden_size // args.tensor_model_parallel_size) * 2, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_geglu(bias, input) + del bias, input, output # Warmup fused bias+dropout+add - if args.sequence_parallel: - seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() - else: - seq_length = args.seq_length - input = torch.rand( - (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - residual = torch.rand( - (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual) - dropout_rate = 0.1 - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): - input.requires_grad = input_grad - bias.requires_grad = bias_grad - residual.requires_grad = residual_grad - for _ in range(5): - output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate) - del bias, input, residual, output + if args.bias_dropout_fusion: + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + input = torch.rand( + (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + residual = torch.rand( + (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual) + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + input.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + for _ in range(5): + output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate) + del bias, input, residual, output torch.cuda.empty_cache()