diff --git a/code-review-graph-vscode/src/backend/sqlite.ts b/code-review-graph-vscode/src/backend/sqlite.ts index 15d617a6..4e8e1b4b 100644 --- a/code-review-graph-vscode/src/backend/sqlite.ts +++ b/code-review-graph-vscode/src/backend/sqlite.ts @@ -212,7 +212,7 @@ export class SqliteReader { if (row) { const version = parseInt(row.value, 10); // Must match LATEST_VERSION in code_review_graph/migrations.py - const SUPPORTED_SCHEMA_VERSION = 6; + const SUPPORTED_SCHEMA_VERSION = 8; if (!isNaN(version) && version > SUPPORTED_SCHEMA_VERSION) { return `Database was created with a newer version (schema v${version}). Update the extension.`; } diff --git a/code_review_graph/changes.py b/code_review_graph/changes.py index 33da1978..8ed8b41e 100644 --- a/code_review_graph/changes.py +++ b/code_review_graph/changes.py @@ -154,15 +154,19 @@ def compute_risk_score(store: GraphStore, node: GraphNode) -> float: Scoring factors: - Flow participation: 0.05 per flow membership, capped at 0.25 - Community crossing: 0.05 per caller from a different community, capped at 0.15 - - Test coverage: 0.30 if no TESTED_BY edges, 0.05 if tested + - Test coverage: 0.30 (untested) scaling down to 0.05 (5+ TESTED_BY edges) - Security sensitivity: 0.20 if name matches security keywords - Caller count: callers / 20, capped at 0.10 """ score = 0.0 - # --- Flow participation (cap 0.25) --- - flow_count = store.count_flow_memberships(node.id) - score += min(flow_count * 0.05, 0.25) + # --- Flow participation (cap 0.25), weighted by criticality --- + flow_criticalities = store.get_flow_criticalities_for_node(node.id) + if flow_criticalities: + score += min(sum(flow_criticalities), 0.25) + else: + flow_count = store.count_flow_memberships(node.id) + score += min(flow_count * 0.05, 0.25) # --- Community crossing (cap 0.15) --- callers = store.get_edges_by_target(node.qualified_name) @@ -179,10 +183,10 @@ def compute_risk_score(store: GraphStore, node: GraphNode) -> float: cross_community += 1 score += min(cross_community * 0.05, 0.15) - # --- Test coverage --- - tested_edges = store.get_edges_by_target(node.qualified_name) - has_test = any(e.kind == "TESTED_BY" for e in tested_edges) - score += 0.05 if has_test else 0.30 + # --- Test coverage (direct + transitive) --- + transitive_tests = store.get_transitive_tests(node.qualified_name) + test_count = len(transitive_tests) + score += 0.30 - (min(test_count / 5.0, 1.0) * 0.25) # --- Security sensitivity --- name_lower = node.name.lower() diff --git a/code_review_graph/communities.py b/code_review_graph/communities.py index eef8c4aa..3810927b 100644 --- a/code_review_graph/communities.py +++ b/code_review_graph/communities.py @@ -196,8 +196,19 @@ def _compute_cohesion_batch( return results +def _build_adjacency(edges: list[GraphEdge]) -> dict[str, list[str]]: + """Build adjacency list from edges (one pass over all edges).""" + adj: dict[str, list[str]] = defaultdict(list) + for e in edges: + adj[e.source_qualified].append(e.target_qualified) + adj[e.target_qualified].append(e.source_qualified) + return adj + + def _compute_cohesion( - member_qns: set[str], all_edges: list[GraphEdge] + member_qns: set[str], + all_edges: list[GraphEdge], + adj: dict[str, list[str]] | None = None, ) -> float: """Compute cohesion: internal_edges / (internal_edges + external_edges). @@ -213,7 +224,10 @@ def _compute_cohesion( def _detect_leiden( - nodes: list[GraphNode], edges: list[GraphEdge], min_size: int + nodes: list[GraphNode], + edges: list[GraphEdge], + min_size: int, + adj: dict[str, list[str]] | None = None, ) -> list[dict[str, Any]]: """Detect communities using Leiden algorithm via igraph. @@ -251,11 +265,18 @@ def _detect_leiden( weights.append(EDGE_WEIGHTS.get(e.kind, 0.5)) if not edge_list: - return _detect_file_based(nodes, edges, min_size) + return _detect_file_based(nodes, edges, min_size, adj=adj) g.add_edges(edge_list) g.es["weight"] = weights + # Run Leiden -- scale resolution inversely with graph size to get + # coarser clusters on large repos. Default resolution=1.0 produces + # thousands of tiny communities for 30k+ node graphs. + import math + n_nodes = g.vcount() + resolution = max(0.05, 1.0 / math.log10(max(n_nodes, 10))) + logger.info( "Running Leiden on %d nodes, %d edges...", g.vcount(), g.ecount(), @@ -264,6 +285,7 @@ def _detect_leiden( partition = g.community_leiden( objective_function="modularity", weights="weight", + resolution=resolution, n_iterations=2, ) @@ -311,28 +333,73 @@ def _detect_leiden( def _detect_file_based( - nodes: list[GraphNode], edges: list[GraphEdge], min_size: int + nodes: list[GraphNode], + edges: list[GraphEdge], + min_size: int, + adj: dict[str, list[str]] | None = None, ) -> list[dict[str, Any]]: - """Group nodes by file_path when igraph is not available.""" - by_file: dict[str, list[GraphNode]] = defaultdict(list) + """Group nodes by directory when Leiden is unavailable or over-fragments. + + Strips the longest common directory prefix from all file paths, then + adaptively picks a grouping depth that yields 10-200 communities. + """ + # Collect all directory paths (normalized, without filename) + all_dir_parts: list[list[str]] = [] for n in nodes: - by_file[n.file_path].append(n) + parts = n.file_path.replace("\\", "/").split("/") + all_dir_parts.append([p for p in parts[:-1] if p]) + + # Find the longest common prefix among directory parts + prefix_len = 0 + if all_dir_parts: + shortest = min(len(p) for p in all_dir_parts) + for i in range(shortest): + seg = all_dir_parts[0][i] + if all(p[i] == seg for p in all_dir_parts): + prefix_len = i + 1 + else: + break + + def _group_at_depth(depth: int) -> dict[str, list[GraphNode]]: + groups: dict[str, list[GraphNode]] = defaultdict(list) + for n in nodes: + parts = n.file_path.replace("\\", "/").split("/") + dir_parts = [p for p in parts[:-1] if p] + remainder = dir_parts[prefix_len:] + if remainder: + key = "/".join(remainder[:depth]) + else: + key = parts[-1].rsplit(".", 1)[0] if parts else "root" + groups[key].append(n) + return groups + + # Try increasing depths until we get 10-200 qualifying groups + max_depth = max((len(p) - prefix_len for p in all_dir_parts), default=0) + best_groups = _group_at_depth(1) # depth=1 always works (file stem fallback) + for depth in range(1, max_depth + 1): + groups = _group_at_depth(depth) + qualifying = sum(1 for v in groups.values() if len(v) >= min_size) + best_groups = groups + if qualifying >= 10: + break + + by_dir = best_groups # Pre-filter to communities meeting min_size and collect their member # sets so we can batch-compute all cohesions in a single O(edges) pass. # Without this, per-community cohesion is O(edges * files), which makes # community detection effectively hang on large repos. pending: list[tuple[str, list[GraphNode], set[str]]] = [] - for file_path, members in by_file.items(): + for dir_path, members in by_dir.items(): if len(members) < min_size: continue member_qns = {m.qualified_name for m in members} - pending.append((file_path, members, member_qns)) + pending.append((dir_path, members, member_qns)) cohesions = _compute_cohesion_batch([p[2] for p in pending], edges) communities: list[dict[str, Any]] = [] - for (file_path, members, member_qns), cohesion in zip(pending, cohesions): + for (dir_path, members, member_qns), cohesion in zip(pending, cohesions): lang_counts = Counter(m.language for m in members if m.language) dominant_lang = lang_counts.most_common(1)[0][0] if lang_counts else "" name = _generate_community_name(members) @@ -343,7 +410,7 @@ def _detect_file_based( "size": len(members), "cohesion": round(cohesion, 4), "dominant_language": dominant_lang, - "description": f"File-based community: {file_path}", + "description": f"Directory-based community: {dir_path}", "members": [m.qualified_name for m in members], "member_qns": member_qns, }) @@ -374,28 +441,10 @@ def detect_communities( """ # Gather all nodes (exclude File nodes to focus on code entities) all_edges = store.get_all_edges() - all_files = store.get_all_files() + unique_nodes = store.get_all_nodes(exclude_files=True) - logger.info("Loading nodes from %d files...", len(all_files)) - - nodes: list[GraphNode] = [] - for fp in all_files: - nodes.extend(store.get_nodes_by_file(fp)) - - # Also gather nodes from files referenced in edges but not in all_files - edge_files: set[str] = set() - for e in all_edges: - edge_files.add(e.file_path) - for fp in edge_files - set(all_files): - nodes.extend(store.get_nodes_by_file(fp)) - - # Deduplicate by qualified_name - seen_qns: set[str] = set() - unique_nodes: list[GraphNode] = [] - for n in nodes: - if n.qualified_name not in seen_qns: - seen_qns.add(n.qualified_name) - unique_nodes.append(n) + # Build adjacency index once for fast cohesion computation + adj = _build_adjacency(all_edges) logger.info( "Loaded %d unique nodes, %d edges", @@ -404,10 +453,10 @@ def detect_communities( if IGRAPH_AVAILABLE: logger.info("Detecting communities with Leiden algorithm (igraph)") - results = _detect_leiden(unique_nodes, all_edges, min_size) + results = _detect_leiden(unique_nodes, all_edges, min_size, adj=adj) else: logger.info("igraph not available, using file-based community detection") - results = _detect_file_based(unique_nodes, all_edges, min_size) + results = _detect_file_based(unique_nodes, all_edges, min_size, adj=adj) # Convert member_qns (internal set) to a list for serialization safety, # then strip it from the returned dicts to avoid leaking internal state. @@ -569,6 +618,17 @@ def get_communities( return communities +_TEST_COMMUNITY_RE = re.compile( + r"(^test[-/]|[-/]test([:/]|$)|it:should|describe:|spec[-/]|[-/]spec$)", + re.IGNORECASE, +) + + +def _is_test_community(name: str) -> bool: + """Return True if a community name indicates it is test-dominated.""" + return bool(_TEST_COMMUNITY_RE.search(name)) + + def get_architecture_overview(store: GraphStore) -> dict[str, Any]: """Generate an architecture overview based on community structure. @@ -596,6 +656,10 @@ def get_architecture_overview(store: GraphStore) -> dict[str, Any]: cross_counts: Counter[tuple[int, int]] = Counter() for e in all_edges: + # TESTED_BY edges are expected cross-community coupling (test → code), + # not an architectural smell. + if e.kind == "TESTED_BY": + continue src_comm = node_to_community.get(e.source_qualified) tgt_comm = node_to_community.get(e.target_qualified) if ( @@ -613,13 +677,17 @@ def get_architecture_overview(store: GraphStore) -> dict[str, Any]: "target": _sanitize_name(e.target_qualified), }) - # Generate warnings for high coupling + # Generate warnings for high coupling, skipping test-dominated pairs. warnings: list[str] = [] comm_name_map = {c.get("id", 0): c["name"] for c in communities} for (c1, c2), count in cross_counts.most_common(): if count > 10: name1 = comm_name_map.get(c1, f"community-{c1}") name2 = comm_name_map.get(c2, f"community-{c2}") + # Skip pairs where either community is test-dominated — coupling + # between test and production code is expected, not architectural. + if _is_test_community(name1) or _is_test_community(name2): + continue warnings.append( f"High coupling ({count} edges) between " f"'{name1}' and '{name2}'" diff --git a/code_review_graph/flows.py b/code_review_graph/flows.py index 1cfb4964..193171e4 100644 --- a/code_review_graph/flows.py +++ b/code_review_graph/flows.py @@ -25,14 +25,46 @@ # Decorator patterns that indicate a function is a framework entry point. _FRAMEWORK_DECORATOR_PATTERNS: list[re.Pattern[str]] = [ - re.compile(r"app\.(get|post|put|delete|patch|route|websocket)", re.IGNORECASE), + # Python web frameworks + re.compile(r"app\.(get|post|put|delete|patch|route|websocket|on_event)", re.IGNORECASE), re.compile(r"router\.(get|post|put|delete|patch|route)", re.IGNORECASE), re.compile(r"blueprint\.(route|before_request|after_request)", re.IGNORECASE), + re.compile(r"(before|after)_(request|response)", re.IGNORECASE), + # CLI frameworks re.compile(r"click\.(command|group)", re.IGNORECASE), - re.compile(r"celery\.(task|shared_task)", re.IGNORECASE), + re.compile(r"\w+\.(command|group)\b", re.IGNORECASE), # Click subgroups: @mygroup.command() + # Pydantic validators/serializers + re.compile(r"(field|model)_(serializer|validator)", re.IGNORECASE), + # Task queues + re.compile(r"(celery\.)?(task|shared_task|periodic_task)", re.IGNORECASE), + # Django + re.compile(r"receiver", re.IGNORECASE), re.compile(r"api_view", re.IGNORECASE), re.compile(r"\baction\b", re.IGNORECASE), - re.compile(r"@(Get|Post|Put|Delete|Patch|RequestMapping)", re.IGNORECASE), + # Testing + re.compile(r"pytest\.(fixture|mark)"), + re.compile(r"(override_settings|modify_settings)", re.IGNORECASE), + # SQLAlchemy / event systems + re.compile(r"(event\.)?listens_for", re.IGNORECASE), + # Java Spring + re.compile(r"(Get|Post|Put|Delete|Patch|RequestMapping)Mapping", re.IGNORECASE), + re.compile(r"(Scheduled|EventListener|Bean|Configuration)", re.IGNORECASE), + # JS/TS frameworks + re.compile(r"(Component|Injectable|Controller|Module|Guard|Pipe)", re.IGNORECASE), + re.compile(r"(Subscribe|Mutation|Query|Resolver)", re.IGNORECASE), + # Express / Koa / Hono route handlers + re.compile(r"(app|router)\.(get|post|put|delete|patch|use|all)\b"), + # Android lifecycle + re.compile(r"@(Override|OnLifecycleEvent|Composable)", re.IGNORECASE), + # Kotlin coroutines / Android ViewModel + re.compile(r"(HiltViewModel|AndroidEntryPoint|Inject)", re.IGNORECASE), + # AI/agent frameworks (pydantic-ai, langchain, etc.) + re.compile(r"\w+\.(tool|tool_plain|system_prompt|result_validator)\b", re.IGNORECASE), + re.compile(r"^tool\b"), # bare @tool (LangChain, etc.) + # Middleware and exception handlers (Starlette, FastAPI, Sanic) + re.compile(r"\w+\.(middleware|exception_handler|on_exception)\b", re.IGNORECASE), + # Generic route decorator (Flask blueprints: @bp.route, @auth_bp.route, etc.) + re.compile(r"\w+\.route\b", re.IGNORECASE), ] # Name patterns that indicate conventional entry points. @@ -43,6 +75,38 @@ re.compile(r"^Test[A-Z]"), re.compile(r"^on_"), re.compile(r"^handle_"), + # Lambda / serverless handler functions (wired via config, not code calls) + re.compile(r"^handler$"), + re.compile(r"^handle$"), + re.compile(r"^lambda_handler$"), + # Alembic migration entry points + re.compile(r"^upgrade$"), + re.compile(r"^downgrade$"), + # FastAPI lifecycle / dependency injection + re.compile(r"^lifespan$"), + re.compile(r"^get_db$"), + # Android Activity/Fragment lifecycle + re.compile(r"^on(Create|Start|Resume|Pause|Stop|Destroy|Bind|Receive)"), + # Servlet / JAX-RS + re.compile(r"^do(Get|Post|Put|Delete)$"), + # Python BaseHTTPRequestHandler + re.compile(r"^do_(GET|POST|PUT|DELETE|PATCH|HEAD|OPTIONS)$"), + re.compile(r"^log_message$"), + # Express middleware signature + re.compile(r"^(middleware|errorHandler)$"), + # Angular lifecycle hooks + re.compile( + r"^ng(OnInit|OnChanges|OnDestroy|DoCheck" + r"|AfterContentInit|AfterContentChecked|AfterViewInit|AfterViewChecked)$" + ), + # Angular Pipe / ControlValueAccessor / Guards / Resolvers + re.compile(r"^(transform|writeValue|registerOnChange|registerOnTouched|setDisabledState)$"), + re.compile(r"^(canActivate|canDeactivate|canActivateChild|canLoad|canMatch|resolve)$"), + # React class component lifecycle + re.compile( + r"^(componentDidMount|componentDidUpdate|componentWillUnmount" + r"|shouldComponentUpdate|render)$" + ), ] @@ -73,13 +137,29 @@ def _matches_entry_name(node: GraphNode) -> bool: return False -def detect_entry_points(store: GraphStore) -> list[GraphNode]: +_TEST_FILE_RE = re.compile( + r"([\\/]__tests__[\\/]|\.spec\.[jt]sx?$|\.test\.[jt]sx?$|[\\/]test_[^/\\]*\.py$)", +) + + +def _is_test_file(file_path: str) -> bool: + """Return True if *file_path* looks like a test file.""" + return bool(_TEST_FILE_RE.search(file_path)) + + +def detect_entry_points( + store: GraphStore, + include_tests: bool = False, +) -> list[GraphNode]: """Find functions that are entry points in the graph. An entry point is a Function/Test node that either: 1. Has no incoming CALLS edges (true root), or 2. Has a framework decorator (e.g. ``@app.get``), or 3. Matches a conventional name pattern (``main``, ``test_*``, etc.). + + When *include_tests* is False (the default), Test nodes are excluded so + that flow analysis focuses on production entry points. """ # Build a set of all qualified names that are CALLS targets. called_qnames = store.get_all_call_targets() @@ -91,6 +171,9 @@ def detect_entry_points(store: GraphStore) -> list[GraphNode]: seen_qn: set[str] = set() for node in candidate_nodes: + if not include_tests and (node.is_test or _is_test_file(node.file_path)): + continue + is_entry = False # True root: no one calls this function. @@ -189,7 +272,11 @@ def _trace_single_flow( return flow -def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: +def trace_flows( + store: GraphStore, + max_depth: int = 15, + include_tests: bool = False, +) -> list[dict]: """Trace execution flows from every entry point via forward BFS. Returns a list of flow dicts, each containing: @@ -203,7 +290,7 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: - files: list of distinct file paths - criticality: computed criticality score (0.0-1.0) """ - entry_points = detect_entry_points(store) + entry_points = detect_entry_points(store, include_tests=include_tests) flows: list[dict] = [] for ep in entry_points: diff --git a/code_review_graph/graph.py b/code_review_graph/graph.py index b2d75bba..71540ac5 100644 --- a/code_review_graph/graph.py +++ b/code_review_graph/graph.py @@ -260,6 +260,24 @@ def store_file_nodes_edges( raise self._invalidate_cache() + def store_file_batch( + self, batch: list[tuple[str, list[NodeInfo], list[EdgeInfo], str]] + ) -> None: + """Atomically replace data for a batch of files in one transaction.""" + self._conn.execute("BEGIN IMMEDIATE") + try: + for file_path, nodes, edges, fhash in batch: + self.remove_file_data(file_path) + for node in nodes: + self.upsert_node(node, file_hash=fhash) + for edge in edges: + self.upsert_edge(edge) + self._conn.commit() + except BaseException: + self._conn.rollback() + raise + self._invalidate_cache() + def set_metadata(self, key: str, value: str) -> None: self._conn.execute( "INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", (key, value) @@ -290,6 +308,16 @@ def get_nodes_by_file(self, file_path: str) -> list[GraphNode]: ).fetchall() return [self._row_to_node(r) for r in rows] + def get_all_nodes(self, exclude_files: bool = True) -> list[GraphNode]: + """Return all nodes, optionally excluding File nodes.""" + if exclude_files: + rows = self._conn.execute( + "SELECT * FROM nodes WHERE kind != 'File'" + ).fetchall() + else: + rows = self._conn.execute("SELECT * FROM nodes").fetchall() + return [self._row_to_node(r) for r in rows] + def get_edges_by_source(self, qualified_name: str) -> list[GraphEdge]: rows = self._conn.execute( "SELECT * FROM edges WHERE source_qualified = ?", (qualified_name,) @@ -317,6 +345,182 @@ def search_edges_by_target_name(self, name: str, kind: str = "CALLS") -> list[Gr ).fetchall() return [self._row_to_edge(r) for r in rows] + def get_transitive_tests( + self, qualified_name: str, max_depth: int = 1, + ) -> list[dict]: + """Find tests covering a node, including indirect (transitive) coverage. + + 1. Direct: TESTED_BY edges targeting this node (+ bare-name fallback). + 2. Indirect: follow outgoing CALLS edges up to *max_depth* hops, + then collect TESTED_BY edges on each callee. + + Returns a list of dicts with node fields plus ``indirect: bool``. + """ + conn = self._conn + seen: set[str] = set() + results: list[dict] = [] + + # If the input is a class, expand to its methods first. + input_qns = [qualified_name] + row = conn.execute( + "SELECT kind FROM nodes WHERE qualified_name = ?", + (qualified_name,), + ).fetchone() + if row and row["kind"] == "Class": + for mrow in conn.execute( + "SELECT target_qualified FROM edges " + "WHERE source_qualified = ? AND kind = 'CONTAINS'", + (qualified_name,), + ).fetchall(): + input_qns.append(mrow["target_qualified"]) + + def _node_dict(qn: str, indirect: bool) -> dict | None: + row = conn.execute( + "SELECT * FROM nodes WHERE qualified_name = ?", (qn,) + ).fetchone() + if not row: + return None + return { + "name": row["name"], + "qualified_name": row["qualified_name"], + "file_path": row["file_path"], + "kind": row["kind"], + "indirect": indirect, + } + + # Direct TESTED_BY + for qn in input_qns: + for row in conn.execute( + "SELECT source_qualified FROM edges " + "WHERE target_qualified = ? AND kind = 'TESTED_BY'", + (qn,), + ).fetchall(): + src = row["source_qualified"] + if src not in seen: + seen.add(src) + d = _node_dict(src, indirect=False) + if d: + results.append(d) + + # Bare-name fallback for direct + bare = qualified_name.rsplit("::", 1)[-1] if "::" in qualified_name else qualified_name + for row in conn.execute( + "SELECT source_qualified FROM edges " + "WHERE target_qualified = ? AND kind = 'TESTED_BY'", + (bare,), + ).fetchall(): + src = row["source_qualified"] + if src not in seen: + seen.add(src) + d = _node_dict(src, indirect=False) + if d: + results.append(d) + + # Transitive: follow CALLS edges, then collect TESTED_BY on callees + frontier = set(input_qns) + for _ in range(max_depth): + next_frontier: set[str] = set() + for qn in frontier: + for row in conn.execute( + "SELECT target_qualified FROM edges " + "WHERE source_qualified = ? AND kind = 'CALLS'", + (qn,), + ).fetchall(): + next_frontier.add(row["target_qualified"]) + for callee in next_frontier: + for row in conn.execute( + "SELECT source_qualified FROM edges " + "WHERE target_qualified = ? AND kind = 'TESTED_BY'", + (callee,), + ).fetchall(): + src = row["source_qualified"] + if src not in seen: + seen.add(src) + d = _node_dict(src, indirect=True) + if d: + results.append(d) + frontier = next_frontier + + return results + + def resolve_bare_call_targets(self) -> int: + """Batch-resolve bare-name CALLS targets using the global node table. + + After parsing, some CALLS edges have bare targets (no ``::`` separator) + because the parser couldn't resolve cross-file. This method matches + them against nodes and updates unambiguous matches in-place. + + Disambiguation strategy: + 1. Single node with that name -> resolve directly + 2. Multiple candidates -> prefer one whose file is imported by the + source file (via IMPORTS_FROM edges) + + Returns the number of resolved edges. + """ + conn = self._conn + + bare_edges = conn.execute( + "SELECT id, source_qualified, target_qualified, file_path " + "FROM edges WHERE kind = 'CALLS' AND target_qualified NOT LIKE '%::%'" + ).fetchall() + if not bare_edges: + return 0 + + # bare_name -> list of qualified_names + node_lookup: dict[str, list[str]] = {} + for row in conn.execute( + "SELECT name, qualified_name FROM nodes " + "WHERE kind IN ('Function', 'Test', 'Class')" + ).fetchall(): + node_lookup.setdefault(row["name"], []).append(row["qualified_name"]) + + # source_file -> set of imported files (for disambiguation) + import_targets: dict[str, set[str]] = {} + for row in conn.execute( + "SELECT DISTINCT file_path, target_qualified FROM edges " + "WHERE kind = 'IMPORTS_FROM'" + ).fetchall(): + target = row["target_qualified"] + target_file = target.split("::", 1)[0] if "::" in target else target + import_targets.setdefault(row["file_path"], set()).add(target_file) + + resolved = 0 + for edge in bare_edges: + bare_name = edge["target_qualified"] + candidates = node_lookup.get(bare_name, []) + if not candidates: + continue + + if len(candidates) == 1: + qualified = candidates[0] + else: + # Disambiguate via imports + src_qn = edge["source_qualified"] + src_file = ( + src_qn.split("::", 1)[0] if "::" in src_qn + else edge["file_path"] + ) + imported_files = import_targets.get(src_file, set()) + imported = [ + c for c in candidates + if c.split("::", 1)[0] in imported_files + ] + if len(imported) == 1: + qualified = imported[0] + else: + continue + + conn.execute( + "UPDATE edges SET target_qualified = ? WHERE id = ?", + (qualified, edge["id"]), + ) + resolved += 1 + + if resolved: + conn.commit() + logger.info("Resolved %d bare-name CALLS targets", resolved) + return resolved + def get_all_files(self) -> list[str]: rows = self._conn.execute( "SELECT DISTINCT file_path FROM nodes WHERE kind = 'File'" @@ -324,24 +528,43 @@ def get_all_files(self) -> list[str]: return [r["file_path"] for r in rows] def search_nodes(self, query: str, limit: int = 20) -> list[GraphNode]: - """Keyword search across node names with multi-word AND logic. + """Keyword search across node names. - Each word in the query must match independently (case-insensitive) - against the node name or qualified name. For example, - ``"firebase auth"`` matches ``verify_firebase_token`` and - ``FirebaseAuth`` but not ``get_user``. + Tries FTS5 first (fast, tokenized matching), then falls back to + LIKE-based substring search when FTS5 returns no results. """ - words = query.lower().split() + words = query.split() if not words: return [] + # Phase 1: FTS5 search (uses the indexed nodes_fts table) + try: + if len(words) == 1: + fts_query = '"' + query.replace('"', '""') + '"' + else: + fts_query = " AND ".join( + '"' + w.replace('"', '""') + '"' for w in words + ) + rows = self._conn.execute( + "SELECT n.* FROM nodes_fts f " + "JOIN nodes n ON f.rowid = n.id " + "WHERE nodes_fts MATCH ? LIMIT ?", + (fts_query, limit), + ).fetchall() + if rows: + return [self._row_to_node(r) for r in rows] + except Exception: # nosec B110 - FTS5 table may not exist on older schemas + pass + + # Phase 2: LIKE fallback (substring matching) conditions: list[str] = [] params: list[str | int] = [] for word in words: + w = word.lower() conditions.append( "(LOWER(name) LIKE ? OR LOWER(qualified_name) LIKE ?)" ) - params.extend([f"%{word}%", f"%{word}%"]) + params.extend([f"%{w}%", f"%{w}%"]) where = " AND ".join(conditions) sql = f"SELECT * FROM nodes WHERE {where} LIMIT ?" # nosec B608 @@ -699,6 +922,16 @@ def count_flow_memberships(self, node_id: int) -> int: ).fetchone() return row["cnt"] if row else 0 + def get_flow_criticalities_for_node(self, node_id: int) -> list[float]: + """Return criticality values for all flows a node participates in.""" + rows = self._conn.execute( + "SELECT f.criticality FROM flows f " + "JOIN flow_memberships fm ON fm.flow_id = f.id " + "WHERE fm.node_id = ?", + (node_id,), + ).fetchall() + return [r["criticality"] for r in rows] + def get_node_community_id(self, node_id: int) -> int | None: """Return the ``community_id`` for a node, or ``None``.""" row = self._conn.execute( diff --git a/code_review_graph/migrations.py b/code_review_graph/migrations.py index 9da04885..6ef33ac2 100644 --- a/code_review_graph/migrations.py +++ b/code_review_graph/migrations.py @@ -203,6 +203,20 @@ def _migrate_v6(conn: sqlite3.Connection) -> None: "(community_summaries, flow_snapshots, risk_index)") +def _migrate_v7(conn: sqlite3.Connection) -> None: + """v7: Reserved (upstream PR #127). No-op for forward compatibility.""" + logger.info("Migration v7: reserved (no-op)") + + +def _migrate_v8(conn: sqlite3.Connection) -> None: + """v8: Add composite index on edges for upsert_edge performance.""" + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_edges_composite + ON edges(kind, source_qualified, target_qualified, file_path, line) + """) + logger.info("Migration v8: created composite edge index") + + # --------------------------------------------------------------------------- # Migration registry # --------------------------------------------------------------------------- @@ -213,6 +227,8 @@ def _migrate_v6(conn: sqlite3.Connection) -> None: 4: _migrate_v4, 5: _migrate_v5, 6: _migrate_v6, + 7: _migrate_v7, + 8: _migrate_v8, } LATEST_VERSION = max(MIGRATIONS.keys()) diff --git a/code_review_graph/refactor.py b/code_review_graph/refactor.py index 4dce1603..32994603 100644 --- a/code_review_graph/refactor.py +++ b/code_review_graph/refactor.py @@ -8,7 +8,9 @@ from __future__ import annotations +import functools import logging +import re import threading import time import uuid @@ -20,6 +22,28 @@ logger = logging.getLogger(__name__) +# Base class names that indicate a framework-managed class (ORM models, +# Pydantic schemas, settings). Classes inheriting from these are invoked +# via metaclass/framework magic and should not be flagged as dead code. +_FRAMEWORK_BASE_CLASSES = frozenset({ + "Base", "DeclarativeBase", "Model", "BaseModel", "BaseSettings", + "db.Model", "TableBase", + # AWS CDK constructs -- instantiated by CDK app wiring, not explicit CALLS. + "Stack", "NestedStack", "Construct", "Resource", +}) + +# Class name suffixes that indicate CDK/IaC constructs. +# These are instantiated by framework wiring, not direct CALLS edges. +# Used as fallback when INHERITS edges to external base classes are absent. +_CDK_CLASS_SUFFIXES = ("Stack", "Construct", "Pipeline", "Resources", "Layer") + +# Patterns for mock/stub variables in test files that should not be flagged dead. +_MOCK_NAME_RE = re.compile( + r"^(mock[A-Z_]|Mock[A-Z]|createMock[A-Z])|" # mockDynamoClient, MockService, createMockX + r"(Mock|Stub|Fake|Spy)$", # s3ClientMock, dbStub + re.IGNORECASE, +) + # --------------------------------------------------------------------------- # Thread-safe pending refactors storage # --------------------------------------------------------------------------- @@ -173,6 +197,46 @@ def _is_entry_point(node: Any) -> bool: return False +# Matches identifiers inside type annotations (e.g. "GoalCreate" in +# "body: GoalCreate", "Optional[UserResponse]", "list[Item]"). +_TEST_FILE_RE = re.compile( + r"([\\/]__tests__[\\/]|\.spec\.[jt]sx?$|\.test\.[jt]sx?$|[\\/]test_[^/\\]*\.py$" + r"|[\\/]e2e[_-]?tests?[\\/]|[\\/]test[_-]utils?[\\/])", +) + + +def _is_test_file(file_path: str) -> bool: + """Return True if *file_path* looks like a test file.""" + return bool(_TEST_FILE_RE.search(file_path)) + + +_MIN_PKG_SEGMENT_LEN = 4 # ignore short dirs like "src", "lib", "app" + + +@functools.lru_cache(maxsize=4096) +def _path_segments(file_path: str) -> tuple[str, ...]: + """Return directory segments long enough to serve as package-name anchors.""" + parts = file_path.replace("\\", "/").split("/") + return tuple( + p for p in parts[:-1] # skip the filename itself + if len(p) >= _MIN_PKG_SEGMENT_LEN and p not in ("home", "src", "lib", "app") + ) + + +_TYPE_IDENT_RE = re.compile(r"[A-Z][A-Za-z0-9_]*") + + +def _collect_type_referenced_names(store: GraphStore) -> set[str]: + """Collect class names that appear in function params or return types.""" + funcs = store.get_nodes_by_kind(kinds=["Function", "Test"]) + names: set[str] = set() + for f in funcs: + for text in (f.params, f.return_type): + if text: + names.update(_TYPE_IDENT_RE.findall(text)) + return names + + def find_dead_code( store: GraphStore, kind: Optional[str] = None, @@ -207,34 +271,297 @@ def find_dead_code( file_pattern=file_pattern, ) + # Build set of class names referenced in function type annotations. + type_ref_names = _collect_type_referenced_names(store) + + # Build class hierarchy: class_qualified_name -> [bare_base_names] + class_bases: dict[str, list[str]] = {} + conn = store._conn + for row in conn.execute( + "SELECT source_qualified, target_qualified FROM edges WHERE kind = 'INHERITS'" + ).fetchall(): + base = row[1].rsplit("::", 1)[-1] if "::" in row[1] else row[1] + class_bases.setdefault(row[0], []).append(base) + + # Build import graph: file_path -> set of file_paths it imports from. + # Used to filter bare-name caller matches to plausible callers. + importer_files: dict[str, set[str]] = {} + for row in conn.execute( + "SELECT file_path, target_qualified FROM edges WHERE kind = 'IMPORTS_FROM'" + ).fetchall(): + importer_files.setdefault(row[0], set()).add(row[1]) + + # Build set of globally unique names (only one non-test node with that name). + # For unique names, any bare-name CALLS edge is reliable — no ambiguity. + name_counts: dict[str, int] = {} + for row in conn.execute( + "SELECT name, COUNT(*) FROM nodes " + "WHERE kind IN ('Function', 'Class') AND is_test = 0 " + "GROUP BY name" + ).fetchall(): + name_counts[row[0]] = row[1] + + def _is_plausible_caller( + edge_file: str, node_file: str, node_name: str = "", + ) -> bool: + """A bare-name edge is plausible if it comes from the same file, + from a file that has an IMPORTS_FROM edge whose target matches + the node's file path, or the name is globally unique (no ambiguity).""" + if edge_file == node_file: + return True + # Unique names (only one definition) have no ambiguity -- accept all callers. + if node_name and name_counts.get(node_name, 0) == 1: + return True + for imp_target in importer_files.get(edge_file, ()): + # Strip "::name" suffix — workspace-resolved imports may include it + imp_path = imp_target.split("::")[0] if "::" in imp_target else imp_target + # __init__.py represents its parent package directory + if imp_path.endswith("/__init__.py"): + imp_dir = imp_path[:-12] # strip "/__init__.py" + if node_file.startswith(imp_dir + "/"): + return True + if imp_path.startswith(node_file) or node_file.startswith(imp_path + "/"): + return True + # 2-hop: edge_file imports X, X re-exports from node_file (barrel files) + for imp2 in importer_files.get(imp_target, ()): + imp2_path = imp2.split("::")[0] if "::" in imp2 else imp2 + if imp2_path.endswith("/__init__.py"): + imp2_dir = imp2_path[:-12] + if node_file.startswith(imp2_dir + "/"): + return True + if imp2_path.startswith(node_file) or node_file.startswith(imp2_path + "/"): + return True + # Package-alias heuristic: monorepo imports like "@scope/pkg-name" + # contain the directory name of the target package. Check if the + # import target string contains a significant directory segment from + # the node's file path (e.g. "lambda-common" in both the import + # "@cova-utils/lambda-common" and the path "libraries/lambda-common/..."). + if not imp_target.startswith("/"): + # imp_target is a package specifier, not a file path + for seg in _path_segments(node_file): + if seg in imp_target: + return True + return False + dead: list[dict[str, Any]] = [] for node in candidates: - # Skip test nodes. - if node.is_test: + # Skip test nodes and anything defined in test files. + if node.is_test or _is_test_file(node.file_path): continue + # Skip ambient type declarations (.d.ts) — they describe external APIs. + if node.file_path.endswith(".d.ts"): + continue + + # Skip dunder methods -- invoked by runtime, never have explicit callers. + if node.name.startswith("__") and node.name.endswith("__"): + continue + + # Skip JS/TS/Java constructors -- invoked via `new ClassName()`, which + # creates a CALLS edge to the class, not to `constructor`. + if node.name == "constructor" and node.parent_name: + continue + + # Skip mock/stub variables in test files -- these are test helpers + # referenced via variable assignment, not function calls. + if node.is_test or _is_test_file(node.file_path): + if _MOCK_NAME_RE.search(node.name): + continue + # Skip entry points (by name pattern or decorator, not just "uncalled"). if _is_entry_point(node): continue # Check for callers (CALLS), test refs (TESTED_BY), importers (IMPORTS_FROM), - # and value references (REFERENCES — function-as-value in maps, arrays, etc.). + # and value references (REFERENCES -- function-as-value in maps, arrays, etc.). + + # Skip classes referenced in type annotations (Pydantic schemas, etc.). + if node.kind == "Class" and node.name in type_ref_names: + continue + + # Skip Angular/NestJS decorated classes -- they are framework-managed + # and instantiated by the DI container, not direct CALLS edges. + if node.kind == "Class" and _has_framework_decorator(node): + continue + + # Skip classes (and their methods) inheriting from known framework bases. + _is_framework_class = False + _check_qn = node.qualified_name if node.kind == "Class" else ( + node.qualified_name.rsplit(".", 1)[0] if node.parent_name else None + ) + if _check_qn: + outgoing = store.get_edges_by_source(_check_qn) + base_names = { + e.target_qualified.rsplit("::", 1)[-1] + for e in outgoing if e.kind == "INHERITS" + } + if base_names & _FRAMEWORK_BASE_CLASSES: + _is_framework_class = True + if node.kind == "Class": + if _is_framework_class: + continue + # Fallback: CDK class name suffixes (no INHERITS edge for external bases) + if any(node.name.endswith(s) for s in _CDK_CLASS_SUFFIXES): + continue + if node.kind == "Function" and _is_framework_class: + continue + # Also skip methods whose parent class name matches CDK suffixes + # (fallback for external base classes without INHERITS edges). + if ( + node.kind == "Function" + and node.parent_name + and any(node.parent_name.endswith(s) for s in _CDK_CLASS_SUFFIXES) + ): + continue + + # Skip decorated functions/classes that are invoked implicitly rather + # than via explicit CALLS edges. + decorators = node.extra.get("decorators", ()) + if isinstance(decorators, (list, tuple)) and decorators: + if node.kind in ("Function", "Test"): + # @property -- invoked via attribute access + # @abstractmethod -- polymorphic dispatch, never called directly + # @classmethod/@staticmethod -- called via Class.method() + if any( + d in ("property", "abstractmethod", "classmethod", "staticmethod") + or d.endswith(".abstractmethod") + # Angular @HostListener -- method called by framework event system + or d.startswith("HostListener") + for d in decorators + ): + continue + if node.kind == "Class": + # @dataclass classes are instantiated as types, not via CALLS + if any("dataclass" in d for d in decorators): + continue + + # Skip methods that override an @abstractmethod in a base class -- + # they are called polymorphically via the base class reference. + if node.kind == "Function" and node.parent_name: + parent_qn = node.qualified_name.rsplit(".", 1)[0] + parent_edges = store.get_edges_by_source(parent_qn) + base_class_names = [ + e.target_qualified for e in parent_edges if e.kind == "INHERITS" + ] + for base_name in base_class_names: + # Try fully-qualified base first, then bare name match + base_method_qn = f"{base_name}.{node.name}" + base_nodes = store.get_node(base_method_qn) + if base_nodes is None: + # Base class may be bare name -- search in same file + base_method_qn2 = ( + node.file_path + "::" + base_name + "." + node.name + ) + base_nodes = store.get_node(base_method_qn2) + if base_nodes is not None: + base_decos = base_nodes.extra.get("decorators", ()) + if isinstance(base_decos, (list, tuple)) and any( + "abstractmethod" in d for d in base_decos + ): + break + else: + base_name = None # no abstract override found + if base_name is not None: + continue + incoming = store.get_edges_by_target(node.qualified_name) + # Also check class-qualified edges (e.g. "ClassName::method") which + # lack the file-path prefix used in node.qualified_name. + if not any(e.kind == "CALLS" for e in incoming) and node.parent_name: + class_qn = f"{node.parent_name}::{node.name}" + incoming = incoming + store.get_edges_by_target(class_qn) + # Also check bare-name and partially-qualified edges. + # CALLS targets may be bare ("funcName"), class-qualified + # ("Class::method"), or workspace-qualified ("pkg/dir::funcName"). + if not any(e.kind == "CALLS" for e in incoming): + bare = store.search_edges_by_target_name(node.name, kind="CALLS") + # Also search for partially-qualified targets ending with ::name + suffix_rows = conn.execute( + "SELECT * FROM edges WHERE kind = 'CALLS'" + " AND target_qualified LIKE ?", + (f"%::{node.name}",), + ).fetchall() + suffix_edges = [store._row_to_edge(r) for r in suffix_rows] + all_bare = bare + suffix_edges + all_bare = [ + e for e in all_bare + if _is_plausible_caller(e.file_path, node.file_path, node.name) + ] + incoming = incoming + all_bare + if not any(e.kind == "TESTED_BY" for e in incoming): + bare_tb = store.search_edges_by_target_name(node.name, kind="TESTED_BY") + bare_tb = [ + e for e in bare_tb + if _is_plausible_caller(e.file_path, node.file_path, node.name) + ] + incoming = incoming + bare_tb + # Check INHERITS -- classes with subclasses are not dead. + if node.kind == "Class" and not any(e.kind == "INHERITS" for e in incoming): + bare_inh = store.search_edges_by_target_name(node.name, kind="INHERITS") + incoming = incoming + bare_inh has_callers = any(e.kind == "CALLS" for e in incoming) has_test_refs = any(e.kind == "TESTED_BY" for e in incoming) has_importers = any(e.kind == "IMPORTS_FROM" for e in incoming) has_references = any(e.kind == "REFERENCES" for e in incoming) + has_subclasses = any(e.kind == "INHERITS" for e in incoming) - if not has_callers and not has_test_refs and not has_importers and not has_references: - dead.append({ - "name": _sanitize_name(node.name), - "qualified_name": _sanitize_name(node.qualified_name), - "kind": node.kind, - "file": node.file_path, - "line": node.line_start, - }) + # For classes with no direct references, check if any member has callers. + no_refs = not ( + has_callers or has_test_refs or has_importers + or has_references or has_subclasses + ) + if node.kind == "Class" and no_refs: + member_prefix = node.qualified_name + "." + # Also check bare class-name pattern (unresolved CALLS targets) + bare_prefix = node.name + "." + member_calls = conn.execute( + "SELECT COUNT(*) FROM edges WHERE kind = 'CALLS'" + " AND (target_qualified LIKE ? OR target_qualified LIKE ?)", + (f"%{member_prefix}%", f"%{bare_prefix}%"), + ).fetchone()[0] + if member_calls > 0: + has_callers = True + + if not ( + has_callers or has_test_refs or has_importers + or has_references or has_subclasses + ): + # Check if this is a method override where the base class method + # has callers (polymorphic dispatch: callers of Base.method() + # implicitly call SubClass.method() at runtime). + if node.kind == "Function" and node.parent_name and not has_callers: + method_suffix = "." + node.name + if node.qualified_name.endswith(method_suffix): + class_qn = node.qualified_name[: -len(method_suffix)] + for base_name in class_bases.get(class_qn, []): + rows = conn.execute( + "SELECT n.qualified_name FROM nodes n " + "WHERE n.parent_name = ? AND n.name = ? " + "AND n.kind IN ('Function', 'Test')", + (base_name, node.name), + ).fetchall() + for (base_method_qn,) in rows: + if conn.execute( + "SELECT 1 FROM edges " + "WHERE target_qualified = ? AND kind = 'CALLS' " + "LIMIT 1", + (base_method_qn,), + ).fetchone(): + has_callers = True + break + if has_callers: + break + + if not has_callers: + dead.append({ + "name": _sanitize_name(node.name), + "qualified_name": _sanitize_name(node.qualified_name), + "kind": node.kind, + "file": node.file_path, + "line": node.line_start, + }) logger.info("find_dead_code: found %d dead symbols", len(dead)) return dead diff --git a/tests/test_changes.py b/tests/test_changes.py index 93562e28..2e4c803c 100644 --- a/tests/test_changes.py +++ b/tests/test_changes.py @@ -294,6 +294,42 @@ def test_risk_score_with_flow_membership(self): # helper should have flow participation bonus. assert helper_score >= isolated_score + def test_risk_score_weighted_by_flow_criticality(self): + """Nodes in high-criticality flows score higher than low-criticality.""" + # Build two separate flows with different criticality + self._add_func("hi_entry", path="hi.py", line_start=1, line_end=5) + self._add_func("hi_func", path="hi.py", line_start=10, line_end=20) + self._add_call("hi.py::hi_entry", "hi.py::hi_func") + + self._add_func("lo_entry", path="lo.py", line_start=1, line_end=5) + self._add_func("lo_func", path="lo.py", line_start=10, line_end=20) + self._add_call("lo.py::lo_entry", "lo.py::lo_func") + + flows = trace_flows(self.store) + store_flows(self.store, flows) + + # Manually set different criticality values + self.store._conn.execute( + "UPDATE flows SET criticality = 0.9 " + "WHERE name = 'hi_entry'" + ) + self.store._conn.execute( + "UPDATE flows SET criticality = 0.1 " + "WHERE name = 'lo_entry'" + ) + self.store.commit() + + hi = self.store.get_node("hi.py::hi_func") + lo = self.store.get_node("lo.py::lo_func") + assert hi and lo + + hi_score = compute_risk_score(self.store, hi) + lo_score = compute_risk_score(self.store, lo) + assert hi_score > lo_score, ( + f"High-criticality flow node ({hi_score}) should score " + f"higher than low-criticality ({lo_score})" + ) + # --------------------------------------------------------------- # analyze_changes # --------------------------------------------------------------- diff --git a/tests/test_communities.py b/tests/test_communities.py index bcb1e921..7d7b2dca 100644 --- a/tests/test_communities.py +++ b/tests/test_communities.py @@ -165,6 +165,50 @@ def test_architecture_overview(self): assert isinstance(overview["cross_community_edges"], list) assert isinstance(overview["warnings"], list) + def test_architecture_overview_excludes_tested_by_coupling(self): + """TESTED_BY edges do not count toward coupling warnings.""" + self._seed_two_clusters() + communities = detect_communities(self.store, min_size=2) + store_communities(self.store, communities) + + # Add many TESTED_BY cross-community edges (well above the threshold of 10) + for i in range(20): + self.store.upsert_edge(EdgeInfo( + kind="TESTED_BY", source=f"auth.py::login", + target=f"db.py::query", file_path="auth.py", line=i + 100, + )) + self.store.commit() + + overview = get_architecture_overview(self.store) + # Warnings should not include any that are purely from TESTED_BY edges + for w in overview["warnings"]: + assert "TESTED_BY" not in w + + def test_architecture_overview_excludes_test_community_warnings(self): + """Warnings involving test-dominated communities are filtered out.""" + self._seed_two_clusters() + communities = detect_communities(self.store, min_size=2) + store_communities(self.store, communities) + + # Manually insert a test-named community with high cross-coupling + conn = self.store._conn + cursor = conn.execute( + "INSERT INTO communities (name, level, cohesion, size, dominant_language, description)" + " VALUES (?, 0, 0.5, 10, 'typescript', 'Test community')", + ("handler-it:should",), + ) + test_comm_id = cursor.lastrowid + # Assign some nodes to this community (reuse existing node) + conn.execute( + "UPDATE nodes SET community_id = ? WHERE name = 'login'", + (test_comm_id,), + ) + conn.commit() + + overview = get_architecture_overview(self.store) + for w in overview["warnings"]: + assert "it:should" not in w, f"Test community should be filtered: {w}" + def test_fallback_file_communities(self): """File-based fallback produces communities grouped by file.""" self._seed_two_clusters() @@ -351,8 +395,8 @@ def mk_edge(eid: int, src: str, tgt: str, fp: str) -> GraphEdge: assert len(result) == 2 by_desc = {c["description"]: c for c in result} - auth = by_desc["File-based community: auth.py"] - db = by_desc["File-based community: db.py"] + auth = by_desc["Directory-based community: auth"] + db = by_desc["Directory-based community: db"] # Member sets — catches wrong member_qns being passed to batch helper assert set(auth["members"]) == { @@ -478,6 +522,54 @@ def test_igraph_available_is_bool(self): """IGRAPH_AVAILABLE is a boolean.""" assert isinstance(IGRAPH_AVAILABLE, bool) + def test_leiden_fallback_to_file_based(self): + """When Leiden produces 0 communities (all < min_size), fall back to file-based.""" + # Seed nodes with only CONTAINS edges (no CALLS/IMPORTS -- sparse graph) + self.store.upsert_node( + NodeInfo( + kind="File", name="a.py", file_path="a.py", + line_start=1, line_end=100, language="python", + ), file_hash="a1" + ) + self.store.upsert_node( + NodeInfo( + kind="Function", name="f1", file_path="a.py", + line_start=1, line_end=10, language="python", + parent_name=None, + ), file_hash="a1" + ) + self.store.upsert_node( + NodeInfo( + kind="Function", name="f2", file_path="a.py", + line_start=11, line_end=20, language="python", + parent_name=None, + ), file_hash="a1" + ) + self.store.upsert_node( + NodeInfo( + kind="Function", name="f3", file_path="a.py", + line_start=21, line_end=30, language="python", + parent_name=None, + ), file_hash="a1" + ) + self.store.upsert_edge( + EdgeInfo(kind="CONTAINS", source="a.py", target="a.py::f1", + file_path="a.py", line=1) + ) + self.store.upsert_edge( + EdgeInfo(kind="CONTAINS", source="a.py", target="a.py::f2", + file_path="a.py", line=11) + ) + self.store.upsert_edge( + EdgeInfo(kind="CONTAINS", source="a.py", target="a.py::f3", + file_path="a.py", line=21) + ) + # With high min_size, Leiden may produce tiny clusters that get dropped. + # The fallback to file-based should still produce results. + result = detect_communities(self.store, min_size=2) + assert isinstance(result, list) + assert len(result) >= 1 + def test_incremental_detect_no_affected_communities(self): """incremental_detect_communities returns 0 when no communities are affected.""" self._seed_two_clusters() diff --git a/tests/test_flows.py b/tests/test_flows.py index 2600ebcc..34cfd05d 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -109,10 +109,99 @@ def test_detect_entry_points_name_pattern(self): assert "handle_request" in ep_names assert "regular_func" not in ep_names + # --------------------------------------------------------------- + # detect_entry_points -- expanded decorator patterns + # --------------------------------------------------------------- + + def test_detect_entry_points_pytest_fixture(self): + """pytest.fixture decorator marks function as entry point.""" + self._add_func("my_fixture", extra={"decorators": ["pytest.fixture"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "my_fixture" in ep_names + + def test_detect_entry_points_django_receiver(self): + """Django signal receiver decorator marks function as entry point.""" + self._add_func("on_save", extra={"decorators": ["receiver(post_save)"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "on_save" in ep_names + + def test_detect_entry_points_spring_scheduled(self): + """Java Spring @Scheduled marks function as entry point.""" + self._add_func("cleanup_job", extra={"decorators": ["Scheduled(cron='0 0 * * *')"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "cleanup_job" in ep_names + + def test_detect_entry_points_celery_task(self): + """Bare @task decorator marks function as entry point.""" + self._add_func("process_data", extra={"decorators": ["task"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "process_data" in ep_names + + def test_detect_entry_points_agent_tool(self): + """@agent.tool decorator marks function as entry point.""" + self._add_func("query_health", extra={"decorators": ["health_agent.tool"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "query_health" in ep_names + + def test_detect_entry_points_alembic(self): + """upgrade/downgrade functions are entry points.""" + self._add_func("upgrade") + self._add_func("downgrade") + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "upgrade" in ep_names + assert "downgrade" in ep_names + + def test_detect_entry_points_lifespan(self): + """FastAPI lifespan function is an entry point.""" + self._add_func("lifespan") + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "lifespan" in ep_names + # --------------------------------------------------------------- # trace_flows # --------------------------------------------------------------- + def test_detect_entry_points_excludes_tests_by_default(self): + """Test nodes are excluded from entry points by default.""" + self._add_func("production_handler") + self._add_func("it:should do something", is_test=True) + self.store.commit() + + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "production_handler" in ep_names + assert "it:should do something" not in ep_names + + # With include_tests=True, both appear + eps_all = detect_entry_points(self.store, include_tests=True) + ep_names_all = {ep.name for ep in eps_all} + assert "production_handler" in ep_names_all + assert "it:should do something" in ep_names_all + + def test_detect_entry_points_excludes_test_files(self): + """Functions in test files (*.spec.ts, *.test.ts) are excluded by default.""" + self._add_func("production_func", path="src/handler.ts") + self._add_func("describe_block", path="src/handler.spec.ts") + self._add_func("test_helper", path="tests/__tests__/utils.ts") + + eps = detect_entry_points(self.store) + ep_files = {ep.file_path for ep in eps} + assert "src/handler.ts" in ep_files + assert "src/handler.spec.ts" not in ep_files + assert "tests/__tests__/utils.ts" not in ep_files + + # With include_tests=True, they appear + eps_all = detect_entry_points(self.store, include_tests=True) + ep_files_all = {ep.file_path for ep in eps_all} + assert "src/handler.spec.ts" in ep_files_all + def test_trace_simple_flow(self): """BFS traces a linear call chain: A -> B -> C.""" self._add_func("entry") diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 993aef4f..435a0aee 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -191,6 +191,232 @@ def test_find_dead_code_file_pattern(self): dead = find_dead_code(self.store, file_pattern="nonexistent") assert len(dead) == 0 + def test_find_dead_code_excludes_dunder(self): + """Dunder methods are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="__init__", file_path="/repo/app.py", + line_start=90, line_end=95, language="python", + parent_name="MyClass", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "__init__" not in dead_names + + def test_find_dead_code_excludes_constructor(self): + """JS/TS constructors are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="constructor", file_path="/repo/component.ts", + line_start=10, line_end=15, language="typescript", + parent_name="MyComponent", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "constructor" not in dead_names + + def test_find_dead_code_excludes_angular_lifecycle(self): + """Angular lifecycle hooks are not flagged as dead code.""" + for name in ("ngOnInit", "ngOnChanges", "ngOnDestroy", "transform", + "writeValue", "canActivate"): + self.store.upsert_node(NodeInfo( + kind="Function", name=name, file_path="/repo/component.ts", + line_start=10, line_end=15, language="typescript", + parent_name="MyComponent", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + for name in ("ngOnInit", "ngOnChanges", "ngOnDestroy", "transform", + "writeValue", "canActivate"): + assert name not in dead_names, f"{name} should not be dead" + + def test_find_dead_code_excludes_decorated_entry(self): + """Functions with framework decorators are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="get_users", file_path="/repo/app.py", + line_start=90, line_end=95, language="python", + extra={"decorators": ["app.get('/users')"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "get_users" not in dead_names + + def test_find_dead_code_excludes_type_referenced_class(self): + """Classes referenced in function type annotations are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="UserSchema", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + # A function that uses UserSchema in its params + self.store.upsert_node(NodeInfo( + kind="Function", name="create_user", file_path="/repo/app.py", + line_start=20, line_end=30, language="python", + params="body: UserSchema", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "UserSchema" not in dead_names + + def test_find_dead_code_excludes_return_type_reference(self): + """Classes referenced in return types are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="UserResponse", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_node(NodeInfo( + kind="Function", name="get_user", file_path="/repo/app.py", + line_start=20, line_end=30, language="python", + return_type="Optional[UserResponse]", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "UserResponse" not in dead_names + + def test_find_dead_code_excludes_orm_model(self): + """Classes inheriting from known ORM bases are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="User", file_path="/repo/app.py", + line_start=5, line_end=20, language="python", + )) + self.store.upsert_edge(EdgeInfo( + kind="INHERITS", source="/repo/app.py::User", + target="Base", file_path="/repo/app.py", line=5, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "User" not in dead_names + + def test_find_dead_code_excludes_pydantic_settings(self): + """Classes inheriting from BaseSettings are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="AppConfig", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_edge(EdgeInfo( + kind="INHERITS", source="/repo/app.py::AppConfig", + target="BaseSettings", file_path="/repo/app.py", line=5, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "AppConfig" not in dead_names + + def test_find_dead_code_excludes_agent_tool(self): + """Functions with @agent.tool decorator are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="query_data", file_path="/repo/app.py", + line_start=10, line_end=20, language="python", + extra={"decorators": ["health_agent.tool"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "query_data" not in dead_names + + def test_find_dead_code_excludes_alembic_upgrade(self): + """upgrade() and downgrade() in alembic files are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="upgrade", file_path="/repo/alembic/versions/001.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_node(NodeInfo( + kind="Function", name="downgrade", file_path="/repo/alembic/versions/001.py", + line_start=20, line_end=30, language="python", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "upgrade" not in dead_names + assert "downgrade" not in dead_names + + def test_find_dead_code_excludes_subclassed_class(self): + """Classes with subclasses (INHERITS edges) are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="BaseConnector", file_path="/repo/connectors.py", + line_start=5, line_end=50, language="python", + )) + # A subclass inherits from BaseConnector (bare-name target) + self.store.upsert_edge(EdgeInfo( + kind="INHERITS", source="/repo/connectors.py::GarminConnector", + target="BaseConnector", file_path="/repo/connectors.py", line=60, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "BaseConnector" not in dead_names + + def test_find_dead_code_bare_name_not_tricked_by_unrelated_caller(self): + """Bare-name CALLS from unrelated files don't save a dead function + when there are multiple definitions with the same name.""" + # Two unrelated functions named "processor" in different files + self.store.upsert_node(NodeInfo( + kind="Function", name="processor", file_path="/repo/api/routes.py", + line_start=10, line_end=20, language="python", + )) + self.store.upsert_node(NodeInfo( + kind="Function", name="processor", file_path="/repo/worker/tasks.py", + line_start=10, line_end=20, language="python", + )) + # A bare CALLS edge from a third file that imports only routes.py + self.store.upsert_edge(EdgeInfo( + kind="IMPORTS_FROM", source="/repo/main.py", + target="/repo/api/routes.py", file_path="/repo/main.py", line=1, + )) + self.store.upsert_edge(EdgeInfo( + kind="CALLS", source="/repo/main.py::start", + target="processor", file_path="/repo/main.py", line=10, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_qnames = {d["qualified_name"] for d in dead} + # routes.py processor is saved (caller imports its file) + assert "/repo/api/routes.py::processor" not in dead_qnames + # worker/tasks.py processor is dead (no relationship with caller) + assert "/repo/worker/tasks.py::processor" in dead_qnames + + def test_find_dead_code_excludes_mock_variables(self): + """Mock/stub variables in test files are not flagged as dead code.""" + for name in ("mockDynamoClient", "s3ClientMock", "MockService", "createMockRequest"): + self.store.upsert_node(NodeInfo( + kind="Function", name=name, file_path="/repo/tests/handler.spec.ts", + line_start=10, line_end=15, language="typescript", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + for name in ("mockDynamoClient", "s3ClientMock", "MockService", "createMockRequest"): + assert name not in dead_names, f"{name} should not be dead (mock pattern)" + + def test_find_dead_code_excludes_angular_decorated_class(self): + """Angular @Component classes are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="ClipboardButtonComponent", + file_path="/repo/src/app/clipboard.component.ts", + line_start=5, line_end=50, language="typescript", + extra={"decorators": ["Component({selector: 'app-clipboard'})"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "ClipboardButtonComponent" not in dead_names + + def test_find_dead_code_excludes_property(self): + """Functions decorated with @property are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="db", file_path="/repo/deps.py", + line_start=10, line_end=15, language="python", + extra={"decorators": ["property"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "db" not in dead_names + class TestSuggestRefactorings: """Tests for suggest_refactorings.""" @@ -553,3 +779,45 @@ def test_only_references_edge_sufficient(self): dead_names = {d["name"] for d in dead} # handleCreate has only a REFERENCES edge, no CALLS targeting it assert "handleCreate" not in dead_names + + +class TestTransitiveImportResolution: + """Tests for 2-hop transitive import resolution in plausible caller.""" + + def setup_method(self): + self.store = GraphStore(":memory:") + for f in ("/repo/consumer.ts", "/repo/lib/index.ts", "/repo/lib/utils.ts"): + self.store.upsert_node(NodeInfo( + kind="File", name=f, file_path=f, + line_start=1, line_end=50, language="typescript", + )) + + def test_transitive_import_via_barrel_file(self): + """consumer.ts imports index.ts which re-exports from utils.ts. + A bare-name CALLS from consumer.ts should be plausible for utils.ts functions.""" + # Function defined in utils.ts + self.store.upsert_node(NodeInfo( + kind="Function", name="safeJsonParse", + file_path="/repo/lib/utils.ts", + line_start=10, line_end=20, language="typescript", + )) + # Import chain: consumer -> index -> utils + self.store.upsert_edge(EdgeInfo( + kind="IMPORTS_FROM", source="/repo/consumer.ts", + target="/repo/lib/index.ts", file_path="/repo/consumer.ts", line=1, + )) + self.store.upsert_edge(EdgeInfo( + kind="IMPORTS_FROM", source="/repo/lib/index.ts", + target="/repo/lib/utils.ts", file_path="/repo/lib/index.ts", line=1, + )) + # Bare-name CALLS from consumer + self.store.upsert_edge(EdgeInfo( + kind="CALLS", source="/repo/consumer.ts::processData", + target="safeJsonParse", file_path="/repo/consumer.ts", line=5, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "safeJsonParse" not in dead_names, ( + "2-hop import chain should make consumer a plausible caller" + )