Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/community_of_python_flake8_plugin/checks/final_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,24 @@ def is_model_factory_class(class_node: ast.ClassDef) -> bool:
return check_inherits_from_bases(class_node, {"ModelFactory", "SQLAlchemyFactory"})


def has_local_subclasses(syntax_tree: ast.AST, class_node: ast.ClassDef) -> bool:
"""Check if there are classes in the same file that inherit from this class."""
for one_node in ast.walk(syntax_tree):
if isinstance(one_node, ast.ClassDef) and one_node != class_node:
for one_base in one_node.bases:
# Check for direct class reference: class Child(Parent):
if isinstance(one_base, ast.Name) and one_base.id == class_node.name:
return True
# Check for attributed class reference: class Child(module.Parent):
if isinstance(one_base, ast.Attribute) and one_base.attr == class_node.name:
return True
return False


@typing.final
class FinalClassCheck(ast.NodeVisitor):
def __init__(self, syntax_tree: ast.AST) -> None: # noqa: ARG002
def __init__(self, syntax_tree: ast.AST) -> None:
self.syntax_tree = syntax_tree
self.violations: list[Violation] = []

def visit_ClassDef(self, ast_node: ast.ClassDef) -> None:
Expand All @@ -54,6 +69,10 @@ def _check_final_decorator(self, ast_node: ast.ClassDef) -> None:
if is_protocol_class(ast_node) or ast_node.name.startswith("Test") or is_model_factory_class(ast_node):
return

# If there are classes in this file that inherit from this class, don't require the decorator
if has_local_subclasses(self.syntax_tree, ast_node):
return

