From b1ba1befcd971d9feaafa04cf1a455cbba4ed735 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sun, 8 Feb 2026 21:52:14 +0100 Subject: [PATCH] Replace astype by idiomatic convert/broadcast. Also rework internals so that we can call truncating intrinsics from the language instead of relying on codegen. --- README.md | 13 ++++- ext/DLFP8TypesExt.jl | 20 +++++++ src/compiler/intrinsics/conversions.jl | 72 -------------------------- src/compiler/intrinsics/core.jl | 8 ++- src/language/operations.jl | 33 ++++-------- src/language/overlays.jl | 12 ++++- test/codegen.jl | 46 ++++++++++++++-- test/execution.jl | 4 +- test/ext/DLFP8TypesExt.jl | 4 +- 9 files changed, 104 insertions(+), 108 deletions(-) diff --git a/README.md b/README.md index c0eba6a..52dc333 100644 --- a/README.md +++ b/README.md @@ -208,8 +208,8 @@ uses standard Julia syntax and is overlaid on `Base`. ### Type Conversion | Operation | Description | |-----------|-------------| -| `ct.astype(tile, T)` | Convert element type | -| `convert(Tile{T}, tile)` | Julia-style conversion | +| `convert(Tile{T}, tile)` | Convert element type | +| `T.(tile)` | Broadcasting conversion (e.g. `Float16.(tile)`) | ### Integer Arithmetic | Operation | Description | @@ -414,6 +414,15 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N)) ``` +## Differences from Julia + +### Float-to-integer conversion truncates + +Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors +truncate toward zero (like C-style casts), rather than throwing `InexactError` as in +standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`. + + ## Limitations ### `for` loops diff --git a/ext/DLFP8TypesExt.jl b/ext/DLFP8TypesExt.jl index 0d4cdbc..bd22663 100644 --- a/ext/DLFP8TypesExt.jl +++ b/ext/DLFP8TypesExt.jl @@ -12,4 +12,24 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) return ct.F8E5M2(table) end +# Float ↔ FP8 scalar constructor overlays (for map/convert dispatch) +const FP8Types = (Float8_E4M3FN, Float8_E5M2) +const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64) + +for F8 in FP8Types + # Standard float → FP8 + for F in StandardFloats + @eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F8(x::$F) = ct.Intrinsics.ftof(x, $F8) + end + # FP8 → standard float + for F in StandardFloats + @eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F(x::$F8) = ct.Intrinsics.ftof(x, $F) + end + # FP8 → FP8 + for F8b in FP8Types + F8 === F8b && continue + @eval Base.Experimental.@consistent_overlay ct.cuTileMethodTable Base.@assume_effects :foldable $F8(x::$F8b) = ct.Intrinsics.ftof(x, $F8) + end +end + end diff --git a/src/compiler/intrinsics/conversions.jl b/src/compiler/intrinsics/conversions.jl index 5b522c4..6c33afc 100644 --- a/src/compiler/intrinsics/conversions.jl +++ b/src/compiler/intrinsics/conversions.jl @@ -1,77 +1,5 @@ # Type conversions -# cuda_tile.astype (high-level tile conversion) -@eval Intrinsics begin - """ - astype(tile, T2) - - Convert tile element type from T1 to T2. - Compiled to cuda_tile.ftof, cuda_tile.ftoi, cuda_tile.itof, - cuda_tile.exti, or cuda_tile.trunci based on source/target types. - """ - @noinline function astype(tile::Tile{T1, Shape}, ::Type{T2}) where {T1, Shape, T2} - Tile{T2, Shape}() - end -end -function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.astype), args) - cb = ctx.cb - tt = ctx.tt - - # Get source tile - source = @something emit_value!(ctx, args[1]) throw(IRError("astype: cannot resolve source")) - - # Get source element type and shape - source_type = CC.widenconst(source.jltype) - source_elem = eltype(source_type) - tile_shape = source.shape - - # Get target element type from the Type argument - target_elem = @something get_constant(ctx, args[2]) throw(IRError("astype() requires a compile-time constant type")) - target_elem isa Type || throw(IRError("astype() second argument must be a Type")) - - # Same type? Return source unchanged - if source_elem === target_elem - return source - end - - # Create target type - target_dtype = julia_to_tile_dtype!(tt, target_elem) - target_tile_type = tile_type!(tt, target_dtype, tile_shape) - - # Emit conversion based on source and target types - result = if source_elem <: AbstractFloat && target_elem <: AbstractFloat - # Float -> Float - encode_FToFOp!(cb, target_tile_type, source.v) - elseif source_elem <: Integer && target_elem <: AbstractFloat - # Integer -> Float - signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned - encode_IToFOp!(cb, target_tile_type, source.v; signedness) - elseif source_elem <: AbstractFloat && target_elem <: Integer - # Float -> Integer - signedness = target_elem <: Signed ? SignednessSigned : SignednessUnsigned - encode_FToIOp!(cb, target_tile_type, source.v; signedness) - elseif source_elem <: Integer && target_elem <: Integer - # Integer -> Integer - source_size = sizeof(source_elem) - target_size = sizeof(target_elem) - if source_size == target_size - # Same size - no conversion needed (just reinterpret) - source.v - elseif target_size > source_size - # Extension (upsize) - signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned - encode_ExtIOp!(cb, target_tile_type, source.v; signedness) - else - # Truncation (downsize) - encode_TruncIOp!(cb, target_tile_type, source.v) - end - else - throw(IRError("astype() unsupported conversion: $source_elem -> $target_elem")) - end - - CGVal(result, target_tile_type, Tile{target_elem, Tuple{tile_shape...}}, tile_shape) -end - # TODO: cuda_tile.bitcast # cuda_tile.exti (scalar integer extension) diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index bcb3eca..b64fbcf 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -948,7 +948,7 @@ end # to_scalar: jltype becomes scalar T (for overlay dispatch), but IR value stays shaped. # from_scalar: restores jltype to Tile{T, S}. @eval Intrinsics begin - @noinline to_scalar(tile::Tile{T, S}) where {T, S} = compilerbarrier(:const, T(0)) + @noinline to_scalar(tile::Tile{T, S}) where {T, S} = compilerbarrier(:type, nothing) @noinline from_scalar(x::T, ::Type{S}) where {T, S} = Tile{T, S}() end function tfunc(::typeof(Intrinsics.from_scalar), argtypes::Vector{Any}) @@ -959,6 +959,12 @@ function tfunc(::typeof(Intrinsics.from_scalar), argtypes::Vector{Any}) S = shape_type.parameters[1] return Tile{T, S} end +function tfunc(::typeof(Intrinsics.to_scalar), argtypes::Vector{Any}) + length(argtypes) >= 2 || return nothing + tile_type = CC.widenconst(argtypes[2]) + tile_type <: Tile || return nothing + return eltype(tile_type) +end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.to_scalar), args) tv = emit_value!(ctx, args[1]) tv === nothing && throw(IRError("Cannot resolve tile for to_scalar")) diff --git a/src/language/operations.jl b/src/language/operations.jl index 50bdb7a..cbf8597 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -258,7 +258,7 @@ tile = ct.gather(arr, indices; latency=3) indices_0 = indices .- one(I) # Convert to Int32 for consistency with array.sizes - indices_i32 = astype(indices_0, Int32) + indices_i32 = convert(Tile{Int32}, indices_0) # Compute pointer tile ptr_tile = Intrinsics.offset(array.ptr, indices_i32) @@ -297,8 +297,8 @@ Indices are 1-indexed. Index tiles are broadcast to a common shape. idx1_bc = broadcast_to(idx1_0, S) # Convert to Int32 for linear index computation - idx0_i32 = astype(idx0_bc, Int32) - idx1_i32 = astype(idx1_bc, Int32) + idx0_i32 = convert(Tile{Int32}, idx0_bc) + idx1_i32 = convert(Tile{Int32}, idx1_bc) # Get strides and broadcast to tile shape stride0_0d = Tile(array.strides[1]) @@ -350,7 +350,7 @@ ct.scatter(arr, indices, result_tile; latency=3) indices_0 = indices .- one(I) # Convert to Int32 for consistency with array.sizes - indices_i32 = astype(indices_0, Int32) + indices_i32 = convert(Tile{Int32}, indices_0) # Compute pointer tile ptr_tile = Intrinsics.offset(array.ptr, indices_i32) @@ -387,8 +387,8 @@ Indices are 1-indexed. Index tiles and value tile must broadcast to same shape. tile_bc = broadcast_to(tile, S) # Convert to Int32 for linear index computation - idx0_i32 = astype(idx0_bc, Int32) - idx1_i32 = astype(idx1_bc, Int32) + idx0_i32 = convert(Tile{Int32}, idx0_bc) + idx1_i32 = convert(Tile{Int32}, idx1_bc) # Get strides and broadcast to tile shape stride0_0d = Tile(array.strides[1]) @@ -468,7 +468,7 @@ zeros_tile = ct.zeros((32, 32), Float32) Shape & DType =============================================================================# -public cat, broadcast_to, astype +public cat, broadcast_to """ cat(tiles::Tuple{Tile, Tile}, axis::Int) -> Tile @@ -580,23 +580,8 @@ Equivalent to `permute(tile, (2, 1))`. @inline Base.transpose(tile::Tile{T}) where {T} = Intrinsics.transpose(tile) -""" - astype(tile::Tile{T1, Shape}, ::Type{T2}) -> Tile{T2, Shape} - -Convert a tile's element type from T1 to T2. - -# Example -```julia -acc = ct.full((64, 64), 0.0f0, Float32) -result = ct.astype(acc, ct.TFloat32) # Convert to TF32 for tensor cores -``` -""" -@inline astype(tile::Tile{T1, Shape}, ::Type{T2}) where {T1, Shape, T2} = - Intrinsics.astype(tile, T2) - -# Julia-style convert syntax @inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} = - astype(tile, T2) + map(T2, tile) #============================================================================= Reduction @@ -754,7 +739,7 @@ n_positive = count(tile .> 0.0f0; dims=1) ``` """ @inline function Base.count(tile::Tile{Bool,S}; dims::Integer) where {S} - sum(astype(tile, Int32); dims) + sum(convert(Tile{Int32}, tile); dims) end """ diff --git a/src/language/overlays.jl b/src/language/overlays.jl index b2ef4b0..7094634 100644 --- a/src/language/overlays.jl +++ b/src/language/overlays.jl @@ -16,7 +16,7 @@ end # Generic overlays don't take precedence over Core's Int64(x::BuiltinInts) etc. const SignedInts = (Int8, Int16, Int32, Int64) const UnsignedInts = (UInt8, UInt16, UInt32, UInt64) -const Floats = (Float16, Float32, Float64) +const Floats = (Float16, BFloat16, Float32, TFloat32, Float64) # Integer to integer (specific type pairs for promotion/truncation) for T in SignedInts, S in SignedInts @@ -77,3 +77,13 @@ for F in Floats @eval @overlay Base.unsafe_trunc(::Type{$I}, x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned) end end + +# Float to integer (direct constructor - truncates like C-style cast) +for F in Floats + for I in SignedInts + @eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned) + end + for I in UnsignedInts + @eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned) + end +end diff --git a/test/codegen.jl b/test/codegen.jl index e452e31..4da6e0f 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -407,7 +407,7 @@ @check "select" result = a .< b # Use same-typed operands for where to avoid Union type - b_promoted = ct.astype(b, Int64) + b_promoted = convert(ct.Tile{Int64}, b) selected = ct.where(result, a, b_promoted) ct.store(out, Int32(0), selected) return @@ -779,7 +779,7 @@ pid = ct.bid(1) tile = ct.load(a, pid, (16,)) @check "ftof" - converted = ct.astype(tile, Float16) + converted = convert(ct.Tile{Float16}, tile) ct.store(b, pid, converted) return end @@ -793,7 +793,7 @@ tile = ct.load(a, pid, (16,)) @check "ftof" converted = convert(ct.Tile{ct.TFloat32}, tile) - ct.store(b, pid, ct.astype(converted, Float32)) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) return end end @@ -806,7 +806,45 @@ tile = ct.load(a, pid, (16,)) @check "ftof" converted = convert(ct.Tile{ct.BFloat16}, tile) - ct.store(b, pid, ct.astype(converted, Float32)) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Broadcasting syntax: Float16.(tile) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float16,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + ct.store(b, pid, Float16.(tile)) + return + end + end + + # Broadcasting syntax: BFloat16.(tile) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + @check "ftof" + ct.store(b, pid, Float32.(ct.BFloat16.(tile))) + return + end + end + + # Broadcasting syntax: TFloat32.(tile) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + @check "ftof" + ct.store(b, pid, Float32.(ct.TFloat32.(tile))) return end end diff --git a/test/execution.jl b/test/execution.jl index 6a8d6b0..8e7549d 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -2269,7 +2269,7 @@ end tile = ct.load(a, ct.bid(1), (tileSz[],)) mask = tile .> 0.0f0 result = any(mask; dims=1) - ct.store(b, ct.bid(1), ct.astype(result, Int32)) + ct.store(b, ct.bid(1), convert(ct.Tile{Int32}, result)) return nothing end @@ -2278,7 +2278,7 @@ end tile = ct.load(a, ct.bid(1), (tileSz[],)) mask = tile .> 0.0f0 result = all(mask; dims=1) - ct.store(b, ct.bid(1), ct.astype(result, Int32)) + ct.store(b, ct.bid(1), convert(ct.Tile{Int32}, result)) return nothing end diff --git a/test/ext/DLFP8TypesExt.jl b/test/ext/DLFP8TypesExt.jl index 472dbcd..100a88e 100644 --- a/test/ext/DLFP8TypesExt.jl +++ b/test/ext/DLFP8TypesExt.jl @@ -12,7 +12,7 @@ spec1d = ct.ArraySpec{1}(16, true) tile = ct.load(a, pid, (16,)) @check "ftof" converted = convert(ct.Tile{Float8_E4M3FN}, tile) - ct.store(b, pid, ct.astype(converted, Float32)) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) return end end @@ -25,7 +25,7 @@ end tile = ct.load(a, pid, (16,)) @check "ftof" converted = convert(ct.Tile{Float8_E5M2}, tile) - ct.store(b, pid, ct.astype(converted, Float32)) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) return end end