From 7be8c3e0624938b657c8485fb963e5d6e9c3c66a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 9 Feb 2026 09:20:54 +0100 Subject: [PATCH 1/2] Add support for assertions. --- src/bytecode/encodings.jl | 13 +++++++ src/compiler/intrinsics/misc.jl | 14 ++++++++ src/language/operations.jl | 30 ++++++++++++++++ test/codegen.jl | 33 +++++++++++++++++- test/execution.jl | 62 +++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 1 deletion(-) diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index bdf6db0..b759241 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -289,6 +289,19 @@ function encode_ConstantOp!(cb::CodeBuilder, result_type::TypeId, value_bytes::V return new_op!(cb) end +""" + encode_AssertOp!(cb, condition, message) + +Assert that a condition is true, killing the kernel with a message on failure. +Opcode: 5 +""" +function encode_AssertOp!(cb::CodeBuilder, condition::Value, message::String) + encode_varint!(cb.buf, Opcode.AssertOp) + encode_opattr_str!(cb, message) + encode_operand!(cb.buf, condition) + return new_op!(cb, 0) +end + """ encode_AssumeOp!(cb, result_type, value, predicate) -> Value diff --git a/src/compiler/intrinsics/misc.jl b/src/compiler/intrinsics/misc.jl index 157a26a..0b9f332 100644 --- a/src/compiler/intrinsics/misc.jl +++ b/src/compiler/intrinsics/misc.jl @@ -1,5 +1,19 @@ # miscellaneous intrinsics +# cuda_tile.assert +@eval Intrinsics begin + @noinline function assert(cond::Bool, message::String) + donotdelete(cond, message) + nothing + end +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.assert), args) + cond = @something emit_value!(ctx, args[1]) throw(IRError("assert: cannot resolve condition")) + message = @something get_constant(ctx, args[2]) throw(IRError("assert: requires constant message")) + encode_AssertOp!(ctx.cb, cond.v, message) + nothing # no result value +end + # XXX: cuda_tile.assume # make this a pass? function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value}, diff --git a/src/language/operations.jl b/src/language/operations.jl index cbf8597..161ad08 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -935,3 +935,33 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8) Intrinsics.extract(tile, map(i -> i - 1, index), shape) @inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} = Intrinsics.extract(tile, map(i -> i - 1, Index), Shape) + +#============================================================================= + Assert +=============================================================================# + +public @assert + +""" + @assert cond [message] + +Assert that `cond` is true, aborting the kernel with `message` on failure. +If no message is given, the stringified condition is used. + +Works like `Base.@assert` but compiles to a Tile IR assert op. +Failed assertions are **fatal** — they crash the kernel and corrupt the +CUDA context (not a catchable exception). + +# Examples +```julia +ct.@assert bid > Int32(0) +ct.@assert bid > Int32(0) "bid must be positive" +``` +""" +macro assert(cond) + msg = string(cond) + :($(Intrinsics.assert)($(esc(cond)), $msg)) +end +macro assert(cond, msg) + :($(Intrinsics.assert)($(esc(cond)), $(esc(msg)))) +end diff --git a/test/codegen.jl b/test/codegen.jl index 4da6e0f..cb9c2fc 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -855,7 +855,38 @@ 8.5 Control Flow =========================================================================# @testset "Control Flow" begin - # TODO: assert - runtime assertion + @testset "@assert with message" begin + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a + bid = ct.bid(1) + @check "cmpi" + ct.@assert bid > Int32(0) "bid must be positive" + @check "assert" + @check "bid must be positive" + tile = ct.load(a, bid, (16,)) + ct.store(a, bid, tile) + return + end + end + end + + @testset "@assert without message" begin + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a + bid = ct.bid(1) + @check "cmpi" + ct.@assert bid > Int32(0) + @check "assert" + # Auto-generated message from the expression + @check "bid > Int32(0)" + tile = ct.load(a, bid, (16,)) + ct.store(a, bid, tile) + return + end + end + end @testset "if with empty branch" begin # Empty if branches must emit YieldOp to satisfy MLIR block terminator requirements diff --git a/test/execution.jl b/test/execution.jl index 8e7549d..e01adb0 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -3193,3 +3193,65 @@ end # invalidations end end +@testset "assert" begin + @testset "passing assertion with message" begin + function assert_msg_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int}) + bid = ct.bid(1) + ct.@assert bid > Int32(0) "bid must be positive" + t = ct.load(a, bid, (tile_size[],)) + ct.store(a, bid, t) + return + end + + a = CUDA.ones(Float32, 1024) + ct.launch(assert_msg_kernel, cld(1024, 128), a, ct.Constant(128)) + CUDA.synchronize() + @test all(Array(a) .== 1.0f0) + end + + @testset "passing assertion without message" begin + function assert_nomsg_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int}) + bid = ct.bid(1) + ct.@assert bid > Int32(0) + t = ct.load(a, bid, (tile_size[],)) + ct.store(a, bid, t) + return + end + + a = CUDA.ones(Float32, 1024) + ct.launch(assert_nomsg_kernel, cld(1024, 128), a, ct.Constant(128)) + CUDA.synchronize() + @test all(Array(a) .== 1.0f0) + end + + @testset "failing assertion" begin + # Failed assertions crash the CUDA context, so we must test in a subprocess + # (following the same pattern as cuTile Python's test_assert.py) + script = """ + using CUDA + import cuTile as ct + + function assert_fail_kernel(a::ct.TileArray{Float32,1}, tile_size::ct.Constant{Int}) + bid = ct.bid(1) + ct.@assert bid > Int32(999999) "custom assert message" + t = ct.load(a, bid, (tile_size[],)) + ct.store(a, bid, t) + return + end + + a = CUDA.ones(Float32, 1024) + ct.launch(assert_fail_kernel, cld(1024, 128), a, ct.Constant(128)) + CUDA.synchronize() + """ + cmd = `$(Base.julia_cmd()) --project=$(Base.active_project()) -e $script` + output = Pipe() + proc = run(pipeline(ignorestatus(cmd); stdout=output, stderr=output); wait=false) + close(output.in) + reader = @async read(output, String) + wait(proc) + result = fetch(reader) + @test proc.exitcode != 0 + @test contains(result, "custom assert message") + end +end + From df26257ff41ef6dedf929be4cf0c8cb2aedaa709 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 9 Feb 2026 09:34:09 +0100 Subject: [PATCH 2/2] Make `Base.@assert` uses into errors. --- src/bytecode/basic.jl | 5 ++++- src/bytecode/writer.jl | 4 +++- src/compiler/codegen/kernel.jl | 8 ++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/bytecode/basic.jl b/src/bytecode/basic.jl index ebf327b..a36b5c6 100644 --- a/src/bytecode/basic.jl +++ b/src/bytecode/basic.jl @@ -11,7 +11,10 @@ Encode an unsigned integer using variable-length encoding (LEB128-style). Each byte uses 7 bits for data and 1 bit to indicate continuation. """ function encode_varint!(buf::Vector{UInt8}, x::Integer) - @assert x >= 0 "Varint encoding requires non-negative integers, got $x" + if x < 0 + throw(ArgumentError("Varint encoding requires non-negative integers, got $x")) + end + # Handle zero specially if x == 0 push!(buf, 0x00) diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index 7009e50..acac1b8 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -496,7 +496,9 @@ function write_bytecode!(f::Function, num_functions::Int) # Let user build functions f(writer, func_buf) - @assert writer.num_functions == num_functions "Expected $num_functions functions, got $(writer.num_functions)" + if writer.num_functions != num_functions + throw(ArgumentError("Expected $num_functions functions, got $(writer.num_functions)")) + end # Build final output buf = UInt8[] diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 0caafee..1c80d31 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -92,7 +92,9 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, if field === nothing # Regular argument - create concrete CGVal - @assert length(values) == 1 + if length(values) != 1 + throw(IRError("Expected exactly one value for argument $arg_idx, got $(length(values))")) + end val = values[1] type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx]) tv = CGVal(val, type_id, sci.argtypes[arg_idx]) @@ -252,7 +254,9 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector, ) # 2. Compile through cuTile pipeline (cached) - @assert haskey(ctx.cache, mi) "Expected $func($(join(arg_types, ", "))) to be cached already by inference." + if !haskey(ctx.cache, mi) + error("Expected $func($(join(arg_types, ", "))) to be cached already by inference.") + end sci, _ = emit_ir(ctx.cache, mi) # 3. Create sub-context