From 97e3689a5b4766b23a2e860c9f1397291b5c9cc8 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 9 May 2025 01:06:00 -0400 Subject: [PATCH 1/2] Add support for `ast.IfExp` in kernels Signed-off-by: fairywreath --- warp/codegen.py | 17 +++++++++++++++++ warp/tests/test_codegen.py | 19 +++++++++++++++++++ warp/tests/test_conditional.py | 24 ++++++++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/warp/codegen.py b/warp/codegen.py index bb394b4d03..0974f59500 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1771,6 +1771,22 @@ def emit_If(adj, node): out = adj.add_builtin_call("where", [cond, var1, var2]) adj.symbols[sym] = out + def emit_IfExp(adj, node): + cond = adj.eval(node.test) + + if cond.constant is not None: + return adj.eval(node.body) if cond.constant else adj.eval(node.orelse) + + adj.begin_if(cond) + body = adj.eval(node.body) + adj.end_if(cond) + + adj.begin_else(cond) + orelse = adj.eval(node.orelse) + adj.end_else(cond) + + return adj.add_builtin_call("where", [cond, body, orelse]) + def emit_Compare(adj, node): # node.left, node.ops (list of ops), node.comparators (things to compare to) # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1] @@ -2780,6 +2796,7 @@ def emit_Pass(adj, node): node_visitors: ClassVar[dict[type[ast.AST], Callable]] = { ast.FunctionDef: emit_FunctionDef, ast.If: emit_If, + ast.IfExp: emit_IfExp, ast.Compare: emit_Compare, ast.BoolOp: emit_BoolOp, ast.Name: emit_Name, diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index 14006d06dd..7c68ce9a35 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -693,6 +693,22 @@ def test_codegen_return_in_kernel(test, device): test.assertEqual(result.numpy()[0], grid_size - 256) +@wp.kernel +def conditional_ifexp(x: float, result: wp.array(dtype=wp.int32)): + wp.atomic_add(result, 0, 1) if x > 0.0 else wp.atomic_add(result, 1, 1) + + +def test_ifexp_only_executes_one_branch(test, device): + result = wp.zeros(2, dtype=wp.int32, device=device) + + wp.launch(conditional_ifexp, dim=1, inputs=[1.0, result], device=device) + + values = result.numpy() + # Only first branch is taken + test.assertEqual(values[0], 1) + test.assertEqual(values[1], 0) + + @wp.kernel def test_multiple_return_values_quat_to_axis_angle_kernel( q: wp.quath, @@ -941,6 +957,9 @@ class TestCodeGen(unittest.TestCase): add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices) add_function_test(TestCodeGen, "test_codegen_return_in_kernel", test_codegen_return_in_kernel, devices=devices) +add_function_test( + TestCodeGen, "test_ifexp_only_executes_one_branch", test_ifexp_only_executes_one_branch, devices=devices +) add_function_test( TestCodeGen, func=test_multiple_return_values, diff --git a/warp/tests/test_conditional.py b/warp/tests/test_conditional.py index 659ba2741c..cd3c92cf77 100644 --- a/warp/tests/test_conditional.py +++ b/warp/tests/test_conditional.py @@ -58,6 +58,28 @@ def test_conditional_if_else_nested(): wp.expect_eq(e, -2.0) +@wp.kernel +def test_conditional_ifexp(): + a = 0.5 + b = 2.0 + + c = 1.0 if a > b else -1.0 + + wp.expect_eq(c, -1.0) + + +@wp.kernel +def test_conditional_ifexp_nested(): + a = 1.0 + b = 2.0 + + c = 3.0 if a > b else 6.0 + d = 4.0 if a > b else 7.0 + e = 1.0 if (a > b and c > d) else (-1.0 if a > b else (2.0 if c > d else -2.0)) + + wp.expect_eq(e, -2.0) + + @wp.kernel def test_boolean_and(): a = 1.0 @@ -231,6 +253,8 @@ class TestConditional(unittest.TestCase): add_kernel_test(TestConditional, kernel=test_conditional_if_else, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_conditional_if_else_nested, dim=1, devices=devices) +add_kernel_test(TestConditional, kernel=test_conditional_ifexp, dim=1, devices=devices) +add_kernel_test(TestConditional, kernel=test_conditional_ifexp_nested, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_and, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_or, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_compound, dim=1, devices=devices) From dfdd8ef85df6b9bb147275dadb5f64f170b17787 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 16 May 2025 00:13:45 -0400 Subject: [PATCH 2/2] Add ast.IfExp constant test cases Signed-off-by: fairywreath --- warp/tests/test_conditional.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/warp/tests/test_conditional.py b/warp/tests/test_conditional.py index cd3c92cf77..6cb0f9befc 100644 --- a/warp/tests/test_conditional.py +++ b/warp/tests/test_conditional.py @@ -80,6 +80,26 @@ def test_conditional_ifexp_nested(): wp.expect_eq(e, -2.0) +@wp.kernel +def test_conditional_ifexp_constant(): + a = 1.0 if False else -1.0 + b = 2.0 if 123 else -2.0 + + wp.expect_eq(a, -1.0) + wp.expect_eq(b, 2.0) + + +@wp.kernel +def test_conditional_ifexp_constant_nested(): + a = 1.0 if False else (2.0 if True else 3.0) + b = 4.0 if 0 else (5.0 if 0 else (6.0 if False else 7.0)) + c = 8.0 if False else (9.0 if False else (10.0 if 321 else 11.0)) + + wp.expect_eq(a, 2.0) + wp.expect_eq(b, 7.0) + wp.expect_eq(c, 10.0) + + @wp.kernel def test_boolean_and(): a = 1.0 @@ -255,6 +275,8 @@ class TestConditional(unittest.TestCase): add_kernel_test(TestConditional, kernel=test_conditional_if_else_nested, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_conditional_ifexp, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_conditional_ifexp_nested, dim=1, devices=devices) +add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant, dim=1, devices=devices) +add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant_nested, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_and, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_or, dim=1, devices=devices) add_kernel_test(TestConditional, kernel=test_boolean_compound, dim=1, devices=devices)