Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/bytecode/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ Encode an unsigned integer using variable-length encoding (LEB128-style).
Each byte uses 7 bits for data and 1 bit to indicate continuation.
"""
function encode_varint!(buf::Vector{UInt8}, x::Integer)
@assert x >= 0 "Varint encoding requires non-negative integers, got $x"
if x < 0
throw(ArgumentError("Varint encoding requires non-negative integers, got $x"))
end

# Handle zero specially
if x == 0
push!(buf, 0x00)
Expand Down
13 changes: 13 additions & 0 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,19 @@ function encode_ConstantOp!(cb::CodeBuilder, result_type::TypeId, value_bytes::V
return new_op!(cb)
end

"""
encode_AssertOp!(cb, condition, message)

Assert that a condition is true, killing the kernel with a message on failure.
Opcode: 5
"""
function encode_AssertOp!(cb::CodeBuilder, condition::Value, message::String)
encode_varint!(cb.buf, Opcode.AssertOp)
encode_opattr_str!(cb, message)
encode_operand!(cb.buf, condition)
return new_op!(cb, 0)
end

"""
encode_AssumeOp!(cb, result_type, value, predicate) -> Value

Expand Down
4 changes: 3 additions & 1 deletion src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ function write_bytecode!(f::Function, num_functions::Int)
# Let user build functions
f(writer, func_buf)

@assert writer.num_functions == num_functions "Expected $num_functions functions, got $(writer.num_functions)"
if writer.num_functions != num_functions
throw(ArgumentError("Expected $num_functions functions, got $(writer.num_functions)"))
end

# Build final output
buf = UInt8[]
Expand Down
8 changes: 6 additions & 2 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},

if field === nothing
# Regular argument - create concrete CGVal
@assert length(values) == 1
if length(values) != 1
throw(IRError("Expected exactly one value for argument $arg_idx, got $(length(values))"))
end
val = values[1]
type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx])
tv = CGVal(val, type_id, sci.argtypes[arg_idx])
Expand Down Expand Up @@ -252,7 +254,9 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
)

# 2. Compile through cuTile pipeline (cached)
@assert haskey(ctx.cache, mi) "Expected $func($(join(arg_types, ", "))) to be cached already by inference."
if !haskey(ctx.cache, mi)
error("Expected $func($(join(arg_types, ", "))) to be cached already by inference.")
end
sci, _ = emit_ir(ctx.cache, mi)

# 3. Create sub-context
Expand Down
14 changes: 14 additions & 0 deletions src/compiler/intrinsics/misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# miscellaneous intrinsics

# cuda_tile.assert
@eval Intrinsics begin
@noinline function assert(cond::Bool, message::String)
donotdelete(cond, message)
nothing
end
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.assert), args)
cond = @something emit_value!(ctx, args[1]) throw(IRError("assert: cannot resolve condition"))
message = @something get_constant(ctx, args[2]) throw(IRError("assert: requires constant message"))
encode_AssertOp!(ctx.cb, cond.v, message)
nothing # no result value
end

# XXX: cuda_tile.assume
# make this a pass?
function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value},
Expand Down
30 changes: 30 additions & 0 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,3 +935,33 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8)
Intrinsics.extract(tile, map(i -> i - 1, index), shape)
@inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} =
Intrinsics.extract(tile, map(i -> i - 1, Index), Shape)

#=============================================================================
Assert
=============================================================================#

public @assert

"""
@assert cond [message]

Assert that `cond` is true, aborting the kernel with `message` on failure.
If no message is given, the stringified condition is used.

Works like `Base.@assert` but compiles to a Tile IR assert op.
Failed assertions are **fatal** — they crash the kernel and corrupt the
CUDA context (not a catchable exception).

# Examples
```julia
ct.@assert bid > Int32(0)
ct.@assert bid > Int32(0) "bid must be positive"
```
"""
macro assert(cond)
msg = string(cond)
:($(Intrinsics.assert)($(esc(cond)), $msg))
end
macro assert(cond, msg)
:($(Intrinsics.assert)($(esc(cond)), $(esc(msg))))
end
33 changes: 32 additions & 1 deletion test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,38 @@
8.5 Control Flow
=========================================================================#
@testset "Control Flow" begin
# TODO: assert - runtime assertion
@testset "@assert with message" begin
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
bid = ct.bid(1)
@check "cmpi"
ct.@assert bid > Int32(0) "bid must be positive"
@check "assert"
@check "bid must be positive"
tile = ct.load(a, bid, (16,))
ct.store(a, bid, tile)
return
end
end
end

@testset "@assert without message" begin
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
bid = ct.bid(1)
@check "cmpi"
ct.@assert bid > Int32(0)
@check "assert"
# Auto-generated message from the expression
@check "bid > Int32(0)"
tile = ct.load(a, bid, (16,))
ct.store(a, bid, tile)
return
end
end
end

@testset "if with empty branch" begin
# Empty if branches must emit YieldOp to satisfy MLIR block terminator requirements
Expand Down
62 changes: 62 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3193,3 +3193,65 @@ end # invalidations
end
end

@testset "assert" begin
@testset "passing assertion with message" begin
function assert_msg_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int})
bid = ct.bid(1)
ct.@assert bid > Int32(0) "bid must be positive"
t = ct.load(a, bid, (tile_size[],))
ct.store(a, bid, t)
return
end

a = CUDA.ones(Float32, 1024)
ct.launch(assert_msg_kernel, cld(1024, 128), a, ct.Constant(128))
CUDA.synchronize()
@test all(Array(a) .== 1.0f0)
end

@testset "passing assertion without message" begin
function assert_nomsg_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int})
bid = ct.bid(1)
ct.@assert bid > Int32(0)
t = ct.load(a, bid, (tile_size[],))
ct.store(a, bid, t)
return
end

a = CUDA.ones(Float32, 1024)
ct.launch(assert_nomsg_kernel, cld(1024, 128), a, ct.Constant(128))
CUDA.synchronize()
@test all(Array(a) .== 1.0f0)
end

@testset "failing assertion" begin
# Failed assertions crash the CUDA context, so we must test in a subprocess
# (following the same pattern as cuTile Python's test_assert.py)
script = """
using CUDA
import cuTile as ct

function assert_fail_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int})
bid = ct.bid(1)
ct.@assert bid > Int32(999999) "custom assert message"
t = ct.load(a, bid, (tile_size[],))
ct.store(a, bid, t)
return
end

a = CUDA.ones(Float32, 1024)
ct.launch(assert_fail_kernel, cld(1024, 128), a, ct.Constant(128))
CUDA.synchronize()
"""
cmd = `$(Base.julia_cmd()) --project=$(Base.active_project()) -e $script`
output = Pipe()
proc = run(pipeline(ignorestatus(cmd); stdout=output, stderr=output); wait=false)
close(output.in)
reader = @async read(output, String)
wait(proc)
result = fetch(reader)
@test proc.exitcode != 0
@test contains(result, "custom assert message")
end
end