From f5c4491096dafd050c9e82e5dbad8c5512d9f0f5 Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Fri, 13 Feb 2026 05:06:46 +0800 Subject: [PATCH 1/4] Update matcher --- src/fuse_attention.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 8912b186cdb..be83f89b66f 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -762,7 +762,7 @@ struct find_kv_cache_attention auto matcher() const { static const std::unordered_set skip_set = { - "multibroadcast", "reshape", "unsqueeze"}; + "multibroadcast", "broadcast", "reshape", "unsqueeze", "squeeze"}; auto keys = match::skip(match::name(skip_set))(match::name("concat_past_present")).bind("pres_k"); @@ -775,12 +775,14 @@ struct find_kv_cache_attention auto attn_scores = match::any_of(scale, gemm1); auto causal_mask = match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores)); + auto local_window_comp = match::skip(match::name(skip_set))(match::name("convert")(match::arg(0)(match::name("greater")))); + auto local_window_mask = match::name("where")(match::arg(0)(match::any_of(local_window_comp, broadcasted_const)), match::arg(2)(match::any_of(causal_mask, scale, gemm1))); auto greater = match::name("greater")(match::arg(1)(match::any().bind("total_sl"))); auto conv_greater = match::skip(match::name("unsqueeze"))(match::name("convert")(match::arg(0)(greater))); auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater)); auto mask = match::name("where")(match::arg(0)(bc_greater), - match::arg(2)(match::any_of(causal_mask, scale, gemm1))); + match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1))); auto attn_probabilities = match::skip(match::name("convert"))( match::softmax_input(match::skip(match::name("convert"))(mask))); auto values = @@ -825,7 +827,8 @@ struct find_kv_cache_attention "broadcast", "multibroadcast", "@literal", - "unsqueeze"}; + "unsqueeze", + "squeeze"}; auto is_valid_attn_op = [&](auto i) { return i->get_operator().attributes().get("pointwise", false) or From 6e581b4987b7b951d2d73f62c64b7c2e775a4f5c Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Fri, 13 Feb 2026 07:38:36 +0800 Subject: [PATCH 2/4] Rocmlir commit --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 32c218936a0..a4969cadc47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@ad0db05b040bacda751c65c705261b8a0a7ed25d --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On -DBUILD_TESTING=Off -ROCm/rocMLIR@9f843291bab85785796985b60f9a2840e7aad302 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@7b4aa8860e1002f98a2fbde411bf1a6289763d6a -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From f5d1fb72d2805e2c109e3afa90e62682211ac581 Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Wed, 18 Feb 2026 00:15:49 +0800 Subject: [PATCH 3/4] Formatting --- src/fuse_attention.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index be83f89b66f..5cd8b41955d 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -775,14 +775,18 @@ struct find_kv_cache_attention auto attn_scores = match::any_of(scale, gemm1); auto causal_mask = match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores)); - auto local_window_comp = match::skip(match::name(skip_set))(match::name("convert")(match::arg(0)(match::name("greater")))); - auto local_window_mask = match::name("where")(match::arg(0)(match::any_of(local_window_comp, broadcasted_const)), match::arg(2)(match::any_of(causal_mask, scale, gemm1))); + auto local_window_comp = match::skip(match::name(skip_set))( + match::name("convert")(match::arg(0)(match::name("greater")))); + auto local_window_mask = + match::name("where")(match::arg(0)(match::any_of(local_window_comp, broadcasted_const)), + match::arg(2)(match::any_of(causal_mask, scale, gemm1))); auto greater = match::name("greater")(match::arg(1)(match::any().bind("total_sl"))); auto conv_greater = match::skip(match::name("unsqueeze"))(match::name("convert")(match::arg(0)(greater))); - auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater)); - auto mask = match::name("where")(match::arg(0)(bc_greater), - match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1))); + auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater)); + auto mask = match::name("where")( + match::arg(0)(bc_greater), + match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1))); auto attn_probabilities = match::skip(match::name("convert"))( match::softmax_input(match::skip(match::name("convert"))(mask))); auto values = From ee4534002ac7c1d731e8bbd64b72e25cde37c075 Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Wed, 18 Feb 2026 00:26:01 +0800 Subject: [PATCH 4/4] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b14c035834f..0014b4a6112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Full documentation for MIGraphX is available at * Fixed an issue with `reshape_lazy`'s shape computation that was leading to invalid reshapes (#4594). ### Optimized +* Added optimized fusion for local_window mode of GQA operator ### Removed