22from ast import NodeTransformer
33import copy
44from typing import Callable , Any , List , Set , cast
5- from luisa_lang .utils import checked_cast , retrieve_ast_and_filename , NestedHashMap
5+ from luisa_lang .utils import Span , checked_cast , retrieve_ast_and_filename , NestedHashMap
66
77"""
88Rewrite rules:
@@ -163,17 +163,21 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
163163 return node
164164
165165 def visit_Name (self , node : ast .Name ) -> Any :
166+ span = Span .from_ast (node )
167+ assert span is not None
166168 # rewrite to __lc_ctx__.name
167- return ast .Subscript (
169+ return span . apply_to_ast ( ast .Subscript (
168170 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
169171 slice = ast .Constant (value = node .id ),
170172 ctx = node .ctx ,
171- )
173+ ))
172174
173175 def visit_Assign (self , node : ast .Assign ) -> Any :
174176 return self .generic_visit (node )
175177
176178 def visit_AnnAssign (self , node : ast .AnnAssign ) -> Any :
179+ span = Span .from_ast (node )
180+ assert span is not None
177181 target = checked_cast (ast .expr , self .visit (node .target ))
178182 assert isinstance (target , (ast .Name , ast .Subscript , ast .Attribute ))
179183 target .ctx = ast .Load ()
@@ -193,81 +197,93 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
193197 target = copy .deepcopy (target )
194198 target .ctx = ast .Store ()
195199 assign = ast .Assign (targets = [target ], value = self .visit (node .value ))
200+ span .apply_to_ast (anno )
201+ span .apply_to_ast (assign )
196202 return [anno , assign ]
197203
198204 def visit_Call (self , node : ast .Call ) -> Any :
205+ span = Span .from_ast (node )
206+ assert span is not None
199207 # first check if it is of form `__intrinsic__(...)`
200208 if isinstance (node .func , ast .Name ):
201209 if node .func .id in NO_REWRITE_FUNCTIONS :
202210 return node
203211 if node .func .id == "__intrinsic__" or node .func .id == "__intrinsic_checked__" :
204212 # rewrite to __lc_ctx__.intrinsic(...)
205- return ast .Call (
213+ return span . apply_to_ast ( ast .Call (
206214 func = ast .Attribute (
207215 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
208216 attr = node .func .id [2 :- 2 ],
209217 ctx = ast .Load (),
210218 ),
211219 args = [self .visit (arg ) for arg in node .args ],
212220 keywords = [self .visit (kw ) for kw in node .keywords ],
213- )
221+ ))
214222 # rewrite to __lc_ctx__.redirect_call(func, args...)
215223 func = self .visit (node .func )
216224 args = [self .visit (arg ) for arg in node .args ]
217225 keywords = [self .visit (kw ) for kw in node .keywords ]
218- return ast .Call (
226+ return span . apply_to_ast ( ast .Call (
219227 func = ast .Attribute (
220228 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
221229 attr = "redirect_call" ,
222230 ctx = ast .Load (),
223231 ),
224232 args = [func ] + args ,
225233 keywords = keywords ,
226- )
234+ ))
227235
228236 def visit_BinOp (self , node : ast .BinOp ) -> Any :
237+ span = Span .from_ast (node )
238+ assert span is not None
229239 lhs = self .visit (node .left )
230240 rhs = self .visit (node .right )
231- return ast .Call (
241+ return span . apply_to_ast ( ast .Call (
232242 func = ast .Attribute (
233243 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
234244 attr = "redirect_binary" ,
235245 ctx = ast .Load (),
236246 ),
237247 args = [ast .Constant (value = type (node .op ).__name__ ), lhs , rhs ],
238248 keywords = [],
239- )
249+ ))
240250
241251 def visit_UnaryOp (self , node : ast .UnaryOp ) -> Any :
252+ span = Span .from_ast (node )
253+ assert span is not None
242254 operand = self .visit (node .operand )
243- return ast .Call (
255+ return span . apply_to_ast ( ast .Call (
244256 func = ast .Attribute (
245257 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
246258 attr = "redirect_unary" ,
247259 ctx = ast .Load (),
248260 ),
249261 args = [ast .Constant (value = type (node .op ).__name__ ), operand ],
250262 keywords = [],
251- )
263+ ))
252264
253265 def visit_Compare (self , node : ast .Compare ) -> Any :
266+ span = Span .from_ast (node )
267+ assert span is not None
254268 if len (node .ops ) != 1 or len (node .comparators ) != 1 :
255269 raise NotImplementedError ("Only single comparison is supported" )
256270 left = self .visit (node .left )
257271 right = self .visit (node .comparators [0 ])
258- return ast .Call (
272+ return span . apply_to_ast ( ast .Call (
259273 func = ast .Attribute (
260274 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
261275 attr = "redirect_binary" ,
262276 ctx = ast .Load (),
263277 ),
264278 args = [ast .Constant (value = type (node .ops [0 ]).__name__ ), left , right ],
265279 keywords = [],
266- )
280+ ))
267281
268282 def visit_Subscript (self , node : ast .Subscript ) -> Any :
283+ span = Span .from_ast (node )
284+ assert span is not None
269285 value = self .visit (node .value )
270- return ast .Subscript (
286+ return span . apply_to_ast ( ast .Subscript (
271287 value = ast .Call (
272288 func = ast .Attribute (
273289 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -279,11 +295,13 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
279295 ),
280296 slice = node .slice ,
281297 ctx = node .ctx ,
282- )
298+ ))
283299
284300 def visit_Attribute (self , node : ast .Attribute ) -> Any :
301+ span = Span .from_ast (node )
302+ assert span is not None
285303 value = self .visit (node .value )
286- return ast .Attribute (
304+ return span . apply_to_ast ( ast .Attribute (
287305 value = ast .Call (
288306 func = ast .Attribute (
289307 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -295,9 +313,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
295313 ),
296314 attr = node .attr ,
297315 ctx = node .ctx ,
298- )
316+ ))
299317
300318 def visit_If (self , node : ast .If ) -> Any :
319+ span = Span .from_ast (node )
320+ assert span is not None
301321 if_id = self .new_id () + "_if"
302322 with_item = ast .withitem (
303323 context_expr = ast .Call (
@@ -361,10 +381,12 @@ def visit_If(self, node: ast.If) -> Any:
361381 ]),
362382 orelse = [],
363383 )
364- with_stmt = ast .With (items = [with_item ], body = [true_branch , false_branch ])
384+ with_stmt = span . apply_to_ast ( ast .With (items = [with_item ], body = [true_branch , false_branch ]) )
365385 return with_stmt
366386
367387 def visit_Return (self , node : ast .Return ) -> Any :
388+ span = Span .from_ast (node )
389+ assert span is not None
368390 self .return_cnt += 1
369391 if self .is_tracing :
370392 if self .return_cnt > 1 :
@@ -380,7 +402,7 @@ def visit_Return(self, node: ast.Return) -> Any:
380402 tmp = self .visit (node .value )
381403 assert isinstance (tmp , ast .expr )
382404 ret_value = tmp
383- return ast .If (
405+ return span . apply_to_ast ( ast .If (
384406 test = ast .Call (
385407 func = ast .Attribute (
386408 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -407,10 +429,12 @@ def visit_Return(self, node: ast.Return) -> Any:
407429 )
408430 ],
409431 ),
410- )
432+ ))
411433
412434 def visit_Break (self , node : ast .Break ) -> Any :
413- return ast .If (
435+ span = Span .from_ast (node )
436+ assert span is not None
437+ return span .apply_to_ast (ast .If (
414438 test = ast .Call (
415439 func = ast .Attribute (
416440 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -437,10 +461,12 @@ def visit_Break(self, node: ast.Break) -> Any:
437461 )
438462 ],
439463 ),
440- )
464+ ))
441465
442466 def visit_Continue (self , node : ast .Continue ) -> Any :
443- return ast .If (
467+ span = Span .from_ast (node )
468+ assert span is not None
469+ return span .apply_to_ast (ast .If (
444470 test = ast .Call (
445471 func = ast .Attribute (
446472 value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -467,15 +493,14 @@ def visit_Continue(self, node: ast.Continue) -> Any:
467493 )
468494 ],
469495 ),
470- )
496+ ))
471497
472498
473499def rewrite_function [F : Callable [..., Any ]](f : F , decorator_name : str ) -> F :
474500 tree , filename = retrieve_ast_and_filename (f )
475501 tree = FuncRewriter (decorator_name , filename ).visit (tree )
476502 ast .fix_missing_locations (tree )
477- # print(ast.unparse(tree))
478- code = compile (tree , filename = "<ast>" , mode = "exec" )
503+ code = compile (tree , filename = filename , mode = "exec" )
479504 local_dict : dict [Any , Any ] = {}
480505 exec (code , f .__globals__ , local_dict )
481506 rewrote_f = local_dict [f .__name__ ]
0 commit comments