Skip to content

Commit 75bff2a

Browse files
committed
fast calculate-shards
1 parent 32bf9a6 commit 75bff2a

File tree

3 files changed

+563
-72
lines changed

3 files changed

+563
-72
lines changed

.github/workflows/backend.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,6 @@ jobs:
184184
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
185185
if: needs.select-tests.outputs.has-selected-tests == 'true'
186186

187-
- name: Setup sentry env
188-
if: needs.select-tests.outputs.has-selected-tests == 'true'
189-
uses: ./.github/actions/setup-sentry
190-
id: setup
191-
with:
192-
mode: backend-ci
193-
skip-devservices: true
194-
195187
- name: Download selected tests artifact
196188
if: needs.select-tests.outputs.has-selected-tests == 'true'
197189
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0

.github/workflows/scripts/calculate-backend-test-shards.py

Lines changed: 111 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
#!/usr/bin/env python3
2+
"""Calculate the number of backend test shards needed for CI.
3+
4+
Uses AST-based static analysis to count tests instead of running
5+
pytest --collect-only, which requires importing every module and
6+
bootstrapping Django (~100s). AST parsing takes a few seconds.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import ast
212
import json
313
import math
414
import os
515
import re
6-
import subprocess
716
import sys
817
from pathlib import Path
918

@@ -12,85 +21,123 @@
1221
MAX_SHARDS = 22
1322
DEFAULT_SHARDS = MAX_SHARDS
1423

24+
IGNORED_DIRS = frozenset(("tests/acceptance", "tests/apidocs", "tests/js", "tests/tools"))
25+
26+
27+
def _resolve(node: ast.expr, scope: dict[str, ast.expr]) -> ast.expr:
28+
"""Chase Name and Subscript references back to a concrete AST node."""
29+
if isinstance(node, ast.Name) and node.id in scope:
30+
return _resolve(scope[node.id], scope)
31+
if (
32+
isinstance(node, ast.Subscript)
33+
and isinstance(node.value, ast.Name)
34+
and isinstance(node.slice, ast.Constant)
35+
and isinstance(node.slice.value, int)
36+
and node.value.id in scope
37+
):
38+
target = _resolve(scope[node.value.id], scope)
39+
i = node.slice.value
40+
if isinstance(target, (ast.List, ast.Tuple)) and 0 <= i < len(target.elts):
41+
return _resolve(target.elts[i], scope)
42+
return node
43+
44+
45+
def _parametrize_count(dec: ast.expr, scope: dict[str, ast.expr]) -> int | None:
46+
"""If *dec* is a ``@pytest.mark.parametrize``, return the case count."""
47+
dec = _resolve(dec, scope)
48+
if not isinstance(dec, ast.Call) or len(dec.args) < 2:
49+
return None
50+
f = dec.func
51+
if not (
52+
isinstance(f, ast.Attribute)
53+
and f.attr == "parametrize"
54+
and isinstance(f.value, ast.Attribute)
55+
and f.value.attr == "mark"
56+
and isinstance(f.value.value, ast.Name)
57+
and f.value.value.id == "pytest"
58+
):
59+
return None
60+
argvals = _resolve(dec.args[1], scope)
61+
return len(argvals.elts) if isinstance(argvals, (ast.List, ast.Tuple)) else None
62+
63+
64+
_TEST_FUNC_RE = re.compile(r"^\s*(?:async\s+)?def\s+test_", re.MULTILINE)
65+
66+
67+
def count_tests_in_file(filepath: Path) -> int:
68+
"""Count the test items *filepath* would produce.
69+
70+
Accounts for ``@pytest.mark.parametrize`` multipliers including
71+
stacked decorators.
72+
"""
73+
try:
74+
source = filepath.read_text(encoding="utf-8")
75+
except (UnicodeDecodeError, OSError):
76+
return 0
77+
78+
# Fast path: no parametrize means each def test_ is exactly one test.
79+
if "parametrize" not in source:
80+
return len(_TEST_FUNC_RE.findall(source))
81+
82+
try:
83+
tree = ast.parse(source, filename=str(filepath))
84+
except SyntaxError:
85+
return len(_TEST_FUNC_RE.findall(source))
86+
87+
scope = {
88+
target.id: node.value
89+
for node in ast.iter_child_nodes(tree)
90+
if isinstance(node, ast.Assign)
91+
for target in node.targets
92+
if isinstance(target, ast.Name)
93+
}
94+
95+
total = 0
96+
for node in ast.walk(tree):
97+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name.startswith(
98+
"test_"
99+
):
100+
counts = filter(None, (_parametrize_count(d, scope) for d in node.decorator_list))
101+
total += math.prod(counts, start=1)
102+
return total
103+
15104

16105
def collect_test_count() -> int | None:
17-
"""Collect the number of tests to run, either from selected files or full suite."""
106+
"""Count tests via AST analysis of test files."""
18107
selected_tests_file = os.environ.get("SELECTED_TESTS_FILE")
19108

20109
if selected_tests_file:
21110
path = Path(selected_tests_file)
22111
if not path.exists():
23-
print(f"Selected tests file not found: {selected_tests_file}", file=sys.stderr)
112+
print(
113+
f"Selected tests file not found: {selected_tests_file}",
114+
file=sys.stderr,
115+
)
24116
return None
25117

26-
with path.open() as f:
27-
selected_files = [line.strip() for line in f if line.strip()]
118+
test_files = [Path(line.strip()) for line in path.read_text().splitlines() if line.strip()]
28119

29-
if not selected_files:
120+
if not test_files:
30121
print("No selected test files, running 0 tests", file=sys.stderr)
31122
return 0
32123

33-
print(f"Counting tests in {len(selected_files)} selected files", file=sys.stderr)
34-
35-
pytest_args = [
36-
"pytest",
37-
# Always pass tests/ directory to ensure proper conftest loading order.
38-
# SELECTED_TESTS_FILE env var triggers filtering in pytest_collection_modifyitems.
39-
"tests",
40-
"--collect-only",
41-
"--quiet",
42-
"--ignore=tests/acceptance",
43-
"--ignore=tests/apidocs",
44-
"--ignore=tests/js",
45-
"--ignore=tests/tools",
46-
]
124+
print(f"Counting tests in {len(test_files)} selected files", file=sys.stderr)
125+
else:
126+
tests_dir = Path("tests")
127+
if not tests_dir.is_dir():
128+
print("tests/ directory not found", file=sys.stderr)
129+
return None
47130

48-
try:
49-
result = subprocess.run(
50-
pytest_args,
51-
capture_output=True,
52-
text=True,
53-
check=False,
131+
test_files = sorted(
132+
p
133+
for p in tests_dir.rglob("test_*.py")
134+
if not any(str(p).startswith(d) for d in IGNORED_DIRS)
54135
)
136+
print(f"Found {len(test_files)} test files", file=sys.stderr)
55137

56-
# Parse output for test count
57-
# Format without deselection: "27000 tests collected in 18.53s"
58-
# Format with deselection: "29/31510 tests collected (31481 deselected) in 18.13s"
59-
output = result.stdout + result.stderr
60-
61-
# Try format with deselection first (selected/total)
62-
match = re.search(r"(\d+)/\d+ tests? collected", output)
63-
if match:
64-
count = int(match.group(1))
65-
print(f"Collected {count} tests", file=sys.stderr)
66-
return count
67-
68-
# Fall back to format without deselection
69-
match = re.search(r"(\d+) tests? collected", output)
70-
if match:
71-
count = int(match.group(1))
72-
print(f"Collected {count} tests", file=sys.stderr)
73-
return count
74-
75-
if result.returncode == 5:
76-
# Exit code 5 indicates no tests collected (https://docs.pytest.org/en/stable/reference/exit-codes.html)
77-
# This can stem from files being deleted in a branch/PR.
78-
print("No tests collected (exit 5)", file=sys.stderr)
79-
return 0
80-
81-
if result.returncode != 0:
82-
print(
83-
f"Pytest collection failed (exit {result.returncode})",
84-
file=sys.stderr,
85-
)
86-
print(result.stderr, file=sys.stderr)
87-
return None
88-
89-
print("No tests collected", file=sys.stderr)
90-
return 0
91-
except Exception as e:
92-
print(f"Error collecting tests: {e}", file=sys.stderr)
93-
return None
138+
total = sum(count_tests_in_file(f) for f in test_files)
139+
print(f"Counted {total} tests via AST analysis", file=sys.stderr)
140+
return total
94141

95142

96143
def calculate_shards(test_count: int | None) -> int:

0 commit comments

Comments
 (0)