From 097abca15689caf85de3365bdb0f60826aaf8011 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 15 Jan 2026 14:05:54 -0800 Subject: [PATCH 1/2] Fix edge checking in attention mask --- .../steel/attn/kernels/steel_attention_nax.h | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index 1814f9b9ec..2c92f7c489 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -349,14 +349,21 @@ template < const short col_pos = base_col + ik * UK + sn; MSubTile mfrag; - mfrag.load_safe( - mask, - int64_t(mask_params->M_strides[2]), - Int<1>{}, - params->qL, - params->kL, - row_pos, - col_pos); + if ((!align_Q && is_last_q) || (!align_K && is_last_k)) { + mfrag.load_safe( + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + } else { + const int M_str = int(mask_params->M_strides[2]); + const int M_load_off = row_pos * M_str + col_pos; + mfrag.load(mask + M_load_off, M_str, Int<1>{}); + + } thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); From 12f766efe96b694e1ed7bdcabe1067fd39867f65 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 15 Jan 2026 14:10:24 -0800 Subject: [PATCH 2/2] Fix edge checking in attention mask --- .../metal/kernels/steel/attn/kernels/steel_attention_nax.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index 2c92f7c489..f06e1081ba 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -362,7 +362,6 @@ template < const int M_str = int(mask_params->M_strides[2]); const int M_load_off = row_pos * M_str + col_pos; mfrag.load(mask + M_load_off, M_str, Int<1>{}); - } thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);