From 586ceae588976369c069f07232be714d2501ae95 Mon Sep 17 00:00:00 2001 From: Gideon Zenz <91069374+gzenz@users.noreply.github.com> Date: Tue, 14 Apr 2026 17:26:18 +0200 Subject: [PATCH] feat: extract per-language logic into handler classes + parser improvements Refactor parser.py to use a strategy pattern for language-specific logic: - Add BaseLanguageHandler with 7-method interface - 21 handler implementations in code_review_graph/lang/ - Thread-safe handler registration and type-set caching Parser accuracy improvements: - Typed variable call resolution for Python, Kotlin, Java, JS/TS - Star import expansion via __all__ or module-level definitions - Method call noise filtering via _INSTANCE_METHOD_BLOCKLIST - Function/class reference detection in call args and return values - Angular .component.html template parsing - Dart call extraction for non-standard AST structure --- code_review_graph/lang/__init__.py | 56 + code_review_graph/lang/_base.py | 62 + code_review_graph/lang/_c_cpp.py | 41 + code_review_graph/lang/_csharp.py | 33 + code_review_graph/lang/_dart.py | 65 + code_review_graph/lang/_go.py | 73 + code_review_graph/lang/_java.py | 30 + code_review_graph/lang/_javascript.py | 304 ++ code_review_graph/lang/_kotlin.py | 24 + code_review_graph/lang/_lua.py | 314 ++ code_review_graph/lang/_perl.py | 24 + code_review_graph/lang/_php.py | 13 + code_review_graph/lang/_python.py | 109 + code_review_graph/lang/_r.py | 339 ++ code_review_graph/lang/_ruby.py | 23 + code_review_graph/lang/_rust.py | 22 + code_review_graph/lang/_scala.py | 54 + code_review_graph/lang/_solidity.py | 222 ++ code_review_graph/lang/_swift.py | 13 + code_review_graph/parser.py | 3317 ++++++++++------- tests/fixtures/android_lifecycle.kt | 33 + tests/fixtures/express_routes.ts | 24 + tests/fixtures/js_namespace_import.ts | 6 + tests/fixtures/js_reexport.ts | 2 + tests/fixtures/js_require.js | 8 + tests/fixtures/jsx_handler_refs.tsx | 32 + tests/fixtures/resolution_java_import.java | 21 + tests/fixtures/resolution_kotlin_import.kt | 20 + .../resolution_python_module_import.py | 16 + .../fixtures/resolution_python_star_import.py | 11 + tests/fixtures/resolution_ts_cross_file.ts | 17 + tests/fixtures/sample.kt | 9 + tests/fixtures/sample.swift | 29 - tests/fixtures/servlet_handler.java | 26 + tests/test_multilang.py | 248 +- tests/test_parser.py | 655 +++- 36 files changed, 4804 insertions(+), 1491 deletions(-) create mode 100644 code_review_graph/lang/__init__.py create mode 100644 code_review_graph/lang/_base.py create mode 100644 code_review_graph/lang/_c_cpp.py create mode 100644 code_review_graph/lang/_csharp.py create mode 100644 code_review_graph/lang/_dart.py create mode 100644 code_review_graph/lang/_go.py create mode 100644 code_review_graph/lang/_java.py create mode 100644 code_review_graph/lang/_javascript.py create mode 100644 code_review_graph/lang/_kotlin.py create mode 100644 code_review_graph/lang/_lua.py create mode 100644 code_review_graph/lang/_perl.py create mode 100644 code_review_graph/lang/_php.py create mode 100644 code_review_graph/lang/_python.py create mode 100644 code_review_graph/lang/_r.py create mode 100644 code_review_graph/lang/_ruby.py create mode 100644 code_review_graph/lang/_rust.py create mode 100644 code_review_graph/lang/_scala.py create mode 100644 code_review_graph/lang/_solidity.py create mode 100644 code_review_graph/lang/_swift.py create mode 100644 tests/fixtures/android_lifecycle.kt create mode 100644 tests/fixtures/express_routes.ts create mode 100644 tests/fixtures/js_namespace_import.ts create mode 100644 tests/fixtures/js_reexport.ts create mode 100644 tests/fixtures/js_require.js create mode 100644 tests/fixtures/jsx_handler_refs.tsx create mode 100644 tests/fixtures/resolution_java_import.java create mode 100644 tests/fixtures/resolution_kotlin_import.kt create mode 100644 tests/fixtures/resolution_python_module_import.py create mode 100644 tests/fixtures/resolution_python_star_import.py create mode 100644 tests/fixtures/resolution_ts_cross_file.ts create mode 100644 tests/fixtures/servlet_handler.java diff --git a/code_review_graph/lang/__init__.py b/code_review_graph/lang/__init__.py new file mode 100644 index 00000000..80b85b76 --- /dev/null +++ b/code_review_graph/lang/__init__.py @@ -0,0 +1,56 @@ +"""Per-language parsing handlers.""" + +from ._base import BaseLanguageHandler +from ._c_cpp import CHandler, CppHandler +from ._csharp import CSharpHandler +from ._dart import DartHandler +from ._go import GoHandler +from ._java import JavaHandler +from ._javascript import JavaScriptHandler, TsxHandler, TypeScriptHandler +from ._kotlin import KotlinHandler +from ._lua import LuaHandler, LuauHandler +from ._perl import PerlHandler +from ._php import PhpHandler +from ._python import PythonHandler +from ._r import RHandler +from ._ruby import RubyHandler +from ._rust import RustHandler +from ._scala import ScalaHandler +from ._solidity import SolidityHandler +from ._swift import SwiftHandler + +ALL_HANDLERS: list[BaseLanguageHandler] = [ + GoHandler(), + PythonHandler(), + JavaScriptHandler(), + TypeScriptHandler(), + TsxHandler(), + RustHandler(), + CHandler(), + CppHandler(), + JavaHandler(), + CSharpHandler(), + KotlinHandler(), + ScalaHandler(), + SolidityHandler(), + RubyHandler(), + DartHandler(), + SwiftHandler(), + PhpHandler(), + PerlHandler(), + RHandler(), + LuaHandler(), + LuauHandler(), +] + +__all__ = [ + "BaseLanguageHandler", "ALL_HANDLERS", + "GoHandler", "PythonHandler", + "JavaScriptHandler", "TypeScriptHandler", "TsxHandler", + "RustHandler", "CHandler", "CppHandler", + "JavaHandler", "CSharpHandler", "KotlinHandler", + "ScalaHandler", "SolidityHandler", + "RubyHandler", "DartHandler", + "SwiftHandler", "PhpHandler", "PerlHandler", + "RHandler", "LuaHandler", "LuauHandler", +] diff --git a/code_review_graph/lang/_base.py b/code_review_graph/lang/_base.py new file mode 100644 index 00000000..fb2ddca0 --- /dev/null +++ b/code_review_graph/lang/_base.py @@ -0,0 +1,62 @@ +"""Base class for language-specific parsing handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..parser import CodeParser, EdgeInfo, NodeInfo + + +class BaseLanguageHandler: + """Override methods where a language differs from default CodeParser logic. + + Methods returning ``NotImplemented`` signal 'use the default code path'. + Subclasses only need to override what they actually customise. + """ + + language: str = "" + class_types: list[str] = [] + function_types: list[str] = [] + import_types: list[str] = [] + call_types: list[str] = [] + builtin_names: frozenset[str] = frozenset() + + def get_name(self, node, kind: str) -> str | None: + return NotImplemented + + def get_bases(self, node, source: bytes) -> list[str]: + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + return NotImplemented + + def collect_import_names(self, node, file_path: str, import_map: dict[str, str]) -> bool: + """Populate import_map from an import node. Return True if handled.""" + return False + + def resolve_module(self, module: str, caller_file: str) -> str | None: + """Resolve a module path to a file path. Return NotImplemented to fall back.""" + return NotImplemented + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + """Handle language-specific AST constructs. + + Returns True if the child was fully handled (skip generic dispatch). + Default: returns False (no language-specific handling). + """ + return False diff --git a/code_review_graph/lang/_c_cpp.py b/code_review_graph/lang/_c_cpp.py new file mode 100644 index 00000000..9659db80 --- /dev/null +++ b/code_review_graph/lang/_c_cpp.py @@ -0,0 +1,41 @@ +"""C / C++ language handlers.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class _CBase(BaseLanguageHandler): + """Shared handler logic for C and C++.""" + + import_types = ["preproc_include"] + call_types = ["call_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type in ("system_lib_string", "string_literal"): + val = child.text.decode("utf-8", errors="replace").strip("<>\"") + imports.append(val) + return imports + + +class CHandler(_CBase): + language = "c" + class_types = ["struct_specifier", "type_definition"] + function_types = ["function_definition"] + + +class CppHandler(_CBase): + language = "cpp" + class_types = ["class_specifier", "struct_specifier"] + function_types = ["function_definition"] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "base_class_clause": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_csharp.py b/code_review_graph/lang/_csharp.py new file mode 100644 index 00000000..0821ecc7 --- /dev/null +++ b/code_review_graph/lang/_csharp.py @@ -0,0 +1,33 @@ +"""C# language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class CSharpHandler(BaseLanguageHandler): + language = "csharp" + class_types = [ + "class_declaration", "interface_declaration", + "enum_declaration", "struct_declaration", + ] + function_types = ["method_declaration", "constructor_declaration"] + import_types = ["using_directive"] + call_types = ["invocation_expression", "object_creation_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + parts = text.split() + if len(parts) >= 2: + return [parts[-1].rstrip(";")] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_dart.py b/code_review_graph/lang/_dart.py new file mode 100644 index 00000000..8d9b3069 --- /dev/null +++ b/code_review_graph/lang/_dart.py @@ -0,0 +1,65 @@ +"""Dart language handler.""" + +from __future__ import annotations + +from typing import Optional + +from ._base import BaseLanguageHandler + + +class DartHandler(BaseLanguageHandler): + language = "dart" + class_types = ["class_definition", "mixin_declaration", "enum_declaration"] + # function_signature covers both top-level functions and class methods + # (class methods appear as method_signature > function_signature pairs; + # the parser recurses into method_signature generically and then matches + # function_signature inside it). + function_types = ["function_signature"] + # import_or_export wraps library_import > import_specification > configurable_uri + import_types = ["import_or_export"] + call_types: list[str] = [] # Dart uses call_expression from fallback + + def get_name(self, node, kind: str) -> str | None: + # function_signature has a return-type node before the identifier; + # search only for 'identifier' to avoid returning the return type name. + if node.type == "function_signature": + for child in node.children: + if child.type == "identifier": + return child.text.decode("utf-8", errors="replace") + return None + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + val = self._find_string_literal(node) + if val: + return [val] + return [] + + @staticmethod + def _find_string_literal(node) -> Optional[str]: + if node.type == "string_literal": + return node.text.decode("utf-8", errors="replace").strip("'\"") + for child in node.children: + result = DartHandler._find_string_literal(child) + if result is not None: + return result + return None + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "superclass": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "mixins": + for m in sub.children: + if m.type == "type_identifier": + bases.append( + m.text.decode("utf-8", errors="replace"), + ) + elif child.type == "interfaces": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_go.py b/code_review_graph/lang/_go.py new file mode 100644 index 00000000..048f1477 --- /dev/null +++ b/code_review_graph/lang/_go.py @@ -0,0 +1,73 @@ +"""Go language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class GoHandler(BaseLanguageHandler): + language = "go" + class_types = ["type_declaration"] + function_types = ["function_declaration", "method_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression"] + builtin_names = frozenset({ + "len", "cap", "make", "new", "delete", "append", "copy", + "close", "panic", "recover", "print", "println", + }) + + def get_name(self, node, kind: str) -> str | None: + # Go type_declaration wraps type_spec which holds the identifier + if node.type == "type_declaration": + for child in node.children: + if child.type == "type_spec": + for sub in child.children: + if sub.type in ("identifier", "name", "type_identifier"): + return sub.text.decode("utf-8", errors="replace") + return None + return NotImplemented # fall back to default for function_declaration etc. + + def get_bases(self, node, source: bytes) -> list[str]: + # Embedded structs / interface composition + # Embedded fields are field_declaration nodes with only a type_identifier + # (no field name), e.g. `type Child struct { Parent }` + bases = [] + for child in node.children: + if child.type == "type_spec": + for sub in child.children: + if sub.type in ("struct_type", "interface_type"): + for field_node in sub.children: + if field_node.type == "field_declaration_list": + for f in field_node.children: + if f.type == "field_declaration": + children = [ + c for c in f.children + if c.type not in ("comment",) + ] + if ( + len(children) == 1 + and children[0].type == "type_identifier" + ): + bases.append( + children[0].text.decode( + "utf-8", errors="replace", + ) + ) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "import_spec_list": + for spec in child.children: + if spec.type == "import_spec": + for s in spec.children: + if s.type == "interpreted_string_literal": + val = s.text.decode("utf-8", errors="replace") + imports.append(val.strip('"')) + elif child.type == "import_spec": + for s in child.children: + if s.type == "interpreted_string_literal": + val = s.text.decode("utf-8", errors="replace") + imports.append(val.strip('"')) + return imports diff --git a/code_review_graph/lang/_java.py b/code_review_graph/lang/_java.py new file mode 100644 index 00000000..08849574 --- /dev/null +++ b/code_review_graph/lang/_java.py @@ -0,0 +1,30 @@ +"""Java language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class JavaHandler(BaseLanguageHandler): + language = "java" + class_types = ["class_declaration", "interface_declaration", "enum_declaration"] + function_types = ["method_declaration", "constructor_declaration"] + import_types = ["import_declaration"] + call_types = ["method_invocation", "object_creation_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + parts = text.split() + if len(parts) >= 2: + return [parts[-1].rstrip(";")] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_javascript.py b/code_review_graph/lang/_javascript.py new file mode 100644 index 00000000..5e565f81 --- /dev/null +++ b/code_review_graph/lang/_javascript.py @@ -0,0 +1,304 @@ +"""JavaScript / TypeScript / TSX language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class _JsTsBase(BaseLanguageHandler): + """Shared handler logic for JS, TS, and TSX.""" + + class_types = ["class_declaration", "class"] + function_types = ["function_declaration", "method_definition", "arrow_function"] + import_types = ["import_statement"] + # No builtin_names -- JS/TS builtins are not filtered + + _JS_FUNC_VALUE_TYPES = frozenset( + {"arrow_function", "function_expression", "function"}, + ) + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ("extends_clause", "implements_clause"): + for sub in child.children: + if sub.type in ("identifier", "type_identifier", "nested_identifier"): + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "string": + val = child.text.decode("utf-8", errors="replace").strip("'\"") + imports.append(val) + return imports + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + # --- Variable-assigned functions (const foo = () => {}) --- + if node_type in ("lexical_declaration", "variable_declaration"): + if self._extract_var_functions( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ): + return True + + # --- Class field arrow functions (handler = () => {}) --- + if node_type == "public_field_definition": + if self._extract_field_function( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ): + return True + + # --- Re-exports: export { X } from './mod', export * from './mod' --- + if node_type == "export_statement": + self._extract_reexport_edges(child, parser, file_path, edges) + # Don't return True -- export_statement may also contain definitions + return False + + return False + + # ------------------------------------------------------------------ + # Extraction helpers + # ------------------------------------------------------------------ + + def _extract_var_functions( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + _depth: int, + ) -> bool: + """Handle JS/TS variable declarations that assign functions. + + Patterns handled: + const foo = () => {} + let bar = function() {} + export const baz = (x: number): string => x.toString() + + Returns True if at least one function was extracted from the + declaration, so the caller can skip generic recursion. + """ + language = self.language + handled = False + for declarator in child.children: + if declarator.type != "variable_declarator": + continue + + # Find identifier and function value + var_name = None + func_node = None + for sub in declarator.children: + if sub.type == "identifier" and var_name is None: + var_name = sub.text.decode("utf-8", errors="replace") + elif sub.type in self._JS_FUNC_VALUE_TYPES: + func_node = sub + + if not var_name or not func_node: + continue + + is_test = _is_test_function(var_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(var_name, file_path, enclosing_class) + params = parser._get_params(func_node, language, source) + ret_type = parser._get_return_type(func_node, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + return_type=ret_type, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + + # Recurse into the function body for calls + parser._extract_from_tree( + func_node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=var_name, + import_map=import_map, + defined_names=defined_names, + _depth=_depth + 1, + ) + handled = True + + if not handled: + # Not a function assignment -- let generic recursion handle it + return False + return True + + def _extract_field_function( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + _depth: int, + ) -> bool: + """Handle class field arrow functions: handler = (e) => { ... }""" + language = self.language + prop_name = None + func_node = None + for sub in child.children: + if sub.type == "property_identifier" and prop_name is None: + prop_name = sub.text.decode("utf-8", errors="replace") + elif sub.type in self._JS_FUNC_VALUE_TYPES: + func_node = sub + + if not prop_name or not func_node: + return False + + is_test = _is_test_function(prop_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(prop_name, file_path, enclosing_class) + params = parser._get_params(func_node, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=prop_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + + parser._extract_from_tree( + func_node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=prop_name, + import_map=import_map, + defined_names=defined_names, + _depth=_depth + 1, + ) + return True + + def _extract_reexport_edges( + self, + node, + parser: CodeParser, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit IMPORTS_FROM edges for JS/TS re-exports with ``from`` clause.""" + language = self.language + # Must have a 'from' string + module = None + for child in node.children: + if child.type == "string": + module = child.text.decode("utf-8", errors="replace").strip("'\"") + if not module: + return + resolved = parser._resolve_module_to_file(module, file_path, language) + target = resolved if resolved else module + # File-level IMPORTS_FROM + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + # Per-symbol edges for named re-exports + if resolved: + for child in node.children: + if child.type == "export_clause": + for spec in child.children: + if spec.type == "export_specifier": + names = [ + s.text.decode("utf-8", errors="replace") + for s in spec.children + if s.type == "identifier" + ] + if names: + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=f"{resolved}::{names[0]}", + file_path=file_path, + line=node.start_point[0] + 1, + )) + + +class JavaScriptHandler(_JsTsBase): + language = "javascript" + call_types = [ + "call_expression", "new_expression", + ] + + +class TypeScriptHandler(_JsTsBase): + language = "typescript" + call_types = ["call_expression", "new_expression"] + + +class TsxHandler(_JsTsBase): + language = "tsx" + call_types = [ + "call_expression", "new_expression", + ] diff --git a/code_review_graph/lang/_kotlin.py b/code_review_graph/lang/_kotlin.py new file mode 100644 index 00000000..bb972156 --- /dev/null +++ b/code_review_graph/lang/_kotlin.py @@ -0,0 +1,24 @@ +"""Kotlin language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class KotlinHandler(BaseLanguageHandler): + language = "kotlin" + class_types = ["class_declaration", "object_declaration"] + function_types = ["function_declaration"] + import_types = ["import_header"] + call_types = ["call_expression"] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + "delegation_specifier", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_lua.py b/code_review_graph/lang/_lua.py new file mode 100644 index 00000000..2df58079 --- /dev/null +++ b/code_review_graph/lang/_lua.py @@ -0,0 +1,314 @@ +"""Lua language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class LuaHandler(BaseLanguageHandler): + language = "lua" + class_types: list[str] = [] # Lua has no class keyword; table-based OOP + function_types = ["function_declaration"] + import_types: list[str] = [] # require() handled via extract_constructs + call_types = ["function_call"] + + def get_name(self, node, kind: str) -> str | None: + # function_declaration names may be dot_index_expression or + # method_index_expression (e.g. function Animal.new() / Animal:speak()). + # Return only the method name; the table name is used as parent_name + # in extract_constructs. + if node.type == "function_declaration": + for child in node.children: + if child.type in ("dot_index_expression", "method_index_expression"): + for sub in reversed(child.children): + if sub.type == "identifier": + return sub.text.decode("utf-8", errors="replace") + return None + return NotImplemented + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + """Handle Lua-specific AST constructs. + + Handles: + - variable_declaration with require() -> IMPORTS_FROM edge + - variable_declaration with function_definition -> named Function node + - function_declaration with dot/method name -> Function with table parent + - top-level require() call -> IMPORTS_FROM edge + """ + if node_type == "variable_declaration": + return self._handle_variable_declaration( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ) + + if node_type == "function_declaration": + return self._handle_table_function( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ) + + # Top-level require() not wrapped in variable_declaration + if node_type == "function_call" and not enclosing_func: + req_target = self._get_require_target(child) + if req_target is not None: + resolved = parser._resolve_module_to_file( + req_target, file_path, self.language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + return False + + # ------------------------------------------------------------------ + # Lua-specific helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_require_target(call_node) -> Optional[str]: + """Extract the module path from a Lua require() call. + + Returns the string argument or None if this is not a require() call. + """ + first_child = call_node.children[0] if call_node.children else None + if ( + not first_child + or first_child.type != "identifier" + or first_child.text != b"require" + ): + return None + for child in call_node.children: + if child.type == "arguments": + for arg in child.children: + if arg.type == "string": + for sub in arg.children: + if sub.type == "string_content": + return sub.text.decode( + "utf-8", errors="replace", + ) + raw = arg.text.decode("utf-8", errors="replace") + return raw.strip("'\"") + return None + + def _handle_variable_declaration( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + depth: int, + ) -> bool: + """Handle Lua variable declarations that contain require() or + anonymous function definitions. + + ``local json = require("json")`` -> IMPORTS_FROM edge + ``local fn = function(x) ... end`` -> Function node named "fn" + """ + language = self.language + + # Walk into: variable_declaration > assignment_statement + assign = None + for sub in child.children: + if sub.type == "assignment_statement": + assign = sub + break + if not assign: + return False + + # Get variable name from variable_list + var_name = None + for sub in assign.children: + if sub.type == "variable_list": + for ident in sub.children: + if ident.type == "identifier": + var_name = ident.text.decode("utf-8", errors="replace") + break + break + + # Get value from expression_list + expr_list = None + for sub in assign.children: + if sub.type == "expression_list": + expr_list = sub + break + + if not var_name or not expr_list: + return False + + # Check for require() call + for expr in expr_list.children: + if expr.type == "function_call": + req_target = self._get_require_target(expr) + if req_target is not None: + resolved = parser._resolve_module_to_file( + req_target, file_path, language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + # Check for anonymous function: local foo = function(...) end + for expr in expr_list.children: + if expr.type == "function_definition": + is_test = _is_test_function(var_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(var_name, file_path, enclosing_class) + params = parser._get_params(expr, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + # Recurse into the function body for calls + parser._extract_from_tree( + expr, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=var_name, + import_map=import_map, + defined_names=defined_names, + _depth=depth + 1, + ) + return True + + return False + + def _handle_table_function( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + depth: int, + ) -> bool: + """Handle Lua function declarations with table-qualified names. + + ``function Animal.new(name)`` -> Function "new", parent "Animal" + ``function Animal:speak()`` -> Function "speak", parent "Animal" + + Plain ``function foo()`` is NOT handled here (returns False). + """ + language = self.language + table_name = None + method_name = None + + for sub in child.children: + if sub.type in ("dot_index_expression", "method_index_expression"): + identifiers = [ + c for c in sub.children if c.type == "identifier" + ] + if len(identifiers) >= 2: + table_name = identifiers[0].text.decode( + "utf-8", errors="replace", + ) + method_name = identifiers[-1].text.decode( + "utf-8", errors="replace", + ) + break + + if not table_name or not method_name: + return False + + is_test = _is_test_function(method_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(method_name, file_path, table_name) + params = parser._get_params(child, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=method_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=table_name, + params=params, + is_test=is_test, + )) + # CONTAINS: table -> method + container = parser._qualify(table_name, file_path, None) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + # Recurse into function body for calls + parser._extract_from_tree( + child, source, language, file_path, nodes, edges, + enclosing_class=table_name, + enclosing_func=method_name, + import_map=import_map, + defined_names=defined_names, + _depth=depth + 1, + ) + return True + + +class LuauHandler(LuaHandler): + """Roblox Luau (.luau) handler -- reuses the Lua handler.""" + + language = "luau" + class_types = ["type_definition"] diff --git a/code_review_graph/lang/_perl.py b/code_review_graph/lang/_perl.py new file mode 100644 index 00000000..fba72cf6 --- /dev/null +++ b/code_review_graph/lang/_perl.py @@ -0,0 +1,24 @@ +"""Perl language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class PerlHandler(BaseLanguageHandler): + language = "perl" + class_types = ["package_statement", "class_statement", "role_statement"] + function_types = ["subroutine_declaration_statement", "method_declaration_statement"] + import_types = ["use_statement", "require_expression"] + call_types = [ + "function_call_expression", "method_call_expression", + "ambiguous_function_call_expression", + ] + + def get_name(self, node, kind: str) -> str | None: + for child in node.children: + if child.type == "bareword": + return child.text.decode("utf-8", errors="replace") + if child.type == "package" and child.text != b"package": + return child.text.decode("utf-8", errors="replace") + return NotImplemented diff --git a/code_review_graph/lang/_php.py b/code_review_graph/lang/_php.py new file mode 100644 index 00000000..f299835f --- /dev/null +++ b/code_review_graph/lang/_php.py @@ -0,0 +1,13 @@ +"""PHP language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class PhpHandler(BaseLanguageHandler): + language = "php" + class_types = ["class_declaration", "interface_declaration"] + function_types = ["function_definition", "method_declaration"] + import_types = ["namespace_use_declaration"] + call_types = ["function_call_expression", "member_call_expression"] diff --git a/code_review_graph/lang/_python.py b/code_review_graph/lang/_python.py new file mode 100644 index 00000000..f836aeef --- /dev/null +++ b/code_review_graph/lang/_python.py @@ -0,0 +1,109 @@ +"""Python language handler.""" + +from __future__ import annotations + +from pathlib import Path + +from ._base import BaseLanguageHandler + + +class PythonHandler(BaseLanguageHandler): + language = "python" + class_types = ["class_definition"] + function_types = ["function_definition"] + import_types = ["import_statement", "import_from_statement"] + call_types = ["call"] + builtin_names = frozenset({ + "len", "str", "int", "float", "bool", "list", "dict", "set", "tuple", + "print", "range", "enumerate", "zip", "map", "filter", "sorted", + "reversed", "isinstance", "issubclass", "type", "id", "hash", + "hasattr", "getattr", "setattr", "delattr", "callable", + "repr", "abs", "min", "max", "sum", "round", "pow", "divmod", + "iter", "next", "open", "super", "property", "staticmethod", + "classmethod", "vars", "dir", "help", "input", "format", + "bytes", "bytearray", "memoryview", "frozenset", "complex", + "chr", "ord", "hex", "oct", "bin", "any", "all", + }) + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "argument_list": + for arg in child.children: + if arg.type in ("identifier", "attribute"): + bases.append(arg.text.decode("utf-8", errors="replace")) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + if node.type == "import_from_statement": + for child in node.children: + if child.type == "dotted_name": + imports.append(child.text.decode("utf-8", errors="replace")) + break + else: + for child in node.children: + if child.type == "dotted_name": + imports.append(child.text.decode("utf-8", errors="replace")) + return imports + + def collect_import_names( + self, node, file_path: str, import_map: dict[str, str], + ) -> bool: + if node.type == "import_from_statement": + # from X.Y import A, B -> {A: X.Y, B: X.Y} + module = None + seen_import_keyword = False + for child in node.children: + if child.type == "dotted_name" and not seen_import_keyword: + module = child.text.decode("utf-8", errors="replace") + elif child.type == "import": + seen_import_keyword = True + elif seen_import_keyword and module: + if child.type in ("identifier", "dotted_name"): + name = child.text.decode("utf-8", errors="replace") + import_map[name] = module + elif child.type == "aliased_import": + # from X import A as B -> {B: X} + names = [ + sub.text.decode("utf-8", errors="replace") + for sub in child.children + if sub.type in ("identifier", "dotted_name") + ] + if names: + import_map[names[-1]] = module + elif node.type == "import_statement": + # import json -> {json: json} + # import os.path -> {os: os.path} + # import X as Y -> {Y: X} + for child in node.children: + if child.type in ("dotted_name", "identifier"): + mod = child.text.decode("utf-8", errors="replace") + top_level = mod.split(".")[0] + import_map[top_level] = mod + elif child.type == "aliased_import": + names = [ + sub.text.decode("utf-8", errors="replace") + for sub in child.children + if sub.type in ("identifier", "dotted_name") + ] + if len(names) >= 2: + import_map[names[-1]] = names[0] + else: + return False + return True + + def resolve_module(self, module: str, caller_file: str) -> str | None: + caller_dir = Path(caller_file).parent + rel_path = module.replace(".", "/") + candidates = [rel_path + ".py", rel_path + "/__init__.py"] + current = caller_dir + while True: + for candidate in candidates: + target = current / candidate + if target.is_file(): + return str(target.resolve()) + if current == current.parent: + break + current = current.parent + return None diff --git a/code_review_graph/lang/_r.py b/code_review_graph/lang/_r.py new file mode 100644 index 00000000..a15ad973 --- /dev/null +++ b/code_review_graph/lang/_r.py @@ -0,0 +1,339 @@ +"""R language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class RHandler(BaseLanguageHandler): + language = "r" + class_types: list[str] = [] # Classes detected via call pattern-matching + function_types = ["function_definition"] + import_types = ["call"] # library(), require(), source() -- filtered downstream + call_types = ["call"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + """Extract import targets from R library/require/source calls.""" + imports = [] + func_name = self._call_func_name(node) + if func_name in ("library", "require", "source"): + for _name, value in self._iter_args(node): + if value.type == "identifier": + imports.append(value.text.decode("utf-8", errors="replace")) + elif value.type == "string": + val = self._first_string_arg(node) + if val: + imports.append(val) + break # Only first argument matters + return imports + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + if node_type == "binary_operator": + if self._handle_binary_operator( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ): + return True + + if node_type == "call": + if self._handle_call( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ): + return True + + return False + + # ------------------------------------------------------------------ + # R-specific helpers + # ------------------------------------------------------------------ + + @staticmethod + def _call_func_name(call_node) -> Optional[str]: + """Extract the function name from an R call node.""" + for child in call_node.children: + if child.type in ("identifier", "namespace_operator"): + return child.text.decode("utf-8", errors="replace") + return None + + @staticmethod + def _first_string_arg(call_node) -> Optional[str]: + """Extract the first string argument value from an R call node.""" + for child in call_node.children: + if child.type == "arguments": + for arg in child.children: + if arg.type == "argument": + for sub in arg.children: + if sub.type == "string": + for sc in sub.children: + if sc.type == "string_content": + return sc.text.decode("utf-8", errors="replace") + break + return None + + @staticmethod + def _iter_args(call_node): + """Yield (name_str, value_node) pairs from an R call's arguments.""" + for child in call_node.children: + if child.type != "arguments": + continue + for arg in child.children: + if arg.type != "argument": + continue + has_eq = any(sub.type == "=" for sub in arg.children) + if has_eq: + name = None + value = None + for sub in arg.children: + if sub.type == "identifier" and name is None: + name = sub.text.decode("utf-8", errors="replace") + elif sub.type not in ("=", ","): + value = sub + yield (name, value) + else: + for sub in arg.children: + if sub.type not in (",",): + yield (None, sub) + break + break + + @classmethod + def _find_named_arg(cls, call_node, arg_name: str): + """Find a named argument's value node in an R call.""" + for name, value in cls._iter_args(call_node): + if name == arg_name: + return value + return None + + # ------------------------------------------------------------------ + # Extraction methods + # ------------------------------------------------------------------ + + def _handle_binary_operator( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> bool: + """Handle R binary_operator nodes: name <- function(...) { ... }.""" + language = self.language + children = node.children + if len(children) < 3: + return False + + left, op, right = children[0], children[1], children[2] + if op.type not in ("<-", "="): + return False + + if right.type == "function_definition" and left.type == "identifier": + name = left.text.decode("utf-8", errors="replace") + is_test = _is_test_function(name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(name, file_path, enclosing_class) + params = parser._get_params(right, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=name, + file_path=file_path, + line_start=right.start_point[0] + 1, + line_end=right.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=right.start_point[0] + 1, + )) + + parser._extract_from_tree( + right, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, enclosing_func=name, + import_map=import_map, defined_names=defined_names, + ) + return True + + if right.type == "call" and left.type == "identifier": + call_func = self._call_func_name(right) + if call_func in ("setRefClass", "setClass", "setGeneric"): + assign_name = left.text.decode("utf-8", errors="replace") + return self._handle_class_call( + right, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + assign_name=assign_name, + ) + + return False + + def _handle_call( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> bool: + """Handle R call nodes for imports and class definitions.""" + language = self.language + func_name = self._call_func_name(node) + if not func_name: + return False + + if func_name in ("library", "require", "source"): + imports = parser._extract_import(node, language, source) + for imp_target in imports: + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=imp_target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + return True + + if func_name in ("setRefClass", "setClass", "setGeneric"): + return self._handle_class_call( + node, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ) + + if enclosing_func: + call_name = parser._get_call_name(node, language, source) + if call_name: + caller = parser._qualify(enclosing_func, file_path, enclosing_class) + target = parser._resolve_call_target( + call_name, file_path, language, + import_map or {}, defined_names or set(), + ) + edges.append(EdgeInfo( + kind="CALLS", + source=caller, + target=target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + + parser._extract_from_tree( + node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, enclosing_func=enclosing_func, + import_map=import_map, defined_names=defined_names, + ) + return True + + def _handle_class_call( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + assign_name: Optional[str] = None, + ) -> bool: + """Handle setClass/setRefClass/setGeneric calls -> Class nodes.""" + language = self.language + class_name = self._first_string_arg(node) or assign_name + if not class_name: + return False + + qualified = parser._qualify(class_name, file_path, enclosing_class) + nodes.append(NodeInfo( + kind="Class", + name=class_name, + file_path=file_path, + line_start=node.start_point[0] + 1, + line_end=node.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=file_path, + target=qualified, + file_path=file_path, + line=node.start_point[0] + 1, + )) + + methods_list = self._find_named_arg(node, "methods") + if methods_list is not None: + self._extract_methods( + methods_list, source, parser, file_path, + nodes, edges, class_name, + import_map, defined_names, + ) + + return True + + def _extract_methods( + self, list_call, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + class_name: str, + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> None: + """Extract methods from a setRefClass methods = list(...) call.""" + language = self.language + for method_name, func_def in self._iter_args(list_call): + if not method_name or func_def is None: + continue + if func_def.type != "function_definition": + continue + + qualified = parser._qualify(method_name, file_path, class_name) + params = parser._get_params(func_def, language, source) + nodes.append(NodeInfo( + kind="Function", + name=method_name, + file_path=file_path, + line_start=func_def.start_point[0] + 1, + line_end=func_def.end_point[0] + 1, + language=language, + parent_name=class_name, + params=params, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=parser._qualify(class_name, file_path, None), + target=qualified, + file_path=file_path, + line=func_def.start_point[0] + 1, + )) + parser._extract_from_tree( + func_def, source, language, file_path, nodes, edges, + enclosing_class=class_name, + enclosing_func=method_name, + import_map=import_map, + defined_names=defined_names, + ) diff --git a/code_review_graph/lang/_ruby.py b/code_review_graph/lang/_ruby.py new file mode 100644 index 00000000..5a6b11fd --- /dev/null +++ b/code_review_graph/lang/_ruby.py @@ -0,0 +1,23 @@ +"""Ruby language handler.""" + +from __future__ import annotations + +import re + +from ._base import BaseLanguageHandler + + +class RubyHandler(BaseLanguageHandler): + language = "ruby" + class_types = ["class", "module"] + function_types = ["method", "singleton_method"] + import_types = ["call"] # require / require_relative + call_types = ["call", "method_call"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + if "require" in text: + match = re.search(r"""['"](.*?)['"]""", text) + if match: + return [match.group(1)] + return [] diff --git a/code_review_graph/lang/_rust.py b/code_review_graph/lang/_rust.py new file mode 100644 index 00000000..839006ee --- /dev/null +++ b/code_review_graph/lang/_rust.py @@ -0,0 +1,22 @@ +"""Rust language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class RustHandler(BaseLanguageHandler): + language = "rust" + class_types = ["struct_item", "enum_item", "impl_item"] + function_types = ["function_item"] + import_types = ["use_declaration"] + call_types = ["call_expression", "macro_invocation"] + builtin_names = frozenset({ + "println", "eprintln", "format", "vec", "panic", "todo", + "unimplemented", "unreachable", "assert", "assert_eq", "assert_ne", + "dbg", "cfg", + }) + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + return [text.replace("use ", "").rstrip(";").strip()] diff --git a/code_review_graph/lang/_scala.py b/code_review_graph/lang/_scala.py new file mode 100644 index 00000000..e5159d1b --- /dev/null +++ b/code_review_graph/lang/_scala.py @@ -0,0 +1,54 @@ +"""Scala language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class ScalaHandler(BaseLanguageHandler): + language = "scala" + class_types = [ + "class_definition", "trait_definition", + "object_definition", "enum_definition", + ] + function_types = ["function_definition", "function_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression", "instance_expression", "generic_function"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + parts: list[str] = [] + selectors: list[str] = [] + is_wildcard = False + for child in node.children: + if child.type == "identifier": + parts.append(child.text.decode("utf-8", errors="replace")) + elif child.type == "namespace_selectors": + for sub in child.children: + if sub.type == "identifier": + selectors.append(sub.text.decode("utf-8", errors="replace")) + elif child.type == "namespace_wildcard": + is_wildcard = True + base = ".".join(parts) + if selectors: + return [f"{base}.{name}" for name in selectors] + if is_wildcard: + return [f"{base}.*"] + if base: + return [base] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "extends_clause": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "generic_type": + for ident in sub.children: + if ident.type == "type_identifier": + bases.append( + ident.text.decode("utf-8", errors="replace"), + ) + break + return bases diff --git a/code_review_graph/lang/_solidity.py b/code_review_graph/lang/_solidity.py new file mode 100644 index 00000000..efd5560d --- /dev/null +++ b/code_review_graph/lang/_solidity.py @@ -0,0 +1,222 @@ +"""Solidity language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..parser import EdgeInfo, NodeInfo +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class SolidityHandler(BaseLanguageHandler): + language = "solidity" + class_types = [ + "contract_declaration", "interface_declaration", "library_declaration", + "struct_declaration", "enum_declaration", "error_declaration", + "user_defined_type_definition", + ] + # Events and modifiers use kind="Function" because the graph schema has no + # dedicated kind for them. State variables are also modeled as Function + # nodes (public ones auto-generate getters). + function_types = [ + "function_definition", "constructor_definition", "modifier_definition", + "event_definition", "fallback_receive_definition", + ] + import_types = ["import_directive"] + call_types = ["call_expression"] + + def get_name(self, node, kind: str) -> str | None: + if node.type == "constructor_definition": + return "constructor" + if node.type == "fallback_receive_definition": + for child in node.children: + if child.type in ("receive", "fallback"): + return child.text.decode("utf-8", errors="replace") + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "string": + val = child.text.decode("utf-8", errors="replace").strip('"') + if val: + imports.append(val) + return imports + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "inheritance_specifier": + for sub in child.children: + if sub.type == "user_defined_type": + for ident in sub.children: + if ident.type == "identifier": + bases.append( + ident.text.decode("utf-8", errors="replace"), + ) + return bases + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + # Emit statements: emit EventName(...) -> CALLS edge + if node_type == "emit_statement" and enclosing_func: + for sub in child.children: + if sub.type == "expression": + for ident in sub.children: + if ident.type == "identifier": + caller = parser._qualify( + enclosing_func, file_path, + enclosing_class, + ) + edges.append(EdgeInfo( + kind="CALLS", + source=caller, + target=ident.text.decode( + "utf-8", errors="replace", + ), + file_path=file_path, + line=child.start_point[0] + 1, + )) + # emit_statement falls through to default recursion + return False + + # State variable declarations -> Function nodes (public ones + # auto-generate getters, and all are critical for reviews) + if node_type == "state_variable_declaration" and enclosing_class: + var_name = None + var_visibility = None + var_mutability = None + var_type = None + for sub in child.children: + if sub.type == "identifier": + var_name = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "visibility": + var_visibility = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "type_name": + var_type = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type in ("constant", "immutable"): + var_mutability = sub.type + if var_name: + qualified = parser._qualify( + var_name, file_path, enclosing_class, + ) + nodes.append(NodeInfo( + kind="Function", + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=self.language, + parent_name=enclosing_class, + return_type=var_type, + modifiers=var_visibility, + extra={ + "solidity_kind": "state_variable", + "mutability": var_mutability, + }, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=parser._qualify( + enclosing_class, file_path, None, + ), + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + return False + + # File-level and contract-level constant declarations + if node_type == "constant_variable_declaration": + var_name = None + var_type = None + for sub in child.children: + if sub.type == "identifier": + var_name = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "type_name": + var_type = sub.text.decode( + "utf-8", errors="replace", + ) + if var_name: + qualified = parser._qualify( + var_name, file_path, enclosing_class, + ) + nodes.append(NodeInfo( + kind="Function", + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=self.language, + parent_name=enclosing_class, + return_type=var_type, + extra={"solidity_kind": "constant"}, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class + else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + return False + + # Using directives: using LibName for Type -> DEPENDS_ON edge + if node_type == "using_directive": + lib_name = None + for sub in child.children: + if sub.type == "type_alias": + for ident in sub.children: + if ident.type == "identifier": + lib_name = ident.text.decode( + "utf-8", errors="replace", + ) + if lib_name: + source_name = ( + parser._qualify( + enclosing_class, file_path, None, + ) + if enclosing_class + else file_path + ) + edges.append(EdgeInfo( + kind="DEPENDS_ON", + source=source_name, + target=lib_name, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + return False diff --git a/code_review_graph/lang/_swift.py b/code_review_graph/lang/_swift.py new file mode 100644 index 00000000..4a4c6754 --- /dev/null +++ b/code_review_graph/lang/_swift.py @@ -0,0 +1,13 @@ +"""Swift language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class SwiftHandler(BaseLanguageHandler): + language = "swift" + class_types = ["class_declaration", "struct_declaration", "protocol_declaration"] + function_types = ["function_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression"] diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index 31af17f7..d07ba725 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -10,14 +10,18 @@ import json import logging import re +import threading from dataclasses import dataclass, field from pathlib import Path -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional import tree_sitter_language_pack as tslp from .tsconfig_resolver import TsconfigResolver +if TYPE_CHECKING: + from .lang import BaseLanguageHandler + class CellInfo(NamedTuple): """Represents a single cell in a notebook with its language.""" @@ -111,12 +115,7 @@ class EdgeInfo: ".ex": "elixir", ".exs": "elixir", ".ipynb": "notebook", - ".zig": "zig", - ".ps1": "powershell", - ".psm1": "powershell", - ".psd1": "powershell", - ".svelte": "svelte", - ".jl": "julia", + ".html": "html", } # Tree-sitter node type mappings per language @@ -161,9 +160,6 @@ class EdgeInfo: # identifier is literally "defmodule". Dispatched via # _extract_elixir_constructs to avoid matching every ``call`` here. "elixir": [], - "zig": ["container_declaration"], - "powershell": ["class_statement"], - "julia": ["struct_definition", "abstract_definition"], } _FUNCTION_TYPES: dict[str, list[str]] = { @@ -208,12 +204,6 @@ class EdgeInfo: # Elixir: def/defp/defmacro are all ``call`` nodes whose first # identifier matches. Dispatched via _extract_elixir_constructs. "elixir": [], - "zig": ["fn_proto", "fn_decl"], - "powershell": ["function_statement"], - "julia": [ - "function_definition", - "short_function_definition", - ], } _IMPORT_TYPES: dict[str, list[str]] = { @@ -248,12 +238,6 @@ class EdgeInfo: # Elixir: alias/import/require/use are all ``call`` nodes — # handled in _extract_elixir_constructs. "elixir": [], - # Zig: @import("...") is a builtin_call_expr — handled - # generically via call types below. - "zig": [], - "powershell": [], - # Julia: import/using are import_statement nodes. - "julia": ["import_statement", "using_statement"], } _CALL_TYPES: dict[str, list[str]] = { @@ -289,9 +273,6 @@ class EdgeInfo: # _extract_elixir_constructs which filters out def/defmodule/alias/etc. # before treating what's left as a real call. "elixir": [], - "zig": ["call_expression", "builtin_call_expr"], - "powershell": ["command_expression"], - "julia": ["call_expression"], } # Patterns that indicate a test function @@ -329,6 +310,50 @@ class EdgeInfo: "org.junit.Test", "org.junit.jupiter.api.Test", }) +_BUILTIN_NAMES: dict[str, frozenset[str]] = { +} + +# Common JS/TS prototype and built-in method names that should NOT create +# CALLS edges when seen as instance method calls (obj.method()). These are +# so ubiquitous that emitting bare-name edges for them creates noise without +# helping dead-code or flow analysis. +_INSTANCE_METHOD_BLOCKLIST: frozenset[str] = frozenset({ + # Array / iterable + "push", "pop", "shift", "unshift", "splice", "slice", "concat", + "map", "filter", "reduce", "reduceRight", "find", "findIndex", + "forEach", "every", "some", "includes", "indexOf", "lastIndexOf", + "flat", "flatMap", "fill", "sort", "reverse", "join", "entries", + "keys", "values", "at", "with", + # Object / prototype + "toString", "valueOf", "toJSON", "hasOwnProperty", "toLocaleString", + # String + "trim", "trimStart", "trimEnd", "split", "replace", "replaceAll", + "match", "matchAll", "search", "startsWith", "endsWith", "padStart", + "padEnd", "repeat", "substring", "toLowerCase", "toUpperCase", "charAt", + "charCodeAt", "normalize", "localeCompare", + # Promise / async + "then", "catch", "finally", + # Map / Set + "get", "set", "has", "delete", "clear", "add", "size", + # EventEmitter / stream (very generic) + "emit", "pipe", "write", "end", "destroy", "pause", "resume", + # Logging / console + "log", "warn", "error", "info", "debug", "trace", + # DOM / common + "addEventListener", "removeEventListener", "querySelector", + "querySelectorAll", "getElementById", "setAttribute", + "getAttribute", "appendChild", "removeChild", "createElement", + "preventDefault", "stopPropagation", + # RxJS / Observable + "subscribe", "unsubscribe", "next", "complete", + # Common generic names (too ambiguous to resolve) + "call", "apply", "bind", "resolve", "reject", + # Python common builtins used as methods + "append", "extend", "insert", "remove", "update", "items", + "encode", "decode", "strip", "lstrip", "rstrip", "format", + "upper", "lower", "title", "count", "copy", "deepcopy", +}) + def _is_test_file(path: str) -> bool: return any(p.search(path) for p in _TEST_FILE_PATTERNS) @@ -368,19 +393,61 @@ def __init__(self) -> None: self._parsers: dict[str, object] = {} self._module_file_cache: dict[str, Optional[str]] = {} self._export_symbol_cache: dict[str, Optional[str]] = {} + self._star_export_cache: dict[str, set[str]] = {} self._tsconfig_resolver = TsconfigResolver() # Per-parse cache of Dart pubspec root lookups; see #87 self._dart_pubspec_cache: dict[tuple[str, str], Optional[Path]] = {} + self._handlers: dict[str, "BaseLanguageHandler"] = {} + self._type_sets_cache: dict[str, tuple] = {} + self._workspace_map: dict[str, str] = {} # pkg name → directory path + self._workspace_map_built = False + self._lock = threading.Lock() + self._register_handlers() + + def _register_handlers(self) -> None: + from .lang import ALL_HANDLERS + for handler in ALL_HANDLERS: + self._handlers[handler.language] = handler + + def _type_sets(self, language: str): + cached = self._type_sets_cache.get(language) + if cached is not None: + return cached + with self._lock: + cached = self._type_sets_cache.get(language) + if cached is not None: + return cached + handler = self._handlers.get(language) + if handler is not None: + result = ( + set(handler.class_types), + set(handler.function_types), + set(handler.import_types), + set(handler.call_types), + ) + else: + result = ( + set(_CLASS_TYPES.get(language, [])), + set(_FUNCTION_TYPES.get(language, [])), + set(_IMPORT_TYPES.get(language, [])), + set(_CALL_TYPES.get(language, [])), + ) + self._type_sets_cache[language] = result + return result def _get_parser(self, language: str): # type: ignore[arg-type] - if language not in self._parsers: + if language in self._parsers: + return self._parsers[language] + with self._lock: + if language in self._parsers: + return self._parsers[language] try: self._parsers[language] = tslp.get_parser(language) # type: ignore[arg-type] except (LookupError, ValueError, ImportError) as exc: # language not packaged, or grammar load failed logger.debug("tree-sitter parser unavailable for %s: %s", language, exc) return None - return self._parsers[language] + return self._parsers[language] def detect_language(self, path: Path) -> Optional[str]: return EXTENSION_TO_LANGUAGE.get(path.suffix.lower()) @@ -389,7 +456,8 @@ def parse_file(self, path: Path) -> tuple[list[NodeInfo], list[EdgeInfo]]: """Parse a single file and return extracted nodes and edges.""" try: source = path.read_bytes() - except (OSError, PermissionError): + except (OSError, PermissionError) as e: + logger.warning("Cannot read %s: %s", path, e) return [], [] return self.parse_bytes(path, source) @@ -403,21 +471,27 @@ def parse_bytes(self, path: Path, source: bytes) -> tuple[list[NodeInfo], list[E if not language: return [], [] + # Skip likely bundled JS files (Rollup/Vite/webpack output). + # These are single files with thousands of lines that pollute the graph. + if language in ("javascript",) and len(source) > 500_000: + return [], [] + + # Angular templates: regex-based extraction (no tree-sitter grammar) + if language == "html": + return self._parse_angular_template(path, source) + # Vue SFCs: parse with vue parser, then delegate script blocks to JS/TS if language == "vue": return self._parse_vue(path, source) - # Svelte SFCs: same approach as Vue — extract