diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8248a1b..0d0e3c7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ on: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: python-version: [3.8, 3.9, '3.10', 3.11, 3.12-dev,3.13-dev] diff --git a/executing/_position_node_finder.py b/executing/_position_node_finder.py index 0f83441..a912400 100644 --- a/executing/_position_node_finder.py +++ b/executing/_position_node_finder.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Iterator, Optional, Sequence, Set, Tuple, Type, Union, cast from .executing import EnhancedAST, NotOneValueFound, Source, only, function_node_types, assert_ from ._exceptions import KnownIssue, VerifierFailure +from ._utils import mangled_name from functools import lru_cache @@ -25,51 +26,6 @@ def node_and_parents(node: EnhancedAST) -> Iterator[EnhancedAST]: yield from parents(node) -def mangled_name(node: EnhancedAST) -> str: - """ - - Parameters: - node: the node which should be mangled - name: the name of the node - - Returns: - The mangled name of `node` - """ - if isinstance(node, ast.Attribute): - name = node.attr - elif isinstance(node, ast.Name): - name = node.id - elif isinstance(node, (ast.alias)): - name = node.asname or node.name.split(".")[0] - elif isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)): - name = node.name - elif isinstance(node, ast.ExceptHandler): - assert node.name - name = node.name - elif sys.version_info >= (3,12) and isinstance(node,ast.TypeVar): - name=node.name - else: - raise TypeError("no node to mangle for type "+repr(type(node))) - - if name.startswith("__") and not name.endswith("__"): - - parent,child=node.parent,node - - while not (isinstance(parent,ast.ClassDef) and child not in parent.bases): - if not hasattr(parent,"parent"): - break # pragma: no mutate - - parent,child=parent.parent,parent - else: - class_name=parent.name.lstrip("_") - if class_name!="": - return "_" + class_name + name - - - - return name - - @lru_cache(128) # pragma: no mutate def get_instructions(code: CodeType) -> list[dis.Instruction]: return list(dis.get_instructions(code)) diff --git a/executing/_utils.py b/executing/_utils.py new file mode 100644 index 0000000..4251184 --- /dev/null +++ b/executing/_utils.py @@ -0,0 +1,139 @@ + +import ast +import sys +import dis +from typing import cast, Any,Iterator +import types + + + +def assert_(condition, message=""): + # type: (Any, str) -> None + """ + Like an assert statement, but unaffected by -O + :param condition: value that is expected to be truthy + :type message: Any + """ + if not condition: + raise AssertionError(str(message)) + + +if sys.version_info >= (3, 4): + # noinspection PyUnresolvedReferences + _get_instructions = dis.get_instructions + from dis import Instruction as _Instruction + + class Instruction(_Instruction): + lineno = None # type: int +else: + from collections import namedtuple + + class Instruction(namedtuple('Instruction', 'offset argval opname starts_line')): + lineno = None # type: int + + from dis import HAVE_ARGUMENT, EXTENDED_ARG, hasconst, opname, findlinestarts, hasname + + # Based on dis.disassemble from 2.7 + # Left as similar as possible for easy diff + + def _get_instructions(co): + # type: (types.CodeType) -> Iterator[Instruction] + code = co.co_code + linestarts = dict(findlinestarts(co)) + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + offset = i + c = code[i] + op = ord(c) + lineno = linestarts.get(i) + argval = None + i = i + 1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i + 1]) * 256 + extended_arg + extended_arg = 0 + i = i + 2 + if op == EXTENDED_ARG: + extended_arg = oparg * 65536 + + if op in hasconst: + argval = co.co_consts[oparg] + elif op in hasname: + argval = co.co_names[oparg] + elif opname[op] == 'LOAD_FAST': + argval = co.co_varnames[oparg] + yield Instruction(offset, argval, opname[op], lineno) + +def get_instructions(co): + # type: (types.CodeType) -> Iterator[EnhancedInstruction] + lineno = co.co_firstlineno + for inst in _get_instructions(co): + inst = cast(EnhancedInstruction, inst) + lineno = inst.starts_line or lineno + assert_(lineno) + inst.lineno = lineno + yield inst + + +# Type class used to expand out the definition of AST to include fields added by this library +# It's not actually used for anything other than type checking though! +class EnhancedAST(ast.AST): + parent = None # type: EnhancedAST + +# Type class used to expand out the definition of AST to include fields added by this library +# It's not actually used for anything other than type checking though! +class EnhancedInstruction(Instruction): + _copied = None # type: bool + + + + + +def mangled_name(node): + # type: (EnhancedAST) -> str + """ + + Parameters: + node: the node which should be mangled + name: the name of the node + + Returns: + The mangled name of `node` + """ + + function_class_types=(ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef) + + if isinstance(node, ast.Attribute): + name = node.attr + elif isinstance(node, ast.Name): + name = node.id + elif isinstance(node, (ast.alias)): + name = node.asname or node.name.split(".")[0] + elif isinstance(node, function_class_types): + name = node.name + elif isinstance(node, ast.ExceptHandler): + assert node.name + name = node.name + elif sys.version_info >= (3,12) and isinstance(node,ast.TypeVar): + name=node.name + else: + raise TypeError("no node to mangle") + + if name.startswith("__") and not name.endswith("__"): + + parent,child=node.parent,node + + while not (isinstance(parent,ast.ClassDef) and child not in parent.bases): + if not hasattr(parent,"parent"): + break # pragma: no mutate + + parent,child=parent.parent,parent + else: + class_name=parent.name.lstrip("_") + if class_name!="" and child not in parent.decorator_list: + return "_" + class_name + name + + + + return name diff --git a/executing/executing.py b/executing/executing.py index 5cf117e..dd1e1d7 100644 --- a/executing/executing.py +++ b/executing/executing.py @@ -40,8 +40,8 @@ from pathlib import Path from threading import RLock from tokenize import detect_encoding -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, \ - Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, Type, TypeVar, Union, cast +from ._utils import mangled_name,assert_, EnhancedAST,EnhancedInstruction,Instruction,get_instructions if TYPE_CHECKING: # pragma: no cover from asttokens import ASTTokens, ASTText @@ -52,48 +52,8 @@ cache = lru_cache(maxsize=None) -# Type class used to expand out the definition of AST to include fields added by this library -# It's not actually used for anything other than type checking though! -class EnhancedAST(ast.AST): - parent = None # type: EnhancedAST - - -class Instruction(dis.Instruction): - lineno = None # type: int - - -# Type class used to expand out the definition of AST to include fields added by this library -# It's not actually used for anything other than type checking though! -class EnhancedInstruction(Instruction): - _copied = None # type: bool - - - -def assert_(condition, message=""): - # type: (Any, str) -> None - """ - Like an assert statement, but unaffected by -O - :param condition: value that is expected to be truthy - :type message: Any - """ - if not condition: - raise AssertionError(str(message)) - - -def get_instructions(co): - # type: (types.CodeType) -> Iterator[EnhancedInstruction] - lineno = co.co_firstlineno - for inst in dis.get_instructions(co): - inst = cast(EnhancedInstruction, inst) - lineno = inst.starts_line or lineno - assert_(lineno) - inst.lineno = lineno - yield inst - - TESTING = 0 - class NotOneValueFound(Exception): def __init__(self,msg,values=[]): # type: (str, Sequence) -> None @@ -581,11 +541,11 @@ def __init__(self, frame, stmts, tree, lasti, source): elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'): typ = ast.Attribute ctx = ast.Load - extra_filter = lambda e: attr_names_match(e.attr, instruction.argval) + extra_filter = lambda e:mangled_name(e) == instruction.argval elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'): typ = ast.Name ctx = ast.Load - extra_filter = lambda e: e.id == instruction.argval + extra_filter = lambda e:mangled_name(e) == instruction.argval elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'): typ = ast.Compare extra_filter = lambda e: len(e.ops) == 1 @@ -595,10 +555,11 @@ def __init__(self, frame, stmts, tree, lasti, source): elif op_name.startswith('STORE_ATTR'): ctx = ast.Store typ = ast.Attribute - extra_filter = lambda e: attr_names_match(e.attr, instruction.argval) + extra_filter = lambda e:mangled_name(e) == instruction.argval else: raise RuntimeError(op_name) + with lock: exprs = { cast(EnhancedAST, node) @@ -1126,19 +1087,6 @@ def find_node_ipython(frame, lasti, stmts, source): return decorator, node -def attr_names_match(attr, argval): - # type: (str, str) -> bool - """ - Checks that the user-visible attr (from ast) can correspond to - the argval in the bytecode, i.e. the real attribute fetched internally, - which may be mangled for private attributes. - """ - if attr == argval: - return True - if not attr.startswith("__"): - return False - return bool(re.match(r"^_\w+%s$" % attr, argval)) - def node_linenos(node): # type: (ast.AST) -> Iterator[int] diff --git a/pyproject.toml b/pyproject.toml index d5cf1ab..226e061 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,7 @@ warn_redundant_casts=true [[tool.mypy.overrides]] module = "astroid" -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[tool.pytest.ini_options] +python_functions = "test_" diff --git a/tests/test_main.py b/tests/test_main.py index e3bc9d6..6d16998 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -15,6 +15,7 @@ from collections import defaultdict, namedtuple from random import shuffle import pytest +from executing._utils import mangled_name sys.path.append(os.path.dirname(os.path.dirname(__file__))) @@ -715,11 +716,7 @@ def sample_files(samples): @pytest.mark.skipif(sys.version_info<(3,),reason="no 2.7 support") def test_small_samples(full_filename, result_filename): skip_sentinel = [ - "load_deref", "4851dc1b626a95e97dbe0c53f96099d165b755dd1bd552c6ca771f7bca6d30f5", - "508ccd0dcac13ecee6f0cea939b73ba5319c780ddbb6c496be96fe5614871d4a", - "fc6eb521024986baa84af2634f638e40af090be4aa70ab3c22f3d022e8068228", - "42a37b8a823eb2e510b967332661afd679c82c60b7177b992a47c16d81117c8a", "206e0609ff0589a0a32422ee902f09156af91746e27157c32c9595d12072f92a", ] @@ -1077,6 +1074,7 @@ def p(*args): p() p("ast node:") + p(mangled_name(node)) p(ast_dump(node, indent=4)) parents = [] @@ -1429,11 +1427,9 @@ def check_code(self, code, nodes, decorators, check_names): raise - # `argval` isn't set for all relevant instructions in python 2 - # The relation between `ast.Name` and `argval` is already - # covered by the verifier and much more complex in python 3.11 - if isinstance(node, ast.Name) and not py11: - assert inst.argval == node.id, (inst, ast.dump(node)) + if isinstance(node, ast.Name) and inst.opname != "CALL_INTRINSIC_1" and inst.argval not in ("__classdict__","__classdictcell__","__static_attributes__"): + # CALL_INTRINSIC_1 and some special names are excuded here because they are generated by cpython for some synthetic code + assert mangled_name(node) == inst.argval if ex.decorator: decorators[(node.lineno, node.name)].append(ex.decorator) diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 5cbe0a2..a169ca4 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -3,6 +3,7 @@ import linecache import os import sys +import dis from time import sleep import asttokens @@ -13,7 +14,10 @@ import executing.executing from executing import Source, NotOneValueFound from executing._exceptions import KnownIssue -from executing.executing import is_ipython_cell_code, attr_names_match, is_rewritten_by_pytest +from executing.executing import is_ipython_cell_code, is_rewritten_by_pytest +from executing._utils import get_instructions, mangled_name + +from textwrap import indent sys.path.append(os.path.dirname(os.path.dirname(__file__))) @@ -56,20 +60,6 @@ def test_ipython_cell_code(): ) -def test_attr_names_match(): - assert attr_names_match("foo", "foo") - - assert not attr_names_match("foo", "_foo") - assert not attr_names_match("foo", "__foo") - assert not attr_names_match("_foo", "foo") - assert not attr_names_match("__foo", "foo") - - assert attr_names_match("__foo", "_Class__foo") - assert not attr_names_match("_Class__foo", "__foo") - assert not attr_names_match("__foo", "Class__foo") - assert not attr_names_match("__foo", "_Class_foo") - - def test_source_file_text_change(tmpdir): # Check that Source.for_filename notices changes in file contents # (assuming that linecache can notice) @@ -167,12 +157,8 @@ def test_bad_linecache(): assert ex.source.text == fake_text -if sys.version_info >= (3, 11): - from executing._position_node_finder import mangled_name - from textwrap import indent - import dis - def test_mangled_name(): +def test_mangled_name(): def result(*code_levels): code = "" for i, level in enumerate(code_levels): @@ -184,25 +170,29 @@ def result(*code_levels): for child in ast.iter_child_nodes(parent): child.parent = parent + + ast_types=( + ast.Name, + ast.Attribute, + ast.alias, + ast.FunctionDef, + ast.ClassDef, + ast.ExceptHandler, + ast.AsyncFunctionDef, + ) + tree_names = { mangled_name(n) for n in ast.walk(tree) if isinstance( n, - ( - ast.Name, - ast.Attribute, - ast.alias, - ast.FunctionDef, - ast.ClassDef, - ast.AsyncFunctionDef, - ast.ExceptHandler, - ), + ast_types + , ) } def collect_names(code): - for instruction in dis.get_instructions(code): + for instruction in get_instructions(code): if instruction.opname in ( "STORE_NAME", "LOAD_NAME", @@ -401,6 +391,14 @@ def collect_names(code): ) == {"Test","_","a", "self", "__thing"} + assert result( + "@__thing\n" + "class Test:\n" + " pass" + )== {"Test","__thing"} + + + def test_pytest_rewrite(): frame = inspect.currentframe() diff --git a/tests/utils.py b/tests/utils.py index 0e20ecf..93ece6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ import executing.executing -from executing.executing import attr_names_match, Instruction +from executing.executing import mangled_name, Instruction try: from dis import Instruction as DisInstruction except ImportError: @@ -104,7 +104,7 @@ def __setattr__(self, name, value): assert name == "_{self.__class__.__name__}{node.attr}".format(self=self, node=node) else: assert name == node.attr - assert attr_names_match(node.attr, name) + assert mangled_name(node) == name return self def __delattr__(self, name):