1414
1515TNode = TypeVar ("TNode" , bound = Node )
1616
17-
1817LUA_DOUBLE_SQUARE_RE = re .compile (r'^\[(?P<eq>=*)\[(?P<body>[\s\S]*?)\]\1\]$' )
1918
19+
2020def _listify (obj ):
2121 if not isinstance (obj , list ):
2222 return [obj ]
@@ -79,7 +79,7 @@ def add_comments(self, start, stop: CommonToken, node: TNode,
7979
8080 if ((token .channel is LuaLexer .DEFAULT_TOKEN_CHANNEL and
8181 (token .type is not LuaLexer .COMMA ) or (token .type is LuaLexer .COMMA and not ignore_left_comma )) or
82- ((token .type is LuaLexer .NL ) and (token_next .type is LuaLexer .NL ) and not ignore_left_double_nl )):
82+ ((token .type is LuaLexer .NL ) and (token_next .type is LuaLexer .NL ) and not ignore_left_double_nl )):
8383 break
8484
8585 if token .channel == self .COMMENT_CHANNEL :
@@ -105,7 +105,7 @@ def add_comments(self, start, stop: CommonToken, node: TNode,
105105 next_right_token += 1
106106
107107 if ((token .channel is LuaLexer .DEFAULT_TOKEN_CHANNEL and token .type is not LuaLexer .COMMA ) or
108- (token .type is LuaLexer .NL and not ignore_right_nl )):
108+ (token .type is LuaLexer .NL and not ignore_right_nl )):
109109 break
110110
111111 if token .channel == self .COMMENT_CHANNEL :
@@ -452,51 +452,42 @@ def visitVar(self, ctx: LuaParser.VarContext):
452452 return Name (ctx .NAME ().getText ())
453453 else : # prefixexp tail
454454 root = self .visit (ctx .prefixexp ())
455- return self .visitAllTails (root , [ctx .tail ()])
455+ return self .visit_tail_chain (root , [ctx .tail ()])
456456
457457 # Visit a parse tree produced by LuaParser#prefixexp.
458458 def visitPrefixexp (self , ctx : LuaParser .PrefixexpContext ):
459- if ctx .NAME (): # NAME tail*
460- root = self .visit (ctx .NAME ())
461- elif ctx .functioncall (): # functioncall tail*
459+ if ctx .functioncall (): # functioncall tail*
462460 root = self .visit (ctx .functioncall ())
461+ elif ctx .NAME (): # NAME tail*
462+ root = self .visit (ctx .NAME ())
463463 else : # '(' exp ')' tail*
464464 root : Expression = self .visit (ctx .exp ())
465465 root .wrapped = True
466466
467- tail = self .visitAllTails (root , ctx .tail ())
467+ tail = self .visit_tail_chain (root , ctx .tail ())
468468 return tail
469469
470470 # Visit a parse tree produced by LuaParser#functioncall_name.
471471 def visitFunctioncall_name (self , ctx : LuaParser .Functioncall_nameContext ):
472472 name = self .visit (ctx .NAME ())
473- tail = self .visitAllTails (name , ctx .tail ())
474- par , args = self .visitArgs (ctx .args ())
475- return self .add_context (ctx , Call (tail , _listify (args ),
476- style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS ))
473+ return self .visit_call_chain (name , ctx .call ())
477474
478475 # Visit a parse tree produced by LuaParser#functioncall_nested.
479476 def visitFunctioncall_nested (self , ctx : LuaParser .Functioncall_nestedContext ):
480477 call = self .visit (ctx .functioncall ())
481- tail = self .visitAllTails (call , ctx .tail ())
482- par , args = self .visitArgs (ctx .args ())
483- return self .add_context (ctx , Call (tail , _listify (args ),
484- style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS ))
478+ return self .visit_call_chain (call , ctx .call ())
485479
486480 # Visit a parse tree produced by LuaParser#functioncall_exp.
487481 def visitFunctioncall_exp (self , ctx : LuaParser .Functioncall_expContext ):
488482 exp = self .visitExp (ctx .exp ())
489483 exp .wrapped = True
490- tail = self .visitAllTails (exp , ctx .tail ())
491- par , args = self .visitArgs (ctx .args ())
492- return self .add_context (ctx , Call (tail , _listify (args ),
493- style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS ))
484+ return self .visit_call_chain (exp , ctx .call ())
494485
495486 # Visit a parse tree produced by LuaParser#functioncall_expinvoke.
496487 def visitFunctioncall_expinvoke (self , ctx : LuaParser .Functioncall_expinvokeContext ):
497488 exp = self .visitExp (ctx .exp ())
498489 exp .wrapped = True
499- tail = self .visitAllTails (exp , ctx .tail ())
490+ tail = self .visit_tail_chain (exp , ctx .tail ())
500491 par , args = self .visitArgs (ctx .args ())
501492 func = self .visit (ctx .NAME ())
502493 return self .add_context (ctx , Invoke (tail , func , _listify (args ),
@@ -506,26 +497,61 @@ def visitFunctioncall_expinvoke(self, ctx: LuaParser.Functioncall_expinvokeConte
506497 def visitFunctioncall_invoke (self , ctx : LuaParser .Functioncall_invokeContext ):
507498 source = self .visit (ctx .NAME (0 ))
508499 func = self .visit (ctx .NAME (1 ))
509- tail = self .visitAllTails (source , ctx .tail ())
500+ tail = self .visit_tail_chain (source , ctx .tail ())
510501 par , args = self .visitArgs (ctx .args ())
511- return self .add_context (ctx , Invoke (tail , func , _listify (args ),
512- style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS ))
502+ invoke = self .add_context (ctx , Invoke (
503+ tail ,
504+ func ,
505+ _listify (args ),
506+ style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS
507+ ))
508+ return self .visit_call_chain (invoke , ctx .call ())
513509
514510 # Visit a parse tree produced by LuaParser#functioncall_nestedinvoke.
515511 def visitFunctioncall_nestedinvoke (self , ctx : LuaParser .Functioncall_nestedinvokeContext ):
516512 call = self .visit (ctx .functioncall ())
517513 func = self .visit (ctx .NAME ())
518- tail = self .visitAllTails (call , ctx .tail ())
514+ tail = self .visit_tail_chain (call , ctx .tail ())
519515 par , args = self .visitArgs (ctx .args ())
520- return self .add_context (ctx , Invoke (tail , func , _listify (args ),
521- style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS ))
516+ invoke = self .add_context (ctx , Invoke (
517+ tail ,
518+ func ,
519+ _listify (args ),
520+ style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS
521+ ))
522+ return self .visit_call_chain (invoke , ctx .call ())
523+
524+ def visit_call_chain (self , root_exp : Optional [Expression ], calls : List [LuaParser .CallContext ]):
525+ if not calls :
526+ return root_exp
527+
528+ root = root_exp # parent root will be set in caller
529+ call : Call = self .visit_call (root , calls [0 ]) # root tail
530+ i = 1
531+ while call :
532+ root = call
533+ if i >= len (calls ):
534+ break
535+
536+ call = self .visit_call (root , calls [i ])
537+ i += 1
538+ return root
539+
540+ def visit_call (self , root_exp : Optional [Expression ], ctx : LuaParser .CallContext ):
541+ tail = self .visit_tail_chain (root_exp , ctx .tail ())
542+ par , args = self .visitArgs (ctx .args ())
543+ return self .add_context (ctx , Call (
544+ tail ,
545+ _listify (args ),
546+ style = CallStyle .DEFAULT if par else CallStyle .NO_PARENTHESIS
547+ ))
522548
523- def visitAllTails (self , root_exp : Expression , tails : List [LuaParser .TailContext ]):
549+ def visit_tail_chain (self , root_exp : Optional [ Expression ] , tails : List [LuaParser .TailContext ]):
524550 if not tails :
525551 return root_exp
526552
527553 root = root_exp # parent root will be set in caller
528- tail : Index = self .visitTail (tails [0 ]) # root tail
554+ tail : Index = self .visit (tails [0 ]) # root tail
529555 i = 1
530556 while tail :
531557 tail .value = root
0 commit comments