From b11177a2e0752f6b956f128fa9cf0d27fc32756e Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Mon, 1 Jun 2026 11:55:35 -0400 Subject: [PATCH 1/6] feat: add OGx-backed agentic loop and state hydration Add trait-based store backend (ResponseStore, VectorSearch) in agentic-core with OGx as the first implementation. The gateway handler detects file_search tool calls and runs an agentic loop: send to vLLM with file_search converted to a function tool, execute vector search via OGx, feed results back, repeat until no more tool calls or max iterations reached. State hydration prepends previous conversation history when previous_response_id is set. - agentic-core: types, store traits, OGx impl, error expansion - agentic-server: handler with agentic loop, state hydration, proxy - 22 new tests: proxy, agentic loop, state hydration, integration - CI: integration job with OGx vector search Co-Authored-By: Claude Opus 4.6 Signed-off-by: Francisco Javier Arceo --- .github/workflows/rust.yml | 32 ++ Cargo.lock | 359 ++++-------------- Makefile | 5 +- crates/agentic-core/Cargo.toml | 3 +- crates/agentic-core/src/error.rs | 15 + crates/agentic-core/src/lib.rs | 1 + crates/agentic-core/src/vector_search/mod.rs | 11 + crates/agentic-core/src/vector_search/ogx.rs | 49 +++ .../agentic-core/src/vector_search/types.rs | 79 ++++ crates/agentic-server/Cargo.toml | 7 +- crates/agentic-server/benches/proxy_bench.rs | 12 +- crates/agentic-server/src/app.rs | 11 +- crates/agentic-server/src/handler.rs | 273 +++++++++++-- crates/agentic-server/src/main.rs | 25 +- crates/agentic-server/src/server.rs | 41 +- .../agentic-server/tests/agentic_loop_test.rs | 175 +++++++++ crates/agentic-server/tests/common/mod.rs | 231 +++++++++++ crates/agentic-server/tests/cors_test.rs | 13 +- crates/agentic-server/tests/health_test.rs | 41 +- .../tests/integration/ogx-config.yaml | 55 +++ .../agentic-server/tests/integration/run.sh | 46 +++ .../agentic-server/tests/integration_test.rs | 126 ++++++ crates/agentic-server/tests/proxy_test.rs | 142 +++++++ docs/architecture/index.md | 72 ++++ 24 files changed, 1472 insertions(+), 352 deletions(-) create mode 100644 crates/agentic-core/src/vector_search/mod.rs create mode 100644 crates/agentic-core/src/vector_search/ogx.rs create mode 100644 crates/agentic-core/src/vector_search/types.rs create mode 100644 crates/agentic-server/tests/agentic_loop_test.rs create mode 100644 crates/agentic-server/tests/common/mod.rs create mode 100644 crates/agentic-server/tests/integration/ogx-config.yaml create mode 100755 crates/agentic-server/tests/integration/run.sh create mode 100644 crates/agentic-server/tests/integration_test.rs create mode 100644 crates/agentic-server/tests/proxy_test.rs create mode 100644 docs/architecture/index.md diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9ccb342..eca5345 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -48,3 +48,35 @@ jobs: - name: Run tests run: cargo test + + integration: + runs-on: ubuntu-latest + needs: rust + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@e97e2d8cc328f1b50210efc529dca0028893a2d9 # v1 + with: + toolchain: stable + + - name: Cache cargo registry and build + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: cargo-${{ runner.os }}-${{ hashFiles('Cargo.lock') }} + restore-keys: | + cargo-${{ runner.os }}- + + - name: Install uv + uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6 + + - name: Install OGx + run: uv tool install "ogx[starter]" --with "sentence-transformers>=5" --with "huggingface_hub<1.18" + + - name: Run integration tests + run: make integration-test diff --git a/Cargo.lock b/Cargo.lock index 463a8e9..84f0288 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,7 @@ name = "agentic-core" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "axum", "bytes", "chrono", @@ -47,6 +48,7 @@ dependencies = [ "reqwest", "serde_json", "tokio", + "tokio-stream", "tower-http", "tracing", "tracing-subscriber", @@ -160,6 +162,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atoi" version = "2.0.0" @@ -247,9 +260,9 @@ checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" [[package]] name = "bitflags" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" dependencies = [ "serde_core", ] @@ -303,17 +316,11 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "js-sys", @@ -795,24 +802,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", -] - -[[package]] -name = "getrandom" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "r-efi 5.3.0", - "wasip2", - "wasm-bindgen", ] [[package]] @@ -823,7 +814,7 @@ checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi 6.0.0", + "r-efi", "wasip2", "wasip3", ] @@ -912,9 +903,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.1" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -976,22 +967,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper-rustls" -version = "0.27.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls", - "tokio", - "tokio-rustls", - "tower-service", - "webpki-roots 1.0.7", -] - [[package]] name = "hyper-tls" version = "0.6.0" @@ -1216,13 +1191,12 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "js-sys" -version = "0.3.99" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" +checksum = "f2025f20d7a4fa7785846e7b63d10a76d3f1cee98ee5cb79ea59703f95e42162" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] @@ -1309,15 +1283,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.31" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "113b30b4cd05f7c06868fdb2854f66a7b9fece9a48425351cd532e810d74024f" - -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "matchers" @@ -1356,6 +1324,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "1.2.1" @@ -1404,7 +1382,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "rand 0.8.6", + "rand", "smallvec", "zeroize", ] @@ -1648,61 +1626,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "quinn" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash", - "rustls", - "socket2", - "thiserror", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" -dependencies = [ - "bytes", - "getrandom 0.3.4", - "lru-slab", - "rand 0.9.4", - "ring", - "rustc-hash", - "rustls", - "rustls-pki-types", - "slab", - "thiserror", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.60.2", -] - [[package]] name = "quote" version = "1.0.45" @@ -1712,12 +1635,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "r-efi" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" - [[package]] name = "r-efi" version = "6.0.0" @@ -1731,18 +1648,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" -dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.5", + "rand_chacha", + "rand_core", ] [[package]] @@ -1752,17 +1659,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.5", + "rand_core", ] [[package]] @@ -1774,15 +1671,6 @@ dependencies = [ "getrandom 0.2.17", ] -[[package]] -name = "rand_core" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" -dependencies = [ - "getrandom 0.3.4", -] - [[package]] name = "rayon" version = "1.12.0" @@ -1823,9 +1711,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -1846,9 +1734,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "reqwest" @@ -1864,16 +1752,14 @@ dependencies = [ "http-body", "http-body-util", "hyper", - "hyper-rustls", "hyper-tls", "hyper-util", "js-sys", "log", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", "rustls-pki-types", "serde", "serde_json", @@ -1881,7 +1767,6 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", - "tokio-rustls", "tokio-util", "tower", "tower-http", @@ -1891,7 +1776,6 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 1.0.7", ] [[package]] @@ -1921,19 +1805,13 @@ dependencies = [ "num-traits", "pkcs1", "pkcs8", - "rand_core 0.6.4", + "rand_core", "signature", "spki", "subtle", "zeroize", ] -[[package]] -name = "rustc-hash" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" - [[package]] name = "rustix" version = "1.1.4" @@ -1967,7 +1845,6 @@ version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ - "web-time", "zeroize", ] @@ -2195,7 +2072,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest", - "rand_core 0.6.4", + "rand_core", ] [[package]] @@ -2358,7 +2235,7 @@ dependencies = [ "memchr", "once_cell", "percent-encoding", - "rand 0.8.6", + "rand", "rsa", "serde", "sha1", @@ -2396,7 +2273,7 @@ dependencies = [ "md-5", "memchr", "once_cell", - "rand 0.8.6", + "rand", "serde", "serde_json", "sha2", @@ -2607,16 +2484,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" -dependencies = [ - "rustls", - "tokio", -] - [[package]] name = "tokio-stream" version = "0.1.18" @@ -2761,6 +2628,12 @@ version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -2832,9 +2705,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.2" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -2911,9 +2784,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" +checksum = "a254a4b10c19a76f09a27640e7ffbf9bc30bf67e16a3bf28aaefa4920fe81563" dependencies = [ "cfg-if", "once_cell", @@ -2924,9 +2797,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.72" +version = "0.4.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" +checksum = "54568702fabf5d4849ce2b90fadfa64168a097eaf4b351ce9df8b687a0086aaf" dependencies = [ "js-sys", "wasm-bindgen", @@ -2934,9 +2807,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" +checksum = "24a40fc75b0ec6f3746ceb10d36f53a93dcd68a93b11b6445983945d79eba0dc" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2944,9 +2817,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" +checksum = "908f34bd9b9ce3d4caf07b72dfab63d61504d156856c6bd3cd87fa350cf3985b" dependencies = [ "bumpalo", "proc-macro2", @@ -2957,9 +2830,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" +checksum = "7acbf7616c27b194bbb550bf77ed0c2c3e5b7fd1260a93082b95fb7f47959b92" dependencies = [ "unicode-ident", ] @@ -3013,19 +2886,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.99" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "web-time" -version = "1.1.0" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +checksum = "6e0871acf327f283dc6da28a1696cdc64fb355ba9f935d052021fa77f35cce69" dependencies = [ "js-sys", "wasm-bindgen", @@ -3145,15 +3008,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", -] - [[package]] name = "windows-sys" version = "0.61.2" @@ -3187,30 +3041,13 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", + "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", -] - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3223,12 +3060,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3241,12 +3072,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3259,24 +3084,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3289,12 +3102,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3307,12 +3114,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3325,12 +3126,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -3343,12 +3138,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - [[package]] name = "wit-bindgen" version = "0.51.0" @@ -3451,9 +3240,9 @@ checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -3474,18 +3263,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.50" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.50" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" dependencies = [ "proc-macro2", "quote", diff --git a/Makefile b/Makefile index 9a2f13a..26c0f62 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help install lint format test build pre-commit clean +.PHONY: help install lint format test build pre-commit clean integration-test help: ## Show this help message @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' @@ -23,3 +23,6 @@ pre-commit: ## Run pre-commit hooks on all files clean: ## Remove Rust build artifacts cargo clean + +integration-test: ## Run integration tests (starts OGx, runs tests, tears down) + ./crates/agentic-server/tests/integration/run.sh diff --git a/crates/agentic-core/Cargo.toml b/crates/agentic-core/Cargo.toml index e8650c2..11863a3 100644 --- a/crates/agentic-core/Cargo.toml +++ b/crates/agentic-core/Cargo.toml @@ -8,11 +8,12 @@ repository.workspace = true [dependencies] async-stream.workspace = true +async-trait = "0.1" bytes.workspace = true either.workspace = true futures.workspace = true http.workspace = true -reqwest = { workspace = true, features = ["default-tls", "stream"] } +reqwest = { workspace = true, features = ["default-tls", "json", "stream"] } serde.workspace = true serde_json.workspace = true thiserror.workspace = true diff --git a/crates/agentic-core/src/error.rs b/crates/agentic-core/src/error.rs index 011c6eb..980ee67 100644 --- a/crates/agentic-core/src/error.rs +++ b/crates/agentic-core/src/error.rs @@ -19,4 +19,19 @@ pub enum Error { #[error("{0}")] Config(String), + + #[error("store request failed")] + Store(#[source] reqwest::Error), + + #[error("store returned {status}: {body}")] + StoreResponse { status: u16, body: String }, + + #[error("vLLM proxy request failed")] + Proxy(#[source] reqwest::Error), + + #[error("vLLM returned {status}: {body}")] + ProxyResponse { status: u16, body: String }, + + #[error("agentic loop exceeded {max_iterations} iterations")] + MaxIterations { max_iterations: u32 }, } diff --git a/crates/agentic-core/src/lib.rs b/crates/agentic-core/src/lib.rs index b89988e..9f903fc 100644 --- a/crates/agentic-core/src/lib.rs +++ b/crates/agentic-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod readiness; pub mod storage; pub mod types; pub mod utils; +pub mod vector_search; pub use storage::{ ConversationData, ConversationStore, DbPool, InOutItem, ItemKind, ResponseData, ResponseMetadata, ResponseStore, diff --git a/crates/agentic-core/src/vector_search/mod.rs b/crates/agentic-core/src/vector_search/mod.rs new file mode 100644 index 0000000..267186b --- /dev/null +++ b/crates/agentic-core/src/vector_search/mod.rs @@ -0,0 +1,11 @@ +pub mod ogx; +pub mod types; + +use async_trait::async_trait; + +use types::SearchResult; + +#[async_trait] +pub trait VectorSearch: Send + Sync { + async fn search(&self, store_id: &str, query: &str) -> Result, crate::error::Error>; +} diff --git a/crates/agentic-core/src/vector_search/ogx.rs b/crates/agentic-core/src/vector_search/ogx.rs new file mode 100644 index 0000000..88b99a6 --- /dev/null +++ b/crates/agentic-core/src/vector_search/ogx.rs @@ -0,0 +1,49 @@ +use async_trait::async_trait; +use tracing::debug; + +use super::types::{SearchResponse, SearchResult}; +use crate::error::Error; + +pub struct OgxStore { + base_url: String, + client: reqwest::Client, +} + +impl OgxStore { + #[must_use] + pub fn new(base_url: &str, client: reqwest::Client) -> Self { + let base_url = base_url.trim_end_matches('/').to_owned(); + Self { base_url, client } + } +} + +#[async_trait] +impl super::VectorSearch for OgxStore { + async fn search(&self, store_id: &str, query: &str) -> Result, Error> { + let url = format!("{}/v1/vector_stores/{store_id}/search", self.base_url); + debug!(%url, %query, "searching vector store via OGx"); + + let resp = self + .client + .post(&url) + .json(&serde_json::json!({ + "query": query, + "max_num_results": 10 + })) + .send() + .await + .map_err(Error::Store)?; + + let status = resp.status(); + if !status.is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(Error::StoreResponse { + status: status.as_u16(), + body, + }); + } + + let search_resp: SearchResponse = resp.json().await.map_err(Error::Store)?; + Ok(search_resp.data) + } +} diff --git a/crates/agentic-core/src/vector_search/types.rs b/crates/agentic-core/src/vector_search/types.rs new file mode 100644 index 0000000..c23e97f --- /dev/null +++ b/crates/agentic-core/src/vector_search/types.rs @@ -0,0 +1,79 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseRequest { + pub model: String, + #[serde(default)] + pub input: Vec, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolConfig { + pub r#type: String, + #[serde(default)] + pub vector_store_ids: Option>, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseBody { + pub id: String, + #[serde(default)] + pub output: Vec, + #[serde(default)] + pub status: String, + #[serde(flatten)] + pub rest: serde_json::Map, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum VllmOutputItem { + #[serde(rename = "message")] + Message { + #[serde(flatten)] + fields: serde_json::Map, + }, + #[serde(rename = "function_call")] + FunctionCall { + id: String, + call_id: String, + name: String, + arguments: String, + #[serde(flatten)] + rest: serde_json::Map, + }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResponse { + pub data: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub file_id: String, + pub filename: String, + pub score: f64, + #[serde(default)] + pub attributes: Option>, + #[serde(default)] + pub content: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContentChunk { + pub r#type: String, + pub text: String, +} diff --git a/crates/agentic-server/Cargo.toml b/crates/agentic-server/Cargo.toml index cc9bbf0..a522ad8 100644 --- a/crates/agentic-server/Cargo.toml +++ b/crates/agentic-server/Cargo.toml @@ -9,9 +9,10 @@ repository.workspace = true [dependencies] agentic-core.workspace = true axum.workspace = true +bytes.workspace = true clap.workspace = true http.workspace = true -reqwest = { workspace = true, default-features = false, features = ["rustls-tls"] } +reqwest = { workspace = true, features = ["json", "stream"] } serde_json.workspace = true tokio.workspace = true tower-http.workspace = true @@ -19,12 +20,12 @@ tracing.workspace = true tracing-subscriber.workspace = true [dev-dependencies] -bytes.workspace = true criterion.workspace = true futures.workspace = true -reqwest = { workspace = true, features = ["json"] } +reqwest = { workspace = true, features = ["json", "multipart"] } serde_json.workspace = true tokio = { workspace = true, features = ["test-util"] } +tokio-stream = "0.1" [[bench]] name = "proxy_bench" diff --git a/crates/agentic-server/benches/proxy_bench.rs b/crates/agentic-server/benches/proxy_bench.rs index 7d924bc..32c41b2 100644 --- a/crates/agentic-server/benches/proxy_bench.rs +++ b/crates/agentic-server/benches/proxy_bench.rs @@ -1,4 +1,5 @@ use std::convert::Infallible; +use std::sync::Arc; use axum::body::Body; use axum::extract::Request; @@ -14,7 +15,9 @@ use tokio::runtime::Runtime; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::vector_search::ogx::OgxStore; use agentic_server::app::{ServerConfig, build_router}; +use agentic_server::handler::AppState; fn bench_config(llm_url: &str) -> Config { Config { @@ -71,7 +74,14 @@ async fn spawn_llm() -> String { } async fn spawn_gateway(config: Config) -> String { - let state = ProxyState::new(config).unwrap(); + let proxy = ProxyState::new(config).unwrap(); + let client = reqwest::Client::new(); + let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let state = Arc::new(AppState { + proxy, + max_iterations: 10, + vector_search: ogx_store, + }); let server_config = ServerConfig::from_env(); let router = build_router(state, &server_config); diff --git a/crates/agentic-server/src/app.rs b/crates/agentic-server/src/app.rs index 1c22a18..f028186 100644 --- a/crates/agentic-server/src/app.rs +++ b/crates/agentic-server/src/app.rs @@ -1,18 +1,17 @@ -use agentic_core::proxy::ProxyState; +use std::sync::Arc; + use axum::Router; use axum::routing::{get, post}; use http::HeaderValue; use tower_http::cors::{AllowOrigin, Any, CorsLayer}; -use crate::handler::{health, proxy_responses, ready}; +use crate::handler::{AppState, handle_responses, health, ready}; -/// Server-level configuration read from environment variables. pub struct ServerConfig { pub cors_allowed_origins: Vec, } impl ServerConfig { - /// Read `CORS_ALLOWED_ORIGINS` (comma-separated). Unset or empty = permissive. #[must_use] pub fn from_env() -> Self { let cors_allowed_origins = std::env::var("CORS_ALLOWED_ORIGINS") @@ -47,11 +46,11 @@ impl ServerConfig { } } -pub fn build_router(state: ProxyState, server_config: &ServerConfig) -> Router { +pub fn build_router(state: Arc, server_config: &ServerConfig) -> Router { Router::new() .route("/health", get(health)) .route("/ready", get(ready)) - .route("/v1/responses", post(proxy_responses)) + .route("/v1/responses", post(handle_responses)) .layer(server_config.cors_layer()) .with_state(state) } diff --git a/crates/agentic-server/src/handler.rs b/crates/agentic-server/src/handler.rs index 154cdd6..8b4c07f 100644 --- a/crates/agentic-server/src/handler.rs +++ b/crates/agentic-server/src/handler.rs @@ -1,22 +1,33 @@ -use agentic_core::proxy::{ProxyBody, ProxyRequest, ProxyResponse, ProxyState, error_response}; +use std::sync::Arc; + use axum::body::Body; use axum::extract::State; +use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; -use http::StatusCode; -use tracing::warn; +use bytes::Bytes; +use tracing::{debug, error, warn}; + +use agentic_core::proxy::ProxyState; +use agentic_core::vector_search::VectorSearch; +use agentic_core::vector_search::types::{ResponseBody, ResponseRequest, SearchResult, VllmOutputItem}; -const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; +pub struct AppState { + pub proxy: ProxyState, + pub max_iterations: u32, + pub vector_search: Arc, +} -pub async fn health() -> impl IntoResponse { +#[allow(clippy::unused_async)] +pub async fn health() -> StatusCode { StatusCode::OK } -pub async fn ready(State(state): State) -> impl IntoResponse { - let base = state.config.llm_api_base.trim_end_matches('/'); +pub async fn ready(State(state): State>) -> impl IntoResponse { + let base = state.proxy.config.llm_api_base.trim_end_matches('/'); let url = format!("{base}/health"); let mut headers = reqwest::header::HeaderMap::new(); - if let Some(key) = state.config.openai_api_key.as_deref() { + if let Some(key) = state.proxy.config.openai_api_key.as_deref() { let trimmed = key.trim(); if !trimmed.is_empty() { if let Ok(v) = reqwest::header::HeaderValue::from_str(&format!("Bearer {trimmed}")) { @@ -47,33 +58,237 @@ pub async fn ready(State(state): State) -> impl IntoResponse { } } -fn convert_response(resp: ProxyResponse) -> Response { - let mut builder = Response::builder().status(resp.status); - for (name, value) in &resp.headers { - builder = builder.header(name, value); +pub async fn handle_responses(State(state): State>, headers: HeaderMap, body: Bytes) -> Response { + let request: ResponseRequest = match serde_json::from_slice(&body) { + Ok(r) => r, + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + serde_json::json!({"error": {"message": format!("invalid request body: {e}")}}).to_string(), + ) + .into_response(); + } + }; + + let has_file_search = request.tools.iter().any(|t| t.r#type == "file_search"); + + if has_file_search { + match agentic_loop(&state, &headers, request).await { + Ok(resp) => resp, + Err(e) => { + error!(error = %e, "agentic loop failed"); + json_error_response(StatusCode::BAD_GATEWAY, &format!("agentic loop error: {e}")) + } + } + } else { + proxy_to_vllm(&state, &headers, &body, request.stream).await + } +} + +async fn agentic_loop( + state: &AppState, + client_headers: &HeaderMap, + mut request: ResponseRequest, +) -> Result { + let vector_store_ids: Vec = request + .tools + .iter() + .filter(|t| t.r#type == "file_search") + .filter_map(|t| t.vector_store_ids.clone()) + .flatten() + .collect(); + + for iteration in 0..state.max_iterations { + debug!(iteration, "agentic loop iteration"); + + let mut loop_request = build_vllm_request(state, client_headers); + + let mut body = serde_json::to_value(&request) + .map_err(|e| agentic_core::error::Error::Config(format!("failed to serialize request: {e}")))?; + if let Some(obj) = body.as_object_mut() { + obj.insert("stream".to_owned(), serde_json::Value::Bool(false)); + if let Some(serde_json::Value::Array(tools)) = obj.get_mut("tools") { + let had_file_search = tools + .iter() + .any(|t| t.get("type").and_then(serde_json::Value::as_str) == Some("file_search")); + tools.retain(|t| t.get("type").and_then(serde_json::Value::as_str) != Some("file_search")); + if had_file_search { + tools.push(serde_json::json!({ + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant content. Use this when the user asks about documents or needs information from files.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to find relevant content in files" + } + }, + "required": ["query"] + } + })); + } + } + } + + loop_request = loop_request.json(&body); + + let resp = loop_request.send().await.map_err(agentic_core::error::Error::Proxy)?; + + let status = resp.status(); + if !status.is_success() { + let resp_body = resp.text().await.unwrap_or_default(); + return Ok(( + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), + resp_body, + ) + .into_response()); + } + + let response_body: ResponseBody = resp.json().await.map_err(agentic_core::error::Error::Proxy)?; + + let tool_calls: Vec<_> = response_body + .output + .iter() + .filter_map(|item| match item { + VllmOutputItem::FunctionCall { + call_id, + name, + arguments, + .. + } if name == "file_search" => Some((call_id.clone(), arguments.clone())), + _ => None, + }) + .collect(); + + if tool_calls.is_empty() { + debug!(iteration, "no tool calls, returning final response"); + let final_json = serde_json::to_string(&response_body) + .unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_owned()); + return Ok((StatusCode::OK, [("content-type", "application/json")], final_json).into_response()); + } + + for output_item in &response_body.output { + request + .input + .push(serde_json::to_value(output_item).unwrap_or_default()); + } + + for (call_id, arguments) in &tool_calls { + let query = extract_query(arguments); + debug!(%call_id, %query, "executing file_search tool call"); + + let results = execute_file_search(state, &vector_store_ids, &query).await; + + let tool_output = serde_json::json!({ + "type": "function_call_output", + "call_id": call_id, + "output": serde_json::to_string(&results).unwrap_or_default() + }); + request.input.push(tool_output); + } + + debug!( + iteration, + tool_calls = tool_calls.len(), + "fed tool results back, continuing loop" + ); } - match resp.body { - ProxyBody::Full(bytes) => builder.body(Body::from(bytes)).expect("valid response"), - ProxyBody::Stream(stream) => builder.body(Body::from_stream(stream)).expect("valid response"), + + warn!( + max_iterations = state.max_iterations, + "agentic loop reached max iterations" + ); + Err(agentic_core::error::Error::MaxIterations { + max_iterations: state.max_iterations, + }) +} + +async fn execute_file_search(state: &AppState, vector_store_ids: &[String], query: &str) -> Vec { + let mut all_results = Vec::new(); + for store_id in vector_store_ids { + match state.vector_search.search(store_id, query).await { + Ok(results) => all_results.extend(results), + Err(e) => { + warn!(%store_id, error = %e, "file_search failed for vector store"); + } + } } + all_results } -pub async fn proxy_responses(State(state): State, req: axum::extract::Request) -> Response { - let (parts, body) = req.into_parts(); +fn extract_query(arguments: &str) -> String { + serde_json::from_str::(arguments) + .ok() + .and_then(|v| v.get("query").and_then(serde_json::Value::as_str).map(String::from)) + .unwrap_or_default() +} - let Ok(body_bytes) = axum::body::to_bytes(body, MAX_BODY_SIZE).await else { - return convert_response(error_response( - StatusCode::PAYLOAD_TOO_LARGE, - "body_too_large", - "Request body too large", - )); - }; +async fn proxy_to_vllm(state: &AppState, client_headers: &HeaderMap, body: &Bytes, stream: bool) -> Response { + let req = build_vllm_request(state, client_headers).body(body.clone()); - let proxy_req = ProxyRequest { - headers: parts.headers, - body: body_bytes, - query: parts.uri.query().map(String::from), + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + error!(error = %e, "failed to connect to vLLM"); + return json_error_response(StatusCode::BAD_GATEWAY, &format!("vLLM connection failed: {e}")); + } }; - convert_response(agentic_core::proxy::proxy_request(proxy_req, &state).await) + let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); + + if stream { + let content_type = resp + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("text/event-stream") + .to_owned(); + + let stream = resp.bytes_stream(); + let body = Body::from_stream(stream); + + (status, [("content-type", content_type)], body).into_response() + } else { + let content_type = resp + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/json") + .to_owned(); + + match resp.bytes().await { + Ok(bytes) => (status, [("content-type", content_type)], bytes).into_response(), + Err(e) => json_error_response(StatusCode::BAD_GATEWAY, &format!("failed to read vLLM response: {e}")), + } + } +} + +fn build_vllm_request(state: &AppState, client_headers: &HeaderMap) -> reqwest::RequestBuilder { + let url = format!("{}/v1/responses", state.proxy.config.llm_api_base); + let mut req = state + .proxy + .non_stream_client + .post(&url) + .header("content-type", "application/json"); + + if let Some(auth) = client_headers.get(http::header::AUTHORIZATION) { + if let Ok(v) = auth.to_str() { + req = req.header("authorization", v); + } + } else if let Some(key) = &state.proxy.config.openai_api_key { + req = req.header("authorization", format!("Bearer {key}")); + } + + req +} + +fn json_error_response(status: StatusCode, message: &str) -> Response { + ( + status, + [("content-type", "application/json")], + serde_json::json!({"error": {"message": message}}).to_string(), + ) + .into_response() } diff --git a/crates/agentic-server/src/main.rs b/crates/agentic-server/src/main.rs index 216b510..f0e8a1a 100644 --- a/crates/agentic-server/src/main.rs +++ b/crates/agentic-server/src/main.rs @@ -21,6 +21,12 @@ struct CommonArgs { #[arg(long, default_value_t = 2.0, global = true)] llm_ready_interval_s: f64, + + #[arg(long, default_value = "http://localhost:8080", global = true)] + ogx_base_url: String, + + #[arg(long, default_value_t = 10, global = true)] + max_iterations: u32, } #[derive(Parser)] @@ -86,7 +92,14 @@ async fn main() -> Result<(), Error> { ) })?; let config = build_config(normalize_base_url(&base), &common); - server::run(config, &common.gateway_host, common.gateway_port).await + server::run( + config, + &common.gateway_host, + common.gateway_port, + &common.ogx_base_url, + common.max_iterations, + ) + .await } Some(Commands::Serve { model, port, llm_args }) => { if llm_api_base.is_some() { @@ -102,7 +115,15 @@ async fn main() -> Result<(), Error> { args.push(port.to_string()); args.extend(llm_args); - server::run_with_llm(config, &common.gateway_host, common.gateway_port, args).await + server::run_with_llm( + config, + &common.gateway_host, + common.gateway_port, + args, + &common.ogx_base_url, + common.max_iterations, + ) + .await } } } diff --git a/crates/agentic-server/src/server.rs b/crates/agentic-server/src/server.rs index e9ce317..b5bfcbd 100644 --- a/crates/agentic-server/src/server.rs +++ b/crates/agentic-server/src/server.rs @@ -1,14 +1,36 @@ +use std::sync::Arc; + use agentic_core::config::Config; use agentic_core::error::Error; use agentic_core::proxy::ProxyState; use agentic_core::readiness::wait_llm_ready; +use agentic_core::vector_search::ogx::OgxStore; use agentic_server::app::{ServerConfig, build_router}; +use agentic_server::handler::AppState; use tokio::net::TcpListener; use tracing::info; -async fn serve_gateway(config: Config, host: &str, port: u16) -> Result<(), Error> { +fn build_app_state(config: Config, ogx_base_url: &str, max_iterations: u32) -> Result, Error> { + let proxy = ProxyState::new(config)?; + let client = reqwest::Client::new(); + let ogx_store = Arc::new(OgxStore::new(ogx_base_url, client)); + + Ok(Arc::new(AppState { + proxy, + max_iterations, + vector_search: ogx_store, + })) +} + +async fn serve_gateway( + config: Config, + host: &str, + port: u16, + ogx_base_url: &str, + max_iterations: u32, +) -> Result<(), Error> { let addr = format!("{host}:{port}"); - let state = ProxyState::new(config)?; + let state = build_app_state(config, ogx_base_url, max_iterations)?; let server_config = ServerConfig::from_env(); let router = build_router(state, &server_config); let listener = TcpListener::bind(&addr).await?; @@ -22,10 +44,10 @@ async fn serve_gateway(config: Config, host: &str, port: u16) -> Result<(), Erro /// # Errors /// /// Returns an error if LLM readiness polling fails or the server cannot bind. -pub async fn run(config: Config, host: &str, port: u16) -> Result<(), Error> { +pub async fn run(config: Config, host: &str, port: u16, ogx_base_url: &str, max_iterations: u32) -> Result<(), Error> { wait_llm_ready(&config).await?; info!("LLM ready: {}", config.llm_api_base); - serve_gateway(config, host, port).await + serve_gateway(config, host, port, ogx_base_url, max_iterations).await } /// Spawn vLLM as a subprocess and run the gateway in the foreground. @@ -33,7 +55,14 @@ pub async fn run(config: Config, host: &str, port: u16) -> Result<(), Error> { /// # Errors /// /// Returns an error if vLLM fails to start or the gateway errors. -pub async fn run_with_llm(config: Config, host: &str, port: u16, llm_args: Vec) -> Result<(), Error> { +pub async fn run_with_llm( + config: Config, + host: &str, + port: u16, + llm_args: Vec, + ogx_base_url: &str, + max_iterations: u32, +) -> Result<(), Error> { let mut cmd = tokio::process::Command::new("python"); cmd.arg("-m").arg("vllm.entrypoints.openai.api_server"); cmd.args(&llm_args); @@ -61,7 +90,7 @@ pub async fn run_with_llm(config: Config, host: &str, port: u16, llm_args: Vec gateway, + gateway = serve_gateway(config, host, port, ogx_base_url, max_iterations) => gateway, status = child.wait() => { let status = status?; Err(Error::LlmProcessExited { diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs new file mode 100644 index 0000000..fbbb5be --- /dev/null +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -0,0 +1,175 @@ +#[allow(dead_code)] +mod common; + +use common::{spawn_ogx, spawn_vllm, spawn_vllm_with_tool_calls, start_gateway}; + +#[tokio::test] +async fn test_passthrough_no_tools() { + let (vllm_port, _h) = spawn_vllm().await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + +#[tokio::test] +async fn test_single_file_search() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Based on the search results..."}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "search for something"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_2"); + assert_eq!(body["output"][0]["type"], "message"); +} + +#[tokio::test] +async fn test_multi_turn_tool_calls() { + let turn1 = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"first query\"}", + "status": "completed" + }] + }); + + let turn2 = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_2", + "call_id": "call_2", + "name": "file_search", + "arguments": "{\"query\": \"second query\"}", + "status": "completed" + }] + }); + + let final_resp = serde_json::json!({ + "id": "resp_3", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "final answer"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![turn1, turn2, final_resp]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "multi-turn search"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_3"); +} + +#[tokio::test] +async fn test_max_iterations_reached() { + let tool_call = serde_json::json!({ + "id": "resp_loop", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_loop", + "call_id": "call_loop", + "name": "file_search", + "arguments": "{\"query\": \"infinite loop\"}", + "status": "completed" + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "search forever"}], + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("exceeded"), "expected max iterations error, got: {msg}"); +} diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs new file mode 100644 index 0000000..2cf0d75 --- /dev/null +++ b/crates/agentic-server/tests/common/mod.rs @@ -0,0 +1,231 @@ +use std::convert::Infallible; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use axum::Router; +use axum::body::Body; +use axum::extract::Request; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use bytes::Bytes; +use futures::stream; +use http::StatusCode; +use tokio::net::TcpListener; + +use agentic_core::config::Config; +use agentic_core::proxy::ProxyState; +use agentic_core::vector_search::ogx::OgxStore; +use agentic_server::handler::AppState; + +fn test_config(llm_port: u16, api_key: Option<&str>) -> Config { + Config { + llm_api_base: format!("http://127.0.0.1:{llm_port}"), + openai_api_key: api_key.map(String::from), + llm_ready_timeout_s: 5.0, + llm_ready_interval_s: 0.1, + } +} + +pub async fn start_gateway(vllm_port: u16, ogx_port: Option, api_key: Option<&str>) -> (String, u16) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let addr = format!("127.0.0.1:{port}"); + + let ogx_base = match ogx_port { + Some(p) => format!("http://127.0.0.1:{p}"), + None => "http://127.0.0.1:1".to_owned(), + }; + + let config = test_config(vllm_port, api_key); + let proxy = ProxyState::new(config).unwrap(); + let client = reqwest::Client::new(); + let ogx_store = Arc::new(OgxStore::new(&ogx_base, client)); + + let state = Arc::new(AppState { + proxy, + max_iterations: 10, + vector_search: ogx_store, + }); + + let server_config = agentic_server::app::ServerConfig::from_env(); + let router = agentic_server::app::build_router(state, &server_config); + + tokio::spawn(async move { + axum::serve(listener, router).await.unwrap(); + }); + + (addr, port) +} + +async fn health_handler() -> impl IntoResponse { + StatusCode::OK +} + +async fn responses_handler(req: Request) -> Response { + let headers = req.headers().clone(); + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + + if body + .get("echo_auth") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + { + let auth = headers.get("authorization").and_then(|v| v.to_str().ok()).unwrap_or(""); + let resp_body = serde_json::json!({"authorization": auth}); + return ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&resp_body).unwrap(), + ) + .into_response(); + } + + if body.get("force_error").and_then(serde_json::Value::as_u64) == Some(429) { + return ( + StatusCode::TOO_MANY_REQUESTS, + [("content-type", "application/json")], + r#"{"error":{"message":"rate limited","code":"rate_limit"}}"#, + ) + .into_response(); + } + + if body.get("stream").and_then(serde_json::Value::as_bool).unwrap_or(false) { + let chunks: Vec> = vec![ + Ok(Bytes::from( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n", + )), + Ok(Bytes::from("data: [DONE]\n\n")), + ]; + let body = Body::from_stream(stream::iter(chunks)); + return ( + StatusCode::OK, + [("content-type", "text/event-stream; charset=utf-8")], + body, + ) + .into_response(); + } + + let out = r#"{"id":"resp_test","object":"response","status":"completed","output":[]}"#; + (StatusCode::OK, [("content-type", "application/json")], out).into_response() +} + +pub async fn spawn_vllm() -> (u16, tokio::task::JoinHandle<()>) { + let app = Router::new() + .route("/health", get(health_handler)) + .route("/v1/responses", post(responses_handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_mid_stream_failure_vllm() -> (u16, tokio::task::JoinHandle<()>) { + async fn handler(_req: Request) -> Response { + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + tokio::spawn(async move { + let _ = tx + .send(Ok(Bytes::from( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n", + ))) + .await; + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + drop(tx); + }); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let body = Body::from_stream(stream); + ( + StatusCode::OK, + [("content-type", "text/event-stream; charset=utf-8")], + body, + ) + .into_response() + } + + let app = Router::new() + .route("/health", get(health_handler)) + .route("/v1/responses", post(handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_vllm_with_tool_calls(responses: Vec) -> (u16, tokio::task::JoinHandle<()>) { + let responses = Arc::new(responses); + let counter = Arc::new(AtomicUsize::new(0)); + + let app = Router::new().route("/health", get(health_handler)).route( + "/v1/responses", + post({ + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + move |_req: Request| { + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + async move { + let idx = counter.fetch_add(1, Ordering::SeqCst); + let resp = responses.get(idx).unwrap_or(responses.last().unwrap()); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(resp).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} + +pub async fn spawn_ogx() -> (u16, tokio::task::JoinHandle<()>) { + async fn search_handler(_req: Request) -> Response { + let body = serde_json::json!({ + "object": "vector_store.search_results.page", + "search_query": ["test query"], + "data": [{ + "file_id": "file_abc", + "filename": "doc.txt", + "score": 0.95, + "attributes": {}, + "content": [{"type": "text", "text": "relevant content from doc"}] + }], + "has_more": false + }); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&body).unwrap(), + ) + .into_response() + } + + let app = Router::new().route("/v1/vector_stores/{store_id}/search", post(search_handler)); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, handle) +} diff --git a/crates/agentic-server/tests/cors_test.rs b/crates/agentic-server/tests/cors_test.rs index e10f05e..eb41fed 100644 --- a/crates/agentic-server/tests/cors_test.rs +++ b/crates/agentic-server/tests/cors_test.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::Router; use axum::response::IntoResponse; use axum::routing::get; @@ -6,6 +8,8 @@ use tokio::net::TcpListener; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::vector_search::ogx::OgxStore; +use agentic_server::handler::AppState; fn test_config(llm_url: &str) -> Config { Config { @@ -27,7 +31,14 @@ async fn spawn_mock_llm() -> (String, tokio::task::JoinHandle<()>) { } async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) { - let state = ProxyState::new(config).unwrap(); + let proxy = ProxyState::new(config).unwrap(); + let client = reqwest::Client::new(); + let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let state = Arc::new(AppState { + proxy, + max_iterations: 10, + vector_search: ogx_store, + }); let server_config = agentic_server::app::ServerConfig::from_env(); let router = agentic_server::app::build_router(state, &server_config); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/crates/agentic-server/tests/health_test.rs b/crates/agentic-server/tests/health_test.rs index 180d65f..f0c92dd 100644 --- a/crates/agentic-server/tests/health_test.rs +++ b/crates/agentic-server/tests/health_test.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::Router; use axum::response::IntoResponse; use axum::routing::get; @@ -6,6 +8,8 @@ use tokio::net::TcpListener; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::vector_search::ogx::OgxStore; +use agentic_server::handler::AppState; fn test_config(llm_url: &str) -> Config { Config { @@ -34,7 +38,14 @@ async fn spawn_mock_llm() -> (String, tokio::task::JoinHandle<()>) { } async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) { - let state = ProxyState::new(config).unwrap(); + let proxy = ProxyState::new(config).unwrap(); + let client = reqwest::Client::new(); + let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let state = Arc::new(AppState { + proxy, + max_iterations: 10, + vector_search: ogx_store, + }); let server_config = agentic_server::app::ServerConfig::from_env(); let router = agentic_server::app::build_router(state, &server_config); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -51,20 +62,18 @@ async fn test_health_returns_200() { let config = test_config(&llm_url); let (gw_url, _h2) = spawn_gateway(config).await; - let resp = reqwest::get(format!("{gw_url}/health")).await.unwrap(); + let client = reqwest::Client::new(); + let resp = client.get(format!("{gw_url}/health")).send().await.unwrap(); assert_eq!(resp.status(), 200); } #[tokio::test] async fn test_health_returns_200_even_when_llm_down() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let dead_addr = listener.local_addr().unwrap(); - drop(listener); - - let config = test_config_no_key(&format!("http://{dead_addr}")); - let (gw_url, _h2) = spawn_gateway(config).await; + let config = test_config_no_key("http://127.0.0.1:1"); + let (gw_url, _h) = spawn_gateway(config).await; - let resp = reqwest::get(format!("{gw_url}/health")).await.unwrap(); + let client = reqwest::Client::new(); + let resp = client.get(format!("{gw_url}/health")).send().await.unwrap(); assert_eq!(resp.status(), 200); } @@ -74,19 +83,17 @@ async fn test_ready_returns_200_when_llm_healthy() { let config = test_config(&llm_url); let (gw_url, _h2) = spawn_gateway(config).await; - let resp = reqwest::get(format!("{gw_url}/ready")).await.unwrap(); + let client = reqwest::Client::new(); + let resp = client.get(format!("{gw_url}/ready")).send().await.unwrap(); assert_eq!(resp.status(), 200); } #[tokio::test] async fn test_ready_returns_503_when_llm_unreachable() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let dead_addr = listener.local_addr().unwrap(); - drop(listener); - - let config = test_config_no_key(&format!("http://{dead_addr}")); - let (gw_url, _h2) = spawn_gateway(config).await; + let config = test_config_no_key("http://127.0.0.1:1"); + let (gw_url, _h) = spawn_gateway(config).await; - let resp = reqwest::get(format!("{gw_url}/ready")).await.unwrap(); + let client = reqwest::Client::new(); + let resp = client.get(format!("{gw_url}/ready")).send().await.unwrap(); assert_eq!(resp.status(), 503); } diff --git a/crates/agentic-server/tests/integration/ogx-config.yaml b/crates/agentic-server/tests/integration/ogx-config.yaml new file mode 100644 index 0000000..22cabf4 --- /dev/null +++ b/crates/agentic-server/tests/integration/ogx-config.yaml @@ -0,0 +1,55 @@ +version: 2 +distro_name: agentic-api-test + +apis: + - inference + - files + - vector_io + - file_processors + +providers: + inference: + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + + files: + - provider_id: localfs + provider_type: inline::localfs + config: + storage_dir: /tmp/ogx-test/files + metadata_store: + table_name: files_metadata + backend: sql_default + + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + persistence: + namespace: vector_io::faiss + backend: kv_default + + file_processors: + - provider_id: auto + provider_type: inline::auto + config: {} + +storage: + backends: + kv_default: + type: kv_sqlite + db_path: /tmp/ogx-test/kvstore.db + sql_default: + type: sql_sqlite + db_path: /tmp/ogx-test/sql_store.db + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + vector_stores: + table_name: vector_store_metadata + backend: sql_default diff --git a/crates/agentic-server/tests/integration/run.sh b/crates/agentic-server/tests/integration/run.sh new file mode 100755 index 0000000..b588f73 --- /dev/null +++ b/crates/agentic-server/tests/integration/run.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +set -euo pipefail + +OGX_PORT="${OGX_PORT:-8321}" +OGX_PID="" + +cleanup() { + if [ -n "$OGX_PID" ] && kill -0 "$OGX_PID" 2>/dev/null; then + echo "Stopping OGx (pid $OGX_PID)..." + kill "$OGX_PID" 2>/dev/null || true + wait "$OGX_PID" 2>/dev/null || true + fi +} +trap cleanup EXIT + +OGX_CMD="${OGX_CMD:-ogx}" +OGX_CONFIG="$(cd "$(dirname "$0")" && pwd)/ogx-config.yaml" + +echo "Starting OGx on port $OGX_PORT..." +HF_HUB_TRUST_REMOTE_CODE=1 $OGX_CMD run "$OGX_CONFIG" --port "$OGX_PORT" > /tmp/ogx-server.log 2>&1 & +OGX_PID=$! + +echo "Waiting for OGx to be ready..." +for i in $(seq 1 60); do + if curl -sf "http://localhost:$OGX_PORT/v1/health" > /dev/null 2>&1; then + echo "OGx is ready." + break + fi + if ! kill -0 "$OGX_PID" 2>/dev/null; then + echo "OGx process exited unexpectedly. Logs:" + cat /tmp/ogx-server.log + exit 1 + fi + sleep 1 +done + +if ! curl -sf "http://localhost:$OGX_PORT/v1/health" > /dev/null 2>&1; then + echo "OGx failed to start within 60s. Logs:" + cat /tmp/ogx-server.log + exit 1 +fi + +echo "Running integration tests..." +OGX_BASE_URL="http://localhost:$OGX_PORT" cargo test -p agentic-server --test integration_test -- --nocapture + +echo "Integration tests passed." diff --git a/crates/agentic-server/tests/integration_test.rs b/crates/agentic-server/tests/integration_test.rs new file mode 100644 index 0000000..6e09893 --- /dev/null +++ b/crates/agentic-server/tests/integration_test.rs @@ -0,0 +1,126 @@ +fn ogx_base_url() -> Option { + std::env::var("OGX_BASE_URL").ok() +} + +async fn find_embedding_model(client: &reqwest::Client, ogx_url: &str) -> (String, u64) { + let models_resp = client.get(format!("{ogx_url}/v1/models")).send().await.unwrap(); + let models: serde_json::Value = models_resp.json().await.unwrap(); + let embedding_model = models["data"] + .as_array() + .and_then(|arr| { + arr.iter() + .find(|m| m["custom_metadata"]["model_type"].as_str() == Some("embedding")) + }) + .expect("OGx should have at least one embedding model") + .clone(); + let model_id = embedding_model["id"].as_str().unwrap().to_owned(); + let dim = embedding_model["custom_metadata"]["embedding_dimension"] + .as_u64() + .unwrap(); + (model_id, dim) +} + +async fn create_vector_store(client: &reqwest::Client, ogx_url: &str, model_id: &str, dim: u64) -> String { + let vs_resp = client + .post(format!("{ogx_url}/v1/vector_stores")) + .json(&serde_json::json!({ + "name": "integration-test-docs", + "metadata": { "embedding_model": model_id, "embedding_dimension": dim } + })) + .send() + .await + .unwrap(); + assert!(vs_resp.status().is_success(), "Failed to create vector store"); + let vs: serde_json::Value = vs_resp.json().await.unwrap(); + vs["id"].as_str().unwrap().to_owned() +} + +async fn upload_and_attach(client: &reqwest::Client, ogx_url: &str, vs_id: &str) { + let file_content = "Rust enforces memory safety without a garbage collector through its ownership system with borrowing and lifetimes. The borrow checker ensures references do not outlive the data they point to."; + + let form = reqwest::multipart::Form::new().text("purpose", "assistants").part( + "file", + reqwest::multipart::Part::text(file_content.to_owned()) + .file_name("rust-memory-safety.txt") + .mime_str("text/plain") + .unwrap(), + ); + + let file_resp = client + .post(format!("{ogx_url}/v1/files")) + .multipart(form) + .send() + .await + .unwrap(); + assert!(file_resp.status().is_success(), "Failed to upload file"); + + let file: serde_json::Value = file_resp.json().await.unwrap(); + let file_id = file["id"].as_str().unwrap(); + eprintln!("Uploaded file: {file_id}"); + + let attach_resp = client + .post(format!("{ogx_url}/v1/vector_stores/{vs_id}/files")) + .json(&serde_json::json!({"file_id": file_id})) + .send() + .await + .unwrap(); + assert!(attach_resp.status().is_success(), "Failed to attach file"); + + let attach: serde_json::Value = attach_resp.json().await.unwrap(); + let status = attach["status"].as_str().unwrap_or("unknown"); + assert_eq!( + status, + "completed", + "File attachment failed: {}", + attach + .get("last_error") + .map_or("none".to_owned(), std::string::ToString::to_string) + ); +} + +#[tokio::test] +async fn test_vector_search_with_ogx() { + let Some(ogx_url) = ogx_base_url() else { + eprintln!("Skipping: OGX_BASE_URL not set"); + return; + }; + + let client = reqwest::Client::new(); + + let (model_id, dim) = find_embedding_model(&client, &ogx_url).await; + eprintln!("Using embedding model: {model_id} (dim={dim})"); + + let vs_id = create_vector_store(&client, &ogx_url, &model_id, dim).await; + eprintln!("Created vector store: {vs_id}"); + + upload_and_attach(&client, &ogx_url, &vs_id).await; + + let search_resp = client + .post(format!("{ogx_url}/v1/vector_stores/{vs_id}/search")) + .json(&serde_json::json!({ + "query": "memory safety ownership", + "max_num_results": 2 + })) + .send() + .await + .unwrap(); + assert!(search_resp.status().is_success(), "Search failed"); + + let results: serde_json::Value = search_resp.json().await.unwrap(); + let data = results["data"].as_array().expect("search should return data array"); + assert!(!data.is_empty(), "search should return at least one result"); + + let top_result = &data[0]; + let score = top_result["score"].as_f64().unwrap_or(0.0); + assert!(score > 0.0, "top result should have a positive score"); + + let content = top_result["content"] + .as_array() + .and_then(|c| c.first()) + .and_then(|c| c["text"].as_str()) + .unwrap_or(""); + assert!(!content.is_empty(), "top result should have content text"); + + eprintln!("Search returned {} results, top score: {score:.3}", data.len()); + eprintln!("Top result: {content}"); +} diff --git a/crates/agentic-server/tests/proxy_test.rs b/crates/agentic-server/tests/proxy_test.rs new file mode 100644 index 0000000..4d06b40 --- /dev/null +++ b/crates/agentic-server/tests/proxy_test.rs @@ -0,0 +1,142 @@ +#[allow(dead_code)] +mod common; + +use common::{spawn_mid_stream_failure_vllm, spawn_vllm, start_gateway}; + +#[tokio::test] +async fn test_non_stream_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + +#[tokio::test] +async fn test_stream_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [{"role": "user", "content": "hello"}], + "stream": true + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + let text = resp.text().await.unwrap(); + assert!(text.contains("data: [DONE]")); + assert!(text.contains("response.output_text.delta")); +} + +#[tokio::test] +async fn test_auth_injection() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "echo_auth": true})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["authorization"], "Bearer env-vllm-key"); +} + +#[tokio::test] +async fn test_client_auth_precedence() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "echo_auth": true})) + .header("authorization", "Bearer client-token") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["authorization"], "Bearer client-token"); +} + +#[tokio::test] +async fn test_vllm_http_error_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": [], "force_error": 429})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 429); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["error"]["message"], "rate limited"); + assert_eq!(body["error"]["code"], "rate_limit"); +} + +#[tokio::test] +async fn test_mid_stream_failure_closes_cleanly() { + let (vllm_port, _h) = spawn_mid_stream_failure_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": [], + "stream": true + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let text = resp.text().await.unwrap_or_default(); + assert!(text.contains("response.output_text.delta")); +} + +#[tokio::test] +async fn test_connect_error_maps_to_502() { + let (gw_addr, _) = start_gateway(1, None, None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({"model": "model-a", "input": []})) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); +} diff --git a/docs/architecture/index.md b/docs/architecture/index.md new file mode 100644 index 0000000..b7a3e7e --- /dev/null +++ b/docs/architecture/index.md @@ -0,0 +1,72 @@ +# Architecture + +## Overview + +The vLLM Agentic API is a Rust gateway that sits between clients and vLLM, adding stateful capabilities on top of vLLM's stateless Responses API. The gateway is structured as a three-crate workspace. + +```mermaid +graph TD + Client -->|POST /v1/responses| Gateway[agentic-server :9000] + Gateway -->|proxy| vLLM[vLLM :8000] + Gateway -.->|vector search| OGx[OGx :8080] +``` + +## Crate Structure + +| Crate | Role | +|-------|------| +| `agentic-core` | Framework-agnostic core: inference caller, storage, vector search traits, OGx client | +| `agentic-server` | Axum HTTP server: routes, handler, CLI, agentic loop | +| `agentic-praxis` | Reserved for Praxis gateway adapter | + +### agentic-core + +Pure async Rust with no framework dependency. Contains: + +- **Proxy** (`proxy.rs`) — HTTP client that forwards requests to vLLM with auth injection, header filtering, and streaming support +- **Readiness** (`readiness.rs`) — Polls vLLM's `/health` endpoint until ready +- **Storage** (`storage/`) — SQLx-based CRUD for conversations and responses (SQLite, PostgreSQL, MySQL) +- **Vector search** (`vector_search/`) — `VectorSearch` trait and OGx implementation for file_search tool calls +- **Types** (`types/`) — Serde structs for the Responses API IO types + +### agentic-server + +Axum-based HTTP server that wires everything together: + +- **Handler** (`handler.rs`) — Request routing: runs the agentic loop if `file_search` tools are present, otherwise proxies to vLLM +- **App** (`app.rs`) — Router with `/health`, `/ready`, `/v1/responses` routes and CORS +- **CLI** (`main.rs`) — Clap-based CLI with `--llm-api-base`, `--ogx-base-url`, `--max-iterations`, and a `serve` subcommand that spawns vLLM as a subprocess + +## Request Flow + +### Passthrough (no tools) + +``` +Client → Gateway → vLLM → Gateway → Client +``` + +The request is forwarded to vLLM unchanged. Streaming responses are proxied as SSE. + +### Agentic Loop (file_search) + +When the request includes `tools: [{type: "file_search", vector_store_ids: [...]}]`: + +1. Convert `file_search` to a `function` tool definition for vLLM +2. Send to vLLM (non-streaming, forced `stream: false`) +3. If vLLM returns `function_call` output items with `name: "file_search"`: + - Extract the query from the call arguments + - Search each vector store via OGx (`POST /v1/vector_stores/{id}/search`) + - Append the tool call output and search results to the input + - Go to step 2 +4. If no tool calls, return the final response to the client +5. If `max_iterations` is reached, return a 502 error + +## OGx Integration + +[OGx](https://github.com/meta-llama/llama-stack) provides the vector search backend via its OpenAI-compatible API: + +| Endpoint | Purpose | +|----------|---------| +| `POST /v1/vector_stores/{id}/search` | Execute vector search for file_search tool calls | + +The `OgxStore` struct implements the `VectorSearch` trait, so the handler depends on the trait, not OGx directly. From cc71bb12a5b8014b46ddb52c501921ffef8a1f66 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 11 Jun 2026 10:01:21 -0400 Subject: [PATCH 2/6] fix: preserve file search gateway semantics Signed-off-by: Francisco Javier Arceo --- crates/agentic-core/src/error.rs | 6 + .../agentic-core/src/vector_search/types.rs | 45 +- crates/agentic-server/benches/proxy_bench.rs | 8 + crates/agentic-server/src/handler.rs | 450 +++++++++++++----- crates/agentic-server/src/main.rs | 5 + crates/agentic-server/src/server.rs | 30 +- .../agentic-server/tests/agentic_loop_test.rs | 101 +++- crates/agentic-server/tests/common/mod.rs | 55 +++ crates/agentic-server/tests/cors_test.rs | 8 + crates/agentic-server/tests/health_test.rs | 8 + crates/agentic-server/tests/proxy_test.rs | 21 + docs/architecture/index.md | 16 +- 12 files changed, 616 insertions(+), 137 deletions(-) diff --git a/crates/agentic-core/src/error.rs b/crates/agentic-core/src/error.rs index 980ee67..54e5627 100644 --- a/crates/agentic-core/src/error.rs +++ b/crates/agentic-core/src/error.rs @@ -32,6 +32,12 @@ pub enum Error { #[error("vLLM returned {status}: {body}")] ProxyResponse { status: u16, body: String }, + #[error("database error")] + Database(#[from] sqlx::Error), + + #[error(transparent)] + StateStore(#[from] crate::storage::StorageError), + #[error("agentic loop exceeded {max_iterations} iterations")] MaxIterations { max_iterations: u32 }, } diff --git a/crates/agentic-core/src/vector_search/types.rs b/crates/agentic-core/src/vector_search/types.rs index c23e97f..061709c 100644 --- a/crates/agentic-core/src/vector_search/types.rs +++ b/crates/agentic-core/src/vector_search/types.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; pub struct ResponseRequest { pub model: String, #[serde(default)] - pub input: Vec, + pub input: ResponseInput, #[serde(default)] pub stream: bool, #[serde(default)] @@ -15,6 +15,49 @@ pub struct ResponseRequest { pub rest: serde_json::Map, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ResponseInput { + Text(String), + Items(Vec), +} + +impl Default for ResponseInput { + fn default() -> Self { + Self::Items(Vec::new()) + } +} + +impl ResponseInput { + #[must_use] + pub fn to_values(&self) -> Vec { + match self { + Self::Text(text) => vec![serde_json::json!({ + "type": "message", + "role": "user", + "content": text + })], + Self::Items(items) => items.clone(), + } + } + + pub fn prepend(&mut self, mut history: Vec) { + history.extend(self.to_values()); + *self = Self::Items(history); + } + + pub fn push(&mut self, item: serde_json::Value) { + match self { + Self::Text(_) => { + let mut items = self.to_values(); + items.push(item); + *self = Self::Items(items); + } + Self::Items(items) => items.push(item), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolConfig { pub r#type: String, diff --git a/crates/agentic-server/benches/proxy_bench.rs b/crates/agentic-server/benches/proxy_bench.rs index 32c41b2..a34047f 100644 --- a/crates/agentic-server/benches/proxy_bench.rs +++ b/crates/agentic-server/benches/proxy_bench.rs @@ -15,6 +15,8 @@ use tokio::runtime::Runtime; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use agentic_core::uuid7_str; use agentic_core::vector_search::ogx::OgxStore; use agentic_server::app::{ServerConfig, build_router}; use agentic_server::handler::AppState; @@ -77,10 +79,16 @@ async fn spawn_gateway(config: Config) -> String { let proxy = ProxyState::new(config).unwrap(); let client = reqwest::Client::new(); let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let db_url = format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-bench-")); + let pool = create_pool_with_schema(Some(&db_url)).await.unwrap(); + let response_store = ResponseStore::new(pool.clone()); + let conversation_store = ConversationStore::new(pool); let state = Arc::new(AppState { proxy, max_iterations: 10, vector_search: ogx_store, + response_store, + conversation_store, }); let server_config = ServerConfig::from_env(); let router = build_router(state, &server_config); diff --git a/crates/agentic-server/src/handler.rs b/crates/agentic-server/src/handler.rs index 8b4c07f..69582bf 100644 --- a/crates/agentic-server/src/handler.rs +++ b/crates/agentic-server/src/handler.rs @@ -4,17 +4,25 @@ use axum::body::Body; use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; -use bytes::Bytes; use tracing::{debug, error, warn}; -use agentic_core::proxy::ProxyState; +use agentic_core::proxy::{ProxyBody, ProxyRequest, ProxyResponse, ProxyState, error_response}; +use agentic_core::storage::{ConversationStore, InOutItem, ResponseMetadata, ResponseStore}; +use agentic_core::types::io::{ + FunctionToolCall, InputItem, InputMessage, InputMessageContent, OutputItem, OutputMessage, ToolChoice, +}; +use agentic_core::utils::uuid7_str; use agentic_core::vector_search::VectorSearch; -use agentic_core::vector_search::types::{ResponseBody, ResponseRequest, SearchResult, VllmOutputItem}; +use agentic_core::vector_search::types::{ResponseBody, ResponseInput, ResponseRequest, SearchResult, VllmOutputItem}; + +const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; pub struct AppState { pub proxy: ProxyState, pub max_iterations: u32, pub vector_search: Arc, + pub response_store: ResponseStore, + pub conversation_store: ConversationStore, } #[allow(clippy::unused_async)] @@ -58,7 +66,28 @@ pub async fn ready(State(state): State>) -> impl IntoResponse { } } -pub async fn handle_responses(State(state): State>, headers: HeaderMap, body: Bytes) -> Response { +fn convert_response(resp: ProxyResponse) -> Response { + let mut builder = Response::builder().status(resp.status); + for (name, value) in &resp.headers { + builder = builder.header(name, value); + } + match resp.body { + ProxyBody::Full(bytes) => builder.body(Body::from(bytes)).expect("valid response"), + ProxyBody::Stream(stream) => builder.body(Body::from_stream(stream)).expect("valid response"), + } +} + +pub async fn handle_responses(State(state): State>, req: axum::extract::Request) -> Response { + let (parts, body) = req.into_parts(); + + let Ok(body) = axum::body::to_bytes(body, MAX_BODY_SIZE).await else { + return convert_response(error_response( + StatusCode::PAYLOAD_TOO_LARGE, + "body_too_large", + "Request body too large", + )); + }; + let request: ResponseRequest = match serde_json::from_slice(&body) { Ok(r) => r, Err(e) => { @@ -73,7 +102,14 @@ pub async fn handle_responses(State(state): State>, headers: Heade let has_file_search = request.tools.iter().any(|t| t.r#type == "file_search"); if has_file_search { - match agentic_loop(&state, &headers, request).await { + if request.stream { + return json_error_response( + StatusCode::BAD_REQUEST, + "streaming file_search is not supported by this gateway path yet", + ); + } + + match agentic_loop(&state, &parts.headers, request).await { Ok(resp) => resp, Err(e) => { error!(error = %e, "agentic loop failed"); @@ -81,7 +117,12 @@ pub async fn handle_responses(State(state): State>, headers: Heade } } } else { - proxy_to_vllm(&state, &headers, &body, request.stream).await + let proxy_req = ProxyRequest { + headers: parts.headers, + body, + query: parts.uri.query().map(String::from), + }; + convert_response(agentic_core::proxy::proxy_request(proxy_req, &state.proxy).await) } } @@ -90,6 +131,12 @@ async fn agentic_loop( client_headers: &HeaderMap, mut request: ResponseRequest, ) -> Result { + let original_input = request.input.clone(); + let original_previous_response_id = request.previous_response_id.clone(); + let mut turn_items = input_values_to_storage_items(&original_input); + + hydrate_previous_response(state, &mut request).await?; + let vector_store_ids: Vec = request .tools .iter() @@ -101,93 +148,35 @@ async fn agentic_loop( for iteration in 0..state.max_iterations { debug!(iteration, "agentic loop iteration"); - let mut loop_request = build_vllm_request(state, client_headers); - - let mut body = serde_json::to_value(&request) - .map_err(|e| agentic_core::error::Error::Config(format!("failed to serialize request: {e}")))?; - if let Some(obj) = body.as_object_mut() { - obj.insert("stream".to_owned(), serde_json::Value::Bool(false)); - if let Some(serde_json::Value::Array(tools)) = obj.get_mut("tools") { - let had_file_search = tools - .iter() - .any(|t| t.get("type").and_then(serde_json::Value::as_str) == Some("file_search")); - tools.retain(|t| t.get("type").and_then(serde_json::Value::as_str) != Some("file_search")); - if had_file_search { - tools.push(serde_json::json!({ - "type": "function", - "name": "file_search", - "description": "Search uploaded files for relevant content. Use this when the user asks about documents or needs information from files.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query to find relevant content in files" - } - }, - "required": ["query"] - } - })); - } - } - } - - loop_request = loop_request.json(&body); + let response_body = match send_agentic_request(state, client_headers, &request).await? { + Ok(response_body) => response_body, + Err(error_response) => return Ok(error_response), + }; + let tool_calls = file_search_tool_calls(&response_body.output); - let resp = loop_request.send().await.map_err(agentic_core::error::Error::Proxy)?; - - let status = resp.status(); - if !status.is_success() { - let resp_body = resp.text().await.unwrap_or_default(); - return Ok(( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - resp_body, - ) - .into_response()); + for output_item in &response_body.output { + if let Some(item) = output_item_to_storage_item(output_item) { + turn_items.push(item); + } } - let response_body: ResponseBody = resp.json().await.map_err(agentic_core::error::Error::Proxy)?; - - let tool_calls: Vec<_> = response_body - .output - .iter() - .filter_map(|item| match item { - VllmOutputItem::FunctionCall { - call_id, - name, - arguments, - .. - } if name == "file_search" => Some((call_id.clone(), arguments.clone())), - _ => None, - }) - .collect(); - if tool_calls.is_empty() { debug!(iteration, "no tool calls, returning final response"); + persist_response( + state, + &request, + &response_body, + original_previous_response_id.as_deref(), + turn_items, + ) + .await?; let final_json = serde_json::to_string(&response_body) .unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_owned()); return Ok((StatusCode::OK, [("content-type", "application/json")], final_json).into_response()); } - for output_item in &response_body.output { - request - .input - .push(serde_json::to_value(output_item).unwrap_or_default()); - } - - for (call_id, arguments) in &tool_calls { - let query = extract_query(arguments); - debug!(%call_id, %query, "executing file_search tool call"); - - let results = execute_file_search(state, &vector_store_ids, &query).await; - - let tool_output = serde_json::json!({ - "type": "function_call_output", - "call_id": call_id, - "output": serde_json::to_string(&results).unwrap_or_default() - }); - request.input.push(tool_output); - } + append_output_items_to_request(&mut request, &response_body.output); + append_tool_outputs(state, &mut request, &mut turn_items, &vector_store_ids, &tool_calls).await; debug!( iteration, @@ -205,6 +194,261 @@ async fn agentic_loop( }) } +fn build_agentic_body(request: &ResponseRequest) -> Result { + let mut body = serde_json::to_value(request) + .map_err(|e| agentic_core::error::Error::Config(format!("failed to serialize request: {e}")))?; + if let Some(obj) = body.as_object_mut() { + obj.insert("stream".to_owned(), serde_json::Value::Bool(false)); + if let Some(serde_json::Value::Array(tools)) = obj.get_mut("tools") { + let had_file_search = tools + .iter() + .any(|t| t.get("type").and_then(serde_json::Value::as_str) == Some("file_search")); + tools.retain(|t| t.get("type").and_then(serde_json::Value::as_str) != Some("file_search")); + if had_file_search { + tools.push(serde_json::json!({ + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant content. Use this when the user asks about documents or needs information from files.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to find relevant content in files" + } + }, + "required": ["query"] + } + })); + } + } + } + Ok(body) +} + +async fn send_agentic_request( + state: &AppState, + client_headers: &HeaderMap, + request: &ResponseRequest, +) -> Result, agentic_core::error::Error> { + let body = build_agentic_body(request)?; + let resp = build_vllm_request(state, client_headers) + .json(&body) + .send() + .await + .map_err(agentic_core::error::Error::Proxy)?; + + let status = resp.status(); + if !status.is_success() { + let resp_body = resp.text().await.unwrap_or_default(); + return Ok(Err(( + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), + resp_body, + ) + .into_response())); + } + + resp.json().await.map(Ok).map_err(agentic_core::error::Error::Proxy) +} + +fn file_search_tool_calls(output: &[VllmOutputItem]) -> Vec<(String, String)> { + output + .iter() + .filter_map(|item| match item { + VllmOutputItem::FunctionCall { + call_id, + name, + arguments, + .. + } if name == "file_search" => Some((call_id.clone(), arguments.clone())), + _ => None, + }) + .collect() +} + +fn append_output_items_to_request(request: &mut ResponseRequest, output: &[VllmOutputItem]) { + for output_item in output { + request + .input + .push(serde_json::to_value(output_item).unwrap_or_default()); + } +} + +async fn append_tool_outputs( + state: &AppState, + request: &mut ResponseRequest, + turn_items: &mut Vec, + vector_store_ids: &[String], + tool_calls: &[(String, String)], +) { + for (call_id, arguments) in tool_calls { + let query = extract_query(arguments); + debug!(%call_id, %query, "executing file_search tool call"); + + let results = execute_file_search(state, vector_store_ids, &query).await; + let tool_output = serde_json::json!({ + "type": "function_call_output", + "call_id": call_id, + "output": serde_json::to_string(&results).unwrap_or_default() + }); + if let Some(item) = input_value_to_storage_item(&tool_output) { + turn_items.push(item); + } + request.input.push(tool_output); + } +} + +async fn hydrate_previous_response( + state: &AppState, + request: &mut ResponseRequest, +) -> Result<(), agentic_core::error::Error> { + let Some(previous_response_id) = request.previous_response_id.clone() else { + return Ok(()); + }; + + let history = state.response_store.rehydrate(&previous_response_id).await?; + let history_values = history + .into_iter() + .filter_map(|item| match item { + InOutItem::Input(input) => serde_json::to_value(input).ok(), + InOutItem::Output(output) => serde_json::to_value(output).ok(), + }) + .collect(); + + request.input.prepend(history_values); + request.previous_response_id = None; + Ok(()) +} + +async fn persist_response( + state: &AppState, + request: &ResponseRequest, + response_body: &ResponseBody, + previous_response_id: Option<&str>, + turn_items: Vec, +) -> Result<(), agentic_core::error::Error> { + let store = request + .rest + .get("store") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true); + if !store { + return Ok(()); + } + + let metadata = ResponseMetadata { + model: request.model.clone(), + previous_response_id: previous_response_id.map(str::to_owned), + effective_tools: None, + effective_tool_choice: ToolChoice::Auto, + effective_instructions: request + .rest + .get("instructions") + .and_then(serde_json::Value::as_str) + .map(str::to_owned), + }; + + if let Some(conversation_id) = request.rest.get("conversation_id").and_then(serde_json::Value::as_str) { + state.conversation_store.get_or_create(conversation_id).await?; + state + .conversation_store + .persist( + conversation_id, + &response_body.id, + previous_response_id, + turn_items, + &metadata, + ) + .await?; + } else { + state + .response_store + .persist(&response_body.id, previous_response_id, turn_items, &metadata) + .await?; + } + + Ok(()) +} + +fn input_values_to_storage_items(input: &ResponseInput) -> Vec { + input + .to_values() + .iter() + .filter_map(input_value_to_storage_item) + .collect() +} + +fn input_value_to_storage_item(value: &serde_json::Value) -> Option { + if let Some(text) = value.as_str() { + return Some(InOutItem::Input(InputItem::Message(InputMessage { + role: "user".to_owned(), + content: InputMessageContent::Text(text.to_owned()), + }))); + } + + let mut value = value.clone(); + if let Some(obj) = value.as_object_mut() { + if !obj.contains_key("type") && obj.contains_key("role") && obj.contains_key("content") { + obj.insert("type".to_owned(), serde_json::Value::String("message".to_owned())); + } + } + + serde_json::from_value::(value) + .ok() + .filter(|item| !matches!(item, InputItem::Unknown)) + .map(InOutItem::Input) +} + +fn output_item_to_storage_item(item: &VllmOutputItem) -> Option { + match item { + VllmOutputItem::Message { fields } => { + let id = fields + .get("id") + .and_then(serde_json::Value::as_str) + .map_or_else(|| uuid7_str("msg_"), str::to_owned); + let role = fields + .get("role") + .and_then(serde_json::Value::as_str) + .unwrap_or("assistant") + .to_owned(); + let status = fields + .get("status") + .and_then(serde_json::Value::as_str) + .unwrap_or("completed") + .to_owned(); + let content = fields + .get("content") + .cloned() + .and_then(|v| serde_json::from_value(v).ok()) + .unwrap_or_default(); + Some(InOutItem::Output(OutputItem::Message(OutputMessage { + id, + role, + status, + content, + }))) + } + VllmOutputItem::FunctionCall { + id, + call_id, + name, + arguments, + rest, + } => Some(InOutItem::Output(OutputItem::FunctionCall(FunctionToolCall { + id: id.clone(), + call_id: call_id.clone(), + name: name.clone(), + arguments: arguments.clone(), + status: rest + .get("status") + .and_then(serde_json::Value::as_str) + .unwrap_or("completed") + .to_owned(), + }))), + VllmOutputItem::Other => None, + } +} + async fn execute_file_search(state: &AppState, vector_store_ids: &[String], query: &str) -> Vec { let mut all_results = Vec::new(); for store_id in vector_store_ids { @@ -225,46 +469,6 @@ fn extract_query(arguments: &str) -> String { .unwrap_or_default() } -async fn proxy_to_vllm(state: &AppState, client_headers: &HeaderMap, body: &Bytes, stream: bool) -> Response { - let req = build_vllm_request(state, client_headers).body(body.clone()); - - let resp = match req.send().await { - Ok(r) => r, - Err(e) => { - error!(error = %e, "failed to connect to vLLM"); - return json_error_response(StatusCode::BAD_GATEWAY, &format!("vLLM connection failed: {e}")); - } - }; - - let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - - if stream { - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("text/event-stream") - .to_owned(); - - let stream = resp.bytes_stream(); - let body = Body::from_stream(stream); - - (status, [("content-type", content_type)], body).into_response() - } else { - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("application/json") - .to_owned(); - - match resp.bytes().await { - Ok(bytes) => (status, [("content-type", content_type)], bytes).into_response(), - Err(e) => json_error_response(StatusCode::BAD_GATEWAY, &format!("failed to read vLLM response: {e}")), - } - } -} - fn build_vllm_request(state: &AppState, client_headers: &HeaderMap) -> reqwest::RequestBuilder { let url = format!("{}/v1/responses", state.proxy.config.llm_api_base); let mut req = state diff --git a/crates/agentic-server/src/main.rs b/crates/agentic-server/src/main.rs index f0e8a1a..1efe373 100644 --- a/crates/agentic-server/src/main.rs +++ b/crates/agentic-server/src/main.rs @@ -27,6 +27,9 @@ struct CommonArgs { #[arg(long, default_value_t = 10, global = true)] max_iterations: u32, + + #[arg(long, env = "DATABASE_URL", hide_env_values = true, global = true)] + database_url: Option, } #[derive(Parser)] @@ -98,6 +101,7 @@ async fn main() -> Result<(), Error> { common.gateway_port, &common.ogx_base_url, common.max_iterations, + common.database_url.as_deref(), ) .await } @@ -122,6 +126,7 @@ async fn main() -> Result<(), Error> { args, &common.ogx_base_url, common.max_iterations, + common.database_url.as_deref(), ) .await } diff --git a/crates/agentic-server/src/server.rs b/crates/agentic-server/src/server.rs index b5bfcbd..840d387 100644 --- a/crates/agentic-server/src/server.rs +++ b/crates/agentic-server/src/server.rs @@ -4,21 +4,32 @@ use agentic_core::config::Config; use agentic_core::error::Error; use agentic_core::proxy::ProxyState; use agentic_core::readiness::wait_llm_ready; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; use agentic_core::vector_search::ogx::OgxStore; use agentic_server::app::{ServerConfig, build_router}; use agentic_server::handler::AppState; use tokio::net::TcpListener; use tracing::info; -fn build_app_state(config: Config, ogx_base_url: &str, max_iterations: u32) -> Result, Error> { +async fn build_app_state( + config: Config, + ogx_base_url: &str, + max_iterations: u32, + database_url: Option<&str>, +) -> Result, Error> { let proxy = ProxyState::new(config)?; let client = reqwest::Client::new(); let ogx_store = Arc::new(OgxStore::new(ogx_base_url, client)); + let pool = create_pool_with_schema(database_url).await?; + let response_store = ResponseStore::new(pool.clone()); + let conversation_store = ConversationStore::new(pool); Ok(Arc::new(AppState { proxy, max_iterations, vector_search: ogx_store, + response_store, + conversation_store, })) } @@ -28,9 +39,10 @@ async fn serve_gateway( port: u16, ogx_base_url: &str, max_iterations: u32, + database_url: Option<&str>, ) -> Result<(), Error> { let addr = format!("{host}:{port}"); - let state = build_app_state(config, ogx_base_url, max_iterations)?; + let state = build_app_state(config, ogx_base_url, max_iterations, database_url).await?; let server_config = ServerConfig::from_env(); let router = build_router(state, &server_config); let listener = TcpListener::bind(&addr).await?; @@ -44,10 +56,17 @@ async fn serve_gateway( /// # Errors /// /// Returns an error if LLM readiness polling fails or the server cannot bind. -pub async fn run(config: Config, host: &str, port: u16, ogx_base_url: &str, max_iterations: u32) -> Result<(), Error> { +pub async fn run( + config: Config, + host: &str, + port: u16, + ogx_base_url: &str, + max_iterations: u32, + database_url: Option<&str>, +) -> Result<(), Error> { wait_llm_ready(&config).await?; info!("LLM ready: {}", config.llm_api_base); - serve_gateway(config, host, port, ogx_base_url, max_iterations).await + serve_gateway(config, host, port, ogx_base_url, max_iterations, database_url).await } /// Spawn vLLM as a subprocess and run the gateway in the foreground. @@ -62,6 +81,7 @@ pub async fn run_with_llm( llm_args: Vec, ogx_base_url: &str, max_iterations: u32, + database_url: Option<&str>, ) -> Result<(), Error> { let mut cmd = tokio::process::Command::new("python"); cmd.arg("-m").arg("vllm.entrypoints.openai.api_server"); @@ -90,7 +110,7 @@ pub async fn run_with_llm( } let result = tokio::select! { - gateway = serve_gateway(config, host, port, ogx_base_url, max_iterations) => gateway, + gateway = serve_gateway(config, host, port, ogx_base_url, max_iterations, database_url) => gateway, status = child.wait() => { let status = status?; Err(Error::LlmProcessExited { diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs index fbbb5be..fcee920 100644 --- a/crates/agentic-server/tests/agentic_loop_test.rs +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -1,7 +1,7 @@ #[allow(dead_code)] mod common; -use common::{spawn_ogx, spawn_vllm, spawn_vllm_with_tool_calls, start_gateway}; +use common::{spawn_ogx, spawn_vllm, spawn_vllm_recording, spawn_vllm_with_tool_calls, start_gateway}; #[tokio::test] async fn test_passthrough_no_tools() { @@ -74,6 +74,105 @@ async fn test_single_file_search() { assert_eq!(body["output"][0]["type"], "message"); } +#[tokio::test] +async fn test_file_search_streaming_rejected() { + let (vllm_port, _h) = spawn_vllm().await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "stream": true, + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 400); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("streaming file_search"), "unexpected error: {msg}"); +} + +#[tokio::test] +async fn test_previous_response_id_hydrates_history() { + let first_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "id": "msg_1", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "first answer"}] + }] + }); + + let second_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "id": "msg_2", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "second answer"}] + }] + }); + + let (vllm_port, requests, _h) = spawn_vllm_recording(vec![first_response, second_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let first = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "first question", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + assert_eq!(first.status(), 200); + + let second = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "follow up", + "previous_response_id": "resp_1", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + let second_status = second.status(); + let second_text = second.text().await.unwrap(); + assert_eq!(second_status, 200, "second response body: {second_text}"); + + let requests = requests.lock().await; + assert_eq!(requests.len(), 2); + assert!(requests[1].get("previous_response_id").is_none()); + let input = requests[1]["input"] + .as_array() + .expect("hydrated input should be an array"); + assert!( + input.len() >= 3, + "expected prior user/output plus follow-up input, got {input:?}" + ); + assert!(input.iter().any(|item| item["content"] == "first question")); + assert!(input.iter().any(|item| item["content"] == "follow up")); +} + #[tokio::test] async fn test_multi_turn_tool_calls() { let turn1 = serde_json::json!({ diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs index 2cf0d75..3804ffd 100644 --- a/crates/agentic-server/tests/common/mod.rs +++ b/crates/agentic-server/tests/common/mod.rs @@ -11,9 +11,12 @@ use bytes::Bytes; use futures::stream; use http::StatusCode; use tokio::net::TcpListener; +use tokio::sync::Mutex; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use agentic_core::uuid7_str; use agentic_core::vector_search::ogx::OgxStore; use agentic_server::handler::AppState; @@ -40,11 +43,17 @@ pub async fn start_gateway(vllm_port: u16, ogx_port: Option, api_key: Optio let proxy = ProxyState::new(config).unwrap(); let client = reqwest::Client::new(); let ogx_store = Arc::new(OgxStore::new(&ogx_base, client)); + let db_url = format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-test-")); + let pool = create_pool_with_schema(Some(&db_url)).await.unwrap(); + let response_store = ResponseStore::new(pool.clone()); + let conversation_store = ConversationStore::new(pool); let state = Arc::new(AppState { proxy, max_iterations: 10, vector_search: ogx_store, + response_store, + conversation_store, }); let server_config = agentic_server::app::ServerConfig::from_env(); @@ -197,6 +206,52 @@ pub async fn spawn_vllm_with_tool_calls(responses: Vec) -> (u (port, handle) } +pub async fn spawn_vllm_recording( + responses: Vec, +) -> (u16, Arc>>, tokio::task::JoinHandle<()>) { + let responses = Arc::new(responses); + let counter = Arc::new(AtomicUsize::new(0)); + let requests = Arc::new(Mutex::new(Vec::new())); + + let app = Router::new().route("/health", get(health_handler)).route( + "/v1/responses", + post({ + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + let requests_for_handler = Arc::clone(&requests); + move |req: Request| { + let responses = Arc::clone(&responses); + let counter = Arc::clone(&counter); + let requests_for_handler = Arc::clone(&requests_for_handler); + async move { + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + requests_for_handler.lock().await.push(body); + + let idx = counter.fetch_add(1, Ordering::SeqCst); + let resp = responses.get(idx).unwrap_or(responses.last().unwrap()); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(resp).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, requests, handle) +} + pub async fn spawn_ogx() -> (u16, tokio::task::JoinHandle<()>) { async fn search_handler(_req: Request) -> Response { let body = serde_json::json!({ diff --git a/crates/agentic-server/tests/cors_test.rs b/crates/agentic-server/tests/cors_test.rs index eb41fed..6e5deb5 100644 --- a/crates/agentic-server/tests/cors_test.rs +++ b/crates/agentic-server/tests/cors_test.rs @@ -8,6 +8,8 @@ use tokio::net::TcpListener; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use agentic_core::uuid7_str; use agentic_core::vector_search::ogx::OgxStore; use agentic_server::handler::AppState; @@ -34,10 +36,16 @@ async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) let proxy = ProxyState::new(config).unwrap(); let client = reqwest::Client::new(); let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let db_url = format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-test-")); + let pool = create_pool_with_schema(Some(&db_url)).await.unwrap(); + let response_store = ResponseStore::new(pool.clone()); + let conversation_store = ConversationStore::new(pool); let state = Arc::new(AppState { proxy, max_iterations: 10, vector_search: ogx_store, + response_store, + conversation_store, }); let server_config = agentic_server::app::ServerConfig::from_env(); let router = agentic_server::app::build_router(state, &server_config); diff --git a/crates/agentic-server/tests/health_test.rs b/crates/agentic-server/tests/health_test.rs index f0c92dd..f75c943 100644 --- a/crates/agentic-server/tests/health_test.rs +++ b/crates/agentic-server/tests/health_test.rs @@ -8,6 +8,8 @@ use tokio::net::TcpListener; use agentic_core::config::Config; use agentic_core::proxy::ProxyState; +use agentic_core::storage::{ConversationStore, ResponseStore, create_pool_with_schema}; +use agentic_core::uuid7_str; use agentic_core::vector_search::ogx::OgxStore; use agentic_server::handler::AppState; @@ -41,10 +43,16 @@ async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) let proxy = ProxyState::new(config).unwrap(); let client = reqwest::Client::new(); let ogx_store = Arc::new(OgxStore::new("http://127.0.0.1:1", client)); + let db_url = format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-test-")); + let pool = create_pool_with_schema(Some(&db_url)).await.unwrap(); + let response_store = ResponseStore::new(pool.clone()); + let conversation_store = ConversationStore::new(pool); let state = Arc::new(AppState { proxy, max_iterations: 10, vector_search: ogx_store, + response_store, + conversation_store, }); let server_config = agentic_server::app::ServerConfig::from_env(); let router = agentic_server::app::build_router(state, &server_config); diff --git a/crates/agentic-server/tests/proxy_test.rs b/crates/agentic-server/tests/proxy_test.rs index 4d06b40..77dc658 100644 --- a/crates/agentic-server/tests/proxy_test.rs +++ b/crates/agentic-server/tests/proxy_test.rs @@ -24,6 +24,27 @@ async fn test_non_stream_passthrough() { assert_eq!(body["id"], "resp_test"); } +#[tokio::test] +async fn test_string_input_passthrough() { + let (vllm_port, _h) = spawn_vllm().await; + let (gw_addr, _) = start_gateway(vllm_port, None, Some("env-vllm-key")).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "hello" + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["id"], "resp_test"); +} + #[tokio::test] async fn test_stream_passthrough() { let (vllm_port, _h) = spawn_vllm().await; diff --git a/docs/architecture/index.md b/docs/architecture/index.md index b7a3e7e..85f5f16 100644 --- a/docs/architecture/index.md +++ b/docs/architecture/index.md @@ -35,7 +35,7 @@ Axum-based HTTP server that wires everything together: - **Handler** (`handler.rs`) — Request routing: runs the agentic loop if `file_search` tools are present, otherwise proxies to vLLM - **App** (`app.rs`) — Router with `/health`, `/ready`, `/v1/responses` routes and CORS -- **CLI** (`main.rs`) — Clap-based CLI with `--llm-api-base`, `--ogx-base-url`, `--max-iterations`, and a `serve` subcommand that spawns vLLM as a subprocess +- **CLI** (`main.rs`) — Clap-based CLI with `--llm-api-base`, `--ogx-base-url`, `--max-iterations`, `--database-url`, and a `serve` subcommand that spawns vLLM as a subprocess ## Request Flow @@ -51,15 +51,17 @@ The request is forwarded to vLLM unchanged. Streaming responses are proxied as S When the request includes `tools: [{type: "file_search", vector_store_ids: [...]}]`: -1. Convert `file_search` to a `function` tool definition for vLLM -2. Send to vLLM (non-streaming, forced `stream: false`) -3. If vLLM returns `function_call` output items with `name: "file_search"`: +1. Reject streaming file-search requests until the tool loop can emit interleaved SSE events +2. Hydrate `previous_response_id` history from the response store, if present +3. Convert `file_search` to a `function` tool definition for vLLM +4. Send to vLLM (non-streaming, forced `stream: false`) +5. If vLLM returns `function_call` output items with `name: "file_search"`: - Extract the query from the call arguments - Search each vector store via OGx (`POST /v1/vector_stores/{id}/search`) - Append the tool call output and search results to the input - - Go to step 2 -4. If no tool calls, return the final response to the client -5. If `max_iterations` is reached, return a 502 error + - Go to step 4 +6. If no tool calls, persist the response and return the final response to the client +7. If `max_iterations` is reached, return a 502 error ## OGx Integration From 659608377bb6f85e8ffe4c8ad553316f8ccefa9f Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Tue, 16 Jun 2026 16:31:44 -0400 Subject: [PATCH 3/6] fix: avoid pre-commit shebang false positive Signed-off-by: Francisco Javier Arceo --- crates/agentic-server/tests/agentic_loop_test.rs | 1 + crates/agentic-server/tests/common/mod.rs | 2 -- crates/agentic-server/tests/conversations_test.rs | 1 + crates/agentic-server/tests/cors_test.rs | 1 + crates/agentic-server/tests/health_test.rs | 1 + crates/agentic-server/tests/proxy_test.rs | 1 + crates/agentic-server/tests/responses_test.rs | 1 + 7 files changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs index caaa917..00e1254 100644 --- a/crates/agentic-server/tests/agentic_loop_test.rs +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use common::{spawn_ogx, spawn_vllm, spawn_vllm_recording, spawn_vllm_with_tool_calls, start_gateway}; diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs index a0129ae..f1ce964 100644 --- a/crates/agentic-server/tests/common/mod.rs +++ b/crates/agentic-server/tests/common/mod.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use std::convert::Infallible; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; diff --git a/crates/agentic-server/tests/conversations_test.rs b/crates/agentic-server/tests/conversations_test.rs index 44204bb..48894e8 100644 --- a/crates/agentic-server/tests/conversations_test.rs +++ b/crates/agentic-server/tests/conversations_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use http::StatusCode; diff --git a/crates/agentic-server/tests/cors_test.rs b/crates/agentic-server/tests/cors_test.rs index 4719e9f..84df570 100644 --- a/crates/agentic-server/tests/cors_test.rs +++ b/crates/agentic-server/tests/cors_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use common::{spawn_gateway, spawn_mock_llm, test_config, test_state}; diff --git a/crates/agentic-server/tests/health_test.rs b/crates/agentic-server/tests/health_test.rs index 72d6dcb..e6b6c3e 100644 --- a/crates/agentic-server/tests/health_test.rs +++ b/crates/agentic-server/tests/health_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use agentic_core::config::Config; diff --git a/crates/agentic-server/tests/proxy_test.rs b/crates/agentic-server/tests/proxy_test.rs index d1ba134..d06c122 100644 --- a/crates/agentic-server/tests/proxy_test.rs +++ b/crates/agentic-server/tests/proxy_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use common::{spawn_mid_stream_failure_vllm, spawn_vllm, start_gateway}; diff --git a/crates/agentic-server/tests/responses_test.rs b/crates/agentic-server/tests/responses_test.rs index d639679..449ddc9 100644 --- a/crates/agentic-server/tests/responses_test.rs +++ b/crates/agentic-server/tests/responses_test.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; use axum::Router; From d3de710c442eea070e4c6743cd41498b9dc08af8 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Tue, 16 Jun 2026 20:53:55 -0400 Subject: [PATCH 4/6] fix: surface file search backend failures Signed-off-by: Francisco Javier Arceo --- crates/agentic-core/src/executor/engine.rs | 6 +- crates/agentic-core/src/executor/error.rs | 6 +- .../agentic-server/tests/agentic_loop_test.rs | 51 ++++++++++++ crates/agentic-server/tests/common/mod.rs | 16 ++-- .../agentic-server/tests/integration_test.rs | 83 +++++++++++++++++++ 5 files changed, 154 insertions(+), 8 deletions(-) diff --git a/crates/agentic-core/src/executor/engine.rs b/crates/agentic-core/src/executor/engine.rs index 3b2fc2a..b520f6a 100644 --- a/crates/agentic-core/src/executor/engine.rs +++ b/crates/agentic-core/src/executor/engine.rs @@ -430,7 +430,11 @@ async fn run_file_search_loop(mut ctx: RequestContext, exec_ctx: &ExecutionConte for store_id in &store_ids { match vector_search.search(store_id, &query).await { Ok(mut store_results) => results.append(&mut store_results), - Err(err) => warn!(%store_id, %query, "file_search vector lookup failed: {err}"), + Err(err) => { + return Err(ExecutorError::ToolExecution(format!( + "file_search vector lookup failed for vector store {store_id}: {err}" + ))); + } } } diff --git a/crates/agentic-core/src/executor/error.rs b/crates/agentic-core/src/executor/error.rs index 5df3bb8..dee656e 100644 --- a/crates/agentic-core/src/executor/error.rs +++ b/crates/agentic-core/src/executor/error.rs @@ -55,6 +55,9 @@ pub enum ExecutorError { #[error("agentic loop exceeded {max_iterations} iterations")] MaxIterations { max_iterations: u32 }, + #[error("{0}")] + ToolExecution(String), + #[error("invalid request: {0}")] InvalidRequest(String), } @@ -65,7 +68,7 @@ impl ExecutorError { pub fn http_status(&self) -> StatusCode { match self { Self::Storage(e) if e.is_not_found() => StatusCode::NOT_FOUND, - Self::MaxIterations { .. } => StatusCode::BAD_GATEWAY, + Self::MaxIterations { .. } | Self::ToolExecution(_) => StatusCode::BAD_GATEWAY, Self::LLMRequest { status, .. } => *status, Self::InvalidRequest(_) | Self::JsonError(_) => StatusCode::BAD_REQUEST, Self::ParseError(_) => StatusCode::UNPROCESSABLE_ENTITY, @@ -79,6 +82,7 @@ impl ExecutorError { match self { Self::Storage(e) if e.is_not_found() => "not_found", Self::LLMRequest { .. } => "upstream_error", + Self::ToolExecution(_) => "tool_execution_error", Self::InvalidRequest(_) | Self::ParseError(_) | Self::JsonError(_) => "invalid_request_error", _ => "server_error", } diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs index 00e1254..695ef72 100644 --- a/crates/agentic-server/tests/agentic_loop_test.rs +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -75,6 +75,57 @@ async fn test_single_file_search() { assert_eq!(body["output"][0]["type"], "message"); } +#[tokio::test] +async fn test_file_search_backend_failure_returns_error() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer without search context"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (gw_addr, _) = start_gateway(vllm_port, None, None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!( + msg.contains("file_search vector lookup failed"), + "unexpected error: {msg}" + ); +} + #[tokio::test] async fn test_file_search_streaming_rejected() { let (vllm_port, _h) = spawn_vllm().await; diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs index f1ce964..f7cc66f 100644 --- a/crates/agentic-server/tests/common/mod.rs +++ b/crates/agentic-server/tests/common/mod.rs @@ -68,24 +68,28 @@ pub async fn spawn_gateway(state: AppState) -> (String, tokio::task::JoinHandle< } pub async fn start_gateway(vllm_port: u16, ogx_port: Option, api_key: Option<&str>) -> (String, u16) { + let ogx_base = match ogx_port { + Some(p) => format!("http://127.0.0.1:{p}"), + None => "http://127.0.0.1:1".to_owned(), + }; + start_gateway_with_ogx_base(vllm_port, &ogx_base, api_key).await +} + +pub async fn start_gateway_with_ogx_base(vllm_port: u16, ogx_base: &str, api_key: Option<&str>) -> (String, u16) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let addr = format!("127.0.0.1:{port}"); let llm_url = format!("http://127.0.0.1:{vllm_port}"); - let ogx_base = match ogx_port { - Some(p) => format!("http://127.0.0.1:{p}"), - None => "http://127.0.0.1:1".to_owned(), - }; let mut config = test_config(&llm_url); config.openai_api_key = api_key.map(String::from); - config.ogx_base_url.clone_from(&ogx_base); + config.ogx_base_url = ogx_base.to_owned(); config.db_url = Some(format!("sqlite:///tmp/{}.db", uuid7_str("agentic-api-test-"))); let proxy_state = ProxyState::new(config.clone()).unwrap(); let pool = create_pool_with_schema(config.db_url.as_deref()).await.unwrap(); - let ogx_store = Arc::new(OgxStore::new(&ogx_base, reqwest::Client::new())); + let ogx_store = Arc::new(OgxStore::new(ogx_base, reqwest::Client::new())); let exec_ctx = ExecutionContext::new( ConversationHandler::new(ConversationStore::new(pool.clone())), ResponseHandler::new(ResponseStore::new(pool)), diff --git a/crates/agentic-server/tests/integration_test.rs b/crates/agentic-server/tests/integration_test.rs index 6e09893..12fef02 100644 --- a/crates/agentic-server/tests/integration_test.rs +++ b/crates/agentic-server/tests/integration_test.rs @@ -1,3 +1,8 @@ +#[allow(dead_code)] +mod common; + +use common::{spawn_vllm_recording, start_gateway_with_ogx_base}; + fn ogx_base_url() -> Option { std::env::var("OGX_BASE_URL").ok() } @@ -78,6 +83,82 @@ async fn upload_and_attach(client: &reqwest::Client, ogx_url: &str, vs_id: &str) ); } +async fn assert_gateway_file_search_uses_ogx(client: &reqwest::Client, ogx_url: &str, vs_id: &str) { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"memory safety ownership\"}", + "status": "completed" + }] + }); + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Rust uses ownership and borrowing."}] + }] + }); + let (vllm_port, requests, _vllm_handle) = spawn_vllm_recording(vec![tool_call_response, final_response]).await; + let (gateway_addr, _) = start_gateway_with_ogx_base(vllm_port, ogx_url, None).await; + + let gateway_resp = client + .post(format!("http://{gateway_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "How does Rust provide memory safety?", + "tools": [{"type": "file_search", "vector_store_ids": [vs_id]}] + })) + .send() + .await + .unwrap(); + assert!( + gateway_resp.status().is_success(), + "gateway file_search failed: {}", + gateway_resp.text().await.unwrap_or_default() + ); + + let requests = requests.lock().await; + assert_eq!( + requests.len(), + 2, + "gateway should call vLLM before and after OGX search" + ); + + let first_tools = requests[0]["tools"] + .as_array() + .expect("first request should include tools"); + assert_eq!(first_tools[0]["type"], "function"); + assert_eq!(first_tools[0]["name"], "file_search"); + + let second_input = requests[1]["input"] + .as_array() + .expect("second request should include input items"); + let tool_output = second_input + .iter() + .find(|item| item["type"] == "function_call_output") + .expect("gateway should append function_call_output"); + let output = tool_output["output"] + .as_str() + .expect("tool output should be a JSON string"); + let output_json: serde_json::Value = serde_json::from_str(output).expect("tool output should parse as JSON"); + let gateway_results = output_json["results"] + .as_array() + .expect("tool output should include results"); + assert!( + !gateway_results.is_empty(), + "gateway should pass OGX search results back to vLLM" + ); +} + #[tokio::test] async fn test_vector_search_with_ogx() { let Some(ogx_url) = ogx_base_url() else { @@ -123,4 +204,6 @@ async fn test_vector_search_with_ogx() { eprintln!("Search returned {} results, top score: {score:.3}", data.len()); eprintln!("Top result: {content}"); + + assert_gateway_file_search_uses_ogx(&client, &ogx_url, &vs_id).await; } From 4a119460b61105c26c8277980738245a859dad1b Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Tue, 16 Jun 2026 23:26:40 -0400 Subject: [PATCH 5/6] fix: harden file search tool boundary Signed-off-by: Francisco Javier Arceo --- crates/agentic-core/src/executor/engine.rs | 114 +++++++--- crates/agentic-core/src/types/io.rs | 4 + crates/agentic-core/src/vector_search/mod.rs | 9 +- crates/agentic-core/src/vector_search/ogx.rs | 33 +-- .../agentic-core/src/vector_search/types.rs | 11 + .../agentic-server/tests/agentic_loop_test.rs | 215 +++++++++++++++++- crates/agentic-server/tests/common/mod.rs | 48 ++++ 7 files changed, 384 insertions(+), 50 deletions(-) diff --git a/crates/agentic-core/src/executor/engine.rs b/crates/agentic-core/src/executor/engine.rs index b520f6a..2ba11a9 100644 --- a/crates/agentic-core/src/executor/engine.rs +++ b/crates/agentic-core/src/executor/engine.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use async_stream::stream; use either::Either; use futures::{Stream, StreamExt}; +use serde::Deserialize; use tracing::warn; use crate::executor::accumulator::ResponseAccumulator; @@ -25,6 +26,7 @@ use crate::types::io::{ use crate::types::request_response::{RequestPayload, ResponsePayload}; use crate::utils::common::serialize_to_string; use crate::utils::uuid7_str; +use crate::vector_search::types::SearchOptions; use std::time::Duration; @@ -304,16 +306,44 @@ fn contains_file_search(tools: Option<&[ResponsesTool]>) -> bool { tools.is_some_and(|tools| tools.iter().any(|tool| matches!(tool, ResponsesTool::FileSearch(_)))) } -fn file_search_store_ids(tools: Option<&[ResponsesTool]>) -> Vec { - tools - .into_iter() - .flatten() - .filter_map(|tool| match tool { - ResponsesTool::FileSearch(tool) => Some(tool.vector_store_ids.iter().cloned()), - ResponsesTool::Function(_) | ResponsesTool::Unknown => None, - }) - .flatten() - .collect() +#[derive(Clone)] +struct FileSearchConfig { + store_ids: Vec, + options: SearchOptions, +} + +fn file_search_config(tools: Option<&[ResponsesTool]>) -> ExecutorResult { + let mut store_ids = Vec::new(); + let mut options = None::; + + for tool in tools.into_iter().flatten() { + match tool { + ResponsesTool::FileSearch(tool) => { + store_ids.extend(tool.vector_store_ids.iter().filter(|id| !id.is_empty()).cloned()); + if options + .as_ref() + .is_some_and(|existing| existing != &tool.search_options) + { + return Err(ExecutorError::InvalidRequest( + "multiple file_search tools with different search options are not supported".into(), + )); + } + options.get_or_insert_with(|| tool.search_options.clone()); + } + ResponsesTool::Function(_) | ResponsesTool::Unknown => {} + } + } + + if store_ids.is_empty() { + return Err(ExecutorError::InvalidRequest( + "file_search requires at least one vector_store_ids entry".into(), + )); + } + + Ok(FileSearchConfig { + store_ids, + options: options.unwrap_or_default(), + }) } fn file_search_function_tool() -> ResponsesTool { @@ -348,26 +378,44 @@ fn translate_file_search_tools(tools: Option<&[ResponsesTool]>) -> Option Vec { - output - .iter() - .filter_map(|item| match item { - OutputItem::FunctionCall(call) if call.name == "file_search" => Some(call.clone()), - OutputItem::Message(_) | OutputItem::FunctionCall(_) | OutputItem::Unknown => None, - }) - .collect() +fn file_search_calls(output: &[OutputItem]) -> ExecutorResult> { + let mut file_search_calls = Vec::new(); + let mut other_tool_names = Vec::new(); + + for item in output { + match item { + OutputItem::FunctionCall(call) if call.name == "file_search" => file_search_calls.push(call.clone()), + OutputItem::FunctionCall(call) => other_tool_names.push(call.name.clone()), + OutputItem::Message(_) | OutputItem::Unknown => {} + } + } + + if !file_search_calls.is_empty() && !other_tool_names.is_empty() { + return Err(ExecutorError::ToolExecution(format!( + "mixed tool calls are not supported in file_search loop: {}", + other_tool_names.join(", ") + ))); + } + + Ok(file_search_calls) } -fn query_from_arguments(arguments: &str) -> String { - serde_json::from_str::(arguments) - .ok() - .and_then(|value| { - value - .get("query") - .and_then(serde_json::Value::as_str) - .map(str::to_string) - }) - .unwrap_or_default() +#[derive(Deserialize)] +struct FileSearchArguments { + query: String, +} + +fn query_from_arguments(arguments: &str) -> ExecutorResult { + let args = serde_json::from_str::(arguments) + .map_err(|err| ExecutorError::ToolExecution(format!("invalid file_search arguments: {err}")))?; + + if args.query.trim().is_empty() { + return Err(ExecutorError::ToolExecution( + "file_search query argument is required".into(), + )); + } + + Ok(args.query) } fn append_input_item(input: &mut ResponsesInput, item: InputItem) { @@ -389,7 +437,7 @@ async fn run_file_search_loop(mut ctx: RequestContext, exec_ctx: &ExecutionConte )); }; - let store_ids = file_search_store_ids(ctx.enriched_request.tools.as_deref()); + let file_search = file_search_config(ctx.enriched_request.tools.as_deref())?; ctx.enriched_request.tools = translate_file_search_tools(ctx.enriched_request.tools.as_deref()); let url = exec_ctx.responses_url(); @@ -404,7 +452,7 @@ async fn run_file_search_loop(mut ctx: RequestContext, exec_ctx: &ExecutionConte ctx.original_request.instructions.as_deref(), ); - let tool_calls = file_search_calls(&payload.output); + let tool_calls = file_search_calls(&payload.output)?; if tool_calls.is_empty() { ctx.inject_ids(&mut payload); let should_persist = ctx.original_request.store @@ -425,10 +473,10 @@ async fn run_file_search_loop(mut ctx: RequestContext, exec_ctx: &ExecutionConte append_input_item(&mut ctx.enriched_request.input, input_call.clone()); ctx.new_input_items.push(input_call); - let query = query_from_arguments(&call.arguments); + let query = query_from_arguments(&call.arguments)?; let mut results = Vec::new(); - for store_id in &store_ids { - match vector_search.search(store_id, &query).await { + for store_id in &file_search.store_ids { + match vector_search.search(store_id, &query, &file_search.options).await { Ok(mut store_results) => results.append(&mut store_results), Err(err) => { return Err(ExecutorError::ToolExecution(format!( diff --git a/crates/agentic-core/src/types/io.rs b/crates/agentic-core/src/types/io.rs index ab52e40..949abd0 100644 --- a/crates/agentic-core/src/types/io.rs +++ b/crates/agentic-core/src/types/io.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_json::{Map, Value}; +use crate::vector_search::types::SearchOptions; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InputTextContent { #[serde(rename = "type")] @@ -207,6 +209,8 @@ pub struct FunctionTool { pub struct FileSearchTool { #[serde(default)] pub vector_store_ids: Vec, + #[serde(default, flatten)] + pub search_options: SearchOptions, #[serde(flatten)] pub rest: Map, } diff --git a/crates/agentic-core/src/vector_search/mod.rs b/crates/agentic-core/src/vector_search/mod.rs index 267186b..9a97377 100644 --- a/crates/agentic-core/src/vector_search/mod.rs +++ b/crates/agentic-core/src/vector_search/mod.rs @@ -3,9 +3,14 @@ pub mod types; use async_trait::async_trait; -use types::SearchResult; +use types::{SearchOptions, SearchResult}; #[async_trait] pub trait VectorSearch: Send + Sync { - async fn search(&self, store_id: &str, query: &str) -> Result, crate::error::Error>; + async fn search( + &self, + store_id: &str, + query: &str, + options: &SearchOptions, + ) -> Result, crate::error::Error>; } diff --git a/crates/agentic-core/src/vector_search/ogx.rs b/crates/agentic-core/src/vector_search/ogx.rs index 88b99a6..77fbe85 100644 --- a/crates/agentic-core/src/vector_search/ogx.rs +++ b/crates/agentic-core/src/vector_search/ogx.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; +use serde_json::{Map, Value}; use tracing::debug; -use super::types::{SearchResponse, SearchResult}; +use super::types::{SearchOptions, SearchResponse, SearchResult}; use crate::error::Error; pub struct OgxStore { @@ -19,20 +20,24 @@ impl OgxStore { #[async_trait] impl super::VectorSearch for OgxStore { - async fn search(&self, store_id: &str, query: &str) -> Result, Error> { + async fn search(&self, store_id: &str, query: &str, options: &SearchOptions) -> Result, Error> { let url = format!("{}/v1/vector_stores/{store_id}/search", self.base_url); - debug!(%url, %query, "searching vector store via OGx"); - - let resp = self - .client - .post(&url) - .json(&serde_json::json!({ - "query": query, - "max_num_results": 10 - })) - .send() - .await - .map_err(Error::Store)?; + debug!(%url, "searching vector store via OGx"); + + let mut body = Map::new(); + body.insert("query".to_owned(), Value::String(query.to_owned())); + body.insert( + "max_num_results".to_owned(), + Value::from(options.max_num_results.unwrap_or(10)), + ); + if let Some(filters) = &options.filters { + body.insert("filters".to_owned(), filters.clone()); + } + if let Some(ranking_options) = &options.ranking_options { + body.insert("ranking_options".to_owned(), ranking_options.clone()); + } + + let resp = self.client.post(&url).json(&body).send().await.map_err(Error::Store)?; let status = resp.status(); if !status.is_success() { diff --git a/crates/agentic-core/src/vector_search/types.rs b/crates/agentic-core/src/vector_search/types.rs index 061709c..7b9c218 100644 --- a/crates/agentic-core/src/vector_search/types.rs +++ b/crates/agentic-core/src/vector_search/types.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResponseRequest { @@ -104,6 +105,16 @@ pub struct SearchResponse { pub data: Vec, } +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct SearchOptions { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub filters: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_num_results: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub ranking_options: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { pub file_id: String, diff --git a/crates/agentic-server/tests/agentic_loop_test.rs b/crates/agentic-server/tests/agentic_loop_test.rs index 695ef72..6a9c15b 100644 --- a/crates/agentic-server/tests/agentic_loop_test.rs +++ b/crates/agentic-server/tests/agentic_loop_test.rs @@ -1,7 +1,9 @@ #[allow(dead_code)] mod common; -use common::{spawn_ogx, spawn_vllm, spawn_vllm_recording, spawn_vllm_with_tool_calls, start_gateway}; +use common::{ + spawn_ogx, spawn_ogx_recording, spawn_vllm, spawn_vllm_recording, spawn_vllm_with_tool_calls, start_gateway, +}; #[tokio::test] async fn test_passthrough_no_tools() { @@ -126,6 +128,217 @@ async fn test_file_search_backend_failure_returns_error() { ); } +#[tokio::test] +async fn test_file_search_rejects_missing_query_argument() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{}", + "status": "completed" + }] + }); + + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer without a query"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_123"]}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("query"), "unexpected error: {msg}"); +} + +#[tokio::test] +async fn test_file_search_rejects_empty_vector_store_ids_before_vllm() { + let response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "should not be called"}] + }] + }); + + let (vllm_port, requests, _h) = spawn_vllm_recording(vec![response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{"type": "file_search", "vector_store_ids": []}] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 400); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("vector_store_ids"), "unexpected error: {msg}"); + assert!( + requests.lock().await.is_empty(), + "gateway should reject before calling vLLM" + ); +} + +#[tokio::test] +async fn test_file_search_preserves_search_options() { + let tool_call_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [{ + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }] + }); + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Based on filtered search results..."}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![tool_call_response, final_response]).await; + let (ogx_port, ogx_requests, _h2) = spawn_ogx_recording().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search for something", + "tools": [{ + "type": "file_search", + "vector_store_ids": ["vs_123"], + "max_num_results": 3, + "filters": {"type": "eq", "key": "tenant_id", "value": "tenant-a"}, + "ranking_options": {"ranker": "default", "score_threshold": 0.25} + }] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + let requests = ogx_requests.lock().await; + assert_eq!(requests.len(), 1); + assert_eq!(requests[0]["query"], "test query"); + assert_eq!(requests[0]["max_num_results"], 3); + assert_eq!(requests[0]["filters"]["key"], "tenant_id"); + assert_eq!(requests[0]["ranking_options"]["score_threshold"], 0.25); +} + +#[tokio::test] +async fn test_file_search_rejects_mixed_tool_calls() { + let mixed_response = serde_json::json!({ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "fc_1", + "call_id": "call_1", + "name": "file_search", + "arguments": "{\"query\": \"test query\"}", + "status": "completed" + }, + { + "type": "function_call", + "id": "fc_2", + "call_id": "call_2", + "name": "get_weather", + "arguments": "{\"city\": \"SF\"}", + "status": "completed" + } + ] + }); + let final_response = serde_json::json!({ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer after dropping get_weather"}] + }] + }); + + let (vllm_port, _h) = spawn_vllm_with_tool_calls(vec![mixed_response, final_response]).await; + let (ogx_port, _h2) = spawn_ogx().await; + let (gw_addr, _) = start_gateway(vllm_port, Some(ogx_port), None).await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{gw_addr}/v1/responses")) + .json(&serde_json::json!({ + "model": "model-a", + "input": "search and call another function", + "tools": [ + {"type": "file_search", "vector_store_ids": ["vs_123"]}, + { + "type": "function", + "name": "get_weather", + "description": "Get weather.", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}} + } + ] + })) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 502); + let body: serde_json::Value = resp.json().await.unwrap(); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!(msg.contains("mixed tool calls"), "unexpected error: {msg}"); +} + #[tokio::test] async fn test_file_search_streaming_rejected() { let (vllm_port, _h) = spawn_vllm().await; diff --git a/crates/agentic-server/tests/common/mod.rs b/crates/agentic-server/tests/common/mod.rs index f7cc66f..163ced2 100644 --- a/crates/agentic-server/tests/common/mod.rs +++ b/crates/agentic-server/tests/common/mod.rs @@ -333,3 +333,51 @@ pub async fn spawn_ogx() -> (u16, tokio::task::JoinHandle<()>) { (port, handle) } + +pub async fn spawn_ogx_recording() -> (u16, Arc>>, tokio::task::JoinHandle<()>) { + let requests = Arc::new(Mutex::new(Vec::new())); + + let app = Router::new().route( + "/v1/vector_stores/{store_id}/search", + post({ + let requests_for_handler = Arc::clone(&requests); + move |req: Request| { + let requests_for_handler = Arc::clone(&requests_for_handler); + async move { + let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024) + .await + .unwrap_or_default(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap_or_default(); + requests_for_handler.lock().await.push(body); + + let response = serde_json::json!({ + "object": "vector_store.search_results.page", + "search_query": ["test query"], + "data": [{ + "file_id": "file_abc", + "filename": "doc.txt", + "score": 0.95, + "attributes": {}, + "content": [{"type": "text", "text": "relevant content from doc"}] + }], + "has_more": false + }); + ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&response).unwrap(), + ) + .into_response() + } + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (port, requests, handle) +} From 8a9eff1e8b6e2a05ff2e72edbd3048cffcd2d318 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 17 Jun 2026 08:34:45 -0400 Subject: [PATCH 6/6] test: replay recorded file search responses Signed-off-by: Francisco Javier Arceo --- crates/agentic-core/src/types/io.rs | 3 +- .../vllm-file-search-openai-gpt-oss-20b.json | 296 ++++++++++++++++++ .../cassettes/record_file_search_cassettes.py | 151 +++++++++ .../agentic-server/tests/integration_test.rs | 87 +++-- 4 files changed, 509 insertions(+), 28 deletions(-) create mode 100644 crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json create mode 100644 crates/agentic-server/tests/cassettes/record_file_search_cassettes.py diff --git a/crates/agentic-core/src/types/io.rs b/crates/agentic-core/src/types/io.rs index 949abd0..f611935 100644 --- a/crates/agentic-core/src/types/io.rs +++ b/crates/agentic-core/src/types/io.rs @@ -162,7 +162,8 @@ pub struct FunctionToolCall { pub call_id: String, pub name: String, pub arguments: String, - pub status: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub status: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json b/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json new file mode 100644 index 0000000..f4067bc --- /dev/null +++ b/crates/agentic-server/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json @@ -0,0 +1,296 @@ +{ + "metadata": { + "recorded_at": "2026-06-17T12:30:24+00:00", + "model": "openai/gpt-oss-20b", + "note": "Harmony models reject tool_choice=required, so this cassette uses tool_choice=auto with a direct prompt." + }, + "turns": [ + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": { + "model": "openai/gpt-oss-20b", + "input": "Use the file_search tool to find information about Rust memory safety ownership.", + "tools": [ + { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + } + } + ], + "tool_choice": "auto", + "stream": false + } + }, + "response": { + "status_code": 200, + "body": { + "id": "resp_a4c1eeb7507be20a", + "created_at": 1781699419, + "incomplete_details": null, + "instructions": null, + "metadata": null, + "model": "openai/gpt-oss-20b", + "object": "response", + "output": [ + { + "id": "rs_b28e1cdf8ad9b24c", + "summary": [], + "type": "reasoning", + "content": [ + { + "text": "We need to use the file_search tool to search for \"Rust memory safety ownership\" or something like that. The instruction: \"Use the file_search tool to find information about Rust memory safety ownership.\" We need to call the function. We can pass query \"Rust memory safety ownership\". Then we need to answer by providing the relevant info. Probably the tool returns passages or summaries. Use tool.", + "type": "reasoning_text" + } + ], + "encrypted_content": null, + "status": null + }, + { + "arguments": "{\"query\":\"Rust memory safety ownership\"}", + "call_id": "call_8e217e84eeadc2fb", + "name": "file_search", + "type": "function_call", + "id": "fc_8e217e84eeadc2fb", + "status": null + } + ], + "parallel_tool_calls": true, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [ + { + "name": "file_search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + }, + "strict": null, + "type": "function", + "description": "Search uploaded files for relevant passages." + } + ], + "top_p": 1.0, + "background": false, + "max_output_tokens": 16247, + "max_tool_calls": null, + "previous_response_id": null, + "prompt": null, + "reasoning": null, + "service_tier": "auto", + "status": "completed", + "text": null, + "top_logprobs": null, + "truncation": "disabled", + "usage": { + "input_tokens": 137, + "input_tokens_details": { + "cached_tokens": 128, + "input_tokens_per_turn": [ + 137 + ], + "cached_tokens_per_turn": [ + 128 + ] + }, + "output_tokens": 106, + "output_tokens_details": { + "reasoning_tokens": 89, + "tool_output_tokens": 0, + "output_tokens_per_turn": [ + 106 + ], + "tool_output_tokens_per_turn": [ + 0 + ] + }, + "total_tokens": 243 + }, + "user": null, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "kv_transfer_params": null, + "input_messages": null, + "output_messages": null + } + } + }, + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": { + "model": "openai/gpt-oss-20b", + "input": [ + { + "type": "message", + "role": "user", + "content": "Use the file_search tool to find information about Rust memory safety ownership." + }, + { + "arguments": "{\"query\":\"Rust memory safety ownership\"}", + "call_id": "call_8e217e84eeadc2fb", + "name": "file_search", + "type": "function_call", + "id": "fc_8e217e84eeadc2fb", + "status": null + }, + { + "type": "function_call_output", + "call_id": "call_8e217e84eeadc2fb", + "output": "{\"results\":[{\"file_id\":\"file_abc\",\"filename\":\"rust-memory-safety.txt\",\"score\":0.95,\"attributes\":{},\"content\":[{\"type\":\"text\",\"text\":\"Rust enforces memory safety without a garbage collector through ownership, borrowing, and lifetimes.\"}]}]}" + } + ], + "tools": [ + { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + } + } + ], + "tool_choice": "auto", + "stream": false + } + }, + "response": { + "status_code": 200, + "body": { + "id": "resp_92585e571ba11e5e", + "created_at": 1781699421, + "incomplete_details": null, + "instructions": null, + "metadata": null, + "model": "openai/gpt-oss-20b", + "object": "response", + "output": [ + { + "id": "rs_b6dcc6546ab2d7c2", + "summary": [], + "type": "reasoning", + "content": [ + { + "text": "We have the file content. Provide answer.", + "type": "reasoning_text" + } + ], + "encrypted_content": null, + "status": null + }, + { + "id": "msg_a7c9b9489fe74d58", + "content": [ + { + "annotations": [], + "text": "**Rust Memory Safety & Ownership**\n\nRust guarantees memory safety at compile time without relying on a garbage collector. The core concepts that enable this are:\n\n- **Ownership** – Each value has a single owner.\n- **Borrowing** – References (`&T` for immutable and `&mut T` for mutable) allow temporary access without taking ownership.\n- **Lifetimes** – Compile‑time annotations that ensure references do not outlive the data they point to.\n\nThese rules together prevent common bugs such as use‑after‑free, double frees, and data races in concurrent programs.", + "type": "output_text", + "logprobs": null + } + ], + "role": "assistant", + "status": "completed", + "type": "message" + } + ], + "parallel_tool_calls": true, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [ + { + "name": "file_search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query." + } + }, + "required": [ + "query" + ] + }, + "strict": null, + "type": "function", + "description": "Search uploaded files for relevant passages." + } + ], + "top_p": 1.0, + "background": false, + "max_output_tokens": 16155, + "max_tool_calls": null, + "previous_response_id": null, + "prompt": null, + "reasoning": null, + "service_tier": "auto", + "status": "completed", + "text": null, + "top_logprobs": null, + "truncation": "disabled", + "usage": { + "input_tokens": 229, + "input_tokens_details": { + "cached_tokens": 224, + "input_tokens_per_turn": [ + 229 + ], + "cached_tokens_per_turn": [ + 224 + ] + }, + "output_tokens": 137, + "output_tokens_details": { + "reasoning_tokens": 10, + "tool_output_tokens": 0, + "output_tokens_per_turn": [ + 137 + ], + "tool_output_tokens_per_turn": [ + 0 + ] + }, + "total_tokens": 366 + }, + "user": null, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "kv_transfer_params": null, + "input_messages": null, + "output_messages": null + } + } + } + ] +} diff --git a/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py b/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py new file mode 100644 index 0000000..e7147cd --- /dev/null +++ b/crates/agentic-server/tests/cassettes/record_file_search_cassettes.py @@ -0,0 +1,151 @@ +"""Record vLLM Responses API cassettes for the file_search integration test.""" + +from __future__ import annotations + +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + + +DEFAULT_MODEL = "openai/gpt-oss-20b" +DEFAULT_PROMPT = "Use the file_search tool to find information about Rust memory safety ownership." +DEFAULT_OUTPUT = Path(__file__).with_name("file_search") / "vllm-file-search-openai-gpt-oss-20b.json" + +SEARCH_RESULT = { + "results": [ + { + "file_id": "file_abc", + "filename": "rust-memory-safety.txt", + "score": 0.95, + "attributes": {}, + "content": [ + { + "type": "text", + "text": "Rust enforces memory safety without a garbage collector through ownership, borrowing, and lifetimes.", + } + ], + } + ] +} + + +def post_json(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: + url = f"{base_url.rstrip('/')}/v1/responses" + data = json.dumps(payload).encode("utf-8") + request = urllib.request.Request(url, data=data, headers={"content-type": "application/json"}, method="POST") + try: + with urllib.request.urlopen(request, timeout=300) as response: + return json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as err: + body = err.read().decode("utf-8", errors="replace") + raise RuntimeError(f"vLLM request failed with HTTP {err.code}: {body}") from err + + +def file_search_tool() -> dict[str, Any]: + return { + "type": "function", + "name": "file_search", + "description": "Search uploaded files for relevant passages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query.", + } + }, + "required": ["query"], + }, + } + + +def find_file_search_call(response: dict[str, Any]) -> dict[str, Any]: + for item in response.get("output", []): + if item.get("type") == "function_call" and item.get("name") == "file_search": + return item + raise RuntimeError(f"recorded response did not include a file_search function call: {response}") + + +def main() -> int: + vllm_url = os.environ.get("VLLM_URL", "http://localhost:8000") + model = os.environ.get("MODEL", DEFAULT_MODEL) + output = Path(os.environ.get("OUTPUT", DEFAULT_OUTPUT)).resolve() + prompt = os.environ.get("PROMPT", DEFAULT_PROMPT) + tools = [file_search_tool()] + + first_request = { + "model": model, + "input": prompt, + "tools": tools, + "tool_choice": "auto", + "stream": False, + } + first_response = post_json(vllm_url, first_request) + call = find_file_search_call(first_response) + + second_request = { + "model": model, + "input": [ + { + "type": "message", + "role": "user", + "content": prompt, + }, + call, + { + "type": "function_call_output", + "call_id": call["call_id"], + "output": json.dumps(SEARCH_RESULT, separators=(",", ":")), + }, + ], + "tools": tools, + "tool_choice": "auto", + "stream": False, + } + second_response = post_json(vllm_url, second_request) + + cassette = { + "metadata": { + "recorded_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"), + "model": model, + "note": "Harmony models reject tool_choice=required, so this cassette uses tool_choice=auto with a direct prompt.", + }, + "turns": [ + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": first_request, + }, + "response": { + "status_code": 200, + "body": first_response, + }, + }, + { + "request": { + "method": "POST", + "path": "/v1/responses", + "body": second_request, + }, + "response": { + "status_code": 200, + "body": second_response, + }, + }, + ], + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(cassette, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + print(f"recorded file_search cassette -> {output}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/crates/agentic-server/tests/integration_test.rs b/crates/agentic-server/tests/integration_test.rs index 12fef02..e60df71 100644 --- a/crates/agentic-server/tests/integration_test.rs +++ b/crates/agentic-server/tests/integration_test.rs @@ -2,11 +2,60 @@ mod common; use common::{spawn_vllm_recording, start_gateway_with_ogx_base}; +use serde::Deserialize; +use serde_json::Value; + +const FILE_SEARCH_VLLM_CASSETTE: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/tests/cassettes/file_search/vllm-file-search-openai-gpt-oss-20b.json" +); + +#[derive(Debug, Deserialize)] +struct VllmCassette { + turns: Vec, +} + +#[derive(Debug, Deserialize)] +struct VllmTurn { + response: VllmResponse, +} + +#[derive(Debug, Deserialize)] +struct VllmResponse { + status_code: u16, + body: Value, +} fn ogx_base_url() -> Option { std::env::var("OGX_BASE_URL").ok() } +fn load_file_search_vllm_responses() -> Vec { + let text = std::fs::read_to_string(FILE_SEARCH_VLLM_CASSETTE) + .unwrap_or_else(|err| panic!("failed to read cassette {FILE_SEARCH_VLLM_CASSETTE}: {err}")); + let cassette: VllmCassette = serde_json::from_str(&text) + .unwrap_or_else(|err| panic!("failed to parse cassette {FILE_SEARCH_VLLM_CASSETTE}: {err}")); + cassette + .turns + .into_iter() + .map(|turn| { + assert_eq!(turn.response.status_code, 200, "cassette response should be successful"); + turn.response.body + }) + .collect() +} + +fn output_text(body: &Value) -> String { + body["output"] + .as_array() + .into_iter() + .flatten() + .filter_map(|item| item["content"].as_array()) + .flatten() + .filter_map(|content| content["text"].as_str()) + .collect() +} + async fn find_embedding_model(client: &reqwest::Client, ogx_url: &str) -> (String, u64) { let models_resp = client.get(format!("{ogx_url}/v1/models")).send().await.unwrap(); let models: serde_json::Value = models_resp.json().await.unwrap(); @@ -84,30 +133,7 @@ async fn upload_and_attach(client: &reqwest::Client, ogx_url: &str, vs_id: &str) } async fn assert_gateway_file_search_uses_ogx(client: &reqwest::Client, ogx_url: &str, vs_id: &str) { - let tool_call_response = serde_json::json!({ - "id": "resp_1", - "object": "response", - "status": "completed", - "output": [{ - "type": "function_call", - "id": "fc_1", - "call_id": "call_1", - "name": "file_search", - "arguments": "{\"query\": \"memory safety ownership\"}", - "status": "completed" - }] - }); - let final_response = serde_json::json!({ - "id": "resp_2", - "object": "response", - "status": "completed", - "output": [{ - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Rust uses ownership and borrowing."}] - }] - }); - let (vllm_port, requests, _vllm_handle) = spawn_vllm_recording(vec![tool_call_response, final_response]).await; + let (vllm_port, requests, _vllm_handle) = spawn_vllm_recording(load_file_search_vllm_responses()).await; let (gateway_addr, _) = start_gateway_with_ogx_base(vllm_port, ogx_url, None).await; let gateway_resp = client @@ -120,10 +146,17 @@ async fn assert_gateway_file_search_uses_ogx(client: &reqwest::Client, ogx_url: .send() .await .unwrap(); + let gateway_status = gateway_resp.status(); + let gateway_body: Value = gateway_resp.json().await.unwrap(); + assert!( + gateway_status.is_success(), + "gateway file_search failed: {gateway_body}" + ); + + let answer = output_text(&gateway_body); assert!( - gateway_resp.status().is_success(), - "gateway file_search failed: {}", - gateway_resp.text().await.unwrap_or_default() + answer.contains("ownership"), + "gateway should return the recorded final vLLM answer, got: {answer}" ); let requests = requests.lock().await;