Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 222 additions & 19 deletions libs/openant-core/parsers/c/call_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def __init__(self, extractor_output: Dict, options: Optional[Dict] = None):
self.macros = extractor_output.get('macros', {})
self.macro_aliases = extractor_output.get('macro_aliases', {})
self.prototypes = extractor_output.get('prototypes', {})
# class_name -> [direct base-class name, ...] for the inheritance walk in
# member dispatch (bug [30]). Defaults to {} when the extractor output
# predates base-class extraction, so resolution degrades to the [51]
# same-type behavior rather than erroring.
self.class_bases: Dict[str, List[str]] = extractor_output.get('class_bases', {})
self.repo_path = extractor_output.get('repository', '')

self.max_depth = options.get('max_depth', 3)
Expand All @@ -121,6 +126,10 @@ def __init__(self, extractor_output: Dict, options: Optional[Dict] = None):
# Indexes for faster lookup
self.functions_by_name: Dict[str, List[str]] = {}
self.functions_by_file: Dict[str, List[str]] = {}
# class_name -> {base_method_name -> [func_id, ...]} for member dispatch.
# Scoped per (class, method) so a receiver-typed call resolves only to a
# method actually declared on that class, never a sibling/free function.
self.methods_by_class: Dict[str, Dict[str, List[str]]] = {}

# Include map: file -> set of included header files
self.include_map: Dict[str, Set[str]] = {}
Expand Down Expand Up @@ -153,6 +162,13 @@ def _build_indexes(self) -> None:
self.functions_by_file[file_path] = []
self.functions_by_file[file_path].append(func_id)

# Index methods by their declaring class for receiver-type dispatch.
class_name = func_data.get('class_name')
if class_name and name:
method_base = name.split('::')[-1] if '::' in name else name
self.methods_by_class.setdefault(class_name, {}) \
.setdefault(method_base, []).append(func_id)

# Build include map
for file_path, inc_list in self.includes.items():
self.include_map[file_path] = set()
Expand Down Expand Up @@ -191,15 +207,25 @@ def _extract_calls_from_code(self, code: str, caller_id: str) -> Set[str]:
except Exception:
return self._extract_calls_regex(code, caller_id)

# Receiver static types inferred from local declarations in this body,
# used to resolve member calls (w.compute() / w->compute()) to the
# method on the receiver's known type.
local_var_types = self._extract_local_var_types(tree.root_node, code_bytes)

stack = [tree.root_node]
while stack:
node = stack.pop()
if node.type == 'call_expression':
func_node = node.child_by_field_name('function')
if func_node:
call_name = self._extract_call_name(func_node, code_bytes)
call_name, receiver = self._extract_call_name_and_receiver(
func_node, code_bytes
)
if call_name:
resolved = self._resolve_call(call_name, caller_file)
receiver_type = local_var_types.get(receiver) if receiver else None
resolved = self._resolve_call(call_name, caller_file,
receiver_type=receiver_type,
is_member=func_node.type == 'field_expression')
if resolved:
calls.add(resolved)
# A function passed by name as an argument (e.g.
Expand Down Expand Up @@ -232,6 +258,81 @@ def _extract_callback_args(self, call_node, source: bytes, caller_file: str) ->
found.add(resolved)
return found

def _extract_call_name_and_receiver(self, node, source: bytes):
"""Return (call_name, receiver_identifier) for a call's function child.

receiver_identifier is the bare identifier text of a member-call receiver
(the `w` in `w.compute()` / `w->compute()`) when it is a simple
identifier, else None. The call_name is identical to what
_extract_call_name returns, so non-member calls are unaffected.
"""
if node.type == 'field_expression':
receiver = None
arg = node.child_by_field_name('argument')
if arg is not None and arg.type == 'identifier':
receiver = source[arg.start_byte:arg.end_byte].decode(
'utf-8', errors='replace')
# _extract_call_name declines field_expression (no false free-function
# edges); the member name is recovered here from the `field` child and
# resolved ONLY through typed/same-file member dispatch in _resolve_call.
field = node.child_by_field_name('field')
name = None
if field is not None:
name = source[field.start_byte:field.end_byte].decode(
'utf-8', errors='replace')
if not name.isidentifier():
name = None
return name, receiver
return self._extract_call_name(node, source), None

