diff --git a/src/community_of_python_flake8_plugin/checks/final_class.py b/src/community_of_python_flake8_plugin/checks/final_class.py index 1da75b2..5842e9b 100644 --- a/src/community_of_python_flake8_plugin/checks/final_class.py +++ b/src/community_of_python_flake8_plugin/checks/final_class.py @@ -40,9 +40,24 @@ def is_model_factory_class(class_node: ast.ClassDef) -> bool: return check_inherits_from_bases(class_node, {"ModelFactory", "SQLAlchemyFactory"}) +def has_local_subclasses(syntax_tree: ast.AST, class_node: ast.ClassDef) -> bool: + """Check if there are classes in the same file that inherit from this class.""" + for one_node in ast.walk(syntax_tree): + if isinstance(one_node, ast.ClassDef) and one_node != class_node: + for one_base in one_node.bases: + # Check for direct class reference: class Child(Parent): + if isinstance(one_base, ast.Name) and one_base.id == class_node.name: + return True + # Check for attributed class reference: class Child(module.Parent): + if isinstance(one_base, ast.Attribute) and one_base.attr == class_node.name: + return True + return False + + @typing.final class FinalClassCheck(ast.NodeVisitor): - def __init__(self, syntax_tree: ast.AST) -> None: # noqa: ARG002 + def __init__(self, syntax_tree: ast.AST) -> None: + self.syntax_tree = syntax_tree self.violations: list[Violation] = [] def visit_ClassDef(self, ast_node: ast.ClassDef) -> None: @@ -54,6 +69,10 @@ def _check_final_decorator(self, ast_node: ast.ClassDef) -> None: if is_protocol_class(ast_node) or ast_node.name.startswith("Test") or is_model_factory_class(ast_node): return + # If there are classes in this file that inherit from this class, don't require the decorator + if has_local_subclasses(self.syntax_tree, ast_node): + return + if not contains_final_decorator(ast_node): self.violations.append( Violation( diff --git a/src/community_of_python_flake8_plugin/checks/for_loop_one_prefix.py b/src/community_of_python_flake8_plugin/checks/for_loop_one_prefix.py index 4f46dce..91ce668 100644 --- a/src/community_of_python_flake8_plugin/checks/for_loop_one_prefix.py +++ b/src/community_of_python_flake8_plugin/checks/for_loop_one_prefix.py @@ -22,14 +22,14 @@ def visit_ListComp(self, ast_node: ast.ListComp) -> None: # Validate targets in generators (the 'v' in 'for v in lst') for one_comprehension in ast_node.generators: if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target): - self._validate_comprehension_target(one_comprehension.target) + self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter) self.generic_visit(ast_node) def visit_SetComp(self, ast_node: ast.SetComp) -> None: # Validate targets in generators (the 'v' in 'for v in lst') for one_comprehension in ast_node.generators: if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target): - self._validate_comprehension_target(one_comprehension.target) + self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter) self.generic_visit(ast_node) def visit_DictComp(self, ast_node: ast.DictComp) -> None: @@ -38,7 +38,7 @@ def visit_DictComp(self, ast_node: ast.DictComp) -> None: # key and value are both used for one_comprehension in ast_node.generators: if not self._is_partial_unpacking_expr_count(2, one_comprehension.target): - self._validate_comprehension_target(one_comprehension.target) + self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter) self.generic_visit(ast_node) def visit_For(self, ast_node: ast.For) -> None: @@ -46,14 +46,14 @@ def visit_For(self, ast_node: ast.For) -> None: # Apply same unpacking logic as comprehensions # For-loops don't have an expression that references vars if not self._is_partial_unpacking_expr_count(1, ast_node.target): - self._validate_comprehension_target(ast_node.target) + self._validate_comprehension_target(ast_node.target, ast_node.iter) self.generic_visit(ast_node) def visit_GeneratorExp(self, ast_node: ast.GeneratorExp) -> None: # Validate targets in generators (the 'v' in 'for v in lst') for one_comprehension in ast_node.generators: if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target): - self._validate_comprehension_target(one_comprehension.target) + self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter) self.generic_visit(ast_node) def _is_partial_unpacking(self, expression: ast.expr, target_node: ast.expr) -> bool: @@ -65,6 +65,22 @@ def _is_partial_unpacking_expr_count(self, expression_count: int, target_node: a target_count: typing.Final = self._count_unpacked_vars(target_node) return target_count > expression_count and target_count > 1 + def _is_literal_range(self, iter_node: ast.expr) -> bool: + """Check if the iteration is over a literal range() call.""" + # Check for direct range() call + if isinstance(iter_node, ast.Call) and isinstance(iter_node.func, ast.Name) and iter_node.func.id == "range": + # Check if all arguments are literals (no variables) + for one_arg in iter_node.args: + if not isinstance(one_arg, (ast.Constant, ast.UnaryOp)): + # If any argument is not a literal, this is not a literal range + # Note: UnaryOp is included to handle negative numbers like -1 + return False + # For UnaryOp (like -1), check if operand is a literal + if isinstance(one_arg, ast.UnaryOp) and not isinstance(one_arg.operand, ast.Constant): + return False + return True + return False + def _count_referenced_vars(self, expression: ast.expr) -> int: """Count how many variables are referenced in the expression.""" if isinstance(expression, ast.Name): @@ -83,8 +99,12 @@ def _count_unpacked_vars(self, target_node: ast.expr) -> int: return len([one_element for one_element in target_node.elts if isinstance(one_element, ast.Name)]) return 0 - def _validate_comprehension_target(self, target_node: ast.expr) -> None: + def _validate_comprehension_target(self, target_node: ast.expr, iter_node: ast.expr | None = None) -> None: """Validate that comprehension target follows the one_ prefix rule.""" + # Skip validation if iterating over literal range() + if iter_node is not None and self._is_literal_range(iter_node): + return + # Skip ignored targets (underscore, unpacking) if _is_ignored_target(target_node): return diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 71bb976..ca06ef1 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -422,6 +422,44 @@ def test_variable_usage_validations(input_source: str, expected_output: list[str "import typing\nclass MyProtocol(typing.Protocol, object):\n def fetch_value(self) -> int: ...\n", [], ), + # No violation: Child classes don't require @typing.final since they inherit from other classes + ( + "class ParentClass:\n pass\n\nclass ChildClass(ParentClass):\n pass", + ["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt + ), + # No violation: Multiple levels of inheritance + ( + "class GrandParentClass:\n pass\n\nclass ParentClass(GrandParentClass):\n pass\n\n" + "class ChildClass(ParentClass):\n pass", + ["COP012"], # Only GrandParentClass requires final decorator, ParentClass and + # ChildClass inherit so they're exempt + ), + # No violation: Child classes with module notation don't require @typing.final + ( + "class ParentClass:\n pass\n\nclass ChildClass(module.ParentClass):\n pass", + ["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt + ), + # No violation: Child class properly inherits, parent doesn't need final decorator + ( + "import typing\n\nclass ParentClass:\n pass\n\n@typing.final\nclass ChildClass(ParentClass):\n pass", + [], # No violations - ChildClass is properly marked final + ), + # No violation: Complex inheritance hierarchy with proper final decorators + ( + "import typing\n\n" + "class BaseClass:\n pass\n\n" + "class MiddleClass(BaseClass):\n pass\n\n" + "@typing.final\nclass DerivedClass(MiddleClass):\n pass", + [], # No violations - derived classes are properly marked final + ), + # No violation: Multiple inheritance with proper final decorators + ( + "import typing\n\n" + "class FirstParent:\n pass\n\n" + "class SecondParent:\n pass\n\n" + "@typing.final\nclass ChildClass(FirstParent, SecondParent):\n pass", + [], # No violations - ChildClass is properly marked final + ), ], ) def test_class_validations(input_source: str, expected_output: list[str]) -> None: @@ -605,6 +643,28 @@ def test_dataclass_validations(input_source: str, expected_output: list[str]) -> ("for x, y in pairs: pass", []), # No violation: Regular for-loop with one_ prefix ("for one_x in some_list: pass", []), + # No violation: Regular for-loop over literal range() without one_ prefix + ("for cur_number in range(10): pass", []), + # No violation: Regular for-loop over literal range() with start and stop + ("for cur_number in range(5, 10): pass", []), + # No violation: Regular for-loop over literal range() with start, stop, and step + ("for cur_number in range(0, 10, 2): pass", []), + # No violation: Regular for-loop over literal range() with negative values + ("for cur_number in range(-5, 5): pass", []), + # COP015: Regular for-loop over non-literal range() should still require one_ prefix + ("for cur_number in range(some_variable): pass", ["COP015"]), + # COP015: Regular for-loop over non-literal range() with multiple variables should still require one_ prefix + ("for cur_number in range(start, stop): pass", ["COP015"]), + # No violation: List comprehension over literal range() without one_ prefix + ("my_result = [cur_number for cur_number in range(10)]", []), + # COP015: List comprehension over non-literal range() should still require one_ prefix + ("my_result = [cur_number for cur_number in range(variable)]", ["COP015"]), + # No violation: Set comprehension over literal range() without one_ prefix + ("my_result = {cur_number for cur_number in range(10)}", []), + # No violation: Dict comprehension over literal range() without one_ prefix + ("my_result = {cur_number: cur_number for cur_number in range(10)}", []), + # No violation: Generator expression over literal range() without one_ prefix + ("my_result = (cur_number for cur_number in range(10))", []), ], ) def test_module_vs_class_level_assignments(input_source: str, expected_output: list[str]) -> None: