From 128bf1113dbbe29c41ecc4f23058f5c20a346f2d Mon Sep 17 00:00:00 2001 From: Gagan Kalra Date: Sat, 4 Apr 2026 18:03:26 +0530 Subject: [PATCH] fix: CLI build/update/watch now run post-processing (signatures, FTS, flows, communities) Extract the 4-step post-processing pipeline from tools/build.py into a shared postprocessing.py module and wire it into all CLI entry points so that `build`, `update`, and `watch` produce the same complete graph as the MCP tool. - Add code_review_graph/postprocessing.py with run_post_processing() - tools/build.py now delegates to run_post_processing() (no duplication) - cli.py calls _cli_post_process() after build and update commands - watch() accepts on_files_updated callback, invoked after each flush - 18 new tests covering all steps, isolation, idempotency, and watch Closes #93 --- code_review_graph/cli.py | 210 +++++++++++------- code_review_graph/incremental.py | 98 ++++++--- code_review_graph/postprocessing.py | 134 ++++++++++++ code_review_graph/tools/build.py | 68 +++--- tests/test_postprocessing.py | 327 ++++++++++++++++++++++++++++ 5 files changed, 698 insertions(+), 139 deletions(-) create mode 100644 code_review_graph/postprocessing.py create mode 100644 tests/test_postprocessing.py diff --git a/code_review_graph/cli.py b/code_review_graph/cli.py index 70b0b7cd..f18aeea2 100644 --- a/code_review_graph/cli.py +++ b/code_review_graph/cli.py @@ -35,6 +35,10 @@ from importlib.metadata import PackageNotFoundError from importlib.metadata import version as pkg_version from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph import GraphStore logger = logging.getLogger(__name__) @@ -63,12 +67,12 @@ def _print_banner() -> None: version = _get_version() # ANSI escape codes - c = "\033[36m" if color else "" # cyan — graph art - y = "\033[33m" if color else "" # yellow — center node - b = "\033[1m" if color else "" # bold - d = "\033[2m" if color else "" # dim - g = "\033[32m" if color else "" # green — commands - r = "\033[0m" if color else "" # reset + c = "\033[36m" if color else "" # cyan — graph art + y = "\033[33m" if color else "" # yellow — center node + b = "\033[1m" if color else "" # bold + d = "\033[2m" if color else "" # dim + g = "\033[32m" if color else "" # green — commands + r = "\033[0m" if color else "" # reset print(f""" {c} ●──●──●{r} @@ -99,7 +103,8 @@ def _print_banner() -> None: def _instruction_files_to_modify( - repo_root: Path, target: str, + repo_root: Path, + target: str, ) -> list[str]: """Return the list of instruction files that ``install`` would write or modify, given the current state of the repo and the selected @@ -250,90 +255,131 @@ def _handle_init(args: argparse.Namespace) -> None: print(" 2. Restart your AI coding tool to pick up the new config") +def _cli_post_process(store: GraphStore) -> None: + """Run post-build pipeline and print a summary line for each step.""" + from .postprocessing import run_post_processing + + pp = run_post_processing(store) + if pp.get("signatures_computed"): + print(f"Signatures: {pp['signatures_computed']} nodes") + if pp.get("fts_indexed"): + print(f"FTS indexed: {pp['fts_indexed']} nodes") + if pp.get("flows_detected") is not None: + print(f"Flows: {pp['flows_detected']}") + if pp.get("communities_detected") is not None: + print(f"Communities: {pp['communities_detected']}") + + def main() -> None: """Main CLI entry point.""" ap = argparse.ArgumentParser( prog="code-review-graph", description="Persistent incremental knowledge graph for code reviews", ) - ap.add_argument( - "-v", "--version", action="store_true", help="Show version and exit" - ) + ap.add_argument("-v", "--version", action="store_true", help="Show version and exit") sub = ap.add_subparsers(dest="command") # install (primary) + init (alias) - install_cmd = sub.add_parser( - "install", help="Register MCP server with AI coding platforms" - ) + install_cmd = sub.add_parser("install", help="Register MCP server with AI coding platforms") install_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") install_cmd.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Show what would be done without writing files", ) install_cmd.add_argument( - "--no-skills", action="store_true", + "--no-skills", + action="store_true", help="Skip generating Claude Code skill files", ) install_cmd.add_argument( - "--no-hooks", action="store_true", + "--no-hooks", + action="store_true", help="Skip installing Claude Code hooks", ) install_cmd.add_argument( - "--no-instructions", action="store_true", + "--no-instructions", + action="store_true", help="Skip injecting graph instructions into CLAUDE.md / AGENTS.md / etc.", ) install_cmd.add_argument( - "-y", "--yes", action="store_true", + "-y", + "--yes", + action="store_true", help="Auto-confirm instruction injection without an interactive prompt", ) # Legacy flags (kept for backwards compat, now no-ops since all is default) install_cmd.add_argument("--skills", action="store_true", help=argparse.SUPPRESS) install_cmd.add_argument("--hooks", action="store_true", help=argparse.SUPPRESS) - install_cmd.add_argument("--all", action="store_true", dest="install_all", - help=argparse.SUPPRESS) + install_cmd.add_argument( + "--all", action="store_true", dest="install_all", help=argparse.SUPPRESS + ) install_cmd.add_argument( "--platform", choices=[ - "codex", "claude", "claude-code", "cursor", "windsurf", "zed", - "continue", "opencode", "antigravity", "qwen", "kiro", "all", + "codex", + "claude", + "claude-code", + "cursor", + "windsurf", + "zed", + "continue", + "opencode", + "antigravity", + "qwen", + "kiro", + "all", ], default="all", help="Target platform for MCP config (default: all detected)", ) - init_cmd = sub.add_parser( - "init", help="Alias for install" - ) + init_cmd = sub.add_parser("init", help="Alias for install") init_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") init_cmd.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Show what would be done without writing files", ) init_cmd.add_argument( - "--no-skills", action="store_true", + "--no-skills", + action="store_true", help="Skip generating Claude Code skill files", ) init_cmd.add_argument( - "--no-hooks", action="store_true", + "--no-hooks", + action="store_true", help="Skip installing Claude Code hooks", ) init_cmd.add_argument( - "--no-instructions", action="store_true", + "--no-instructions", + action="store_true", help="Skip injecting graph instructions into CLAUDE.md / AGENTS.md / etc.", ) init_cmd.add_argument( - "-y", "--yes", action="store_true", + "-y", + "--yes", + action="store_true", help="Auto-confirm instruction injection without an interactive prompt", ) init_cmd.add_argument("--skills", action="store_true", help=argparse.SUPPRESS) init_cmd.add_argument("--hooks", action="store_true", help=argparse.SUPPRESS) - init_cmd.add_argument("--all", action="store_true", dest="install_all", - help=argparse.SUPPRESS) + init_cmd.add_argument("--all", action="store_true", dest="install_all", help=argparse.SUPPRESS) init_cmd.add_argument( "--platform", choices=[ - "codex", "claude", "claude-code", "cursor", "windsurf", "zed", - "continue", "opencode", "antigravity", "qwen", "kiro", "all", + "codex", + "claude", + "claude-code", + "cursor", + "windsurf", + "zed", + "continue", + "opencode", + "antigravity", + "qwen", + "kiro", + "all", ], default="all", help="Target platform for MCP config (default: all detected)", @@ -343,11 +389,13 @@ def main() -> None: build_cmd = sub.add_parser("build", help="Full graph build (re-parse all files)") build_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") build_cmd.add_argument( - "--skip-flows", action="store_true", + "--skip-flows", + action="store_true", help="Skip flow/community detection (signatures + FTS only)", ) build_cmd.add_argument( - "--skip-postprocess", action="store_true", + "--skip-postprocess", + action="store_true", help="Skip all post-processing (raw parse only)", ) @@ -356,11 +404,13 @@ def main() -> None: update_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") update_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") update_cmd.add_argument( - "--skip-flows", action="store_true", + "--skip-flows", + action="store_true", help="Skip flow/community detection (signatures + FTS only)", ) update_cmd.add_argument( - "--skip-postprocess", action="store_true", + "--skip-postprocess", + action="store_true", help="Skip all post-processing (raw parse only)", ) @@ -392,7 +442,8 @@ def main() -> None: help="Rendering mode: auto (default), full, community, or file", ) vis_cmd.add_argument( - "--serve", action="store_true", + "--serve", + action="store_true", help="Start a local HTTP server to view the visualization (localhost:8765)", ) vis_cmd.add_argument( @@ -406,7 +457,8 @@ def main() -> None: wiki_cmd = sub.add_parser("wiki", help="Generate markdown wiki from community structure") wiki_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") wiki_cmd.add_argument( - "--force", action="store_true", + "--force", + action="store_true", help="Regenerate all pages even if content unchanged", ) @@ -429,9 +481,10 @@ def main() -> None: # eval eval_cmd = sub.add_parser("eval", help="Run evaluation benchmarks") eval_cmd.add_argument( - "--benchmark", default=None, + "--benchmark", + default=None, help="Comma-separated benchmarks to run (token_efficiency, impact_accuracy, " - "flow_completeness, search_quality, build_performance)", + "flow_completeness, search_quality, build_performance)", ) eval_cmd.add_argument("--repo", default=None, help="Comma-separated repo config names") eval_cmd.add_argument("--all", action="store_true", dest="run_all", help="Run all benchmarks") @@ -440,12 +493,8 @@ def main() -> None: # detect-changes detect_cmd = sub.add_parser("detect-changes", help="Analyze change impact") - detect_cmd.add_argument( - "--base", default="HEAD~1", help="Git diff base (default: HEAD~1)" - ) - detect_cmd.add_argument( - "--brief", action="store_true", help="Show brief summary only" - ) + detect_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") + detect_cmd.add_argument("--brief", action="store_true", help="Show brief summary only") detect_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") # serve @@ -464,6 +513,7 @@ def main() -> None: if args.command == "serve": from .main import main as serve_main + serve_main(repo_root=args.repo) return @@ -472,9 +522,7 @@ def main() -> None: from .eval.runner import run_eval if getattr(args, "report", False): - output_dir = Path( - getattr(args, "output_dir", None) or "evaluate/results" - ) + output_dir = Path(getattr(args, "output_dir", None) or "evaluate/results") report = generate_full_report(output_dir) report_path = Path("evaluate/reports/summary.md") report_path.parent.mkdir(parents=True, exist_ok=True) @@ -486,9 +534,7 @@ def main() -> None: print(tables) else: repos = ( - [r.strip() for r in args.repo.split(",")] - if getattr(args, "repo", None) - else None + [r.strip() for r in args.repo.split(",")] if getattr(args, "repo", None) else None ) benchmarks = ( [b.strip() for b in args.benchmark.split(",")] @@ -560,6 +606,7 @@ def main() -> None: store = GraphStore(db_path) try: from .tools.build import run_postprocess + result = run_postprocess( flows=not getattr(args, "no_flows", False), communities=not getattr(args, "no_communities", False), @@ -596,32 +643,39 @@ def main() -> None: try: if args.command == "build": - pp = "none" if getattr(args, "skip_postprocess", False) else ( - "minimal" if getattr(args, "skip_flows", False) else "full" + pp = ( + "none" + if getattr(args, "skip_postprocess", False) + else ("minimal" if getattr(args, "skip_flows", False) else "full") ) from .tools.build import build_or_update_graph + result = build_or_update_graph( - full_rebuild=True, repo_root=str(repo_root), postprocess=pp, + full_rebuild=True, + repo_root=str(repo_root), + postprocess=pp, ) parsed = result.get("files_parsed", 0) nodes = result.get("total_nodes", 0) edges = result.get("total_edges", 0) - print( - f"Full build: {parsed} files, " - f"{nodes} nodes, {edges} edges" - f" (postprocess={pp})" - ) + print(f"Full build: {parsed} files, {nodes} nodes, {edges} edges (postprocess={pp})") if result.get("errors"): print(f"Errors: {len(result['errors'])}") + _cli_post_process(store) elif args.command == "update": - pp = "none" if getattr(args, "skip_postprocess", False) else ( - "minimal" if getattr(args, "skip_flows", False) else "full" + pp = ( + "none" + if getattr(args, "skip_postprocess", False) + else ("minimal" if getattr(args, "skip_flows", False) else "full") ) from .tools.build import build_or_update_graph + result = build_or_update_graph( - full_rebuild=False, repo_root=str(repo_root), - base=args.base, postprocess=pp, + full_rebuild=False, + repo_root=str(repo_root), + base=args.base, + postprocess=pp, ) updated = result.get("files_updated", 0) nodes = result.get("total_nodes", 0) @@ -631,6 +685,8 @@ def main() -> None: f"{nodes} nodes, {edges} edges" f" (postprocess={pp})" ) + if result.get("files_updated", 0) > 0: + _cli_post_process(store) elif args.command == "status": stats = store.get_stats() @@ -647,6 +703,7 @@ def main() -> None: if stored_sha: print(f"Built at commit: {stored_sha[:12]}") from .incremental import _git_branch_info + current_branch, current_sha = _git_branch_info(repo_root) if stored_branch and current_branch and stored_branch != current_branch: print( @@ -656,35 +713,43 @@ def main() -> None: ) elif args.command == "watch": - watch(repo_root, store) + from .postprocessing import run_post_processing + + watch(repo_root, store, on_files_updated=run_post_processing) elif args.command == "visualize": from .incremental import get_data_dir + data_dir = get_data_dir(repo_root) fmt = getattr(args, "format", "html") or "html" if fmt == "graphml": from .exports import export_graphml + out = data_dir / "graph.graphml" export_graphml(store, out) print(f"GraphML exported: {out}") elif fmt == "cypher": from .exports import export_neo4j_cypher + out = data_dir / "graph.cypher" export_neo4j_cypher(store, out) print(f"Neo4j Cypher exported: {out}") elif fmt == "obsidian": from .exports import export_obsidian_vault + out = data_dir / "obsidian" export_obsidian_vault(store, out) print(f"Obsidian vault exported: {out}") elif fmt == "svg": from .exports import export_svg + out = data_dir / "graph.svg" export_svg(store, out) print(f"SVG exported: {out}") else: from .visualization import generate_html + html_path = data_dir / "graph.html" vis_mode = getattr(args, "mode", "auto") or "auto" generate_html(store, html_path, mode=vis_mode) @@ -692,20 +757,16 @@ def main() -> None: if getattr(args, "serve", False): import functools import http.server + serve_dir = html_path.parent port = 8765 handler = functools.partial( http.server.SimpleHTTPRequestHandler, directory=str(serve_dir), ) - print( - f"Serving at http://localhost:{port}" - f"/graph.html" - ) + print(f"Serving at http://localhost:{port}/graph.html") print("Press Ctrl+C to stop.") - with http.server.HTTPServer( - ("localhost", port), handler - ) as httpd: + with http.server.HTTPServer(("localhost", port), handler) as httpd: try: httpd.serve_forever() except KeyboardInterrupt: @@ -716,6 +777,7 @@ def main() -> None: elif args.command == "wiki": from .incremental import get_data_dir from .wiki import generate_wiki + wiki_dir = get_data_dir(repo_root) / "wiki" result = generate_wiki(store, wiki_dir, force=args.force) total = result["pages_generated"] + result["pages_updated"] + result["pages_unchanged"] diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index cfa672c0..e20fdb74 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -15,14 +15,12 @@ import subprocess import time from pathlib import Path, PurePosixPath -from typing import Optional +from typing import Callable, Optional from .graph import GraphStore from .parser import CodeParser -_MAX_PARSE_WORKERS = int(os.environ.get( - "CRG_PARSE_WORKERS", str(min(os.cpu_count() or 4, 8)) -)) +_MAX_PARSE_WORKERS = int(os.environ.get("CRG_PARSE_WORKERS", str(min(os.cpu_count() or 4, 8)))) logger = logging.getLogger(__name__) @@ -247,9 +245,7 @@ def _is_binary(path: Path) -> bool: # When True, `git ls-files --recurse-submodules` is used so that files # inside git submodules are included in the graph. Opt-in via env var; # can also be overridden per-call through function parameters. -_RECURSE_SUBMODULES = os.environ.get( - "CRG_RECURSE_SUBMODULES", "" -).lower() in ("1", "true", "yes") +_RECURSE_SUBMODULES = os.environ.get("CRG_RECURSE_SUBMODULES", "").lower() in ("1", "true", "yes") def _git_branch_info(repo_root: Path) -> tuple[str, str]: @@ -259,8 +255,10 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: try: result = subprocess.run( ["git", "rev-parse", "--abbrev-ref", "HEAD"], - capture_output=True, text=True, - cwd=str(repo_root), timeout=_GIT_TIMEOUT, + capture_output=True, + text=True, + cwd=str(repo_root), + timeout=_GIT_TIMEOUT, ) if result.returncode == 0: branch = result.stdout.strip() @@ -269,8 +267,10 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: try: result = subprocess.run( ["git", "rev-parse", "HEAD"], - capture_output=True, text=True, - cwd=str(repo_root), timeout=_GIT_TIMEOUT, + capture_output=True, + text=True, + cwd=str(repo_root), + timeout=_GIT_TIMEOUT, ) if result.returncode == 0: sha = result.stdout.strip() @@ -278,6 +278,7 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: pass return branch, sha + _SAFE_GIT_REF = re.compile(r"^[A-Za-z0-9_.~^/@{}\-]+$") @@ -386,11 +387,7 @@ def collect_all_files( candidates = tracked else: # Fallback: walk directory - candidates = [ - str(p.relative_to(repo_root)) - for p in repo_root.rglob("*") - if p.is_file() - ] + candidates = [str(p.relative_to(repo_root)) for p in repo_root.rglob("*") if p.is_file()] for rel_path in candidates: if _should_ignore(rel_path, ignore_patterns): @@ -458,7 +455,8 @@ def find_dependents( if len(all_dependents) > _MAX_DEPENDENT_FILES: logger.warning( "Dependent expansion capped at %d files for %s", - len(all_dependents), file_path, + len(all_dependents), + file_path, ) # Truncate to the cap return list(all_dependents)[:_MAX_DEPENDENT_FILES] @@ -545,7 +543,8 @@ def full_build( max_workers=_MAX_PARSE_WORKERS, ) as executor: for i, (rel_path, nodes, edges, error, fhash) in enumerate( - executor.map(_parse_single_file, args_list, chunksize=20), 1, + executor.map(_parse_single_file, args_list, chunksize=20), + 1, ): if error: logger.warning("Error parsing %s: %s", rel_path, error) @@ -553,7 +552,10 @@ def full_build( continue full_path = repo_root / rel_path store.store_file_nodes_edges( - str(full_path), nodes, edges, fhash, + str(full_path), + nodes, + edges, + fhash, ) total_nodes += len(nodes) total_edges += len(edges) @@ -671,14 +673,19 @@ def incremental_update( max_workers=_MAX_PARSE_WORKERS, ) as executor: for rel_path, nodes, edges, error, fhash in executor.map( - _parse_single_file, args_list, chunksize=20, + _parse_single_file, + args_list, + chunksize=20, ): if error: logger.warning("Error parsing %s: %s", rel_path, error) errors.append({"file": rel_path, "error": error}) continue store.store_file_nodes_edges( - str(repo_root / rel_path), nodes, edges, fhash, + str(repo_root / rel_path), + nodes, + edges, + fhash, ) total_nodes += len(nodes) total_edges += len(edges) @@ -710,10 +717,22 @@ def incremental_update( _DEBOUNCE_SECONDS = 0.3 -def watch(repo_root: Path, store: GraphStore) -> None: +def watch( + repo_root: Path, + store: GraphStore, + on_files_updated: Optional[Callable] = None, +) -> None: """Watch for file changes and auto-update the graph. Uses a 300ms debounce to batch rapid-fire saves into a single update. + + Args: + repo_root: Repository root to watch. + store: Graph database to update. + on_files_updated: Optional callback invoked after each debounced + batch of file updates completes. Receives the store as its + only argument. Used by the CLI to run post-processing + (FTS, flows, communities) after watch updates. """ import threading @@ -777,9 +796,7 @@ def _schedule(self, abs_path: str): self._pending.add(abs_path) if self._timer is not None: self._timer.cancel() - self._timer = threading.Timer( - _DEBOUNCE_SECONDS, self._flush - ) + self._timer = threading.Timer(_DEBOUNCE_SECONDS, self._flush) self._timer.start() def _flush(self): @@ -789,33 +806,43 @@ def _flush(self): self._pending.clear() self._timer = None + updated = 0 for abs_path in paths: - self._update_file(abs_path) + if self._update_file(abs_path): + updated += 1 + + if updated > 0 and on_files_updated is not None: + try: + on_files_updated(store) + except Exception as e: + logger.error("Post-update callback failed: %s", e) - def _update_file(self, abs_path: str): + def _update_file(self, abs_path: str) -> bool: path = Path(abs_path) if not path.is_file(): - return + return False if path.is_symlink(): - return + return False if _is_binary(path): - return + return False try: source = path.read_bytes() fhash = hashlib.sha256(source).hexdigest() nodes, edges = parser.parse_bytes(path, source) store.store_file_nodes_edges(abs_path, nodes, edges, fhash) - store.set_metadata( - "last_updated", time.strftime("%Y-%m-%dT%H:%M:%S") - ) + store.set_metadata("last_updated", time.strftime("%Y-%m-%dT%H:%M:%S")) store.commit() rel = str(path.relative_to(repo_root)) logger.info( "Updated: %s (%d nodes, %d edges)", - rel, len(nodes), len(edges), + rel, + len(nodes), + len(edges), ) + return True except Exception as e: logger.error("Error updating %s: %s", abs_path, e) + return False handler = GraphUpdateHandler() observer = Observer() @@ -825,11 +852,10 @@ def _update_file(self, abs_path: str): logger.info("Watching %s for changes... (Ctrl+C to stop)", repo_root) try: import time as _time + while True: _time.sleep(1) except KeyboardInterrupt: observer.stop() observer.join() logger.info("Watch stopped.") - - diff --git a/code_review_graph/postprocessing.py b/code_review_graph/postprocessing.py new file mode 100644 index 00000000..c7dec597 --- /dev/null +++ b/code_review_graph/postprocessing.py @@ -0,0 +1,134 @@ +"""Shared post-build processing pipeline. + +After the core Tree-sitter parse (full_build or incremental_update), four +post-processing steps must run to populate derived tables: + +1. Compute node signatures +2. Rebuild FTS5 search index +3. Trace execution flows +4. Detect code communities + +This module extracts that pipeline so every entry point — MCP tool, CLI +commands, and watch mode — produces identical results. +""" + +from __future__ import annotations + +import logging +import sqlite3 +from typing import Any + +from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def run_post_processing(store: GraphStore) -> dict[str, Any]: + """Run all post-build steps on a populated graph. + + Each step is non-fatal: failures are logged and collected as warnings + so the primary build result is never lost. + + Args: + store: An open GraphStore with nodes and edges already populated. + + Returns: + Dict with keys for each step's result count and a ``warnings`` + list (only present when at least one step failed). + """ + result: dict[str, Any] = {} + warnings: list[str] = [] + + _compute_signatures(store, result, warnings) + _rebuild_fts_index(store, result, warnings) + _trace_flows(store, result, warnings) + _detect_communities(store, result, warnings) + + if warnings: + result["warnings"] = warnings + return result + + +# -- Individual steps (private) ------------------------------------------ + + +def _compute_signatures( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Compute human-readable signatures for nodes that lack one.""" + try: + rows = store.get_nodes_without_signature() + for row in rows: + node_id, name, kind, params, ret = ( + row[0], + row[1], + row[2], + row[3], + row[4], + ) + if kind in ("Function", "Test"): + sig = f"def {name}({params or ''})" + if ret: + sig += f" -> {ret}" + elif kind == "Class": + sig = f"class {name}" + else: + sig = name + store.update_node_signature(node_id, sig[:512]) + store.commit() + result["signatures_computed"] = len(rows) + except (sqlite3.OperationalError, TypeError, KeyError) as e: + logger.warning("Signature computation failed: %s", e) + warnings.append(f"Signature computation failed: {type(e).__name__}: {e}") + + +def _rebuild_fts_index( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Rebuild the FTS5 full-text search index.""" + try: + from .search import rebuild_fts_index + + fts_count = rebuild_fts_index(store) + result["fts_indexed"] = fts_count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("FTS index rebuild failed: %s", e) + warnings.append(f"FTS index rebuild failed: {type(e).__name__}: {e}") + + +def _trace_flows( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Trace execution flows from entry points.""" + try: + from .flows import store_flows, trace_flows + + flows = trace_flows(store) + count = store_flows(store, flows) + result["flows_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Flow detection failed: %s", e) + warnings.append(f"Flow detection failed: {type(e).__name__}: {e}") + + +def _detect_communities( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Detect code communities via Leiden algorithm or file grouping.""" + try: + from .communities import detect_communities, store_communities + + comms = detect_communities(store) + count = store_communities(store, comms) + result["communities_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Community detection failed: %s", e) + warnings.append(f"Community detection failed: {type(e).__name__}: {e}") diff --git a/code_review_graph/tools/build.py b/code_review_graph/tools/build.py index ddfc210c..d8bfe9b1 100644 --- a/code_review_graph/tools/build.py +++ b/code_review_graph/tools/build.py @@ -38,7 +38,11 @@ def _run_postprocess( rows = store.get_nodes_without_signature() for row in rows: node_id, name, kind, params, ret = ( - row[0], row[1], row[2], row[3], row[4], + row[0], + row[1], + row[2], + row[3], + row[4], ) if kind in ("Function", "Test"): sig = f"def {name}({params or ''})" @@ -118,7 +122,8 @@ def _run_postprocess( warnings.append(f"Summary computation failed: {type(e).__name__}: {e}") store.set_metadata( - "last_postprocessed_at", time.strftime("%Y-%m-%dT%H:%M:%S"), + "last_postprocessed_at", + time.strftime("%Y-%m-%dT%H:%M:%S"), ) store.set_metadata("postprocess_level", postprocess) @@ -156,13 +161,11 @@ def _compute_summaries(store: Any) -> None: # thousands of communities was the second-biggest hang. edge_counts: dict[str, int] = defaultdict(int) for row in conn.execute( - "SELECT source_qualified, COUNT(*) FROM edges " - "GROUP BY source_qualified" + "SELECT source_qualified, COUNT(*) FROM edges GROUP BY source_qualified" ): edge_counts[row[0]] += row[1] for row in conn.execute( - "SELECT target_qualified, COUNT(*) FROM edges " - "GROUP BY target_qualified" + "SELECT target_qualified, COUNT(*) FROM edges GROUP BY target_qualified" ): edge_counts[row[0]] += row[1] @@ -180,8 +183,7 @@ def _compute_summaries(store: Any) -> None: files_by_comm: dict[int, list[str]] = defaultdict(list) seen_files: dict[int, set[str]] = defaultdict(set) for row in conn.execute( - "SELECT community_id, file_path FROM nodes " - "WHERE community_id IS NOT NULL" + "SELECT community_id, file_path FROM nodes WHERE community_id IS NOT NULL" ): cid, fp = row[0], row[1] if fp not in seen_files[cid]: @@ -209,10 +211,7 @@ def _compute_summaries(store: Any) -> None: if paths: prefix = commonprefix(paths) if "/" in prefix: - purpose = ( - prefix.rsplit("/", 1)[0].split("/")[-1] - if "/" in prefix else "" - ) + purpose = prefix.rsplit("/", 1)[0].split("/")[-1] if "/" in prefix else "" conn.execute( "INSERT OR REPLACE INTO community_summaries " @@ -255,11 +254,10 @@ def _compute_summaries(store: Any) -> None: # GraphStore.get_edges_among. id_list = list(needed_ids) for i in range(0, len(id_list), 450): - batch = id_list[i:i + 450] + batch = id_list[i : i + 450] placeholders = ",".join("?" for _ in batch) node_rows = conn.execute( - "SELECT id, qualified_name FROM nodes " - f"WHERE id IN ({placeholders})", # nosec B608 + f"SELECT id, qualified_name FROM nodes WHERE id IN ({placeholders})", # nosec B608 batch, ).fetchall() for nr in node_rows: @@ -285,8 +283,7 @@ def _compute_summaries(store: Any) -> None: "INSERT OR REPLACE INTO flow_snapshots " "(flow_id, name, entry_point, critical_path, criticality, " "node_count, file_count) VALUES (?, ?, ?, ?, ?, ?, ?)", - (fid, fname, ep_name, _json.dumps(critical_path), - crit, ncount, fcount), + (fid, fname, ep_name, _json.dumps(critical_path), crit, ncount, fcount), ) conn.commit() except sqlite3.OperationalError: @@ -317,12 +314,20 @@ def _compute_summaries(store: Any) -> None: tested_counts[row[0]] = row[1] risk_nodes = conn.execute( - "SELECT id, qualified_name, name FROM nodes " - "WHERE kind IN ('Function', 'Class', 'Test')" + "SELECT id, qualified_name, name FROM nodes WHERE kind IN ('Function', 'Class', 'Test')" ).fetchall() security_kw = { - "auth", "login", "password", "token", "session", "crypt", - "secret", "credential", "permission", "sql", "execute", + "auth", + "login", + "password", + "token", + "session", + "crypt", + "secret", + "credential", + "permission", + "sql", + "execute", } for n in risk_nodes: nid, qn, name = n[0], n[1], n[2] @@ -330,9 +335,7 @@ def _compute_summaries(store: Any) -> None: tested = tested_counts.get(qn, 0) coverage = "tested" if tested > 0 else "untested" name_lower = name.lower() - sec_relevant = ( - 1 if any(kw in name_lower for kw in security_kw) else 0 - ) + sec_relevant = 1 if any(kw in name_lower for kw in security_kw) else 0 risk = 0.0 if caller_count > 10: risk += 0.3 @@ -421,8 +424,11 @@ def build_or_update_graph( # Pass changed_files for incremental flow/community detection changed = result.get("changed_files") if not full_rebuild else None warnings = _run_postprocess( - store, build_result, postprocess, - full_rebuild=full_rebuild, changed_files=changed, + store, + build_result, + postprocess, + full_rebuild=full_rebuild, + changed_files=changed, ) if warnings: build_result["warnings"] = warnings @@ -457,12 +463,15 @@ def run_postprocess( warnings: list[str] = [] try: - # Signatures are always fast — run them try: rows = store.get_nodes_without_signature() for row in rows: node_id, name, kind, params, ret = ( - row[0], row[1], row[2], row[3], row[4], + row[0], + row[1], + row[2], + row[3], + row[4], ) if kind in ("Function", "Test"): sig = f"def {name}({params or ''})" @@ -518,7 +527,8 @@ def run_postprocess( warnings.append(f"Community detection failed: {type(e).__name__}: {e}") store.set_metadata( - "last_postprocessed_at", time.strftime("%Y-%m-%dT%H:%M:%S"), + "last_postprocessed_at", + time.strftime("%Y-%m-%dT%H:%M:%S"), ) result["summary"] = "Post-processing complete." if warnings: diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py new file mode 100644 index 00000000..f9b0f946 --- /dev/null +++ b/tests/test_postprocessing.py @@ -0,0 +1,327 @@ +"""Tests for the shared post-processing pipeline.""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +from code_review_graph.graph import GraphStore +from code_review_graph.incremental import full_build +from code_review_graph.parser import EdgeInfo, NodeInfo +from code_review_graph.postprocessing import run_post_processing + + +def _get_signature(store, qualified_name): + row = store._conn.execute( + "SELECT signature FROM nodes WHERE qualified_name = ?", + (qualified_name,), + ).fetchone() + return row["signature"] if row else None + + +class TestRunPostProcessing: + def setup_method(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.store = GraphStore(self.tmp.name) + self._seed_data() + + def teardown_method(self): + self.store.close() + Path(self.tmp.name).unlink(missing_ok=True) + + def _seed_data(self): + self.store.upsert_node( + NodeInfo( + kind="File", + name="/repo/app.py", + file_path="/repo/app.py", + line_start=1, + line_end=50, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Class", + name="Service", + file_path="/repo/app.py", + line_start=5, + line_end=40, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="handle", + file_path="/repo/app.py", + line_start=10, + line_end=20, + language="python", + parent_name="Service", + params="request", + return_type="Response", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="process", + file_path="/repo/app.py", + line_start=25, + line_end=35, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Test", + name="test_handle", + file_path="/repo/test_app.py", + line_start=1, + line_end=10, + language="python", + is_test=True, + ) + ) + + self.store.upsert_edge( + EdgeInfo( + kind="CONTAINS", + source="/repo/app.py", + target="/repo/app.py::Service", + file_path="/repo/app.py", + ) + ) + self.store.upsert_edge( + EdgeInfo( + kind="CONTAINS", + source="/repo/app.py::Service", + target="/repo/app.py::Service.handle", + file_path="/repo/app.py", + ) + ) + self.store.upsert_edge( + EdgeInfo( + kind="CALLS", + source="/repo/app.py::Service.handle", + target="/repo/app.py::process", + file_path="/repo/app.py", + line=15, + ) + ) + self.store.commit() + + def test_computes_signatures(self): + unsigned = self.store.get_nodes_without_signature() + assert len(unsigned) > 0 + + result = run_post_processing(self.store) + + assert result["signatures_computed"] > 0 + remaining = self.store.get_nodes_without_signature() + assert len(remaining) == 0 + + def test_function_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/app.py::Service.handle") + assert sig == "def handle(request) -> Response" + + def test_class_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/app.py::Service") + assert sig == "class Service" + + def test_test_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/test_app.py::test_handle") + assert sig is not None + assert sig.startswith("def test_handle(") + + def test_rebuilds_fts_index(self): + result = run_post_processing(self.store) + + assert "fts_indexed" in result + assert result["fts_indexed"] > 0 + + def test_fts_search_works_after_post_processing(self): + run_post_processing(self.store) + + from code_review_graph.search import hybrid_search + + hits = hybrid_search(self.store, "handle") + names = {h["name"] for h in hits} + assert "handle" in names + + def test_detects_flows(self): + result = run_post_processing(self.store) + + assert "flows_detected" in result + assert result["flows_detected"] >= 0 + + def test_detects_communities(self): + result = run_post_processing(self.store) + + assert "communities_detected" in result + assert result["communities_detected"] >= 0 + + def test_no_warnings_on_healthy_store(self): + result = run_post_processing(self.store) + + assert "warnings" not in result + + def test_empty_store_no_crash(self): + empty_tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + empty_store = GraphStore(empty_tmp.name) + try: + result = run_post_processing(empty_store) + assert result["signatures_computed"] == 0 + assert result["fts_indexed"] == 0 + finally: + empty_store.close() + Path(empty_tmp.name).unlink(missing_ok=True) + + def test_idempotent(self): + first = run_post_processing(self.store) + second = run_post_processing(self.store) + + assert second["fts_indexed"] == first["fts_indexed"] + assert second["signatures_computed"] == 0 + + def test_signature_truncated_at_512(self): + self.store.upsert_node( + NodeInfo( + kind="Function", + name="f", + file_path="/repo/big.py", + line_start=1, + line_end=2, + language="python", + params="a" * 600, + ) + ) + self.store.commit() + + run_post_processing(self.store) + sig = _get_signature(self.store, "/repo/big.py::f") + assert sig is not None + assert len(sig) <= 512 + + +class TestPostProcessingStepIsolation: + def setup_method(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.store = GraphStore(self.tmp.name) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="fn", + file_path="/repo/a.py", + line_start=1, + line_end=5, + language="python", + ) + ) + self.store.commit() + + def teardown_method(self): + self.store.close() + Path(self.tmp.name).unlink(missing_ok=True) + + def test_fts_failure_does_not_block_flows(self): + with patch( + "code_review_graph.search.rebuild_fts_index", + side_effect=ImportError("fts boom"), + ): + result = run_post_processing(self.store) + + assert "flows_detected" in result + assert "communities_detected" in result + assert "warnings" in result + assert any("FTS" in w for w in result["warnings"]) + + def test_flow_failure_does_not_block_communities(self): + with patch( + "code_review_graph.flows.trace_flows", + side_effect=ImportError("flow boom"), + ): + result = run_post_processing(self.store) + + assert "communities_detected" in result + assert "warnings" in result + assert any("Flow" in w for w in result["warnings"]) + + def test_community_failure_still_has_signatures(self): + with patch( + "code_review_graph.communities.detect_communities", + side_effect=ImportError("comm boom"), + ): + result = run_post_processing(self.store) + + assert result["signatures_computed"] > 0 + assert "warnings" in result + assert any("Community" in w for w in result["warnings"]) + + +class TestToolBuildUsesSharedPipeline: + def test_build_tool_runs_post_processing(self, tmp_path): + py_file = tmp_path / "sample.py" + py_file.write_text("def hello():\n pass\n") + (tmp_path / ".git").mkdir() + (tmp_path / ".code-review-graph").mkdir() + + db_path = tmp_path / ".code-review-graph" / "graph.db" + store = GraphStore(db_path) + try: + mock_target = "code_review_graph.incremental.get_all_tracked_files" + with patch(mock_target, return_value=["sample.py"]): + full_build(tmp_path, store) + + unsigned_before_pp = store.get_nodes_without_signature() + run_post_processing(store) + unsigned_after_pp = store.get_nodes_without_signature() + + assert len(unsigned_before_pp) > 0 + assert len(unsigned_after_pp) == 0 + finally: + store.close() + + +class TestWatchCallbackIntegration: + def test_watch_accepts_callback_parameter(self): + import inspect + + from code_review_graph.incremental import watch + + sig = inspect.signature(watch) + assert "on_files_updated" in sig.parameters + + def test_watch_callback_not_called_without_updates(self, tmp_path): + import threading + + from code_review_graph.incremental import watch + + (tmp_path / ".git").mkdir() + db_path = tmp_path / "test.db" + store = GraphStore(db_path) + callback = MagicMock() + + try: + + def run_watch(): + try: + watch(tmp_path, store, on_files_updated=callback) + except KeyboardInterrupt: + pass + + t = threading.Thread(target=run_watch, daemon=True) + t.start() + + import time + + time.sleep(0.5) + callback.assert_not_called() + finally: + store.close()