diff --git a/src/command/profile.rs b/src/command/profile.rs index 76d361c..ae6f29e 100644 --- a/src/command/profile.rs +++ b/src/command/profile.rs @@ -86,19 +86,28 @@ async fn autocomplete_voice_name( ctx: Context<'_>, partial: &str, ) -> impl Iterator { - let candidates = ctx.data().registry.find_prefixed_all(partial); - candidates.map(|(id, package)| { - AutocompleteChoice::new( - match package.detail.description.as_ref() { - Some(description) => format!( - "{} | {} ({})", - package.detail.provider, package.detail.name, description - ), - None => format!("{} | {}", package.detail.provider, package.detail.name), - }, - id, - ) - }) + let keywords: Vec<&str> = partial.split_whitespace().filter(|s| *s != "|").collect(); + let candidates = ctx + .data() + .registry + .find_matching_keywords(keywords.as_ref()); + + candidates + .map(|(id, package)| { + AutocompleteChoice::new( + match package.detail.description.as_ref() { + Some(description) => format!( + "{} | {} ({})", + package.detail.provider, package.detail.name, description + ), + None => format!("{} | {}", package.detail.provider, package.detail.name), + }, + id, + ) + }) + .take(25) + .collect::>() + .into_iter() } async fn common_choose(ctx: Context<'_>, scope: Scope, name: String) -> Result<()> { diff --git a/src/tts/registry.rs b/src/tts/registry.rs index c44819b..b68b3ad 100644 --- a/src/tts/registry.rs +++ b/src/tts/registry.rs @@ -12,6 +12,15 @@ use std::sync::Arc; pub struct VoicePackage { pub voice: Arc, pub detail: VoiceDetail, + pub search_index: String, +} + +impl VoicePackage { + fn matches_keywords(&self, keywords: &[String]) -> bool { + keywords + .iter() + .all(|keyword| self.search_index.contains(keyword)) + } } #[derive(Clone)] @@ -45,6 +54,17 @@ impl VoicePackageRegistry { .filter(move |&(_, package)| package.detail.name.starts_with(prefix)) .map(|(id, voice)| (id.as_str(), voice)) } + + pub fn find_matching_keywords( + &self, + keywords: &[&str], + ) -> impl Iterator { + let normalized_keywords: Vec = keywords.iter().map(|s| s.to_lowercase()).collect(); + self.packages + .iter() + .filter(move |&(_, package)| package.matches_keywords(&normalized_keywords)) + .map(|(id, package)| (id.as_str(), package)) + } } pub struct VoiceRegistryBuilder { @@ -112,7 +132,22 @@ impl VoiceRegistryBuilder { } }; - voices.insert(id.to_string(), VoicePackage { voice, detail }); + let search_index = format!( + "{} {} {}", + detail.name, + detail.provider, + detail.description.as_deref().unwrap_or("") + ) + .to_lowercase(); + + voices.insert( + id.to_string(), + VoicePackage { + voice, + detail, + search_index, + }, + ); } Ok(VoicePackageRegistry::new(voices)) @@ -137,6 +172,7 @@ mod tests { use super::*; use crate::config::{ CacheConfig, DatabaseConfig, DatabaseKind, InMemoryCacheConfig, ProfileConfig, + VoiceDetailConfig, }; use crate::tts::google_cloud::GoogleCloudVoiceConfig; @@ -145,7 +181,10 @@ mod tests { profiles.insert( "test_preset".to_string(), ProfileConfig { - note: Default::default(), + note: Some(VoiceDetailConfig { + name: Some("ja-JP-Wavenet-A".to_string()), + description: Some("test description".to_string()), + }), voice_backend: ProfileBackendConfig::GoogleCloudVoice(GoogleCloudVoiceConfig { language_code: "ja-JP".to_string(), name: Some("ja-JP-Wavenet-A".to_string()), @@ -226,11 +265,39 @@ mod tests { } #[tokio::test] - async fn test_build_fail_missing_client() { + async fn test_find_matching_keywords() { let config = create_test_config(CacheConfig::Disabled); + let client = create_dummy_client().await; + + let registry = VoicePackageRegistry::builder(config) + .google_cloud(client) + .build() + .expect("Should build successfully"); - // build without client - let result = VoicePackageRegistry::builder(config).build(); - assert!(result.is_err()); + // "test" should match "test_preset" name + let keywords = vec!["test"]; + let results: Vec<_> = registry.find_matching_keywords(&keywords).collect(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, "test_preset"); + + // "WAVENET" (uppercase) should match "ja-JP-Wavenet-A" (case-insensitive) + let keywords = vec!["WAVENET"]; + let results: Vec<_> = registry.find_matching_keywords(&keywords).collect(); + assert_eq!(results.len(), 1); + + // "google" should match provider + let keywords = vec!["google"]; + let results: Vec<_> = registry.find_matching_keywords(&keywords).collect(); + assert_eq!(results.len(), 1); + + // multiple keywords (AND) + let keywords = vec!["test", "google"]; + let results: Vec<_> = registry.find_matching_keywords(&keywords).collect(); + assert_eq!(results.len(), 1); + + // "nonexistent" should not match + let keywords = vec!["nonexistent"]; + let results: Vec<_> = registry.find_matching_keywords(&keywords).collect(); + assert_eq!(results.is_empty(), true); } }