diff --git a/executing/executing.py b/executing/executing.py index c28091d..6c68683 100644 --- a/executing/executing.py +++ b/executing/executing.py @@ -85,6 +85,50 @@ def wrapper(*args): # noinspection PyUnresolvedReferences text_type = unicode + +class CachedIntervalList(object): + """ + Poor's man interval datastructure. + + We don't need a full-interval tree but this is much faster to construct + and query that a mapping line:nodes, especially when some (nested) nodes + span a really large number of lines. + """ + + _tree = {} # type: Dict[int, Dict[int, List[Any]]] + + _cache = {} # type: Dict[int, List[Any]] + + + def add(self, low, high, value): + # type: (int, int, Any) -> None + assert low <= high + self._tree.setdefault(low, {}).setdefault(high, []).append(value) + self._cache = {} + + def __getitem__(self, index): + # type: (int) -> Any + if index not in self._cache: + self._cache[index] = self._querry(index) + return self._cache[index] + + def _querry(self, index): + # type: (int) -> List[Any] + + acc = [] + for mn, v in self._tree.items(): + if mn > index: + continue + for it in [item for mx, item in v.items() if mx >= index]: + acc.extend(it) + return acc + + def __init__(self): + # type: () -> None + self._tree = {} + self._cache = {} + + # Type class used to expand out the definition of AST to include fields added by this library # It's not actually used for anything other than type checking though! class EnhancedAST(ast.AST): @@ -233,7 +277,7 @@ def __init__(self, filename, lines): lines = [line.decode(encoding) for line in lines] self.text = text - self.lines = [line.rstrip('\r\n') for line in lines] + self.lines = text.splitlines() if sys.version_info[0] == 3: ast_text = text @@ -247,7 +291,7 @@ def __init__(self, filename, lines): for i, line in enumerate(lines) ]) - self._nodes_by_line = defaultdict(list) + self._nodes_by_line = CachedIntervalList() self.tree = None self._qualnames = {} self._asttokens = None # type: Optional[ASTTokens] @@ -261,8 +305,10 @@ def __init__(self, filename, lines): for node in ast.walk(self.tree): for child in ast.iter_child_nodes(node): cast(EnhancedAST, child).parent = cast(EnhancedAST, node) - for lineno in node_linenos(node): - self._nodes_by_line[lineno].append(node) + linenos = node_minmax(node) + if linenos: + min_, max_ = linenos[0], linenos[-1] + self._nodes_by_line.add(min_, max_, node) visitor = QualnameVisitor() visitor.visit(self.tree) @@ -1207,6 +1253,7 @@ def is_ipython_cell_code(code_obj): ) + def find_node_ipython(frame, lasti, stmts, source): # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]] node = decorator = None @@ -1239,7 +1286,6 @@ def attr_names_match(attr, argval): return False return bool(re.match(r"^_\w+%s$" % attr, argval)) - def node_linenos(node): # type: (ast.AST) -> Iterator[int] if hasattr(node, "lineno"): @@ -1252,6 +1298,17 @@ def node_linenos(node): for lineno in linenos: yield lineno +def node_minmax(node): + # type: (ast.AST) -> Optional[Tuple[int, int]] + if hasattr(node, "lineno"): + linenos = [] # type: Sequence[int] + if hasattr(node, "end_lineno") and isinstance(node, ast.expr): + assert node.end_lineno is not None # type: ignore[attr-defined] + return (node.lineno, node.end_lineno) # type: ignore[attr-defined] + else: + return (node.lineno, node.lineno) # type: ignore[attr-defined] + return None + if sys.version_info >= (3, 11): from ._position_node_finder import PositionNodeFinder as NodeFinder