Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,6 @@ jobs:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
if: needs.select-tests.outputs.has-selected-tests == 'true'

- name: Setup sentry env
if: needs.select-tests.outputs.has-selected-tests == 'true'
uses: ./.github/actions/setup-sentry
id: setup
with:
mode: backend-ci
skip-devservices: true

- name: Download selected tests artifact
if: needs.select-tests.outputs.has-selected-tests == 'true'
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
Expand Down
178 changes: 114 additions & 64 deletions .github/workflows/scripts/calculate-backend-test-shards.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
#!/usr/bin/env python3
"""Calculate the number of backend test shards needed for CI.

Uses AST-based static analysis to count tests instead of running
pytest --collect-only, which requires importing every module and
bootstrapping Django (~100s). AST parsing takes a few seconds.
"""

from __future__ import annotations

import ast
import json
import math
import os
import re
import subprocess
import sys
from pathlib import Path

Expand All @@ -12,85 +21,126 @@
MAX_SHARDS = 22
DEFAULT_SHARDS = MAX_SHARDS

IGNORED_DIRS = frozenset(("tests/acceptance/", "tests/apidocs/", "tests/js/", "tests/tools/"))


def _resolve(node: ast.expr, scope: dict[str, ast.expr]) -> ast.expr:
"""Chase Name and Subscript references back to a concrete AST node."""
if isinstance(node, ast.Name) and node.id in scope:
return _resolve(scope[node.id], scope)
if (
isinstance(node, ast.Subscript)
and isinstance(node.value, ast.Name)
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, int)
and node.value.id in scope
):
target = _resolve(scope[node.value.id], scope)
i = node.slice.value
if isinstance(target, (ast.List, ast.Tuple)) and 0 <= i < len(target.elts):
return _resolve(target.elts[i], scope)
return node


def _parametrize_count(dec: ast.expr, scope: dict[str, ast.expr]) -> int | None:
"""If *dec* is a ``@pytest.mark.parametrize``, return the case count."""
dec = _resolve(dec, scope)
if not isinstance(dec, ast.Call) or len(dec.args) < 2:
return None
f = dec.func
if not (
isinstance(f, ast.Attribute)
and f.attr == "parametrize"
and isinstance(f.value, ast.Attribute)
and f.value.attr == "mark"
and isinstance(f.value.value, ast.Name)
and f.value.value.id == "pytest"
):
return None
argvals = _resolve(dec.args[1], scope)
return len(argvals.elts) if isinstance(argvals, (ast.List, ast.Tuple)) else None


_TEST_FUNC_RE = re.compile(r"^\s*(?:async\s+)?def\s+test_", re.MULTILINE)


def count_tests_in_file(filepath: Path) -> int:
"""Count the test items *filepath* would produce.

Accounts for ``@pytest.mark.parametrize`` multipliers including
stacked decorators.
"""
try:
source = filepath.read_text(encoding="utf-8")
except (UnicodeDecodeError, OSError):
return 0

# Fast path: no parametrize means each def test_ is exactly one test.
if "parametrize" not in source:
return len(_TEST_FUNC_RE.findall(source))

try:
tree = ast.parse(source, filename=str(filepath))
except SyntaxError:
return len(_TEST_FUNC_RE.findall(source))

scope: dict[str, ast.expr] = {}
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
scope[target.id] = node.value
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.value:
scope[node.target.id] = node.value

total = 0
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name.startswith(
"test_"
):
counts = (
c for d in node.decorator_list if (c := _parametrize_count(d, scope)) is not None
)
total += math.prod(counts, start=1)
return total


def collect_test_count() -> int | None:
"""Collect the number of tests to run, either from selected files or full suite."""
"""Count tests via AST analysis of test files."""
selected_tests_file = os.environ.get("SELECTED_TESTS_FILE")

if selected_tests_file:
path = Path(selected_tests_file)
if not path.exists():
print(f"Selected tests file not found: {selected_tests_file}", file=sys.stderr)
print(
f"Selected tests file not found: {selected_tests_file}",
file=sys.stderr,
)
return None

with path.open() as f:
selected_files = [line.strip() for line in f if line.strip()]
test_files = [Path(line.strip()) for line in path.read_text().splitlines() if line.strip()]

