diff --git a/CHANGELOG.md b/CHANGELOG.md index 35a6f8e2..5ef03049 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,46 +2,6 @@ ## [Unreleased] -## [2.3.2] - 2026-04-14 - -Major feature release — 15 new capabilities, 6 community PRs merged, 6 new MCP tools, 4 new languages, multi-format export, and graph analysis suite. - -### Added - -- **Hub node detection** (`get_hub_nodes_tool`): find the most-connected nodes in the codebase (architectural hotspots) by in+out degree, excluding File nodes. -- **Bridge node detection** (`get_bridge_nodes_tool`): find architectural chokepoints via betweenness centrality with sampling approximation for graphs >5000 nodes. -- **Knowledge gap analysis** (`get_knowledge_gaps_tool`): identify structural weaknesses — isolated nodes, thin communities (<3 members), untested hotspots, and single-file communities. -- **Surprise scoring** (`get_surprising_connections_tool`): composite scoring for unexpected architectural coupling (cross-community, cross-language, peripheral-to-hub, cross-test-boundary). -- **Suggested questions** (`get_suggested_questions_tool`): auto-generate prioritized review questions from graph analysis (bridge nodes, untested hubs, surprising connections, thin communities). -- **BFS/DFS traversal** (`traverse_graph_tool`): free-form graph exploration from any node with configurable depth (1-6) and token budget. -- **Edge confidence scoring**: three-tier system (EXTRACTED/INFERRED/AMBIGUOUS) with float confidence scores on all edges. Schema migration v9. -- **Export formats**: GraphML (Gephi/yEd/Cytoscape), Neo4j Cypher statements, Obsidian vault (wikilinks + YAML frontmatter + community pages), SVG static graph. CLI: `visualize --format graphml|cypher|obsidian|svg`. -- **Graph diff**: snapshot/compare graph state over time — new/removed nodes, edges, community membership changes. -- **Token reduction benchmark**: measure naive full-corpus tokens vs graph query tokens with per-question reduction ratios. -- **Memory/feedback loop**: persist Q&A results as markdown for re-ingestion via `save_result` / `list_memories` / `clear_memories`. -- **Oversized community auto-splitting**: communities exceeding 25% of graph are recursively split via Leiden algorithm. -- **4 new languages**: Zig, PowerShell, Julia, Svelte SFC (23 total). -- **Visualization enhancements**: node size scaled by degree, community legend with toggle visibility, improved interactivity. -- **README translations**: Simplified Chinese, Japanese, Korean, Hindi. - -### Merged community PRs - -- **#127** (xtfer): SQLite compound edge indexes for query performance. -- **#184** (realkotob): batch `_compute_summaries` — fixes build hangs on large repos. -- **#202** (lngyeen): Swift extension detection, inheritance edges, type kind metadata. -- **#249** (gzenz): community detection resolution scaling (21x speedup), expanded framework patterns, framework-aware dead code detection (56 new tests). -- **#253** (cwoolum): automatic graph build for new worktrees in Claude Code. -- **#267** (jindalarpit): Kiro platform support with 9 tests. - -### Changed - -- MCP tool count: 22 → 28. -- Schema version: 8 → 9 (edge confidence columns). -- Community detection uses resolution scaling for large graphs. -- Risk scoring uses weighted flow criticality and graduated test coverage. -- Dead code detection is framework-aware (ORM models, Pydantic, CDK constructs filtered). -- Flow entry points expanded with 30+ framework decorator patterns. - ## [2.3.1] - 2026-04-11 Hotfix for the Windows long-running-MCP-tool hang that v2.2.4 only partially fixed. @@ -150,14 +110,50 @@ Hotfix on top of 2.2.3 for two bugs surfaced by a full first-time-user smoke tes ### Added - **Codex platform install support** (PR #177): `code-review-graph install --platform codex` appends a `mcp_servers.code-review-graph` section to `~/.codex/config.toml` without overwriting existing Codex settings. -- **Luau language support** (PR #165, closes #153): Roblox Luau (`.luau`) parsing — functions, classes, local functions, requires, tests. +- **Luau language support** (PR #165, closes #153): Roblox Luau (`.luau`) parsing -- functions, classes, local functions, requires, tests. - **REFERENCES edge type** (PR #217): New edge kind for symbol references that aren't direct calls (map/dispatch lookups, string-keyed handlers), including Python and TypeScript patterns. - **`recurse_submodules` build option** (PR #215): Build/update can now optionally recurse into git submodules. - **`.gitignore` default for `.code-review-graph/`** (PR #185): Fresh installs automatically add the SQLite DB directory to `.gitignore` so the database isn't accidentally committed. - **Clearer gitignore docs** (PR #171, closes #157): Documentation now spells out that `code-review-graph` already respects `.gitignore` via `git ls-files`. +- **Parser refactoring**: Extracted 16 per-language handler modules into `code_review_graph/lang/` package using a strategy pattern, replacing monolithic conditionals in `parser.py` +- **Jedi-based call resolution**: New `jedi_resolver.py` module resolves Python method calls at build time via Jedi static analysis, with pre-scan filtering by project function names (36s to 3s on large repos) +- **PreToolUse search enrichment**: New `enrich.py` module and `code-review-graph enrich` CLI command inject graph context (callers, callees, flows, community, tests) into agent search results passively +- **Typed variable call enrichment**: Track constructor-based type inference and instance method calls for Python, JS/TS, and Kotlin/Java +- **Star import resolution**: Resolve `from module import *` by scanning target module's exported names +- **Namespace imports**: Track `import * as X from 'module'` and CommonJS `require()` patterns +- **Angular template parsing**: Extract call targets from Angular component templates +- **JSX handler tracking**: Detect function/class references passed as JSX event handler props +- **Framework decorator recognition**: Identify entry points decorated with `@app.route`, `@router.get`, `@cli.command`, etc., reducing dead code false positives +- **Module-level import tracking**: Track module-qualified call resolution (`module.function()`) +- **Thread safety**: Double-check locking on parser caches (`_type_sets`, `_get_parser`, `_resolve_module_to_file`, `_get_exported_names`) +- **Batch file storage**: `store_file_batch()` groups file insertions into 50-file transactions for faster builds +- **Bulk node loading**: `get_all_nodes()` replaces per-file SQL queries for community detection +- **Adjacency-indexed cohesion**: Community cohesion computed in O(community-edges) instead of O(all-edges), yielding 21x speedup (48.6s to 2.3s on 41k-node repos) +- **Phase timing instrumentation**: `time.perf_counter()` timing at INFO level for all build phases +- **Batch risk_index**: 2 GROUP BY queries replace per-node COUNT loops in risk scoring +- **Weighted flow risk scoring**: Risk scores weighted by flow criticality instead of flat edge counts +- **Transitive TESTED_BY lookup**: `tests_for` and risk scoring follow transitive test relationships +- **DB schema v8**: Composite edge index for upsert performance (v7 reserved by upstream PR #127) +- **`--quiet` and `--json` CLI flags**: Machine-readable output for `build`, `update`, `status` +- **829+ tests** across 26 test files (up from 615), including `test_pain_points.py` (1,587 lines TDD suite), `test_hardened.py` (467 lines), `test_enrich.py` (237 lines) +- **14 new test fixtures**: Kotlin, Java, TypeScript, JSX, Python resolution scenarios ### Changed -- Community detection is now bounded — large repos complete in reasonable time instead of hanging indefinitely. +- Community detection is now bounded -- large repos complete in reasonable time instead of hanging indefinitely. +- New `[enrichment]` optional dependency group for Jedi-based Python call resolution +- Leiden community detection scales resolution parameter with graph size +- Adaptive directory-based fallback for community detection when Leiden produces poor clusters +- Search query deduplication and test function deprioritization + +### Fixed +- **Dead code false positives**: Decorators, CDK construct methods, abstract overrides, and overriding methods with called parents no longer flagged as dead +- **E2e test exclusion**: Playwright/Cypress e2e test directories excluded from dead code detection +- **Unique-name plausible caller optimization**: Faster dead code analysis via pre-filtered candidate sets +- **Store cache liveness check**: Cached SQLite connections verified as alive before reuse + +### Performance +- **Community detection**: 48.6s to 2.3s (21x) on Gadgetbridge (41k nodes, 280k edges) +- **Jedi enrichment**: 36s to 3s (12x) via pre-scan filtering by project function names ## [2.2.2] - 2026-04-08 diff --git a/CLAUDE.md b/CLAUDE.md index 682cfea0..b67ec637 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -46,7 +46,7 @@ When using code-review-graph MCP tools, follow these rules: ```bash # Development -uv run pytest tests/ --tb=short -q # Run tests (572 tests) +uv run pytest tests/ --tb=short -q # Run tests (609 tests) uv run ruff check code_review_graph/ # Lint uv run mypy code_review_graph/ --ignore-missing-imports --no-strict-optional diff --git a/README.md b/README.md index aa389d12..25e91f27 100644 --- a/README.md +++ b/README.md @@ -5,23 +5,14 @@

- English | - 简体中文 | - 日本語 | - 한국어 | - हिन्दी -

- -

- PyPI - Downloads + Website + Discord Stars MIT Licence CI Python 3.10+ MCP - Website - Discord + v2.1.0


@@ -45,7 +36,7 @@ code-review-graph build # parse your codebase One command sets up everything. `install` detects which AI coding tools you have, writes the correct MCP configuration for each one, and injects graph-aware instructions into your platform rules. It auto-detects whether you installed via `uvx` or `pip`/`pipx` and generates the right config. Restart your editor/tool after installing.

- One Install, Every Platform: auto-detects Codex, Claude Code, Cursor, Windsurf, Zed, Continue, OpenCode, Antigravity, and Kiro + One Install, Every Platform: auto-detects Codex, Claude Code, Cursor, Windsurf, Zed, Continue, OpenCode, and Antigravity

To target a specific platform: @@ -54,7 +45,6 @@ To target a specific platform: code-review-graph install --platform codex # configure only Codex code-review-graph install --platform cursor # configure only Cursor code-review-graph install --platform claude-code # configure only Claude Code -code-review-graph install --platform kiro # configure only Kiro ``` Requires Python 3.10+. For the best experience, install [uv](https://docs.astral.sh/uv/) (the MCP config will use `uvx` if available, otherwise falls back to the `code-review-graph` command directly). @@ -105,13 +95,13 @@ Large monorepos are where token waste is most painful. The graph cuts through th Next.js monorepo: 27,732 files funnelled through code-review-graph down to ~15 files — 49x fewer tokens

-### 23 languages + Jupyter notebooks +### 19 languages + Jupyter notebooks

19 languages organized by category: Web, Backend, Systems, Mobile, Scripting, plus Jupyter/Databricks notebook support

