diff --git a/README.md b/README.md index 85dd2ef..261d095 100644 --- a/README.md +++ b/README.md @@ -368,9 +368,13 @@ result = sum(tile; dims=2) # (M, N) → (M, 1) result = dropdims(sum(tile; dims=2); dims=2) # (M, N) → (M,) ``` -### Store reshaping +### Automatic rank matching -`ct.store` automatically reshapes the tile to match the target array's rank by dropping singleton dimensions (e.g., storing a `(1, N)` tile into a 1D array reshapes it to `(N,)`). Scalar `()` tiles are reshaped to `(1,)`. +`ct.load` and `ct.store` automatically match the tile rank to that of the target: + +- **Lower rank**: trailing `1`s are appended. Loading `(M, N)` from a 4D array internally uses `(M, N, 1, 1)`. Storing a scalar tile into a 2D array pads to `(1, 1)`. +- **Higher rank**: trailing `1`s are stripped. Storing `(M, 1)` into a 1D array reshapes to `(M,)`. + Non-trailing singletons (e.g., from `sum(tile; dims=1)`) require explicit `dropdims`. ### Broadcasting shape alignment diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 30b0d51..0950999 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -166,8 +166,8 @@ Accumulates partial gradients using atomic locks. Args: DX: Output gradient with respect to X (M, N). DY: Input gradient with respect to Y (M, N). - DW: Partial gradient with respect to W (GROUP_SIZE_M, N). - DB: Partial gradient with respect to B (GROUP_SIZE_M, N). + DW: Partial gradient with respect to W (N, GROUP_SIZE_M). + DB: Partial gradient with respect to B (N, GROUP_SIZE_M). X: Input tensor (M, N). W: Weight tensor (N,). Mean: Mean tensor (M,). @@ -211,8 +211,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd ct.store(DX, (bid_m, j), tdx) - partial_dw = tdy .* xhat - partial_db = tdy + partial_dw = reshape(tdy .* xhat, (TILE_N[], 1)) + partial_db = reshape(tdy, (TILE_N[], 1)) # Acquire spinlock while ct.atomic_cas(Locks, group_bid_m, 0, 1; @@ -221,10 +221,10 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til end # Critical section: accumulate partial gradients - partial_dw = partial_dw .+ ct.load(DW, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - partial_db = partial_db .+ ct.load(DB, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - ct.store(DW, (group_bid_m, j), partial_dw) - ct.store(DB, (group_bid_m, j), partial_db) + partial_dw = partial_dw .+ ct.load(DW, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero) + partial_db = partial_db .+ ct.load(DB, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero) + ct.store(DW, (j, group_bid_m), partial_dw) + ct.store(DB, (j, group_bid_m), partial_db) # Release spinlock ct.atomic_xchg(Locks, group_bid_m, 0; @@ -242,8 +242,8 @@ end Backward pass part 2: Final reduction for dW and dB. Args: - DW: Partial gradient with respect to W (TILE_M, N). - DB: Partial gradient with respect to B (TILE_M, N). + DW: Partial gradient with respect to W (N, TILE_M). + DB: Partial gradient with respect to B (N, TILE_M). FINAL_DW: Final gradient with respect to W (N,). FINAL_DB: Final gradient with respect to B (N,). TILE_M: Number of partial gradients to reduce. @@ -253,18 +253,18 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa FINAL_DW::ct.TileArray{Float32, 1}, FINAL_DB::ct.TileArray{Float32, 1}, TILE_M::ConstInt, TILE_N::ConstInt) bid_n = ct.bid(1) - num_tiles = ct.num_tiles(DW, 1, (TILE_M[], TILE_N[])) + num_tiles = ct.num_tiles(DW, 2, (TILE_N[], TILE_M[])) - dw = ct.zeros((TILE_M[], TILE_N[]), Float32) - db = ct.zeros((TILE_M[], TILE_N[]), Float32) + dw = ct.zeros((TILE_N[], TILE_M[]), Float32) + db = ct.zeros((TILE_N[], TILE_M[]), Float32) i = Int32(1) while i <= num_tiles - dw = dw .+ ct.load(DW, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero) - db = db .+ ct.load(DB, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero) + dw = dw .+ ct.load(DW, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero) + db = db .+ ct.load(DB, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero) i += Int32(1) end - sum_dw = sum(dw; dims=1) - sum_db = sum(db; dims=1) + sum_dw = sum(dw; dims=2) + sum_db = sum(db; dims=2) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) @@ -291,8 +291,8 @@ function prepare(; benchmark::Bool=false, # Backward inputs/outputs DY = 0.1f0 .* CUDA.randn(Float32, M, N), DX = CuArray{Float32}(undef, M, N), - DW_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), - DB_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + DW_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), + DB_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), Locks = CuArray{Int}(undef, GROUP_SIZE_M), FINAL_DW = CuArray{Float32}(undef, N), FINAL_DB = CuArray{Float32}(undef, N), diff --git a/src/language/operations.jl b/src/language/operations.jl index 0d49a7e..1f1ec95 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -65,6 +65,28 @@ Axis is 1-indexed. Equivalent to cld(arr.sizes[axis], shape[axis]). Intrinsics.get_index_space_shape(pv, axis - One()) # convert to 0-indexed end +# Match a shape tuple to a target rank N by padding trailing 1s or squeezing trailing singletons. +@generated function _match_shape(::Val{Shape}, ::Val{N}) where {Shape, N} + M = length(Shape) + M == N && return :($Shape) + if M < N + padded = (Shape..., ntuple(_ -> 1, N - M)...) + return :($padded) + end + trailing = M - something(findlast(!=(1), Shape), 0) + trailing >= M - N || error("cannot squeeze shape $Shape to rank $N: ", + "need to drop $(M-N) trailing singletons but only found $trailing") + kept = Shape[1:N] + return :($kept) +end + +# Reshape a tile to match target rank N, preserving data layout. +@inline function _reshape_to_rank(tile::Tile, ::Val{N}) where {N} + new_shape = _match_shape(Val(size(tile)), Val(N)) + reshape(tile, new_shape) +end + + """ load(arr::TileArray, index, shape; order=nothing, padding_mode=PaddingMode.Undetermined, latency=nothing, allow_tma=true) -> Tile @@ -102,66 +124,26 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1)) padding_mode::Int=PaddingMode.Undetermined, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) + matched = _match_shape(Val(shape), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) -end - -@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, (index - One(),)) -end - -# Load with Constant shape tuple -@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - shape_val = _extract_shape(shape) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + reshape(tile, shape) end -# Keyword argument version -@inline function load(arr::TileArray; index, shape, - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - shape_val = _extract_shape(shape) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) +# Scalar index → wrap in tuple +@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; kwargs...) + load(arr, (index,), shape; kwargs...) end -# Auto-reshape tile to match target array rank for store. -@inline function _reshape_for_store(tile::Tile, ::Val{N}) where {N} - ndims(tile) <= N && return tile - new_shape = _store_shape(Val(size(tile)), Val(N)) - reshape(tile, new_shape) +# Constant shape → extract and delegate +@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}}; kwargs...) + load(arr, index, _extract_shape(shape); kwargs...) end -@generated function _store_shape(::Val{Shape}, ::Val{N}) where {Shape, N} - M = length(Shape) - to_drop = M - N - singletons = [i for i in 1:M if Shape[i] == 1] - if length(singletons) < to_drop - error("cannot squeeze shape $Shape to rank $N: only $(length(singletons)) singleton dims but need to drop $to_drop") - end - drop_set = Set(singletons[1:to_drop]) - kept = tuple((Shape[i] for i in 1:M if !(i in drop_set))...) - # Partition views require at least 1D - if isempty(kept) - kept = (1,) - end - return :($kept) +# Keyword argument version → extract and delegate +@inline function load(arr::TileArray; index, shape, kwargs...) + load(arr, index, _extract_shape(shape); kwargs...) end """ @@ -183,18 +165,14 @@ Returns the stored tile (enables chaining and helps constant folding). order::Union{NTuple{<:Any, Int}, Nothing}=nothing, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) where {T} - reshaped = _reshape_for_store(tile, Val(ndims(arr))) + reshaped = _reshape_to_rank(tile, Val(ndims(arr))) _store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...) .- One()) return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) end -@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) where {T} - reshaped = _reshape_for_store(tile, Val(ndims(arr))) - _store_reshaped(arr, reshaped, order, latency, allow_tma, (index - One(),)) - return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) +# Scalar index → wrap in tuple +@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T}; kwargs...) where {T} + store(arr, (index,), tile; kwargs...) end @inline function _store_reshaped(arr::TileArray{T}, tile::Tile{T}, @@ -496,8 +474,10 @@ tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8), 32 elements reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements ``` """ -@inline Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T} = +@inline function Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T} + size(tile) === shape && return tile Intrinsics.reshape(tile, shape) +end """ permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape} diff --git a/test/codegen.jl b/test/codegen.jl index a27bbf3..58a056d 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -135,14 +135,14 @@ end end - # 1D -> 1D reshape (no permutes needed - optimization) + # 1D -> 1D same-shape reshape is a no-op @test @filecheck begin @check_label "entry" - @check_not "permute" # should NOT have permute for 1D->1D + @check_not "permute" + @check_not "reshape" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a pid = ct.bid(1) tile = ct.load(a, pid, (32,)) - @check "reshape" reshaped = reshape(tile, (32,)) ct.store(a, pid, reshaped) return @@ -165,6 +165,34 @@ end end + @testset "dropdims" begin + # dropdims on dim 1: (1, 8) -> dropdims(; dims=2) -> (8,) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 8)) + @check "reshape" + squeezed = dropdims(tile; dims=1) + ct.store(b, pid, squeezed) + return + end + end + + # dropdims on dim 2: (8, 1) -> dropdims(; dims=2) -> (8,) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (8, 1)) + @check "reshape" + squeezed = dropdims(tile; dims=2) + ct.store(b, pid, squeezed) + return + end + end + end + @testset "permutedims" begin # 2D permutedims with explicit perm (same as transpose) @test @filecheck begin @@ -1351,6 +1379,21 @@ end end + @testset "rank mismatch load/store" begin + # 1D shape on 2D array: should pad shape to (16, 1) internally + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + @check "load_view_tko" + tile = ct.load(a, (pid, 1), (16,)) + @check "store_view_tko" + ct.store(b, (pid, 1), tile) + return + end + end + end + @testset "num_tiles helper" begin spec = ct.ArraySpec{2}(16, true) @test @filecheck begin @@ -1741,7 +1784,7 @@ end TILE_N = 1024 # Use ArraySpec with shape_div_by to match real CuArray behavior - spec2d = ct.ArraySpec{2}(128, true, (0, 4), (32, 32)) + spec2d = ct.ArraySpec{2}(128, true, (4, 0), (32, 32)) spec1d = ct.ArraySpec{1}(128, true, (0,), (32,)) @test @filecheck begin @@ -1755,19 +1798,19 @@ end ct.TileArray{Float32, 1, spec1d}, ct.TileArray{Float32, 1, spec1d}, ct.Constant{Int, TILE_M}, ct.Constant{Int, TILE_N}}) do DW, DB, FINAL_DW, FINAL_DB, _TILE_M, _TILE_N bid_n = ct.bid(1) - num_tiles = ct.num_tiles(DW, 1, (_TILE_M[], _TILE_N[])) + num_tiles = ct.num_tiles(DW, 2, (_TILE_N[], _TILE_M[])) - dw = ct.zeros((_TILE_M[], _TILE_N[]), Float32) - db = ct.zeros((_TILE_M[], _TILE_N[]), Float32) + dw = ct.zeros((_TILE_N[], _TILE_M[]), Float32) + db = ct.zeros((_TILE_N[], _TILE_M[]), Float32) i = Int32(1) while i <= num_tiles - dw = dw .+ ct.load(DW, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero) - db = db .+ ct.load(DB, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero) + dw = dw .+ ct.load(DW, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero) + db = db .+ ct.load(DB, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero) i += Int32(1) end - sum_dw = sum(dw; dims=1) - sum_db = sum(db; dims=1) + sum_dw = sum(dw; dims=2) + sum_db = sum(db; dims=2) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) diff --git a/test/execution.jl b/test/execution.jl index 72d33ad..7c0d271 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -114,6 +114,43 @@ end @test Array(c) ≈ Array(a) + Array(b) end +@testset "rank mismatch load/store" begin + @testset "1D shape on 2D array" begin + function copy_1d_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(src, (bid, 1), (16,)) + ct.store(dst, (bid, 1), tile) + return + end + + m = 64 + src = CUDA.rand(Float32, m, 1) + dst = CUDA.zeros(Float32, m, 1) + + ct.launch(copy_1d_2d, cld(m, 16), src, dst) + + @test Array(dst) ≈ Array(src) + end + + @testset "2D shape on 4D array" begin + function copy_2d_4d(src::ct.TileArray{Float32,4}, dst::ct.TileArray{Float32,4}) + bidx = ct.bid(1) + bidy = ct.bid(2) + tile = ct.load(src, (bidx, bidy, 1, 1), (4, 4)) + ct.store(dst, (bidx, bidy, 1, 1), tile) + return + end + + d1, d2 = 16, 16 + src = CUDA.rand(Float32, d1, d2, 1, 1) + dst = CUDA.zeros(Float32, d1, d2, 1, 1) + + ct.launch(copy_2d_4d, (cld(d1, 4), cld(d2, 4)), src, dst) + + @test Array(dst) ≈ Array(src) + end +end + @testset "transpose" begin function transpose_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) bidx = ct.bid(1) @@ -1163,6 +1200,31 @@ end end end +@testset "dropdims" begin + # Mean-subtract pattern: reduce row to get mean, dropdims the singleton, + # then broadcast-subtract from the original tile and store the column norms. + function dropdims_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 128)) # (1, 128) + row_sum = sum(tile; dims=2) # (1, 1) + row_sum_1d = dropdims(row_sum; dims=2) # (1,) + ct.store(b, pid, row_sum_1d) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, m) + + ct.launch(dropdims_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ sum(a_cpu[i, :]) rtol=1e-3 + end +end + end @testset "scan" begin