diff --git a/crates/tui/src/tools/file_search.rs b/crates/tui/src/tools/file_search.rs index c417e81e4..ccc72d231 100644 --- a/crates/tui/src/tools/file_search.rs +++ b/crates/tui/src/tools/file_search.rs @@ -7,8 +7,9 @@ use async_trait::async_trait; use ignore::WalkBuilder; use serde::Serialize; use serde_json::{Value, json}; +use tokio_util::sync::CancellationToken; -use crate::tools::search::matches_glob; +use crate::tools::search::{check_cancelled, matches_glob}; use super::spec::{ ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, @@ -87,7 +88,14 @@ impl ToolSpec for FileSearchTool { let extensions = parse_extensions(&input); let exclude_patterns = parse_exclude_patterns(&input); - let matches = search_files(query, &base_path, extensions, exclude_patterns, limit)?; + let matches = search_files( + query, + &base_path, + extensions, + exclude_patterns, + limit, + context.cancel_token.as_ref(), + )?; ToolResult::json(&matches).map_err(|e| ToolError::execution_failed(e.to_string())) } } @@ -147,6 +155,7 @@ fn search_files( extensions: Vec, exclude_patterns: Vec, limit: usize, + cancel_token: Option<&CancellationToken>, ) -> Result, ToolError> { if !base_path.exists() { return Err(ToolError::invalid_input(format!( @@ -163,6 +172,8 @@ fn search_files( let walker = builder.build(); for entry in walker { + check_cancelled(cancel_token)?; + let entry = match entry { Ok(entry) => entry, Err(_) => continue, @@ -430,4 +441,24 @@ mod tests { assert!(result.success); assert!(!result.content.contains("secret.txt")); } + + #[tokio::test] + async fn test_file_search_respects_cancel_token() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join("needle.txt"), "x\n").expect("write"); + let cancel_token = CancellationToken::new(); + cancel_token.cancel(); + let ctx = ToolContext::new(tmp.path().to_path_buf()).with_cancel_token(cancel_token); + + let tool = FileSearchTool; + let err = tool + .execute(json!({"query": "needle"}), &ctx) + .await + .expect_err("cancelled file_search should return an error"); + + assert!( + format!("{err:?}").contains("cancelled"), + "unexpected error: {err:?}" + ); + } } diff --git a/crates/tui/src/tools/search.rs b/crates/tui/src/tools/search.rs index b4fc8d1f6..7840ecb32 100644 --- a/crates/tui/src/tools/search.rs +++ b/crates/tui/src/tools/search.rs @@ -347,7 +347,7 @@ fn collect_files_recursive( Ok(()) } -fn check_cancelled(cancel_token: Option<&CancellationToken>) -> Result<(), ToolError> { +pub(super) fn check_cancelled(cancel_token: Option<&CancellationToken>) -> Result<(), ToolError> { if cancel_token.is_some_and(CancellationToken::is_cancelled) { return Err(ToolError::execution_failed( "search cancelled before completion",