Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,13 @@ result = sum(tile; dims=2) # (M, N) → (M, 1)
result = dropdims(sum(tile; dims=2); dims=2) # (M, N) → (M,)
```

### Store reshaping
### Automatic rank matching

`ct.store` automatically reshapes the tile to match the target array's rank by dropping singleton dimensions (e.g., storing a `(1, N)` tile into a 1D array reshapes it to `(N,)`). Scalar `()` tiles are reshaped to `(1,)`.
`ct.load` and `ct.store` automatically match the tile rank to that of the target:

- **Lower rank**: trailing `1`s are appended. Loading `(M, N)` from a 4D array internally uses `(M, N, 1, 1)`. Storing a scalar tile into a 2D array pads to `(1, 1)`.
- **Higher rank**: trailing `1`s are stripped. Storing `(M, 1)` into a 1D array reshapes to `(M,)`.
Non-trailing singletons (e.g., from `sum(tile; dims=1)`) require explicit `dropdims`.

### Broadcasting shape alignment

Expand Down
38 changes: 19 additions & 19 deletions examples/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ Accumulates partial gradients using atomic locks.
Args:
DX: Output gradient with respect to X (M, N).
DY: Input gradient with respect to Y (M, N).
DW: Partial gradient with respect to W (GROUP_SIZE_M, N).
DB: Partial gradient with respect to B (GROUP_SIZE_M, N).
DW: Partial gradient with respect to W (N, GROUP_SIZE_M).
DB: Partial gradient with respect to B (N, GROUP_SIZE_M).
X: Input tensor (M, N).
W: Weight tensor (N,).
Mean: Mean tensor (M,).
Expand Down Expand Up @@ -211,8 +211,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
ct.store(DX, (bid_m, j), tdx)

partial_dw = tdy .* xhat
partial_db = tdy
partial_dw = reshape(tdy .* xhat, (TILE_N[], 1))
partial_db = reshape(tdy, (TILE_N[], 1))

# Acquire spinlock
while ct.atomic_cas(Locks, group_bid_m, 0, 1;
Expand All @@ -221,10 +221,10 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
end

# Critical section: accumulate partial gradients
partial_dw = partial_dw .+ ct.load(DW, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero)
partial_db = partial_db .+ ct.load(DB, (group_bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero)
ct.store(DW, (group_bid_m, j), partial_dw)
ct.store(DB, (group_bid_m, j), partial_db)
partial_dw = partial_dw .+ ct.load(DW, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero)
partial_db = partial_db .+ ct.load(DB, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero)
ct.store(DW, (j, group_bid_m), partial_dw)
ct.store(DB, (j, group_bid_m), partial_db)

# Release spinlock
ct.atomic_xchg(Locks, group_bid_m, 0;
Expand All @@ -242,8 +242,8 @@ end
Backward pass part 2: Final reduction for dW and dB.

Args:
DW: Partial gradient with respect to W (TILE_M, N).
DB: Partial gradient with respect to B (TILE_M, N).
DW: Partial gradient with respect to W (N, TILE_M).
DB: Partial gradient with respect to B (N, TILE_M).
FINAL_DW: Final gradient with respect to W (N,).
FINAL_DB: Final gradient with respect to B (N,).
TILE_M: Number of partial gradients to reduce.
Expand All @@ -253,18 +253,18 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa
FINAL_DW::ct.TileArray{Float32, 1}, FINAL_DB::ct.TileArray{Float32, 1},
TILE_M::ConstInt, TILE_N::ConstInt)
bid_n = ct.bid(1)
num_tiles = ct.num_tiles(DW, 1, (TILE_M[], TILE_N[]))
num_tiles = ct.num_tiles(DW, 2, (TILE_N[], TILE_M[]))

dw = ct.zeros((TILE_M[], TILE_N[]), Float32)
db = ct.zeros((TILE_M[], TILE_N[]), Float32)
dw = ct.zeros((TILE_N[], TILE_M[]), Float32)
db = ct.zeros((TILE_N[], TILE_M[]), Float32)
i = Int32(1)
while i <= num_tiles
dw = dw .+ ct.load(DW, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero)
db = db .+ ct.load(DB, (i, bid_n), (TILE_M[], TILE_N[]); padding_mode=ct.PaddingMode.Zero)
dw = dw .+ ct.load(DW, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero)
db = db .+ ct.load(DB, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero)
i += Int32(1)
end
sum_dw = sum(dw; dims=1)
sum_db = sum(db; dims=1)
sum_dw = sum(dw; dims=2)
sum_db = sum(db; dims=2)

ct.store(FINAL_DW, bid_n, sum_dw)
ct.store(FINAL_DB, bid_n, sum_db)
Expand All @@ -291,8 +291,8 @@ function prepare(; benchmark::Bool=false,
# Backward inputs/outputs
DY = 0.1f0 .* CUDA.randn(Float32, M, N),
DX = CuArray{Float32}(undef, M, N),
DW_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N),
DB_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N),
DW_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M),
DB_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M),
Locks = CuArray{Int}(undef, GROUP_SIZE_M),
FINAL_DW = CuArray{Float32}(undef, N),
FINAL_DB = CuArray{Float32}(undef, N),
Expand Down
104 changes: 42 additions & 62 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ Axis is 1-indexed. Equivalent to cld(arr.sizes[axis], shape[axis]).
Intrinsics.get_index_space_shape(pv, axis - One()) # convert to 0-indexed
end

# Match a shape tuple to a target rank N by padding trailing 1s or squeezing trailing singletons.
@generated function _match_shape(::Val{Shape}, ::Val{N}) where {Shape, N}
M = length(Shape)
M == N && return :($Shape)
if M < N
padded = (Shape..., ntuple(_ -> 1, N - M)...)
return :($padded)
end
trailing = M - something(findlast(!=(1), Shape), 0)
trailing >= M - N || error("cannot squeeze shape $Shape to rank $N: ",
"need to drop $(M-N) trailing singletons but only found $trailing")
kept = Shape[1:N]
return :($kept)
end

# Reshape a tile to match target rank N, preserving data layout.
@inline function _reshape_to_rank(tile::Tile, ::Val{N}) where {N}
new_shape = _match_shape(Val(size(tile)), Val(N))
reshape(tile, new_shape)
end


"""
load(arr::TileArray, index, shape; order=nothing, padding_mode=PaddingMode.Undetermined, latency=nothing, allow_tma=true) -> Tile

Expand Down Expand Up @@ -102,66 +124,26 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1))
padding_mode::Int=PaddingMode.Undetermined,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true)
matched = _match_shape(Val(shape), Val(ndims(arr)))
tv = Intrinsics.make_tensor_view(arr)
pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order)
Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One())
end

@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int};
order::Union{NTuple{<:Any, Int}, Nothing}=nothing,
padding_mode::Int=PaddingMode.Undetermined,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true)
tv = Intrinsics.make_tensor_view(arr)
pv = Intrinsics.make_partition_view(tv, shape, padding_mode, order)
Intrinsics.load_partition_view(pv, latency, allow_tma, (index - One(),))
end

# Load with Constant shape tuple
@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}};
order::Union{NTuple{<:Any, Int}, Nothing}=nothing,
padding_mode::Int=PaddingMode.Undetermined,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true)
shape_val = _extract_shape(shape)
tv = Intrinsics.make_tensor_view(arr)
pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order)
Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One())
pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order)
tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One())
reshape(tile, shape)
end

# Keyword argument version
@inline function load(arr::TileArray; index, shape,
order::Union{NTuple{<:Any, Int}, Nothing}=nothing,
padding_mode::Int=PaddingMode.Undetermined,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true)
shape_val = _extract_shape(shape)
tv = Intrinsics.make_tensor_view(arr)
pv = Intrinsics.make_partition_view(tv, shape_val, padding_mode, order)
Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One())
# Scalar index → wrap in tuple
@inline function load(arr::TileArray, index::Integer, shape::NTuple{<:Any, Int}; kwargs...)
load(arr, (index,), shape; kwargs...)
end

# Auto-reshape tile to match target array rank for store.
@inline function _reshape_for_store(tile::Tile, ::Val{N}) where {N}
ndims(tile) <= N && return tile
new_shape = _store_shape(Val(size(tile)), Val(N))
reshape(tile, new_shape)
# Constant shape → extract and delegate
@inline function load(arr::TileArray, index, shape::Tuple{Vararg{Constant{Int}}}; kwargs...)
load(arr, index, _extract_shape(shape); kwargs...)
end

@generated function _store_shape(::Val{Shape}, ::Val{N}) where {Shape, N}
M = length(Shape)
to_drop = M - N
singletons = [i for i in 1:M if Shape[i] == 1]
if length(singletons) < to_drop
error("cannot squeeze shape $Shape to rank $N: only $(length(singletons)) singleton dims but need to drop $to_drop")
end
drop_set = Set(singletons[1:to_drop])
kept = tuple((Shape[i] for i in 1:M if !(i in drop_set))...)
# Partition views require at least 1D
if isempty(kept)
kept = (1,)
end
return :($kept)
# Keyword argument version → extract and delegate
@inline function load(arr::TileArray; index, shape, kwargs...)
load(arr, index, _extract_shape(shape); kwargs...)
end

"""
Expand All @@ -183,18 +165,14 @@ Returns the stored tile (enables chaining and helps constant folding).
order::Union{NTuple{<:Any, Int}, Nothing}=nothing,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true) where {T}
reshaped = _reshape_for_store(tile, Val(ndims(arr)))
reshaped = _reshape_to_rank(tile, Val(ndims(arr)))
_store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...) .- One())
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
end

@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T};
order::Union{NTuple{<:Any, Int}, Nothing}=nothing,
latency::Union{Int, Nothing}=nothing,
allow_tma::Bool=true) where {T}
reshaped = _reshape_for_store(tile, Val(ndims(arr)))
_store_reshaped(arr, reshaped, order, latency, allow_tma, (index - One(),))
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
# Scalar index → wrap in tuple
@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T}; kwargs...) where {T}
store(arr, (index,), tile; kwargs...)
end

@inline function _store_reshaped(arr::TileArray{T}, tile::Tile{T},
Expand Down Expand Up @@ -496,8 +474,10 @@ tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8), 32 elements
reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements
```
"""
@inline Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T} =
@inline function Base.reshape(tile::Tile{T}, shape::NTuple{<:Any, Int}) where {T}
size(tile) === shape && return tile
Intrinsics.reshape(tile, shape)
end

"""
permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape}
Expand Down
65 changes: 54 additions & 11 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@
end
end

