Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,78 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
end
end


#=============================================================================
Scan operations
=============================================================================#

"""
encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityVal},
body_scalar_types::Vector{TypeId})

Encode a ScanOp (parallel prefix sum) operation.

# Arguments
- body: Function that takes block args and yields result(s)
- cb: CodeBuilder for the bytecode
- result_types: Output tile types
- operands: Input tiles to scan
- dim: Dimension to scan along (0-indexed)
- reverse: Whether to scan in reverse order
- identities: Identity values for each operand
- body_scalar_types: 0D tile types for body arguments
"""
function encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityVal},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ScanOp)

# Variadic result types
encode_typeid_seq!(cb.buf, result_types)

# Attributes: dim (int), reverse (bool), identities (array)
encode_opattr_int!(cb, dim)
encode_opattr_bool!(cb, reverse)
encode_identity_array!(cb, identities)

# Variadic operands
encode_varint!(cb.buf, length(operands))
encode_operands!(cb.buf, operands)

# Number of regions
push!(cb.debug_attrs, cb.cur_debug_attr)
cb.num_ops += 1
encode_varint!(cb.buf, 1) # 1 region: body

# Body region - block args are pairs of (acc, elem) for each operand
# The body operates on 0D tiles (scalars)
body_arg_types = TypeId[]
for scalar_type in body_scalar_types
push!(body_arg_types, scalar_type) # accumulator
push!(body_arg_types, scalar_type) # element
end
with_region(body, cb, body_arg_types)

# Create result values
num_results = length(result_types)
if num_results == 0
return Value[]
else
vals = [Value(cb.next_value_id + i) for i in 0:num_results-1]
cb.next_value_id += num_results
return vals
end
end

#=============================================================================
Comparison and selection operations
=============================================================================#
Expand Down
11 changes: 3 additions & 8 deletions src/compiler/intrinsics/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,17 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.astype), args)
target_dtype = julia_to_tile_dtype!(tt, target_elem)
target_tile_type = tile_type!(tt, target_dtype, tile_shape)

# Determine signedness for integer types
function is_signed_int(T)
T <: Signed || T === Int32 || T === Int64 || T === Int16 || T === Int8
end

# Emit conversion based on source and target types
result = if source_elem <: AbstractFloat && target_elem <: AbstractFloat
# Float -> Float
encode_FToFOp!(cb, target_tile_type, source.v)
elseif source_elem <: Integer && target_elem <: AbstractFloat
# Integer -> Float
signedness = is_signed_int(source_elem) ? SignednessSigned : SignednessUnsigned
signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_IToFOp!(cb, target_tile_type, source.v; signedness)
elseif source_elem <: AbstractFloat && target_elem <: Integer
# Float -> Integer
signedness = is_signed_int(target_elem) ? SignednessSigned : SignednessUnsigned
signedness = target_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_FToIOp!(cb, target_tile_type, source.v; signedness)
elseif source_elem <: Integer && target_elem <: Integer
# Integer -> Integer
Expand All @@ -66,7 +61,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.astype), args)
source.v
elseif target_size > source_size
# Extension (upsize)
signedness = is_signed_int(source_elem) ? SignednessSigned : SignednessUnsigned
signedness = source_elem <: Signed ? SignednessSigned : SignednessUnsigned
encode_ExtIOp!(cb, target_tile_type, source.v; signedness)
else
# Truncation (downsize)
Expand Down
95 changes: 76 additions & 19 deletions src/compiler/intrinsics/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol)
results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args
acc, elem = block_args[1], block_args[2]

res = encode_reduce_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type)
res = encode_binop_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type)
encode_YieldOp!(cb, [res])
end

Expand Down Expand Up @@ -609,26 +609,18 @@ operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer =
IntegerIdentityVal(to_uint128(typemin(T)), dtype, T)

#=============================================================================#
# Reduce Body Operations
# Binary Operation Body Encoding (shared by reduce and scan)
#=============================================================================#
function encode_reduce_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T
function encode_binop_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T
if T <: AbstractFloat
if op == :add
encode_AddFOp!(cb, type, acc, elem)
elseif op == :max
encode_MaxFOp!(cb, type, acc, elem)
else
error("Unsupported float reduction operation: $op")
end
else # Integer
op == :add ? encode_AddFOp!(cb, type, acc, elem) :
op == :max ? encode_MaxFOp!(cb, type, acc, elem) :
error("Unsupported float operation: $op")
else
signedness = T <: Signed ? SignednessSigned : SignednessUnsigned
if op == :add
encode_AddIOp!(cb, type, acc, elem)
elseif op == :max
encode_MaxIOp!(cb, type, acc, elem; signedness)
else
error("Unsupported integer reduction operation: $op")
end
op == :add ? encode_AddIOp!(cb, type, acc, elem) :
op == :max ? encode_MaxIOp!(cb, type, acc, elem; signedness) :
error("Unsupported integer operation: $op")
end
end

Expand Down Expand Up @@ -702,7 +694,72 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args)
CGVal(current_val, result_type_id, Tile{elem_type, Tuple(target_shape)}, target_shape)
end

