diff --git a/CHANGELOG.md b/CHANGELOG.md index b14c035834f..0014b4a6112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Full documentation for MIGraphX is available at * Fixed an issue with `reshape_lazy`'s shape computation that was leading to invalid reshapes (#4594). ### Optimized +* Added optimized fusion for local_window mode of GQA operator ### Removed diff --git a/requirements.txt b/requirements.txt index 32c218936a0..a4969cadc47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@ad0db05b040bacda751c65c705261b8a0a7ed25d --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On -DBUILD_TESTING=Off -ROCm/rocMLIR@9f843291bab85785796985b60f9a2840e7aad302 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@7b4aa8860e1002f98a2fbde411bf1a6289763d6a -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 8912b186cdb..5cd8b41955d 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -762,7 +762,7 @@ struct find_kv_cache_attention auto matcher() const { static const std::unordered_set skip_set = { - "multibroadcast", "reshape", "unsqueeze"}; + "multibroadcast", "broadcast", "reshape", "unsqueeze", "squeeze"}; auto keys = match::skip(match::name(skip_set))(match::name("concat_past_present")).bind("pres_k"); @@ -775,12 +775,18 @@ struct find_kv_cache_attention auto attn_scores = match::any_of(scale, gemm1); auto causal_mask = match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores)); + auto local_window_comp = match::skip(match::name(skip_set))( + match::name("convert")(match::arg(0)(match::name("greater")))); + auto local_window_mask = + match::name("where")(match::arg(0)(match::any_of(local_window_comp, broadcasted_const)), + match::arg(2)(match::any_of(causal_mask, scale, gemm1))); auto greater = match::name("greater")(match::arg(1)(match::any().bind("total_sl"))); auto conv_greater = match::skip(match::name("unsqueeze"))(match::name("convert")(match::arg(0)(greater))); - auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater)); - auto mask = match::name("where")(match::arg(0)(bc_greater), - match::arg(2)(match::any_of(causal_mask, scale, gemm1))); + auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater)); + auto mask = match::name("where")( + match::arg(0)(bc_greater), + match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1))); auto attn_probabilities = match::skip(match::name("convert"))( match::softmax_input(match::skip(match::name("convert"))(mask))); auto values = @@ -825,7 +831,8 @@ struct find_kv_cache_attention "broadcast", "multibroadcast", "@literal", - "unsqueeze"}; + "unsqueeze", + "squeeze"}; auto is_valid_attn_op = [&](auto i) { return i->get_operator().attributes().get("pointwise", false) or