|
1 | 1 | #!/usr/bin/env python3 |
| 2 | +"""Calculate the number of backend test shards needed for CI. |
| 3 | +
|
| 4 | +Uses AST-based static analysis to count tests instead of running |
| 5 | +pytest --collect-only, which requires importing every module and |
| 6 | +bootstrapping Django (~100s). AST parsing takes a few seconds. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import ast |
2 | 12 | import json |
3 | 13 | import math |
4 | 14 | import os |
5 | 15 | import re |
6 | | -import subprocess |
7 | 16 | import sys |
8 | 17 | from pathlib import Path |
9 | 18 |
|
|
12 | 21 | MAX_SHARDS = 22 |
13 | 22 | DEFAULT_SHARDS = MAX_SHARDS |
14 | 23 |
|
| 24 | +IGNORED_DIRS = frozenset(("tests/acceptance", "tests/apidocs", "tests/js", "tests/tools")) |
| 25 | + |
| 26 | + |
| 27 | +def _resolve(node: ast.expr, scope: dict[str, ast.expr]) -> ast.expr: |
| 28 | + """Chase Name and Subscript references back to a concrete AST node.""" |
| 29 | + if isinstance(node, ast.Name) and node.id in scope: |
| 30 | + return _resolve(scope[node.id], scope) |
| 31 | + if ( |
| 32 | + isinstance(node, ast.Subscript) |
| 33 | + and isinstance(node.value, ast.Name) |
| 34 | + and isinstance(node.slice, ast.Constant) |
| 35 | + and isinstance(node.slice.value, int) |
| 36 | + and node.value.id in scope |
| 37 | + ): |
| 38 | + target = _resolve(scope[node.value.id], scope) |
| 39 | + i = node.slice.value |
| 40 | + if isinstance(target, (ast.List, ast.Tuple)) and 0 <= i < len(target.elts): |
| 41 | + return _resolve(target.elts[i], scope) |
| 42 | + return node |
| 43 | + |
| 44 | + |
| 45 | +def _parametrize_count(dec: ast.expr, scope: dict[str, ast.expr]) -> int | None: |
| 46 | + """If *dec* is a ``@pytest.mark.parametrize``, return the case count.""" |
| 47 | + dec = _resolve(dec, scope) |
| 48 | + if not isinstance(dec, ast.Call) or len(dec.args) < 2: |
| 49 | + return None |
| 50 | + f = dec.func |
| 51 | + if not ( |
| 52 | + isinstance(f, ast.Attribute) |
| 53 | + and f.attr == "parametrize" |
| 54 | + and isinstance(f.value, ast.Attribute) |
| 55 | + and f.value.attr == "mark" |
| 56 | + and isinstance(f.value.value, ast.Name) |
| 57 | + and f.value.value.id == "pytest" |
| 58 | + ): |
| 59 | + return None |
| 60 | + argvals = _resolve(dec.args[1], scope) |
| 61 | + return len(argvals.elts) if isinstance(argvals, (ast.List, ast.Tuple)) else None |
| 62 | + |
| 63 | + |
| 64 | +_TEST_FUNC_RE = re.compile(r"^\s*(?:async\s+)?def\s+test_", re.MULTILINE) |
| 65 | + |
| 66 | + |
| 67 | +def count_tests_in_file(filepath: Path) -> int: |
| 68 | + """Count the test items *filepath* would produce. |
| 69 | +
|
| 70 | + Accounts for ``@pytest.mark.parametrize`` multipliers including |
| 71 | + stacked decorators. |
| 72 | + """ |
| 73 | + try: |
| 74 | + source = filepath.read_text(encoding="utf-8") |
| 75 | + except (UnicodeDecodeError, OSError): |
| 76 | + return 0 |
| 77 | + |
| 78 | + # Fast path: no parametrize means each def test_ is exactly one test. |
| 79 | + if "parametrize" not in source: |
| 80 | + return len(_TEST_FUNC_RE.findall(source)) |
| 81 | + |
| 82 | + try: |
| 83 | + tree = ast.parse(source, filename=str(filepath)) |
| 84 | + except SyntaxError: |
| 85 | + return len(_TEST_FUNC_RE.findall(source)) |
| 86 | + |
| 87 | + scope = { |
| 88 | + target.id: node.value |
| 89 | + for node in ast.iter_child_nodes(tree) |
| 90 | + if isinstance(node, ast.Assign) |
| 91 | + for target in node.targets |
| 92 | + if isinstance(target, ast.Name) |
| 93 | + } |
| 94 | + |
| 95 | + total = 0 |
| 96 | + for node in ast.walk(tree): |
| 97 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name.startswith( |
| 98 | + "test_" |
| 99 | + ): |
| 100 | + counts = filter(None, (_parametrize_count(d, scope) for d in node.decorator_list)) |
| 101 | + total += math.prod(counts, start=1) |
| 102 | + return total |
| 103 | + |
15 | 104 |
|
16 | 105 | def collect_test_count() -> int | None: |
17 | | - """Collect the number of tests to run, either from selected files or full suite.""" |
| 106 | + """Count tests via AST analysis of test files.""" |
18 | 107 | selected_tests_file = os.environ.get("SELECTED_TESTS_FILE") |
19 | 108 |
|
20 | 109 | if selected_tests_file: |
21 | 110 | path = Path(selected_tests_file) |
22 | 111 | if not path.exists(): |
23 | | - print(f"Selected tests file not found: {selected_tests_file}", file=sys.stderr) |
| 112 | + print( |
| 113 | + f"Selected tests file not found: {selected_tests_file}", |
| 114 | + file=sys.stderr, |
| 115 | + ) |
24 | 116 | return None |
25 | 117 |
|
26 | | - with path.open() as f: |
27 | | - selected_files = [line.strip() for line in f if line.strip()] |
| 118 | + test_files = [Path(line.strip()) for line in path.read_text().splitlines() if line.strip()] |
28 | 119 |
|
29 | | - if not selected_files: |
| 120 | + if not test_files: |
30 | 121 | print("No selected test files, running 0 tests", file=sys.stderr) |
31 | 122 | return 0 |
32 | 123 |
|
33 | | - print(f"Counting tests in {len(selected_files)} selected files", file=sys.stderr) |
34 | | - |
35 | | - pytest_args = [ |
36 | | - "pytest", |
37 | | - # Always pass tests/ directory to ensure proper conftest loading order. |
38 | | - # SELECTED_TESTS_FILE env var triggers filtering in pytest_collection_modifyitems. |
39 | | - "tests", |
40 | | - "--collect-only", |
41 | | - "--quiet", |
42 | | - "--ignore=tests/acceptance", |
43 | | - "--ignore=tests/apidocs", |
44 | | - "--ignore=tests/js", |
45 | | - "--ignore=tests/tools", |
46 | | - ] |
| 124 | + print(f"Counting tests in {len(test_files)} selected files", file=sys.stderr) |
| 125 | + else: |
| 126 | + tests_dir = Path("tests") |
| 127 | + if not tests_dir.is_dir(): |
| 128 | + print("tests/ directory not found", file=sys.stderr) |
| 129 | + return None |
47 | 130 |
|
48 | | - try: |
49 | | - result = subprocess.run( |
50 | | - pytest_args, |
51 | | - capture_output=True, |
52 | | - text=True, |
53 | | - check=False, |
| 131 | + test_files = sorted( |
| 132 | + p |
| 133 | + for p in tests_dir.rglob("test_*.py") |
| 134 | + if not any(str(p).startswith(d) for d in IGNORED_DIRS) |
54 | 135 | ) |
| 136 | + print(f"Found {len(test_files)} test files", file=sys.stderr) |
55 | 137 |
|
56 | | - # Parse output for test count |
57 | | - # Format without deselection: "27000 tests collected in 18.53s" |
58 | | - # Format with deselection: "29/31510 tests collected (31481 deselected) in 18.13s" |
59 | | - output = result.stdout + result.stderr |
60 | | - |
61 | | - # Try format with deselection first (selected/total) |
62 | | - match = re.search(r"(\d+)/\d+ tests? collected", output) |
63 | | - if match: |
64 | | - count = int(match.group(1)) |
65 | | - print(f"Collected {count} tests", file=sys.stderr) |
66 | | - return count |
67 | | - |
68 | | - # Fall back to format without deselection |
69 | | - match = re.search(r"(\d+) tests? collected", output) |
70 | | - if match: |
71 | | - count = int(match.group(1)) |
72 | | - print(f"Collected {count} tests", file=sys.stderr) |
73 | | - return count |
74 | | - |
75 | | - if result.returncode == 5: |
76 | | - # Exit code 5 indicates no tests collected (https://docs.pytest.org/en/stable/reference/exit-codes.html) |
77 | | - # This can stem from files being deleted in a branch/PR. |
78 | | - print("No tests collected (exit 5)", file=sys.stderr) |
79 | | - return 0 |
80 | | - |
81 | | - if result.returncode != 0: |
82 | | - print( |
83 | | - f"Pytest collection failed (exit {result.returncode})", |
84 | | - file=sys.stderr, |
85 | | - ) |
86 | | - print(result.stderr, file=sys.stderr) |
87 | | - return None |
88 | | - |
89 | | - print("No tests collected", file=sys.stderr) |
90 | | - return 0 |
91 | | - except Exception as e: |
92 | | - print(f"Error collecting tests: {e}", file=sys.stderr) |
93 | | - return None |
| 138 | + total = sum(count_tests_in_file(f) for f in test_files) |
| 139 | + print(f"Counted {total} tests via AST analysis", file=sys.stderr) |
| 140 | + return total |
94 | 141 |
|
95 | 142 |
|
96 | 143 | def calculate_shards(test_count: int | None) -> int: |
|
0 commit comments