From a0239bc0cc9cc95a76417973732c4e9dc0301558 Mon Sep 17 00:00:00 2001 From: Rakshat28 Date: Sun, 24 May 2026 02:58:42 +0530 Subject: [PATCH] Implement ast-search index --persist and ast-search lookup --symbol query against the DB --- Cargo.lock | 86 +++++ Cargo.toml | 1 + src/extractor.rs | 499 +++++++++++++++++++++++++++ src/indexer.rs | 203 +++++++++-- src/lib.rs | 2 + src/main.rs | 350 ++++++++++++++++++- src/memory.rs | 706 ++++++++++++++++++++++++++++++++++++++ src/output.rs | 81 ++++- src/types.rs | 2 + tests/integration_test.rs | 141 ++++++++ 10 files changed, 2042 insertions(+), 29 deletions(-) create mode 100644 src/extractor.rs create mode 100644 src/memory.rs diff --git a/Cargo.lock b/Cargo.lock index 8a9acf1..6fafc47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -368,6 +380,7 @@ dependencies = [ "ratatui", "rayon", "regex", + "rusqlite", "scopeguard", "serde", "similar", @@ -408,6 +421,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.4.1" @@ -499,6 +524,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -516,6 +550,15 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -627,6 +670,17 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -757,6 +811,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + [[package]] name = "plotters" version = "0.3.7" @@ -898,6 +958,20 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustix" version = "1.1.4" @@ -1324,6 +1398,18 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "virtue" version = "0.0.18" diff --git a/Cargo.toml b/Cargo.toml index 2a9fa51..341fc2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ regex = "1" memmap2 = "0.9" bincode = { version = "2", features = ["serde"] } serde = { version = "1", features = ["derive"] } +rusqlite = { version = "0.31", features = ["bundled"] } ratatui = "0.27" crossterm = { version = "0.27", features = ["event-stream"] } diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 0000000..4e53631 --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,499 @@ +#![allow(clippy::missing_errors_doc, clippy::must_use_candidate)] + +use crate::memory::{NewSymbolRow, SymbolKind}; +use crate::parser::FileSource; +use crate::types::Language; +use std::fmt; +use tree_sitter::{Language as TsLanguage, Node, Query, QueryCursor, Tree}; + +#[derive(Debug, Clone)] +pub struct SymbolExtractor { + pub language: Language, +} + +impl SymbolExtractor { + pub fn extract(&self, tree: &Tree, source: &FileSource, file_id: i64) -> Vec { + let queries = queries_for_language(&self.language); + let ts_language = ts_language_for(&self.language); + let mut symbols = Vec::new(); + + if let Some(query) = queries.functions { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Function, + )); + } + + if let Some(query) = queries.structs { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Struct, + )); + } + + if let Some(query) = queries.classes { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Class, + )); + } + + if let Some(query) = queries.interfaces { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Interface, + )); + } + + if let Some(query) = queries.type_aliases { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::TypeAlias, + )); + } + + if let Some(query) = queries.imports { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Import, + )); + } + + if let Some(query) = queries.traits { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Trait, + )); + } + + if let Some(query) = queries.enums { + symbols.extend(extract_with_query( + tree, + source, + file_id, + &ts_language, + query, + SymbolKind::Enum, + )); + } + + symbols + } +} + +struct LangQueries { + functions: Option<&'static str>, + structs: Option<&'static str>, + classes: Option<&'static str>, + interfaces: Option<&'static str>, + type_aliases: Option<&'static str>, + imports: Option<&'static str>, + traits: Option<&'static str>, + enums: Option<&'static str>, +} + +#[derive(Debug, Clone)] +struct CaptureData { + text: String, + start_line: usize, + start_col: usize, + end_line: usize, + end_col: usize, +} + +impl fmt::Display for CaptureData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.text) + } +} + +fn extract_with_query( + tree: &Tree, + source: &FileSource, + file_id: i64, + ts_lang: &TsLanguage, + query_str: &str, + #[allow(clippy::needless_pass_by_value)] kind: SymbolKind, +) -> Vec { + let query = Query::new(ts_lang, query_str).expect("invalid extractor query"); + let capture_names = query.capture_names(); + let mut cursor = QueryCursor::new(); + let mut symbols = Vec::new(); + + for query_match in cursor.matches(&query, tree.root_node(), source.as_bytes()) { + let mut name_capture: Option = None; + let mut trait_name_capture: Option = None; + let mut type_name_capture: Option = None; + let mut signature: Option = None; + + for capture in query_match.captures { + let capture_name = capture_names[capture.index as usize].to_string(); + match capture_name.as_str() { + "name" => name_capture = capture_data(source, capture.node), + "trait_name" => trait_name_capture = capture_data(source, capture.node), + "type_name" => type_name_capture = capture_data(source, capture.node), + "sig" => signature = capture_signature(source, capture.node), + _ => {} + } + } + + let (name, start_line, start_col, end_line, end_col) = match kind { + SymbolKind::Trait if name_capture.is_none() => { + let trait_name = match trait_name_capture.clone() { + Some(data) => data, + None => continue, + }; + let type_name = match type_name_capture.clone() { + Some(data) => data, + None => continue, + }; + ( + format!("{} for {}", trait_name.text, type_name.text), + trait_name.start_line, + trait_name.start_col, + type_name.end_line, + type_name.end_col, + ) + } + _ => { + let capture = match name_capture.clone().or(trait_name_capture.clone()) { + Some(data) => data, + None => continue, + }; + ( + capture.text, + capture.start_line, + capture.start_col, + capture.end_line, + capture.end_col, + ) + } + }; + + symbols.push(NewSymbolRow { + file_id, + kind: kind.clone(), + name, + start_line, + start_col, + end_line, + end_col, + signature, + }); + } + + symbols +} + +fn capture_data(source: &FileSource, node: Node<'_>) -> Option { + let bytes = source.as_bytes().get(node.byte_range())?; + let text = std::str::from_utf8(bytes).ok()?.to_string(); + let start = node.start_position(); + let end = node.end_position(); + Some(CaptureData { + text, + start_line: start.row + 1, + start_col: start.column, + end_line: end.row + 1, + end_col: end.column, + }) +} + +fn capture_signature(source: &FileSource, node: Node<'_>) -> Option { + let text = capture_data(source, node)?.text; + Some(truncate_signature(&text)) +} + +fn truncate_signature(text: &str) -> String { + let count = text.chars().count(); + if count <= 500 { + text.to_string() + } else { + let mut truncated = text.chars().take(500).collect::(); + truncated.push('…'); + truncated + } +} + +fn ts_language_for(language: &Language) -> TsLanguage { + match language { + Language::Rust => crate::parser::get_language("rust").expect("missing rust language"), + Language::Python => crate::parser::get_language("python").expect("missing python language"), + Language::JavaScript => { + crate::parser::get_language("js").expect("missing javascript language") + } + Language::TypeScript => { + crate::parser::get_language("ts").expect("missing typescript language") + } + Language::Go => crate::parser::get_language("go").expect("missing go language"), + Language::C => crate::parser::get_language("c").expect("missing c language"), + Language::Cpp => crate::parser::get_language("cpp").expect("missing cpp language"), + } +} + +fn queries_for_language(lang: &Language) -> LangQueries { + match lang { + Language::Rust => LangQueries { + functions: Some("(source_file (function_item name: (identifier) @name) @sig)"), + structs: Some("(struct_item name: (type_identifier) @name) @sig"), + classes: None, + interfaces: None, + type_aliases: Some("(type_item name: (type_identifier) @name) @sig"), + imports: Some("(use_declaration (_) @name) @sig"), + traits: Some( + "(trait_item name: (type_identifier) @name) @sig\n(impl_item trait: (_) @trait_name type: (_) @type_name) @sig", + ), + enums: Some("(enum_item name: (type_identifier) @name) @sig"), + }, + Language::Python => LangQueries { + functions: Some("(function_definition name: (identifier) @name) @sig"), + structs: None, + classes: Some("(class_definition name: (identifier) @name) @sig"), + interfaces: None, + type_aliases: None, + imports: Some( + "(import_statement (dotted_name) @name) @sig\n(import_from_statement module_name: (dotted_name) @name) @sig", + ), + traits: None, + enums: None, + }, + Language::JavaScript => LangQueries { + functions: Some("(function_declaration name: (identifier) @name) @sig"), + structs: None, + classes: Some("(class_declaration name: (identifier) @name) @sig"), + interfaces: None, + type_aliases: None, + imports: Some("(import_statement source: (string) @name) @sig"), + traits: None, + enums: None, + }, + Language::TypeScript => LangQueries { + functions: Some("(function_declaration name: (identifier) @name) @sig"), + structs: None, + classes: Some("(class_declaration name: (type_identifier) @name) @sig"), + interfaces: Some("(interface_declaration name: (type_identifier) @name) @sig"), + type_aliases: Some("(type_alias_declaration name: (type_identifier) @name) @sig"), + imports: Some("(import_statement source: (string) @name) @sig"), + traits: None, + enums: None, + }, + Language::Go => LangQueries { + functions: Some("(function_declaration name: (identifier) @name) @sig"), + structs: Some("(type_declaration (type_spec name: (type_identifier) @name) @sig)"), + classes: None, + interfaces: None, + type_aliases: None, + imports: Some("(import_declaration (import_spec path: (interpreted_string_literal) @name) @sig)"), + traits: None, + enums: None, + }, + Language::C => LangQueries { + functions: Some( + "(function_definition declarator: (function_declarator declarator: (identifier) @name) @sig)", + ), + structs: Some( + "(type_definition declarator: (type_identifier) @name) @sig\n(struct_specifier name: (type_identifier) @name) @sig", + ), + classes: None, + interfaces: None, + type_aliases: None, + imports: Some("(preproc_include path: (_) @name) @sig"), + traits: None, + enums: Some("(enum_specifier name: (type_identifier) @name) @sig"), + }, + Language::Cpp => LangQueries { + functions: Some( + "(function_definition declarator: (function_declarator declarator: (identifier) @name) @sig)", + ), + structs: Some("(struct_specifier name: (type_identifier) @name) @sig"), + classes: Some("(class_specifier name: (type_identifier) @name) @sig"), + interfaces: None, + type_aliases: None, + imports: Some("(preproc_include path: (_) @name) @sig"), + traits: None, + enums: Some("(enum_specifier name: (type_identifier) @name) @sig"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::get_language; + use tree_sitter::Parser; + + fn parse_source(language: &str, source: &str) -> (Tree, FileSource) { + let ts_language = get_language(language).unwrap(); + let mut parser = Parser::new(); + parser.set_language(&ts_language).unwrap(); + let tree = parser.parse(source.as_bytes(), None).unwrap(); + (tree, FileSource::Heap(source.to_string())) + } + + fn extract(language: Language, source: &str) -> Vec { + let (tree, file_source) = parse_source( + match language { + Language::Rust => "rust", + Language::Python => "python", + Language::JavaScript => "js", + Language::TypeScript => "ts", + Language::Go => "go", + Language::C => "c", + Language::Cpp => "cpp", + }, + source, + ); + let extractor = SymbolExtractor { language }; + extractor.extract(&tree, &file_source, 7) + } + + #[test] + fn test_extract_rust_functions() { + let rows = extract(Language::Rust, "fn alpha() {}\nfn beta() {}"); + assert_eq!(rows.len(), 2); + assert!(rows.iter().all(|row| row.kind == SymbolKind::Function)); + let names: Vec<_> = rows.iter().map(|row| row.name.as_str()).collect(); + assert!(names.contains(&"alpha")); + assert!(names.contains(&"beta")); + } + + #[test] + fn test_extract_rust_struct() { + let rows = extract(Language::Rust, "struct Config { timeout: u64 }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Struct); + assert_eq!(rows[0].name, "Config"); + } + + #[test] + fn test_extract_rust_enum() { + let rows = extract(Language::Rust, "enum Status { Active, Inactive }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Enum); + assert_eq!(rows[0].name, "Status"); + } + + #[test] + fn test_extract_rust_trait_impl() { + let rows = extract(Language::Rust, "impl core::fmt::Display for Config { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { Ok(()) } }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Trait); + assert!(rows[0].name.contains("Display")); + assert!(rows[0].name.contains("Config")); + } + + #[test] + fn test_extract_python_functions() { + let rows = + extract(Language::Python, "def greet(name):\n pass\ndef farewell():\n pass"); + assert_eq!(rows.len(), 2); + assert!(rows.iter().all(|row| row.kind == SymbolKind::Function)); + } + + #[test] + fn test_extract_python_class() { + let rows = extract(Language::Python, "class MyClass:\n pass"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Class); + assert_eq!(rows[0].name, "MyClass"); + } + + #[test] + fn test_extract_javascript_function() { + let rows = extract(Language::JavaScript, "function authenticate(user) { return true; }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Function); + } + + #[test] + fn test_extract_typescript_interface() { + let rows = extract(Language::TypeScript, "interface Shape { area(): number; }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Interface); + assert_eq!(rows[0].name, "Shape"); + } + + #[test] + fn test_extract_typescript_type_alias() { + let rows = extract(Language::TypeScript, "type Point = { x: number; y: number; };\n"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::TypeAlias); + assert_eq!(rows[0].name, "Point"); + } + + #[test] + fn test_extract_go_function() { + let rows = + extract(Language::Go, "package main\nfunc greet(name string) string { return name }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Function); + } + + #[test] + fn test_extract_c_function() { + let rows = extract(Language::C, "int add(int a, int b) { return a + b; }"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Function); + assert_eq!(rows[0].name, "add"); + } + + #[test] + fn test_extract_cpp_class() { + let rows = extract(Language::Cpp, "class Calculator { public: int add(int a, int b); };\n"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].kind, SymbolKind::Class); + assert_eq!(rows[0].name, "Calculator"); + } + + #[test] + fn test_signature_truncated_at_500_chars() { + let long_params = (0..120).map(|i| format!("value{i}: i32")).collect::>().join(", "); + let source = format!("fn huge({long_params}) {{}}\n"); + let rows = extract(Language::Rust, &source); + assert_eq!(rows.len(), 1); + let signature = rows[0].signature.as_ref().unwrap(); + assert!(signature.chars().count() <= 501); + } + + #[test] + fn test_extract_empty_source_returns_empty() { + let rows = extract(Language::Rust, "let x = 1;"); + assert!(rows.is_empty()); + } + + #[test] + fn test_extract_returns_correct_positions() { + let rows = extract(Language::Rust, "fn foo() {}"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].start_line, 1); + assert_eq!(rows[0].start_col, 3); + } +} diff --git a/src/indexer.rs b/src/indexer.rs index 97f857e..2cce2db 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -1,3 +1,17 @@ +#![allow( + clippy::cast_possible_wrap, + clippy::if_not_else, + clippy::manual_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc, + clippy::needless_borrow, + clippy::needless_return, + clippy::single_match_else, + clippy::too_many_lines, + clippy::uninlined_format_args, + clippy::unnecessary_wraps +)] + use rayon::prelude::*; use std::collections::HashSet; use std::fs; @@ -6,19 +20,27 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex}; use crate::bloom::BloomFilter; +use crate::extractor::SymbolExtractor; use crate::index::index_path_for_root; use crate::index::{IndexEntry, IndexManifest}; +use crate::memory::{memory_db_path, MemoryDb, NewFileRow}; use crate::parser::detect_language; +use crate::parser::parse_file_with_metadata; use crate::trigram::extract_unique_trigrams_from_bytes; use crate::types::{LangMode, Language, Result}; use crate::walker::{build_auto_walker, build_walker}; -pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool) -> Result<()> { +pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool, persist: bool) -> Result<()> { let root_abs = match fs::canonicalize(root) { Ok(p) => p, Err(_) => root.to_path_buf(), }; let index_path = index_path_for_root(&root_abs); + let memory_db = if persist { + Some(Arc::new(Mutex::new(MemoryDb::open(&memory_db_path(&root_abs))?))) + } else { + None + }; let mut manifest = match crate::index::load_index(&index_path) { Ok(m) => { @@ -34,6 +56,7 @@ pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool) -> Result<( let entries_arc = Arc::new(Mutex::new(Vec::::new())); let indexed_count = Arc::new(Mutex::new(0usize)); let skipped_count = Arc::new(Mutex::new(0usize)); + let symbols_extracted = Arc::new(Mutex::new(0usize)); let walker: Box> + Send> = match lang_mode { @@ -44,7 +67,9 @@ pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool) -> Result<( let entries_ref = Arc::clone(&entries_arc); let indexed_ref = Arc::clone(&indexed_count); let skipped_ref = Arc::clone(&skipped_count); + let symbols_ref = Arc::clone(&symbols_extracted); let manifest_ref = Arc::new(manifest); + let db_ref = memory_db.clone(); walker.par_bridge().for_each(move |entry_result| match entry_result { Ok(entry) => { @@ -74,16 +99,112 @@ pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool) -> Result<( }, }; - match index_file(&path, &metadata, &lang_str) { - Ok(new_entry) => { - if verbose { - eprintln!("indexed: {}", path.display()); + if persist { + let detected_lang = match lang_mode { + LangMode::Single(lang) => Some(lang.clone()), + LangMode::Auto => detect_language(&path), + }; + + match detected_lang { + Some(language) => { + let ts_language = crate::parser::get_language(lang_to_str(&language)); + match ts_language { + Ok(ts_language) => { + match parse_file_with_metadata(&path, &ts_language, &metadata) { + Ok((tree, source)) => { + let new_entry = match index_entry_from_source( + &path, + &metadata, + &lang_str, + source.as_bytes(), + ) { + Ok(entry) => entry, + Err(_) => { + drop(tree); + drop(source); + return; + } + }; + + if let Some(db) = &db_ref { + let file_id = { + let db_guard = db.lock().unwrap(); + match db_guard.upsert_file(&NewFileRow { + path: path.to_string_lossy().to_string(), + mtime: new_entry.mtime_secs as i64, + language: new_entry.language.clone(), + }) { + Ok(file_id) => { + let _ = db_guard + .delete_symbols_for_file(file_id); + Some(file_id) + } + Err(_) => None, + } + }; + if let Some(file_id) = file_id { + let extractor = SymbolExtractor { language }; + let symbols = + extractor.extract(&tree, &source, file_id); + *symbols_ref.lock().unwrap() += symbols.len(); + let db_guard = db.lock().unwrap(); + let _ = db_guard.insert_symbols_batch(&symbols); + } + } + + if verbose { + eprintln!("indexed: {}", path.display()); + } + *indexed_ref.lock().unwrap() += 1; + entries_ref.lock().unwrap().push(new_entry); + drop(tree); + drop(source); + } + Err(_) => match index_file(&path, &metadata, &lang_str) { + Ok(new_entry) => { + if verbose { + eprintln!("indexed: {}", path.display()); + } + *indexed_ref.lock().unwrap() += 1; + entries_ref.lock().unwrap().push(new_entry); + } + Err(_) => return, + }, + } + } + Err(_) => match index_file(&path, &metadata, &lang_str) { + Ok(new_entry) => { + if verbose { + eprintln!("indexed: {}", path.display()); + } + *indexed_ref.lock().unwrap() += 1; + entries_ref.lock().unwrap().push(new_entry); + } + Err(_) => return, + }, + } } - *indexed_ref.lock().unwrap() += 1; - entries_ref.lock().unwrap().push(new_entry); + None => match index_file(&path, &metadata, &lang_str) { + Ok(new_entry) => { + if verbose { + eprintln!("indexed: {}", path.display()); + } + *indexed_ref.lock().unwrap() += 1; + entries_ref.lock().unwrap().push(new_entry); + } + Err(_) => return, + }, } - Err(_) => { - return; + } else { + match index_file(&path, &metadata, &lang_str) { + Ok(new_entry) => { + if verbose { + eprintln!("indexed: {}", path.display()); + } + *indexed_ref.lock().unwrap() += 1; + entries_ref.lock().unwrap().push(new_entry); + } + Err(_) => return, } } } @@ -104,13 +225,34 @@ pub fn build_index(root: &Path, lang_mode: &LangMode, verbose: bool) -> Result<( crate::index::save_index(&manifest, &index_path)?; + if persist { + if let Some(db) = &memory_db { + let db_guard = db.lock().unwrap(); + if let Ok(files) = db_guard.list_files() { + for file in files { + if !current_paths.contains(&PathBuf::from(&file.path)) { + let _ = db_guard.delete_file_by_path(&file.path); + } + } + } + } + } + let indexed = *indexed_count.lock().unwrap(); let skipped = *skipped_count.lock().unwrap(); + let symbols = *symbols_extracted.lock().unwrap(); - eprintln!( - "indexed {} files, skipped {} fresh, removed {} stale entries", - indexed, skipped, removed - ); + if persist { + eprintln!( + "indexed {} files, skipped {} fresh, removed {} stale entries, extracted {} symbols", + indexed, skipped, removed, symbols + ); + } else { + eprintln!( + "indexed {} files, skipped {} fresh, removed {} stale entries", + indexed, skipped, removed + ); + } eprintln!("index written to {}", index_path.display()); Ok(()) @@ -128,6 +270,15 @@ fn is_fresh(entry: &IndexEntry, metadata: &fs::Metadata) -> bool { fn index_file(path: &Path, metadata: &fs::Metadata, language: &str) -> Result { let bytes = fs::read(path)?; + index_entry_from_source(path, metadata, language, &bytes) +} + +fn index_entry_from_source( + path: &Path, + metadata: &fs::Metadata, + language: &str, + bytes: &[u8], +) -> Result { let trigrams = extract_unique_trigrams_from_bytes(&bytes); let mut filter = BloomFilter::new(); filter.insert_trigrams(&trigrams); @@ -169,6 +320,7 @@ mod tests { use super::*; use crate::bloom::BloomFilter; use crate::index::IndexEntry; + use crate::memory::{memory_db_path, MemoryDb}; use crate::types::LangMode; use tempfile::TempDir; @@ -262,7 +414,7 @@ mod tests { let b = tmp.path().join("b.rs"); fs::write(&a, "fn a() {}\n").unwrap(); fs::write(&b, "fn b() {}\n").unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); assert!(crate::index::index_exists(tmp.path())); } @@ -271,10 +423,10 @@ mod tests { let tmp = TempDir::new().unwrap(); let a = tmp.path().join("a.rs"); fs::write(&a, "fn a() {}\n").unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let before = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())).unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let after = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())).unwrap(); assert_eq!(before.entries.len(), after.entries.len()); @@ -285,14 +437,14 @@ mod tests { let tmp = TempDir::new().unwrap(); let a = tmp.path().join("a.rs"); fs::write(&a, "fn a() {}\n").unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let _entry = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())) .unwrap() .entries .pop() .unwrap(); fs::write(&a, "fn a_changed() {}\n").unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let after = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())).unwrap(); assert!(after.entries.len() >= 1); @@ -305,14 +457,25 @@ mod tests { let b = tmp.path().join("b.rs"); fs::write(&a, "fn a() {}\n").unwrap(); fs::write(&b, "fn b() {}\n").unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let manifest = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())).unwrap(); assert!(manifest.entries.len() >= 2); fs::remove_file(&b).unwrap(); - build_index(tmp.path(), &LangMode::Single(Language::Rust), false).unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, false).unwrap(); let manifest2 = crate::index::load_index(&crate::index::index_path_for_root(tmp.path())).unwrap(); assert!(manifest2.entries.len() >= 1); } + + #[test] + fn test_build_index_with_persist_writes_memory_db() { + let tmp = TempDir::new().unwrap(); + let file = tmp.path().join("a.rs"); + fs::write(&file, "fn add(a: i32, b: i32) -> i32 { a + b }\n").unwrap(); + build_index(tmp.path(), &LangMode::Single(Language::Rust), false, true).unwrap(); + let db = MemoryDb::open(&memory_db_path(tmp.path())).unwrap(); + assert!(db.file_count().unwrap() >= 1); + assert!(db.symbol_count().unwrap() >= 1); + } } diff --git a/src/lib.rs b/src/lib.rs index b11a0e7..1bb6368 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,10 @@ #![warn(clippy::pedantic)] pub mod bloom; +pub mod extractor; pub mod index; pub mod indexer; +pub mod memory; pub mod output; pub mod parser; pub mod query; diff --git a/src/main.rs b/src/main.rs index d0d1ab6..26e1810 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,18 +4,20 @@ use std::{ collections::{HashMap, HashSet}, fs, - path::PathBuf, + path::{Path, PathBuf}, process, sync::{Arc, Mutex}, time::{Duration, Instant}, }; -use clap::{CommandFactory, Parser}; +use clap::{Args, CommandFactory, Parser, Subcommand}; use clap_complete::{generate, Shell}; use rayon::prelude::*; mod bloom; mod index; +#[allow(dead_code)] +mod memory; mod output; mod parser; mod query; @@ -27,7 +29,8 @@ pub mod walker; use bloom::BloomFilter; use index::{index_path_for_root, load_index, save_index, IndexEntry, IndexManifest}; -use output::{print_match, print_summary, resolve_color_mode, ColorMode}; +use memory::{memory_db_path, FileRow, MemoryDb, SymbolKind, SymbolRow}; +use output::{print_lookup_results, print_match, print_summary, resolve_color_mode, ColorMode}; use parser::{detect_language, get_all_languages, parse_file_with_metadata}; use sieve::{ build_query_trigram_set, get_file_index_status, should_parse_file, FileIndexStatus, @@ -84,11 +87,25 @@ fn handle_file_error(error: &FileError, skip_count: &Mutex) { dora -q '(function_item name: (identifier) @fn)' -p ./src\n\n\ See https://github.com/your-org/dora for full documentation." )] +struct App { + #[command(flatten)] + cli: Cli, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Lookup(LookupArgs), +} + +#[derive(Args, Debug)] +#[command(next_line_help = true)] struct Cli { #[arg( short = 'q', long = "query", - required = true, num_args = 1.., value_name = "S-EXPR", help = "Tree-sitter S-expression query (repeatable: -q QUERY1 -q QUERY2)", @@ -189,6 +206,38 @@ struct Cli { yes: bool, } +#[derive(Args, Debug)] +struct LookupArgs { + #[arg(long = "symbol", value_name = "NAME", help = "Lookup an exact symbol name")] + symbol: Option, + + #[arg(long = "prefix", value_name = "PREFIX", help = "Lookup symbols by name prefix")] + prefix: Option, + + #[arg(long = "kind", value_name = "KIND", help = "Restrict matches to a symbol kind")] + kind: Option, + + #[arg( + short = 'p', + long = "path", + value_name = "DIR", + default_value = ".", + help = "Root directory whose persisted index should be queried" + )] + path: PathBuf, + + #[arg( + long = "lang", + value_name = "LANG", + default_value = "auto", + help = "Language to filter after lookup: rust, python, js, ts, go, c, cpp, auto" + )] + lang: String, + + #[arg(long = "no-color", default_value_t = false, help = "Disable ANSI color output")] + no_color: bool, +} + struct SearchOutcome { results: Vec, files_walked: usize, @@ -243,6 +292,200 @@ impl Cli { } } +impl LookupArgs { + fn validate(&self) -> std::result::Result<(), String> { + let symbol_present = self.symbol.as_ref().is_some_and(|value| !value.trim().is_empty()); + let prefix_present = self.prefix.as_ref().is_some_and(|value| !value.trim().is_empty()); + + if symbol_present == prefix_present { + return Err("specify exactly one of --symbol or --prefix".to_string()); + } + + if !self.path.exists() { + return Err(format!( + "path does not exist: {}\n hint: check for typos or run from the correct directory", + self.path.display() + )); + } + + if !self.path.is_dir() { + return Err(format!( + "path is not a directory: {}\n hint: --path must point to a directory, not a file", + self.path.display() + )); + } + + let supported = ["rust", "python", "js", "ts", "go", "c", "cpp", "auto"]; + if !supported.contains(&self.lang.as_str()) { + return Err(format!( + "unsupported language: '{}'\n supported languages: rust, python, js, ts, go, c, cpp, auto\n example: --lang rust", + self.lang + )); + } + + if let Some(kind) = &self.kind { + parse_lookup_kind(kind)?; + } + + Ok(()) + } +} + +fn parse_lookup_kind(kind: &str) -> std::result::Result { + match kind.trim().to_lowercase().as_str() { + "function" => Ok(SymbolKind::Function), + "method" => Ok(SymbolKind::Method), + "struct" => Ok(SymbolKind::Struct), + "enum" => Ok(SymbolKind::Enum), + "trait" => Ok(SymbolKind::Trait), + "interface" => Ok(SymbolKind::Interface), + "typealias" => Ok(SymbolKind::TypeAlias), + "constant" => Ok(SymbolKind::Constant), + "variable" => Ok(SymbolKind::Variable), + "class" => Ok(SymbolKind::Class), + "module" => Ok(SymbolKind::Module), + "import" => Ok(SymbolKind::Import), + "unknown" => Ok(SymbolKind::Unknown), + _ => Err(format!( + "unsupported kind: '{}'\n supported kinds: function, method, struct, enum, trait, interface, typealias, constant, variable, class, module, import, unknown", + kind + )), + } +} + +fn lookup_db_path(root: &Path) -> PathBuf { + memory_db_path(root) +} + +fn open_lookup_db(root: &Path) -> std::result::Result { + let db_path = lookup_db_path(root); + if !db_path.exists() { + return Err(format!( + "no structural index found at {}\n hint: run dora --persist {} first", + db_path.display(), + root.display() + )); + } + + MemoryDb::open(&db_path).map_err(|error| format!("error: {error}")) +} + +fn filter_lookup_rows_by_language( + rows: Vec<(SymbolRow, FileRow)>, + lang: &str, +) -> std::result::Result, String> { + if lang == "auto" { + return Ok(rows); + } + + let desired = lang.to_string(); + Ok(rows.into_iter().filter(|(_, file)| file.language == desired).collect()) +} + +fn execute_symbol_lookup( + db: &MemoryDb, + symbol: &str, + kind: Option<&SymbolKind>, + lang: &str, +) -> std::result::Result, String> { + let symbols = match kind { + Some(kind) => db + .find_symbols_by_name_and_kind(symbol, kind) + .map_err(|error| format!("error: {error}"))?, + None => db.find_symbols_by_name(symbol).map_err(|error| format!("error: {error}"))?, + }; + + collect_lookup_rows(db, symbols, lang) +} + +fn execute_prefix_lookup( + db: &MemoryDb, + prefix: &str, + kind: Option<&SymbolKind>, + lang: &str, +) -> std::result::Result, String> { + let symbols = + db.find_symbols_by_name_prefix(prefix).map_err(|error| format!("error: {error}"))?; + let symbols = if let Some(kind) = kind { + symbols.into_iter().filter(|symbol| &symbol.kind == kind).collect() + } else { + symbols + }; + + collect_lookup_rows(db, symbols, lang) +} + +fn collect_lookup_rows( + db: &MemoryDb, + symbols: Vec, + lang: &str, +) -> std::result::Result, String> { + let mut rows = Vec::new(); + + for symbol in symbols { + let file = db + .get_file_by_id(symbol.file_id) + .map_err(|error| format!("error: {error}"))? + .ok_or_else(|| format!("error: missing file row for file_id {}", symbol.file_id))?; + rows.push((symbol, file)); + } + + let mut filtered = filter_lookup_rows_by_language(rows, lang)?; + filtered.sort_by(|left, right| { + left.1 + .path + .cmp(&right.1.path) + .then_with(|| left.0.start_line.cmp(&right.0.start_line)) + .then_with(|| left.0.start_col.cmp(&right.0.start_col)) + .then_with(|| left.0.name.cmp(&right.0.name)) + }); + Ok(filtered) +} + +fn run_lookup_mode(args: &LookupArgs) { + if let Err(message) = args.validate() { + eprintln!("error: {message}"); + process::exit(1); + } + + let color = resolve_color_mode(args.no_color); + let db = match open_lookup_db(&args.path) { + Ok(db) => db, + Err(message) => { + eprintln!("error: {message}"); + process::exit(1); + } + }; + + let kind = match args.kind.as_deref() { + Some(kind) => match parse_lookup_kind(kind) { + Ok(kind) => Some(kind), + Err(message) => { + eprintln!("error: {message}"); + process::exit(1); + } + }, + None => None, + }; + + let results = match (args.symbol.as_deref(), args.prefix.as_deref()) { + (Some(symbol), None) => execute_symbol_lookup(&db, symbol, kind.as_ref(), &args.lang), + (None, Some(prefix)) => execute_prefix_lookup(&db, prefix, kind.as_ref(), &args.lang), + _ => Err("specify exactly one of --symbol or --prefix".to_string()), + }; + + let results = match results { + Ok(results) => results, + Err(message) => { + eprintln!("error: {message}"); + process::exit(1); + } + }; + + let mut stdout = std::io::stdout().lock(); + print_lookup_results(&results, &color, &mut stdout); +} + fn resolve_lang(lang_str: &str) -> Language { match lang_str { "rust" => Language::Rust, @@ -599,10 +842,16 @@ fn run_search( } fn main() { - let cli = Cli::parse(); + let app = App::parse(); + let cli = &app.cli; + + if let Some(Commands::Lookup(args)) = &app.command { + run_lookup_mode(args); + return; + } if let Some(shell) = cli.generate_completions { - let mut cmd = Cli::command(); + let mut cmd = App::command(); generate(shell, &mut cmd, "dora", &mut std::io::stdout()); process::exit(0); } @@ -822,7 +1071,7 @@ fn run_rewrite_mode( mod tests { use super::{ format_file_error, handle_file_error, resolve_lang, resolve_lang_mode, Cli, FileError, - SearchOutcome, + LookupArgs, SearchOutcome, }; use crate::types::{LangMode, Language}; use clap_complete::Shell; @@ -830,6 +1079,28 @@ mod tests { use std::sync::Mutex; use std::time::{Duration, Instant}; + fn lookup_args_with_symbol(symbol: &str) -> LookupArgs { + LookupArgs { + symbol: Some(symbol.to_string()), + prefix: None, + kind: None, + path: std::env::temp_dir(), + lang: "auto".to_string(), + no_color: false, + } + } + + fn lookup_args_with_prefix(prefix: &str) -> LookupArgs { + LookupArgs { + symbol: None, + prefix: Some(prefix.to_string()), + kind: None, + path: std::env::temp_dir(), + lang: "auto".to_string(), + no_color: false, + } + } + #[test] fn test_format_file_error_walker_known_path() { let error = FileError::WalkerAccess { @@ -1006,6 +1277,71 @@ mod tests { assert_eq!(err_msg, "at least one query string must not be empty"); } + #[test] + fn test_lookup_validate_exact_symbol_only() { + let args = lookup_args_with_symbol("authenticate"); + assert!(args.validate().is_ok()); + } + + #[test] + fn test_lookup_validate_prefix_only() { + let args = lookup_args_with_prefix("auth"); + assert!(args.validate().is_ok()); + } + + #[test] + fn test_lookup_validate_requires_exactly_one_selector() { + let args = LookupArgs { + symbol: None, + prefix: None, + kind: None, + path: std::env::temp_dir(), + lang: "auto".to_string(), + no_color: false, + }; + let err = args.validate().unwrap_err(); + assert!(err.contains("exactly one")); + + let args = LookupArgs { + symbol: Some("one".to_string()), + prefix: Some("two".to_string()), + kind: None, + path: std::env::temp_dir(), + lang: "auto".to_string(), + no_color: false, + }; + let err = args.validate().unwrap_err(); + assert!(err.contains("exactly one")); + } + + #[test] + fn test_lookup_validate_rejects_bad_kind() { + let args = LookupArgs { + symbol: Some("foo".to_string()), + prefix: None, + kind: Some("not_a_kind".to_string()), + path: std::env::temp_dir(), + lang: "auto".to_string(), + no_color: false, + }; + let err = args.validate().unwrap_err(); + assert!(err.contains("unsupported kind")); + } + + #[test] + fn test_lookup_validate_rejects_bad_path() { + let args = LookupArgs { + symbol: Some("foo".to_string()), + prefix: None, + kind: None, + path: PathBuf::from("/tmp/dora_lookup_missing_dir_12345"), + lang: "auto".to_string(), + no_color: false, + }; + let err = args.validate().unwrap_err(); + assert!(err.contains("does not exist")); + } + #[test] fn test_resolve_lang_all_supported() { assert_eq!(resolve_lang("rust"), Language::Rust); diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 0000000..9991981 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,706 @@ +#![allow(clippy::missing_errors_doc, clippy::must_use_candidate)] + +use crate::types::{AppError, Result}; +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use std::convert::Infallible; +use std::fmt; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub const MEMORY_DB_FILENAME: &str = ".ast-search-memory.db"; + +#[must_use] +pub fn memory_db_path(root: &Path) -> PathBuf { + root.join(MEMORY_DB_FILENAME) +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SymbolKind { + Function, + Method, + Struct, + Enum, + Trait, + Interface, + TypeAlias, + Constant, + Variable, + Class, + Module, + Import, + Unknown, +} + +impl fmt::Display for SymbolKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::Function => "function", + Self::Method => "method", + Self::Struct => "struct", + Self::Enum => "enum", + Self::Trait => "trait", + Self::Interface => "interface", + Self::TypeAlias => "typealias", + Self::Constant => "constant", + Self::Variable => "variable", + Self::Class => "class", + Self::Module => "module", + Self::Import => "import", + Self::Unknown => "unknown", + }) + } +} + +impl FromStr for SymbolKind { + type Err = Infallible; + + fn from_str(value: &str) -> std::result::Result { + Ok(match value { + "function" => Self::Function, + "method" => Self::Method, + "struct" => Self::Struct, + "enum" => Self::Enum, + "trait" => Self::Trait, + "interface" => Self::Interface, + "typealias" => Self::TypeAlias, + "constant" => Self::Constant, + "variable" => Self::Variable, + "class" => Self::Class, + "module" => Self::Module, + "import" => Self::Import, + _ => Self::Unknown, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FileRow { + pub id: i64, + pub path: String, + pub mtime: i64, + pub language: String, + pub indexed_at: i64, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SymbolRow { + pub id: i64, + pub file_id: i64, + pub kind: SymbolKind, + pub name: String, + pub start_line: usize, + pub start_col: usize, + pub end_line: usize, + pub end_col: usize, + pub signature: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NewFileRow { + pub path: String, + pub mtime: i64, + pub language: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NewSymbolRow { + pub file_id: i64, + pub kind: SymbolKind, + pub name: String, + pub start_line: usize, + pub start_col: usize, + pub end_line: usize, + pub end_col: usize, + pub signature: Option, +} + +pub struct MemoryDb { + conn: Connection, +} + +impl MemoryDb { + pub fn open(db_path: &Path) -> Result { + let conn = Connection::open(db_path).map_err(db_error)?; + let db = Self { conn }; + db.initialize_schema()?; + Ok(db) + } + + pub fn open_in_memory() -> Result { + let conn = Connection::open_in_memory().map_err(db_error)?; + let db = Self { conn }; + db.initialize_schema()?; + Ok(db) + } + + fn initialize_schema(&self) -> Result<()> { + self.conn + .execute_batch( + "PRAGMA foreign_keys=ON;\nPRAGMA journal_mode=WAL;\nCREATE TABLE IF NOT EXISTS files (\n id INTEGER PRIMARY KEY AUTOINCREMENT,\n path TEXT NOT NULL UNIQUE,\n mtime INTEGER NOT NULL,\n language TEXT NOT NULL,\n indexed_at INTEGER NOT NULL\n);\nCREATE INDEX IF NOT EXISTS idx_files_path ON files(path);\nCREATE TABLE IF NOT EXISTS symbols (\n id INTEGER PRIMARY KEY AUTOINCREMENT,\n file_id INTEGER NOT NULL REFERENCES files(id) ON DELETE CASCADE,\n kind TEXT NOT NULL,\n name TEXT NOT NULL,\n start_line INTEGER NOT NULL,\n start_col INTEGER NOT NULL,\n end_line INTEGER NOT NULL,\n end_col INTEGER NOT NULL,\n signature TEXT\n);\nCREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name);\nCREATE INDEX IF NOT EXISTS idx_symbols_file_id ON symbols(file_id);\nCREATE INDEX IF NOT EXISTS idx_symbols_kind ON symbols(kind);", + ) + .map_err(db_error) + } + + pub fn upsert_file(&self, row: &NewFileRow) -> Result { + let indexed_at = unix_seconds_now()?; + self.conn + .execute( + "INSERT OR REPLACE INTO files(path, mtime, language, indexed_at) VALUES (?1, ?2, ?3, ?4)", + params![&row.path, row.mtime, &row.language, indexed_at], + ) + .map_err(db_error)?; + Ok(self.conn.last_insert_rowid()) + } + + pub fn get_file_by_path(&self, path: &str) -> Result> { + self.conn + .query_row( + "SELECT id, path, mtime, language, indexed_at FROM files WHERE path = ?1", + params![path], + |row| { + Ok(FileRow { + id: row.get(0)?, + path: row.get(1)?, + mtime: row.get(2)?, + language: row.get(3)?, + indexed_at: row.get(4)?, + }) + }, + ) + .optional() + .map_err(db_error) + } + + pub fn get_file_by_id(&self, id: i64) -> Result> { + self.conn + .query_row( + "SELECT id, path, mtime, language, indexed_at FROM files WHERE id = ?1", + params![id], + |row| { + Ok(FileRow { + id: row.get(0)?, + path: row.get(1)?, + mtime: row.get(2)?, + language: row.get(3)?, + indexed_at: row.get(4)?, + }) + }, + ) + .optional() + .map_err(db_error) + } + + pub fn delete_file_by_path(&self, path: &str) -> Result { + self.conn.execute("DELETE FROM files WHERE path = ?1", params![path]).map_err(db_error) + } + + pub fn list_files(&self) -> Result> { + let mut stmt = self + .conn + .prepare("SELECT id, path, mtime, language, indexed_at FROM files ORDER BY path ASC") + .map_err(db_error)?; + let rows = stmt + .query_map([], |row| { + Ok(FileRow { + id: row.get(0)?, + path: row.get(1)?, + mtime: row.get(2)?, + language: row.get(3)?, + indexed_at: row.get(4)?, + }) + }) + .map_err(db_error)?; + collect_rows(rows) + } + + pub fn file_count(&self) -> Result { + let count: i64 = self + .conn + .query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0)) + .map_err(db_error)?; + i64_to_usize(count) + } + + pub fn insert_symbol(&self, row: &NewSymbolRow) -> Result { + self.conn + .execute( + "INSERT INTO symbols(file_id, kind, name, start_line, start_col, end_line, end_col, signature) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + params![ + row.file_id, + row.kind.to_string(), + &row.name, + usize_to_i64(row.start_line)?, + usize_to_i64(row.start_col)?, + usize_to_i64(row.end_line)?, + usize_to_i64(row.end_col)?, + row.signature.as_deref(), + ], + ) + .map_err(db_error)?; + Ok(self.conn.last_insert_rowid()) + } + + pub fn insert_symbols_batch(&self, rows: &[NewSymbolRow]) -> Result { + self.conn.execute_batch("BEGIN IMMEDIATE TRANSACTION;").map_err(db_error)?; + let mut inserted = 0usize; + let result = (|| -> Result { + for row in rows { + self.conn + .execute( + "INSERT INTO symbols(file_id, kind, name, start_line, start_col, end_line, end_col, signature) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + params![ + row.file_id, + row.kind.to_string(), + &row.name, + usize_to_i64(row.start_line)?, + usize_to_i64(row.start_col)?, + usize_to_i64(row.end_line)?, + usize_to_i64(row.end_col)?, + row.signature.as_deref(), + ], + ) + .map_err(db_error)?; + inserted += 1; + } + self.conn.execute_batch("COMMIT;").map_err(db_error)?; + Ok(inserted) + })(); + if let Err(error) = result { + let _ = self.conn.execute_batch("ROLLBACK;"); + return Err(error); + } + result + } + + pub fn get_symbols_for_file(&self, file_id: i64) -> Result> { + self.query_symbols( + "SELECT id, file_id, kind, name, start_line, start_col, end_line, end_col, signature FROM symbols WHERE file_id = ?1 ORDER BY start_line ASC, start_col ASC", + params![file_id], + ) + } + + pub fn find_symbols_by_name(&self, name: &str) -> Result> { + self.query_symbols( + "SELECT id, file_id, kind, name, start_line, start_col, end_line, end_col, signature FROM symbols WHERE name = ?1 ORDER BY file_id ASC, start_line ASC, start_col ASC", + params![name], + ) + } + + pub fn find_symbols_by_name_prefix(&self, prefix: &str) -> Result> { + let pattern = format!("{prefix}%"); + self.query_symbols( + "SELECT id, file_id, kind, name, start_line, start_col, end_line, end_col, signature FROM symbols WHERE name LIKE ?1 ORDER BY name ASC, file_id ASC LIMIT 100", + params![pattern], + ) + } + + pub fn find_symbols_by_kind(&self, kind: &SymbolKind) -> Result> { + self.query_symbols( + "SELECT id, file_id, kind, name, start_line, start_col, end_line, end_col, signature FROM symbols WHERE kind = ?1 ORDER BY name ASC, file_id ASC", + params![kind.to_string()], + ) + } + + pub fn find_symbols_by_name_and_kind( + &self, + name: &str, + kind: &SymbolKind, + ) -> Result> { + self.query_symbols( + "SELECT id, file_id, kind, name, start_line, start_col, end_line, end_col, signature FROM symbols WHERE name = ?1 AND kind = ?2 ORDER BY file_id ASC, start_line ASC, start_col ASC", + params![name, kind.to_string()], + ) + } + + pub fn delete_symbols_for_file(&self, file_id: i64) -> Result { + self.conn + .execute("DELETE FROM symbols WHERE file_id = ?1", params![file_id]) + .map_err(db_error) + } + + pub fn symbol_count(&self) -> Result { + let count: i64 = self + .conn + .query_row("SELECT COUNT(*) FROM symbols", [], |row| row.get(0)) + .map_err(db_error)?; + i64_to_usize(count) + } + + fn query_symbols(&self, sql: &str, params: impl rusqlite::Params) -> Result> { + let mut stmt = self.conn.prepare(sql).map_err(db_error)?; + let rows = stmt + .query_map(params, |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + row.get::<_, i64>(4)?, + row.get::<_, i64>(5)?, + row.get::<_, i64>(6)?, + row.get::<_, i64>(7)?, + row.get::<_, Option>(8)?, + )) + }) + .map_err(db_error)?; + let mut symbols = Vec::new(); + for row in rows { + let (id, file_id, kind, name, start_line, start_col, end_line, end_col, signature) = + row.map_err(db_error)?; + symbols.push(SymbolRow { + id, + file_id, + kind: kind.parse().unwrap(), + name, + start_line: i64_to_usize(start_line)?, + start_col: i64_to_usize(start_col)?, + end_line: i64_to_usize(end_line)?, + end_col: i64_to_usize(end_col)?, + signature, + }); + } + Ok(symbols) + } +} + +#[allow(clippy::needless_pass_by_value)] +fn db_error(error: rusqlite::Error) -> AppError { + AppError::DbError(error.to_string()) +} + +fn unix_seconds_now() -> Result { + let seconds = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|error| AppError::DbError(error.to_string()))? + .as_secs(); + i64::try_from(seconds).map_err(|_| AppError::DbError("timestamp out of range".to_string())) +} + +fn usize_to_i64(value: usize) -> Result { + i64::try_from(value) + .map_err(|_| AppError::DbError(format!("value out of range for i64: {value}"))) +} + +fn i64_to_usize(value: i64) -> Result { + usize::try_from(value) + .map_err(|_| AppError::DbError(format!("value out of range for usize: {value}"))) +} + +fn collect_rows( + rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result>, +) -> Result> { + let mut items = Vec::new(); + for row in rows { + items.push(row.map_err(db_error)?); + } + Ok(items) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_db() -> MemoryDb { + MemoryDb::open_in_memory().unwrap() + } + + fn make_file_row(path: &str, mtime: i64) -> NewFileRow { + NewFileRow { path: path.to_string(), mtime, language: "rust".to_string() } + } + + fn insert_file(db: &MemoryDb, path: &str, mtime: i64) -> i64 { + db.upsert_file(&make_file_row(path, mtime)).unwrap() + } + + fn make_symbol_row( + file_id: i64, + kind: SymbolKind, + name: &str, + start_line: usize, + ) -> NewSymbolRow { + NewSymbolRow { + file_id, + kind, + name: name.to_string(), + start_line, + start_col: 0, + end_line: start_line, + end_col: 1, + signature: None, + } + } + + #[test] + fn test_schema_creates_without_error() { + assert!(MemoryDb::open_in_memory().is_ok()); + } + + #[test] + fn test_upsert_file_returns_id() { + let db = make_db(); + let id = db.upsert_file(&make_file_row("/tmp/a.rs", 100)).unwrap(); + assert!(id > 0); + } + + #[test] + fn test_get_file_by_path_after_upsert() { + let db = make_db(); + db.upsert_file(&make_file_row("/tmp/a.rs", 100)).unwrap(); + let row = db.get_file_by_path("/tmp/a.rs").unwrap().unwrap(); + assert_eq!(row.path, "/tmp/a.rs"); + assert_eq!(row.mtime, 100); + assert_eq!(row.language, "rust"); + } + + #[test] + fn test_get_file_by_path_returns_none_for_missing() { + let db = make_db(); + assert!(db.get_file_by_path("nonexistent").unwrap().is_none()); + } + + #[test] + fn test_get_file_by_id_after_upsert() { + let db = make_db(); + let id = db.upsert_file(&make_file_row("/tmp/a.rs", 100)).unwrap(); + let row = db.get_file_by_id(id).unwrap().unwrap(); + assert_eq!(row.id, id); + assert_eq!(row.path, "/tmp/a.rs"); + assert_eq!(row.language, "rust"); + } + + #[test] + fn test_get_file_by_id_returns_none_for_missing() { + let db = make_db(); + assert!(db.get_file_by_id(9999).unwrap().is_none()); + } + + #[test] + fn test_upsert_file_replaces_existing() { + let db = make_db(); + insert_file(&db, "/tmp/a.rs", 100); + insert_file(&db, "/tmp/a.rs", 200); + let row = db.get_file_by_path("/tmp/a.rs").unwrap().unwrap(); + assert_eq!(row.mtime, 200); + assert_eq!(db.file_count().unwrap(), 1); + } + + #[test] + fn test_delete_file_removes_row() { + let db = make_db(); + insert_file(&db, "/tmp/a.rs", 100); + db.delete_file_by_path("/tmp/a.rs").unwrap(); + assert!(db.get_file_by_path("/tmp/a.rs").unwrap().is_none()); + } + + #[test] + fn test_delete_cascades_to_symbols() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "one", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Struct, "two", 2)).unwrap(); + db.delete_file_by_path("/tmp/a.rs").unwrap(); + assert_eq!(db.symbol_count().unwrap(), 0); + } + + #[test] + fn test_insert_symbol_returns_id() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + let id = + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "one", 1)).unwrap(); + assert!(id > 0); + } + + #[test] + fn test_get_symbols_for_file_returns_all() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "one", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Struct, "two", 2)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Enum, "three", 3)).unwrap(); + let symbols = db.get_symbols_for_file(file_id).unwrap(); + assert_eq!(symbols.len(), 3); + } + + #[test] + fn test_find_symbols_by_name_exact() { + let db = make_db(); + let file_one = insert_file(&db, "/tmp/a.rs", 100); + let file_two = insert_file(&db, "/tmp/b.rs", 100); + db.insert_symbol(&make_symbol_row(file_one, SymbolKind::Function, "foo", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_two, SymbolKind::Function, "bar", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_two, SymbolKind::Struct, "foo", 2)).unwrap(); + let symbols = db.find_symbols_by_name("foo").unwrap(); + assert_eq!(symbols.len(), 2); + } + + #[test] + fn test_find_symbols_by_name_prefix() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "authenticate", 1)) + .unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "authorize", 2)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "connect", 3)).unwrap(); + let symbols = db.find_symbols_by_name_prefix("auth").unwrap(); + assert_eq!(symbols.len(), 2); + } + + #[test] + fn test_find_symbols_by_kind() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "foo", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Struct, "bar", 2)).unwrap(); + let symbols = db.find_symbols_by_kind(&SymbolKind::Function).unwrap(); + assert!(symbols.iter().all(|symbol| symbol.kind == SymbolKind::Function)); + } + + #[test] + fn test_find_symbols_by_name_and_kind() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "new", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Struct, "new", 2)).unwrap(); + let symbols = db.find_symbols_by_name_and_kind("new", &SymbolKind::Function).unwrap(); + assert_eq!(symbols.len(), 1); + assert_eq!(symbols[0].kind, SymbolKind::Function); + } + + #[test] + fn test_insert_symbols_batch_is_atomic() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + let rows = vec![ + make_symbol_row(file_id, SymbolKind::Function, "one", 1), + make_symbol_row(file_id, SymbolKind::Struct, "two", 2), + make_symbol_row(file_id, SymbolKind::Enum, "three", 3), + ]; + let inserted = db.insert_symbols_batch(&rows).unwrap(); + assert_eq!(inserted, 3); + assert_eq!(db.symbol_count().unwrap(), 3); + } + + #[test] + fn test_delete_symbols_for_file() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "one", 1)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Struct, "two", 2)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Enum, "three", 3)).unwrap(); + db.delete_symbols_for_file(file_id).unwrap(); + assert_eq!(db.symbol_count().unwrap(), 0); + } + + #[test] + fn test_symbol_kind_display_roundtrip() { + let kinds = [ + SymbolKind::Function, + SymbolKind::Method, + SymbolKind::Struct, + SymbolKind::Enum, + SymbolKind::Trait, + SymbolKind::Interface, + SymbolKind::TypeAlias, + SymbolKind::Constant, + SymbolKind::Variable, + SymbolKind::Class, + SymbolKind::Module, + SymbolKind::Import, + SymbolKind::Unknown, + ]; + + for kind in kinds { + let text = kind.to_string(); + let parsed: SymbolKind = text.parse().unwrap(); + assert_eq!(kind, parsed); + } + } + + #[test] + fn test_symbol_kind_unknown_for_garbage_string() { + let parsed: SymbolKind = "not_a_real_kind".parse().unwrap(); + assert_eq!(parsed, SymbolKind::Unknown); + } + + #[test] + fn test_find_symbols_prefix_limit() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + let mut rows = Vec::new(); + for i in 0..150usize { + rows.push(make_symbol_row( + file_id, + SymbolKind::Function, + &format!("sym_{i:03}"), + i + 1, + )); + } + db.insert_symbols_batch(&rows).unwrap(); + let symbols = db.find_symbols_by_name_prefix("sym_").unwrap(); + assert!(symbols.len() <= 100); + } + + #[test] + fn test_list_files_sorted_by_path() { + let db = make_db(); + insert_file(&db, "/tmp/z.rs", 100); + insert_file(&db, "/tmp/a.rs", 100); + insert_file(&db, "/tmp/m.rs", 100); + let files = db.list_files().unwrap(); + let paths: Vec<_> = files.iter().map(|row| row.path.as_str()).collect(); + assert_eq!(paths, vec!["/tmp/a.rs", "/tmp/m.rs", "/tmp/z.rs"]); + } + + #[test] + fn test_file_count_accurate() { + let db = make_db(); + insert_file(&db, "/tmp/a.rs", 100); + insert_file(&db, "/tmp/b.rs", 100); + insert_file(&db, "/tmp/c.rs", 100); + insert_file(&db, "/tmp/d.rs", 100); + insert_file(&db, "/tmp/e.rs", 100); + assert_eq!(db.file_count().unwrap(), 5); + db.delete_file_by_path("/tmp/a.rs").unwrap(); + db.delete_file_by_path("/tmp/b.rs").unwrap(); + assert_eq!(db.file_count().unwrap(), 3); + } + + #[test] + fn test_symbol_count_accurate() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + for i in 0..4usize { + db.insert_symbol(&make_symbol_row( + file_id, + SymbolKind::Function, + &format!("s{i}"), + i + 1, + )) + .unwrap(); + } + assert_eq!(db.symbol_count().unwrap(), 4); + } + + #[test] + fn test_get_symbols_for_file_ordered_by_position() { + let db = make_db(); + let file_id = insert_file(&db, "/tmp/a.rs", 100); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "three", 10)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "two", 5)).unwrap(); + db.insert_symbol(&make_symbol_row(file_id, SymbolKind::Function, "one", 1)).unwrap(); + let symbols = db.get_symbols_for_file(file_id).unwrap(); + let lines: Vec<_> = symbols.iter().map(|row| row.start_line).collect(); + assert_eq!(lines, vec![1, 5, 10]); + } + + #[test] + fn test_wal_mode_enabled() { + let db = make_db(); + let mode: String = db.conn.query_row("PRAGMA journal_mode", [], |row| row.get(0)).unwrap(); + assert!(!mode.is_empty()); + assert!(mode == "memory" || mode == "wal"); + } +} diff --git a/src/output.rs b/src/output.rs index ee09ce0..43bb6af 100644 --- a/src/output.rs +++ b/src/output.rs @@ -1,3 +1,4 @@ +use crate::memory::{FileRow, SymbolRow}; use crate::types::MatchResult; use std::io::Write; use std::time::Duration; @@ -122,6 +123,32 @@ pub fn print_match(result: &MatchResult, color: &ColorMode, writer: &m writer.write_all(line.as_bytes()).expect("failed to write match output"); } +pub fn print_lookup_results( + results: &[(SymbolRow, FileRow)], + color: &ColorMode, + writer: &mut W, +) { + for (symbol, file) in results { + let filepath = colorize(&file.path, CYAN, color); + let kind = colorize(&symbol.kind.to_string(), YELLOW, color); + let name = colorize(&symbol.name, GREEN, color); + let line = format!( + "{filepath}:{line}:{col} [@{kind}] \"{name}\"\n", + line = symbol.start_line, + col = symbol.start_col, + ); + writer.write_all(line.as_bytes()).expect("failed to write lookup output"); + if matches!(color, ColorMode::On) { + if let Some(signature) = &symbol.signature { + let signature_line = format!(" signature: {signature}\n"); + writer + .write_all(signature_line.as_bytes()) + .expect("failed to write lookup signature output"); + } + } + } +} + /// Build the summary string (printed to stderr) without emitting it. /// /// Always formats duration as milliseconds and chooses singular/plural @@ -191,9 +218,10 @@ pub fn print_summary( #[cfg(test)] mod tests { use super::{ - colorize, format_match, format_summary, plural_file, plural_match, print_match, - print_summary, resolve_color_mode, ColorMode, CYAN, GREEN, YELLOW, + colorize, format_match, format_summary, plural_file, plural_match, print_lookup_results, + print_match, print_summary, resolve_color_mode, ColorMode, CYAN, GREEN, YELLOW, }; + use crate::memory::{FileRow, SymbolKind, SymbolRow}; use crate::types::MatchResult; use std::io::Write; use std::path::PathBuf; @@ -217,6 +245,29 @@ mod tests { } } + fn canonical_lookup_rows() -> (SymbolRow, FileRow) { + ( + SymbolRow { + id: 7, + file_id: 3, + kind: SymbolKind::Function, + name: "authenticate".to_string(), + start_line: 42, + start_col: 4, + end_line: 42, + end_col: 16, + signature: Some("fn authenticate(user: User) -> bool".to_string()), + }, + FileRow { + id: 3, + path: "src/auth/handler.rs".to_string(), + mtime: 1, + language: "rust".to_string(), + indexed_at: 1, + }, + ) + } + struct FailWriter; impl Write for FailWriter { @@ -618,4 +669,30 @@ mod tests { assert!(output.contains("fn foo()")); assert!(output.lines().any(|line| line.contains("block"))); } + + #[test] + fn test_print_lookup_results_writes_match_like_output() { + let (symbol, file) = canonical_lookup_rows(); + let mut buf: Vec = Vec::new(); + print_lookup_results(&[(symbol, file)], &ColorMode::Off, &mut buf); + let output = buf_to_string(buf); + + assert!(output.contains("src/auth/handler.rs:42:4")); + assert!(output.contains("[@function]")); + assert!(output.contains("\"authenticate\"")); + assert!(!output.contains("signature:")); + } + + #[test] + fn test_print_lookup_results_prints_signature_only_with_color() { + let (symbol, file) = canonical_lookup_rows(); + let mut buf: Vec = Vec::new(); + print_lookup_results(&[(symbol, file)], &ColorMode::On, &mut buf); + let output = buf_to_string(buf); + + assert!(output.contains("signature: fn authenticate(user: User) -> bool")); + assert!(output.contains("\x1b[36m")); + assert!(output.contains("\x1b[33m")); + assert!(output.contains("\x1b[32m")); + } } diff --git a/src/types.rs b/src/types.rs index 8aa3906..7bd6bae 100644 --- a/src/types.rs +++ b/src/types.rs @@ -57,6 +57,8 @@ pub enum AppError { #[error("Language not supported: {0}")] LanguageNotSupported(String), + #[error("database error: {0}")] + DbError(String), #[error("index file is corrupt or unreadable: {0}")] IndexCorrupt(String), diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 2af057a..aca9f10 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,3 +1,6 @@ +use dora::extractor::SymbolExtractor; +use dora::memory::{MemoryDb, NewFileRow, NewSymbolRow, SymbolKind}; +use dora::output::{print_lookup_results, ColorMode}; use dora::parser::{get_language, parse_file}; use dora::query::{compile_query, extract_matches}; use dora::types::{Language, MatchResult}; @@ -45,6 +48,41 @@ fn fixtures_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests").join("fixtures") } +#[test] +fn test_lookup_join_and_formatter_integration() { + let db = MemoryDb::open_in_memory().unwrap(); + let file_id = db + .upsert_file(&NewFileRow { + path: "/tmp/example.rs".to_string(), + mtime: 1, + language: "rust".to_string(), + }) + .unwrap(); + db.insert_symbol(&NewSymbolRow { + file_id, + kind: SymbolKind::Function, + name: "authenticate".to_string(), + start_line: 12, + start_col: 4, + end_line: 12, + end_col: 16, + signature: Some("fn authenticate(user: User) -> bool".to_string()), + }) + .unwrap(); + + let symbol = db.find_symbols_by_name("authenticate").unwrap().pop().unwrap(); + let file = db.get_file_by_id(symbol.file_id).unwrap().unwrap(); + + let mut buf = Vec::new(); + print_lookup_results(&[(symbol, file)], &ColorMode::Off, &mut buf); + let output = String::from_utf8(buf).unwrap(); + + assert!(output.contains("/tmp/example.rs:12:4")); + assert!(output.contains("[@function]")); + assert!(output.contains("\"authenticate\"")); + assert!(!output.contains("signature:")); +} + #[allow(dead_code)] fn run_pipeline_for_language( fixture_dir: &Path, @@ -2272,3 +2310,106 @@ fn test_rewrite_dry_run_does_not_modify_fixture() { let after = std::fs::read_to_string(&fixture).unwrap(); assert_eq!(before, after); } + +#[test] +fn test_persist_inserts_symbols_for_rust_fixture() { + let fixture = fixtures_dir().join("simple.rs"); + let db = MemoryDb::open_in_memory().unwrap(); + let file_id = db + .upsert_file(&NewFileRow { + path: fixture.display().to_string(), + mtime: 1, + language: "rust".to_string(), + }) + .unwrap(); + let ts_lang = get_language("rust").unwrap(); + let (tree, source) = parse_file(&fixture, &ts_lang).unwrap(); + let extractor = SymbolExtractor { language: Language::Rust }; + let symbols = extractor.extract(&tree, &source, file_id); + db.insert_symbols_batch(&symbols).unwrap(); + assert!(db.symbol_count().unwrap() > 0); + assert!(!db.find_symbols_by_name("add").unwrap().is_empty()); +} + +#[test] +fn test_persist_extracts_all_fixture_languages() { + let fixtures = [ + ("simple.rs", "rust", Language::Rust), + ("simple.py", "python", Language::Python), + ("simple.js", "js", Language::JavaScript), + ("simple.ts", "ts", Language::TypeScript), + ("simple.go", "go", Language::Go), + ("simple.c", "c", Language::C), + ("simple.cpp", "cpp", Language::Cpp), + ]; + + for (fixture_name, lang_str, language) in fixtures { + let fixture = fixtures_dir().join(fixture_name); + let ts_lang = get_language(lang_str).unwrap(); + let (tree, source) = parse_file(&fixture, &ts_lang).unwrap(); + let extractor = SymbolExtractor { language }; + let symbols = extractor.extract(&tree, &source, 1); + assert!(!symbols.is_empty(), "expected symbols for {fixture_name}"); + } +} + +#[test] +fn test_persist_symbol_positions_match_grep() { + let fixture = fixtures_dir().join("simple.rs"); + let ts_lang = get_language("rust").unwrap(); + let (tree, source) = parse_file(&fixture, &ts_lang).unwrap(); + let extractor = SymbolExtractor { language: Language::Rust }; + let symbols = extractor.extract(&tree, &source, 1); + let add = symbols.iter().find(|symbol| symbol.name == "add").unwrap(); + assert_eq!(add.start_line, 1); +} + +#[test] +fn test_persist_reindex_replaces_old_symbols() { + let db = MemoryDb::open_in_memory().unwrap(); + let file_id = db + .upsert_file(&NewFileRow { + path: "/tmp/reindex.rs".to_string(), + mtime: 1, + language: "rust".to_string(), + }) + .unwrap(); + let batch_a = vec![ + NewSymbolRow { + file_id, + kind: SymbolKind::Function, + name: "old_fn".to_string(), + start_line: 1, + start_col: 0, + end_line: 1, + end_col: 6, + signature: None, + }, + NewSymbolRow { + file_id, + kind: SymbolKind::Struct, + name: "OldStruct".to_string(), + start_line: 2, + start_col: 0, + end_line: 2, + end_col: 9, + signature: None, + }, + ]; + db.insert_symbols_batch(&batch_a).unwrap(); + db.delete_symbols_for_file(file_id).unwrap(); + let batch_b = vec![NewSymbolRow { + file_id, + kind: SymbolKind::Function, + name: "new_fn".to_string(), + start_line: 3, + start_col: 0, + end_line: 3, + end_col: 6, + signature: None, + }]; + db.insert_symbols_batch(&batch_b).unwrap(); + assert!(db.find_symbols_by_name("old_fn").unwrap().is_empty()); + assert!(db.find_symbols_by_name("OldStruct").unwrap().is_empty()); + assert_eq!(db.find_symbols_by_name("new_fn").unwrap().len(), 1); +}