diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9ccb342..eca5345 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -48,3 +48,35 @@ jobs: - name: Run tests run: cargo test + + integration: + runs-on: ubuntu-latest + needs: rust + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@e97e2d8cc328f1b50210efc529dca0028893a2d9 # v1 + with: + toolchain: stable + + - name: Cache cargo registry and build + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: cargo-${{ runner.os }}-${{ hashFiles('Cargo.lock') }} + restore-keys: | + cargo-${{ runner.os }}- + + - name: Install uv + uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6 + + - name: Install OGx + run: uv tool install "ogx[starter]" --with "sentence-transformers>=5" --with "huggingface_hub<1.18" + + - name: Run integration tests + run: make integration-test diff --git a/Cargo.lock b/Cargo.lock index b7c5a69..9a55ec8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,7 @@ name = "agentic-core" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "axum", "bytes", "chrono", @@ -49,6 +50,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "tokio-stream", "tower-http", "tracing", "tracing-subscriber", @@ -163,6 +165,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atoi" version = "2.0.0" @@ -250,9 +263,9 @@ checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" [[package]] name = "bitflags" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" dependencies = [ "serde_core", ] @@ -314,9 +327,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "js-sys", @@ -915,9 +928,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.1" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -1219,13 +1232,12 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "js-sys" -version = "0.3.99" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" +checksum = "f2025f20d7a4fa7785846e7b63d10a76d3f1cee98ee5cb79ea59703f95e42162" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] @@ -1312,9 +1324,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.31" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "113b30b4cd05f7c06868fdb2854f66a7b9fece9a48425351cd532e810d74024f" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "lru-slab" @@ -1359,6 +1371,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "1.2.1" @@ -1703,7 +1725,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.52.0", ] [[package]] @@ -1826,9 +1848,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -1849,9 +1871,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "reqwest" @@ -1872,6 +1894,7 @@ dependencies = [ "hyper-util", "js-sys", "log", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -2764,6 +2787,12 @@ version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -2835,9 +2864,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.2" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -2914,9 +2943,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" +checksum = "a254a4b10c19a76f09a27640e7ffbf9bc30bf67e16a3bf28aaefa4920fe81563" dependencies = [ "cfg-if", "once_cell", @@ -2927,9 +2956,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.72" +version = "0.4.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" +checksum = "54568702fabf5d4849ce2b90fadfa64168a097eaf4b351ce9df8b687a0086aaf" dependencies = [ "js-sys", "wasm-bindgen", @@ -2937,9 +2966,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" +checksum = "24a40fc75b0ec6f3746ceb10d36f53a93dcd68a93b11b6445983945d79eba0dc" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2947,9 +2976,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" +checksum = "908f34bd9b9ce3d4caf07b72dfab63d61504d156856c6bd3cd87fa350cf3985b" dependencies = [ "bumpalo", "proc-macro2", @@ -2960,9 +2989,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" +checksum = "7acbf7616c27b194bbb550bf77ed0c2c3e5b7fd1260a93082b95fb7f47959b92" dependencies = [ "unicode-ident", ] @@ -3016,9 +3045,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.99" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" +checksum = "6e0871acf327f283dc6da28a1696cdc64fb355ba9f935d052021fa77f35cce69" dependencies = [ "js-sys", "wasm-bindgen", @@ -3148,15 +3177,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", -] - [[package]] name = "windows-sys" version = "0.61.2" @@ -3190,30 +3210,13 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", + "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", -] - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3226,12 +3229,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3244,12 +3241,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3262,24 +3253,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3292,12 +3271,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3310,12 +3283,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3328,12 +3295,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -3346,12 +3307,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - [[package]] name = "wit-bindgen" version = "0.51.0" @@ -3454,9 +3409,9 @@ checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -3477,18 +3432,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.50" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.50" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" dependencies = [ "proc-macro2", "quote", diff --git a/Makefile b/Makefile index 9a2f13a..26c0f62 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help install lint format test build pre-commit clean +.PHONY: help install lint format test build pre-commit clean integration-test help: ## Show this help message @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' @@ -23,3 +23,6 @@ pre-commit: ## Run pre-commit hooks on all files clean: ## Remove Rust build artifacts cargo clean + +integration-test: ## Run integration tests (starts OGx, runs tests, tears down) + ./crates/agentic-server/tests/integration/run.sh diff --git a/crates/agentic-core/Cargo.toml b/crates/agentic-core/Cargo.toml index e8650c2..11863a3 100644 --- a/crates/agentic-core/Cargo.toml +++ b/crates/agentic-core/Cargo.toml @@ -8,11 +8,12 @@ repository.workspace = true [dependencies] async-stream.workspace = true +async-trait = "0.1" bytes.workspace = true either.workspace = true futures.workspace = true http.workspace = true -reqwest = { workspace = true, features = ["default-tls", "stream"] } +reqwest = { workspace = true, features = ["default-tls", "json", "stream"] } serde.workspace = true serde_json.workspace = true thiserror.workspace = true diff --git a/crates/agentic-core/src/config.rs b/crates/agentic-core/src/config.rs index cb87708..1f88505 100644 --- a/crates/agentic-core/src/config.rs +++ b/crates/agentic-core/src/config.rs @@ -7,6 +7,10 @@ pub struct Config { /// Database URL for conversation and response storage. /// `None` means stateful features are disabled; all requests are proxied. pub db_url: Option, + /// Base URL for OGX-compatible vector search. + pub ogx_base_url: String, + /// Maximum number of model/tool iterations before stopping the agentic loop. + pub max_iterations: u32, } #[must_use] diff --git a/crates/agentic-core/src/error.rs b/crates/agentic-core/src/error.rs index 011c6eb..54e5627 100644 --- a/crates/agentic-core/src/error.rs +++ b/crates/agentic-core/src/error.rs @@ -19,4 +19,25 @@ pub enum Error { #[error("{0}")] Config(String), + + #[error("store request failed")] + Store(#[source] reqwest::Error), + + #[error("store returned {status}: {body}")] + StoreResponse { status: u16, body: String }, + + #[error("vLLM proxy request failed")] + Proxy(#[source] reqwest::Error), + + #[error("vLLM returned {status}: {body}")] + ProxyResponse { status: u16, body: String }, + + #[error("database error")] + Database(#[from] sqlx::Error), + + #[error(transparent)] + StateStore(#[from] crate::storage::StorageError), + + #[error("agentic loop exceeded {max_iterations} iterations")] + MaxIterations { max_iterations: u32 }, } diff --git a/crates/agentic-core/src/executor/engine.rs b/crates/agentic-core/src/executor/engine.rs index a0aa2c9..40e4f12 100644 --- a/crates/agentic-core/src/executor/engine.rs +++ b/crates/agentic-core/src/executor/engine.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use async_stream::stream; use either::Either; use futures::{Stream, StreamExt}; +use serde::Deserialize; use tracing::warn; use crate::executor::accumulator::ResponseAccumulator; @@ -18,10 +19,14 @@ use crate::executor::modes::{ConversationHandler, ResponseHandler}; use crate::executor::request::{ExecutionContext, RequestContext}; use crate::storage::InOutItem; use crate::types::event::ResponseStatus; -use crate::types::io::{InputItem, ResponsesInput, resolve_tool_choice, resolve_tools}; +use crate::types::io::{ + FunctionTool, FunctionToolCall, FunctionToolResultMessage, InputItem, OutputItem, ResponsesInput, ResponsesTool, + resolve_tool_choice, resolve_tools, +}; use crate::types::request_response::{RequestPayload, ResponsePayload}; use crate::utils::common::serialize_to_string; use crate::utils::uuid7_str; +use crate::vector_search::types::SearchOptions; use std::time::Duration; @@ -297,6 +302,206 @@ pub async fn persist_response( } } +fn contains_file_search(tools: Option<&[ResponsesTool]>) -> bool { + tools.is_some_and(|tools| tools.iter().any(|tool| matches!(tool, ResponsesTool::FileSearch(_)))) +} + +#[derive(Clone)] +struct FileSearchConfig { + store_ids: Vec, + options: SearchOptions, +} + +fn file_search_config(tools: Option<&[ResponsesTool]>) -> ExecutorResult { + let mut store_ids = Vec::new(); + let mut options = None::; + + for tool in tools.into_iter().flatten() { + match tool { + ResponsesTool::FileSearch(tool) => { + store_ids.extend(tool.vector_store_ids.iter().filter(|id| !id.is_empty()).cloned()); + if options + .as_ref() + .is_some_and(|existing| existing != &tool.search_options) + { + return Err(ExecutorError::InvalidRequest( + "multiple file_search tools with different search options are not supported".into(), + )); + } + options.get_or_insert_with(|| tool.search_options.clone()); + } + ResponsesTool::Function(_) | ResponsesTool::Unknown => {} + } + } + + if store_ids.is_empty() { + return Err(ExecutorError::InvalidRequest( + "file_search requires at least one vector_store_ids entry".into(), + )); + } + + Ok(FileSearchConfig { + store_ids, + options: options.unwrap_or_default(), + }) +} + +fn file_search_function_tool() -> ResponsesTool { + ResponsesTool::Function(FunctionTool { + name: "file_search".to_string(), + description: Some("Search attached vector stores for relevant file content.".to_string()), + parameters: Some(serde_json::json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to run against the vector store." + } + }, + "required": ["query"], + "additionalProperties": false + })), + strict: Some(true), + }) +} + +fn translate_file_search_tools(tools: Option<&[ResponsesTool]>) -> Option> { + let tools = tools?; + let mut translated = Vec::with_capacity(tools.len()); + for tool in tools { + match tool { + ResponsesTool::Function(tool) => translated.push(ResponsesTool::Function(tool.clone())), + ResponsesTool::FileSearch(_) => translated.push(file_search_function_tool()), + ResponsesTool::Unknown => {} + } + } + Some(translated) +} + +fn file_search_calls(output: &[OutputItem]) -> ExecutorResult> { + let mut file_search_calls = Vec::new(); + let mut other_tool_names = Vec::new(); + + for item in output { + match item { + OutputItem::FunctionCall(call) if call.name == "file_search" => file_search_calls.push(call.clone()), + OutputItem::FunctionCall(call) => other_tool_names.push(call.name.clone()), + OutputItem::Message(_) | OutputItem::Reasoning(_) | OutputItem::Unknown => {} + } + } + + if !file_search_calls.is_empty() && !other_tool_names.is_empty() { + return Err(ExecutorError::ToolExecution(format!( + "mixed tool calls are not supported in file_search loop: {}", + other_tool_names.join(", ") + ))); + } + + Ok(file_search_calls) +} + +#[derive(Deserialize)] +struct FileSearchArguments { + query: String, +} + +fn query_from_arguments(arguments: &str) -> ExecutorResult { + let args = serde_json::from_str::(arguments) + .map_err(|err| ExecutorError::ToolExecution(format!("invalid file_search arguments: {err}")))?; + + if args.query.trim().is_empty() { + return Err(ExecutorError::ToolExecution( + "file_search query argument is required".into(), + )); + } + + Ok(args.query) +} + +fn append_input_item(input: &mut ResponsesInput, item: InputItem) { + let mut items = Vec::::from(&*input); + items.push(item); + *input = ResponsesInput::Items(items); +} + +async fn run_file_search_loop(mut ctx: RequestContext, exec_ctx: &ExecutionContext) -> ExecutorResult { + if ctx.original_request.stream { + return Err(ExecutorError::InvalidRequest( + "streaming file_search requests are not supported".into(), + )); + } + + let Some(vector_search) = exec_ctx.vector_search.as_ref() else { + return Err(ExecutorError::InvalidRequest( + "file_search requires a configured vector search backend".into(), + )); + }; + + let file_search = file_search_config(ctx.enriched_request.tools.as_deref())?; + ctx.enriched_request.tools = translate_file_search_tools(ctx.enriched_request.tools.as_deref()); + let url = exec_ctx.responses_url(); + + for _ in 0..exec_ctx.max_iterations { + let upstream_json = + serialize_to_string(&ctx.enriched_request.to_upstream_request(false)).map_err(ExecutorError::JsonError)?; + let body = fetch_response_json(upstream_json, &url, &exec_ctx.client, exec_ctx.client_auth.as_deref()).await?; + let acc = ResponseAccumulator::from_json(&body, ctx.conversation_id.as_deref())?; + let mut payload = acc.finalize( + &ctx.enriched_request.model, + ctx.original_request.previous_response_id.as_deref(), + ctx.original_request.instructions.as_deref(), + ); + + let tool_calls = file_search_calls(&payload.output)?; + if tool_calls.is_empty() { + ctx.inject_ids(&mut payload); + let should_persist = ctx.original_request.store + || ctx.original_request.previous_response_id.is_some() + || ctx.original_request.conversation_id.is_some(); + if should_persist { + let ch = exec_ctx.conv_handler.clone(); + let rh = exec_ctx.resp_handler.clone(); + if let Err(e) = persist_response(payload.clone(), ctx, ch, rh).await { + warn!("persist failed: {e}"); + } + } + return Ok(payload); + } + + for call in tool_calls { + let input_call = InputItem::FunctionCall(call.clone()); + append_input_item(&mut ctx.enriched_request.input, input_call.clone()); + ctx.new_input_items.push(input_call); + + let query = query_from_arguments(&call.arguments)?; + let mut results = Vec::new(); + for store_id in &file_search.store_ids { + match vector_search.search(store_id, &query, &file_search.options).await { + Ok(mut store_results) => results.append(&mut store_results), + Err(err) => { + return Err(ExecutorError::ToolExecution(format!( + "file_search vector lookup failed for vector store {store_id}: {err}" + ))); + } + } + } + + let output = + serialize_to_string(&serde_json::json!({ "results": results })).map_err(ExecutorError::JsonError)?; + let result_item = InputItem::FunctionCallOutput(FunctionToolResultMessage { + call_id: call.call_id, + output, + }); + append_input_item(&mut ctx.enriched_request.input, result_item.clone()); + ctx.new_input_items.push(result_item); + } + } + + Err(ExecutorError::MaxIterations { + max_iterations: exec_ctx.max_iterations, + }) +} + async fn run_blocking(ctx: RequestContext, exec_ctx: &ExecutionContext) -> ExecutorResult { let url = exec_ctx.responses_url(); // Non-streaming request: stream=false → full JSON body → from_json. @@ -410,6 +615,9 @@ pub async fn execute( exec_ctx: Arc, ) -> ExecutorResult> { let ctx = rehydrate_conversation(request, &exec_ctx).await?; + if contains_file_search(ctx.enriched_request.tools.as_deref()) { + return Ok(Either::Left(run_file_search_loop(ctx, &exec_ctx).await?)); + } if ctx.original_request.stream { Ok(Either::Right(run_stream(ctx, exec_ctx))) } else { diff --git a/crates/agentic-core/src/executor/error.rs b/crates/agentic-core/src/executor/error.rs index df200a2..dee656e 100644 --- a/crates/agentic-core/src/executor/error.rs +++ b/crates/agentic-core/src/executor/error.rs @@ -52,6 +52,12 @@ pub enum ExecutorError { #[error("{entity} not found: {id}")] NotFound { entity: String, id: String }, + #[error("agentic loop exceeded {max_iterations} iterations")] + MaxIterations { max_iterations: u32 }, + + #[error("{0}")] + ToolExecution(String), + #[error("invalid request: {0}")] InvalidRequest(String), } @@ -62,6 +68,7 @@ impl ExecutorError { pub fn http_status(&self) -> StatusCode { match self { Self::Storage(e) if e.is_not_found() => StatusCode::NOT_FOUND, + Self::MaxIterations { .. } | Self::ToolExecution(_) => StatusCode::BAD_GATEWAY, Self::LLMRequest { status, .. } => *status, Self::InvalidRequest(_) | Self::JsonError(_) => StatusCode::BAD_REQUEST, Self::ParseError(_) => StatusCode::UNPROCESSABLE_ENTITY, @@ -75,6 +82,7 @@ impl ExecutorError { match self { Self::Storage(e) if e.is_not_found() => "not_found", Self::LLMRequest { .. } => "upstream_error", + Self::ToolExecution(_) => "tool_execution_error", Self::InvalidRequest(_) | Self::ParseError(_) | Self::JsonError(_) => "invalid_request_error", _ => "server_error", } diff --git a/crates/agentic-core/src/executor/request.rs b/crates/agentic-core/src/executor/request.rs index 66f7fe0..83529e8 100644 --- a/crates/agentic-core/src/executor/request.rs +++ b/crates/agentic-core/src/executor/request.rs @@ -7,6 +7,7 @@ use crate::executor::modes::{ConversationHandler, ResponseHandler}; use crate::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; use crate::types::io::InputItem; use crate::types::request_response::{RequestPayload, ResponsePayload}; +use crate::vector_search::{VectorSearch, ogx::OgxStore}; /// Context built by `rehydrate_conversation`, threaded through the execute pipeline. #[derive(Debug)] @@ -39,15 +40,18 @@ impl RequestContext { /// Runtime dependencies passed into `execute()`. /// /// Owns the storage handlers, HTTP client, and LLM endpoint configuration. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ExecutionContext { pub conv_handler: ConversationHandler, pub resp_handler: ResponseHandler, pub client: Arc, + pub vector_search: Option>, /// Base URL for the LLM backend, e.g. `"http://localhost:8000"`. pub llm_base_url: String, /// Bearer token forwarded from the client, if any. pub client_auth: Option, + /// Maximum model/tool turns for the agentic loop. + pub max_iterations: u32, /// Maximum wait time for the next SSE chunk. `Duration::ZERO` disables the timeout. /// Sourced from [`Config::streaming_chunk_timeout_s`](crate::config::Config::streaming_chunk_timeout_s). pub streaming_timeout: Duration, @@ -78,12 +82,21 @@ impl ExecutionContext { conv_handler, resp_handler, client, + vector_search: None, llm_base_url, client_auth, + max_iterations: 10, streaming_timeout: Duration::from_secs(30), } } + #[must_use] + pub fn with_vector_search(mut self, vector_search: Arc, max_iterations: u32) -> Self { + self.vector_search = Some(vector_search); + self.max_iterations = max_iterations; + self + } + /// Build an `ExecutionContext` directly from [`Config`](crate::config::Config). /// /// Creates the database pool, both storage handlers, and an HTTP client @@ -102,13 +115,16 @@ impl ExecutionContext { let conv_handler = ConversationHandler::new(ConversationStore::new(pool.clone())); let resp_handler = ResponseHandler::new(ResponseStore::new(pool)); let client = Arc::new(reqwest::Client::new()); + let vector_search = Arc::new(OgxStore::new(&cfg.ogx_base_url, reqwest::Client::new())); Ok(Self { conv_handler, resp_handler, client, + vector_search: Some(vector_search), llm_base_url: cfg.llm_api_base.clone(), client_auth: cfg.openai_api_key.clone(), + max_iterations: cfg.max_iterations, streaming_timeout: Duration::from_secs(30), }) } diff --git a/crates/agentic-core/src/lib.rs b/crates/agentic-core/src/lib.rs index 5828a68..2545490 100644 --- a/crates/agentic-core/src/lib.rs +++ b/crates/agentic-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod readiness; pub mod storage; pub mod types; pub mod utils; +pub mod vector_search; pub use storage::{ ConversationData, ConversationStore, DbPool, InOutItem, ItemKind, ResponseData, ResponseMetadata, ResponseStore, diff --git a/crates/agentic-core/src/proxy.rs b/crates/agentic-core/src/proxy.rs index 79329b9..116bd03 100644 --- a/crates/agentic-core/src/proxy.rs +++ b/crates/agentic-core/src/proxy.rs @@ -232,6 +232,8 @@ mod tests { llm_ready_timeout_s: 5.0, llm_ready_interval_s: 0.1, db_url: None, + ogx_base_url: "http://localhost:8080".to_owned(), + max_iterations: 10, } } diff --git a/crates/agentic-core/src/storage/types/item.rs b/crates/agentic-core/src/storage/types/item.rs index 2f36fa4..82d297c 100644 --- a/crates/agentic-core/src/storage/types/item.rs +++ b/crates/agentic-core/src/storage/types/item.rs @@ -88,9 +88,13 @@ impl InOutItem { .into_iter() .filter_map(|i| match i { InOutItem::Input(item) => Some(item), - InOutItem::Output(OutputItem::Message(msg)) => Some(InputItem::Message(msg.into())), + InOutItem::Output(OutputItem::Message(msg)) => { + // Embed history OutputMessage as an input item so the model sees prior turns. + Some(InputItem::Message(msg.into())) + } InOutItem::Output(OutputItem::Reasoning(r)) => Some(InputItem::Reasoning(r)), - InOutItem::Output(OutputItem::FunctionCall(_) | OutputItem::Unknown) => None, + InOutItem::Output(OutputItem::FunctionCall(call)) => Some(InputItem::FunctionCall(call)), + InOutItem::Output(OutputItem::Unknown) => None, }) .collect() } diff --git a/crates/agentic-core/src/types/io.rs b/crates/agentic-core/src/types/io.rs index 9e24d8e..27135e0 100644 --- a/crates/agentic-core/src/types/io.rs +++ b/crates/agentic-core/src/types/io.rs @@ -1,5 +1,7 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::{Map, Value}; + +use crate::vector_search::types::SearchOptions; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InputTextContent { @@ -44,11 +46,13 @@ pub struct FunctionToolResultMessage { pub output: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] #[serde(tag = "type")] pub enum InputItem { #[serde(rename = "message")] Message(InputMessage), + #[serde(rename = "function_call")] + FunctionCall(FunctionToolCall), #[serde(rename = "function_call_output")] FunctionCallOutput(FunctionToolResultMessage), #[serde(rename = "reasoning")] @@ -57,6 +61,49 @@ pub enum InputItem { Unknown, } +impl<'de> Deserialize<'de> for InputItem { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let mut value = Value::deserialize(deserializer)?; + let Some(object) = value.as_object_mut() else { + return Ok(Self::Unknown); + }; + + match object.get("type").and_then(Value::as_str) { + Some("message") => { + object.remove("type"); + serde_json::from_value(value) + .map(Self::Message) + .map_err(serde::de::Error::custom) + } + Some("function_call") => { + object.remove("type"); + serde_json::from_value(value) + .map(Self::FunctionCall) + .map_err(serde::de::Error::custom) + } + Some("function_call_output") => { + object.remove("type"); + serde_json::from_value(value) + .map(Self::FunctionCallOutput) + .map_err(serde::de::Error::custom) + } + Some("reasoning") => { + object.remove("type"); + serde_json::from_value(value) + .map(Self::Reasoning) + .map_err(serde::de::Error::custom) + } + None if object.contains_key("role") && object.contains_key("content") => serde_json::from_value(value) + .map(Self::Message) + .map_err(serde::de::Error::custom), + Some(_) | None => Ok(Self::Unknown), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OutputTextContent { #[serde(rename = "type")] @@ -78,8 +125,10 @@ impl OutputTextContent { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OutputMessage { + #[serde(default)] pub id: String, pub role: String, + #[serde(default)] pub status: String, #[serde(default)] pub content: Vec, @@ -121,7 +170,8 @@ pub struct FunctionToolCall { pub call_id: String, pub name: String, pub arguments: String, - pub status: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub status: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -199,15 +249,32 @@ pub struct ResponseUsage { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionTool { - #[serde(rename = "type")] - pub type_: String, pub name: String, pub description: Option, pub parameters: Option, pub strict: Option, } -pub type ResponsesTool = FunctionTool; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileSearchTool { + #[serde(default)] + pub vector_store_ids: Vec, + #[serde(default, flatten)] + pub search_options: SearchOptions, + #[serde(flatten)] + pub rest: Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ResponsesTool { + #[serde(rename = "function")] + Function(FunctionTool), + #[serde(rename = "file_search")] + FileSearch(FileSearchTool), + #[serde(other)] + Unknown, +} #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] diff --git a/crates/agentic-core/src/types/request_response.rs b/crates/agentic-core/src/types/request_response.rs index de89a08..5cca28a 100644 --- a/crates/agentic-core/src/types/request_response.rs +++ b/crates/agentic-core/src/types/request_response.rs @@ -80,6 +80,13 @@ impl RequestPayload { metadata: self.metadata.as_ref(), } } + + #[must_use] + pub fn has_file_search_tool(&self) -> bool { + self.tools + .as_deref() + .is_some_and(|tools| tools.iter().any(|tool| matches!(tool, ResponsesTool::FileSearch(_)))) + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/agentic-core/src/vector_search/mod.rs b/crates/agentic-core/src/vector_search/mod.rs new file mode 100644 index 0000000..9a97377 --- /dev/null +++ b/crates/agentic-core/src/vector_search/mod.rs @@ -0,0 +1,16 @@ +pub mod ogx; +pub mod types; + +use async_trait::async_trait; + +use types::{SearchOptions, SearchResult}; + +#[async_trait] +pub trait VectorSearch: Send + Sync { + async fn search( + &self, + store_id: &str, + query: &str, + options: &SearchOptions, + ) -> Result, crate::error::Error>; +} diff --git a/crates/agentic-core/src/vector_search/ogx.rs b/crates/agentic-core/src/vector_search/ogx.rs new file mode 100644 index 0000000..77fbe85 --- /dev/null +++ b/crates/agentic-core/src/vector_search/ogx.rs @@ -0,0 +1,54 @@ +use async_trait::async_trait; +use serde_json::{Map, Value}; +use tracing::debug; + +use super::types::{SearchOptions, SearchResponse, SearchResult}; +use crate::error::Error; + +pub struct OgxStore { + base_url: String, + client: reqwest::Client, +} + +impl OgxStore { + #[must_use] + pub fn new(base_url: &str, client: reqwest::Client) -> Self { + let base_url = base_url.trim_end_matches('/').to_owned(); + Self { base_url, client } + } +} + +#[async_trait] +impl super::VectorSearch for OgxStore { + async fn search(&self, store_id: &str, query: &str, options: &SearchOptions) -> Result, Error> { + let url = format!("{}/v1/vector_stores/{store_id}/search", self.base_url); + debug!(%url, "searching vector store via OGx"); + + let mut body = Map::new(); + body.insert("query".to_owned(), Value::String(query.to_owned())); + body.insert( + "max_num_results".to_owned(), + Value::from(options.max_num_results.unwrap_or(10)), + ); + if let Some(filters) = &options.filters { + body.insert("filters".to_owned(), filters.clone()); + } + if let Some(ranking_options) = &options.ranking_options { + body.insert("ranking_options".to_owned(), ranking_options.clone()); + } + + let resp = self.client.post(&url).json(&body).send().await.map_err(Error::Store)?; + + let status = resp.status(); + if !status.is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(Error::StoreResponse { + status: status.as_u16(), + body, + }); + } + + let search_resp: SearchResponse = resp.json().await.map_err(Error::Store)?; + Ok(search_resp.data) + } +} diff --git a/crates/agentic-core/src/vector_search/types.rs b/crates/agentic-core/src/vector_search/types.rs new file mode 100644 index 0000000..7b9c218 --- /dev/null +++ b/crates/agentic-core/src/vector_search/types.rs @@ -0,0 +1,133 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseRequest { + pub model: String, + #[serde(default)] + pub input: ResponseInput, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ResponseInput { + Text(String), + Items(Vec), +} + +impl Default for ResponseInput { + fn default() -> Self { + Self::Items(Vec::new()) + } +} + +impl ResponseInput { + #[must_use] + pub fn to_values(&self) -> Vec { + match self { + Self::Text(text) => vec![serde_json::json!({ + "type": "message", + "role": "user", + "content": text + })], + Self::Items(items) => items.clone(), + } + } + + pub fn prepend(&mut self, mut history: Vec) { + history.extend(self.to_values()); + *self = Self::Items(history); + } + + pub fn push(&mut self, item: serde_json::Value) { + match self { + Self::Text(_) => { + let mut items = self.to_values(); + items.push(item); + *self = Self::Items(items); + } + Self::Items(items) => items.push(item), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolConfig { + pub r#type: String, + #[serde(default)] + pub vector_store_ids: Option>, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseBody { + pub id: String, + #[serde(default)] + pub output: Vec, + #[serde(default)] + pub status: String, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum VllmOutputItem { + #[serde(rename = "message")] + Message { + #[serde(flatten)] + fields: serde_json::Map, + }, + #[serde(rename = "function_call")] + FunctionCall { + id: String, + call_id: String, + name: String, + arguments: String, + #[serde(flatten)] + rest: serde_json::Map, + }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResponse { + pub data: Vec, +} + +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct SearchOptions { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub filters: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_num_results: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub ranking_options: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub file_id: String, + pub filename: String, + pub score: f64, + #[serde(default)] + pub attributes: Option>, + #[serde(default)] + pub content: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContentChunk { + pub r#type: String, + pub text: String, +} diff --git a/crates/agentic-server/Cargo.toml b/crates/agentic-server/Cargo.toml index 9a894d6..601faad 100644 --- a/crates/agentic-server/Cargo.toml +++ b/crates/agentic-server/Cargo.toml @@ -26,9 +26,10 @@ tracing-subscriber.workspace = true bytes.workspace = true criterion.workspace = true futures.workspace = true -reqwest = { workspace = true, features = ["json"] } +reqwest = { workspace = true, features = ["json", "multipart"] } serde_json.workspace = true tokio = { workspace = true, features = ["test-util"] } +tokio-stream = "0.1" uuid = { version = "1", features = ["v7"] } [[bench]] diff --git a/crates/agentic-server/benches/gateway_bench.rs b/crates/agentic-server/benches/gateway_bench.rs index 725f944..a08965a 100644 --- a/crates/agentic-server/benches/gateway_bench.rs +++ b/crates/agentic-server/benches/gateway_bench.rs @@ -157,6 +157,8 @@ async fn spawn_gateway(llm_url: &str) -> (Arc, String) { llm_ready_timeout_s: 5.0, llm_ready_interval_s: 0.1, db_url: Some(format!("sqlite://{}", db_path.display())), + ogx_base_url: "http://127.0.0.1:1".to_owned(), + max_iterations: 10, }; let proxy_state = ProxyState::new(config.clone()).unwrap(); diff --git a/crates/agentic-server/benches/proxy_bench.rs b/crates/agentic-server/benches/proxy_bench.rs index ae677d6..62adeac 100644 --- a/crates/agentic-server/benches/proxy_bench.rs +++ b/crates/agentic-server/benches/proxy_bench.rs @@ -26,6 +26,8 @@ fn bench_config(llm_url: &str) -> Config { llm_ready_timeout_s: 5.0, llm_ready_interval_s: 0.1, db_url: None, + ogx_base_url: "http://127.0.0.1:1".to_owned(), + max_iterations: 10, } } diff --git a/crates/agentic-server/src/handler.rs b/crates/agentic-server/src/handler.rs index d132464..f4ee081 100644 --- a/crates/agentic-server/src/handler.rs +++ b/crates/agentic-server/src/handler.rs @@ -177,7 +177,7 @@ pub async fn responses(State(state): State, req: Request) -> Response let should_persist = payload.store || payload.previous_response_id.is_some() || payload.conversation_id.is_some(); - if should_persist { + if should_persist || payload.has_file_search_tool() { execute_responses(&state, parts, payload).await } else { proxy_responses(&state, parts, bytes).await diff --git a/crates/agentic-server/src/main.rs b/crates/agentic-server/src/main.rs index 5cf7609..f5640bf 100644 --- a/crates/agentic-server/src/main.rs +++ b/crates/agentic-server/src/main.rs @@ -31,6 +31,12 @@ struct CommonArgs { global = true )] db_url: String, + + #[arg(long, default_value = "http://localhost:8080", global = true)] + ogx_base_url: String, + + #[arg(long, default_value_t = 10, global = true)] + max_iterations: u32, } #[derive(Parser)] @@ -70,6 +76,8 @@ fn build_config(llm_api_base: String, common: &CommonArgs) -> Config { llm_ready_timeout_s: common.llm_ready_timeout_s, llm_ready_interval_s: common.llm_ready_interval_s, db_url: Some(common.db_url.clone()), + ogx_base_url: normalize_base_url(&common.ogx_base_url), + max_iterations: common.max_iterations, } } diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs new file mode 100644 index 0000000..6a9c15b --- /dev/null +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -0,0 +1,541 @@ +#[allow(dead_code)] +mod common; + +use common::{ + spawn_ogx, spawn_ogx_recording, spawn_vllm, spawn_vllm_recording, spawn_vllm_with_tool_calls, start_gateway, +}; + +#[tokio::test] +async fn test_passthrough_no_tools() { + let (vllm_port, _h) = spawn_vllm().await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}], + "store": false + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + +#[tokio::test] +async fn test_single_file_search() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Based on the search results..."}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "search for something"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert!(body["id"].as_str().unwrap_or("").starts_with("resp_")); + assert_eq!(body["output"][0]["type"], "message"); +} + +#[tokio::test] +async fn test_file_search_backend_failure_returns_error() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer without search context"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (gw_addr, _) = start_gateway(vllm_port, None, None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!( + msg.contains("file_search vector lookup failed"), + "unexpected error: {msg}" + ); +} + +#[tokio::test] +async fn test_file_search_rejects_missing_query_argument() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer without a query"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("query"), "unexpected error: {msg}"); +} + +#[tokio::test] +async fn test_file_search_rejects_empty_vector_store_ids_before_vllm() { + let response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "should not be called"}] + }] + }); + + let (vllm_port, requests, _h) = spawn_vllm_recording(vec![response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": []}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 400); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("vector_store_ids"), "unexpected error: {msg}"); + assert!( + requests.lock().await.is_empty(), + "gateway should reject before calling vLLM" + ); +} + +#[tokio::test] +async fn test_file_search_preserves_search_options() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Based on filtered search results..."}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, ogx_requests, _h2) = spawn_ogx_recording().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{ + "type": "file_search", + "vector_store_ids": ["vs_123"], + "max_num_results": 3, + "filters": {"type": "eq", "key": "tenant_id", "value": "tenant-a"}, + "ranking_options": {"ranker": "default", "score_threshold": 0.25} + }] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let requests = ogx_requests.lock().await; + assert_eq!(requests.len(), 1); + assert_eq!(requests[0]["query"], "test query"); + assert_eq!(requests[0]["max_num_results"], 3); + assert_eq!(requests[0]["filters"]["key"], "tenant_id"); + assert_eq!(requests[0]["ranking_options"]["score_threshold"], 0.25); +} + +#[tokio::test] +async fn test_file_search_rejects_mixed_tool_calls() { + let mixed_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }, + { + "type": "function_call", + "id": "fc_2", + "call_id": "call_2", + "name": "get_weather", + "arguments": "{\"city\": \"SF\"}", + "status": "completed" + } + ] + }); + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer after dropping get_weather"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![mixed_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search and call another function", + "tools": [ + {"type": "file_search", "vector_store_ids": ["vs_123"]}, + { + "type": "function", + "name": "get_weather", + "description": "Get weather.", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}} + } + ] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("mixed tool calls"), "unexpected error: {msg}"); +} + +#[tokio::test] +async fn test_file_search_streaming_rejected() { + let (vllm_port, _h) = spawn_vllm().await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "stream": true, + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 400); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("streaming file_search"), "unexpected error: {msg}"); +} + +#[tokio::test] +async fn test_previous_response_id_hydrates_history() { + let first_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "id": "msg_1", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "first answer"}] + }] + }); + + let second_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "id": "msg_2", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "second answer"}] + }] + }); + + let (vllm_port, requests, _h) = spawn_vllm_recording(vec![first_response, second_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let first = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "first question", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + assert_eq!(first.status(), 200); + let first_body: serde_json::Value = first.json().await.unwrap(); + let first_id = first_body["id"].as_str().expect("first response id"); + + let second = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "follow up", + "previous_response_id": first_id, + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + let second_status = second.status(); + let second_text = second.text().await.unwrap(); + assert_eq!(second_status, 200, "second response body: {second_text}"); + + let requests = requests.lock().await; + assert_eq!(requests.len(), 2); + assert!(requests[1].get("previous_response_id").is_none()); + let input = requests[1]["input"] + .as_array() + .expect("hydrated input should be an array"); + assert!( + input.len() >= 3, + "expected prior user/output plus follow-up input, got {input:?}" + ); + assert!(input.iter().any(|item| item["content"] == "first question")); + assert!(input.iter().any(|item| item["content"] == "follow up")); +} + +#[tokio::test] +async fn test_multi_turn_tool_calls() { + let turn1 = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"first query\"}", + "status": "completed" + }] + }); + + let turn2 = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_2", + "call_id": "call_2", + "name": "file_search", + "arguments": "{\"query\": \"second query\"}", + "status": "completed" + }] + }); + + let final_resp = serde_json::json!({ + "id": "resp_3", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "final answer"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![turn1, turn2, final_resp]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "multi-turn search"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert!(body["id"].as_str().unwrap_or("").starts_with("resp_")); +} + +#[tokio::test] +async fn test_max_iterations_reached() { + let tool_call = serde_json::json!({ + "id": "resp_loop", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_loop", + "call_id": "call_loop", + "name": "file_search", + "arguments": "{\"query\": \"infinite loop\"}", + "status": "completed" + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "search forever"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("exceeded"), "expected max iterations error, got: {msg}"); +} diff --git a/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json b/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json new file mode 100644 index 0000000..f4067bc --- /dev/null +++ b/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json @@ -0,0 +1,296 @@ +{ + "metadata": { + "recorded_at": "2026-06-17T12:30:24+00:00", + "model": "openai/gpt-oss-20b", + "note": "Harmony models reject tool_choice=required, so this cassette uses tool_choice=auto with a direct prompt." + }, + "turns": [ + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": { + "model": "openai/gpt-oss-20b", + "input": "Use the file_search tool to find information about Rust memory safety ownership.", + "tools": [ + { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + } + } + ], + "tool_choice": "auto", + "stream": false + } + }, + "response": { + "status_code": 200, + "body": { + "id": "resp_a4c1eeb7507be20a", + "created_at": 1781699419, + "incomplete_details": null, + "instructions": null, + "metadata": null, + "model": "openai/gpt-oss-20b", + "object": "response", + "output": [ + { + "id": "rs_b28e1cdf8ad9b24c", + "summary": [], + "type": "reasoning", + "content": [ + { + "text": "We need to use the file_search tool to search for \"Rust memory safety ownership\" or something like that. The instruction: \"Use the file_search tool to find information about Rust memory safety ownership.\" We need to call the function. We can pass query \"Rust memory safety ownership\". Then we need to answer by providing the relevant info. Probably the tool returns passages or summaries. Use tool.", + "type": "reasoning_text" + } + ], + "encrypted_content": null, + "status": null + }, + { + "arguments": "{\"query\":\"Rust memory safety ownership\"}", + "call_id": "call_8e217e84eeadc2fb", + "name": "file_search", + "type": "function_call", + "id": "fc_8e217e84eeadc2fb", + "status": null + } + ], + "parallel_tool_calls": true, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [ + { + "name": "file_search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + }, + "strict": null, + "type": "function", + "description": "Search uploaded files for relevant passages." + } + ], + "top_p": 1.0, + "background": false, + "max_output_tokens": 16247, + "max_tool_calls": null, + "previous_response_id": null, + "prompt": null, + "reasoning": null, + "service_tier": "auto", + "status": "completed", + "text": null, + "top_logprobs": null, + "truncation": "disabled", + "usage": { + "input_tokens": 137, + "input_tokens_details": { + "cached_tokens": 128, + "input_tokens_per_turn": [ + 137 + ], + "cached_tokens_per_turn": [ + 128 + ] + }, + "output_tokens": 106, + "output_tokens_details": { + "reasoning_tokens": 89, + "tool_output_tokens": 0, + "output_tokens_per_turn": [ + 106 + ], + "tool_output_tokens_per_turn": [ + 0 + ] + }, + "total_tokens": 243 + }, + "user": null, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "kv_transfer_params": null, + "input_messages": null, + "output_messages": null + } + } + }, + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": { + "model": "openai/gpt-oss-20b", + "input": [ + { + "type": "message", + "role": "user", + "content": "Use the file_search tool to find information about Rust memory safety ownership." + }, + { + "arguments": "{\"query\":\"Rust memory safety ownership\"}", + "call_id": "call_8e217e84eeadc2fb", + "name": "file_search", + "type": "function_call", + "id": "fc_8e217e84eeadc2fb", + "status": null + }, + { + "type": "function_call_output", + "call_id": "call_8e217e84eeadc2fb", + "output": "{\"results\":[{\"file_id\":\"file_abc\",\"filename\":\"rust-memory-safety.txt\",\"score\":0.95,\"attributes\":{},\"content\":[{\"type\":\"text\",\"text\":\"Rust enforces memory safety without a garbage collector through ownership, borrowing, and lifetimes.\"}]}]}" + } + ], + "tools": [ + { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + } + } + ], + "tool_choice": "auto", + "stream": false + } + }, + "response": { + "status_code": 200, + "body": { + "id": "resp_92585e571ba11e5e", + "created_at": 1781699421, + "incomplete_details": null, + "instructions": null, + "metadata": null, + "model": "openai/gpt-oss-20b", + "object": "response", + "output": [ + { + "id": "rs_b6dcc6546ab2d7c2", + "summary": [], + "type": "reasoning", + "content": [ + { + "text": "We have the file content. Provide answer.", + "type": "reasoning_text" + } + ], + "encrypted_content": null, + "status": null + }, + { + "id": "msg_a7c9b9489fe74d58", + "content": [ + { + "annotations": [], + "text": "**Rust Memory Safety & Ownership**\n\nRust guarantees memory safety at compile time without relying on a garbage collector. The core concepts that enable this are:\n\n- **Ownership** – Each value has a single owner.\n- **Borrowing** – References (`&T` for immutable and `&mut T` for mutable) allow temporary access without taking ownership.\n- **Lifetimes** – Compile‑time annotations that ensure references do not outlive the data they point to.\n\nThese rules together prevent common bugs such as use‑after‑free, double frees, and data races in concurrent programs.", + "type": "output_text", + "logprobs": null + } + ], + "role": "assistant", + "status": "completed", + "type": "message" + } + ], + "parallel_tool_calls": true, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [ + { + "name": "file_search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + }, + "strict": null, + "type": "function", + "description": "Search uploaded files for relevant passages." + } + ], + "top_p": 1.0, + "background": false, + "max_output_tokens": 16155, + "max_tool_calls": null, + "previous_response_id": null, + "prompt": null, + "reasoning": null, + "service_tier": "auto", + "status": "completed", + "text": null, + "top_logprobs": null, + "truncation": "disabled", + "usage": { + "input_tokens": 229, + "input_tokens_details": { + "cached_tokens": 224, + "input_tokens_per_turn": [ + 229 + ], + "cached_tokens_per_turn": [ + 224 + ] + }, + "output_tokens": 137, + "output_tokens_details": { + "reasoning_tokens": 10, + "tool_output_tokens": 0, + "output_tokens_per_turn": [ + 137 + ], + "tool_output_tokens_per_turn": [ + 0 + ] + }, + "total_tokens": 366 + }, + "user": null, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "kv_transfer_params": null, + "input_messages": null, + "output_messages": null + } + } + } + ] +} diff --git a/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py b/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py new file mode 100644 index 0000000..e7147cd --- /dev/null +++ b/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py @@ -0,0 +1,151 @@ +"""Record vLLM Responses API cassettes for the file_search integration test.""" + +from __future__ import annotations + +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + + +DEFAULT_MODEL = "openai/gpt-oss-20b" +DEFAULT_PROMPT = "Use the file_search tool to find information about Rust memory safety ownership." +DEFAULT_OUTPUT = Path(__file__).with_name("file_search") / "vllm-file-search-openai-gpt-oss-20b.json" + +SEARCH_RESULT = { + "results": [ + { + "file_id": "file_abc", + "filename": "rust-memory-safety.txt", + "score": 0.95, + "attributes": {}, + "content": [ + { + "type": "text", + "text": "Rust enforces memory safety without a garbage collector through ownership, borrowing, and lifetimes.", + } + ], + } + ] +} + + +def post_json(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: + url = f"{base_url.rstrip('/')}/v1/responses" + data = json.dumps(payload).encode("utf-8") + request = urllib.request.Request(url, data=data, headers={"content-type": "application/json"}, method="POST") + try: + with urllib.request.urlopen(request, timeout=300) as response: + return json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as err: + body = err.read().decode("utf-8", errors="replace") + raise RuntimeError(f"vLLM request failed with HTTP {err.code}: {body}") from err + + +def file_search_tool() -> dict[str, Any]: + return { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query.", + } + }, + "required": ["query"], + }, + } + + +def find_file_search_call(response: dict[str, Any]) -> dict[str, Any]: + for item in response.get("output", []): + if item.get("type") == "function_call" and item.get("name") == "file_search": + return item + raise RuntimeError(f"recorded response did not include a file_search function call: {response}") + + +def main() -> int: + vllm_url = os.environ.get("VLLM_URL", "http://localhost:8000") + model = os.environ.get("MODEL", DEFAULT_MODEL) + output = Path(os.environ.get("OUTPUT", DEFAULT_OUTPUT)).resolve() + prompt = os.environ.get("PROMPT", DEFAULT_PROMPT) + tools = [file_search_tool()] + + first_request = { + "model": model, + "input": prompt, + "tools": tools, + "tool_choice": "auto", + "stream": False, + } + first_response = post_json(vllm_url, first_request) + call = find_file_search_call(first_response) + + second_request = { + "model": model, + "input": [ + { + "type": "message", + "role": "user", + "content": prompt, + }, + call, + { + "type": "function_call_output", + "call_id": call["call_id"], + "output": json.dumps(SEARCH_RESULT, separators=(",", ":")), + }, + ], + "tools": tools, + "tool_choice": "auto", + "stream": False, + } + second_response = post_json(vllm_url, second_request) + + cassette = { + "metadata": { + "recorded_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"), + "model": model, + "note": "Harmony models reject tool_choice=required, so this cassette uses tool_choice=auto with a direct prompt.", + }, + "turns": [ + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": first_request, + }, + "response": { + "status_code": 200, + "body": first_response, + }, + }, + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": second_request, + }, + "response": { + "status_code": 200, + "body": second_response, + }, + }, + ], + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(cassette, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + print(f"recorded file_search cassette -> {output}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs index d852e85..163ced2 100644 --- a/crates/agentic-server/tests/common/mod.rs +++ b/crates/agentic-server/tests/common/mod.rs @@ -1,15 +1,24 @@ +use std::convert::Infallible; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use axum::Router; -use axum::response::IntoResponse; -use axum::routing::get; +use axum::body::Body; +use axum::extract::Request; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use bytes::Bytes; +use futures::stream; use http::StatusCode; use tokio::net::TcpListener; +use tokio::sync::Mutex; use agentic_core::config::Config; use agentic_core::executor::{ConversationHandler, ExecutionContext, ResponseHandler}; use agentic_core::proxy::ProxyState; -use agentic_core::storage::{ConversationStore, ResponseStore}; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use agentic_core::uuid7_str; +use agentic_core::vector_search::ogx::OgxStore; use agentic_server::app::{AppState, ServerConfig, build_router}; pub fn test_config(llm_url: &str) -> Config { @@ -19,6 +28,8 @@ pub fn test_config(llm_url: &str) -> Config { llm_ready_timeout_s: 5.0, llm_ready_interval_s: 0.1, db_url: None, + ogx_base_url: "http://127.0.0.1:1".to_owned(), + max_iterations: 10, } } @@ -55,3 +66,318 @@ pub async fn spawn_gateway(state: AppState) -> (String, tokio::task::JoinHandle< let handle = tokio::spawn(async move { axum::serve(listener, router).await.unwrap() }); (format!("http://{addr}"), handle) } + +pub async fn start_gateway(vllm_port: u16, ogx_port: Option, api_key: Option<&str>) -> (String, u16) { + let ogx_base = match ogx_port { + Some(p) => format!("http://127.0.0.1:{p}"), + None => "http://127.0.0.1:1".to_owned(), + }; + start_gateway_with_ogx_base(vllm_port, &ogx_base, api_key).await +} + +pub async fn start_gateway_with_ogx_base(vllm_port: u16, ogx_base: &str, api_key: Option<&str>) -> (String, u16) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let addr = format!("127.0.0.1:{port}"); + + let llm_url = format!("http://127.0.0.1:{vllm_port}"); + + let mut config = test_config(&llm_url); + config.openai_api_key = api_key.map(String::from); + config.ogx_base_url = ogx_base.to_owned(); + config.db_url = Some(format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-test-"))); + + let proxy_state = ProxyState::new(config.clone()).unwrap(); + let pool = create_pool_with_schema(config.db_url.as_deref()).await.unwrap(); + let ogx_store = Arc::new(OgxStore::new(ogx_base, reqwest::Client::new())); + let exec_ctx = ExecutionContext::new( + ConversationHandler::new(ConversationStore::new(pool.clone())), + ResponseHandler::new(ResponseStore::new(pool)), + Arc::new(reqwest::Client::new()), + config.llm_api_base.clone(), + config.openai_api_key.clone(), + ) + .with_vector_search(ogx_store, config.max_iterations); + + let state = AppState { + proxy_state, + exec_ctx: Arc::new(exec_ctx), + llm_api_base: config.llm_api_base, + }; + + let server_config = ServerConfig::from_env(); + let router = build_router(state, &server_config); + + tokio::spawn(async move { + axum::serve(listener, router).await.unwrap(); + }); + + (addr, port) +} + +async fn health_handler() -> impl IntoResponse { + StatusCode::OK +} + +async fn responses_handler(req: Request) -> Response { + let headers = req.headers().clone(); + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + + if body + .get("echo_auth") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + { + let auth = headers.get("authorization").and_then(|v| v.to_str().ok()).unwrap_or(""); + let resp_body = serde_json::json!({"authorization": auth}); + return ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&resp_body).unwrap(), + ) + .into_response(); + } + + if body.get("force_error").and_then(serde_json::Value::as_u64) == Some(429) { + return ( + StatusCode::TOO_MANY_REQUESTS, + [("content-type", "application/json")], + r#"{"error":{"message":"rate limited","code":"rate_limit"}}"#, + ) + .into_response(); + } + + if body.get("stream").and_then(serde_json::Value::as_bool).unwrap_or(false) { + let chunks: Vec> = vec![ + Ok(Bytes::from( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n", + )), + Ok(Bytes::from("data: [DONE]\n\n")), + ]; + let body = Body::from_stream(stream::iter(chunks)); + return ( + StatusCode::OK, + [("content-type", "text/event-stream; charset=utf-8")], + body, + ) + .into_response(); + } + + let out = r#"{"id":"resp_test","object":"response","status":"completed","output":[]}"#; + (StatusCode::OK, [("content-type", "application/json")], out).into_response() +} + +pub async fn spawn_vllm() -> (u16, tokio::task::JoinHandle<()>) { + let app = Router::new() + .route("/health", get(health_handler)) + .route("/v1/responses", post(responses_handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_mid_stream_failure_vllm() -> (u16, tokio::task::JoinHandle<()>) { + async fn handler(_req: Request) -> Response { + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + tokio::spawn(async move { + let _ = tx + .send(Ok(Bytes::from( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n", + ))) + .await; + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + drop(tx); + }); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let body = Body::from_stream(stream); + ( + StatusCode::OK, + [("content-type", "text/event-stream; charset=utf-8")], + body, + ) + .into_response() + } + + let app = Router::new() + .route("/health", get(health_handler)) + .route("/v1/responses", post(handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_vllm_with_tool_calls(responses: Vec) -> (u16, tokio::task::JoinHandle<()>) { + let responses = Arc::new(responses); + let counter = Arc::new(AtomicUsize::new(0)); + + let app = Router::new().route("/health", get(health_handler)).route( + "/v1/responses", + post({ + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + move |_req: Request| { + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + async move { + let idx = counter.fetch_add(1, Ordering::SeqCst); + let resp = responses.get(idx).unwrap_or(responses.last().unwrap()); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(resp).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_vllm_recording( + responses: Vec, +) -> (u16, Arc>>, tokio::task::JoinHandle<()>) { + let responses = Arc::new(responses); + let counter = Arc::new(AtomicUsize::new(0)); + let requests = Arc::new(Mutex::new(Vec::new())); + + let app = Router::new().route("/health", get(health_handler)).route( + "/v1/responses", + post({ + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + let requests_for_handler = Arc::clone(&requests); + move |req: Request| { + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + let requests_for_handler = Arc::clone(&requests_for_handler); + async move { + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + requests_for_handler.lock().await.push(body); + + let idx = counter.fetch_add(1, Ordering::SeqCst); + let resp = responses.get(idx).unwrap_or(responses.last().unwrap()); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(resp).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, requests, handle) +} + +pub async fn spawn_ogx() -> (u16, tokio::task::JoinHandle<()>) { + async fn search_handler(_req: Request) -> Response { + let body = serde_json::json!({ + "object": "vector_store.search_results.page", + "search_query": ["test query"], + "data": [{ + "file_id": "file_abc", + "filename": "doc.txt", + "score": 0.95, + "attributes": {}, + "content": [{"type": "text", "text": "relevant content from doc"}] + }], + "has_more": false + }); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&body).unwrap(), + ) + .into_response() + } + + let app = Router::new().route("/v1/vector_stores/{store_id}/search", post(search_handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_ogx_recording() -> (u16, Arc>>, tokio::task::JoinHandle<()>) { + let requests = Arc::new(Mutex::new(Vec::new())); + + let app = Router::new().route( + "/v1/vector_stores/{store_id}/search", + post({ + let requests_for_handler = Arc::clone(&requests); + move |req: Request| { + let requests_for_handler = Arc::clone(&requests_for_handler); + async move { + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + requests_for_handler.lock().await.push(body); + + let response = serde_json::json!({ + "object": "vector_store.search_results.page", + "search_query": ["test query"], + "data": [{ + "file_id": "file_abc", + "filename": "doc.txt", + "score": 0.95, + "attributes": {}, + "content": [{"type": "text", "text": "relevant content from doc"}] + }], + "has_more": false + }); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&response).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, requests, handle) +} diff --git a/crates/agentic-server/tests/conversations_test.rs b/crates/agentic-server/tests/conversations_test.rs index 44204bb..48894e8 100644 --- a/crates/agentic-server/tests/conversations_test.rs +++ b/crates/agentic-server/tests/conversations_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use http::StatusCode; diff --git a/crates/agentic-server/tests/cors_test.rs b/crates/agentic-server/tests/cors_test.rs index 4719e9f..84df570 100644 --- a/crates/agentic-server/tests/cors_test.rs +++ b/crates/agentic-server/tests/cors_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use common::{spawn_gateway, spawn_mock_llm, test_config, test_state}; diff --git a/crates/agentic-server/tests/health_test.rs b/crates/agentic-server/tests/health_test.rs index 72d6dcb..e6b6c3e 100644 --- a/crates/agentic-server/tests/health_test.rs +++ b/crates/agentic-server/tests/health_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use agentic_core::config::Config; diff --git a/crates/agentic-server/tests/integration/ogx-config.yaml b/crates/agentic-server/tests/integration/ogx-config.yaml new file mode 100644 index 0000000..c79c86e --- /dev/null +++ b/crates/agentic-server/tests/integration/ogx-config.yaml @@ -0,0 +1,56 @@ +version: 2 +distro_name: agentic-api-test + +apis: + - inference + - files + - vector_io + - file_processors + +providers: + inference: + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: + trust_remote_code: true + + files: + - provider_id: localfs + provider_type: inline::localfs + config: + storage_dir: /tmp/ogx-test/files + metadata_store: + table_name: files_metadata + backend: sql_default + + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + persistence: + namespace: vector_io::faiss + backend: kv_default + + file_processors: + - provider_id: auto + provider_type: inline::auto + config: {} + +storage: + backends: + kv_default: + type: kv_sqlite + db_path: /tmp/ogx-test/kvstore.db + sql_default: + type: sql_sqlite + db_path: /tmp/ogx-test/sql_store.db + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + vector_stores: + table_name: vector_store_metadata + backend: sql_default diff --git a/crates/agentic-server/tests/integration/run.sh b/crates/agentic-server/tests/integration/run.sh new file mode 100755 index 0000000..b588f73 --- /dev/null +++ b/crates/agentic-server/tests/integration/run.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +set -euo pipefail + +OGX_PORT="${OGX_PORT:-8321}" +OGX_PID="" + +cleanup() { + if [ -n "$OGX_PID" ] && kill -0 "$OGX_PID" 2>/dev/null; then + echo "Stopping OGx (pid $OGX_PID)..." + kill "$OGX_PID" 2>/dev/null || true + wait "$OGX_PID" 2>/dev/null || true + fi +} +trap cleanup EXIT + +OGX_CMD="${OGX_CMD:-ogx}" +OGX_CONFIG="$(cd "$(dirname "$0")" && pwd)/ogx-config.yaml" + +echo "Starting OGx on port $OGX_PORT..." +HF_HUB_TRUST_REMOTE_CODE=1 $OGX_CMD run "$OGX_CONFIG" --port "$OGX_PORT" > /tmp/ogx-server.log 2>&1 & +OGX_PID=$! + +echo "Waiting for OGx to be ready..." +for i in $(seq 1 60); do + if curl -sf "http://localhost:$OGX_PORT/v1/health" > /dev/null 2>&1; then + echo "OGx is ready." + break + fi + if ! kill -0 "$OGX_PID" 2>/dev/null; then + echo "OGx process exited unexpectedly. Logs:" + cat /tmp/ogx-server.log + exit 1 + fi + sleep 1 +done + +if ! curl -sf "http://localhost:$OGX_PORT/v1/health" > /dev/null 2>&1; then + echo "OGx failed to start within 60s. Logs:" + cat /tmp/ogx-server.log + exit 1 +fi + +echo "Running integration tests..." +OGX_BASE_URL="http://localhost:$OGX_PORT" cargo test -p agentic-server --test integration_test -- --nocapture + +echo "Integration tests passed." diff --git a/crates/agentic-server/tests/integration_test.rs b/crates/agentic-server/tests/integration_test.rs new file mode 100644 index 0000000..e60df71 --- /dev/null +++ b/crates/agentic-server/tests/integration_test.rs @@ -0,0 +1,242 @@ +#[allow(dead_code)] +mod common; + +use common::{spawn_vllm_recording, start_gateway_with_ogx_base}; +use serde::Deserialize; +use serde_json::Value; + +const FILE_SEARCH_VLLM_CASSETTE: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json" +); + +#[derive(Debug, Deserialize)] +struct VllmCassette { + turns: Vec, +} + +#[derive(Debug, Deserialize)] +struct VllmTurn { + response: VllmResponse, +} + +#[derive(Debug, Deserialize)] +struct VllmResponse { + status_code: u16, + body: Value, +} + +fn ogx_base_url() -> Option { + std::env::var("OGX_BASE_URL").ok() +} + +fn load_file_search_vllm_responses() -> Vec { + let text = std::fs::read_to_string(FILE_SEARCH_VLLM_CASSETTE) + .unwrap_or_else(|err| panic!("failed to read cassette {FILE_SEARCH_VLLM_CASSETTE}: {err}")); + let cassette: VllmCassette = serde_json::from_str(&text) + .unwrap_or_else(|err| panic!("failed to parse cassette {FILE_SEARCH_VLLM_CASSETTE}: {err}")); + cassette + .turns + .into_iter() + .map(|turn| { + assert_eq!(turn.response.status_code, 200, "cassette response should be successful"); + turn.response.body + }) + .collect() +} + +fn output_text(body: &Value) -> String { + body["output"] + .as_array() + .into_iter() + .flatten() + .filter_map(|item| item["content"].as_array()) + .flatten() + .filter_map(|content| content["text"].as_str()) + .collect() +} + +async fn find_embedding_model(client: &reqwest::Client, ogx_url: &str) -> (String, u64) { + let models_resp = client.get(format!("{ogx_url}/v1/models")).send().await.unwrap(); + let models: serde_json::Value = models_resp.json().await.unwrap(); + let embedding_model = models["data"] + .as_array() + .and_then(|arr| { + arr.iter() + .find(|m| m["custom_metadata"]["model_type"].as_str() == Some("embedding")) + }) + .expect("OGx should have at least one embedding model") + .clone(); + let model_id = embedding_model["id"].as_str().unwrap().to_owned(); + let dim = embedding_model["custom_metadata"]["embedding_dimension"] + .as_u64() + .unwrap(); + (model_id, dim) +} + +async fn create_vector_store(client: &reqwest::Client, ogx_url: &str, model_id: &str, dim: u64) -> String { + let vs_resp = client + .post(format!("{ogx_url}/v1/vector_stores")) + .json(&serde_json::json!({ + "name": "integration-test-docs", + "metadata": { "embedding_model": model_id, "embedding_dimension": dim } + })) + .send() + .await + .unwrap(); + assert!(vs_resp.status().is_success(), "Failed to create vector store"); + let vs: serde_json::Value = vs_resp.json().await.unwrap(); + vs["id"].as_str().unwrap().to_owned() +} + +async fn upload_and_attach(client: &reqwest::Client, ogx_url: &str, vs_id: &str) { + let file_content = "Rust enforces memory safety without a garbage collector through its ownership system with borrowing and lifetimes. The borrow checker ensures references do not outlive the data they point to."; + + let form = reqwest::multipart::Form::new().text("purpose", "assistants").part( + "file", + reqwest::multipart::Part::text(file_content.to_owned()) + .file_name("rust-memory-safety.txt") + .mime_str("text/plain") + .unwrap(), + ); + + let file_resp = client + .post(format!("{ogx_url}/v1/files")) + .multipart(form) + .send() + .await + .unwrap(); + assert!(file_resp.status().is_success(), "Failed to upload file"); + + let file: serde_json::Value = file_resp.json().await.unwrap(); + let file_id = file["id"].as_str().unwrap(); + eprintln!("Uploaded file: {file_id}"); + + let attach_resp = client + .post(format!("{ogx_url}/v1/vector_stores/{vs_id}/files")) + .json(&serde_json::json!({"file_id": file_id})) + .send() + .await + .unwrap(); + assert!(attach_resp.status().is_success(), "Failed to attach file"); + + let attach: serde_json::Value = attach_resp.json().await.unwrap(); + let status = attach["status"].as_str().unwrap_or("unknown"); + assert_eq!( + status, + "completed", + "File attachment failed: {}", + attach + .get("last_error") + .map_or("none".to_owned(), std::string::ToString::to_string) + ); +} + +async fn assert_gateway_file_search_uses_ogx(client: &reqwest::Client, ogx_url: &str, vs_id: &str) { + let (vllm_port, requests, _vllm_handle) = spawn_vllm_recording(load_file_search_vllm_responses()).await; + let (gateway_addr, _) = start_gateway_with_ogx_base(vllm_port, ogx_url, None).await; + + let gateway_resp = client + .post(format!("http://{gateway_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "How does Rust provide memory safety?", + "tools": [{"type": "file_search", "vector_store_ids": [vs_id]}] + })) + .send() + .await + .unwrap(); + let gateway_status = gateway_resp.status(); + let gateway_body: Value = gateway_resp.json().await.unwrap(); + assert!( + gateway_status.is_success(), + "gateway file_search failed: {gateway_body}" + ); + + let answer = output_text(&gateway_body); + assert!( + answer.contains("ownership"), + "gateway should return the recorded final vLLM answer, got: {answer}" + ); + + let requests = requests.lock().await; + assert_eq!( + requests.len(), + 2, + "gateway should call vLLM before and after OGX search" + ); + + let first_tools = requests[0]["tools"] + .as_array() + .expect("first request should include tools"); + assert_eq!(first_tools[0]["type"], "function"); + assert_eq!(first_tools[0]["name"], "file_search"); + + let second_input = requests[1]["input"] + .as_array() + .expect("second request should include input items"); + let tool_output = second_input + .iter() + .find(|item| item["type"] == "function_call_output") + .expect("gateway should append function_call_output"); + let output = tool_output["output"] + .as_str() + .expect("tool output should be a JSON string"); + let output_json: serde_json::Value = serde_json::from_str(output).expect("tool output should parse as JSON"); + let gateway_results = output_json["results"] + .as_array() + .expect("tool output should include results"); + assert!( + !gateway_results.is_empty(), + "gateway should pass OGX search results back to vLLM" + ); +} + +#[tokio::test] +async fn test_vector_search_with_ogx() { + let Some(ogx_url) = ogx_base_url() else { + eprintln!("Skipping: OGX_BASE_URL not set"); + return; + }; + + let client = reqwest::Client::new(); + + let (model_id, dim) = find_embedding_model(&client, &ogx_url).await; + eprintln!("Using embedding model: {model_id} (dim={dim})"); + + let vs_id = create_vector_store(&client, &ogx_url, &model_id, dim).await; + eprintln!("Created vector store: {vs_id}"); + + upload_and_attach(&client, &ogx_url, &vs_id).await; + + let search_resp = client + .post(format!("{ogx_url}/v1/vector_stores/{vs_id}/search")) + .json(&serde_json::json!({ + "query": "memory safety ownership", + "max_num_results": 2 + })) + .send() + .await + .unwrap(); + assert!(search_resp.status().is_success(), "Search failed"); + + let results: serde_json::Value = search_resp.json().await.unwrap(); + let data = results["data"].as_array().expect("search should return data array"); + assert!(!data.is_empty(), "search should return at least one result"); + + let top_result = &data[0]; + let score = top_result["score"].as_f64().unwrap_or(0.0); + assert!(score > 0.0, "top result should have a positive score"); + + let content = top_result["content"] + .as_array() + .and_then(|c| c.first()) + .and_then(|c| c["text"].as_str()) + .unwrap_or(""); + assert!(!content.is_empty(), "top result should have content text"); + + eprintln!("Search returned {} results, top score: {score:.3}", data.len()); + eprintln!("Top result: {content}"); + + assert_gateway_file_search_uses_ogx(&client, &ogx_url, &vs_id).await; +} diff --git a/crates/agentic-server/tests/proxy_test.rs b/crates/agentic-server/tests/proxy_test.rs new file mode 100644 index 0000000..d06c122 --- /dev/null +++ b/crates/agentic-server/tests/proxy_test.rs @@ -0,0 +1,167 @@ +#[allow(dead_code)] +mod common; + +use common::{spawn_mid_stream_failure_vllm, spawn_vllm, start_gateway}; + +#[tokio::test] +async fn test_non_stream_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}], + "store": false + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + +#[tokio::test] +async fn test_string_input_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "hello", + "store": false + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + +#[tokio::test] +async fn test_stream_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}], + "store": false, + "stream": true + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + let text = resp.text().await.unwrap(); + assert!(text.contains("data: [DONE]")); + assert!(text.contains("response.output_text.delta")); +} + +#[tokio::test] +async fn test_auth_injection() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "store": false, "echo_auth": true})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["authorization"], "Bearer env-vllm-key"); +} + +#[tokio::test] +async fn test_client_auth_precedence() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "store": false, "echo_auth": true})) + .header("authorization", "Bearer client-token") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["authorization"], "Bearer client-token"); +} + +#[tokio::test] +async fn test_vllm_http_error_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "store": false, "force_error": 429})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 429); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["error"]["message"], "rate limited"); + assert_eq!(body["error"]["code"], "rate_limit"); +} + +#[tokio::test] +async fn test_mid_stream_failure_closes_cleanly() { + let (vllm_port, _h) = spawn_mid_stream_failure_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [], + "store": false, + "stream": true + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let text = resp.text().await.unwrap_or_default(); + assert!(text.contains("response.output_text.delta")); +} + +#[tokio::test] +async fn test_connect_error_maps_to_502() { + let (gw_addr, _) = start_gateway(1, None, None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "store": false})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); +} diff --git a/crates/agentic-server/tests/responses_test.rs b/crates/agentic-server/tests/responses_test.rs index d639679..449ddc9 100644 --- a/crates/agentic-server/tests/responses_test.rs +++ b/crates/agentic-server/tests/responses_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use axum::Router; diff --git a/docs/architecture/index.md b/docs/architecture/index.md new file mode 100644 index 0000000..700560b --- /dev/null +++ b/docs/architecture/index.md @@ -0,0 +1,78 @@ +# Architecture + +## Overview + +The vLLM Agentic API is a Rust gateway that sits between clients and vLLM, adding stateful capabilities on top of vLLM's stateless Responses API. The gateway is structured as a three-crate workspace. + +```mermaid +graph TD + Client -->|POST /v1/responses| Gateway[agentic-server :9000] + Gateway -->|proxy| vLLM[vLLM :8000] + Gateway -.->|vector search| OGx[OGx :8080] +``` + +## Crate Structure + +| Crate | Role | +|-------|------| +| `agentic-core` | Framework-agnostic core: inference caller, storage, vector search traits, OGx client | +| `agentic-server` | Axum HTTP server: routes, handler, CLI, agentic loop | +| `agentic-praxis` | Reserved for Praxis gateway adapter | + +### agentic-core + +Pure async Rust with no framework dependency. Contains: + +- **Proxy** (`proxy.rs`) — HTTP client that forwards requests to vLLM with auth injection, header filtering, and streaming support +- **Readiness** (`readiness.rs`) — Polls vLLM's `/health` endpoint until ready +- **Storage** (`storage/`) — SQLx-based CRUD for conversations and responses (SQLite, PostgreSQL, MySQL) +- **Vector search** (`vector_search/`) — `VectorSearch` trait and OGx implementation for file_search tool calls +- **Types** (`types/`) — Serde structs for the Responses API IO types + +### agentic-server + +Axum-based HTTP server that wires everything together: + +- **Handler** (`handler.rs`) — Request routing: proxies `store=false` requests to vLLM, otherwise runs the executor and its agentic loop when `file_search` tools are present +- **App** (`app.rs`) — Router with `/health`, `/ready`, `/v1/responses` routes and CORS +- **CLI** (`main.rs`) — Clap-based CLI with `--llm-api-base`, `--ogx-base-url`, `--max-iterations`, `--db-url`, and a `serve` subcommand that spawns vLLM as a subprocess + +## Request Flow + +### Passthrough (`store=false`) + +``` +Client → Gateway → vLLM → Gateway → Client +``` + +Requests with `store: false` and no state IDs are forwarded to vLLM unchanged. Streaming responses are proxied as SSE. + +### Stateful Requests + +Requests with `store: true`, `previous_response_id`, or `conversation_id` run through the executor so the gateway can assign response IDs and persist/replay conversation history. + +### Agentic Loop (`file_search`) + +When the request includes `tools: [{type: "file_search", vector_store_ids: [...]}]`: + +1. Reject streaming file-search requests until the tool loop can emit interleaved SSE events +2. Hydrate `previous_response_id` history from the response store, if present +3. Convert `file_search` to a `function` tool definition for vLLM +4. Send to vLLM (non-streaming, forced `stream: false`) +5. If vLLM returns `function_call` output items with `name: "file_search"`: + - Extract the query from the call arguments + - Search each vector store via OGx (`POST /v1/vector_stores/{id}/search`) + - Append the tool call output and search results to the input + - Go to step 4 +6. If no tool calls, persist the response and return the final response to the client +7. If `max_iterations` is reached, return a 502 error + +## OGx Integration + +[OGx](https://github.com/meta-llama/llama-stack) provides the vector search backend via its OpenAI-compatible API: + +| Endpoint | Purpose | +|----------|---------| +| `POST /v1/vector_stores/{id}/search` | Execute vector search for file_search tool calls | + +The `OgxStore` struct implements the `VectorSearch` trait, so the handler depends on the trait, not OGx directly.