Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,24 @@ b = ct.load(...) # (M, N)
result = reshape(a, (1, N)) .+ b # (1, N) .+ (M, N) → (M, N)
```

### Scalar access and 0-D tiles

cuTile Python represents single-element loads as 0-D tiles (`shape=()`), which can be used
directly as indices. cuTile.jl uses Julia's standard indexing syntax instead — `getindex`
returns a scalar `T` and `setindex!` stores a scalar:

```python
# Python
expert_id = ct.load(ids, index=bid_m, shape=())
b = ct.load(B, (expert_id, k, bid_n), shape=(1, TILE_K, TILE_N))
```

```julia
# Julia
expert_id = ids[bid_m]
b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N))
```


## Limitations

Expand Down
46 changes: 46 additions & 0 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,41 @@ end
load(arr, index, _extract_shape(shape); kwargs...)
end

# Scalar indexing: arr[i, j, ...] → scalar T
@overlay function Base.getindex(arr::TileArray{T, N}, indices::Vararg{Integer, N}) where {T, N}
tv = Intrinsics.make_tensor_view(arr)
shape = ntuple(_ -> 1, Val(N))
pv = Intrinsics.make_partition_view(tv, Val(shape), PaddingMode.Undetermined)
tile = Intrinsics.load_partition_view(pv, nothing, true, promote(indices...) .- One())
Intrinsics.to_scalar(reshape(tile, ()))
end

# Scalar indexing: tile[i, j, ...] → scalar T
@inline function Base.getindex(tile::Tile, indices::Vararg{Int})
shape = ntuple(_ -> 1, Val(length(indices)))
subtile = extract(tile, indices, shape)
Intrinsics.to_scalar(reshape(subtile, ()))
end

# Functional setindex: Base.setindex(tile, val, i, j, ...) → new Tile with element replaced
@inline function Base.setindex(tile::Tile, val, indices::Vararg{Int})
T = eltype(tile)
S = size(tile)
flat_len = prod(S)
linear = _linear_index(S, indices)
flat = reshape(tile, (flat_len,))
idx = Intrinsics.iota((flat_len,), Int32)
mask = idx .== Int32(linear)
val_tile = broadcast_to(Tile(T(val)), (flat_len,))
new_flat = where(mask, val_tile, flat)
reshape(new_flat, S)
end

# 0-indexed column-major linear index from 1-indexed indices
@inline _linear_index(::Tuple{}, ::Tuple{}) = 0
@inline _linear_index(S::NTuple{N, Int}, indices::NTuple{N, Int}) where {N} =
(indices[1] - 1) + S[1] * _linear_index(Base.tail(S), Base.tail(indices))

# Keyword argument version → extract and delegate
@inline function load(arr::TileArray; index, shape, kwargs...)
load(arr, index, _extract_shape(shape); kwargs...)
Expand Down Expand Up @@ -190,6 +225,17 @@ end
store(arr, index, tile; order, latency, allow_tma)
end

# Scalar store: arr[i, j, ...] = val
# NOTE: Cannot use @overlay (which adds @assume_effects :foldable) because
# setindex! is a side-effecting operation returning nothing — the compiler
# would DCE the entire call as a pure function with unused result.
Base.Experimental.@consistent_overlay cuTileMethodTable function Base.setindex!(arr::TileArray{T, N}, val::T, indices::Vararg{Integer, N}) where {T, N}
shape = ntuple(_ -> 1, Val(N))
tile = reshape(Intrinsics.from_scalar(val, Val(Tuple{})), shape)
store(arr, indices, tile)
return
end

"""
gather(array::TileArray{T, 1}, indices::Tile{I, S}; latency=nothing) -> Tile{T, S}

Expand Down
27 changes: 27 additions & 0 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,33 @@
end
end

# Scalar tile getindex
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
tile = ct.load(a, 1, (8,))
@check "extract"
# extract produces tile<1xf32>, reshape to scalar tile<f32>
@check "tile<1xf32> -> tile<f32>"
scalar = tile[3]
ct.store(a, 1, ct.broadcast_to(ct.Tile(scalar), (8,)))
return
end
end

# Scalar tile setindex (functional)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
tile = ct.load(a, 1, (8,))
@check "iota"
@check "select"
new_tile = Base.setindex(tile, 0.0f0, 3)
ct.store(a, 1, new_tile)
return
end
end

# Extract slice from 3D tile (FFT real/imag pattern)
@test @filecheck begin
@check_label "entry"
Expand Down
30 changes: 30 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,36 @@ end
end
end

@testset "scalar tile getindex" begin
function tile_getindex_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1})
tile = ct.load(x, 1, (8,))
scalar = tile[3] # Extract 3rd element
ct.store(y, 1, ct.broadcast_to(ct.Tile(scalar), (8,)))
return
end
host_x = zeros(Float32, 8)
host_x[3] = 42.0f0
x = CuArray(host_x)
y = CUDA.zeros(Float32, 8)
ct.launch(tile_getindex_kernel, 1, x, y)
@test all(Array(y) .≈ 42.0f0)
end

@testset "scalar tile setindex" begin
function tile_setindex_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1})
tile = ct.load(x, 1, (8,))
new_tile = Base.setindex(tile, 0.0f0, 3)
ct.store(y, 1, new_tile)
return
end
x = CuArray(Float32.(1:8))
y = CUDA.zeros(Float32, 8)
ct.launch(tile_setindex_kernel, 1, x, y)
expected = Float32.(1:8)
expected[3] = 0.0f0
@test Array(y) ≈ expected
end

@testset "cat" begin
@testset "cat along last axis (axis -1)" begin
function cat_last_axis_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,2},
Expand Down