77
88from luisa_lang .utils import IdentityDict , check_type , is_generic_class
99import luisa_lang .hir as hir
10- from hir import PyTreeStructure
11- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Union , cast
10+ from luisa_lang . hir import PyTreeStructure
11+ from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union , cast
1212
1313
1414class Scope :
@@ -174,11 +174,11 @@ def __init__(self, node: hir.Value, scope: Scope | None = None):
174174
175175
176176class FlattenedTree :
177- metadata : Tuple [type , Tuple [Any ], Any ] # (type, type_args, Any)
177+ metadata : Tuple [Type [ Any ] , Tuple [Any ], Any ] # (type, type_args, Any)
178178 children : List ["FlattenedTree" ]
179179
180180 def __init__ (
181- self , metadata : Tuple [type , Tuple [Any ], Any ], children : List ["FlattenedTree" ]
181+ self , metadata : Tuple [Type [ Any ] , Tuple [Any ], Any ], children : List ["FlattenedTree" ]
182182 ):
183183 self .metadata = metadata
184184 self .children = children
@@ -342,13 +342,13 @@ def tree_flatten(obj: Any, allow_non_pytree_objects: bool) -> FlattenedTree:
342342 if isinstance (obj , PyTree ):
343343 return obj ._flatten ()
344344 if allow_non_pytree_objects and not PyTreeRegistry .is_registered (type (obj )):
345- return FlattenedTree ((type (obj ), tuple (), obj ), [])
345+ return FlattenedTree ((type (obj ), cast ( Tuple [ Any , ...], tuple () ), obj ), [])
346346 flatten_func , _ = PyTreeRegistry .get (type (obj ))
347347 return flatten_func (obj )
348348
349349
350350def tree_unflatten (obj : FlattenedTree , allow_non_pytree_objects : bool ) -> Any :
351- typ = obj .metadata [0 ]
351+ typ : Type [ Any ] = obj .metadata [0 ]
352352 if issubclass (typ , JitVar ):
353353 _type_args , v = obj .metadata [1 :]
354354 assert isinstance (v , JitVar )
@@ -409,7 +409,7 @@ def is_registered(typ: type) -> bool:
409409 @staticmethod
410410 def __register_default_types () -> None :
411411 def flatten_primitive (obj : Any ) -> FlattenedTree :
412- return FlattenedTree ((type (obj ), tuple (), obj ), [])
412+ return FlattenedTree ((type (obj ), cast ( Tuple [ Any , ...], tuple () ), obj ), [])
413413
414414 def unflatten_primitive (tree : FlattenedTree ) -> Any :
415415 assert len (tree .children ) == 0
@@ -423,7 +423,7 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423423
424424 def flatten_list (obj : List [Any ]) -> FlattenedTree :
425425 return FlattenedTree (
426- (list , tuple (), None ), [tree_flatten (o , True ) for o in obj ]
426+ (list , cast ( Tuple [ Any , ...], tuple () ), None ), [tree_flatten (o , True ) for o in obj ]
427427 )
428428
429429 def unflatten_list (tree : FlattenedTree ) -> List [Any ]:
@@ -435,7 +435,7 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435435
436436 def flatten_tuple (obj : Tuple [Any , ...]) -> FlattenedTree :
437437 return FlattenedTree (
438- (tuple , tuple (), None ), [tree_flatten (o , True ) for o in obj ]
438+ (tuple , cast ( Tuple [ Any , ...], tuple () ), None ), [tree_flatten (o , True ) for o in obj ]
439439 )
440440
441441 def unflatten_tuple (tree : FlattenedTree ) -> Tuple [Any , ...]:
@@ -447,14 +447,15 @@ def unflatten_tuple(tree: FlattenedTree) -> Tuple[Any, ...]:
447447
448448 def flatten_dict (obj : Dict [Any , Any ]) -> FlattenedTree :
449449 return FlattenedTree (
450- (dict , tuple (), (len (obj .keys ()))),
450+ (dict , cast ( Tuple [ Any , ...], tuple () ), (len (obj .keys ()))),
451451 [tree_flatten (k , True ) for k in obj .keys ()]
452452 + [tree_flatten (v , True ) for v in obj .values ()],
453453 )
454454
455455 def unflatten_dict (tree : FlattenedTree ) -> Dict [Any , Any ]:
456456 assert tree .metadata [0 ] is dict
457- length = tree .metadata [1 ]
457+ length = tree .metadata [2 ][0 ]
458+ assert isinstance (length , int ), "Invalid length for dict unflattening"
458459 assert len (tree .children ) == length * 2
459460 keys = tree .children [:length ]
460461 values = tree .children [length :]
@@ -561,7 +562,7 @@ class ControlFlowFrame:
561562 parent : Optional ["ControlFlowFrame" ]
562563 is_static : bool
563564
564- def __init__ (self , parent : Optional ["ControlFlowFrame" ]):
565+ def __init__ (self , * args , parent : Optional ["ControlFlowFrame" ]):
565566 self .parent = parent
566567 self .is_static = False
567568
@@ -590,7 +591,7 @@ class IfFrame(ControlFlowFrame):
590591 false_bb : Optional [hir .BasicBlock ]
591592
592593 def __init__ (self , cond : Any , parent : ControlFlowFrame ):
593- super ().__init__ (parent )
594+ super ().__init__ (parent = parent )
594595 self .cond = cond
595596 self .is_static = not isinstance (cond , JitVar )
596597 self .static_cond = bool (cond ) if self .is_static else None
@@ -712,7 +713,7 @@ class TraceContext:
712713 top_level_func : Optional [hir .Function ]
713714
714715 def __init__ (self , is_top_level ):
715- self .cf_frame = ControlFlowFrame (None )
716+ self .cf_frame = ControlFlowFrame (parent = None )
716717 self .is_top_level = is_top_level
717718 self .top_level_func = None
718719
0 commit comments