if not selected_files:
if not test_files:
print("No selected test files, running 0 tests", file=sys.stderr)
return 0

print(f"Counting tests in {len(selected_files)} selected files", file=sys.stderr)

pytest_args = [
"pytest",
# Always pass tests/ directory to ensure proper conftest loading order.
# SELECTED_TESTS_FILE env var triggers filtering in pytest_collection_modifyitems.
"tests",
"--collect-only",
"--quiet",
"--ignore=tests/acceptance",
"--ignore=tests/apidocs",
"--ignore=tests/js",
"--ignore=tests/tools",
]
print(f"Counting tests in {len(test_files)} selected files", file=sys.stderr)
else:
tests_dir = Path("tests")
if not tests_dir.is_dir():
print("tests/ directory not found", file=sys.stderr)
return None

try:
result = subprocess.run(
pytest_args,
capture_output=True,
text=True,
check=False,
test_files = sorted(
p
for p in tests_dir.rglob("test_*.py")
if not any(str(p).startswith(d) for d in IGNORED_DIRS)
)
print(f"Found {len(test_files)} test files", file=sys.stderr)

# Parse output for test count
# Format without deselection: "27000 tests collected in 18.53s"
# Format with deselection: "29/31510 tests collected (31481 deselected) in 18.13s"
output = result.stdout + result.stderr

# Try format with deselection first (selected/total)
match = re.search(r"(\d+)/\d+ tests? collected", output)
if match:
count = int(match.group(1))
print(f"Collected {count} tests", file=sys.stderr)
return count

# Fall back to format without deselection
match = re.search(r"(\d+) tests? collected", output)
if match:
count = int(match.group(1))
print(f"Collected {count} tests", file=sys.stderr)
return count

if result.returncode == 5:
# Exit code 5 indicates no tests collected (https://docs.pytest.org/en/stable/reference/exit-codes.html)
# This can stem from files being deleted in a branch/PR.
print("No tests collected (exit 5)", file=sys.stderr)
return 0

if result.returncode != 0:
print(
f"Pytest collection failed (exit {result.returncode})",
file=sys.stderr,
)
print(result.stderr, file=sys.stderr)
return None

print("No tests collected", file=sys.stderr)
return 0
except Exception as e:
print(f"Error collecting tests: {e}", file=sys.stderr)
return None
total = sum(count_tests_in_file(f) for f in test_files)
print(f"Counted {total} tests via AST analysis", file=sys.stderr)
return total


def calculate_shards(test_count: int | None) -> int:
Expand Down
28 changes: 26 additions & 2 deletions .github/workflows/scripts/compute-sentry-selected-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,36 @@
"tests/integration/",
)

# Most of these won't have coverage info because they're evaluated at
# module load time and app warmup, before any per-test coverage context is active.
#
# Tracking a "startup" coverage context doesn't work: django.setup()
# eagerly imports models, fields, validators, utils, etc. We also have
# large dynamic __init__'s so a startup context would select nearly every
# test.
FULL_SUITE_TRIGGERS: list[str | re.Pattern[str]] = [
"src/sentry/testutils/pytest/sentry.py",
re.compile(r"^src/sentry/testutils/pytest/"),
re.compile(r"(^|/)conftest\.py$"),
"src/sentry/runner/initializer.py",
"src/sentry/constants.py",
"pyproject.toml",
# option defaults registered at startup via initialize_app()
re.compile(r"^src/sentry/options/"),
# feature flags registered via manager.add() at import time
re.compile(r"^src/sentry/features/"),
# signal definitions created at module level; receivers depend on these
"src/sentry/signals.py",
# signal handlers registered globally via initialize_receivers()
re.compile(r"^src/sentry/receivers/"),
# stdlib/third-party monkey-patches applied before Django setup
re.compile(r"^src/sentry/monkey/"),
# monkeypatches transaction.atomic for silo-aware DB routing
re.compile(r"^src/sentry/silo/patches/"),
# SiloRouter loaded via DATABASE_ROUTERS; affects every DB query
"src/sentry/db/router.py",
"src/sentry/conf/server.py",
"src/sentry/web/urls.py",
"pyproject.toml",
"uv.lock",
re.compile(r"/migrations/\d{4}_[^/]+\.py$"),
]

Expand Down
Loading
Loading