From 9178f78d68fcbdcf7d58e407939daa193202ac48 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 10 Jan 2026 17:39:19 +0100 Subject: [PATCH 1/8] add mod1, max, min --- src/language/arithmetic.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/language/arithmetic.jl b/src/language/arithmetic.jl index 1a69f39..99a040b 100644 --- a/src/language/arithmetic.jl +++ b/src/language/arithmetic.jl @@ -29,6 +29,7 @@ @overlay Base.div(x::T, y::T, ::typeof(RoundUp)) where {T <: Unsigned} = Intrinsics.cldi(x, y, SignednessUnsigned) @overlay Base.rem(x::T, y::T) where {T <: Signed} = Intrinsics.remi(x, y, SignednessSigned) @overlay Base.rem(x::T, y::T) where {T <: Unsigned} = Intrinsics.remi(x, y, SignednessUnsigned) +@overlay Base.mod1(x::T, y::T) where {T <: ScalarInt} = (m = mod(x, y); m == zero(m) ? y : m) # float @overlay Base.:+(x::T, y::T) where {T <: ScalarFloat} = Intrinsics.addf(x, y) @@ -77,7 +78,7 @@ @inline Base.:(-)(a::Tile{T, S}, b::Tile{T, S}) where {T <: Integer, S} = Intrinsics.subi(a, b) # broadcasted arithmetic (float) -for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf)) +for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf), (:max, :maxf), (:min, :minf)) @eval @inline function Base.Broadcast.broadcasted(::TileStyle, ::typeof($op), a::Tile{T,S1}, b::Tile{T,S2}) where {T<:AbstractFloat,S1,S2} S = broadcast_shape(S1, S2) Intrinsics.$intrinsic(broadcast_to(a, S), broadcast_to(b, S)) @@ -157,7 +158,7 @@ end @inline Base.:(/)(a::Tile{T, S}, b::Number) where {T <: AbstractFloat, S} = Intrinsics.divf(a, broadcast_to(Tile(T(b)), S)) # broadcasted arithmetic (float) -for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf)) +for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf), (:max, :maxf), (:min, :minf)) @eval begin @inline Base.Broadcast.broadcasted(::TileStyle, ::typeof($op), a::Tile{T,S}, b::Number) where {T<:AbstractFloat,S} = Intrinsics.$intrinsic(a, broadcast_to(Tile(T(b)), S)) From 1bd757ca70372c758aad8cc40780a6f9d541a9be Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 10 Jan 2026 17:39:48 +0100 Subject: [PATCH 2/8] add fmha --- examples/fmha.jl | 208 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 examples/fmha.jl diff --git a/examples/fmha.jl b/examples/fmha.jl new file mode 100644 index 0000000..ed49df0 --- /dev/null +++ b/examples/fmha.jl @@ -0,0 +1,208 @@ +# Batch matrix multiplication example - Julia port of cuTile Python's AttentionFMHA.py sample +# +# SPDX-License-Identifier: Apache-2.0 + +using CUDA +import cuTile as ct + +import NNlib + +const INV_LOG_2 = Float32(1 / log(2)) +const ConstInt = ct.Constant{Int} +const ConstBool = ct.Constant{Bool} + +# TODO: "latency" + +# cuTile kernel for Fused Multi-Head Attention +# Q: d x +function fmha_kernel( + Q::ct.TileArray{T,4}, K::ct.TileArray{T,4}, V::ct.TileArray{T,4}, Out::ct.TileArray{T,4}, + qk_scale::AbstractFloat, + input_pos::Integer, + TILE_D::ConstInt, + H::ConstInt, # number of heads? + TILE_M::ConstInt, + TILE_N::ConstInt, + QUERY_GROUP_SIZE::ConstInt, + CAUSAL::ConstBool, + EVEN_K::ConstBool +) where T + bid_x = ct.bid(1) + bid_y = ct.bid(2) + batch_idx = cld(bid_y, H[]) + head_idx = mod1(bid_y, H[]) + off_kv_h = cld(head_idx, QUERY_GROUP_SIZE[]) + + qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) + + # Offsets for query tile (M-dimension) + offs_m = bid_x * TILE_M[] .+ ct.arange((TILE_M[],), Int32) .+ input_pos + + # local offsets for key/value tile (N-dimension) + offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32), (1, TILE_N[])) + + # online softmax accumulators in Float32 for stability + m_i = ct.full((1, TILE_M[]), -Inf32, Float32) + l_i = ct.zeros((1, TILE_M[]), Float32) + acc = ct.zeros((TILE_D[], TILE_M[]), Float32) + + # query tile for this batch, head, and M-chunk + q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (TILE_D[], TILE_M[], 1, 1)) + q = ct.reshape(q, (TILE_D[], TILE_M[])) + + m_end = input_pos + (bid_x + 1) * TILE_M[] + k_seqlen = K.sizes[2] + if CAUSAL[] + # when kv pos could exceed q pos + mask_start = cld(input_pos + bid_x * TILE_M[], TILE_N[]) + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, cld(k_seqlen, TILE_N[])) + Tc = cld(min(m_end, k_seqlen), TILE_N[]) + else + Tc = cld(k_seqlen, TILE_N[]) + mask_start = cld(k_seqlen, TILE_N[]) + end + + # loop over K, V blocks (N-dimension chunks) + j = Int32(1) + while j <= Tc + k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + k = ct.reshape(k, (TILE_D[], TILE_N[])) + k = ct.transpose(k) + + qk = ct.zeros((TILE_N[], TILE_M[]), Float32) + qk = ct.muladd(k, q, qk) + + if (CAUSAL[] || !EVEN_K[]) && j >= mask_start + offs_n = j * TILE_N[] + offs_n_tile + mask = ct.full((TILE_N[], TILE_M[]), true, Bool) + if !EVEN_K[] + mask = mask .& (offs_n .< k_seqlen) + end + if CAUSAL[] + mask = mask .& (offs_m .>= offs_n) + end + mask = ct.where(mask, -Inf32, Float32) + qk = qk .+ mask + end + + # moving qk_scale multiplication after reduce_max + m_ij = max.(m_i, (ct.reduce_max(qk, 1) * qk_scale)) + qk = qk * qk_scale .- m_ij + + # attention weights [TILE_N, TILE_M] + p = exp2.(qk) # might need to expose "flush_to_zero" + l_ij = ct.reduce_sum(p, 1) + alpha = exp2.(m_i .- m_ij) # flush to zero? + + l_i = l_i .* alpha .+ l_ij + acc = acc .* alpha + + v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + v = ct.reshape(v, (TILE_D[], TILE_N[])) + p = ct.astype(p, eltype(q)) + acc = ct.muladd(v, p, acc) # [TILE_D, TILE_M] + m_i = m_ij + + j += Int32(1) + end + + acc = acc ./ l_i # flush to zero? rounding mode? + acc = ct.reshape(acc, (TILE_D[], TILE_M[], 1, 1)) + ct.store(Out, (1, bid_x, head_idx, batch_idx), acc) + + return +end + +function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; + qk_scale::Union{AbstractFloat,Nothing} = nothing, + input_pos::Integer = 0, + tile_m::Integer = 128, + tile_n::Integer = 128, + query_group_size::Integer = 1, + causal::Bool = false, +) where T + if size(Q, 4) != size(K, 4) || size(Q, 4) != size(V, 4) + throw(ArgumentError("Batch dimensions must match for Q, K, V.")) + end + if size(Q, 3) % query_group_size != 0 + throw(ArgumentError("Number of query heads must be divisible by query_group_size.")) + end + if size(K, 3) * query_group_size != size(Q, 3) + throw(ArgumentError("K_heads * query_group_size must equal Q_heads.")) + end + if size(Q, 1) != size(K, 1) + throw(ArgumentError("D_k (first dim of Q and K) must match.")) + end + if size(K, 2) != size(V, 2) + throw(ArgumentError("SeqLen_KV (dim 2 of K and V) must match.")) + end + + D_k, SeqLen_Q, Heads, Batch = size(Q) + D_v, SeqLen_KV, KV_heads, _ = size(V) + even_k = (SeqLen_KV % tile_n) == 0 + + isnothing(qk_scale) && (qk_scale = 1 / sqrt(D_k)) + + Out = CUDA.zeros(T, D_v, SeqLen_Q, Heads, Batch) + + grid_x = cld(SeqLen_Q, tile_m) + grid_y = Heads * Batch + grid = (grid_x, grid_y, 1) + + ct.launch(fmha_kernel, grid, + Q, K, V, Out, + qk_scale, input_pos, + ct.Constant(D_k), + ct.Constant(Heads), + ct.Constant(tile_m), + ct.Constant(tile_n), + ct.Constant(query_group_size), + ct.Constant(causal), + ct.Constant(even_k)) + + return Out +end + +function nnlib_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; + query_group_size::Integer = 1, + causal::Bool = false, +) where T + mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing + if query_group_size > 1 + K, V = repeat.((K, V), inner=(1, 1, query_group_size)) + end + Out, _ = NNlib.dot_product_attention(Q, K, V; mask) + return Out +end + + +function test_fmha(::Type{T}, + D_k, SeqLen_Q, Heads, Batch, + D_v, SeqLen_KV, KV_heads, + causal, tile_m, tile_n, +) where T + query_group_size = Heads ÷ KV_heads + + Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch) + K = CUDA.randn(T, D_k, SeqLen_KV, KV_heads, Batch) + V = CUDA.randn(T, D_v, SeqLen_KV, KV_heads, Batch) + + out_cutile = cutile_fmha(Q, K, V; + causal=causal, + tile_m=tile_m, tile_n=tile_n, + query_group_size=query_group_size) + + Q_cpu = Array(Q) + K_cpu = Array(K) + V_cpu = Array(V) + expected = nnlib_fmha(Q_cpu, K_cpu, V_cpu; query_group_size, causal) + result = Array(out_cutile) + + if isapprox(result, expected, rtol=1e-2, atol=1e-2) + println(" passed") + else + max_diff = maximum(abs.(result - expected)) + println(" FAILED (max diff: $max_diff)") + end +end From 905e73224acefedd45c34e7b1d6b145b5252e954 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 10 Jan 2026 18:13:10 +0100 Subject: [PATCH 3/8] fix tests --- examples/fmha.jl | 15 ++++++++++++++- test/Project.toml | 1 + 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/fmha.jl b/examples/fmha.jl index ed49df0..97c27a3 100644 --- a/examples/fmha.jl +++ b/examples/fmha.jl @@ -170,7 +170,7 @@ function nnlib_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArr ) where T mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing if query_group_size > 1 - K, V = repeat.((K, V), inner=(1, 1, query_group_size)) + K, V = repeat.((K, V), inner=(1, 1, query_group_size, 1)) end Out, _ = NNlib.dot_product_attention(Q, K, V; mask) return Out @@ -206,3 +206,16 @@ function test_fmha(::Type{T}, println(" FAILED (max diff: $max_diff)") end end + +function main() + println("--- cuTile Fused Multi-Head Attention Examples ---\n") + + # Float32 tests, causal=false + test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, false, 32, 32) + test_fmha(Float32, 64, 256, 8, 2, 64, 128, 8, false, 32, 32) + test_fmha(Float32, 64, 256, 8, 2, 64, 128, 4, false, 32, 32) + + println("\n--- All batch matmul examples completed ---") +end + +isinteractive() || main() diff --git a/test/Project.toml b/test/Project.toml index c98454c..e340931 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FileCheck = "4e644321-382b-4b05-b0b6-5d23c3d944fb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From d6f9d9b3f4956cb5596760b582cbfa7d14bf7ab7 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Sat, 10 Jan 2026 17:42:49 +0100 Subject: [PATCH 4/8] Update fmha.jl --- examples/fmha.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fmha.jl b/examples/fmha.jl index 97c27a3..e069891 100644 --- a/examples/fmha.jl +++ b/examples/fmha.jl @@ -1,4 +1,4 @@ -# Batch matrix multiplication example - Julia port of cuTile Python's AttentionFMHA.py sample +# Fused Multi-Head Attention example - Julia port of cuTile Python's AttentionFMHA.py sample # # SPDX-License-Identifier: Apache-2.0 From 5c950fd5829376d3068dcc6b7a30616fd2b71077 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 26 Jan 2026 14:49:26 +0100 Subject: [PATCH 5/8] fix off-by-ones --- examples/fmha.jl | 90 +++++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/examples/fmha.jl b/examples/fmha.jl index e069891..fb602d7 100644 --- a/examples/fmha.jl +++ b/examples/fmha.jl @@ -11,16 +11,14 @@ const INV_LOG_2 = Float32(1 / log(2)) const ConstInt = ct.Constant{Int} const ConstBool = ct.Constant{Bool} -# TODO: "latency" - # cuTile kernel for Fused Multi-Head Attention -# Q: d x +# Layout: (D, SeqLen, Heads, Batch) - Julia column-major function fmha_kernel( Q::ct.TileArray{T,4}, K::ct.TileArray{T,4}, V::ct.TileArray{T,4}, Out::ct.TileArray{T,4}, qk_scale::AbstractFloat, input_pos::Integer, TILE_D::ConstInt, - H::ConstInt, # number of heads? + H::ConstInt, TILE_M::ConstInt, TILE_N::ConstInt, QUERY_GROUP_SIZE::ConstInt, @@ -35,11 +33,14 @@ function fmha_kernel( qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) - # Offsets for query tile (M-dimension) - offs_m = bid_x * TILE_M[] .+ ct.arange((TILE_M[],), Int32) .+ input_pos + # Offsets for query tile (M-dimension) - 0-indexed positions + # bid_x is 1-indexed, so first tile (bid_x=1) has positions [0, TILE_M-1] + # NOTE: ct.arange is 1-indexed in Julia (returns [1,2,...,N]), so subtract 1 for 0-indexed + offs_m = ct.reshape((bid_x - Int32(1)) * TILE_M[] .+ (ct.arange((TILE_M[],), Int32) .- Int32(1)) .+ input_pos, (1, TILE_M[])) # local offsets for key/value tile (N-dimension) - offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32), (1, TILE_N[])) + # NOTE: ct.arange is 1-indexed in Julia, subtract 1 for 0-indexed + offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32) .- Int32(1), (TILE_N[], 1)) # online softmax accumulators in Float32 for stability m_i = ct.full((1, TILE_M[]), -Inf32, Float32) @@ -50,64 +51,73 @@ function fmha_kernel( q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (TILE_D[], TILE_M[], 1, 1)) q = ct.reshape(q, (TILE_D[], TILE_M[])) - m_end = input_pos + (bid_x + 1) * TILE_M[] + # m_end: one past the last query position in this tile + m_end = input_pos + bid_x * TILE_M[] k_seqlen = K.sizes[2] + if CAUSAL[] - # when kv pos could exceed q pos - mask_start = cld(input_pos + bid_x * TILE_M[], TILE_N[]) - # when kv pos could exceed k_seqlen - mask_start = min(mask_start, cld(k_seqlen, TILE_N[])) + # Python: mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # In Julia with 1-indexed bid_x: mask_start = (input_pos + (bid_x-1) * TILE_M) // TILE_N + 1 + mask_start = div(input_pos + (bid_x - Int32(1)) * TILE_M[], TILE_N[]) + Int32(1) + # Python: mask_start = min(mask_start, k_seqlen // TILE_N) + mask_start = min(mask_start, div(k_seqlen, TILE_N[]) + Int32(1)) Tc = cld(min(m_end, k_seqlen), TILE_N[]) else Tc = cld(k_seqlen, TILE_N[]) - mask_start = cld(k_seqlen, TILE_N[]) + # Python: mask_start = k_seqlen // TILE_N + mask_start = div(k_seqlen, TILE_N[]) + Int32(1) end # loop over K, V blocks (N-dimension chunks) j = Int32(1) while j <= Tc - k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1); padding_mode=ct.PaddingMode.Zero) k = ct.reshape(k, (TILE_D[], TILE_N[])) k = ct.transpose(k) qk = ct.zeros((TILE_N[], TILE_M[]), Float32) qk = ct.muladd(k, q, qk) + # Apply masking (matches Python: if (CAUSAL or not EVEN_K) and j >= mask_start) if (CAUSAL[] || !EVEN_K[]) && j >= mask_start - offs_n = j * TILE_N[] + offs_n_tile - mask = ct.full((TILE_N[], TILE_M[]), true, Bool) + offs_n = (j - Int32(1)) * TILE_N[] .+ offs_n_tile + # Build mask: start with all true + valid_mask = ct.full((TILE_N[], TILE_M[]), true, Bool) + # out of bound mask (Python: if not EVEN_K: mask = mask & (offs_n < k_seqlen)) if !EVEN_K[] - mask = mask .& (offs_n .< k_seqlen) + valid_mask = valid_mask .& (offs_n .< k_seqlen) end + # causal mask (Python: if CAUSAL: mask = mask & (offs_m >= offs_n)) if CAUSAL[] - mask = mask .& (offs_m .>= offs_n) + valid_mask = valid_mask .& (offs_m .>= offs_n) end - mask = ct.where(mask, -Inf32, Float32) - qk = qk .+ mask + # Apply mask: set invalid positions to -Inf + qk = ct.where(valid_mask, qk, ct.full((TILE_N[], TILE_M[]), -Inf32, Float32)) end - # moving qk_scale multiplication after reduce_max + # Online Softmax Update + # Moving qk_scale multiplication after reduce_max is to improve performance m_ij = max.(m_i, (ct.reduce_max(qk, 1) * qk_scale)) qk = qk * qk_scale .- m_ij - + # attention weights [TILE_N, TILE_M] - p = exp2.(qk) # might need to expose "flush_to_zero" + p = exp2.(qk) # TODO: flush_to_zero=True l_ij = ct.reduce_sum(p, 1) - alpha = exp2.(m_i .- m_ij) # flush to zero? + alpha = exp2.(m_i .- m_ij) # TODO: flush_to_zero=True l_i = l_i .* alpha .+ l_ij acc = acc .* alpha - v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1); padding_mode=ct.PaddingMode.Zero) v = ct.reshape(v, (TILE_D[], TILE_N[])) p = ct.astype(p, eltype(q)) - acc = ct.muladd(v, p, acc) # [TILE_D, TILE_M] + acc = ct.muladd(v, p, acc) m_i = m_ij j += Int32(1) end - acc = acc ./ l_i # flush to zero? rounding mode? + acc = acc ./ l_i # TODO: flush_to_zero=True, rounding_mode=APPROX acc = ct.reshape(acc, (TILE_D[], TILE_M[], 1, 1)) ct.store(Out, (1, bid_x, head_idx, batch_idx), acc) @@ -149,7 +159,7 @@ function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractAr grid_x = cld(SeqLen_Q, tile_m) grid_y = Heads * Batch grid = (grid_x, grid_y, 1) - + ct.launch(fmha_kernel, grid, Q, K, V, Out, qk_scale, input_pos, @@ -160,7 +170,7 @@ function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractAr ct.Constant(query_group_size), ct.Constant(causal), ct.Constant(even_k)) - + return Out end @@ -187,7 +197,7 @@ function test_fmha(::Type{T}, Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch) K = CUDA.randn(T, D_k, SeqLen_KV, KV_heads, Batch) V = CUDA.randn(T, D_v, SeqLen_KV, KV_heads, Batch) - + out_cutile = cutile_fmha(Q, K, V; causal=causal, tile_m=tile_m, tile_n=tile_n, @@ -210,12 +220,28 @@ end function main() println("--- cuTile Fused Multi-Head Attention Examples ---\n") - # Float32 tests, causal=false + # Float32 tests, causal=false, EVEN_K=true + println("Non-causal, EVEN_K=true:") test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, false, 32, 32) test_fmha(Float32, 64, 256, 8, 2, 64, 128, 8, false, 32, 32) test_fmha(Float32, 64, 256, 8, 2, 64, 128, 4, false, 32, 32) - println("\n--- All batch matmul examples completed ---") + # Float32 tests, causal=true, EVEN_K=true + println("\nCausal, EVEN_K=true:") + test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, true, 32, 32) + test_fmha(Float32, 64, 128, 8, 2, 64, 128, 8, true, 32, 32) + test_fmha(Float32, 64, 64, 4, 1, 64, 64, 4, true, 32, 32) + + # Float32 tests, EVEN_K=false (K_SEQLEN not divisible by tile_n) + println("\nNon-causal, EVEN_K=false:") + test_fmha(Float32, 64, 64, 1, 1, 64, 50, 1, false, 32, 32) # 18 keys in partial tile + test_fmha(Float32, 64, 64, 1, 1, 64, 33, 1, false, 32, 32) # 1 key in partial tile + + println("\nCausal, EVEN_K=false:") + test_fmha(Float32, 64, 50, 1, 1, 64, 50, 1, true, 32, 32) # 18 keys in partial tile + test_fmha(Float32, 64, 33, 1, 1, 64, 33, 1, true, 32, 32) # 1 key in partial tile + + println("\n--- All FMHA tests completed ---") end isinteractive() || main() From 3fd6f0b98e8c0f7f7513cd7b1426a7e8a4481003 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Thu, 5 Feb 2026 19:52:42 +0100 Subject: [PATCH 6/8] fmha.jl -> attention.jl --- examples/attention.jl | 289 ++++++++++++++++++++++++++++++++++++++++++ examples/fmha.jl | 247 ------------------------------------ 2 files changed, 289 insertions(+), 247 deletions(-) create mode 100644 examples/attention.jl delete mode 100644 examples/fmha.jl diff --git a/examples/attention.jl b/examples/attention.jl new file mode 100644 index 0000000..ad58df5 --- /dev/null +++ b/examples/attention.jl @@ -0,0 +1,289 @@ +# Fused Multi-Head Attention example - Julia port of cuTile Python's AttentionFMHA.py sample +# +# SPDX-License-Identifier: Apache-2.0 + +using CUDA +import cuTile as ct + +import NNlib +import CUDA.GPUArrays: AllocCache, @cached # more fair NNlib comparison + +const INV_LOG_2 = Float32(1 / log(2)) +const ConstInt = ct.Constant{Int} +const ConstBool = ct.Constant{Bool} + +# cuTile kernel for Fused Multi-Head Attention (FMHA) +# +# Computes attention output for a psecific batch item and head, +# using tiling and online softmax. +# +# Layout: (D, SeqLen, Heads, Batch) +function fmha_kernel( + Q::ct.TileArray{T,4}, + K::ct.TileArray{T,4}, + V::ct.TileArray{T,4}, + Out::ct.TileArray{T,4}, + qk_scale::AbstractFloat, + input_pos::Integer, + D_K::ConstInt, # Head dimension of Q and K + D_V::ConstInt, # Head dimension of V + H::ConstInt, + TILE_M::ConstInt, + TILE_N::ConstInt, + QUERY_GROUP_SIZE::ConstInt, + CAUSAL::ConstBool, + EVEN_K::ConstBool +) where T + # Map block IDs to batch and head indices + bid_x = ct.bid(1) + bid_y = ct.bid(2) + batch_idx, head_idx = fldmod1(bid_y, H[]) # floored division and modulus for 1-based indexing + off_kv_h = cld(head_idx, QUERY_GROUP_SIZE[]) + + # Adjust qk_scale for exp2 + qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) + + # Initialize offsets for current query tile (M-dimension) + # bid_x is 1-indexed, so first tile (bid_x=1) has offsets [0, TILE_M-1] + offs_m = (bid_x - 1) * TILE_M[] .+ (ct.arange((TILE_M[],), Int32) .- 1) + offs_m = offs_m .+ input_pos + offs_m = reshape(offs_m, (1, TILE_M[])) + + # local offsets for key/value tile (N-dimension) + offs_n_tile = ct.arange((TILE_N[],), Int32) .- 1 + offs_n_tile = reshape(offs_n_tile, (TILE_N[], 1)) + + # online softmax accumulators in Float32 for stability + m_i = ct.full((1, TILE_M[]), -Inf32, Float32) + l_i = ct.zeros((1, TILE_M[]), Float32) + acc = ct.zeros((D_V[], TILE_M[]), Float32) + + # query tile for this batch, head, and M-chunk + q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (D_K[], TILE_M[], 1, 1)) + q = reshape(q, (D_K[], TILE_M[])) + + # m_end: one past the last query position in this tile + m_end = input_pos + bid_x * TILE_M[] + k_seqlen = K.sizes[2] + if CAUSAL[] + # Python: mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # In Julia with 1-indexed bid_x: mask_start = (input_pos + (bid_x-1) * TILE_M) // TILE_N + 1 + mask_start = fld(input_pos + (bid_x - 1) * TILE_M[], TILE_N[]) + 1 + # Python: mask_start = min(mask_start, k_seqlen // TILE_N) + mask_start = min(mask_start, fld(k_seqlen, TILE_N[]) + 1) + Tc = cld(min(m_end, k_seqlen), TILE_N[]) + else + Tc = cld(k_seqlen, TILE_N[]) + # Python: mask_start = k_seqlen // TILE_N + mask_start = fld(k_seqlen, TILE_N[]) + 1 + end + + # loop over K, V blocks (N-dimension chunks) + j = Int32(1) + while j <= Tc + k = ct.load( + K, (1, j, off_kv_h, batch_idx), (D_K[], TILE_N[], 1, 1), + latency=2) + k = reshape(k, (D_K[], TILE_N[])) + k = transpose(k) + + qk = ct.zeros((TILE_N[], TILE_M[]), Float32) + qk = ct.muladd(k, q, qk) + + # Apply masking (matches Python: if (CAUSAL or not EVEN_K) and j >= mask_start) + if (CAUSAL[] || !EVEN_K[]) && j >= mask_start + offs_n = (j - 1) * TILE_N[] .+ offs_n_tile + # Build mask: start with all true + mask = ct.full((TILE_N[], TILE_M[]), true, Bool) + # out of bound mask (Python: if not EVEN_K: mask = mask & (offs_n < k_seqlen)) + if !EVEN_K[] + mask = mask .& (offs_n .< k_seqlen) + end + # causal mask (Python: if CAUSAL: mask = mask & (offs_m >= offs_n)) + if CAUSAL[] + mask = mask .& (offs_m .>= offs_n) + end + # Apply mask: set invalid positions to -Inf + qk = ifelse.(mask, qk, -Inf32) + end + + # Online Softmax Update + # Moving qk_scale multiplication after reduce_max is to improve performance + m_ij = max.(m_i, maximum(qk, dims=1) * qk_scale) + qk = qk * qk_scale .- m_ij + + # attention weights [TILE_N, TILE_M] + p = exp2.(qk) # XXX: flush_to_zero=True + l_ij = sum(p, dims=1) + alpha = exp2.(m_i .- m_ij) # XXX: flush_to_zero=True + + l_i = l_i .* alpha .+ l_ij + acc = acc .* alpha + + v = ct.load( + V, (1, j, off_kv_h, batch_idx), (D_V[], TILE_N[], 1, 1), + latency=4) + v = reshape(v, (D_V[], TILE_N[])) + p = ct.astype(p, eltype(q)) + acc = ct.muladd(v, p, acc) + m_i = m_ij + + j += Int32(1) + end + + acc = acc ./ l_i # XXX: flush_to_zero=True, rounding_mode=APPROX + acc = reshape(acc, (D_V[], TILE_M[], 1, 1)) + acc = ct.astype(acc, eltype(Out)) + ct.store(Out, (1, bid_x, head_idx, batch_idx), acc) + + return +end + +function prepare(; benchmark::Bool=false, + D_k::Int=64, + SeqLen_Q::Int=benchmark ? 4096 : 256, + Heads::Int=4, + Batch::Int=4, + D_v::Int=D_k, + SeqLen_KV::Int=SeqLen_Q, + Heads_KV::Int=Heads, + causal::Bool=false, + T::DataType=Float32) + return (; + Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch), + K = CUDA.randn(T, D_k, SeqLen_KV, Heads_KV, Batch), + V = CUDA.randn(T, D_v, SeqLen_KV, Heads_KV, Batch), + Out = CUDA.randn(T, D_v, SeqLen_Q, Heads, Batch), + D_k, SeqLen_Q, Heads, Batch, + D_v, SeqLen_KV, Heads_KV, causal + ) +end + +function run(data; tm::Int=64, tn::Int=64, nruns::Int=1, warmup::Int=0) + (; Q, K, V, Out, D_k, D_v, SeqLen_Q, Heads, Batch, SeqLen_KV, Heads_KV, causal) = data + grid_x = cld(SeqLen_Q, tm) + grid_y = Heads * Batch + grid = (grid_x, grid_y) + + qk_scale = 1 / sqrt(D_k) + input_pos = 0 + + query_group_size, remainder = divrem(Heads, Heads_KV) + @assert remainder == 0 + + even_k = (SeqLen_KV % tn) == 0 + + CUDA.@sync for _ in 1:warmup + ct.launch(fmha_kernel, grid, Q, K, V, Out, + qk_scale, input_pos, + ct.Constant(D_k), ct.Constant(D_v), ct.Constant(Heads), + ct.Constant(tm), ct.Constant(tn), + ct.Constant(query_group_size), + ct.Constant(causal), ct.Constant(even_k)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(fmha_kernel, grid, Q, K, V, Out, + qk_scale, input_pos, + ct.Constant(D_k), ct.Constant(D_v), ct.Constant(Heads), + ct.Constant(tm), ct.Constant(tn), + ct.Constant(query_group_size), + ct.Constant(causal), ct.Constant(even_k)) + push!(times, t * 1000) + end + + return (; Out, times) +end + +function nnlib_attention( + Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; + causal::Bool = false, +) where T + mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing + query_group_size = cld(size(Q, 3), size(K, 3)) + if query_group_size > 1 + K, V = repeat.((K, V), inner=(1, 1, query_group_size, 1)) + end + Out, _ = NNlib.dot_product_attention(Q, K, V; mask) + return Out +end + +function verify(data, result) + # run on GPU for proper accumulation + expected = nnlib_attention(data.Q, data.K, data.V; data.causal) + @assert isapprox(expected, result.Out, rtol=1e-2) "max diff: $(maximum(abs, result.Out - expected))" +end + +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; Q, K, V, causal) = data + results = Dict{String, Vector{Float64}}() + + cache = AllocCache() + + CUDA.@sync for _ in 1:warmup + @cached cache nnlib_attention(Q, K, V; causal) + end + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed @cached cache nnlib_attention(Q, K, V; causal) + push!(times, t * 1000) + end + results["NNlib"] = times + + return results +end + +#============================================================================= + Main +=============================================================================# + +function test_attention(::Type{T}, + D_k, SeqLen_Q, Heads, Batch, + D_v, SeqLen_KV, Heads_KV, + causal, tm, tn; + name=nothing +) where T + name = something(name, + join([ + T, + "tile=$tm×$tn", + "Q=$D_k×$SeqLen_Q", + "K=$D_k×$SeqLen_KV", + "V=$D_v×$SeqLen_KV", + "Heads=$Heads/$Heads_KV", + "Batch=$Batch", + "causal=$causal" + ], ", ")) + println("--- $name ---") + data = prepare(; T, D_k, SeqLen_Q, Heads, Batch, D_v, SeqLen_KV, Heads_KV, causal) + result = run(data; tm, tn) + verify(data, result) + println(" passed") +end + +function main() + println("--- cuTile Fused Multi-Head Attention Examples ---\n") + + for T in (Float32, Float16) + # basic + test_attention(T, 64, 256, 8, 2, 64, 256, 8, false, 32, 32) + test_attention(T, 64, 256, 8, 2, 64, 128, 4, false, 32, 64) + test_attention(T, 64, 256, 8, 2, 64, 256, 8, true, 32, 32) + + # uneven seqlen + test_attention(T, 64, 128, 4, 1, 64, 97, 2, false, 32, 32) + test_attention(T, 64, 127, 4, 1, 64, 127, 4, true, 32, 32) + + # D_k != D_v + test_attention(T, 64, 256, 8, 2, 32, 256, 4, false, 32, 32) + end + + println("\n--- All attention examples completed ---") +end + +isinteractive() || main() diff --git a/examples/fmha.jl b/examples/fmha.jl deleted file mode 100644 index fb602d7..0000000 --- a/examples/fmha.jl +++ /dev/null @@ -1,247 +0,0 @@ -# Fused Multi-Head Attention example - Julia port of cuTile Python's AttentionFMHA.py sample -# -# SPDX-License-Identifier: Apache-2.0 - -using CUDA -import cuTile as ct - -import NNlib - -const INV_LOG_2 = Float32(1 / log(2)) -const ConstInt = ct.Constant{Int} -const ConstBool = ct.Constant{Bool} - -# cuTile kernel for Fused Multi-Head Attention -# Layout: (D, SeqLen, Heads, Batch) - Julia column-major -function fmha_kernel( - Q::ct.TileArray{T,4}, K::ct.TileArray{T,4}, V::ct.TileArray{T,4}, Out::ct.TileArray{T,4}, - qk_scale::AbstractFloat, - input_pos::Integer, - TILE_D::ConstInt, - H::ConstInt, - TILE_M::ConstInt, - TILE_N::ConstInt, - QUERY_GROUP_SIZE::ConstInt, - CAUSAL::ConstBool, - EVEN_K::ConstBool -) where T - bid_x = ct.bid(1) - bid_y = ct.bid(2) - batch_idx = cld(bid_y, H[]) - head_idx = mod1(bid_y, H[]) - off_kv_h = cld(head_idx, QUERY_GROUP_SIZE[]) - - qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) - - # Offsets for query tile (M-dimension) - 0-indexed positions - # bid_x is 1-indexed, so first tile (bid_x=1) has positions [0, TILE_M-1] - # NOTE: ct.arange is 1-indexed in Julia (returns [1,2,...,N]), so subtract 1 for 0-indexed - offs_m = ct.reshape((bid_x - Int32(1)) * TILE_M[] .+ (ct.arange((TILE_M[],), Int32) .- Int32(1)) .+ input_pos, (1, TILE_M[])) - - # local offsets for key/value tile (N-dimension) - # NOTE: ct.arange is 1-indexed in Julia, subtract 1 for 0-indexed - offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32) .- Int32(1), (TILE_N[], 1)) - - # online softmax accumulators in Float32 for stability - m_i = ct.full((1, TILE_M[]), -Inf32, Float32) - l_i = ct.zeros((1, TILE_M[]), Float32) - acc = ct.zeros((TILE_D[], TILE_M[]), Float32) - - # query tile for this batch, head, and M-chunk - q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (TILE_D[], TILE_M[], 1, 1)) - q = ct.reshape(q, (TILE_D[], TILE_M[])) - - # m_end: one past the last query position in this tile - m_end = input_pos + bid_x * TILE_M[] - k_seqlen = K.sizes[2] - - if CAUSAL[] - # Python: mask_start = (input_pos + bid_x * TILE_M) // TILE_N - # In Julia with 1-indexed bid_x: mask_start = (input_pos + (bid_x-1) * TILE_M) // TILE_N + 1 - mask_start = div(input_pos + (bid_x - Int32(1)) * TILE_M[], TILE_N[]) + Int32(1) - # Python: mask_start = min(mask_start, k_seqlen // TILE_N) - mask_start = min(mask_start, div(k_seqlen, TILE_N[]) + Int32(1)) - Tc = cld(min(m_end, k_seqlen), TILE_N[]) - else - Tc = cld(k_seqlen, TILE_N[]) - # Python: mask_start = k_seqlen // TILE_N - mask_start = div(k_seqlen, TILE_N[]) + Int32(1) - end - - # loop over K, V blocks (N-dimension chunks) - j = Int32(1) - while j <= Tc - k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1); padding_mode=ct.PaddingMode.Zero) - k = ct.reshape(k, (TILE_D[], TILE_N[])) - k = ct.transpose(k) - - qk = ct.zeros((TILE_N[], TILE_M[]), Float32) - qk = ct.muladd(k, q, qk) - - # Apply masking (matches Python: if (CAUSAL or not EVEN_K) and j >= mask_start) - if (CAUSAL[] || !EVEN_K[]) && j >= mask_start - offs_n = (j - Int32(1)) * TILE_N[] .+ offs_n_tile - # Build mask: start with all true - valid_mask = ct.full((TILE_N[], TILE_M[]), true, Bool) - # out of bound mask (Python: if not EVEN_K: mask = mask & (offs_n < k_seqlen)) - if !EVEN_K[] - valid_mask = valid_mask .& (offs_n .< k_seqlen) - end - # causal mask (Python: if CAUSAL: mask = mask & (offs_m >= offs_n)) - if CAUSAL[] - valid_mask = valid_mask .& (offs_m .>= offs_n) - end - # Apply mask: set invalid positions to -Inf - qk = ct.where(valid_mask, qk, ct.full((TILE_N[], TILE_M[]), -Inf32, Float32)) - end - - # Online Softmax Update - # Moving qk_scale multiplication after reduce_max is to improve performance - m_ij = max.(m_i, (ct.reduce_max(qk, 1) * qk_scale)) - qk = qk * qk_scale .- m_ij - - # attention weights [TILE_N, TILE_M] - p = exp2.(qk) # TODO: flush_to_zero=True - l_ij = ct.reduce_sum(p, 1) - alpha = exp2.(m_i .- m_ij) # TODO: flush_to_zero=True - - l_i = l_i .* alpha .+ l_ij - acc = acc .* alpha - - v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1); padding_mode=ct.PaddingMode.Zero) - v = ct.reshape(v, (TILE_D[], TILE_N[])) - p = ct.astype(p, eltype(q)) - acc = ct.muladd(v, p, acc) - m_i = m_ij - - j += Int32(1) - end - - acc = acc ./ l_i # TODO: flush_to_zero=True, rounding_mode=APPROX - acc = ct.reshape(acc, (TILE_D[], TILE_M[], 1, 1)) - ct.store(Out, (1, bid_x, head_idx, batch_idx), acc) - - return -end - -function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; - qk_scale::Union{AbstractFloat,Nothing} = nothing, - input_pos::Integer = 0, - tile_m::Integer = 128, - tile_n::Integer = 128, - query_group_size::Integer = 1, - causal::Bool = false, -) where T - if size(Q, 4) != size(K, 4) || size(Q, 4) != size(V, 4) - throw(ArgumentError("Batch dimensions must match for Q, K, V.")) - end - if size(Q, 3) % query_group_size != 0 - throw(ArgumentError("Number of query heads must be divisible by query_group_size.")) - end - if size(K, 3) * query_group_size != size(Q, 3) - throw(ArgumentError("K_heads * query_group_size must equal Q_heads.")) - end - if size(Q, 1) != size(K, 1) - throw(ArgumentError("D_k (first dim of Q and K) must match.")) - end - if size(K, 2) != size(V, 2) - throw(ArgumentError("SeqLen_KV (dim 2 of K and V) must match.")) - end - - D_k, SeqLen_Q, Heads, Batch = size(Q) - D_v, SeqLen_KV, KV_heads, _ = size(V) - even_k = (SeqLen_KV % tile_n) == 0 - - isnothing(qk_scale) && (qk_scale = 1 / sqrt(D_k)) - - Out = CUDA.zeros(T, D_v, SeqLen_Q, Heads, Batch) - - grid_x = cld(SeqLen_Q, tile_m) - grid_y = Heads * Batch - grid = (grid_x, grid_y, 1) - - ct.launch(fmha_kernel, grid, - Q, K, V, Out, - qk_scale, input_pos, - ct.Constant(D_k), - ct.Constant(Heads), - ct.Constant(tile_m), - ct.Constant(tile_n), - ct.Constant(query_group_size), - ct.Constant(causal), - ct.Constant(even_k)) - - return Out -end - -function nnlib_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; - query_group_size::Integer = 1, - causal::Bool = false, -) where T - mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing - if query_group_size > 1 - K, V = repeat.((K, V), inner=(1, 1, query_group_size, 1)) - end - Out, _ = NNlib.dot_product_attention(Q, K, V; mask) - return Out -end - - -function test_fmha(::Type{T}, - D_k, SeqLen_Q, Heads, Batch, - D_v, SeqLen_KV, KV_heads, - causal, tile_m, tile_n, -) where T - query_group_size = Heads ÷ KV_heads - - Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch) - K = CUDA.randn(T, D_k, SeqLen_KV, KV_heads, Batch) - V = CUDA.randn(T, D_v, SeqLen_KV, KV_heads, Batch) - - out_cutile = cutile_fmha(Q, K, V; - causal=causal, - tile_m=tile_m, tile_n=tile_n, - query_group_size=query_group_size) - - Q_cpu = Array(Q) - K_cpu = Array(K) - V_cpu = Array(V) - expected = nnlib_fmha(Q_cpu, K_cpu, V_cpu; query_group_size, causal) - result = Array(out_cutile) - - if isapprox(result, expected, rtol=1e-2, atol=1e-2) - println(" passed") - else - max_diff = maximum(abs.(result - expected)) - println(" FAILED (max diff: $max_diff)") - end -end - -function main() - println("--- cuTile Fused Multi-Head Attention Examples ---\n") - - # Float32 tests, causal=false, EVEN_K=true - println("Non-causal, EVEN_K=true:") - test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, false, 32, 32) - test_fmha(Float32, 64, 256, 8, 2, 64, 128, 8, false, 32, 32) - test_fmha(Float32, 64, 256, 8, 2, 64, 128, 4, false, 32, 32) - - # Float32 tests, causal=true, EVEN_K=true - println("\nCausal, EVEN_K=true:") - test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, true, 32, 32) - test_fmha(Float32, 64, 128, 8, 2, 64, 128, 8, true, 32, 32) - test_fmha(Float32, 64, 64, 4, 1, 64, 64, 4, true, 32, 32) - - # Float32 tests, EVEN_K=false (K_SEQLEN not divisible by tile_n) - println("\nNon-causal, EVEN_K=false:") - test_fmha(Float32, 64, 64, 1, 1, 64, 50, 1, false, 32, 32) # 18 keys in partial tile - test_fmha(Float32, 64, 64, 1, 1, 64, 33, 1, false, 32, 32) # 1 key in partial tile - - println("\nCausal, EVEN_K=false:") - test_fmha(Float32, 64, 50, 1, 1, 64, 50, 1, true, 32, 32) # 18 keys in partial tile - test_fmha(Float32, 64, 33, 1, 1, 64, 33, 1, true, 32, 32) # 1 key in partial tile - - println("\n--- All FMHA tests completed ---") -end - -isinteractive() || main() From 1e40c305b09465263f456e9a3648288568ad44cb Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Thu, 5 Feb 2026 19:53:05 +0100 Subject: [PATCH 7/8] add attention.py, fix transpose.py --- examples/attention.py | 337 ++++++++++++++++++++++++++++++++++++++++++ examples/transpose.py | 2 +- 2 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 examples/attention.py diff --git a/examples/attention.py b/examples/attention.py new file mode 100644 index 0000000..4039add --- /dev/null +++ b/examples/attention.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Fused Multi-Head Attention example - cuTile Python +Julia port equivalent with prepare/run/verify pattern for benchmarking. +""" + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +import cuda.tile as ct +from cuda.tile import RoundingMode as RMd +from math import ceil, sqrt + +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import sdpa_kernel, SDPBackend + +INV_LOG_2 = 1.0 / np.log(2) +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] + + +@ct.kernel(occupancy=2) +def fmha_kernel(Q, K, V, Out, + qk_scale: float, + input_pos: int, + D_K: ConstInt, # Head dimension of Q and K + D_V: ConstInt, # Head dimension of V + H: ConstInt, + TILE_M: ConstInt, + TILE_N: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, + EVEN_K: ConstBool): + """ + cuTile kernel for Fused Multi-Head Attention (FMHA). + Computes attention output for a specific batch item and head, using tiling and online softmax. + + Layout: (Batch, Heads, SeqLen, D) + """ + # Map block IDs to batch and head indices + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + off_kv_h = head_idx // QUERY_GROUP_SIZE + + # Adjust qk_scale for exp2 + qk_scale = qk_scale * INV_LOG_2 + + # Initialize offsets for current query tile (M-dimension) + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + + # Initialize local offsets for key/value tile (N-dimension) + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # Initialize online softmax accumulators in float32 for stability + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, D_V), 0.0, dtype=np.float32) + + # Load query tile for this batch, head, and M-chunk + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, D_K) + ).reshape((TILE_M, D_K)) # [TILE_M, D_K] + + # loop over k, v and update accumulator + m_end = input_pos + (bid_x + 1) * TILE_M + k_seqlen = K.shape[2] + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdiv(k_seqlen, TILE_N) + mask_start = k_seqlen // TILE_N + + # Loop over K, V blocks (N-dimension chunks) + for j in range(0, Tc): + # --- Compute QK product --- + k = ct.load( + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, D_K, TILE_N), + order=(0, 1, 3, 2), + latency=2, + ) + k = k.reshape((D_K, TILE_N)) # [D_K, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # --- Apply Causal Masking --- + if (CAUSAL or not EVEN_K) and j >= mask_start: + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask + + # --- Online Softmax Update --- + # Moving qk_scale multiplication after reduce_max is to improve performance. + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] + + # attention weights + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] + # update m_i and l_i + l_i = l_i * alpha + l_ij # [TILE_M, 1] + # scale acc + acc = acc * alpha # [TILE_M, D_V] + + # --- Compute PV product --- + v = ct.load( + V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, D_V), + latency=4, + ).reshape((TILE_N, D_V)) # [TILE_N, D_V] + p = p.astype(Q.dtype) + acc = ct.mma(p, v, acc) # [TILE_M, D_V] + m_i = m_ij # [TILE_M, 1] + + # --- Final Normalization and Store --- + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = acc.reshape((1, 1, TILE_M, D_V)).astype(Out.dtype) + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, + D_k: int = 64, + SeqLen_Q: int = None, + Heads: int = 4, + Batch: int = 4, + D_v: int = None, + SeqLen_KV: int = None, + Heads_KV: int = None, + causal: bool = False, + dtype=torch.float32): + """Allocate and initialize data for FMHA.""" + if SeqLen_Q is None: + SeqLen_Q = 4096 if benchmark else 256 + if D_v is None: + D_v = D_k + if SeqLen_KV is None: + SeqLen_KV = SeqLen_Q + if Heads_KV is None: + Heads_KV = Heads + + # Layout: (Batch, Heads, SeqLen, D) + return { + "Q": torch.randn(Batch, Heads, SeqLen_Q, D_k, dtype=dtype, device='cuda'), + "K": torch.randn(Batch, Heads_KV, SeqLen_KV, D_k, dtype=dtype, device='cuda'), + "V": torch.randn(Batch, Heads_KV, SeqLen_KV, D_v, dtype=dtype, device='cuda'), + "Out": torch.empty(Batch, Heads, SeqLen_Q, D_v, dtype=dtype, device='cuda'), + "D_k": D_k, + "D_v": D_v, + "SeqLen_Q": SeqLen_Q, + "SeqLen_KV": SeqLen_KV, + "Heads": Heads, + "Heads_KV": Heads_KV, + "Batch": Batch, + "causal": causal, + } + + +def run(data, *, tm: int = 64, tn: int = 64, nruns: int = 1, warmup: int = 0): + """Run FMHA kernel with timing.""" + Q, K, V, Out = data["Q"], data["K"], data["V"], data["Out"] + D_k, D_v = data["D_k"], data["D_v"] + SeqLen_Q, SeqLen_KV = data["SeqLen_Q"], data["SeqLen_KV"] + Heads, Heads_KV, Batch = data["Heads"], data["Heads_KV"], data["Batch"] + causal = data["causal"] + + grid_x = ceil(SeqLen_Q / tm) + grid_y = Heads * Batch + grid = (grid_x, grid_y, 1) + + qk_scale = 1.0 / sqrt(D_k) + input_pos = 0 + + query_group_size, remainder = divmod(Heads, Heads_KV) + assert remainder == 0, "Heads must be divisible by Heads_KV" + + even_k = (SeqLen_KV % tn) == 0 + + stream = torch.cuda.current_stream() + + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, fmha_kernel, ( + Q, K, V, Out, + qk_scale, input_pos, + D_k, D_v, Heads, + tm, tn, + query_group_size, + causal, even_k + )) + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + ct.launch(stream, grid, fmha_kernel, ( + Q, K, V, Out, + qk_scale, input_pos, + D_k, D_v, Heads, + tm, tn, + query_group_size, + causal, even_k + )) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + + return {"Out": Out, "times": times} + + +def torch_sdpa(Q, K, V, *, causal: bool = False, enable_gqa: bool = False): + """Reference scaled dot-product attention using PyTorch.""" + # Use MATH backend as fallback (works with all dtypes) + # cuDNN/Flash only support float16/bfloat16 + with sdpa_kernel(SDPBackend.MATH): + return scaled_dot_product_attention(Q, K, V, is_causal=causal, enable_gqa=enable_gqa) + + +def verify(data, result): + """Verify FMHA results against reference implementation.""" + Q, K, V = data["Q"], data["K"], data["V"] + causal = data["causal"] + Heads, Heads_KV = data["Heads"], data["Heads_KV"] + + enable_gqa = Heads != Heads_KV + expected = torch_sdpa(Q, K, V, causal=causal, enable_gqa=enable_gqa) + actual = result["Out"] + + max_diff = float(torch.max(torch.abs(actual - expected))) + assert torch.allclose(actual, expected, rtol=1e-2, atol=1e-2), \ + f"FMHA mismatch! max diff: {max_diff}" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + Q, K, V = data["Q"], data["K"], data["V"] + causal = data["causal"] + Heads, Heads_KV = data["Heads"], data["Heads_KV"] + enable_gqa = Heads != Heads_KV + + # PyTorch SDPA (uses cuDNN or Flash Attention) + for _ in range(warmup): + _ = torch_sdpa(Q, K, V, causal=causal, enable_gqa=enable_gqa) + torch.cuda.synchronize() + + times_torch = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _ = torch_sdpa(Q, K, V, causal=causal, enable_gqa=enable_gqa) + end.record() + torch.cuda.synchronize() + times_torch.append(start.elapsed_time(end)) + results["PyTorch SDPA"] = times_torch + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_attention(dtype, D_k, SeqLen_Q, Heads, Batch, D_v, SeqLen_KV, Heads_KV, + causal, tm, tn, name=None): + """Test attention with given parameters.""" + if name is None: + dtype_name = str(dtype).split('.')[-1] + name = ", ".join([ + dtype_name, + f"tile={tm}x{tn}", + f"Q={D_k}x{SeqLen_Q}", + f"K={D_k}x{SeqLen_KV}", + f"V={D_v}x{SeqLen_KV}", + f"Heads={Heads}/{Heads_KV}", + f"Batch={Batch}", + f"causal={causal}" + ]) + print(f"--- {name} ---") + data = prepare( + D_k=D_k, SeqLen_Q=SeqLen_Q, Heads=Heads, Batch=Batch, + D_v=D_v, SeqLen_KV=SeqLen_KV, Heads_KV=Heads_KV, + causal=causal, dtype=dtype + ) + result = run(data, tm=tm, tn=tn) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile Fused Multi-Head Attention Examples ---\n") + + for dtype in (torch.float32, torch.float16): + # basic + test_attention(dtype, 64, 256, 8, 2, 64, 256, 8, False, 32, 32) + test_attention(dtype, 64, 256, 8, 2, 64, 128, 4, False, 32, 64) + test_attention(dtype, 64, 256, 8, 2, 64, 256, 8, True, 32, 32) + + # uneven seqlen + test_attention(dtype, 64, 127, 4, 1, 64, 127, 4, False, 32, 32) + test_attention(dtype, 64, 128, 4, 1, 64, 97, 2, False, 32, 32) + + # D_k != D_v + test_attention(dtype, 64, 256, 8, 2, 64, 256, 8, False, 32, 32) + + print("\n--- All attention examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/transpose.py b/examples/transpose.py index e299ef3..1996a3b 100644 --- a/examples/transpose.py +++ b/examples/transpose.py @@ -12,7 +12,7 @@ def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct. pid_m = ct.bid(0) pid_n = ct.bid(1) tile = ct.load(input, index=(pid_m, pid_n), shape=(tile_m, tile_n)) - tile_t = transpose(tile) + tile_t = ct.transpose(tile) ct.store(output, index=(pid_n, pid_m), tile=tile_t) From 9ef252ce5380ac17e7e790281fd38d9cde7015a9 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Thu, 5 Feb 2026 20:58:08 +0100 Subject: [PATCH 8/8] remove mod1 overlay (fixed on main) --- src/language/arithmetic.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/language/arithmetic.jl b/src/language/arithmetic.jl index 20ce934..0440bcd 100644 --- a/src/language/arithmetic.jl +++ b/src/language/arithmetic.jl @@ -29,7 +29,6 @@ @overlay Base.div(x::T, y::T, ::typeof(RoundUp)) where {T <: Unsigned} = Intrinsics.cldi(x, y, SignednessUnsigned) @overlay Base.rem(x::T, y::T) where {T <: Signed} = Intrinsics.remi(x, y, SignednessSigned) @overlay Base.rem(x::T, y::T) where {T <: Unsigned} = Intrinsics.remi(x, y, SignednessUnsigned) -@overlay Base.mod1(x::T, y::T) where {T <: ScalarInt} = (m = mod(x, y); m == zero(m) ? y : m) # float @overlay Base.:+(x::T, y::T) where {T <: ScalarFloat} = Intrinsics.addf(x, y)