Skip to content
12 changes: 10 additions & 2 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def Rock_AttentionOp
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse, I32Attr:$numHeadsQ,
I32Attr:$numHeadsKV, UnitAttr:$qTransposed, UnitAttr:$kTransposed,
UnitAttr:$vTransposed, UnitAttr:$oTransposed, UnitAttr:$causal,
I32Attr:$splitKV, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
I32Attr:$splitKV, OptionalAttr<I32Attr>:$slidingWindowSize,
OptionalAttr<Rock_GemmFeaturesAttr>:$features,
StoreMethodAttr:$storeMethod, OptionalAttr<TypeAttr>:$softmaxType,
OptionalAttr<RockTuningParamAttrInterface>:$params0,
OptionalAttr<RockTuningParamAttrInterface>:$params1,
Expand Down Expand Up @@ -253,6 +254,11 @@ def Rock_AttentionOp
- A tensor of shape [G]: per-group/batch offsets, allowing different prefix
lengths for each sequence in the batch

If slidingWindowSize is set, we implement sliding window attention where
only the last `slidingWindowSize` key positions (relative to currentSeqLen)
are attended to. Positions before `max(0, currentSeqLen - slidingWindowSize)`
are masked with -inf. This requires currentSeqLen to be set.

LSE (log-sum-exp) is an optional output typically used for flash decoding.
For flash decoding, you can pass splitKV > 1, the default value is 1, which means flash decoding is disabled.
Flash decoding multiplies the number of blocks by splitKV. Note that "lse" has to be non-null for splitKV > 1.
Expand All @@ -278,6 +284,7 @@ def Rock_AttentionOp
` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n`
(`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)?
(`prefixOffset` `=` `(` $prefixOffset^ `:` type($prefixOffset) `)` `\n`)?
(`slidingWindowSize` `=` $slidingWindowSize^ `\n`)?
(`causal` `\n` $causal^)?
(`lse` `=` $lse^ `:` type($lse) `\n`)?
(`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)?
Expand Down Expand Up @@ -583,7 +590,8 @@ def Rock_GridwiseAttentionAccelOp
Optional<MemRefRankOf<[I32], [1]>>:$prefixOffset,
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
Optional<MemRefRankOf<[F32, F16, BF16], [2]>>:$lse, UnitAttr:$causal,
I32Attr:$splitKV, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
I32Attr:$splitKV, OptionalAttr<I32Attr>:$slidingWindowSize,
OptionalAttr<Rock_GemmFeaturesAttr>:$features,
StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize,
UnitAttr:$disableQBypassLDS, OptionalAttr<IndexAttr>:$prePadG0M,
OptionalAttr<IndexAttr>:$prePadG0N,
Expand Down
Loading