Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ uses standard Julia syntax and is overlaid on `Base`.
### Type Conversion
| Operation | Description |
|-----------|-------------|
| `ct.astype(tile, T)` | Convert element type |
| `convert(Tile{T}, tile)` | Julia-style conversion |
| `convert(Tile{T}, tile)` | Convert element type |
| `T.(tile)` | Broadcasting conversion (e.g. `Float16.(tile)`) |

### Integer Arithmetic
| Operation | Description |
Expand Down Expand Up @@ -414,6 +414,15 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N))
```


## Differences from Julia

### Float-to-integer conversion truncates

Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors
truncate toward zero (like C-style casts), rather than throwing `InexactError` as in
standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`.


## Limitations

### `for` loops
Expand Down
20 changes: 20 additions & 0 deletions ext/DLFP8TypesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,24 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2})
return ct.F8E5M2(table)
end

# Float ↔ FP8 scalar constructor overlays (for map/convert dispatch)
const FP8Types = (Float8_E4M3FN, Float8_E5M2)
const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64)

for F8 in FP8Types
# Standard float → FP8
for F in StandardFloats
@eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F8(x::$F) = ct.Intrinsics.ftof(x, $F8)
end
# FP8 → standard float
for F in StandardFloats
@eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F(x::$F8) = ct.Intrinsics.ftof(x, $F)
end
# FP8 → FP8
for F8b in FP8Types
F8 === F8b && continue
@eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F8(x::$F8b) = ct.Intrinsics.ftof(x, $F8)
end
end

end
72 changes: 0 additions & 72 deletions src/compiler/intrinsics/conversions.jl
Original file line number Diff line number Diff line change
@@ -1,77 +1,5 @@
# Type conversions

# cuda_tile.astype (high-level tile conversion)
@eval Intrinsics begin
"""
astype(tile, T2)

Convert tile element type from T1 to T2.
Compiled to cuda_tile.ftof, cuda_tile.ftoi, cuda_tile.itof,
cuda_tile.exti, or cuda_tile.trunci based on source/target types.
"""
@noinline function astype(tile::Tile{T1, Shape}, ::Type{T2}) where {T1, Shape, T2}
Tile{T2, Shape}()
end
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.astype), args)
cb = ctx.cb
tt = ctx.tt

# Get source tile
source = @something emit_value!(ctx, args[1]) throw(IRError("astype: cannot resolve source"))

# Get source element type and shape
source_type = CC.widenconst(source.jltype)
source_elem = eltype(source_type)
tile_shape = source.shape

# Get target element type from the Type argument
target_elem = @something get_constant(ctx, args[2]) throw(IRError("astype() requires a compile-time constant type"))
target_elem isa Type || throw(IRError("astype() second argument must be a Type"))

# Same type? Return source unchanged
if source_elem === target_elem
return source
end

# Create target type
target_dtype = julia_to_tile_dtype!(tt, target_elem)
target_tile_type = tile_type!(tt, target_dtype, tile_shape)

# Emit conversion based on source and target types
result = if source_elem <: AbstractFloat && target_elem <: AbstractFloat
# Float -> Float
encode_FToFOp!(cb, target_tile_type, source.v)
elseif source_elem <: Integer && target_elem <: AbstractFloat
# Integer -> Float
signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_IToFOp!(cb, target_tile_type, source.v; signedness)
elseif source_elem <: AbstractFloat && target_elem <: Integer
# Float -> Integer
signedness = target_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_FToIOp!(cb, target_tile_type, source.v; signedness)
elseif source_elem <: Integer && target_elem <: Integer
# Integer -> Integer
source_size = sizeof(source_elem)
target_size = sizeof(target_elem)
if source_size == target_size
# Same size - no conversion needed (just reinterpret)
source.v
elseif target_size > source_size
# Extension (upsize)
signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_ExtIOp!(cb, target_tile_type, source.v; signedness)
else
# Truncation (downsize)
encode_TruncIOp!(cb, target_tile_type, source.v)
end
else
throw(IRError("astype() unsupported conversion: $source_elem -> $target_elem"))
end

CGVal(result, target_tile_type, Tile{target_elem, Tuple{tile_shape...}}, tile_shape)
end

