diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py index ab59569d..07dfd82a 100644 --- a/src/twinkle/model/megatron/model/register.py +++ b/src/twinkle/model/megatron/model/register.py @@ -55,10 +55,10 @@ def get_layer_spec(self, config, args, mg_config_dict): A ``ModuleSpec`` or ``TransformerBlockSubmodules`` instance. """ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec - num_experts = mg_config_dict.get('num_experts', 0) or 0 + num_experts = mg_config_dict.get('num_experts') or None return get_gpt_layer_with_transformer_engine_spec( num_experts=num_experts, - moe_grouped_gemm=num_experts > 0, + moe_grouped_gemm=num_experts is not None, qk_layernorm=mg_config_dict.get('qk_layernorm', False), )