Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions src/browser/approval_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ pub struct ApprovalGateMiddleware {
denial_counts: Mutex<HashMap<String, u32>>,
/// Set when the same action is denied twice — triggers session termination.
user_denied: AtomicBool,
/// When true, also listen for a spoken "yes/approve" alongside the keyboard reply.
voice: bool,
}

impl ApprovalGateMiddleware {
Expand All @@ -239,6 +241,7 @@ impl ApprovalGateMiddleware {
current_url: Arc<tokio::sync::Mutex<String>>,
approval_tx: mpsc::Sender<ApprovalPrompt>,
step_counter: Arc<AtomicU32>,
voice: bool,
) -> Self {
Self {
gate,
Expand All @@ -248,6 +251,7 @@ impl ApprovalGateMiddleware {
step_counter,
denial_counts: Mutex::new(HashMap::new()),
user_denied: AtomicBool::new(false),
voice,
}
}

Expand Down Expand Up @@ -313,7 +317,8 @@ impl ToolMiddleware for ApprovalGateMiddleware {
MiddlewareVerdict::Allow
}
GateVerdict::RequireConfirmation { reason, .. } => {
let step = self.step_counter.load(Ordering::Relaxed);
// +1 because the step_emitter middleware increments after the gate runs.
let step = self.step_counter.load(Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
let prompt = ApprovalPrompt {
step,
Expand All @@ -330,14 +335,34 @@ impl ToolMiddleware for ApprovalGateMiddleware {
};
}
use tokio::time::{timeout, Duration};
match timeout(Duration::from_secs(60), rx).await {
Ok(Ok(true)) => {
// If voice is on, race the keyboard reply against a voice-approval
// listener. Whichever resolves first wins. Voice only contributes
// an Approve vote (false/timeout is ignored unless no keyboard reply
// arrives either).
let approved_opt: Option<bool> = if self.voice {
tokio::select! {
kb = timeout(Duration::from_secs(60), rx) => match kb {
Ok(Ok(b)) => Some(b),
_ => None,
},
voice_yes = crate::voice::await_voice_approval(60) => {
if voice_yes { Some(true) } else { None }
}
}
} else {
match timeout(Duration::from_secs(60), rx).await {
Ok(Ok(b)) => Some(b),
_ => None,
}
};
match approved_opt {
Some(true) => {
// Approved — clear denial counter for this action.
let key = format!("{tool_name}:{target_text}");
self.denial_counts.lock().unwrap_or_else(|e| e.into_inner()).remove(&key);
MiddlewareVerdict::Allow
}
Ok(Ok(false)) => {
Some(false) => {
// User explicitly denied — increment counter; terminate after 2 denials.
let key = format!("{tool_name}:{target_text}");
let count = {
Expand All @@ -355,19 +380,15 @@ impl ToolMiddleware for ApprovalGateMiddleware {
}
MiddlewareVerdict::Deny { reason }
}
Ok(Err(_)) => MiddlewareVerdict::Deny {
reason: "approval channel dropped".into(),
},
Err(_) => MiddlewareVerdict::Deny {
reason: "approval timed out (60s)".into(),
None => MiddlewareVerdict::Deny {
reason: "approval timed out or channel dropped".into(),
},
}
}
}
}

async fn after_tool(&self, _tool_name: &str, _output: &str) {
// Increment step counter after each tool execution.
self.step_counter.fetch_add(1, Ordering::Relaxed);
// Step counting is owned by StepEmitterMiddleware (runs after this middleware).
}
}
125 changes: 118 additions & 7 deletions src/browser/browse_loop.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;

use crate::browser::approval_gate::{ApprovalGate, ApprovalGateMiddleware, ApprovalPrompt};
use crate::browser::loop_detector::LoopDetectorMiddleware;
use crate::browser::middleware::{MiddlewareVerdict, ToolMiddleware};
use crate::config::Config;
use crate::query_engine::QueryEngine;
use crate::tools::DynTool;
Expand Down Expand Up @@ -110,6 +112,45 @@ fn parse_browse_done(text: &str) -> Option<(bool, String)> {
Some((achieved, summary))
}

/// Extract a human-readable "target" from tool input for Step events.
/// Prefers `url` → `ref` → `selector` → `key` → empty string.
fn extract_target(input: &serde_json::Value) -> String {
for key in ["url", "ref", "selector", "key"] {
if let Some(s) = input.get(key).and_then(|v| v.as_str()) {
if !s.is_empty() {
return s.to_string();
}
}
}
String::new()
}

/// Middleware that emits `BrowseProgress::Step` for each allowed tool call.
/// Placed last in the chain so denied calls (by gate or loop detector) are not
/// reported as executed steps.
struct StepEmitterMiddleware {
progress_tx: mpsc::Sender<BrowseProgress>,
counter: Arc<AtomicU32>,
}

#[async_trait]
impl ToolMiddleware for StepEmitterMiddleware {
async fn before_tool(&self, tool_name: &str, input: &serde_json::Value) -> MiddlewareVerdict {
let n = self.counter.fetch_add(1, Ordering::Relaxed) + 1;
let _ = self
.progress_tx
.send(BrowseProgress::Step {
n,
action: tool_name.to_string(),
target: extract_target(input),
})
.await;
MiddlewareVerdict::Allow
}

async fn after_tool(&self, _tool_name: &str, _output: &str) {}
}

/// Orchestrate an autonomous browser agent run.
pub async fn run_browse(
req: BrowseRequest,
Expand All @@ -120,35 +161,90 @@ pub async fn run_browse(
approval_tx: mpsc::Sender<ApprovalPrompt>,
cancel: Arc<AtomicBool>,
) -> Result<BrowseResult> {
// 1. Emit Started event.
// 1. Emit Started event + speak the goal if voice is enabled.
let _ = progress_tx
.send(BrowseProgress::Started {
goal: req.goal.clone(),
max_steps: req.max_steps,
})
.await;
if req.voice {
let goal = req.goal.clone();
tokio::spawn(async move {
crate::voice::speak_browse_milestone(
crate::voice::BrowseMilestone::Start,
&goal,
)
.await;
});
}

// 2. Shared step counter for the approval prompt.
let step_counter = Arc::new(AtomicU32::new(0));

// 3. Build the approval gate from config patterns.
// The gate sends prompts to an internal channel; a bridge task mirrors each
// prompt as BrowseProgress::ApprovalNeeded, then forwards it to the caller.
let gate = ApprovalGate::with_user_patterns(config.browse_approval_patterns.clone());
let (internal_approval_tx, mut internal_approval_rx) = mpsc::channel::<ApprovalPrompt>(16);
let gate_mw = Arc::new(ApprovalGateMiddleware::new(
gate,
req.policy,
current_url.clone(),
approval_tx,
internal_approval_tx,
step_counter.clone(),
req.voice,
));

// Bridge: internal approvals → progress event + external approval channel.
let approval_bridge_progress_tx = progress_tx.clone();
let voice_on_gate = req.voice;
let approval_bridge_handle = tokio::spawn(async move {
while let Some(prompt) = internal_approval_rx.recv().await {
let _ = approval_bridge_progress_tx
.send(BrowseProgress::ApprovalNeeded {
step: prompt.step,
action: prompt.tool_name.clone(),
target_text: prompt.target_text.clone(),
url: prompt.url.clone(),
reason: prompt.reason.clone(),
})
.await;
if voice_on_gate {
let phrase = format!(
"Approval needed for {} — {}",
prompt.tool_name, prompt.reason
);
tokio::spawn(async move {
crate::voice::speak_browse_milestone(
crate::voice::BrowseMilestone::GateTrip,
&phrase,
)
.await;
});
}
if approval_tx.send(prompt).await.is_err() {
// Caller dropped the approval channel — stop bridging.
break;
}
}
});

// 4. Build the loop detector middleware.
let (nudge_tx, mut nudge_rx) = mpsc::channel::<String>(16);
let loop_mw = Arc::new(LoopDetectorMiddleware::new(nudge_tx));

// 5. Assemble middleware chain (keep Arc refs for post-run inspection).
// 5. Build the step-emitter middleware (runs last — only fires for allowed calls).
let step_emitter = Arc::new(StepEmitterMiddleware {
progress_tx: progress_tx.clone(),
counter: step_counter.clone(),
});

// 6. Assemble middleware chain (keep Arc refs for post-run inspection).
let middlewares: crate::browser::middleware::MiddlewareChain = vec![
gate_mw.clone() as Arc<dyn crate::browser::middleware::ToolMiddleware>,
loop_mw.clone() as Arc<dyn crate::browser::middleware::ToolMiddleware>,
gate_mw.clone() as Arc<dyn ToolMiddleware>,
loop_mw.clone() as Arc<dyn ToolMiddleware>,
step_emitter as Arc<dyn ToolMiddleware>,
];

// 6. Build browse-specific system prompt.
Expand Down Expand Up @@ -179,8 +275,9 @@ pub async fn run_browse(

// 11. Check for early cancellation before starting the loop.
if cancel.load(Ordering::SeqCst) {
drop(engine); // drop engine (and its middleware chain) to shut down nudge_tx
drop(engine); // drop engine (and its middleware chain) to shut down channels
let _ = nudge_handle.await;
let _ = approval_bridge_handle.await;
let final_url = {
let url = current_url.lock().await;
if url.is_empty() { None } else { Some(url.clone()) }
Expand Down Expand Up @@ -317,10 +414,24 @@ pub async fn run_browse(
}
};

// 14. Emit Completed event.
// 14. Emit Completed event + speak the final summary if voice is enabled.
let _ = progress_tx
.send(BrowseProgress::Completed(result.clone()))
.await;
if req.voice {
let phrase = if result.achieved {
format!("Done. {}", result.summary)
} else {
format!("Stopped. {}", result.summary)
};
tokio::spawn(async move {
crate::voice::speak_browse_milestone(
crate::voice::BrowseMilestone::End,
&phrase,
)
.await;
});
}

Ok(result)
}
9 changes: 0 additions & 9 deletions src/browser/loop_detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ impl LoopDetector {
self.window.clear();
self.nudge_level = 0;
}

pub fn window_len(&self) -> usize {
self.window.len()
}
}

impl Default for LoopDetector {
Expand All @@ -88,11 +84,6 @@ impl Default for LoopDetector {
}
}

/// Public helper: SHA-256 of "{action_type}:{target}".
pub fn fingerprint_action(action_type: &str, target: &str, _extra: &str) -> String {
hash_string(&format!("{action_type}:{target}"))
}

fn hash_string(s: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(s.as_bytes());
Expand Down
5 changes: 1 addition & 4 deletions src/browser/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@ pub enum MiddlewareVerdict {
Allow,
/// Block the tool with an error reason.
Deny { reason: String },
/// Request confirmation before proceeding.
/// Treated as Deny until the approval gate resolves it internally.
RequireConfirmation { reason: String, detail: String },
}

/// Extension point invoked before and after every tool execution.
#[async_trait]
pub trait ToolMiddleware: Send + Sync {
/// Called before a tool runs. Return `Deny` or `RequireConfirmation` to block.
/// Called before a tool runs. Return `Deny` to block.
async fn before_tool(&self, tool_name: &str, input: &Value) -> MiddlewareVerdict;

/// Called after a tool runs with its output text.
Expand Down
6 changes: 5 additions & 1 deletion src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,11 @@ pub fn clipboard_write(text: &str) -> CommandAction {
.args(args)
.stdin(std::process::Stdio::piped())
.spawn()?;
child.stdin.as_mut().unwrap().write_all(text.as_bytes())?;
child
.stdin
.as_mut()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stdin not available"))?
.write_all(text.as_bytes())?;
child.wait()
};

Expand Down
18 changes: 5 additions & 13 deletions src/query_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,18 +439,6 @@ impl QueryEngine {
middleware_denied = true;
break;
}
MiddlewareVerdict::RequireConfirmation { reason, .. } => {
// Treat as Deny until the approval gate resolves internally.
results.push(ContentBlock::ToolResult {
tool_use_id: id.clone(),
content: vec![ToolResultContent::text(format!(
"Middleware requires confirmation: {reason}"
))],
is_error: Some(true),
});
middleware_denied = true;
break;
}
}
}
if middleware_denied {
Expand Down Expand Up @@ -597,7 +585,11 @@ fn truncate_json(v: &serde_json::Value, max_len: usize) -> String {
if s.len() <= max_len {
s
} else {
format!("{}...", &s[..max_len])
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &s[..end])
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/rag/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ fn collect_symbols(
// Extract the source text for this node
let start_byte = child.start_byte();
let end_byte = child.end_byte();
let content = &ctx.full_source[start_byte..end_byte.min(ctx.full_source.len())];
let src_len = ctx.full_source.len();
let content = &ctx.full_source[start_byte.min(src_len)..end_byte.min(src_len)];

// Cap chunk size at 200 lines — huge functions get truncated
let content = if end_line - start_line > 200 {
Expand Down
Loading
Loading