From f56c0eb5d3b7741fc10b89a95b55292780798a55 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 2 Jul 2026 11:01:15 -0700 Subject: [PATCH 1/2] feat: add web search gateway tool Signed-off-by: Francisco Javier Arceo --- README.md | 6 + crates/agentic-core/src/events/normalize.rs | 2 + crates/agentic-core/src/events/types.rs | 1 + crates/agentic-core/src/executor/engine.rs | 615 +++++++++++- crates/agentic-core/src/executor/error.rs | 12 +- crates/agentic-core/src/executor/request.rs | 49 + crates/agentic-core/src/lib.rs | 5 +- crates/agentic-core/src/storage/types/item.rs | 2 +- crates/agentic-core/src/tool/handler.rs | 1 + crates/agentic-core/src/tool/mod.rs | 2 + crates/agentic-core/src/tool/normalize.rs | 6 +- crates/agentic-core/src/tool/web_search.rs | 332 +++++++ crates/agentic-core/src/types/io/mod.rs | 1 + crates/agentic-core/src/types/io/output.rs | 71 ++ crates/agentic-core/src/types/mod.rs | 4 +- crates/agentic-core/src/types/tools/mod.rs | 2 +- crates/agentic-core/src/types/tools/params.rs | 62 +- .../tests/event_normalizer_test.rs | 23 +- crates/agentic-core/tests/support/mod.rs | 5 +- .../tests/tool_normalization_test.rs | 21 + .../tests/web_search_tool_test.rs | 937 ++++++++++++++++++ 21 files changed, 2086 insertions(+), 73 deletions(-) create mode 100644 crates/agentic-core/src/tool/web_search.rs create mode 100644 crates/agentic-core/tests/web_search_tool_test.rs diff --git a/README.md b/README.md index 4b3f9a1..451bc30 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ cargo build cargo test ``` +## Web search + +The stateful `/v1/responses` executor supports OpenAI-compatible `web_search` +tool declarations by normalizing them into a `web_search` function call for +vLLM. Set `YOU_API_KEY` to enable execution through You.com's Search API. + ## Lint and format ```bash diff --git a/crates/agentic-core/src/events/normalize.rs b/crates/agentic-core/src/events/normalize.rs index 6d2b52a..9f774a2 100644 --- a/crates/agentic-core/src/events/normalize.rs +++ b/crates/agentic-core/src/events/normalize.rs @@ -57,6 +57,7 @@ fn classify_event_type(type_str: &str) -> SSEEventType { "response.reasoning_summary_text.done" => SSEEventType::ReasoningSummaryTextDone, "response.file_search_call.searching" => SSEEventType::FileSearchCallSearching, "response.file_search_call.completed" => SSEEventType::FileSearchCallCompleted, + "response.web_search_call.in_progress" => SSEEventType::WebSearchCallInProgress, "response.web_search_call.searching" => SSEEventType::WebSearchCallSearching, "response.web_search_call.completed" => SSEEventType::WebSearchCallCompleted, _ => SSEEventType::Other, @@ -90,6 +91,7 @@ fn extract_payload(event_type: SSEEventType, json: &Value) -> EventPayload { | SSEEventType::ReasoningPartDone | SSEEventType::FileSearchCallSearching | SSEEventType::FileSearchCallCompleted + | SSEEventType::WebSearchCallInProgress | SSEEventType::WebSearchCallSearching | SSEEventType::WebSearchCallCompleted | SSEEventType::Other => EventPayload::Raw(json.clone()), diff --git a/crates/agentic-core/src/events/types.rs b/crates/agentic-core/src/events/types.rs index 84d0ee5..11b408e 100644 --- a/crates/agentic-core/src/events/types.rs +++ b/crates/agentic-core/src/events/types.rs @@ -88,6 +88,7 @@ pub enum SSEEventType { // Built-in tool calls FileSearchCallSearching, FileSearchCallCompleted, + WebSearchCallInProgress, WebSearchCallSearching, WebSearchCallCompleted, diff --git a/crates/agentic-core/src/executor/engine.rs b/crates/agentic-core/src/executor/engine.rs index 1afc1a4..6f9accf 100644 --- a/crates/agentic-core/src/executor/engine.rs +++ b/crates/agentic-core/src/executor/engine.rs @@ -9,6 +9,9 @@ use std::sync::Arc; use async_stream::stream; use either::Either; +use futures::future::join_all; +use serde_json::Value; +use tokio::sync::mpsc; use tracing::warn; use crate::executor::accumulator::ResponseAccumulator; @@ -17,13 +20,24 @@ use crate::executor::inference::{DONE_MARKER, call_inference, fetch_response_jso use crate::executor::persist::persist_response; use crate::executor::rehydrate::rehydrate_conversation; use crate::executor::request::{ExecutionContext, RequestContext}; +use crate::tool::{ToolError, ToolOutput, ToolRegistry, ToolType}; +use crate::types::io::output::{FunctionToolCall, WebSearchCall, WebSearchCallStatus, WebSearchSource}; +use crate::types::io::{InputItem, OutputItem, ResponseUsage, ResponsesInput, ToolChoice}; use crate::types::request_response::{RequestPayload, ResponsePayload}; use crate::utils::common::serialize_to_string; pub use crate::executor::inference::BoxStream; -async fn run_blocking( - ctx: RequestContext, +const MAX_GATEWAY_TOOL_ROUNDS: usize = 10; + +fn should_persist(ctx: &RequestContext) -> bool { + ctx.original_request.store + || ctx.original_request.previous_response_id.is_some() + || ctx.original_request.conversation_id.is_some() +} + +async fn fetch_blocking_payload( + ctx: &RequestContext, exec_ctx: &ExecutionContext, auth: Option<&str>, ) -> ExecutorResult { @@ -42,10 +56,508 @@ async fn run_blocking( ); ctx.inject_ids(&mut payload); - let should_persist = ctx.original_request.store - || ctx.original_request.previous_response_id.is_some() - || ctx.original_request.conversation_id.is_some(); - if should_persist { + Ok(payload) +} + +async fn fetch_stream_payload( + ctx: &RequestContext, + exec_ctx: &ExecutionContext, + auth: Option<&str>, +) -> ExecutorResult { + let url = exec_ctx.responses_url(); + let upstream_json = + serialize_to_string(&ctx.enriched_request.to_upstream_request(true)).map_err(ExecutorError::JsonError)?; + let line_stream = Box::pin(call_inference( + upstream_json, + url, + Arc::clone(&exec_ctx.client), + auth.map(str::to_owned), + exec_ctx.streaming_timeout, + )); + let acc = ResponseAccumulator::from_stream(line_stream, ctx.conversation_id.as_deref()).await?; + let mut payload = acc.finalize( + &ctx.enriched_request.model, + ctx.original_request.previous_response_id.as_deref(), + ctx.original_request.instructions.as_deref(), + ); + ctx.inject_ids(&mut payload); + Ok(payload) +} + +fn append_input_item(input: &mut ResponsesInput, item: InputItem) { + match input { + ResponsesInput::Items(items) => items.push(item), + ResponsesInput::Text(text) => { + let text_input = ResponsesInput::Text(std::mem::take(text)); + let mut items = Vec::::from(&text_input); + items.push(item); + *input = ResponsesInput::Items(items); + } + } +} + +fn append_output_items_to_input(input: &mut ResponsesInput, output_items: &[OutputItem]) { + for output_item in output_items { + let input_item = match output_item { + OutputItem::Message(message) => Some(InputItem::Message(message.clone().into())), + OutputItem::FunctionCall(call) => Some(InputItem::FunctionCall(call.clone())), + OutputItem::Reasoning(reasoning) => Some(InputItem::Reasoning(reasoning.clone())), + OutputItem::WebSearchCall(_) | OutputItem::Unknown => None, + }; + if let Some(input_item) = input_item { + append_input_item(input, input_item); + } + } +} + +fn append_tool_outputs(ctx: &mut RequestContext, tool_outputs: Vec) { + for output in tool_outputs { + ctx.new_input_items.push(output.clone()); + append_input_item(&mut ctx.enriched_request.input, output); + } +} + +fn function_calls(output_items: &[OutputItem]) -> Vec { + output_items + .iter() + .filter_map(|item| match item { + OutputItem::FunctionCall(call) => Some(call.clone()), + _ => None, + }) + .collect() +} + +fn is_gateway_owned_call(call: &FunctionToolCall, registry: &ToolRegistry) -> bool { + registry + .lookup(&call.name) + .is_some_and(|entry| entry.tool_type != ToolType::Function) +} + +fn append_gateway_calls_to_new_input(ctx: &mut RequestContext, output_items: &[OutputItem], registry: &ToolRegistry) { + ctx.new_input_items.extend(output_items.iter().filter_map(|item| { + let OutputItem::FunctionCall(call) = item else { + return None; + }; + is_gateway_owned_call(call, registry).then(|| InputItem::FunctionCall(call.clone())) + })); +} + +fn public_output_items( + output_items: Vec, + registry: &ToolRegistry, + gateway_results: &[GatewayCallResult], +) -> Vec { + output_items + .into_iter() + .map(|item| match item { + OutputItem::FunctionCall(call) if is_gateway_owned_call(&call, registry) => gateway_results + .iter() + .find(|result| result.call.call_id == call.call_id) + .and_then(|result| result.public_output.clone()) + .unwrap_or(OutputItem::FunctionCall(call)), + other => other, + }) + .collect() +} + +fn add_usage(total: ResponseUsage, usage: ResponseUsage) -> ResponseUsage { + ResponseUsage { + input_tokens: total.input_tokens.saturating_add(usage.input_tokens), + output_tokens: total.output_tokens.saturating_add(usage.output_tokens), + total_tokens: total.total_tokens.saturating_add(usage.total_tokens), + input_tokens_details: crate::types::io::InputTokenDetails { + cached_tokens: total + .input_tokens_details + .cached_tokens + .saturating_add(usage.input_tokens_details.cached_tokens), + }, + output_tokens_details: crate::types::io::OutputTokenDetails { + reasoning_tokens: total + .output_tokens_details + .reasoning_tokens + .saturating_add(usage.output_tokens_details.reasoning_tokens), + }, + } +} + +fn accumulate_usage(total: &mut Option, usage: Option) { + if let Some(usage) = usage { + *total = Some(total.map_or(usage, |current| add_usage(current, usage))); + } +} + +struct GatewayCallDispatch { + call: FunctionToolCall, + tool_type: ToolType, + config: Value, + executor: Arc, +} + +struct GatewayCallExecution { + call: FunctionToolCall, + tool_type: ToolType, + output: Result, +} + +#[derive(Clone)] +struct GatewayCallResult { + call: FunctionToolCall, + input_item: InputItem, + public_output: Option, +} + +struct GatewayCallEventPlan { + call_id: String, + output_index: u32, + started_output: Option, +} + +fn gateway_dispatches( + calls: &[FunctionToolCall], + registry: &ToolRegistry, + exec_ctx: &ExecutionContext, +) -> ExecutorResult> { + registry + .gateway_owned(calls) + .into_iter() + .map(|call| { + let entry = registry.lookup(&call.name).ok_or_else(|| { + ToolError::Config(format!( + "gateway tool '{}' was not found in the request registry", + call.name + )) + })?; + let executor = exec_ctx + .gateway_executors + .get(entry.tool_type) + .ok_or_else(|| ToolError::Config(format!("no gateway executor registered for tool '{}'", call.name)))?; + Ok(GatewayCallDispatch { + call: call.clone(), + tool_type: entry.tool_type, + config: entry.config.clone(), + executor, + }) + }) + .collect::, ToolError>>() + .map_err(ExecutorError::from) +} + +async fn execute_gateway_dispatch(dispatch: GatewayCallDispatch) -> GatewayCallExecution { + let GatewayCallDispatch { + call, + tool_type, + config, + executor, + } = dispatch; + let output = executor + .execute(&call.call_id, &call.name, &call.arguments, &config) + .await; + GatewayCallExecution { + call, + tool_type, + output, + } +} + +fn execution_error_output(call: &FunctionToolCall, message: &str) -> ExecutorResult { + let output = serialize_to_string(&serde_json::json!({ "error": message })).map_err(ExecutorError::JsonError)?; + Ok(ToolOutput { + call_id: call.call_id.clone(), + output, + }) +} + +fn gateway_public_output( + tool_type: ToolType, + call: &FunctionToolCall, + output: &ToolOutput, + status: WebSearchCallStatus, +) -> Option { + match tool_type { + ToolType::WebSearch => Some(web_search_output_item(call, output, status)), + ToolType::Function | ToolType::Mcp | ToolType::FileSearch | ToolType::CodeInterpreter => None, + } +} + +async fn execute_gateway_calls( + calls: &[FunctionToolCall], + registry: &ToolRegistry, + exec_ctx: &ExecutionContext, +) -> ExecutorResult> { + let dispatches = gateway_dispatches(calls, registry, exec_ctx)?; + let executions = join_all(dispatches.into_iter().map(execute_gateway_dispatch)).await; + let mut results = Vec::with_capacity(executions.len()); + + for execution in executions { + let GatewayCallExecution { + call, + tool_type, + output, + } = execution; + let (output, status) = match output { + Ok(output) => (output, WebSearchCallStatus::Completed), + Err(ToolError::Execution(message)) => { + (execution_error_output(&call, &message)?, WebSearchCallStatus::Failed) + } + Err(err @ ToolError::Config(_)) => return Err(err.into()), + }; + let public_output = gateway_public_output(tool_type, &call, &output, status); + results.push(GatewayCallResult { + call, + input_item: InputItem::FunctionCallOutput(output.into()), + public_output, + }); + } + + Ok(results) +} + +fn web_search_output_item(call: &FunctionToolCall, output: &ToolOutput, status: WebSearchCallStatus) -> OutputItem { + let parsed_output = serde_json::from_str::(&output.output).ok(); + let query = parsed_output + .as_ref() + .and_then(|value| clean_json_str(value.get("query"))) + .or_else(|| web_search_query_from_arguments(&call.arguments)) + .unwrap_or_default(); + let sources = parsed_output + .as_ref() + .map(web_search_sources_from_output) + .unwrap_or_default(); + OutputItem::WebSearchCall(WebSearchCall::new(web_search_call_id(call), status, query, sources)) +} + +fn started_web_search_output_item(call: &FunctionToolCall) -> OutputItem { + OutputItem::WebSearchCall(WebSearchCall::new( + web_search_call_id(call), + WebSearchCallStatus::InProgress, + web_search_query_from_arguments(&call.arguments).unwrap_or_default(), + Vec::new(), + )) +} + +fn web_search_call_id(call: &FunctionToolCall) -> String { + if let Some(suffix) = call.id.strip_prefix("fc_").filter(|suffix| !suffix.is_empty()) { + return format!("ws_{suffix}"); + } + if let Some(suffix) = call.call_id.strip_prefix("call_").filter(|suffix| !suffix.is_empty()) { + return format!("ws_{suffix}"); + } + crate::utils::uuid7_str("ws_") +} + +fn web_search_query_from_arguments(arguments: &str) -> Option { + let args = serde_json::from_str::(arguments).ok()?; + clean_json_str(args.get("query")) +} + +fn web_search_sources_from_output(output: &Value) -> Vec { + ["web", "news"] + .into_iter() + .filter_map(|section| output.get("results")?.get(section)?.as_array()) + .flat_map(|results| results.iter()) + .filter_map(web_search_source_from_result) + .collect() +} + +fn web_search_source_from_result(result: &Value) -> Option { + let url = clean_json_str(result.get("url"))?; + Some(WebSearchSource { + url, + title: clean_json_str(result.get("title")), + }) +} + +fn clean_json_str(value: Option<&Value>) -> Option { + value + .and_then(Value::as_str) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_owned) +} + +fn gateway_event_plans( + output_items: &[OutputItem], + registry: &ToolRegistry, + output_offset: usize, +) -> Vec { + let mut output_index = output_offset; + let mut plans = Vec::new(); + for item in output_items { + if let OutputItem::FunctionCall(call) = item + && let Some(entry) = registry.lookup(&call.name) + && entry.tool_type != ToolType::Function + { + plans.push(GatewayCallEventPlan { + call_id: call.call_id.clone(), + output_index: u32::try_from(output_index).unwrap_or(u32::MAX), + started_output: match entry.tool_type { + ToolType::WebSearch => Some(started_web_search_output_item(call)), + ToolType::Function | ToolType::Mcp | ToolType::FileSearch | ToolType::CodeInterpreter => None, + }, + }); + } + output_index = output_index.saturating_add(1); + } + plans +} + +fn emit_sse_json(sender: &mpsc::UnboundedSender, event: &Value) -> ExecutorResult<()> { + let event_json = serialize_to_string(&event).map_err(ExecutorError::JsonError)?; + let _ = sender.send(format!("data: {event_json}\n\n")); + Ok(()) +} + +fn output_item_value(item: &OutputItem) -> ExecutorResult { + serde_json::to_value(item).map_err(ExecutorError::JsonError) +} + +fn emit_gateway_start_events( + plans: &[GatewayCallEventPlan], + stream_events: Option<&mpsc::UnboundedSender>, +) -> ExecutorResult<()> { + let Some(sender) = stream_events else { + return Ok(()); + }; + for plan in plans { + let Some(output_item) = &plan.started_output else { + continue; + }; + let OutputItem::WebSearchCall(web_search_call) = output_item else { + continue; + }; + let item = output_item_value(output_item)?; + let added_event = serde_json::json!({ + "type": "response.output_item.added", + "output_index": plan.output_index, + "item": item + }); + emit_sse_json(sender, &added_event)?; + let in_progress_event = serde_json::json!({ + "type": "response.web_search_call.in_progress", + "item_id": web_search_call.id, + "output_index": plan.output_index + }); + emit_sse_json(sender, &in_progress_event)?; + let searching_event = serde_json::json!({ + "type": "response.web_search_call.searching", + "item_id": web_search_call.id, + "output_index": plan.output_index + }); + emit_sse_json(sender, &searching_event)?; + } + Ok(()) +} + +fn emit_gateway_completed_events( + results: &[GatewayCallResult], + plans: &[GatewayCallEventPlan], + stream_events: Option<&mpsc::UnboundedSender>, +) -> ExecutorResult<()> { + let Some(sender) = stream_events else { + return Ok(()); + }; + for result in results { + let Some(OutputItem::WebSearchCall(web_search_call)) = &result.public_output else { + continue; + }; + let output_index = plans + .iter() + .find(|plan| plan.call_id == result.call.call_id) + .map_or(0, |plan| plan.output_index); + let output_item = OutputItem::WebSearchCall(web_search_call.clone()); + let item = output_item_value(&output_item)?; + let completed_event = serde_json::json!({ + "type": "response.web_search_call.completed", + "item_id": web_search_call.id, + "output_index": output_index, + "item": item.clone() + }); + emit_sse_json(sender, &completed_event)?; + let done_event = serde_json::json!({ + "type": "response.output_item.done", + "output_index": output_index, + "item": item + }); + emit_sse_json(sender, &done_event)?; + } + Ok(()) +} + +async fn run_until_gateway_tools_complete( + mut ctx: RequestContext, + exec_ctx: &ExecutionContext, + auth: Option<&str>, + stream_upstream: bool, + stream_events: Option<&mpsc::UnboundedSender>, +) -> ExecutorResult<(ResponsePayload, RequestContext)> { + let registry = ctx + .enriched_request + .tools + .as_ref() + .map_or_else(ToolRegistry::default, |tools| ToolRegistry::build(tools)); + let mut combined_output = Vec::new(); + let mut combined_usage = None; + + for _ in 0..MAX_GATEWAY_TOOL_ROUNDS { + let mut payload = if stream_upstream { + fetch_stream_payload(&ctx, exec_ctx, auth).await? + } else { + fetch_blocking_payload(&ctx, exec_ctx, auth).await? + }; + accumulate_usage(&mut combined_usage, payload.usage); + let current_output = std::mem::take(&mut payload.output); + let calls = function_calls(¤t_output); + let has_client_owned_calls = !registry.client_owned(&calls).is_empty(); + let event_plans = gateway_event_plans(¤t_output, ®istry, combined_output.len()); + emit_gateway_start_events(&event_plans, stream_events)?; + let gateway_results = execute_gateway_calls(&calls, ®istry, exec_ctx).await?; + emit_gateway_completed_events(&gateway_results, &event_plans, stream_events)?; + let public_output = public_output_items(current_output.clone(), ®istry, &gateway_results); + + if has_client_owned_calls { + combined_output.extend(public_output); + append_gateway_calls_to_new_input(&mut ctx, ¤t_output, ®istry); + append_tool_outputs( + &mut ctx, + gateway_results.into_iter().map(|result| result.input_item).collect(), + ); + payload.output = combined_output; + payload.usage = combined_usage; + ctx.inject_ids(&mut payload); + return Ok((payload, ctx)); + } + + if gateway_results.is_empty() { + combined_output.extend(public_output); + payload.output = combined_output; + payload.usage = combined_usage; + ctx.inject_ids(&mut payload); + return Ok((payload, ctx)); + } + + combined_output.extend(public_output); + ctx.enriched_request.tool_choice = ToolChoice::Auto; + append_output_items_to_input(&mut ctx.enriched_request.input, ¤t_output); + append_gateway_calls_to_new_input(&mut ctx, ¤t_output, ®istry); + append_tool_outputs( + &mut ctx, + gateway_results.into_iter().map(|result| result.input_item).collect(), + ); + } + + Err(ExecutorError::InvalidRequest(format!( + "gateway tool execution exceeded {MAX_GATEWAY_TOOL_ROUNDS} rounds" + ))) +} + +async fn run_blocking( + ctx: RequestContext, + exec_ctx: &ExecutionContext, + auth: Option<&str>, +) -> ExecutorResult { + let (payload, ctx) = run_until_gateway_tools_complete(ctx, exec_ctx, auth, false, None).await?; + + if should_persist(&ctx) { let ch = exec_ctx.conv_handler.clone(); let rh = exec_ctx.resp_handler.clone(); if let Err(e) = persist_response(payload.clone(), ctx, ch, rh).await { @@ -57,56 +569,53 @@ async fn run_blocking( } fn run_stream(ctx: RequestContext, exec_ctx: Arc, auth: Option) -> BoxStream { - let url = exec_ctx.responses_url(); - // Streaming request: stream=true → SSE lines → from_stream. - let upstream_json = match serialize_to_string(&ctx.enriched_request.to_upstream_request(true)) { - Ok(s) => s, - Err(e) => { - return Box::pin(stream! { - yield format!("data: {{\"error\": \"serialize error: {e}\"}}\n\n"); - yield DONE_MARKER.to_string(); - }); - } - }; + Box::pin(stream! { + let should_persist = should_persist(&ctx); + let (event_tx, mut event_rx) = mpsc::unbounded_channel(); + let exec_ctx_for_run = Arc::clone(&exec_ctx); + let mut run_handle = tokio::spawn(async move { + run_until_gateway_tools_complete( + ctx, + exec_ctx_for_run.as_ref(), + auth.as_deref(), + true, + Some(&event_tx), + ) + .await + }); - // Persist when store=true, or when an ID is passed — context continuity must - // be preserved even if the caller sets store=false. - let should_persist = ctx.original_request.store - || ctx.original_request.previous_response_id.is_some() - || ctx.original_request.conversation_id.is_some(); + loop { + tokio::select! { + Some(event) = event_rx.recv() => { + yield event; + } + result = &mut run_handle => { + while let Ok(event) = event_rx.try_recv() { + yield event; + } + match result { + Err(e) => { + yield format!("data: {{\"error\": \"stream task failed: {e}\"}}\n\n"); + yield DONE_MARKER.to_string(); + } + Ok(Err(e)) => { + yield format!("data: {{\"error\": \"{e}\"}}\n\n"); + yield DONE_MARKER.to_string(); + } + Ok(Ok((payload, ctx))) => { + yield payload.as_responses_chunk(); + yield DONE_MARKER.to_string(); - Box::pin(stream! { - let line_stream = Box::pin(call_inference( - upstream_json, - url, - Arc::clone(&exec_ctx.client), - auth, - exec_ctx.streaming_timeout, - )); - - // from_stream feeds SSE lines to a spawn_blocking worker via channel. - // All JSON parsing is CPU-bound and runs off the async executor. - match ResponseAccumulator::from_stream(line_stream, ctx.conversation_id.as_deref()).await { - Err(e) => { - yield format!("data: {{\"error\": \"{e}\"}}\n\n"); - yield DONE_MARKER.to_string(); - } - Ok(acc) => { - let mut payload = acc.finalize( - &ctx.enriched_request.model, - ctx.original_request.previous_response_id.as_deref(), - ctx.original_request.instructions.as_deref(), - ); - ctx.inject_ids(&mut payload); - yield payload.as_responses_chunk(); - yield DONE_MARKER.to_string(); - - if should_persist { - let ch = exec_ctx.conv_handler.clone(); - let rh = exec_ctx.resp_handler.clone(); - if let Err(e) = persist_response(payload, ctx, ch, rh).await { - warn!("persist failed: {e}"); + if should_persist { + let ch = exec_ctx.conv_handler.clone(); + let rh = exec_ctx.resp_handler.clone(); + if let Err(e) = persist_response(payload, ctx, ch, rh).await { + warn!("persist failed: {e}"); + } + } + } } + break; } } } diff --git a/crates/agentic-core/src/executor/error.rs b/crates/agentic-core/src/executor/error.rs index df200a2..845044e 100644 --- a/crates/agentic-core/src/executor/error.rs +++ b/crates/agentic-core/src/executor/error.rs @@ -2,6 +2,7 @@ use http::StatusCode; use thiserror::Error; use crate::StorageError; +use crate::tool::ToolError; use crate::utils::common::serialize_to_vec_or_default; #[non_exhaustive] @@ -54,6 +55,9 @@ pub enum ExecutorError { #[error("invalid request: {0}")] InvalidRequest(String), + + #[error("tool error: {0}")] + Tool(#[from] ToolError), } impl ExecutorError { @@ -63,7 +67,8 @@ impl ExecutorError { match self { Self::Storage(e) if e.is_not_found() => StatusCode::NOT_FOUND, Self::LLMRequest { status, .. } => *status, - Self::InvalidRequest(_) | Self::JsonError(_) => StatusCode::BAD_REQUEST, + Self::Tool(ToolError::Config(_)) | Self::InvalidRequest(_) | Self::JsonError(_) => StatusCode::BAD_REQUEST, + Self::Tool(ToolError::Execution(_)) => StatusCode::BAD_GATEWAY, Self::ParseError(_) => StatusCode::UNPROCESSABLE_ENTITY, _ => StatusCode::INTERNAL_SERVER_ERROR, } @@ -75,7 +80,10 @@ impl ExecutorError { match self { Self::Storage(e) if e.is_not_found() => "not_found", Self::LLMRequest { .. } => "upstream_error", - Self::InvalidRequest(_) | Self::ParseError(_) | Self::JsonError(_) => "invalid_request_error", + Self::Tool(ToolError::Config(_)) | Self::InvalidRequest(_) | Self::ParseError(_) | Self::JsonError(_) => { + "invalid_request_error" + } + Self::Tool(ToolError::Execution(_)) => "tool_error", _ => "server_error", } } diff --git a/crates/agentic-core/src/executor/request.rs b/crates/agentic-core/src/executor/request.rs index 4c8ee2c..9ec98d5 100644 --- a/crates/agentic-core/src/executor/request.rs +++ b/crates/agentic-core/src/executor/request.rs @@ -5,9 +5,47 @@ use crate::config::Config; use crate::error::Error; use crate::executor::modes::{ConversationHandler, ResponseHandler}; use crate::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use crate::tool::{GatewayExecutor, ToolType, WebSearchHandler}; use crate::types::io::InputItem; use crate::types::request_response::{RequestPayload, ResponsePayload}; +#[derive(Clone, Default)] +pub struct GatewayExecutors { + web_search: Option>, +} + +impl GatewayExecutors { + #[must_use] + pub fn from_env(client: Arc) -> Self { + Self { + web_search: Some(Arc::new(WebSearchHandler::from_env(client))), + } + } + + pub fn insert(&mut self, executor: Arc) { + match executor.tool_type() { + ToolType::WebSearch => self.web_search = Some(executor), + other => tracing::debug!(tool_type = ?other, "gateway executor type not wired yet"), + } + } + + #[must_use] + pub fn get(&self, tool_type: ToolType) -> Option> { + match tool_type { + ToolType::WebSearch => self.web_search.clone(), + ToolType::Function | ToolType::Mcp | ToolType::FileSearch | ToolType::CodeInterpreter => None, + } + } +} + +impl std::fmt::Debug for GatewayExecutors { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GatewayExecutors") + .field("web_search", &self.web_search.is_some()) + .finish() + } +} + /// Context built by `rehydrate_conversation`, threaded through the execute pipeline. #[derive(Debug)] pub struct RequestContext { @@ -46,6 +84,7 @@ pub struct ExecutionContext { pub conv_handler: ConversationHandler, pub resp_handler: ResponseHandler, pub client: Arc, + pub gateway_executors: GatewayExecutors, /// Base URL for the LLM backend, e.g. `"http://localhost:8000"`. pub llm_base_url: String, /// Maximum wait time for the next SSE chunk. `Duration::ZERO` disables the timeout. @@ -73,15 +112,23 @@ impl ExecutionContext { client: Arc, llm_base_url: String, ) -> Self { + let gateway_executors = GatewayExecutors::from_env(Arc::clone(&client)); Self { conv_handler, resp_handler, client, + gateway_executors, llm_base_url, streaming_timeout: Duration::from_secs(30), } } + #[must_use] + pub fn with_gateway_executor(mut self, executor: Arc) -> Self { + self.gateway_executors.insert(executor); + self + } + /// Build an `ExecutionContext` directly from [`Config`](crate::config::Config). /// /// Creates the database pool, both storage handlers, and an HTTP client @@ -100,11 +147,13 @@ impl ExecutionContext { let conv_handler = ConversationHandler::new(ConversationStore::new(pool.clone())); let resp_handler = ResponseHandler::new(ResponseStore::new(pool)); let client = Arc::new(reqwest::Client::new()); + let gateway_executors = GatewayExecutors::from_env(Arc::clone(&client)); Ok(Self { conv_handler, resp_handler, client, + gateway_executors, llm_base_url: cfg.llm_api_base.clone(), streaming_timeout: Duration::from_secs(30), }) diff --git a/crates/agentic-core/src/lib.rs b/crates/agentic-core/src/lib.rs index 93bb05a..380ff52 100644 --- a/crates/agentic-core/src/lib.rs +++ b/crates/agentic-core/src/lib.rs @@ -16,12 +16,15 @@ pub use storage::{ }; pub use tool::{ FunctionHandler, GatewayExecutor, ToolEntry, ToolError, ToolHandler, ToolOutput, ToolRegistry, ToolType, + WebSearchHandler, }; pub use types::{ CodeInterpreterToolParam, EmptyToolNameError, FileSearchToolParam, FunctionTool, FunctionToolCall, FunctionToolParam, FunctionToolResultMessage, IncompleteDetails, InputContent, InputImageContent, InputItem, InputMessage, InputMessageContent, InputTextContent, InputTokenDetails, McpToolParam, NonEmptyToolName, OutputItem, OutputMessage, OutputTextContent, OutputTokenDetails, ReasoningOutput, ReasoningTextContent, RequestPayload, - ResponsePayload, ResponseUsage, ResponsesInput, ResponsesTool, ToolChoice, UpstreamRequest, WebSearchToolParam, + ResponsePayload, ResponseUsage, ResponsesInput, ResponsesTool, ToolChoice, UpstreamRequest, WebSearchActionSearch, + WebSearchCall, WebSearchCallStatus, WebSearchContextSize, WebSearchFilters, WebSearchSource, WebSearchToolParam, + WebSearchUserLocation, }; pub use utils::{utcnow_str, uuid7_str}; diff --git a/crates/agentic-core/src/storage/types/item.rs b/crates/agentic-core/src/storage/types/item.rs index 85a5491..20751a1 100644 --- a/crates/agentic-core/src/storage/types/item.rs +++ b/crates/agentic-core/src/storage/types/item.rs @@ -91,7 +91,7 @@ impl InOutItem { InOutItem::Output(OutputItem::Message(msg)) => Some(InputItem::Message(msg.into())), InOutItem::Output(OutputItem::Reasoning(r)) => Some(InputItem::Reasoning(r)), InOutItem::Output(OutputItem::FunctionCall(f)) => Some(InputItem::FunctionCall(f)), - InOutItem::Output(OutputItem::Unknown) => None, + InOutItem::Output(OutputItem::WebSearchCall(_) | OutputItem::Unknown) => None, }) .collect() } diff --git a/crates/agentic-core/src/tool/handler.rs b/crates/agentic-core/src/tool/handler.rs index 7e9ba79..4ec6953 100644 --- a/crates/agentic-core/src/tool/handler.rs +++ b/crates/agentic-core/src/tool/handler.rs @@ -67,6 +67,7 @@ pub trait GatewayExecutor: ToolHandler + 'static { /// Returns [`ToolError::Execution`] if the tool call fails. fn execute( &self, + call_id: &str, tool_name: &str, arguments: &str, config: &Value, diff --git a/crates/agentic-core/src/tool/mod.rs b/crates/agentic-core/src/tool/mod.rs index 6b9b49d..eadef03 100644 --- a/crates/agentic-core/src/tool/mod.rs +++ b/crates/agentic-core/src/tool/mod.rs @@ -7,7 +7,9 @@ pub mod function; pub mod handler; pub mod normalize; pub mod registry; +pub mod web_search; pub use function::FunctionHandler; pub use handler::{GatewayExecutor, ToolError, ToolHandler, ToolOutput}; pub use registry::{ToolEntry, ToolRegistry, ToolType}; +pub use web_search::WebSearchHandler; diff --git a/crates/agentic-core/src/tool/normalize.rs b/crates/agentic-core/src/tool/normalize.rs index de88a4b..ffc7539 100644 --- a/crates/agentic-core/src/tool/normalize.rs +++ b/crates/agentic-core/src/tool/normalize.rs @@ -3,6 +3,7 @@ use crate::types::io::input::FunctionToolResultMessage; use crate::types::tools::ResponsesTool; use super::handler::ToolOutput; +use super::web_search::web_search_function_tool; impl ResponsesTool { /// Normalise this tool declaration to the `FunctionTool` wire format that vLLM understands. @@ -27,10 +28,7 @@ impl ResponsesTool { ); None } - ResponsesTool::WebSearch(_) => { - tracing::debug!("web_search tool skipped in normalize — handler not yet registered"); - None - } + ResponsesTool::WebSearch(_) => Some(web_search_function_tool()), ResponsesTool::FileSearch(_) => { tracing::debug!("file_search tool skipped in normalize — handler not yet registered"); None diff --git a/crates/agentic-core/src/tool/web_search.rs b/crates/agentic-core/src/tool/web_search.rs new file mode 100644 index 0000000..88471d7 --- /dev/null +++ b/crates/agentic-core/src/tool/web_search.rs @@ -0,0 +1,332 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::types::io::FunctionTool; +use crate::types::tools::{WebSearchContextSize, WebSearchToolParam}; +use crate::utils::common::serialize_to_string; + +use super::handler::{GatewayExecutor, ToolError, ToolHandler, ToolOutput}; +use super::registry::ToolType; + +const DEFAULT_YOU_SEARCH_BASE_URL: &str = "https://ydc-index.io"; +const YOU_API_KEY_ENV: &str = "YOU_API_KEY"; +const YOU_API_BASE_URL_ENV: &str = "YOU_API_BASE_URL"; + +#[must_use] +pub(crate) fn web_search_function_tool() -> FunctionTool { + FunctionTool { + type_: "function".to_owned(), + name: "web_search".to_owned(), + description: Some( + "Search the public web for current information and return structured web and news results.".to_owned(), + ), + parameters: Some(serde_json::json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The natural language web search query." + }, + "count": { + "type": "integer", + "description": "Maximum results per section, from 1 to 100." + }, + "freshness": { + "type": "string", + "description": "Optional recency filter: day, week, month, year, or YYYY-MM-DDtoYYYY-MM-DD." + }, + "country": { + "type": "string", + "description": "Optional ISO 3166-1 alpha-2 country code." + }, + "language": { + "type": "string", + "description": "Optional BCP 47 language code." + }, + "include_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional strict allowlist of domains." + }, + "exclude_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional domain blocklist." + } + }, + "required": ["query"] + })), + strict: Some(false), + } +} + +#[derive(Debug, Clone)] +pub struct WebSearchHandler { + client: Arc, + api_key: Option, + base_url: String, +} + +impl WebSearchHandler { + #[must_use] + pub fn from_env(client: Arc) -> Self { + let api_key = std::env::var(YOU_API_KEY_ENV) + .ok() + .map(|value| value.trim().to_owned()) + .filter(|value| !value.is_empty()); + let base_url = std::env::var(YOU_API_BASE_URL_ENV) + .ok() + .map(|value| value.trim().trim_end_matches('/').to_owned()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| DEFAULT_YOU_SEARCH_BASE_URL.to_owned()); + Self { + client, + api_key, + base_url, + } + } + + #[must_use] + pub fn with_api_key(client: Arc, api_key: String, base_url: &str) -> Self { + Self { + client, + api_key: Some(api_key), + base_url: base_url.trim_end_matches('/').to_owned(), + } + } + + async fn execute_search(&self, call_id: &str, arguments: &str, config: &Value) -> Result { + let api_key = self + .api_key + .as_deref() + .ok_or_else(|| ToolError::Config(format!("{YOU_API_KEY_ENV} must be set to use the web_search tool")))?; + let args = WebSearchArguments::from_json(arguments)?; + let config = serde_json::from_value::(config.clone()) + .map_err(|e| ToolError::Config(format!("invalid web_search config: {e}")))?; + let request = YouSearchRequest::from_args_and_config(&args, &config)?; + let url = format!("{}/v1/search", self.base_url); + let body = serialize_to_string(&request) + .map_err(|e| ToolError::Execution(format!("failed to serialize web_search request: {e}")))?; + + let resp = self + .client + .post(url) + .header("X-API-Key", api_key) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + .map_err(|e| ToolError::Execution(format!("You.com search request failed: {e}")))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(ToolError::Execution(format!( + "You.com search returned {status}: {body}" + ))); + } + + let response_text = resp + .text() + .await + .map_err(|e| ToolError::Execution(format!("failed to read You.com search response: {e}")))?; + let response: Value = serde_json::from_str(&response_text) + .map_err(|e| ToolError::Execution(format!("You.com search returned invalid JSON: {e}")))?; + let output = serde_json::to_string(&serde_json::json!({ + "query": request.query, + "results": response + .get("results") + .cloned() + .unwrap_or_else(|| serde_json::json!({"web": [], "news": []})), + "metadata": response.get("metadata").cloned().unwrap_or(Value::Null) + })) + .map_err(|e| ToolError::Execution(format!("failed to serialize web_search output: {e}")))?; + + Ok(ToolOutput { + call_id: call_id.to_owned(), + output, + }) + } +} + +impl ToolHandler for WebSearchHandler { + fn tool_type(&self) -> ToolType { + ToolType::WebSearch + } + + fn validate(&self, param: &Value) -> Result<(), ToolError> { + serde_json::from_value::(param.clone()) + .map(|_| ()) + .map_err(|e| ToolError::Config(format!("invalid web_search config: {e}"))) + } + + fn normalize(&self, _param: &Value) -> Vec { + vec![web_search_function_tool()] + } +} + +impl GatewayExecutor for WebSearchHandler { + fn execute( + &self, + call_id: &str, + tool_name: &str, + arguments: &str, + config: &Value, + ) -> Pin> + Send + '_>> { + let call_id = call_id.to_owned(); + let tool_name = tool_name.to_owned(); + let arguments = arguments.to_owned(); + let config = config.clone(); + Box::pin(async move { + if tool_name != "web_search" { + return Err(ToolError::Config(format!( + "web_search handler cannot execute tool '{tool_name}'" + ))); + } + self.execute_search(&call_id, &arguments, &config).await + }) + } +} + +#[derive(Debug, Deserialize)] +struct WebSearchArguments { + query: String, + count: Option, + freshness: Option, + country: Option, + language: Option, + safesearch: Option, + livecrawl: Option, + livecrawl_formats: Option>, + crawl_timeout: Option, + include_domains: Option>, + exclude_domains: Option>, + boost_domains: Option>, +} + +impl WebSearchArguments { + fn from_json(arguments: &str) -> Result { + let args = serde_json::from_str::(arguments) + .map_err(|e| ToolError::Config(format!("web_search arguments must be valid JSON: {e}")))?; + if args.query.trim().is_empty() { + return Err(ToolError::Config("web_search query must not be empty".to_owned())); + } + Ok(args) + } +} + +#[derive(Debug, Serialize)] +struct YouSearchRequest { + query: String, + #[serde(skip_serializing_if = "Option::is_none")] + count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + freshness: Option, + #[serde(skip_serializing_if = "Option::is_none")] + country: Option, + #[serde(skip_serializing_if = "Option::is_none")] + language: Option, + #[serde(skip_serializing_if = "Option::is_none")] + safesearch: Option, + #[serde(skip_serializing_if = "Option::is_none")] + livecrawl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + livecrawl_formats: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + crawl_timeout: Option, + #[serde(skip_serializing_if = "Option::is_none")] + include_domains: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + exclude_domains: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + boost_domains: Option>, +} + +impl YouSearchRequest { + fn from_args_and_config(args: &WebSearchArguments, config: &WebSearchToolParam) -> Result { + let count = args + .count + .or_else(|| { + config + .search_context_size + .map(WebSearchContextSize::default_count) + .map(u16::from) + }) + .map(validate_count) + .transpose()?; + let crawl_timeout = args.crawl_timeout.map(validate_crawl_timeout).transpose()?; + let config_domains = config + .filters + .as_ref() + .and_then(|filters| clean_vec(filters.allowed_domains.as_deref())); + let include_domains = config_domains.or_else(|| clean_vec(args.include_domains.as_deref())); + let exclude_domains = clean_vec(args.exclude_domains.as_deref()); + let boost_domains = clean_vec(args.boost_domains.as_deref()); + if include_domains.is_some() && (exclude_domains.is_some() || boost_domains.is_some()) { + return Err(ToolError::Config( + "include_domains cannot be combined with exclude_domains or boost_domains".to_owned(), + )); + } + let country = config + .user_location + .as_ref() + .and_then(|location| clean_string(location.country.as_deref())) + .or_else(|| clean_string(args.country.as_deref())) + .map(|value| value.to_ascii_uppercase()); + + Ok(Self { + query: args.query.trim().to_owned(), + count, + freshness: clean_string(args.freshness.as_deref()), + country, + language: clean_string(args.language.as_deref()), + safesearch: clean_string(args.safesearch.as_deref()), + livecrawl: clean_string(args.livecrawl.as_deref()), + livecrawl_formats: clean_vec(args.livecrawl_formats.as_deref()), + crawl_timeout, + include_domains, + exclude_domains, + boost_domains, + }) + } +} + +fn validate_count(count: u16) -> Result { + if (1..=100).contains(&count) { + Ok(u8::try_from(count).expect("validated web_search count must fit in u8")) + } else { + Err(ToolError::Config( + "web_search count must be between 1 and 100".to_owned(), + )) + } +} + +fn validate_crawl_timeout(timeout: u16) -> Result { + if (1..=60).contains(&timeout) { + u8::try_from(timeout).map_err(|e| ToolError::Config(format!("invalid crawl_timeout: {e}"))) + } else { + Err(ToolError::Config( + "web_search crawl_timeout must be between 1 and 60".to_owned(), + )) + } +} + +fn clean_string(value: Option<&str>) -> Option { + value + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_owned) +} + +fn clean_vec(values: Option<&[String]>) -> Option> { + let cleaned: Vec = values + .unwrap_or_default() + .iter() + .filter_map(|value| clean_string(Some(value.as_str()))) + .collect(); + (!cleaned.is_empty()).then_some(cleaned) +} diff --git a/crates/agentic-core/src/types/io/mod.rs b/crates/agentic-core/src/types/io/mod.rs index 3675995..78b48bf 100644 --- a/crates/agentic-core/src/types/io/mod.rs +++ b/crates/agentic-core/src/types/io/mod.rs @@ -9,6 +9,7 @@ pub use input::{ }; pub use output::{ ApplyDone, FunctionToolCall, OutputItem, OutputMessage, OutputTextContent, ReasoningOutput, ReasoningTextContent, + WebSearchActionSearch, WebSearchCall, WebSearchCallStatus, WebSearchSource, }; pub use tools::{FunctionTool, ToolChoice}; pub(crate) use tools::{resolve_tool_choice, resolve_tools}; diff --git a/crates/agentic-core/src/types/io/output.rs b/crates/agentic-core/src/types/io/output.rs index fa4a85e..0b65d95 100644 --- a/crates/agentic-core/src/types/io/output.rs +++ b/crates/agentic-core/src/types/io/output.rs @@ -125,6 +125,75 @@ impl TryFrom<&EventPayload> for FunctionToolCall { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WebSearchCallStatus { + InProgress, + Completed, + Failed, +} + +impl WebSearchCallStatus { + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::InProgress => "in_progress", + Self::Completed => "completed", + Self::Failed => "failed", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebSearchSource { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebSearchActionSearch { + #[serde(rename = "type")] + pub type_: String, + pub query: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub sources: Vec, +} + +impl WebSearchActionSearch { + #[must_use] + pub fn new(query: impl Into, sources: Vec) -> Self { + Self { + type_: "search".to_owned(), + query: query.into(), + sources, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebSearchCall { + pub id: String, + pub status: WebSearchCallStatus, + pub action: WebSearchActionSearch, +} + +impl WebSearchCall { + #[must_use] + pub fn new( + id: impl Into, + status: WebSearchCallStatus, + query: impl Into, + sources: Vec, + ) -> Self { + Self { + id: id.into(), + status, + action: WebSearchActionSearch::new(query, sources), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ReasoningTextContent { #[serde(rename = "type")] @@ -240,6 +309,8 @@ pub enum OutputItem { Message(OutputMessage), #[serde(rename = "function_call")] FunctionCall(FunctionToolCall), + #[serde(rename = "web_search_call")] + WebSearchCall(WebSearchCall), #[serde(rename = "reasoning")] Reasoning(ReasoningOutput), #[serde(other)] diff --git a/crates/agentic-core/src/types/mod.rs b/crates/agentic-core/src/types/mod.rs index 8cdde1a..655daf7 100644 --- a/crates/agentic-core/src/types/mod.rs +++ b/crates/agentic-core/src/types/mod.rs @@ -7,10 +7,10 @@ pub use io::{ FunctionTool, FunctionToolCall, FunctionToolResultMessage, InputContent, InputImageContent, InputItem, InputMessage, InputMessageContent, InputTextContent, InputTokenDetails, OutputItem, OutputMessage, OutputTextContent, OutputTokenDetails, ReasoningOutput, ReasoningTextContent, ResponseUsage, ResponsesInput, - ToolChoice, + ToolChoice, WebSearchActionSearch, WebSearchCall, WebSearchCallStatus, WebSearchSource, }; pub use request_response::{IncompleteDetails, RequestPayload, ResponsePayload, UpstreamRequest}; pub use tools::{ CodeInterpreterToolParam, EmptyToolNameError, FileSearchToolParam, FunctionToolParam, McpToolParam, - NonEmptyToolName, ResponsesTool, WebSearchToolParam, + NonEmptyToolName, ResponsesTool, WebSearchContextSize, WebSearchFilters, WebSearchToolParam, WebSearchUserLocation, }; diff --git a/crates/agentic-core/src/types/tools/mod.rs b/crates/agentic-core/src/types/tools/mod.rs index dd63b8d..b3e5a6c 100644 --- a/crates/agentic-core/src/types/tools/mod.rs +++ b/crates/agentic-core/src/types/tools/mod.rs @@ -7,5 +7,5 @@ pub mod params; pub use params::{ CodeInterpreterToolParam, EmptyToolNameError, FileSearchToolParam, FunctionToolParam, McpToolParam, - NonEmptyToolName, ResponsesTool, WebSearchToolParam, + NonEmptyToolName, ResponsesTool, WebSearchContextSize, WebSearchFilters, WebSearchToolParam, WebSearchUserLocation, }; diff --git a/crates/agentic-core/src/types/tools/params.rs b/crates/agentic-core/src/types/tools/params.rs index e25a0a8..ee9dc60 100644 --- a/crates/agentic-core/src/types/tools/params.rs +++ b/crates/agentic-core/src/types/tools/params.rs @@ -83,7 +83,12 @@ pub enum ResponsesTool { #[serde(rename = "mcp")] Mcp(McpToolParam), - #[serde(rename = "web_search_preview")] + #[serde( + rename = "web_search_preview", + alias = "web_search", + alias = "web_search_preview_2025_03_11", + alias = "web_search_2025_08_26" + )] WebSearch(WebSearchToolParam), #[serde(rename = "file_search")] @@ -118,9 +123,46 @@ pub struct McpToolParam { pub headers: Option>, } -/// Parameters for a web search tool (no required fields). +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WebSearchContextSize { + Low, + Medium, + High, +} + +impl WebSearchContextSize { + pub(crate) const fn default_count(self) -> u8 { + match self { + Self::Low => 3, + Self::Medium => 5, + Self::High => 10, + } + } +} + #[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct WebSearchToolParam {} +pub struct WebSearchFilters { + pub allowed_domains: Option>, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct WebSearchUserLocation { + #[serde(rename = "type")] + pub type_: Option, + pub city: Option, + pub country: Option, + pub region: Option, + pub timezone: Option, +} + +/// Parameters for a web search tool. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct WebSearchToolParam { + pub search_context_size: Option, + pub filters: Option, + pub user_location: Option, +} /// Parameters for a file search tool. #[derive(Debug, Clone, Default, Serialize, Deserialize)] @@ -206,6 +248,20 @@ mod tests { assert_eq!(serde_json::to_value(&tool).unwrap()["type"], "web_search_preview"); } + #[test] + fn responses_tool_web_search_accepts_openai_aliases() { + for type_name in [ + "web_search", + "web_search_preview", + "web_search_preview_2025_03_11", + "web_search_2025_08_26", + ] { + let json = serde_json::json!({"type": type_name}); + let tool: ResponsesTool = serde_json::from_value(json).unwrap(); + assert!(matches!(tool, ResponsesTool::WebSearch(_))); + } + } + #[test] fn responses_tool_file_search_round_trips() { let json = serde_json::json!({"type": "file_search", "vector_store_ids": ["vs_abc"]}); diff --git a/crates/agentic-core/tests/event_normalizer_test.rs b/crates/agentic-core/tests/event_normalizer_test.rs index c2e861f..37c70af 100644 --- a/crates/agentic-core/tests/event_normalizer_test.rs +++ b/crates/agentic-core/tests/event_normalizer_test.rs @@ -342,11 +342,24 @@ fn test_file_search_classification() { #[test] fn test_web_search_classification() { - let line = - r#"data: {"type":"response.web_search_call.completed","item_id":"ws_1","output_index":0,"sequence_number":6}"#; - let frame = normalize_sse_line(line).unwrap(); - assert_eq!(frame.event_type, SSEEventType::WebSearchCallCompleted); - assert!(matches!(frame.payload, EventPayload::Raw(_))); + for (line, expected) in [ + ( + r#"data: {"type":"response.web_search_call.in_progress","item_id":"ws_1","output_index":0,"sequence_number":4}"#, + SSEEventType::WebSearchCallInProgress, + ), + ( + r#"data: {"type":"response.web_search_call.searching","item_id":"ws_1","output_index":0,"sequence_number":5}"#, + SSEEventType::WebSearchCallSearching, + ), + ( + r#"data: {"type":"response.web_search_call.completed","item_id":"ws_1","output_index":0,"sequence_number":6}"#, + SSEEventType::WebSearchCallCompleted, + ), + ] { + let frame = normalize_sse_line(line).unwrap(); + assert_eq!(frame.event_type, expected); + assert!(matches!(frame.payload, EventPayload::Raw(_))); + } } // --- Helpers and constants for integration tests --- diff --git a/crates/agentic-core/tests/support/mod.rs b/crates/agentic-core/tests/support/mod.rs index 4837593..7a9edd8 100644 --- a/crates/agentic-core/tests/support/mod.rs +++ b/crates/agentic-core/tests/support/mod.rs @@ -402,7 +402,10 @@ pub fn output_text(payload: &ResponsePayload) -> String { .iter() .filter_map(|item| match item { OutputItem::Message(msg) => Some(msg.content.iter().map(|c| c.text.as_str()).collect::()), - OutputItem::FunctionCall(_) | OutputItem::Reasoning(_) | OutputItem::Unknown => None, + OutputItem::FunctionCall(_) + | OutputItem::WebSearchCall(_) + | OutputItem::Reasoning(_) + | OutputItem::Unknown => None, }) .collect::() } diff --git a/crates/agentic-core/tests/tool_normalization_test.rs b/crates/agentic-core/tests/tool_normalization_test.rs index 6065118..f59be6c 100644 --- a/crates/agentic-core/tests/tool_normalization_test.rs +++ b/crates/agentic-core/tests/tool_normalization_test.rs @@ -192,3 +192,24 @@ fn roundtrip_5turn() { fn roundtrip_parallel() { assert_full_roundtrip("openai_responses_tool_calls_parallel.yaml"); } + +#[test] +fn web_search_preview_normalizes_to_gateway_function() { + let payload: RequestPayload = serde_json::from_value(serde_json::json!({ + "model": "test", + "input": "what changed today?", + "tools": [{"type": "web_search_preview"}] + })) + .unwrap(); + + let upstream = payload.to_upstream_request(false); + let tools = upstream.tools.expect("web_search should normalize to a function tool"); + + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].type_, "function"); + assert_eq!(tools[0].name, "web_search"); + assert_eq!( + tools[0].parameters.as_ref().unwrap()["required"], + serde_json::json!(["query"]) + ); +} diff --git a/crates/agentic-core/tests/web_search_tool_test.rs b/crates/agentic-core/tests/web_search_tool_test.rs new file mode 100644 index 0000000..bffe42b --- /dev/null +++ b/crates/agentic-core/tests/web_search_tool_test.rs @@ -0,0 +1,937 @@ +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; +use std::time::Duration; + +use agentic_core::executor::{ConversationHandler, ExecuteRequest, ExecutionContext, ResponseHandler}; +use agentic_core::storage::{ConversationStore, ResponseStore}; +use agentic_core::tool::{GatewayExecutor, WebSearchHandler}; +use agentic_core::types::io::{OutputItem, ResponsesInput, ToolChoice}; +use agentic_core::types::request_response::RequestPayload; +use agentic_core::types::tools::ResponsesTool; +use axum::extract::State; +use axum::http::{HeaderMap, StatusCode}; +use axum::routing::post; +use axum::{Json, Router}; +use either::Either; +use futures::StreamExt; +use tokio::net::TcpListener; +use tokio::sync::{Notify, mpsc}; + +mod support; + +#[derive(Debug)] +struct CapturedSearchRequest { + api_key: String, + body: serde_json::Value, +} + +async fn spawn_mock_you() -> ( + String, + mpsc::Receiver, + tokio::task::JoinHandle<()>, +) { + spawn_mock_you_with_response( + StatusCode::OK, + serde_json::json!({ + "results": { + "web": [{ + "url": "https://example.com/rust", + "title": "Rust async guide", + "description": "A useful guide", + "snippets": ["Use async carefully."] + }], + "news": [] + }, + "metadata": { + "query": "rust async", + "search_uuid": "search_123", + "latency": 0.12 + } + }), + ) + .await +} + +async fn spawn_mock_you_with_response( + status: StatusCode, + response_body: serde_json::Value, +) -> ( + String, + mpsc::Receiver, + tokio::task::JoinHandle<()>, +) { + let (tx, rx) = mpsc::channel(16); + let app = Router::new() + .route( + "/v1/search", + post( + move |State(tx): State>, + headers: HeaderMap, + Json(body): Json| async move { + let api_key = headers + .get("x-api-key") + .and_then(|value| value.to_str().ok()) + .unwrap_or_default() + .to_owned(); + tx.send(CapturedSearchRequest { api_key, body }).await.unwrap(); + (status, Json(response_body.clone())) + }, + ), + ) + .with_state(tx); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let handle = tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + (format!("http://{addr}"), rx, handle) +} + +async fn spawn_mock_you_waiting_for_two_searches() -> ( + String, + mpsc::Receiver, + tokio::task::JoinHandle<()>, +) { + let (tx, rx) = mpsc::channel(16); + let started = Arc::new(AtomicUsize::new(0)); + let notify = Arc::new(Notify::new()); + let app = Router::new() + .route( + "/v1/search", + post( + move |State((tx, started, notify)): State<( + mpsc::Sender, + Arc, + Arc, + )>, + headers: HeaderMap, + Json(body): Json| async move { + let api_key = headers + .get("x-api-key") + .and_then(|value| value.to_str().ok()) + .unwrap_or_default() + .to_owned(); + tx.send(CapturedSearchRequest { + api_key, + body: body.clone(), + }) + .await + .unwrap(); + if started.fetch_add(1, Ordering::SeqCst) + 1 >= 2 { + notify.notify_waiters(); + } + while started.load(Ordering::SeqCst) < 2 { + notify.notified().await; + } + + let query = body["query"].as_str().unwrap_or("unknown"); + let slug = query.replace(' ', "-"); + ( + StatusCode::OK, + Json(serde_json::json!({ + "results": { + "web": [{ + "url": format!("https://example.com/{slug}"), + "title": format!("{query} guide") + }], + "news": [] + }, + "metadata": {"query": query} + })), + ) + }, + ), + ) + .with_state((tx, started, notify)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let handle = tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + (format!("http://{addr}"), rx, handle) +} + +#[tokio::test] +async fn web_search_handler_posts_to_you_and_formats_results() { + let (base_url, mut captured, _handle) = spawn_mock_you().await; + let handler = + WebSearchHandler::with_api_key(Arc::new(reqwest::Client::new()), "secret-you-key".to_owned(), &base_url); + + let output = handler + .execute( + "call_search", + "web_search", + r#"{"query":"rust async","count":2}"#, + &serde_json::json!({"type":"web_search_preview"}), + ) + .await + .unwrap(); + + let request = captured.recv().await.expect("mock You.com should receive request"); + assert_eq!(request.api_key, "secret-you-key"); + assert_eq!(request.body["query"], "rust async"); + assert_eq!(request.body["count"], 2); + + assert_eq!(output.call_id, "call_search"); + let output_json: serde_json::Value = serde_json::from_str(&output.output).unwrap(); + assert_eq!(output_json["query"], "rust async"); + assert_eq!(output_json["results"]["web"][0]["url"], "https://example.com/rust"); + assert_eq!(output_json["metadata"]["search_uuid"], "search_123"); +} + +fn web_search_function_call_response() -> support::MockResponse { + support::MockResponse::Json( + serde_json::json!({ + "id": "resp_tool_call", + "object": "response", + "created_at": 0, + "model": "test-model", + "status": "completed", + "output": [{ + "id": "fc_search", + "type": "function_call", + "call_id": "call_search", + "name": "web_search", + "arguments": "{\"query\":\"rust async\",\"count\":2}", + "status": "completed" + }], + "usage": null, + "incomplete_details": null, + "error": null, + "previous_response_id": null, + "conversation_id": null, + "instructions": null + }) + .to_string(), + ) +} + +fn two_web_search_function_call_response() -> support::MockResponse { + support::MockResponse::Json( + serde_json::json!({ + "id": "resp_two_tool_calls", + "object": "response", + "created_at": 0, + "model": "test-model", + "status": "completed", + "output": [ + { + "id": "fc_search_1", + "type": "function_call", + "call_id": "call_search_1", + "name": "web_search", + "arguments": "{\"query\":\"rust async\",\"count\":2}", + "status": "completed" + }, + { + "id": "fc_search_2", + "type": "function_call", + "call_id": "call_search_2", + "name": "web_search", + "arguments": "{\"query\":\"tokio streams\",\"count\":2}", + "status": "completed" + } + ], + "usage": null, + "incomplete_details": null, + "error": null, + "previous_response_id": null, + "conversation_id": null, + "instructions": null + }) + .to_string(), + ) +} + +fn sse_response(events: impl IntoIterator) -> support::MockResponse { + let mut body = String::new(); + for event in events { + body.push_str("data: "); + body.push_str(&serde_json::to_string(&event).unwrap()); + body.push_str("\n\n"); + } + body.push_str("data: [DONE]\n\n"); + support::MockResponse::Sse(body) +} + +fn web_search_function_call_sse_response() -> support::MockResponse { + sse_response([ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp_tool_call", "status": "in_progress", "usage": null} + }), + serde_json::json!({ + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": "fc_search", + "type": "function_call", + "call_id": "call_search", + "name": "web_search", + "arguments": "", + "status": "in_progress" + } + }), + serde_json::json!({ + "type": "response.function_call_arguments.done", + "item_id": "fc_search", + "output_index": 0, + "call_id": "call_search", + "name": "web_search", + "arguments": "{\"query\":\"rust async\",\"count\":2}" + }), + serde_json::json!({ + "type": "response.completed", + "response": {"id": "resp_tool_call", "status": "completed", "usage": null} + }), + ]) +} + +fn text_sse_response(text: &str) -> support::MockResponse { + sse_response([ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp_final", "status": "in_progress", "usage": null} + }), + serde_json::json!({ + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": "msg_final", + "type": "message", + "role": "assistant", + "status": "in_progress", + "content": [] + } + }), + serde_json::json!({ + "type": "response.output_text.delta", + "item_id": "msg_final", + "output_index": 0, + "content_index": 0, + "delta": text + }), + serde_json::json!({ + "type": "response.completed", + "response": {"id": "resp_final", "status": "completed", "usage": null} + }), + ]) +} + +fn mixed_web_search_and_client_function_response() -> support::MockResponse { + support::MockResponse::Json( + serde_json::json!({ + "id": "resp_mixed_tool_call", + "object": "response", + "created_at": 0, + "model": "test-model", + "status": "completed", + "output": [ + { + "id": "fc_search", + "type": "function_call", + "call_id": "call_search", + "name": "web_search", + "arguments": "{\"query\":\"rust async\",\"count\":2}", + "status": "completed" + }, + { + "id": "fc_weather", + "type": "function_call", + "call_id": "call_weather", + "name": "get_weather", + "arguments": "{\"city\":\"San Francisco\"}", + "status": "completed" + } + ], + "usage": null, + "incomplete_details": null, + "error": null, + "previous_response_id": null, + "conversation_id": null, + "instructions": null + }) + .to_string(), + ) +} + +fn text_response_with_usage(text: &str, input_tokens: i64, output_tokens: i64) -> support::MockResponse { + let id_suffix = text.replace(' ', "_"); + support::MockResponse::Json( + serde_json::json!({ + "id": format!("resp_upstream_{id_suffix}"), + "object": "response", + "created_at": 0, + "model": "test-model", + "status": "completed", + "output": [{ + "id": format!("msg_upstream_{id_suffix}"), + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{ + "type": "output_text", + "text": text, + "annotations": [] + }] + }], + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "input_tokens_details": {"cached_tokens": 1}, + "output_tokens_details": {"reasoning_tokens": 2} + }, + "incomplete_details": null, + "error": null, + "previous_response_id": null, + "conversation_id": null, + "instructions": null + }) + .to_string(), + ) +} + +fn web_search_function_call_response_with_usage(input_tokens: i64, output_tokens: i64) -> support::MockResponse { + support::MockResponse::Json( + serde_json::json!({ + "id": "resp_tool_call", + "object": "response", + "created_at": 0, + "model": "test-model", + "status": "completed", + "output": [{ + "id": "fc_search", + "type": "function_call", + "call_id": "call_search", + "name": "web_search", + "arguments": "{\"query\":\"rust async\",\"count\":2}", + "status": "completed" + }], + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "input_tokens_details": {"cached_tokens": 3}, + "output_tokens_details": {"reasoning_tokens": 4} + }, + "incomplete_details": null, + "error": null, + "previous_response_id": null, + "conversation_id": null, + "instructions": null + }) + .to_string(), + ) +} + +async fn build_exec_ctx(llm_url: &str, you_url: String) -> Arc { + let pool = support::setup_pool().await; + let conv_handler = ConversationHandler::new(ConversationStore::new(Arc::clone(&pool))); + let resp_handler = ResponseHandler::new(ResponseStore::new(pool)); + let client = Arc::new(reqwest::Client::new()); + Arc::new( + ExecutionContext::new(conv_handler, resp_handler, Arc::clone(&client), llm_url.to_owned()) + .with_gateway_executor(Arc::new(WebSearchHandler::with_api_key( + client, + "secret-you-key".to_owned(), + &you_url, + ))), + ) +} + +#[tokio::test] +async fn execute_runs_web_search_and_sends_tool_output_back_to_model() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm = support::MockServer::start_deque(vec![ + web_search_function_call_response(), + support::text_response("Use async carefully."), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, Arc::clone(&exec_ctx)).run().await.unwrap(); + let Either::Left(response) = result else { + panic!("expected non-streaming response"); + }; + + let request = captured_you.recv().await.expect("mock You.com should receive request"); + assert_eq!(request.body["query"], "rust async"); + + let request_bodies = llm.request_bodies().await; + assert_eq!(request_bodies.len(), 2); + assert_eq!(request_bodies[0]["tools"][0]["name"], "web_search"); + let second_input = request_bodies[1]["input"] + .as_array() + .expect("second request input array"); + let tool_output = second_input + .iter() + .find(|item| item["type"] == "function_call_output") + .expect("second request includes web_search output"); + assert_eq!(tool_output["call_id"], "call_search"); + assert!( + tool_output["output"] + .as_str() + .unwrap() + .contains("https://example.com/rust") + ); + + let response_output = serde_json::to_value(&response.output).unwrap(); + let output_items = response_output.as_array().unwrap(); + assert!( + !output_items + .iter() + .any(|item| item["type"] == "function_call" && item["name"] == "web_search"), + "raw web_search function calls must stay internal" + ); + let web_search_call = output_items + .iter() + .find(|item| item["type"] == "web_search_call") + .expect("response output should include web_search_call item"); + assert_eq!(web_search_call["status"], "completed"); + assert_eq!(web_search_call["action"]["type"], "search"); + assert_eq!(web_search_call["action"]["query"], "rust async"); + assert_eq!( + web_search_call["action"]["sources"][0]["url"], + "https://example.com/rust" + ); + assert_eq!(web_search_call["action"]["sources"][0]["title"], "Rust async guide"); + assert!( + response + .output + .iter() + .any(|item| matches!(item, OutputItem::Message(_))) + ); +} + +#[tokio::test] +async fn execute_relaxes_forced_tool_choice_after_web_search_result() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm = support::MockServer::start_deque(vec![ + web_search_function_call_response(), + support::text_response("Use async carefully."), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Required, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, Arc::clone(&exec_ctx)).run().await.unwrap(); + assert!(matches!(result, Either::Left(_))); + captured_you.recv().await.expect("mock You.com should receive request"); + + let request_bodies = llm.request_bodies().await; + assert_eq!(request_bodies.len(), 2); + assert_eq!(request_bodies[0]["tool_choice"], "required"); + assert!(request_bodies[1].get("tool_choice").is_none()); +} + +#[tokio::test] +async fn execute_returns_mixed_client_tool_calls_without_followup_model_request() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm = support::MockServer::start_deque(vec![ + mixed_web_search_and_client_function_response(), + support::text_response("continued after mixed tools"), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let client_function: ResponsesTool = serde_json::from_value(serde_json::json!({ + "type": "function", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}} + } + })) + .unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async and weather".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search, client_function]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, Arc::clone(&exec_ctx)).run().await.unwrap(); + let Either::Left(response) = result else { + panic!("expected non-streaming response"); + }; + + assert_eq!(llm.request_bodies().await.len(), 1); + let search_request = captured_you.recv().await.expect("mock You.com should receive request"); + assert_eq!(search_request.body["query"], "rust async"); + let response_output = serde_json::to_value(&response.output).unwrap(); + let output_items = response_output.as_array().unwrap(); + assert_eq!(output_items[0]["type"], "web_search_call"); + assert_eq!(output_items[0]["action"]["query"], "rust async"); + assert!( + !output_items + .iter() + .any(|item| item["type"] == "function_call" && item["name"] == "web_search"), + "raw web_search function calls must stay internal" + ); + let function_names: Vec<&str> = response + .output + .iter() + .filter_map(|item| match item { + OutputItem::FunctionCall(call) => Some(call.name.as_str()), + _ => None, + }) + .collect(); + assert_eq!(function_names, ["get_weather"]); + + let continuation_payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("continue".to_owned()), + instructions: None, + previous_response_id: Some(response.id), + conversation_id: None, + tools: None, + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + let continuation = ExecuteRequest::new(continuation_payload, exec_ctx).run().await.unwrap(); + assert!(matches!(continuation, Either::Left(_))); + let request_bodies = llm.request_bodies().await; + assert_eq!(request_bodies.len(), 2); + let second_input = request_bodies[1]["input"] + .as_array() + .expect("second request input array"); + let tool_output = second_input + .iter() + .find(|item| item["type"] == "function_call_output" && item["call_id"] == "call_search") + .expect("continuation includes persisted web_search output"); + assert!( + tool_output["output"] + .as_str() + .unwrap() + .contains("https://example.com/rust") + ); +} + +#[tokio::test] +async fn web_search_rejects_incompatible_domain_filters_before_calling_you() { + let (base_url, mut captured, _handle) = spawn_mock_you().await; + let handler = + WebSearchHandler::with_api_key(Arc::new(reqwest::Client::new()), "secret-you-key".to_owned(), &base_url); + + let err = handler + .execute( + "call_search", + "web_search", + r#"{"query":"rust async","exclude_domains":["example.com"]}"#, + &serde_json::json!({ + "type": "web_search_preview", + "filters": {"allowed_domains": ["rust-lang.org"]} + }), + ) + .await + .expect_err("allowed_domains and exclude_domains should be rejected"); + + assert!(err.to_string().contains("include_domains cannot be combined")); + assert!(captured.try_recv().is_err()); +} + +#[tokio::test] +async fn execute_accumulates_usage_across_web_search_model_rounds() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm = support::MockServer::start_deque(vec![ + web_search_function_call_response_with_usage(10, 5), + text_response_with_usage("Use async carefully.", 7, 3), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, exec_ctx).run().await.unwrap(); + let Either::Left(response) = result else { + panic!("expected non-streaming response"); + }; + captured_you.recv().await.expect("mock You.com should receive request"); + + let usage = response.usage.expect("usage should be present"); + assert_eq!(usage.input_tokens, 17); + assert_eq!(usage.output_tokens, 8); + assert_eq!(usage.total_tokens, 25); + assert_eq!(usage.input_tokens_details.cached_tokens, 4); + assert_eq!(usage.output_tokens_details.reasoning_tokens, 6); +} + +#[tokio::test] +async fn stream_emits_web_search_lifecycle_events_before_final_payload() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm = support::MockServer::start_deque(vec![ + web_search_function_call_sse_response(), + text_sse_response("Use async carefully."), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: true, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, Arc::clone(&exec_ctx)).run().await.unwrap(); + let Either::Right(stream) = result else { + panic!("expected streaming response"); + }; + let chunks: Vec = stream.collect().await; + captured_you.recv().await.expect("mock You.com should receive request"); + + let json_events: Vec = chunks + .iter() + .filter_map(|chunk| { + let data = chunk.trim_end_matches('\n').strip_prefix("data: ")?; + if data == "[DONE]" { + return None; + } + serde_json::from_str(data).ok() + }) + .collect(); + let event_types: Vec<&str> = json_events.iter().filter_map(|event| event["type"].as_str()).collect(); + let expected_types = [ + "response.output_item.added", + "response.web_search_call.in_progress", + "response.web_search_call.searching", + "response.web_search_call.completed", + "response.output_item.done", + ]; + let mut last_index = 0; + for expected in expected_types { + let index = event_types + .iter() + .enumerate() + .skip(last_index) + .find_map(|(index, actual)| (*actual == expected).then_some(index)) + .unwrap_or_else(|| panic!("missing streaming event {expected}; got {event_types:?}")); + last_index = index + 1; + } + + let final_payload = json_events + .iter() + .find(|event| event["object"] == "response") + .expect("stream should include final response payload"); + let output = final_payload["output"].as_array().unwrap(); + assert!(output.iter().any(|item| item["type"] == "web_search_call")); + assert!(output.iter().any(|item| item["type"] == "message")); +} + +#[tokio::test] +async fn execute_runs_multiple_web_search_calls_concurrently() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you_waiting_for_two_searches().await; + let llm = support::MockServer::start_deque(vec![ + two_web_search_function_call_response(), + support::text_response("Use async carefully."), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async and tokio streams".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = tokio::time::timeout(Duration::from_secs(2), ExecuteRequest::new(payload, exec_ctx).run()) + .await + .expect("gateway calls should execute concurrently instead of waiting on the first search") + .unwrap(); + assert!(matches!(result, Either::Left(_))); + + let mut queries = Vec::new(); + for _ in 0..2 { + let request = captured_you.recv().await.expect("mock You.com should receive request"); + queries.push(request.body["query"].as_str().unwrap().to_owned()); + } + queries.sort(); + assert_eq!(queries, ["rust async", "tokio streams"]); +} + +#[tokio::test] +async fn execute_feeds_web_search_execution_errors_back_to_model() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you_with_response( + StatusCode::INTERNAL_SERVER_ERROR, + serde_json::json!({"error": "search backend down"}), + ) + .await; + let llm = support::MockServer::start_deque(vec![ + web_search_function_call_response(), + support::text_response("I could not search live web results."), + ]) + .await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, exec_ctx).run().await.unwrap(); + assert!(matches!(result, Either::Left(_))); + captured_you.recv().await.expect("mock You.com should receive request"); + + let request_bodies = llm.request_bodies().await; + assert_eq!(request_bodies.len(), 2); + let second_input = request_bodies[1]["input"] + .as_array() + .expect("second request input array"); + let tool_output = second_input + .iter() + .find(|item| item["type"] == "function_call_output") + .expect("second request includes web_search error output"); + let output_json: serde_json::Value = serde_json::from_str(tool_output["output"].as_str().unwrap()).unwrap(); + assert!( + output_json["error"] + .as_str() + .unwrap() + .contains("You.com search returned 500 Internal Server Error") + ); +} + +#[tokio::test] +async fn execute_errors_after_max_gateway_tool_rounds() { + let (you_url, mut captured_you, _you_handle) = spawn_mock_you().await; + let llm_responses = std::iter::repeat_with(web_search_function_call_response) + .take(10) + .collect(); + let llm = support::MockServer::start_deque(llm_responses).await; + let exec_ctx = build_exec_ctx(llm.url(), you_url).await; + let web_search: ResponsesTool = serde_json::from_value(serde_json::json!({"type": "web_search_preview"})).unwrap(); + let payload = RequestPayload { + model: "test-model".to_owned(), + input: ResponsesInput::Text("look up rust async".to_owned()), + instructions: None, + previous_response_id: None, + conversation_id: None, + tools: Some(vec![web_search]), + tool_choice: ToolChoice::Auto, + stream: false, + store: true, + include: None, + temperature: None, + top_p: None, + max_output_tokens: None, + truncation: None, + metadata: None, + }; + + let result = ExecuteRequest::new(payload, exec_ctx).run().await; + assert!( + result + .err() + .is_some_and(|err| err.to_string().contains("gateway tool execution exceeded 10 rounds")) + ); + for _ in 0..10 { + captured_you.recv().await.expect("mock You.com should receive request"); + } + assert!(captured_you.try_recv().is_err()); + assert_eq!(llm.request_bodies().await.len(), 10); +} From b87aacc81642ede052f6d5dfe917d474b1412316 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 2 Jul 2026 20:12:28 -0700 Subject: [PATCH 2/2] fix: require you api base url Signed-off-by: Francisco Javier Arceo --- crates/agentic-core/src/tool/web_search.rs | 30 +++++++++++-------- .../tests/web_search_tool_test.rs | 20 +++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/crates/agentic-core/src/tool/web_search.rs b/crates/agentic-core/src/tool/web_search.rs index 88471d7..53a1fc4 100644 --- a/crates/agentic-core/src/tool/web_search.rs +++ b/crates/agentic-core/src/tool/web_search.rs @@ -12,9 +12,8 @@ use crate::utils::common::serialize_to_string; use super::handler::{GatewayExecutor, ToolError, ToolHandler, ToolOutput}; use super::registry::ToolType; -const DEFAULT_YOU_SEARCH_BASE_URL: &str = "https://ydc-index.io"; -const YOU_API_KEY_ENV: &str = "YOU_API_KEY"; -const YOU_API_BASE_URL_ENV: &str = "YOU_API_BASE_URL"; +const YOU_API_KEY: &str = "YOU_API_KEY"; +const YOU_API_BASE_URL: &str = "YOU_API_BASE_URL"; #[must_use] pub(crate) fn web_search_function_tool() -> FunctionTool { @@ -68,21 +67,19 @@ pub(crate) fn web_search_function_tool() -> FunctionTool { pub struct WebSearchHandler { client: Arc, api_key: Option, - base_url: String, + base_url: Option, } impl WebSearchHandler { #[must_use] pub fn from_env(client: Arc) -> Self { - let api_key = std::env::var(YOU_API_KEY_ENV) + let api_key = std::env::var(YOU_API_KEY) .ok() .map(|value| value.trim().to_owned()) .filter(|value| !value.is_empty()); - let base_url = std::env::var(YOU_API_BASE_URL_ENV) + let base_url = std::env::var(YOU_API_BASE_URL) .ok() - .map(|value| value.trim().trim_end_matches('/').to_owned()) - .filter(|value| !value.is_empty()) - .unwrap_or_else(|| DEFAULT_YOU_SEARCH_BASE_URL.to_owned()); + .and_then(|value| clean_base_url(&value)); Self { client, api_key, @@ -95,7 +92,7 @@ impl WebSearchHandler { Self { client, api_key: Some(api_key), - base_url: base_url.trim_end_matches('/').to_owned(), + base_url: clean_base_url(base_url), } } @@ -103,12 +100,16 @@ impl WebSearchHandler { let api_key = self .api_key .as_deref() - .ok_or_else(|| ToolError::Config(format!("{YOU_API_KEY_ENV} must be set to use the web_search tool")))?; + .ok_or_else(|| ToolError::Config(format!("{YOU_API_KEY} must be set to use the web_search tool")))?; + let base_url = self + .base_url + .as_deref() + .ok_or_else(|| ToolError::Config(format!("{YOU_API_BASE_URL} must be set to use the web_search tool")))?; let args = WebSearchArguments::from_json(arguments)?; let config = serde_json::from_value::(config.clone()) .map_err(|e| ToolError::Config(format!("invalid web_search config: {e}")))?; let request = YouSearchRequest::from_args_and_config(&args, &config)?; - let url = format!("{}/v1/search", self.base_url); + let url = format!("{base_url}/v1/search"); let body = serialize_to_string(&request) .map_err(|e| ToolError::Execution(format!("failed to serialize web_search request: {e}")))?; @@ -322,6 +323,11 @@ fn clean_string(value: Option<&str>) -> Option { .map(str::to_owned) } +fn clean_base_url(value: &str) -> Option { + let trimmed = value.trim().trim_end_matches('/'); + (!trimmed.is_empty()).then(|| trimmed.to_owned()) +} + fn clean_vec(values: Option<&[String]>) -> Option> { let cleaned: Vec = values .unwrap_or_default() diff --git a/crates/agentic-core/tests/web_search_tool_test.rs b/crates/agentic-core/tests/web_search_tool_test.rs index bffe42b..264578d 100644 --- a/crates/agentic-core/tests/web_search_tool_test.rs +++ b/crates/agentic-core/tests/web_search_tool_test.rs @@ -179,6 +179,26 @@ async fn web_search_handler_posts_to_you_and_formats_results() { assert_eq!(output_json["metadata"]["search_uuid"], "search_123"); } +#[tokio::test] +async fn web_search_handler_requires_base_url() { + let handler = WebSearchHandler::with_api_key(Arc::new(reqwest::Client::new()), "secret-you-key".to_owned(), ""); + + let err = handler + .execute( + "call_search", + "web_search", + r#"{"query":"rust async"}"#, + &serde_json::json!({"type":"web_search_preview"}), + ) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "invalid tool config: YOU_API_BASE_URL must be set to use the web_search tool" + ); +} + fn web_search_function_call_response() -> support::MockResponse { support::MockResponse::Json( serde_json::json!({