Skip to content

Commit cb80c99

Browse files
Deduplicate AST import detection across helper boundary guards
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent d8980f1 commit cb80c99

File tree

3 files changed

+19
-28
lines changed

3 files changed

+19
-28
lines changed

tests/ast_import_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import ast
2+
3+
4+
def imports_collect_function_sources(module_text: str) -> bool:
5+
module_ast = ast.parse(module_text)
6+
for node in module_ast.body:
7+
if not isinstance(node, ast.ImportFrom):
8+
continue
9+
if node.module != "tests.ast_function_source_utils":
10+
continue
11+
if any(alias.name == "collect_function_sources" for alias in node.names):
12+
return True
13+
return False

tests/test_ast_function_source_helper_usage.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import ast
21
from pathlib import Path
32

43
import pytest
54

5+
from tests.ast_import_utils import imports_collect_function_sources
6+
67
pytestmark = pytest.mark.architecture
78

89

@@ -23,23 +24,11 @@
2324
"tests/test_web_request_wrapper_internal_reuse.py",
2425
)
2526

26-
def _imports_collect_function_sources(module_text: str) -> bool:
27-
module_ast = ast.parse(module_text)
28-
for node in module_ast.body:
29-
if not isinstance(node, ast.ImportFrom):
30-
continue
31-
if node.module != "tests.ast_function_source_utils":
32-
continue
33-
if any(alias.name == "collect_function_sources" for alias in node.names):
34-
return True
35-
return False
36-
37-
3827
def test_ast_guard_modules_reuse_shared_collect_function_sources_helper():
3928
violating_modules: list[str] = []
4029
for module_path in AST_FUNCTION_SOURCE_GUARD_MODULES:
4130
module_text = Path(module_path).read_text(encoding="utf-8")
42-
if not _imports_collect_function_sources(module_text):
31+
if not imports_collect_function_sources(module_text):
4332
violating_modules.append(module_path)
4433
continue
4534
if "collect_function_sources(" not in module_text:
@@ -62,7 +51,7 @@ def test_ast_guard_inventory_stays_in_sync_with_helper_imports():
6251
if normalized_path in excluded_modules:
6352
continue
6453
module_text = module_path.read_text(encoding="utf-8")
65-
if not _imports_collect_function_sources(module_text):
54+
if not imports_collect_function_sources(module_text):
6655
continue
6756
discovered_modules.append(normalized_path)
6857

tests/test_ast_function_source_import_boundary.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import ast
21
from pathlib import Path
32

43
import pytest
54

5+
from tests.ast_import_utils import imports_collect_function_sources
66
from tests.test_ast_function_source_helper_usage import (
77
AST_FUNCTION_SOURCE_GUARD_MODULES,
88
)
@@ -12,23 +12,12 @@
1212

1313
EXPECTED_EXTRA_IMPORTER_MODULES = ("tests/test_ast_function_source_utils.py",)
1414

15-
def _imports_collect_function_sources(module_text: str) -> bool:
16-
module_ast = ast.parse(module_text)
17-
for node in module_ast.body:
18-
if not isinstance(node, ast.ImportFrom):
19-
continue
20-
if node.module != "tests.ast_function_source_utils":
21-
continue
22-
if any(alias.name == "collect_function_sources" for alias in node.names):
23-
return True
24-
return False
25-
2615

2716
def test_ast_function_source_helper_imports_are_centralized():
2817
discovered_modules: list[str] = []
2918
for module_path in sorted(Path("tests").glob("test_*.py")):
3019
module_text = module_path.read_text(encoding="utf-8")
31-
if not _imports_collect_function_sources(module_text):
20+
if not imports_collect_function_sources(module_text):
3221
continue
3322
discovered_modules.append(module_path.as_posix())
3423

0 commit comments

Comments
 (0)