diff --git a/libs/openant-core/parsers/c/call_graph_builder.py b/libs/openant-core/parsers/c/call_graph_builder.py index 9de6fae..a7af09e 100644 --- a/libs/openant-core/parsers/c/call_graph_builder.py +++ b/libs/openant-core/parsers/c/call_graph_builder.py @@ -110,6 +110,11 @@ def __init__(self, extractor_output: Dict, options: Optional[Dict] = None): self.macros = extractor_output.get('macros', {}) self.macro_aliases = extractor_output.get('macro_aliases', {}) self.prototypes = extractor_output.get('prototypes', {}) + # class_name -> [direct base-class name, ...] for the inheritance walk in + # member dispatch (bug [30]). Defaults to {} when the extractor output + # predates base-class extraction, so resolution degrades to the [51] + # same-type behavior rather than erroring. + self.class_bases: Dict[str, List[str]] = extractor_output.get('class_bases', {}) self.repo_path = extractor_output.get('repository', '') self.max_depth = options.get('max_depth', 3) @@ -121,6 +126,10 @@ def __init__(self, extractor_output: Dict, options: Optional[Dict] = None): # Indexes for faster lookup self.functions_by_name: Dict[str, List[str]] = {} self.functions_by_file: Dict[str, List[str]] = {} + # class_name -> {base_method_name -> [func_id, ...]} for member dispatch. + # Scoped per (class, method) so a receiver-typed call resolves only to a + # method actually declared on that class, never a sibling/free function. + self.methods_by_class: Dict[str, Dict[str, List[str]]] = {} # Include map: file -> set of included header files self.include_map: Dict[str, Set[str]] = {} @@ -153,6 +162,13 @@ def _build_indexes(self) -> None: self.functions_by_file[file_path] = [] self.functions_by_file[file_path].append(func_id) + # Index methods by their declaring class for receiver-type dispatch. + class_name = func_data.get('class_name') + if class_name and name: + method_base = name.split('::')[-1] if '::' in name else name + self.methods_by_class.setdefault(class_name, {}) \ + .setdefault(method_base, []).append(func_id) + # Build include map for file_path, inc_list in self.includes.items(): self.include_map[file_path] = set() @@ -191,15 +207,25 @@ def _extract_calls_from_code(self, code: str, caller_id: str) -> Set[str]: except Exception: return self._extract_calls_regex(code, caller_id) + # Receiver static types inferred from local declarations in this body, + # used to resolve member calls (w.compute() / w->compute()) to the + # method on the receiver's known type. + local_var_types = self._extract_local_var_types(tree.root_node, code_bytes) + stack = [tree.root_node] while stack: node = stack.pop() if node.type == 'call_expression': func_node = node.child_by_field_name('function') if func_node: - call_name = self._extract_call_name(func_node, code_bytes) + call_name, receiver = self._extract_call_name_and_receiver( + func_node, code_bytes + ) if call_name: - resolved = self._resolve_call(call_name, caller_file) + receiver_type = local_var_types.get(receiver) if receiver else None + resolved = self._resolve_call(call_name, caller_file, + receiver_type=receiver_type, + is_member=func_node.type == 'field_expression') if resolved: calls.add(resolved) # A function passed by name as an argument (e.g. @@ -232,6 +258,81 @@ def _extract_callback_args(self, call_node, source: bytes, caller_file: str) -> found.add(resolved) return found + def _extract_call_name_and_receiver(self, node, source: bytes): + """Return (call_name, receiver_identifier) for a call's function child. + + receiver_identifier is the bare identifier text of a member-call receiver + (the `w` in `w.compute()` / `w->compute()`) when it is a simple + identifier, else None. The call_name is identical to what + _extract_call_name returns, so non-member calls are unaffected. + """ + if node.type == 'field_expression': + receiver = None + arg = node.child_by_field_name('argument') + if arg is not None and arg.type == 'identifier': + receiver = source[arg.start_byte:arg.end_byte].decode( + 'utf-8', errors='replace') + # _extract_call_name declines field_expression (no false free-function + # edges); the member name is recovered here from the `field` child and + # resolved ONLY through typed/same-file member dispatch in _resolve_call. + field = node.child_by_field_name('field') + name = None + if field is not None: + name = source[field.start_byte:field.end_byte].decode( + 'utf-8', errors='replace') + if not name.isidentifier(): + name = None + return name, receiver + return self._extract_call_name(node, source), None + + def _extract_local_var_types(self, root, source: bytes) -> Dict[str, str]: + """Map local variable name -> declared type name within a function body. + + Walks `declaration` nodes and records the (type_identifier, variable) + pairs for both plain declarations (`Widget w;`) and pointer declarations + (`Widget* w = ...;`). Only simple type_identifier types are recorded; + anything else (templates, qualified types, multiple declarators we can't + cleanly attribute) is skipped so callers fall back to base-name + resolution rather than risk a wrong-type edge. + """ + var_types: Dict[str, str] = {} + stack = [root] + while stack: + node = stack.pop() + if node.type == 'declaration': + type_node = node.child_by_field_name('type') + if type_node is not None and type_node.type == 'type_identifier': + type_name = source[type_node.start_byte:type_node.end_byte] \ + .decode('utf-8', errors='replace') + # A declaration can hold several declarators (Widget a, b;); + # attribute the type to every variable name we extract. + for child in node.children: + var_name = self._declared_var_name(child, source) + if var_name: + var_types[var_name] = type_name + stack.extend(reversed(node.children)) + return var_types + + def _declared_var_name(self, node, source: bytes) -> Optional[str]: + """Extract the declared variable identifier from a declarator subtree. + + Handles the plain identifier (`w`), pointer_declarator (`* w`) and + init_declarator (`* w = ...` / `w = ...`) shapes. Returns None for nodes + that are not a variable declarator (e.g. the type node, `;`). + """ + if node.type == 'identifier': + return source[node.start_byte:node.end_byte].decode('utf-8', errors='replace') + if node.type in ('pointer_declarator', 'init_declarator', 'reference_declarator'): + inner = node.child_by_field_name('declarator') + if inner is not None: + return self._declared_var_name(inner, source) + # init_declarator with no declarator field: scan children. + for child in node.children: + name = self._declared_var_name(child, source) + if name: + return name + return None + def _extract_call_name(self, node, source: bytes) -> Optional[str]: """Extract the function name from a call_expression's function child.""" text = source[node.start_byte:node.end_byte].decode('utf-8', errors='replace') @@ -274,27 +375,126 @@ def _is_visible_from(self, func_id: str, caller_file: str) -> bool: return True return not func_data.get('is_static', False) - def _resolve_call(self, call_name: str, caller_file: str) -> Optional[str]: - """Resolve a function call name to a function ID.""" + def _resolve_same_file(self, call_name: str, caller_file: str) -> Optional[str]: + """Resolve a call to a user-defined function in the same file, if any.""" + same_file_funcs = self.functions_by_file.get(caller_file, []) + for func_id in same_file_funcs: + func_data = self.functions.get(func_id, {}) + fname = func_data.get('name', '') + base_name = fname.split('::')[-1] if '::' in fname else fname + if base_name == call_name: + return func_id + return None + + def _resolve_method_on_class(self, class_name: str, call_name: str, + caller_file: str) -> Optional[str]: + """Resolve call_name to a method DIRECTLY declared on class_name (same file). + + Returns the func_id of a method named call_name declared on class_name and + defined in caller_file, else None. No inheritance — this is the single-hop + lookup the walk in _resolve_member_call composes over the base chain. + """ + by_method = self.methods_by_class.get(class_name) + if not by_method: + return None + for func_id in by_method.get(call_name, []): + func_data = self.functions.get(func_id, {}) + if func_data.get('file_path', '') == caller_file: + return func_id + return None + + def _resolve_member_call(self, call_name: str, caller_file: str, + receiver_type: str) -> Optional[str]: + """Resolve a member call to the method on the receiver's STATIC type, + walking UP the base-class chain to the first ancestor that defines it. + + Sound static-type floor (bug [30]): start at the receiver's declared type + and return its own method if it defines call_name; otherwise walk up its + direct base classes (BFS, cycle-guarded) and resolve to the FIRST ancestor + that declares call_name in the same file. The walk STOPS at the first + definer, so a derived override resolves to itself, never an ancestor. + + Deliberately does NOT link derived overrides of an ancestor's virtual + method (a documented non-goal that would create false edges): a call via a + Base* receiver resolves to Base's method only — the static-type floor. + + Same-file only: if no class on the chain defines call_name in this + translation unit, returns None so the caller falls back to base-name + resolution (never a wrong-type / unrelated-free-function edge). + """ + visited: Set[str] = set() + queue: List[str] = [receiver_type] + while queue: + cls = queue.pop(0) + if cls in visited: + continue + visited.add(cls) + # First definer on the chain wins (own type before ancestors). + match = self._resolve_method_on_class(cls, call_name, caller_file) + if match: + return match + for base in self.class_bases.get(cls, []): + if base not in visited: + queue.append(base) + return None + + def _resolve_call(self, call_name: str, caller_file: str, + receiver_type: Optional[str] = None, + is_member: bool = False, + _alias_chain: Optional[Set[str]] = None) -> Optional[str]: + """Resolve a function call name to a function ID. + + When receiver_type is given (a member call like w.compute() whose receiver + w has a known same-file type), resolve to that type's method FIRST. If + that fails, fall through to the unchanged base-name resolution below. + """ + if receiver_type: + member_match = self._resolve_member_call(call_name, caller_file, + receiver_type) + if member_match: + return member_match + + # A user-defined function in the SAME FILE shadows any stdlib/builtin + # of the same name, so it must be checked BEFORE the stdlib filter. + # Scope is deliberately same-file only: a genuine stdlib call (no + # same-file definition) still falls through to _is_stdlib below, so we + # never wrongly link a real stdlib call (e.g. printf/open) to an + # unrelated same-named user function in another file. + same_file_user_func = self._resolve_same_file(call_name, caller_file) + if same_file_user_func: + return same_file_user_func + + # A member call (obj->m() / obj.m()) whose receiver type is unknown or + # whose chain defines no such method resolves same-file only: declining + # here keeps the field-expression precision guarantee (never an edge to + # an unrelated cross-file free function of the same name). + if is_member: + return None + if self._is_stdlib(call_name): return None # Check for macro aliases resolved_name = self.macro_aliases.get(call_name, call_name) if resolved_name != call_name: - # Try resolving the aliased name instead - result = self._resolve_call(resolved_name, caller_file) - if result: - return result + # Guard against cyclic macro aliases (e.g. ``#define A B`` / + # ``#define B A`` -> {"A": "B", "B": "A"}). Without a visited-set + # the recursion below would loop A->B->A->... until RecursionError + # aborted the whole repo's call-graph build. + if _alias_chain is None: + _alias_chain = {call_name} + if resolved_name not in _alias_chain: + _alias_chain.add(resolved_name) + # Try resolving the aliased name instead + result = self._resolve_call(resolved_name, caller_file, + _alias_chain=_alias_chain) + if result: + return result # 1. Same-file functions - same_file_funcs = self.functions_by_file.get(caller_file, []) - for func_id in same_file_funcs: - func_data = self.functions.get(func_id, {}) - fname = func_data.get('name', '') - base_name = fname.split('::')[-1] if '::' in fname else fname - if base_name == call_name: - return func_id + same_file_match = self._resolve_same_file(call_name, caller_file) + if same_file_match: + return same_file_match # 2. Functions in included headers included_files = self.include_map.get(caller_file, set()) @@ -392,10 +592,13 @@ def _extract_calls_regex(self, code: str, caller_id: str) -> Set[str]: if func_name in ('if', 'while', 'for', 'switch', 'return', 'sizeof', 'typeof', 'alignof', 'offsetof', 'case', 'else'): continue - if not self._is_stdlib(func_name): - resolved = self._resolve_call(func_name, caller_file) - if resolved: - calls.add(resolved) + # No _is_stdlib gate here: _resolve_call applies the same-file-first + # rule and the stdlib filter internally, so a user function whose + # name collides with a builtin still resolves (same leak as the + # tree-sitter path otherwise). + resolved = self._resolve_call(func_name, caller_file) + if resolved: + calls.add(resolved) return calls diff --git a/libs/openant-core/tests/test_c_macro_alias_cycle.py b/libs/openant-core/tests/test_c_macro_alias_cycle.py new file mode 100644 index 0000000..7e53b5a --- /dev/null +++ b/libs/openant-core/tests/test_c_macro_alias_cycle.py @@ -0,0 +1,51 @@ +"""Regression test for F4: C call-graph builder must not crash on cyclic macro aliases. + +A function-like ``#define`` pair such as:: + + #define A(x) B(x) + #define B(x) A(x) + +produces ``macro_aliases = {"A": "B", "B": "A"}``. Before the fix, +``_resolve_call`` recursed on the aliased name with no cycle guard, so +resolving ``A`` recursed A->B->A->... until ``RecursionError`` aborted the +entire repository's C call-graph build. +""" + +from __future__ import annotations + +import pytest + +tree_sitter_c = pytest.importorskip("tree_sitter_c") + +from parsers.c.call_graph_builder import CallGraphBuilder + + +def _builder(macro_aliases): + return CallGraphBuilder( + { + "functions": {}, + "includes": {}, + "macros": {}, + "macro_aliases": macro_aliases, + } + ) + + +def test_cyclic_macro_alias_does_not_recurse(): + """Two-node alias cycle must resolve to None, not raise RecursionError.""" + builder = _builder({"A": "B", "B": "A"}) + # Must not raise; unresolved cyclic alias returns None. + assert builder._resolve_call("A", "foo.c") is None + assert builder._resolve_call("B", "foo.c") is None + + +def test_longer_macro_alias_cycle_terminates(): + """A 3-node alias cycle must also terminate.""" + builder = _builder({"A": "B", "B": "C", "C": "A"}) + assert builder._resolve_call("A", "foo.c") is None + + +def test_self_macro_alias_does_not_recurse(): + """A self-alias is short-circuited by the != guard but must stay safe.""" + builder = _builder({"A": "A"}) + assert builder._resolve_call("A", "foo.c") is None