# TODO: cuda_tile.bitcast

# cuda_tile.exti (scalar integer extension)
Expand Down
8 changes: 7 additions & 1 deletion src/compiler/intrinsics/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ end
# to_scalar: jltype becomes scalar T (for overlay dispatch), but IR value stays shaped.
# from_scalar: restores jltype to Tile{T, S}.
@eval Intrinsics begin
@noinline to_scalar(tile::Tile{T, S}) where {T, S} = compilerbarrier(:const, T(0))
@noinline to_scalar(tile::Tile{T, S}) where {T, S} = compilerbarrier(:type, nothing)
@noinline from_scalar(x::T, ::Type{S}) where {T, S} = Tile{T, S}()
end
function tfunc(::typeof(Intrinsics.from_scalar), argtypes::Vector{Any})
Expand All @@ -959,6 +959,12 @@ function tfunc(::typeof(Intrinsics.from_scalar), argtypes::Vector{Any})
S = shape_type.parameters[1]
return Tile{T, S}
end
function tfunc(::typeof(Intrinsics.to_scalar), argtypes::Vector{Any})
length(argtypes) >= 2 || return nothing
tile_type = CC.widenconst(argtypes[2])
tile_type <: Tile || return nothing
return eltype(tile_type)
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.to_scalar), args)
tv = emit_value!(ctx, args[1])
tv === nothing && throw(IRError("Cannot resolve tile for to_scalar"))
Expand Down
33 changes: 9 additions & 24 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ tile = ct.gather(arr, indices; latency=3)
indices_0 = indices .- one(I)

# Convert to Int32 for consistency with array.sizes
indices_i32 = astype(indices_0, Int32)
indices_i32 = convert(Tile{Int32}, indices_0)

# Compute pointer tile
ptr_tile = Intrinsics.offset(array.ptr, indices_i32)
Expand Down Expand Up @@ -297,8 +297,8 @@ Indices are 1-indexed. Index tiles are broadcast to a common shape.
idx1_bc = broadcast_to(idx1_0, S)

# Convert to Int32 for linear index computation
idx0_i32 = astype(idx0_bc, Int32)
idx1_i32 = astype(idx1_bc, Int32)
idx0_i32 = convert(Tile{Int32}, idx0_bc)
idx1_i32 = convert(Tile{Int32}, idx1_bc)

# Get strides and broadcast to tile shape
stride0_0d = Tile(array.strides[1])
Expand Down Expand Up @@ -350,7 +350,7 @@ ct.scatter(arr, indices, result_tile; latency=3)
indices_0 = indices .- one(I)

# Convert to Int32 for consistency with array.sizes
indices_i32 = astype(indices_0, Int32)
indices_i32 = convert(Tile{Int32}, indices_0)

# Compute pointer tile
ptr_tile = Intrinsics.offset(array.ptr, indices_i32)
Expand Down Expand Up @@ -387,8 +387,8 @@ Indices are 1-indexed. Index tiles and value tile must broadcast to same shape.
tile_bc = broadcast_to(tile, S)

# Convert to Int32 for linear index computation
idx0_i32 = astype(idx0_bc, Int32)
idx1_i32 = astype(idx1_bc, Int32)
idx0_i32 = convert(Tile{Int32}, idx0_bc)
idx1_i32 = convert(Tile{Int32}, idx1_bc)

# Get strides and broadcast to tile shape
stride0_0d = Tile(array.strides[1])
Expand Down Expand Up @@ -468,7 +468,7 @@ zeros_tile = ct.zeros((32, 32), Float32)
Shape & DType
=============================================================================#

public cat, broadcast_to, astype
public cat, broadcast_to

"""
cat(tiles::Tuple{Tile, Tile}, axis::Int) -> Tile
Expand Down Expand Up @@ -580,23 +580,8 @@ Equivalent to `permute(tile, (2, 1))`.
@inline Base.transpose(tile::Tile{T}) where {T} =
Intrinsics.transpose(tile)

"""
astype(tile::Tile{T1, Shape}, ::Type{T2}) -> Tile{T2, Shape}

Convert a tile's element type from T1 to T2.

