Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,22 @@ def emit_If(adj, node):
out = adj.add_builtin_call("where", [cond, var1, var2])
adj.symbols[sym] = out

def emit_IfExp(adj, node):
cond = adj.eval(node.test)

if cond.constant is not None:
return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)

adj.begin_if(cond)
body = adj.eval(node.body)
adj.end_if(cond)

adj.begin_else(cond)
orelse = adj.eval(node.orelse)
adj.end_else(cond)

return adj.add_builtin_call("where", [cond, body, orelse])

def emit_Compare(adj, node):
# node.left, node.ops (list of ops), node.comparators (things to compare to)
# e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
Expand Down Expand Up @@ -2780,6 +2796,7 @@ def emit_Pass(adj, node):
node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
ast.FunctionDef: emit_FunctionDef,
ast.If: emit_If,
ast.IfExp: emit_IfExp,
ast.Compare: emit_Compare,
ast.BoolOp: emit_BoolOp,
ast.Name: emit_Name,
Expand Down
19 changes: 19 additions & 0 deletions warp/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,22 @@ def test_codegen_return_in_kernel(test, device):
test.assertEqual(result.numpy()[0], grid_size - 256)


@wp.kernel
def conditional_ifexp(x: float, result: wp.array(dtype=wp.int32)):
wp.atomic_add(result, 0, 1) if x > 0.0 else wp.atomic_add(result, 1, 1)


def test_ifexp_only_executes_one_branch(test, device):
result = wp.zeros(2, dtype=wp.int32, device=device)

wp.launch(conditional_ifexp, dim=1, inputs=[1.0, result], device=device)

values = result.numpy()
# Only first branch is taken
test.assertEqual(values[0], 1)
test.assertEqual(values[1], 0)


@wp.kernel
def test_multiple_return_values_quat_to_axis_angle_kernel(
q: wp.quath,
Expand Down Expand Up @@ -941,6 +957,9 @@ class TestCodeGen(unittest.TestCase):
add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices)
add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)
add_function_test(TestCodeGen, "test_codegen_return_in_kernel", test_codegen_return_in_kernel, devices=devices)
add_function_test(
TestCodeGen, "test_ifexp_only_executes_one_branch", test_ifexp_only_executes_one_branch, devices=devices
)
add_function_test(
TestCodeGen,
func=test_multiple_return_values,
Expand Down
46 changes: 46 additions & 0 deletions warp/tests/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,48 @@ def test_conditional_if_else_nested():
wp.expect_eq(e, -2.0)


@wp.kernel
def test_conditional_ifexp():
a = 0.5
b = 2.0

c = 1.0 if a > b else -1.0

wp.expect_eq(c, -1.0)


@wp.kernel
def test_conditional_ifexp_nested():
a = 1.0
b = 2.0

c = 3.0 if a > b else 6.0
d = 4.0 if a > b else 7.0
e = 1.0 if (a > b and c > d) else (-1.0 if a > b else (2.0 if c > d else -2.0))

wp.expect_eq(e, -2.0)


@wp.kernel
def test_conditional_ifexp_constant():
a = 1.0 if False else -1.0
b = 2.0 if 123 else -2.0

wp.expect_eq(a, -1.0)
wp.expect_eq(b, 2.0)


@wp.kernel
def test_conditional_ifexp_constant_nested():
a = 1.0 if False else (2.0 if True else 3.0)
b = 4.0 if 0 else (5.0 if 0 else (6.0 if False else 7.0))
c = 8.0 if False else (9.0 if False else (10.0 if 321 else 11.0))

wp.expect_eq(a, 2.0)
wp.expect_eq(b, 7.0)
wp.expect_eq(c, 10.0)


@wp.kernel
def test_boolean_and():
a = 1.0
Expand Down Expand Up @@ -231,6 +273,10 @@ class TestConditional(unittest.TestCase):

add_kernel_test(TestConditional, kernel=test_conditional_if_else, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_conditional_if_else_nested, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_conditional_ifexp, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_conditional_ifexp_nested, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant_nested, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_boolean_and, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_boolean_or, dim=1, devices=devices)
add_kernel_test(TestConditional, kernel=test_boolean_compound, dim=1, devices=devices)
Expand Down