diff --git a/README.md b/README.md index b98457e5..f79323da 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ TypeScript, JavaScript, Python, Go, Rust, Java, C#, PHP, Ruby, C, C++, Swift, Ko ### 🔒 100% Local -No data leaves your machine. No API keys. No external services. Everything runs on your local SQLite database. +No data leaves your machine. No API keys. No external services. Everything runs locally — SQLite by default, with optional PostgreSQL for faster vector search. @@ -289,6 +289,181 @@ At the start of a session, ask the user if they'd like to initialize CodeGraph: ## 📋 Requirements - Node.js >= 18.0.0 +- PostgreSQL with [pgvector](https://github.com/pgvector/pgvector) *(optional — for faster semantic search)* + +--- + +## 🐘 Using PostgreSQL for Vector Search + +By default, CodeGraph stores embeddings in SQLite. For faster semantic search on larger codebases, you can use PostgreSQL with pgvector for approximate nearest neighbor (ANN) search via HNSW indexes. + +### Step 1: Install PostgreSQL with pgvector + +**macOS (Homebrew):** +```bash +brew install postgresql +brew install pgvector +``` + +**Linux (Ubuntu/Debian):** +```bash +sudo apt-get install postgresql postgresql-contrib +sudo apt-get install postgresql--pgvector +``` + +**Other platforms:** See [pgvector installation guide](https://github.com/pgvector/pgvector#installation) + +### Step 2: Start PostgreSQL + +**macOS (Homebrew):** +```bash +brew services start postgresql +``` + +**Linux (systemd):** +```bash +sudo systemctl start postgresql +``` + +### Step 3: Create a Database + +Connect to PostgreSQL and create a database for CodeGraph: + +```bash +# Connect to PostgreSQL (on macOS with Homebrew, no password needed) +psql postgres + +# Inside psql: +CREATE DATABASE codegraph; +\c codegraph +CREATE EXTENSION IF NOT EXISTS vector; +\q +``` + +Verify the setup: +```bash +psql codegraph -c "SELECT extname FROM pg_extension WHERE extname = 'vector';" +``` + +You should see: `vector` + +### Step 4: Install the PostgreSQL Driver + +```bash +npm install pg +``` + +### Step 5: Configure CodeGraph + +You have two options: + +**Option A: Per-Project Configuration** + +Edit `.codegraph/config.json` in your project: + +```json +{ + "version": 1, + "languages": ["typescript", "javascript"], + "vectorStore": { + "backend": "pgvector", + "connectionString": "postgresql://localhost:5432/codegraph" + } +} +``` + +**Option B: Global Configuration (Environment Variable)** + +Set once in your shell profile (`~/.bashrc`, `~/.zshrc`, etc.): + +```bash +export CODEGRAPH_PG_URL="postgresql://localhost:5432/codegraph" +``` + +Then in `.codegraph/config.json`: + +```json +{ + "version": 1, + "vectorStore": { + "backend": "pgvector" + } +} +``` + +The `CODEGRAPH_PG_URL` environment variable will be used as the connection string. + +### Step 6: Index Your Project + +If you're switching from SQLite to pgvector, regenerate embeddings: + +```bash +codegraph index --force +``` + +For new projects: + +```bash +codegraph init --index +``` + +### Step 7: Verify the Setup + +Check that vectors are being stored in PostgreSQL: + +```bash +psql codegraph -c "SELECT COUNT(*) FROM codegraph_vectors;" +``` + +You should see the number of vectors indexed for your project. + +### Configuration Options + +For advanced configuration, see the full options in `.codegraph/config.json`: + +```json +{ + "vectorStore": { + "backend": "pgvector", + "connectionString": "postgresql://user:pass@host:5432/dbname", + "indexType": "hnsw", // HNSW indexes for faster search + "distanceMetric": "cosine", // cosine, l2, or inner_product + "poolSize": 5, // Connection pool size + "tablePrefix": "codegraph_" // Table name prefix + } +} +``` + +| Option | Description | Default | +|--------|-------------|---------| +| `backend` | `"sqlite"` or `"pgvector"` | `"sqlite"` | +| `connectionString` | PostgreSQL connection URL | — | +| `indexType` | `"hnsw"` (recommended), `"ivfflat"`, or `"none"` | `"hnsw"` | +| `distanceMetric` | `"cosine"`, `"l2"`, or `"inner_product"` | `"cosine"` | +| `poolSize` | Connection pool size | `5` | +| `tablePrefix` | Table name prefix | `"codegraph_"` | + +### Troubleshooting PostgreSQL Setup + +**Connection Refused:** +```bash +# Verify PostgreSQL is running +psql postgres -c "SELECT version();" +``` + +**pgvector Extension Not Found:** +```bash +# Verify pgvector is installed +psql codegraph -c "CREATE EXTENSION vector;" +``` + +**Switching Backends:** +When switching from SQLite to pgvector or back, always re-index to regenerate embeddings: +```bash +codegraph index --force +``` + +The graph data (nodes, edges, files) stays in SQLite — only vector embeddings use the configured backend. --- @@ -603,7 +778,7 @@ All data is stored in a local SQLite database (`.codegraph/codegraph.db`): - **edges** table: Relationships between nodes - **files** table: File tracking for incremental updates - **unresolved_refs** table: References pending resolution -- **vectors** table: Embeddings stored as BLOBs for semantic search +- **vectors** table: Embeddings stored as BLOBs for semantic search (or in PostgreSQL with pgvector — see [configuration](#-vector-store-postgresql--pgvector)) - **nodes_fts**: FTS5 virtual table for full-text search - **schema_versions** table: Schema version tracking - **project_metadata** table: Project-level key-value metadata @@ -621,10 +796,12 @@ After extraction, CodeGraph resolves references: CodeGraph uses local embeddings (via [@xenova/transformers](https://github.com/xenova/transformers.js)) to enable semantic search: -1. Code symbols are embedded using a transformer model +1. Code symbols are embedded using a transformer model (nomic-embed-text-v1.5, 768 dimensions) 2. Queries are embedded and compared using cosine similarity 3. Results are ranked by relevance +By default, embeddings are stored in SQLite as BLOBs with brute-force cosine similarity search. For larger codebases, you can use **PostgreSQL with pgvector** for production-grade HNSW indexes and significantly faster approximate nearest neighbor search. See [Vector Store Configuration](#-vector-store-postgresql--pgvector) below. + ### 5. Graph Queries The graph structure enables powerful queries: @@ -675,6 +852,75 @@ The `.codegraph/config.json` file controls indexing behavior: | `extractDocstrings` | Whether to extract docstrings from code | `true` | | `trackCallSites` | Whether to track call site locations | `true` | +### 🐘 Vector Store (PostgreSQL + pgvector) + +By default, CodeGraph stores embeddings in SQLite. For faster semantic search on large codebases, you can use PostgreSQL with the [pgvector](https://github.com/pgvector/pgvector) extension, which provides HNSW indexes for approximate nearest neighbor search. + +**Prerequisites:** +1. PostgreSQL installed with the pgvector extension +2. A database created for CodeGraph (e.g., `createdb codegraph`) +3. Install the `pg` driver: `npm install pg` + +#### Per-Project Configuration + +Add `vectorStore` to your `.codegraph/config.json`: + +```json +{ + "version": 1, + "languages": ["typescript", "javascript"], + "vectorStore": { + "backend": "pgvector", + "connectionString": "postgresql://localhost:5432/codegraph" + } +} +``` + +#### Global Configuration (Environment Variable) + +Set the `CODEGRAPH_PG_URL` environment variable to use pgvector for all projects without per-project config: + +```bash +# In your shell profile (~/.bashrc, ~/.zshrc, etc.) +export CODEGRAPH_PG_URL="postgresql://localhost:5432/codegraph" +``` + +When `CODEGRAPH_PG_URL` is set and a project's config has `"backend": "pgvector"` without a `connectionString`, the environment variable is used as the connection string. + +#### Full Options + +```json +{ + "vectorStore": { + "backend": "pgvector", + "connectionString": "postgresql://user:pass@host:5432/dbname", + "indexType": "hnsw", + "distanceMetric": "cosine", + "poolSize": 5, + "tablePrefix": "codegraph_" + } +} +``` + +| Option | Description | Default | +|--------|-------------|---------| +| `backend` | `"sqlite"` or `"pgvector"` | `"sqlite"` | +| `connectionString` | PostgreSQL connection URL (or use `CODEGRAPH_PG_URL` env var) | — | +| `indexType` | `"hnsw"` (recommended), `"ivfflat"`, or `"none"` | `"hnsw"` | +| `distanceMetric` | `"cosine"`, `"l2"`, or `"inner_product"` | `"cosine"` | +| `poolSize` | Connection pool size | `5` | +| `tablePrefix` | Table name prefix (letters, digits, underscores) | `"codegraph_"` | + +#### After Switching Backends + +When switching from SQLite to pgvector (or vice versa), regenerate embeddings: + +```bash +codegraph index --force # Re-index the project +``` + +The graph data (nodes, edges, files) always stays in SQLite — only the vector embeddings use the configured backend. + ## 🌐 Supported Languages | Language | Extension | Status | diff --git a/__tests__/context.test.ts b/__tests__/context.test.ts index 52dae1fe..a500c55d 100644 --- a/__tests__/context.test.ts +++ b/__tests__/context.test.ts @@ -159,7 +159,7 @@ export function validateEmail(email: string): boolean { describe('getCode()', () => { it('should extract code for a node', async () => { // Find the PaymentService class - const nodes = cg.getNodesByKind('class'); + const nodes = await cg.getNodesByKind('class'); const paymentService = nodes.find((n) => n.name === 'PaymentService'); expect(paymentService).toBeDefined(); diff --git a/__tests__/extraction.test.ts b/__tests__/extraction.test.ts index 05ac094a..b2c6cd9d 100644 --- a/__tests__/extraction.test.ts +++ b/__tests__/extraction.test.ts @@ -2564,7 +2564,7 @@ export function multiply(a: number, b: number): number { expect(result.nodesCreated).toBeGreaterThanOrEqual(2); // Check nodes were stored - const nodes = cg.getNodesInFile('src/utils.ts'); + const nodes = await cg.getNodesInFile('src/utils.ts'); expect(nodes.length).toBeGreaterThanOrEqual(2); const addFunc = nodes.find((n) => n.name === 'add'); @@ -2596,7 +2596,7 @@ export function multiply(a: number, b: number): number { expect(result.success).toBe(true); expect(result.filesIndexed).toBe(2); - const files = cg.getFiles(); + const files = await cg.getFiles(); expect(files.length).toBe(2); cg.close(); @@ -2613,7 +2613,7 @@ export function multiply(a: number, b: number): number { await cg.indexAll(); // Check file is tracked - const file = cg.getFile('src/main.ts'); + const file = await cg.getFile('src/main.ts'); expect(file).toBeDefined(); expect(file?.contentHash).toBeDefined(); @@ -2621,7 +2621,7 @@ export function multiply(a: number, b: number): number { fs.writeFileSync(path.join(srcDir, 'main.ts'), `export const x = 2;`); // Check for changes - const changes = cg.getChangedFiles(); + const changes = await cg.getChangedFiles(); expect(changes.modified).toContain('src/main.ts'); cg.close(); @@ -2640,7 +2640,7 @@ export function multiply(a: number, b: number): number { const cg = CodeGraph.initSync(tempDir); await cg.indexAll(); - const initialNodes = cg.getNodesInFile('src/main.ts'); + const initialNodes = await cg.getNodesInFile('src/main.ts'); expect(initialNodes.some((n) => n.name === 'original')).toBe(true); // Modify file @@ -2654,7 +2654,7 @@ export function multiply(a: number, b: number): number { expect(syncResult.filesModified).toBe(1); // Check nodes were updated - const updatedNodes = cg.getNodesInFile('src/main.ts'); + const updatedNodes = await cg.getNodesInFile('src/main.ts'); expect(updatedNodes.some((n) => n.name === 'updated')).toBe(true); expect(updatedNodes.some((n) => n.name === 'original')).toBe(false); diff --git a/__tests__/foundation.test.ts b/__tests__/foundation.test.ts index 1dc5559d..a02e5639 100644 --- a/__tests__/foundation.test.ts +++ b/__tests__/foundation.test.ts @@ -138,11 +138,11 @@ describe('CodeGraph Foundation', () => { }); describe('Database', () => { - it('should create database with correct schema', () => { + it('should create database with correct schema', async () => { const cg = CodeGraph.initSync(tempDir); // Check that we can get stats (requires tables to exist) - const stats = cg.getStats(); + const stats = await cg.getStats(); expect(stats.nodeCount).toBe(0); expect(stats.edgeCount).toBe(0); expect(stats.fileCount).toBe(0); @@ -150,9 +150,9 @@ describe('CodeGraph Foundation', () => { cg.close(); }); - it('should return correct database size', () => { + it('should return correct database size', async () => { const cg = CodeGraph.initSync(tempDir); - const stats = cg.getStats(); + const stats = await cg.getStats(); // Database should have some size (at least the schema) expect(stats.dbSizeBytes).toBeGreaterThan(0); @@ -160,22 +160,22 @@ describe('CodeGraph Foundation', () => { cg.close(); }); - it('should support optimize operation', () => { + it('should support optimize operation', async () => { const cg = CodeGraph.initSync(tempDir); // Should not throw - expect(() => cg.optimize()).not.toThrow(); + await expect(cg.optimize()).resolves.not.toThrow(); cg.close(); }); - it('should support clear operation', () => { + it('should support clear operation', async () => { const cg = CodeGraph.initSync(tempDir); // Should not throw - expect(() => cg.clear()).not.toThrow(); + await cg.clear(); - const stats = cg.getStats(); + const stats = await cg.getStats(); expect(stats.nodeCount).toBe(0); cg.close(); @@ -192,10 +192,10 @@ describe('CodeGraph Foundation', () => { expect(config.rootDir).toBe(path.resolve(tempDir)); }); - it('should update configuration', () => { + it('should update configuration', async () => { const cg = CodeGraph.initSync(tempDir); - cg.updateConfig({ maxFileSize: 999999 }); + await cg.updateConfig({ maxFileSize: 999999 }); expect(cg.getConfig().maxFileSize).toBe(999999); @@ -247,29 +247,29 @@ describe('CodeGraph Foundation', () => { }); describe('Graph Query Methods', () => { - it('should throw "Node not found" for non-existent nodes', () => { + it('should throw "Node not found" for non-existent nodes', async () => { const cg = CodeGraph.initSync(tempDir); // getContext throws for non-existent nodes - expect(() => cg.getContext('non-existent')).toThrow(/not found/i); + await expect(cg.getContext('non-existent')).rejects.toThrow(/not found/i); cg.close(); }); - it('should return empty results for non-existent nodes', () => { + it('should return empty results for non-existent nodes', async () => { const cg = CodeGraph.initSync(tempDir); // These methods return empty results instead of throwing - const traverseResult = cg.traverse('non-existent'); + const traverseResult = await cg.traverse('non-existent'); expect(traverseResult.nodes.size).toBe(0); - const callGraph = cg.getCallGraph('non-existent'); + const callGraph = await cg.getCallGraph('non-existent'); expect(callGraph.nodes.size).toBe(0); - const typeHierarchy = cg.getTypeHierarchy('non-existent'); + const typeHierarchy = await cg.getTypeHierarchy('non-existent'); expect(typeHierarchy.nodes.size).toBe(0); - const usages = cg.findUsages('non-existent'); + const usages = await cg.findUsages('non-existent'); expect(usages.length).toBe(0); cg.close(); @@ -356,28 +356,28 @@ describe('Query Builder', () => { cleanupTempDir(tempDir); }); - it('should return null for non-existent node', () => { - const node = cg.getNode('nonexistent'); + it('should return null for non-existent node', async () => { + const node = await cg.getNode('nonexistent'); expect(node).toBeNull(); }); - it('should return empty array for nodes in non-existent file', () => { - const nodes = cg.getNodesInFile('nonexistent.ts'); + it('should return empty array for nodes in non-existent file', async () => { + const nodes = await cg.getNodesInFile('nonexistent.ts'); expect(nodes).toEqual([]); }); - it('should return empty array for edges from non-existent node', () => { - const edges = cg.getOutgoingEdges('nonexistent'); + it('should return empty array for edges from non-existent node', async () => { + const edges = await cg.getOutgoingEdges('nonexistent'); expect(edges).toEqual([]); }); - it('should return null for non-existent file', () => { - const file = cg.getFile('nonexistent.ts'); + it('should return null for non-existent file', async () => { + const file = await cg.getFile('nonexistent.ts'); expect(file).toBeNull(); }); - it('should return empty array for files when none tracked', () => { - const files = cg.getFiles(); + it('should return empty array for files when none tracked', async () => { + const files = await cg.getFiles(); expect(files).toEqual([]); }); }); diff --git a/__tests__/graph.test.ts b/__tests__/graph.test.ts index 7c771af0..a01e6c1c 100644 --- a/__tests__/graph.test.ts +++ b/__tests__/graph.test.ts @@ -122,7 +122,7 @@ export { main }; }); await cg.indexAll(); - cg.resolveReferences(); + await cg.resolveReferences(); }); afterEach(() => { @@ -135,8 +135,8 @@ export { main }; }); describe('traverse()', () => { - it('should traverse graph from a starting node', () => { - const nodes = cg.getNodesByKind('function'); + it('should traverse graph from a starting node', async () => { + const nodes = await cg.getNodesByKind('function'); const mainFunc = nodes.find((n) => n.name === 'main'); if (!mainFunc) { @@ -144,7 +144,7 @@ export { main }; return; } - const subgraph = cg.traverse(mainFunc.id, { + const subgraph = await cg.traverse(mainFunc.id, { maxDepth: 2, direction: 'outgoing', }); @@ -153,29 +153,29 @@ export { main }; expect(subgraph.roots).toContain(mainFunc.id); }); - it('should respect maxDepth option', () => { - const nodes = cg.getNodesByKind('function'); + it('should respect maxDepth option', async () => { + const nodes = await cg.getNodesByKind('function'); const mainFunc = nodes.find((n) => n.name === 'main'); if (!mainFunc) { return; } - const shallow = cg.traverse(mainFunc.id, { maxDepth: 1 }); - const deep = cg.traverse(mainFunc.id, { maxDepth: 3 }); + const shallow = await cg.traverse(mainFunc.id, { maxDepth: 1 }); + const deep = await cg.traverse(mainFunc.id, { maxDepth: 3 }); expect(deep.nodes.size).toBeGreaterThanOrEqual(shallow.nodes.size); }); - it('should support incoming direction', () => { - const nodes = cg.getNodesByKind('function'); + it('should support incoming direction', async () => { + const nodes = await cg.getNodesByKind('function'); const formatValue = nodes.find((n) => n.name === 'formatValue'); if (!formatValue) { return; } - const subgraph = cg.traverse(formatValue.id, { + const subgraph = await cg.traverse(formatValue.id, { maxDepth: 2, direction: 'incoming', }); @@ -185,8 +185,8 @@ export { main }; }); describe('getContext()', () => { - it('should return context for a node', () => { - const nodes = cg.getNodesByKind('class'); + it('should return context for a node', async () => { + const nodes = await cg.getNodesByKind('class'); const derivedClass = nodes.find((n) => n.name === 'DerivedClass'); if (!derivedClass) { @@ -194,7 +194,7 @@ export { main }; return; } - const context = cg.getContext(derivedClass.id); + const context = await cg.getContext(derivedClass.id); expect(context.focal).toBeDefined(); expect(context.focal.id).toBe(derivedClass.id); @@ -204,14 +204,14 @@ export { main }; expect(context.outgoingRefs).toBeDefined(); }); - it('should throw for non-existent node', () => { - expect(() => cg.getContext('non-existent-id')).toThrow('Node not found'); + it('should throw for non-existent node', async () => { + await expect(cg.getContext('non-existent-id')).rejects.toThrow('Node not found'); }); }); describe('getCallGraph()', () => { - it('should return call graph for a function', () => { - const nodes = cg.getNodesByKind('function'); + it('should return call graph for a function', async () => { + const nodes = await cg.getNodesByKind('function'); const processValue = nodes.find((n) => n.name === 'processValue'); if (!processValue) { @@ -219,7 +219,7 @@ export { main }; return; } - const callGraph = cg.getCallGraph(processValue.id, 2); + const callGraph = await cg.getCallGraph(processValue.id, 2); expect(callGraph.nodes.size).toBeGreaterThan(0); expect(callGraph.nodes.has(processValue.id)).toBe(true); @@ -227,22 +227,22 @@ export { main }; }); describe('getTypeHierarchy()', () => { - it('should return type hierarchy for a class', () => { - const nodes = cg.getNodesByKind('class'); + it('should return type hierarchy for a class', async () => { + const nodes = await cg.getNodesByKind('class'); const derivedClass = nodes.find((n) => n.name === 'DerivedClass'); if (!derivedClass) { return; } - const hierarchy = cg.getTypeHierarchy(derivedClass.id); + const hierarchy = await cg.getTypeHierarchy(derivedClass.id); expect(hierarchy.nodes.size).toBeGreaterThan(0); expect(hierarchy.nodes.has(derivedClass.id)).toBe(true); }); - it('should return empty subgraph for non-existent node', () => { - const hierarchy = cg.getTypeHierarchy('non-existent-id'); + it('should return empty subgraph for non-existent node', async () => { + const hierarchy = await cg.getTypeHierarchy('non-existent-id'); expect(hierarchy.nodes.size).toBe(0); expect(hierarchy.edges.length).toBe(0); @@ -250,15 +250,15 @@ export { main }; }); describe('findUsages()', () => { - it('should find usages of a symbol', () => { - const nodes = cg.getNodesByKind('class'); + it('should find usages of a symbol', async () => { + const nodes = await cg.getNodesByKind('class'); const baseClass = nodes.find((n) => n.name === 'BaseClass'); if (!baseClass) { return; } - const usages = cg.findUsages(baseClass.id); + const usages = await cg.findUsages(baseClass.id); // Should find at least the extends relationship expect(usages).toBeDefined(); @@ -267,44 +267,44 @@ export { main }; }); describe('getCallers() and getCallees()', () => { - it('should get callers of a function', () => { - const nodes = cg.getNodesByKind('function'); + it('should get callers of a function', async () => { + const nodes = await cg.getNodesByKind('function'); const formatValue = nodes.find((n) => n.name === 'formatValue'); if (!formatValue) { return; } - const callers = cg.getCallers(formatValue.id); + const callers = await cg.getCallers(formatValue.id); // processValue calls formatValue expect(Array.isArray(callers)).toBe(true); }); - it('should get callees of a function', () => { - const nodes = cg.getNodesByKind('function'); + it('should get callees of a function', async () => { + const nodes = await cg.getNodesByKind('function'); const processValue = nodes.find((n) => n.name === 'processValue'); if (!processValue) { return; } - const callees = cg.getCallees(processValue.id); + const callees = await cg.getCallees(processValue.id); expect(Array.isArray(callees)).toBe(true); }); }); describe('getImpactRadius()', () => { - it('should calculate impact radius', () => { - const nodes = cg.getNodesByKind('function'); + it('should calculate impact radius', async () => { + const nodes = await cg.getNodesByKind('function'); const formatValue = nodes.find((n) => n.name === 'formatValue'); if (!formatValue) { return; } - const impact = cg.getImpactRadius(formatValue.id, 3); + const impact = await cg.getImpactRadius(formatValue.id, 3); expect(impact.nodes.size).toBeGreaterThan(0); expect(impact.nodes.has(formatValue.id)).toBe(true); @@ -312,14 +312,14 @@ export { main }; }); describe('findPath()', () => { - it('should find path between connected nodes', () => { - const stats = cg.getStats(); + it('should find path between connected nodes', async () => { + const stats = await cg.getStats(); if (stats.nodeCount < 2) { return; } - const functions = cg.getNodesByKind('function'); + const functions = await cg.getNodesByKind('function'); if (functions.length < 2) { return; } @@ -329,45 +329,45 @@ export { main }; const formatValue = functions.find((n) => n.name === 'formatValue'); if (processValue && formatValue) { - const path = cg.findPath(processValue.id, formatValue.id); + const foundPath = await cg.findPath(processValue.id, formatValue.id); // Path might exist or might not depending on edge direction - expect(path === null || Array.isArray(path)).toBe(true); + expect(foundPath === null || Array.isArray(foundPath)).toBe(true); } }); - it('should return null for disconnected nodes', () => { + it('should return null for disconnected nodes', async () => { // Create two nodes that definitely don't have a path - const path = cg.findPath('non-existent-1', 'non-existent-2'); + const foundPath = await cg.findPath('non-existent-1', 'non-existent-2'); - expect(path).toBeNull(); + expect(foundPath).toBeNull(); }); }); describe('getAncestors() and getChildren()', () => { - it('should get ancestors of a node', () => { - const methods = cg.getNodesByKind('method'); + it('should get ancestors of a node', async () => { + const methods = await cg.getNodesByKind('method'); const printMethod = methods.find((n) => n.name === 'print'); if (!printMethod) { return; } - const ancestors = cg.getAncestors(printMethod.id); + const ancestors = await cg.getAncestors(printMethod.id); // Should have class and file as ancestors expect(Array.isArray(ancestors)).toBe(true); }); - it('should get children of a node', () => { - const classes = cg.getNodesByKind('class'); + it('should get children of a node', async () => { + const classes = await cg.getNodesByKind('class'); const derivedClass = classes.find((n) => n.name === 'DerivedClass'); if (!derivedClass) { return; } - const children = cg.getChildren(derivedClass.id); + const children = await cg.getChildren(derivedClass.id); // Should have methods as children expect(Array.isArray(children)).toBe(true); @@ -375,22 +375,22 @@ export { main }; }); describe('File dependency analysis', () => { - it('should get file dependencies', () => { - const deps = cg.getFileDependencies('src/main.ts'); + it('should get file dependencies', async () => { + const deps = await cg.getFileDependencies('src/main.ts'); expect(Array.isArray(deps)).toBe(true); }); - it('should get file dependents', () => { - const dependents = cg.getFileDependents('src/utils.ts'); + it('should get file dependents', async () => { + const dependents = await cg.getFileDependents('src/utils.ts'); expect(Array.isArray(dependents)).toBe(true); }); }); describe('findCircularDependencies()', () => { - it('should detect circular dependencies', () => { - const cycles = cg.findCircularDependencies(); + it('should detect circular dependencies', async () => { + const cycles = await cg.findCircularDependencies(); // Our test files don't have circular deps expect(Array.isArray(cycles)).toBe(true); @@ -398,8 +398,8 @@ export { main }; }); describe('findDeadCode()', () => { - it('should find dead code', () => { - const deadCode = cg.findDeadCode(['function']); + it('should find dead code', async () => { + const deadCode = await cg.findDeadCode(['function']); expect(Array.isArray(deadCode)).toBe(true); @@ -411,15 +411,15 @@ export { main }; }); describe('getNodeMetrics()', () => { - it('should return metrics for a node', () => { - const functions = cg.getNodesByKind('function'); + it('should return metrics for a node', async () => { + const functions = await cg.getNodesByKind('function'); const func = functions[0]; if (!func) { return; } - const metrics = cg.getNodeMetrics(func.id); + const metrics = await cg.getNodeMetrics(func.id); expect(metrics).toHaveProperty('incomingEdgeCount'); expect(metrics).toHaveProperty('outgoingEdgeCount'); diff --git a/__tests__/pg-vectors.test.ts b/__tests__/pg-vectors.test.ts new file mode 100644 index 00000000..ef4c3923 --- /dev/null +++ b/__tests__/pg-vectors.test.ts @@ -0,0 +1,211 @@ +/** + * PostgreSQL Vector Store (pgvector) Integration Tests + * + * These tests require a running PostgreSQL instance with pgvector extension. + * Set CODEGRAPH_TEST_PG_URL to enable these tests. + * + * Example: + * CODEGRAPH_TEST_PG_URL="postgresql://user:pass@localhost:5432/codegraph_test" npm test + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; + +const PG_URL = process.env.CODEGRAPH_TEST_PG_URL; + +describe.skipIf(!PG_URL)('PgVectorStore', () => { + let PgVectorStore: any; + let store: any; + const TEST_DIMENSION = 3; + const testPrefix = `codegraph_test_${Date.now()}_`; + + beforeEach(async () => { + const mod = await import('../src/vectors/pg-store'); + PgVectorStore = mod.PgVectorStore; + + store = new PgVectorStore({ + connectionString: PG_URL!, + dimension: TEST_DIMENSION, + indexType: 'none', // Skip index creation for tests with small vectors + tablePrefix: testPrefix, + }); + await store.initialize(); + }); + + afterEach(async () => { + if (store) { + // Clean up test table + try { + await store.clear(); + } catch { /* ignore */ } + await store.dispose(); + } + }); + + it('should store and retrieve vectors', async () => { + const embedding = new Float32Array([0.1, 0.2, 0.3]); + await store.storeVector('node1', embedding, 'test-model'); + + const retrieved = await store.getVector('node1'); + + expect(retrieved).not.toBeNull(); + expect(retrieved?.length).toBe(3); + expect(retrieved?.[0]).toBeCloseTo(0.1, 4); + }); + + it('should return null for non-existent vectors', async () => { + const retrieved = await store.getVector('non-existent'); + expect(retrieved).toBeNull(); + }); + + it('should check if vector exists', async () => { + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + + expect(await store.hasVector('node1')).toBe(true); + expect(await store.hasVector('node2')).toBe(false); + }); + + it('should delete vectors', async () => { + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + expect(await store.hasVector('node1')).toBe(true); + + await store.deleteVector('node1'); + expect(await store.hasVector('node1')).toBe(false); + }); + + it('should count vectors', async () => { + expect(await store.getVectorCount()).toBe(0); + + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await store.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + + expect(await store.getVectorCount()).toBe(2); + }); + + it('should clear all vectors', async () => { + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await store.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + + expect(await store.getVectorCount()).toBe(2); + + await store.clear(); + + expect(await store.getVectorCount()).toBe(0); + }); + + it('should store vectors in batch', async () => { + const entries = [ + { nodeId: 'node1', embedding: new Float32Array([1.0, 0.0, 0.0]) }, + { nodeId: 'node2', embedding: new Float32Array([0.0, 1.0, 0.0]) }, + { nodeId: 'node3', embedding: new Float32Array([0.0, 0.0, 1.0]) }, + ]; + + await store.storeVectorBatch(entries, 'test-model'); + + expect(await store.getVectorCount()).toBe(3); + expect(await store.hasVector('node1')).toBe(true); + expect(await store.hasVector('node2')).toBe(true); + expect(await store.hasVector('node3')).toBe(true); + }); + + it('should get indexed node IDs', async () => { + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await store.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + + const ids = await store.getIndexedNodeIds(); + + expect(ids).toContain('node1'); + expect(ids).toContain('node2'); + expect(ids.length).toBe(2); + }); + + it('should perform cosine similarity search', async () => { + await store.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + await store.storeVector('node2', new Float32Array([0.9, 0.1, 0]), 'test'); + await store.storeVector('node3', new Float32Array([0, 1, 0]), 'test'); + + const query = new Float32Array([1, 0, 0]); + const results = await store.search(query, { limit: 3 }); + + expect(results.length).toBe(3); + expect(results[0].nodeId).toBe('node1'); + expect(results[0].score).toBeCloseTo(1.0, 3); + expect(results[1].nodeId).toBe('node2'); + }); + + it('should respect minScore in search', async () => { + await store.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + await store.storeVector('node2', new Float32Array([0, 1, 0]), 'test'); + + const query = new Float32Array([1, 0, 0]); + const results = await store.search(query, { limit: 10, minScore: 0.5 }); + + expect(results.length).toBe(1); + expect(results[0].nodeId).toBe('node1'); + }); + + it('should upsert on conflict', async () => { + await store.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await store.storeVector('node1', new Float32Array([0.4, 0.5, 0.6]), 'test'); + + expect(await store.getVectorCount()).toBe(1); + + const retrieved = await store.getVector('node1'); + expect(retrieved?.[0]).toBeCloseTo(0.4, 4); + }); + + it('should report ANN disabled when indexType is none', () => { + expect(store.isAnnEnabled()).toBe(false); + }); +}); + +describe.skipIf(!PG_URL)('PgVectorStore with HNSW', () => { + let PgVectorStore: any; + let store: any; + const TEST_DIMENSION = 3; + const testPrefix = `codegraph_hnsw_${Date.now()}_`; + + beforeEach(async () => { + const mod = await import('../src/vectors/pg-store'); + PgVectorStore = mod.PgVectorStore; + + store = new PgVectorStore({ + connectionString: PG_URL!, + dimension: TEST_DIMENSION, + indexType: 'hnsw', + distanceMetric: 'cosine', + tablePrefix: testPrefix, + }); + await store.initialize(); + }); + + afterEach(async () => { + if (store) { + try { await store.clear(); } catch { /* ignore */ } + await store.dispose(); + } + }); + + it('should report ANN enabled with HNSW index', () => { + expect(store.isAnnEnabled()).toBe(true); + }); + + it('should search with HNSW index', async () => { + await store.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + await store.storeVector('node2', new Float32Array([0, 1, 0]), 'test'); + + const results = await store.search(new Float32Array([1, 0, 0]), { limit: 2 }); + + expect(results.length).toBe(2); + expect(results[0].nodeId).toBe('node1'); + }); + + it('should rebuild index', async () => { + await store.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + + // Should not throw + await store.rebuildIndex(); + + const results = await store.search(new Float32Array([1, 0, 0]), { limit: 1 }); + expect(results.length).toBe(1); + }); +}); diff --git a/__tests__/pr19-improvements.test.ts b/__tests__/pr19-improvements.test.ts index 5fbe17d7..e7743b7c 100644 --- a/__tests__/pr19-improvements.test.ts +++ b/__tests__/pr19-improvements.test.ts @@ -257,9 +257,9 @@ export function funcC(): void { console.log('c'); } }); await cg.indexAll(); - cg.resolveReferences(); + await cg.resolveReferences(); - const functions = cg.getNodesByKind('function'); + const functions = await cg.getNodesByKind('function'); const funcB = functions.find((n) => n.name === 'funcB'); if (!funcB) { @@ -268,7 +268,7 @@ export function funcC(): void { console.log('c'); } } // Traverse 'both' from B - should find A (incoming caller) and C (outgoing callee) - const subgraph = cg.traverse(funcB.id, { + const subgraph = await cg.traverse(funcB.id, { maxDepth: 1, direction: 'both', }); @@ -329,10 +329,11 @@ describe('Database Layer Improvements', () => { const dbPath = path.join(testDir, 'codegraph.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Insert a node first (needed as foreign key) - queries.insertNode({ + await queries.insertNode({ id: 'func:test:1', kind: 'function', name: 'testFunc', @@ -347,7 +348,7 @@ describe('Database Layer Improvements', () => { }); // Batch insert unresolved refs with filePath and language - queries.insertUnresolvedRefsBatch([ + await queries.insertUnresolvedRefsBatch([ { fromNodeId: 'func:test:1', referenceName: 'helperA', @@ -368,7 +369,7 @@ describe('Database Layer Improvements', () => { }, ]); - const refs = queries.getUnresolvedReferences(); + const refs = await queries.getUnresolvedReferences(); expect(refs).toHaveLength(2); expect(refs.map((r) => r.referenceName).sort()).toEqual(['helperA', 'helperB']); @@ -385,11 +386,12 @@ describe('Database Layer Improvements', () => { const dbPath = path.join(testDir, 'codegraph.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Insert some nodes for (let i = 0; i < 3; i++) { - queries.insertNode({ + await queries.insertNode({ id: `func:test:${i}`, kind: 'function', name: `func${i}`, @@ -404,7 +406,7 @@ describe('Database Layer Improvements', () => { }); } - const allNodes = queries.getAllNodes(); + const allNodes = await queries.getAllNodes(); expect(allNodes).toHaveLength(3); expect(allNodes.map((n) => n.name).sort()).toEqual(['func0', 'func1', 'func2']); @@ -440,10 +442,11 @@ describe('Database Layer Improvements', () => { const dbPath = path.join(testDir, 'codegraph.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Should not throw on empty array - expect(() => queries.insertUnresolvedRefsBatch([])).not.toThrow(); + await expect(queries.insertUnresolvedRefsBatch([])).resolves.not.toThrow(); db.close(); }); @@ -482,7 +485,7 @@ export function otherFunc(): void { myFunc(); } await cg.indexAll(); // resolveReferences internally calls warmCaches - const result = cg.resolveReferences(); + const result = await cg.resolveReferences(); // Should complete without error expect(result.stats.total).toBeGreaterThanOrEqual(0); @@ -569,7 +572,7 @@ export function getValueFromCache(): number { return 2; } const handler = new ToolHandler(cg); const findSymbol = (handler as any).findSymbol.bind(handler); - const match = findSymbol(cg, 'getValue'); + const match = await findSymbol(cg, 'getValue'); expect(match).not.toBeNull(); expect(match.node.name).toBe('getValue'); // Should not have a disambiguation note for single exact match @@ -604,7 +607,7 @@ export function handle(): void {} const handler = new ToolHandler(cg); const findSymbol = (handler as any).findSymbol.bind(handler); - const match = findSymbol(cg, 'handle'); + const match = await findSymbol(cg, 'handle'); expect(match).not.toBeNull(); expect(match.node.name).toBe('handle'); // Should have a disambiguation note @@ -632,7 +635,7 @@ export function handle(): void {} const handler = new ToolHandler(cg); const findSymbol = (handler as any).findSymbol.bind(handler); - const match = findSymbol(cg, 'nonExistentSymbol'); + const match = await findSymbol(cg, 'nonExistentSymbol'); expect(match).toBeNull(); handler.closeAll(); diff --git a/__tests__/resolution.test.ts b/__tests__/resolution.test.ts index bb7fe9b0..82c822aa 100644 --- a/__tests__/resolution.test.ts +++ b/__tests__/resolution.test.ts @@ -36,7 +36,7 @@ describe('Resolution Module', () => { }); describe('Name Matcher', () => { - it('should match exact name references', () => { + it('should match exact name references', async () => { // Create a mock context const mockNodes: Node[] = [ { @@ -55,14 +55,16 @@ describe('Resolution Module', () => { ]; const context: ResolutionContext = { - getNodesInFile: () => mockNodes, - getNodesByName: (name) => mockNodes.filter((n) => n.name === name), - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => mockNodes, + getNodesByName: async (name) => mockNodes.filter((n) => n.name === name), + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => true, readFile: () => null, getProjectRoot: () => '/test', - getAllFiles: () => ['test.ts'], + getAllFiles: async () => ['test.ts'], + getNodesByLowerName: async () => [], + getImportMappings: async () => [], }; const ref = { @@ -75,14 +77,14 @@ describe('Resolution Module', () => { language: 'typescript' as const, }; - const result = matchReference(ref, context); + const result = await matchReference(ref, context); expect(result).not.toBeNull(); expect(result?.targetNodeId).toBe('func:test.ts:myFunction:10'); expect(result?.resolvedBy).toBe('exact-match'); }); - it('should prefer same-module candidates over cross-module matches', () => { + it('should prefer same-module candidates over cross-module matches', async () => { // Simulates a Python monorepo where multiple apps define navigate() const candidateA: Node = { id: 'func:apps/app_a/src/server.py:navigate:10', @@ -113,16 +115,16 @@ describe('Resolution Module', () => { }; const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: (name) => name === 'navigate' ? [candidateA, candidateB] : [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async (name) => name === 'navigate' ? [candidateA, candidateB] : [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => true, readFile: () => null, getProjectRoot: () => '/test', - getAllFiles: () => [], - getNodesByLowerName: () => [], - getImportMappings: () => [], + getAllFiles: async () => [], + getNodesByLowerName: async () => [], + getImportMappings: async () => [], }; // Reference from app_a should resolve to app_a's navigate, not app_b's @@ -136,14 +138,14 @@ describe('Resolution Module', () => { language: 'python' as const, }; - const result = matchReference(ref, context); + const result = await matchReference(ref, context); expect(result).not.toBeNull(); expect(result?.targetNodeId).toBe('func:apps/app_a/src/server.py:navigate:10'); expect(result?.resolvedBy).toBe('exact-match'); }); - it('should lower confidence for cross-module exact matches', () => { + it('should lower confidence for cross-module exact matches', async () => { // Only one candidate but in a completely different module const candidates: Node[] = [ { @@ -175,16 +177,16 @@ describe('Resolution Module', () => { ]; const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: (name) => name === 'navigate' ? candidates : [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async (name) => name === 'navigate' ? candidates : [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => true, readFile: () => null, getProjectRoot: () => '/test', - getAllFiles: () => [], - getNodesByLowerName: () => [], - getImportMappings: () => [], + getAllFiles: async () => [], + getNodesByLowerName: async () => [], + getImportMappings: async () => [], }; // Reference from app_a — neither candidate is in the same module @@ -198,14 +200,14 @@ describe('Resolution Module', () => { language: 'python' as const, }; - const result = matchReference(ref, context); + const result = await matchReference(ref, context); // Should still resolve but with low confidence expect(result).not.toBeNull(); expect(result?.confidence).toBeLessThanOrEqual(0.4); }); - it('should match qualified name references', () => { + it('should match qualified name references', async () => { const mockClassNode: Node = { id: 'class:user.ts:User:5', kind: 'class', @@ -235,21 +237,21 @@ describe('Resolution Module', () => { }; const context: ResolutionContext = { - getNodesInFile: (fp) => fp === 'user.ts' ? [mockClassNode, mockMethodNode] : [], - getNodesByName: (name) => { + getNodesInFile: async (fp) => fp === 'user.ts' ? [mockClassNode, mockMethodNode] : [], + getNodesByName: async (name) => { if (name === 'User') return [mockClassNode]; if (name === 'save') return [mockMethodNode]; return []; }, - getNodesByQualifiedName: (qn) => { + getNodesByQualifiedName: async (qn) => { if (qn === 'user.ts::User::save') return [mockMethodNode]; return []; }, - getNodesByKind: () => [], + getNodesByKind: async () => [], fileExists: () => true, readFile: () => null, getProjectRoot: () => '/test', - getAllFiles: () => ['user.ts'], + getAllFiles: async () => ['user.ts'], }; const ref = { @@ -262,7 +264,7 @@ describe('Resolution Module', () => { language: 'typescript' as const, }; - const result = matchReference(ref, context); + const result = await matchReference(ref, context); expect(result).not.toBeNull(); expect(result?.targetNodeId).toBe('method:user.ts:User.save:15'); @@ -272,14 +274,14 @@ describe('Resolution Module', () => { describe('Import Resolver', () => { it('should resolve relative import paths', () => { const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: () => [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async () => [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: (p) => p === 'src/components/utils.ts' || p === 'src/components/utils/index.ts', readFile: () => null, getProjectRoot: () => '', - getAllFiles: () => ['src/components/utils.ts', 'src/components/utils/index.ts'], + getAllFiles: async () => ['src/components/utils.ts', 'src/components/utils/index.ts'], }; const result = resolveImportPath( @@ -294,14 +296,14 @@ describe('Resolution Module', () => { it('should resolve parent directory imports', () => { const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: () => [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async () => [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: (p) => p === 'src/helpers.ts' || p === 'src/helpers/index.ts', readFile: () => null, getProjectRoot: () => '', - getAllFiles: () => ['src/helpers.ts', 'src/helpers/index.ts'], + getAllFiles: async () => ['src/helpers.ts', 'src/helpers/index.ts'], }; const result = resolveImportPath( @@ -354,12 +356,12 @@ from ..services import auth_service }); describe('Framework Detection', () => { - it('should detect React framework', () => { + it('should detect React framework', async () => { const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: () => [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async () => [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => false, readFile: (p) => { if (p === 'package.json') { @@ -370,19 +372,19 @@ from ..services import auth_service return null; }, getProjectRoot: () => '/test', - getAllFiles: () => ['package.json', 'src/App.tsx'], + getAllFiles: async () => ['package.json', 'src/App.tsx'], }; - const frameworks = detectFrameworks(context); + const frameworks = await detectFrameworks(context); expect(frameworks.some((f) => f.name === 'react')).toBe(true); }); - it('should detect Express framework', () => { + it('should detect Express framework', async () => { const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: () => [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async () => [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => false, readFile: (p) => { if (p === 'package.json') { @@ -393,26 +395,26 @@ from ..services import auth_service return null; }, getProjectRoot: () => '/test', - getAllFiles: () => ['package.json', 'src/app.js'], + getAllFiles: async () => ['package.json', 'src/app.js'], }; - const frameworks = detectFrameworks(context); + const frameworks = await detectFrameworks(context); expect(frameworks.some((f) => f.name === 'express')).toBe(true); }); - it('should detect Laravel framework', () => { + it('should detect Laravel framework', async () => { const context: ResolutionContext = { - getNodesInFile: () => [], - getNodesByName: () => [], - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async () => [], + getNodesByName: async () => [], + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: (p) => p === 'artisan', readFile: () => null, getProjectRoot: () => '/test', - getAllFiles: () => ['artisan', 'app/Http/Kernel.php'], + getAllFiles: async () => ['artisan', 'app/Http/Kernel.php'], }; - const frameworks = detectFrameworks(context); + const frameworks = await detectFrameworks(context); expect(frameworks.some((f) => f.name === 'laravel')).toBe(true); }); @@ -426,7 +428,7 @@ from ..services import auth_service }); describe('React Framework Resolver', () => { - it('should resolve React component references', () => { + it('should resolve React component references', async () => { const mockNodes: Node[] = [ { id: 'component:src/Button.tsx:Button:5', @@ -444,10 +446,10 @@ from ..services import auth_service ]; const context: ResolutionContext = { - getNodesInFile: (fp) => (fp === 'src/Button.tsx' ? mockNodes : []), - getNodesByName: () => mockNodes, - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async (fp) => (fp === 'src/Button.tsx' ? mockNodes : []), + getNodesByName: async () => mockNodes, + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => false, readFile: (p) => { if (p === 'package.json') { @@ -456,10 +458,10 @@ from ..services import auth_service return null; }, getProjectRoot: () => '/test', - getAllFiles: () => ['package.json', 'src/Button.tsx', 'src/App.tsx'], + getAllFiles: async () => ['package.json', 'src/Button.tsx', 'src/App.tsx'], }; - const frameworks = detectFrameworks(context); + const frameworks = await detectFrameworks(context); const reactResolver = frameworks.find((f) => f.name === 'react'); expect(reactResolver).toBeDefined(); @@ -473,12 +475,12 @@ from ..services import auth_service language: 'typescript' as const, }; - const result = reactResolver!.resolve(ref, context); + const result = await reactResolver!.resolve(ref, context); expect(result).not.toBeNull(); expect(result?.targetNodeId).toBe('component:src/Button.tsx:Button:5'); }); - it('should resolve custom hook references', () => { + it('should resolve custom hook references', async () => { const mockNodes: Node[] = [ { id: 'hook:src/hooks/useAuth.ts:useAuth:1', @@ -496,10 +498,10 @@ from ..services import auth_service ]; const context: ResolutionContext = { - getNodesInFile: (fp) => (fp.includes('useAuth') ? mockNodes : []), - getNodesByName: () => mockNodes, - getNodesByQualifiedName: () => [], - getNodesByKind: () => [], + getNodesInFile: async (fp) => (fp.includes('useAuth') ? mockNodes : []), + getNodesByName: async () => mockNodes, + getNodesByQualifiedName: async () => [], + getNodesByKind: async () => [], fileExists: () => false, readFile: (p) => { if (p === 'package.json') { @@ -508,10 +510,10 @@ from ..services import auth_service return null; }, getProjectRoot: () => '/test', - getAllFiles: () => ['package.json', 'src/hooks/useAuth.ts'], + getAllFiles: async () => ['package.json', 'src/hooks/useAuth.ts'], }; - const frameworks = detectFrameworks(context); + const frameworks = await detectFrameworks(context); const reactResolver = frameworks.find((f) => f.name === 'react'); const ref = { @@ -524,7 +526,7 @@ from ..services import auth_service language: 'typescript' as const, }; - const result = reactResolver!.resolve(ref, context); + const result = await reactResolver!.resolve(ref, context); expect(result).not.toBeNull(); expect(result?.targetNodeId).toBe('hook:src/hooks/useAuth.ts:useAuth:1'); }); @@ -572,7 +574,7 @@ function processDate(input: string): string { expect(frameworks).toContain('react'); // Get stats to verify indexing worked - const stats = cg.getStats(); + const stats = await cg.getStats(); expect(stats.fileCount).toBe(2); expect(stats.nodeCount).toBeGreaterThan(0); }); @@ -601,7 +603,7 @@ function main(): void { cg = await CodeGraph.init(tempDir, { index: true }); // Run reference resolution - const result = cg.resolveReferences(); + const result = await cg.resolveReferences(); // Should have attempted resolution expect(result.stats.total).toBeGreaterThanOrEqual(0); diff --git a/__tests__/security.test.ts b/__tests__/security.test.ts index 53441d58..a031735b 100644 --- a/__tests__/security.test.ts +++ b/__tests__/security.test.ts @@ -164,7 +164,7 @@ describe('Path Traversal Prevention', () => { }); it('should read code for valid nodes within project', async () => { - const nodes = cg.getNodesByKind('function'); + const nodes = await cg.getNodesByKind('function'); const hello = nodes.find((n) => n.name === 'hello'); expect(hello).toBeDefined(); @@ -364,10 +364,11 @@ describe('JSON.parse Error Boundaries in DB', () => { cleanupTempDir(tempDir); }); - it('should not crash when node has malformed JSON in decorators column', () => { + it('should not crash when node has malformed JSON in decorators column', async () => { const dbPath = path.join(tempDir, 'test.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Insert a node with malformed JSON in the decorators column db.getDb().prepare(` @@ -381,7 +382,7 @@ describe('JSON.parse Error Boundaries in DB', () => { ); // Should not throw - should return node with undefined decorators - const node = queries.getNodeById('test-node-1'); + const node = await queries.getNodeById('test-node-1'); expect(node).not.toBeNull(); expect(node!.name).toBe('myFunc'); expect(node!.decorators).toBeUndefined(); @@ -389,10 +390,11 @@ describe('JSON.parse Error Boundaries in DB', () => { db.close(); }); - it('should not crash when edge has malformed JSON in metadata column', () => { + it('should not crash when edge has malformed JSON in metadata column', async () => { const dbPath = path.join(tempDir, 'test.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Insert two nodes first const insertNode = db.getDb().prepare(` @@ -409,7 +411,7 @@ describe('JSON.parse Error Boundaries in DB', () => { `).run('node-a', 'node-b', 'calls', 'broken json {{{'); // Should not throw - should return edge with undefined metadata - const edges = queries.getOutgoingEdges('node-a'); + const edges = await queries.getOutgoingEdges('node-a'); expect(edges.length).toBe(1); expect(edges[0].source).toBe('node-a'); expect(edges[0].target).toBe('node-b'); @@ -418,10 +420,11 @@ describe('JSON.parse Error Boundaries in DB', () => { db.close(); }); - it('should not crash when file record has malformed JSON in errors column', () => { + it('should not crash when file record has malformed JSON in errors column', async () => { const dbPath = path.join(tempDir, 'test.db'); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const { SqliteDbAdapter } = await import('../src/db/sqlite-db-adapter'); + const queries = new QueryBuilder(new SqliteDbAdapter(db.getDb())); // Insert a file with malformed errors JSON db.getDb().prepare(` @@ -430,7 +433,7 @@ describe('JSON.parse Error Boundaries in DB', () => { `).run('test.ts', 'abc123', 'typescript', 100, Date.now(), Date.now(), 5, 'not-an-array'); // Should not throw - should return file with undefined errors - const file = queries.getFileByPath('test.ts'); + const file = await queries.getFileByPath('test.ts'); expect(file).not.toBeNull(); expect(file!.path).toBe('test.ts'); expect(file!.errors).toBeUndefined(); diff --git a/__tests__/sync.test.ts b/__tests__/sync.test.ts index 8365f630..2ca9b044 100644 --- a/__tests__/sync.test.ts +++ b/__tests__/sync.test.ts @@ -49,39 +49,39 @@ describe('Sync Module', () => { }); describe('getChangedFiles()', () => { - it('should detect added files', () => { + it('should detect added files', async () => { // Add a new file fs.writeFileSync( path.join(testDir, 'src', 'new.ts'), `export function newFunc() { return 42; }` ); - const changes = cg.getChangedFiles(); + const changes = await cg.getChangedFiles(); expect(changes.added).toContain('src/new.ts'); expect(changes.modified).toHaveLength(0); expect(changes.removed).toHaveLength(0); }); - it('should detect modified files', () => { + it('should detect modified files', async () => { // Modify existing file fs.writeFileSync( path.join(testDir, 'src', 'index.ts'), `export function hello() { return 'modified'; }` ); - const changes = cg.getChangedFiles(); + const changes = await cg.getChangedFiles(); expect(changes.added).toHaveLength(0); expect(changes.modified).toContain('src/index.ts'); expect(changes.removed).toHaveLength(0); }); - it('should detect removed files', () => { + it('should detect removed files', async () => { // Remove file fs.unlinkSync(path.join(testDir, 'src', 'index.ts')); - const changes = cg.getChangedFiles(); + const changes = await cg.getChangedFiles(); expect(changes.added).toHaveLength(0); expect(changes.modified).toHaveLength(0); @@ -104,7 +104,7 @@ describe('Sync Module', () => { expect(result.filesRemoved).toBe(0); // Verify new function is in the graph - const nodes = cg.searchNodes('newFunc'); + const nodes = await cg.searchNodes('newFunc'); expect(nodes.length).toBeGreaterThan(0); }); @@ -120,11 +120,11 @@ describe('Sync Module', () => { expect(result.filesModified).toBe(1); // Verify new function is in the graph - const nodes = cg.searchNodes('goodbye'); + const nodes = await cg.searchNodes('goodbye'); expect(nodes.length).toBeGreaterThan(0); // Verify old function is gone - const oldNodes = cg.searchNodes('hello'); + const oldNodes = await cg.searchNodes('hello'); expect(oldNodes.length).toBe(0); }); @@ -137,7 +137,7 @@ describe('Sync Module', () => { expect(result.filesRemoved).toBe(1); // Verify function is gone - const nodes = cg.searchNodes('hello'); + const nodes = await cg.searchNodes('hello'); expect(nodes.length).toBe(0); }); @@ -221,7 +221,7 @@ describe('Sync Module', () => { expect(result.changedFilePaths).toContain('src/new.ts'); // Verify the function was indexed - const nodes = cg.searchNodes('newFunc'); + const nodes = await cg.searchNodes('newFunc'); expect(nodes.length).toBeGreaterThan(0); }); @@ -233,7 +233,7 @@ describe('Sync Module', () => { expect(result.filesRemoved).toBe(1); // Verify function is gone - const nodes = cg.searchNodes('hello'); + const nodes = await cg.searchNodes('hello'); expect(nodes.length).toBe(0); }); diff --git a/__tests__/vectors.test.ts b/__tests__/vectors.test.ts index 449d8def..0fb8c336 100644 --- a/__tests__/vectors.test.ts +++ b/__tests__/vectors.test.ts @@ -120,9 +120,9 @@ describe('Vector Embeddings', () => { await searchManager.initialize(); const embedding = new Float32Array([0.1, 0.2, 0.3]); - searchManager.storeVector('node1', embedding, 'test-model'); + await searchManager.storeVector('node1', embedding, 'test-model'); - const retrieved = searchManager.getVector('node1'); + const retrieved = await searchManager.getVector('node1'); expect(retrieved).not.toBeNull(); expect(retrieved?.length).toBe(3); @@ -132,7 +132,7 @@ describe('Vector Embeddings', () => { it('should return null for non-existent vectors', async () => { await searchManager.initialize(); - const retrieved = searchManager.getVector('non-existent'); + const retrieved = await searchManager.getVector('non-existent'); expect(retrieved).toBeNull(); }); @@ -141,60 +141,60 @@ describe('Vector Embeddings', () => { await searchManager.initialize(); const embedding = new Float32Array([0.1, 0.2, 0.3]); - searchManager.storeVector('node1', embedding, 'test-model'); + await searchManager.storeVector('node1', embedding, 'test-model'); - expect(searchManager.hasVector('node1')).toBe(true); - expect(searchManager.hasVector('node2')).toBe(false); + expect(await searchManager.hasVector('node1')).toBe(true); + expect(await searchManager.hasVector('node2')).toBe(false); }); it('should delete vectors', async () => { await searchManager.initialize(); const embedding = new Float32Array([0.1, 0.2, 0.3]); - searchManager.storeVector('node1', embedding, 'test-model'); + await searchManager.storeVector('node1', embedding, 'test-model'); - expect(searchManager.hasVector('node1')).toBe(true); + expect(await searchManager.hasVector('node1')).toBe(true); - searchManager.deleteVector('node1'); + await searchManager.deleteVector('node1'); - expect(searchManager.hasVector('node1')).toBe(false); + expect(await searchManager.hasVector('node1')).toBe(false); }); it('should count vectors', async () => { await searchManager.initialize(); - expect(searchManager.getVectorCount()).toBe(0); + expect(await searchManager.getVectorCount()).toBe(0); - searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); - searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + await searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); - expect(searchManager.getVectorCount()).toBe(2); + expect(await searchManager.getVectorCount()).toBe(2); }); it('should clear all vectors', async () => { await searchManager.initialize(); - searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); - searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + await searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); - expect(searchManager.getVectorCount()).toBe(2); + expect(await searchManager.getVectorCount()).toBe(2); - searchManager.clear(); + await searchManager.clear(); - expect(searchManager.getVectorCount()).toBe(0); + expect(await searchManager.getVectorCount()).toBe(0); }); it('should perform brute-force similarity search', async () => { await searchManager.initialize(); // Store some test vectors - searchManager.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); - searchManager.storeVector('node2', new Float32Array([0.9, 0.1, 0]), 'test'); - searchManager.storeVector('node3', new Float32Array([0, 1, 0]), 'test'); + await searchManager.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + await searchManager.storeVector('node2', new Float32Array([0.9, 0.1, 0]), 'test'); + await searchManager.storeVector('node3', new Float32Array([0, 1, 0]), 'test'); // Search for similar to [1, 0, 0] const query = new Float32Array([1, 0, 0]); - const results = searchManager.search(query, { limit: 3 }); + const results = await searchManager.search(query, { limit: 3 }); expect(results.length).toBe(3); expect(results[0].nodeId).toBe('node1'); // Most similar @@ -205,11 +205,11 @@ describe('Vector Embeddings', () => { it('should respect minScore in search', async () => { await searchManager.initialize(); - searchManager.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); - searchManager.storeVector('node2', new Float32Array([0, 1, 0]), 'test'); + await searchManager.storeVector('node1', new Float32Array([1, 0, 0]), 'test'); + await searchManager.storeVector('node2', new Float32Array([0, 1, 0]), 'test'); const query = new Float32Array([1, 0, 0]); - const results = searchManager.search(query, { limit: 10, minScore: 0.5 }); + const results = await searchManager.search(query, { limit: 10, minScore: 0.5 }); // Only node1 should match with score >= 0.5 expect(results.length).toBe(1); @@ -226,21 +226,21 @@ describe('Vector Embeddings', () => { { nodeId: 'node3', embedding: new Float32Array([0.0, 0.0, 1.0]) }, ]; - searchManager.storeVectorBatch(entries, 'test-model'); + await searchManager.storeVectorBatch(entries, 'test-model'); - expect(searchManager.getVectorCount()).toBe(3); - expect(searchManager.hasVector('node1')).toBe(true); - expect(searchManager.hasVector('node2')).toBe(true); - expect(searchManager.hasVector('node3')).toBe(true); + expect(await searchManager.getVectorCount()).toBe(3); + expect(await searchManager.hasVector('node1')).toBe(true); + expect(await searchManager.hasVector('node2')).toBe(true); + expect(await searchManager.hasVector('node3')).toBe(true); }); it('should get indexed node IDs', async () => { await searchManager.initialize(); - searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); - searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); + await searchManager.storeVector('node1', new Float32Array([0.1, 0.2, 0.3]), 'test'); + await searchManager.storeVector('node2', new Float32Array([0.4, 0.5, 0.6]), 'test'); - const ids = searchManager.getIndexedNodeIds(); + const ids = await searchManager.getIndexedNodeIds(); expect(ids).toContain('node1'); expect(ids).toContain('node2'); @@ -286,8 +286,8 @@ export function processData(input: string): string { expect(cg.isEmbeddingsInitialized()).toBe(false); }); - it('should return embedding stats even before initialization', () => { - const stats = cg.getEmbeddingStats(); + it('should return embedding stats even before initialization', async () => { + const stats = await cg.getEmbeddingStats(); expect(stats).not.toBeNull(); expect(stats!.totalVectors).toBe(0); }); diff --git a/package-lock.json b/package-lock.json index 910a6c01..f8207c90 100644 --- a/package-lock.json +++ b/package-lock.json @@ -24,6 +24,7 @@ "devDependencies": { "@types/better-sqlite3": "^7.6.0", "@types/node": "^20.19.30", + "@types/pg": "^8.11.0", "@types/picomatch": "^4.0.2", "typescript": "^5.0.0", "vitest": "^2.1.9" @@ -33,6 +34,7 @@ }, "optionalDependencies": { "better-sqlite3": "^11.0.0", + "pg": "^8.13.0", "sqlite-vss": "^0.1.2" } }, @@ -911,6 +913,18 @@ "undici-types": "~6.21.0" } }, + "node_modules/@types/pg": { + "version": "8.20.0", + "resolved": "https://registry.npmjs.org/@types/pg/-/pg-8.20.0.tgz", + "integrity": "sha512-bEPFOaMAHTEP1EzpvHTbmwR8UsFyHSKsRisLIHVMXnpNefSbGA1bD6CVy+qKjGSqmZqNqBDV2azOBo8TgkcVow==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "pg-protocol": "*", + "pg-types": "^2.2.0" + } + }, "node_modules/@types/picomatch": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/@types/picomatch/-/picomatch-4.0.2.tgz", @@ -1768,6 +1782,102 @@ "node": ">= 14.16" } }, + "node_modules/pg": { + "version": "8.20.0", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.20.0.tgz", + "integrity": "sha512-ldhMxz2r8fl/6QkXnBD3CR9/xg694oT6DZQ2s6c/RI28OjtSOpxnPrUCGOBJ46RCUxcWdx3p6kw/xnDHjKvaRA==", + "license": "MIT", + "optional": true, + "dependencies": { + "pg-connection-string": "^2.12.0", + "pg-pool": "^3.13.0", + "pg-protocol": "^1.13.0", + "pg-types": "2.2.0", + "pgpass": "1.0.5" + }, + "engines": { + "node": ">= 16.0.0" + }, + "optionalDependencies": { + "pg-cloudflare": "^1.3.0" + }, + "peerDependencies": { + "pg-native": ">=3.0.1" + }, + "peerDependenciesMeta": { + "pg-native": { + "optional": true + } + } + }, + "node_modules/pg-cloudflare": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/pg-cloudflare/-/pg-cloudflare-1.3.0.tgz", + "integrity": "sha512-6lswVVSztmHiRtD6I8hw4qP/nDm1EJbKMRhf3HCYaqud7frGysPv7FYJ5noZQdhQtN2xJnimfMtvQq21pdbzyQ==", + "license": "MIT", + "optional": true + }, + "node_modules/pg-connection-string": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.12.0.tgz", + "integrity": "sha512-U7qg+bpswf3Cs5xLzRqbXbQl85ng0mfSV/J0nnA31MCLgvEaAo7CIhmeyrmJpOr7o+zm0rXK+hNnT5l9RHkCkQ==", + "license": "MIT", + "optional": true + }, + "node_modules/pg-int8": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/pg-int8/-/pg-int8-1.0.1.tgz", + "integrity": "sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==", + "devOptional": true, + "license": "ISC", + "engines": { + "node": ">=4.0.0" + } + }, + "node_modules/pg-pool": { + "version": "3.13.0", + "resolved": "https://registry.npmjs.org/pg-pool/-/pg-pool-3.13.0.tgz", + "integrity": "sha512-gB+R+Xud1gLFuRD/QgOIgGOBE2KCQPaPwkzBBGC9oG69pHTkhQeIuejVIk3/cnDyX39av2AxomQiyPT13WKHQA==", + "license": "MIT", + "optional": true, + "peerDependencies": { + "pg": ">=8.0" + } + }, + "node_modules/pg-protocol": { + "version": "1.13.0", + "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.13.0.tgz", + "integrity": "sha512-zzdvXfS6v89r6v7OcFCHfHlyG/wvry1ALxZo4LqgUoy7W9xhBDMaqOuMiF3qEV45VqsN6rdlcehHrfDtlCPc8w==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/pg-types": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/pg-types/-/pg-types-2.2.0.tgz", + "integrity": "sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "pg-int8": "1.0.1", + "postgres-array": "~2.0.0", + "postgres-bytea": "~1.0.0", + "postgres-date": "~1.0.4", + "postgres-interval": "^1.1.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/pgpass": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/pgpass/-/pgpass-1.0.5.tgz", + "integrity": "sha512-FdW9r/jQZhSeohs1Z3sI1yxFQNFvMcnmfuj4WBMUTxOrAyLMaTcE1aAMBiTlbMNaXvBCQuVi0R7hd8udDSP7ug==", + "license": "MIT", + "optional": true, + "dependencies": { + "split2": "^4.1.0" + } + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -1822,6 +1932,49 @@ "node": "^10 || ^12 || >=14" } }, + "node_modules/postgres-array": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz", + "integrity": "sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/postgres-bytea": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/postgres-bytea/-/postgres-bytea-1.0.1.tgz", + "integrity": "sha512-5+5HqXnsZPE65IJZSMkZtURARZelel2oXUEO8rH83VS/hxH5vv1uHquPg5wZs8yMAfdv971IU+kcPUczi7NVBQ==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postgres-date": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/postgres-date/-/postgres-date-1.0.7.tgz", + "integrity": "sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postgres-interval": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/postgres-interval/-/postgres-interval-1.2.0.tgz", + "integrity": "sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "xtend": "^4.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/prebuild-install": { "version": "7.1.3", "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.3.tgz", @@ -2121,6 +2274,16 @@ "node": ">=0.10.0" } }, + "node_modules/split2": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/split2/-/split2-4.2.0.tgz", + "integrity": "sha512-UcjcJOWknrNkF6PLX83qcHM6KHgVKNkV62Y8a5uYDVv9ydGQVwAHMKqHdJje1VTWpljG0WYpCDhrCdAOYH4TWg==", + "license": "ISC", + "optional": true, + "engines": { + "node": ">= 10.x" + } + }, "node_modules/sqlite-vss": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/sqlite-vss/-/sqlite-vss-0.1.2.tgz", @@ -2346,7 +2509,6 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -2526,6 +2688,16 @@ "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", "license": "ISC" + }, + "node_modules/xtend": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", + "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=0.4" + } } } } diff --git a/package.json b/package.json index cff7b65e..bf04205c 100644 --- a/package.json +++ b/package.json @@ -46,11 +46,13 @@ "@types/better-sqlite3": "^7.6.0", "@types/node": "^20.19.30", "@types/picomatch": "^4.0.2", + "@types/pg": "^8.11.0", "typescript": "^5.0.0", "vitest": "^2.1.9" }, "optionalDependencies": { "better-sqlite3": "^11.0.0", + "pg": "^8.13.0", "sqlite-vss": "^0.1.2" }, "engines": { diff --git a/src/bin/codegraph.ts b/src/bin/codegraph.ts index 13f603fe..cde9e300 100644 --- a/src/bin/codegraph.ts +++ b/src/bin/codegraph.ts @@ -448,7 +448,7 @@ program } const { default: CodeGraph } = await loadCodeGraph(); - const cg = CodeGraph.openSync(projectPath); + const cg = await CodeGraph.open(projectPath); cg.uninitialize(); success(`Removed CodeGraph from ${projectPath}`); @@ -482,7 +482,7 @@ program if (options.quiet) { // Quiet mode: no UI, just run - if (options.force) cg.clear(); + if (options.force) await cg.clear(); const result = await cg.indexAll(); if (!result.success) process.exit(1); cg.destroy(); @@ -493,7 +493,7 @@ program clack.intro('Indexing project'); if (options.force) { - cg.clear(); + await cg.clear(); clack.log.info('Cleared existing index'); } @@ -614,8 +614,8 @@ program const { default: CodeGraph } = await loadCodeGraph(); const cg = await CodeGraph.open(projectPath); - const stats = cg.getStats(); - const changes = cg.getChangedFiles(); + const stats = await cg.getStats(); + const changes = await cg.getChangedFiles(); // JSON output mode if (options.json) { @@ -721,7 +721,7 @@ program const cg = await CodeGraph.open(projectPath); const limit = parseInt(options.limit || '10', 10); - const results = cg.searchNodes(search, { + const results = await cg.searchNodes(search, { limit, kinds: options.kind ? [options.kind as any] : undefined, }); @@ -792,7 +792,7 @@ program const { default: CodeGraph } = await loadCodeGraph(); const cg = await CodeGraph.open(projectPath); - let files = cg.getFiles(); + let files = await cg.getFiles(); if (files.length === 0) { info('No files indexed. Run "codegraph index" first.'); @@ -1082,7 +1082,7 @@ program const { default: CodeGraph } = await loadCodeGraph(); const cg = await CodeGraph.open(projectPath); - const stats = cg.getStats(); + const stats = await cg.getStats(); console.log(chalk.bold('\n CodeGraph Explorer\n')); info(`Project: ${projectPath}`); @@ -1330,7 +1330,7 @@ program const current = queue.shift()!; if (current.depth >= maxDepth) continue; - const dependents = cg.getFileDependents(current.file); + const dependents = await cg.getFileDependents(current.file); for (const dep of dependents) { if (visited.has(dep)) continue; visited.add(dep); diff --git a/src/config.ts b/src/config.ts index 9ab1032a..f22255eb 100644 --- a/src/config.ts +++ b/src/config.ts @@ -107,6 +107,28 @@ export function validateConfig(config: unknown): config is CodeGraphConfig { } } + // Validate database if present + if (c.database !== undefined) { + if (typeof c.database !== 'object' || c.database === null) return false; + const db = c.database as Record; + if (db.backend !== undefined && db.backend !== 'sqlite' && db.backend !== 'postgres') return false; + if (db.connectionString !== undefined && typeof db.connectionString !== 'string') return false; + if (db.poolSize !== undefined && (typeof db.poolSize !== 'number' || db.poolSize < 1)) return false; + if (db.tablePrefix !== undefined && (typeof db.tablePrefix !== 'string' || !/^[a-zA-Z_][a-zA-Z0-9_]{0,50}$/.test(db.tablePrefix))) return false; + } + + // Validate vectorStore if present + if (c.vectorStore !== undefined) { + if (typeof c.vectorStore !== 'object' || c.vectorStore === null) return false; + const vs = c.vectorStore as Record; + if (vs.backend !== undefined && vs.backend !== 'sqlite' && vs.backend !== 'pgvector') return false; + if (vs.connectionString !== undefined && typeof vs.connectionString !== 'string') return false; + if (vs.indexType !== undefined && vs.indexType !== 'hnsw' && vs.indexType !== 'ivfflat' && vs.indexType !== 'none') return false; + if (vs.distanceMetric !== undefined && vs.distanceMetric !== 'cosine' && vs.distanceMetric !== 'l2' && vs.distanceMetric !== 'inner_product') return false; + if (vs.poolSize !== undefined && (typeof vs.poolSize !== 'number' || vs.poolSize < 1)) return false; + if (vs.tablePrefix !== undefined && (typeof vs.tablePrefix !== 'string' || !/^[a-zA-Z_][a-zA-Z0-9_]{0,50}$/.test(vs.tablePrefix))) return false; + } + return true; } @@ -128,6 +150,8 @@ function mergeConfig( extractDocstrings: overrides.extractDocstrings ?? defaults.extractDocstrings, trackCallSites: overrides.trackCallSites ?? defaults.trackCallSites, customPatterns: overrides.customPatterns ?? defaults.customPatterns, + database: overrides.database ?? defaults.database, + vectorStore: overrides.vectorStore ?? defaults.vectorStore, }; } diff --git a/src/context/index.ts b/src/context/index.ts index 3d4da098..ed4fb9c0 100644 --- a/src/context/index.ts +++ b/src/context/index.ts @@ -291,7 +291,7 @@ export class ContextBuilder { let exactMatches: SearchResult[] = []; if (symbolsFromQuery.length > 0) { try { - exactMatches = this.queries.findNodesByExactName(symbolsFromQuery, { + exactMatches = await this.queries.findNodesByExactName(symbolsFromQuery, { limit: Math.ceil(opts.searchLimit * 2), // Get more since we'll merge kinds: opts.nodeKinds && opts.nodeKinds.length > 0 ? opts.nodeKinds : undefined, }); @@ -327,7 +327,7 @@ export class ContextBuilder { // then boost results that match multiple terms const termResultsMap = new Map(); for (const term of searchTerms) { - const termResults = this.queries.searchNodes(term, { + const termResults = await this.queries.searchNodes(term, { limit: opts.searchLimit * 2, kinds: opts.nodeKinds && opts.nodeKinds.length > 0 ? opts.nodeKinds : undefined, }); @@ -403,7 +403,7 @@ export class ContextBuilder { // Resolve imports/exports to their actual definitions // If someone searches "terminal" and finds `import { TerminalPanel }`, // they want the TerminalPanel class, not the import statement - filteredResults = this.resolveImportsToDefinitions(filteredResults); + filteredResults = await this.resolveImportsToDefinitions(filteredResults); // Add entry points to subgraph for (const result of filteredResults) { @@ -413,7 +413,7 @@ export class ContextBuilder { // Traverse from each entry point for (const result of filteredResults) { - const traversalResult = this.traverser.traverseBFS(result.node.id, { + const traversalResult = await this.traverser.traverseBFS(result.node.id, { maxDepth: opts.traversalDepth, edgeKinds: opts.edgeKinds && opts.edgeKinds.length > 0 ? opts.edgeKinds : undefined, nodeKinds: opts.nodeKinds && opts.nodeKinds.length > 0 ? opts.nodeKinds : undefined, @@ -489,7 +489,7 @@ export class ContextBuilder { * @returns Code string or null if not found */ async getCode(nodeId: string): Promise { - const node = this.queries.getNodeById(nodeId); + const node = await this.queries.getNodeById(nodeId); if (!node) { return null; } @@ -636,7 +636,7 @@ export class ContextBuilder { * @param results - Search results that may include import/export nodes * @returns Results with imports resolved to definitions where possible */ - private resolveImportsToDefinitions(results: SearchResult[]): SearchResult[] { + private async resolveImportsToDefinitions(results: SearchResult[]): Promise { const resolved: SearchResult[] = []; const seenIds = new Set(); @@ -656,11 +656,11 @@ export class ContextBuilder { // Imports have outgoing 'imports' edges to the definition // Exports have outgoing 'exports' edges to the definition const edgeKind = node.kind === 'import' ? 'imports' : 'exports'; - const outgoingEdges = this.queries.getOutgoingEdges(node.id, [edgeKind as EdgeKind]); + const outgoingEdges = await this.queries.getOutgoingEdges(node.id, [edgeKind as EdgeKind]); let foundDefinition = false; for (const edge of outgoingEdges) { - const targetNode = this.queries.getNodeById(edge.target); + const targetNode = await this.queries.getNodeById(edge.target); if (targetNode && !seenIds.has(targetNode.id)) { // Found the definition - use it instead of the import seenIds.add(targetNode.id); diff --git a/src/db/adapter.ts b/src/db/adapter.ts new file mode 100644 index 00000000..d8a6f692 --- /dev/null +++ b/src/db/adapter.ts @@ -0,0 +1,130 @@ +/** + * Database Adapter Interface + * + * Provides a unified async interface over SQLite and PostgreSQL backends. + * SQLite wraps its synchronous calls in resolved Promises; PostgreSQL + * uses native async operations via the `pg` driver. + */ + +import { NodeKind, Language } from '../types'; + +/** + * Result of a write operation (INSERT, UPDATE, DELETE) + */ +export interface RunResult { + /** Number of rows changed */ + changes: number; + /** Last auto-increment ID (SQLite only; 0 for PostgreSQL) */ + lastInsertRowid: number | bigint; +} + +/** + * A prepared statement whose execution methods are async. + * + * - SQLite: wraps better-sqlite3's synchronous `.run()/.get()/.all()` in `Promise.resolve()` + * - PostgreSQL: executes `pool.query()` on each call + */ +export interface DbStatement { + run(...params: any[]): Promise; + get(...params: any[]): Promise; + all(...params: any[]): Promise; +} + +/** + * Options for full-text search + */ +export interface FtsSearchOptions { + /** Filter by node kinds */ + kinds?: NodeKind[]; + /** Filter by languages */ + languages?: Language[]; + /** Maximum results */ + limit: number; + /** Result offset */ + offset: number; +} + +/** + * A single FTS search result row + */ +export interface FtsSearchResult { + /** Raw database row (NodeRow shape) */ + row: any; + /** Relevance score (higher = better) */ + score: number; +} + +/** + * Unified async database adapter. + * + * Both SQLite and PostgreSQL implement this interface. + * The adapter handles SQL dialect differences internally: + * - Parameter binding (@named for SQLite, $N for PostgreSQL) + * - INSERT OR REPLACE vs ON CONFLICT DO UPDATE + * - FTS5 vs tsvector full-text search + */ +export interface DbAdapter { + /** Backend identifier */ + readonly backendType: 'sqlite' | 'postgres'; + + /** + * Prepare a SQL statement for execution. + * + * Synchronous -- returns a statement object whose `run/get/all` methods are async. + * The adapter translates SQL dialect differences at prepare time: + * - PostgreSQL rewrites @named params to $N positional params + * - PostgreSQL rewrites INSERT OR REPLACE to ON CONFLICT DO UPDATE + * - PostgreSQL rewrites INSERT OR IGNORE to ON CONFLICT DO NOTHING + */ + prepare(sql: string): DbStatement; + + /** + * Execute raw SQL (DDL, multi-statement scripts, etc.) + */ + exec(sql: string): Promise; + + /** + * Execute a function within a database transaction. + * + * - SQLite: uses better-sqlite3's synchronous transaction wrapper + * - PostgreSQL: acquires a client, BEGIN/COMMIT/ROLLBACK + * + * Nested transactions use SAVEPOINTs on PostgreSQL. + */ + transaction(fn: () => Promise): Promise; + + /** + * Close the database connection and release resources. + */ + close(): Promise; + + /** + * Whether the database connection is open. + */ + readonly open: boolean; + + /** + * Full-text search abstraction. + * + * Each backend implements its own FTS dialect: + * - SQLite: FTS5 with MATCH and bm25() scoring + * - PostgreSQL: tsvector with @@ and ts_rank_cd() scoring + * + * Returns rows from the `nodes` table with relevance scores. + */ + ftsSearch(query: string, options: FtsSearchOptions): Promise; +} + +/** + * Primary key map for each table. + * Used by PostgreSQL adapter to rewrite INSERT OR REPLACE to ON CONFLICT. + */ +export const TABLE_PRIMARY_KEYS: Record = { + nodes: 'id', + edges: 'id', + files: 'path', + unresolved_refs: 'id', + schema_versions: 'version', + project_metadata: 'key', + vectors: 'node_id', +}; diff --git a/src/db/db-factory.ts b/src/db/db-factory.ts new file mode 100644 index 00000000..97cb1481 --- /dev/null +++ b/src/db/db-factory.ts @@ -0,0 +1,99 @@ +/** + * Database Adapter Factory + * + * Creates the appropriate database adapter based on configuration. + * Follows the same pattern as vectors/store-factory.ts. + */ + +import * as fs from 'fs'; +import * as path from 'path'; +import { DbAdapter } from './adapter'; + +/** + * Configuration for the database backend + */ +export interface DbBackendConfig { + /** Backend type: 'sqlite' (default) or 'postgres' */ + backend: 'sqlite' | 'postgres'; + + /** PostgreSQL connection string. Can also use CODEGRAPH_PG_URL env var. */ + connectionString?: string; + + /** Connection pool size for PostgreSQL (default: 10) */ + poolSize?: number; + + /** Table name prefix for PostgreSQL (default: '') */ + tablePrefix?: string; +} + +/** Default database config (SQLite) */ +export const DEFAULT_DB_CONFIG: DbBackendConfig = { + backend: 'sqlite', +}; + +/** + * Create the appropriate database adapter based on configuration. + * + * For PostgreSQL, the `pg` module is dynamically imported so it's only + * loaded when actually needed. + * + * @param config - Database backend configuration + * @param sqliteDbPath - Path to SQLite database file (required for sqlite backend) + * @returns Initialized DbAdapter ready for use + */ +export async function createDbAdapter( + config: DbBackendConfig = DEFAULT_DB_CONFIG, + sqliteDbPath?: string, +): Promise { + if (config.backend === 'postgres') { + const connectionString = config.connectionString || process.env.CODEGRAPH_PG_URL; + if (!connectionString) { + throw new Error( + 'PostgreSQL connection string required for postgres backend. ' + + 'Set "database.connectionString" in .codegraph/config.json or set the CODEGRAPH_PG_URL environment variable.' + ); + } + + // Dynamic import so `pg` is only loaded when postgres is configured + const { PgDbAdapter } = await import('./pg-db-adapter'); + const adapter = new PgDbAdapter({ + connectionString, + poolSize: config.poolSize, + tablePrefix: config.tablePrefix, + }); + + // Initialize connection pool + await adapter.initialize(); + + // Run schema if this is a fresh database + const schemaFile = path.join(__dirname, 'pg-schema.sql'); + if (fs.existsSync(schemaFile)) { + const schema = fs.readFileSync(schemaFile, 'utf-8'); + await adapter.exec(schema); + } + + // Run pending migrations + const { getCurrentPgVersion, runPgMigrations } = await import('./pg-migrations'); + const currentVersion = await getCurrentPgVersion(adapter); + await runPgMigrations(adapter, currentVersion); + + return adapter; + } + + // Default: SQLite + if (!sqliteDbPath) { + throw new Error('SQLite database path required for sqlite backend.'); + } + + // Use existing DatabaseConnection infrastructure for SQLite + const { DatabaseConnection } = await import('./index'); + const { SqliteDbAdapter } = await import('./sqlite-db-adapter'); + + // Check if database already exists + const dbExists = fs.existsSync(sqliteDbPath); + const dbConn = dbExists + ? DatabaseConnection.open(sqliteDbPath) + : DatabaseConnection.initialize(sqliteDbPath); + + return new SqliteDbAdapter(dbConn.getDb()); +} diff --git a/src/db/pg-db-adapter.ts b/src/db/pg-db-adapter.ts new file mode 100644 index 00000000..f5b3441c --- /dev/null +++ b/src/db/pg-db-adapter.ts @@ -0,0 +1,386 @@ +/** + * PostgreSQL Database Adapter + * + * Implements the DbAdapter interface using the `pg` driver with connection pooling. + * Handles SQL dialect translation from SQLite-style to PostgreSQL: + * - @named params -> $N positional params + * - INSERT OR REPLACE -> ON CONFLICT DO UPDATE + * - INSERT OR IGNORE -> ON CONFLICT DO NOTHING + * - FTS5 MATCH -> tsvector @@ to_tsquery + */ + +import { DbAdapter, DbStatement, FtsSearchOptions, FtsSearchResult, RunResult, TABLE_PRIMARY_KEYS } from './adapter'; + +/** + * Options for the PostgreSQL database adapter. + */ +export interface PgDbAdapterOptions { + /** PostgreSQL connection string */ + connectionString: string; + /** Connection pool size (default: 10) */ + poolSize?: number; + /** Table name prefix (default: '') */ + tablePrefix?: string; +} + +// ============================================================================ +// SQL Translation Utilities +// ============================================================================ + +/** + * Translate @named parameters to $N positional parameters for PostgreSQL. + * + * Returns the rewritten SQL and an ordered list of parameter names. + * If no named params are found, returns null for paramOrder (positional mode). + */ +function translateNamedToPositional(sql: string): { sql: string; paramOrder: string[] | null } { + const paramOrder: string[] = []; + const paramMap = new Map(); + + const rewritten = sql.replace(/@(\w+)/g, (_match, name: string) => { + if (!paramMap.has(name)) { + paramMap.set(name, paramOrder.length + 1); + paramOrder.push(name); + } + return `$${paramMap.get(name)}`; + }); + + if (paramOrder.length === 0) { + return { sql: rewritten, paramOrder: null }; + } + return { sql: rewritten, paramOrder }; +} + +/** + * Translate positional ? params to $N params for PostgreSQL. + */ +function translatePositionalParams(sql: string): string { + let idx = 0; + return sql.replace(/\?/g, () => `$${++idx}`); +} + +/** + * Resolve parameters from better-sqlite3 calling conventions to a positional array. + * + * Handles: + * - Named object: run({ id: '1', name: 'a' }) -> positional array via paramOrder + * - Positional args: run('a', 'b') -> ['a', 'b'] + * - No args: run() -> [] + */ +function resolveParams(params: any[], paramOrder: string[] | null): any[] { + if (params.length === 0) return []; + + // Named object -> positional array + if ( + paramOrder && + params.length === 1 && + params[0] !== null && + typeof params[0] === 'object' && + !Array.isArray(params[0]) && + !(params[0] instanceof Buffer) && + !(params[0] instanceof Uint8Array) + ) { + return paramOrder.map(name => params[0][name]); + } + + // Already positional + return params; +} + +/** + * Rewrite INSERT OR REPLACE to INSERT ... ON CONFLICT (pk) DO UPDATE SET ... + */ +function rewriteInsertOrReplace(sql: string): string { + const match = sql.match(/INSERT\s+OR\s+REPLACE\s+INTO\s+(\w+)\s*\(([^)]+)\)/i); + if (!match) return sql; + + const tableName = match[1]!; + const columns = match[2]!.split(',').map(c => c.trim()); + const pk = TABLE_PRIMARY_KEYS[tableName]; + + if (!pk) { + // Unknown table, can't determine PK -- fall back to basic insert + return sql.replace(/INSERT\s+OR\s+REPLACE/i, 'INSERT'); + } + + // Build ON CONFLICT clause with all non-PK columns + const updateCols = columns.filter(c => c !== pk); + const updateSet = updateCols + .map(col => `${col} = EXCLUDED.${col}`) + .join(', '); + + // Replace the INSERT OR REPLACE prefix and append ON CONFLICT + let newSql = sql.replace(/INSERT\s+OR\s+REPLACE/i, 'INSERT'); + newSql += ` ON CONFLICT(${pk}) DO UPDATE SET ${updateSet}`; + + return newSql; +} + +/** + * Rewrite INSERT OR IGNORE to INSERT ... ON CONFLICT DO NOTHING + */ +function rewriteInsertOrIgnore(sql: string): string { + return sql.replace(/INSERT\s+OR\s+IGNORE/i, 'INSERT') + ' ON CONFLICT DO NOTHING'; +} + +/** + * Full SQL translation pipeline for PostgreSQL. + */ +function translateSql(sql: string): { sql: string; paramOrder: string[] | null } { + let translated = sql; + + // Rewrite INSERT OR REPLACE before param translation (since it modifies SQL structure) + if (/INSERT\s+OR\s+REPLACE/i.test(translated)) { + translated = rewriteInsertOrReplace(translated); + } else if (/INSERT\s+OR\s+IGNORE/i.test(translated)) { + translated = rewriteInsertOrIgnore(translated); + } + + // Translate parameters: check for @named first, then ? positional + if (/@\w+/.test(translated)) { + return translateNamedToPositional(translated); + } + + if (translated.includes('?')) { + return { sql: translatePositionalParams(translated), paramOrder: null }; + } + + return { sql: translated, paramOrder: null }; +} + +// ============================================================================ +// PostgreSQL Statement +// ============================================================================ + +/** + * A virtual prepared statement for PostgreSQL. + * + * Unlike SQLite's real prepared statements, this just stores the translated SQL + * and executes it via pool.query() on each call. PostgreSQL handles statement + * caching at the driver/server level. + */ +class PgStatement implements DbStatement { + private pool: any; // pg.Pool + private sql: string; + private paramOrder: string[] | null; + + constructor(pool: any, originalSql: string) { + this.pool = pool; + const { sql, paramOrder } = translateSql(originalSql); + this.sql = sql; + this.paramOrder = paramOrder; + } + + async run(...params: any[]): Promise { + const resolved = resolveParams(params, this.paramOrder); + const result = await this.pool.query(this.sql, resolved.length > 0 ? resolved : undefined); + return { + changes: result.rowCount ?? 0, + lastInsertRowid: 0, // PostgreSQL doesn't have this concept in the same way + }; + } + + async get(...params: any[]): Promise { + const resolved = resolveParams(params, this.paramOrder); + const result = await this.pool.query(this.sql, resolved.length > 0 ? resolved : undefined); + return result.rows[0]; + } + + async all(...params: any[]): Promise { + const resolved = resolveParams(params, this.paramOrder); + const result = await this.pool.query(this.sql, resolved.length > 0 ? resolved : undefined); + return result.rows; + } +} + +// ============================================================================ +// PostgreSQL Database Adapter +// ============================================================================ + +/** + * PostgreSQL implementation of DbAdapter. + * + * Uses pg.Pool for connection pooling. All SQL is translated from + * SQLite-compatible syntax at prepare time. + */ +export class PgDbAdapter implements DbAdapter { + readonly backendType = 'postgres' as const; + private pool: any; // pg.Pool + private _open = false; + private options: Required; + private transactionDepth = 0; + private transactionClient: any = null; + + constructor(options: PgDbAdapterOptions) { + this.options = { + connectionString: options.connectionString, + poolSize: options.poolSize ?? 10, + tablePrefix: options.tablePrefix ?? '', + }; + } + + get open(): boolean { + return this._open; + } + + /** + * Initialize the connection pool. + * Must be called before any other method. + */ + async initialize(): Promise { + if (this._open) return; + + let pg: any; + try { + pg = await import('pg'); + } catch { + throw new Error( + 'The "pg" package is required for PostgreSQL backend. Install it with: npm install pg' + ); + } + + const Pool = pg.default?.Pool ?? pg.Pool; + this.pool = new Pool({ + connectionString: this.options.connectionString, + max: this.options.poolSize, + }); + + // Test connection + let client: any; + try { + client = await this.pool.connect(); + client.release(); + } catch (error: any) { + await this.pool.end().catch(() => {}); + throw new Error( + `Failed to connect to PostgreSQL: ${error.message}. ` + + 'Verify your connection string and ensure the database is running.' + ); + } + + this._open = true; + } + + prepare(sql: string): DbStatement { + return new PgStatement(this.pool, sql); + } + + async exec(sql: string): Promise { + // Execute raw SQL -- may contain multiple statements + await this.pool.query(sql); + } + + async transaction(fn: () => Promise): Promise { + // Support nested transactions via SAVEPOINTs + if (this.transactionDepth > 0 && this.transactionClient) { + const savepointName = `sp_${this.transactionDepth}`; + this.transactionDepth++; + await this.transactionClient.query(`SAVEPOINT ${savepointName}`); + try { + const result = await fn(); + await this.transactionClient.query(`RELEASE SAVEPOINT ${savepointName}`); + this.transactionDepth--; + return result; + } catch (error) { + await this.transactionClient.query(`ROLLBACK TO SAVEPOINT ${savepointName}`); + this.transactionDepth--; + throw error; + } + } + + // Top-level transaction + const client = await this.pool.connect(); + const originalPool = this.pool; + this.transactionClient = client; + this.transactionDepth = 1; + + // Temporarily redirect queries through the transaction client + // so that prepared statements within the transaction use the same connection + this.pool = { + query: (...args: any[]) => client.query(...args), + }; + + try { + await client.query('BEGIN'); + const result = await fn(); + await client.query('COMMIT'); + return result; + } catch (error) { + await client.query('ROLLBACK'); + throw error; + } finally { + this.pool = originalPool; + this.transactionClient = null; + this.transactionDepth = 0; + client.release(); + } + } + + async close(): Promise { + if (this.pool) { + await this.pool.end(); + this.pool = null; + } + this._open = false; + } + + /** + * Full-text search using PostgreSQL tsvector. + * + * Queries the `search_vector` tsvector column on the nodes table + * using to_tsquery() with prefix matching and ts_rank_cd() scoring. + */ + async ftsSearch(query: string, options: FtsSearchOptions): Promise { + const { kinds, languages, limit, offset } = options; + + // Build tsquery: each term gets :* suffix for prefix matching + const terms = query + .replace(/['"*():^&|!<>]/g, '') + .split(/\s+/) + .filter(term => term.length > 0) + .filter(term => !/^(AND|OR|NOT|NEAR)$/i.test(term)); + + if (terms.length === 0) { + return []; + } + + // Use | (OR) between terms, with :* for prefix matching + const tsQueryStr = terms.map(t => `${t}:*`).join(' | '); + + let sql = ` + SELECT nodes.*, + ts_rank_cd(search_vector, to_tsquery('simple', $1)) as score + FROM nodes + WHERE search_vector @@ to_tsquery('simple', $1) + `; + + const params: (string | number)[] = [tsQueryStr]; + let paramIdx = 2; + + if (kinds && kinds.length > 0) { + const placeholders = kinds.map(() => `$${paramIdx++}`).join(','); + sql += ` AND kind IN (${placeholders})`; + params.push(...kinds); + } + + if (languages && languages.length > 0) { + const placeholders = languages.map(() => `$${paramIdx++}`).join(','); + sql += ` AND language IN (${placeholders})`; + params.push(...languages); + } + + sql += ` ORDER BY score DESC LIMIT $${paramIdx++} OFFSET $${paramIdx++}`; + params.push(limit, offset); + + try { + const result = await this.pool.query(sql, params); + return result.rows.map((row: any) => ({ + row, + score: parseFloat(row.score), + })); + } catch { + // Query failed, return empty + return []; + } + } +} diff --git a/src/db/pg-migrations.ts b/src/db/pg-migrations.ts new file mode 100644 index 00000000..ebfbd42c --- /dev/null +++ b/src/db/pg-migrations.ts @@ -0,0 +1,116 @@ +/** + * PostgreSQL Database Migrations + * + * Async migration runner for the PostgreSQL backend. + * Follows the same versioning scheme as migrations.ts (SQLite). + */ + +import { DbAdapter } from './adapter'; + +/** + * Current PostgreSQL schema version + */ +export const CURRENT_PG_SCHEMA_VERSION = 3; + +/** + * PostgreSQL migration definition + */ +interface PgMigration { + version: number; + description: string; + up: (adapter: DbAdapter) => Promise; +} + +/** + * All PostgreSQL migrations in order. + * + * Version 1 is the initial schema (pg-schema.sql). + * Future migrations go here. + */ +const pgMigrations: PgMigration[] = [ + { + version: 2, + description: 'Add project metadata, provenance tracking, and unresolved ref context', + up: async (adapter) => { + // These are already in pg-schema.sql for fresh installs. + // This migration handles upgrades from v1. + await adapter.exec(` + CREATE TABLE IF NOT EXISTS project_metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at BIGINT NOT NULL + ); + `); + // ALTER TABLE ADD COLUMN IF NOT EXISTS is PostgreSQL 9.6+ + await adapter.exec(` + ALTER TABLE unresolved_refs ADD COLUMN IF NOT EXISTS file_path TEXT NOT NULL DEFAULT ''; + ALTER TABLE unresolved_refs ADD COLUMN IF NOT EXISTS language TEXT NOT NULL DEFAULT 'unknown'; + ALTER TABLE edges ADD COLUMN IF NOT EXISTS provenance TEXT DEFAULT NULL; + CREATE INDEX IF NOT EXISTS idx_unresolved_file_path ON unresolved_refs(file_path); + CREATE INDEX IF NOT EXISTS idx_edges_provenance ON edges(provenance); + `); + }, + }, + { + version: 3, + description: 'Add lower(name) expression index for memory-efficient case-insensitive lookups', + up: async (adapter) => { + await adapter.exec(` + CREATE INDEX IF NOT EXISTS idx_nodes_lower_name ON nodes(LOWER(name)); + `); + }, + }, +]; + +/** + * Get the current schema version from the PostgreSQL database + */ +export async function getCurrentPgVersion(adapter: DbAdapter): Promise { + try { + const stmt = adapter.prepare('SELECT MAX(version) as version FROM schema_versions'); + const row = await stmt.get() as { version: number | null } | undefined; + return row?.version ?? 0; + } catch { + // Table doesn't exist yet + return 0; + } +} + +/** + * Record a migration as applied + */ +async function recordPgMigration(adapter: DbAdapter, version: number, description: string): Promise { + const stmt = adapter.prepare( + 'INSERT INTO schema_versions (version, applied_at, description) VALUES ($1, $2, $3)' + ); + await stmt.run(version, Date.now(), description); +} + +/** + * Run all pending PostgreSQL migrations + */ +export async function runPgMigrations(adapter: DbAdapter, fromVersion: number): Promise { + const pending = pgMigrations + .filter((m) => m.version > fromVersion) + .sort((a, b) => a.version - b.version); + + if (pending.length === 0) { + return; + } + + // Run each migration in a transaction + for (const migration of pending) { + await adapter.transaction(async () => { + await migration.up(adapter); + await recordPgMigration(adapter, migration.version, migration.description); + }); + } +} + +/** + * Check if the PostgreSQL database needs migration + */ +export async function needsPgMigration(adapter: DbAdapter): Promise { + const current = await getCurrentPgVersion(adapter); + return current < CURRENT_PG_SCHEMA_VERSION; +} diff --git a/src/db/pg-schema.sql b/src/db/pg-schema.sql new file mode 100644 index 00000000..d273983c --- /dev/null +++ b/src/db/pg-schema.sql @@ -0,0 +1,181 @@ +-- CodeGraph PostgreSQL Schema +-- Version 1 +-- +-- PostgreSQL equivalent of schema.sql (SQLite). +-- Key differences: +-- - SERIAL instead of AUTOINCREMENT +-- - tsvector + GIN instead of FTS5 +-- - Trigger function instead of FTS5 sync triggers +-- - No COLLATE NOCASE (handled in queries via LOWER()) + +-- ============================================================================= +-- Schema Version Tracking +-- ============================================================================= + +CREATE TABLE IF NOT EXISTS schema_versions ( + version INTEGER PRIMARY KEY, + applied_at BIGINT NOT NULL, + description TEXT +); + +INSERT INTO schema_versions (version, applied_at, description) +VALUES (1, (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT, 'Initial PostgreSQL schema') +ON CONFLICT (version) DO NOTHING; + +-- ============================================================================= +-- Core Tables +-- ============================================================================= + +-- Nodes: Code symbols (functions, classes, variables, etc.) +CREATE TABLE IF NOT EXISTS nodes ( + id TEXT PRIMARY KEY, + kind TEXT NOT NULL, + name TEXT NOT NULL, + qualified_name TEXT NOT NULL, + file_path TEXT NOT NULL, + language TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + start_column INTEGER NOT NULL, + end_column INTEGER NOT NULL, + docstring TEXT, + signature TEXT, + visibility TEXT, + is_exported INTEGER DEFAULT 0, + is_async INTEGER DEFAULT 0, + is_static INTEGER DEFAULT 0, + is_abstract INTEGER DEFAULT 0, + decorators TEXT, -- JSON array + type_parameters TEXT, -- JSON array + updated_at BIGINT NOT NULL, + -- Full-text search vector (populated by trigger) + search_vector tsvector +); + +-- Edges: Relationships between nodes +CREATE TABLE IF NOT EXISTS edges ( + id SERIAL PRIMARY KEY, + source TEXT NOT NULL, + target TEXT NOT NULL, + kind TEXT NOT NULL, + metadata TEXT, -- JSON object + line INTEGER, + col INTEGER, + provenance TEXT DEFAULT NULL, + FOREIGN KEY (source) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (target) REFERENCES nodes(id) ON DELETE CASCADE +); + +-- Files: Tracked source files +CREATE TABLE IF NOT EXISTS files ( + path TEXT PRIMARY KEY, + content_hash TEXT NOT NULL, + language TEXT NOT NULL, + size INTEGER NOT NULL, + modified_at BIGINT NOT NULL, + indexed_at BIGINT NOT NULL, + node_count INTEGER DEFAULT 0, + errors TEXT -- JSON array +); + +-- Unresolved References: References that need resolution after full indexing +CREATE TABLE IF NOT EXISTS unresolved_refs ( + id SERIAL PRIMARY KEY, + from_node_id TEXT NOT NULL, + reference_name TEXT NOT NULL, + reference_kind TEXT NOT NULL, + line INTEGER NOT NULL, + col INTEGER NOT NULL, + candidates TEXT, -- JSON array + file_path TEXT NOT NULL DEFAULT '', + language TEXT NOT NULL DEFAULT 'unknown', + FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE +); + +-- ============================================================================= +-- Full-Text Search (tsvector + GIN) +-- ============================================================================= + +-- Trigger function to maintain the search_vector column +CREATE OR REPLACE FUNCTION update_nodes_search_vector() +RETURNS trigger AS $$ +BEGIN + NEW.search_vector := + setweight(to_tsvector('simple', COALESCE(NEW.name, '')), 'A') || + setweight(to_tsvector('simple', COALESCE(NEW.qualified_name, '')), 'A') || + setweight(to_tsvector('simple', COALESCE(NEW.docstring, '')), 'B') || + setweight(to_tsvector('simple', COALESCE(NEW.signature, '')), 'C'); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger: update search_vector on INSERT or UPDATE +CREATE OR REPLACE TRIGGER trg_nodes_search_vector + BEFORE INSERT OR UPDATE ON nodes + FOR EACH ROW EXECUTE FUNCTION update_nodes_search_vector(); + +-- GIN index for fast tsvector searches +CREATE INDEX IF NOT EXISTS idx_nodes_search_vector ON nodes USING gin(search_vector); + +-- ============================================================================= +-- Indexes for Query Performance +-- ============================================================================= + +-- Node indexes +CREATE INDEX IF NOT EXISTS idx_nodes_kind ON nodes(kind); +CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name); +CREATE INDEX IF NOT EXISTS idx_nodes_qualified_name ON nodes(qualified_name); +CREATE INDEX IF NOT EXISTS idx_nodes_file_path ON nodes(file_path); +CREATE INDEX IF NOT EXISTS idx_nodes_language ON nodes(language); +CREATE INDEX IF NOT EXISTS idx_nodes_file_line ON nodes(file_path, start_line); +CREATE INDEX IF NOT EXISTS idx_nodes_lower_name ON nodes(LOWER(name)); + +-- Edge indexes +CREATE INDEX IF NOT EXISTS idx_edges_source ON edges(source); +CREATE INDEX IF NOT EXISTS idx_edges_target ON edges(target); +CREATE INDEX IF NOT EXISTS idx_edges_kind ON edges(kind); +CREATE INDEX IF NOT EXISTS idx_edges_source_kind ON edges(source, kind); +CREATE INDEX IF NOT EXISTS idx_edges_target_kind ON edges(target, kind); +CREATE INDEX IF NOT EXISTS idx_edges_provenance ON edges(provenance); + +-- File indexes +CREATE INDEX IF NOT EXISTS idx_files_language ON files(language); +CREATE INDEX IF NOT EXISTS idx_files_modified_at ON files(modified_at); + +-- Unresolved refs indexes +CREATE INDEX IF NOT EXISTS idx_unresolved_from_node ON unresolved_refs(from_node_id); +CREATE INDEX IF NOT EXISTS idx_unresolved_name ON unresolved_refs(reference_name); +CREATE INDEX IF NOT EXISTS idx_unresolved_file_path ON unresolved_refs(file_path); +CREATE INDEX IF NOT EXISTS idx_unresolved_from_name ON unresolved_refs(from_node_id, reference_name); + +-- ============================================================================= +-- Vector Storage +-- ============================================================================= + +-- Vector embeddings for semantic search +-- Uses pgvector extension for native vector type and ANN indexes +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE IF NOT EXISTS vectors ( + node_id TEXT PRIMARY KEY, + embedding vector(768) NOT NULL, + model TEXT NOT NULL, + created_at BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_vectors_model ON vectors(model); + +-- HNSW index for fast approximate nearest neighbor search +CREATE INDEX IF NOT EXISTS idx_vectors_embedding + ON vectors USING hnsw (embedding vector_cosine_ops) + WITH (m = 16, ef_construction = 64); + +-- ============================================================================= +-- Project Metadata +-- ============================================================================= + +CREATE TABLE IF NOT EXISTS project_metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at BIGINT NOT NULL +); diff --git a/src/db/queries.ts b/src/db/queries.ts index df46018a..2d99943c 100644 --- a/src/db/queries.ts +++ b/src/db/queries.ts @@ -1,10 +1,11 @@ /** * Database Queries * - * Prepared statements for CRUD operations on the knowledge graph. + * Async prepared statements for CRUD operations on the knowledge graph. + * Works with both SQLite and PostgreSQL backends via the DbAdapter interface. */ -import { SqliteDatabase, SqliteStatement } from './sqlite-adapter'; +import { DbAdapter, DbStatement } from './adapter'; import { Node, Edge, @@ -21,7 +22,7 @@ import { safeJsonParse } from '../utils'; import { kindBonus, scorePathRelevance } from '../search/query-utils'; /** - * Database row types (snake_case from SQLite) + * Database row types (snake_case from database) */ interface NodeRow { id: string; @@ -140,10 +141,29 @@ function rowToFileRecord(row: FileRow): FileRecord { } /** - * Query builder for the knowledge graph database + * Convert UnresolvedRefRow to UnresolvedReference + */ +function rowToUnresolvedRef(row: UnresolvedRefRow): UnresolvedReference { + return { + fromNodeId: row.from_node_id, + referenceName: row.reference_name, + referenceKind: row.reference_kind as EdgeKind, + line: row.line, + column: row.col, + candidates: row.candidates ? safeJsonParse(row.candidates, undefined) : undefined, + filePath: row.file_path, + language: row.language as Language, + }; +} + +/** + * Async query builder for the knowledge graph database. + * + * All methods are async to support both SQLite (sync wrapped in Promise.resolve) + * and PostgreSQL (native async via pg driver) backends. */ export class QueryBuilder { - private db: SqliteDatabase; + private db: DbAdapter; // Node cache for frequently accessed nodes (LRU-style, max 1000 entries) private nodeCache: Map = new Map(); @@ -151,39 +171,46 @@ export class QueryBuilder { // Prepared statements (lazily initialized) private stmts: { - insertNode?: SqliteStatement; - updateNode?: SqliteStatement; - deleteNode?: SqliteStatement; - deleteNodesByFile?: SqliteStatement; - getNodeById?: SqliteStatement; - getNodesByFile?: SqliteStatement; - getNodesByKind?: SqliteStatement; - insertEdge?: SqliteStatement; - upsertFile?: SqliteStatement; - deleteEdgesBySource?: SqliteStatement; - deleteEdgesByTarget?: SqliteStatement; - getEdgesBySource?: SqliteStatement; - getEdgesByTarget?: SqliteStatement; - insertFile?: SqliteStatement; - updateFile?: SqliteStatement; - deleteFile?: SqliteStatement; - getFileByPath?: SqliteStatement; - getAllFiles?: SqliteStatement; - insertUnresolved?: SqliteStatement; - deleteUnresolvedByNode?: SqliteStatement; - getUnresolvedByName?: SqliteStatement; - getNodesByName?: SqliteStatement; - getNodesByQualifiedNameExact?: SqliteStatement; - getNodesByLowerName?: SqliteStatement; - getUnresolvedCount?: SqliteStatement; - getUnresolvedBatch?: SqliteStatement; - getAllFilePaths?: SqliteStatement; + insertNode?: DbStatement; + updateNode?: DbStatement; + deleteNode?: DbStatement; + deleteNodesByFile?: DbStatement; + getNodeById?: DbStatement; + getNodesByFile?: DbStatement; + getNodesByKind?: DbStatement; + insertEdge?: DbStatement; + upsertFile?: DbStatement; + deleteEdgesBySource?: DbStatement; + deleteEdgesByTarget?: DbStatement; + getEdgesBySource?: DbStatement; + getEdgesByTarget?: DbStatement; + insertFile?: DbStatement; + updateFile?: DbStatement; + deleteFile?: DbStatement; + getFileByPath?: DbStatement; + getAllFiles?: DbStatement; + insertUnresolved?: DbStatement; + deleteUnresolvedByNode?: DbStatement; + getUnresolvedByName?: DbStatement; + getNodesByName?: DbStatement; + getNodesByQualifiedNameExact?: DbStatement; + getNodesByLowerName?: DbStatement; + getUnresolvedCount?: DbStatement; + getUnresolvedBatch?: DbStatement; + getAllFilePaths?: DbStatement; } = {}; - constructor(db: SqliteDatabase) { + constructor(db: DbAdapter) { this.db = db; } + /** + * Get the underlying database adapter. + */ + getAdapter(): DbAdapter { + return this.db; + } + // =========================================================================== // Node Operations // =========================================================================== @@ -191,7 +218,7 @@ export class QueryBuilder { /** * Insert a new node */ - insertNode(node: Node): void { + async insertNode(node: Node): Promise { if (!this.stmts.insertNode) { this.stmts.insertNode = this.db.prepare(` INSERT OR REPLACE INTO nodes ( @@ -210,7 +237,7 @@ export class QueryBuilder { `); } - // Validate required fields to prevent SQLite bind errors + // Validate required fields to prevent bind errors if (!node.id || !node.kind || !node.name || !node.filePath || !node.language) { console.error('[CodeGraph] Skipping node with missing required fields:', { id: node.id, @@ -223,7 +250,7 @@ export class QueryBuilder { } try { - this.stmts.insertNode.run({ + await this.stmts.insertNode.run({ id: node.id, kind: node.kind, name: node.name, @@ -253,18 +280,18 @@ export class QueryBuilder { /** * Insert multiple nodes in a transaction */ - insertNodes(nodes: Node[]): void { - this.db.transaction(() => { + async insertNodes(nodes: Node[]): Promise { + await this.db.transaction(async () => { for (const node of nodes) { - this.insertNode(node); + await this.insertNode(node); } - })(); + }); } /** * Update an existing node */ - updateNode(node: Node): void { + async updateNode(node: Node): Promise { if (!this.stmts.updateNode) { this.stmts.updateNode = this.db.prepare(` UPDATE nodes SET @@ -300,7 +327,7 @@ export class QueryBuilder { return; } - this.stmts.updateNode.run({ + await this.stmts.updateNode.run({ id: node.id, kind: node.kind, name: node.name, @@ -327,19 +354,19 @@ export class QueryBuilder { /** * Delete a node by ID */ - deleteNode(id: string): void { + async deleteNode(id: string): Promise { if (!this.stmts.deleteNode) { this.stmts.deleteNode = this.db.prepare('DELETE FROM nodes WHERE id = ?'); } // Invalidate cache this.nodeCache.delete(id); - this.stmts.deleteNode.run(id); + await this.stmts.deleteNode.run(id); } /** * Delete all nodes for a file */ - deleteNodesByFile(filePath: string): void { + async deleteNodesByFile(filePath: string): Promise { if (!this.stmts.deleteNodesByFile) { this.stmts.deleteNodesByFile = this.db.prepare('DELETE FROM nodes WHERE file_path = ?'); } @@ -349,13 +376,13 @@ export class QueryBuilder { this.nodeCache.delete(id); } } - this.stmts.deleteNodesByFile.run(filePath); + await this.stmts.deleteNodesByFile.run(filePath); } /** * Get a node by ID */ - getNodeById(id: string): Node | null { + async getNodeById(id: string): Promise { // Check cache first if (this.nodeCache.has(id)) { const cached = this.nodeCache.get(id)!; @@ -368,7 +395,7 @@ export class QueryBuilder { if (!this.stmts.getNodeById) { this.stmts.getNodeById = this.db.prepare('SELECT * FROM nodes WHERE id = ?'); } - const row = this.stmts.getNodeById.get(id) as NodeRow | undefined; + const row = await this.stmts.getNodeById.get(id) as NodeRow | undefined; if (!row) { return null; } @@ -402,69 +429,69 @@ export class QueryBuilder { /** * Get all nodes in a file */ - getNodesByFile(filePath: string): Node[] { + async getNodesByFile(filePath: string): Promise { if (!this.stmts.getNodesByFile) { this.stmts.getNodesByFile = this.db.prepare( 'SELECT * FROM nodes WHERE file_path = ? ORDER BY start_line' ); } - const rows = this.stmts.getNodesByFile.all(filePath) as NodeRow[]; + const rows = await this.stmts.getNodesByFile.all(filePath) as NodeRow[]; return rows.map(rowToNode); } /** * Get all nodes of a specific kind */ - getNodesByKind(kind: NodeKind): Node[] { + async getNodesByKind(kind: NodeKind): Promise { if (!this.stmts.getNodesByKind) { this.stmts.getNodesByKind = this.db.prepare('SELECT * FROM nodes WHERE kind = ?'); } - const rows = this.stmts.getNodesByKind.all(kind) as NodeRow[]; + const rows = await this.stmts.getNodesByKind.all(kind) as NodeRow[]; return rows.map(rowToNode); } /** * Get all nodes in the database */ - getAllNodes(): Node[] { - const rows = this.db.prepare('SELECT * FROM nodes').all() as NodeRow[]; + async getAllNodes(): Promise { + const rows = await this.db.prepare('SELECT * FROM nodes').all() as NodeRow[]; return rows.map(rowToNode); } /** * Get nodes by exact name match (uses idx_nodes_name index) */ - getNodesByName(name: string): Node[] { + async getNodesByName(name: string): Promise { if (!this.stmts.getNodesByName) { this.stmts.getNodesByName = this.db.prepare('SELECT * FROM nodes WHERE name = ?'); } - const rows = this.stmts.getNodesByName.all(name) as NodeRow[]; + const rows = await this.stmts.getNodesByName.all(name) as NodeRow[]; return rows.map(rowToNode); } /** * Get nodes by exact qualified name match (uses idx_nodes_qualified_name index) */ - getNodesByQualifiedNameExact(qualifiedName: string): Node[] { + async getNodesByQualifiedNameExact(qualifiedName: string): Promise { if (!this.stmts.getNodesByQualifiedNameExact) { this.stmts.getNodesByQualifiedNameExact = this.db.prepare( 'SELECT * FROM nodes WHERE qualified_name = ?' ); } - const rows = this.stmts.getNodesByQualifiedNameExact.all(qualifiedName) as NodeRow[]; + const rows = await this.stmts.getNodesByQualifiedNameExact.all(qualifiedName) as NodeRow[]; return rows.map(rowToNode); } /** * Get nodes by lowercase name match (uses idx_nodes_lower_name expression index) */ - getNodesByLowerName(lowerName: string): Node[] { + async getNodesByLowerName(lowerName: string): Promise { if (!this.stmts.getNodesByLowerName) { this.stmts.getNodesByLowerName = this.db.prepare( 'SELECT * FROM nodes WHERE lower(name) = ?' ); } - const rows = this.stmts.getNodesByLowerName.all(lowerName) as NodeRow[]; + const rows = await this.stmts.getNodesByLowerName.all(lowerName) as NodeRow[]; return rows.map(rowToNode); } @@ -472,19 +499,19 @@ export class QueryBuilder { * Search nodes by name using FTS with fallback to LIKE for better matching * * Search strategy: - * 1. Try FTS5 prefix match (query*) for word-start matching - * 2. If no results, try LIKE for substring matching (e.g., "signIn" finds "signInWithGoogle") + * 1. Try FTS (FTS5 on SQLite, tsvector on PostgreSQL) for word-start matching + * 2. If no results, try LIKE for substring matching * 3. Score results based on match quality */ - searchNodes(query: string, options: SearchOptions = {}): SearchResult[] { + async searchNodes(query: string, options: SearchOptions = {}): Promise { const { kinds, languages, limit = 100, offset = 0 } = options; - // First try FTS5 with prefix matching - let results = this.searchNodesFTS(query, { kinds, languages, limit, offset }); + // Delegate FTS to the adapter (handles SQLite FTS5 vs PostgreSQL tsvector) + let results = await this.searchNodesFTS(query, { kinds, languages, limit, offset }); // If no FTS results, try LIKE-based substring search if (results.length === 0 && query.length >= 2) { - results = this.searchNodesLike(query, { kinds, languages, limit, offset }); + results = await this.searchNodesLike(query, { kinds, languages, limit, offset }); } // Apply multi-signal scoring @@ -500,65 +527,30 @@ export class QueryBuilder { } /** - * FTS5 search with prefix matching + * FTS search -- delegates to the adapter's ftsSearch() method. + * SQLite uses FTS5 MATCH + bm25, PostgreSQL uses tsvector + ts_rank_cd. */ - private searchNodesFTS(query: string, options: SearchOptions): SearchResult[] { + private async searchNodesFTS(query: string, options: SearchOptions): Promise { const { kinds, languages, limit = 100, offset = 0 } = options; - // Add prefix wildcard for better matching (e.g., "auth" matches "AuthService", "authenticate") - // Escape special FTS5 characters and add prefix wildcard - const ftsQuery = query - .replace(/['"*():^]/g, '') // Remove FTS5 special chars - .split(/\s+/) - .filter(term => term.length > 0) - // Strip FTS5 boolean operators to prevent query manipulation - .filter(term => !/^(AND|OR|NOT|NEAR)$/i.test(term)) - .map(term => `"${term}"*`) // Prefix match each term - .join(' OR '); - - if (!ftsQuery) { - return []; - } - - let sql = ` - SELECT nodes.*, bm25(nodes_fts) as score - FROM nodes_fts - JOIN nodes ON nodes_fts.id = nodes.id - WHERE nodes_fts MATCH ? - `; - - const params: (string | number)[] = [ftsQuery]; - - if (kinds && kinds.length > 0) { - sql += ` AND nodes.kind IN (${kinds.map(() => '?').join(',')})`; - params.push(...kinds); - } - - if (languages && languages.length > 0) { - sql += ` AND nodes.language IN (${languages.map(() => '?').join(',')})`; - params.push(...languages); - } - - sql += ' ORDER BY score LIMIT ? OFFSET ?'; - params.push(limit, offset); + const ftsResults = await this.db.ftsSearch(query, { + kinds: kinds as NodeKind[], + languages: languages as Language[], + limit, + offset, + }); - try { - const rows = this.db.prepare(sql).all(...params) as (NodeRow & { score: number })[]; - return rows.map((row) => ({ - node: rowToNode(row), - score: Math.abs(row.score), // bm25 returns negative scores - })); - } catch { - // FTS query failed, return empty - return []; - } + return ftsResults.map(({ row, score }) => ({ + node: rowToNode(row as NodeRow), + score, + })); } /** * LIKE-based substring search for cases where FTS doesn't match * Useful for camelCase matching (e.g., "signIn" finds "signInWithGoogle") */ - private searchNodesLike(query: string, options: SearchOptions): SearchResult[] { + private async searchNodesLike(query: string, options: SearchOptions): Promise { const { kinds, languages, limit = 100, offset = 0 } = options; let sql = ` @@ -606,7 +598,7 @@ export class QueryBuilder { sql += ' ORDER BY score DESC, length(name) ASC LIMIT ? OFFSET ?'; params.push(limit, offset); - const rows = this.db.prepare(sql).all(...params) as (NodeRow & { score: number })[]; + const rows = await this.db.prepare(sql).all(...params) as (NodeRow & { score: number })[]; return rows.map((row) => ({ node: rowToNode(row), @@ -617,31 +609,27 @@ export class QueryBuilder { /** * Find nodes by exact name match * - * Used for hybrid search - looks up symbols by exact name or case-insensitive match. - * Returns high-confidence matches for known symbol names extracted from query. - * - * @param names - Array of symbol names to look up - * @param options - Search options (kinds, languages, limit) - * @returns SearchResult array with exact matches scored at 1.0 + * Uses case-insensitive matching via LOWER() for cross-database compatibility. */ - findNodesByExactName(names: string[], options: SearchOptions = {}): SearchResult[] { + async findNodesByExactName(names: string[], options: SearchOptions = {}): Promise { if (names.length === 0) return []; const { kinds, languages, limit = 50 } = options; + const lowerNames = names.map(n => n.toLowerCase()); - // Build query with exact matches (case-insensitive) + // Build query with exact matches (case-insensitive via LOWER) let sql = ` SELECT nodes.*, CASE - WHEN name COLLATE NOCASE IN (${names.map(() => '?').join(',')}) THEN 1.0 + WHEN LOWER(name) IN (${lowerNames.map(() => '?').join(',')}) THEN 1.0 ELSE 0.9 END as score FROM nodes - WHERE name COLLATE NOCASE IN (${names.map(() => '?').join(',')}) + WHERE LOWER(name) IN (${lowerNames.map(() => '?').join(',')}) `; - // Duplicate names for both SELECT and WHERE clauses - const params: (string | number)[] = [...names, ...names]; + // Duplicate lowerNames for both SELECT and WHERE clauses + const params: (string | number)[] = [...lowerNames, ...lowerNames]; if (kinds && kinds.length > 0) { sql += ` AND kind IN (${kinds.map(() => '?').join(',')})`; @@ -656,7 +644,7 @@ export class QueryBuilder { sql += ' ORDER BY score DESC, length(name) ASC LIMIT ?'; params.push(limit); - const rows = this.db.prepare(sql).all(...params) as (NodeRow & { score: number })[]; + const rows = await this.db.prepare(sql).all(...params) as (NodeRow & { score: number })[]; return rows.map((row) => ({ node: rowToNode(row), @@ -671,7 +659,7 @@ export class QueryBuilder { /** * Insert a new edge */ - insertEdge(edge: Edge): void { + async insertEdge(edge: Edge): Promise { if (!this.stmts.insertEdge) { this.stmts.insertEdge = this.db.prepare(` INSERT OR IGNORE INTO edges (source, target, kind, metadata, line, col, provenance) @@ -679,7 +667,7 @@ export class QueryBuilder { `); } - this.stmts.insertEdge.run({ + await this.stmts.insertEdge.run({ source: edge.source, target: edge.target, kind: edge.kind, @@ -693,28 +681,28 @@ export class QueryBuilder { /** * Insert multiple edges in a transaction */ - insertEdges(edges: Edge[]): void { - this.db.transaction(() => { + async insertEdges(edges: Edge[]): Promise { + await this.db.transaction(async () => { for (const edge of edges) { - this.insertEdge(edge); + await this.insertEdge(edge); } - })(); + }); } /** * Delete all edges from a source node */ - deleteEdgesBySource(sourceId: string): void { + async deleteEdgesBySource(sourceId: string): Promise { if (!this.stmts.deleteEdgesBySource) { this.stmts.deleteEdgesBySource = this.db.prepare('DELETE FROM edges WHERE source = ?'); } - this.stmts.deleteEdgesBySource.run(sourceId); + await this.stmts.deleteEdgesBySource.run(sourceId); } /** * Get outgoing edges from a node */ - getOutgoingEdges(sourceId: string, kinds?: EdgeKind[], provenance?: string): Edge[] { + async getOutgoingEdges(sourceId: string, kinds?: EdgeKind[], provenance?: string): Promise { if ((kinds && kinds.length > 0) || provenance) { let sql = 'SELECT * FROM edges WHERE source = ?'; const params: (string | number)[] = [sourceId]; @@ -729,31 +717,31 @@ export class QueryBuilder { params.push(provenance); } - const rows = this.db.prepare(sql).all(...params) as EdgeRow[]; + const rows = await this.db.prepare(sql).all(...params) as EdgeRow[]; return rows.map(rowToEdge); } if (!this.stmts.getEdgesBySource) { this.stmts.getEdgesBySource = this.db.prepare('SELECT * FROM edges WHERE source = ?'); } - const rows = this.stmts.getEdgesBySource.all(sourceId) as EdgeRow[]; + const rows = await this.stmts.getEdgesBySource.all(sourceId) as EdgeRow[]; return rows.map(rowToEdge); } /** * Get incoming edges to a node */ - getIncomingEdges(targetId: string, kinds?: EdgeKind[]): Edge[] { + async getIncomingEdges(targetId: string, kinds?: EdgeKind[]): Promise { if (kinds && kinds.length > 0) { const sql = `SELECT * FROM edges WHERE target = ? AND kind IN (${kinds.map(() => '?').join(',')})`; - const rows = this.db.prepare(sql).all(targetId, ...kinds) as EdgeRow[]; + const rows = await this.db.prepare(sql).all(targetId, ...kinds) as EdgeRow[]; return rows.map(rowToEdge); } if (!this.stmts.getEdgesByTarget) { this.stmts.getEdgesByTarget = this.db.prepare('SELECT * FROM edges WHERE target = ?'); } - const rows = this.stmts.getEdgesByTarget.all(targetId) as EdgeRow[]; + const rows = await this.stmts.getEdgesByTarget.all(targetId) as EdgeRow[]; return rows.map(rowToEdge); } @@ -764,7 +752,7 @@ export class QueryBuilder { /** * Insert or update a file record */ - upsertFile(file: FileRecord): void { + async upsertFile(file: FileRecord): Promise { if (!this.stmts.upsertFile) { this.stmts.upsertFile = this.db.prepare(` INSERT INTO files (path, content_hash, language, size, modified_at, indexed_at, node_count, errors) @@ -780,7 +768,7 @@ export class QueryBuilder { `); } - this.stmts.upsertFile.run({ + await this.stmts.upsertFile.run({ path: file.path, contentHash: file.contentHash, language: file.language, @@ -795,43 +783,43 @@ export class QueryBuilder { /** * Delete a file record and its nodes */ - deleteFile(filePath: string): void { - this.db.transaction(() => { - this.deleteNodesByFile(filePath); + async deleteFile(filePath: string): Promise { + await this.db.transaction(async () => { + await this.deleteNodesByFile(filePath); if (!this.stmts.deleteFile) { this.stmts.deleteFile = this.db.prepare('DELETE FROM files WHERE path = ?'); } - this.stmts.deleteFile.run(filePath); - })(); + await this.stmts.deleteFile.run(filePath); + }); } /** * Get a file record by path */ - getFileByPath(filePath: string): FileRecord | null { + async getFileByPath(filePath: string): Promise { if (!this.stmts.getFileByPath) { this.stmts.getFileByPath = this.db.prepare('SELECT * FROM files WHERE path = ?'); } - const row = this.stmts.getFileByPath.get(filePath) as FileRow | undefined; + const row = await this.stmts.getFileByPath.get(filePath) as FileRow | undefined; return row ? rowToFileRecord(row) : null; } /** * Get all tracked files */ - getAllFiles(): FileRecord[] { + async getAllFiles(): Promise { if (!this.stmts.getAllFiles) { this.stmts.getAllFiles = this.db.prepare('SELECT * FROM files ORDER BY path'); } - const rows = this.stmts.getAllFiles.all() as FileRow[]; + const rows = await this.stmts.getAllFiles.all() as FileRow[]; return rows.map(rowToFileRecord); } /** * Get files that need re-indexing (hash changed) */ - getStaleFiles(currentHashes: Map): FileRecord[] { - const files = this.getAllFiles(); + async getStaleFiles(currentHashes: Map): Promise { + const files = await this.getAllFiles(); return files.filter((f) => { const currentHash = currentHashes.get(f.path); return currentHash && currentHash !== f.contentHash; @@ -845,7 +833,7 @@ export class QueryBuilder { /** * Insert an unresolved reference */ - insertUnresolvedRef(ref: UnresolvedReference): void { + async insertUnresolvedRef(ref: UnresolvedReference): Promise { if (!this.stmts.insertUnresolved) { this.stmts.insertUnresolved = this.db.prepare(` INSERT INTO unresolved_refs (from_node_id, reference_name, reference_kind, line, col, candidates, file_path, language) @@ -853,7 +841,7 @@ export class QueryBuilder { `); } - this.stmts.insertUnresolved.run({ + await this.stmts.insertUnresolved.run({ fromNodeId: ref.fromNodeId, referenceName: ref.referenceName, referenceKind: ref.referenceKind, @@ -868,77 +856,58 @@ export class QueryBuilder { /** * Insert multiple unresolved references in a transaction */ - insertUnresolvedRefsBatch(refs: UnresolvedReference[]): void { + async insertUnresolvedRefsBatch(refs: UnresolvedReference[]): Promise { if (refs.length === 0) return; - const insert = this.db.transaction(() => { + await this.db.transaction(async () => { for (const ref of refs) { - this.insertUnresolvedRef(ref); + await this.insertUnresolvedRef(ref); } }); - insert(); } /** * Delete unresolved references from a node */ - deleteUnresolvedByNode(nodeId: string): void { + async deleteUnresolvedByNode(nodeId: string): Promise { if (!this.stmts.deleteUnresolvedByNode) { this.stmts.deleteUnresolvedByNode = this.db.prepare( 'DELETE FROM unresolved_refs WHERE from_node_id = ?' ); } - this.stmts.deleteUnresolvedByNode.run(nodeId); + await this.stmts.deleteUnresolvedByNode.run(nodeId); } /** * Get unresolved references by name (for resolution) */ - getUnresolvedByName(name: string): UnresolvedReference[] { + async getUnresolvedByName(name: string): Promise { if (!this.stmts.getUnresolvedByName) { this.stmts.getUnresolvedByName = this.db.prepare( 'SELECT * FROM unresolved_refs WHERE reference_name = ?' ); } - const rows = this.stmts.getUnresolvedByName.all(name) as UnresolvedRefRow[]; - return rows.map((row) => ({ - fromNodeId: row.from_node_id, - referenceName: row.reference_name, - referenceKind: row.reference_kind as EdgeKind, - line: row.line, - column: row.col, - candidates: row.candidates ? safeJsonParse(row.candidates, undefined) : undefined, - filePath: row.file_path, - language: row.language as Language, - })); + const rows = await this.stmts.getUnresolvedByName.all(name) as UnresolvedRefRow[]; + return rows.map(rowToUnresolvedRef); } /** * Get all unresolved references */ - getUnresolvedReferences(): UnresolvedReference[] { - const rows = this.db.prepare('SELECT * FROM unresolved_refs').all() as UnresolvedRefRow[]; - return rows.map((row) => ({ - fromNodeId: row.from_node_id, - referenceName: row.reference_name, - referenceKind: row.reference_kind as EdgeKind, - line: row.line, - column: row.col, - candidates: row.candidates ? safeJsonParse(row.candidates, undefined) : undefined, - filePath: row.file_path, - language: row.language as Language, - })); + async getUnresolvedReferences(): Promise { + const rows = await this.db.prepare('SELECT * FROM unresolved_refs').all() as UnresolvedRefRow[]; + return rows.map(rowToUnresolvedRef); } /** * Get the count of unresolved references without loading them into memory */ - getUnresolvedReferencesCount(): number { + async getUnresolvedReferencesCount(): Promise { if (!this.stmts.getUnresolvedCount) { this.stmts.getUnresolvedCount = this.db.prepare( 'SELECT COUNT(*) as count FROM unresolved_refs' ); } - const row = this.stmts.getUnresolvedCount.get() as { count: number }; + const row = await this.stmts.getUnresolvedCount.get() as { count: number }; return row.count; } @@ -946,33 +915,24 @@ export class QueryBuilder { * Get a batch of unresolved references using LIMIT/OFFSET pagination. * Used to process references in bounded memory chunks. */ - getUnresolvedReferencesBatch(offset: number, limit: number): UnresolvedReference[] { + async getUnresolvedReferencesBatch(offset: number, limit: number): Promise { if (!this.stmts.getUnresolvedBatch) { this.stmts.getUnresolvedBatch = this.db.prepare( 'SELECT * FROM unresolved_refs LIMIT ? OFFSET ?' ); } - const rows = this.stmts.getUnresolvedBatch.all(limit, offset) as UnresolvedRefRow[]; - return rows.map((row) => ({ - fromNodeId: row.from_node_id, - referenceName: row.reference_name, - referenceKind: row.reference_kind as EdgeKind, - line: row.line, - column: row.col, - candidates: row.candidates ? safeJsonParse(row.candidates, undefined) : undefined, - filePath: row.file_path, - language: row.language as Language, - })); + const rows = await this.stmts.getUnresolvedBatch.all(limit, offset) as UnresolvedRefRow[]; + return rows.map(rowToUnresolvedRef); } /** - * Get all tracked file paths (lightweight — no full FileRecord objects) + * Get all tracked file paths (lightweight -- no full FileRecord objects) */ - getAllFilePaths(): string[] { + async getAllFilePaths(): Promise { if (!this.stmts.getAllFilePaths) { this.stmts.getAllFilePaths = this.db.prepare('SELECT path FROM files ORDER BY path'); } - const rows = this.stmts.getAllFilePaths.all() as Array<{ path: string }>; + const rows = await this.stmts.getAllFilePaths.all() as Array<{ path: string }>; return rows.map((r) => r.path); } @@ -980,57 +940,47 @@ export class QueryBuilder { * Get unresolved references scoped to specific file paths. * Uses the idx_unresolved_file_path index for efficient lookup. */ - getUnresolvedReferencesByFiles(filePaths: string[]): UnresolvedReference[] { + async getUnresolvedReferencesByFiles(filePaths: string[]): Promise { if (filePaths.length === 0) return []; const placeholders = filePaths.map(() => '?').join(','); - const rows = this.db + const rows = await this.db .prepare(`SELECT * FROM unresolved_refs WHERE file_path IN (${placeholders})`) .all(...filePaths) as UnresolvedRefRow[]; - return rows.map((row) => ({ - fromNodeId: row.from_node_id, - referenceName: row.reference_name, - referenceKind: row.reference_kind as EdgeKind, - line: row.line, - column: row.col, - candidates: row.candidates ? safeJsonParse(row.candidates, undefined) : undefined, - filePath: row.file_path, - language: row.language as Language, - })); + return rows.map(rowToUnresolvedRef); } /** * Delete all unresolved references (after resolution) */ - clearUnresolvedReferences(): void { - this.db.exec('DELETE FROM unresolved_refs'); + async clearUnresolvedReferences(): Promise { + await this.db.exec('DELETE FROM unresolved_refs'); } /** * Delete resolved references by their IDs */ - deleteResolvedReferences(fromNodeIds: string[]): void { + async deleteResolvedReferences(fromNodeIds: string[]): Promise { if (fromNodeIds.length === 0) return; const placeholders = fromNodeIds.map(() => '?').join(','); - this.db.prepare(`DELETE FROM unresolved_refs WHERE from_node_id IN (${placeholders})`).run(...fromNodeIds); + await this.db.prepare(`DELETE FROM unresolved_refs WHERE from_node_id IN (${placeholders})`).run(...fromNodeIds); } /** * Delete specific resolved references by (fromNodeId, referenceName, referenceKind) tuples. - * More precise than deleteResolvedReferences — only removes refs that were actually resolved. + * More precise than deleteResolvedReferences -- only removes refs that were actually resolved. */ - deleteSpecificResolvedReferences(refs: Array<{ fromNodeId: string; referenceName: string; referenceKind: string }>): void { + async deleteSpecificResolvedReferences(refs: Array<{ fromNodeId: string; referenceName: string; referenceKind: string }>): Promise { if (refs.length === 0) return; - const stmt = this.db.prepare( - 'DELETE FROM unresolved_refs WHERE from_node_id = ? AND reference_name = ? AND reference_kind = ?' - ); - const deleteMany = this.db.transaction((items: typeof refs) => { - for (const ref of items) { - stmt.run(ref.fromNodeId, ref.referenceName, ref.referenceKind); + await this.db.transaction(async () => { + const stmt = this.db.prepare( + 'DELETE FROM unresolved_refs WHERE from_node_id = ? AND reference_name = ? AND reference_kind = ?' + ); + for (const ref of refs) { + await stmt.run(ref.fromNodeId, ref.referenceName, ref.referenceKind); } }); - deleteMany(refs); } // =========================================================================== @@ -1040,9 +990,9 @@ export class QueryBuilder { /** * Get graph statistics */ - getStats(): GraphStats { + async getStats(): Promise { // Single query for all three aggregate counts - const counts = this.db.prepare(` + const counts = await this.db.prepare(` SELECT (SELECT COUNT(*) FROM nodes) AS node_count, (SELECT COUNT(*) FROM edges) AS edge_count, @@ -1050,7 +1000,7 @@ export class QueryBuilder { `).get() as { node_count: number; edge_count: number; file_count: number }; const nodesByKind = {} as Record; - const nodeKindRows = this.db + const nodeKindRows = await this.db .prepare('SELECT kind, COUNT(*) as count FROM nodes GROUP BY kind') .all() as Array<{ kind: string; count: number }>; for (const row of nodeKindRows) { @@ -1058,7 +1008,7 @@ export class QueryBuilder { } const edgesByKind = {} as Record; - const edgeKindRows = this.db + const edgeKindRows = await this.db .prepare('SELECT kind, COUNT(*) as count FROM edges GROUP BY kind') .all() as Array<{ kind: string; count: number }>; for (const row of edgeKindRows) { @@ -1066,7 +1016,7 @@ export class QueryBuilder { } const filesByLanguage = {} as Record; - const languageRows = this.db + const languageRows = await this.db .prepare('SELECT language, COUNT(*) as count FROM files GROUP BY language') .all() as Array<{ language: string; count: number }>; for (const row of languageRows) { @@ -1092,16 +1042,16 @@ export class QueryBuilder { /** * Get a metadata value by key */ - getMetadata(key: string): string | null { - const row = this.db.prepare('SELECT value FROM project_metadata WHERE key = ?').get(key) as { value: string } | undefined; + async getMetadata(key: string): Promise { + const row = await this.db.prepare('SELECT value FROM project_metadata WHERE key = ?').get(key) as { value: string } | undefined; return row?.value ?? null; } /** * Set a metadata key-value pair (upsert) */ - setMetadata(key: string, value: string): void { - this.db.prepare( + async setMetadata(key: string, value: string): Promise { + await this.db.prepare( 'INSERT INTO project_metadata (key, value, updated_at) VALUES (?, ?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at' ).run(key, value, Date.now()); } @@ -1109,8 +1059,8 @@ export class QueryBuilder { /** * Get all metadata as a key-value record */ - getAllMetadata(): Record { - const rows = this.db.prepare('SELECT key, value FROM project_metadata').all() as { key: string; value: string }[]; + async getAllMetadata(): Promise> { + const rows = await this.db.prepare('SELECT key, value FROM project_metadata').all() as { key: string; value: string }[]; const result: Record = {}; for (const row of rows) { result[row.key] = row.value; @@ -1121,14 +1071,14 @@ export class QueryBuilder { /** * Clear all data from the database */ - clear(): void { + async clear(): Promise { this.nodeCache.clear(); - this.db.transaction(() => { - this.db.exec('DELETE FROM unresolved_refs'); - this.db.exec('DELETE FROM vectors'); - this.db.exec('DELETE FROM edges'); - this.db.exec('DELETE FROM nodes'); - this.db.exec('DELETE FROM files'); - })(); + await this.db.transaction(async () => { + await this.db.exec('DELETE FROM unresolved_refs'); + await this.db.exec('DELETE FROM vectors'); + await this.db.exec('DELETE FROM edges'); + await this.db.exec('DELETE FROM nodes'); + await this.db.exec('DELETE FROM files'); + }); } } diff --git a/src/db/sqlite-db-adapter.ts b/src/db/sqlite-db-adapter.ts new file mode 100644 index 00000000..47f7ab96 --- /dev/null +++ b/src/db/sqlite-db-adapter.ts @@ -0,0 +1,155 @@ +/** + * SQLite Database Adapter + * + * Wraps the existing SqliteDatabase (better-sqlite3 or WASM) into the + * async DbAdapter interface. Synchronous calls are wrapped in resolved Promises. + */ + +import { DbAdapter, DbStatement, FtsSearchOptions, FtsSearchResult, RunResult } from './adapter'; +import { SqliteDatabase } from './sqlite-adapter'; + +/** + * Wrap a synchronous SqliteStatement into an async DbStatement. + */ +function wrapStatement(syncStmt: ReturnType): DbStatement { + return { + async run(...params: any[]): Promise { + const result = syncStmt.run(...params); + return { + changes: result.changes, + lastInsertRowid: result.lastInsertRowid, + }; + }, + async get(...params: any[]): Promise { + return syncStmt.get(...params); + }, + async all(...params: any[]): Promise { + return syncStmt.all(...params); + }, + }; +} + +/** + * SQLite implementation of DbAdapter. + * + * Wraps the existing SqliteDatabase (better-sqlite3 native or WASM fallback) + * in the unified async interface. All async methods resolve immediately since + * SQLite operations are synchronous. + */ +export class SqliteDbAdapter implements DbAdapter { + readonly backendType = 'sqlite' as const; + private db: SqliteDatabase; + + constructor(db: SqliteDatabase) { + this.db = db; + } + + get open(): boolean { + return this.db.open; + } + + /** + * Get the underlying SqliteDatabase for direct access. + * Used during migration from the old API. + */ + getDb(): SqliteDatabase { + return this.db; + } + + prepare(sql: string): DbStatement { + const stmt = this.db.prepare(sql); + return wrapStatement(stmt); + } + + async exec(sql: string): Promise { + this.db.exec(sql); + } + + async transaction(fn: () => Promise): Promise { + // better-sqlite3's transaction() expects a synchronous function. + // Since the SQLite adapter wraps sync calls in Promise.resolve(), + // the callback will only contain resolved promises internally. + // We execute synchronously via the existing transaction mechanism + // and collect the result. + // + // For the SQLite backend, we manually manage BEGIN/COMMIT/ROLLBACK + // to support async callbacks (even though they resolve immediately). + this.db.exec('BEGIN'); + try { + const result = await fn(); + this.db.exec('COMMIT'); + return result; + } catch (error) { + this.db.exec('ROLLBACK'); + throw error; + } + } + + async close(): Promise { + this.db.close(); + } + + /** + * Execute a SQLite pragma. + * This is SQLite-specific and not part of the DbAdapter interface. + */ + pragma(str: string): any { + return this.db.pragma(str); + } + + /** + * FTS5 full-text search. + * + * Executes FTS5 MATCH query with bm25() scoring against the nodes_fts + * virtual table, joined with the nodes table for full row data. + */ + async ftsSearch(query: string, options: FtsSearchOptions): Promise { + const { kinds, languages, limit, offset } = options; + + // Build FTS5 query: escape special chars and add prefix wildcards + const ftsQuery = query + .replace(/['"*():^]/g, '') + .split(/\s+/) + .filter(term => term.length > 0) + .filter(term => !/^(AND|OR|NOT|NEAR)$/i.test(term)) + .map(term => `"${term}"*`) + .join(' OR '); + + if (!ftsQuery) { + return []; + } + + let sql = ` + SELECT nodes.*, bm25(nodes_fts) as score + FROM nodes_fts + JOIN nodes ON nodes_fts.id = nodes.id + WHERE nodes_fts MATCH ? + `; + + const params: (string | number)[] = [ftsQuery]; + + if (kinds && kinds.length > 0) { + sql += ` AND nodes.kind IN (${kinds.map(() => '?').join(',')})`; + params.push(...kinds); + } + + if (languages && languages.length > 0) { + sql += ` AND nodes.language IN (${languages.map(() => '?').join(',')})`; + params.push(...languages); + } + + sql += ' ORDER BY score LIMIT ? OFFSET ?'; + params.push(limit, offset); + + try { + const rows = this.db.prepare(sql).all(...params) as any[]; + return rows.map(row => ({ + row, + score: Math.abs(row.score), // bm25 returns negative scores + })); + } catch { + // FTS query failed (e.g., invalid query syntax), return empty + return []; + } + } +} diff --git a/src/extraction/index.ts b/src/extraction/index.ts index c026bce7..7f48dd59 100644 --- a/src/extraction/index.ts +++ b/src/extraction/index.ts @@ -697,7 +697,7 @@ export class ExtractionOrchestrator { // Store in database on main thread (SQLite is not thread-safe) if (result.nodes.length > 0 || result.errors.length === 0) { const language = detectLanguage(filePath); - this.storeExtractionResult(filePath, content, language, stats, result); + await this.storeExtractionResult(filePath, content, language, stats, result); } if (result.errors.length > 0) { @@ -759,7 +759,7 @@ export class ExtractionOrchestrator { if (result.nodes.length > 0 || result.errors.length === 0) { const language = detectLanguage(filePath); const stats = await fsp.stat(path.join(this.rootDir, filePath)); - this.storeExtractionResult(filePath, content, language, stats, result); + await this.storeExtractionResult(filePath, content, language, stats, result); const idx = errors.indexOf(errEntry); if (idx >= 0) errors.splice(idx, 1); @@ -810,7 +810,7 @@ export class ExtractionOrchestrator { if (result.nodes.length > 0 || result.errors.length === 0) { const language = detectLanguage(filePath); const stats = await fsp.stat(path.join(this.rootDir, filePath)); - this.storeExtractionResult(filePath, fullContent, language, stats, result); + await this.storeExtractionResult(filePath, fullContent, language, stats, result); const idx = errors.indexOf(errEntry); if (idx >= 0) errors.splice(idx, 1); @@ -992,7 +992,7 @@ export class ExtractionOrchestrator { // Store in database if (result.nodes.length > 0 || result.errors.length === 0) { - this.storeExtractionResult(relativePath, content, language, stats, result); + await this.storeExtractionResult(relativePath, content, language, stats, result); } return result; @@ -1001,24 +1001,24 @@ export class ExtractionOrchestrator { /** * Store extraction result in database */ - private storeExtractionResult( + private async storeExtractionResult( filePath: string, content: string, language: Language, stats: fs.Stats, result: ExtractionResult - ): void { + ): Promise { const contentHash = hashContent(content); // Check if file already exists and hasn't changed - const existingFile = this.queries.getFileByPath(filePath); + const existingFile = await this.queries.getFileByPath(filePath); if (existingFile && existingFile.contentHash === contentHash) { return; // No changes } // Delete existing data for this file if (existingFile) { - this.queries.deleteFile(filePath); + await this.queries.deleteFile(filePath); } // Filter out nodes with missing required fields before insertion. @@ -1028,7 +1028,7 @@ export class ExtractionOrchestrator { // Insert nodes if (validNodes.length > 0) { - this.queries.insertNodes(validNodes); + await this.queries.insertNodes(validNodes); } // Filter edges to only reference nodes that were actually inserted @@ -1038,7 +1038,7 @@ export class ExtractionOrchestrator { (e) => insertedIds.has(e.source) && insertedIds.has(e.target) ); if (validEdges.length > 0) { - this.queries.insertEdges(validEdges); + await this.queries.insertEdges(validEdges); } } @@ -1053,7 +1053,7 @@ export class ExtractionOrchestrator { language: ref.language ?? language, })); if (refsWithContext.length > 0) { - this.queries.insertUnresolvedRefsBatch(refsWithContext); + await this.queries.insertUnresolvedRefsBatch(refsWithContext); } } @@ -1068,7 +1068,7 @@ export class ExtractionOrchestrator { nodeCount: result.nodes.length, errors: result.errors.length > 0 ? result.errors : undefined, }; - this.queries.upsertFile(fileRecord); + await this.queries.upsertFile(fileRecord); } /** @@ -1101,9 +1101,9 @@ export class ExtractionOrchestrator { // Handle deleted files for (const filePath of gitChanges.deleted) { - const tracked = this.queries.getFileByPath(filePath); + const tracked = await this.queries.getFileByPath(filePath); if (tracked) { - this.queries.deleteFile(filePath); + await this.queries.deleteFile(filePath); filesRemoved++; } } @@ -1120,7 +1120,7 @@ export class ExtractionOrchestrator { } const contentHash = hashContent(content); - const tracked = this.queries.getFileByPath(filePath); + const tracked = await this.queries.getFileByPath(filePath); if (!tracked) { filesToIndex.push(filePath); @@ -1145,7 +1145,7 @@ export class ExtractionOrchestrator { filesChecked = currentFiles.size; // Build Map for O(1) lookups instead of .find() per file - const trackedFiles = this.queries.getAllFiles(); + const trackedFiles = await this.queries.getAllFiles(); const trackedMap = new Map(); for (const f of trackedFiles) { trackedMap.set(f.path, f); @@ -1154,7 +1154,7 @@ export class ExtractionOrchestrator { // Find files to remove (in DB but not on disk) for (const tracked of trackedFiles) { if (!currentFiles.has(tracked.path)) { - this.queries.deleteFile(tracked.path); + await this.queries.deleteFile(tracked.path); filesRemoved++; } } @@ -1221,7 +1221,7 @@ export class ExtractionOrchestrator { * Get files that have changed since last index. * Uses git status as a fast path when available, falling back to full scan. */ - getChangedFiles(): { added: string[]; modified: string[]; removed: string[] } { + async getChangedFiles(): Promise<{ added: string[]; modified: string[]; removed: string[] }> { const gitChanges = getGitChangedFiles(this.rootDir, this.config); if (gitChanges) { @@ -1232,7 +1232,7 @@ export class ExtractionOrchestrator { // Deleted files — only report if tracked in DB for (const filePath of gitChanges.deleted) { - const tracked = this.queries.getFileByPath(filePath); + const tracked = await this.queries.getFileByPath(filePath); if (tracked) { removed.push(filePath); } @@ -1250,7 +1250,7 @@ export class ExtractionOrchestrator { } const contentHash = hashContent(content); - const tracked = this.queries.getFileByPath(filePath); + const tracked = await this.queries.getFileByPath(filePath); if (!tracked) { added.push(filePath); @@ -1269,7 +1269,7 @@ export class ExtractionOrchestrator { // === Fallback: full scan (non-git project or git failure) === const currentFiles = new Set(scanDirectory(this.rootDir, this.config)); - const trackedFiles = this.queries.getAllFiles(); + const trackedFiles = await this.queries.getAllFiles(); // Build Map for O(1) lookups const trackedMap = new Map(); diff --git a/src/graph/queries.ts b/src/graph/queries.ts index c39e2e32..3fcb5bf6 100644 --- a/src/graph/queries.ts +++ b/src/graph/queries.ts @@ -29,42 +29,42 @@ export class GraphQueryManager { * @param nodeId - ID of the focal node * @returns Context object with all related information */ - getContext(nodeId: string): Context { - const focal = this.queries.getNodeById(nodeId); + async getContext(nodeId: string): Promise { + const focal = await this.queries.getNodeById(nodeId); if (!focal) { throw new Error(`Node not found: ${nodeId}`); } // Get ancestors (containment hierarchy) - const ancestors = this.traverser.getAncestors(nodeId); + const ancestors = await this.traverser.getAncestors(nodeId); // Get children - const children = this.traverser.getChildren(nodeId); + const children = await this.traverser.getChildren(nodeId); // Get incoming references (things that reference this node) - const incomingEdges = this.queries.getIncomingEdges(nodeId); + const incomingEdges = await this.queries.getIncomingEdges(nodeId); const incomingRefs: Array<{ node: Node; edge: Edge }> = []; for (const edge of incomingEdges) { // Skip containment edges (already in ancestors) if (edge.kind === 'contains') { continue; } - const node = this.queries.getNodeById(edge.source); + const node = await this.queries.getNodeById(edge.source); if (node) { incomingRefs.push({ node, edge }); } } // Get outgoing references (things this node references) - const outgoingEdges = this.queries.getOutgoingEdges(nodeId); + const outgoingEdges = await this.queries.getOutgoingEdges(nodeId); const outgoingRefs: Array<{ node: Node; edge: Edge }> = []; for (const edge of outgoingEdges) { // Skip containment edges (already in children) if (edge.kind === 'contains') { continue; } - const node = this.queries.getNodeById(edge.target); + const node = await this.queries.getNodeById(edge.target); if (node) { outgoingRefs.push({ node, edge }); } @@ -74,9 +74,9 @@ export class GraphQueryManager { const types: Node[] = []; const typeEdgeKinds: EdgeKind[] = ['type_of', 'returns']; for (const kind of typeEdgeKinds) { - const typeEdges = this.queries.getOutgoingEdges(nodeId, [kind]); + const typeEdges = await this.queries.getOutgoingEdges(nodeId, [kind]); for (const edge of typeEdges) { - const typeNode = this.queries.getNodeById(edge.target); + const typeNode = await this.queries.getNodeById(edge.target); if (typeNode && !types.some((t) => t.id === typeNode.id)) { types.push(typeNode); } @@ -87,9 +87,9 @@ export class GraphQueryManager { const imports: Node[] = []; const fileNode = ancestors.find((a) => a.kind === 'file'); if (fileNode) { - const importEdges = this.queries.getOutgoingEdges(fileNode.id, ['imports']); + const importEdges = await this.queries.getOutgoingEdges(fileNode.id, ['imports']); for (const edge of importEdges) { - const importNode = this.queries.getNodeById(edge.target); + const importNode = await this.queries.getNodeById(edge.target); if (importNode) { imports.push(importNode); } @@ -115,8 +115,8 @@ export class GraphQueryManager { * @param filePath - Path to the file * @returns Array of file paths this file depends on */ - getFileDependencies(filePath: string): string[] { - const nodes = this.queries.getNodesByFile(filePath); + async getFileDependencies(filePath: string): Promise { + const nodes = await this.queries.getNodesByFile(filePath); const fileNode = nodes.find((n) => n.kind === 'file'); if (!fileNode) { @@ -124,10 +124,10 @@ export class GraphQueryManager { } const dependencies = new Set(); - const importEdges = this.queries.getOutgoingEdges(fileNode.id, ['imports']); + const importEdges = await this.queries.getOutgoingEdges(fileNode.id, ['imports']); for (const edge of importEdges) { - const targetNode = this.queries.getNodeById(edge.target); + const targetNode = await this.queries.getNodeById(edge.target); if (targetNode && targetNode.filePath !== filePath) { dependencies.add(targetNode.filePath); } @@ -144,16 +144,16 @@ export class GraphQueryManager { * @param filePath - Path to the file * @returns Array of file paths that depend on this file */ - getFileDependents(filePath: string): string[] { - const nodes = this.queries.getNodesByFile(filePath); + async getFileDependents(filePath: string): Promise { + const nodes = await this.queries.getNodesByFile(filePath); const dependents = new Set(); // Check file-level incoming import edges (file:X imports file:Y) const fileNode = nodes.find((n) => n.kind === 'file'); if (fileNode) { - const incomingFileEdges = this.queries.getIncomingEdges(fileNode.id, ['imports']); + const incomingFileEdges = await this.queries.getIncomingEdges(fileNode.id, ['imports']); for (const edge of incomingFileEdges) { - const sourceNode = this.queries.getNodeById(edge.source); + const sourceNode = await this.queries.getNodeById(edge.source); if (sourceNode && sourceNode.filePath !== filePath) { dependents.add(sourceNode.filePath); } @@ -163,9 +163,9 @@ export class GraphQueryManager { // Also check node-level imports of exported symbols for (const node of nodes) { if (node.isExported) { - const incomingEdges = this.queries.getIncomingEdges(node.id, ['imports']); + const incomingEdges = await this.queries.getIncomingEdges(node.id, ['imports']); for (const edge of incomingEdges) { - const sourceNode = this.queries.getNodeById(edge.source); + const sourceNode = await this.queries.getNodeById(edge.source); if (sourceNode && sourceNode.filePath !== filePath) { dependents.add(sourceNode.filePath); } @@ -182,8 +182,8 @@ export class GraphQueryManager { * @param filePath - Path to the file * @returns Array of exported nodes */ - getExportedSymbols(filePath: string): Node[] { - const nodes = this.queries.getNodesByFile(filePath); + async getExportedSymbols(filePath: string): Promise { + const nodes = await this.queries.getNodesByFile(filePath); return nodes.filter((n) => n.isExported); } @@ -193,7 +193,7 @@ export class GraphQueryManager { * @param pattern - Pattern to match (supports * wildcard) * @returns Array of matching nodes */ - findByQualifiedName(pattern: string): Node[] { + async findByQualifiedName(pattern: string): Promise { // Convert glob pattern to regex const regexPattern = pattern .replace(/[.+^${}()|[\]\\]/g, '\\$&') @@ -216,7 +216,7 @@ export class GraphQueryManager { ]; for (const kind of kinds) { - const nodes = this.queries.getNodesByKind(kind); + const nodes = await this.queries.getNodesByKind(kind); for (const node of nodes) { if (regex.test(node.qualifiedName)) { allNodes.push(node); @@ -234,8 +234,8 @@ export class GraphQueryManager { * * @returns Map of directory paths to contained files */ - getModuleStructure(): Map { - const files = this.queries.getAllFiles(); + async getModuleStructure(): Promise> { + const files = await this.queries.getAllFiles(); const structure = new Map(); for (const file of files) { @@ -256,13 +256,13 @@ export class GraphQueryManager { * * @returns Array of cycles, each cycle is an array of node IDs */ - findCircularDependencies(): string[][] { - const files = this.queries.getAllFiles(); + async findCircularDependencies(): Promise { + const files = await this.queries.getAllFiles(); const cycles: string[][] = []; const visited = new Set(); const recursionStack = new Set(); - const dfs = (filePath: string, path: string[]): void => { + const dfs = async (filePath: string, path: string[]): Promise => { if (recursionStack.has(filePath)) { // Found a cycle const cycleStart = path.indexOf(filePath); @@ -279,9 +279,9 @@ export class GraphQueryManager { visited.add(filePath); recursionStack.add(filePath); - const dependencies = this.getFileDependencies(filePath); + const dependencies = await this.getFileDependencies(filePath); for (const dep of dependencies) { - dfs(dep, [...path, filePath]); + await dfs(dep, [...path, filePath]); } recursionStack.delete(filePath); @@ -289,7 +289,7 @@ export class GraphQueryManager { for (const file of files) { if (!visited.has(file.path)) { - dfs(file.path, []); + await dfs(file.path, []); } } @@ -302,22 +302,22 @@ export class GraphQueryManager { * @param nodeId - ID of the node * @returns Object containing various complexity metrics */ - getNodeMetrics(nodeId: string): { + async getNodeMetrics(nodeId: string): Promise<{ incomingEdgeCount: number; outgoingEdgeCount: number; callCount: number; callerCount: number; childCount: number; depth: number; - } { - const incomingEdges = this.queries.getIncomingEdges(nodeId); - const outgoingEdges = this.queries.getOutgoingEdges(nodeId); + }> { + const incomingEdges = await this.queries.getIncomingEdges(nodeId); + const outgoingEdges = await this.queries.getOutgoingEdges(nodeId); const callEdges = outgoingEdges.filter((e) => e.kind === 'calls'); const callerEdges = incomingEdges.filter((e) => e.kind === 'calls'); const containsEdges = outgoingEdges.filter((e) => e.kind === 'contains'); - const ancestors = this.traverser.getAncestors(nodeId); + const ancestors = await this.traverser.getAncestors(nodeId); return { incomingEdgeCount: incomingEdges.length, @@ -335,19 +335,19 @@ export class GraphQueryManager { * @param kinds - Node kinds to check (default: functions, methods, classes) * @returns Array of unreferenced nodes */ - findDeadCode(kinds?: Node['kind'][]): Node[] { + async findDeadCode(kinds?: Node['kind'][]): Promise { const targetKinds = kinds || ['function', 'method', 'class']; const deadCode: Node[] = []; for (const kind of targetKinds) { - const nodes = this.queries.getNodesByKind(kind); + const nodes = await this.queries.getNodesByKind(kind); for (const node of nodes) { // Skip exported symbols (they may be used externally) if (node.isExported) { continue; } - const incomingEdges = this.queries.getIncomingEdges(node.id); + const incomingEdges = await this.queries.getIncomingEdges(node.id); // Filter out containment edges const references = incomingEdges.filter((e) => e.kind !== 'contains'); @@ -368,10 +368,10 @@ export class GraphQueryManager { * @param includeEdges - Whether to include edges between matching nodes * @returns Subgraph containing matching nodes */ - getFilteredSubgraph( + async getFilteredSubgraph( filter: (node: Node) => boolean, includeEdges: boolean = true - ): Subgraph { + ): Promise { const nodes = new Map(); const edges: Edge[] = []; @@ -392,7 +392,7 @@ export class GraphQueryManager { ]; for (const kind of kinds) { - const kindNodes = this.queries.getNodesByKind(kind); + const kindNodes = await this.queries.getNodesByKind(kind); for (const node of kindNodes) { if (filter(node)) { nodes.set(node.id, node); @@ -403,7 +403,7 @@ export class GraphQueryManager { // Include edges between matching nodes if (includeEdges) { for (const nodeId of nodes.keys()) { - const outgoing = this.queries.getOutgoingEdges(nodeId); + const outgoing = await this.queries.getOutgoingEdges(nodeId); for (const edge of outgoing) { if (nodes.has(edge.target)) { edges.push(edge); diff --git a/src/graph/traversal.ts b/src/graph/traversal.ts index 7d6723ba..298ef6e8 100644 --- a/src/graph/traversal.ts +++ b/src/graph/traversal.ts @@ -45,9 +45,9 @@ export class GraphTraverser { * @param options - Traversal options * @returns Subgraph containing traversed nodes and edges */ - traverseBFS(startId: string, options: TraversalOptions = {}): Subgraph { + async traverseBFS(startId: string, options: TraversalOptions = {}): Promise { const opts = { ...DEFAULT_OPTIONS, ...options }; - const startNode = this.queries.getNodeById(startId); + const startNode = await this.queries.getNodeById(startId); if (!startNode) { return { nodes: new Map(), edges: [], roots: [] }; @@ -82,7 +82,7 @@ export class GraphTraverser { } // Get adjacent edges - const adjacentEdges = this.getAdjacentEdges(node.id, opts.direction, opts.edgeKinds); + const adjacentEdges = await this.getAdjacentEdges(node.id, opts.direction, opts.edgeKinds); for (const adjEdge of adjacentEdges) { // Determine next node: for 'both' direction, edges can be either @@ -93,7 +93,7 @@ export class GraphTraverser { continue; } - const nextNode = this.queries.getNodeById(nextNodeId); + const nextNode = await this.queries.getNodeById(nextNodeId); if (!nextNode) { continue; } @@ -125,9 +125,9 @@ export class GraphTraverser { * @param options - Traversal options * @returns Subgraph containing traversed nodes and edges */ - traverseDFS(startId: string, options: TraversalOptions = {}): Subgraph { + async traverseDFS(startId: string, options: TraversalOptions = {}): Promise { const opts = { ...DEFAULT_OPTIONS, ...options }; - const startNode = this.queries.getNodeById(startId); + const startNode = await this.queries.getNodeById(startId); if (!startNode) { return { nodes: new Map(), edges: [], roots: [] }; @@ -141,7 +141,7 @@ export class GraphTraverser { nodes.set(startNode.id, startNode); } - this.dfsRecursive(startNode, 0, opts, nodes, edges, visited); + await this.dfsRecursive(startNode, 0, opts, nodes, edges, visited); return { nodes, @@ -153,14 +153,14 @@ export class GraphTraverser { /** * Recursive DFS helper */ - private dfsRecursive( + private async dfsRecursive( node: Node, depth: number, opts: Required, nodes: Map, edges: Edge[], visited: Set - ): void { + ): Promise { if (visited.has(node.id) || nodes.size >= opts.limit || depth >= opts.maxDepth) { return; } @@ -168,7 +168,7 @@ export class GraphTraverser { visited.add(node.id); // Get adjacent edges - const adjacentEdges = this.getAdjacentEdges(node.id, opts.direction, opts.edgeKinds); + const adjacentEdges = await this.getAdjacentEdges(node.id, opts.direction, opts.edgeKinds); for (const edge of adjacentEdges) { // Determine next node: for 'both' direction, edges can be either @@ -179,7 +179,7 @@ export class GraphTraverser { continue; } - const nextNode = this.queries.getNodeById(nextNodeId); + const nextNode = await this.queries.getNodeById(nextNodeId); if (!nextNode) { continue; } @@ -194,28 +194,28 @@ export class GraphTraverser { edges.push(edge); // Recurse - this.dfsRecursive(nextNode, depth + 1, opts, nodes, edges, visited); + await this.dfsRecursive(nextNode, depth + 1, opts, nodes, edges, visited); } } /** * Get adjacent edges based on direction */ - private getAdjacentEdges( + private async getAdjacentEdges( nodeId: string, direction: 'outgoing' | 'incoming' | 'both', edgeKinds?: EdgeKind[] - ): Edge[] { + ): Promise { const kinds = edgeKinds && edgeKinds.length > 0 ? edgeKinds : undefined; if (direction === 'outgoing') { - return this.queries.getOutgoingEdges(nodeId, kinds); + return await this.queries.getOutgoingEdges(nodeId, kinds); } else if (direction === 'incoming') { - return this.queries.getIncomingEdges(nodeId, kinds); + return await this.queries.getIncomingEdges(nodeId, kinds); } else { // Both directions - const outgoing = this.queries.getOutgoingEdges(nodeId, kinds); - const incoming = this.queries.getIncomingEdges(nodeId, kinds); + const outgoing = await this.queries.getOutgoingEdges(nodeId, kinds); + const incoming = await this.queries.getIncomingEdges(nodeId, kinds); return [...outgoing, ...incoming]; } } @@ -227,34 +227,34 @@ export class GraphTraverser { * @param maxDepth - Maximum depth to traverse (default: 1) * @returns Array of nodes that call this function */ - getCallers(nodeId: string, maxDepth: number = 1): Array<{ node: Node; edge: Edge }> { + async getCallers(nodeId: string, maxDepth: number = 1): Promise> { const result: Array<{ node: Node; edge: Edge }> = []; const visited = new Set(); - this.getCallersRecursive(nodeId, maxDepth, 0, result, visited); + await this.getCallersRecursive(nodeId, maxDepth, 0, result, visited); return result; } - private getCallersRecursive( + private async getCallersRecursive( nodeId: string, maxDepth: number, currentDepth: number, result: Array<{ node: Node; edge: Edge }>, visited: Set - ): void { + ): Promise { if (currentDepth >= maxDepth || visited.has(nodeId)) { return; } visited.add(nodeId); - const incomingEdges = this.queries.getIncomingEdges(nodeId, ['calls', 'references', 'imports']); + const incomingEdges = await this.queries.getIncomingEdges(nodeId, ['calls', 'references', 'imports']); for (const edge of incomingEdges) { - const callerNode = this.queries.getNodeById(edge.source); + const callerNode = await this.queries.getNodeById(edge.source); if (callerNode && !visited.has(callerNode.id)) { result.push({ node: callerNode, edge }); - this.getCallersRecursive(callerNode.id, maxDepth, currentDepth + 1, result, visited); + await this.getCallersRecursive(callerNode.id, maxDepth, currentDepth + 1, result, visited); } } } @@ -266,34 +266,34 @@ export class GraphTraverser { * @param maxDepth - Maximum depth to traverse (default: 1) * @returns Array of nodes called by this function */ - getCallees(nodeId: string, maxDepth: number = 1): Array<{ node: Node; edge: Edge }> { + async getCallees(nodeId: string, maxDepth: number = 1): Promise> { const result: Array<{ node: Node; edge: Edge }> = []; const visited = new Set(); - this.getCalleesRecursive(nodeId, maxDepth, 0, result, visited); + await this.getCalleesRecursive(nodeId, maxDepth, 0, result, visited); return result; } - private getCalleesRecursive( + private async getCalleesRecursive( nodeId: string, maxDepth: number, currentDepth: number, result: Array<{ node: Node; edge: Edge }>, visited: Set - ): void { + ): Promise { if (currentDepth >= maxDepth || visited.has(nodeId)) { return; } visited.add(nodeId); - const outgoingEdges = this.queries.getOutgoingEdges(nodeId, ['calls', 'references', 'imports']); + const outgoingEdges = await this.queries.getOutgoingEdges(nodeId, ['calls', 'references', 'imports']); for (const edge of outgoingEdges) { - const calleeNode = this.queries.getNodeById(edge.target); + const calleeNode = await this.queries.getNodeById(edge.target); if (calleeNode && !visited.has(calleeNode.id)) { result.push({ node: calleeNode, edge }); - this.getCalleesRecursive(calleeNode.id, maxDepth, currentDepth + 1, result, visited); + await this.getCalleesRecursive(calleeNode.id, maxDepth, currentDepth + 1, result, visited); } } } @@ -305,8 +305,8 @@ export class GraphTraverser { * @param depth - Maximum depth in each direction (default: 2) * @returns Subgraph containing the call graph */ - getCallGraph(nodeId: string, depth: number = 2): Subgraph { - const focalNode = this.queries.getNodeById(nodeId); + async getCallGraph(nodeId: string, depth: number = 2): Promise { + const focalNode = await this.queries.getNodeById(nodeId); if (!focalNode) { return { nodes: new Map(), edges: [], roots: [] }; } @@ -318,14 +318,14 @@ export class GraphTraverser { nodes.set(focalNode.id, focalNode); // Get callers - const callers = this.getCallers(nodeId, depth); + const callers = await this.getCallers(nodeId, depth); for (const { node, edge } of callers) { nodes.set(node.id, node); edges.push(edge); } // Get callees - const callees = this.getCallees(nodeId, depth); + const callees = await this.getCallees(nodeId, depth); for (const { node, edge } of callees) { nodes.set(node.id, node); edges.push(edge); @@ -344,8 +344,8 @@ export class GraphTraverser { * @param nodeId - ID of the class/interface node * @returns Subgraph containing the type hierarchy */ - getTypeHierarchy(nodeId: string): Subgraph { - const focalNode = this.queries.getNodeById(nodeId); + async getTypeHierarchy(nodeId: string): Promise { + const focalNode = await this.queries.getNodeById(nodeId); if (!focalNode) { return { nodes: new Map(), edges: [], roots: [] }; } @@ -358,10 +358,10 @@ export class GraphTraverser { nodes.set(focalNode.id, focalNode); // Get ancestors (what this extends/implements) - this.getTypeAncestors(nodeId, nodes, edges, visited); + await this.getTypeAncestors(nodeId, nodes, edges, visited); // Get descendants (what extends/implements this) - this.getTypeDescendants(nodeId, nodes, edges, visited); + await this.getTypeDescendants(nodeId, nodes, edges, visited); return { nodes, @@ -370,48 +370,48 @@ export class GraphTraverser { }; } - private getTypeAncestors( + private async getTypeAncestors( nodeId: string, nodes: Map, edges: Edge[], visited: Set - ): void { + ): Promise { if (visited.has(nodeId)) { return; } visited.add(nodeId); - const outgoingEdges = this.queries.getOutgoingEdges(nodeId, ['extends', 'implements']); + const outgoingEdges = await this.queries.getOutgoingEdges(nodeId, ['extends', 'implements']); for (const edge of outgoingEdges) { - const parentNode = this.queries.getNodeById(edge.target); + const parentNode = await this.queries.getNodeById(edge.target); if (parentNode && !nodes.has(parentNode.id)) { nodes.set(parentNode.id, parentNode); edges.push(edge); - this.getTypeAncestors(parentNode.id, nodes, edges, visited); + await this.getTypeAncestors(parentNode.id, nodes, edges, visited); } } } - private getTypeDescendants( + private async getTypeDescendants( nodeId: string, nodes: Map, edges: Edge[], visited: Set - ): void { + ): Promise { if (visited.has(nodeId)) { return; } visited.add(nodeId); - const incomingEdges = this.queries.getIncomingEdges(nodeId, ['extends', 'implements']); + const incomingEdges = await this.queries.getIncomingEdges(nodeId, ['extends', 'implements']); for (const edge of incomingEdges) { - const childNode = this.queries.getNodeById(edge.source); + const childNode = await this.queries.getNodeById(edge.source); if (childNode && !nodes.has(childNode.id)) { nodes.set(childNode.id, childNode); edges.push(edge); - this.getTypeDescendants(childNode.id, nodes, edges, visited); + await this.getTypeDescendants(childNode.id, nodes, edges, visited); } } } @@ -422,14 +422,14 @@ export class GraphTraverser { * @param nodeId - ID of the symbol node * @returns Array of nodes and edges that reference this symbol */ - findUsages(nodeId: string): Array<{ node: Node; edge: Edge }> { + async findUsages(nodeId: string): Promise> { const result: Array<{ node: Node; edge: Edge }> = []; // Get all incoming edges (references, calls, type_of, etc.) - const incomingEdges = this.queries.getIncomingEdges(nodeId); + const incomingEdges = await this.queries.getIncomingEdges(nodeId); for (const edge of incomingEdges) { - const sourceNode = this.queries.getNodeById(edge.source); + const sourceNode = await this.queries.getNodeById(edge.source); if (sourceNode) { result.push({ node: sourceNode, edge }); } @@ -447,8 +447,8 @@ export class GraphTraverser { * @param maxDepth - Maximum depth to traverse (default: 3) * @returns Subgraph containing potentially impacted nodes */ - getImpactRadius(nodeId: string, maxDepth: number = 3): Subgraph { - const focalNode = this.queries.getNodeById(nodeId); + async getImpactRadius(nodeId: string, maxDepth: number = 3): Promise { + const focalNode = await this.queries.getNodeById(nodeId); if (!focalNode) { return { nodes: new Map(), edges: [], roots: [] }; } @@ -461,7 +461,7 @@ export class GraphTraverser { nodes.set(focalNode.id, focalNode); // Traverse incoming edges to find all dependents - this.getImpactRecursive(nodeId, maxDepth, 0, nodes, edges, visited); + await this.getImpactRecursive(nodeId, maxDepth, 0, nodes, edges, visited); return { nodes, @@ -470,14 +470,14 @@ export class GraphTraverser { }; } - private getImpactRecursive( + private async getImpactRecursive( nodeId: string, maxDepth: number, currentDepth: number, nodes: Map, edges: Edge[], visited: Set - ): void { + ): Promise { if (currentDepth >= maxDepth || visited.has(nodeId)) { return; } @@ -485,32 +485,32 @@ export class GraphTraverser { // For container nodes (classes, interfaces, structs, etc.), also traverse // into their children so that callers of contained methods appear in impact - const focalNode = this.queries.getNodeById(nodeId); + const focalNode = await this.queries.getNodeById(nodeId); if (focalNode) { const containerKinds = new Set(['class', 'interface', 'struct', 'trait', 'protocol', 'module', 'enum']); if (containerKinds.has(focalNode.kind)) { - const containsEdges = this.queries.getOutgoingEdges(nodeId, ['contains']); + const containsEdges = await this.queries.getOutgoingEdges(nodeId, ['contains']); for (const edge of containsEdges) { - const childNode = this.queries.getNodeById(edge.target); + const childNode = await this.queries.getNodeById(edge.target); if (childNode && !visited.has(childNode.id)) { nodes.set(childNode.id, childNode); edges.push(edge); // Recurse into children at the same depth (they're part of the same symbol) - this.getImpactRecursive(childNode.id, maxDepth, currentDepth, nodes, edges, visited); + await this.getImpactRecursive(childNode.id, maxDepth, currentDepth, nodes, edges, visited); } } } } // Get all incoming edges (things that depend on this node) - const incomingEdges = this.queries.getIncomingEdges(nodeId); + const incomingEdges = await this.queries.getIncomingEdges(nodeId); for (const edge of incomingEdges) { - const sourceNode = this.queries.getNodeById(edge.source); + const sourceNode = await this.queries.getNodeById(edge.source); if (sourceNode && !nodes.has(sourceNode.id)) { nodes.set(sourceNode.id, sourceNode); edges.push(edge); - this.getImpactRecursive(sourceNode.id, maxDepth, currentDepth + 1, nodes, edges, visited); + await this.getImpactRecursive(sourceNode.id, maxDepth, currentDepth + 1, nodes, edges, visited); } } } @@ -523,13 +523,13 @@ export class GraphTraverser { * @param edgeKinds - Edge types to consider (all if empty) * @returns Array of nodes and edges forming the path, or null if no path exists */ - findPath( + async findPath( fromId: string, toId: string, edgeKinds: EdgeKind[] = [] - ): Array<{ node: Node; edge: Edge | null }> | null { - const fromNode = this.queries.getNodeById(fromId); - const toNode = this.queries.getNodeById(toId); + ): Promise | null> { + const fromNode = await this.queries.getNodeById(fromId); + const toNode = await this.queries.getNodeById(toId); if (!fromNode || !toNode) { return null; @@ -554,14 +554,14 @@ export class GraphTraverser { visited.add(nodeId); // Get outgoing edges - const outgoingEdges = this.queries.getOutgoingEdges( + const outgoingEdges = await this.queries.getOutgoingEdges( nodeId, edgeKinds.length > 0 ? edgeKinds : undefined ); for (const edge of outgoingEdges) { if (!visited.has(edge.target)) { - const nextNode = this.queries.getNodeById(edge.target); + const nextNode = await this.queries.getNodeById(edge.target); if (nextNode) { queue.push({ nodeId: edge.target, @@ -581,7 +581,7 @@ export class GraphTraverser { * @param nodeId - ID of the node * @returns Array of ancestor nodes from immediate parent to root */ - getAncestors(nodeId: string): Node[] { + async getAncestors(nodeId: string): Promise { const ancestors: Node[] = []; const visited = new Set(); let currentId = nodeId; @@ -593,7 +593,7 @@ export class GraphTraverser { visited.add(currentId); // Look for 'contains' edges pointing to this node - const containingEdges = this.queries.getIncomingEdges(currentId, ['contains']); + const containingEdges = await this.queries.getIncomingEdges(currentId, ['contains']); const firstEdge = containingEdges[0]; if (!firstEdge) { @@ -601,7 +601,7 @@ export class GraphTraverser { } // Typically there should be at most one containing parent - const parentNode = this.queries.getNodeById(firstEdge.source); + const parentNode = await this.queries.getNodeById(firstEdge.source); if (parentNode) { ancestors.push(parentNode); currentId = parentNode.id; @@ -619,12 +619,12 @@ export class GraphTraverser { * @param nodeId - ID of the node * @returns Array of child nodes */ - getChildren(nodeId: string): Node[] { - const containsEdges = this.queries.getOutgoingEdges(nodeId, ['contains']); + async getChildren(nodeId: string): Promise { + const containsEdges = await this.queries.getOutgoingEdges(nodeId, ['contains']); const children: Node[] = []; for (const edge of containsEdges) { - const childNode = this.queries.getNodeById(edge.target); + const childNode = await this.queries.getNodeById(edge.target); if (childNode) { children.push(childNode); } diff --git a/src/index.ts b/src/index.ts index 9d0afb18..6e9afd42 100644 --- a/src/index.ts +++ b/src/index.ts @@ -25,6 +25,9 @@ import { } from './types'; import { DatabaseConnection, getDatabasePath } from './db'; import { QueryBuilder } from './db/queries'; +import { SqliteDbAdapter } from './db/sqlite-db-adapter'; +import { DbAdapter } from './db/adapter'; +import { createDbAdapter } from './db/db-factory'; import { loadConfig, saveConfig, createDefaultConfig } from './config'; import { isInitialized, @@ -46,7 +49,7 @@ import { ResolutionResult, } from './resolution'; import { GraphTraverser, GraphQueryManager } from './graph'; -import { VectorManager, createVectorManager, EmbeddingProgress } from './vectors'; +import { VectorManager, createVectorManager, EmbeddingProgress, SqliteVectorStore, createVectorStore, VectorStoreConfig } from './vectors'; import { ContextBuilder, createContextBuilder } from './context'; import { Mutex, FileLock } from './utils'; @@ -126,7 +129,8 @@ export interface IndexOptions { * Provides the primary interface for interacting with the code knowledge graph. */ export class CodeGraph { - private db: DatabaseConnection; + private db: DatabaseConnection | null; + private dbAdapter: DbAdapter; private queries: QueryBuilder; private config: CodeGraphConfig; private projectRoot: string; @@ -144,12 +148,14 @@ export class CodeGraph { private fileLock: FileLock; private constructor( - db: DatabaseConnection, + db: DatabaseConnection | null, + dbAdapter: DbAdapter, queries: QueryBuilder, config: CodeGraphConfig, projectRoot: string ) { this.db = db; + this.dbAdapter = dbAdapter; this.queries = queries; this.config = config; this.projectRoot = projectRoot; @@ -157,12 +163,17 @@ export class CodeGraph { path.join(projectRoot, '.codegraph', 'codegraph.lock') ); this.orchestrator = new ExtractionOrchestrator(projectRoot, config, queries); - this.resolver = createResolver(projectRoot, queries); + // Create resolver without async initialize() -- deferred to first use or explicit reinitializeResolver() + this.resolver = new ReferenceResolver(projectRoot, queries); this.graphManager = new GraphQueryManager(queries); this.traverser = new GraphTraverser(queries); - // Vector manager — always created, embeddings generated lazily on first use - this.vectorManager = createVectorManager(db.getDb(), queries, {}); - // Context builder (uses vector manager for semantic search) + // Vector manager -- created lazily via initializeEmbeddings() when pgvector is configured. + // For SQLite (default), create eagerly so context builder can use it immediately. + if (db && (!config.vectorStore || config.vectorStore.backend === 'sqlite')) { + const sqliteStore = new SqliteVectorStore(db.getDb()); + this.vectorManager = createVectorManager(sqliteStore, queries, {}); + } + // Context builder (uses vector manager for semantic search if available) this.contextBuilder = createContextBuilder( projectRoot, queries, @@ -204,11 +215,18 @@ export class CodeGraph { saveConfig(resolvedRoot, config); // Initialize database - const dbPath = getDatabasePath(resolvedRoot); - const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + let db: DatabaseConnection | null = null; + let adapter: DbAdapter; + if (config.database?.backend === 'postgres') { + adapter = await createDbAdapter({ backend: 'postgres', connectionString: config.database.connectionString, poolSize: config.database.poolSize, tablePrefix: config.database.tablePrefix }); + } else { + const dbPath = getDatabasePath(resolvedRoot); + db = DatabaseConnection.initialize(dbPath); + adapter = new SqliteDbAdapter(db.getDb()); + } + const queries = new QueryBuilder(adapter); - const instance = new CodeGraph(db, queries, config, resolvedRoot); + const instance = new CodeGraph(db, adapter, queries, config, resolvedRoot); // Run initial indexing if requested if (options.index) { @@ -239,12 +257,13 @@ export class CodeGraph { } saveConfig(resolvedRoot, config); - // Initialize database + // Initialize database (initSync only supports SQLite) const dbPath = getDatabasePath(resolvedRoot); const db = DatabaseConnection.initialize(dbPath); - const queries = new QueryBuilder(db.getDb()); + const adapter = new SqliteDbAdapter(db.getDb()); + const queries = new QueryBuilder(adapter); - return new CodeGraph(db, queries, config, resolvedRoot); + return new CodeGraph(db, adapter, queries, config, resolvedRoot); } /** @@ -273,11 +292,18 @@ export class CodeGraph { const config = loadConfig(resolvedRoot); // Open database - const dbPath = getDatabasePath(resolvedRoot); - const db = DatabaseConnection.open(dbPath); - const queries = new QueryBuilder(db.getDb()); + let db: DatabaseConnection | null = null; + let adapter: DbAdapter; + if (config.database?.backend === 'postgres') { + adapter = await createDbAdapter({ backend: 'postgres', connectionString: config.database.connectionString, poolSize: config.database.poolSize, tablePrefix: config.database.tablePrefix }); + } else { + const dbPath = getDatabasePath(resolvedRoot); + db = DatabaseConnection.open(dbPath); + adapter = new SqliteDbAdapter(db.getDb()); + } + const queries = new QueryBuilder(adapter); - const instance = new CodeGraph(db, queries, config, resolvedRoot); + const instance = new CodeGraph(db, adapter, queries, config, resolvedRoot); // Sync if requested if (options.sync) { @@ -307,12 +333,13 @@ export class CodeGraph { // Load configuration const config = loadConfig(resolvedRoot); - // Open database + // Open database (openSync only supports SQLite) const dbPath = getDatabasePath(resolvedRoot); const db = DatabaseConnection.open(dbPath); - const queries = new QueryBuilder(db.getDb()); + const adapter = new SqliteDbAdapter(db.getDb()); + const queries = new QueryBuilder(adapter); - return new CodeGraph(db, queries, config, resolvedRoot); + return new CodeGraph(db, adapter, queries, config, resolvedRoot); } /** @@ -328,12 +355,17 @@ export class CodeGraph { close(): void { // Release file lock if held this.fileLock.release(); - // Dispose vector manager first to release ONNX workers + // Dispose vector manager to release ONNX workers and close pg pool + // Fire-and-forget: SQLite dispose is a no-op, pg pool cleanup is best-effort if (this.vectorManager) { - this.vectorManager.dispose(); + this.vectorManager.dispose().catch(() => {}); this.vectorManager = null; } - this.db.close(); + if (this.db) { + this.db.close(); + } else { + this.dbAdapter.close().catch(() => {}); + } } // =========================================================================== @@ -350,7 +382,7 @@ export class CodeGraph { /** * Update configuration */ - updateConfig(updates: Partial): void { + async updateConfig(updates: Partial): Promise { Object.assign(this.config, updates); saveConfig(this.projectRoot, this.config); // Recreate orchestrator and resolver with new config @@ -359,7 +391,7 @@ export class CodeGraph { this.config, this.queries ); - this.resolver = createResolver(this.projectRoot, this.queries); + this.resolver = await createResolver(this.projectRoot, this.queries); } /** @@ -391,7 +423,7 @@ export class CodeGraph { // Resolve references to create call/import/extends edges if (result.success && result.filesIndexed > 0) { // Get count without loading all refs into memory - const unresolvedCount = this.queries.getUnresolvedReferencesCount(); + const unresolvedCount = await this.queries.getUnresolvedReferencesCount(); options.onProgress?.({ phase: 'resolving', @@ -399,7 +431,7 @@ export class CodeGraph { total: unresolvedCount, }); - this.resolveReferencesBatched((current, total) => { + await this.resolveReferencesBatched((current, total) => { options.onProgress?.({ phase: 'resolving', current, @@ -454,7 +486,7 @@ export class CodeGraph { if (result.filesAdded > 0 || result.filesModified > 0) { if (result.changedFilePaths) { // Scope resolution to changed files (git fast path — bounded set) - const unresolvedRefs = this.queries.getUnresolvedReferencesByFiles(result.changedFilePaths); + const unresolvedRefs = await this.queries.getUnresolvedReferencesByFiles(result.changedFilePaths); options.onProgress?.({ phase: 'resolving', @@ -462,7 +494,7 @@ export class CodeGraph { total: unresolvedRefs.length, }); - this.resolver.resolveAndPersist(unresolvedRefs, (current, total) => { + await this.resolver.resolveAndPersist(unresolvedRefs, (current, total) => { options.onProgress?.({ phase: 'resolving', current, @@ -471,7 +503,7 @@ export class CodeGraph { }); } else { // No git info — use batched resolution to avoid OOM - const unresolvedCount = this.queries.getUnresolvedReferencesCount(); + const unresolvedCount = await this.queries.getUnresolvedReferencesCount(); options.onProgress?.({ phase: 'resolving', @@ -479,7 +511,7 @@ export class CodeGraph { total: unresolvedCount, }); - this.resolveReferencesBatched((current, total) => { + await this.resolveReferencesBatched((current, total) => { options.onProgress?.({ phase: 'resolving', current, @@ -506,7 +538,7 @@ export class CodeGraph { /** * Get files that have changed since last index */ - getChangedFiles(): { added: string[]; modified: string[]; removed: string[] } { + async getChangedFiles(): Promise<{ added: string[]; modified: string[]; removed: string[] }> { return this.orchestrator.getChangedFiles(); } @@ -530,9 +562,9 @@ export class CodeGraph { * - Import-based resolution * - Name-based symbol matching */ - resolveReferences(onProgress?: (current: number, total: number) => void): ResolutionResult { + async resolveReferences(onProgress?: (current: number, total: number) => void): Promise { // Get all unresolved references from the database - const unresolvedRefs = this.queries.getUnresolvedReferences(); + const unresolvedRefs = await this.queries.getUnresolvedReferences(); return this.resolver.resolveAndPersist(unresolvedRefs, onProgress); } @@ -540,7 +572,7 @@ export class CodeGraph { * Resolve references in batches to keep memory bounded on large codebases. * Processes chunks of unresolved refs, persisting results after each batch. */ - resolveReferencesBatched(onProgress?: (current: number, total: number) => void): ResolutionResult { + async resolveReferencesBatched(onProgress?: (current: number, total: number) => void): Promise { return this.resolver.resolveAndPersistBatched(onProgress); } @@ -554,8 +586,8 @@ export class CodeGraph { /** * Re-initialize the resolver (useful after adding new files) */ - reinitializeResolver(): void { - this.resolver.initialize(); + async reinitializeResolver(): Promise { + await this.resolver.initialize(); } // =========================================================================== @@ -565,9 +597,9 @@ export class CodeGraph { /** * Get statistics about the knowledge graph */ - getStats(): GraphStats { - const stats = this.queries.getStats(); - stats.dbSizeBytes = this.db.getSize(); + async getStats(): Promise { + const stats = await this.queries.getStats(); + stats.dbSizeBytes = this.db ? this.db.getSize() : 0; return stats; } @@ -578,28 +610,28 @@ export class CodeGraph { /** * Get a node by ID */ - getNode(id: string): Node | null { + async getNode(id: string): Promise { return this.queries.getNodeById(id); } /** * Get all nodes in a file */ - getNodesInFile(filePath: string): Node[] { + async getNodesInFile(filePath: string): Promise { return this.queries.getNodesByFile(filePath); } /** * Get all nodes of a specific kind */ - getNodesByKind(kind: Node['kind']): Node[] { + async getNodesByKind(kind: Node['kind']): Promise { return this.queries.getNodesByKind(kind); } /** * Search nodes by text */ - searchNodes(query: string, options?: SearchOptions): SearchResult[] { + async searchNodes(query: string, options?: SearchOptions): Promise { return this.queries.searchNodes(query, options); } @@ -610,14 +642,14 @@ export class CodeGraph { /** * Get outgoing edges from a node */ - getOutgoingEdges(nodeId: string): Edge[] { + async getOutgoingEdges(nodeId: string): Promise { return this.queries.getOutgoingEdges(nodeId); } /** * Get incoming edges to a node */ - getIncomingEdges(nodeId: string): Edge[] { + async getIncomingEdges(nodeId: string): Promise { return this.queries.getIncomingEdges(nodeId); } @@ -628,14 +660,14 @@ export class CodeGraph { /** * Get a file record by path */ - getFile(filePath: string): FileRecord | null { + async getFile(filePath: string): Promise { return this.queries.getFileByPath(filePath); } /** * Get all tracked files */ - getFiles(): FileRecord[] { + async getFiles(): Promise { return this.queries.getAllFiles(); } @@ -653,7 +685,7 @@ export class CodeGraph { * @param nodeId - ID of the focal node * @returns Context object with all related information */ - getContext(nodeId: string): Context { + async getContext(nodeId: string): Promise { return this.graphManager.getContext(nodeId); } @@ -667,7 +699,7 @@ export class CodeGraph { * @param options - Traversal options * @returns Subgraph containing traversed nodes and edges */ - traverse(startId: string, options?: TraversalOptions): Subgraph { + async traverse(startId: string, options?: TraversalOptions): Promise { return this.traverser.traverseBFS(startId, options); } @@ -681,7 +713,7 @@ export class CodeGraph { * @param depth - Maximum depth in each direction (default: 2) * @returns Subgraph containing the call graph */ - getCallGraph(nodeId: string, depth: number = 2): Subgraph { + async getCallGraph(nodeId: string, depth: number = 2): Promise { return this.traverser.getCallGraph(nodeId, depth); } @@ -694,7 +726,7 @@ export class CodeGraph { * @param nodeId - ID of the class/interface node * @returns Subgraph containing the type hierarchy */ - getTypeHierarchy(nodeId: string): Subgraph { + async getTypeHierarchy(nodeId: string): Promise { return this.traverser.getTypeHierarchy(nodeId); } @@ -707,7 +739,7 @@ export class CodeGraph { * @param nodeId - ID of the symbol node * @returns Array of nodes and edges that reference this symbol */ - findUsages(nodeId: string): Array<{ node: Node; edge: Edge }> { + async findUsages(nodeId: string): Promise> { return this.traverser.findUsages(nodeId); } @@ -718,7 +750,7 @@ export class CodeGraph { * @param maxDepth - Maximum depth to traverse (default: 1) * @returns Array of nodes that call this function */ - getCallers(nodeId: string, maxDepth: number = 1): Array<{ node: Node; edge: Edge }> { + async getCallers(nodeId: string, maxDepth: number = 1): Promise> { return this.traverser.getCallers(nodeId, maxDepth); } @@ -729,7 +761,7 @@ export class CodeGraph { * @param maxDepth - Maximum depth to traverse (default: 1) * @returns Array of nodes called by this function */ - getCallees(nodeId: string, maxDepth: number = 1): Array<{ node: Node; edge: Edge }> { + async getCallees(nodeId: string, maxDepth: number = 1): Promise> { return this.traverser.getCallees(nodeId, maxDepth); } @@ -742,7 +774,7 @@ export class CodeGraph { * @param maxDepth - Maximum depth to traverse (default: 3) * @returns Subgraph containing potentially impacted nodes */ - getImpactRadius(nodeId: string, maxDepth: number = 3): Subgraph { + async getImpactRadius(nodeId: string, maxDepth: number = 3): Promise { return this.traverser.getImpactRadius(nodeId, maxDepth); } @@ -754,11 +786,11 @@ export class CodeGraph { * @param edgeKinds - Edge types to consider (all if empty) * @returns Array of nodes and edges forming the path, or null if no path exists */ - findPath( + async findPath( fromId: string, toId: string, edgeKinds?: Edge['kind'][] - ): Array<{ node: Node; edge: Edge | null }> | null { + ): Promise | null> { return this.traverser.findPath(fromId, toId, edgeKinds); } @@ -768,7 +800,7 @@ export class CodeGraph { * @param nodeId - ID of the node * @returns Array of ancestor nodes from immediate parent to root */ - getAncestors(nodeId: string): Node[] { + async getAncestors(nodeId: string): Promise { return this.traverser.getAncestors(nodeId); } @@ -778,7 +810,7 @@ export class CodeGraph { * @param nodeId - ID of the node * @returns Array of child nodes */ - getChildren(nodeId: string): Node[] { + async getChildren(nodeId: string): Promise { return this.traverser.getChildren(nodeId); } @@ -788,7 +820,7 @@ export class CodeGraph { * @param filePath - Path to the file * @returns Array of file paths this file depends on */ - getFileDependencies(filePath: string): string[] { + async getFileDependencies(filePath: string): Promise { return this.graphManager.getFileDependencies(filePath); } @@ -798,7 +830,7 @@ export class CodeGraph { * @param filePath - Path to the file * @returns Array of file paths that depend on this file */ - getFileDependents(filePath: string): string[] { + async getFileDependents(filePath: string): Promise { return this.graphManager.getFileDependents(filePath); } @@ -807,7 +839,7 @@ export class CodeGraph { * * @returns Array of cycles, each cycle is an array of file paths */ - findCircularDependencies(): string[][] { + async findCircularDependencies(): Promise { return this.graphManager.findCircularDependencies(); } @@ -817,7 +849,7 @@ export class CodeGraph { * @param kinds - Node kinds to check (default: functions, methods, classes) * @returns Array of unreferenced nodes */ - findDeadCode(kinds?: Node['kind'][]): Node[] { + async findDeadCode(kinds?: Node['kind'][]): Promise { return this.graphManager.findDeadCode(kinds); } @@ -827,14 +859,14 @@ export class CodeGraph { * @param nodeId - ID of the node * @returns Object containing various complexity metrics */ - getNodeMetrics(nodeId: string): { + async getNodeMetrics(nodeId: string): Promise<{ incomingEdgeCount: number; outgoingEdgeCount: number; callCount: number; callerCount: number; childCount: number; depth: number; - } { + }> { return this.graphManager.getNodeMetrics(nodeId); } @@ -850,7 +882,18 @@ export class CodeGraph { */ async initializeEmbeddings(): Promise { if (!this.vectorManager) { - this.vectorManager = createVectorManager(this.db.getDb(), this.queries, { + const storeConfig: VectorStoreConfig = this.config.vectorStore + ? { + backend: this.config.vectorStore.backend, + connectionString: this.config.vectorStore.connectionString, + indexType: this.config.vectorStore.indexType, + distanceMetric: this.config.vectorStore.distanceMetric, + poolSize: this.config.vectorStore.poolSize, + tablePrefix: this.config.vectorStore.tablePrefix, + } + : { backend: 'sqlite' }; + const store = await createVectorStore(storeConfig, this.db?.getDb()); + this.vectorManager = createVectorManager(store, this.queries, { embedder: { showProgress: true, }, @@ -922,12 +965,14 @@ export class CodeGraph { /** * Get vector embedding statistics */ - getEmbeddingStats(): { + async getEmbeddingStats(): Promise<{ totalVectors: number; vssEnabled: boolean; + annEnabled: boolean; + backend: 'sqlite' | 'pgvector'; modelId: string; dimension: number; - } | null { + } | null> { if (!this.vectorManager) { return null; } @@ -1008,15 +1053,17 @@ export class CodeGraph { /** * Optimize the database (vacuum and analyze) */ - optimize(): void { - this.db.optimize(); + async optimize(): Promise { + if (this.db) { + this.db.optimize(); + } } /** * Clear all data from the graph */ - clear(): void { - this.queries.clear(); + async clear(): Promise { + await this.queries.clear(); } /** diff --git a/src/mcp/tools.ts b/src/mcp/tools.ts index aebfe973..0eabe0f7 100644 --- a/src/mcp/tools.ts +++ b/src/mcp/tools.ts @@ -307,7 +307,7 @@ export class ToolHandler { * Walks up parent directories to find the nearest .codegraph/ folder, * similar to how git finds .git/ directories. */ - private getCodeGraph(projectPath?: string): CodeGraph { + private async getCodeGraph(projectPath?: string): Promise { if (!projectPath) { if (!this.cg) { throw new Error('CodeGraph not initialized for this project. Run \'codegraph init\' first.'); @@ -336,7 +336,7 @@ export class ToolHandler { } // Open and cache under both paths - const cg = CodeGraph.openSync(resolvedRoot); + const cg = await CodeGraph.open(resolvedRoot); this.projectCache.set(resolvedRoot, cg); if (projectPath !== resolvedRoot) { this.projectCache.set(projectPath, cg); @@ -403,12 +403,12 @@ export class ToolHandler { const query = this.validateString(args.query, 'query'); if (typeof query !== 'string') return query; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const kind = args.kind as string | undefined; const rawLimit = Number(args.limit) || 10; const limit = clamp(rawLimit, 1, 100); - const results = cg.searchNodes(query, { + const results = await cg.searchNodes(query, { limit, kinds: kind ? [kind as NodeKind] : undefined, }); @@ -434,7 +434,7 @@ export class ToolHandler { markSessionConsulted(sessionId); } - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const maxNodes = (args.maxNodes as number) || 20; const includeCode = args.includeCode !== false; @@ -494,10 +494,10 @@ export class ToolHandler { const symbol = this.validateString(args.symbol, 'symbol'); if (typeof symbol !== 'string') return symbol; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const limit = clamp((args.limit as number) || 20, 1, 100); - const allMatches = this.findAllSymbols(cg, symbol); + const allMatches = await this.findAllSymbols(cg, symbol); if (allMatches.nodes.length === 0) { return this.textResult(`Symbol "${symbol}" not found in the codebase`); } @@ -506,7 +506,7 @@ export class ToolHandler { const seen = new Set(); const allCallers: Node[] = []; for (const node of allMatches.nodes) { - for (const c of cg.getCallers(node.id)) { + for (const c of await cg.getCallers(node.id)) { if (!seen.has(c.node.id)) { seen.add(c.node.id); allCallers.push(c.node); @@ -529,10 +529,10 @@ export class ToolHandler { const symbol = this.validateString(args.symbol, 'symbol'); if (typeof symbol !== 'string') return symbol; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const limit = clamp((args.limit as number) || 20, 1, 100); - const allMatches = this.findAllSymbols(cg, symbol); + const allMatches = await this.findAllSymbols(cg, symbol); if (allMatches.nodes.length === 0) { return this.textResult(`Symbol "${symbol}" not found in the codebase`); } @@ -541,7 +541,7 @@ export class ToolHandler { const seen = new Set(); const allCallees: Node[] = []; for (const node of allMatches.nodes) { - for (const c of cg.getCallees(node.id)) { + for (const c of await cg.getCallees(node.id)) { if (!seen.has(c.node.id)) { seen.add(c.node.id); allCallees.push(c.node); @@ -564,10 +564,10 @@ export class ToolHandler { const symbol = this.validateString(args.symbol, 'symbol'); if (typeof symbol !== 'string') return symbol; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const depth = clamp((args.depth as number) || 2, 1, 10); - const allMatches = this.findAllSymbols(cg, symbol); + const allMatches = await this.findAllSymbols(cg, symbol); if (allMatches.nodes.length === 0) { return this.textResult(`Symbol "${symbol}" not found in the codebase`); } @@ -578,7 +578,7 @@ export class ToolHandler { const seenEdges = new Set(); for (const node of allMatches.nodes) { - const impact = cg.getImpactRadius(node.id, depth); + const impact = await cg.getImpactRadius(node.id, depth); for (const [id, n] of impact.nodes) { mergedNodes.set(id, n); } @@ -615,7 +615,7 @@ export class ToolHandler { const query = this.validateString(args.query, 'query'); if (typeof query !== 'string') return query; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const maxFiles = clamp((args.maxFiles as number) || 12, 1, 20); const projectRoot = cg.getProjectRoot(); @@ -864,11 +864,11 @@ export class ToolHandler { const symbol = this.validateString(args.symbol, 'symbol'); if (typeof symbol !== 'string') return symbol; - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); // Default to false to minimize context usage const includeCode = args.includeCode === true; - const match = this.findSymbol(cg, symbol); + const match = await this.findSymbol(cg, symbol); if (!match) { return this.textResult(`Symbol "${symbol}" not found in the codebase`); } @@ -887,8 +887,8 @@ export class ToolHandler { * Handle codegraph_status */ private async handleStatus(args: Record): Promise { - const cg = this.getCodeGraph(args.projectPath as string | undefined); - const stats = cg.getStats(); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); + const stats = await cg.getStats(); const lines: string[] = [ '## CodeGraph Status', @@ -921,7 +921,7 @@ export class ToolHandler { * Handle codegraph_files - get project file structure from the index */ private async handleFiles(args: Record): Promise { - const cg = this.getCodeGraph(args.projectPath as string | undefined); + const cg = await this.getCodeGraph(args.projectPath as string | undefined); const pathFilter = args.path as string | undefined; const pattern = args.pattern as string | undefined; const format = (args.format as 'tree' | 'flat' | 'grouped') || 'tree'; @@ -929,7 +929,7 @@ export class ToolHandler { const maxDepth = args.maxDepth != null ? clamp(args.maxDepth as number, 1, 20) : undefined; // Get all files from the index - const allFiles = cg.getFiles(); + const allFiles = await cg.getFiles(); if (allFiles.length === 0) { return this.textResult('No files indexed. Run `codegraph index` first.'); @@ -1134,11 +1134,11 @@ export class ToolHandler { return false; } - private findSymbol(cg: CodeGraph, symbol: string): { node: Node; note: string } | null { + private async findSymbol(cg: CodeGraph, symbol: string): Promise<{ node: Node; note: string } | null> { // Use higher limit for qualified lookups (e.g., "Session.request") since the // target may rank lower in FTS when there are many partial matches const limit = symbol.includes('.') ? 50 : 10; - const results = cg.searchNodes(symbol, { limit }); + const results = await cg.searchNodes(symbol, { limit }); if (results.length === 0 || !results[0]) { return null; @@ -1168,8 +1168,8 @@ export class ToolHandler { * Find ALL symbols matching a name. Used by callers/callees/impact to aggregate * results across all matching symbols (e.g., multiple classes with an `execute` method). */ - private findAllSymbols(cg: CodeGraph, symbol: string): { nodes: Node[]; note: string } { - const results = cg.searchNodes(symbol, { limit: 50 }); + private async findAllSymbols(cg: CodeGraph, symbol: string): Promise<{ nodes: Node[]; note: string }> { + const results = await cg.searchNodes(symbol, { limit: 50 }); if (results.length === 0) { return { nodes: [], note: '' }; diff --git a/src/resolution/frameworks/csharp.ts b/src/resolution/frameworks/csharp.ts index 5f278f01..e7ba51ff 100644 --- a/src/resolution/frameworks/csharp.ts +++ b/src/resolution/frameworks/csharp.ts @@ -10,9 +10,9 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const aspnetResolver: FrameworkResolver = { name: 'aspnet', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for .csproj files with ASP.NET references - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.csproj')) { const content = context.readFile(file); @@ -45,10 +45,10 @@ export const aspnetResolver: FrameworkResolver = { return allFiles.some((f) => f.includes('/Controllers/') && f.endsWith('Controller.cs')); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Controller references if (ref.referenceName.endsWith('Controller')) { - const result = resolveController(ref.referenceName, context); + const result = await resolveController(ref.referenceName, context); if (result) { return { original: ref, @@ -61,7 +61,7 @@ export const aspnetResolver: FrameworkResolver = { // Pattern 2: Service references (dependency injection) if (ref.referenceName.endsWith('Service') || ref.referenceName.startsWith('I') && ref.referenceName.length > 1) { - const result = resolveService(ref.referenceName, context); + const result = await resolveService(ref.referenceName, context); if (result) { return { original: ref, @@ -74,7 +74,7 @@ export const aspnetResolver: FrameworkResolver = { // Pattern 3: Repository references if (ref.referenceName.endsWith('Repository')) { - const result = resolveRepository(ref.referenceName, context); + const result = await resolveRepository(ref.referenceName, context); if (result) { return { original: ref, @@ -87,7 +87,7 @@ export const aspnetResolver: FrameworkResolver = { // Pattern 4: Model/Entity references if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveModel(ref.referenceName, context); + const result = await resolveModel(ref.referenceName, context); if (result) { return { original: ref, @@ -100,7 +100,7 @@ export const aspnetResolver: FrameworkResolver = { // Pattern 5: ViewModel references if (ref.referenceName.endsWith('ViewModel') || ref.referenceName.endsWith('Dto')) { - const result = resolveViewModel(ref.referenceName, context); + const result = await resolveViewModel(ref.referenceName, context); if (result) { return { original: ref, @@ -215,12 +215,12 @@ export const aspnetResolver: FrameworkResolver = { // Helper functions -function resolveController(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveController(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.cs') && file.includes('/Controllers/')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const controllerNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -233,13 +233,13 @@ function resolveController(name: string, context: ResolutionContext): string | n return null; } -function resolveService(name: string, context: ResolutionContext): string | null { +async function resolveService(name: string, context: ResolutionContext): Promise { const serviceDirs = ['Services', 'Service', 'Application']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.cs') && serviceDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'interface') && n.name === name ); @@ -252,7 +252,7 @@ function resolveService(name: string, context: ResolutionContext): string | null // Search all C# files for interfaces (often services are injected via interface) for (const file of allFiles) { if (file.endsWith('.cs')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'interface') && n.name === name ); @@ -265,13 +265,13 @@ function resolveService(name: string, context: ResolutionContext): string | null return null; } -function resolveRepository(name: string, context: ResolutionContext): string | null { +async function resolveRepository(name: string, context: ResolutionContext): Promise { const repoDirs = ['Repositories', 'Repository', 'Data', 'Infrastructure']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.cs') && repoDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const repoNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'interface') && n.name === name ); @@ -284,13 +284,13 @@ function resolveRepository(name: string, context: ResolutionContext): string | n return null; } -function resolveModel(name: string, context: ResolutionContext): string | null { +async function resolveModel(name: string, context: ResolutionContext): Promise { const modelDirs = ['Models', 'Model', 'Entities', 'Entity', 'Domain']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.cs') && modelDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -303,13 +303,13 @@ function resolveModel(name: string, context: ResolutionContext): string | null { return null; } -function resolveViewModel(name: string, context: ResolutionContext): string | null { +async function resolveViewModel(name: string, context: ResolutionContext): Promise { const viewModelDirs = ['ViewModels', 'ViewModel', 'DTOs', 'Dto']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.cs') && viewModelDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const vmNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); diff --git a/src/resolution/frameworks/express.ts b/src/resolution/frameworks/express.ts index f98dac65..da9d3ea8 100644 --- a/src/resolution/frameworks/express.ts +++ b/src/resolution/frameworks/express.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const expressResolver: FrameworkResolver = { name: 'express', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for Express in package.json const packageJson = context.readFile('package.json'); if (packageJson) { @@ -26,7 +26,7 @@ export const expressResolver: FrameworkResolver = { } // Check for common Express patterns - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if ( file.includes('routes') || @@ -43,10 +43,10 @@ export const expressResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Middleware references if (isMiddlewareName(ref.referenceName)) { - const result = resolveMiddleware(ref.referenceName, context); + const result = await resolveMiddleware(ref.referenceName, context); if (result) { return { original: ref, @@ -61,7 +61,7 @@ export const expressResolver: FrameworkResolver = { const controllerMatch = ref.referenceName.match(/^(\w+)Controller\.(\w+)$/); if (controllerMatch) { const [, controller, method] = controllerMatch; - const result = resolveController(controller!, method!, context); + const result = await resolveController(controller!, method!, context); if (result) { return { original: ref, @@ -76,7 +76,7 @@ export const expressResolver: FrameworkResolver = { const serviceMatch = ref.referenceName.match(/^(\w+)(Service|Helper|Utils?)\.(\w+)$/); if (serviceMatch) { const [, name, suffix, method] = serviceMatch; - const result = resolveService(name! + suffix!, method!, context); + const result = await resolveService(name! + suffix!, method!, context); if (result) { return { original: ref, @@ -156,18 +156,18 @@ function isMiddlewareName(name: string): boolean { /** * Resolve middleware reference */ -function resolveMiddleware( +async function resolveMiddleware( name: string, context: ResolutionContext -): string | null { +): Promise { // Look in middleware directories const middlewareDirs = ['middleware', 'middlewares', 'src/middleware', 'src/middlewares']; for (const dir of middlewareDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) || file.includes('/middleware/')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const match = nodes.find( (n) => n.name.toLowerCase() === name.toLowerCase() || @@ -186,21 +186,21 @@ function resolveMiddleware( /** * Resolve controller method */ -function resolveController( +async function resolveController( controller: string, method: string, context: ResolutionContext -): string | null { +): Promise { const controllerDirs = ['controllers', 'src/controllers', 'app/controllers']; for (const dir of controllerDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if ( (file.startsWith(dir) || file.includes('/controllers/')) && file.toLowerCase().includes(controller.toLowerCase()) ) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const methodNode = nodes.find( (n) => (n.kind === 'method' || n.kind === 'function') && n.name === method ); @@ -217,21 +217,21 @@ function resolveController( /** * Resolve service/helper */ -function resolveService( +async function resolveService( serviceName: string, method: string, context: ResolutionContext -): string | null { +): Promise { const serviceDirs = ['services', 'src/services', 'helpers', 'src/helpers', 'utils', 'src/utils']; for (const dir of serviceDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if ( (file.startsWith(dir) || file.includes('/services/') || file.includes('/helpers/') || file.includes('/utils/')) && file.toLowerCase().includes(serviceName.toLowerCase().replace(/(service|helper|utils?)$/i, '')) ) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const methodNode = nodes.find( (n) => (n.kind === 'method' || n.kind === 'function') && n.name === method ); diff --git a/src/resolution/frameworks/go.ts b/src/resolution/frameworks/go.ts index 1a06e70e..6235cb3f 100644 --- a/src/resolution/frameworks/go.ts +++ b/src/resolution/frameworks/go.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const goResolver: FrameworkResolver = { name: 'go', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for go.mod file (Go modules) const goMod = context.readFile('go.mod'); if (goMod) { @@ -18,14 +18,14 @@ export const goResolver: FrameworkResolver = { } // Check for .go files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); return allFiles.some((f) => f.endsWith('.go')); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Handler references if (ref.referenceName.endsWith('Handler') || ref.referenceName.startsWith('Handle')) { - const result = resolveHandler(ref.referenceName, context); + const result = await resolveHandler(ref.referenceName, context); if (result) { return { original: ref, @@ -38,7 +38,7 @@ export const goResolver: FrameworkResolver = { // Pattern 2: Service/Repository references if (ref.referenceName.endsWith('Service') || ref.referenceName.endsWith('Repository') || ref.referenceName.endsWith('Store')) { - const result = resolveService(ref.referenceName, context); + const result = await resolveService(ref.referenceName, context); if (result) { return { original: ref, @@ -51,7 +51,7 @@ export const goResolver: FrameworkResolver = { // Pattern 3: Middleware references if (ref.referenceName.endsWith('Middleware') || ref.referenceName.startsWith('Auth') || ref.referenceName.startsWith('Log')) { - const result = resolveMiddleware(ref.referenceName, context); + const result = await resolveMiddleware(ref.referenceName, context); if (result) { return { original: ref, @@ -64,7 +64,7 @@ export const goResolver: FrameworkResolver = { // Pattern 4: Model/Entity references (typically PascalCase structs) if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveModel(ref.referenceName, context); + const result = await resolveModel(ref.referenceName, context); if (result) { return { original: ref, @@ -180,13 +180,13 @@ export const goResolver: FrameworkResolver = { // Helper functions -function resolveHandler(name: string, context: ResolutionContext): string | null { +async function resolveHandler(name: string, context: ResolutionContext): Promise { const handlerDirs = ['handler', 'handlers', 'api', 'routes', 'controller', 'controllers']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.go') && handlerDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const handlerNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); @@ -199,7 +199,7 @@ function resolveHandler(name: string, context: ResolutionContext): string | null // Search all go files for (const file of allFiles) { if (file.endsWith('.go')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const handlerNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); @@ -212,13 +212,13 @@ function resolveHandler(name: string, context: ResolutionContext): string | null return null; } -function resolveService(name: string, context: ResolutionContext): string | null { +async function resolveService(name: string, context: ResolutionContext): Promise { const serviceDirs = ['service', 'services', 'repository', 'store', 'pkg']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.go') && serviceDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => (n.kind === 'struct' || n.kind === 'interface') && n.name === name ); @@ -231,13 +231,13 @@ function resolveService(name: string, context: ResolutionContext): string | null return null; } -function resolveMiddleware(name: string, context: ResolutionContext): string | null { +async function resolveMiddleware(name: string, context: ResolutionContext): Promise { const middlewareDirs = ['middleware', 'middlewares']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.go') && middlewareDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const mwNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); @@ -250,13 +250,13 @@ function resolveMiddleware(name: string, context: ResolutionContext): string | n return null; } -function resolveModel(name: string, context: ResolutionContext): string | null { +async function resolveModel(name: string, context: ResolutionContext): Promise { const modelDirs = ['model', 'models', 'entity', 'entities', 'domain', 'pkg']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.go') && modelDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => n.kind === 'struct' && n.name === name ); diff --git a/src/resolution/frameworks/index.ts b/src/resolution/frameworks/index.ts index 830a7d62..3752c6bc 100644 --- a/src/resolution/frameworks/index.ts +++ b/src/resolution/frameworks/index.ts @@ -64,14 +64,18 @@ export function getFrameworkResolver(name: string): FrameworkResolver | undefine /** * Detect which frameworks are used in a project */ -export function detectFrameworks(context: ResolutionContext): FrameworkResolver[] { - return FRAMEWORK_RESOLVERS.filter((resolver) => { - try { - return resolver.detect(context); - } catch { - return false; - } - }); +export async function detectFrameworks(context: ResolutionContext): Promise { + const results = await Promise.all( + FRAMEWORK_RESOLVERS.map(async (resolver) => { + try { + const detected = await resolver.detect(context); + return detected ? resolver : null; + } catch { + return null; + } + }) + ); + return results.filter((r): r is FrameworkResolver => r !== null); } /** diff --git a/src/resolution/frameworks/java.ts b/src/resolution/frameworks/java.ts index 6bb3ae7d..9467b05d 100644 --- a/src/resolution/frameworks/java.ts +++ b/src/resolution/frameworks/java.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const springResolver: FrameworkResolver = { name: 'spring', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for pom.xml with Spring const pomXml = context.readFile('pom.xml'); if (pomXml && (pomXml.includes('spring-boot') || pomXml.includes('springframework'))) { @@ -29,7 +29,7 @@ export const springResolver: FrameworkResolver = { } // Check for Spring annotations in Java files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.java')) { const content = context.readFile(file); @@ -47,10 +47,10 @@ export const springResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Service references (dependency injection) if (ref.referenceName.endsWith('Service')) { - const result = resolveService(ref.referenceName, context); + const result = await resolveService(ref.referenceName, context); if (result) { return { original: ref, @@ -63,7 +63,7 @@ export const springResolver: FrameworkResolver = { // Pattern 2: Repository references if (ref.referenceName.endsWith('Repository')) { - const result = resolveRepository(ref.referenceName, context); + const result = await resolveRepository(ref.referenceName, context); if (result) { return { original: ref, @@ -76,7 +76,7 @@ export const springResolver: FrameworkResolver = { // Pattern 3: Controller references if (ref.referenceName.endsWith('Controller')) { - const result = resolveController(ref.referenceName, context); + const result = await resolveController(ref.referenceName, context); if (result) { return { original: ref, @@ -89,7 +89,7 @@ export const springResolver: FrameworkResolver = { // Pattern 4: Entity/Model references if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveEntity(ref.referenceName, context); + const result = await resolveEntity(ref.referenceName, context); if (result) { return { original: ref, @@ -102,7 +102,7 @@ export const springResolver: FrameworkResolver = { // Pattern 5: Component references if (ref.referenceName.endsWith('Component') || ref.referenceName.endsWith('Config')) { - const result = resolveComponent(ref.referenceName, context); + const result = await resolveComponent(ref.referenceName, context); if (result) { return { original: ref, @@ -178,12 +178,12 @@ export const springResolver: FrameworkResolver = { // Helper functions -function resolveService(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveService(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.java') && (file.includes('/service/') || file.includes('/services/'))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -196,7 +196,7 @@ function resolveService(name: string, context: ResolutionContext): string | null // Also check interface definitions for (const file of allFiles) { if (file.endsWith('.java')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'interface') && n.name === name ); @@ -209,12 +209,12 @@ function resolveService(name: string, context: ResolutionContext): string | null return null; } -function resolveRepository(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveRepository(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.java') && (file.includes('/repository/') || file.includes('/repositories/'))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const repoNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'interface') && n.name === name ); @@ -227,12 +227,12 @@ function resolveRepository(name: string, context: ResolutionContext): string | n return null; } -function resolveController(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveController(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.java') && (file.includes('/controller/') || file.includes('/controllers/'))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const controllerNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -245,8 +245,8 @@ function resolveController(name: string, context: ResolutionContext): string | n return null; } -function resolveEntity(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveEntity(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); // Check entity/model directories first for (const file of allFiles) { @@ -257,7 +257,7 @@ function resolveEntity(name: string, context: ResolutionContext): string | null file.includes('/models/') || file.includes('/domain/') )) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const entityNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -270,8 +270,8 @@ function resolveEntity(name: string, context: ResolutionContext): string | null return null; } -function resolveComponent(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveComponent(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.java') && ( @@ -279,7 +279,7 @@ function resolveComponent(name: string, context: ResolutionContext): string | nu file.includes('/components/') || file.includes('/config/') )) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const componentNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); diff --git a/src/resolution/frameworks/laravel.ts b/src/resolution/frameworks/laravel.ts index 05d0559c..0ce6ee0f 100644 --- a/src/resolution/frameworks/laravel.ts +++ b/src/resolution/frameworks/laravel.ts @@ -37,17 +37,17 @@ export const FACADE_MAPPINGS: Record = { export const laravelResolver: FrameworkResolver = { name: 'laravel', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for artisan file (Laravel signature) return context.fileExists('artisan') || context.fileExists('app/Http/Kernel.php'); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Model::method() - Eloquent static calls const modelMatch = ref.referenceName.match(/^([A-Z][a-zA-Z]+)::(\w+)$/); if (modelMatch) { const [, className, methodName] = modelMatch; - const result = resolveModelCall(className!, methodName!, context); + const result = await resolveModelCall(className!, methodName!, context); if (result) { return { original: ref, @@ -76,7 +76,7 @@ export const laravelResolver: FrameworkResolver = { const controllerMatch = ref.referenceName.match(/^([A-Z][a-zA-Z]+Controller)@(\w+)$/); if (controllerMatch) { const [, controller, method] = controllerMatch; - const result = resolveControllerMethod(controller!, method!, context); + const result = await resolveControllerMethod(controller!, method!, context); if (result) { return { original: ref, @@ -150,15 +150,15 @@ export const laravelResolver: FrameworkResolver = { /** * Resolve a Model::method() call */ -function resolveModelCall( +async function resolveModelCall( className: string, methodName: string, context: ResolutionContext -): string | null { +): Promise { // Try app/Models/ first (Laravel 8+) let modelPath = `app/Models/${className}.php`; if (context.fileExists(modelPath)) { - const nodes = context.getNodesInFile(modelPath); + const nodes = await context.getNodesInFile(modelPath); // Look for the method in this class const methodNode = nodes.find( (n) => n.kind === 'method' && n.name === methodName @@ -178,7 +178,7 @@ function resolveModelCall( // Try app/ (Laravel 7 and below) modelPath = `app/${className}.php`; if (context.fileExists(modelPath)) { - const nodes = context.getNodesInFile(modelPath); + const nodes = await context.getNodesInFile(modelPath); const methodNode = nodes.find( (n) => n.kind === 'method' && n.name === methodName ); @@ -199,15 +199,15 @@ function resolveModelCall( /** * Resolve a Controller@method reference */ -function resolveControllerMethod( +async function resolveControllerMethod( controller: string, method: string, context: ResolutionContext -): string | null { +): Promise { // Try app/Http/Controllers/ const controllerPath = `app/Http/Controllers/${controller}.php`; if (context.fileExists(controllerPath)) { - const nodes = context.getNodesInFile(controllerPath); + const nodes = await context.getNodesInFile(controllerPath); const methodNode = nodes.find( (n) => n.kind === 'method' && n.name === method ); @@ -217,10 +217,10 @@ function resolveControllerMethod( } // Try subdirectories (namespaced controllers) - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith(`${controller}.php`) && file.includes('Controllers')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const methodNode = nodes.find( (n) => n.kind === 'method' && n.name === method ); diff --git a/src/resolution/frameworks/python.ts b/src/resolution/frameworks/python.ts index b91623fb..61cfc48e 100644 --- a/src/resolution/frameworks/python.ts +++ b/src/resolution/frameworks/python.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const djangoResolver: FrameworkResolver = { name: 'django', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for Django in requirements.txt or setup.py const requirements = context.readFile('requirements.txt'); if (requirements && requirements.includes('django')) { @@ -31,10 +31,10 @@ export const djangoResolver: FrameworkResolver = { return context.fileExists('manage.py'); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Model references if (ref.referenceName.endsWith('Model') || /^[A-Z][a-z]+$/.test(ref.referenceName)) { - const result = resolveModel(ref.referenceName, context); + const result = await resolveModel(ref.referenceName, context); if (result) { return { original: ref, @@ -47,7 +47,7 @@ export const djangoResolver: FrameworkResolver = { // Pattern 2: View references if (ref.referenceName.endsWith('View') || ref.referenceName.endsWith('ViewSet')) { - const result = resolveView(ref.referenceName, context); + const result = await resolveView(ref.referenceName, context); if (result) { return { original: ref, @@ -60,7 +60,7 @@ export const djangoResolver: FrameworkResolver = { // Pattern 3: Form references if (ref.referenceName.endsWith('Form')) { - const result = resolveForm(ref.referenceName, context); + const result = await resolveForm(ref.referenceName, context); if (result) { return { original: ref, @@ -114,7 +114,7 @@ export const djangoResolver: FrameworkResolver = { export const flaskResolver: FrameworkResolver = { name: 'flask', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { const requirements = context.readFile('requirements.txt'); if (requirements && (requirements.includes('flask') || requirements.includes('Flask'))) { return true; @@ -137,10 +137,10 @@ export const flaskResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Blueprint references if (ref.referenceName.endsWith('_bp') || ref.referenceName.endsWith('_blueprint')) { - const result = resolveBlueprint(ref.referenceName, context); + const result = await resolveBlueprint(ref.referenceName, context); if (result) { return { original: ref, @@ -189,7 +189,7 @@ export const flaskResolver: FrameworkResolver = { export const fastapiResolver: FrameworkResolver = { name: 'fastapi', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { const requirements = context.readFile('requirements.txt'); if (requirements && requirements.includes('fastapi')) { return true; @@ -212,10 +212,10 @@ export const fastapiResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Router references if (ref.referenceName.endsWith('_router') || ref.referenceName === 'router') { - const result = resolveRouter(ref.referenceName, context); + const result = await resolveRouter(ref.referenceName, context); if (result) { return { original: ref, @@ -228,7 +228,7 @@ export const fastapiResolver: FrameworkResolver = { // Pattern 2: Dependency references if (ref.referenceName.startsWith('get_') || ref.referenceName.startsWith('Depends')) { - const result = resolveDependency(ref.referenceName, context); + const result = await resolveDependency(ref.referenceName, context); if (result) { return { original: ref, @@ -276,14 +276,14 @@ export const fastapiResolver: FrameworkResolver = { // Helper functions -function resolveModel(name: string, context: ResolutionContext): string | null { +async function resolveModel(name: string, context: ResolutionContext): Promise { const modelDirs = ['models', 'app/models', 'src/models']; for (const dir of modelDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) && file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -297,14 +297,14 @@ function resolveModel(name: string, context: ResolutionContext): string | null { return null; } -function resolveView(name: string, context: ResolutionContext): string | null { +async function resolveView(name: string, context: ResolutionContext): Promise { const viewDirs = ['views', 'app/views', 'src/views', 'api/views']; for (const dir of viewDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) && file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const viewNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'function') && n.name === name ); @@ -318,14 +318,14 @@ function resolveView(name: string, context: ResolutionContext): string | null { return null; } -function resolveForm(name: string, context: ResolutionContext): string | null { +async function resolveForm(name: string, context: ResolutionContext): Promise { const formDirs = ['forms', 'app/forms', 'src/forms']; for (const dir of formDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) && file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const formNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -339,11 +339,11 @@ function resolveForm(name: string, context: ResolutionContext): string | null { return null; } -function resolveBlueprint(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveBlueprint(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const bpNode = nodes.find( (n) => n.kind === 'variable' && n.name === name ); @@ -356,14 +356,14 @@ function resolveBlueprint(name: string, context: ResolutionContext): string | nu return null; } -function resolveRouter(name: string, context: ResolutionContext): string | null { +async function resolveRouter(name: string, context: ResolutionContext): Promise { const routerDirs = ['routers', 'api', 'routes', 'endpoints']; for (const dir of routerDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if ((file.startsWith(dir) || file.includes('/routers/')) && file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const routerNode = nodes.find( (n) => n.kind === 'variable' && n.name === name ); @@ -377,14 +377,14 @@ function resolveRouter(name: string, context: ResolutionContext): string | null return null; } -function resolveDependency(name: string, context: ResolutionContext): string | null { +async function resolveDependency(name: string, context: ResolutionContext): Promise { const depDirs = ['dependencies', 'deps', 'core']; for (const dir of depDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if ((file.startsWith(dir) || file.includes('/dependencies/')) && file.endsWith('.py')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const depNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); diff --git a/src/resolution/frameworks/react.ts b/src/resolution/frameworks/react.ts index f1ec697c..d0dc9d2c 100644 --- a/src/resolution/frameworks/react.ts +++ b/src/resolution/frameworks/react.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const reactResolver: FrameworkResolver = { name: 'react', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for React in package.json const packageJson = context.readFile('package.json'); if (packageJson) { @@ -26,14 +26,14 @@ export const reactResolver: FrameworkResolver = { } // Check for .jsx/.tsx files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); return allFiles.some((f) => f.endsWith('.jsx') || f.endsWith('.tsx')); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Component references (PascalCase) if (isPascalCase(ref.referenceName) && !isBuiltInType(ref.referenceName)) { - const result = resolveComponent(ref.referenceName, ref.filePath, context); + const result = await resolveComponent(ref.referenceName, ref.filePath, context); if (result) { return { original: ref, @@ -46,7 +46,7 @@ export const reactResolver: FrameworkResolver = { // Pattern 2: Hook references (use*) if (ref.referenceName.startsWith('use') && ref.referenceName.length > 3) { - const result = resolveHook(ref.referenceName, context); + const result = await resolveHook(ref.referenceName, context); if (result) { return { original: ref, @@ -59,7 +59,7 @@ export const reactResolver: FrameworkResolver = { // Pattern 3: Context references if (ref.referenceName.endsWith('Context') || ref.referenceName.endsWith('Provider')) { - const result = resolveContext(ref.referenceName, context); + const result = await resolveContext(ref.referenceName, context); if (result) { return { original: ref, @@ -194,11 +194,11 @@ function isBuiltInType(name: string): boolean { /** * Resolve a component reference */ -function resolveComponent( +async function resolveComponent( name: string, fromFile: string, context: ResolutionContext -): string | null { +): Promise { // Look for component in common locations const componentDirs = [ 'components', @@ -212,10 +212,11 @@ function resolveComponent( // First, check same directory const fromDir = fromFile.substring(0, fromFile.lastIndexOf('/')); - const sameDir = context.getAllFiles().filter((f) => f.startsWith(fromDir)); + const allFilesForSameDir = await context.getAllFiles(); + const sameDir = allFilesForSameDir.filter((f) => f.startsWith(fromDir)); for (const file of sameDir) { if (file.toLowerCase().includes(name.toLowerCase())) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const component = nodes.find( (n) => (n.kind === 'component' || n.kind === 'function' || n.kind === 'class') && n.name === name ); @@ -227,10 +228,10 @@ function resolveComponent( // Then check component directories for (const dir of componentDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) && file.toLowerCase().includes(name.toLowerCase())) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const component = nodes.find( (n) => (n.kind === 'component' || n.kind === 'function' || n.kind === 'class') && n.name === name ); @@ -247,14 +248,14 @@ function resolveComponent( /** * Resolve a custom hook reference */ -function resolveHook(name: string, context: ResolutionContext): string | null { +async function resolveHook(name: string, context: ResolutionContext): Promise { const hookDirs = ['hooks', 'src/hooks', 'lib/hooks', 'utils/hooks']; for (const dir of hookDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) || file.includes('/hooks/')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const hook = nodes.find((n) => n.kind === 'function' && n.name === name); if (hook) { return hook.id; @@ -264,7 +265,7 @@ function resolveHook(name: string, context: ResolutionContext): string | null { } // Also check all files for the hook - const allNodes = context.getNodesByName(name); + const allNodes = await context.getNodesByName(name); const hookNode = allNodes.find((n) => n.kind === 'function' && n.name.startsWith('use')); if (hookNode) { return hookNode.id; @@ -276,14 +277,14 @@ function resolveHook(name: string, context: ResolutionContext): string | null { /** * Resolve a context reference */ -function resolveContext(name: string, context: ResolutionContext): string | null { +async function resolveContext(name: string, context: ResolutionContext): Promise { const contextDirs = ['context', 'contexts', 'src/context', 'src/contexts', 'providers', 'src/providers']; for (const dir of contextDirs) { - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.startsWith(dir) || file.includes('/context/') || file.includes('/contexts/')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const contextNode = nodes.find((n) => n.name === name || n.name === name.replace(/Context$|Provider$/, '')); if (contextNode) { return contextNode.id; diff --git a/src/resolution/frameworks/ruby.ts b/src/resolution/frameworks/ruby.ts index c2c6f402..937f3460 100644 --- a/src/resolution/frameworks/ruby.ts +++ b/src/resolution/frameworks/ruby.ts @@ -10,7 +10,7 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const railsResolver: FrameworkResolver = { name: 'rails', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for Gemfile with rails const gemfile = context.readFile('Gemfile'); if (gemfile && gemfile.includes("'rails'")) { @@ -29,10 +29,10 @@ export const railsResolver: FrameworkResolver = { ); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Model references (ActiveRecord) if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveModel(ref.referenceName, context); + const result = await resolveModel(ref.referenceName, context); if (result) { return { original: ref, @@ -45,7 +45,7 @@ export const railsResolver: FrameworkResolver = { // Pattern 2: Controller references if (ref.referenceName.endsWith('Controller')) { - const result = resolveController(ref.referenceName, context); + const result = await resolveController(ref.referenceName, context); if (result) { return { original: ref, @@ -58,7 +58,7 @@ export const railsResolver: FrameworkResolver = { // Pattern 3: Helper references if (ref.referenceName.endsWith('Helper')) { - const result = resolveHelper(ref.referenceName, context); + const result = await resolveHelper(ref.referenceName, context); if (result) { return { original: ref, @@ -71,7 +71,7 @@ export const railsResolver: FrameworkResolver = { // Pattern 4: Service/Job references if (ref.referenceName.endsWith('Service') || ref.referenceName.endsWith('Job')) { - const result = resolveService(ref.referenceName, context); + const result = await resolveService(ref.referenceName, context); if (result) { return { original: ref, @@ -188,7 +188,7 @@ export const railsResolver: FrameworkResolver = { // Helper functions -function resolveModel(name: string, context: ResolutionContext): string | null { +async function resolveModel(name: string, context: ResolutionContext): Promise { // Convert CamelCase to snake_case for file lookup const snakeName = name.replace(/([A-Z])/g, '_$1').toLowerCase().slice(1); const possiblePaths = [ @@ -198,7 +198,7 @@ function resolveModel(name: string, context: ResolutionContext): string | null { for (const modelPath of possiblePaths) { if (context.fileExists(modelPath)) { - const nodes = context.getNodesInFile(modelPath); + const nodes = await context.getNodesInFile(modelPath); const modelNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -209,10 +209,10 @@ function resolveModel(name: string, context: ResolutionContext): string | null { } // Search all model files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.includes('app/models/') && file.endsWith('.rb')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -225,7 +225,7 @@ function resolveModel(name: string, context: ResolutionContext): string | null { return null; } -function resolveController(name: string, context: ResolutionContext): string | null { +async function resolveController(name: string, context: ResolutionContext): Promise { // Convert CamelCase to snake_case const snakeName = name.replace(/([A-Z])/g, '_$1').toLowerCase().slice(1); const possiblePaths = [ @@ -236,7 +236,7 @@ function resolveController(name: string, context: ResolutionContext): string | n for (const controllerPath of possiblePaths) { if (context.fileExists(controllerPath)) { - const nodes = context.getNodesInFile(controllerPath); + const nodes = await context.getNodesInFile(controllerPath); const controllerNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -247,10 +247,10 @@ function resolveController(name: string, context: ResolutionContext): string | n } // Search all controller files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.includes('controllers/') && file.endsWith('.rb')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const controllerNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -263,12 +263,12 @@ function resolveController(name: string, context: ResolutionContext): string | n return null; } -function resolveHelper(name: string, context: ResolutionContext): string | null { +async function resolveHelper(name: string, context: ResolutionContext): Promise { const snakeName = name.replace(/([A-Z])/g, '_$1').toLowerCase().slice(1); const helperPath = `app/helpers/${snakeName}.rb`; if (context.fileExists(helperPath)) { - const nodes = context.getNodesInFile(helperPath); + const nodes = await context.getNodesInFile(helperPath); const helperNode = nodes.find( (n) => n.kind === 'module' && n.name === name ); @@ -280,7 +280,7 @@ function resolveHelper(name: string, context: ResolutionContext): string | null return null; } -function resolveService(name: string, context: ResolutionContext): string | null { +async function resolveService(name: string, context: ResolutionContext): Promise { const snakeName = name.replace(/([A-Z])/g, '_$1').toLowerCase().slice(1); const possiblePaths = [ `app/services/${snakeName}.rb`, @@ -290,7 +290,7 @@ function resolveService(name: string, context: ResolutionContext): string | null for (const servicePath of possiblePaths) { if (context.fileExists(servicePath)) { - const nodes = context.getNodesInFile(servicePath); + const nodes = await context.getNodesInFile(servicePath); const serviceNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); diff --git a/src/resolution/frameworks/rust.ts b/src/resolution/frameworks/rust.ts index 68b6b2d4..9f16fa21 100644 --- a/src/resolution/frameworks/rust.ts +++ b/src/resolution/frameworks/rust.ts @@ -10,15 +10,15 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const rustResolver: FrameworkResolver = { name: 'rust', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for Cargo.toml (Rust project signature) return context.fileExists('Cargo.toml'); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Handler references if (ref.referenceName.endsWith('_handler') || ref.referenceName.startsWith('handle_')) { - const result = resolveHandler(ref.referenceName, context); + const result = await resolveHandler(ref.referenceName, context); if (result) { return { original: ref, @@ -31,7 +31,7 @@ export const rustResolver: FrameworkResolver = { // Pattern 2: Service/Repository trait implementations if (ref.referenceName.endsWith('Service') || ref.referenceName.endsWith('Repository')) { - const result = resolveService(ref.referenceName, context); + const result = await resolveService(ref.referenceName, context); if (result) { return { original: ref, @@ -44,7 +44,7 @@ export const rustResolver: FrameworkResolver = { // Pattern 3: Struct references (PascalCase) if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveStruct(ref.referenceName, context); + const result = await resolveStruct(ref.referenceName, context); if (result) { return { original: ref, @@ -57,7 +57,7 @@ export const rustResolver: FrameworkResolver = { // Pattern 4: Module references if (/^[a-z_]+$/.test(ref.referenceName)) { - const result = resolveModule(ref.referenceName, context); + const result = await resolveModule(ref.referenceName, context); if (result) { return { original: ref, @@ -155,13 +155,13 @@ export const rustResolver: FrameworkResolver = { // Helper functions -function resolveHandler(name: string, context: ResolutionContext): string | null { +async function resolveHandler(name: string, context: ResolutionContext): Promise { const handlerDirs = ['handlers', 'handler', 'api', 'routes', 'controllers']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.rs') && handlerDirs.some((d) => file.includes(`/${d}/`) || file.includes(`/${d}.rs`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const handlerNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); @@ -174,7 +174,7 @@ function resolveHandler(name: string, context: ResolutionContext): string | null // Search all Rust files for (const file of allFiles) { if (file.endsWith('.rs')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const handlerNode = nodes.find( (n) => n.kind === 'function' && n.name === name ); @@ -187,13 +187,13 @@ function resolveHandler(name: string, context: ResolutionContext): string | null return null; } -function resolveService(name: string, context: ResolutionContext): string | null { +async function resolveService(name: string, context: ResolutionContext): Promise { const serviceDirs = ['services', 'service', 'repository', 'domain']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.rs') && serviceDirs.some((d) => file.includes(`/${d}/`) || file.includes(`/${d}.rs`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const serviceNode = nodes.find( (n) => (n.kind === 'struct' || n.kind === 'trait') && n.name === name ); @@ -206,15 +206,15 @@ function resolveService(name: string, context: ResolutionContext): string | null return null; } -function resolveStruct(name: string, context: ResolutionContext): string | null { +async function resolveStruct(name: string, context: ResolutionContext): Promise { const modelDirs = ['models', 'model', 'entities', 'entity', 'domain', 'types']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); // Check model directories first for (const file of allFiles) { if (file.endsWith('.rs') && modelDirs.some((d) => file.includes(`/${d}/`) || file.includes(`/${d}.rs`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const structNode = nodes.find( (n) => n.kind === 'struct' && n.name === name ); @@ -227,7 +227,7 @@ function resolveStruct(name: string, context: ResolutionContext): string | null // Search all Rust files for (const file of allFiles) { if (file.endsWith('.rs')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const structNode = nodes.find( (n) => n.kind === 'struct' && n.name === name ); @@ -240,7 +240,7 @@ function resolveStruct(name: string, context: ResolutionContext): string | null return null; } -function resolveModule(name: string, context: ResolutionContext): string | null { +async function resolveModule(name: string, context: ResolutionContext): Promise { // Rust modules can be either mod.rs in a directory or name.rs const possiblePaths = [ `src/${name}.rs`, @@ -249,7 +249,7 @@ function resolveModule(name: string, context: ResolutionContext): string | null for (const modPath of possiblePaths) { if (context.fileExists(modPath)) { - const nodes = context.getNodesInFile(modPath); + const nodes = await context.getNodesInFile(modPath); const modNode = nodes.find((n) => n.kind === 'module'); if (modNode) { return modNode.id; diff --git a/src/resolution/frameworks/svelte.ts b/src/resolution/frameworks/svelte.ts index 5e0fd9a0..efe5cf0d 100644 --- a/src/resolution/frameworks/svelte.ts +++ b/src/resolution/frameworks/svelte.ts @@ -9,7 +9,7 @@ import { Node } from '../../types'; import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from '../types'; /** - * Svelte 5 runes — compiler-provided, not user code + * Svelte 5 runes -- compiler-provided, not user code */ const SVELTE_RUNES = new Set([ '$state', @@ -45,7 +45,7 @@ const SVELTEKIT_MODULE_PREFIXES = [ export const svelteResolver: FrameworkResolver = { name: 'svelte', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for svelte or @sveltejs/kit in package.json const packageJson = context.readFile('package.json'); if (packageJson) { @@ -61,14 +61,14 @@ export const svelteResolver: FrameworkResolver = { } // Check for .svelte files in project - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); return allFiles.some((f) => f.endsWith('.svelte')); }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Svelte runes ($state, $derived, $effect, etc.) if (isRuneReference(ref.referenceName)) { - // Runes are compiler-provided — return a high-confidence "framework" resolution + // Runes are compiler-provided -- return a high-confidence "framework" resolution // so CodeGraph doesn't waste time searching for user-defined symbols. // We use the fromNodeId as targetNodeId since runes don't have real targets. return { @@ -82,7 +82,8 @@ export const svelteResolver: FrameworkResolver = { // Pattern 2: Store auto-subscriptions ($storeName) if (ref.referenceName.startsWith('$') && !ref.referenceName.startsWith('$$')) { const storeName = ref.referenceName.substring(1); - const storeNode = context.getNodesByName(storeName).find( + const storeNodes = await context.getNodesByName(storeName); + const storeNode = storeNodes.find( (n) => n.kind === 'variable' || n.kind === 'constant' ); if (storeNode) { @@ -97,14 +98,14 @@ export const svelteResolver: FrameworkResolver = { // Pattern 3: SvelteKit module imports ($app/*, $env/*, $lib/*) if (ref.referenceKind === 'imports' && ref.referenceName.startsWith('$')) { - // $lib/* resolves to src/lib/* — try to find the target file + // $lib/* resolves to src/lib/* -- try to find the target file if (ref.referenceName.startsWith('$lib/')) { const libPath = ref.referenceName.replace('$lib/', 'src/lib/'); // Try common extensions for (const ext of ['', '.ts', '.js', '.svelte', '/index.ts', '/index.js']) { const fullPath = libPath + ext; if (context.fileExists(fullPath)) { - const nodes = context.getNodesInFile(fullPath); + const nodes = await context.getNodesInFile(fullPath); if (nodes.length > 0) { return { original: ref, @@ -128,9 +129,9 @@ export const svelteResolver: FrameworkResolver = { } } - // Pattern 4: Component references (PascalCase) — resolve to .svelte files + // Pattern 4: Component references (PascalCase) -- resolve to .svelte files if (isPascalCase(ref.referenceName) && ref.referenceKind === 'calls') { - const result = resolveComponent(ref.referenceName, ref.filePath, context); + const result = await resolveComponent(ref.referenceName, ref.filePath, context); if (result) { return { original: ref, @@ -203,13 +204,13 @@ function isPascalCase(str: string): boolean { /** * Resolve a Svelte component reference to its .svelte file */ -function resolveComponent( +async function resolveComponent( name: string, fromFile: string, context: ResolutionContext -): string | null { +): Promise { // Look for matching .svelte files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); const svelteFiles = allFiles.filter((f) => f.endsWith('.svelte')); // Check for exact name match (Button -> Button.svelte) @@ -217,7 +218,7 @@ function resolveComponent( const fileName = file.split(/[/\\]/).pop() || ''; const componentName = fileName.replace(/\.svelte$/, ''); if (componentName === name) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const component = nodes.find((n) => n.kind === 'component' && n.name === name); if (component) { return component.id; @@ -232,7 +233,7 @@ function resolveComponent( const fileName = file.split(/[/\\]/).pop() || ''; const componentName = fileName.replace(/\.svelte$/, ''); if (componentName === name) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const component = nodes.find((n) => n.kind === 'component'); if (component) { return component.id; diff --git a/src/resolution/frameworks/swift.ts b/src/resolution/frameworks/swift.ts index 6a7f8a25..b3a8dad5 100644 --- a/src/resolution/frameworks/swift.ts +++ b/src/resolution/frameworks/swift.ts @@ -10,9 +10,9 @@ import { FrameworkResolver, UnresolvedRef, ResolvedRef, ResolutionContext } from export const swiftUIResolver: FrameworkResolver = { name: 'swiftui', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for SwiftUI imports in Swift files - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift')) { const content = context.readFile(file); @@ -32,10 +32,10 @@ export const swiftUIResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: View references (SwiftUI views are PascalCase ending in View) if (ref.referenceName.endsWith('View') && /^[A-Z]/.test(ref.referenceName)) { - const result = resolveView(ref.referenceName, context); + const result = await resolveView(ref.referenceName, context); if (result) { return { original: ref, @@ -48,7 +48,7 @@ export const swiftUIResolver: FrameworkResolver = { // Pattern 2: ViewModel/ObservableObject references if (ref.referenceName.endsWith('ViewModel') || ref.referenceName.endsWith('Store') || ref.referenceName.endsWith('Manager')) { - const result = resolveViewModel(ref.referenceName, context); + const result = await resolveViewModel(ref.referenceName, context); if (result) { return { original: ref, @@ -61,7 +61,7 @@ export const swiftUIResolver: FrameworkResolver = { // Pattern 3: Model references if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveModel(ref.referenceName, context); + const result = await resolveModel(ref.referenceName, context); if (result) { return { original: ref, @@ -132,8 +132,8 @@ export const swiftUIResolver: FrameworkResolver = { export const uikitResolver: FrameworkResolver = { name: 'uikit', - detect(context: ResolutionContext): boolean { - const allFiles = context.getAllFiles(); + async detect(context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift')) { const content = context.readFile(file); @@ -150,10 +150,10 @@ export const uikitResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: ViewController references if (ref.referenceName.endsWith('ViewController')) { - const result = resolveViewController(ref.referenceName, context); + const result = await resolveViewController(ref.referenceName, context); if (result) { return { original: ref, @@ -166,7 +166,7 @@ export const uikitResolver: FrameworkResolver = { // Pattern 2: UIView subclass references if (ref.referenceName.endsWith('View') && !ref.referenceName.endsWith('ViewController')) { - const result = resolveUIView(ref.referenceName, context); + const result = await resolveUIView(ref.referenceName, context); if (result) { return { original: ref, @@ -179,7 +179,7 @@ export const uikitResolver: FrameworkResolver = { // Pattern 3: Cell references if (ref.referenceName.endsWith('Cell')) { - const result = resolveCell(ref.referenceName, context); + const result = await resolveCell(ref.referenceName, context); if (result) { return { original: ref, @@ -192,7 +192,7 @@ export const uikitResolver: FrameworkResolver = { // Pattern 4: Delegate/DataSource references if (ref.referenceName.endsWith('Delegate') || ref.referenceName.endsWith('DataSource')) { - const result = resolveProtocol(ref.referenceName, context); + const result = await resolveProtocol(ref.referenceName, context); if (result) { return { original: ref, @@ -262,7 +262,7 @@ export const uikitResolver: FrameworkResolver = { export const vaporResolver: FrameworkResolver = { name: 'vapor', - detect(context: ResolutionContext): boolean { + async detect(context: ResolutionContext): Promise { // Check for Package.swift with Vapor dependency const packageSwift = context.readFile('Package.swift'); if (packageSwift && packageSwift.includes('vapor')) { @@ -270,7 +270,7 @@ export const vaporResolver: FrameworkResolver = { } // Check for Vapor imports - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift')) { const content = context.readFile(file); @@ -283,10 +283,10 @@ export const vaporResolver: FrameworkResolver = { return false; }, - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null { + async resolve(ref: UnresolvedRef, context: ResolutionContext): Promise { // Pattern 1: Controller references if (ref.referenceName.endsWith('Controller')) { - const result = resolveVaporController(ref.referenceName, context); + const result = await resolveVaporController(ref.referenceName, context); if (result) { return { original: ref, @@ -299,7 +299,7 @@ export const vaporResolver: FrameworkResolver = { // Pattern 2: Model references (Fluent) if (/^[A-Z][a-zA-Z]+$/.test(ref.referenceName)) { - const result = resolveFluentModel(ref.referenceName, context); + const result = await resolveFluentModel(ref.referenceName, context); if (result) { return { original: ref, @@ -312,7 +312,7 @@ export const vaporResolver: FrameworkResolver = { // Pattern 3: Middleware references if (ref.referenceName.endsWith('Middleware')) { - const result = resolveVaporMiddleware(ref.referenceName, context); + const result = await resolveVaporMiddleware(ref.referenceName, context); if (result) { return { original: ref, @@ -384,13 +384,13 @@ export const vaporResolver: FrameworkResolver = { // Helper functions for SwiftUI -function resolveView(name: string, context: ResolutionContext): string | null { +async function resolveView(name: string, context: ResolutionContext): Promise { const viewDirs = ['Views', 'View', 'Screens', 'Components', 'UI']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && viewDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const viewNode = nodes.find( (n) => (n.kind === 'struct' || n.kind === 'component') && n.name === name ); @@ -403,7 +403,7 @@ function resolveView(name: string, context: ResolutionContext): string | null { // Search all Swift files for (const file of allFiles) { if (file.endsWith('.swift')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const viewNode = nodes.find( (n) => (n.kind === 'struct' || n.kind === 'component') && n.name === name ); @@ -416,13 +416,13 @@ function resolveView(name: string, context: ResolutionContext): string | null { return null; } -function resolveViewModel(name: string, context: ResolutionContext): string | null { +async function resolveViewModel(name: string, context: ResolutionContext): Promise { const vmDirs = ['ViewModels', 'ViewModel', 'Stores', 'Managers', 'Services']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && vmDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const vmNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -435,13 +435,13 @@ function resolveViewModel(name: string, context: ResolutionContext): string | nu return null; } -function resolveModel(name: string, context: ResolutionContext): string | null { +async function resolveModel(name: string, context: ResolutionContext): Promise { const modelDirs = ['Models', 'Model', 'Entities', 'Domain']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && modelDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => (n.kind === 'struct' || n.kind === 'class') && n.name === name ); @@ -456,13 +456,13 @@ function resolveModel(name: string, context: ResolutionContext): string | null { // Helper functions for UIKit -function resolveViewController(name: string, context: ResolutionContext): string | null { +async function resolveViewController(name: string, context: ResolutionContext): Promise { const vcDirs = ['ViewControllers', 'ViewController', 'Controllers', 'Screens']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && (vcDirs.some((d) => file.includes(`/${d}/`)) || file.includes(name))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const vcNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -475,13 +475,13 @@ function resolveViewController(name: string, context: ResolutionContext): string return null; } -function resolveUIView(name: string, context: ResolutionContext): string | null { +async function resolveUIView(name: string, context: ResolutionContext): Promise { const viewDirs = ['Views', 'View', 'UI', 'Components']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && viewDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const viewNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -494,13 +494,13 @@ function resolveUIView(name: string, context: ResolutionContext): string | null return null; } -function resolveCell(name: string, context: ResolutionContext): string | null { +async function resolveCell(name: string, context: ResolutionContext): Promise { const cellDirs = ['Cells', 'Cell', 'Views', 'TableViewCells', 'CollectionViewCells']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && cellDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const cellNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -513,12 +513,12 @@ function resolveCell(name: string, context: ResolutionContext): string | null { return null; } -function resolveProtocol(name: string, context: ResolutionContext): string | null { - const allFiles = context.getAllFiles(); +async function resolveProtocol(name: string, context: ResolutionContext): Promise { + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift')) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const protocolNode = nodes.find( (n) => n.kind === 'protocol' && n.name === name ); @@ -533,13 +533,13 @@ function resolveProtocol(name: string, context: ResolutionContext): string | nul // Helper functions for Vapor -function resolveVaporController(name: string, context: ResolutionContext): string | null { +async function resolveVaporController(name: string, context: ResolutionContext): Promise { const controllerDirs = ['Controllers', 'Controller', 'Routes']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && controllerDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const controllerNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'struct') && n.name === name ); @@ -552,13 +552,13 @@ function resolveVaporController(name: string, context: ResolutionContext): strin return null; } -function resolveFluentModel(name: string, context: ResolutionContext): string | null { +async function resolveFluentModel(name: string, context: ResolutionContext): Promise { const modelDirs = ['Models', 'Model', 'Entities', 'Database']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && modelDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const modelNode = nodes.find( (n) => n.kind === 'class' && n.name === name ); @@ -571,13 +571,13 @@ function resolveFluentModel(name: string, context: ResolutionContext): string | return null; } -function resolveVaporMiddleware(name: string, context: ResolutionContext): string | null { +async function resolveVaporMiddleware(name: string, context: ResolutionContext): Promise { const middlewareDirs = ['Middleware', 'Middlewares']; - const allFiles = context.getAllFiles(); + const allFiles = await context.getAllFiles(); for (const file of allFiles) { if (file.endsWith('.swift') && middlewareDirs.some((d) => file.includes(`/${d}/`))) { - const nodes = context.getNodesInFile(file); + const nodes = await context.getNodesInFile(file); const mwNode = nodes.find( (n) => (n.kind === 'class' || n.kind === 'struct') && n.name === name ); diff --git a/src/resolution/import-resolver.ts b/src/resolution/import-resolver.ts index a2c509dc..e7793216 100644 --- a/src/resolution/import-resolver.ts +++ b/src/resolution/import-resolver.ts @@ -438,12 +438,12 @@ export function clearImportMappingCache(): void { /** * Resolve a reference using import mappings */ -export function resolveViaImport( +export async function resolveViaImport( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { +): Promise { // Use cached import mappings (avoids re-reading and re-parsing per ref) - const imports = context.getImportMappings(ref.filePath, ref.language); + const imports = await context.getImportMappings(ref.filePath, ref.language); if (imports.length === 0 && !context.readFile(ref.filePath)) { return null; } @@ -461,7 +461,7 @@ export function resolveViaImport( if (resolvedPath) { // Find the exported symbol in the resolved file - const nodesInFile = context.getNodesInFile(resolvedPath); + const nodesInFile = await context.getNodesInFile(resolvedPath); const exportedName = imp.isDefault ? 'default' : imp.exportedName; // Look for the symbol diff --git a/src/resolution/index.ts b/src/resolution/index.ts index f4c10aab..c1a1f31b 100644 --- a/src/resolution/index.ts +++ b/src/resolution/index.ts @@ -39,6 +39,7 @@ export class ReferenceResolver { private importMappingCache: Map = new Map(); private knownFiles: Set | null = null; private cachesWarmed = false; + private initialized = false; constructor(projectRoot: string, queries: QueryBuilder) { this.projectRoot = projectRoot; @@ -49,8 +50,9 @@ export class ReferenceResolver { /** * Initialize the resolver (detect frameworks, etc.) */ - initialize(): void { - this.frameworks = detectFrameworks(this.context); + async initialize(): Promise { + this.frameworks = await detectFrameworks(this.context); + this.initialized = true; this.clearCaches(); } @@ -59,11 +61,16 @@ export class ReferenceResolver { * Node lookups are now handled by indexed SQLite queries instead of * loading all nodes into memory (which caused OOM on large codebases). */ - warmCaches(): void { + async warmCaches(): Promise { + // Ensure framework detection has run + if (!this.initialized) { + await this.initialize(); + } + if (this.cachesWarmed) return; // Only cache the set of known file paths (lightweight string set) - this.knownFiles = new Set(this.queries.getAllFilePaths()); + this.knownFiles = new Set(await this.queries.getAllFilePaths()); this.cachesWarmed = true; } @@ -84,22 +91,22 @@ export class ReferenceResolver { */ private createContext(): ResolutionContext { return { - getNodesInFile: (filePath: string) => { + getNodesInFile: async (filePath: string) => { if (!this.nodeCache.has(filePath)) { - this.nodeCache.set(filePath, this.queries.getNodesByFile(filePath)); + this.nodeCache.set(filePath, await this.queries.getNodesByFile(filePath)); } return this.nodeCache.get(filePath)!; }, - getNodesByName: (name: string) => { + getNodesByName: async (name: string) => { return this.queries.getNodesByName(name); }, - getNodesByQualifiedName: (qualifiedName: string) => { + getNodesByQualifiedName: async (qualifiedName: string) => { return this.queries.getNodesByQualifiedNameExact(qualifiedName); }, - getNodesByKind: (kind: Node['kind']) => { + getNodesByKind: async (kind: Node['kind']) => { return this.queries.getNodesByKind(kind); }, @@ -140,15 +147,15 @@ export class ReferenceResolver { getProjectRoot: () => this.projectRoot, - getAllFiles: () => { + getAllFiles: async () => { return this.queries.getAllFilePaths(); }, - getNodesByLowerName: (lowerName: string) => { + getNodesByLowerName: async (lowerName: string) => { return this.queries.getNodesByLowerName(lowerName); }, - getImportMappings: (filePath: string, language) => { + getImportMappings: async (filePath: string, language) => { const cacheKey = filePath; const cached = this.importMappingCache.get(cacheKey); if (cached) return cached; @@ -169,34 +176,37 @@ export class ReferenceResolver { /** * Resolve all unresolved references */ - resolveAll( + async resolveAll( unresolvedRefs: UnresolvedReference[], onProgress?: (current: number, total: number) => void - ): ResolutionResult { + ): Promise { // Pre-load all nodes into memory for fast lookups - this.warmCaches(); + await this.warmCaches(); const resolved: ResolvedRef[] = []; const unresolved: UnresolvedRef[] = []; const byMethod: Record = {}; // Convert to our internal format, using denormalized fields when available - const refs: UnresolvedRef[] = unresolvedRefs.map((ref) => ({ - fromNodeId: ref.fromNodeId, - referenceName: ref.referenceName, - referenceKind: ref.referenceKind, - line: ref.line, - column: ref.column, - filePath: ref.filePath || this.getFilePathFromNodeId(ref.fromNodeId), - language: ref.language || this.getLanguageFromNodeId(ref.fromNodeId), - })); + const refs: UnresolvedRef[] = []; + for (const ref of unresolvedRefs) { + refs.push({ + fromNodeId: ref.fromNodeId, + referenceName: ref.referenceName, + referenceKind: ref.referenceKind, + line: ref.line, + column: ref.column, + filePath: ref.filePath || await this.getFilePathFromNodeId(ref.fromNodeId), + language: ref.language || await this.getLanguageFromNodeId(ref.fromNodeId), + }); + } const total = refs.length; let lastReportedPercent = -1; for (let i = 0; i < refs.length; i++) { const ref = refs[i]!; // Array index is guaranteed to be in bounds - const result = this.resolveOne(ref); + const result = await this.resolveOne(ref); if (result) { resolved.push(result); @@ -235,7 +245,7 @@ export class ReferenceResolver { /** * Resolve a single reference */ - resolveOne(ref: UnresolvedRef): ResolvedRef | null { + async resolveOne(ref: UnresolvedRef): Promise { // Skip built-in/external references if (this.isBuiltInOrExternal(ref)) { return null; @@ -245,7 +255,7 @@ export class ReferenceResolver { // Strategy 1: Try framework-specific resolution for (const framework of this.frameworks) { - const result = framework.resolve(ref, this.context); + const result = await framework.resolve(ref, this.context); if (result) { if (result.confidence >= 0.9) return result; // High confidence, return immediately candidates.push(result); @@ -253,14 +263,14 @@ export class ReferenceResolver { } // Strategy 2: Try import-based resolution - const importResult = resolveViaImport(ref, this.context); + const importResult = await resolveViaImport(ref, this.context); if (importResult) { if (importResult.confidence >= 0.9) return importResult; candidates.push(importResult); } // Strategy 3: Try name matching - const nameResult = matchReference(ref, this.context); + const nameResult = await matchReference(ref, this.context); if (nameResult) { candidates.push(nameResult); } @@ -293,23 +303,23 @@ export class ReferenceResolver { /** * Resolve and persist edges to database */ - resolveAndPersist( + async resolveAndPersist( unresolvedRefs: UnresolvedReference[], onProgress?: (current: number, total: number) => void - ): ResolutionResult { - const result = this.resolveAll(unresolvedRefs, onProgress); + ): Promise { + const result = await this.resolveAll(unresolvedRefs, onProgress); // Create edges from resolved references const edges = this.createEdges(result.resolved); // Insert edges into database if (edges.length > 0) { - this.queries.insertEdges(edges); + await this.queries.insertEdges(edges); } // Clean up resolved refs from unresolved_refs table so metrics are accurate if (result.resolved.length > 0) { - this.queries.deleteSpecificResolvedReferences( + await this.queries.deleteSpecificResolvedReferences( result.resolved.map((r) => ({ fromNodeId: r.original.fromNodeId, referenceName: r.original.referenceName, @@ -326,13 +336,13 @@ export class ReferenceResolver { * Processes unresolved references in chunks, persisting edges and cleaning * up resolved refs after each batch to avoid accumulating large arrays. */ - resolveAndPersistBatched( + async resolveAndPersistBatched( onProgress?: (current: number, total: number) => void, batchSize: number = 5000 - ): ResolutionResult { - this.warmCaches(); + ): Promise { + await this.warmCaches(); - const total = this.queries.getUnresolvedReferencesCount(); + const total = await this.queries.getUnresolvedReferencesCount(); let processed = 0; const aggregateStats = { total: 0, @@ -344,20 +354,20 @@ export class ReferenceResolver { // Process in batches. We always read from offset 0 because resolved refs // are deleted after each batch, shifting the remaining rows forward. while (true) { - const batch = this.queries.getUnresolvedReferencesBatch(0, batchSize); + const batch = await this.queries.getUnresolvedReferencesBatch(0, batchSize); if (batch.length === 0) break; - const result = this.resolveAll(batch); + const result = await this.resolveAll(batch); // Persist edges immediately const edges = this.createEdges(result.resolved); if (edges.length > 0) { - this.queries.insertEdges(edges); + await this.queries.insertEdges(edges); } // Clean up resolved refs so they don't appear in the next batch if (result.resolved.length > 0) { - this.queries.deleteSpecificResolvedReferences( + await this.queries.deleteSpecificResolvedReferences( result.resolved.map((r) => ({ fromNodeId: r.original.fromNodeId, referenceName: r.original.referenceName, @@ -368,7 +378,7 @@ export class ReferenceResolver { // Delete unresolvable refs from this batch to avoid re-processing them if (result.unresolved.length > 0) { - this.queries.deleteSpecificResolvedReferences( + await this.queries.deleteSpecificResolvedReferences( result.unresolved.map((r) => ({ fromNodeId: r.fromNodeId, referenceName: r.referenceName, @@ -455,7 +465,7 @@ export class ReferenceResolver { const dotIdx = name.indexOf('.'); if (dotIdx > 0) { const receiver = name.substring(0, dotIdx); - // self.method and cls.method are internal calls, not built-in — let them resolve + // self.method and cls.method are internal calls, not built-in -- let them resolve // But receiver types that are built-in types should be filtered const pythonBuiltInTypes = new Set([ 'list', 'dict', 'set', 'tuple', 'str', 'int', 'float', 'bool', @@ -483,7 +493,7 @@ export class ReferenceResolver { // Pascal/Delphi built-ins and standard library units if (ref.language === 'pascal') { - // Standard RTL/VCL/FMX unit prefixes — these are external dependencies + // Standard RTL/VCL/FMX unit prefixes -- these are external dependencies const pascalUnitPrefixes = [ 'System.', 'Winapi.', 'Vcl.', 'Fmx.', 'Data.', 'Datasnap.', 'Soap.', 'Xml.', 'Web.', 'REST.', 'FireDAC.', 'IBX.', @@ -525,16 +535,16 @@ export class ReferenceResolver { /** * Get file path from node ID */ - private getFilePathFromNodeId(nodeId: string): string { - const node = this.queries.getNodeById(nodeId); + private async getFilePathFromNodeId(nodeId: string): Promise { + const node = await this.queries.getNodeById(nodeId); return node?.filePath || ''; } /** * Get language from node ID */ - private getLanguageFromNodeId(nodeId: string): UnresolvedRef['language'] { - const node = this.queries.getNodeById(nodeId); + private async getLanguageFromNodeId(nodeId: string): Promise { + const node = await this.queries.getNodeById(nodeId); return node?.language || 'unknown'; } } @@ -542,8 +552,8 @@ export class ReferenceResolver { /** * Create a reference resolver instance */ -export function createResolver(projectRoot: string, queries: QueryBuilder): ReferenceResolver { +export async function createResolver(projectRoot: string, queries: QueryBuilder): Promise { const resolver = new ReferenceResolver(projectRoot, queries); - resolver.initialize(); + await resolver.initialize(); return resolver; } diff --git a/src/resolution/name-matcher.ts b/src/resolution/name-matcher.ts index 7a508b92..9042b39b 100644 --- a/src/resolution/name-matcher.ts +++ b/src/resolution/name-matcher.ts @@ -10,11 +10,11 @@ import { UnresolvedRef, ResolvedRef, ResolutionContext } from './types'; /** * Try to resolve a reference by exact name match */ -export function matchByExactName( +export async function matchByExactName( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { - const candidates = context.getNodesByName(ref.referenceName); +): Promise { + const candidates = await context.getNodesByName(ref.referenceName); if (candidates.length === 0) { return null; @@ -51,16 +51,16 @@ export function matchByExactName( /** * Try to resolve by qualified name */ -export function matchByQualifiedName( +export async function matchByQualifiedName( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { +): Promise { // Check if the reference name looks qualified (contains :: or .) if (!ref.referenceName.includes('::') && !ref.referenceName.includes('.')) { return null; } - const candidates = context.getNodesByQualifiedName(ref.referenceName); + const candidates = await context.getNodesByQualifiedName(ref.referenceName); if (candidates.length === 1) { return { @@ -75,7 +75,7 @@ export function matchByQualifiedName( const parts = ref.referenceName.split(/[:.]/); const lastName = parts[parts.length - 1]; if (lastName) { - const partialCandidates = context.getNodesByName(lastName); + const partialCandidates = await context.getNodesByName(lastName); for (const candidate of partialCandidates) { if (candidate.qualifiedName.endsWith(ref.referenceName)) { return { @@ -94,10 +94,10 @@ export function matchByQualifiedName( /** * Try to resolve by method name on a class/object */ -export function matchMethodCall( +export async function matchMethodCall( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { +): Promise { // Parse method call patterns like "obj.method" or "Class::method" const dotMatch = ref.referenceName.match(/^(\w+)\.(\w+)$/); const colonMatch = ref.referenceName.match(/^(\w+)::(\w+)$/); @@ -110,14 +110,14 @@ export function matchMethodCall( const [, objectOrClass, methodName] = match; // Strategy 1: Direct class name match (existing logic) - const classCandidates = context.getNodesByName(objectOrClass!); + const classCandidates = await context.getNodesByName(objectOrClass!); for (const classNode of classCandidates) { if (classNode.kind === 'class' || classNode.kind === 'struct' || classNode.kind === 'interface') { // Skip cross-language class matches if (classNode.language !== ref.language) continue; - const nodesInFile = context.getNodesInFile(classNode.filePath); + const nodesInFile = await context.getNodesInFile(classNode.filePath); const methodNode = nodesInFile.find( (n) => n.kind === 'method' && @@ -140,13 +140,13 @@ export function matchMethodCall( // e.g., "permissionEngine" → look for classes containing "PermissionEngine" const capitalizedReceiver = objectOrClass!.charAt(0).toUpperCase() + objectOrClass!.slice(1); if (capitalizedReceiver !== objectOrClass) { - const fuzzyClassCandidates = context.getNodesByName(capitalizedReceiver); + const fuzzyClassCandidates = await context.getNodesByName(capitalizedReceiver); for (const classNode of fuzzyClassCandidates) { if (classNode.kind === 'class' || classNode.kind === 'struct' || classNode.kind === 'interface') { // Skip cross-language class matches if (classNode.language !== ref.language) continue; - const nodesInFile = context.getNodesInFile(classNode.filePath); + const nodesInFile = await context.getNodesInFile(classNode.filePath); const methodNode = nodesInFile.find( (n) => n.kind === 'method' && @@ -170,7 +170,7 @@ export function matchMethodCall( // name similarity with the containing class. Handles abbreviated variable // names like permissionEngine → PermissionRuleEngine. if (methodName) { - const methodCandidates = context.getNodesByName(methodName!); + const methodCandidates = await context.getNodesByName(methodName!); const methods = methodCandidates.filter( (n) => n.kind === 'method' && n.name === methodName ); @@ -320,14 +320,14 @@ function findBestMatch( /** * Fuzzy match - last resort with lower confidence */ -export function matchFuzzy( +export async function matchFuzzy( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { +): Promise { const lowerName = ref.referenceName.toLowerCase(); // Use pre-built lowercase index for O(1) lookup instead of scanning all nodes - const candidates = context.getNodesByLowerName(lowerName); + const candidates = await context.getNodesByLowerName(lowerName); // Filter to callable kinds only (function, method, class) const callableKinds = new Set(['function', 'method', 'class']); @@ -353,27 +353,27 @@ export function matchFuzzy( /** * Match all strategies in order of confidence */ -export function matchReference( +export async function matchReference( ref: UnresolvedRef, context: ResolutionContext -): ResolvedRef | null { +): Promise { // Try strategies in order of confidence let result: ResolvedRef | null; // 1. Qualified name match (highest confidence) - result = matchByQualifiedName(ref, context); + result = await matchByQualifiedName(ref, context); if (result) return result; // 2. Method call pattern - result = matchMethodCall(ref, context); + result = await matchMethodCall(ref, context); if (result) return result; // 3. Exact name match - result = matchByExactName(ref, context); + result = await matchByExactName(ref, context); if (result) return result; // 4. Fuzzy match (lowest confidence) - result = matchFuzzy(ref, context); + result = await matchFuzzy(ref, context); if (result) return result; return null; diff --git a/src/resolution/types.ts b/src/resolution/types.ts index c28c9584..d6d5d83e 100644 --- a/src/resolution/types.ts +++ b/src/resolution/types.ts @@ -64,13 +64,13 @@ export interface ResolutionResult { */ export interface ResolutionContext { /** Get all nodes in a file */ - getNodesInFile(filePath: string): Node[]; + getNodesInFile(filePath: string): Promise; /** Get all nodes by name */ - getNodesByName(name: string): Node[]; + getNodesByName(name: string): Promise; /** Get all nodes by qualified name */ - getNodesByQualifiedName(qualifiedName: string): Node[]; + getNodesByQualifiedName(qualifiedName: string): Promise; /** Get all nodes of a kind */ - getNodesByKind(kind: Node['kind']): Node[]; + getNodesByKind(kind: Node['kind']): Promise; /** Check if a file exists */ fileExists(filePath: string): boolean; /** Read file content */ @@ -78,11 +78,11 @@ export interface ResolutionContext { /** Get project root */ getProjectRoot(): string; /** Get all files */ - getAllFiles(): string[]; + getAllFiles(): Promise; /** Get nodes by lowercase name (O(1) lookup for fuzzy matching) */ - getNodesByLowerName(lowerName: string): Node[]; + getNodesByLowerName(lowerName: string): Promise; /** Get cached import mappings for a file */ - getImportMappings(filePath: string, language: Language): ImportMapping[]; + getImportMappings(filePath: string, language: Language): Promise; } /** @@ -92,9 +92,9 @@ export interface FrameworkResolver { /** Framework name */ name: string; /** Detect if project uses this framework */ - detect(context: ResolutionContext): boolean; + detect(context: ResolutionContext): Promise; /** Resolve a reference using framework-specific patterns */ - resolve(ref: UnresolvedRef, context: ResolutionContext): ResolvedRef | null; + resolve(ref: UnresolvedRef, context: ResolutionContext): Promise; /** Extract additional nodes specific to this framework */ extractNodes?(filePath: string, content: string): Node[]; } diff --git a/src/types.ts b/src/types.ts index 6834483d..48b17ab8 100644 --- a/src/types.ts +++ b/src/types.ts @@ -474,6 +474,42 @@ export interface CodeGraphConfig { /** Node kind to assign */ kind: NodeKind; }[]; + + /** Vector store backend configuration (optional, default: sqlite) */ + /** Database backend configuration */ + database?: { + /** Backend type: 'sqlite' (default) or 'postgres' */ + backend: 'sqlite' | 'postgres'; + + /** PostgreSQL connection string. Can also use CODEGRAPH_PG_URL env var. */ + connectionString?: string; + + /** Connection pool size for PostgreSQL (default: 10) */ + poolSize?: number; + + /** Table name prefix for PostgreSQL (default: '') */ + tablePrefix?: string; + }; + + vectorStore?: { + /** Backend type: 'sqlite' (default) or 'pgvector' */ + backend: 'sqlite' | 'pgvector'; + + /** PostgreSQL connection string (pgvector only). Can also use CODEGRAPH_PG_URL env var. */ + connectionString?: string; + + /** Index type for pgvector: 'hnsw' (default), 'ivfflat', or 'none' */ + indexType?: 'hnsw' | 'ivfflat' | 'none'; + + /** Distance metric for pgvector: 'cosine' (default), 'l2', or 'inner_product' */ + distanceMetric?: 'cosine' | 'l2' | 'inner_product'; + + /** Connection pool size for pgvector (default: 5) */ + poolSize?: number; + + /** Table name prefix for pgvector (default: 'codegraph_') */ + tablePrefix?: string; + }; } /** diff --git a/src/vectors/index.ts b/src/vectors/index.ts index 06cbbe93..6d2b6079 100644 --- a/src/vectors/index.ts +++ b/src/vectors/index.ts @@ -2,6 +2,7 @@ * Vectors Module * * Provides text embedding and vector similarity search for semantic code search. + * Supports SQLite (default) and PostgreSQL (pgvector) backends. */ export { @@ -14,10 +15,24 @@ export { BatchEmbeddingResult, } from './embedder'; +export { + VectorStore, + VectorSearchOptions, + VectorSearchResult, +} from './store'; + +export { SqliteVectorStore } from './sqlite-store'; + +export { + VectorStoreConfig, + DEFAULT_VECTOR_STORE_CONFIG, + createVectorStore, +} from './store-factory'; + +// Backward-compatible re-exports export { VectorSearchManager, createVectorSearch, - VectorSearchOptions, } from './search'; export { diff --git a/src/vectors/manager.ts b/src/vectors/manager.ts index 3f8ac3d7..3fe36338 100644 --- a/src/vectors/manager.ts +++ b/src/vectors/manager.ts @@ -4,10 +4,9 @@ * High-level manager that coordinates embedding generation and vector search. */ -import { SqliteDatabase } from '../db/sqlite-adapter'; import { Node, SearchResult, SearchOptions } from '../types'; -import { TextEmbedder, createEmbedder, EmbedderOptions, EMBEDDING_DIMENSION } from './embedder'; -import { VectorSearchManager, createVectorSearch } from './search'; +import { TextEmbedder, createEmbedder, EmbedderOptions } from './embedder'; +import { VectorStore } from './store'; import { QueryBuilder } from '../db/queries'; /** @@ -56,24 +55,24 @@ const DEFAULT_NODE_KINDS: Node['kind'][] = [ * * Provides high-level interface for semantic search: * - Generates embeddings for code nodes - * - Stores embeddings in the database + * - Stores embeddings in the configured backend * - Performs semantic similarity search */ export class VectorManager { private embedder: TextEmbedder; - private searchManager: VectorSearchManager; + private store: VectorStore; private queries: QueryBuilder; private nodeKinds: Node['kind'][]; private batchSize: number; private initialized = false; constructor( - db: SqliteDatabase, + store: VectorStore, queries: QueryBuilder, options: VectorManagerOptions = {} ) { this.embedder = createEmbedder(options.embedder); - this.searchManager = createVectorSearch(db, EMBEDDING_DIMENSION); + this.store = store; this.queries = queries; this.nodeKinds = options.nodeKinds || DEFAULT_NODE_KINDS; this.batchSize = options.batchSize || 32; @@ -82,7 +81,7 @@ export class VectorManager { /** * Initialize the vector manager * - * Loads the embedding model and initializes vector search. + * Loads the embedding model and initializes the vector store. */ async initialize(): Promise { if (this.initialized) { @@ -92,8 +91,8 @@ export class VectorManager { // Initialize embedder (downloads model if needed) await this.embedder.initialize(); - // Initialize vector search (loads sqlite-vss if available) - await this.searchManager.initialize(); + // Initialize vector store + await this.store.initialize(); this.initialized = true; } @@ -119,12 +118,12 @@ export class VectorManager { // Get all nodes that should be embedded const nodesToEmbed: Node[] = []; for (const kind of this.nodeKinds) { - const nodes = this.queries.getNodesByKind(kind); + const nodes = await this.queries.getNodesByKind(kind); nodesToEmbed.push(...nodes); } // Filter out nodes that already have embeddings - const existingIds = new Set(this.searchManager.getIndexedNodeIds()); + const existingIds = new Set(await this.store.getIndexedNodeIds()); const newNodes = nodesToEmbed.filter((n) => !existingIds.has(n.id)); if (newNodes.length === 0) { @@ -153,7 +152,7 @@ export class VectorManager { entries.push({ nodeId: node.id, embedding }); } } - this.searchManager.storeVectorBatch(entries, model); + await this.store.storeVectorBatch(entries, model); processed += batch.length; @@ -182,7 +181,7 @@ export class VectorManager { const text = TextEmbedder.createNodeText(node); const result = await this.embedder.embed(text); - this.searchManager.storeVector(node.id, result.embedding, result.model); + await this.store.storeVector(node.id, result.embedding, result.model); } /** @@ -203,7 +202,7 @@ export class VectorManager { const queryResult = await this.embedder.embedQuery(query); // Search for similar vectors - const vectorResults = this.searchManager.search(queryResult.embedding, { + const vectorResults = await this.store.search(queryResult.embedding, { limit: limit * 2, // Get more results to filter minScore: 0.3, // Minimum similarity threshold }); @@ -211,7 +210,7 @@ export class VectorManager { // Get nodes and filter by kind if specified const results: SearchResult[] = []; for (const vr of vectorResults) { - const node = this.queries.getNodeById(vr.nodeId); + const node = await this.queries.getNodeById(vr.nodeId); if (!node) { continue; } @@ -249,17 +248,17 @@ export class VectorManager { const { limit = 10, kinds } = options; // Get the node's embedding - let embedding = this.searchManager.getVector(nodeId); + let embedding = await this.store.getVector(nodeId); // If no embedding exists, generate one if (!embedding) { - const node = this.queries.getNodeById(nodeId); + const node = await this.queries.getNodeById(nodeId); if (!node) { throw new Error(`Node not found: ${nodeId}`); } await this.embedNode(node); - embedding = this.searchManager.getVector(nodeId); + embedding = await this.store.getVector(nodeId); if (!embedding) { throw new Error(`Failed to generate embedding for node: ${nodeId}`); @@ -267,7 +266,7 @@ export class VectorManager { } // Search for similar vectors (excluding the source node) - const vectorResults = this.searchManager.search(embedding, { + const vectorResults = await this.store.search(embedding, { limit: limit + 1, // Get one extra to exclude the source minScore: 0.3, }); @@ -280,7 +279,7 @@ export class VectorManager { continue; } - const node = this.queries.getNodeById(vr.nodeId); + const node = await this.queries.getNodeById(vr.nodeId); if (!node) { continue; } @@ -308,22 +307,27 @@ export class VectorManager { * * @param nodeId - ID of the node */ - deleteNodeEmbedding(nodeId: string): void { - this.searchManager.deleteVector(nodeId); + async deleteNodeEmbedding(nodeId: string): Promise { + await this.store.deleteVector(nodeId); } /** * Get statistics about vector storage */ - getStats(): { + async getStats(): Promise<{ totalVectors: number; vssEnabled: boolean; + annEnabled: boolean; + backend: 'sqlite' | 'pgvector'; modelId: string; dimension: number; - } { + }> { + const annEnabled = this.store.isAnnEnabled(); return { - totalVectors: this.searchManager.getVectorCount(), - vssEnabled: this.searchManager.isVssEnabled(), + totalVectors: await this.store.getVectorCount(), + vssEnabled: annEnabled, // backward compat + annEnabled, + backend: this.store.backendType, modelId: this.embedder.getModelId(), dimension: this.embedder.getDimension(), }; @@ -332,22 +336,23 @@ export class VectorManager { /** * Clear all vectors */ - clear(): void { - this.searchManager.clear(); + async clear(): Promise { + await this.store.clear(); } /** - * Rebuild the VSS index + * Rebuild the vector index */ - rebuildIndex(): void { - this.searchManager.rebuildVssIndex(); + async rebuildIndex(): Promise { + await this.store.rebuildIndex(); } /** * Release resources */ - dispose(): void { + async dispose(): Promise { this.embedder.dispose(); + await this.store.dispose(); } } @@ -355,9 +360,9 @@ export class VectorManager { * Create a vector manager */ export function createVectorManager( - db: SqliteDatabase, + store: VectorStore, queries: QueryBuilder, options?: VectorManagerOptions ): VectorManager { - return new VectorManager(db, queries, options); + return new VectorManager(store, queries, options); } diff --git a/src/vectors/pg-store.ts b/src/vectors/pg-store.ts new file mode 100644 index 00000000..738a06aa --- /dev/null +++ b/src/vectors/pg-store.ts @@ -0,0 +1,388 @@ +/** + * PostgreSQL Vector Store (pgvector) + * + * Vector storage backend using PostgreSQL with the pgvector extension. + * Provides HNSW/IVFFlat indexes for fast approximate nearest neighbor search. + */ + +import { VectorStore, VectorSearchOptions, VectorSearchResult } from './store'; +import { EMBEDDING_DIMENSION } from './embedder'; + +/** Only allow safe SQL identifiers: letters, digits, underscores, starting with letter/underscore */ +const SAFE_IDENTIFIER = /^[a-zA-Z_][a-zA-Z0-9_]{0,50}$/; + +/** + * Options for the PostgreSQL vector store + */ +export interface PgVectorStoreOptions { + /** PostgreSQL connection string */ + connectionString: string; + + /** Vector dimension (default: 768 for nomic-embed-text-v1.5) */ + dimension?: number; + + /** Index type: 'hnsw' (default), 'ivfflat', or 'none' */ + indexType?: 'hnsw' | 'ivfflat' | 'none'; + + /** Distance metric: 'cosine' (default), 'l2', or 'inner_product' */ + distanceMetric?: 'cosine' | 'l2' | 'inner_product'; + + /** Connection pool size (default: 5) */ + poolSize?: number; + + /** Table name prefix (default: 'codegraph_') */ + tablePrefix?: string; +} + +/** pgvector operator for each distance metric */ +const DISTANCE_OPERATORS: Record = { + cosine: '<=>', + l2: '<->', + inner_product: '<#>', +}; + +/** pgvector index ops class for each distance metric */ +const INDEX_OPS: Record = { + cosine: 'vector_cosine_ops', + l2: 'vector_l2_ops', + inner_product: 'vector_ip_ops', +}; + +/** + * Convert a Float32Array to pgvector string format: '[0.1,0.2,...]' + */ +function toVectorString(embedding: Float32Array): string { + return '[' + Array.from(embedding).join(',') + ']'; +} + +/** + * Convert a pgvector string back to Float32Array + */ +function fromVectorString(str: string): Float32Array { + const values = str.slice(1, -1).split(',').map(Number); + return new Float32Array(values); +} + +/** + * PostgreSQL vector store using pgvector extension + * + * Requires PostgreSQL with pgvector extension installed. + * Provides production-grade HNSW indexes for fast ANN search. + */ +export class PgVectorStore implements VectorStore { + readonly backendType = 'pgvector' as const; + private pool: any; // pg.Pool + private options: Required; + private tableName: string; + private _initialized = false; + + constructor(options: PgVectorStoreOptions) { + const prefix = options.tablePrefix ?? 'codegraph_'; + if (!SAFE_IDENTIFIER.test(prefix)) { + throw new Error( + `tablePrefix must be a safe SQL identifier (letters, digits, underscores, max 50 chars). Got: "${prefix}"` + ); + } + + this.options = { + connectionString: options.connectionString, + dimension: options.dimension ?? EMBEDDING_DIMENSION, + indexType: options.indexType ?? 'hnsw', + distanceMetric: options.distanceMetric ?? 'cosine', + poolSize: options.poolSize ?? 5, + tablePrefix: prefix, + }; + this.tableName = `${prefix}vectors`; + } + + async initialize(): Promise { + if (this._initialized) { + return; + } + + // Dynamically import pg + let pg: any; + try { + pg = await import('pg'); + } catch { + throw new Error( + 'The "pg" package is required for pgvector backend. Install it with: npm install pg' + ); + } + + const Pool = pg.default?.Pool ?? pg.Pool; + const pool = new Pool({ + connectionString: this.options.connectionString, + max: this.options.poolSize, + }); + + try { + // Test connection + let client: any; + try { + client = await pool.connect(); + } catch (error: any) { + throw new Error( + `Failed to connect to PostgreSQL: ${error.message}. ` + + 'Verify your connection string and ensure the database is running.' + ); + } + + try { + // Enable pgvector extension + try { + await client.query('CREATE EXTENSION IF NOT EXISTS vector'); + } catch (error: any) { + throw new Error( + `pgvector extension not available: ${error.message}. ` + + 'Install pgvector on your PostgreSQL server: https://github.com/pgvector/pgvector' + ); + } + + // Create vectors table + await client.query(` + CREATE TABLE IF NOT EXISTS ${this.tableName} ( + node_id TEXT PRIMARY KEY, + embedding vector(${this.options.dimension}) NOT NULL, + model TEXT NOT NULL, + created_at BIGINT NOT NULL + ) + `); + + // Create index based on configured type + await this.createIndex(client); + } finally { + client.release(); + } + } catch (error) { + // Clean up pool on any setup failure to prevent leaks + await pool.end().catch(() => {}); + throw error; + } + + this.pool = pool; + this._initialized = true; + } + + private async createIndex(client: any): Promise { + const { indexType, distanceMetric } = this.options; + const ops = INDEX_OPS[distanceMetric]; + const indexName = `${this.tableName}_embedding_idx`; + + if (indexType === 'none') { + return; + } + + // Check if index already exists + const indexCheck = await client.query( + `SELECT 1 FROM pg_indexes WHERE indexname = $1`, + [indexName] + ); + if (indexCheck.rows.length > 0) { + return; + } + + if (indexType === 'hnsw') { + await client.query(` + CREATE INDEX ${indexName} + ON ${this.tableName} + USING hnsw (embedding ${ops}) + WITH (m = 16, ef_construction = 64) + `); + } else if (indexType === 'ivfflat') { + const countResult = await client.query(`SELECT COUNT(*) as count FROM ${this.tableName}`); + const count = parseInt(countResult.rows[0].count, 10); + const lists = Math.max(1, Math.floor(Math.sqrt(count))); + + await client.query(` + CREATE INDEX ${indexName} + ON ${this.tableName} + USING ivfflat (embedding ${ops}) + WITH (lists = ${lists}) + `); + } + } + + async storeVector(nodeId: string, embedding: Float32Array, model: string): Promise { + const vectorStr = toVectorString(embedding); + const now = Date.now(); + + await this.pool.query( + `INSERT INTO ${this.tableName} (node_id, embedding, model, created_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (node_id) DO UPDATE SET embedding = $2, model = $3, created_at = $4`, + [nodeId, vectorStr, model, now] + ); + } + + async storeVectorBatch( + entries: Array<{ nodeId: string; embedding: Float32Array }>, + model: string + ): Promise { + if (entries.length === 0) return; + + const now = Date.now(); + const client = await this.pool.connect(); + + try { + await client.query('BEGIN'); + + // Batch in chunks of 100 to avoid parameter limit + const chunkSize = 100; + for (let i = 0; i < entries.length; i += chunkSize) { + const chunk = entries.slice(i, i + chunkSize); + const values: any[] = []; + const placeholders: string[] = []; + let paramIdx = 1; + + for (const entry of chunk) { + placeholders.push(`($${paramIdx}, $${paramIdx + 1}, $${paramIdx + 2}, $${paramIdx + 3})`); + values.push(entry.nodeId, toVectorString(entry.embedding), model, now); + paramIdx += 4; + } + + await client.query( + `INSERT INTO ${this.tableName} (node_id, embedding, model, created_at) + VALUES ${placeholders.join(', ')} + ON CONFLICT (node_id) DO UPDATE SET + embedding = EXCLUDED.embedding, + model = EXCLUDED.model, + created_at = EXCLUDED.created_at`, + values + ); + } + + await client.query('COMMIT'); + } catch (error) { + await client.query('ROLLBACK'); + throw error; + } finally { + client.release(); + } + } + + async getVector(nodeId: string): Promise { + const result = await this.pool.query( + `SELECT embedding::text FROM ${this.tableName} WHERE node_id = $1`, + [nodeId] + ); + + if (result.rows.length === 0) { + return null; + } + + return fromVectorString(result.rows[0].embedding); + } + + async hasVector(nodeId: string): Promise { + const result = await this.pool.query( + `SELECT 1 FROM ${this.tableName} WHERE node_id = $1 LIMIT 1`, + [nodeId] + ); + return result.rows.length > 0; + } + + async getIndexedNodeIds(): Promise { + const result = await this.pool.query( + `SELECT node_id FROM ${this.tableName}` + ); + return result.rows.map((r: any) => r.node_id); + } + + async search( + queryEmbedding: Float32Array, + options: VectorSearchOptions = {} + ): Promise { + const { limit = 10, minScore = 0 } = options; + const vectorStr = toVectorString(queryEmbedding); + const operator = DISTANCE_OPERATORS[this.options.distanceMetric]; + + // Compute score expression once per row via subquery. + // All distance operators return values where lower = more similar, + // so ORDER BY distance ASC gives the best matches first. + let scoreExpr: string; + if (this.options.distanceMetric === 'cosine') { + // <=> returns cosine distance in [0,2]; similarity = 1 - distance + scoreExpr = `1 - (embedding ${operator} $1::vector)`; + } else if (this.options.distanceMetric === 'inner_product') { + // <#> returns negative inner product; negate for similarity score + scoreExpr = `-(embedding ${operator} $1::vector)`; + } else { + // <-> returns L2 distance; convert to similarity + scoreExpr = `1.0 / (1.0 + (embedding ${operator} $1::vector))`; + } + + const result = await this.pool.query( + `SELECT node_id, score FROM ( + SELECT node_id, ${scoreExpr} AS score + FROM ${this.tableName} + ORDER BY embedding ${operator} $1::vector + LIMIT $3 + ) sub + WHERE score >= $2 + ORDER BY score DESC`, + [vectorStr, minScore, limit] + ); + + return result.rows.map((row: any) => ({ + nodeId: row.node_id, + score: parseFloat(row.score), + })); + } + + async deleteVector(nodeId: string): Promise { + await this.pool.query( + `DELETE FROM ${this.tableName} WHERE node_id = $1`, + [nodeId] + ); + } + + async clear(): Promise { + await this.pool.query(`DELETE FROM ${this.tableName}`); + } + + async getVectorCount(): Promise { + const result = await this.pool.query( + `SELECT COUNT(*) as count FROM ${this.tableName}` + ); + return parseInt(result.rows[0].count, 10); + } + + isAnnEnabled(): boolean { + return this.options.indexType !== 'none'; + } + + async rebuildIndex(): Promise { + const indexName = `${this.tableName}_embedding_idx`; + const ops = INDEX_OPS[this.options.distanceMetric]; + + await this.pool.query(`DROP INDEX IF EXISTS ${indexName}`); + + if (this.options.indexType === 'hnsw') { + await this.pool.query(` + CREATE INDEX ${indexName} + ON ${this.tableName} + USING hnsw (embedding ${ops}) + WITH (m = 16, ef_construction = 64) + `); + } else if (this.options.indexType === 'ivfflat') { + const countResult = await this.pool.query(`SELECT COUNT(*) as count FROM ${this.tableName}`); + const count = parseInt(countResult.rows[0].count, 10); + const lists = Math.max(1, Math.floor(Math.sqrt(count))); + + await this.pool.query(` + CREATE INDEX ${indexName} + ON ${this.tableName} + USING ivfflat (embedding ${ops}) + WITH (lists = ${lists}) + `); + } + } + + async dispose(): Promise { + if (this.pool) { + await this.pool.end(); + this.pool = null; + } + } +} diff --git a/src/vectors/search.ts b/src/vectors/search.ts index bdde94d3..ca48e826 100644 --- a/src/vectors/search.ts +++ b/src/vectors/search.ts @@ -1,472 +1,32 @@ /** * Vector Search * - * Provides vector similarity search using sqlite-vss extension. - * Falls back to brute-force cosine similarity if sqlite-vss is not available. + * Re-exports from sqlite-store for backward compatibility. + * New code should import from './store' and './sqlite-store' directly. + * + * @deprecated Use SqliteVectorStore from './sqlite-store' instead of VectorSearchManager */ import { SqliteDatabase } from '../db/sqlite-adapter'; -import { Node } from '../types'; -import { TextEmbedder, EMBEDDING_DIMENSION } from './embedder'; - -/** - * Options for vector search - */ -export interface VectorSearchOptions { - /** Maximum number of results to return */ - limit?: number; +import { SqliteVectorStore } from './sqlite-store'; - /** Minimum similarity score (0-1) */ - minScore?: number; - - /** Node kinds to filter results */ - nodeKinds?: Node['kind'][]; -} +// Re-export types from store.ts +export { VectorSearchOptions } from './store'; /** - * Vector Search Manager - * - * Handles vector storage and similarity search for semantic code search. + * @deprecated Use SqliteVectorStore instead */ -export class VectorSearchManager { - private db: SqliteDatabase; - private vssEnabled = false; - private embeddingDimension: number; - - constructor(db: SqliteDatabase, dimension: number = EMBEDDING_DIMENSION) { - this.db = db; - this.embeddingDimension = dimension; - } - - /** - * Initialize vector search - * - * Attempts to load sqlite-vss extension. Falls back to brute-force - * search if the extension is not available. - */ - async initialize(): Promise { - try { - // Try to load sqlite-vss extension - await this.loadVssExtension(); - this.vssEnabled = true; - console.log('sqlite-vss extension loaded successfully'); - - // Create the VSS virtual table - this.createVssTable(); - } catch (error) { - // Fall back to brute-force search - console.warn( - 'sqlite-vss extension not available, falling back to brute-force search:', - error instanceof Error ? error.message : String(error) - ); - this.vssEnabled = false; - } - - // Ensure the vectors table exists (for both VSS and fallback modes) - this.ensureVectorsTable(); - } - - /** - * Load the sqlite-vss extension - */ - private async loadVssExtension(): Promise { - try { - // The sqlite-vss npm package provides functions to load extensions - const vss = await import('sqlite-vss'); - - // Use the load function which loads both vector0 and vss0 - // VSS extension expects the raw better-sqlite3 Database instance - if (typeof vss.load === 'function') { - vss.load(this.db as any); - } else if (typeof vss.default?.load === 'function') { - vss.default.load(this.db as any); - } else { - throw new Error('sqlite-vss load function not found'); - } - } catch (error) { - throw new Error(`Failed to load sqlite-vss: ${error instanceof Error ? error.message : String(error)}`); - } - } - - /** - * Create the VSS virtual table for vector search - */ - private createVssTable(): void { - // Check if the table already exists - const tableExists = this.db - .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='vss_vectors'") - .get(); - - if (!tableExists) { - // Create VSS virtual table - // vss0 is the vector search extension - this.db.exec(` - CREATE VIRTUAL TABLE IF NOT EXISTS vss_vectors USING vss0( - embedding(${this.embeddingDimension}) - ); - `); - - // Create mapping table to link VSS rowids to node IDs - this.db.exec(` - CREATE TABLE IF NOT EXISTS vss_map ( - rowid INTEGER PRIMARY KEY, - node_id TEXT NOT NULL UNIQUE - ); - `); - - // Create index on node_id - this.db.exec(` - CREATE INDEX IF NOT EXISTS idx_vss_map_node ON vss_map(node_id); - `); - } - } - - /** - * Ensure the basic vectors table exists (for fallback mode) - */ - private ensureVectorsTable(): void { - this.db.exec(` - CREATE TABLE IF NOT EXISTS vectors ( - node_id TEXT PRIMARY KEY, - embedding BLOB NOT NULL, - model TEXT NOT NULL, - created_at INTEGER NOT NULL - ); - `); - } - - /** - * Check if VSS extension is enabled - */ - isVssEnabled(): boolean { - return this.vssEnabled; - } - - /** - * Store a vector embedding for a node - * - * @param nodeId - ID of the node - * @param embedding - Vector embedding - * @param model - Model used to generate embedding - */ - storeVector(nodeId: string, embedding: Float32Array, model: string): void { - const now = Date.now(); - - // Store in the vectors table (always, for persistence) - const blob = Buffer.from(embedding.buffer); - this.db - .prepare( - ` - INSERT OR REPLACE INTO vectors (node_id, embedding, model, created_at) - VALUES (?, ?, ?, ?) - ` - ) - .run(nodeId, blob, model, now); - - // Also store in VSS table if enabled - if (this.vssEnabled) { - this.storeInVss(nodeId, embedding); - } - } - - /** - * Store vector in VSS virtual table - */ - private storeInVss(nodeId: string, embedding: Float32Array): void { - try { - // Check if already exists - const existing = this.db - .prepare('SELECT rowid FROM vss_map WHERE node_id = ?') - .get(nodeId) as { rowid: number } | undefined; - - if (existing) { - // Update existing vector - const vectorJson = JSON.stringify(Array.from(embedding)); - this.db - .prepare('UPDATE vss_vectors SET embedding = ? WHERE rowid = ?') - .run(vectorJson, existing.rowid); - } else { - // Insert new vector - get max rowid and increment - const maxRow = this.db - .prepare('SELECT MAX(rowid) as max FROM vss_map') - .get() as { max: number | null } | undefined; - const newRowid = (maxRow?.max ?? 0) + 1; - - const vectorJson = JSON.stringify(Array.from(embedding)); - this.db - .prepare('INSERT INTO vss_vectors (rowid, embedding) VALUES (?, ?)') - .run(newRowid, vectorJson); - - // Map the rowid to node_id - this.db - .prepare('INSERT INTO vss_map (rowid, node_id) VALUES (?, ?)') - .run(newRowid, nodeId); - } - } catch (error) { - // VSS operations can fail for various reasons (dimension mismatch, etc.) - // Fall back to brute-force search silently - console.warn( - 'VSS storage failed, using brute-force search:', - error instanceof Error ? error.message : String(error) - ); - } - } - - /** - * Store multiple vectors in a batch - * - * @param entries - Array of node IDs and embeddings - * @param model - Model used to generate embeddings - */ - storeVectorBatch( - entries: Array<{ nodeId: string; embedding: Float32Array }>, - model: string - ): void { - const now = Date.now(); - - // Use a transaction for better performance - this.db.transaction(() => { - for (const entry of entries) { - const blob = Buffer.from(entry.embedding.buffer); - this.db - .prepare( - ` - INSERT OR REPLACE INTO vectors (node_id, embedding, model, created_at) - VALUES (?, ?, ?, ?) - ` - ) - .run(entry.nodeId, blob, model, now); - - if (this.vssEnabled) { - this.storeInVss(entry.nodeId, entry.embedding); - } - } - })(); - } - - /** - * Get vector for a node - * - * @param nodeId - ID of the node - * @returns Embedding or null if not found - */ - getVector(nodeId: string): Float32Array | null { - const row = this.db - .prepare('SELECT embedding FROM vectors WHERE node_id = ?') - .get(nodeId) as { embedding: Buffer } | undefined; - - if (!row) { - return null; - } - - return new Float32Array(row.embedding.buffer.slice( - row.embedding.byteOffset, - row.embedding.byteOffset + row.embedding.byteLength - )); - } - - /** - * Delete vector for a node - * - * @param nodeId - ID of the node - */ - deleteVector(nodeId: string): void { - this.db.prepare('DELETE FROM vectors WHERE node_id = ?').run(nodeId); - - if (this.vssEnabled) { - // Get the rowid before deleting - const mapping = this.db - .prepare('SELECT rowid FROM vss_map WHERE node_id = ?') - .get(nodeId) as { rowid: number } | undefined; - - if (mapping) { - this.db.prepare('DELETE FROM vss_vectors WHERE rowid = ?').run(mapping.rowid); - this.db.prepare('DELETE FROM vss_map WHERE node_id = ?').run(nodeId); - } - } - } - - /** - * Search for similar vectors - * - * @param queryEmbedding - Query vector to search for - * @param options - Search options - * @returns Array of node IDs with similarity scores - */ - search( - queryEmbedding: Float32Array, - options: VectorSearchOptions = {} - ): Array<{ nodeId: string; score: number }> { - const { limit = 10, minScore = 0 } = options; - - if (this.vssEnabled) { - return this.searchWithVss(queryEmbedding, limit, minScore); - } else { - return this.searchBruteForce(queryEmbedding, limit, minScore); - } - } - - /** - * Search using sqlite-vss KNN search - */ - private searchWithVss( - queryEmbedding: Float32Array, - limit: number, - minScore: number - ): Array<{ nodeId: string; score: number }> { - try { - const vectorJson = JSON.stringify(Array.from(queryEmbedding)); - // Sanitize limit to prevent SQL injection (ensure it's a positive integer) - const safeLimit = Math.max(1, Math.floor(limit)); - - // Use VSS KNN search - // The distance is L2 (euclidean), we need to convert to similarity score - // Note: sqlite-vss requires LIMIT to be a literal, not a parameter - const rows = this.db - .prepare( - ` - SELECT m.node_id, v.distance - FROM ( - SELECT rowid, distance - FROM vss_vectors - WHERE vss_search(embedding, ?) - LIMIT ${safeLimit} - ) v - JOIN vss_map m ON m.rowid = v.rowid - ` - ) - .all(vectorJson) as Array<{ node_id: string; distance: number }>; - - // Convert L2 distance to similarity score (1 / (1 + distance)) - return rows - .map((row) => ({ - nodeId: row.node_id, - score: 1 / (1 + row.distance), - })) - .filter((r) => r.score >= minScore); - } catch (error) { - // VSS search failed, fall back to brute force - console.warn( - 'VSS search failed, using brute-force:', - error instanceof Error ? error.message : String(error) - ); - return this.searchBruteForce(queryEmbedding, limit, minScore); - } - } - - /** - * Brute-force search using cosine similarity - */ - private searchBruteForce( - queryEmbedding: Float32Array, - limit: number, - minScore: number - ): Array<{ nodeId: string; score: number }> { - // Get all vectors - const rows = this.db - .prepare('SELECT node_id, embedding FROM vectors') - .all() as Array<{ node_id: string; embedding: Buffer }>; - - // Calculate cosine similarity for each - const results: Array<{ nodeId: string; score: number }> = []; - - for (const row of rows) { - const embedding = new Float32Array(row.embedding.buffer.slice( - row.embedding.byteOffset, - row.embedding.byteOffset + row.embedding.byteLength - )); - - const score = TextEmbedder.cosineSimilarity(queryEmbedding, embedding); - - if (score >= minScore) { - results.push({ nodeId: row.node_id, score }); - } - } - - // Sort by score descending and limit - results.sort((a, b) => b.score - a.score); - return results.slice(0, limit); - } - - /** - * Get count of stored vectors - */ - getVectorCount(): number { - const result = this.db - .prepare('SELECT COUNT(*) as count FROM vectors') - .get() as { count: number }; - return result.count; - } - - /** - * Check if a node has a vector - */ - hasVector(nodeId: string): boolean { - const result = this.db - .prepare('SELECT 1 FROM vectors WHERE node_id = ? LIMIT 1') - .get(nodeId); - return !!result; - } - - /** - * Get all node IDs that have vectors - */ - getIndexedNodeIds(): string[] { - const rows = this.db - .prepare('SELECT node_id FROM vectors') - .all() as Array<{ node_id: string }>; - return rows.map((r) => r.node_id); - } - - /** - * Clear all vectors - */ - clear(): void { - this.db.prepare('DELETE FROM vectors').run(); - - if (this.vssEnabled) { - this.db.prepare('DELETE FROM vss_vectors').run(); - this.db.prepare('DELETE FROM vss_map').run(); - } - } - - /** - * Rebuild VSS index from vectors table - * - * Useful after bulk operations or if VSS index gets out of sync. - */ - rebuildVssIndex(): void { - if (!this.vssEnabled) { - return; - } - - // Clear VSS tables - this.db.prepare('DELETE FROM vss_vectors').run(); - this.db.prepare('DELETE FROM vss_map').run(); - - // Reload from vectors table - const rows = this.db - .prepare('SELECT node_id, embedding FROM vectors') - .all() as Array<{ node_id: string; embedding: Buffer }>; - - this.db.transaction(() => { - for (const row of rows) { - const embedding = new Float32Array(row.embedding.buffer.slice( - row.embedding.byteOffset, - row.embedding.byteOffset + row.embedding.byteLength - )); - this.storeInVss(row.node_id, embedding); - } - })(); - } -} +export const VectorSearchManager = SqliteVectorStore; +export type VectorSearchManager = SqliteVectorStore; /** * Create a vector search manager + * + * @deprecated Use new SqliteVectorStore(db, dimension) instead */ export function createVectorSearch( db: SqliteDatabase, dimension?: number -): VectorSearchManager { - return new VectorSearchManager(db, dimension); +): SqliteVectorStore { + return new SqliteVectorStore(db, dimension); } diff --git a/src/vectors/sqlite-store.ts b/src/vectors/sqlite-store.ts new file mode 100644 index 00000000..3cee729b --- /dev/null +++ b/src/vectors/sqlite-store.ts @@ -0,0 +1,343 @@ +/** + * SQLite Vector Store + * + * Vector storage backend using SQLite with optional sqlite-vss for ANN search. + * Falls back to brute-force cosine similarity if sqlite-vss is not available. + */ + +import { SqliteDatabase } from '../db/sqlite-adapter'; +import { TextEmbedder, EMBEDDING_DIMENSION } from './embedder'; +import { VectorStore, VectorSearchOptions, VectorSearchResult } from './store'; + +/** + * SQLite-based vector store + * + * Stores embeddings as BLOBs in SQLite. Optionally uses sqlite-vss + * extension for accelerated approximate nearest neighbor search. + */ +export class SqliteVectorStore implements VectorStore { + readonly backendType = 'sqlite' as const; + private db: SqliteDatabase; + private vssEnabled = false; + private embeddingDimension: number; + + constructor(db: SqliteDatabase, dimension: number = EMBEDDING_DIMENSION) { + this.db = db; + this.embeddingDimension = dimension; + } + + async initialize(): Promise { + try { + await this.loadVssExtension(); + this.vssEnabled = true; + console.log('sqlite-vss extension loaded successfully'); + this.createVssTable(); + } catch (error) { + console.warn( + 'sqlite-vss extension not available, falling back to brute-force search:', + error instanceof Error ? error.message : String(error) + ); + this.vssEnabled = false; + } + + this.ensureVectorsTable(); + } + + private async loadVssExtension(): Promise { + try { + const vss = await import('sqlite-vss'); + if (typeof vss.load === 'function') { + vss.load(this.db as any); + } else if (typeof vss.default?.load === 'function') { + vss.default.load(this.db as any); + } else { + throw new Error('sqlite-vss load function not found'); + } + } catch (error) { + throw new Error(`Failed to load sqlite-vss: ${error instanceof Error ? error.message : String(error)}`); + } + } + + private createVssTable(): void { + const tableExists = this.db + .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='vss_vectors'") + .get(); + + if (!tableExists) { + this.db.exec(` + CREATE VIRTUAL TABLE IF NOT EXISTS vss_vectors USING vss0( + embedding(${this.embeddingDimension}) + ); + `); + + this.db.exec(` + CREATE TABLE IF NOT EXISTS vss_map ( + rowid INTEGER PRIMARY KEY, + node_id TEXT NOT NULL UNIQUE + ); + `); + + this.db.exec(` + CREATE INDEX IF NOT EXISTS idx_vss_map_node ON vss_map(node_id); + `); + } + } + + private ensureVectorsTable(): void { + this.db.exec(` + CREATE TABLE IF NOT EXISTS vectors ( + node_id TEXT PRIMARY KEY, + embedding BLOB NOT NULL, + model TEXT NOT NULL, + created_at INTEGER NOT NULL + ); + `); + } + + isAnnEnabled(): boolean { + return this.vssEnabled; + } + + async storeVector(nodeId: string, embedding: Float32Array, model: string): Promise { + const now = Date.now(); + const blob = Buffer.from(embedding.buffer); + this.db + .prepare( + ` + INSERT OR REPLACE INTO vectors (node_id, embedding, model, created_at) + VALUES (?, ?, ?, ?) + ` + ) + .run(nodeId, blob, model, now); + + if (this.vssEnabled) { + this.storeInVss(nodeId, embedding); + } + } + + private storeInVss(nodeId: string, embedding: Float32Array): void { + try { + const existing = this.db + .prepare('SELECT rowid FROM vss_map WHERE node_id = ?') + .get(nodeId) as { rowid: number } | undefined; + + if (existing) { + const vectorJson = JSON.stringify(Array.from(embedding)); + this.db + .prepare('UPDATE vss_vectors SET embedding = ? WHERE rowid = ?') + .run(vectorJson, existing.rowid); + } else { + const maxRow = this.db + .prepare('SELECT MAX(rowid) as max FROM vss_map') + .get() as { max: number | null } | undefined; + const newRowid = (maxRow?.max ?? 0) + 1; + + const vectorJson = JSON.stringify(Array.from(embedding)); + this.db + .prepare('INSERT INTO vss_vectors (rowid, embedding) VALUES (?, ?)') + .run(newRowid, vectorJson); + + this.db + .prepare('INSERT INTO vss_map (rowid, node_id) VALUES (?, ?)') + .run(newRowid, nodeId); + } + } catch (error) { + console.warn( + 'VSS storage failed, using brute-force search:', + error instanceof Error ? error.message : String(error) + ); + } + } + + async storeVectorBatch( + entries: Array<{ nodeId: string; embedding: Float32Array }>, + model: string + ): Promise { + const now = Date.now(); + + this.db.transaction(() => { + for (const entry of entries) { + const blob = Buffer.from(entry.embedding.buffer); + this.db + .prepare( + ` + INSERT OR REPLACE INTO vectors (node_id, embedding, model, created_at) + VALUES (?, ?, ?, ?) + ` + ) + .run(entry.nodeId, blob, model, now); + + if (this.vssEnabled) { + this.storeInVss(entry.nodeId, entry.embedding); + } + } + })(); + } + + async getVector(nodeId: string): Promise { + const row = this.db + .prepare('SELECT embedding FROM vectors WHERE node_id = ?') + .get(nodeId) as { embedding: Buffer } | undefined; + + if (!row) { + return null; + } + + return new Float32Array(row.embedding.buffer.slice( + row.embedding.byteOffset, + row.embedding.byteOffset + row.embedding.byteLength + )); + } + + async deleteVector(nodeId: string): Promise { + this.db.prepare('DELETE FROM vectors WHERE node_id = ?').run(nodeId); + + if (this.vssEnabled) { + const mapping = this.db + .prepare('SELECT rowid FROM vss_map WHERE node_id = ?') + .get(nodeId) as { rowid: number } | undefined; + + if (mapping) { + this.db.prepare('DELETE FROM vss_vectors WHERE rowid = ?').run(mapping.rowid); + this.db.prepare('DELETE FROM vss_map WHERE node_id = ?').run(nodeId); + } + } + } + + async search( + queryEmbedding: Float32Array, + options: VectorSearchOptions = {} + ): Promise { + const { limit = 10, minScore = 0 } = options; + + if (this.vssEnabled) { + return this.searchWithVss(queryEmbedding, limit, minScore); + } else { + return this.searchBruteForce(queryEmbedding, limit, minScore); + } + } + + private searchWithVss( + queryEmbedding: Float32Array, + limit: number, + minScore: number + ): VectorSearchResult[] { + try { + const vectorJson = JSON.stringify(Array.from(queryEmbedding)); + const safeLimit = Math.max(1, Math.floor(limit)); + + const rows = this.db + .prepare( + ` + SELECT m.node_id, v.distance + FROM ( + SELECT rowid, distance + FROM vss_vectors + WHERE vss_search(embedding, ?) + LIMIT ${safeLimit} + ) v + JOIN vss_map m ON m.rowid = v.rowid + ` + ) + .all(vectorJson) as Array<{ node_id: string; distance: number }>; + + return rows + .map((row) => ({ + nodeId: row.node_id, + score: 1 / (1 + row.distance), + })) + .filter((r) => r.score >= minScore); + } catch (error) { + console.warn( + 'VSS search failed, using brute-force:', + error instanceof Error ? error.message : String(error) + ); + return this.searchBruteForce(queryEmbedding, limit, minScore); + } + } + + private searchBruteForce( + queryEmbedding: Float32Array, + limit: number, + minScore: number + ): VectorSearchResult[] { + const rows = this.db + .prepare('SELECT node_id, embedding FROM vectors') + .all() as Array<{ node_id: string; embedding: Buffer }>; + + const results: VectorSearchResult[] = []; + + for (const row of rows) { + const embedding = new Float32Array(row.embedding.buffer.slice( + row.embedding.byteOffset, + row.embedding.byteOffset + row.embedding.byteLength + )); + + const score = TextEmbedder.cosineSimilarity(queryEmbedding, embedding); + + if (score >= minScore) { + results.push({ nodeId: row.node_id, score }); + } + } + + results.sort((a, b) => b.score - a.score); + return results.slice(0, limit); + } + + async getVectorCount(): Promise { + const result = this.db + .prepare('SELECT COUNT(*) as count FROM vectors') + .get() as { count: number }; + return result.count; + } + + async hasVector(nodeId: string): Promise { + const result = this.db + .prepare('SELECT 1 FROM vectors WHERE node_id = ? LIMIT 1') + .get(nodeId); + return !!result; + } + + async getIndexedNodeIds(): Promise { + const rows = this.db + .prepare('SELECT node_id FROM vectors') + .all() as Array<{ node_id: string }>; + return rows.map((r) => r.node_id); + } + + async clear(): Promise { + this.db.prepare('DELETE FROM vectors').run(); + + if (this.vssEnabled) { + this.db.prepare('DELETE FROM vss_vectors').run(); + this.db.prepare('DELETE FROM vss_map').run(); + } + } + + async rebuildIndex(): Promise { + if (!this.vssEnabled) { + return; + } + + this.db.prepare('DELETE FROM vss_vectors').run(); + this.db.prepare('DELETE FROM vss_map').run(); + + const rows = this.db + .prepare('SELECT node_id, embedding FROM vectors') + .all() as Array<{ node_id: string; embedding: Buffer }>; + + this.db.transaction(() => { + for (const row of rows) { + const embedding = new Float32Array(row.embedding.buffer.slice( + row.embedding.byteOffset, + row.embedding.byteOffset + row.embedding.byteLength + )); + this.storeInVss(row.node_id, embedding); + } + })(); + } + + async dispose(): Promise { + // No-op: SQLite connection is managed by DatabaseConnection + } +} diff --git a/src/vectors/store-factory.ts b/src/vectors/store-factory.ts new file mode 100644 index 00000000..d5bf375d --- /dev/null +++ b/src/vectors/store-factory.ts @@ -0,0 +1,77 @@ +/** + * Vector Store Factory + * + * Creates the appropriate vector store backend based on configuration. + */ + +import { SqliteDatabase } from '../db/sqlite-adapter'; +import { SqliteVectorStore } from './sqlite-store'; +import { VectorStore } from './store'; +import { EMBEDDING_DIMENSION } from './embedder'; + +/** + * Configuration for the vector store backend + */ +export interface VectorStoreConfig { + /** Backend type: 'sqlite' (default) or 'pgvector' */ + backend: 'sqlite' | 'pgvector'; + + /** PostgreSQL connection string (pgvector only). Can also use CODEGRAPH_PG_URL env var. */ + connectionString?: string; + + /** Index type for pgvector: 'hnsw' (default), 'ivfflat', or 'none' */ + indexType?: 'hnsw' | 'ivfflat' | 'none'; + + /** Distance metric for pgvector: 'cosine' (default), 'l2', or 'inner_product' */ + distanceMetric?: 'cosine' | 'l2' | 'inner_product'; + + /** Connection pool size for pgvector (default: 5) */ + poolSize?: number; + + /** Table name prefix for pgvector (default: 'codegraph_') */ + tablePrefix?: string; +} + +/** Default vector store config (SQLite) */ +export const DEFAULT_VECTOR_STORE_CONFIG: VectorStoreConfig = { + backend: 'sqlite', +}; + +/** + * Create the appropriate vector store based on configuration. + * + * For pgvector, the `pg` module is dynamically imported so it's only + * loaded when actually needed. + */ +export async function createVectorStore( + config: VectorStoreConfig = DEFAULT_VECTOR_STORE_CONFIG, + sqliteDb?: SqliteDatabase, + dimension: number = EMBEDDING_DIMENSION +): Promise { + if (config.backend === 'pgvector') { + const connectionString = config.connectionString || process.env.CODEGRAPH_PG_URL; + if (!connectionString) { + throw new Error( + 'PostgreSQL connection string required for pgvector backend. ' + + 'Set "vectorStore.connectionString" in .codegraph/config.json or set the CODEGRAPH_PG_URL environment variable.' + ); + } + + // Dynamic import so `pg` is only loaded when pgvector is configured + const { PgVectorStore } = await import('./pg-store'); + return new PgVectorStore({ + connectionString, + dimension, + indexType: config.indexType, + distanceMetric: config.distanceMetric, + poolSize: config.poolSize, + tablePrefix: config.tablePrefix, + }); + } + + // Default: SQLite + if (!sqliteDb) { + throw new Error('SQLite database instance required for sqlite vector backend.'); + } + return new SqliteVectorStore(sqliteDb, dimension); +} diff --git a/src/vectors/store.ts b/src/vectors/store.ts new file mode 100644 index 00000000..17b5472e --- /dev/null +++ b/src/vectors/store.ts @@ -0,0 +1,87 @@ +/** + * Vector Store Interface + * + * Abstraction layer for vector storage backends. + * Implementations: SqliteVectorStore (default), PgVectorStore (optional). + */ + +import { Node } from '../types'; + +/** + * Options for vector search + */ +export interface VectorSearchOptions { + /** Maximum number of results to return */ + limit?: number; + + /** Minimum similarity score (0-1) */ + minScore?: number; + + /** Node kinds to filter results */ + nodeKinds?: Node['kind'][]; +} + +/** + * Vector search result + */ +export interface VectorSearchResult { + nodeId: string; + score: number; +} + +/** + * Vector store backend interface + * + * All vector storage backends must implement this interface. + * Methods return Promises to support both synchronous (SQLite) and + * asynchronous (PostgreSQL) backends uniformly. + */ +export interface VectorStore { + /** Backend identifier — used for stats reporting. Minification-safe. */ + readonly backendType: 'sqlite' | 'pgvector'; + + /** Initialize the store (create tables, load extensions, connect) */ + initialize(): Promise; + + /** Store a single vector */ + storeVector(nodeId: string, embedding: Float32Array, model: string): Promise; + + /** Store multiple vectors in a batch (transactionally) */ + storeVectorBatch( + entries: Array<{ nodeId: string; embedding: Float32Array }>, + model: string + ): Promise; + + /** Retrieve a vector by node ID */ + getVector(nodeId: string): Promise; + + /** Check if a node has a stored vector */ + hasVector(nodeId: string): Promise; + + /** Get all node IDs that have vectors */ + getIndexedNodeIds(): Promise; + + /** Search for similar vectors */ + search( + queryEmbedding: Float32Array, + options?: VectorSearchOptions + ): Promise; + + /** Delete a vector */ + deleteVector(nodeId: string): Promise; + + /** Clear all vectors */ + clear(): Promise; + + /** Get count of stored vectors */ + getVectorCount(): Promise; + + /** Whether the store supports approximate nearest neighbor search */ + isAnnEnabled(): boolean; + + /** Rebuild index (no-op for stores without separate index structures) */ + rebuildIndex(): Promise; + + /** Release resources / close connections */ + dispose(): Promise; +} diff --git a/src/visualizer/server.ts b/src/visualizer/server.ts index 58f86881..749ba501 100644 --- a/src/visualizer/server.ts +++ b/src/visualizer/server.ts @@ -48,14 +48,15 @@ export class VisualizerServer { /** * Build a compact symbol index string for Claude prompts */ - private buildSymbolIndex(): string { + private async buildSymbolIndex(): Promise { if (this.symbolIndexCache) return this.symbolIndexCache; const validKinds: NodeKind[] = ['function', 'method', 'class', 'interface', 'component', 'route', 'enum', 'type_alias']; const byFile = new Map(); for (const kind of validKinds) { - for (const node of this.cg.getNodesByKind(kind)) { + const kindNodes = await this.cg.getNodesByKind(kind); + for (const node of kindNodes) { const symbols = byFile.get(node.filePath) || []; symbols.push(`${node.kind}:${node.name}`); byFile.set(node.filePath, symbols); @@ -78,7 +79,7 @@ export class VisualizerServer { // Check if claude is available (cache result) if (this.claudeAvailable === false) return null; - const symbolIndex = this.buildSymbolIndex(); + const symbolIndex = await this.buildSymbolIndex(); const prompt = `Given the question and codebase symbol index below, identify the single best ENTRY POINT symbol — the one function, component, or route handler where this flow starts. @@ -229,17 +230,17 @@ ${symbolIndex}`; try { // GET /api/status if (pathname === '/api/status') { - const stats = this.cg.getStats(); + const stats = await this.cg.getStats(); json({ stats, projectRoot: this.projectRoot, projectName: path.basename(this.projectRoot) }); return; } // GET /api/embeddings/status if (pathname === '/api/embeddings/status') { - const embeddingStats = this.cg.getEmbeddingStats(); + const embeddingStats = await this.cg.getEmbeddingStats(); const isInitialized = this.cg.isEmbeddingsInitialized(); const totalVectors = embeddingStats?.totalVectors ?? 0; - const stats = this.cg.getStats(); + const stats = await this.cg.getStats(); // Consider ready if we have vectors for at least half the eligible nodes const eligibleNodes = stats.nodeCount - (stats.nodesByKind.file ?? 0) - (stats.nodesByKind.import ?? 0); const isReady = totalVectors > 0 && totalVectors >= eligibleNodes * 0.5; @@ -295,7 +296,7 @@ ${symbolIndex}`; json({ results: [] }); return; } - const results = this.cg.searchNodes(q, { kinds: kind ? [kind] : undefined, limit }); + const results = await this.cg.searchNodes(q, { kinds: kind ? [kind] : undefined, limit }); json({ results }); return; } @@ -320,7 +321,7 @@ ${symbolIndex}`; // Find the entry point in the graph for (const name of claudeNames) { if (entryNodeId) break; - const results = this.cg.searchNodes(name, { kinds: validKinds, limit: 3 }); + const results = await this.cg.searchNodes(name, { kinds: validKinds, limit: 3 }); for (const r of results) { if (r.node.name.toLowerCase() === name.toLowerCase() || r.node.name.toLowerCase().includes(name.toLowerCase()) || @@ -341,7 +342,7 @@ ${symbolIndex}`; for (const kw of keywords) { if (entryNodeId) break; - const results = this.cg.searchNodes(kw, { kinds: validKinds, limit: 5 }); + const results = await this.cg.searchNodes(kw, { kinds: validKinds, limit: 5 }); if (results.length > 0) { entryNodeId = results[0]!.node.id; } @@ -354,7 +355,7 @@ ${symbolIndex}`; } // Get the call graph from this entry point (depth 3) - const callGraph = this.cg.getCallGraph(entryNodeId, 3); + const callGraph = await this.cg.getCallGraph(entryNodeId, 3); const result = serializeSubgraph(callGraph); json({ @@ -374,7 +375,7 @@ ${symbolIndex}`; const kinds: NodeKind[] = ['class', 'function', 'interface', 'component', 'enum', 'type_alias']; const nodes: Node[] = []; for (const kind of kinds) { - const kindNodes = this.cg.getNodesByKind(kind); + const kindNodes = await this.cg.getNodesByKind(kind); for (const n of kindNodes) { if (n.isExported || n.kind === 'class' || n.kind === 'component') { nodes.push(n); @@ -389,7 +390,7 @@ ${symbolIndex}`; // GET /api/files if (pathname === '/api/files') { - const files = this.cg.getFiles(); + const files = await this.cg.getFiles(); json({ files }); return; } @@ -402,13 +403,13 @@ ${symbolIndex}`; // GET /api/node/ if (!sub || sub === '/') { - const node = this.cg.getNode(nodeId); + const node = await this.cg.getNode(nodeId); if (!node) { json({ error: 'Node not found' }, 404); return; } const code = await this.cg.getCode(nodeId); - const ancestors = this.cg.getAncestors(nodeId); + const ancestors = await this.cg.getAncestors(nodeId); json({ node, code, ancestors }); return; } @@ -416,7 +417,7 @@ ${symbolIndex}`; // GET /api/node//callers?depth=... if (sub === '/callers') { const depth = parseInt(query.depth || '1', 10); - const items = this.cg.getCallers(nodeId, depth); + const items = await this.cg.getCallers(nodeId, depth); json({ items }); return; } @@ -424,14 +425,14 @@ ${symbolIndex}`; // GET /api/node//callees?depth=... if (sub === '/callees') { const depth = parseInt(query.depth || '1', 10); - const items = this.cg.getCallees(nodeId, depth); + const items = await this.cg.getCallees(nodeId, depth); json({ items }); return; } // GET /api/node//children if (sub === '/children') { - const children = this.cg.getChildren(nodeId); + const children = await this.cg.getChildren(nodeId); json({ children }); return; } @@ -439,7 +440,7 @@ ${symbolIndex}`; // GET /api/node//impact?depth=... if (sub === '/impact') { const depth = parseInt(query.depth || '2', 10); - const subgraph = this.cg.getImpactRadius(nodeId, depth); + const subgraph = await this.cg.getImpactRadius(nodeId, depth); json(serializeSubgraph(subgraph)); return; } @@ -447,14 +448,14 @@ ${symbolIndex}`; // GET /api/node//callgraph?depth=... if (sub === '/callgraph') { const depth = parseInt(query.depth || '2', 10); - const subgraph = this.cg.getCallGraph(nodeId, depth); + const subgraph = await this.cg.getCallGraph(nodeId, depth); json(serializeSubgraph(subgraph)); return; } // GET /api/node//context if (sub === '/context') { - const context = this.cg.getContext(nodeId); + const context = await this.cg.getContext(nodeId); json({ context }); return; } @@ -470,7 +471,7 @@ ${symbolIndex}`; json({ error: 'path parameter required' }, 400); return; } - const nodes = this.cg.getNodesInFile(filePath); + const nodes = await this.cg.getNodesInFile(filePath); json({ nodes }); return; }