Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions executing/executing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,50 @@ def wrapper(*args):
# noinspection PyUnresolvedReferences
text_type = unicode


class CachedIntervalList(object):
"""
Poor's man interval datastructure.

We don't need a full-interval tree but this is much faster to construct
and query that a mapping line:nodes, especially when some (nested) nodes
span a really large number of lines.
"""

_tree = {} # type: Dict[int, Dict[int, List[Any]]]

_cache = {} # type: Dict[int, List[Any]]


def add(self, low, high, value):
# type: (int, int, Any) -> None
assert low <= high
self._tree.setdefault(low, {}).setdefault(high, []).append(value)
self._cache = {}

def __getitem__(self, index):
# type: (int) -> Any
if index not in self._cache:
self._cache[index] = self._querry(index)
return self._cache[index]

def _querry(self, index):
# type: (int) -> List[Any]

acc = []
for mn, v in self._tree.items():
if mn > index:
continue
for it in [item for mx, item in v.items() if mx >= index]:
acc.extend(it)
return acc

def __init__(self):
# type: () -> None
self._tree = {}
self._cache = {}


# Type class used to expand out the definition of AST to include fields added by this library
# It's not actually used for anything other than type checking though!
class EnhancedAST(ast.AST):
Expand Down Expand Up @@ -233,7 +277,7 @@ def __init__(self, filename, lines):
lines = [line.decode(encoding) for line in lines]

self.text = text
self.lines = [line.rstrip('\r\n') for line in lines]
self.lines = text.splitlines()

if sys.version_info[0] == 3:
ast_text = text
Expand All @@ -247,7 +291,7 @@ def __init__(self, filename, lines):
for i, line in enumerate(lines)
])

self._nodes_by_line = defaultdict(list)
self._nodes_by_line = CachedIntervalList()
self.tree = None
self._qualnames = {}
self._asttokens = None # type: Optional[ASTTokens]
Expand All @@ -261,8 +305,10 @@ def __init__(self, filename, lines):
for node in ast.walk(self.tree):
for child in ast.iter_child_nodes(node):
cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
for lineno in node_linenos(node):
self._nodes_by_line[lineno].append(node)
linenos = node_minmax(node)
if linenos:
min_, max_ = linenos[0], linenos[-1]
self._nodes_by_line.add(min_, max_, node)

visitor = QualnameVisitor()
visitor.visit(self.tree)
Expand Down Expand Up @@ -1207,6 +1253,7 @@ def is_ipython_cell_code(code_obj):
)



def find_node_ipython(frame, lasti, stmts, source):
# type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
node = decorator = None
Expand Down Expand Up @@ -1239,7 +1286,6 @@ def attr_names_match(attr, argval):
return False
return bool(re.match(r"^_\w+%s$" % attr, argval))


def node_linenos(node):
# type: (ast.AST) -> Iterator[int]
if hasattr(node, "lineno"):
Expand All @@ -1252,6 +1298,17 @@ def node_linenos(node):
for lineno in linenos:
yield lineno

def node_minmax(node):
# type: (ast.AST) -> Optional[Tuple[int, int]]
if hasattr(node, "lineno"):
linenos = [] # type: Sequence[int]
if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
assert node.end_lineno is not None # type: ignore[attr-defined]
return (node.lineno, node.end_lineno) # type: ignore[attr-defined]
else:
return (node.lineno, node.lineno) # type: ignore[attr-defined]
return None


if sys.version_info >= (3, 11):
from ._position_node_finder import PositionNodeFinder as NodeFinder
Expand Down