# TODO: cuda_tile.scan
# cuda_tile.scan
@eval Intrinsics begin
"""
scan(tile, axis_val, fn_type; reverse=false)

Parallel prefix scan along specified dimension.
fn_type=:add for cumulative sum (only supported operation).
reverse=false for forward scan, true for reverse scan.
Compiled to cuda_tile.scan.
"""
@noinline function scan(tile::Tile{T, S}, ::Val{axis}, fn::Symbol, reverse::Bool=false) where {T, S, axis}
# Scan preserves shape - result has same dimensions as input
Tile{T, S}()
end
end

function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args)
cb = ctx.cb
tt = ctx.tt

# Get input tile
input_tv = emit_value!(ctx, args[1])
input_tv === nothing && error("Cannot resolve input tile for scan")

# Get scan axis
axis = @something get_constant(ctx, args[2]) error("Scan axis must be a compile-time constant")

# Get scan function type (only :add is supported)
fn_type = @something get_constant(ctx, args[3]) error("Scan function type must be a compile-time constant")
fn_type == :add || error("Only :add (cumulative sum) is currently supported for scan operations")

# Get reverse flag (optional, defaults to false)
reverse = false
if length(args) >= 4
reverse_val = get_constant(ctx, args[4])
reverse = reverse_val === true
end

# Get element type and shapes
input_type = unwrap_type(input_tv.jltype)
elem_type = input_type <: Tile ? input_type.parameters[1] : input_type
input_shape = input_tv.shape

# For scan, output shape is same as input shape
output_shape = copy(input_shape)

dtype = julia_to_tile_dtype!(tt, elem_type)

# Output tile type (same shape as input)
output_tile_type = tile_type!(tt, dtype, output_shape)

# Scalar type for scan body (0D tile)
scalar_tile_type = tile_type!(tt, dtype, Int[])

# Create identity value using operation_identity
identity = operation_identity(Val(fn_type), dtype, elem_type)

# Emit ScanOp
results = encode_ScanOp!(cb, [output_tile_type], [input_tv.v], axis, reverse, [identity], [scalar_tile_type]) do block_args
acc, elem = block_args[1], block_args[2]
res = encode_binop_body(cb, scalar_tile_type, acc, elem, fn_type, elem_type)
encode_YieldOp!(cb, [res])
end

CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape)
end

# cuda_tile.select
@eval Intrinsics begin
Expand Down
13 changes: 13 additions & 0 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,19 @@ end
Intrinsics.reduce_max(tile, Val(axis - 1))
end

# Scan (Prefix Sum) Operations

@inline function scan(tile::Tile{T, S}, ::Val{axis},
fn::Symbol=:add,
reverse::Bool=false) where {T<:Number, S, axis}
Intrinsics.scan(tile, Val(axis - 1), fn, reverse)
end

@inline function cumsum(tile::Tile{T, S}, ::Val{axis},
reverse::Bool=false) where {T<:Number, S, axis}
scan(tile, Val(axis), :add, reverse)
end

#=============================================================================
Matrix multiplication
=============================================================================#
Expand Down
130 changes: 75 additions & 55 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,63 @@
# TODO: mmai - integer matrix multiply-accumulate
# TODO: offset - tile offset computation
# TODO: pack - pack tiles
# TODO: scan - parallel scan/prefix sum
@testset "scan" begin
# Forward scan - float and integer types
for (T, spec, op_check) in [
(Float32, spec1d, "addf"),
(Int32, spec1d, "addi"),
]
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{T,1,spec}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "scan"
@check op_check
Base.donotdelete(ct.scan(tile, Val(1), :add, false))
return
end
end
end

# 2D scan along different axes
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 8))
@check "scan"
Base.donotdelete(ct.scan(tile, Val(1), :add, false))
@check "scan"
Base.donotdelete(ct.scan(tile, Val(2), :add, false))
return
end
end

# Reverse scan
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 8))
@check "scan"
Base.donotdelete(ct.scan(tile, Val(1), :add, true))
return
end
end

# cumsum convenience
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 8))
@check "scan"
Base.donotdelete(ct.cumsum(tile, Val(2), false))
return
end
end
end
# TODO: unpack - unpack tiles

@testset "reshape" begin
Expand Down Expand Up @@ -385,61 +441,25 @@
return
end
end
end

# Integer reduce_sum (Int32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "addi"
sums = ct.reduce_sum(tile, 2)
ct.store(b, pid, sums)
return
end
end

# Integer reduce_max (Int32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "maxi"
maxes = ct.reduce_max(tile, 2)
ct.store(b, pid, maxes)
return
end
end

# Unsigned reduce_sum (UInt32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "addi"
sums = ct.reduce_sum(tile, 2)
ct.store(b, pid, sums)
return
end
end

# Unsigned reduce_max (UInt32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "maxi"
maxes = ct.reduce_max(tile, 2)
ct.store(b, pid, maxes)
return
# Integer/unsigned reduce
for (T, op, op_check) in [
(Int32, ct.reduce_sum, "addi"),
(Int32, ct.reduce_max, "maxi"),
(UInt32, ct.reduce_sum, "addi"),
(UInt32, ct.reduce_max, "maxi"),
]
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{T,2,spec2d}}) do a
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check op_check
Base.donotdelete(op(tile, 2))
return
end
end
end
end

Expand Down
Loading