def _extract_local_var_types(self, root, source: bytes) -> Dict[str, str]:
"""Map local variable name -> declared type name within a function body.

Walks `declaration` nodes and records the (type_identifier, variable)
pairs for both plain declarations (`Widget w;`) and pointer declarations
(`Widget* w = ...;`). Only simple type_identifier types are recorded;
anything else (templates, qualified types, multiple declarators we can't
cleanly attribute) is skipped so callers fall back to base-name
resolution rather than risk a wrong-type edge.
"""
var_types: Dict[str, str] = {}
stack = [root]
while stack:
node = stack.pop()
if node.type == 'declaration':
type_node = node.child_by_field_name('type')
if type_node is not None and type_node.type == 'type_identifier':
type_name = source[type_node.start_byte:type_node.end_byte] \
.decode('utf-8', errors='replace')
# A declaration can hold several declarators (Widget a, b;);
# attribute the type to every variable name we extract.
for child in node.children:
var_name = self._declared_var_name(child, source)
if var_name:
var_types[var_name] = type_name
stack.extend(reversed(node.children))
return var_types

def _declared_var_name(self, node, source: bytes) -> Optional[str]:
"""Extract the declared variable identifier from a declarator subtree.

Handles the plain identifier (`w`), pointer_declarator (`* w`) and
init_declarator (`* w = ...` / `w = ...`) shapes. Returns None for nodes
that are not a variable declarator (e.g. the type node, `;`).
"""
if node.type == 'identifier':
return source[node.start_byte:node.end_byte].decode('utf-8', errors='replace')
if node.type in ('pointer_declarator', 'init_declarator', 'reference_declarator'):
inner = node.child_by_field_name('declarator')
if inner is not None:
return self._declared_var_name(inner, source)
# init_declarator with no declarator field: scan children.
for child in node.children:
name = self._declared_var_name(child, source)
if name:
return name
return None

def _extract_call_name(self, node, source: bytes) -> Optional[str]:
"""Extract the function name from a call_expression's function child."""
text = source[node.start_byte:node.end_byte].decode('utf-8', errors='replace')
Expand Down Expand Up @@ -274,27 +375,126 @@ def _is_visible_from(self, func_id: str, caller_file: str) -> bool:
return True
return not func_data.get('is_static', False)

def _resolve_call(self, call_name: str, caller_file: str) -> Optional[str]:
"""Resolve a function call name to a function ID."""
def _resolve_same_file(self, call_name: str, caller_file: str) -> Optional[str]:
"""Resolve a call to a user-defined function in the same file, if any."""
same_file_funcs = self.functions_by_file.get(caller_file, [])
for func_id in same_file_funcs:
func_data = self.functions.get(func_id, {})
fname = func_data.get('name', '')
base_name = fname.split('::')[-1] if '::' in fname else fname
if base_name == call_name:
return func_id
return None

def _resolve_method_on_class(self, class_name: str, call_name: str,
caller_file: str) -> Optional[str]:
"""Resolve call_name to a method DIRECTLY declared on class_name (same file).

Returns the func_id of a method named call_name declared on class_name and
defined in caller_file, else None. No inheritance — this is the single-hop
lookup the walk in _resolve_member_call composes over the base chain.
"""
by_method = self.methods_by_class.get(class_name)
if not by_method:
return None
for func_id in by_method.get(call_name, []):
func_data = self.functions.get(func_id, {})
if func_data.get('file_path', '') == caller_file:
return func_id
return None

def _resolve_member_call(self, call_name: str, caller_file: str,
receiver_type: str) -> Optional[str]:
"""Resolve a member call to the method on the receiver's STATIC type,
walking UP the base-class chain to the first ancestor that defines it.

Sound static-type floor (bug [30]): start at the receiver's declared type
and return its own method if it defines call_name; otherwise walk up its
direct base classes (BFS, cycle-guarded) and resolve to the FIRST ancestor
that declares call_name in the same file. The walk STOPS at the first
definer, so a derived override resolves to itself, never an ancestor.

Deliberately does NOT link derived overrides of an ancestor's virtual
method (a documented non-goal that would create false edges): a call via a
Base* receiver resolves to Base's method only — the static-type floor.

Same-file only: if no class on the chain defines call_name in this
translation unit, returns None so the caller falls back to base-name
resolution (never a wrong-type / unrelated-free-function edge).
"""
visited: Set[str] = set()
queue: List[str] = [receiver_type]
while queue:
cls = queue.pop(0)
if cls in visited:
continue
visited.add(cls)
# First definer on the chain wins (own type before ancestors).
match = self._resolve_method_on_class(cls, call_name, caller_file)
if match:
return match
for base in self.class_bases.get(cls, []):
if base not in visited:
queue.append(base)
return None

def _resolve_call(self, call_name: str, caller_file: str,
receiver_type: Optional[str] = None,
is_member: bool = False,
_alias_chain: Optional[Set[str]] = None) -> Optional[str]:
"""Resolve a function call name to a function ID.

When receiver_type is given (a member call like w.compute() whose receiver
w has a known same-file type), resolve to that type's method FIRST. If
that fails, fall through to the unchanged base-name resolution below.
"""
if receiver_type:
member_match = self._resolve_member_call(call_name, caller_file,
receiver_type)
if member_match:
return member_match

