diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 28e3dde01c4..dc1cf05c235 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -803,6 +803,8 @@ def flash_decode_and_prefill( q = q.squeeze(1) if getattr(self, "softmax_scale", None) is not None: softmax_scale = self.softmax_scale + elif self.config.softmax_scale is not None: + softmax_scale = self.config.softmax_scale else: softmax_scale = q.shape[-1] ** -0.5 if HAVE_FA3: diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a9cdc697cc8..2cdcc3a7fd3 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -125,7 +125,15 @@ def __init__( self.qkv_up_checkpoint = None mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale_all_dim) - self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim) + self.softmax_scale = ( + mscale + * mscale + * ( + (1 / math.sqrt(self.q_head_dim)) + if self.config.softmax_scale is None + else self.config.softmax_scale + ) + ) self.cache_mla_latents = self.config.cache_mla_latents if self.config.rope_type == "rope":