diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index 1814f9b9ec..f06e1081ba 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -349,14 +349,20 @@ template < const short col_pos = base_col + ik * UK + sn; MSubTile mfrag; - mfrag.load_safe( - mask, - int64_t(mask_params->M_strides[2]), - Int<1>{}, - params->qL, - params->kL, - row_pos, - col_pos); + if ((!align_Q && is_last_q) || (!align_K && is_last_k)) { + mfrag.load_safe( + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + } else { + const int M_str = int(mask_params->M_strides[2]); + const int M_load_off = row_pos * M_str + col_pos; + mfrag.load(mask + M_load_off, M_str, Int<1>{}); + } thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);