if not contains_final_decorator(ast_node):
self.violations.append(
Violation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def visit_ListComp(self, ast_node: ast.ListComp) -> None:
# Validate targets in generators (the 'v' in 'for v in lst')
for one_comprehension in ast_node.generators:
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
self._validate_comprehension_target(one_comprehension.target)
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
self.generic_visit(ast_node)

def visit_SetComp(self, ast_node: ast.SetComp) -> None:
# Validate targets in generators (the 'v' in 'for v in lst')
for one_comprehension in ast_node.generators:
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
self._validate_comprehension_target(one_comprehension.target)
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
self.generic_visit(ast_node)

def visit_DictComp(self, ast_node: ast.DictComp) -> None:
Expand All @@ -38,22 +38,22 @@ def visit_DictComp(self, ast_node: ast.DictComp) -> None:
# key and value are both used
for one_comprehension in ast_node.generators:
if not self._is_partial_unpacking_expr_count(2, one_comprehension.target):
self._validate_comprehension_target(one_comprehension.target)
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
self.generic_visit(ast_node)

def visit_For(self, ast_node: ast.For) -> None:
# Validate target variables in regular for-loops
# Apply same unpacking logic as comprehensions
# For-loops don't have an expression that references vars
if not self._is_partial_unpacking_expr_count(1, ast_node.target):
self._validate_comprehension_target(ast_node.target)
self._validate_comprehension_target(ast_node.target, ast_node.iter)
self.generic_visit(ast_node)

def visit_GeneratorExp(self, ast_node: ast.GeneratorExp) -> None:
# Validate targets in generators (the 'v' in 'for v in lst')
for one_comprehension in ast_node.generators:
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
self._validate_comprehension_target(one_comprehension.target)
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
self.generic_visit(ast_node)

def _is_partial_unpacking(self, expression: ast.expr, target_node: ast.expr) -> bool:
Expand All @@ -65,6 +65,22 @@ def _is_partial_unpacking_expr_count(self, expression_count: int, target_node: a
target_count: typing.Final = self._count_unpacked_vars(target_node)
return target_count > expression_count and target_count > 1

def _is_literal_range(self, iter_node: ast.expr) -> bool:
"""Check if the iteration is over a literal range() call."""
# Check for direct range() call
if isinstance(iter_node, ast.Call) and isinstance(iter_node.func, ast.Name) and iter_node.func.id == "range":
# Check if all arguments are literals (no variables)
for one_arg in iter_node.args:
if not isinstance(one_arg, (ast.Constant, ast.UnaryOp)):
# If any argument is not a literal, this is not a literal range
# Note: UnaryOp is included to handle negative numbers like -1
return False
# For UnaryOp (like -1), check if operand is a literal
if isinstance(one_arg, ast.UnaryOp) and not isinstance(one_arg.operand, ast.Constant):
return False
return True
return False

def _count_referenced_vars(self, expression: ast.expr) -> int:
"""Count how many variables are referenced in the expression."""
if isinstance(expression, ast.Name):
Expand All @@ -83,8 +99,12 @@ def _count_unpacked_vars(self, target_node: ast.expr) -> int:
return len([one_element for one_element in target_node.elts if isinstance(one_element, ast.Name)])
return 0

def _validate_comprehension_target(self, target_node: ast.expr) -> None:
def _validate_comprehension_target(self, target_node: ast.expr, iter_node: ast.expr | None = None) -> None:
"""Validate that comprehension target follows the one_ prefix rule."""
# Skip validation if iterating over literal range()
if iter_node is not None and self._is_literal_range(iter_node):
return

# Skip ignored targets (underscore, unpacking)
if _is_ignored_target(target_node):
return
Expand Down
60 changes: 60 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,44 @@ def test_variable_usage_validations(input_source: str, expected_output: list[str
"import typing\nclass MyProtocol(typing.Protocol, object):\n def fetch_value(self) -> int: ...\n",
[],
),
# No violation: Child classes don't require @typing.final since they inherit from other classes
(
"class ParentClass:\n pass\n\nclass ChildClass(ParentClass):\n pass",
["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt
),
# No violation: Multiple levels of inheritance
(
"class GrandParentClass:\n pass\n\nclass ParentClass(GrandParentClass):\n pass\n\n"
"class ChildClass(ParentClass):\n pass",
["COP012"], # Only GrandParentClass requires final decorator, ParentClass and
# ChildClass inherit so they're exempt
),
# No violation: Child classes with module notation don't require @typing.final
(
"class ParentClass:\n pass\n\nclass ChildClass(module.ParentClass):\n pass",
["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt
),
# No violation: Child class properly inherits, parent doesn't need final decorator
(
"import typing\n\nclass ParentClass:\n pass\n\n@typing.final\nclass ChildClass(ParentClass):\n pass",
[], # No violations - ChildClass is properly marked final
),
# No violation: Complex inheritance hierarchy with proper final decorators
(
"import typing\n\n"
"class BaseClass:\n pass\n\n"
"class MiddleClass(BaseClass):\n pass\n\n"
"@typing.final\nclass DerivedClass(MiddleClass):\n pass",
[], # No violations - derived classes are properly marked final
),
# No violation: Multiple inheritance with proper final decorators
(
"import typing\n\n"
"class FirstParent:\n pass\n\n"
"class SecondParent:\n pass\n\n"
"@typing.final\nclass ChildClass(FirstParent, SecondParent):\n pass",
[], # No violations - ChildClass is properly marked final
),
],
)
def test_class_validations(input_source: str, expected_output: list[str]) -> None:
Expand Down Expand Up @@ -605,6 +643,28 @@ def test_dataclass_validations(input_source: str, expected_output: list[str]) ->
("for x, y in pairs: pass", []),
# No violation: Regular for-loop with one_ prefix
("for one_x in some_list: pass", []),
# No violation: Regular for-loop over literal range() without one_ prefix
("for cur_number in range(10): pass", []),
# No violation: Regular for-loop over literal range() with start and stop
("for cur_number in range(5, 10): pass", []),
# No violation: Regular for-loop over literal range() with start, stop, and step
("for cur_number in range(0, 10, 2): pass", []),
# No violation: Regular for-loop over literal range() with negative values
("for cur_number in range(-5, 5): pass", []),
# COP015: Regular for-loop over non-literal range() should still require one_ prefix
("for cur_number in range(some_variable): pass", ["COP015"]),
# COP015: Regular for-loop over non-literal range() with multiple variables should still require one_ prefix
("for cur_number in range(start, stop): pass", ["COP015"]),
# No violation: List comprehension over literal range() without one_ prefix
("my_result = [cur_number for cur_number in range(10)]", []),
# COP015: List comprehension over non-literal range() should still require one_ prefix
("my_result = [cur_number for cur_number in range(variable)]", ["COP015"]),
# No violation: Set comprehension over literal range() without one_ prefix
("my_result = {cur_number for cur_number in range(10)}", []),
# No violation: Dict comprehension over literal range() without one_ prefix
("my_result = {cur_number: cur_number for cur_number in range(10)}", []),
# No violation: Generator expression over literal range() without one_ prefix
("my_result = (cur_number for cur_number in range(10))", []),
],
)
def test_module_vs_class_level_assignments(input_source: str, expected_output: list[str]) -> None:
Expand Down