# Example
```julia
acc = ct.full((64, 64), 0.0f0, Float32)
result = ct.astype(acc, ct.TFloat32) # Convert to TF32 for tensor cores
```
"""
@inline astype(tile::Tile{T1, Shape}, ::Type{T2}) where {T1, Shape, T2} =
Intrinsics.astype(tile, T2)

# Julia-style convert syntax
@inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} =
astype(tile, T2)
map(T2, tile)

#=============================================================================
Reduction
Expand Down Expand Up @@ -754,7 +739,7 @@ n_positive = count(tile .> 0.0f0; dims=1)
```
"""
@inline function Base.count(tile::Tile{Bool,S}; dims::Integer) where {S}
sum(astype(tile, Int32); dims)
sum(convert(Tile{Int32}, tile); dims)
end

"""
Expand Down
12 changes: 11 additions & 1 deletion src/language/overlays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
# Generic overlays don't take precedence over Core's Int64(x::BuiltinInts) etc.
const SignedInts = (Int8, Int16, Int32, Int64)
const UnsignedInts = (UInt8, UInt16, UInt32, UInt64)
const Floats = (Float16, Float32, Float64)
const Floats = (Float16, BFloat16, Float32, TFloat32, Float64)

# Integer to integer (specific type pairs for promotion/truncation)
for T in SignedInts, S in SignedInts
Expand Down Expand Up @@ -77,3 +77,13 @@ for F in Floats
@eval @overlay Base.unsafe_trunc(::Type{$I}, x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned)
end
end

# Float to integer (direct constructor - truncates like C-style cast)
for F in Floats
for I in SignedInts
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned)
end
for I in UnsignedInts
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned)
end
end
46 changes: 42 additions & 4 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@
@check "select"
result = a .< b
# Use same-typed operands for where to avoid Union type
b_promoted = ct.astype(b, Int64)
b_promoted = convert(ct.Tile{Int64}, b)
selected = ct.where(result, a, b_promoted)
ct.store(out, Int32(0), selected)
return
Expand Down Expand Up @@ -779,7 +779,7 @@
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = ct.astype(tile, Float16)
converted = convert(ct.Tile{Float16}, tile)
ct.store(b, pid, converted)
return
end
Expand All @@ -793,7 +793,7 @@
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{ct.TFloat32}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
ct.store(b, pid, convert(ct.Tile{Float32}, converted))
return
end
end
Expand All @@ -806,7 +806,45 @@
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{ct.BFloat16}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
ct.store(b, pid, convert(ct.Tile{Float32}, converted))
return
end
end

# Broadcasting syntax: Float16.(tile)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float16,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
ct.store(b, pid, Float16.(tile))
return
end
end

# Broadcasting syntax: BFloat16.(tile)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
@check "ftof"
ct.store(b, pid, Float32.(ct.BFloat16.(tile)))
return
end
end

# Broadcasting syntax: TFloat32.(tile)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
@check "ftof"
ct.store(b, pid, Float32.(ct.TFloat32.(tile)))
return
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ end
tile = ct.load(a, ct.bid(1), (tileSz[],))
mask = tile .> 0.0f0
result = any(mask; dims=1)
ct.store(b, ct.bid(1), ct.astype(result, Int32))
ct.store(b, ct.bid(1), convert(ct.Tile{Int32}, result))
return nothing
end

Expand All @@ -2278,7 +2278,7 @@ end
tile = ct.load(a, ct.bid(1), (tileSz[],))
mask = tile .> 0.0f0
result = all(mask; dims=1)
ct.store(b, ct.bid(1), ct.astype(result, Int32))
ct.store(b, ct.bid(1), convert(ct.Tile{Int32}, result))
return nothing
end

Expand Down
4 changes: 2 additions & 2 deletions test/ext/DLFP8TypesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ spec1d = ct.ArraySpec{1}(16, true)
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{Float8_E4M3FN}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
ct.store(b, pid, convert(ct.Tile{Float32}, converted))
return
end
end
Expand All @@ -25,7 +25,7 @@ end
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{Float8_E5M2}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
ct.store(b, pid, convert(ct.Tile{Float32}, converted))
return
end
end
Expand Down