Skip to content
158 changes: 114 additions & 44 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from megatron.core import mpu, tensor_parallel
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_geglu import bias_geglu
from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
from megatron.core.parallel_state import create_group
Expand Down Expand Up @@ -470,56 +471,125 @@ def _warmup_jit_function():
else:
dtype = torch.float32

# Warmup fused bias+gelu
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
# Check if TE activation function is used (in which case, no need to warmup custom fusions)
use_te_activation_func = getattr(args, 'use_te_activation_func', False)
gated_linear_unit = getattr(args, 'gated_linear_unit', False)

# Warmup bias_swiglu: swiglu activation (F.silu + GLU)
warmup_bias_swiglu = (
not use_te_activation_func
and args.swiglu
and args.bias_swiglu_fusion
)
input = torch.rand(
(
args.seq_length // args.context_parallel_size,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",

warmup_bias_gelu = (
not use_te_activation_func
and not args.swiglu
and not getattr(args, 'quick_geglu', False)
and not gated_linear_unit
and args.bias_gelu_fusion
)

warmup_bias_geglu = (
not use_te_activation_func
and not args.swiglu
and not getattr(args, 'quick_geglu', False)
and gated_linear_unit
and args.bias_gelu_fusion
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
if args.swiglu:

if warmup_bias_swiglu:
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
)
input = torch.rand(
(
args.seq_length // args.context_parallel_size,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_swiglu(input, bias)
else:
del bias, input, output

if warmup_bias_gelu:
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
)
input = torch.rand(
(
args.seq_length // args.context_parallel_size,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu(bias, input)
del bias, input, output
del bias, input, output

# Warmup fused bias+geglu (gated gelu)
if warmup_bias_geglu:
# For geglu, input size is 2x ffn_hidden_size (will be split into two halves)
bias = torch.rand(
(args.ffn_hidden_size // args.tensor_model_parallel_size) * 2, dtype=dtype, device="cuda"
)
input = torch.rand(
(
args.seq_length // args.context_parallel_size,
args.micro_batch_size,
(args.ffn_hidden_size // args.tensor_model_parallel_size) * 2,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_geglu(bias, input)
del bias, input, output

# Warmup fused bias+dropout+add
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand(
(seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
residual = torch.rand(
(seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate)
del bias, input, residual, output
if args.bias_dropout_fusion:
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand(
(seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
residual = torch.rand(
(seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate)
del bias, input, residual, output
torch.cuda.empty_cache()


Expand Down
Loading