From 3e84ef6a9fb78d959041fd40c1cf93dd3d680d1a Mon Sep 17 00:00:00 2001 From: kaihsin Date: Tue, 25 Nov 2025 05:17:04 -0500 Subject: [PATCH 1/2] lambda comp --- src/kirin/dialects/lowering/func.py | 72 ++++++++++++++++++++++++++++- src/kirin/dialects/py/assign.py | 2 + test/lowering/test_lambda_comp.py | 64 +++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 test/lowering/test_lambda_comp.py diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index 19b57fc49..97dcf45ec 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -22,6 +22,9 @@ def lower_Return(self, state: lowering.State, node: ast.Return) -> lowering.Resu def lower_FunctionDef( self, state: lowering.State[ast.AST], node: ast.FunctionDef ) -> lowering.Result: + + frame = state.current_frame + slots = tuple(arg.arg for arg in node.args.args) self.assert_simple_arguments(node.args) signature = func.Signature( inputs=tuple( @@ -29,9 +32,7 @@ def lower_FunctionDef( ), output=self.get_hint(state, node.returns), ) - frame = state.current_frame - slots = tuple(arg.arg for arg in node.args.args) entries: dict[str, ir.SSAValue] = {} entr_block = ir.Block() fn_self = entr_block.args.append_from( @@ -109,6 +110,73 @@ def callback(frame: lowering.Frame, value: ir.SSAValue): # NOTE: Python automatically assigns the lambda to the name frame.defs[node.name] = lambda_stmt.result + def lower_Lambda( + self, state: lowering.State[ast.AST], node: ast.Lambda + ) -> lowering.Result: + + frame = state.current_frame + slots = tuple(arg.arg for arg in node.args.args) + self.assert_simple_arguments(node.args) + signature = func.Signature( + inputs=tuple( + self.get_hint(state, arg.annotation) for arg in node.args.args + ), + output=types.Any, + ) + node_name = f"lambda_0x{id(node)}" + + entries: dict[str, ir.SSAValue] = {} + entr_block = ir.Block() + fn_self = entr_block.args.append_from( + types.MethodType[list(signature.inputs), signature.output], + node_name + "_self", + ) + entries[node_name] = fn_self + for arg, type in zip(node.args.args, signature.inputs): + entries[arg.arg] = entr_block.args.append_from(type, arg.arg) + + def callback(frame: lowering.Frame, value: ir.SSAValue): + first_stmt = entr_block.first_stmt + stmt = func.GetField(obj=fn_self, field=len(frame.captures) - 1) + if value.name: + stmt.result.name = value.name + stmt.result.type = value.type + stmt.source = state.source + if first_stmt: + stmt.insert_before(first_stmt) + else: + entr_block.stmts.append(stmt) + return stmt.result + + with state.frame( + [node.body], entr_block=entr_block, capture_callback=callback + ) as func_frame: + func_frame.defs.update(entries) + func_frame.exhaust() + + last_stmt = func_frame.curr_region.blocks[0].last_stmt + rtrn_stmt = func.Return(last_stmt.result) + func_frame.curr_block.stmts.append(rtrn_stmt) + + first_stmt = func_frame.curr_region.blocks[0].first_stmt + if first_stmt is None: + raise lowering.BuildError("empty lambda body") + + func_frame.curr_region.blocks[1].delete() + + lambda_stmt = func.Lambda( + tuple(value for value in func_frame.captures.values()), + sym_name=node_name, + slots=slots, + signature=signature, + body=func_frame.curr_region, + ) + + lambda_stmt.result.name = node_name + frame.push(lambda_stmt) + frame.defs[node_name] = lambda_stmt.result + return lambda_stmt.result + def assert_simple_arguments(self, node: ast.arguments) -> None: if node.kwonlyargs: raise lowering.BuildError("keyword-only arguments are not supported") diff --git a/src/kirin/dialects/py/assign.py b/src/kirin/dialects/py/assign.py index f36e5010e..ba1511425 100644 --- a/src/kirin/dialects/py/assign.py +++ b/src/kirin/dialects/py/assign.py @@ -125,9 +125,11 @@ def lower_Assign(self, state: lowering.State, node: ast.Assign) -> lowering.Resu case ast.Assign( targets=[ast.Name(lhs_name, ast.Store())], value=ast.Name(_, ast.Load()) ): + stmt = Alias( value=result.data[0], target=ir.PyAttr(lhs_name) ) # NOTE: this is guaranteed to be one result + stmt.result.name = lhs_name current_frame.defs[lhs_name] = current_frame.push(stmt).result case _: diff --git a/test/lowering/test_lambda_comp.py b/test/lowering/test_lambda_comp.py new file mode 100644 index 000000000..39c95a4fe --- /dev/null +++ b/test/lowering/test_lambda_comp.py @@ -0,0 +1,64 @@ +from kirin import ir +from kirin.prelude import basic +from kirin.dialects import ilist + + +def test_lambda_comp_with_closure(): + @basic(fold=False) + def main(z, r): + return (lambda x: x + z)(r) + + assert main(3, 4) == 7 + + +def test_lambda_comp(): + @basic(fold=False) + def main(z): + return lambda x: x + z + + x = main(3) + assert isinstance(x, ir.Method) + assert x(4) == 7 + + +def test_invoke_from_lambda_comp(): + + @basic + def foo(a): + return a * 2 + + @basic(fold=False) + def main(z): + return lambda x: x + foo(z) + + x = main(3) + + assert isinstance(x, ir.Method) + assert x(4) == 10 + + +def test_lambda_in_lambda(): + + @basic(fold=False) + def main(z): + + def my_foo(a): + return lambda x: x * a + + return my_foo(z) + + x = main(3) + + assert isinstance(x, ir.Method) + assert x(4) == 12 + + +def test_ilist_map(): + + @basic(fold=False) + def main(z): + return ilist.map(lambda x: x + z, ilist.range(10)) + + x = main(3) + assert len(x) == 10 + assert x.data == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] From 9253bc875d9082f3d81e56f4e88f4a2f865ae57d Mon Sep 17 00:00:00 2001 From: kaihsin Date: Tue, 6 Jan 2026 15:15:55 -0500 Subject: [PATCH 2/2] fix pyright --- src/kirin/dialects/lowering/func.py | 3 ++- test/lowering/{test_lambda_comp.py => test_py_lambda.py} | 0 2 files changed, 2 insertions(+), 1 deletion(-) rename test/lowering/{test_lambda_comp.py => test_py_lambda.py} (100%) diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index 97dcf45ec..75ed7d48f 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -155,7 +155,8 @@ def callback(frame: lowering.Frame, value: ir.SSAValue): func_frame.exhaust() last_stmt = func_frame.curr_region.blocks[0].last_stmt - rtrn_stmt = func.Return(last_stmt.result) + # assert hasattr(last_stmt,"result"), "python lambda should always have a return value" + rtrn_stmt = func.Return(last_stmt) func_frame.curr_block.stmts.append(rtrn_stmt) first_stmt = func_frame.curr_region.blocks[0].first_stmt diff --git a/test/lowering/test_lambda_comp.py b/test/lowering/test_py_lambda.py similarity index 100% rename from test/lowering/test_lambda_comp.py rename to test/lowering/test_py_lambda.py