# 1D -> 1D reshape (no permutes needed - optimization)
# 1D -> 1D same-shape reshape is a no-op
@test @filecheck begin
@check_label "entry"
@check_not "permute" # should NOT have permute for 1D->1D
@check_not "permute"
@check_not "reshape"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (32,))
@check "reshape"
reshaped = reshape(tile, (32,))
ct.store(a, pid, reshaped)
return
Expand All @@ -165,6 +165,34 @@
end
end

@testset "dropdims" begin
# dropdims on dim 1: (1, 8) -> dropdims(; dims=2) -> (8,)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, (pid, 1), (1, 8))
@check "reshape"
squeezed = dropdims(tile; dims=1)
ct.store(b, pid, squeezed)
return
end
end

# dropdims on dim 2: (8, 1) -> dropdims(; dims=2) -> (8,)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, (1, pid), (8, 1))
@check "reshape"
squeezed = dropdims(tile; dims=2)
ct.store(b, pid, squeezed)
return
end
end
end

@testset "permutedims" begin
# 2D permutedims with explicit perm (same as transpose)
@test @filecheck begin
Expand Down Expand Up @@ -1351,6 +1379,21 @@
end
end

@testset "rank mismatch load/store" begin
# 1D shape on 2D array: should pad shape to (16, 1) internally
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b
pid = ct.bid(1)
@check "load_view_tko"
tile = ct.load(a, (pid, 1), (16,))
@check "store_view_tko"
ct.store(b, (pid, 1), tile)
return
end
end
end

