From 3ba4b3fd36adc164c4542df87f93426b4f0ae1ce Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 6 Feb 2026 12:21:51 +0100 Subject: [PATCH 1/7] implement more cases of automatic resqueezing --- README.md | 7 ++- src/language/operations.jl | 88 ++++++++++++++++++++++++-------------- test/codegen.jl | 15 +++++++ test/execution.jl | 39 ++++++++++++++++- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 85dd2ef..ee5465e 100644 --- a/README.md +++ b/README.md @@ -368,9 +368,12 @@ result = sum(tile; dims=2) # (M, N) → (M, 1) result = dropdims(sum(tile; dims=2); dims=2) # (M, N) → (M,) ``` -### Store reshaping +### Automatic rank matching -`ct.store` automatically reshapes the tile to match the target array's rank by dropping singleton dimensions (e.g., storing a `(1, N)` tile into a 1D array reshapes it to `(N,)`). Scalar `()` tiles are reshaped to `(1,)`. +`ct.load` and `ct.store` automatically match the tile rank to that of the target: + +- **Lower rank**: trailing `1`s are appended. Loading `(M, N)` from a 4D array internally uses `(M, N, 1, 1)`. Storing a scalar tile into a 2D array pads to `(1, 1)`. +- **Higher rank**: trailing singletons are squeezed. Storing `(M, 1)` into a 1D array reshapes to `(M,)`. ### Broadcasting shape alignment diff --git a/src/language/operations.jl b/src/language/operations.jl index 0d49a7e..624140c 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -65,6 +65,43 @@ Axis is 1-indexed. Equivalent to cld(arr.sizes[axis], shape[axis]). Intrinsics.get_index_space_shape(pv, axis - One()) # convert to 0-indexed end +# Match a shape tuple to a target rank N by padding trailing 1s or squeezing trailing singletons. +@generated function _match_shape(::Val{Shape}, ::Val{N}) where {Shape, N} + M = length(Shape) + if M == N + return :($Shape) + elseif M < N + # Pad with trailing 1s + padded = (Shape..., ntuple(_ -> 1, N - M)...) + return :($padded) + else + # Squeeze: drop trailing singleton dims first + to_drop = M - N + singletons = [i for i in M:-1:1 if Shape[i] == 1] + if length(singletons) < to_drop + error("cannot squeeze shape $Shape to rank $N: only $(length(singletons)) singleton dims but need to drop $to_drop") + end + drop_set = Set(singletons[1:to_drop]) + kept = tuple((Shape[i] for i in 1:M if !(i in drop_set))...) + # Partition views require at least 1D + if isempty(kept) + kept = (1,) + end + return :($kept) + end +end + +# Reshape a tile to match target rank N, preserving data layout. +@inline function _reshape_to_rank(tile::Tile, ::Val{N}) where {N} + ndims(tile) == N && return tile + new_shape = _match_shape(Val(size(tile)), Val(N)) + reshape(tile, new_shape) +end + +# Reshape a tile back to a requested shape (no-op when already matching). +@inline _reshape_to_shape(tile::Tile, shape) = + size(tile) === shape ? tile : reshape(tile, shape) + """ load(arr::TileArray, index, shape; order=nothing, padding_mode=PaddingMode.Undetermined, latency=nothing, allow_tma=true) -> Tile @@ -102,9 +139,11 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1)) padding_mode::Int=PaddingMode.Undetermined, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) + matched = _match_shape(Val(shape), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + _reshape_to_shape(tile, shape) end @inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; @@ -112,9 +151,11 @@ end padding_mode::Int=PaddingMode.Undetermined, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) + matched = _match_shape(Val(shape), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, (index - One(),)) + pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, (index - One(),)) + _reshape_to_shape(tile, shape) end # Load with Constant shape tuple @@ -124,9 +165,11 @@ end latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) shape_val = _extract_shape(shape) + matched = _match_shape(Val(shape_val), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + _reshape_to_shape(tile, shape_val) end # Keyword argument version @@ -136,32 +179,11 @@ end latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) shape_val = _extract_shape(shape) + matched = _match_shape(Val(shape_val), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order) - Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) -end - -# Auto-reshape tile to match target array rank for store. -@inline function _reshape_for_store(tile::Tile, ::Val{N}) where {N} - ndims(tile) <= N && return tile - new_shape = _store_shape(Val(size(tile)), Val(N)) - reshape(tile, new_shape) -end - -@generated function _store_shape(::Val{Shape}, ::Val{N}) where {Shape, N} - M = length(Shape) - to_drop = M - N - singletons = [i for i in 1:M if Shape[i] == 1] - if length(singletons) < to_drop - error("cannot squeeze shape $Shape to rank $N: only $(length(singletons)) singleton dims but need to drop $to_drop") - end - drop_set = Set(singletons[1:to_drop]) - kept = tuple((Shape[i] for i in 1:M if !(i in drop_set))...) - # Partition views require at least 1D - if isempty(kept) - kept = (1,) - end - return :($kept) + pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + _reshape_to_shape(tile, shape_val) end """ @@ -183,7 +205,7 @@ Returns the stored tile (enables chaining and helps constant folding). order::Union{NTuple{<:Any, Int}, Nothing}=nothing, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) where {T} - reshaped = _reshape_for_store(tile, Val(ndims(arr))) + reshaped = _reshape_to_rank(tile, Val(ndims(arr))) _store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...) .- One()) return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) end @@ -192,7 +214,7 @@ end order::Union{NTuple{<:Any, Int}, Nothing}=nothing, latency::Union{Int, Nothing}=nothing, allow_tma::Bool=true) where {T} - reshaped = _reshape_for_store(tile, Val(ndims(arr))) + reshaped = _reshape_to_rank(tile, Val(ndims(arr))) _store_reshaped(arr, reshaped, order, latency, allow_tma, (index - One(),)) return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) end diff --git a/test/codegen.jl b/test/codegen.jl index a27bbf3..239a78c 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -1351,6 +1351,21 @@ end end + @testset "rank mismatch load/store" begin + # 1D shape on 2D array: should pad shape to (16, 1) internally + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + @check "load_view_tko" + tile = ct.load(a, (pid, 1), (16,)) + @check "store_view_tko" + ct.store(b, (pid, 1), tile) + return + end + end + end + @testset "num_tiles helper" begin spec = ct.ArraySpec{2}(16, true) @test @filecheck begin diff --git a/test/execution.jl b/test/execution.jl index 72d33ad..644a6bd 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -114,6 +114,43 @@ end @test Array(c) ≈ Array(a) + Array(b) end +@testset "rank mismatch load/store" begin + @testset "1D shape on 2D array" begin + function copy_1d_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(src, (bid, 1), (16,)) + ct.store(dst, (bid, 1), tile) + return + end + + m = 64 + src = CUDA.rand(Float32, m, 1) + dst = CUDA.zeros(Float32, m, 1) + + ct.launch(copy_1d_2d, cld(m, 16), src, dst) + + @test Array(dst) ≈ Array(src) + end + + @testset "2D shape on 4D array" begin + function copy_2d_4d(src::ct.TileArray{Float32,4}, dst::ct.TileArray{Float32,4}) + bidx = ct.bid(1) + bidy = ct.bid(2) + tile = ct.load(src, (bidx, bidy, 1, 1), (4, 4)) + ct.store(dst, (bidx, bidy, 1, 1), tile) + return + end + + d1, d2 = 16, 16 + src = CUDA.rand(Float32, d1, d2, 1, 1) + dst = CUDA.zeros(Float32, d1, d2, 1, 1) + + ct.launch(copy_2d_4d, (cld(d1, 4), cld(d2, 4)), src, dst) + + @test Array(dst) ≈ Array(src) + end +end + @testset "transpose" begin function transpose_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) bidx = ct.bid(1) @@ -2098,7 +2135,7 @@ end cpu_result = cpu_reduce(a_reshaped, cpu_op) if elType <: AbstractFloat - @test b_cpu ≈ cpu_result rtol=1e-3 + @test b_cpu ≈ cpu_result rtol=2e-3 else @test b_cpu == cpu_result end From 101ad9fc08184a166a2cfc3e8f76c54f69866e4d Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Feb 2026 15:45:53 +0100 Subject: [PATCH 2/7] Restrict shape matching to consecutive trailing singletons. --- README.md | 3 ++- examples/layernorm.jl | 4 ++-- src/language/operations.jl | 15 +++++++++------ test/codegen.jl | 6 +++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ee5465e..261d095 100644 --- a/README.md +++ b/README.md @@ -373,7 +373,8 @@ result = dropdims(sum(tile; dims=2); dims=2) # (M, N) → (M,) `ct.load` and `ct.store` automatically match the tile rank to that of the target: - **Lower rank**: trailing `1`s are appended. Loading `(M, N)` from a 4D array internally uses `(M, N, 1, 1)`. Storing a scalar tile into a 2D array pads to `(1, 1)`. -- **Higher rank**: trailing singletons are squeezed. Storing `(M, 1)` into a 1D array reshapes to `(M,)`. +- **Higher rank**: trailing `1`s are stripped. Storing `(M, 1)` into a 1D array reshapes to `(M,)`. + Non-trailing singletons (e.g., from `sum(tile; dims=1)`) require explicit `dropdims`. ### Broadcasting shape alignment diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 30b0d51..adf275c 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -263,8 +263,8 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa db = db .+ ct.load(DB, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero) i += Int32(1) end - sum_dw = sum(dw; dims=1) - sum_db = sum(db; dims=1) + sum_dw = dropdims(sum(dw; dims=1); dims=1) + sum_db = dropdims(sum(db; dims=1); dims=1) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) diff --git a/src/language/operations.jl b/src/language/operations.jl index 624140c..ea6e258 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -75,14 +75,17 @@ end padded = (Shape..., ntuple(_ -> 1, N - M)...) return :($padded) else - # Squeeze: drop trailing singleton dims first + # Squeeze: only drop consecutive trailing singletons to_drop = M - N - singletons = [i for i in M:-1:1 if Shape[i] == 1] - if length(singletons) < to_drop - error("cannot squeeze shape $Shape to rank $N: only $(length(singletons)) singleton dims but need to drop $to_drop") + trailing_ones = 0 + for i in M:-1:1 + Shape[i] == 1 || break + trailing_ones += 1 end - drop_set = Set(singletons[1:to_drop]) - kept = tuple((Shape[i] for i in 1:M if !(i in drop_set))...) + if trailing_ones < to_drop + error("cannot squeeze shape $Shape to rank $N: need to drop $to_drop trailing singletons but only found $trailing_ones") + end + kept = Shape[1:N] # Partition views require at least 1D if isempty(kept) kept = (1,) diff --git a/test/codegen.jl b/test/codegen.jl index 239a78c..5f2ebdf 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -1781,8 +1781,8 @@ end i += Int32(1) end - sum_dw = sum(dw; dims=1) - sum_db = sum(db; dims=1) + sum_dw = dropdims(sum(dw; dims=1); dims=1) + sum_db = dropdims(sum(db; dims=1); dims=1) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) @@ -1835,7 +1835,7 @@ end end @check "reduce" @check "store_view_tko" - ct.store(out, bid, sum(acc2; dims=1)) + ct.store(out, bid, dropdims(sum(acc2; dims=1); dims=1)) return end From ba0b86a278b57990cd2cea9c50643f65c3685310 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sat, 7 Feb 2026 14:56:32 +0100 Subject: [PATCH 3/7] Transpose DW/DB layout to avoid dropdims on leading singletons. Switch layernorm backward's DW/DB partial buffers from (GROUP_SIZE_M, N) to (N, GROUP_SIZE_M) so that sum(; dims=2) produces trailing singletons that auto-squeeze, removing the need for explicit dropdims calls. Co-Authored-By: Claude Opus 4.6 --- examples/layernorm.jl | 38 +++++++++++++++++++------------------- test/codegen.jl | 18 +++++++++--------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/layernorm.jl b/examples/layernorm.jl index adf275c..0950999 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -166,8 +166,8 @@ Accumulates partial gradients using atomic locks. Args: DX: Output gradient with respect to X (M, N). DY: Input gradient with respect to Y (M, N). - DW: Partial gradient with respect to W (GROUP_SIZE_M, N). - DB: Partial gradient with respect to B (GROUP_SIZE_M, N). + DW: Partial gradient with respect to W (N, GROUP_SIZE_M). + DB: Partial gradient with respect to B (N, GROUP_SIZE_M). X: Input tensor (M, N). W: Weight tensor (N,). Mean: Mean tensor (M,). @@ -211,8 +211,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd ct.store(DX, (bid_m, j), tdx) - partial_dw = tdy .* xhat - partial_db = tdy + partial_dw = reshape(tdy .* xhat, (TILE_N[], 1)) + partial_db = reshape(tdy, (TILE_N[], 1)) # Acquire spinlock while ct.atomic_cas(Locks, group_bid_m, 0, 1; @@ -221,10 +221,10 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til end # Critical section: accumulate partial gradients - partial_dw = partial_dw .+ ct.load(DW, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - partial_db = partial_db .+ ct.load(DB, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - ct.store(DW, (group_bid_m, j), partial_dw) - ct.store(DB, (group_bid_m, j), partial_db) + partial_dw = partial_dw .+ ct.load(DW, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero) + partial_db = partial_db .+ ct.load(DB, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero) + ct.store(DW, (j, group_bid_m), partial_dw) + ct.store(DB, (j, group_bid_m), partial_db) # Release spinlock ct.atomic_xchg(Locks, group_bid_m, 0; @@ -242,8 +242,8 @@ end Backward pass part 2: Final reduction for dW and dB. Args: - DW: Partial gradient with respect to W (TILE_M, N). - DB: Partial gradient with respect to B (TILE_M, N). + DW: Partial gradient with respect to W (N, TILE_M). + DB: Partial gradient with respect to B (N, TILE_M). FINAL_DW: Final gradient with respect to W (N,). FINAL_DB: Final gradient with respect to B (N,). TILE_M: Number of partial gradients to reduce. @@ -253,18 +253,18 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa FINAL_DW::ct.TileArray{Float32, 1}, FINAL_DB::ct.TileArray{Float32, 1}, TILE_M::ConstInt, TILE_N::ConstInt) bid_n = ct.bid(1) - num_tiles = ct.num_tiles(DW, 1, (TILE_M[], TILE_N[])) + num_tiles = ct.num_tiles(DW, 2, (TILE_N[], TILE_M[])) - dw = ct.zeros((TILE_M[], TILE_N[]), Float32) - db = ct.zeros((TILE_M[], TILE_N[]), Float32) + dw = ct.zeros((TILE_N[], TILE_M[]), Float32) + db = ct.zeros((TILE_N[], TILE_M[]), Float32) i = Int32(1) while i <= num_tiles - dw = dw .+ ct.load(DW, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero) - db = db .+ ct.load(DB, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero) + dw = dw .+ ct.load(DW, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero) + db = db .+ ct.load(DB, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero) i += Int32(1) end - sum_dw = dropdims(sum(dw; dims=1); dims=1) - sum_db = dropdims(sum(db; dims=1); dims=1) + sum_dw = sum(dw; dims=2) + sum_db = sum(db; dims=2) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) @@ -291,8 +291,8 @@ function prepare(; benchmark::Bool=false, # Backward inputs/outputs DY = 0.1f0 .* CUDA.randn(Float32, M, N), DX = CuArray{Float32}(undef, M, N), - DW_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), - DB_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + DW_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), + DB_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), Locks = CuArray{Int}(undef, GROUP_SIZE_M), FINAL_DW = CuArray{Float32}(undef, N), FINAL_DB = CuArray{Float32}(undef, N), diff --git a/test/codegen.jl b/test/codegen.jl index 5f2ebdf..2f43042 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -1756,7 +1756,7 @@ end TILE_N = 1024 # Use ArraySpec with shape_div_by to match real CuArray behavior - spec2d = ct.ArraySpec{2}(128, true, (0, 4), (32, 32)) + spec2d = ct.ArraySpec{2}(128, true, (4, 0), (32, 32)) spec1d = ct.ArraySpec{1}(128, true, (0,), (32,)) @test @filecheck begin @@ -1770,19 +1770,19 @@ end ct.TileArray{Float32, 1, spec1d}, ct.TileArray{Float32, 1, spec1d}, ct.Constant{Int, TILE_M}, ct.Constant{Int, TILE_N}}) do DW, DB, FINAL_DW, FINAL_DB, _TILE_M, _TILE_N bid_n = ct.bid(1) - num_tiles = ct.num_tiles(DW, 1, (_TILE_M[], _TILE_N[])) + num_tiles = ct.num_tiles(DW, 2, (_TILE_N[], _TILE_M[])) - dw = ct.zeros((_TILE_M[], _TILE_N[]), Float32) - db = ct.zeros((_TILE_M[], _TILE_N[]), Float32) + dw = ct.zeros((_TILE_N[], _TILE_M[]), Float32) + db = ct.zeros((_TILE_N[], _TILE_M[]), Float32) i = Int32(1) while i <= num_tiles - dw = dw .+ ct.load(DW, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero) - db = db .+ ct.load(DB, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero) + dw = dw .+ ct.load(DW, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero) + db = db .+ ct.load(DB, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero) i += Int32(1) end - sum_dw = dropdims(sum(dw; dims=1); dims=1) - sum_db = dropdims(sum(db; dims=1); dims=1) + sum_dw = sum(dw; dims=2) + sum_db = sum(db; dims=2) ct.store(FINAL_DW, bid_n, sum_dw) ct.store(FINAL_DB, bid_n, sum_db) @@ -1835,7 +1835,7 @@ end end @check "reduce" @check "store_view_tko" - ct.store(out, bid, dropdims(sum(acc2; dims=1); dims=1)) + ct.store(out, bid, sum(acc2; dims=1)) return end From 3491b65159a6ad2b8dc8cdfbb1a118161852fd1a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sat, 7 Feb 2026 15:02:23 +0100 Subject: [PATCH 4/7] Add dropdims codegen and execution tests. Codegen tests verify dropdims emits a reshape op for both dim=1 and dim=2 on tiles with singleton dimensions. Execution test verifies correctness of sum + dropdims pattern on GPU. Co-Authored-By: Claude Opus 4.6 --- test/codegen.jl | 28 ++++++++++++++++++++++++++++ test/execution.jl | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/test/codegen.jl b/test/codegen.jl index 2f43042..dc1887d 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -165,6 +165,34 @@ end end + @testset "dropdims" begin + # dropdims on dim 1: (1, 8) -> dropdims(; dims=2) -> (8,) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 8)) + @check "reshape" + squeezed = dropdims(tile; dims=1) + ct.store(b, pid, squeezed) + return + end + end + + # dropdims on dim 2: (8, 1) -> dropdims(; dims=2) -> (8,) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (8, 1)) + @check "reshape" + squeezed = dropdims(tile; dims=2) + ct.store(b, pid, squeezed) + return + end + end + end + @testset "permutedims" begin # 2D permutedims with explicit perm (same as transpose) @test @filecheck begin diff --git a/test/execution.jl b/test/execution.jl index 644a6bd..7993e8c 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -1200,6 +1200,31 @@ end end end +@testset "dropdims" begin + # Mean-subtract pattern: reduce row to get mean, dropdims the singleton, + # then broadcast-subtract from the original tile and store the column norms. + function dropdims_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 128)) # (1, 128) + row_sum = sum(tile; dims=2) # (1, 1) + row_sum_1d = dropdims(row_sum; dims=2) # (1,) + ct.store(b, pid, row_sum_1d) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, m) + + ct.launch(dropdims_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ sum(a_cpu[i, :]) rtol=1e-3 + end +end + end @testset "scan" begin From 2e59e02f9c2fa1cd188e3c24adec236e54ebeb71 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sat, 7 Feb 2026 22:14:32 +0100 Subject: [PATCH 5/7] Simplify implementation. --- src/language/operations.jl | 81 +++++++++----------------------------- 1 file changed, 19 insertions(+), 62 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index ea6e258..ad75076 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -68,30 +68,16 @@ end # Match a shape tuple to a target rank N by padding trailing 1s or squeezing trailing singletons. @generated function _match_shape(::Val{Shape}, ::Val{N}) where {Shape, N} M = length(Shape) - if M == N - return :($Shape) - elseif M < N - # Pad with trailing 1s + M == N && return :($Shape) + if M < N padded = (Shape..., ntuple(_ -> 1, N - M)...) return :($padded) - else - # Squeeze: only drop consecutive trailing singletons - to_drop = M - N - trailing_ones = 0 - for i in M:-1:1 - Shape[i] == 1 || break - trailing_ones += 1 - end - if trailing_ones < to_drop - error("cannot squeeze shape $Shape to rank $N: need to drop $to_drop trailing singletons but only found $trailing_ones") - end - kept = Shape[1:N] - # Partition views require at least 1D - if isempty(kept) - kept = (1,) - end - return :($kept) end + trailing = M - something(findlast(!=(1), Shape), 0) + trailing >= M - N || error("cannot squeeze shape $Shape to rank $N: ", + "need to drop $(M-N) trailing singletons but only found $trailing") + kept = Shape[1:N] + return :($kept) end # Reshape a tile to match target rank N, preserving data layout. @@ -149,44 +135,19 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1)) _reshape_to_shape(tile, shape) end -@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - matched = _match_shape(Val(shape), Val(ndims(arr))) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) - tile = Intrinsics.load_partition_view(pv, latency, allow_tma, (index - One(),)) - _reshape_to_shape(tile, shape) +# Scalar index → wrap in tuple +@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; kwargs...) + load(arr, (index,), shape; kwargs...) end -# Load with Constant shape tuple -@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - shape_val = _extract_shape(shape) - matched = _match_shape(Val(shape_val), Val(ndims(arr))) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) - tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) - _reshape_to_shape(tile, shape_val) +# Constant shape → extract and delegate +@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}}; kwargs...) + load(arr, index, _extract_shape(shape); kwargs...) end -# Keyword argument version -@inline function load(arr::TileArray; index, shape, - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - padding_mode::Int=PaddingMode.Undetermined, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) - shape_val = _extract_shape(shape) - matched = _match_shape(Val(shape_val), Val(ndims(arr))) - tv = Intrinsics.make_tensor_view(arr) - pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) - tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) - _reshape_to_shape(tile, shape_val) +# Keyword argument version → extract and delegate +@inline function load(arr::TileArray; index, shape, kwargs...) + load(arr, index, _extract_shape(shape); kwargs...) end """ @@ -213,13 +174,9 @@ Returns the stored tile (enables chaining and helps constant folding). return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) end -@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T}; - order::Union{NTuple{<:Any, Int}, Nothing}=nothing, - latency::Union{Int, Nothing}=nothing, - allow_tma::Bool=true) where {T} - reshaped = _reshape_to_rank(tile, Val(ndims(arr))) - _store_reshaped(arr, reshaped, order, latency, allow_tma, (index - One(),)) - return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) +# Scalar index → wrap in tuple +@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T}; kwargs...) where {T} + store(arr, (index,), tile; kwargs...) end @inline function _store_reshaped(arr::TileArray{T}, tile::Tile{T}, From f53eafef2f0cda0c71924c3c87ce91355588faca Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sat, 7 Feb 2026 22:23:49 +0100 Subject: [PATCH 6/7] Undo spurious rtol change. --- test/execution.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/execution.jl b/test/execution.jl index 7993e8c..7c0d271 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -2160,7 +2160,7 @@ end cpu_result = cpu_reduce(a_reshaped, cpu_op) if elType <: AbstractFloat - @test b_cpu ≈ cpu_result rtol=2e-3 + @test b_cpu ≈ cpu_result rtol=1e-3 else @test b_cpu == cpu_result end From 889c3bfe726e71cf82d342035c72b97321596ec1 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 8 Feb 2026 07:44:30 +0100 Subject: [PATCH 7/7] Make same-shape reshape a no-op. --- src/language/operations.jl | 10 ++++------ test/codegen.jl | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index ad75076..1f1ec95 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -82,14 +82,10 @@ end # Reshape a tile to match target rank N, preserving data layout. @inline function _reshape_to_rank(tile::Tile, ::Val{N}) where {N} - ndims(tile) == N && return tile new_shape = _match_shape(Val(size(tile)), Val(N)) reshape(tile, new_shape) end -# Reshape a tile back to a requested shape (no-op when already matching). -@inline _reshape_to_shape(tile::Tile, shape) = - size(tile) === shape ? tile : reshape(tile, shape) """ load(arr::TileArray, index, shape; order=nothing, padding_mode=PaddingMode.Undetermined, latency=nothing, allow_tma=true) -> Tile @@ -132,7 +128,7 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1)) tv = Intrinsics.make_tensor_view(arr) pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) - _reshape_to_shape(tile, shape) + reshape(tile, shape) end # Scalar index → wrap in tuple @@ -478,8 +474,10 @@ tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8), 32 elements reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements ``` """ -@inline Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T} = +@inline function Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T} + size(tile) === shape && return tile Intrinsics.reshape(tile, shape) +end """ permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape} diff --git a/test/codegen.jl b/test/codegen.jl index dc1887d..58a056d 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -135,14 +135,14 @@ end end - # 1D -> 1D reshape (no permutes needed - optimization) + # 1D -> 1D same-shape reshape is a no-op @test @filecheck begin @check_label "entry" - @check_not "permute" # should NOT have permute for 1D->1D + @check_not "permute" + @check_not "reshape" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a pid = ct.bid(1) tile = ct.load(a, pid, (32,)) - @check "reshape" reshaped = reshape(tile, (32,)) ct.store(a, pid, reshaped) return