diff --git a/README.md b/README.md index 261d095..c0eba6a 100644 --- a/README.md +++ b/README.md @@ -395,6 +395,24 @@ b = ct.load(...) # (M, N) result = reshape(a, (1, N)) .+ b # (1, N) .+ (M, N) → (M, N) ``` +### Scalar access and 0-D tiles + +cuTile Python represents single-element loads as 0-D tiles (`shape=()`), which can be used +directly as indices. cuTile.jl uses Julia's standard indexing syntax instead — `getindex` +returns a scalar `T` and `setindex!` stores a scalar: + +```python +# Python +expert_id = ct.load(ids, index=bid_m, shape=()) +b = ct.load(B, (expert_id, k, bid_n), shape=(1, TILE_K, TILE_N)) +``` + +```julia +# Julia +expert_id = ids[bid_m] +b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N)) +``` + ## Limitations diff --git a/src/language/operations.jl b/src/language/operations.jl index 1f1ec95..2fd53bd 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -141,6 +141,41 @@ end load(arr, index, _extract_shape(shape); kwargs...) end +# Scalar indexing: arr[i, j, ...] → scalar T +@overlay function Base.getindex(arr::TileArray{T, N}, indices::Vararg{Integer, N}) where {T, N} + tv = Intrinsics.make_tensor_view(arr) + shape = ntuple(_ -> 1, Val(N)) + pv = Intrinsics.make_partition_view(tv, Val(shape), PaddingMode.Undetermined) + tile = Intrinsics.load_partition_view(pv, nothing, true, promote(indices...) .- One()) + Intrinsics.to_scalar(reshape(tile, ())) +end + +# Scalar indexing: tile[i, j, ...] → scalar T +@inline function Base.getindex(tile::Tile, indices::Vararg{Int}) + shape = ntuple(_ -> 1, Val(length(indices))) + subtile = extract(tile, indices, shape) + Intrinsics.to_scalar(reshape(subtile, ())) +end + +# Functional setindex: Base.setindex(tile, val, i, j, ...) → new Tile with element replaced +@inline function Base.setindex(tile::Tile, val, indices::Vararg{Int}) + T = eltype(tile) + S = size(tile) + flat_len = prod(S) + linear = _linear_index(S, indices) + flat = reshape(tile, (flat_len,)) + idx = Intrinsics.iota((flat_len,), Int32) + mask = idx .== Int32(linear) + val_tile = broadcast_to(Tile(T(val)), (flat_len,)) + new_flat = where(mask, val_tile, flat) + reshape(new_flat, S) +end + +# 0-indexed column-major linear index from 1-indexed indices +@inline _linear_index(::Tuple{}, ::Tuple{}) = 0 +@inline _linear_index(S::NTuple{N, Int}, indices::NTuple{N, Int}) where {N} = + (indices[1] - 1) + S[1] * _linear_index(Base.tail(S), Base.tail(indices)) + # Keyword argument version → extract and delegate @inline function load(arr::TileArray; index, shape, kwargs...) load(arr, index, _extract_shape(shape); kwargs...) @@ -190,6 +225,17 @@ end store(arr, index, tile; order, latency, allow_tma) end +# Scalar store: arr[i, j, ...] = val +# NOTE: Cannot use @overlay (which adds @assume_effects :foldable) because +# setindex! is a side-effecting operation returning nothing — the compiler +# would DCE the entire call as a pure function with unused result. +Base.Experimental.@consistent_overlay cuTileMethodTable function Base.setindex!(arr::TileArray{T, N}, val::T, indices::Vararg{Integer, N}) where {T, N} + shape = ntuple(_ -> 1, Val(N)) + tile = reshape(Intrinsics.from_scalar(val, Val(Tuple{})), shape) + store(arr, indices, tile) + return +end + """ gather(array::TileArray{T, 1}, indices::Tile{I, S}; latency=nothing) -> Tile{T, S} diff --git a/test/codegen.jl b/test/codegen.jl index 58a056d..1c2f292 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -277,6 +277,33 @@ end end + # Scalar tile getindex + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a + tile = ct.load(a, 1, (8,)) + @check "extract" + # extract produces tile<1xf32>, reshape to scalar tile + @check "tile<1xf32> -> tile" + scalar = tile[3] + ct.store(a, 1, ct.broadcast_to(ct.Tile(scalar), (8,))) + return + end + end + + # Scalar tile setindex (functional) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a + tile = ct.load(a, 1, (8,)) + @check "iota" + @check "select" + new_tile = Base.setindex(tile, 0.0f0, 3) + ct.store(a, 1, new_tile) + return + end + end + # Extract slice from 3D tile (FFT real/imag pattern) @test @filecheck begin @check_label "entry" diff --git a/test/execution.jl b/test/execution.jl index 7c0d271..089a673 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -650,6 +650,36 @@ end end end +@testset "scalar tile getindex" begin + function tile_getindex_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1}) + tile = ct.load(x, 1, (8,)) + scalar = tile[3] # Extract 3rd element + ct.store(y, 1, ct.broadcast_to(ct.Tile(scalar), (8,))) + return + end + host_x = zeros(Float32, 8) + host_x[3] = 42.0f0 + x = CuArray(host_x) + y = CUDA.zeros(Float32, 8) + ct.launch(tile_getindex_kernel, 1, x, y) + @test all(Array(y) .≈ 42.0f0) +end + +@testset "scalar tile setindex" begin + function tile_setindex_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1}) + tile = ct.load(x, 1, (8,)) + new_tile = Base.setindex(tile, 0.0f0, 3) + ct.store(y, 1, new_tile) + return + end + x = CuArray(Float32.(1:8)) + y = CUDA.zeros(Float32, 8) + ct.launch(tile_setindex_kernel, 1, x, y) + expected = Float32.(1:8) + expected[3] = 0.0f0 + @test Array(y) ≈ expected +end + @testset "cat" begin @testset "cat along last axis (axis -1)" begin function cat_last_axis_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,2},