@testset "num_tiles helper" begin
spec = ct.ArraySpec{2}(16, true)
@test @filecheck begin
Expand Down Expand Up @@ -1741,7 +1784,7 @@ end
TILE_N = 1024

# Use ArraySpec with shape_div_by to match real CuArray behavior
spec2d = ct.ArraySpec{2}(128, true, (0, 4), (32, 32))
spec2d = ct.ArraySpec{2}(128, true, (4, 0), (32, 32))
spec1d = ct.ArraySpec{1}(128, true, (0,), (32,))

@test @filecheck begin
Expand All @@ -1755,19 +1798,19 @@ end
ct.TileArray{Float32, 1, spec1d}, ct.TileArray{Float32, 1, spec1d},
ct.Constant{Int, TILE_M}, ct.Constant{Int, TILE_N}}) do DW, DB, FINAL_DW, FINAL_DB, _TILE_M, _TILE_N
bid_n = ct.bid(1)
num_tiles = ct.num_tiles(DW, 1, (_TILE_M[], _TILE_N[]))
num_tiles = ct.num_tiles(DW, 2, (_TILE_N[], _TILE_M[]))

dw = ct.zeros((_TILE_M[], _TILE_N[]), Float32)
db = ct.zeros((_TILE_M[], _TILE_N[]), Float32)
dw = ct.zeros((_TILE_N[], _TILE_M[]), Float32)
db = ct.zeros((_TILE_N[], _TILE_M[]), Float32)
i = Int32(1)
while i <= num_tiles
dw = dw .+ ct.load(DW, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero)
db = db .+ ct.load(DB, (i, bid_n), (_TILE_M[], _TILE_N[]); padding_mode=ct.PaddingMode.Zero)
dw = dw .+ ct.load(DW, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero)
db = db .+ ct.load(DB, (bid_n, i), (_TILE_N[], _TILE_M[]); padding_mode=ct.PaddingMode.Zero)
i += Int32(1)
end

sum_dw = sum(dw; dims=1)
sum_db = sum(db; dims=1)
sum_dw = sum(dw; dims=2)
sum_db = sum(db; dims=2)

ct.store(FINAL_DW, bid_n, sum_dw)
ct.store(FINAL_DB, bid_n, sum_db)
Expand Down
Loading