88from luisa_lang .utils import IdentityDict , check_type , is_generic_class
99import luisa_lang .hir as hir
1010from luisa_lang .hir import PyTreeStructure
11- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union , cast
11+ from typing import (
12+ Any ,
13+ Callable ,
14+ Dict ,
15+ List ,
16+ Mapping ,
17+ Optional ,
18+ Sequence ,
19+ Tuple ,
20+ Type ,
21+ Union ,
22+ cast ,
23+ )
1224
1325
1426class Scope :
@@ -43,12 +55,12 @@ class FuncTracer:
4355 locals : List [hir .Var ]
4456 params : List [hir .Var ]
4557 scopes : List [Scope ]
46- ret_type : hir . Type | None
58+ ret_type : Type [ "JitVar" ] | None
4759 func_globals : Dict [str , Any ]
4860 name : str
4961 entry_bb : hir .BasicBlock
5062
51- def __init__ (self , name :str , func_globals : Dict [str , Any ]):
63+ def __init__ (self , name : str , func_globals : Dict [str , Any ]):
5264 self .locals = []
5365 self .py_locals = {}
5466 self .params = []
@@ -76,7 +88,6 @@ def create_var(self, name: str, ty: hir.Type, is_param: bool) -> hir.Var:
7688 self .params .append (var )
7789 return var
7890
79-
8091 def add_py_var (self , name : str , obj : object ):
8192 assert not isinstance (obj , JitVar )
8293 if name in self .py_locals :
@@ -119,7 +130,7 @@ def set_var(self, key: str, value: Any) -> None:
119130 else :
120131 self .py_locals [key ] = value
121132
122- def check_return_type (self , ty :hir . Type ) :
133+ def check_return_type (self , ty : Type [ "JitVar" ]) -> None :
123134 if self .ret_type is None :
124135 self .ret_type = ty
125136 else :
@@ -130,7 +141,7 @@ def check_return_type(self, ty:hir.Type):
130141
131142 def cur_bb (self ) -> hir .BasicBlock :
132143 return self .scopes [- 1 ].bb
133-
144+
134145 def set_cur_bb (self , bb : hir .BasicBlock ) -> None :
135146 """
136147 Set the current basic block to `bb`
@@ -142,7 +153,14 @@ def finalize(self) -> hir.Function:
142153 assert len (self .scopes ) == 1
143154 entry_bb = self .entry_bb
144155 assert self .ret_type is not None
145- return hir .Function (self .name , self .params , self .locals , entry_bb , self .ret_type )
156+ return hir .Function (
157+ self .name ,
158+ self .params ,
159+ self .locals ,
160+ entry_bb ,
161+ self .ret_type ,
162+ self .ret_type .hir_type (),
163+ )
146164
147165
148166FUNC_STACK : List [FuncTracer ] = []
@@ -160,7 +178,6 @@ def push_to_current_bb[T: hir.Node](node: T) -> T:
160178 return current_func ().cur_bb ().append (node )
161179
162180
163-
164181class Symbolic :
165182 node : hir .Value
166183 scope : Scope
@@ -178,7 +195,9 @@ class FlattenedTree:
178195 children : List ["FlattenedTree" ]
179196
180197 def __init__ (
181- self , metadata : Tuple [Type [Any ], Tuple [Any ], Any ], children : List ["FlattenedTree" ]
198+ self ,
199+ metadata : Tuple [Type [Any ], Tuple [Any ], Any ],
200+ children : List ["FlattenedTree" ],
182201 ):
183202 self .metadata = metadata
184203 self .children = children
@@ -214,8 +233,8 @@ def structure(self) -> hir.PyTreeStructure:
214233 return hir .PyTreeStructure (
215234 (typ , self .metadata [1 ], self .metadata [2 ]), children
216235 )
217-
218- def collect_jitvars (self ) -> List [' JitVar' ]:
236+
237+ def collect_jitvars (self ) -> List [" JitVar" ]:
219238 """
220239 Collect all JitVar instances from the flattened tree
221240 """
@@ -275,7 +294,7 @@ class JitVar:
275294 __symbolic__ : Optional [Symbolic ]
276295 dtype : type [Any ]
277296
278- def __init__ (self , dtype :type [Any ]):
297+ def __init__ (self , dtype : type [Any ]):
279298 """
280299 Zero-initialize a variable with given data type
281300 """
@@ -316,6 +335,14 @@ def from_hir_node[T: JitVar](cls: type[T], node: hir.Value) -> T:
316335 instance .dtype = cls
317336 return instance
318337
338+ @classmethod
339+ def hir_type (cls ) -> hir .Type :
340+ """
341+ Get the HIR type of the JitVar
342+ """
343+ # TODO: handle generic types
344+ return hir .get_dsl_type (cls ).default ()
345+
319346 def symbolic (self ) -> Symbolic :
320347 """
321348 Retrieve the internal symbolic representation of the variable. This is used for internal DSL code generation.
@@ -423,7 +450,8 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423450
424451 def flatten_list (obj : List [Any ]) -> FlattenedTree :
425452 return FlattenedTree (
426- (list , cast (Tuple [Any , ...], tuple ()), None ), [tree_flatten (o , True ) for o in obj ]
453+ (list , cast (Tuple [Any , ...], tuple ()), None ),
454+ [tree_flatten (o , True ) for o in obj ],
427455 )
428456
429457 def unflatten_list (tree : FlattenedTree ) -> List [Any ]:
@@ -435,7 +463,8 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435463
436464 def flatten_tuple (obj : Tuple [Any , ...]) -> FlattenedTree :
437465 return FlattenedTree (
438- (tuple , cast (Tuple [Any , ...], tuple ()), None ), [tree_flatten (o , True ) for o in obj ]
466+ (tuple , cast (Tuple [Any , ...], tuple ()), None ),
467+ [tree_flatten (o , True ) for o in obj ],
439468 )
440469
441470 def unflatten_tuple (tree : FlattenedTree ) -> Tuple [Any , ...]:
@@ -491,7 +520,9 @@ def create_intrinsic_node[T: JitVar](
491520 elif isinstance (a , hir .Value ):
492521 nodes .append (a )
493522 else :
494- raise ValueError (f"Argument [{ i } ] `{ a } ` of type { type (a )} is not a valid DSL variable or HIR node" )
523+ raise ValueError (
524+ f"Argument [{ i } ] `{ a } ` of type { type (a )} is not a valid DSL variable or HIR node"
525+ )
495526 if ret_type is not None :
496527 ret_dsl_type = hir .get_dsl_type (ret_type ).default ()
497528 if ret_dsl_type is None :
@@ -500,25 +531,28 @@ def create_intrinsic_node[T: JitVar](
500531 ret_dsl_type = hir .UnitType ()
501532 return push_to_current_bb (hir .Intrinsic (name , nodes , ret_dsl_type ))
502533
534+
503535def __escape__ (x : Any ) -> Any :
504536 return x
505537
538+
506539def __intrinsic_checked__ [T ](
507540 name : str , arg_types : Sequence [Any ], ret_type : type [T ], * args
508541) -> T :
509542 """
510543 Call an intrinsic function with type checking.
511544 """
512- assert len (args ) == len (arg_types ), (
513- f"Intrinsic { name } expects { len ( arg_types ) } arguments, got { len ( args ) } "
514- )
545+ assert len (args ) == len (
546+ arg_types
547+ ), f"Intrinsic { name } expects { len ( arg_types ) } arguments, got { len ( args ) } "
515548 for i , (arg , arg_type ) in enumerate (zip (args , arg_types )):
516549 if not check_type (arg_type , arg ):
517550 raise ValueError (
518551 f"Argument { i } of intrinsic { name } is not of type { arg_type } , got { type (arg )} "
519552 )
520553 return __intrinsic__ (name , ret_type , * args )
521554
555+
522556def __intrinsic__ [T ](name : str , ret_type : type [T ], * args ) -> T :
523557 """
524558 Call an intrinsic function. This function does not check the arguemnts.
@@ -572,19 +606,21 @@ def on_exit(self) -> None:
572606 """
573607 pass
574608
609+
575610class ScopeGuard :
576611 def __enter__ (self ) -> Scope :
577612 """
578613 Enter a new scope
579614 """
580615 return current_func ().push_scope ()
581-
616+
582617 def __exit__ (self , exc_type , exc_val , exc_tb ):
583618 """
584619 Exit the current scope
585620 """
586621 current_func ().pop_scope ()
587622
623+
588624class IfFrame (ControlFlowFrame ):
589625 static_cond : Optional [bool ]
590626 true_bb : Optional [hir .BasicBlock ]
@@ -625,12 +661,16 @@ def on_exit(self) -> None:
625661 cond = self .cond
626662 assert isinstance (cond , JitVar ), "Condition must be a DSL variable"
627663 merge_bb = hir .BasicBlock ()
628- if_stmt = hir .If (cond .symbolic ().node ,
629- cast (hir .BasicBlock , self .true_bb ),
630- cast (hir .BasicBlock , self .false_bb ), merge_bb )
664+ if_stmt = hir .If (
665+ cond .symbolic ().node ,
666+ cast (hir .BasicBlock , self .true_bb ),
667+ cast (hir .BasicBlock , self .false_bb ),
668+ merge_bb ,
669+ )
631670 push_to_current_bb (if_stmt )
632671 current_func ().set_cur_bb (merge_bb )
633672
673+
634674class ControlFrameGuard [T : ControlFlowFrame ]:
635675 cf_type : type [T ]
636676 args : Tuple [Any , ...]
@@ -707,31 +747,52 @@ def __exit__(self, exc_type, exc_val, exc_tb):
707747 "GtE" : ["__ge__" , "__le__" ],
708748}
709749
750+
751+ class LineTable :
752+ span_of_line : Dict [int , hir .Span ]
753+
754+
710755class TraceContext :
711756 cf_frame : ControlFlowFrame
712757 is_top_level : bool
713758 top_level_func : Optional [hir .Function ]
759+ line_table : Optional [LineTable ] # for better error reporting
760+ current_line : Optional [int ] = None
714761
715762 def __init__ (self , is_top_level ):
716763 self .cf_frame = ControlFlowFrame (parent = None )
717764 self .is_top_level = is_top_level
718765 self .top_level_func = None
719766
767+ def set_line_table (self , line_table : LineTable ) -> None :
768+ self .line_table = line_table
769+
770+ def set_current_line (self , line : int ) -> None :
771+ self .current_line = line
772+
773+ def current_span (self ) -> hir .Span | None :
774+ """
775+ Get the current span for error reporting
776+ """
777+ if self .line_table is not None and self .current_line is not None :
778+ return self .line_table .span_of_line .get (self .current_line , None )
779+ return None
780+
720781 def is_parent_static (self ) -> bool :
721782 return self .cf_frame .is_static
722783
723784 def if_ (self , cond : Any ) -> ControlFrameGuard [IfFrame ]:
724785 return ControlFrameGuard (self , IfFrame , cond )
725-
786+
726787 def scope (self ) -> ScopeGuard :
727788 return ScopeGuard ()
728789
729790 def return_ (self , expr : JitVar ) -> None :
730791 """
731792 Return a value from the current function
732793 """
733- ty = expr . _symbolic_type ()
734- current_func ().check_return_type (ty )
794+ assert isinstance ( expr , JitVar ), "Return expression must be a DSL variable"
795+ current_func ().check_return_type (type ( expr )) # TODO: handle generics
735796 push_to_current_bb (hir .Return (expr .symbolic ().node ))
736797
737798 def redirect_binary (self , op , x , y ):
@@ -745,7 +806,7 @@ def redirect_binary(self, op, x, y):
745806 raise ValueError (
746807 f"Binary operation { op } not supported for { type (x )} and { type (y )} "
747808 )
748-
809+
749810 def redirect_cmp (self , op , x , y ):
750811 op , rop = CMP_OP_TO_METHOD_NAMES [op ]
751812 if hasattr (x , op ):
@@ -758,11 +819,11 @@ def redirect_cmp(self, op, x, y):
758819 )
759820
760821 def redirect_call (self , f , * args , ** kwargs ):
761- return f (* args , ** kwargs , __lc_ctx__ = self ) # TODO: shoould not always pass self
822+ return f (* args , ** kwargs , __lc_ctx__ = self ) # TODO: shoould not always pass self
762823
763824 def intrinsic (self , f , * args , ** kwargs ):
764825 return __intrinsic__ (f , * args , ** kwargs )
765-
826+
766827 def intrinsic_checked (self , f , arg_types , ret_type , * args ):
767828 return __intrinsic_checked__ (f , arg_types , ret_type , * args )
768829
@@ -837,7 +898,7 @@ def _invoke_function_tracer(
837898 trace_ctx = TraceContext (False )
838899
839900 # args is Type | object
840- func_tracer = FuncTracer (f .__name__ .replace ('.' , '_' ), globalns )
901+ func_tracer = FuncTracer (f .__name__ .replace ("." , "_" ), globalns )
841902 FUNC_STACK .append (func_tracer )
842903 try :
843904 args_vars , kwargs_vars , jit_vars = _encode_func_args (args )
@@ -852,7 +913,7 @@ class KernelTracer:
852913 top_level_tracer : FuncTracer
853914
854915 def __init__ (self , func_globals : Dict [str , Any ]):
855- self .top_level_tracer = FuncTracer (' __kernel__' , func_globals )
916+ self .top_level_tracer = FuncTracer (" __kernel__" , func_globals )
856917
857918 def __enter__ (self ) -> FuncTracer :
858919 FUNC_STACK .append (self .top_level_tracer )
0 commit comments