diff --git a/testmon/configure.py b/testmon/configure.py index 2288d5ce..29744b70 100644 --- a/testmon/configure.py +++ b/testmon/configure.py @@ -99,6 +99,7 @@ class TmConf: message: str collect: bool select: bool + import_depth: int = -1 tmnet: bool = False def __eq__(self, other): @@ -106,6 +107,7 @@ def __eq__(self, other): self.message == other.message and self.collect == other.collect and self.select == other.select + and self.include_imports == other.include_imports and self.tmnet == other.tmnet ) @@ -124,6 +126,8 @@ def _header_collect_select( if notestmon_reasons: return TmConf("testmon: " + notestmon_reasons, False, False) + import_depth = options['testmon_imports_recursion_depth'] + nocollect_reasons = _get_nocollect_reasons( options, debugger=debugger, @@ -146,6 +150,7 @@ def _header_collect_select( f"testmon: {message}", not bool(nocollect_reasons), not bool(noselect_reasons), + import_depth, bool(options.get("tmnet")), ) diff --git a/testmon/pytest_testmon.py b/testmon/pytest_testmon.py index 481e0701..d2272b5c 100644 --- a/testmon/pytest_testmon.py +++ b/testmon/pytest_testmon.py @@ -82,6 +82,20 @@ def pytest_addoption(parser): ), ) + group.addoption( + "--testmon-imports-recursion-depth", + action="store", + dest="testmon_imports_recursion_depth", + default=-1, + type=int, + help=( + "Recursively include imported files as dependencies of " + "tests up to this depth. Defaults to -1, which disables " + "this feature. A depth of 0 will include just the " + "modules directly imported by the test" + ), + ) + group.addoption( "--no-testmon", action="store_true", @@ -181,6 +195,7 @@ def init_testmon_data(config: Config): environment=environment, system_packages=system_packages, readonly=get_running_as(config) == "worker", + import_depth=config.testmon_config.import_depth ) testmon_data.determine_stable(bool(rpc_proxy)) config.testmon_data = testmon_data diff --git a/testmon/testmon_core.py b/testmon/testmon_core.py index a97cd78f..97f7c1c7 100644 --- a/testmon/testmon_core.py +++ b/testmon/testmon_core.py @@ -4,6 +4,7 @@ import sys import sysconfig import textwrap +import ast from functools import lru_cache from collections import defaultdict from xmlrpc.client import Fault, ProtocolError @@ -38,6 +39,8 @@ from testmon.common import DepsNOutcomes, TestExecutions +from typing import Optional + T = TypeVar("T") TEST_BATCH_SIZE = 250 @@ -60,6 +63,70 @@ def is_python_file(file_path): return file_path[-3:] == ".py" +# helpers for import dependency tracking +def parse_imported_modules(rootdir: str, source_path: str, level: int = 0, imported: Optional[dict] = None) -> dict: + """ + Return a set of module names imported by a Python file. + + Only `import` and `from` statements are considered. + """ + if imported is None: + imported = dict() + fullpath = os.path.join(rootdir, source_path) + if not os.path.exists(fullpath): + return imported + try: + with open(fullpath, "r", encoding="utf8") as f: + contents = f.read() + except (OSError, IOError): + return imported + nextlevel = level - 1 + try: + tree = ast.parse(contents, filename=fullpath) + except SyntaxError: + # If the file contains syntax errors we can't parse it, so return an empty set. + return imported + for node in ast.walk(tree): + # Handle `import x` statements + if isinstance(node, ast.Import): + for alias in node.names: + name = alias.name + if not name or name in imported: + continue + relpath = resolve_module_to_file(name, rootdir) + if relpath is None: + continue + imported[name] = relpath + if nextlevel >= 0: + imported = parse_imported_modules(rootdir, relpath, nextlevel, imported) + + # Handle `from x import y` statements + elif isinstance(node, ast.ImportFrom): + if node.module in imported: + continue + + relpath = resolve_module_to_file(node.module, rootdir) + if relpath is None: + continue + imported[node.module] = relpath + if nextlevel >= 0: + imported = parse_imported_modules(rootdir, relpath, nextlevel, imported) + + return imported + +def resolve_module_to_file(module_name: str, rootdir: str) -> Optional[str]: + """ + Attempt to resolve a dotted module name to a Python file within rootdir. + """ + # convert module name to potential file system paths + relative_module_path = module_name.replace(".", os.sep) + for candidate in [f"{relative_module_path}.py", os.path.join(relative_module_path, "__init__.py")]: + absolute_path = os.path.join(rootdir, candidate) + if os.path.exists(absolute_path): + return candidate + return None + + class TestmonException(Exception): pass @@ -165,10 +232,12 @@ def __init__( # pylint: disable=too-many-arguments system_packages=None, python_version=None, readonly=False, + import_depth: int = -1, ): self.rootdir = rootdir self.environment = environment if environment else "default" self.source_tree = SourceTree(rootdir=self.rootdir) + self.import_depth = import_depth if system_packages is None: system_packages = get_system_packages() system_packages = drop_patch_version(system_packages) @@ -236,6 +305,7 @@ def get_tests_fingerprints(self, nodes_files_lines, reports) -> TestExecutions: test_executions_fingerprints = {} for context in nodes_files_lines: deps_n_outcomes: DepsNOutcomes = {"deps": []} + processed_filenames: set[str] = set() for filename, covered in nodes_files_lines[context].items(): if os.path.exists(os.path.join(self.rootdir, filename)): @@ -249,7 +319,30 @@ def get_tests_fingerprints(self, nodes_files_lines, reports) -> TestExecutions: "method_checksums": fingerprint, } ) + processed_filenames.add(filename) + + # include modules imported by the test as dependencies + if self.import_depth >= 0: + test_file = home_file(context) + imported = parse_imported_modules(self.rootdir, test_file, level=self.import_depth) + for mod_rel in imported.values(): + if not mod_rel or mod_rel in processed_filenames: + continue + module = self.source_tree.get_file(mod_rel) + if not module: + continue + deps_n_outcomes["deps"].append( + { + "filename": mod_rel, + "mtime": module.mtime, + "fsha": module.fs_fsha, + # Use the full method_checksums for the module as fingerprint + "method_checksums": module.method_checksums, + } + ) + processed_filenames.add(mod_rel) + # Copy over execution result fields and forced flag deps_n_outcomes.update(process_result(reports[context])) deps_n_outcomes["forced"] = context in self.stable_test_names and ( context not in self.failing_tests