Skip to content
Open
345 changes: 345 additions & 0 deletions crates/agentic-core/src/executor/agentic_loop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
//! Agentic tool loop — multi-turn executor that dispatches tool calls between inference steps.

use std::collections::HashMap;
use std::sync::Arc;

use either::Either;

use crate::executor::dispatch::{LoopDecision, dispatch_tools};
use crate::executor::engine::execute;
use crate::executor::error::{ExecutorError, ExecutorResult};
use crate::executor::request::ExecutionContext;
use crate::tool::{GatewayExecutor, ToolRegistry, ToolType};
use crate::types::io::input::{InputItem, ResponsesInput};
use crate::types::request_response::{RequestPayload, ResponsePayload};

/// Hard safety guard — prevents runaway loops regardless of `max_iterations`.
const MAX_LOOP_GUARD: usize = 128;

/// Default soft cap on tool-dispatch iterations per request.
pub const DEFAULT_MAX_ITERATIONS: usize = 10;

/// Run the full agentic tool loop.
///
/// Calls [`execute`] repeatedly, dispatching gateway-owned tool calls after each
/// inference turn, until the model stops producing tool calls or `max_iterations`
/// is reached.
///
/// ## Persistence contract
///
/// This function **never writes to the database**. All three persistence triggers
/// (`store`, `previous_response_id`, `conversation_id`) are cleared before the
/// first iteration to suppress intermediate persists inside [`execute`]. The
/// original IDs are restored onto the final payload before returning.
///
/// Persistence is the caller's responsibility. The caller must obtain a
/// [`crate::executor::request::RequestContext`] (e.g. via
/// [`crate::executor::engine::rehydrate_conversation`]) before calling this
/// function and pass it to [`crate::executor::engine::persist_response`]
/// after this function returns.
///
/// ## MCP discovery
///
/// `registry` must be built via [`ToolRegistry::build`] before calling. MCP tool
/// names are absent from the registry in this PR — discovery is added in PR C.
/// Any `function_call` for an MCP tool name is treated as client-owned (skipped).
///
/// # Errors
///
/// - [`ExecutorError::InvalidRequest`] if `request.stream` is `true` (streaming
/// + tool dispatch requires the `StreamTee`, which is a future PR).
/// - [`ExecutorError`] variants from [`execute`] if LLM inference fails.
#[must_use = "the ResponsePayload contains the final inference result"]
pub async fn execute_loop<S: std::hash::BuildHasher>(
mut request: RequestPayload,
exec_ctx: Arc<ExecutionContext>,
registry: ToolRegistry,
executors: HashMap<ToolType, Arc<dyn GatewayExecutor>, S>,
max_iterations: usize,
) -> ExecutorResult<ResponsePayload> {
if request.stream {
return Err(ExecutorError::InvalidRequest(
"execute_loop does not support streaming requests; use execute() directly or wait for StreamTee PR"
.to_owned(),
));
}

// Capture original IDs before clearing persistence triggers.
// The loop calls execute() internally — clearing these suppresses intermediate persists.
let original_prev_id = request.previous_response_id.clone();
let original_conv_id = request.conversation_id.clone();

request.store = false;
request.previous_response_id = None;
request.conversation_id = None;

// Clamp caller's max to the hard guard. The loop runs up to and including
// effective_max iterations so that dispatch_tools sees iteration == effective_max
// and returns Incomplete rather than silently stopping at the loop boundary.
let effective_max = max_iterations.min(MAX_LOOP_GUARD);
// Stub payload replaced on the first inference call. Initialised here so
// Rust's definite-assignment rules are satisfied — the loop always overwrites
// it at least once before returning (stream=false path).
let mut payload = ResponsePayload {
id: String::new(),
object: "response".to_owned(),
created_at: 0,
model: request.model.clone(),
status: "completed".to_owned(),
output: vec![],
usage: None,
incomplete_details: None,
error: None,
// prev_id / conv_id are stripped for internal calls and restored below.
previous_response_id: None,
conversation_id: None,
instructions: None,
};

// Inclusive upper bound: iteration `effective_max` is the last allowed
// inference call. After it we call dispatch_tools with iteration==effective_max
// which triggers Incomplete, so the guard fires rather than the loop silently
// ending with a misleading `status: "completed"`.
for iteration in 0..=effective_max {
let result = execute(request.clone(), Arc::clone(&exec_ctx)).await?;

payload = match result {
Either::Left(p) => p,
Either::Right(_) => {
// execute() returned a stream — shouldn't happen since we set stream=false,
// but guard defensively.
return Err(ExecutorError::InvalidRequest(
"execute() returned a stream despite stream=false".to_owned(),
));
}
};

match dispatch_tools(&payload.output, &registry, &executors, iteration, effective_max).await? {
LoopDecision::Done => break,

LoopDecision::Incomplete(reason) => {
"incomplete".clone_into(&mut payload.status);
payload.incomplete_details =
Some(crate::types::request_response::IncompleteDetails { reason: Some(reason) });
break;
}

LoopDecision::Continue(tool_results) => {
// Build the set of call_ids that have a matching tool result.
// We need this to filter `fc_items` below: including FunctionCall
// items that have no corresponding FunctionCallOutput (e.g.
// client-owned tools not executed by the gateway) produces orphaned
// pairs in the conversation history that vLLM rejects.
let result_call_ids: std::collections::HashSet<&str> = tool_results
.iter()
.filter_map(|item| {
if let InputItem::FunctionCallOutput(msg) = item {
Some(msg.call_id.as_str())
} else {
None
}
})
.collect();

let existing_items: Vec<InputItem> = Vec::from(&request.input);
// Only append FC items whose call_id has a matching tool result.
let mut fc_items: Vec<InputItem> = payload
.output
.iter()
.filter_map(|item| {
if let crate::types::io::output::OutputItem::FunctionCall(fc) = item {
if result_call_ids.contains(fc.call_id.as_str()) {
return Some(InputItem::FunctionCall(fc.clone()));
}
}
None
})
.collect();

let mut next_items = existing_items;
next_items.append(&mut fc_items);
next_items.extend(tool_results);
request.input = ResponsesInput::Items(next_items);
}
}
}

// Restore original IDs onto the final payload.
payload.previous_response_id = original_prev_id;
payload.conversation_id = original_conv_id;

Ok(payload)
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;

use super::*;
use crate::ToolRegistry;
use crate::tool::ToolType;
use crate::types::request_response::RequestPayload;

// ── Helpers ───────────────────────────────────────────────────────────────

fn make_request(stream: bool) -> RequestPayload {
RequestPayload {
model: "test-model".to_owned(),
input: crate::types::io::input::ResponsesInput::Text("hello".to_owned()),
instructions: None,
previous_response_id: None,
conversation_id: None,
tools: None,
tool_choice: crate::types::io::ToolChoice::Auto,
stream,
store: false,
include: None,
temperature: None,
top_p: None,
max_output_tokens: None,
truncation: None,
metadata: None,
}
}

fn no_executors() -> HashMap<ToolType, Arc<dyn GatewayExecutor>> {
HashMap::new()
}

// ── Tests ─────────────────────────────────────────────────────────────────

#[tokio::test]
async fn rejects_streaming_request() {
let request = make_request(true);
let exec_ctx = Arc::new(ExecutionContext::new(
crate::executor::modes::ConversationHandler::new(crate::storage::ConversationStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
crate::executor::modes::ResponseHandler::new(crate::storage::ResponseStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
Arc::new(reqwest::Client::new()),
"http://localhost:9999".to_owned(),
));
let result = execute_loop(request, exec_ctx, ToolRegistry::default(), no_executors(), 10).await;
assert!(matches!(result, Err(ExecutorError::InvalidRequest(_))));
}

#[tokio::test]
async fn persistence_triggers_cleared_for_internal_calls() {
// This test verifies the contract by inspecting the request clone.
// We can't call a real LLM — but we can confirm the trigger-clearing
// logic compiles and runs by checking that store/ids are reset.
let mut request = make_request(false);
request.store = true;
request.previous_response_id = Some("resp_orig".to_owned());
request.conversation_id = Some("conv_orig".to_owned());

// The loop will attempt to call the LLM; fail immediately (no server).
// We just verify the function handles the error path — not a behavior test.
let exec_ctx = Arc::new(ExecutionContext::new(
crate::executor::modes::ConversationHandler::new(crate::storage::ConversationStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
crate::executor::modes::ResponseHandler::new(crate::storage::ResponseStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
Arc::new(reqwest::Client::new()),
"http://localhost:9999".to_owned(), // unreachable — will error
));

// Should fail with a network/LLM error, not a panic
let result = execute_loop(request, exec_ctx, ToolRegistry::default(), no_executors(), 10).await;
assert!(result.is_err(), "expected error from unreachable LLM");
}

#[test]
fn default_max_iterations_is_ten() {
assert_eq!(DEFAULT_MAX_ITERATIONS, 10);
}

#[test]
fn max_loop_guard_is_128() {
assert_eq!(MAX_LOOP_GUARD, 128);
}

#[test]
fn loop_decision_done_is_non_exhaustive() {
// Compile-time check that the enum is #[non_exhaustive] — adding a new
// variant won't silently break this match.
let d = LoopDecision::Done;
// Compile-time check: all current variants are handled, and #[non_exhaustive]
// means downstream match arms must include a wildcard for future variants.
#[allow(unreachable_patterns, clippy::match_same_arms)]
match d {
LoopDecision::Done => {}
LoopDecision::Continue(_) => {}
LoopDecision::Incomplete(_) => {}
_ => {}
}
}

// ── Plan Section B: additional unit tests ─────────────────────────────────

/// When `stream=true`, `execute_loop` must return Err before making any LLM
/// call. The streaming path is blocked until the `StreamTee` PR lands.
#[tokio::test]
async fn streaming_request_returns_err_without_llm_call() {
// Point at an unreachable LLM — if the function calls it the test panics
// (connection refused), proving the guard fires before the first call.
let exec_ctx = Arc::new(ExecutionContext::new(
crate::executor::modes::ConversationHandler::new(crate::storage::ConversationStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
crate::executor::modes::ResponseHandler::new(crate::storage::ResponseStore::new(
crate::storage::create_pool(None).await.unwrap(),
)),
Arc::new(reqwest::Client::new()),
"http://127.0.0.1:1".to_owned(), // port 1 — always refused
));
let mut request = make_request(true);
request.store = false;

let result = execute_loop(request, exec_ctx, ToolRegistry::default(), no_executors(), 10).await;
assert!(
matches!(result, Err(ExecutorError::InvalidRequest(_))),
"expected InvalidRequest for streaming, got: {result:?}"
);
}

/// `max_iterations` == `MAX_LOOP_GUARD` (128) must still produce Incomplete —
/// the fix to use `0..=effective_max` must hold at the boundary.
/// We cannot actually run 128 LLM calls in a unit test, but we can verify
/// the constants and formula:
/// `effective_max` = 128.min(128) = 128
/// loop runs 0..=128 → `dispatch_tools` sees iteration==128 >= max==128 → Incomplete
#[test]
fn loop_guard_boundary_formula_is_correct() {
// Prove: effective_max == MAX_LOOP_GUARD when max_iterations >= MAX_LOOP_GUARD
let caller_max = MAX_LOOP_GUARD; // == 128
let effective = caller_max.min(MAX_LOOP_GUARD);
assert_eq!(effective, 128);

// The loop runs 0..=128 (129 iterations).
// On iteration 128: dispatch_tools(128, 128) → 128 >= 128 → Incomplete.
// This is the key invariant that prevents silent "completed" on truncation.
let loop_count = (0..=effective).count();
assert_eq!(loop_count, 129, "loop must include iteration {effective}");
}

/// Verify `prev_id` is stored and restored: the function captures it,
/// clears it for internal LLM calls, and restores it on the returned payload.
/// Since we can't easily reach a real LLM here, the test verifies the
/// constant declarations and logic are consistent.
#[test]
fn default_and_guard_constants_are_sane() {
// DEFAULT_MAX_ITERATIONS must be < MAX_LOOP_GUARD so a default call never
// silently hits the guard. Expressed as a const assertion so the compiler
// enforces it rather than a runtime test.
const _: () = assert!(
DEFAULT_MAX_ITERATIONS < MAX_LOOP_GUARD,
"DEFAULT_MAX_ITERATIONS must be less than MAX_LOOP_GUARD"
);
}
}
Loading