-Full Tree-sitter grammar support for functions, classes, imports, call sites, inheritance, and test detection in every language. Includes Zig, PowerShell, Julia, and Svelte SFC support. Plus Jupyter/Databricks notebook parsing (`.ipynb`) with multi-language cell support (Python, R, SQL), and Perl XS files (`.xs`). +Full Tree-sitter grammar support for functions, classes, imports, call sites, inheritance, and test detection in every language. Plus Jupyter/Databricks notebook parsing (`.ipynb`) with multi-language cell support (Python, R, SQL), and Perl XS files (`.xs`). --- @@ -193,33 +183,22 @@ The blast-radius analysis never misses an actually impacted file (perfect recall | Feature | Details | |---------|---------| | **Incremental updates** | Re-parses only changed files. Subsequent updates complete in under 2 seconds. | -| **23 languages + notebooks** | Python, TypeScript/TSX, JavaScript, Vue, Svelte, Go, Rust, Java, Scala, C#, Ruby, Kotlin, Swift, PHP, Solidity, C/C++, Dart, R, Perl, Lua, Zig, PowerShell, Julia, Jupyter/Databricks (.ipynb) | +| **19 languages + notebooks** | Python, TypeScript/TSX, JavaScript, Vue, Go, Rust, Java, Scala, C#, Ruby, Kotlin, Swift, PHP, Solidity, C/C++, Dart, R, Perl, Lua, Jupyter/Databricks (.ipynb) | | **Blast-radius analysis** | Shows exactly which functions, classes, and files are affected by any change | | **Auto-update hooks** | Graph updates on every file edit and git commit without manual intervention | | **Semantic search** | Optional vector embeddings via sentence-transformers, Google Gemini, or MiniMax | -| **Interactive visualisation** | D3.js force-directed graph with search, community legend toggles, and degree-scaled nodes | -| **Hub & bridge detection** | Find most-connected nodes and architectural chokepoints via betweenness centrality | -| **Surprise scoring** | Detect unexpected coupling: cross-community, cross-language, peripheral-to-hub edges | -| **Knowledge gap analysis** | Identify isolated nodes, untested hotspots, thin communities, and structural weaknesses | -| **Suggested questions** | Auto-generated review questions from graph analysis (bridges, hubs, surprises) | -| **Edge confidence** | Three-tier confidence scoring (EXTRACTED/INFERRED/AMBIGUOUS) with float scores on edges | -| **Graph traversal** | Free-form BFS/DFS exploration from any node with configurable depth and token budget | -| **Export formats** | GraphML (Gephi/yEd), Neo4j Cypher, Obsidian vault with wikilinks, SVG static graph | -| **Graph diff** | Compare graph snapshots over time: new/removed nodes, edges, community changes | -| **Token benchmarking** | Measure naive full-corpus tokens vs graph query tokens with per-question ratios | -| **Memory loop** | Persist Q&A results as markdown for re-ingestion, so the graph grows from queries | -| **Community auto-split** | Oversized communities (>25% of graph) are recursively split via Leiden | -| **Execution flows** | Trace call chains from entry points, sorted by weighted criticality | -| **Community detection** | Cluster related code via Leiden algorithm with resolution scaling for large graphs | +| **Interactive visualisation** | D3.js force-directed graph with edge-type toggles and search | +| **Local storage** | SQLite file in `.code-review-graph/`. No external database, no cloud dependency. | +| **Watch mode** | Continuous graph updates as you work | +| **Execution flows** | Trace call chains from entry points, sorted by criticality | +| **Community detection** | Cluster related code via Leiden algorithm or file grouping | | **Architecture overview** | Auto-generated architecture map with coupling warnings | | **Risk-scored reviews** | `detect_changes` maps diffs to affected functions, flows, and test gaps | -| **Refactoring tools** | Rename preview, framework-aware dead code detection, community-driven suggestions | +| **Refactoring tools** | Rename preview, dead code detection, community-driven suggestions | | **Wiki generation** | Auto-generate markdown wiki from community structure | | **Multi-repo registry** | Register multiple repos, search across all of them | | **MCP prompts** | 5 workflow templates: review, architecture, debug, onboard, pre-merge | | **Full-text search** | FTS5-powered hybrid search combining keyword and vector similarity | -| **Local storage** | SQLite file in `.code-review-graph/`. No external database, no cloud dependency. | -| **Watch mode** | Continuous graph updates as you work | --- @@ -249,12 +228,9 @@ code-review-graph update # Incremental update (changed files only) code-review-graph status # Graph statistics code-review-graph watch # Auto-update on file changes code-review-graph visualize # Generate interactive HTML graph -code-review-graph visualize --format graphml # Export as GraphML -code-review-graph visualize --format svg # Export as SVG -code-review-graph visualize --format obsidian # Export as Obsidian vault -code-review-graph visualize --format cypher # Export as Neo4j Cypher code-review-graph wiki # Generate markdown wiki from communities code-review-graph detect-changes # Risk-scored change impact analysis +code-review-graph enrich # Enrich search results with graph context code-review-graph register # Register repo in multi-repo registry code-review-graph unregister # Remove repo from registry code-review-graph repos # List registered repositories @@ -265,7 +241,7 @@ code-review-graph serve # Start MCP server
-28 MCP tools +22 MCP tools
Your AI assistant uses these automatically once the graph is built. @@ -273,11 +249,9 @@ Your AI assistant uses these automatically once the graph is built. | Tool | Description | |------|-------------| | `build_or_update_graph_tool` | Build or incrementally update the graph | -| `get_minimal_context_tool` | Ultra-compact context (~100 tokens) — call this first | | `get_impact_radius_tool` | Blast radius of changed files | | `get_review_context_tool` | Token-optimised review context with structural summary | | `query_graph_tool` | Callers, callees, tests, imports, inheritance queries | -| `traverse_graph_tool` | BFS/DFS traversal from any node with token budget | | `semantic_search_nodes_tool` | Search code entities by name or meaning | | `embed_graph_tool` | Compute vector embeddings for semantic search | | `list_graph_stats_tool` | Graph size and health | @@ -290,11 +264,6 @@ Your AI assistant uses these automatically once the graph is built. | `get_community_tool` | Get details of a single community | | `get_architecture_overview_tool` | Architecture overview from community structure | | `detect_changes_tool` | Risk-scored change impact analysis for code review | -| `get_hub_nodes_tool` | Find most-connected nodes (architectural hotspots) | -| `get_bridge_nodes_tool` | Find chokepoints via betweenness centrality | -| `get_knowledge_gaps_tool` | Identify structural weaknesses and untested hotspots | -| `get_surprising_connections_tool` | Detect unexpected cross-community coupling | -| `get_suggested_questions_tool` | Auto-generated review questions from analysis | | `refactor_tool` | Rename preview, dead code detection, suggestions | | `apply_refactor_tool` | Apply a previously previewed refactoring | | `generate_wiki_tool` | Generate markdown wiki from communities | @@ -328,6 +297,7 @@ Optional dependency groups: pip install code-review-graph[embeddings] # Local vector embeddings (sentence-transformers) pip install code-review-graph[google-embeddings] # Google Gemini embeddings pip install code-review-graph[communities] # Community detection (igraph) +pip install code-review-graph[enrichment] # Jedi-based Python call resolution pip install code-review-graph[eval] # Evaluation benchmarks (matplotlib) pip install code-review-graph[wiki] # Wiki generation with LLM summaries (ollama) pip install code-review-graph[all] # All optional dependencies @@ -351,7 +321,7 @@ pytest Adding a new language
-Edit `code_review_graph/parser.py` and add your extension to `EXTENSION_TO_LANGUAGE` along with node type mappings in `_CLASS_TYPES`, `_FUNCTION_TYPES`, `_IMPORT_TYPES`, and `_CALL_TYPES`. Include a test fixture and open a PR. +Edit the appropriate language handler in `code_review_graph/lang/` (e.g., `_python.py`, `_kotlin.py`) or create a new one following `_base.py`. Add your extension to `EXTENSION_TO_LANGUAGE` in `parser.py`, include a test fixture, and open a PR.
@@ -363,5 +333,5 @@ MIT. See [LICENSE](LICENSE).
code-review-graph.com

pip install code-review-graph && code-review-graph install
-Works with Codex, Claude Code, Cursor, Windsurf, Zed, Continue, OpenCode, Antigravity, and Kiro +Works with Codex, Claude Code, Cursor, Windsurf, Zed, Continue, OpenCode, and Antigravity

diff --git a/code_review_graph/cli.py b/code_review_graph/cli.py index 70b0b7cd..32d836c0 100644 --- a/code_review_graph/cli.py +++ b/code_review_graph/cli.py @@ -11,6 +11,7 @@ code-review-graph visualize code-review-graph wiki code-review-graph detect-changes [--base BASE] [--brief] + code-review-graph enrich code-review-graph register [--alias name] code-review-graph unregister code-review-graph repos @@ -250,6 +251,67 @@ def _handle_init(args: argparse.Namespace) -> None: print(" 2. Restart your AI coding tool to pick up the new config") +def _run_post_processing(store, quiet: bool = False) -> None: + """Run signatures, FTS, flows, and communities after build/update.""" + import sqlite3 + + # Signatures + try: + nodes = store._conn.execute( + "SELECT id, name, kind, params, return_type FROM nodes " + "WHERE kind IN ('Function','Test','Class')" + ).fetchall() + for row in nodes: + node_id, name, kind, params, ret = row + if kind in ("Function", "Test"): + sig = f"{name}({params or ''})" + if ret: + sig += f" -> {ret}" + elif kind == "Class": + sig = f"class {name}" + else: + sig = name + store.update_node_signature(node_id, sig[:512]) + store.commit() + except (sqlite3.OperationalError, TypeError, KeyError) as e: + if not quiet: + print(f"Warning: signature computation failed: {e}") + + # FTS index + try: + from .search import rebuild_fts_index + fts_count = rebuild_fts_index(store) + if not quiet: + print(f"FTS indexed: {fts_count} nodes") + except (sqlite3.OperationalError, ImportError) as e: + if not quiet: + print(f"Warning: FTS index rebuild failed: {e}") + + # Flows + try: + from .flows import store_flows as _store_flows + from .flows import trace_flows as _trace_flows + flows = _trace_flows(store) + count = _store_flows(store, flows) + if not quiet: + print(f"Flows detected: {count}") + except (sqlite3.OperationalError, ImportError) as e: + if not quiet: + print(f"Warning: flow detection failed: {e}") + + # Communities + try: + from .communities import detect_communities as _detect_communities + from .communities import store_communities as _store_communities + comms = _detect_communities(store) + count = _store_communities(store, comms) + if not quiet: + print(f"Communities detected: {count}") + except (sqlite3.OperationalError, ImportError) as e: + if not quiet: + print(f"Warning: community detection failed: {e}") + + def main() -> None: """Main CLI entry point.""" ap = argparse.ArgumentParser( @@ -342,6 +404,7 @@ def main() -> None: # build build_cmd = sub.add_parser("build", help="Full graph build (re-parse all files)") build_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + build_cmd.add_argument("-q", "--quiet", action="store_true", help="Suppress output") build_cmd.add_argument( "--skip-flows", action="store_true", help="Skip flow/community detection (signatures + FTS only)", @@ -355,6 +418,7 @@ def main() -> None: update_cmd = sub.add_parser("update", help="Incremental update (only changed files)") update_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") update_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + update_cmd.add_argument("-q", "--quiet", action="store_true", help="Suppress output") update_cmd.add_argument( "--skip-flows", action="store_true", help="Skip flow/community detection (signatures + FTS only)", @@ -381,6 +445,11 @@ def main() -> None: # status status_cmd = sub.add_parser("status", help="Show graph statistics") status_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + status_cmd.add_argument("-q", "--quiet", action="store_true", help="Suppress output") + status_cmd.add_argument( + "--json", action="store_true", dest="json_output", + help="Output as JSON", + ) # visualize vis_cmd = sub.add_parser("visualize", help="Generate interactive HTML graph visualization") @@ -448,6 +517,13 @@ def main() -> None: ) detect_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + # embed + embed_cmd = sub.add_parser("embed", help="Compute vector embeddings for graph nodes") + embed_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # enrich (PreToolUse hook -- reads hook JSON from stdin) + sub.add_parser("enrich", help="Enrich search results with graph context (hook)") + # serve serve_cmd = sub.add_parser("serve", help="Start MCP server (stdio transport)") serve_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") @@ -467,6 +543,28 @@ def main() -> None: serve_main(repo_root=args.repo) return + if args.command == "embed": + from .incremental import find_repo_root + repo_root = Path(args.repo) if args.repo else find_repo_root() + if not repo_root: + repo_root = Path.cwd() + db_path = repo_root / ".code-review-graph" / "graph.db" + if not db_path.exists(): + print("No graph database found. Run 'code-review-graph build' first.") + return + from .embeddings import EmbeddingStore, embed_all_nodes + from .graph import GraphStore + store = GraphStore(str(db_path)) + emb_store = EmbeddingStore(str(db_path)) + count = embed_all_nodes(store, emb_store) + print(f"Embedded {count} nodes.") + return + + if args.command == "enrich": + from .enrich import run_hook + run_hook() + return + if args.command == "eval": from .eval.reporter import generate_full_report, generate_readme_tables from .eval.runner import run_eval @@ -606,13 +704,14 @@ def main() -> None: parsed = result.get("files_parsed", 0) nodes = result.get("total_nodes", 0) edges = result.get("total_edges", 0) - print( - f"Full build: {parsed} files, " - f"{nodes} nodes, {edges} edges" - f" (postprocess={pp})" - ) - if result.get("errors"): - print(f"Errors: {len(result['errors'])}") + if not getattr(args, "quiet", False): + print( + f"Full build: {parsed} files, " + f"{nodes} nodes, {edges} edges" + f" (postprocess={pp})" + ) + if result.get("errors"): + print(f"Errors: {len(result['errors'])}") elif args.command == "update": pp = "none" if getattr(args, "skip_postprocess", False) else ( @@ -626,35 +725,53 @@ def main() -> None: updated = result.get("files_updated", 0) nodes = result.get("total_nodes", 0) edges = result.get("total_edges", 0) - print( - f"Incremental: {updated} files updated, " - f"{nodes} nodes, {edges} edges" - f" (postprocess={pp})" - ) + if not getattr(args, "quiet", False): + print( + f"Incremental: {updated} files updated, " + f"{nodes} nodes, {edges} edges" + f" (postprocess={pp})" + ) elif args.command == "status": + import json as json_mod stats = store.get_stats() - print(f"Nodes: {stats.total_nodes}") - print(f"Edges: {stats.total_edges}") - print(f"Files: {stats.files_count}") - print(f"Languages: {', '.join(stats.languages)}") - print(f"Last updated: {stats.last_updated or 'never'}") - # Show branch info and warn if stale stored_branch = store.get_metadata("git_branch") stored_sha = store.get_metadata("git_head_sha") - if stored_branch: - print(f"Built on branch: {stored_branch}") - if stored_sha: - print(f"Built at commit: {stored_sha[:12]}") from .incremental import _git_branch_info current_branch, current_sha = _git_branch_info(repo_root) + stale_warning = None if stored_branch and current_branch and stored_branch != current_branch: - print( - f"WARNING: Graph was built on '{stored_branch}' " + stale_warning = ( + f"Graph was built on '{stored_branch}' " f"but you are now on '{current_branch}'. " f"Run 'code-review-graph build' to rebuild." ) + if getattr(args, "json_output", False): + data = { + "nodes": stats.total_nodes, + "edges": stats.total_edges, + "files": stats.files_count, + "languages": list(stats.languages), + "last_updated": stats.last_updated, + "branch": stored_branch, + "commit": stored_sha[:12] if stored_sha else None, + "stale": stale_warning, + } + print(json_mod.dumps(data)) + elif not args.quiet: + print(f"Nodes: {stats.total_nodes}") + print(f"Edges: {stats.total_edges}") + print(f"Files: {stats.files_count}") + print(f"Languages: {', '.join(stats.languages)}") + print(f"Last updated: {stats.last_updated or 'never'}") + if stored_branch: + print(f"Built on branch: {stored_branch}") + if stored_sha: + print(f"Built at commit: {stored_sha[:12]}") + if stale_warning: + print(f"WARNING: {stale_warning}") + elif args.command == "watch": watch(repo_root, store) diff --git a/code_review_graph/enrich.py b/code_review_graph/enrich.py new file mode 100644 index 00000000..f95c334a --- /dev/null +++ b/code_review_graph/enrich.py @@ -0,0 +1,303 @@ +"""PreToolUse search enrichment for Claude Code hooks. + +Intercepts Grep/Glob/Bash/Read tool calls and enriches them with +structural context from the code knowledge graph: callers, callees, +execution flows, community membership, and test coverage. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import sys +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +# Flags that consume the next token in grep/rg commands +_RG_FLAGS_WITH_VALUES = frozenset({ + "-e", "-f", "-m", "-A", "-B", "-C", "-g", "--glob", + "-t", "--type", "--include", "--exclude", "--max-count", + "--max-depth", "--max-filesize", "--color", "--colors", + "--context-separator", "--field-match-separator", + "--path-separator", "--replace", "--sort", "--sortr", +}) + + +def extract_pattern(tool_name: str, tool_input: dict[str, Any]) -> str | None: + """Extract a search pattern from a tool call's input. + + Returns None if no meaningful pattern can be extracted. + """ + if tool_name == "Grep": + return tool_input.get("pattern") + + if tool_name == "Glob": + raw = tool_input.get("pattern", "") + # Extract meaningful name from glob: "**/auth*.ts" -> "auth" + # Skip pure extension globs like "**/*.ts" + match = re.search(r"[*/]([a-zA-Z][a-zA-Z0-9_]{2,})", raw) + return match.group(1) if match else None + + if tool_name == "Bash": + cmd = tool_input.get("command", "") + if not re.search(r"\brg\b|\bgrep\b", cmd): + return None + tokens = cmd.split() + found_cmd = False + skip_next = False + for token in tokens: + if skip_next: + skip_next = False + continue + if not found_cmd: + if re.search(r"\brg$|\bgrep$", token): + found_cmd = True + continue + if token.startswith("-"): + if token in _RG_FLAGS_WITH_VALUES: + skip_next = True + continue + cleaned = token.strip("'\"") + return cleaned if len(cleaned) >= 3 else None + return None + + return None + + +def _make_relative(file_path: str, repo_root: str) -> str: + """Make a file path relative to repo_root for display.""" + try: + return str(Path(file_path).relative_to(repo_root)) + except ValueError: + return file_path + + +def _get_community_name(conn: Any, community_id: int) -> str: + """Fetch a community name by ID.""" + row = conn.execute( + "SELECT name FROM communities WHERE id = ?", (community_id,) + ).fetchone() + return row["name"] if row else "" + + +def _get_flow_names_for_node(conn: Any, node_id: int) -> list[str]: + """Fetch execution flow names that a node participates in (max 3).""" + rows = conn.execute( + "SELECT f.name FROM flow_memberships fm " + "JOIN flows f ON fm.flow_id = f.id " + "WHERE fm.node_id = ? LIMIT 3", + (node_id,), + ).fetchall() + return [r["name"] for r in rows] + + +def _format_node_context( + node: Any, + store: Any, + conn: Any, + repo_root: str, +) -> list[str]: + """Format a single node's structural context as plain text lines.""" + from .graph import GraphNode + assert isinstance(node, GraphNode) + + qn = node.qualified_name + loc = _make_relative(node.file_path, repo_root) + if node.line_start: + loc = f"{loc}:{node.line_start}" + + header = f"{node.name} ({loc})" + + # Community + if node.extra.get("community_id"): + cname = _get_community_name(conn, node.extra["community_id"]) + if cname: + header += f" [{cname}]" + else: + # Check via direct query + row = conn.execute( + "SELECT community_id FROM nodes WHERE id = ?", (node.id,) + ).fetchone() + if row and row["community_id"]: + cname = _get_community_name(conn, row["community_id"]) + if cname: + header += f" [{cname}]" + + lines = [header] + + # Callers (max 5, deduplicated) + callers: list[str] = [] + seen: set[str] = set() + for e in store.get_edges_by_target(qn): + if e.kind == "CALLS" and len(callers) < 5: + c = store.get_node(e.source_qualified) + if c and c.name not in seen: + seen.add(c.name) + callers.append(c.name) + if callers: + lines.append(f" Called by: {', '.join(callers)}") + + # Callees (max 5, deduplicated) + callees: list[str] = [] + seen.clear() + for e in store.get_edges_by_source(qn): + if e.kind == "CALLS" and len(callees) < 5: + c = store.get_node(e.target_qualified) + if c and c.name not in seen: + seen.add(c.name) + callees.append(c.name) + if callees: + lines.append(f" Calls: {', '.join(callees)}") + + # Execution flows + flow_names = _get_flow_names_for_node(conn, node.id) + if flow_names: + lines.append(f" Flows: {', '.join(flow_names)}") + + # Tests + tests: list[str] = [] + for e in store.get_edges_by_target(qn): + if e.kind == "TESTED_BY" and len(tests) < 3: + t = store.get_node(e.source_qualified) + if t: + tests.append(t.name) + if tests: + lines.append(f" Tests: {', '.join(tests)}") + + return lines + + +def enrich_search(pattern: str, repo_root: str) -> str: + """Search the graph for pattern and return enriched context.""" + from .graph import GraphStore + from .search import _fts_search + + db_path = Path(repo_root) / ".code-review-graph" / "graph.db" + if not db_path.exists(): + return "" + + store = GraphStore(db_path) + try: + conn = store._conn + + fts_results = _fts_search(conn, pattern, limit=8) + if not fts_results: + return "" + + all_lines: list[str] = [] + count = 0 + for node_id, _score in fts_results: + if count >= 5: + break + node = store.get_node_by_id(node_id) + if not node or node.is_test: + continue + node_lines = _format_node_context(node, store, conn, repo_root) + all_lines.extend(node_lines) + all_lines.append("") + count += 1 + + if not all_lines: + return "" + + header = f'[code-review-graph] {count} symbol(s) matching "{pattern}":\n' + return header + "\n".join(all_lines) + finally: + store.close() + + +def enrich_file_read(file_path: str, repo_root: str) -> str: + """Enrich a file read with structural context for functions in that file.""" + from .graph import GraphStore + + db_path = Path(repo_root) / ".code-review-graph" / "graph.db" + if not db_path.exists(): + return "" + + store = GraphStore(db_path) + try: + conn = store._conn + nodes = store.get_nodes_by_file(file_path) + if not nodes: + # Try with resolved path + try: + resolved = str(Path(file_path).resolve()) + nodes = store.get_nodes_by_file(resolved) + except (OSError, ValueError): + pass + if not nodes: + return "" + + # Filter to functions/classes/types (skip File nodes), limit to 10 + interesting = [ + n for n in nodes + if n.kind in ("Function", "Class", "Type", "Test") + ][:10] + + if not interesting: + return "" + + all_lines: list[str] = [] + for node in interesting: + node_lines = _format_node_context(node, store, conn, repo_root) + all_lines.extend(node_lines) + all_lines.append("") + + rel_path = _make_relative(file_path, repo_root) + header = ( + f"[code-review-graph] {len(interesting)} symbol(s) in {rel_path}:\n" + ) + return header + "\n".join(all_lines) + finally: + store.close() + + +def run_hook() -> None: + """Entry point for the enrich CLI subcommand. + + Reads Claude Code hook JSON from stdin, extracts the search pattern, + queries the graph, and outputs hookSpecificOutput JSON to stdout. + """ + try: + hook_input = json.load(sys.stdin) + except (json.JSONDecodeError, ValueError): + return + + tool_name = hook_input.get("tool_name", "") + tool_input = hook_input.get("tool_input", {}) + cwd = hook_input.get("cwd", os.getcwd()) + + # Find repo root by walking up from cwd + from .incremental import find_project_root + + repo_root = str(find_project_root(Path(cwd))) + db_path = Path(repo_root) / ".code-review-graph" / "graph.db" + if not db_path.exists(): + return + + # Dispatch + context = "" + if tool_name == "Read": + fp = tool_input.get("file_path", "") + if fp: + context = enrich_file_read(fp, repo_root) + else: + pattern = extract_pattern(tool_name, tool_input) + if not pattern or len(pattern) < 3: + return + context = enrich_search(pattern, repo_root) + + if not context: + return + + response = { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "additionalContext": context, + } + } + json.dump(response, sys.stdout) diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index cfa672c0..b579a55d 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -62,6 +62,8 @@ "*.min.css", "*.map", "*.lock", + "*.bundle.js", + "cdk.out/**", "package-lock.json", "yarn.lock", "*.db", @@ -486,6 +488,16 @@ def _parse_single_file( return (rel_path, [], [], str(e), "") +def _run_jedi_enrichment(store: GraphStore, repo_root: Path) -> dict: + """Run optional Jedi enrichment for Python method calls.""" + try: + from .jedi_resolver import enrich_jedi_calls + return enrich_jedi_calls(store, repo_root) + except Exception as e: + logger.warning("Jedi enrichment failed: %s", e) + return {"error": str(e)} + + def full_build( repo_root: Path, store: GraphStore, @@ -520,6 +532,7 @@ def full_build( use_serial = os.environ.get("CRG_SERIAL_PARSE", "") == "1" + t0 = time.perf_counter() if use_serial or file_count < 8: # Serial fallback (for debugging or tiny repos) for i, rel_path in enumerate(files, 1): @@ -539,8 +552,10 @@ def full_build( if i % 50 == 0 or i == file_count: logger.info("Progress: %d/%d files parsed", i, file_count) else: - # Parallel parsing — store calls remain serial (SQLite single-writer) + # Parallel parsing -- batch store to reduce transaction overhead + batch_size = 50 args_list = [(rel_path, str(repo_root)) for rel_path in files] + batch: list[tuple[str, list, list, str]] = [] with concurrent.futures.ProcessPoolExecutor( max_workers=_MAX_PARSE_WORKERS, ) as executor: @@ -552,13 +567,28 @@ def full_build( errors.append({"file": rel_path, "error": error}) continue full_path = repo_root / rel_path - store.store_file_nodes_edges( - str(full_path), nodes, edges, fhash, - ) + batch.append((str(full_path), nodes, edges, fhash)) total_nodes += len(nodes) total_edges += len(edges) + if len(batch) >= batch_size: + store.store_file_batch(batch) + batch = [] if i % 200 == 0 or i == file_count: logger.info("Progress: %d/%d files parsed", i, file_count) + if batch: + store.store_file_batch(batch) + t_parse = time.perf_counter() + logger.info("Phase: parsing %d files took %.2fs", file_count, t_parse - t0) + + # Post-parse Jedi enrichment for Python method calls + jedi_stats = _run_jedi_enrichment(store, repo_root) + t_jedi = time.perf_counter() + logger.info("Phase: Jedi enrichment took %.2fs", t_jedi - t_parse) + + # Post-build: resolve bare-name CALLS targets across all files + bare_resolved = store.resolve_bare_call_targets() + t_bare = time.perf_counter() + logger.info("Phase: bare-name resolution took %.2fs", t_bare - t_jedi) store.set_metadata("last_updated", time.strftime("%Y-%m-%dT%H:%M:%S")) store.set_metadata("last_build_type", "full") @@ -574,6 +604,13 @@ def full_build( "total_nodes": total_nodes, "total_edges": total_edges, "errors": errors, + "jedi": jedi_stats, + "bare_resolved": bare_resolved, + "timing": { + "parse_s": round(t_parse - t0, 2), + "jedi_s": round(t_jedi - t_parse, 2), + "bare_resolve_s": round(t_bare - t_jedi, 2), + }, } @@ -650,6 +687,7 @@ def incremental_update( use_serial = os.environ.get("CRG_SERIAL_PARSE", "") == "1" + t0 = time.perf_counter() if use_serial or len(to_parse) < 8: for rel_path in to_parse: abs_path = repo_root / rel_path @@ -666,7 +704,9 @@ def incremental_update( logger.warning("Error parsing %s: %s", rel_path, e) errors.append({"file": rel_path, "error": str(e)}) else: + batch_size = 50 args_list = [(rel_path, str(repo_root)) for rel_path in to_parse] + batch: list[tuple[str, list, list, str]] = [] with concurrent.futures.ProcessPoolExecutor( max_workers=_MAX_PARSE_WORKERS, ) as executor: @@ -677,11 +717,26 @@ def incremental_update( logger.warning("Error parsing %s: %s", rel_path, error) errors.append({"file": rel_path, "error": error}) continue - store.store_file_nodes_edges( - str(repo_root / rel_path), nodes, edges, fhash, - ) + batch.append((str(repo_root / rel_path), nodes, edges, fhash)) total_nodes += len(nodes) total_edges += len(edges) + if len(batch) >= batch_size: + store.store_file_batch(batch) + batch = [] + if batch: + store.store_file_batch(batch) + t_parse = time.perf_counter() + logger.info("Phase: parsing %d files took %.2fs", len(to_parse), t_parse - t0) + + # Post-parse Jedi enrichment for Python method calls + jedi_stats = _run_jedi_enrichment(store, repo_root) + t_jedi = time.perf_counter() + logger.info("Phase: Jedi enrichment took %.2fs", t_jedi - t_parse) + + # Post-build: resolve bare-name CALLS targets across all files + bare_resolved = store.resolve_bare_call_targets() + t_bare = time.perf_counter() + logger.info("Phase: bare-name resolution took %.2fs", t_bare - t_jedi) store.set_metadata("last_updated", time.strftime("%Y-%m-%dT%H:%M:%S")) store.set_metadata("last_build_type", "incremental") @@ -699,6 +754,13 @@ def incremental_update( "changed_files": list(changed_files), "dependent_files": list(dependent_files), "errors": errors, + "jedi": jedi_stats, + "bare_resolved": bare_resolved, + "timing": { + "parse_s": round(t_parse - t0, 2), + "jedi_s": round(t_jedi - t_parse, 2), + "bare_resolve_s": round(t_bare - t_jedi, 2), + }, } diff --git a/code_review_graph/jedi_resolver.py b/code_review_graph/jedi_resolver.py new file mode 100644 index 00000000..8ec007e5 --- /dev/null +++ b/code_review_graph/jedi_resolver.py @@ -0,0 +1,303 @@ +"""Post-build Jedi enrichment for Python call resolution. + +After tree-sitter parsing, many method calls on lowercase-receiver variables +are dropped (e.g. ``svc.authenticate()`` where ``svc = factory()``). Jedi +can resolve these by tracing return types across files. + +This module runs as a post-build step: it re-walks Python ASTs to find +dropped calls, uses ``jedi.Script.goto()`` to resolve them, and adds the +resulting CALLS edges to the graph database. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Optional + +from .parser import CodeParser, EdgeInfo +from .parser import _is_test_file as _parser_is_test_file + +logger = logging.getLogger(__name__) + +_SELF_NAMES = frozenset({"self", "cls", "super"}) + + +def enrich_jedi_calls(store, repo_root: Path) -> dict: + """Resolve untracked Python method calls via Jedi. + + Walks Python files, finds ``receiver.method()`` calls that tree-sitter + dropped (lowercase receiver, not self/cls), resolves them with Jedi, + and inserts new CALLS edges. + + Returns stats dict with ``resolved`` count. + """ + try: + import jedi + except ImportError: + logger.info("Jedi not installed, skipping Python enrichment") + return {"skipped": True, "reason": "jedi not installed"} + + repo_root = Path(repo_root).resolve() + + # Get Python files from the graph — skip early if none + all_files = store.get_all_files() + py_files = [f for f in all_files if f.endswith(".py")] + + if not py_files: + return {"resolved": 0, "files": 0} + + # Scope the Jedi project to Python-only directories to avoid scanning + # non-Python files (e.g. node_modules, TS sources). This matters for + # polyglot monorepos where jedi.Project(path=repo_root) would scan + # thousands of irrelevant files during initialization. + py_dirs = sorted({str(Path(f).parent) for f in py_files}) + common_py_root = Path(os.path.commonpath(py_dirs)) if py_dirs else repo_root + if not str(common_py_root).startswith(str(repo_root)): + common_py_root = repo_root + project = jedi.Project( + path=str(common_py_root), + added_sys_path=[str(repo_root)], + smart_sys_path=False, + ) + + # Pre-parse all Python files to find which ones have pending method calls. + # This avoids expensive Jedi Script creation for files with nothing to resolve. + parser = CodeParser() + ts_parser = parser._get_parser("python") + if not ts_parser: + return {"resolved": 0, "files": 0} + + # Build set of method names that actually exist in project code. + # No point asking Jedi to resolve `logger.getLogger()` if no project + # file defines a function called `getLogger`. + project_func_names = { + r["name"] + for r in store._conn.execute( + "SELECT DISTINCT name FROM nodes WHERE kind IN ('Function', 'Test')" + ).fetchall() + } + + files_with_pending: list[tuple[str, bytes, list]] = [] + total_skipped = 0 + for file_path in py_files: + try: + source = Path(file_path).read_bytes() + except (OSError, PermissionError): + continue + tree = ts_parser.parse(source) + is_test = _parser_is_test_file(file_path) + pending = _find_untracked_method_calls(tree.root_node, is_test) + if pending: + # Only keep calls whose method name exists in project code + filtered = [p for p in pending if p[2] in project_func_names] + total_skipped += len(pending) - len(filtered) + if filtered: + files_with_pending.append((file_path, source, filtered)) + + if not files_with_pending: + return {"resolved": 0, "files": 0} + + logger.debug( + "Jedi: %d/%d Python files have pending calls (%d calls skipped — no project target)", + len(files_with_pending), len(py_files), total_skipped, + ) + + resolved_count = 0 + files_enriched = 0 + errors = 0 + + for file_path, source, pending in files_with_pending: + source_text = source.decode("utf-8", errors="replace") + + # Get existing CALLS edges for this file to skip duplicates + existing = set() + for edge in _get_file_call_edges(store, file_path): + existing.add((edge.source_qualified, edge.line)) + + # Get function nodes from DB for enclosing-function lookup + func_nodes = [ + n for n in store.get_nodes_by_file(file_path) + if n.kind in ("Function", "Test") + ] + + # Create Jedi script once per file + try: + script = jedi.Script(source_text, path=file_path, project=project) + except Exception as e: + logger.debug("Jedi failed to load %s: %s", file_path, e) + errors += 1 + continue + + file_resolved = 0 + for jedi_line, col, _method_name, _enclosing_name in pending: + # Find enclosing function qualified name + enclosing = _find_enclosing(func_nodes, jedi_line) + if not enclosing: + enclosing = file_path # module-level + + # Skip if we already have a CALLS edge from this source at this line + if (enclosing, jedi_line) in existing: + continue + + # Ask Jedi to resolve + try: + names = script.goto(jedi_line, col) + except Exception: # nosec B112 - Jedi may fail on malformed code + continue + + if not names: + continue + + name = names[0] + if not name.module_path: + continue + + module_path = Path(name.module_path).resolve() + + # Only emit edges for project-internal definitions + try: + module_path.relative_to(repo_root) + except ValueError: + continue + + # Build qualified target: file_path::Class.method or file_path::func + target_file = str(module_path) + parent = name.parent() + if parent and parent.type == "class": + target = f"{target_file}::{parent.name}.{name.name}" + else: + target = f"{target_file}::{name.name}" + + store.upsert_edge(EdgeInfo( + kind="CALLS", + source=enclosing, + target=target, + file_path=file_path, + line=jedi_line, + )) + existing.add((enclosing, jedi_line)) + file_resolved += 1 + + if file_resolved: + files_enriched += 1 + resolved_count += file_resolved + + if resolved_count: + store.commit() + logger.info( + "Jedi enrichment: resolved %d calls in %d files", + resolved_count, files_enriched, + ) + + return { + "resolved": resolved_count, + "files": files_enriched, + "errors": errors, + } + + +def _get_file_call_edges(store, file_path: str): + """Get all CALLS edges originating from a file.""" + conn = store._conn + rows = conn.execute( + "SELECT * FROM edges WHERE file_path = ? AND kind = 'CALLS'", + (file_path,), + ).fetchall() + from .graph import GraphEdge + return [ + GraphEdge( + id=r["id"], kind=r["kind"], + source_qualified=r["source_qualified"], + target_qualified=r["target_qualified"], + file_path=r["file_path"], line=r["line"], + extra={}, + ) + for r in rows + ] + + +def _find_enclosing(func_nodes, line: int) -> Optional[str]: + """Find the qualified name of the function enclosing a given line.""" + best = None + best_span = float("inf") + for node in func_nodes: + if node.line_start <= line <= node.line_end: + span = node.line_end - node.line_start + if span < best_span: + best = node.qualified_name + best_span = span + return best + + +def _find_untracked_method_calls(root, is_test_file: bool = False): + """Walk Python AST to find method calls the parser would have dropped. + + Returns list of (jedi_line, col, method_name, enclosing_func_name) tuples. + Jedi_line is 1-indexed, col is 0-indexed. + """ + results: list[tuple[int, int, str, Optional[str]]] = [] + _walk_calls(root, results, is_test_file, enclosing_func=None) + return results + + +def _walk_calls(node, results, is_test_file, enclosing_func): + """Recursively walk AST collecting dropped method calls.""" + # Track enclosing function scope + if node.type == "function_definition": + name = None + for child in node.children: + if child.type == "identifier": + name = child.text.decode("utf-8", errors="replace") + break + for child in node.children: + _walk_calls(child, results, is_test_file, name or enclosing_func) + return + + if node.type == "decorated_definition": + for child in node.children: + _walk_calls(child, results, is_test_file, enclosing_func) + return + + # Check for call expressions with attribute access + if node.type == "call": + first = node.children[0] if node.children else None + if first and first.type == "attribute": + _check_dropped_call(first, results, is_test_file, enclosing_func) + + for child in node.children: + _walk_calls(child, results, is_test_file, enclosing_func) + + +def _check_dropped_call(attr_node, results, is_test_file, enclosing_func): + """Check if an attribute-based call was dropped by the parser.""" + children = attr_node.children + if len(children) < 2: + return + + receiver = children[0] + # Only handle simple identifier receivers + if receiver.type != "identifier": + return + + receiver_text = receiver.text.decode("utf-8", errors="replace") + + # The parser keeps: self/cls/super calls and uppercase-receiver calls + # The parser keeps: calls handled by typed-var enrichment (but those are + # separate edges -- we check for duplicates via existing-edge set) + if receiver_text in _SELF_NAMES: + return + if receiver_text[:1].isupper(): + return + if is_test_file: + return # test files already track all calls + + # Find the method name identifier + method_node = children[-1] + if method_node.type != "identifier": + return + + row, col = method_node.start_point # 0-indexed + method_name = method_node.text.decode("utf-8", errors="replace") + results.append((row + 1, col, method_name, enclosing_func)) diff --git a/code_review_graph/lang/__init__.py b/code_review_graph/lang/__init__.py new file mode 100644 index 00000000..80b85b76 --- /dev/null +++ b/code_review_graph/lang/__init__.py @@ -0,0 +1,56 @@ +"""Per-language parsing handlers.""" + +from ._base import BaseLanguageHandler +from ._c_cpp import CHandler, CppHandler +from ._csharp import CSharpHandler +from ._dart import DartHandler +from ._go import GoHandler +from ._java import JavaHandler +from ._javascript import JavaScriptHandler, TsxHandler, TypeScriptHandler +from ._kotlin import KotlinHandler +from ._lua import LuaHandler, LuauHandler +from ._perl import PerlHandler +from ._php import PhpHandler +from ._python import PythonHandler +from ._r import RHandler +from ._ruby import RubyHandler +from ._rust import RustHandler +from ._scala import ScalaHandler +from ._solidity import SolidityHandler +from ._swift import SwiftHandler + +ALL_HANDLERS: list[BaseLanguageHandler] = [ + GoHandler(), + PythonHandler(), + JavaScriptHandler(), + TypeScriptHandler(), + TsxHandler(), + RustHandler(), + CHandler(), + CppHandler(), + JavaHandler(), + CSharpHandler(), + KotlinHandler(), + ScalaHandler(), + SolidityHandler(), + RubyHandler(), + DartHandler(), + SwiftHandler(), + PhpHandler(), + PerlHandler(), + RHandler(), + LuaHandler(), + LuauHandler(), +] + +__all__ = [ + "BaseLanguageHandler", "ALL_HANDLERS", + "GoHandler", "PythonHandler", + "JavaScriptHandler", "TypeScriptHandler", "TsxHandler", + "RustHandler", "CHandler", "CppHandler", + "JavaHandler", "CSharpHandler", "KotlinHandler", + "ScalaHandler", "SolidityHandler", + "RubyHandler", "DartHandler", + "SwiftHandler", "PhpHandler", "PerlHandler", + "RHandler", "LuaHandler", "LuauHandler", +] diff --git a/code_review_graph/lang/_base.py b/code_review_graph/lang/_base.py new file mode 100644 index 00000000..fb2ddca0 --- /dev/null +++ b/code_review_graph/lang/_base.py @@ -0,0 +1,62 @@ +"""Base class for language-specific parsing handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..parser import CodeParser, EdgeInfo, NodeInfo + + +class BaseLanguageHandler: + """Override methods where a language differs from default CodeParser logic. + + Methods returning ``NotImplemented`` signal 'use the default code path'. + Subclasses only need to override what they actually customise. + """ + + language: str = "" + class_types: list[str] = [] + function_types: list[str] = [] + import_types: list[str] = [] + call_types: list[str] = [] + builtin_names: frozenset[str] = frozenset() + + def get_name(self, node, kind: str) -> str | None: + return NotImplemented + + def get_bases(self, node, source: bytes) -> list[str]: + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + return NotImplemented + + def collect_import_names(self, node, file_path: str, import_map: dict[str, str]) -> bool: + """Populate import_map from an import node. Return True if handled.""" + return False + + def resolve_module(self, module: str, caller_file: str) -> str | None: + """Resolve a module path to a file path. Return NotImplemented to fall back.""" + return NotImplemented + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + """Handle language-specific AST constructs. + + Returns True if the child was fully handled (skip generic dispatch). + Default: returns False (no language-specific handling). + """ + return False diff --git a/code_review_graph/lang/_c_cpp.py b/code_review_graph/lang/_c_cpp.py new file mode 100644 index 00000000..9659db80 --- /dev/null +++ b/code_review_graph/lang/_c_cpp.py @@ -0,0 +1,41 @@ +"""C / C++ language handlers.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class _CBase(BaseLanguageHandler): + """Shared handler logic for C and C++.""" + + import_types = ["preproc_include"] + call_types = ["call_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type in ("system_lib_string", "string_literal"): + val = child.text.decode("utf-8", errors="replace").strip("<>\"") + imports.append(val) + return imports + + +class CHandler(_CBase): + language = "c" + class_types = ["struct_specifier", "type_definition"] + function_types = ["function_definition"] + + +class CppHandler(_CBase): + language = "cpp" + class_types = ["class_specifier", "struct_specifier"] + function_types = ["function_definition"] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "base_class_clause": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_csharp.py b/code_review_graph/lang/_csharp.py new file mode 100644 index 00000000..0821ecc7 --- /dev/null +++ b/code_review_graph/lang/_csharp.py @@ -0,0 +1,33 @@ +"""C# language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class CSharpHandler(BaseLanguageHandler): + language = "csharp" + class_types = [ + "class_declaration", "interface_declaration", + "enum_declaration", "struct_declaration", + ] + function_types = ["method_declaration", "constructor_declaration"] + import_types = ["using_directive"] + call_types = ["invocation_expression", "object_creation_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + parts = text.split() + if len(parts) >= 2: + return [parts[-1].rstrip(";")] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_dart.py b/code_review_graph/lang/_dart.py new file mode 100644 index 00000000..8d9b3069 --- /dev/null +++ b/code_review_graph/lang/_dart.py @@ -0,0 +1,65 @@ +"""Dart language handler.""" + +from __future__ import annotations + +from typing import Optional + +from ._base import BaseLanguageHandler + + +class DartHandler(BaseLanguageHandler): + language = "dart" + class_types = ["class_definition", "mixin_declaration", "enum_declaration"] + # function_signature covers both top-level functions and class methods + # (class methods appear as method_signature > function_signature pairs; + # the parser recurses into method_signature generically and then matches + # function_signature inside it). + function_types = ["function_signature"] + # import_or_export wraps library_import > import_specification > configurable_uri + import_types = ["import_or_export"] + call_types: list[str] = [] # Dart uses call_expression from fallback + + def get_name(self, node, kind: str) -> str | None: + # function_signature has a return-type node before the identifier; + # search only for 'identifier' to avoid returning the return type name. + if node.type == "function_signature": + for child in node.children: + if child.type == "identifier": + return child.text.decode("utf-8", errors="replace") + return None + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + val = self._find_string_literal(node) + if val: + return [val] + return [] + + @staticmethod + def _find_string_literal(node) -> Optional[str]: + if node.type == "string_literal": + return node.text.decode("utf-8", errors="replace").strip("'\"") + for child in node.children: + result = DartHandler._find_string_literal(child) + if result is not None: + return result + return None + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "superclass": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "mixins": + for m in sub.children: + if m.type == "type_identifier": + bases.append( + m.text.decode("utf-8", errors="replace"), + ) + elif child.type == "interfaces": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_go.py b/code_review_graph/lang/_go.py new file mode 100644 index 00000000..048f1477 --- /dev/null +++ b/code_review_graph/lang/_go.py @@ -0,0 +1,73 @@ +"""Go language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class GoHandler(BaseLanguageHandler): + language = "go" + class_types = ["type_declaration"] + function_types = ["function_declaration", "method_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression"] + builtin_names = frozenset({ + "len", "cap", "make", "new", "delete", "append", "copy", + "close", "panic", "recover", "print", "println", + }) + + def get_name(self, node, kind: str) -> str | None: + # Go type_declaration wraps type_spec which holds the identifier + if node.type == "type_declaration": + for child in node.children: + if child.type == "type_spec": + for sub in child.children: + if sub.type in ("identifier", "name", "type_identifier"): + return sub.text.decode("utf-8", errors="replace") + return None + return NotImplemented # fall back to default for function_declaration etc. + + def get_bases(self, node, source: bytes) -> list[str]: + # Embedded structs / interface composition + # Embedded fields are field_declaration nodes with only a type_identifier + # (no field name), e.g. `type Child struct { Parent }` + bases = [] + for child in node.children: + if child.type == "type_spec": + for sub in child.children: + if sub.type in ("struct_type", "interface_type"): + for field_node in sub.children: + if field_node.type == "field_declaration_list": + for f in field_node.children: + if f.type == "field_declaration": + children = [ + c for c in f.children + if c.type not in ("comment",) + ] + if ( + len(children) == 1 + and children[0].type == "type_identifier" + ): + bases.append( + children[0].text.decode( + "utf-8", errors="replace", + ) + ) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "import_spec_list": + for spec in child.children: + if spec.type == "import_spec": + for s in spec.children: + if s.type == "interpreted_string_literal": + val = s.text.decode("utf-8", errors="replace") + imports.append(val.strip('"')) + elif child.type == "import_spec": + for s in child.children: + if s.type == "interpreted_string_literal": + val = s.text.decode("utf-8", errors="replace") + imports.append(val.strip('"')) + return imports diff --git a/code_review_graph/lang/_java.py b/code_review_graph/lang/_java.py new file mode 100644 index 00000000..08849574 --- /dev/null +++ b/code_review_graph/lang/_java.py @@ -0,0 +1,30 @@ +"""Java language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class JavaHandler(BaseLanguageHandler): + language = "java" + class_types = ["class_declaration", "interface_declaration", "enum_declaration"] + function_types = ["method_declaration", "constructor_declaration"] + import_types = ["import_declaration"] + call_types = ["method_invocation", "object_creation_expression"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + parts = text.split() + if len(parts) >= 2: + return [parts[-1].rstrip(";")] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_javascript.py b/code_review_graph/lang/_javascript.py new file mode 100644 index 00000000..5e565f81 --- /dev/null +++ b/code_review_graph/lang/_javascript.py @@ -0,0 +1,304 @@ +"""JavaScript / TypeScript / TSX language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class _JsTsBase(BaseLanguageHandler): + """Shared handler logic for JS, TS, and TSX.""" + + class_types = ["class_declaration", "class"] + function_types = ["function_declaration", "method_definition", "arrow_function"] + import_types = ["import_statement"] + # No builtin_names -- JS/TS builtins are not filtered + + _JS_FUNC_VALUE_TYPES = frozenset( + {"arrow_function", "function_expression", "function"}, + ) + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ("extends_clause", "implements_clause"): + for sub in child.children: + if sub.type in ("identifier", "type_identifier", "nested_identifier"): + bases.append(sub.text.decode("utf-8", errors="replace")) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "string": + val = child.text.decode("utf-8", errors="replace").strip("'\"") + imports.append(val) + return imports + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + # --- Variable-assigned functions (const foo = () => {}) --- + if node_type in ("lexical_declaration", "variable_declaration"): + if self._extract_var_functions( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ): + return True + + # --- Class field arrow functions (handler = () => {}) --- + if node_type == "public_field_definition": + if self._extract_field_function( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ): + return True + + # --- Re-exports: export { X } from './mod', export * from './mod' --- + if node_type == "export_statement": + self._extract_reexport_edges(child, parser, file_path, edges) + # Don't return True -- export_statement may also contain definitions + return False + + return False + + # ------------------------------------------------------------------ + # Extraction helpers + # ------------------------------------------------------------------ + + def _extract_var_functions( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + _depth: int, + ) -> bool: + """Handle JS/TS variable declarations that assign functions. + + Patterns handled: + const foo = () => {} + let bar = function() {} + export const baz = (x: number): string => x.toString() + + Returns True if at least one function was extracted from the + declaration, so the caller can skip generic recursion. + """ + language = self.language + handled = False + for declarator in child.children: + if declarator.type != "variable_declarator": + continue + + # Find identifier and function value + var_name = None + func_node = None + for sub in declarator.children: + if sub.type == "identifier" and var_name is None: + var_name = sub.text.decode("utf-8", errors="replace") + elif sub.type in self._JS_FUNC_VALUE_TYPES: + func_node = sub + + if not var_name or not func_node: + continue + + is_test = _is_test_function(var_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(var_name, file_path, enclosing_class) + params = parser._get_params(func_node, language, source) + ret_type = parser._get_return_type(func_node, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + return_type=ret_type, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + + # Recurse into the function body for calls + parser._extract_from_tree( + func_node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=var_name, + import_map=import_map, + defined_names=defined_names, + _depth=_depth + 1, + ) + handled = True + + if not handled: + # Not a function assignment -- let generic recursion handle it + return False + return True + + def _extract_field_function( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + _depth: int, + ) -> bool: + """Handle class field arrow functions: handler = (e) => { ... }""" + language = self.language + prop_name = None + func_node = None + for sub in child.children: + if sub.type == "property_identifier" and prop_name is None: + prop_name = sub.text.decode("utf-8", errors="replace") + elif sub.type in self._JS_FUNC_VALUE_TYPES: + func_node = sub + + if not prop_name or not func_node: + return False + + is_test = _is_test_function(prop_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(prop_name, file_path, enclosing_class) + params = parser._get_params(func_node, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=prop_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + + parser._extract_from_tree( + func_node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=prop_name, + import_map=import_map, + defined_names=defined_names, + _depth=_depth + 1, + ) + return True + + def _extract_reexport_edges( + self, + node, + parser: CodeParser, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit IMPORTS_FROM edges for JS/TS re-exports with ``from`` clause.""" + language = self.language + # Must have a 'from' string + module = None + for child in node.children: + if child.type == "string": + module = child.text.decode("utf-8", errors="replace").strip("'\"") + if not module: + return + resolved = parser._resolve_module_to_file(module, file_path, language) + target = resolved if resolved else module + # File-level IMPORTS_FROM + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + # Per-symbol edges for named re-exports + if resolved: + for child in node.children: + if child.type == "export_clause": + for spec in child.children: + if spec.type == "export_specifier": + names = [ + s.text.decode("utf-8", errors="replace") + for s in spec.children + if s.type == "identifier" + ] + if names: + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=f"{resolved}::{names[0]}", + file_path=file_path, + line=node.start_point[0] + 1, + )) + + +class JavaScriptHandler(_JsTsBase): + language = "javascript" + call_types = [ + "call_expression", "new_expression", + ] + + +class TypeScriptHandler(_JsTsBase): + language = "typescript" + call_types = ["call_expression", "new_expression"] + + +class TsxHandler(_JsTsBase): + language = "tsx" + call_types = [ + "call_expression", "new_expression", + ] diff --git a/code_review_graph/lang/_kotlin.py b/code_review_graph/lang/_kotlin.py new file mode 100644 index 00000000..bb972156 --- /dev/null +++ b/code_review_graph/lang/_kotlin.py @@ -0,0 +1,24 @@ +"""Kotlin language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class KotlinHandler(BaseLanguageHandler): + language = "kotlin" + class_types = ["class_declaration", "object_declaration"] + function_types = ["function_declaration"] + import_types = ["import_header"] + call_types = ["call_expression"] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type in ( + "superclass", "super_interfaces", "extends_type", + "implements_type", "type_identifier", "supertype", + "delegation_specifier", + ): + bases.append(child.text.decode("utf-8", errors="replace")) + return bases diff --git a/code_review_graph/lang/_lua.py b/code_review_graph/lang/_lua.py new file mode 100644 index 00000000..2df58079 --- /dev/null +++ b/code_review_graph/lang/_lua.py @@ -0,0 +1,314 @@ +"""Lua language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class LuaHandler(BaseLanguageHandler): + language = "lua" + class_types: list[str] = [] # Lua has no class keyword; table-based OOP + function_types = ["function_declaration"] + import_types: list[str] = [] # require() handled via extract_constructs + call_types = ["function_call"] + + def get_name(self, node, kind: str) -> str | None: + # function_declaration names may be dot_index_expression or + # method_index_expression (e.g. function Animal.new() / Animal:speak()). + # Return only the method name; the table name is used as parent_name + # in extract_constructs. + if node.type == "function_declaration": + for child in node.children: + if child.type in ("dot_index_expression", "method_index_expression"): + for sub in reversed(child.children): + if sub.type == "identifier": + return sub.text.decode("utf-8", errors="replace") + return None + return NotImplemented + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + """Handle Lua-specific AST constructs. + + Handles: + - variable_declaration with require() -> IMPORTS_FROM edge + - variable_declaration with function_definition -> named Function node + - function_declaration with dot/method name -> Function with table parent + - top-level require() call -> IMPORTS_FROM edge + """ + if node_type == "variable_declaration": + return self._handle_variable_declaration( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ) + + if node_type == "function_declaration": + return self._handle_table_function( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, depth, + ) + + # Top-level require() not wrapped in variable_declaration + if node_type == "function_call" and not enclosing_func: + req_target = self._get_require_target(child) + if req_target is not None: + resolved = parser._resolve_module_to_file( + req_target, file_path, self.language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + return False + + # ------------------------------------------------------------------ + # Lua-specific helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_require_target(call_node) -> Optional[str]: + """Extract the module path from a Lua require() call. + + Returns the string argument or None if this is not a require() call. + """ + first_child = call_node.children[0] if call_node.children else None + if ( + not first_child + or first_child.type != "identifier" + or first_child.text != b"require" + ): + return None + for child in call_node.children: + if child.type == "arguments": + for arg in child.children: + if arg.type == "string": + for sub in arg.children: + if sub.type == "string_content": + return sub.text.decode( + "utf-8", errors="replace", + ) + raw = arg.text.decode("utf-8", errors="replace") + return raw.strip("'\"") + return None + + def _handle_variable_declaration( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + depth: int, + ) -> bool: + """Handle Lua variable declarations that contain require() or + anonymous function definitions. + + ``local json = require("json")`` -> IMPORTS_FROM edge + ``local fn = function(x) ... end`` -> Function node named "fn" + """ + language = self.language + + # Walk into: variable_declaration > assignment_statement + assign = None + for sub in child.children: + if sub.type == "assignment_statement": + assign = sub + break + if not assign: + return False + + # Get variable name from variable_list + var_name = None + for sub in assign.children: + if sub.type == "variable_list": + for ident in sub.children: + if ident.type == "identifier": + var_name = ident.text.decode("utf-8", errors="replace") + break + break + + # Get value from expression_list + expr_list = None + for sub in assign.children: + if sub.type == "expression_list": + expr_list = sub + break + + if not var_name or not expr_list: + return False + + # Check for require() call + for expr in expr_list.children: + if expr.type == "function_call": + req_target = self._get_require_target(expr) + if req_target is not None: + resolved = parser._resolve_module_to_file( + req_target, file_path, language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + # Check for anonymous function: local foo = function(...) end + for expr in expr_list.children: + if expr.type == "function_definition": + is_test = _is_test_function(var_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(var_name, file_path, enclosing_class) + params = parser._get_params(expr, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + # Recurse into the function body for calls + parser._extract_from_tree( + expr, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, + enclosing_func=var_name, + import_map=import_map, + defined_names=defined_names, + _depth=depth + 1, + ) + return True + + return False + + def _handle_table_function( + self, + child, + source: bytes, + parser: CodeParser, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + depth: int, + ) -> bool: + """Handle Lua function declarations with table-qualified names. + + ``function Animal.new(name)`` -> Function "new", parent "Animal" + ``function Animal:speak()`` -> Function "speak", parent "Animal" + + Plain ``function foo()`` is NOT handled here (returns False). + """ + language = self.language + table_name = None + method_name = None + + for sub in child.children: + if sub.type in ("dot_index_expression", "method_index_expression"): + identifiers = [ + c for c in sub.children if c.type == "identifier" + ] + if len(identifiers) >= 2: + table_name = identifiers[0].text.decode( + "utf-8", errors="replace", + ) + method_name = identifiers[-1].text.decode( + "utf-8", errors="replace", + ) + break + + if not table_name or not method_name: + return False + + is_test = _is_test_function(method_name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(method_name, file_path, table_name) + params = parser._get_params(child, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=method_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=language, + parent_name=table_name, + params=params, + is_test=is_test, + )) + # CONTAINS: table -> method + container = parser._qualify(table_name, file_path, None) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + # Recurse into function body for calls + parser._extract_from_tree( + child, source, language, file_path, nodes, edges, + enclosing_class=table_name, + enclosing_func=method_name, + import_map=import_map, + defined_names=defined_names, + _depth=depth + 1, + ) + return True + + +class LuauHandler(LuaHandler): + """Roblox Luau (.luau) handler -- reuses the Lua handler.""" + + language = "luau" + class_types = ["type_definition"] diff --git a/code_review_graph/lang/_perl.py b/code_review_graph/lang/_perl.py new file mode 100644 index 00000000..fba72cf6 --- /dev/null +++ b/code_review_graph/lang/_perl.py @@ -0,0 +1,24 @@ +"""Perl language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class PerlHandler(BaseLanguageHandler): + language = "perl" + class_types = ["package_statement", "class_statement", "role_statement"] + function_types = ["subroutine_declaration_statement", "method_declaration_statement"] + import_types = ["use_statement", "require_expression"] + call_types = [ + "function_call_expression", "method_call_expression", + "ambiguous_function_call_expression", + ] + + def get_name(self, node, kind: str) -> str | None: + for child in node.children: + if child.type == "bareword": + return child.text.decode("utf-8", errors="replace") + if child.type == "package" and child.text != b"package": + return child.text.decode("utf-8", errors="replace") + return NotImplemented diff --git a/code_review_graph/lang/_php.py b/code_review_graph/lang/_php.py new file mode 100644 index 00000000..f299835f --- /dev/null +++ b/code_review_graph/lang/_php.py @@ -0,0 +1,13 @@ +"""PHP language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class PhpHandler(BaseLanguageHandler): + language = "php" + class_types = ["class_declaration", "interface_declaration"] + function_types = ["function_definition", "method_declaration"] + import_types = ["namespace_use_declaration"] + call_types = ["function_call_expression", "member_call_expression"] diff --git a/code_review_graph/lang/_python.py b/code_review_graph/lang/_python.py new file mode 100644 index 00000000..f836aeef --- /dev/null +++ b/code_review_graph/lang/_python.py @@ -0,0 +1,109 @@ +"""Python language handler.""" + +from __future__ import annotations + +from pathlib import Path + +from ._base import BaseLanguageHandler + + +class PythonHandler(BaseLanguageHandler): + language = "python" + class_types = ["class_definition"] + function_types = ["function_definition"] + import_types = ["import_statement", "import_from_statement"] + call_types = ["call"] + builtin_names = frozenset({ + "len", "str", "int", "float", "bool", "list", "dict", "set", "tuple", + "print", "range", "enumerate", "zip", "map", "filter", "sorted", + "reversed", "isinstance", "issubclass", "type", "id", "hash", + "hasattr", "getattr", "setattr", "delattr", "callable", + "repr", "abs", "min", "max", "sum", "round", "pow", "divmod", + "iter", "next", "open", "super", "property", "staticmethod", + "classmethod", "vars", "dir", "help", "input", "format", + "bytes", "bytearray", "memoryview", "frozenset", "complex", + "chr", "ord", "hex", "oct", "bin", "any", "all", + }) + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "argument_list": + for arg in child.children: + if arg.type in ("identifier", "attribute"): + bases.append(arg.text.decode("utf-8", errors="replace")) + return bases + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + if node.type == "import_from_statement": + for child in node.children: + if child.type == "dotted_name": + imports.append(child.text.decode("utf-8", errors="replace")) + break + else: + for child in node.children: + if child.type == "dotted_name": + imports.append(child.text.decode("utf-8", errors="replace")) + return imports + + def collect_import_names( + self, node, file_path: str, import_map: dict[str, str], + ) -> bool: + if node.type == "import_from_statement": + # from X.Y import A, B -> {A: X.Y, B: X.Y} + module = None + seen_import_keyword = False + for child in node.children: + if child.type == "dotted_name" and not seen_import_keyword: + module = child.text.decode("utf-8", errors="replace") + elif child.type == "import": + seen_import_keyword = True + elif seen_import_keyword and module: + if child.type in ("identifier", "dotted_name"): + name = child.text.decode("utf-8", errors="replace") + import_map[name] = module + elif child.type == "aliased_import": + # from X import A as B -> {B: X} + names = [ + sub.text.decode("utf-8", errors="replace") + for sub in child.children + if sub.type in ("identifier", "dotted_name") + ] + if names: + import_map[names[-1]] = module + elif node.type == "import_statement": + # import json -> {json: json} + # import os.path -> {os: os.path} + # import X as Y -> {Y: X} + for child in node.children: + if child.type in ("dotted_name", "identifier"): + mod = child.text.decode("utf-8", errors="replace") + top_level = mod.split(".")[0] + import_map[top_level] = mod + elif child.type == "aliased_import": + names = [ + sub.text.decode("utf-8", errors="replace") + for sub in child.children + if sub.type in ("identifier", "dotted_name") + ] + if len(names) >= 2: + import_map[names[-1]] = names[0] + else: + return False + return True + + def resolve_module(self, module: str, caller_file: str) -> str | None: + caller_dir = Path(caller_file).parent + rel_path = module.replace(".", "/") + candidates = [rel_path + ".py", rel_path + "/__init__.py"] + current = caller_dir + while True: + for candidate in candidates: + target = current / candidate + if target.is_file(): + return str(target.resolve()) + if current == current.parent: + break + current = current.parent + return None diff --git a/code_review_graph/lang/_r.py b/code_review_graph/lang/_r.py new file mode 100644 index 00000000..a15ad973 --- /dev/null +++ b/code_review_graph/lang/_r.py @@ -0,0 +1,339 @@ +"""R language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ..parser import EdgeInfo, NodeInfo, _is_test_function +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class RHandler(BaseLanguageHandler): + language = "r" + class_types: list[str] = [] # Classes detected via call pattern-matching + function_types = ["function_definition"] + import_types = ["call"] # library(), require(), source() -- filtered downstream + call_types = ["call"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + """Extract import targets from R library/require/source calls.""" + imports = [] + func_name = self._call_func_name(node) + if func_name in ("library", "require", "source"): + for _name, value in self._iter_args(node): + if value.type == "identifier": + imports.append(value.text.decode("utf-8", errors="replace")) + elif value.type == "string": + val = self._first_string_arg(node) + if val: + imports.append(val) + break # Only first argument matters + return imports + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + if node_type == "binary_operator": + if self._handle_binary_operator( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ): + return True + + if node_type == "call": + if self._handle_call( + child, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ): + return True + + return False + + # ------------------------------------------------------------------ + # R-specific helpers + # ------------------------------------------------------------------ + + @staticmethod + def _call_func_name(call_node) -> Optional[str]: + """Extract the function name from an R call node.""" + for child in call_node.children: + if child.type in ("identifier", "namespace_operator"): + return child.text.decode("utf-8", errors="replace") + return None + + @staticmethod + def _first_string_arg(call_node) -> Optional[str]: + """Extract the first string argument value from an R call node.""" + for child in call_node.children: + if child.type == "arguments": + for arg in child.children: + if arg.type == "argument": + for sub in arg.children: + if sub.type == "string": + for sc in sub.children: + if sc.type == "string_content": + return sc.text.decode("utf-8", errors="replace") + break + return None + + @staticmethod + def _iter_args(call_node): + """Yield (name_str, value_node) pairs from an R call's arguments.""" + for child in call_node.children: + if child.type != "arguments": + continue + for arg in child.children: + if arg.type != "argument": + continue + has_eq = any(sub.type == "=" for sub in arg.children) + if has_eq: + name = None + value = None + for sub in arg.children: + if sub.type == "identifier" and name is None: + name = sub.text.decode("utf-8", errors="replace") + elif sub.type not in ("=", ","): + value = sub + yield (name, value) + else: + for sub in arg.children: + if sub.type not in (",",): + yield (None, sub) + break + break + + @classmethod + def _find_named_arg(cls, call_node, arg_name: str): + """Find a named argument's value node in an R call.""" + for name, value in cls._iter_args(call_node): + if name == arg_name: + return value + return None + + # ------------------------------------------------------------------ + # Extraction methods + # ------------------------------------------------------------------ + + def _handle_binary_operator( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> bool: + """Handle R binary_operator nodes: name <- function(...) { ... }.""" + language = self.language + children = node.children + if len(children) < 3: + return False + + left, op, right = children[0], children[1], children[2] + if op.type not in ("<-", "="): + return False + + if right.type == "function_definition" and left.type == "identifier": + name = left.text.decode("utf-8", errors="replace") + is_test = _is_test_function(name, file_path) + kind = "Test" if is_test else "Function" + qualified = parser._qualify(name, file_path, enclosing_class) + params = parser._get_params(right, language, source) + + nodes.append(NodeInfo( + kind=kind, + name=name, + file_path=file_path, + line_start=right.start_point[0] + 1, + line_end=right.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + params=params, + is_test=is_test, + )) + + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=right.start_point[0] + 1, + )) + + parser._extract_from_tree( + right, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, enclosing_func=name, + import_map=import_map, defined_names=defined_names, + ) + return True + + if right.type == "call" and left.type == "identifier": + call_func = self._call_func_name(right) + if call_func in ("setRefClass", "setClass", "setGeneric"): + assign_name = left.text.decode("utf-8", errors="replace") + return self._handle_class_call( + right, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + assign_name=assign_name, + ) + + return False + + def _handle_call( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> bool: + """Handle R call nodes for imports and class definitions.""" + language = self.language + func_name = self._call_func_name(node) + if not func_name: + return False + + if func_name in ("library", "require", "source"): + imports = parser._extract_import(node, language, source) + for imp_target in imports: + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=imp_target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + return True + + if func_name in ("setRefClass", "setClass", "setGeneric"): + return self._handle_class_call( + node, source, parser, file_path, nodes, edges, + enclosing_class, enclosing_func, + import_map, defined_names, + ) + + if enclosing_func: + call_name = parser._get_call_name(node, language, source) + if call_name: + caller = parser._qualify(enclosing_func, file_path, enclosing_class) + target = parser._resolve_call_target( + call_name, file_path, language, + import_map or {}, defined_names or set(), + ) + edges.append(EdgeInfo( + kind="CALLS", + source=caller, + target=target, + file_path=file_path, + line=node.start_point[0] + 1, + )) + + parser._extract_from_tree( + node, source, language, file_path, nodes, edges, + enclosing_class=enclosing_class, enclosing_func=enclosing_func, + import_map=import_map, defined_names=defined_names, + ) + return True + + def _handle_class_call( + self, node, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + enclosing_class: Optional[str], enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + assign_name: Optional[str] = None, + ) -> bool: + """Handle setClass/setRefClass/setGeneric calls -> Class nodes.""" + language = self.language + class_name = self._first_string_arg(node) or assign_name + if not class_name: + return False + + qualified = parser._qualify(class_name, file_path, enclosing_class) + nodes.append(NodeInfo( + kind="Class", + name=class_name, + file_path=file_path, + line_start=node.start_point[0] + 1, + line_end=node.end_point[0] + 1, + language=language, + parent_name=enclosing_class, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=file_path, + target=qualified, + file_path=file_path, + line=node.start_point[0] + 1, + )) + + methods_list = self._find_named_arg(node, "methods") + if methods_list is not None: + self._extract_methods( + methods_list, source, parser, file_path, + nodes, edges, class_name, + import_map, defined_names, + ) + + return True + + def _extract_methods( + self, list_call, source: bytes, parser: CodeParser, file_path: str, + nodes: list[NodeInfo], edges: list[EdgeInfo], + class_name: str, + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + ) -> None: + """Extract methods from a setRefClass methods = list(...) call.""" + language = self.language + for method_name, func_def in self._iter_args(list_call): + if not method_name or func_def is None: + continue + if func_def.type != "function_definition": + continue + + qualified = parser._qualify(method_name, file_path, class_name) + params = parser._get_params(func_def, language, source) + nodes.append(NodeInfo( + kind="Function", + name=method_name, + file_path=file_path, + line_start=func_def.start_point[0] + 1, + line_end=func_def.end_point[0] + 1, + language=language, + parent_name=class_name, + params=params, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=parser._qualify(class_name, file_path, None), + target=qualified, + file_path=file_path, + line=func_def.start_point[0] + 1, + )) + parser._extract_from_tree( + func_def, source, language, file_path, nodes, edges, + enclosing_class=class_name, + enclosing_func=method_name, + import_map=import_map, + defined_names=defined_names, + ) diff --git a/code_review_graph/lang/_ruby.py b/code_review_graph/lang/_ruby.py new file mode 100644 index 00000000..5a6b11fd --- /dev/null +++ b/code_review_graph/lang/_ruby.py @@ -0,0 +1,23 @@ +"""Ruby language handler.""" + +from __future__ import annotations + +import re + +from ._base import BaseLanguageHandler + + +class RubyHandler(BaseLanguageHandler): + language = "ruby" + class_types = ["class", "module"] + function_types = ["method", "singleton_method"] + import_types = ["call"] # require / require_relative + call_types = ["call", "method_call"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + if "require" in text: + match = re.search(r"""['"](.*?)['"]""", text) + if match: + return [match.group(1)] + return [] diff --git a/code_review_graph/lang/_rust.py b/code_review_graph/lang/_rust.py new file mode 100644 index 00000000..839006ee --- /dev/null +++ b/code_review_graph/lang/_rust.py @@ -0,0 +1,22 @@ +"""Rust language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class RustHandler(BaseLanguageHandler): + language = "rust" + class_types = ["struct_item", "enum_item", "impl_item"] + function_types = ["function_item"] + import_types = ["use_declaration"] + call_types = ["call_expression", "macro_invocation"] + builtin_names = frozenset({ + "println", "eprintln", "format", "vec", "panic", "todo", + "unimplemented", "unreachable", "assert", "assert_eq", "assert_ne", + "dbg", "cfg", + }) + + def extract_import_targets(self, node, source: bytes) -> list[str]: + text = node.text.decode("utf-8", errors="replace").strip() + return [text.replace("use ", "").rstrip(";").strip()] diff --git a/code_review_graph/lang/_scala.py b/code_review_graph/lang/_scala.py new file mode 100644 index 00000000..e5159d1b --- /dev/null +++ b/code_review_graph/lang/_scala.py @@ -0,0 +1,54 @@ +"""Scala language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class ScalaHandler(BaseLanguageHandler): + language = "scala" + class_types = [ + "class_definition", "trait_definition", + "object_definition", "enum_definition", + ] + function_types = ["function_definition", "function_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression", "instance_expression", "generic_function"] + + def extract_import_targets(self, node, source: bytes) -> list[str]: + parts: list[str] = [] + selectors: list[str] = [] + is_wildcard = False + for child in node.children: + if child.type == "identifier": + parts.append(child.text.decode("utf-8", errors="replace")) + elif child.type == "namespace_selectors": + for sub in child.children: + if sub.type == "identifier": + selectors.append(sub.text.decode("utf-8", errors="replace")) + elif child.type == "namespace_wildcard": + is_wildcard = True + base = ".".join(parts) + if selectors: + return [f"{base}.{name}" for name in selectors] + if is_wildcard: + return [f"{base}.*"] + if base: + return [base] + return [] + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "extends_clause": + for sub in child.children: + if sub.type == "type_identifier": + bases.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "generic_type": + for ident in sub.children: + if ident.type == "type_identifier": + bases.append( + ident.text.decode("utf-8", errors="replace"), + ) + break + return bases diff --git a/code_review_graph/lang/_solidity.py b/code_review_graph/lang/_solidity.py new file mode 100644 index 00000000..efd5560d --- /dev/null +++ b/code_review_graph/lang/_solidity.py @@ -0,0 +1,222 @@ +"""Solidity language handler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..parser import EdgeInfo, NodeInfo +from ._base import BaseLanguageHandler + +if TYPE_CHECKING: + from ..parser import CodeParser + + +class SolidityHandler(BaseLanguageHandler): + language = "solidity" + class_types = [ + "contract_declaration", "interface_declaration", "library_declaration", + "struct_declaration", "enum_declaration", "error_declaration", + "user_defined_type_definition", + ] + # Events and modifiers use kind="Function" because the graph schema has no + # dedicated kind for them. State variables are also modeled as Function + # nodes (public ones auto-generate getters). + function_types = [ + "function_definition", "constructor_definition", "modifier_definition", + "event_definition", "fallback_receive_definition", + ] + import_types = ["import_directive"] + call_types = ["call_expression"] + + def get_name(self, node, kind: str) -> str | None: + if node.type == "constructor_definition": + return "constructor" + if node.type == "fallback_receive_definition": + for child in node.children: + if child.type in ("receive", "fallback"): + return child.text.decode("utf-8", errors="replace") + return NotImplemented + + def extract_import_targets(self, node, source: bytes) -> list[str]: + imports = [] + for child in node.children: + if child.type == "string": + val = child.text.decode("utf-8", errors="replace").strip('"') + if val: + imports.append(val) + return imports + + def get_bases(self, node, source: bytes) -> list[str]: + bases = [] + for child in node.children: + if child.type == "inheritance_specifier": + for sub in child.children: + if sub.type == "user_defined_type": + for ident in sub.children: + if ident.type == "identifier": + bases.append( + ident.text.decode("utf-8", errors="replace"), + ) + return bases + + def extract_constructs( + self, + child, + node_type: str, + parser: CodeParser, + source: bytes, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: str | None, + enclosing_func: str | None, + import_map: dict[str, str] | None, + defined_names: set[str] | None, + depth: int, + ) -> bool: + # Emit statements: emit EventName(...) -> CALLS edge + if node_type == "emit_statement" and enclosing_func: + for sub in child.children: + if sub.type == "expression": + for ident in sub.children: + if ident.type == "identifier": + caller = parser._qualify( + enclosing_func, file_path, + enclosing_class, + ) + edges.append(EdgeInfo( + kind="CALLS", + source=caller, + target=ident.text.decode( + "utf-8", errors="replace", + ), + file_path=file_path, + line=child.start_point[0] + 1, + )) + # emit_statement falls through to default recursion + return False + + # State variable declarations -> Function nodes (public ones + # auto-generate getters, and all are critical for reviews) + if node_type == "state_variable_declaration" and enclosing_class: + var_name = None + var_visibility = None + var_mutability = None + var_type = None + for sub in child.children: + if sub.type == "identifier": + var_name = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "visibility": + var_visibility = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "type_name": + var_type = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type in ("constant", "immutable"): + var_mutability = sub.type + if var_name: + qualified = parser._qualify( + var_name, file_path, enclosing_class, + ) + nodes.append(NodeInfo( + kind="Function", + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=self.language, + parent_name=enclosing_class, + return_type=var_type, + modifiers=var_visibility, + extra={ + "solidity_kind": "state_variable", + "mutability": var_mutability, + }, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=parser._qualify( + enclosing_class, file_path, None, + ), + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + return False + + # File-level and contract-level constant declarations + if node_type == "constant_variable_declaration": + var_name = None + var_type = None + for sub in child.children: + if sub.type == "identifier": + var_name = sub.text.decode( + "utf-8", errors="replace", + ) + elif sub.type == "type_name": + var_type = sub.text.decode( + "utf-8", errors="replace", + ) + if var_name: + qualified = parser._qualify( + var_name, file_path, enclosing_class, + ) + nodes.append(NodeInfo( + kind="Function", + name=var_name, + file_path=file_path, + line_start=child.start_point[0] + 1, + line_end=child.end_point[0] + 1, + language=self.language, + parent_name=enclosing_class, + return_type=var_type, + extra={"solidity_kind": "constant"}, + )) + container = ( + parser._qualify(enclosing_class, file_path, None) + if enclosing_class + else file_path + ) + edges.append(EdgeInfo( + kind="CONTAINS", + source=container, + target=qualified, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + return False + + # Using directives: using LibName for Type -> DEPENDS_ON edge + if node_type == "using_directive": + lib_name = None + for sub in child.children: + if sub.type == "type_alias": + for ident in sub.children: + if ident.type == "identifier": + lib_name = ident.text.decode( + "utf-8", errors="replace", + ) + if lib_name: + source_name = ( + parser._qualify( + enclosing_class, file_path, None, + ) + if enclosing_class + else file_path + ) + edges.append(EdgeInfo( + kind="DEPENDS_ON", + source=source_name, + target=lib_name, + file_path=file_path, + line=child.start_point[0] + 1, + )) + return True + + return False diff --git a/code_review_graph/lang/_swift.py b/code_review_graph/lang/_swift.py new file mode 100644 index 00000000..4a4c6754 --- /dev/null +++ b/code_review_graph/lang/_swift.py @@ -0,0 +1,13 @@ +"""Swift language handler.""" + +from __future__ import annotations + +from ._base import BaseLanguageHandler + + +class SwiftHandler(BaseLanguageHandler): + language = "swift" + class_types = ["class_declaration", "struct_declaration", "protocol_declaration"] + function_types = ["function_declaration"] + import_types = ["import_declaration"] + call_types = ["call_expression"] diff --git a/code_review_graph/migrations.py b/code_review_graph/migrations.py index 1787b98d..7803f37d 100644 --- a/code_review_graph/migrations.py +++ b/code_review_graph/migrations.py @@ -204,16 +204,8 @@ def _migrate_v6(conn: sqlite3.Connection) -> None: def _migrate_v7(conn: sqlite3.Connection) -> None: - """v7: Add compound edge indexes for summary and risk queries.""" - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_edges_target_kind " - "ON edges(target_qualified, kind)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_edges_source_kind " - "ON edges(source_qualified, kind)" - ) - logger.info("Migration v7: added compound edge indexes") + """v7: Reserved (upstream PR #127). No-op for forward compatibility.""" + logger.info("Migration v7: reserved (no-op)") def _migrate_v8(conn: sqlite3.Connection) -> None: diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index 31af17f7..d07ba725 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -10,14 +10,18 @@ import json import logging import re +import threading from dataclasses import dataclass, field from pathlib import Path -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional import tree_sitter_language_pack as tslp from .tsconfig_resolver import TsconfigResolver +if TYPE_CHECKING: + from .lang import BaseLanguageHandler + class CellInfo(NamedTuple): """Represents a single cell in a notebook with its language.""" @@ -111,12 +115,7 @@ class EdgeInfo: ".ex": "elixir", ".exs": "elixir", ".ipynb": "notebook", - ".zig": "zig", - ".ps1": "powershell", - ".psm1": "powershell", - ".psd1": "powershell", - ".svelte": "svelte", - ".jl": "julia", + ".html": "html", } # Tree-sitter node type mappings per language @@ -161,9 +160,6 @@ class EdgeInfo: # identifier is literally "defmodule". Dispatched via # _extract_elixir_constructs to avoid matching every ``call`` here. "elixir": [], - "zig": ["container_declaration"], - "powershell": ["class_statement"], - "julia": ["struct_definition", "abstract_definition"], } _FUNCTION_TYPES: dict[str, list[str]] = { @@ -208,12 +204,6 @@ class EdgeInfo: # Elixir: def/defp/defmacro are all ``call`` nodes whose first # identifier matches. Dispatched via _extract_elixir_constructs. "elixir": [], - "zig": ["fn_proto", "fn_decl"], - "powershell": ["function_statement"], - "julia": [ - "function_definition", - "short_function_definition", - ], } _IMPORT_TYPES: dict[str, list[str]] = { @@ -248,12 +238,6 @@ class EdgeInfo: # Elixir: alias/import/require/use are all ``call`` nodes — # handled in _extract_elixir_constructs. "elixir": [], - # Zig: @import("...") is a builtin_call_expr — handled - # generically via call types below. - "zig": [], - "powershell": [], - # Julia: import/using are import_statement nodes. - "julia": ["import_statement", "using_statement"], } _CALL_TYPES: dict[str, list[str]] = { @@ -289,9 +273,6 @@ class EdgeInfo: # _extract_elixir_constructs which filters out def/defmodule/alias/etc. # before treating what's left as a real call. "elixir": [], - "zig": ["call_expression", "builtin_call_expr"], - "powershell": ["command_expression"], - "julia": ["call_expression"], } # Patterns that indicate a test function @@ -329,6 +310,50 @@ class EdgeInfo: "org.junit.Test", "org.junit.jupiter.api.Test", }) +_BUILTIN_NAMES: dict[str, frozenset[str]] = { +} + +# Common JS/TS prototype and built-in method names that should NOT create +# CALLS edges when seen as instance method calls (obj.method()). These are +# so ubiquitous that emitting bare-name edges for them creates noise without +# helping dead-code or flow analysis. +_INSTANCE_METHOD_BLOCKLIST: frozenset[str] = frozenset({ + # Array / iterable + "push", "pop", "shift", "unshift", "splice", "slice", "concat", + "map", "filter", "reduce", "reduceRight", "find", "findIndex", + "forEach", "every", "some", "includes", "indexOf", "lastIndexOf", + "flat", "flatMap", "fill", "sort", "reverse", "join", "entries", + "keys", "values", "at", "with", + # Object / prototype + "toString", "valueOf", "toJSON", "hasOwnProperty", "toLocaleString", + # String + "trim", "trimStart", "trimEnd", "split", "replace", "replaceAll", + "match", "matchAll", "search", "startsWith", "endsWith", "padStart", + "padEnd", "repeat", "substring", "toLowerCase", "toUpperCase", "charAt", + "charCodeAt", "normalize", "localeCompare", + # Promise / async + "then", "catch", "finally", + # Map / Set + "get", "set", "has", "delete", "clear", "add", "size", + # EventEmitter / stream (very generic) + "emit", "pipe", "write", "end", "destroy", "pause", "resume", + # Logging / console + "log", "warn", "error", "info", "debug", "trace", + # DOM / common + "addEventListener", "removeEventListener", "querySelector", + "querySelectorAll", "getElementById", "setAttribute", + "getAttribute", "appendChild", "removeChild", "createElement", + "preventDefault", "stopPropagation", + # RxJS / Observable + "subscribe", "unsubscribe", "next", "complete", + # Common generic names (too ambiguous to resolve) + "call", "apply", "bind", "resolve", "reject", + # Python common builtins used as methods + "append", "extend", "insert", "remove", "update", "items", + "encode", "decode", "strip", "lstrip", "rstrip", "format", + "upper", "lower", "title", "count", "copy", "deepcopy", +}) + def _is_test_file(path: str) -> bool: return any(p.search(path) for p in _TEST_FILE_PATTERNS) @@ -368,19 +393,61 @@ def __init__(self) -> None: self._parsers: dict[str, object] = {} self._module_file_cache: dict[str, Optional[str]] = {} self._export_symbol_cache: dict[str, Optional[str]] = {} + self._star_export_cache: dict[str, set[str]] = {} self._tsconfig_resolver = TsconfigResolver() # Per-parse cache of Dart pubspec root lookups; see #87 self._dart_pubspec_cache: dict[tuple[str, str], Optional[Path]] = {} + self._handlers: dict[str, "BaseLanguageHandler"] = {} + self._type_sets_cache: dict[str, tuple] = {} + self._workspace_map: dict[str, str] = {} # pkg name → directory path + self._workspace_map_built = False + self._lock = threading.Lock() + self._register_handlers() + + def _register_handlers(self) -> None: + from .lang import ALL_HANDLERS + for handler in ALL_HANDLERS: + self._handlers[handler.language] = handler + + def _type_sets(self, language: str): + cached = self._type_sets_cache.get(language) + if cached is not None: + return cached + with self._lock: + cached = self._type_sets_cache.get(language) + if cached is not None: + return cached + handler = self._handlers.get(language) + if handler is not None: + result = ( + set(handler.class_types), + set(handler.function_types), + set(handler.import_types), + set(handler.call_types), + ) + else: + result = ( + set(_CLASS_TYPES.get(language, [])), + set(_FUNCTION_TYPES.get(language, [])), + set(_IMPORT_TYPES.get(language, [])), + set(_CALL_TYPES.get(language, [])), + ) + self._type_sets_cache[language] = result + return result def _get_parser(self, language: str): # type: ignore[arg-type] - if language not in self._parsers: + if language in self._parsers: + return self._parsers[language] + with self._lock: + if language in self._parsers: + return self._parsers[language] try: self._parsers[language] = tslp.get_parser(language) # type: ignore[arg-type] except (LookupError, ValueError, ImportError) as exc: # language not packaged, or grammar load failed logger.debug("tree-sitter parser unavailable for %s: %s", language, exc) return None - return self._parsers[language] + return self._parsers[language] def detect_language(self, path: Path) -> Optional[str]: return EXTENSION_TO_LANGUAGE.get(path.suffix.lower()) @@ -389,7 +456,8 @@ def parse_file(self, path: Path) -> tuple[list[NodeInfo], list[EdgeInfo]]: """Parse a single file and return extracted nodes and edges.""" try: source = path.read_bytes() - except (OSError, PermissionError): + except (OSError, PermissionError) as e: + logger.warning("Cannot read %s: %s", path, e) return [], [] return self.parse_bytes(path, source) @@ -403,21 +471,27 @@ def parse_bytes(self, path: Path, source: bytes) -> tuple[list[NodeInfo], list[E if not language: return [], [] + # Skip likely bundled JS files (Rollup/Vite/webpack output). + # These are single files with thousands of lines that pollute the graph. + if language in ("javascript",) and len(source) > 500_000: + return [], [] + + # Angular templates: regex-based extraction (no tree-sitter grammar) + if language == "html": + return self._parse_angular_template(path, source) + # Vue SFCs: parse with vue parser, then delegate script blocks to JS/TS if language == "vue": return self._parse_vue(path, source) - # Svelte SFCs: same approach as Vue — extract