# A user-defined function in the SAME FILE shadows any stdlib/builtin
# of the same name, so it must be checked BEFORE the stdlib filter.
# Scope is deliberately same-file only: a genuine stdlib call (no
# same-file definition) still falls through to _is_stdlib below, so we
# never wrongly link a real stdlib call (e.g. printf/open) to an
# unrelated same-named user function in another file.
same_file_user_func = self._resolve_same_file(call_name, caller_file)
if same_file_user_func:
return same_file_user_func

# A member call (obj->m() / obj.m()) whose receiver type is unknown or
# whose chain defines no such method resolves same-file only: declining
# here keeps the field-expression precision guarantee (never an edge to
# an unrelated cross-file free function of the same name).
if is_member:
return None

if self._is_stdlib(call_name):
return None

# Check for macro aliases
resolved_name = self.macro_aliases.get(call_name, call_name)
if resolved_name != call_name:
# Try resolving the aliased name instead
result = self._resolve_call(resolved_name, caller_file)
if result:
return result
# Guard against cyclic macro aliases (e.g. ``#define A B`` /
# ``#define B A`` -> {"A": "B", "B": "A"}). Without a visited-set
# the recursion below would loop A->B->A->... until RecursionError
# aborted the whole repo's call-graph build.
if _alias_chain is None:
_alias_chain = {call_name}
if resolved_name not in _alias_chain:
_alias_chain.add(resolved_name)
# Try resolving the aliased name instead
result = self._resolve_call(resolved_name, caller_file,
_alias_chain=_alias_chain)
if result:
return result

# 1. Same-file functions
same_file_funcs = self.functions_by_file.get(caller_file, [])
for func_id in same_file_funcs:
func_data = self.functions.get(func_id, {})
fname = func_data.get('name', '')
base_name = fname.split('::')[-1] if '::' in fname else fname
if base_name == call_name:
return func_id
same_file_match = self._resolve_same_file(call_name, caller_file)
if same_file_match:
return same_file_match

# 2. Functions in included headers
included_files = self.include_map.get(caller_file, set())
Expand Down Expand Up @@ -392,10 +592,13 @@ def _extract_calls_regex(self, code: str, caller_id: str) -> Set[str]:
if func_name in ('if', 'while', 'for', 'switch', 'return', 'sizeof',
'typeof', 'alignof', 'offsetof', 'case', 'else'):
continue
if not self._is_stdlib(func_name):
resolved = self._resolve_call(func_name, caller_file)
if resolved:
calls.add(resolved)
# No _is_stdlib gate here: _resolve_call applies the same-file-first
# rule and the stdlib filter internally, so a user function whose
# name collides with a builtin still resolves (same leak as the
# tree-sitter path otherwise).
resolved = self._resolve_call(func_name, caller_file)
if resolved:
calls.add(resolved)

return calls

Expand Down
51 changes: 51 additions & 0 deletions libs/openant-core/tests/test_c_macro_alias_cycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Regression test for F4: C call-graph builder must not crash on cyclic macro aliases.

A function-like ``#define`` pair such as::

#define A(x) B(x)
#define B(x) A(x)

produces ``macro_aliases = {"A": "B", "B": "A"}``. Before the fix,
``_resolve_call`` recursed on the aliased name with no cycle guard, so
resolving ``A`` recursed A->B->A->... until ``RecursionError`` aborted the
entire repository's C call-graph build.
"""

from __future__ import annotations

import pytest

tree_sitter_c = pytest.importorskip("tree_sitter_c")

from parsers.c.call_graph_builder import CallGraphBuilder


def _builder(macro_aliases):
return CallGraphBuilder(
{
"functions": {},
"includes": {},
"macros": {},
"macro_aliases": macro_aliases,
}
)


def test_cyclic_macro_alias_does_not_recurse():
"""Two-node alias cycle must resolve to None, not raise RecursionError."""
builder = _builder({"A": "B", "B": "A"})
# Must not raise; unresolved cyclic alias returns None.
assert builder._resolve_call("A", "foo.c") is None
assert builder._resolve_call("B", "foo.c") is None


def test_longer_macro_alias_cycle_terminates():
"""A 3-node alias cycle must also terminate."""
builder = _builder({"A": "B", "B": "C", "C": "A"})
assert builder._resolve_call("A", "foo.c") is None


def test_self_macro_alias_does_not_recurse():
"""A self-alias is short-circuited by the != guard but must stay safe."""
builder = _builder({"A": "A"})
assert builder._resolve_call("A", "foo.c") is None
Loading