From d7be45ef4ed61441f37abe4cfe1e3b5828f1e01a Mon Sep 17 00:00:00 2001 From: shane Date: Thu, 2 Apr 2026 07:36:48 -0500 Subject: [PATCH 1/3] feat(rust): add initial tanstack-ai crate and rs-chat example --- .gitignore | 1 + examples/rs-chat/Cargo.lock | 1957 +++++++++++++++++ examples/rs-chat/Cargo.toml | 9 + examples/rs-chat/src/main.rs | 223 ++ packages/rust/tanstack-ai/Cargo.lock | 1948 ++++++++++++++++ packages/rust/tanstack-ai/Cargo.toml | 26 + packages/rust/tanstack-ai/src/adapter.rs | 89 + .../tanstack-ai/src/adapters/anthropic.rs | 779 +++++++ .../rust/tanstack-ai/src/adapters/gemini.rs | 486 ++++ packages/rust/tanstack-ai/src/adapters/mod.rs | 7 + .../rust/tanstack-ai/src/adapters/openai.rs | 914 ++++++++ packages/rust/tanstack-ai/src/chat.rs | 546 +++++ .../src/client/connection_adapters.rs | 586 +++++ packages/rust/tanstack-ai/src/client/mod.rs | 3 + packages/rust/tanstack-ai/src/error.rs | 60 + packages/rust/tanstack-ai/src/lib.rs | 237 ++ packages/rust/tanstack-ai/src/messages.rs | 324 +++ packages/rust/tanstack-ai/src/middleware.rs | 322 +++ .../tanstack-ai/src/stream/json_parser.rs | 102 + packages/rust/tanstack-ai/src/stream/mod.rs | 9 + .../rust/tanstack-ai/src/stream/processor.rs | 265 +++ .../rust/tanstack-ai/src/stream/strategies.rs | 111 + packages/rust/tanstack-ai/src/stream/types.rs | 98 + .../rust/tanstack-ai/src/stream_response.rs | 135 ++ .../rust/tanstack-ai/src/tools/definition.rs | 105 + packages/rust/tanstack-ai/src/tools/mod.rs | 7 + .../rust/tanstack-ai/src/tools/registry.rs | 130 ++ .../rust/tanstack-ai/src/tools/tool_calls.rs | 357 +++ packages/rust/tanstack-ai/src/types.rs | 885 ++++++++ packages/rust/tanstack-ai/tests/chat_tests.rs | 861 ++++++++ .../tests/stream_processor_tests.rs | 208 ++ packages/rust/tanstack-ai/tests/test_utils.rs | 290 +++ .../tests/tool_call_manager_tests.rs | 170 ++ packages/rust/tanstack-ai/tests/unit_tests.rs | 730 ++++++ 34 files changed, 12980 insertions(+) create mode 100644 examples/rs-chat/Cargo.lock create mode 100644 examples/rs-chat/Cargo.toml create mode 100644 examples/rs-chat/src/main.rs create mode 100644 packages/rust/tanstack-ai/Cargo.lock create mode 100644 packages/rust/tanstack-ai/Cargo.toml create mode 100644 packages/rust/tanstack-ai/src/adapter.rs create mode 100644 packages/rust/tanstack-ai/src/adapters/anthropic.rs create mode 100644 packages/rust/tanstack-ai/src/adapters/gemini.rs create mode 100644 packages/rust/tanstack-ai/src/adapters/mod.rs create mode 100644 packages/rust/tanstack-ai/src/adapters/openai.rs create mode 100644 packages/rust/tanstack-ai/src/chat.rs create mode 100644 packages/rust/tanstack-ai/src/client/connection_adapters.rs create mode 100644 packages/rust/tanstack-ai/src/client/mod.rs create mode 100644 packages/rust/tanstack-ai/src/error.rs create mode 100644 packages/rust/tanstack-ai/src/lib.rs create mode 100644 packages/rust/tanstack-ai/src/messages.rs create mode 100644 packages/rust/tanstack-ai/src/middleware.rs create mode 100644 packages/rust/tanstack-ai/src/stream/json_parser.rs create mode 100644 packages/rust/tanstack-ai/src/stream/mod.rs create mode 100644 packages/rust/tanstack-ai/src/stream/processor.rs create mode 100644 packages/rust/tanstack-ai/src/stream/strategies.rs create mode 100644 packages/rust/tanstack-ai/src/stream/types.rs create mode 100644 packages/rust/tanstack-ai/src/stream_response.rs create mode 100644 packages/rust/tanstack-ai/src/tools/definition.rs create mode 100644 packages/rust/tanstack-ai/src/tools/mod.rs create mode 100644 packages/rust/tanstack-ai/src/tools/registry.rs create mode 100644 packages/rust/tanstack-ai/src/tools/tool_calls.rs create mode 100644 packages/rust/tanstack-ai/src/types.rs create mode 100644 packages/rust/tanstack-ai/tests/chat_tests.rs create mode 100644 packages/rust/tanstack-ai/tests/stream_processor_tests.rs create mode 100644 packages/rust/tanstack-ai/tests/test_utils.rs create mode 100644 packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs create mode 100644 packages/rust/tanstack-ai/tests/unit_tests.rs diff --git a/.gitignore b/.gitignore index 2ba04aa8..4e0fd8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ node_modules package-lock.json yarn.lock +target/ # builds build diff --git a/examples/rs-chat/Cargo.lock b/examples/rs-chat/Cargo.lock new file mode 100644 index 00000000..1c4bec63 --- /dev/null +++ b/examples/rs-chat/Cargo.lock @@ -0,0 +1,1957 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "js-sys" +version = "0.3.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e04e2ef80ce82e13552136fabeef8a5ed1f985a96805761cbb9a2c34e7664d9" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.184" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rs-chat" +version = "0.1.0" +dependencies = [ + "serde_json", + "tanstack-ai", + "tokio", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tanstack-ai" +version = "0.1.0" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures-core", + "futures-util", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "uuid", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0551fc1bb415591e3372d0bc4780db7e587d84e2a7e79da121051c5c4b89d0b0" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03623de6905b7206edd0a75f69f747f134b7f0a2323392d664448bf2d3c5d87e" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fbdf9a35adf44786aecd5ff89b4563a90325f9da0923236f6104e603c7e86be" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dca9693ef2bab6d4e6707234500350d8dad079eb508dca05530c85dc3a529ff2" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39129a682a6d2d841b6c429d0c51e5cb0ed1a03829d8b3d1e69a011e62cb3d3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd70027e39b12f0849461e08ffc50b9cd7688d942c1c8e3c7b22273236b4dd0a" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/examples/rs-chat/Cargo.toml b/examples/rs-chat/Cargo.toml new file mode 100644 index 00000000..c86c3788 --- /dev/null +++ b/examples/rs-chat/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "rs-chat" +version = "0.1.0" +edition = "2021" + +[dependencies] +tanstack-ai = { path = "../../packages/rust/tanstack-ai" } +tokio = { version = "1", features = ["full"] } +serde_json = "1" diff --git a/examples/rs-chat/src/main.rs b/examples/rs-chat/src/main.rs new file mode 100644 index 00000000..9ebd57d7 --- /dev/null +++ b/examples/rs-chat/src/main.rs @@ -0,0 +1,223 @@ +use std::sync::Arc; +use tanstack_ai::*; + +/// Quick example: chat with OpenAI using tanstack-ai in Rust. +/// +/// Usage: +/// OPENAI_API_KEY=sk-... cargo run +/// +/// Or: +/// cargo run -- sk-... +#[tokio::main] +async fn main() { + // Get API key from env or first arg + let api_key = std::env::var("OPENAI_API_KEY") + .or_else(|_| { + std::env::args() + .nth(1) + .ok_or_else(|| "Provide OPENAI_API_KEY as env var or first argument".to_string()) + }) + .expect("No API key provided"); + + // 1. Create an adapter + let adapter: Arc = Arc::new(openai_text("gpt-4o", &api_key)); + + println!("=== Simple Text Chat ===\n"); + simple_text(adapter.clone()).await; + + println!("\n=== Chat with Tool ===\n"); + chat_with_tool(adapter.clone()).await; + + println!("\n=== Non-Streaming Chat ===\n"); + non_streaming(adapter.clone()).await; + + println!("\n=== Multi-Turn Conversation ===\n"); + multi_turn(adapter.clone()).await; + + println!("\nDone!"); +} + +/// Helper to build ChatOptions with defaults +fn chat_opts(adapter: Arc, messages: Vec) -> ChatOptions { + ChatOptions { + adapter, + messages, + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + } +} + +/// Simple streaming text chat +async fn simple_text(adapter: Arc) { + let mut opts = chat_opts( + adapter, + vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Say hello in exactly 5 words.".into()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + opts.stream = true; + + let result = chat(opts).await.unwrap(); + + if let ChatResult::Chunks(chunks) = result { + let text = extract_text(&chunks); + println!(" Response: {}", text); + println!(" Chunks received: {}", chunks.len()); + } +} + +/// Chat with a tool that the model can call +async fn chat_with_tool(adapter: Arc) { + let weather_tool = Tool::new("get_weather", "Get current weather for a city") + .with_input_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "City name" } + }, + "required": ["city"] + }))) + .with_execute(|args: serde_json::Value, _ctx| async move { + let city = args["city"].as_str().unwrap_or("unknown"); + println!(" [Tool called: get_weather({})]", city); + Ok(serde_json::json!({ + "city": city, + "temperature": 72, + "conditions": "Sunny", + "unit": "fahrenheit" + })) + }); + + let mut opts = chat_opts( + adapter, + vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("What's the weather in San Francisco?".into()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + opts.tools = vec![weather_tool]; + + let result = chat(opts).await.unwrap(); + + if let ChatResult::Chunks(chunks) = result { + let text = extract_text(&chunks); + let tool_calls = count_chunk_type(&chunks, "TOOL_CALL_END"); + println!(" Response: {}", text); + println!(" Tool calls executed: {}", tool_calls); + println!(" Total chunks: {}", chunks.len()); + } +} + +/// Non-streaming chat (collects all text at once) +async fn non_streaming(adapter: Arc) { + let mut opts = chat_opts( + adapter, + vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("What is 2+2? Reply with just the number.".into()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + opts.stream = false; + + let result = chat(opts).await.unwrap(); + + if let ChatResult::Text(text) = result { + println!(" Response: {}", text); + } +} + +/// Multi-turn conversation +async fn multi_turn(adapter: Arc) { + let mut messages = vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("My favorite color is blue.".into()), + name: None, + tool_calls: None, + tool_call_id: None, + }]; + + // First turn - get response + let mut opts = chat_opts(adapter.clone(), messages.clone()); + opts.stream = false; + let result = chat(opts).await.unwrap(); + + let first_response = match &result { + ChatResult::Text(t) => t.clone(), + _ => String::new(), + }; + println!(" Turn 1 response: {}", first_response); + + // Add assistant response to history + messages.push(ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Text(first_response), + name: None, + tool_calls: None, + tool_call_id: None, + }); + + // Second turn - follow up + messages.push(ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("What color did I just tell you?".into()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + + let mut opts = chat_opts(adapter, messages); + opts.stream = false; + let result = chat(opts).await.unwrap(); + + if let ChatResult::Text(text) = result { + println!(" Turn 2 response: {}", text); + } +} + +// Helpers + +fn extract_text(chunks: &[StreamChunk]) -> String { + let mut content = String::new(); + for chunk in chunks { + if let StreamChunk::TextMessageContent { + delta, content: full, .. + } = chunk + { + if let Some(f) = full { + content = f.clone(); + } else { + content.push_str(delta); + } + } + } + content +} + +fn count_chunk_type(chunks: &[StreamChunk], chunk_type: &str) -> usize { + chunks + .iter() + .filter(|c| match chunk_type { + "TOOL_CALL_END" => matches!(c, StreamChunk::ToolCallEnd { .. }), + _ => false, + }) + .count() +} diff --git a/packages/rust/tanstack-ai/Cargo.lock b/packages/rust/tanstack-ai/Cargo.lock new file mode 100644 index 00000000..3d34ec8b --- /dev/null +++ b/packages/rust/tanstack-ai/Cargo.lock @@ -0,0 +1,1948 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "js-sys" +version = "0.3.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e04e2ef80ce82e13552136fabeef8a5ed1f985a96805761cbb9a2c34e7664d9" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.184" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tanstack-ai" +version = "0.1.0" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures-core", + "futures-util", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "uuid", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0551fc1bb415591e3372d0bc4780db7e587d84e2a7e79da121051c5c4b89d0b0" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03623de6905b7206edd0a75f69f747f134b7f0a2323392d664448bf2d3c5d87e" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fbdf9a35adf44786aecd5ff89b4563a90325f9da0923236f6104e603c7e86be" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dca9693ef2bab6d4e6707234500350d8dad079eb508dca05530c85dc3a529ff2" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39129a682a6d2d841b6c429d0c51e5cb0ed1a03829d8b3d1e69a011e62cb3d3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd70027e39b12f0849461e08ffc50b9cd7688d942c1c8e3c7b22273236b4dd0a" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/packages/rust/tanstack-ai/Cargo.toml b/packages/rust/tanstack-ai/Cargo.toml new file mode 100644 index 00000000..41ccaf00 --- /dev/null +++ b/packages/rust/tanstack-ai/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "tanstack-ai" +version = "0.1.0" +edition = "2021" +description = "Type-safe, provider-agnostic AI SDK for Rust — chat, streaming, tools, and agent loops" +license = "MIT" +repository = "https://github.com/TanStack/ai" + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = ["full"] } +futures-core = "0.3" +futures-util = "0.3" +async-trait = "0.1" +reqwest = { version = "0.12", features = ["json", "stream"] } +thiserror = "2" +uuid = { version = "1", features = ["v4"] } +chrono = { version = "0.4", features = ["serde"] } +tokio-stream = "0.1" +tokio-util = { version = "0.7", features = ["codec"] } +bytes = "1" +regex = "1" + +[dev-dependencies] +tokio = { version = "1", features = ["full", "test-util"] } diff --git a/packages/rust/tanstack-ai/src/adapter.rs b/packages/rust/tanstack-ai/src/adapter.rs new file mode 100644 index 00000000..b91210da --- /dev/null +++ b/packages/rust/tanstack-ai/src/adapter.rs @@ -0,0 +1,89 @@ +use async_trait::async_trait; +use futures_core::Stream; +use std::pin::Pin; + +use crate::error::AiResult; +use crate::types::{StreamChunk, StructuredOutputOptions, StructuredOutputResult, TextOptions}; + +/// Configuration for adapter instances. +#[derive(Debug, Clone, Default)] +pub struct TextAdapterConfig { + pub api_key: Option, + pub base_url: Option, + pub timeout: Option, + pub max_retries: Option, + pub headers: Option>, +} + +/// A pinned, boxed stream of `StreamChunk` events. +pub type ChunkStream = Pin> + Send>>; + +/// Core text adapter trait. +/// +/// An adapter bridges the TanStack AI engine to a specific AI provider (OpenAI, Anthropic, etc.). +/// Providers implement this trait to handle streaming chat completions and structured output. +#[async_trait] +pub trait TextAdapter: Send + Sync { + /// Provider name identifier (e.g., "openai", "anthropic"). + fn name(&self) -> &str; + + /// The model this adapter is configured for. + fn model(&self) -> &str; + + /// Stream chat completions from the model, yielding AG-UI protocol events. + async fn chat_stream(&self, options: &TextOptions) -> AiResult; + + /// Generate structured output using the provider's native structured output API. + async fn structured_output( + &self, + options: &StructuredOutputOptions, + ) -> AiResult; +} + +/// Image generation adapter trait. +#[async_trait] +pub trait ImageAdapter: Send + Sync { + fn name(&self) -> &str; + fn model(&self) -> &str; + + async fn generate_image( + &self, + options: &crate::types::ImageGenerationOptions, + ) -> AiResult; +} + +/// Summarization adapter trait. +#[async_trait] +pub trait SummarizeAdapter: Send + Sync { + fn name(&self) -> &str; + fn model(&self) -> &str; + + async fn summarize( + &self, + options: &crate::types::SummarizationOptions, + ) -> AiResult; +} + +/// Text-to-speech adapter trait. +#[async_trait] +pub trait TtsAdapter: Send + Sync { + fn name(&self) -> &str; + fn model(&self) -> &str; + + async fn generate_speech( + &self, + options: &crate::types::TtsOptions, + ) -> AiResult; +} + +/// Transcription adapter trait. +#[async_trait] +pub trait TranscriptionAdapter: Send + Sync { + fn name(&self) -> &str; + fn model(&self) -> &str; + + async fn transcribe( + &self, + options: &crate::types::TranscriptionOptions, + ) -> AiResult; +} diff --git a/packages/rust/tanstack-ai/src/adapters/anthropic.rs b/packages/rust/tanstack-ai/src/adapters/anthropic.rs new file mode 100644 index 00000000..c0852391 --- /dev/null +++ b/packages/rust/tanstack-ai/src/adapters/anthropic.rs @@ -0,0 +1,779 @@ +use async_trait::async_trait; +use futures_core::Stream; +use reqwest::Client; +use serde::Deserialize; +use std::collections::HashMap; + +use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::error::{AiError, AiResult}; +use crate::types::*; + +/// Anthropic Claude text adapter. +/// +/// Uses the Anthropic Messages API with streaming to produce AG-UI protocol events. +pub struct AnthropicTextAdapter { + api_key: String, + model: String, + base_url: String, + client: Client, +} + +impl AnthropicTextAdapter { + /// Create a new Anthropic adapter. + pub fn new(model: impl Into, api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: model.into(), + base_url: "https://api.anthropic.com".to_string(), + client: Client::new(), + } + } + + /// Create with custom configuration. + pub fn with_config( + model: impl Into, + api_key: impl Into, + config: TextAdapterConfig, + ) -> Self { + let mut adapter = Self::new(model, api_key); + if let Some(base_url) = config.base_url { + adapter.base_url = base_url; + } + adapter + } + + fn build_request_body( + &self, + options: &TextOptions, + stream: bool, + ) -> serde_json::Value { + let mut body = serde_json::Map::new(); + + body.insert("model".to_string(), serde_json::json!(self.model)); + body.insert("stream".to_string(), serde_json::json!(stream)); + + // System prompts + if !options.system_prompts.is_empty() { + let system = options.system_prompts.join("\n"); + body.insert("system".to_string(), serde_json::json!(system)); + } + + // Messages + let messages: Vec = options + .messages + .iter() + .filter(|m| m.role != MessageRole::System) + .map(|msg| { + let mut m = serde_json::Map::new(); + // Anthropic uses "user" for tool result messages + let role_str = if msg.role == MessageRole::Tool { "user" } else { msg.role.as_str() }; + m.insert("role".to_string(), serde_json::json!(role_str)); + + if msg.role == MessageRole::Tool { + let tool_result_text = match &msg.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Null | MessageContent::Parts(_) => String::new(), + }; + m.insert( + "content".to_string(), + serde_json::json!([{ + "type": "tool_result", + "tool_use_id": msg.tool_call_id, + "content": tool_result_text + }]), + ); + return serde_json::Value::Object(m); + } + + match &msg.content { + MessageContent::Text(text) => { + m.insert("content".to_string(), serde_json::json!(text)); + } + MessageContent::Parts(parts) => { + let content: Vec = parts + .iter() + .map(|part| match part { + ContentPart::Text { content } => { + serde_json::json!({"type": "text", "text": content}) + } + ContentPart::Image { source } => { + match source { + ContentPartSource::Data { value, mime_type } => { + serde_json::json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": value + } + }) + } + ContentPartSource::Url { value, .. } => { + serde_json::json!({ + "type": "image", + "source": { + "type": "url", + "url": value + } + }) + } + } + } + _ => serde_json::json!({"type": "text", "text": ""}), + }) + .collect(); + m.insert("content".to_string(), serde_json::json!(content)); + } + MessageContent::Null => { + if let Some(tool_calls) = &msg.tool_calls { + let content: Vec = tool_calls + .iter() + .map(|tc| { + serde_json::json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.function.name, + "input": serde_json::from_str::( + &tc.function.arguments + ).unwrap_or(serde_json::json!({})) + }) + }) + .collect(); + m.insert("content".to_string(), serde_json::json!(content)); + } else if msg.role == MessageRole::Tool { + m.insert( + "content".to_string(), + serde_json::json!([{ + "type": "tool_result", + "tool_use_id": msg.tool_call_id, + "content": msg.content.as_str().unwrap_or("") + }]), + ); + } else { + m.insert( + "content".to_string(), + serde_json::json!(""), + ); + } + } + } + + serde_json::Value::Object(m) + }) + .collect(); + + body.insert("messages".to_string(), serde_json::json!(messages)); + + // Tools (Anthropic format) + if !options.tools.is_empty() { + let tools: Vec = options + .tools + .iter() + .map(|t| { + let mut tool = serde_json::Map::new(); + tool.insert("name".to_string(), serde_json::json!(t.name)); + tool.insert("description".to_string(), serde_json::json!(t.description)); + if let Some(schema) = &t.input_schema { + tool.insert( + "input_schema".to_string(), + serde_json::to_value(schema).unwrap_or(serde_json::json!({})), + ); + } + serde_json::Value::Object(tool) + }) + .collect(); + body.insert("tools".to_string(), serde_json::json!(tools)); + } + + if let Some(temp) = options.temperature { + body.insert("temperature".to_string(), serde_json::json!(temp)); + } + if let Some(top_p) = options.top_p { + body.insert("top_p".to_string(), serde_json::json!(top_p)); + } + if let Some(max_tokens) = options.max_tokens { + body.insert("max_tokens".to_string(), serde_json::json!(max_tokens)); + } + + serde_json::Value::Object(body) + } +} + +// Anthropic SSE event types +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum AnthropicEvent { + #[serde(rename = "message_start")] + MessageStart { message: AnthropicMessage }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: AnthropicContentBlock, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { + index: usize, + delta: AnthropicDelta, + }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { + delta: AnthropicMessageDelta, + usage: Option, + }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "error")] + Error { error: AnthropicError }, + #[serde(other)] + Unknown, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct AnthropicMessage { + id: String, + #[serde(rename = "type")] + msg_type: String, + role: String, + content: Vec, + model: String, + stop_reason: Option, + usage: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum AnthropicContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(other)] + Unknown, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum AnthropicDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, + #[serde(other)] + Unknown, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct AnthropicMessageDelta { + stop_reason: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct AnthropicUsage { + input_tokens: u32, + output_tokens: u32, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct AnthropicError { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[async_trait] +impl TextAdapter for AnthropicTextAdapter { + fn name(&self) -> &str { + "anthropic" + } + + fn model(&self) -> &str { + &self.model + } + + async fn chat_stream(&self, options: &TextOptions) -> AiResult { + let body = self.build_request_body(options, true); + let url = format!("{}/v1/messages", self.base_url); + + let response = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "Anthropic API error ({}): {}", + status, text + ))); + } + + let model = self.model.clone(); + let byte_stream = response.bytes_stream(); + + let chunk_stream = parse_anthropic_sse_stream(byte_stream, model); + Ok(Box::pin(chunk_stream)) + } + + async fn structured_output( + &self, + options: &StructuredOutputOptions, + ) -> AiResult { + let mut body = self.build_request_body(&options.chat_options, false); + + // Add tool for structured output + if let Some(obj) = body.as_object_mut() { + obj.insert( + "tools".to_string(), + serde_json::json!([{ + "name": "structured_output", + "description": "Return structured output", + "input_schema": options.output_schema + }]), + ); + obj.insert( + "tool_choice".to_string(), + serde_json::json!({"type": "tool", "name": "structured_output"}), + ); + } + + let url = format!("{}/v1/messages", self.base_url); + + let response = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "Anthropic API error ({}): {}", + status, text + ))); + } + + let response_data: serde_json::Value = response.json().await?; + + // Extract tool use input + if let Some(content) = response_data.get("content").and_then(|c| c.as_array()) { + for block in content { + if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") { + if let Some(input) = block.get("input") { + return Ok(StructuredOutputResult { + data: input.clone(), + raw_text: serde_json::to_string(input).unwrap_or_default(), + }); + } + } + } + } + + Err(AiError::Provider( + "No structured output in Anthropic response".to_string(), + )) + } +} + +fn parse_anthropic_sse_stream( + stream: S, + model: String, +) -> impl Stream> +where + S: Stream> + Send, +{ + use futures_util::StreamExt; + + let stream = Box::pin(stream); + futures_util::stream::unfold( + (stream, String::new(), HashMap::::new()), + move |(mut stream, mut line_buf, mut tool_block_ids)| { + let model = model.clone(); + async move { + loop { + match stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + line_buf.push_str(&text); + + let mut chunks: Vec> = Vec::new(); + let mut processed_to = 0; + + while let Some(newline_pos) = line_buf[processed_to..].find('\n') { + let abs_pos = processed_to + newline_pos; + let line = line_buf[processed_to..abs_pos].trim(); + processed_to = abs_pos + 1; + + if line.is_empty() || line.starts_with(':') { + continue; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + match serde_json::from_str::(data_str) { + Ok(event) => { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.clone()); + + match event { + AnthropicEvent::ContentBlockStart { index, content_block } => { + match content_block { + AnthropicContentBlock::Text { .. } => { + chunks.push(Ok(StreamChunk::TextMessageStart { + timestamp: now, + message_id: format!("block-{}", index), + role: "assistant".to_string(), + model: model_str, + })); + } + AnthropicContentBlock::ToolUse { id, name, .. } => { + tool_block_ids.insert(index, id.clone()); + chunks.push(Ok(StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: id, + tool_name: name, + parent_message_id: None, + index: Some(index), + provider_metadata: None, + model: model_str, + })); + } + _ => {} + } + } + AnthropicEvent::ContentBlockDelta { index, delta } => { + match delta { + AnthropicDelta::TextDelta { text } => { + chunks.push(Ok(StreamChunk::TextMessageContent { + timestamp: now, + message_id: format!("block-{}", index), + delta: text, + content: None, + model: model_str, + })); + } + AnthropicDelta::InputJsonDelta { partial_json } => { + let tool_call_id = tool_block_ids + .get(&index) + .cloned() + .unwrap_or_else(|| format!("block-{}", index)); + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta: partial_json, + args: None, + model: model_str, + })); + } + _ => {} + } + } + other => { + for chunk in convert_anthropic_event(other, &model) { + chunks.push(Ok(chunk)); + } + } + } + } + Err(e) => { + eprintln!( + "Failed to parse Anthropic event: {} - {}", + e, + &data_str[..data_str.len().min(200)] + ); + } + } + } + } + + line_buf = line_buf[processed_to..].to_string(); + + if !chunks.is_empty() { + return Some(( + futures_util::stream::iter(chunks), + (stream, line_buf, tool_block_ids), + )); + } + } + Some(Err(e)) => { + return Some(( + futures_util::stream::iter(vec![Err(AiError::Http(e))]), + (stream, line_buf, tool_block_ids), + )); + } + None => { + let line = line_buf.trim(); + if line.is_empty() { + return None; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + match serde_json::from_str::(data_str) { + Ok(event) => { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.clone()); + let mut chunks: Vec> = Vec::new(); + + match event { + AnthropicEvent::ContentBlockStart { index, content_block } => { + match content_block { + AnthropicContentBlock::Text { .. } => { + chunks.push(Ok(StreamChunk::TextMessageStart { + timestamp: now, + message_id: format!("block-{}", index), + role: "assistant".to_string(), + model: model_str, + })); + } + AnthropicContentBlock::ToolUse { id, name, .. } => { + tool_block_ids.insert(index, id.clone()); + chunks.push(Ok(StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: id, + tool_name: name, + parent_message_id: None, + index: Some(index), + provider_metadata: None, + model: model_str, + })); + } + _ => {} + } + } + AnthropicEvent::ContentBlockDelta { index, delta } => { + match delta { + AnthropicDelta::TextDelta { text } => { + chunks.push(Ok(StreamChunk::TextMessageContent { + timestamp: now, + message_id: format!("block-{}", index), + delta: text, + content: None, + model: model_str, + })); + } + AnthropicDelta::InputJsonDelta { partial_json } => { + let tool_call_id = tool_block_ids + .get(&index) + .cloned() + .unwrap_or_else(|| format!("block-{}", index)); + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta: partial_json, + args: None, + model: model_str, + })); + } + _ => {} + } + } + other => { + for chunk in convert_anthropic_event(other, &model) { + chunks.push(Ok(chunk)); + } + } + } + + return Some(( + futures_util::stream::iter(chunks), + (stream, String::new(), tool_block_ids), + )); + } + Err(e) => { + eprintln!( + "Failed to parse Anthropic event: {} - {}", + e, + &data_str[..data_str.len().min(200)] + ); + } + } + } + + return None; + } + } + } + } + }, + ) + .flatten() +} + +fn convert_anthropic_event(event: AnthropicEvent, model: &str) -> Vec { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.to_string()); + + match event { + AnthropicEvent::MessageStart { message } => { + vec![ + StreamChunk::RunStarted { + timestamp: now, + run_id: message.id, + thread_id: None, + model: model_str, + }, + ] + } + + AnthropicEvent::ContentBlockStart { index, content_block } => match content_block { + AnthropicContentBlock::Text { .. } => { + vec![StreamChunk::TextMessageStart { + timestamp: now, + message_id: format!("block-{}", index), + role: "assistant".to_string(), + model: model_str, + }] + } + AnthropicContentBlock::ToolUse { id, name, .. } => { + vec![StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: id, + tool_name: name, + parent_message_id: None, + index: Some(index), + provider_metadata: None, + model: model_str, + }] + } + _ => vec![], + }, + + AnthropicEvent::ContentBlockDelta { index, delta } => match delta { + AnthropicDelta::TextDelta { text } => { + vec![StreamChunk::TextMessageContent { + timestamp: now, + message_id: format!("block-{}", index), + delta: text, + content: None, + model: model_str, + }] + } + AnthropicDelta::InputJsonDelta { partial_json } => { + vec![StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id: format!("block-{}", index), + delta: partial_json, + args: None, + model: model_str, + }] + } + _ => vec![], + }, + + AnthropicEvent::MessageDelta { delta, usage } => { + let mut chunks = vec![]; + + if let Some(reason) = delta.stop_reason { + let finish_reason = match reason.as_str() { + "end_turn" => "stop".to_string(), + "max_tokens" => "length".to_string(), + "tool_use" => "tool_calls".to_string(), + _ => reason, + }; + + let u = usage.map(|u| Usage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: u.input_tokens + u.output_tokens, + }); + + chunks.push(StreamChunk::RunFinished { + timestamp: now, + run_id: format!("anthropic-{}", chrono::Utc::now().timestamp_millis()), + finish_reason: Some(finish_reason), + usage: u, + model: model_str, + }); + } + + chunks + } + + AnthropicEvent::Error { error } => { + vec![StreamChunk::RunError { + timestamp: now, + run_id: None, + error: RunErrorData { + message: error.message, + code: Some(error.error_type), + }, + model: model_str, + }] + } + + _ => vec![], + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use futures_util::StreamExt; + + #[tokio::test] + async fn parse_anthropic_maps_index_to_tool_use_id() { + let start = serde_json::json!({ + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {} + } + }) + .to_string(); + + let delta = serde_json::json!({ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": "{\"city\":\"SF\"}" + } + }) + .to_string(); + + let sse = format!("data: {}\ndata: {}", start, delta); + let stream = futures_util::stream::iter(vec![Ok(Bytes::from(sse))]); + + let chunks = parse_anthropic_sse_stream(stream, "claude-3-5-sonnet".to_string()) + .collect::>() + .await; + + assert_eq!(chunks.len(), 2); + + match &chunks[0] { + Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => assert_eq!(tool_call_id, "toolu_123"), + other => panic!("unexpected chunk 0: {other:?}"), + } + + match &chunks[1] { + Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => assert_eq!(tool_call_id, "toolu_123"), + other => panic!("unexpected chunk 1: {other:?}"), + } + } +} diff --git a/packages/rust/tanstack-ai/src/adapters/gemini.rs b/packages/rust/tanstack-ai/src/adapters/gemini.rs new file mode 100644 index 00000000..02ebfb97 --- /dev/null +++ b/packages/rust/tanstack-ai/src/adapters/gemini.rs @@ -0,0 +1,486 @@ +use async_trait::async_trait; +use futures_core::Stream; +use reqwest::Client; +use serde::Deserialize; + +use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::error::{AiError, AiResult}; +use crate::types::*; + +/// Google Gemini text adapter. +/// +/// Uses the Gemini Generate Content API with streaming to produce AG-UI protocol events. +pub struct GeminiTextAdapter { + api_key: String, + model: String, + base_url: String, + client: Client, +} + +impl GeminiTextAdapter { + /// Create a new Gemini adapter. + pub fn new(model: impl Into, api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: model.into(), + base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(), + client: Client::new(), + } + } + + /// Create with custom configuration. + pub fn with_config( + model: impl Into, + api_key: impl Into, + config: TextAdapterConfig, + ) -> Self { + let mut adapter = Self::new(model, api_key); + if let Some(base_url) = config.base_url { + adapter.base_url = base_url; + } + adapter + } + + fn build_request_body( + &self, + options: &TextOptions, + ) -> serde_json::Value { + let mut body = serde_json::Map::new(); + + // Contents (messages) + let mut contents = Vec::new(); + + // System instruction + if !options.system_prompts.is_empty() { + body.insert( + "system_instruction".to_string(), + serde_json::json!({ + "parts": options.system_prompts.iter().map(|p| { + serde_json::json!({"text": p}) + }).collect::>() + }), + ); + } + + for msg in &options.messages { + if msg.role == MessageRole::System { + continue; + } + + if msg.role == MessageRole::Tool { + let tool_result_text = match &msg.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Null | MessageContent::Parts(_) => String::new(), + }; + contents.push(serde_json::json!({ + "role": "user", + "parts": [{ + "functionResponse": { + "name": msg.tool_call_id.as_deref().unwrap_or(""), + "response": { + "result": tool_result_text + } + } + }] + })); + continue; + } + + let role = match msg.role { + MessageRole::User | MessageRole::Tool => "user", + MessageRole::Assistant => "model", + MessageRole::System => continue, + }; + + let parts = match &msg.content { + MessageContent::Text(text) => { + vec![serde_json::json!({"text": text})] + } + MessageContent::Parts(ps) => { + ps.iter() + .map(|part| match part { + ContentPart::Text { content } => { + serde_json::json!({"text": content}) + } + ContentPart::Image { source } => match source { + ContentPartSource::Data { value, mime_type } => { + serde_json::json!({ + "inlineData": { + "mimeType": mime_type, + "data": value + } + }) + } + ContentPartSource::Url { value, .. } => { + serde_json::json!({"text": format!("[Image: {}]", value)}) + } + }, + _ => serde_json::json!({"text": ""}), + }) + .collect() + } + MessageContent::Null => { + vec![serde_json::json!({"text": ""})] + } + }; + + contents.push(serde_json::json!({ + "role": role, + "parts": parts + })); + } + + body.insert("contents".to_string(), serde_json::json!(contents)); + + // Tools + if !options.tools.is_empty() { + let function_declarations: Vec = options + .tools + .iter() + .map(|t| { + let mut func = serde_json::Map::new(); + func.insert("name".to_string(), serde_json::json!(t.name)); + func.insert("description".to_string(), serde_json::json!(t.description)); + if let Some(schema) = &t.input_schema { + func.insert( + "parameters".to_string(), + serde_json::to_value(schema).unwrap_or(serde_json::json!({})), + ); + } + serde_json::Value::Object(func) + }) + .collect(); + + body.insert( + "tools".to_string(), + serde_json::json!([{"function_declarations": function_declarations}]), + ); + } + + // Generation config + let mut gen_config = serde_json::Map::new(); + if let Some(temp) = options.temperature { + gen_config.insert("temperature".to_string(), serde_json::json!(temp)); + } + if let Some(top_p) = options.top_p { + gen_config.insert("topP".to_string(), serde_json::json!(top_p)); + } + if let Some(max_tokens) = options.max_tokens { + gen_config.insert("maxOutputTokens".to_string(), serde_json::json!(max_tokens)); + } + + if !gen_config.is_empty() { + body.insert( + "generationConfig".to_string(), + serde_json::Value::Object(gen_config), + ); + } + + serde_json::Value::Object(body) + } +} + +// Gemini SSE event types +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiResponse { + candidates: Option>, + usage_metadata: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiCandidate { + content: Option, + finish_reason: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiContent { + parts: Option>, + role: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum GeminiPart { + Text { text: String }, + FunctionCall { function_call: GeminiFunctionCall }, + FunctionResponse { function_response: GeminiFunctionResponse }, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiFunctionCall { + name: String, + args: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiFunctionResponse { + name: String, + response: serde_json::Value, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct GeminiUsage { + prompt_token_count: Option, + candidates_token_count: Option, + total_token_count: Option, +} + +#[async_trait] +impl TextAdapter for GeminiTextAdapter { + fn name(&self) -> &str { + "gemini" + } + + fn model(&self) -> &str { + &self.model + } + + async fn chat_stream(&self, options: &TextOptions) -> AiResult { + let body = self.build_request_body(options); + let url = format!( + "{}/models/{}:streamGenerateContent?key={}&alt=sse", + self.base_url, self.model, self.api_key + ); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "Gemini API error ({}): {}", + status, text + ))); + } + + let model = self.model.clone(); + let byte_stream = response.bytes_stream(); + + let chunk_stream = parse_gemini_sse_stream(byte_stream, model); + Ok(Box::pin(chunk_stream)) + } + + async fn structured_output( + &self, + options: &StructuredOutputOptions, + ) -> AiResult { + let mut body = self.build_request_body(&options.chat_options); + + // Add response schema + if let Some(obj) = body.as_object_mut() { + obj.insert( + "generationConfig".to_string(), + serde_json::json!({ + "response_mime_type": "application/json", + "response_schema": options.output_schema + }), + ); + } + + let url = format!( + "{}/models/{}:generateContent?key={}", + self.base_url, self.model, self.api_key + ); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "Gemini API error ({}): {}", + status, text + ))); + } + + let response_data: GeminiResponse = response.json().await?; + + let text = response_data + .candidates + .and_then(|c| c.into_iter().next()) + .and_then(|c| c.content) + .and_then(|c| c.parts) + .and_then(|p| p.into_iter().next()) + .and_then(|p| match p { + GeminiPart::Text { text } => Some(text), + _ => None, + }) + .unwrap_or_default(); + + let data: serde_json::Value = serde_json::from_str(&text)?; + + Ok(StructuredOutputResult { + data, + raw_text: text, + }) + } +} + +fn parse_gemini_sse_stream( + stream: S, + model: String, +) -> impl Stream> +where + S: Stream> + Send, +{ + use futures_util::StreamExt; + + let stream = Box::pin(stream); + futures_util::stream::unfold( + (stream, String::new()), + move |(mut stream, mut line_buf)| { + let model = model.clone(); + async move { + loop { + match stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + line_buf.push_str(&text); + + let mut chunks: Vec> = Vec::new(); + let mut processed_to = 0; + + while let Some(newline_pos) = line_buf[processed_to..].find('\n') { + let abs_pos = processed_to + newline_pos; + let line = line_buf[processed_to..abs_pos].trim(); + processed_to = abs_pos + 1; + + if line.is_empty() || line.starts_with(':') { + continue; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + match serde_json::from_str::(data_str) { + Ok(response) => { + for chunk in convert_gemini_response(response, &model) { + chunks.push(Ok(chunk)); + } + } + Err(e) => { + eprintln!( + "Failed to parse Gemini event: {} - {}", + e, + &data_str[..data_str.len().min(200)] + ); + } + } + } + } + + line_buf = line_buf[processed_to..].to_string(); + + if !chunks.is_empty() { + return Some((futures_util::stream::iter(chunks), (stream, line_buf))); + } + } + Some(Err(e)) => { + return Some(( + futures_util::stream::iter(vec![Err(AiError::Http(e))]), + (stream, line_buf), + )); + } + None => return None, + } + } + } + }, + ) + .flatten() +} + +fn convert_gemini_response(response: GeminiResponse, model: &str) -> Vec { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.to_string()); + let mut chunks = vec![]; + + if let Some(candidates) = response.candidates { + for (i, candidate) in candidates.iter().enumerate() { + if let Some(content) = &candidate.content { + if let Some(parts) = &content.parts { + for part in parts { + match part { + GeminiPart::Text { text } => { + chunks.push(StreamChunk::TextMessageContent { + timestamp: now, + message_id: format!("gemini-{}", i), + delta: text.clone(), + content: None, + model: model_str.clone(), + }); + } + GeminiPart::FunctionCall { function_call } => { + chunks.push(StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: format!("gemini-tool-{}", i), + tool_name: function_call.name.clone(), + parent_message_id: None, + index: Some(i), + provider_metadata: None, + model: model_str.clone(), + }); + + if let Some(args) = &function_call.args { + chunks.push(StreamChunk::ToolCallEnd { + timestamp: now, + tool_call_id: format!("gemini-tool-{}", i), + tool_name: function_call.name.clone(), + input: Some(args.clone()), + result: None, + model: model_str.clone(), + }); + } + } + _ => {} + } + } + } + } + + if let Some(finish_reason) = &candidate.finish_reason { + let reason = match finish_reason.as_str() { + "STOP" => "stop", + "MAX_TOKENS" => "length", + "SAFETY" => "content_filter", + "TOOL_CALL" => "tool_calls", + _ => finish_reason.as_str(), + }; + + let usage = response.usage_metadata.as_ref().map(|u| Usage { + prompt_tokens: u.prompt_token_count.unwrap_or(0), + completion_tokens: u.candidates_token_count.unwrap_or(0), + total_tokens: u.total_token_count.unwrap_or(0), + }); + + chunks.push(StreamChunk::RunFinished { + timestamp: now, + run_id: format!("gemini-{}", chrono::Utc::now().timestamp_millis()), + finish_reason: Some(reason.to_string()), + usage, + model: model_str.clone(), + }); + } + } + } + + chunks +} diff --git a/packages/rust/tanstack-ai/src/adapters/mod.rs b/packages/rust/tanstack-ai/src/adapters/mod.rs new file mode 100644 index 00000000..bc2f6bfa --- /dev/null +++ b/packages/rust/tanstack-ai/src/adapters/mod.rs @@ -0,0 +1,7 @@ +pub mod anthropic; +pub mod gemini; +pub mod openai; + +pub use anthropic::*; +pub use gemini::*; +pub use openai::*; diff --git a/packages/rust/tanstack-ai/src/adapters/openai.rs b/packages/rust/tanstack-ai/src/adapters/openai.rs new file mode 100644 index 00000000..85e1095d --- /dev/null +++ b/packages/rust/tanstack-ai/src/adapters/openai.rs @@ -0,0 +1,914 @@ +use async_trait::async_trait; +use futures_core::Stream; +use reqwest::Client; +use serde::Deserialize; +use std::collections::HashMap; + +use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::error::{AiError, AiResult}; +use crate::types::*; + +/// OpenAI text adapter. +/// +/// Uses the OpenAI Responses API with streaming to produce AG-UI protocol events. +pub struct OpenAiTextAdapter { + api_key: String, + model: String, + base_url: String, + client: Client, +} + +impl OpenAiTextAdapter { + /// Create a new OpenAI adapter with an explicit API key. + pub fn new(model: impl Into, api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: model.into(), + base_url: "https://api.openai.com/v1".to_string(), + client: Client::new(), + } + } + + /// Create with custom configuration. + pub fn with_config( + model: impl Into, + api_key: impl Into, + config: TextAdapterConfig, + ) -> Self { + let mut adapter = Self::new(model, api_key); + if let Some(base_url) = config.base_url { + adapter.base_url = base_url; + } + adapter + } + + /// Build the request body for the OpenAI Responses API. + fn build_request_body( + &self, + options: &TextOptions, + stream: bool, + ) -> serde_json::Value { + let mut body = serde_json::Map::new(); + + body.insert("model".to_string(), serde_json::json!(self.model)); + body.insert("stream".to_string(), serde_json::json!(stream)); + + // Convert messages to OpenAI format + let input = self.convert_messages(&options.messages, &options.system_prompts); + body.insert("input".to_string(), input); + + // Tools + if !options.tools.is_empty() { + let tools: Vec = options + .tools + .iter() + .map(|t| { + let mut tool = serde_json::Map::new(); + tool.insert("type".to_string(), serde_json::json!("function")); + tool.insert("name".to_string(), serde_json::json!(t.name)); + tool.insert("description".to_string(), serde_json::json!(t.description)); + if let Some(schema) = &t.input_schema { + tool.insert( + "parameters".to_string(), + serde_json::to_value(schema).unwrap_or(serde_json::json!({})), + ); + } + serde_json::Value::Object(tool) + }) + .collect(); + body.insert("tools".to_string(), serde_json::json!(tools)); + } + + // Optional parameters + if let Some(temp) = options.temperature { + body.insert("temperature".to_string(), serde_json::json!(temp)); + } + if let Some(top_p) = options.top_p { + body.insert("top_p".to_string(), serde_json::json!(top_p)); + } + if let Some(max_tokens) = options.max_tokens { + body.insert("max_output_tokens".to_string(), serde_json::json!(max_tokens)); + } + + serde_json::Value::Object(body) + } + + /// Convert ModelMessages to OpenAI Responses API input format. + fn convert_messages( + &self, + messages: &[ModelMessage], + system_prompts: &[String], + ) -> serde_json::Value { + let mut items = Vec::new(); + + // System prompts + for prompt in system_prompts { + items.push(serde_json::json!({ + "role": "system", + "content": prompt + })); + } + + for msg in messages { + if msg.role == MessageRole::Tool { + let content = match &msg.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Null => String::new(), + MessageContent::Parts(_) => String::new(), + }; + items.push(serde_json::json!({ + "role": "tool", + "tool_call_id": msg.tool_call_id, + "content": content + })); + continue; + } + + match &msg.content { + MessageContent::Text(text) => { + items.push(serde_json::json!({ + "role": msg.role.as_str(), + "content": text + })); + } + MessageContent::Parts(parts) => { + let mut content = Vec::new(); + for part in parts { + match part { + ContentPart::Text { content: text } => { + content.push(serde_json::json!({ + "type": "text", + "text": text + })); + } + ContentPart::Image { source } => { + let url = match source { + ContentPartSource::Url { value, .. } => value.clone(), + ContentPartSource::Data { value, mime_type, .. } => { + format!("data:{};base64,{}", mime_type, value) + } + }; + content.push(serde_json::json!({ + "type": "image_url", + "image_url": { "url": url } + })); + } + _ => {} + } + } + items.push(serde_json::json!({ + "role": msg.role.as_str(), + "content": content + })); + } + MessageContent::Null => { + // Handle tool calls in assistant messages + if let Some(tool_calls) = &msg.tool_calls { + items.push(serde_json::json!({ + "role": "assistant", + "tool_calls": tool_calls.iter().map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + }) + }).collect::>() + })); + } + } + } + } + + serde_json::json!(items) + } +} + +// ============================================================================ +// OpenAI SSE Event Types (deserialization structs - fields needed for serde) +// ============================================================================ + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum OpenAiStreamEvent { + #[serde(rename = "response.created")] + ResponseCreated { response: OpenAiResponse }, + #[serde(rename = "response.output_item.added")] + OutputItemAdded { + #[serde(default)] + output_index: usize, + item: OpenAiOutputItem, + }, + #[serde(rename = "response.content_part.added")] + ContentPartAdded { + item_id: String, + #[serde(default)] + output_index: usize, + #[serde(default)] + content_index: usize, + part: OpenAiContentPart, + }, + #[serde(rename = "response.output_text.delta")] + OutputTextDelta { + item_id: String, + #[serde(default)] + output_index: usize, + #[serde(default)] + content_index: usize, + delta: String, + }, + #[serde(rename = "response.output_text.done")] + OutputTextDone { + item_id: String, + #[serde(default)] + output_index: usize, + #[serde(default)] + content_index: usize, + text: String, + }, + #[serde(rename = "response.function_call_arguments.delta")] + FunctionCallArgumentsDelta { + item_id: String, + #[serde(default)] + output_index: usize, + #[serde(default)] + content_index: usize, + delta: String, + }, + #[serde(rename = "response.function_call_arguments.done")] + FunctionCallArgumentsDone { + item_id: String, + #[serde(default)] + output_index: usize, + #[serde(default)] + content_index: usize, + arguments: String, + }, + #[serde(rename = "response.completed")] + ResponseCompleted { response: OpenAiResponse }, + #[serde(rename = "error")] + Error { code: String, message: String }, + #[serde(other)] + Unknown, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenAiResponse { + id: String, + status: Option, + usage: Option, + #[serde(flatten)] + _extra: std::collections::HashMap, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + input_tokens: u32, + output_tokens: u32, + total_tokens: u32, + #[serde(flatten)] + _extra: std::collections::HashMap, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum OpenAiOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + status: Option, + }, + #[serde(rename = "function_call")] + FunctionCall { + id: String, + name: String, + arguments: String, + status: Option, + call_id: String, + }, + #[serde(other)] + Unknown, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum OpenAiContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(other)] + Unknown, +} + +// ============================================================================ +// Adapter Implementation +// ============================================================================ + +#[async_trait] +impl TextAdapter for OpenAiTextAdapter { + fn name(&self) -> &str { + "openai" + } + + fn model(&self) -> &str { + &self.model + } + + async fn chat_stream(&self, options: &TextOptions) -> AiResult { + let body = self.build_request_body(options, true); + let url = format!("{}/responses", self.base_url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "OpenAI API error ({}): {}", + status, text + ))); + } + + let model = self.model.clone(); + let byte_stream = response.bytes_stream(); + + // Convert SSE byte stream to StreamChunk stream + let chunk_stream = parse_openai_sse_stream(byte_stream, model); + + Ok(Box::pin(chunk_stream)) + } + + async fn structured_output( + &self, + options: &StructuredOutputOptions, + ) -> AiResult { + let mut body = self.build_request_body(&options.chat_options, false); + + // Add structured output format + if let Some(obj) = body.as_object_mut() { + obj.insert( + "text".to_string(), + serde_json::json!({ + "format": { + "type": "json_schema", + "name": "structured_output", + "schema": options.output_schema, + "strict": true + } + }), + ); + } + + let url = format!("{}/responses", self.base_url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!( + "OpenAI API error ({}): {}", + status, text + ))); + } + + let response_data: serde_json::Value = response.json().await?; + + // Extract text content from response + let text = extract_text_from_response(&response_data); + let data: serde_json::Value = serde_json::from_str(&text)?; + + Ok(StructuredOutputResult { + data, + raw_text: text, + }) + } +} + +/// Parse OpenAI SSE stream into StreamChunk events. +/// +/// Uses an unfolding state machine to buffer incomplete SSE lines +/// across HTTP response chunks. +fn parse_openai_sse_stream(stream: S, model: String) -> impl Stream> +where + S: Stream> + Send, +{ + use futures_util::StreamExt; + + let stream = Box::pin(stream); + futures_util::stream::unfold( + ( + stream, + String::new(), + HashMap::::new(), + HashMap::::new(), + ), + move |(mut stream, mut line_buf, mut item_to_call, mut item_to_tool)| { + let model = model.clone(); + async move { + loop { + match stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + line_buf.push_str(&text); + + let mut chunks: Vec> = Vec::new(); + let mut processed_to = 0; + + // Process complete lines (delimited by \n) + while let Some(newline_pos) = line_buf[processed_to..].find('\n') { + let abs_pos = processed_to + newline_pos; + let line = line_buf[processed_to..abs_pos].trim(); + processed_to = abs_pos + 1; + + if line.is_empty() || line.starts_with(':') { + continue; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + if data_str == "[DONE]" { + continue; + } + + match serde_json::from_str::(data_str) { + Ok(event) => { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.clone()); + + match event { + OpenAiStreamEvent::OutputItemAdded { item, .. } => match item { + OpenAiOutputItem::Message { id, role, .. } => { + chunks.push(Ok(StreamChunk::TextMessageStart { + timestamp: now, + message_id: id, + role, + model: model_str, + })); + } + OpenAiOutputItem::FunctionCall { + id, + name, + call_id, + .. + } => { + item_to_call.insert(id.clone(), call_id.clone()); + item_to_tool.insert(id.clone(), name.clone()); + chunks.push(Ok(StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: call_id, + tool_name: name, + parent_message_id: Some(id), + index: None, + provider_metadata: None, + model: model_str, + })); + } + _ => {} + }, + OpenAiStreamEvent::FunctionCallArgumentsDelta { + item_id, + delta, + .. + } => { + let tool_call_id = + item_to_call.get(&item_id).cloned().unwrap_or(item_id); + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta, + args: None, + model: model_str, + })); + } + OpenAiStreamEvent::FunctionCallArgumentsDone { + item_id, + arguments, + .. + } => { + let tool_call_id = + item_to_call.get(&item_id).cloned().unwrap_or(item_id.clone()); + let tool_name = + item_to_tool.get(&item_id).cloned().unwrap_or_default(); + let input: Option = + serde_json::from_str(&arguments).ok(); + chunks.push(Ok(StreamChunk::ToolCallEnd { + timestamp: now, + tool_call_id, + tool_name, + input, + result: None, + model: model_str, + })); + } + other => { + for chunk in convert_openai_event(other, &model) { + chunks.push(Ok(chunk)); + } + } + } + } + Err(e) => { + eprintln!( + "Failed to parse OpenAI event: {} - {}", + e, &data_str[..data_str.len().min(200)] + ); + } + } + } + } + + // Keep incomplete trailing data in buffer + line_buf = line_buf[processed_to..].to_string(); + + if !chunks.is_empty() { + return Some(( + futures_util::stream::iter(chunks), + (stream, line_buf, item_to_call, item_to_tool), + )); + } + } + Some(Err(e)) => { + return Some(( + futures_util::stream::iter(vec![Err(AiError::Http(e))]), + (stream, line_buf, item_to_call, item_to_tool), + )); + } + None => { + let line = line_buf.trim(); + if line.is_empty() { + return None; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + if data_str == "[DONE]" { + return None; + } + + match serde_json::from_str::(data_str) { + Ok(event) => { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.clone()); + let mut chunks: Vec> = Vec::new(); + + match event { + OpenAiStreamEvent::OutputItemAdded { item, .. } => match item { + OpenAiOutputItem::Message { id, role, .. } => { + chunks.push(Ok(StreamChunk::TextMessageStart { + timestamp: now, + message_id: id, + role, + model: model_str, + })); + } + OpenAiOutputItem::FunctionCall { + id, + name, + call_id, + .. + } => { + item_to_call.insert(id.clone(), call_id.clone()); + item_to_tool.insert(id.clone(), name.clone()); + chunks.push(Ok(StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: call_id, + tool_name: name, + parent_message_id: Some(id), + index: None, + provider_metadata: None, + model: model_str, + })); + } + _ => {} + }, + OpenAiStreamEvent::FunctionCallArgumentsDelta { + item_id, + delta, + .. + } => { + let tool_call_id = + item_to_call.get(&item_id).cloned().unwrap_or(item_id); + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta, + args: None, + model: model_str, + })); + } + OpenAiStreamEvent::FunctionCallArgumentsDone { + item_id, + arguments, + .. + } => { + let tool_call_id = item_to_call + .get(&item_id) + .cloned() + .unwrap_or(item_id.clone()); + let tool_name = + item_to_tool.get(&item_id).cloned().unwrap_or_default(); + let input: Option = + serde_json::from_str(&arguments).ok(); + chunks.push(Ok(StreamChunk::ToolCallEnd { + timestamp: now, + tool_call_id, + tool_name, + input, + result: None, + model: model_str, + })); + } + other => { + for chunk in convert_openai_event(other, &model) { + chunks.push(Ok(chunk)); + } + } + } + + return Some(( + futures_util::stream::iter(chunks), + (stream, String::new(), item_to_call, item_to_tool), + )); + } + Err(e) => { + eprintln!( + "Failed to parse OpenAI event: {} - {}", + e, + &data_str[..data_str.len().min(200)] + ); + } + } + } + + return None; + } + } + } + } + }, + ) + .flatten() +} + +/// Convert an OpenAI SSE event to one or more StreamChunks. +fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model_str = Some(model.to_string()); + + match event { + OpenAiStreamEvent::ResponseCreated { response } => { + vec![StreamChunk::RunStarted { + timestamp: now, + run_id: response.id, + thread_id: None, + model: model_str, + }] + } + + OpenAiStreamEvent::OutputItemAdded { item, .. } => match item { + OpenAiOutputItem::Message { id, role, .. } => { + vec![StreamChunk::TextMessageStart { + timestamp: now, + message_id: id, + role, + model: model_str, + }] + } + OpenAiOutputItem::FunctionCall { + id, + name, + call_id, + .. + } => { + vec![StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: call_id.clone(), + tool_name: name, + parent_message_id: Some(id), + index: None, + provider_metadata: None, + model: model_str, + }] + } + _ => vec![], + }, + + OpenAiStreamEvent::OutputTextDelta { item_id, delta, .. } => { + vec![StreamChunk::TextMessageContent { + timestamp: now, + message_id: item_id, + delta, + content: None, + model: model_str, + }] + } + + OpenAiStreamEvent::OutputTextDone { item_id, text, .. } => { + vec![StreamChunk::TextMessageContent { + timestamp: now, + message_id: item_id, + delta: String::new(), + content: Some(text), + model: model_str.clone(), + }] + } + + OpenAiStreamEvent::FunctionCallArgumentsDelta { + item_id, delta, .. + } => { + vec![StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id: item_id, + delta, + args: None, + model: model_str, + }] + } + + OpenAiStreamEvent::FunctionCallArgumentsDone { + item_id, + arguments, + .. + } => { + let input: Option = serde_json::from_str(&arguments).ok(); + vec![StreamChunk::ToolCallEnd { + timestamp: now, + tool_call_id: item_id, + tool_name: String::new(), + input, + result: None, + model: model_str, + }] + } + + OpenAiStreamEvent::ResponseCompleted { response } => { + let mut chunks = vec![]; + + // Emit text message end if we had text + // (This is a simplification — in production we'd track state) + + let finish_reason = response + .status + .as_deref() + .map(|s| match s { + "completed" => "stop".to_string(), + "failed" => "error".to_string(), + _ => s.to_string(), + }); + + let usage = response.usage.map(|u| Usage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: u.total_tokens, + }); + + chunks.push(StreamChunk::RunFinished { + timestamp: now, + run_id: response.id, + finish_reason, + usage, + model: model_str, + }); + + chunks + } + + OpenAiStreamEvent::Error { code, message } => { + vec![StreamChunk::RunError { + timestamp: now, + run_id: None, + error: RunErrorData { + message, + code: Some(code), + }, + model: model_str, + }] + } + + OpenAiStreamEvent::Unknown => vec![], + OpenAiStreamEvent::ContentPartAdded { .. } => vec![], + } +} + +/// Extract text content from an OpenAI response object. +fn extract_text_from_response(response: &serde_json::Value) -> String { + // Try output array (Responses API format) + if let Some(output) = response.get("output").and_then(|o| o.as_array()) { + for item in output { + if let Some(content) = item.get("content").and_then(|c| c.as_array()) { + for part in content { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + return text.to_string(); + } + } + } + } + } + + // Try output_text direct field + if let Some(text) = response.get("output_text").and_then(|t| t.as_str()) { + return text.to_string(); + } + + // Try choices (Chat Completions API format) + if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) { + if let Some(first) = choices.first() { + if let Some(message) = first.get("message") { + if let Some(text) = message.get("content").and_then(|c| c.as_str()) { + return text.to_string(); + } + } + } + } + + String::new() +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use futures_util::StreamExt; + + #[tokio::test] + async fn parse_openai_maps_item_id_to_call_id_across_events() { + let start = serde_json::json!({ + "type": "response.output_item.added", + "item": { + "type": "function_call", + "id": "item_1", + "name": "get_weather", + "arguments": "", + "status": "in_progress", + "call_id": "call_abc" + } + }) + .to_string(); + + let delta = serde_json::json!({ + "type": "response.function_call_arguments.delta", + "item_id": "item_1", + "delta": "{\"city\":\"San" + }) + .to_string(); + + let done = serde_json::json!({ + "type": "response.function_call_arguments.done", + "item_id": "item_1", + "arguments": "{\"city\":\"San Francisco\"}" + }) + .to_string(); + + let sse_1 = format!("data: {}\ndata: {}\n", start, delta); + let sse_2 = format!("data: {}", done); + + let stream = futures_util::stream::iter(vec![ + Ok(Bytes::from(sse_1)), + Ok(Bytes::from(sse_2)), + ]); + + let chunks = parse_openai_sse_stream(stream, "gpt-4o".to_string()) + .collect::>() + .await; + + assert_eq!(chunks.len(), 3); + + match &chunks[0] { + Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => assert_eq!(tool_call_id, "call_abc"), + other => panic!("unexpected chunk 0: {other:?}"), + } + + match &chunks[1] { + Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => assert_eq!(tool_call_id, "call_abc"), + other => panic!("unexpected chunk 1: {other:?}"), + } + + match &chunks[2] { + Ok(StreamChunk::ToolCallEnd { tool_call_id, tool_name, .. }) => { + assert_eq!(tool_call_id, "call_abc"); + assert_eq!(tool_name, "get_weather"); + } + other => panic!("unexpected chunk 2: {other:?}"), + } + } +} diff --git a/packages/rust/tanstack-ai/src/chat.rs b/packages/rust/tanstack-ai/src/chat.rs new file mode 100644 index 00000000..58999d07 --- /dev/null +++ b/packages/rust/tanstack-ai/src/chat.rs @@ -0,0 +1,546 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::adapter::TextAdapter; +use crate::error::AiResult; +use crate::messages::generate_message_id; +use crate::middleware::*; +use crate::tools::tool_calls::*; +use crate::types::*; + +/// Options for the chat function. +pub struct ChatOptions { + pub adapter: Arc, + pub messages: Vec, + pub system_prompts: Vec, + pub tools: Vec, + pub temperature: Option, + pub top_p: Option, + pub max_tokens: Option, + pub metadata: Option, + pub model_options: Option, + pub agent_loop_strategy: Option, + pub conversation_id: Option, + pub middleware: Vec>, + pub stream: bool, + pub output_schema: Option, +} + +/// Result from the chat engine. +pub enum ChatResult { + /// Collected chunks (used internally; callers get the stream directly via chat_stream). + Chunks(Vec), + /// Collected text content. + Text(String), + /// Structured output. + Structured(serde_json::Value), +} + +/// Chat with streaming - returns a Vec of all chunks after the full agentic loop completes. +/// +/// This is the main entry point that handles: +/// 1. Streaming text (single or multi-iteration) +/// 2. Automatic tool execution when finish_reason is 'tool_calls' +/// 3. Agent loop with configurable strategy +/// 4. Middleware lifecycle hooks +pub async fn chat(options: ChatOptions) -> AiResult { + if options.output_schema.is_some() { + return run_agentic_structured_output(options).await; + } + + if !options.stream { + return run_non_streaming_text(options).await; + } + + run_streaming_text(options).await +} + +/// Run the full agentic streaming loop, collecting all chunks. +async fn run_streaming_text(options: ChatOptions) -> AiResult { + let mut engine = TextEngine::new(options); + let chunks = engine.run().await?; + Ok(ChatResult::Chunks(chunks)) +} + +/// Run non-streaming: collect all chunks, extract text content. +async fn run_non_streaming_text(options: ChatOptions) -> AiResult { + let mut engine = TextEngine::new(options); + let chunks = engine.run().await?; + + let mut content = String::new(); + for chunk in &chunks { + if let StreamChunk::TextMessageContent { + delta, content: full, .. + } = chunk + { + if let Some(full_content) = full { + content = full_content.clone(); + } else { + content.push_str(delta); + } + } + } + + Ok(ChatResult::Text(content)) +} + +/// Run agentic structured output: +/// 1. Execute the full agentic loop (with tools) +/// 2. Once complete, call adapter.structured_output with the conversation context +/// 3. Return the structured result +async fn run_agentic_structured_output(options: ChatOptions) -> AiResult { + let schema = options.output_schema.clone() + .expect("run_agentic_structured_output called without output_schema"); + let adapter = options.adapter.clone(); + let model_name = adapter.model().to_string(); + + // Run the agentic loop without the schema + let mut engine = TextEngine::new(ChatOptions { + output_schema: None, + adapter: options.adapter.clone(), + messages: options.messages, + system_prompts: options.system_prompts.clone(), + tools: options.tools, + temperature: options.temperature, + top_p: options.top_p, + max_tokens: options.max_tokens, + metadata: options.metadata, + model_options: options.model_options.clone(), + agent_loop_strategy: options.agent_loop_strategy, + conversation_id: options.conversation_id, + middleware: options.middleware, + stream: true, + }); + + // Consume the agentic loop + let _chunks = engine.run().await?; + + // Get final messages + let final_messages = engine.messages.clone(); + let system_prompts = engine.system_prompts.clone(); + let temperature = engine.temperature; + let top_p = engine.top_p; + let max_tokens = engine.max_tokens; + + let structured_options = StructuredOutputOptions { + chat_options: TextOptions { + model: model_name, + messages: final_messages, + tools: Vec::new(), + system_prompts, + temperature, + top_p, + max_tokens, + output_schema: Some(schema.clone()), + ..Default::default() + }, + output_schema: schema, + }; + + let result = adapter.structured_output(&structured_options).await?; + Ok(ChatResult::Structured(result.data)) +} + +/// Core text engine with full agentic loop. +/// +/// Manages: +/// - Multi-iteration streaming with tool execution +/// - Middleware lifecycle hooks +/// - Agent loop strategy +/// - Tool call accumulation and execution +pub struct TextEngine { + adapter: Arc, + pub messages: Vec, + pub system_prompts: Vec, + pub tools: Vec, + pub temperature: Option, + pub top_p: Option, + pub max_tokens: Option, + metadata: Option, + model_options: Option, + loop_strategy: AgentLoopStrategy, + tool_call_manager: ToolCallManager, + middleware_runner: MiddlewareRunner, + request_id: String, + stream_id: String, + conversation_id: Option, + iteration_count: u32, + last_finish_reason: Option, +} + +impl TextEngine { + pub fn new(options: ChatOptions) -> Self { + let loop_strategy = options + .agent_loop_strategy + .unwrap_or_else(|| max_iterations(5)); + + let tool_manager = ToolCallManager::new(); + let middleware_runner = MiddlewareRunner::new(options.middleware); + + Self { + adapter: options.adapter, + messages: options.messages, + system_prompts: options.system_prompts, + tools: options.tools, + temperature: options.temperature, + top_p: options.top_p, + max_tokens: options.max_tokens, + metadata: options.metadata, + model_options: options.model_options, + loop_strategy, + tool_call_manager: tool_manager, + middleware_runner, + request_id: generate_message_id("chat"), + stream_id: generate_message_id("stream"), + conversation_id: options.conversation_id, + iteration_count: 0, + last_finish_reason: None, + } + } + + /// Run the full agentic loop, returning all chunks from all iterations. + pub async fn run(&mut self) -> AiResult> { + let mut all_chunks = Vec::new(); + let stream_start = std::time::Instant::now(); + + // Build middleware context + let mut middleware_ctx = ChatMiddlewareContext { + request_id: self.request_id.clone(), + stream_id: self.stream_id.clone(), + conversation_id: self.conversation_id.clone(), + phase: ChatMiddlewarePhase::Init, + iteration: 0, + chunk_index: 0, + context: None, + provider: self.adapter.name().to_string(), + model: self.adapter.model().to_string(), + source: "server".to_string(), + streaming: true, + system_prompts: self.system_prompts.clone(), + tool_names: Some(self.tools.iter().map(|t| t.name.clone()).collect()), + options: None, + model_options: None, + message_count: self.messages.len(), + has_tools: !self.tools.is_empty(), + current_message_id: None, + accumulated_content: String::new(), + }; + + // Run onStart middleware + self.middleware_runner.run_on_start(&middleware_ctx); + + // Agent loop + loop { + // Check loop strategy + let state = AgentLoopState { + iteration_count: self.iteration_count, + messages: self.messages.clone(), + finish_reason: self.last_finish_reason.clone(), + }; + + if !(self.loop_strategy)(&state) { + break; + } + + // Update middleware context for this iteration + let message_id = generate_message_id("msg"); + middleware_ctx.phase = ChatMiddlewarePhase::BeforeModel; + middleware_ctx.iteration = self.iteration_count; + middleware_ctx.current_message_id = Some(message_id.clone()); + middleware_ctx.accumulated_content = String::new(); + + self.middleware_runner.run_on_iteration( + &middleware_ctx, + &IterationInfo { + iteration: self.iteration_count, + message_id: message_id.clone(), + }, + ); + + // Run onConfig middleware + let config = self.build_middleware_config(); + let transformed = self.middleware_runner.run_on_config(&middleware_ctx, config); + + // Build text options for this iteration from middleware-transformed config + let text_options = self.build_text_options_from_config(&transformed); + + // Stream from adapter + middleware_ctx.phase = ChatMiddlewarePhase::ModelStream; + let mut stream = self.adapter.chat_stream(&text_options).await?; + let mut accumulated_content = String::new(); + let mut iteration_chunks = Vec::new(); + + // Process chunks from this iteration + use futures_util::StreamExt; + while let Some(result) = stream.next().await { + match result { + Ok(chunk) => { + // Run onChunk middleware + let output_chunks = + self.middleware_runner.run_on_chunk(&middleware_ctx, chunk); + + for output_chunk in output_chunks { + // Track state + self.process_chunk(&output_chunk, &mut accumulated_content); + + // Update middleware context + middleware_ctx.accumulated_content = accumulated_content.clone(); + middleware_ctx.chunk_index += 1; + + iteration_chunks.push(output_chunk); + } + } + Err(e) => { + let error_info = ErrorInfo { + error: e.to_string(), + duration_ms: stream_start.elapsed().as_millis(), + }; + self.middleware_runner + .run_on_error(&middleware_ctx, &error_info); + return Err(e); + } + } + } + + all_chunks.extend(iteration_chunks); + + // Check if we need to execute tools + if self.last_finish_reason.as_deref() == Some("tool_calls") + && self.tool_call_manager.has_tool_calls() + { + middleware_ctx.phase = ChatMiddlewarePhase::BeforeTools; + + // Execute tools + let tool_calls = self.tool_call_manager.tool_calls(); + let results = self.execute_tool_calls(&tool_calls).await?; + + middleware_ctx.phase = ChatMiddlewarePhase::AfterTools; + + // Build tool result chunks and update messages + let tool_result_chunks = + self.build_tool_result_chunks(&results, &message_id); + + // Add assistant message with tool calls + self.messages.push(ModelMessage { + role: MessageRole::Assistant, + content: if accumulated_content.is_empty() { + MessageContent::Null + } else { + MessageContent::Text(accumulated_content.clone()) + }, + name: None, + tool_calls: Some(tool_calls), + tool_call_id: None, + }); + + // Add tool result messages + for result in &results { + let content_str = serde_json::to_string(&result.result) + .unwrap_or_else(|_| result.result.to_string()); + self.messages.push(ModelMessage { + role: MessageRole::Tool, + content: MessageContent::Text(content_str), + name: None, + tool_calls: None, + tool_call_id: Some(result.tool_call_id.clone()), + }); + } + + all_chunks.extend(tool_result_chunks); + + // Clear tool call manager for next iteration + self.tool_call_manager.clear(); + self.iteration_count += 1; + self.last_finish_reason = None; + + // Continue the loop + continue; + } + + // finish_reason is 'tool_calls' but no tool calls were accumulated — + // the model may have more to say, continue the loop + if self.last_finish_reason.as_deref() == Some("tool_calls") { + self.iteration_count += 1; + self.last_finish_reason = None; + continue; + } + + // No tool calls or finish reason is not tool_calls — we're done + break; + } + + // Run onFinish middleware + let finish_info = FinishInfo { + finish_reason: self.last_finish_reason.clone(), + duration_ms: stream_start.elapsed().as_millis(), + content: accumulated_content_from_chunks(&all_chunks), + usage: None, + }; + self.middleware_runner + .run_on_finish(&middleware_ctx, &finish_info); + + Ok(all_chunks) + } + + fn build_text_options_from_config(&self, config: &ChatMiddlewareConfig) -> TextOptions { + TextOptions { + model: self.adapter.model().to_string(), + messages: config.messages.clone(), + tools: config.tools.clone(), + system_prompts: config.system_prompts.clone(), + temperature: config.temperature, + top_p: config.top_p, + max_tokens: config.max_tokens, + metadata: config.metadata.clone(), + model_options: config.model_options.clone(), + output_schema: None, + conversation_id: self.conversation_id.clone(), + agent_loop_strategy: None, + } + } + + fn build_middleware_config(&self) -> ChatMiddlewareConfig { + ChatMiddlewareConfig { + messages: self.messages.clone(), + system_prompts: self.system_prompts.clone(), + tools: self.tools.clone(), + temperature: self.temperature, + top_p: self.top_p, + max_tokens: self.max_tokens, + metadata: self.metadata.clone(), + model_options: self.model_options.clone(), + } + } + + /// Process a single chunk, updating internal state. + fn process_chunk(&mut self, chunk: &StreamChunk, accumulated_content: &mut String) { + match chunk { + StreamChunk::TextMessageContent { delta, content, .. } => { + if let Some(full) = content { + *accumulated_content = full.clone(); + } else { + accumulated_content.push_str(delta); + } + } + + StreamChunk::ToolCallStart { .. } => { + self.tool_call_manager.add_start_event(chunk); + } + + StreamChunk::ToolCallArgs { .. } => { + self.tool_call_manager.add_args_event(chunk); + } + + StreamChunk::ToolCallEnd { .. } => { + self.tool_call_manager.complete_tool_call(chunk); + } + + StreamChunk::RunFinished { finish_reason, .. } => { + self.last_finish_reason = finish_reason.clone(); + } + + StreamChunk::RunError { error, .. } => { + self.last_finish_reason = Some("error".to_string()); + eprintln!("Run error: {}", error.message); + } + + _ => {} + } + } + + /// Execute all pending tool calls. + async fn execute_tool_calls(&self, tool_calls: &[ToolCall]) -> AiResult> { + let approvals: HashMap = HashMap::new(); + let client_results: HashMap = HashMap::new(); + + let result = crate::tools::tool_calls::execute_tool_calls( + tool_calls, + &self.tools, + &approvals, + &client_results, + None, + ) + .await?; + + // Combine server results with error results for client tools + let mut all_results = result.results; + + // Handle tools that need client execution (mark as error for server-side) + for client_tool in result.needs_client_execution { + all_results.push(ToolResult { + tool_call_id: client_tool.tool_call_id, + tool_name: client_tool.tool_name, + result: serde_json::json!({"error": "Client-side tool execution not available"}), + state: ToolResultOutputState::Error, + duration_ms: None, + }); + } + + // Handle tools that need approval (mark as error) + for approval in result.needs_approval { + all_results.push(ToolResult { + tool_call_id: approval.tool_call_id, + tool_name: approval.tool_name, + result: serde_json::json!({"error": "Tool requires approval"}), + state: ToolResultOutputState::Error, + duration_ms: None, + }); + } + + Ok(all_results) + } + + /// Build TOOL_CALL_END chunks for tool results. + fn build_tool_result_chunks( + &self, + results: &[ToolResult], + _message_id: &str, + ) -> Vec { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model = Some(self.adapter.model().to_string()); + + results + .iter() + .map(|result| { + let result_str = serde_json::to_string(&result.result).ok(); + StreamChunk::ToolCallEnd { + timestamp: now, + tool_call_id: result.tool_call_id.clone(), + tool_name: result.tool_name.clone(), + input: None, + result: result_str, + model: model.clone(), + } + }) + .collect() + } + + /// Get accumulated content from the last text message. + pub fn content(&self) -> String { + self.messages + .iter() + .rev() + .find(|m| m.role == MessageRole::Assistant) + .and_then(|m| m.content.as_str()) + .map(|s| s.to_string()) + .unwrap_or_default() + } +} + +/// Extract accumulated text content from all chunks. +fn accumulated_content_from_chunks(chunks: &[StreamChunk]) -> String { + let mut content = String::new(); + for chunk in chunks { + if let StreamChunk::TextMessageContent { + delta, content: full, .. + } = chunk + { + if let Some(full_content) = full { + content = full_content.clone(); + } else { + content.push_str(delta); + } + } + } + content +} diff --git a/packages/rust/tanstack-ai/src/client/connection_adapters.rs b/packages/rust/tanstack-ai/src/client/connection_adapters.rs new file mode 100644 index 00000000..545388ad --- /dev/null +++ b/packages/rust/tanstack-ai/src/client/connection_adapters.rs @@ -0,0 +1,586 @@ +use bytes::Bytes; +use futures_core::Stream; +use futures_util::StreamExt; +use reqwest::Client; +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::{broadcast, Mutex, RwLock}; + +use crate::error::{AiError, AiResult}; +use crate::stream::StreamProcessor; +use crate::types::*; + +/// Connection type for the chat client. +pub enum ConnectionAdapter { + /// Server-Sent Events via fetch. + ServerSentEvents { url: String, headers: HashMap }, + /// HTTP streaming. + HttpStream { url: String, headers: HashMap }, + /// Custom stream provider. + Custom { + provider: Box) -> Pin> + Send>> + Send + Sync>, + }, +} + +/// Subscription handle for chat client events. +pub type SubscriptionId = u64; + +/// Chat state that clients can subscribe to. +#[derive(Debug, Clone)] +pub struct ChatState { + pub messages: Vec, + pub is_loading: bool, + pub error: Option, + pub accumulated_content: String, +} + +impl Default for ChatState { + fn default() -> Self { + Self { + messages: Vec::new(), + is_loading: false, + error: None, + accumulated_content: String::new(), + } + } +} + +/// Headless chat client with framework-agnostic state management. +/// +/// Supports connection adapters (SSE, HTTP stream, custom) and +/// provides subscription-based architecture for real-time updates. +pub struct ChatClient { + state: Arc>, + connection: ConnectionAdapter, + processor: Arc>, + next_subscription_id: Arc>, + subscribers: Arc>>>, + client: Client, +} + +impl ChatClient { + /// Create a new chat client with a connection adapter. + pub fn new(connection: ConnectionAdapter) -> Self { + Self { + state: Arc::new(RwLock::new(ChatState::default())), + connection, + processor: Arc::new(Mutex::new(StreamProcessor::new())), + next_subscription_id: Arc::new(Mutex::new(0)), + subscribers: Arc::new(RwLock::new(HashMap::new())), + client: Client::new(), + } + } + + /// Subscribe to state changes. Returns a broadcast receiver. + pub async fn subscribe(&self) -> broadcast::Receiver { + let mut id_guard = self.next_subscription_id.lock().await; + *id_guard += 1; + + let (tx, rx) = broadcast::channel(100); + self.subscribers.write().await.insert(*id_guard, tx); + + rx + } + + /// Get the current state snapshot. + pub async fn get_state(&self) -> ChatState { + self.state.read().await.clone() + } + + /// Send a message and stream the response. + pub async fn send(&self, content: impl Into) -> AiResult<()> { + let user_content = content.into(); + + // Add user message + { + let mut state = self.state.write().await; + state.messages.push(UiMessage { + id: generate_message_id("msg"), + role: UiMessageRole::User, + parts: vec![MessagePart::Text { + content: user_content.clone(), + metadata: None, + }], + created_at: Some(chrono::Utc::now()), + }); + state.is_loading = true; + state.error = None; + } + + self.notify_subscribers().await; + + // Build messages for the provider + let messages = { + let state = self.state.read().await; + crate::messages::ui_messages_to_model_messages(&state.messages) + }; + + // Stream the response + match &self.connection { + ConnectionAdapter::ServerSentEvents { url, headers } => { + self.stream_via_sse(url, headers, messages).await + } + ConnectionAdapter::HttpStream { url, headers } => { + self.stream_via_http(url, headers, messages).await + } + ConnectionAdapter::Custom { provider } => { + let mut stream = provider(messages); + self.process_stream(&mut stream).await + } + } + } + + /// Stop the current generation. + pub async fn stop(&self) { + let mut state = self.state.write().await; + state.is_loading = false; + self.notify_subscribers().await; + } + + /// Clear all messages. + pub async fn clear(&self) { + let mut state = self.state.write().await; + state.messages.clear(); + state.accumulated_content.clear(); + state.error = None; + self.notify_subscribers().await; + } + + async fn stream_via_sse( + &self, + url: &str, + headers: &HashMap, + messages: Vec, + ) -> AiResult<()> { + let body = serde_json::json!({ "messages": messages }); + + let mut request = self.client.post(url).json(&body); + for (key, value) in headers { + request = request.header(key.as_str(), value.as_str()); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(format!("HTTP {}: {}", status, text)); + self.notify_subscribers().await; + return Err(AiError::Provider(format!("HTTP {}: {}", status, text))); + } + + let byte_stream = response.bytes_stream(); + + // Parse SSE events from byte stream + let chunk_stream = parse_sse_to_chunks(byte_stream); + futures_util::pin_mut!(chunk_stream); + + self.process_stream(&mut chunk_stream).await + } + + async fn stream_via_http( + &self, + url: &str, + headers: &HashMap, + messages: Vec, + ) -> AiResult<()> { + let body = serde_json::json!({ "messages": messages }); + + let mut request = self.client.post(url).json(&body); + for (key, value) in headers { + request = request.header(key.as_str(), value.as_str()); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(AiError::Provider(format!("HTTP {}: {}", status, text))); + } + + let byte_stream = response.bytes_stream(); + + // Parse NDJSON events + let chunk_stream = parse_ndjson_to_chunks(byte_stream); + futures_util::pin_mut!(chunk_stream); + + self.process_stream(&mut chunk_stream).await + } + + async fn process_stream(&self, stream: &mut Pin> + Send>>) -> AiResult<()> + { + let mut processor = self.processor.lock().await; + + while let Some(result) = stream.next().await { + match result { + Ok(chunk) => { + if let Some(processed) = processor.process_chunk(chunk) { + self.apply_chunk(&processed).await; + } + } + Err(e) => { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(e.to_string()); + self.notify_subscribers().await; + return Err(e); + } + } + } + + let mut state = self.state.write().await; + state.is_loading = false; + self.notify_subscribers().await; + + Ok(()) + } + + async fn apply_chunk(&self, chunk: &StreamChunk) { + let mut state = self.state.write().await; + + match chunk { + StreamChunk::TextMessageContent { delta, content, .. } => { + if let Some(full) = content { + state.accumulated_content = full.clone(); + } else { + state.accumulated_content.push_str(delta); + } + + // Update or create the assistant message + let new_content = state.accumulated_content.clone(); + if let Some(last) = state.messages.last_mut() { + if last.role == UiMessageRole::Assistant { + if let Some(MessagePart::Text { content, .. }) = last.parts.last_mut() { + *content = new_content; + } + } + } + } + + StreamChunk::TextMessageStart { role, .. } => { + let ui_role = match role.as_str() { + "system" => UiMessageRole::System, + "assistant" => UiMessageRole::Assistant, + _ => UiMessageRole::Assistant, + }; + state.messages.push(UiMessage { + id: generate_message_id("msg"), + role: ui_role, + parts: vec![MessagePart::Text { + content: String::new(), + metadata: None, + }], + created_at: Some(chrono::Utc::now()), + }); + } + + StreamChunk::ToolCallStart { tool_call_id, tool_name, .. } => { + if let Some(last) = state.messages.last_mut() { + if last.role == UiMessageRole::Assistant { + last.parts.push(MessagePart::ToolCall { + id: tool_call_id.clone(), + name: tool_name.clone(), + arguments: String::new(), + state: ToolCallState::AwaitingInput, + approval: None, + output: None, + }); + } + } + } + + StreamChunk::ToolCallArgs { tool_call_id, delta, .. } => { + if let Some(last) = state.messages.last_mut() { + for part in &mut last.parts { + if let MessagePart::ToolCall { id, arguments, state: tc_state, .. } = part { + if id == tool_call_id { + arguments.push_str(delta); + *tc_state = ToolCallState::InputStreaming; + break; + } + } + } + } + } + + StreamChunk::ToolCallEnd { tool_call_id, input, result, .. } => { + if let Some(last) = state.messages.last_mut() { + for part in &mut last.parts { + if let MessagePart::ToolCall { id, state: tc_state, arguments, output, .. } = part { + if id == tool_call_id { + *tc_state = ToolCallState::InputComplete; + if let Some(input_val) = input { + *arguments = serde_json::to_string(input_val).unwrap_or_default(); + } + if let Some(result_str) = result { + *output = serde_json::from_str(result_str).ok(); + } + break; + } + } + } + } + } + + StreamChunk::RunError { error, .. } => { + state.is_loading = false; + state.error = Some(error.message.clone()); + } + + _ => {} + } + + self.notify_subscribers().await; + } + + async fn notify_subscribers(&self) { + let state = self.state.read().await; + let subscribers = self.subscribers.read().await; + + // Remove closed subscribers + let mut to_remove = Vec::new(); + for (id, tx) in subscribers.iter() { + if tx.send(state.clone()).is_err() { + to_remove.push(*id); + } + } + + if !to_remove.is_empty() { + drop(subscribers); + let mut subs = self.subscribers.write().await; + for id in to_remove { + subs.remove(&id); + } + } + } +} + +/// Generate a unique message ID. +fn generate_message_id(prefix: &str) -> String { + format!( + "{}-{}-{}", + prefix, + chrono::Utc::now().timestamp_millis(), + &uuid::Uuid::new_v4().to_string()[..8] + ) +} + +/// Parse SSE byte stream into StreamChunk events. +fn parse_sse_to_chunks( + stream: S, +) -> Pin> + Send>> +where + S: Stream> + Send + 'static, +{ + use futures_util::StreamExt; + + futures_util::stream::unfold( + (Box::pin(stream), String::new()), + |(mut stream, mut line_buf)| async move { + loop { + match stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + line_buf.push_str(&text); + + let mut chunks = Vec::new(); + let mut processed_to = 0; + + while let Some(newline_pos) = line_buf[processed_to..].find('\n') { + let abs_pos = processed_to + newline_pos; + let line = line_buf[processed_to..abs_pos].trim(); + processed_to = abs_pos + 1; + + if line.is_empty() || line.starts_with(':') { + continue; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + if data_str == "[DONE]" { + continue; + } + match serde_json::from_str::(data_str) { + Ok(chunk) => chunks.push(Ok(chunk)), + Err(e) => { + chunks.push(Err(AiError::Stream(format!( + "Failed to parse SSE chunk: {}", + e + )))); + } + } + } + } + + line_buf = line_buf[processed_to..].to_string(); + + if !chunks.is_empty() { + return Some((futures_util::stream::iter(chunks), (stream, line_buf))); + } + } + Some(Err(e)) => { + return Some(( + futures_util::stream::iter(vec![Err(AiError::Http(e))]), + (stream, line_buf), + )); + } + None => { + let line = line_buf.trim(); + if line.is_empty() { + return None; + } + + if let Some(data_str) = line.strip_prefix("data: ") { + if data_str == "[DONE]" { + return None; + } + let parsed = serde_json::from_str::(data_str) + .map_err(|e| { + AiError::Stream(format!("Failed to parse SSE chunk: {}", e)) + }); + return Some((futures_util::stream::iter(vec![parsed]), (stream, String::new()))); + } + + return Some(( + futures_util::stream::iter(vec![Err(AiError::Stream( + "Failed to parse SSE chunk: missing data prefix".to_string(), + ))]), + (stream, String::new()), + )); + } + } + } + }, + ) + .flatten() + .boxed() +} + +/// Parse NDJSON byte stream into StreamChunk events. +fn parse_ndjson_to_chunks( + stream: S, +) -> Pin> + Send>> +where + S: Stream> + Send + 'static, +{ + use futures_util::StreamExt; + + futures_util::stream::unfold( + (Box::pin(stream), String::new()), + |(mut stream, mut line_buf)| async move { + loop { + match stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + line_buf.push_str(&text); + + let mut chunks = Vec::new(); + let mut processed_to = 0; + + while let Some(newline_pos) = line_buf[processed_to..].find('\n') { + let abs_pos = processed_to + newline_pos; + let line = line_buf[processed_to..abs_pos].trim(); + processed_to = abs_pos + 1; + + if line.is_empty() { + continue; + } + + match serde_json::from_str::(line) { + Ok(chunk) => chunks.push(Ok(chunk)), + Err(e) => { + chunks.push(Err(AiError::Stream(format!( + "Failed to parse NDJSON chunk: {}", + e + )))); + } + } + } + + line_buf = line_buf[processed_to..].to_string(); + + if !chunks.is_empty() { + return Some((futures_util::stream::iter(chunks), (stream, line_buf))); + } + } + Some(Err(e)) => { + return Some(( + futures_util::stream::iter(vec![Err(AiError::Http(e))]), + (stream, line_buf), + )); + } + None => { + let line = line_buf.trim(); + if line.is_empty() { + return None; + } + + let parsed = serde_json::from_str::(line) + .map_err(|e| AiError::Stream(format!("Failed to parse NDJSON chunk: {}", e))); + return Some((futures_util::stream::iter(vec![parsed]), (stream, String::new()))); + } + } + } + }, + ) + .flatten() + .boxed() +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::StreamExt; + + #[tokio::test] + async fn parse_sse_handles_final_line_without_newline() { + let json = serde_json::to_string(&StreamChunk::RunStarted { + timestamp: 1.0, + run_id: "run_1".to_string(), + thread_id: None, + model: Some("test-model".to_string()), + }) + .expect("serialize stream chunk"); + + let sse = format!("data: {}", json); + let stream = futures_util::stream::iter(vec![Ok(Bytes::from(sse))]); + + let parsed = parse_sse_to_chunks(stream).collect::>().await; + assert_eq!(parsed.len(), 1); + + match &parsed[0] { + Ok(StreamChunk::RunStarted { run_id, .. }) => assert_eq!(run_id, "run_1"), + other => panic!("unexpected parsed chunk: {other:?}"), + } + } + + #[tokio::test] + async fn parse_ndjson_handles_split_final_line_without_newline() { + let json = serde_json::to_string(&StreamChunk::RunFinished { + timestamp: 2.0, + run_id: "run_2".to_string(), + finish_reason: Some("stop".to_string()), + usage: None, + model: Some("test-model".to_string()), + }) + .expect("serialize stream chunk"); + + let split = json.len() / 2; + let stream = futures_util::stream::iter(vec![ + Ok(Bytes::from(json[..split].to_string())), + Ok(Bytes::from(json[split..].to_string())), + ]); + + let parsed = parse_ndjson_to_chunks(stream).collect::>().await; + assert_eq!(parsed.len(), 1); + + match &parsed[0] { + Ok(StreamChunk::RunFinished { run_id, finish_reason, .. }) => { + assert_eq!(run_id, "run_2"); + assert_eq!(finish_reason.as_deref(), Some("stop")); + } + other => panic!("unexpected parsed chunk: {other:?}"), + } + } +} diff --git a/packages/rust/tanstack-ai/src/client/mod.rs b/packages/rust/tanstack-ai/src/client/mod.rs new file mode 100644 index 00000000..cf1a88fc --- /dev/null +++ b/packages/rust/tanstack-ai/src/client/mod.rs @@ -0,0 +1,3 @@ +pub mod connection_adapters; + +pub use connection_adapters::*; diff --git a/packages/rust/tanstack-ai/src/error.rs b/packages/rust/tanstack-ai/src/error.rs new file mode 100644 index 00000000..562a72cf --- /dev/null +++ b/packages/rust/tanstack-ai/src/error.rs @@ -0,0 +1,60 @@ +use thiserror::Error; + +/// Result type alias using `AiError`. +pub type AiResult = Result; + +/// Errors that can occur in the TanStack AI SDK. +#[derive(Error, Debug)] +pub enum AiError { + /// Error from the underlying AI provider. + #[error("provider error: {0}")] + Provider(String), + + /// HTTP request error. + #[error("http error: {0}")] + Http(#[from] reqwest::Error), + + /// JSON serialization/deserialization error. + #[error("json error: {0}")] + Json(#[from] serde_json::Error), + + /// Tool execution error. + #[error("tool execution error: {0}")] + ToolExecution(String), + + /// Unknown tool requested by the model. + #[error("unknown tool: {0}")] + UnknownTool(String), + + /// Input validation error. + #[error("validation error: {0}")] + Validation(String), + + /// Stream processing error. + #[error("stream error: {0}")] + Stream(String), + + /// Request was aborted. + #[error("request aborted: {0}")] + Aborted(String), + + /// Agent loop exceeded maximum iterations. + #[error("agent loop exceeded max iterations ({0})")] + MaxIterationsExceeded(u32), + + /// Generic error. + #[error("{0}")] + Other(String), +} + +impl From for AiError { + fn from(s: String) -> Self { + AiError::Other(s) + } +} + +impl From<&str> for AiError { + fn from(s: &str) -> Self { + AiError::Other(s.to_string()) + } +} diff --git a/packages/rust/tanstack-ai/src/lib.rs b/packages/rust/tanstack-ai/src/lib.rs new file mode 100644 index 00000000..dfeebe7f --- /dev/null +++ b/packages/rust/tanstack-ai/src/lib.rs @@ -0,0 +1,237 @@ +//! # tanstack-ai +//! +//! Type-safe, provider-agnostic AI SDK for Rust. +//! +//! This crate provides a comprehensive toolkit for building AI-powered applications +//! with support for multiple providers (OpenAI, Anthropic, Gemini), streaming responses, +//! tool calling with agent loops, and framework-agnostic state management. +//! +//! ## Core Concepts +//! +//! ### Adapters +//! Adapters bridge the SDK to specific AI providers. Each provider has its own adapter: +//! - [`OpenAiTextAdapter`] — OpenAI Responses API +//! - [`AnthropicTextAdapter`] — Anthropic Messages API +//! - [`GeminiTextAdapter`] — Google Gemini API +//! +//! ### Chat +//! The [`chat`] function is the main entry point for text generation: +//! - Streaming with automatic tool execution (agentic loop) +//! - Non-streaming (collected text) +//! - Structured output with schema validation +//! +//! ### Tools +//! Tools enable the model to interact with external systems: +//! - [`Tool`] — Tool definition with schemas and execute function +//! - [`ToolDefinition`] — Builder for creating tools +//! - [`ToolRegistry`] — Registry for managing available tools +//! +//! ### Streaming +//! The streaming system uses the AG-UI protocol with typed events: +//! - [`StreamChunk`] — Union of all AG-UI event types +//! - [`StreamProcessor`] — State machine for processing streams +//! - Chunk strategies for controlling emission frequency +//! +//! ### Middleware +//! The middleware system allows observing and transforming chat behavior: +//! - [`ChatMiddleware`] — Trait for implementing middleware +//! - [`MiddlewareRunner`] — Composes multiple middlewares +//! +//! ### Chat Client +//! The [`ChatClient`] provides headless state management: +//! - Connection adapters (SSE, HTTP stream, custom) +//! - Subscription-based state updates +//! - Framework-agnostic design + +pub mod adapter; +pub mod adapters; +pub mod chat; +pub mod client; +pub mod error; +pub mod messages; +pub mod middleware; +pub mod stream; +pub mod stream_response; +pub mod tools; +pub mod types; + +// Re-export main types and functions +pub use adapter::*; +pub use adapters::*; +pub use chat::*; +pub use client::*; +pub use error::*; +pub use messages::*; +pub use middleware::*; +pub use stream::*; +pub use stream_response::*; +pub use tools::*; +pub use types::*; + +/// Convenience function to create a chat with an OpenAI model. +pub fn openai_text( + model: impl Into, + api_key: impl Into, +) -> OpenAiTextAdapter { + OpenAiTextAdapter::new(model, api_key) +} + +/// Convenience function to create a chat with an Anthropic model. +pub fn anthropic_text( + model: impl Into, + api_key: impl Into, +) -> AnthropicTextAdapter { + AnthropicTextAdapter::new(model, api_key) +} + +/// Convenience function to create a chat with a Gemini model. +pub fn gemini_text( + model: impl Into, + api_key: impl Into, +) -> GeminiTextAdapter { + GeminiTextAdapter::new(model, api_key) +} + +/// Create adapter from environment variables. +/// +/// Looks for provider-specific API keys: +/// - OpenAI: `OPENAI_API_KEY` +/// - Anthropic: `ANTHROPIC_API_KEY` +/// - Gemini: `GEMINI_API_KEY` or `GOOGLE_API_KEY` +pub fn openai_text_from_env(model: impl Into) -> AiResult { + let api_key = std::env::var("OPENAI_API_KEY") + .map_err(|_| AiError::Other("OPENAI_API_KEY not set".to_string()))?; + Ok(OpenAiTextAdapter::new(model, api_key)) +} + +/// Create Anthropic adapter from environment variables. +pub fn anthropic_text_from_env(model: impl Into) -> AiResult { + let api_key = std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| AiError::Other("ANTHROPIC_API_KEY not set".to_string()))?; + Ok(AnthropicTextAdapter::new(model, api_key)) +} + +/// Create Gemini adapter from environment variables. +pub fn gemini_text_from_env(model: impl Into) -> AiResult { + let api_key = std::env::var("GEMINI_API_KEY") + .or_else(|_| std::env::var("GOOGLE_API_KEY")) + .map_err(|_| AiError::Other("GEMINI_API_KEY or GOOGLE_API_KEY not set".to_string()))?; + Ok(GeminiTextAdapter::new(model, api_key)) +} + +/// Detect image MIME type from bytes. +pub fn detect_image_mime_type(data: &[u8]) -> &'static str { + if data.starts_with(&[0x89, 0x50, 0x4E, 0x47]) { + "image/png" + } else if data.starts_with(&[0xFF, 0xD8, 0xFF]) { + "image/jpeg" + } else if data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a") { + "image/gif" + } else if data.starts_with(b"RIFF") && data.len() > 12 && &data[8..12] == b"WEBP" { + "image/webp" + } else { + "application/octet-stream" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_image_mime_type() { + let png_header = &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]; + assert_eq!(detect_image_mime_type(png_header), "image/png"); + + let jpeg_header = &[0xFF, 0xD8, 0xFF, 0xE0]; + assert_eq!(detect_image_mime_type(jpeg_header), "image/jpeg"); + } + + #[test] + fn test_tool_builder() { + let tool = Tool::new("get_weather", "Get the weather for a location") + .with_input_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { + "location": { "type": "string" } + }, + "required": ["location"] + }))) + .with_approval(); + + assert_eq!(tool.name, "get_weather"); + assert!(tool.needs_approval); + assert!(tool.input_schema.is_some()); + } + + #[test] + fn test_stream_chunk_serialization() { + let chunk = StreamChunk::TextMessageContent { + timestamp: 1234567890.0, + message_id: "msg-123".to_string(), + delta: "Hello".to_string(), + content: None, + model: Some("gpt-4o".to_string()), + }; + + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("\"type\":\"TEXT_MESSAGE_CONTENT\"")); + assert!(json.contains("\"delta\":\"Hello\"")); + + let deserialized: StreamChunk = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.event_type(), AguiEventType::TextMessageContent); + } + + #[test] + fn test_model_message_serialization() { + let msg = ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hello!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }; + + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: ModelMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.role, MessageRole::User); + assert_eq!(deserialized.content.as_str(), Some("Hello!")); + } + + #[test] + fn test_agent_loop_strategies() { + let strategy = max_iterations(3); + let state = AgentLoopState { + iteration_count: 2, + messages: vec![], + finish_reason: None, + }; + assert!(strategy(&state)); + + let state = AgentLoopState { + iteration_count: 3, + messages: vec![], + finish_reason: None, + }; + assert!(!strategy(&state)); + } + + #[test] + fn test_tool_definition_builder() { + let def = tool_definition("search", "Search the web") + .input_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { "query": { "type": "string" } } + }))) + .needs_approval(false) + .lazy(true); + + assert_eq!(def.name, "search"); + assert!(def.lazy); + assert!(def.input_schema.is_some()); + + let tool = def.to_tool(); + assert_eq!(tool.name, "search"); + assert!(!tool.needs_approval); + } +} diff --git a/packages/rust/tanstack-ai/src/messages.rs b/packages/rust/tanstack-ai/src/messages.rs new file mode 100644 index 00000000..1cf26ed3 --- /dev/null +++ b/packages/rust/tanstack-ai/src/messages.rs @@ -0,0 +1,324 @@ +use crate::types::*; + +/// Generate a unique message ID. +pub fn generate_message_id(prefix: &str) -> String { + format!( + "{}-{}-{}", + prefix, + chrono::Utc::now().timestamp_millis(), + &uuid::Uuid::new_v4().to_string()[..8] + ) +} + +/// Convert UI messages to model messages. +/// +/// Handles the conversion of UIMessage format (with parts) to ModelMessage +/// format (with content/toolCalls) that is sent to providers. +pub fn ui_messages_to_model_messages(ui_messages: &[UiMessage]) -> Vec { + ui_messages + .iter() + .flat_map(ui_message_to_model_messages) + .collect() +} + +/// Convert a single UI message to one or more model messages. +pub fn ui_message_to_model_messages(ui_msg: &UiMessage) -> Vec { + let role = match ui_msg.role { + UiMessageRole::System => MessageRole::System, + UiMessageRole::User => MessageRole::User, + UiMessageRole::Assistant => MessageRole::Assistant, + }; + + match role { + MessageRole::System | MessageRole::User => { + // Collect text parts into a single message + let content_parts: Vec = ui_msg + .parts + .iter() + .filter_map(|part| match part { + MessagePart::Text { content, .. } => Some(ContentPart::Text { + content: content.clone(), + }), + MessagePart::Image { source, .. } => Some(ContentPart::Image { + source: source.clone(), + }), + MessagePart::Audio { source, .. } => Some(ContentPart::Audio { + source: source.clone(), + }), + _ => None, + }) + .collect(); + + if content_parts.is_empty() { + vec![ModelMessage { + role, + content: MessageContent::Null, + name: None, + tool_calls: None, + tool_call_id: None, + }] + } else if content_parts.len() == 1 { + if let ContentPart::Text { content } = &content_parts[0] { + vec![ModelMessage { + role, + content: MessageContent::Text(content.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }] + } else { + vec![ModelMessage { + role, + content: MessageContent::Parts(content_parts), + name: None, + tool_calls: None, + tool_call_id: None, + }] + } + } else { + vec![ModelMessage { + role, + content: MessageContent::Parts(content_parts), + name: None, + tool_calls: None, + tool_call_id: None, + }] + } + } + + MessageRole::Assistant => { + let mut messages = Vec::new(); + let mut text_content = String::new(); + let mut tool_calls = Vec::new(); + + for part in &ui_msg.parts { + match part { + MessagePart::Text { content, .. } => { + text_content.push_str(content); + } + MessagePart::Thinking { content } => { + // Thinking parts are included as text with a marker + text_content.push_str(&format!("[thinking]{}[/thinking]", content)); + } + MessagePart::ToolCall { + id, + name, + arguments, + .. + } => { + tool_calls.push(ToolCall { + id: id.clone(), + call_type: "function".to_string(), + function: ToolCallFunction { + name: name.clone(), + arguments: arguments.clone(), + }, + provider_metadata: None, + }); + } + MessagePart::ToolResult { + tool_call_id, + content, + state: _, + error, + .. + } => { + // Tool results are separate messages with role=tool + let result_content = if let Some(err) = error { + serde_json::json!({"error": err}).to_string() + } else { + content.clone() + }; + + messages.push(ModelMessage { + role: MessageRole::Tool, + content: MessageContent::Text(result_content), + name: None, + tool_calls: None, + tool_call_id: Some(tool_call_id.clone()), + }); + } + _ => {} + } + } + + // Create the assistant message + let content = if text_content.is_empty() { + MessageContent::Null + } else { + MessageContent::Text(text_content) + }; + + messages.insert( + 0, + ModelMessage { + role: MessageRole::Assistant, + content, + name: None, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }, + ); + + messages + } + + MessageRole::Tool => { + // UI messages never have Tool role — this is unreachable + unreachable!("Tool role should not appear in UI messages") + } + } +} + +/// Convert model messages back to UI messages. +pub fn model_messages_to_ui_messages(messages: &[ModelMessage]) -> Vec { + messages.iter().map(model_message_to_ui_message).collect() +} + +/// Convert a model message to a UI message. +pub fn model_message_to_ui_message(msg: &ModelMessage) -> UiMessage { + let role = match msg.role { + MessageRole::System => UiMessageRole::System, + MessageRole::User => UiMessageRole::User, + MessageRole::Assistant | MessageRole::Tool => UiMessageRole::Assistant, + }; + + let mut parts = Vec::new(); + + // Add content parts + match &msg.content { + MessageContent::Text(text) => { + parts.push(MessagePart::Text { + content: text.clone(), + metadata: None, + }); + } + MessageContent::Parts(content_parts) => { + for part in content_parts { + match part { + ContentPart::Text { content } => { + parts.push(MessagePart::Text { + content: content.clone(), + metadata: None, + }); + } + ContentPart::Image { source } => { + parts.push(MessagePart::Image { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Audio { source } => { + parts.push(MessagePart::Audio { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Video { source } => { + parts.push(MessagePart::Video { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Document { source } => { + parts.push(MessagePart::Document { + source: source.clone(), + metadata: None, + }); + } + } + } + } + MessageContent::Null => {} + } + + // Add tool call parts + if let Some(tool_calls) = &msg.tool_calls { + for tc in tool_calls { + parts.push(MessagePart::ToolCall { + id: tc.id.clone(), + name: tc.function.name.clone(), + arguments: tc.function.arguments.clone(), + state: ToolCallState::InputComplete, + approval: None, + output: None, + }); + } + } + + // Add tool result parts + if msg.role == MessageRole::Tool { + if let Some(tool_call_id) = &msg.tool_call_id { + let content_str = match &msg.content { + MessageContent::Text(s) => s.clone(), + _ => String::new(), + }; + parts.push(MessagePart::ToolResult { + tool_call_id: tool_call_id.clone(), + content: content_str, + state: ToolResultState::Complete, + error: None, + }); + } + } + + UiMessage { + id: generate_message_id("msg"), + role, + parts, + created_at: Some(chrono::Utc::now()), + } +} + +/// Normalize messages to ModelMessage format. +/// If the messages are already ModelMessages, pass through. +/// If they are UI messages, convert them. +pub fn normalize_to_model_messages(messages: &[ModelMessage]) -> Vec { + messages.to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ui_to_model_text() { + let ui_msg = UiMessage { + id: "test-1".to_string(), + role: UiMessageRole::User, + parts: vec![MessagePart::Text { + content: "Hello!".to_string(), + metadata: None, + }], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!(model_msgs[0].role, MessageRole::User); + assert_eq!(model_msgs[0].content.as_str(), Some("Hello!")); + } + + #[test] + fn test_model_to_ui_round_trip() { + let model_msg = ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Text("Hi there!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }; + + let ui_msg = model_message_to_ui_message(&model_msg); + assert_eq!(ui_msg.role, UiMessageRole::Assistant); + assert_eq!(ui_msg.parts.len(), 1); + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!(model_msgs[0].content.as_str(), Some("Hi there!")); + } +} diff --git a/packages/rust/tanstack-ai/src/middleware.rs b/packages/rust/tanstack-ai/src/middleware.rs new file mode 100644 index 00000000..7fe5e307 --- /dev/null +++ b/packages/rust/tanstack-ai/src/middleware.rs @@ -0,0 +1,322 @@ +use crate::types::{ModelMessage, StreamChunk, Tool, ToolCall}; +use std::collections::HashMap; + +/// Phase of the chat middleware lifecycle. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChatMiddlewarePhase { + Init, + BeforeModel, + ModelStream, + BeforeTools, + AfterTools, +} + +/// Stable context object passed to all middleware hooks. +#[derive(Debug, Clone)] +pub struct ChatMiddlewareContext { + pub request_id: String, + pub stream_id: String, + pub conversation_id: Option, + pub phase: ChatMiddlewarePhase, + pub iteration: u32, + pub chunk_index: usize, + pub context: Option, + pub provider: String, + pub model: String, + pub source: String, + pub streaming: bool, + pub system_prompts: Vec, + pub tool_names: Option>, + pub options: Option>, + pub model_options: Option>, + pub message_count: usize, + pub has_tools: bool, + pub current_message_id: Option, + pub accumulated_content: String, +} + +/// Chat configuration that middleware can observe or transform. +#[derive(Debug, Clone)] +pub struct ChatMiddlewareConfig { + pub messages: Vec, + pub system_prompts: Vec, + pub tools: Vec, + pub temperature: Option, + pub top_p: Option, + pub max_tokens: Option, + pub metadata: Option, + pub model_options: Option, +} + +/// Context provided to tool call hooks. +#[derive(Debug, Clone)] +pub struct ToolCallHookContext { + pub tool_call: ToolCall, + pub tool: Option, + pub args: serde_json::Value, + pub tool_name: String, + pub tool_call_id: String, +} + +/// Decision from onBeforeToolCall. +#[derive(Debug, Clone)] +pub enum BeforeToolCallDecision { + /// Continue with normal execution. + Continue, + /// Replace args used for execution. + TransformArgs { args: serde_json::Value }, + /// Skip execution, use provided result. + Skip { result: serde_json::Value }, + /// Abort the entire chat run. + Abort { reason: Option }, +} + +/// Outcome information provided to onAfterToolCall. +#[derive(Debug, Clone)] +pub struct AfterToolCallInfo { + pub tool_call: ToolCall, + pub tool_name: String, + pub tool_call_id: String, + pub ok: bool, + pub duration_ms: u128, + pub result: Option, + pub error: Option, +} + +/// Information passed to onIteration. +#[derive(Debug, Clone)] +pub struct IterationInfo { + pub iteration: u32, + pub message_id: String, +} + +/// Aggregate information passed to onToolPhaseComplete. +#[derive(Debug, Clone)] +pub struct ToolPhaseCompleteInfo { + pub tool_calls: Vec, + pub results: Vec, + pub needs_approval: Vec, + pub needs_client_execution: Vec, +} + +#[derive(Debug, Clone)] +pub struct ToolPhaseResultInfo { + pub tool_call_id: String, + pub tool_name: String, + pub result: serde_json::Value, + pub duration_ms: Option, +} + +#[derive(Debug, Clone)] +pub struct ToolPhaseApprovalInfo { + pub tool_call_id: String, + pub tool_name: String, + pub input: serde_json::Value, + pub approval_id: String, +} + +#[derive(Debug, Clone)] +pub struct ToolPhaseClientInfo { + pub tool_call_id: String, + pub tool_name: String, + pub input: serde_json::Value, +} + +/// Token usage statistics. +#[derive(Debug, Clone)] +pub struct UsageInfo { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +/// Information passed to onFinish. +#[derive(Debug, Clone)] +pub struct FinishInfo { + pub finish_reason: Option, + pub duration_ms: u128, + pub content: String, + pub usage: Option, +} + +/// Information passed to onAbort. +#[derive(Debug, Clone)] +pub struct AbortInfo { + pub reason: Option, + pub duration_ms: u128, +} + +/// Information passed to onError. +#[derive(Debug, Clone)] +pub struct ErrorInfo { + pub error: String, + pub duration_ms: u128, +} + +/// Chat middleware trait. +/// +/// All methods have default implementations that do nothing. +/// Middleware is composed in array order. +pub trait ChatMiddleware: Send + Sync { + /// Optional name for debugging. + fn name(&self) -> Option<&str> { + None + } + + /// Called to observe or transform the chat configuration. + fn on_config( + &self, + _ctx: &ChatMiddlewareContext, + config: &ChatMiddlewareConfig, + ) -> Option { + Some(config.clone()) + } + + /// Called when the chat run starts. + fn on_start(&self, _ctx: &ChatMiddlewareContext) {} + + /// Called at the start of each agent loop iteration. + fn on_iteration(&self, _ctx: &ChatMiddlewareContext, _info: &IterationInfo) {} + + /// Called for every chunk yielded by the chat engine. + /// Returns None to drop, Some(chunk) to pass through or transform. + fn on_chunk(&self, _ctx: &ChatMiddlewareContext, chunk: StreamChunk) -> Option { + Some(chunk) + } + + /// Called before a tool is executed. + fn on_before_tool_call( + &self, + _ctx: &ChatMiddlewareContext, + _hook_ctx: &ToolCallHookContext, + ) -> BeforeToolCallDecision { + BeforeToolCallDecision::Continue + } + + /// Called after a tool execution completes. + fn on_after_tool_call(&self, _ctx: &ChatMiddlewareContext, _info: &AfterToolCallInfo) {} + + /// Called after all tool calls in an iteration have been processed. + fn on_tool_phase_complete(&self, _ctx: &ChatMiddlewareContext, _info: &ToolPhaseCompleteInfo) {} + + /// Called when usage data is available. + fn on_usage(&self, _ctx: &ChatMiddlewareContext, _usage: &UsageInfo) {} + + /// Called when the chat run completes normally. + fn on_finish(&self, _ctx: &ChatMiddlewareContext, _info: &FinishInfo) {} + + /// Called when the chat run is aborted. + fn on_abort(&self, _ctx: &ChatMiddlewareContext, _info: &AbortInfo) {} + + /// Called when the chat run encounters an unhandled error. + fn on_error(&self, _ctx: &ChatMiddlewareContext, _info: &ErrorInfo) {} +} + +/// Middleware runner that composes multiple middlewares. +pub struct MiddlewareRunner { + middlewares: Vec>, +} + +impl MiddlewareRunner { + pub fn new(middlewares: Vec>) -> Self { + Self { middlewares } + } + + /// Run on_config through all middlewares in order. + pub fn run_on_config( + &self, + ctx: &ChatMiddlewareContext, + config: ChatMiddlewareConfig, + ) -> ChatMiddlewareConfig { + let mut current = config; + for mw in &self.middlewares { + if let Some(transformed) = mw.on_config(ctx, ¤t) { + current = transformed; + } + } + current + } + + /// Run on_start through all middlewares. + pub fn run_on_start(&self, ctx: &ChatMiddlewareContext) { + for mw in &self.middlewares { + mw.on_start(ctx); + } + } + + /// Run on_iteration through all middlewares. + pub fn run_on_iteration(&self, ctx: &ChatMiddlewareContext, info: &IterationInfo) { + for mw in &self.middlewares { + mw.on_iteration(ctx, info); + } + } + + /// Run on_chunk through all middlewares in order. + /// Each middleware can observe, transform, expand, or drop chunks. + pub fn run_on_chunk( + &self, + ctx: &ChatMiddlewareContext, + chunk: StreamChunk, + ) -> Vec { + let mut chunks = vec![chunk]; + for mw in &self.middlewares { + let mut next_chunks = Vec::new(); + for c in chunks { + if let Some(output) = mw.on_chunk(ctx, c) { + next_chunks.push(output); + } + } + chunks = next_chunks; + } + chunks + } + + /// Run on_before_tool_call through middlewares. + pub fn run_on_before_tool_call( + &self, + ctx: &ChatMiddlewareContext, + hook_ctx: &ToolCallHookContext, + ) -> BeforeToolCallDecision { + for mw in &self.middlewares { + match mw.on_before_tool_call(ctx, hook_ctx) { + BeforeToolCallDecision::Continue => {} + decision => return decision, + } + } + BeforeToolCallDecision::Continue + } + + /// Run on_after_tool_call through all middlewares. + pub fn run_on_after_tool_call(&self, ctx: &ChatMiddlewareContext, info: &AfterToolCallInfo) { + for mw in &self.middlewares { + mw.on_after_tool_call(ctx, info); + } + } + + /// Run on_finish through all middlewares. + pub fn run_on_finish(&self, ctx: &ChatMiddlewareContext, info: &FinishInfo) { + for mw in &self.middlewares { + mw.on_finish(ctx, info); + } + } + + /// Run on_abort through all middlewares. + pub fn run_on_abort(&self, ctx: &ChatMiddlewareContext, info: &AbortInfo) { + for mw in &self.middlewares { + mw.on_abort(ctx, info); + } + } + + /// Run on_error through all middlewares. + pub fn run_on_error(&self, ctx: &ChatMiddlewareContext, info: &ErrorInfo) { + for mw in &self.middlewares { + mw.on_error(ctx, info); + } + } +} + +impl Default for MiddlewareRunner { + fn default() -> Self { + Self::new(Vec::new()) + } +} diff --git a/packages/rust/tanstack-ai/src/stream/json_parser.rs b/packages/rust/tanstack-ai/src/stream/json_parser.rs new file mode 100644 index 00000000..74a78187 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream/json_parser.rs @@ -0,0 +1,102 @@ +/// JSON parser for partial/incomplete JSON strings. +/// +/// Used during streaming to parse tool call arguments that may be incomplete. + +/// Parse a potentially incomplete JSON string. +/// +/// Returns `None` if parsing fails or the input is empty. +pub fn parse_partial_json(json_string: &str) -> Option { + if json_string.trim().is_empty() { + return None; + } + + // Try standard JSON parse first + if let Ok(value) = serde_json::from_str(json_string) { + return Some(value); + } + + // Attempt to close incomplete JSON structures + let trimmed = json_string.trim(); + let mut attempt = trimmed.to_string(); + + // Count open/close brackets and braces to determine what to close + let mut brace_depth: i32 = 0; + let mut bracket_depth: i32 = 0; + let mut in_string = false; + let mut escape_next = false; + + for ch in trimmed.chars() { + if escape_next { + escape_next = false; + continue; + } + if ch == '\\' && in_string { + escape_next = true; + continue; + } + if ch == '"' { + in_string = !in_string; + continue; + } + if in_string { + continue; + } + match ch { + '{' => brace_depth += 1, + '}' => brace_depth -= 1, + '[' => bracket_depth += 1, + ']' => bracket_depth -= 1, + _ => {} + } + } + + // Close unclosed string + if in_string { + attempt.push('"'); + } + + // Close unclosed brackets and braces + for _ in 0..bracket_depth { + attempt.push(']'); + } + for _ in 0..brace_depth { + attempt.push('}'); + } + + serde_json::from_str(&attempt).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_complete_json() { + let result = parse_partial_json(r#"{"name": "test"}"#).unwrap(); + assert_eq!(result["name"], "test"); + } + + #[test] + fn test_partial_json_object() { + let result = parse_partial_json(r#"{"name": "te"#).unwrap(); + assert_eq!(result["name"], "te"); + } + + #[test] + fn test_partial_json_array() { + let result = parse_partial_json(r#"[1, 2, 3"#).unwrap(); + assert_eq!(result, serde_json::json!([1, 2, 3])); + } + + #[test] + fn test_empty_string() { + assert!(parse_partial_json("").is_none()); + assert!(parse_partial_json(" ").is_none()); + } + + #[test] + fn test_nested_partial() { + let result = parse_partial_json(r#"{"user": {"name": "Jo"#).unwrap(); + assert_eq!(result["user"]["name"], "Jo"); + } +} diff --git a/packages/rust/tanstack-ai/src/stream/mod.rs b/packages/rust/tanstack-ai/src/stream/mod.rs new file mode 100644 index 00000000..1d389364 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream/mod.rs @@ -0,0 +1,9 @@ +pub mod processor; +pub mod strategies; +pub mod json_parser; +pub mod types; + +pub use processor::*; +pub use strategies::*; +pub use json_parser::*; +pub use types::*; diff --git a/packages/rust/tanstack-ai/src/stream/processor.rs b/packages/rust/tanstack-ai/src/stream/processor.rs new file mode 100644 index 00000000..edebff68 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream/processor.rs @@ -0,0 +1,265 @@ +use crate::stream::strategies::ImmediateStrategy; +use crate::stream::types::*; +use crate::types::{StreamChunk, ToolCall, ToolCallFunction, ToolCallState}; + +/// Core stream processing state machine. +/// +/// Manages full UIMessage conversation state, tracks per-message stream state, +/// handles text accumulation with configurable chunking strategies, manages +/// parallel tool calls with lifecycle state tracking, and supports +/// thinking/reasoning content. +pub struct StreamProcessor { + messages: Vec, + current_message: Option, + chunk_strategy: Box, + recordings: Vec, + recording_enabled: bool, +} + +impl StreamProcessor { + /// Create a new StreamProcessor with the default immediate chunk strategy. + pub fn new() -> Self { + Self { + messages: Vec::new(), + current_message: None, + chunk_strategy: Box::new(ImmediateStrategy), + recordings: Vec::new(), + recording_enabled: false, + } + } + + /// Create a new StreamProcessor with a custom chunk strategy. + pub fn with_strategy(strategy: Box) -> Self { + Self { + messages: Vec::new(), + current_message: None, + chunk_strategy: strategy, + recordings: Vec::new(), + recording_enabled: false, + } + } + + /// Enable recording for replay testing. + pub fn enable_recording(&mut self) { + self.recording_enabled = true; + } + + /// Get the recorded chunks. + pub fn recordings(&self) -> &[RecordedChunk] { + &self.recordings + } + + /// Process a single stream chunk, updating internal state. + /// + /// Returns the processed chunk if it should be emitted to the UI. + pub fn process_chunk(&mut self, chunk: StreamChunk) -> Option { + if self.recording_enabled { + self.recordings.push(RecordedChunk { + chunk: chunk.clone(), + timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0, + index: self.recordings.len(), + }); + } + + match &chunk { + StreamChunk::TextMessageStart { + message_id, role, .. + } => { + let state = MessageStreamState::new(message_id.clone(), role.clone()); + self.current_message = Some(state); + Some(chunk) + } + + StreamChunk::TextMessageContent { + message_id: _, + delta, + content, + .. + } => { + if let Some(msg) = &mut self.current_message { + if let Some(full_content) = content { + msg.total_text_content = full_content.clone(); + } else { + msg.total_text_content.push_str(delta); + msg.current_segment_text.push_str(delta); + } + + if self + .chunk_strategy + .should_emit(delta, &msg.total_text_content) + { + msg.last_emitted_text = msg.total_text_content.clone(); + msg.current_segment_text.clear(); + Some(chunk) + } else { + None + } + } else { + Some(chunk) + } + } + + StreamChunk::TextMessageEnd { .. } => { + if let Some(msg) = &mut self.current_message { + // Flush any remaining text that wasn't emitted + if !msg.current_segment_text.is_empty() { + msg.last_emitted_text = msg.total_text_content.clone(); + msg.current_segment_text.clear(); + } + msg.is_complete = true; + if let Some(completed) = self.current_message.take() { + self.messages.push(completed); + } + } + self.chunk_strategy.reset(); + Some(chunk) + } + + StreamChunk::ToolCallStart { + tool_call_id, + tool_name, + index, + .. + } => { + // Auto-create a current message if none exists + if self.current_message.is_none() { + self.current_message = Some(MessageStreamState::new("auto-msg", "assistant")); + } + if let Some(msg) = &mut self.current_message { + let tc_state = InternalToolCallState { + id: tool_call_id.clone(), + name: tool_name.clone(), + arguments: String::new(), + state: ToolCallState::AwaitingInput, + parsed_arguments: None, + index: index.unwrap_or(msg.tool_calls.len()), + }; + msg.tool_calls.insert(tool_call_id.clone(), tc_state); + msg.tool_call_order.push(tool_call_id.clone()); + msg.has_tool_calls_since_text_start = true; + } + Some(chunk) + } + + StreamChunk::ToolCallArgs { + tool_call_id, + delta, + .. + } => { + if let Some(msg) = &mut self.current_message { + if let Some(tc) = msg.tool_calls.get_mut(tool_call_id) { + tc.arguments.push_str(delta); + tc.state = ToolCallState::InputStreaming; + tc.parsed_arguments = + crate::stream::json_parser::parse_partial_json(&tc.arguments); + } + } + Some(chunk) + } + + StreamChunk::ToolCallEnd { + tool_call_id, + input, + .. + } => { + if let Some(msg) = &mut self.current_message { + if let Some(tc) = msg.tool_calls.get_mut(tool_call_id) { + if let Some(final_input) = input { + tc.arguments = serde_json::to_string(final_input).unwrap_or_default(); + tc.parsed_arguments = Some(final_input.clone()); + } + tc.state = ToolCallState::InputComplete; + } + } + Some(chunk) + } + + StreamChunk::StepStarted { .. } => Some(chunk), + StreamChunk::StepFinished { + step_id: _step_id, + delta, + content, + .. + } => { + if let Some(msg) = &mut self.current_message { + if let Some(full) = content { + msg.thinking_content = full.clone(); + } else { + msg.thinking_content.push_str(delta); + } + } + Some(chunk) + } + + StreamChunk::RunFinished { .. } => Some(chunk), + StreamChunk::RunError { .. } => Some(chunk), + + // Pass through other events unchanged + _ => Some(chunk), + } + } + + /// Get the final processor result after all chunks have been processed. + pub fn result(&self) -> ProcessorResult { + // Check current_message first (may not have been finalized), then fall back to messages + let msg = self + .current_message + .as_ref() + .or_else(|| self.messages.last()); + + let content = msg + .map(|m| m.total_text_content.clone()) + .unwrap_or_default(); + + let thinking = msg + .filter(|m| !m.thinking_content.is_empty()) + .map(|m| m.thinking_content.clone()); + + let tool_calls = msg.map(|m| { + m.tool_call_order + .iter() + .filter_map(|id| m.tool_calls.get(id)) + .filter(|tc| tc.state == ToolCallState::InputComplete) + .map(|tc| ToolCall { + id: tc.id.clone(), + call_type: "function".to_string(), + function: ToolCallFunction { + name: tc.name.clone(), + arguments: tc.arguments.clone(), + }, + provider_metadata: None, + }) + .collect() + }); + + ProcessorResult { + content, + thinking, + tool_calls, + finish_reason: None, + } + } + + /// Get all messages processed so far. + pub fn messages(&self) -> &[MessageStreamState] { + &self.messages + } + + /// Create a recording for replay testing. + pub fn to_recording(&self, model: Option, provider: Option) -> ChunkRecording { + ChunkRecording { + version: "1.0".to_string(), + timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0, + model, + provider, + chunks: self.recordings.clone(), + result: None, + } + } +} + +impl Default for StreamProcessor { + fn default() -> Self { + Self::new() + } +} diff --git a/packages/rust/tanstack-ai/src/stream/strategies.rs b/packages/rust/tanstack-ai/src/stream/strategies.rs new file mode 100644 index 00000000..7eb2b778 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream/strategies.rs @@ -0,0 +1,111 @@ +use crate::stream::types::ChunkStrategy; +use regex::Regex; + +/// Immediate Strategy - emit on every chunk (default behavior). +#[derive(Debug, Clone)] +pub struct ImmediateStrategy; + +impl ChunkStrategy for ImmediateStrategy { + fn should_emit(&mut self, _chunk: &str, _accumulated: &str) -> bool { + true + } +} + +/// Punctuation Strategy - emit when chunk contains punctuation. +#[derive(Debug, Clone)] +pub struct PunctuationStrategy { + pattern: Regex, +} + +impl PunctuationStrategy { + pub fn new() -> Self { + Self { + pattern: Regex::new(r"[.,!?;:\n]").unwrap(), + } + } +} + +impl Default for PunctuationStrategy { + fn default() -> Self { + Self::new() + } +} + +impl ChunkStrategy for PunctuationStrategy { + fn should_emit(&mut self, chunk: &str, _accumulated: &str) -> bool { + self.pattern.is_match(chunk) + } +} + +/// Batch Strategy - emit every N chunks. +#[derive(Debug, Clone)] +pub struct BatchStrategy { + batch_size: usize, + chunk_count: usize, +} + +impl BatchStrategy { + pub fn new(batch_size: usize) -> Self { + Self { + batch_size, + chunk_count: 0, + } + } +} + +impl Default for BatchStrategy { + fn default() -> Self { + Self::new(5) + } +} + +impl ChunkStrategy for BatchStrategy { + fn should_emit(&mut self, _chunk: &str, _accumulated: &str) -> bool { + self.chunk_count += 1; + if self.chunk_count >= self.batch_size { + self.chunk_count = 0; + true + } else { + false + } + } + + fn reset(&mut self) { + self.chunk_count = 0; + } +} + +/// Word Boundary Strategy - emit at word boundaries (whitespace). +#[derive(Debug, Clone)] +pub struct WordBoundaryStrategy; + +impl ChunkStrategy for WordBoundaryStrategy { + fn should_emit(&mut self, chunk: &str, _accumulated: &str) -> bool { + chunk.ends_with(|c: char| c.is_whitespace()) + } +} + +/// Composite Strategy - combine multiple strategies with OR logic. +pub struct CompositeStrategy { + strategies: Vec>, +} + +impl CompositeStrategy { + pub fn new(strategies: Vec>) -> Self { + Self { strategies } + } +} + +impl ChunkStrategy for CompositeStrategy { + fn should_emit(&mut self, chunk: &str, accumulated: &str) -> bool { + self.strategies + .iter_mut() + .any(|s| s.should_emit(chunk, accumulated)) + } + + fn reset(&mut self) { + for s in &mut self.strategies { + s.reset(); + } + } +} diff --git a/packages/rust/tanstack-ai/src/stream/types.rs b/packages/rust/tanstack-ai/src/stream/types.rs new file mode 100644 index 00000000..8ff01522 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream/types.rs @@ -0,0 +1,98 @@ +use crate::types::{StreamChunk, ToolCall, ToolCallState}; + +/// Internal state for a tool call being tracked during streaming. +#[derive(Debug, Clone)] +pub struct InternalToolCallState { + pub id: String, + pub name: String, + pub arguments: String, + pub state: ToolCallState, + pub parsed_arguments: Option, + pub index: usize, +} + +/// Strategy for determining when to emit text updates. +pub trait ChunkStrategy: Send + Sync { + /// Called for each text chunk received. Returns true if an update should be emitted now. + fn should_emit(&mut self, chunk: &str, accumulated: &str) -> bool; + + /// Reset strategy state (called when streaming starts). + fn reset(&mut self) {} +} + +/// Per-message streaming state. +#[derive(Debug)] +pub struct MessageStreamState { + pub id: String, + pub role: String, + pub total_text_content: String, + pub current_segment_text: String, + pub last_emitted_text: String, + pub thinking_content: String, + pub tool_calls: std::collections::HashMap, + pub tool_call_order: Vec, + pub has_tool_calls_since_text_start: bool, + pub is_complete: bool, +} + +impl MessageStreamState { + pub fn new(id: impl Into, role: impl Into) -> Self { + Self { + id: id.into(), + role: role.into(), + total_text_content: String::new(), + current_segment_text: String::new(), + last_emitted_text: String::new(), + thinking_content: String::new(), + tool_calls: std::collections::HashMap::new(), + tool_call_order: Vec::new(), + has_tool_calls_since_text_start: false, + is_complete: false, + } + } +} + +/// Result from processing a stream. +#[derive(Debug, Clone)] +pub struct ProcessorResult { + pub content: String, + pub thinking: Option, + pub tool_calls: Option>, + pub finish_reason: Option, +} + +/// Current state of the processor. +#[derive(Debug)] +pub struct ProcessorState { + pub content: String, + pub thinking: String, + pub tool_calls: std::collections::HashMap, + pub tool_call_order: Vec, + pub finish_reason: Option, + pub done: bool, +} + +/// Recording format for replay testing. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ChunkRecording { + pub version: String, + pub timestamp: f64, + pub model: Option, + pub provider: Option, + pub chunks: Vec, + pub result: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RecordedChunk { + pub chunk: StreamChunk, + pub timestamp: f64, + pub index: usize, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ProcessorResultSerde { + pub content: String, + pub thinking: Option, + pub finish_reason: Option, +} diff --git a/packages/rust/tanstack-ai/src/stream_response.rs b/packages/rust/tanstack-ai/src/stream_response.rs new file mode 100644 index 00000000..bb3ce913 --- /dev/null +++ b/packages/rust/tanstack-ai/src/stream_response.rs @@ -0,0 +1,135 @@ +use bytes::Bytes; +use futures_util::Stream; +use std::pin::Pin; + +use crate::error::{AiError, AiResult}; +use crate::types::StreamChunk; + +/// Parse an SSE event line into (field, value). +pub fn parse_sse_line(line: &str) -> Option<(String, String)> { + let line = line.trim(); + if line.is_empty() || line.starts_with(':') { + return None; + } + + if let Some(colon_pos) = line.find(':') { + let field = &line[..colon_pos]; + let value = line[colon_pos + 1..].trim_start(); + Some((field.to_string(), value.to_string())) + } else { + Some((line.to_string(), String::new())) + } +} + +/// Parse SSE data from a byte stream into JSON values. +pub fn sse_stream_to_json( + stream: Pin> + Send>>, +) -> Pin> + Send>> { + use futures_util::StreamExt; + + let lines = tokio_util::codec::FramedRead::new( + tokio_util::io::StreamReader::new(stream.map(|r| r.map_err(|e| { + std::io::Error::new(std::io::ErrorKind::Other, e) + }))), + tokio_util::codec::LinesCodec::new(), + ); + + let stream = futures_util::stream::unfold( + (lines, String::new()), + |(mut lines, mut data_buf)| async move { + loop { + match lines.next().await { + Some(Ok(line)) => { + if line.trim().is_empty() { + // Empty line signals end of SSE event + if !data_buf.is_empty() { + let data = data_buf.clone(); + data_buf.clear(); + + // Handle [DONE] sentinel + if data.trim() == "[DONE]" { + return None; + } + + match serde_json::from_str::(&data) { + Ok(json) => return Some((Ok(json), (lines, data_buf))), + Err(e) => { + return Some(( + Err(AiError::Stream(format!( + "Failed to parse SSE data: {}", + e + ))), + (lines, data_buf), + )); + } + } + } + continue; + } + + if let Some((field, value)) = parse_sse_line(&line) { + if field == "data" { + if !data_buf.is_empty() { + data_buf.push('\n'); + } + data_buf.push_str(&value); + } + // Ignore other fields (id, event, retry) + } + } + Some(Err(e)) => { + return Some((Err(AiError::Stream(e.to_string())), (lines, data_buf))); + } + None => return None, // Stream ended + } + } + }, + ); + + Box::pin(stream) +} + +/// Convert a stream of StreamChunks to an HTTP response with SSE format. +pub fn chunks_to_sse_stream( + stream: Pin> + Send>>, +) -> Pin> + Send>> { + use futures_util::StreamExt; + + stream + .map(|result| match result { + Ok(chunk) => { + let json = serde_json::to_string(&chunk).unwrap_or_default(); + Ok(Bytes::from(format!("data: {}\n\n", json))) + } + Err(e) => Ok(Bytes::from(format!("data: {}\n\n", + serde_json::json!({"type": "RUN_ERROR", "error": {"message": e.to_string()}}) + ))), + }) + .boxed() +} + +/// Stream chunks to collected text. +pub async fn stream_to_text( + stream: &mut Pin> + Send>>, +) -> AiResult { + use futures_util::StreamExt; + let mut content = String::new(); + + while let Some(result) = stream.next().await { + match result? { + StreamChunk::TextMessageContent { delta, content: full, .. } => { + if let Some(full_content) = full { + content = full_content; + } else { + content.push_str(&delta); + } + } + StreamChunk::RunError { error, .. } => { + return Err(AiError::Provider(error.message)); + } + _ => {} + } + } + + Ok(content) +} diff --git a/packages/rust/tanstack-ai/src/tools/definition.rs b/packages/rust/tanstack-ai/src/tools/definition.rs new file mode 100644 index 00000000..9c04ae24 --- /dev/null +++ b/packages/rust/tanstack-ai/src/tools/definition.rs @@ -0,0 +1,105 @@ +use crate::error::AiResult; +use crate::types::{JsonSchema, Tool, ToolExecutionContext}; + +/// Tool definition builder for creating isomorphic tools. +/// +/// Tools defined with `ToolDefinition` can be used directly in chat or +/// converted to server/client variants. +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub input_schema: Option, + pub output_schema: Option, + pub needs_approval: bool, + pub lazy: bool, + pub metadata: Option, +} + +impl ToolDefinition { + /// Create a new tool definition. + pub fn new(name: impl Into, description: impl Into) -> Self { + Self { + name: name.into(), + description: description.into(), + input_schema: None, + output_schema: None, + needs_approval: false, + lazy: false, + metadata: None, + } + } + + /// Set the input schema. + pub fn input_schema(mut self, schema: JsonSchema) -> Self { + self.input_schema = Some(schema); + self + } + + /// Set the output schema. + pub fn output_schema(mut self, schema: JsonSchema) -> Self { + self.output_schema = Some(schema); + self + } + + /// Mark tool as requiring approval. + pub fn needs_approval(mut self, value: bool) -> Self { + self.needs_approval = value; + self + } + + /// Mark tool as lazy. + pub fn lazy(mut self, value: bool) -> Self { + self.lazy = value; + self + } + + /// Set metadata. + pub fn metadata(mut self, metadata: serde_json::Value) -> Self { + self.metadata = Some(metadata); + self + } + + /// Convert to a `Tool` definition (without execute function). + pub fn to_tool(&self) -> Tool { + Tool { + name: self.name.clone(), + description: self.description.clone(), + input_schema: self.input_schema.clone(), + output_schema: self.output_schema.clone(), + needs_approval: self.needs_approval, + lazy: self.lazy, + metadata: self.metadata.clone(), + execute: None, + } + } + + /// Convert to a server `Tool` with an execute function. + pub fn to_server_tool(&self, execute: F) -> Tool + where + F: Fn(serde_json::Value, ToolExecutionContext) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + Tool { + name: self.name.clone(), + description: self.description.clone(), + input_schema: self.input_schema.clone(), + output_schema: self.output_schema.clone(), + needs_approval: self.needs_approval, + lazy: self.lazy, + metadata: self.metadata.clone(), + execute: Some(std::sync::Arc::new(move |args, ctx| { + Box::pin(execute(args, ctx)) + })), + } + } +} + +/// Convenience function to create a tool definition. +pub fn tool_definition(name: impl Into, description: impl Into) -> ToolDefinition { + ToolDefinition::new(name, description) +} + +/// Helper to create a JSON Schema from a JSON value. +pub fn json_schema(value: serde_json::Value) -> JsonSchema { + serde_json::from_value(value).unwrap_or_default() +} diff --git a/packages/rust/tanstack-ai/src/tools/mod.rs b/packages/rust/tanstack-ai/src/tools/mod.rs new file mode 100644 index 00000000..f38a047e --- /dev/null +++ b/packages/rust/tanstack-ai/src/tools/mod.rs @@ -0,0 +1,7 @@ +pub mod definition; +pub mod registry; +pub mod tool_calls; + +pub use definition::*; +pub use registry::*; +pub use tool_calls::*; diff --git a/packages/rust/tanstack-ai/src/tools/registry.rs b/packages/rust/tanstack-ai/src/tools/registry.rs new file mode 100644 index 00000000..8931dd82 --- /dev/null +++ b/packages/rust/tanstack-ai/src/tools/registry.rs @@ -0,0 +1,130 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::types::Tool; + +/// Registry for managing tools available to the chat engine. +#[derive(Debug, Clone)] +pub struct ToolRegistry { + tools: HashMap, +} + +impl ToolRegistry { + /// Create an empty registry. + pub fn new() -> Self { + Self { + tools: HashMap::new(), + } + } + + /// Create a registry from a list of tools. + pub fn from_tools(tools: Vec) -> Self { + let mut registry = Self::new(); + for tool in tools { + registry.register(tool); + } + registry + } + + /// Register a tool. + pub fn register(&mut self, tool: Tool) { + self.tools.insert(tool.name.clone(), tool); + } + + /// Get a tool by name. + pub fn get(&self, name: &str) -> Option<&Tool> { + self.tools.get(name) + } + + /// Check if a tool exists. + pub fn contains(&self, name: &str) -> bool { + self.tools.contains_key(name) + } + + /// Get all tool names. + pub fn names(&self) -> Vec { + self.tools.keys().cloned().collect() + } + + /// Get all tools as a Vec. + pub fn all(&self) -> Vec<&Tool> { + self.tools.values().collect() + } + + /// Get the number of tools. + pub fn len(&self) -> usize { + self.tools.len() + } + + /// Check if the registry is empty. + pub fn is_empty(&self) -> bool { + self.tools.is_empty() + } + + /// Remove a tool by name. + pub fn remove(&mut self, name: &str) -> Option { + self.tools.remove(name) + } +} + +impl Default for ToolRegistry { + fn default() -> Self { + Self::new() + } +} + +impl From> for ToolRegistry { + fn from(tools: Vec) -> Self { + Self::from_tools(tools) + } +} + +/// Frozen/immutable tool registry for thread-safe sharing. +#[derive(Debug, Clone)] +pub struct FrozenToolRegistry { + tools: Arc>, +} + +impl FrozenToolRegistry { + /// Create a frozen registry from a mutable one. + pub fn freeze(registry: ToolRegistry) -> Self { + Self { + tools: Arc::new(registry.tools), + } + } + + /// Get a tool by name. + pub fn get(&self, name: &str) -> Option<&Tool> { + self.tools.get(name) + } + + /// Check if a tool exists. + pub fn contains(&self, name: &str) -> bool { + self.tools.contains_key(name) + } + + /// Get all tool names. + pub fn names(&self) -> Vec { + self.tools.keys().cloned().collect() + } + + /// Get all tools. + pub fn all(&self) -> Vec<&Tool> { + self.tools.values().collect() + } + + /// Get the number of tools. + pub fn len(&self) -> usize { + self.tools.len() + } + + /// Check if the registry is empty. + pub fn is_empty(&self) -> bool { + self.tools.is_empty() + } +} + +/// Convenience function to create a frozen registry. +pub fn create_frozen_registry(tools: Vec) -> FrozenToolRegistry { + FrozenToolRegistry::freeze(ToolRegistry::from_tools(tools)) +} diff --git a/packages/rust/tanstack-ai/src/tools/tool_calls.rs b/packages/rust/tanstack-ai/src/tools/tool_calls.rs new file mode 100644 index 00000000..cfc422c8 --- /dev/null +++ b/packages/rust/tanstack-ai/src/tools/tool_calls.rs @@ -0,0 +1,357 @@ +use std::collections::HashMap; +use crate::error::{AiError, AiResult}; +use crate::types::*; +use tokio::sync::mpsc; + +/// Result of a tool execution. +#[derive(Debug, Clone)] +pub struct ToolResult { + pub tool_call_id: String, + pub tool_name: String, + pub result: serde_json::Value, + pub state: ToolResultOutputState, + pub duration_ms: Option, +} + +/// Output state for a tool result. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ToolResultOutputState { + Available, + Error, +} + +/// A tool call that needs user approval. +#[derive(Debug, Clone)] +pub struct ApprovalRequest { + pub tool_call_id: String, + pub tool_name: String, + pub input: serde_json::Value, + pub approval_id: String, +} + +/// A tool that needs client-side execution. +#[derive(Debug, Clone)] +pub struct ClientToolRequest { + pub tool_call_id: String, + pub tool_name: String, + pub input: serde_json::Value, +} + +/// Result from executing a batch of tool calls. +#[derive(Debug)] +pub struct ExecuteToolCallsResult { + pub results: Vec, + pub needs_approval: Vec, + pub needs_client_execution: Vec, +} + +/// Custom event emitted during tool execution. +#[derive(Debug, Clone)] +pub struct ToolCustomEvent { + pub name: String, + pub value: serde_json::Value, +} + +/// Manages tool call accumulation and execution for the chat engine. +#[derive(Debug, Default)] +pub struct ToolCallManager { + tool_calls: HashMap, +} + +impl ToolCallManager { + pub fn new() -> Self { + Self { + tool_calls: HashMap::new(), + } + } + + /// Add a TOOL_CALL_START event. + pub fn add_start_event(&mut self, event: &StreamChunk) { + if let StreamChunk::ToolCallStart { + tool_call_id, + tool_name, + index: _, + provider_metadata, + .. + } = event + { + self.tool_calls.insert( + tool_call_id.clone(), + ToolCall { + id: tool_call_id.clone(), + call_type: "function".to_string(), + function: ToolCallFunction { + name: tool_name.clone(), + arguments: String::new(), + }, + provider_metadata: provider_metadata.clone(), + }, + ); + } + } + + /// Add a TOOL_CALL_ARGS event to accumulate arguments. + pub fn add_args_event(&mut self, event: &StreamChunk) { + if let StreamChunk::ToolCallArgs { tool_call_id, delta, .. } = event { + if let Some(tc) = self.tool_calls.get_mut(tool_call_id) { + tc.function.arguments.push_str(delta); + } + } + } + + /// Complete a tool call with its final input. + pub fn complete_tool_call(&mut self, event: &StreamChunk) { + if let StreamChunk::ToolCallEnd { tool_call_id, input, .. } = event { + if let Some(tc) = self.tool_calls.get_mut(tool_call_id) { + if let Some(final_input) = input { + tc.function.arguments = + serde_json::to_string(final_input).unwrap_or_default(); + } + } + } + } + + /// Check if there are any complete tool calls to execute. + pub fn has_tool_calls(&self) -> bool { + self.tool_calls + .values() + .any(|tc| !tc.id.is_empty() && !tc.function.name.trim().is_empty()) + } + + /// Get all tool calls as a Vec. + pub fn tool_calls(&self) -> Vec { + self.tool_calls + .values() + .filter(|tc| !tc.id.is_empty() && !tc.function.name.trim().is_empty()) + .cloned() + .collect() + } + + /// Clear all tool calls for the next iteration. + pub fn clear(&mut self) { + self.tool_calls.clear(); + } +} + +/// Execute tool calls with full approval and client-tool support. +/// +/// Returns a stream of custom events during execution, and the final result. +pub async fn execute_tool_calls( + tool_calls: &[ToolCall], + tools: &[Tool], + approvals: &HashMap, + client_results: &HashMap, + event_tx: Option>, +) -> AiResult { + let mut results = Vec::new(); + let mut needs_approval = Vec::new(); + let mut needs_client_execution = Vec::new(); + + // Build tool lookup map + let tool_map: HashMap<&str, &Tool> = tools.iter().map(|t| (t.name.as_str(), t)).collect(); + + // Check if any tools need pending approvals (batch gating) + let has_pending_approvals = tool_calls.iter().any(|tc| { + tool_map + .get(tc.function.name.as_str()) + .map(|t| t.needs_approval && !approvals.contains_key(&format!("approval_{}", tc.id))) + .unwrap_or(false) + }); + + for tool_call in tool_calls { + let tool_name = &tool_call.function.name; + let tool = tool_map.get(tool_name.as_str()).copied(); + + if tool.is_none() { + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + result: serde_json::json!({"error": format!("Unknown tool: {}", tool_name)}), + state: ToolResultOutputState::Error, + duration_ms: None, + }); + continue; + } + + let tool = tool.unwrap(); + + // Skip non-pending tools while approvals are outstanding + if has_pending_approvals { + if !tool.needs_approval || approvals.contains_key(&format!("approval_{}", tool_call.id)) + { + continue; + } + } + + // Parse arguments + let input: serde_json::Value = { + let args_str = if tool_call.function.arguments.trim().is_empty() { + "{}" + } else { + tool_call.function.arguments.trim() + }; + match serde_json::from_str(args_str) { + Ok(v) => v, + Err(e) => { + return Err(AiError::ToolExecution(format!( + "Failed to parse tool arguments: {}", + e + ))); + } + } + }; + + // CASE 1: Tool has no execute function (client-side tool) + if tool.execute.is_none() { + if tool.needs_approval { + let approval_id = format!("approval_{}", tool_call.id); + if let Some(&approved) = approvals.get(&approval_id) { + if approved { + if let Some(result) = client_results.get(&tool_call.id) { + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + result: result.clone(), + state: ToolResultOutputState::Available, + duration_ms: None, + }); + } else { + needs_client_execution.push(ClientToolRequest { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + input, + }); + } + } else { + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + result: serde_json::json!({"error": "User declined tool execution"}), + state: ToolResultOutputState::Error, + duration_ms: None, + }); + } + } else { + needs_approval.push(ApprovalRequest { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + input, + approval_id, + }); + } + } else if let Some(result) = client_results.get(&tool_call.id) { + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + result: result.clone(), + state: ToolResultOutputState::Available, + duration_ms: None, + }); + } else { + needs_client_execution.push(ClientToolRequest { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + input, + }); + } + continue; + } + + // CASE 2: Server tool with approval + if tool.needs_approval { + let approval_id = format!("approval_{}", tool_call.id); + if let Some(&approved) = approvals.get(&approval_id) { + if approved { + execute_server_tool( + tool_call, + tool, + input, + &event_tx, + &mut results, + ) + .await?; + } else { + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + result: serde_json::json!({"error": "User declined tool execution"}), + state: ToolResultOutputState::Error, + duration_ms: None, + }); + } + } else { + needs_approval.push(ApprovalRequest { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + input, + approval_id, + }); + } + continue; + } + + // CASE 3: Normal server tool - execute immediately + execute_server_tool(tool_call, tool, input, &event_tx, &mut results).await?; + } + + Ok(ExecuteToolCallsResult { + results, + needs_approval, + needs_client_execution, + }) +} + +async fn execute_server_tool( + tool_call: &ToolCall, + tool: &Tool, + input: serde_json::Value, + event_tx: &Option>, + results: &mut Vec, +) -> AiResult<()> { + let start = std::time::Instant::now(); + + let ctx = ToolExecutionContext { + tool_call_id: Some(tool_call.id.clone()), + custom_event_tx: None, // TODO: wire up custom event channel + }; + + let execute_fn = tool.execute.as_ref().unwrap(); + + match execute_fn(input, ctx).await { + Ok(result) => { + let duration = start.elapsed().as_millis(); + + // Emit custom event if channel is available + if let Some(tx) = event_tx { + let _ = tx.send(ToolCustomEvent { + name: "tool-result".to_string(), + value: serde_json::json!({ + "toolCallId": tool_call.id, + "toolName": tool.name, + "result": result, + }), + }); + } + + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool.name.clone(), + result, + state: ToolResultOutputState::Available, + duration_ms: Some(duration), + }); + } + Err(e) => { + let duration = start.elapsed().as_millis(); + results.push(ToolResult { + tool_call_id: tool_call.id.clone(), + tool_name: tool.name.clone(), + result: serde_json::json!({"error": e.to_string()}), + state: ToolResultOutputState::Error, + duration_ms: Some(duration), + }); + } + } + + Ok(()) +} diff --git a/packages/rust/tanstack-ai/src/types.rs b/packages/rust/tanstack-ai/src/types.rs new file mode 100644 index 00000000..374623d2 --- /dev/null +++ b/packages/rust/tanstack-ai/src/types.rs @@ -0,0 +1,885 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ============================================================================ +// Tool Call States +// ============================================================================ + +/// Lifecycle state of a tool call. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum ToolCallState { + /// Received start but no arguments yet. + AwaitingInput, + /// Partial arguments received. + InputStreaming, + /// All arguments received. + InputComplete, + /// Waiting for user approval. + ApprovalRequested, + /// User has approved/denied. + ApprovalResponded, +} + +/// Lifecycle state of a tool result. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum ToolResultState { + /// Placeholder for future streamed output. + Streaming, + /// Result is complete. + Complete, + /// Error occurred. + Error, +} + +// ============================================================================ +// JSON Schema +// ============================================================================ + +/// JSON Schema type for defining tool input/output schemas. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct JsonSchema { + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#enum: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#ref: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub defs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub all_of: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub any_of: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub one_of: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_properties: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +// ============================================================================ +// Multimodal Content Types +// ============================================================================ + +/// Supported input modality types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Modality { + Text, + Image, + Audio, + Video, + Document, +} + +/// Source for inline data content (base64). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ContentPartSource { + /// Inline base64 data. + Data { value: String, mime_type: String }, + /// URL-referenced content. + Url { + value: String, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + }, +} + +/// Image content part. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImagePart { + #[serde(rename = "type")] + pub part_type: &'static str, // always "image" + pub source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Audio content part. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AudioPart { + #[serde(rename = "type")] + pub part_type: &'static str, + pub source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Video content part. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoPart { + #[serde(rename = "type")] + pub part_type: &'static str, + pub source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Document content part (e.g., PDFs). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DocumentPart { + #[serde(rename = "type")] + pub part_type: &'static str, + pub source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Union type for all multimodal content parts. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ContentPart { + Text { content: String }, + Image { source: ContentPartSource }, + Audio { source: ContentPartSource }, + Video { source: ContentPartSource }, + Document { source: ContentPartSource }, +} + +// ============================================================================ +// Message Types +// ============================================================================ + +/// A message in the conversation with a model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelMessage { + pub role: MessageRole, + pub content: MessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Message role. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MessageRole { + System, + User, + Assistant, + Tool, +} + +/// Message content — can be simple text, null, or multimodal parts. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Parts(Vec), + Null, +} + +impl MessageContent { + pub fn text(s: impl Into) -> Self { + MessageContent::Text(s.into()) + } + + pub fn as_str(&self) -> Option<&str> { + match self { + MessageContent::Text(s) => Some(s.as_str()), + _ => None, + } + } +} + +// ============================================================================ +// Tool Call +// ============================================================================ + +/// A tool call from the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, // always "function" + pub function: ToolCallFunction, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_metadata: Option, +} + +/// Function details within a tool call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallFunction { + pub name: String, + pub arguments: String, // JSON string +} + +// ============================================================================ +// UI Message & Parts +// ============================================================================ + +/// Domain-specific message format optimized for chat UIs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UiMessage { + pub id: String, + pub role: UiMessageRole, + pub parts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option>, +} + +/// Role for UI messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum UiMessageRole { + System, + User, + Assistant, +} + +/// A part of a UI message. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum MessagePart { + Text { + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option, + }, + Image { + source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option, + }, + Audio { + source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option, + }, + Video { + source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option, + }, + Document { + source: ContentPartSource, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option, + }, + ToolCall { + id: String, + name: String, + arguments: String, + state: ToolCallState, + #[serde(skip_serializing_if = "Option::is_none")] + approval: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + }, + ToolResult { + #[serde(rename = "toolCallId")] + tool_call_id: String, + content: String, + state: ToolResultState, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, + Thinking { + content: String, + }, +} + +/// Approval metadata for a tool call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallApproval { + pub id: String, + pub needs_approval: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub approved: Option, +} + +// ============================================================================ +// Tool Definition +// ============================================================================ + +/// Tool/Function definition for function calling. +#[derive(Clone)] +pub struct Tool { + /// Unique name (used by the model to call it). + pub name: String, + /// Description of what the tool does. + pub description: String, + /// JSON Schema for input parameters. + pub input_schema: Option, + /// JSON Schema for output validation. + pub output_schema: Option, + /// If true, tool execution requires user approval. + pub needs_approval: bool, + /// If true, this tool is lazy and discovered on-demand. + pub lazy: bool, + /// Additional metadata. + pub metadata: Option, + /// Execute function (server-side tools). + pub execute: Option, +} + +/// Type for tool execute functions. +pub type ToolExecuteFn = + std::sync::Arc ToolFuture + Send + Sync>; + +/// Future returned by tool execute functions. +pub type ToolFuture = + std::pin::Pin> + Send>>; + +/// Context passed to tool execute functions. +#[derive(Debug, Clone)] +pub struct ToolExecutionContext { + pub tool_call_id: Option, + pub custom_event_tx: Option>, +} + +impl ToolExecutionContext { + /// Emit a custom event during tool execution. + pub fn emit_custom_event(&self, event_name: &str, value: serde_json::Value) { + if let Some(tx) = &self.custom_event_tx { + let _ = tx.send(CustomEventData { + name: event_name.to_string(), + value, + tool_call_id: self.tool_call_id.clone(), + }); + } + } +} + +/// Data for a custom event emitted during tool execution. +#[derive(Debug, Clone)] +pub struct CustomEventData { + pub name: String, + pub value: serde_json::Value, + pub tool_call_id: Option, +} + +impl Tool { + /// Create a new tool definition. + pub fn new(name: impl Into, description: impl Into) -> Self { + Self { + name: name.into(), + description: description.into(), + input_schema: None, + output_schema: None, + needs_approval: false, + lazy: false, + metadata: None, + execute: None, + } + } + + /// Set the input schema. + pub fn with_input_schema(mut self, schema: JsonSchema) -> Self { + self.input_schema = Some(schema); + self + } + + /// Set the output schema. + pub fn with_output_schema(mut self, schema: JsonSchema) -> Self { + self.output_schema = Some(schema); + self + } + + /// Mark this tool as requiring approval. + pub fn with_approval(mut self) -> Self { + self.needs_approval = true; + self + } + + /// Mark this tool as lazy. + pub fn with_lazy(mut self) -> Self { + self.lazy = true; + self + } + + /// Set the execute function. + pub fn with_execute(mut self, f: F) -> Self + where + F: Fn(serde_json::Value, ToolExecutionContext) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + self.execute = Some(std::sync::Arc::new(move |args, ctx| Box::pin(f(args, ctx)))); + self + } +} + +// ============================================================================ +// AG-UI Protocol Event Types +// ============================================================================ + +/// AG-UI Protocol event types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum AguiEventType { + RunStarted, + RunFinished, + RunError, + TextMessageStart, + TextMessageContent, + TextMessageEnd, + ToolCallStart, + ToolCallArgs, + ToolCallEnd, + StepStarted, + StepFinished, + MessagesSnapshot, + StateSnapshot, + StateDelta, + Custom, +} + +/// Token usage statistics. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +/// Finish reason from the model. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ContentFilter, + ToolCalls, +} + +/// Stream chunk / AG-UI event. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE")] +pub enum StreamChunk { + #[serde(rename_all = "camelCase")] + RunStarted { + timestamp: f64, + run_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + RunFinished { + timestamp: f64, + run_id: String, + finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + RunError { + timestamp: f64, + #[serde(skip_serializing_if = "Option::is_none")] + run_id: Option, + error: RunErrorData, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + TextMessageStart { + timestamp: f64, + message_id: String, + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + TextMessageContent { + timestamp: f64, + message_id: String, + delta: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + TextMessageEnd { + timestamp: f64, + message_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + ToolCallStart { + timestamp: f64, + tool_call_id: String, + tool_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + parent_message_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + provider_metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + ToolCallArgs { + timestamp: f64, + tool_call_id: String, + delta: String, + #[serde(skip_serializing_if = "Option::is_none")] + args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + ToolCallEnd { + timestamp: f64, + tool_call_id: String, + tool_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + StepStarted { + timestamp: f64, + step_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + step_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + StepFinished { + timestamp: f64, + step_id: String, + delta: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + MessagesSnapshot { + timestamp: f64, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + StateSnapshot { + timestamp: f64, + state: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + StateDelta { + timestamp: f64, + delta: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, + #[serde(rename_all = "camelCase")] + Custom { + timestamp: f64, + name: String, + #[serde(skip_serializing_if = "Option::is_none")] + value: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + }, +} + +/// Error data in a RUN_ERROR event. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RunErrorData { + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} + +impl StreamChunk { + /// Get the event type of this chunk. + pub fn event_type(&self) -> AguiEventType { + match self { + StreamChunk::RunStarted { .. } => AguiEventType::RunStarted, + StreamChunk::RunFinished { .. } => AguiEventType::RunFinished, + StreamChunk::RunError { .. } => AguiEventType::RunError, + StreamChunk::TextMessageStart { .. } => AguiEventType::TextMessageStart, + StreamChunk::TextMessageContent { .. } => AguiEventType::TextMessageContent, + StreamChunk::TextMessageEnd { .. } => AguiEventType::TextMessageEnd, + StreamChunk::ToolCallStart { .. } => AguiEventType::ToolCallStart, + StreamChunk::ToolCallArgs { .. } => AguiEventType::ToolCallArgs, + StreamChunk::ToolCallEnd { .. } => AguiEventType::ToolCallEnd, + StreamChunk::StepStarted { .. } => AguiEventType::StepStarted, + StreamChunk::StepFinished { .. } => AguiEventType::StepFinished, + StreamChunk::MessagesSnapshot { .. } => AguiEventType::MessagesSnapshot, + StreamChunk::StateSnapshot { .. } => AguiEventType::StateSnapshot, + StreamChunk::StateDelta { .. } => AguiEventType::StateDelta, + StreamChunk::Custom { .. } => AguiEventType::Custom, + } + } +} + +// ============================================================================ +// Configuration Types +// ============================================================================ + +/// State passed to agent loop strategy functions. +#[derive(Debug, Clone)] +pub struct AgentLoopState { + pub iteration_count: u32, + pub messages: Vec, + pub finish_reason: Option, +} + +/// Strategy function that determines whether the agent loop should continue. +pub type AgentLoopStrategy = Box bool + Send + Sync>; + +/// Options for text generation / chat. +pub struct TextOptions { + pub model: String, + pub messages: Vec, + pub tools: Vec, + pub system_prompts: Vec, + pub agent_loop_strategy: Option, + pub temperature: Option, + pub top_p: Option, + pub max_tokens: Option, + pub metadata: Option, + pub model_options: Option, + pub output_schema: Option, + pub conversation_id: Option, +} + +impl Clone for TextOptions { + fn clone(&self) -> Self { + Self { + model: self.model.clone(), + messages: self.messages.clone(), + tools: self.tools.clone(), + system_prompts: self.system_prompts.clone(), + agent_loop_strategy: None, // Strategy functions cannot be cloned + temperature: self.temperature, + top_p: self.top_p, + max_tokens: self.max_tokens, + metadata: self.metadata.clone(), + model_options: self.model_options.clone(), + output_schema: self.output_schema.clone(), + conversation_id: self.conversation_id.clone(), + } + } +} + +impl Default for TextOptions { + fn default() -> Self { + Self { + model: String::new(), + messages: Vec::new(), + tools: Vec::new(), + system_prompts: Vec::new(), + agent_loop_strategy: None, + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + output_schema: None, + conversation_id: None, + } + } +} + +/// Options for structured output generation. +#[derive(Clone)] +pub struct StructuredOutputOptions { + pub chat_options: TextOptions, + pub output_schema: JsonSchema, +} + +/// Result from structured output generation. +#[derive(Debug, Clone)] +pub struct StructuredOutputResult { + pub data: serde_json::Value, + pub raw_text: String, +} + +// ============================================================================ +// Agent Loop Strategies +// ============================================================================ + +/// Continue for up to `max` iterations. +pub fn max_iterations(max: u32) -> AgentLoopStrategy { + Box::new(move |state: &AgentLoopState| state.iteration_count < max) +} + +/// Continue until a specific finish reason is received. +pub fn until_finish_reason(reason: impl Into) -> AgentLoopStrategy { + let reason = reason.into(); + Box::new(move |state: &AgentLoopState| state.finish_reason.as_ref() != Some(&reason)) +} + +/// Combine multiple strategies with AND logic (all must return true). +pub fn combine_strategies(strategies: Vec) -> AgentLoopStrategy { + Box::new(move |state: &AgentLoopState| strategies.iter().all(|s| s(state))) +} + +// ============================================================================ +// Summarization, Image, Video, TTS, Transcription Types +// ============================================================================ + +/// Options for summarization. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SummarizationOptions { + pub model: String, + pub text: String, + pub max_length: Option, + pub style: Option, + pub focus: Option>, +} + +/// Result of summarization. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SummarizationResult { + pub id: String, + pub model: String, + pub summary: String, + pub usage: Usage, +} + +/// Options for image generation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageGenerationOptions { + pub model: String, + pub prompt: String, + pub number_of_images: Option, + pub size: Option, +} + +/// A generated image. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneratedImage { + pub b64_json: Option, + pub url: Option, + pub revised_prompt: Option, +} + +/// Result of image generation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageGenerationResult { + pub id: String, + pub model: String, + pub images: Vec, + pub usage: Option, +} + +/// Options for text-to-speech. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TtsOptions { + pub model: String, + pub text: String, + pub voice: Option, + pub format: Option, + pub speed: Option, +} + +/// Result of text-to-speech. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TtsResult { + pub id: String, + pub model: String, + pub audio: String, // base64 + pub format: String, + pub duration: Option, + pub content_type: Option, +} + +/// Options for transcription. +#[derive(Debug, Clone)] +pub struct TranscriptionOptions { + pub model: String, + pub audio: Vec, + pub language: Option, + pub prompt: Option, + pub response_format: Option, +} + +/// A transcribed segment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscriptionSegment { + pub id: u32, + pub start: f64, + pub end: f64, + pub text: String, + pub confidence: Option, + pub speaker: Option, +} + +/// Result of transcription. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscriptionResult { + pub id: String, + pub model: String, + pub text: String, + pub language: Option, + pub duration: Option, + pub segments: Option>, +} + +use crate::error::AiResult; + +impl MessageRole { + /// Get the string representation of the role for API requests. + pub fn as_str(&self) -> &'static str { + match self { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + MessageRole::Tool => "tool", + } + } +} + +impl std::fmt::Debug for Tool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tool") + .field("name", &self.name) + .field("description", &self.description) + .field("input_schema", &self.input_schema) + .field("output_schema", &self.output_schema) + .field("needs_approval", &self.needs_approval) + .field("lazy", &self.lazy) + .field("metadata", &self.metadata) + .field("execute", &self.execute.as_ref().map(|_| "")) + .finish() + } +} diff --git a/packages/rust/tanstack-ai/tests/chat_tests.rs b/packages/rust/tanstack-ai/tests/chat_tests.rs new file mode 100644 index 00000000..d955db64 --- /dev/null +++ b/packages/rust/tanstack-ai/tests/chat_tests.rs @@ -0,0 +1,861 @@ +//! Chat engine integration tests with mock adapter. +//! +//! Ports the TypeScript chat.test.ts patterns to Rust. + +mod test_utils; + +use std::sync::Arc; +use tanstack_ai::*; +use test_utils::*; + +// ============================================================================ +// Streaming text (no tools) +// ============================================================================ + +#[tokio::test] +async fn test_streaming_text_yields_all_chunks() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + text_start("msg-1"), + text_content("Hello", "msg-1"), + text_content(" world!", "msg-1"), + text_end("msg-1"), + run_finished("stop", "run-1"), + ]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Chunks(chunks) => { + assert_eq!(chunks.len(), 6); + assert_eq!(count_chunk_type(&chunks, "RUN_STARTED"), 1); + assert_eq!(count_chunk_type(&chunks, "TEXT_MESSAGE_START"), 1); + assert_eq!(count_chunk_type(&chunks, "TEXT_MESSAGE_CONTENT"), 2); + assert_eq!(count_chunk_type(&chunks, "TEXT_MESSAGE_END"), 1); + assert_eq!(count_chunk_type(&chunks, "RUN_FINISHED"), 1); + } + _ => panic!("Expected Chunks result"), + } +} + +#[tokio::test] +async fn test_streaming_text_passes_messages() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + text_content("Hi", "msg-1"), + run_finished("stop", "run-1"), + ]])); + + let chat_adapter = adapter.clone(); + + let _ = chat(ChatOptions { + adapter: chat_adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + let calls = adapter.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].messages.len(), 1); + assert_eq!(calls[0].messages[0].role, MessageRole::User); +} + +#[tokio::test] +async fn test_streaming_text_passes_options() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + run_finished("stop", "run-1"), + ]])); + + let _ = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec!["You are helpful".to_string()], + tools: vec![], + temperature: Some(0.5), + top_p: Some(0.9), + max_tokens: Some(100), + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + let calls = adapter.calls(); + assert_eq!(calls[0].system_prompts, vec!["You are helpful"]); + assert_eq!(calls[0].temperature, Some(0.5)); + assert_eq!(calls[0].top_p, Some(0.9)); + assert_eq!(calls[0].max_tokens, Some(100)); +} + +// ============================================================================ +// Non-streaming text (stream: false) +// ============================================================================ + +#[tokio::test] +async fn test_non_streaming_returns_collected_text() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + text_start("msg-1"), + text_content("Hello", "msg-1"), + text_content(" world!", "msg-1"), + text_end("msg-1"), + run_finished("stop", "run-1"), + ]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: false, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Text(text) => assert_eq!(text, "Hello world!"), + _ => panic!("Expected Text result"), + } +} + +// ============================================================================ +// Server tool execution +// ============================================================================ + +#[tokio::test] +async fn test_server_tool_execution() { + let tool = server_tool("getWeather", serde_json::json!({"temp": 72})); + let adapter = Arc::new(MockAdapter::new(vec![ + // First iteration: model requests tool + vec![ + run_started("run-1"), + text_start("msg-1"), + text_content("Let me check.", "msg-1"), + text_end("msg-1"), + tool_start("call_1", "getWeather", None), + tool_args("call_1", "{\"city\":\"NYC\"}"), + run_finished("tool_calls", "run-1"), + ], + // Second iteration: model produces final text + vec![ + run_started("run-2"), + text_start("msg-2"), + text_content("72F in NYC.", "msg-2"), + text_end("msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let result = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Weather?".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Adapter was called twice (tool call + final text) + assert_eq!(adapter.call_count(), 2); + + // Second call should have tool result in messages + let calls = adapter.calls(); + let second_call_messages = &calls[1].messages; + let has_tool_result = second_call_messages + .iter() + .any(|m| m.role == MessageRole::Tool); + assert!(has_tool_result, "Expected tool result message in second call"); + + // Should have TOOL_CALL_END chunks + match result { + ChatResult::Chunks(chunks) => { + let tool_ends = count_chunk_type(&chunks, "TOOL_CALL_END"); + assert!(tool_ends >= 1, "Expected at least one TOOL_CALL_END"); + } + _ => panic!("Expected Chunks result"), + } +} + +#[tokio::test] +async fn test_tool_execution_error_handled_gracefully() { + // Create a tool that always fails + let tool = Tool::new("failTool", "A tool that fails").with_execute( + |_args: serde_json::Value, _ctx| async { + Err(AiError::ToolExecution("Tool broke".to_string())) + }, + ); + + let adapter = Arc::new(MockAdapter::new(vec![ + vec![ + run_started("run-1"), + tool_start("call_1", "failTool", None), + tool_args("call_1", "{}"), + run_finished("tool_calls", "run-1"), + ], + vec![ + run_started("run-2"), + text_content("Error happened.", "msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Do something".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Chunks(chunks) => { + // Should still complete with error result + let tool_ends: Vec<_> = chunks + .iter() + .filter(|c| matches!(c, StreamChunk::ToolCallEnd { .. })) + .collect(); + assert!(!tool_ends.is_empty(), "Expected TOOL_CALL_END with error"); + } + _ => panic!("Expected Chunks result"), + } +} + +// ============================================================================ +// Parallel tool calls +// ============================================================================ + +#[tokio::test] +async fn test_parallel_tool_calls() { + let weather_tool = server_tool("getWeather", serde_json::json!({"temp": 72})); + let time_tool = server_tool("getTime", serde_json::json!({"time": "3pm"})); + + let adapter = Arc::new(MockAdapter::new(vec![ + // Model requests two tools + vec![ + run_started("run-1"), + tool_start("call_1", "getWeather", Some(0)), + tool_start("call_2", "getTime", Some(1)), + tool_args("call_1", "{\"city\":\"NYC\"}"), + tool_args("call_2", "{\"tz\":\"EST\"}"), + run_finished("tool_calls", "run-1"), + ], + // Model produces final text + vec![ + run_started("run-2"), + text_content("It's 3pm and 72F in NYC.", "msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let _result = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Weather and time?".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![weather_tool, time_tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Second call should have two tool result messages + let calls = adapter.calls(); + let second_call_messages = &calls[1].messages; + let tool_results: Vec<_> = second_call_messages + .iter() + .filter(|m| m.role == MessageRole::Tool) + .collect(); + assert_eq!(tool_results.len(), 2, "Expected 2 tool results"); +} + +// ============================================================================ +// Multi-iteration agent loop +// ============================================================================ + +#[tokio::test] +async fn test_multi_iteration_agent_loop() { + let tool1 = server_tool("search", serde_json::json!({"results": "found"})); + let tool2 = server_tool("analyze", serde_json::json!({"analysis": "done"})); + + let adapter = Arc::new(MockAdapter::new(vec![ + // Iteration 1: search + vec![ + run_started("run-1"), + tool_start("call_1", "search", None), + tool_args("call_1", "{\"query\":\"test\"}"), + run_finished("tool_calls", "run-1"), + ], + // Iteration 2: analyze + vec![ + run_started("run-2"), + tool_start("call_2", "analyze", None), + tool_args("call_2", "{\"data\":\"found\"}"), + run_finished("tool_calls", "run-2"), + ], + // Iteration 3: final text + vec![ + run_started("run-3"), + text_content("Analysis complete.", "msg-3"), + run_finished("stop", "run-3"), + ], + ])); + + let result = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Research this".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool1, tool2], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, // Uses default max_iterations(5) + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Should have called adapter 3 times + assert_eq!(adapter.call_count(), 3); + + // Final text should be present + match result { + ChatResult::Chunks(chunks) => { + let text = collect_text(&chunks); + assert_eq!(text, "Analysis complete."); + } + _ => panic!("Expected Chunks result"), + } +} + +// ============================================================================ +// Agent loop strategy enforcement +// ============================================================================ + +#[tokio::test] +async fn test_agent_loop_respects_max_iterations() { + let tool = server_tool("loop", serde_json::json!({"result": "ok"})); + + // Adapter always returns tool_calls - would loop forever without strategy + let adapter = Arc::new(MockAdapter::new(vec![ + vec![ + run_started("run-1"), + tool_start("call_1", "loop", None), + tool_args("call_1", "{}"), + run_finished("tool_calls", "run-1"), + ], + vec![ + run_started("run-2"), + tool_start("call_2", "loop", None), + tool_args("call_2", "{}"), + run_finished("tool_calls", "run-2"), + ], + vec![ + run_started("run-3"), + tool_start("call_3", "loop", None), + tool_args("call_3", "{}"), + run_finished("tool_calls", "run-3"), + ], + ])); + + let _result = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Loop test".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: Some(max_iterations(2)), // Only allow 2 iterations + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Should stop at 2 iterations even though adapter has 3 + assert_eq!(adapter.call_count(), 2); +} + +// ============================================================================ +// Non-streaming with tools +// ============================================================================ + +#[tokio::test] +async fn test_non_streaming_still_executes_tools() { + let tool = server_tool("getWeather", serde_json::json!({"temp": 72})); + + let adapter = Arc::new(MockAdapter::new(vec![ + vec![ + run_started("run-1"), + tool_start("call_1", "getWeather", None), + tool_args("call_1", "{\"city\":\"NYC\"}"), + run_finished("tool_calls", "run-1"), + ], + vec![ + run_started("run-2"), + text_content("72F in NYC", "msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Weather in NYC?".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: false, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Text(text) => assert_eq!(text, "72F in NYC"), + _ => panic!("Expected Text result"), + } +} + +// ============================================================================ +// Text accumulation across chunks +// ============================================================================ + +#[tokio::test] +async fn test_text_accumulates_across_content_chunks() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + text_start("msg-1"), + text_content("Hello", "msg-1"), + text_content(" ", "msg-1"), + text_content("world", "msg-1"), + text_content("!", "msg-1"), + text_end("msg-1"), + run_finished("stop", "run-1"), + ]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: false, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Text(text) => assert_eq!(text, "Hello world!"), + _ => panic!("Expected Text result"), + } +} + +// ============================================================================ +// Custom events from tool execution +// ============================================================================ + +#[tokio::test] +async fn test_tool_receives_correct_arguments() { + let (tool, captured) = capturing_tool("getWeather"); + + let adapter = Arc::new(MockAdapter::new(vec![ + vec![ + run_started("run-1"), + tool_start("call_1", "getWeather", None), + tool_args("call_1", "{\"city\":\"NYC\",\"unit\":\"fahrenheit\"}"), + run_finished("tool_calls", "run-1"), + ], + vec![ + run_started("run-2"), + text_content("Done.", "msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let _ = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Weather?".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Check captured arguments + let captured_args = captured.lock().unwrap(); + assert_eq!(captured_args.len(), 1); + assert_eq!(captured_args[0]["city"], "NYC"); + assert_eq!(captured_args[0]["unit"], "fahrenheit"); +} + +// ============================================================================ +// Thinking/step events +// ============================================================================ + +#[tokio::test] +async fn test_thinking_step_events_passed_through() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + text_start("msg-1"), + step_finished("Let me think...", "step-1"), + text_content("Here is my answer.", "msg-1"), + text_end("msg-1"), + run_finished("stop", "run-1"), + ]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Think about it".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Chunks(chunks) => { + let step_chunks: Vec<_> = chunks + .iter() + .filter(|c| matches!(c, StreamChunk::StepFinished { .. })) + .collect(); + assert_eq!(step_chunks.len(), 1, "Expected StepFinished chunk"); + let text = collect_text(&chunks); + assert_eq!(text, "Here is my answer."); + } + _ => panic!("Expected Chunks result"), + } +} + +// ============================================================================ +// Error handling +// ============================================================================ + +#[tokio::test] +async fn test_run_error_from_adapter() { + let adapter = Arc::new(MockAdapter::new(vec![vec![ + run_started("run-1"), + run_error("Something went wrong", "run-1"), + ]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Chunks(chunks) => { + let error_chunks: Vec<_> = chunks + .iter() + .filter(|c| matches!(c, StreamChunk::RunError { .. })) + .collect(); + assert_eq!(error_chunks.len(), 1); + } + _ => panic!("Expected Chunks result"), + } +} + +// ============================================================================ +// Empty adapter response +// ============================================================================ + +#[tokio::test] +async fn test_empty_adapter_response() { + let adapter = Arc::new(MockAdapter::new(vec![vec![]])); + + let result = chat(ChatOptions { + adapter, + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + match result { + ChatResult::Chunks(chunks) => assert!(chunks.is_empty()), + _ => panic!("Expected Chunks result"), + } +} + +// ============================================================================ +// Tool with tool_calls finish but no TOOL_CALL events +// ============================================================================ + +#[tokio::test] +async fn test_tool_calls_finish_without_tool_events() { + // Edge case: finish_reason is tool_calls but no tool call events were emitted + let adapter = Arc::new(MockAdapter::new(vec![ + vec![ + run_started("run-1"), + text_content("I tried.", "msg-1"), + run_finished("tool_calls", "run-1"), + ], + vec![ + run_started("run-2"), + text_content("OK.", "msg-2"), + run_finished("stop", "run-2"), + ], + ])); + + let tool = server_tool("unused", serde_json::json!({})); + + let _result = chat(ChatOptions { + adapter: adapter.clone(), + messages: vec![ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hi".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + system_prompts: vec![], + tools: vec![tool], + temperature: None, + top_p: None, + max_tokens: None, + metadata: None, + model_options: None, + agent_loop_strategy: None, + conversation_id: None, + middleware: vec![], + stream: true, + output_schema: None, + }) + .await + .unwrap(); + + // Should have called adapter twice (no tool calls to execute, but still loops) + assert_eq!(adapter.call_count(), 2); +} diff --git a/packages/rust/tanstack-ai/tests/stream_processor_tests.rs b/packages/rust/tanstack-ai/tests/stream_processor_tests.rs new file mode 100644 index 00000000..5dd022be --- /dev/null +++ b/packages/rust/tanstack-ai/tests/stream_processor_tests.rs @@ -0,0 +1,208 @@ +//! Stream processor tests. + +use tanstack_ai::stream::strategies::*; +use tanstack_ai::stream::StreamProcessor; +use tanstack_ai::types::*; + +fn now() -> f64 { + chrono::Utc::now().timestamp_millis() as f64 / 1000.0 +} + +#[test] +fn test_processor_text_streaming() { + let mut processor = StreamProcessor::new(); + + processor.process_chunk(StreamChunk::TextMessageStart { + timestamp: now(), + message_id: "msg-1".to_string(), + role: "assistant".to_string(), + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "Hello".to_string(), + content: None, + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: " world!".to_string(), + content: None, + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageEnd { + timestamp: now(), + message_id: "msg-1".to_string(), + model: None, + }); + + let result = processor.result(); + assert_eq!(result.content, "Hello world!"); +} + +#[test] +fn test_processor_with_full_content() { + let mut processor = StreamProcessor::new(); + + processor.process_chunk(StreamChunk::TextMessageStart { + timestamp: now(), + message_id: "msg-1".to_string(), + role: "assistant".to_string(), + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "Final".to_string(), + content: Some("Full content here".to_string()), + model: None, + }); + + let result = processor.result(); + assert_eq!(result.content, "Full content here"); +} + +#[test] +fn test_processor_tool_calls() { + let mut processor = StreamProcessor::new(); + + processor.process_chunk(StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + processor.process_chunk(StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_1".to_string(), + delta: "{\"city\":".to_string(), + args: None, + model: None, + }); + + processor.process_chunk(StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_1".to_string(), + delta: "\"NYC\"}".to_string(), + args: None, + model: None, + }); + + processor.process_chunk(StreamChunk::ToolCallEnd { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + input: Some(serde_json::json!({"city": "NYC"})), + result: None, + model: None, + }); + + let result = processor.result(); + assert!(result.tool_calls.is_some()); + let tool_calls = result.tool_calls.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "getWeather"); + assert_eq!(tool_calls[0].function.arguments, "{\"city\":\"NYC\"}"); +} + +#[test] +fn test_processor_with_batch_strategy() { + let mut processor = StreamProcessor::with_strategy(Box::new(BatchStrategy::new(2))); + + // First chunk: should not emit (batch size 2) + let r1 = processor.process_chunk(StreamChunk::TextMessageStart { + timestamp: now(), + message_id: "msg-1".to_string(), + role: "assistant".to_string(), + model: None, + }); + assert!(r1.is_some()); // Start always passes through + + let r2 = processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "a".to_string(), + content: None, + model: None, + }); + assert!(r2.is_none()); // Batch: not enough chunks yet + + let r3 = processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "b".to_string(), + content: None, + model: None, + }); + assert!(r3.is_some()); // Batch: reached 2 chunks +} + +#[test] +fn test_processor_recording() { + let mut processor = StreamProcessor::new(); + processor.enable_recording(); + + processor.process_chunk(StreamChunk::RunStarted { + timestamp: now(), + run_id: "run-1".to_string(), + thread_id: None, + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "Hello".to_string(), + content: None, + model: None, + }); + + assert_eq!(processor.recordings().len(), 2); + + let recording = processor.to_recording(Some("gpt-4o".to_string()), Some("openai".to_string())); + assert_eq!(recording.version, "1.0"); + assert_eq!(recording.model, Some("gpt-4o".to_string())); + assert_eq!(recording.chunks.len(), 2); +} + +#[test] +fn test_processor_thinking_content() { + let mut processor = StreamProcessor::new(); + + processor.process_chunk(StreamChunk::TextMessageStart { + timestamp: now(), + message_id: "msg-1".to_string(), + role: "assistant".to_string(), + model: None, + }); + + processor.process_chunk(StreamChunk::StepFinished { + timestamp: now(), + step_id: "step-1".to_string(), + delta: "Let me think...".to_string(), + content: None, + model: None, + }); + + processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "Here's my answer.".to_string(), + content: None, + model: None, + }); + + let result = processor.result(); + assert_eq!(result.thinking, Some("Let me think...".to_string())); + assert_eq!(result.content, "Here's my answer."); +} diff --git a/packages/rust/tanstack-ai/tests/test_utils.rs b/packages/rust/tanstack-ai/tests/test_utils.rs new file mode 100644 index 00000000..3b24cdc4 --- /dev/null +++ b/packages/rust/tanstack-ai/tests/test_utils.rs @@ -0,0 +1,290 @@ +//! Test utilities for tanstack-ai tests. +//! +//! Provides a mock adapter, chunk factories, and helpers mirroring +//! the TypeScript test-utils.ts. + +use async_trait::async_trait; +use futures_util::stream; +use std::sync::{Arc, Mutex}; +use tanstack_ai::adapter::{ChunkStream, TextAdapter}; +use tanstack_ai::error::AiResult; +use tanstack_ai::types::*; + +// ============================================================================ +// Chunk factory helpers +// ============================================================================ + +pub fn now() -> f64 { + chrono::Utc::now().timestamp_millis() as f64 / 1000.0 +} + +pub fn run_started(run_id: &str) -> StreamChunk { + StreamChunk::RunStarted { + timestamp: now(), + run_id: run_id.to_string(), + thread_id: None, + model: Some("mock-model".to_string()), + } +} + +pub fn text_start(message_id: &str) -> StreamChunk { + StreamChunk::TextMessageStart { + timestamp: now(), + message_id: message_id.to_string(), + role: "assistant".to_string(), + model: Some("mock-model".to_string()), + } +} + +pub fn text_content(delta: &str, message_id: &str) -> StreamChunk { + StreamChunk::TextMessageContent { + timestamp: now(), + message_id: message_id.to_string(), + delta: delta.to_string(), + content: None, + model: Some("mock-model".to_string()), + } +} + +pub fn _text_content_with_full(delta: &str, full: &str, message_id: &str) -> StreamChunk { + StreamChunk::TextMessageContent { + timestamp: now(), + message_id: message_id.to_string(), + delta: delta.to_string(), + content: Some(full.to_string()), + model: Some("mock-model".to_string()), + } +} + +pub fn text_end(message_id: &str) -> StreamChunk { + StreamChunk::TextMessageEnd { + timestamp: now(), + message_id: message_id.to_string(), + model: Some("mock-model".to_string()), + } +} + +pub fn tool_start(tool_call_id: &str, tool_name: &str, index: Option) -> StreamChunk { + StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: tool_call_id.to_string(), + tool_name: tool_name.to_string(), + parent_message_id: None, + index, + provider_metadata: None, + model: Some("mock-model".to_string()), + } +} + +pub fn tool_args(tool_call_id: &str, delta: &str) -> StreamChunk { + StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: tool_call_id.to_string(), + delta: delta.to_string(), + args: None, + model: Some("mock-model".to_string()), + } +} + +pub fn _tool_end( + tool_call_id: &str, + tool_name: &str, + input: Option, + result: Option, +) -> StreamChunk { + StreamChunk::ToolCallEnd { + timestamp: now(), + tool_call_id: tool_call_id.to_string(), + tool_name: tool_name.to_string(), + input, + result, + model: Some("mock-model".to_string()), + } +} + +pub fn run_finished( + finish_reason: &str, + run_id: &str, +) -> StreamChunk { + StreamChunk::RunFinished { + timestamp: now(), + run_id: run_id.to_string(), + finish_reason: Some(finish_reason.to_string()), + usage: None, + model: Some("mock-model".to_string()), + } +} + +pub fn run_error(message: &str, run_id: &str) -> StreamChunk { + StreamChunk::RunError { + timestamp: now(), + run_id: Some(run_id.to_string()), + error: RunErrorData { + message: message.to_string(), + code: None, + }, + model: Some("mock-model".to_string()), + } +} + +pub fn step_finished(delta: &str, step_id: &str) -> StreamChunk { + StreamChunk::StepFinished { + timestamp: now(), + step_id: step_id.to_string(), + delta: delta.to_string(), + content: None, + model: Some("mock-model".to_string()), + } +} + +// ============================================================================ +// Mock adapter +// ============================================================================ + +/// Tracks calls made to the mock adapter. +#[derive(Debug, Clone)] +pub struct MockCall { + pub messages: Vec, + pub _tools: Vec, + pub system_prompts: Vec, + pub temperature: Option, + pub top_p: Option, + pub max_tokens: Option, +} + +/// A mock adapter that returns predetermined chunks per iteration. +pub struct MockAdapter { + iterations: Vec>, + calls: Arc>>, + call_index: Arc>, +} + +impl MockAdapter { + pub fn new(iterations: Vec>) -> Self { + Self { + iterations, + calls: Arc::new(Mutex::new(Vec::new())), + call_index: Arc::new(Mutex::new(0)), + } + } + + pub fn calls(&self) -> Vec { + self.calls.lock().unwrap().clone() + } + + pub fn call_count(&self) -> usize { + self.calls.lock().unwrap().len() + } +} + +#[async_trait] +impl TextAdapter for MockAdapter { + fn name(&self) -> &str { + "mock" + } + + fn model(&self) -> &str { + "mock-model" + } + + async fn chat_stream(&self, options: &TextOptions) -> AiResult { + // Record the call + let call = MockCall { + messages: options.messages.clone(), + _tools: options.tools.iter().map(|t| t.name.clone()).collect(), + system_prompts: options.system_prompts.clone(), + temperature: options.temperature, + top_p: options.top_p, + max_tokens: options.max_tokens, + }; + self.calls.lock().unwrap().push(call); + + // Get chunks for this iteration + let mut idx = self.call_index.lock().unwrap(); + let chunks = self.iterations.get(*idx).cloned().unwrap_or_default(); + *idx += 1; + + let stream = stream::iter(chunks.into_iter().map(Ok)); + Ok(Box::pin(stream)) + } + + async fn structured_output( + &self, + _options: &StructuredOutputOptions, + ) -> AiResult { + Ok(StructuredOutputResult { + data: serde_json::json!({}), + raw_text: "{}".to_string(), + }) + } +} + +// ============================================================================ +// Collect helper +// ============================================================================ + +/// Collect all text content from chunks. +pub fn collect_text(chunks: &[StreamChunk]) -> String { + let mut content = String::new(); + for chunk in chunks { + if let StreamChunk::TextMessageContent { + delta, content: full, .. + } = chunk + { + if let Some(full_content) = full { + content = full_content.clone(); + } else { + content.push_str(delta); + } + } + } + content +} + +/// Count chunks of a specific type. +pub fn count_chunk_type(chunks: &[StreamChunk], chunk_type: &str) -> usize { + chunks + .iter() + .filter(|c| match chunk_type { + "RUN_STARTED" => matches!(c, StreamChunk::RunStarted { .. }), + "RUN_FINISHED" => matches!(c, StreamChunk::RunFinished { .. }), + "TEXT_MESSAGE_START" => matches!(c, StreamChunk::TextMessageStart { .. }), + "TEXT_MESSAGE_CONTENT" => matches!(c, StreamChunk::TextMessageContent { .. }), + "TEXT_MESSAGE_END" => matches!(c, StreamChunk::TextMessageEnd { .. }), + "TOOL_CALL_START" => matches!(c, StreamChunk::ToolCallStart { .. }), + "TOOL_CALL_ARGS" => matches!(c, StreamChunk::ToolCallArgs { .. }), + "TOOL_CALL_END" => matches!(c, StreamChunk::ToolCallEnd { .. }), + _ => false, + }) + .count() +} + +/// Helper to create a simple server tool for testing. +pub fn server_tool( + name: &str, + result: serde_json::Value, +) -> Tool { + let result_clone = result.clone(); + Tool::new(name, format!("Test tool: {}", name)).with_execute( + move |_args: serde_json::Value, _ctx: tanstack_ai::types::ToolExecutionContext| { + let r = result_clone.clone(); + async move { Ok(r) } + }, + ) +} + +/// Helper to create a server tool that captures its arguments. +pub fn capturing_tool(name: &str) -> (Tool, Arc>>) { + let captured: Arc>> = Arc::new(Mutex::new(Vec::new())); + let captured_clone = captured.clone(); + let tool = Tool::new(name, format!("Capturing tool: {}", name)).with_execute( + move |args: serde_json::Value, _ctx: tanstack_ai::types::ToolExecutionContext| { + let c = captured_clone.clone(); + async move { + c.lock().unwrap().push(args.clone()); + Ok(args) + } + }, + ); + (tool, captured) +} diff --git a/packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs b/packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs new file mode 100644 index 00000000..72f7f0a9 --- /dev/null +++ b/packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs @@ -0,0 +1,170 @@ +//! Tool call manager tests. + +use tanstack_ai::tools::ToolCallManager; +use tanstack_ai::types::*; + +fn now() -> f64 { + chrono::Utc::now().timestamp_millis() as f64 / 1000.0 +} + +#[test] +fn test_manager_accumulate_tool_calls() { + let mut manager = ToolCallManager::new(); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + assert!(manager.has_tool_calls()); + let calls = manager.tool_calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "getWeather"); +} + +#[test] +fn test_manager_accumulate_arguments() { + let mut manager = ToolCallManager::new(); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + manager.add_args_event(&StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_1".to_string(), + delta: "{\"city\":".to_string(), + args: None, + model: None, + }); + + manager.add_args_event(&StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_1".to_string(), + delta: "\"NYC\"}".to_string(), + args: None, + model: None, + }); + + let calls = manager.tool_calls(); + assert_eq!(calls[0].function.arguments, "{\"city\":\"NYC\"}"); +} + +#[test] +fn test_manager_complete_tool_call() { + let mut manager = ToolCallManager::new(); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + manager.complete_tool_call(&StreamChunk::ToolCallEnd { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + input: Some(serde_json::json!({"city": "NYC"})), + result: None, + model: None, + }); + + let calls = manager.tool_calls(); + assert_eq!(calls[0].function.arguments, "{\"city\":\"NYC\"}"); +} + +#[test] +fn test_manager_clear() { + let mut manager = ToolCallManager::new(); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + assert!(manager.has_tool_calls()); + manager.clear(); + assert!(!manager.has_tool_calls()); +} + +#[test] +fn test_manager_filters_incomplete_calls() { + let mut manager = ToolCallManager::new(); + + // Add call with empty name (incomplete) + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + assert!(!manager.has_tool_calls()); +} + +#[test] +fn test_manager_parallel_tool_calls() { + let mut manager = ToolCallManager::new(); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }); + + manager.add_start_event(&StreamChunk::ToolCallStart { + timestamp: now(), + tool_call_id: "call_2".to_string(), + tool_name: "getTime".to_string(), + parent_message_id: None, + index: Some(1), + provider_metadata: None, + model: None, + }); + + manager.add_args_event(&StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_1".to_string(), + delta: "{\"city\":\"NYC\"}".to_string(), + args: None, + model: None, + }); + + manager.add_args_event(&StreamChunk::ToolCallArgs { + timestamp: now(), + tool_call_id: "call_2".to_string(), + delta: "{\"tz\":\"EST\"}".to_string(), + args: None, + model: None, + }); + + let calls = manager.tool_calls(); + assert_eq!(calls.len(), 2); +} diff --git a/packages/rust/tanstack-ai/tests/unit_tests.rs b/packages/rust/tanstack-ai/tests/unit_tests.rs new file mode 100644 index 00000000..52bcfece --- /dev/null +++ b/packages/rust/tanstack-ai/tests/unit_tests.rs @@ -0,0 +1,730 @@ +//! Comprehensive tests for tanstack-ai. +//! +//! Ports the TypeScript test suite to Rust with full feature parity. + +use tanstack_ai::*; + +// ============================================================================ +// Strategies Tests +// ============================================================================ + +#[test] +fn test_immediate_strategy_always_emits() { + let mut strategy = ImmediateStrategy; + assert!(strategy.should_emit("", "")); + assert!(strategy.should_emit("hello", "hello")); + assert!(strategy.should_emit("world", "hello world")); +} + +#[test] +fn test_immediate_strategy_empty_strings() { + let mut strategy = ImmediateStrategy; + assert!(strategy.should_emit("", "")); + assert!(strategy.should_emit("", "accumulated")); +} + +#[test] +fn test_punctuation_strategy_emits_on_punctuation() { + let mut strategy = PunctuationStrategy::new(); + assert!(strategy.should_emit("Hello.", "Hello.")); + assert!(strategy.should_emit("World!", "Hello. World!")); + assert!(strategy.should_emit("How?", "Hello. World! How?")); + assert!(strategy.should_emit("Test;", "Test;")); + assert!(strategy.should_emit("Test:", "Test:")); +} + +#[test] +fn test_punctuation_strategy_no_punctuation() { + let mut strategy = PunctuationStrategy::new(); + assert!(!strategy.should_emit("Hello", "Hello")); + assert!(!strategy.should_emit("world", "Hello world")); + assert!(!strategy.should_emit("test", "Hello world test")); +} + +#[test] +fn test_punctuation_strategy_newline() { + let mut strategy = PunctuationStrategy::new(); + assert!(strategy.should_emit("Line 1\n", "Line 1\n")); + assert!(strategy.should_emit("\nLine 2", "Line 1\n\nLine 2")); +} + +#[test] +fn test_punctuation_strategy_comma() { + let mut strategy = PunctuationStrategy::new(); + assert!(strategy.should_emit("Hello,", "Hello,")); + assert!(strategy.should_emit("world,", "Hello, world,")); +} + +#[test] +fn test_batch_strategy_emits_every_n() { + let mut strategy = BatchStrategy::new(3); + assert!(!strategy.should_emit("chunk1", "chunk1")); + assert!(!strategy.should_emit("chunk2", "chunk1chunk2")); + assert!(strategy.should_emit("chunk3", "chunk1chunk2chunk3")); + assert!(!strategy.should_emit("chunk4", "chunk1chunk2chunk3chunk4")); + assert!(!strategy.should_emit("chunk5", "chunk1chunk2chunk3chunk4chunk5")); + assert!(strategy.should_emit("chunk6", "chunk1chunk2chunk3chunk4chunk5chunk6")); +} + +#[test] +fn test_batch_strategy_default_size() { + let mut strategy = BatchStrategy::default(); + for i in 1..5 { + assert!(!strategy.should_emit(&format!("chunk{}", i), &"x".repeat(i))); + } + assert!(strategy.should_emit("chunk5", "xxxxx")); +} + +#[test] +fn test_batch_strategy_resets_after_emit() { + let mut strategy = BatchStrategy::new(2); + assert!(!strategy.should_emit("chunk1", "chunk1")); + assert!(strategy.should_emit("chunk2", "chunk1chunk2")); + // Counter resets + assert!(!strategy.should_emit("chunk3", "chunk1chunk2chunk3")); + assert!(strategy.should_emit("chunk4", "chunk1chunk2chunk3chunk4")); +} + +#[test] +fn test_batch_strategy_size_one() { + let mut strategy = BatchStrategy::new(1); + assert!(strategy.should_emit("chunk1", "chunk1")); + assert!(strategy.should_emit("chunk2", "chunk1chunk2")); + assert!(strategy.should_emit("chunk3", "chunk1chunk2chunk3")); +} + +#[test] +fn test_batch_strategy_reset_method() { + let mut strategy = BatchStrategy::new(3); + assert!(!strategy.should_emit("chunk1", "chunk1")); + assert!(!strategy.should_emit("chunk2", "chunk1chunk2")); + strategy.reset(); + assert!(!strategy.should_emit("chunk3", "chunk1chunk2chunk3")); + assert!(!strategy.should_emit("chunk4", "chunk1chunk2chunk3chunk4")); + assert!(strategy.should_emit("chunk5", "chunk1chunk2chunk3chunk4chunk5")); +} + +#[test] +fn test_word_boundary_strategy_whitespace() { + let mut strategy = WordBoundaryStrategy; + assert!(strategy.should_emit("Hello ", "Hello ")); + assert!(strategy.should_emit("world ", "Hello world ")); + assert!(strategy.should_emit("test\n", "Hello world test\n")); + assert!(strategy.should_emit("more\t", "Hello world test\nmore\t")); +} + +#[test] +fn test_word_boundary_strategy_no_whitespace() { + let mut strategy = WordBoundaryStrategy; + assert!(!strategy.should_emit("Hello", "Hello")); + assert!(!strategy.should_emit("world", "Helloworld")); + assert!(!strategy.should_emit("test", "Helloworldtest")); +} + +#[test] +fn test_composite_strategy_or_logic() { + let mut strategy = CompositeStrategy::new(vec![ + Box::new(ImmediateStrategy), + Box::new(PunctuationStrategy::new()), + ]); + // ImmediateStrategy always returns true + assert!(strategy.should_emit("hello", "hello")); + assert!(strategy.should_emit("world", "hello world")); +} + +#[test] +fn test_composite_strategy_all_false() { + let mut strategy = CompositeStrategy::new(vec![ + Box::new(BatchStrategy::new(10)), + Box::new(WordBoundaryStrategy), + ]); + assert!(!strategy.should_emit("Hello", "Hello")); + assert!(!strategy.should_emit("world", "Helloworld")); +} + +#[test] +fn test_composite_strategy_any_true() { + let mut strategy = CompositeStrategy::new(vec![ + Box::new(BatchStrategy::new(10)), + Box::new(WordBoundaryStrategy), + ]); + // Batch says no, but wordBoundary says yes + assert!(strategy.should_emit("Hello ", "Hello ")); +} + +#[test] +fn test_composite_strategy_reset() { + let mut strategy = CompositeStrategy::new(vec![ + Box::new(BatchStrategy::new(3)), + Box::new(BatchStrategy::new(5)), + ]); + strategy.should_emit("chunk1", "chunk1"); + strategy.should_emit("chunk2", "chunk1chunk2"); + strategy.reset(); + assert!(!strategy.should_emit("chunk3", "chunk1chunk2chunk3")); + assert!(!strategy.should_emit("chunk4", "chunk1chunk2chunk3chunk4")); + assert!(strategy.should_emit("chunk5", "chunk1chunk2chunk3chunk4chunk5")); +} + +#[test] +fn test_strategies_unicode() { + let mut punctuation = PunctuationStrategy::new(); + let mut word_boundary = WordBoundaryStrategy; + + assert!(punctuation.should_emit("Hello 世界.", "Hello 世界.")); + assert!(word_boundary.should_emit("世界 ", "Hello 世界 ")); + assert!(!word_boundary.should_emit("世界", "Hello 世界")); +} + +#[test] +fn test_strategies_empty_chunks() { + assert!(ImmediateStrategy.should_emit("", "")); + assert!(!PunctuationStrategy::new().should_emit("", "")); + assert!(!WordBoundaryStrategy.should_emit("", "")); +} + +// ============================================================================ +// Agent Loop Strategy Tests +// ============================================================================ + +fn make_state(iteration_count: u32, finish_reason: Option<&str>) -> AgentLoopState { + AgentLoopState { + iteration_count, + messages: vec![], + finish_reason: finish_reason.map(String::from), + } +} + +#[test] +fn test_max_iterations_below_max() { + let strategy = max_iterations(5); + assert!(strategy(&make_state(0, None))); + assert!(strategy(&make_state(2, None))); + assert!(strategy(&make_state(4, None))); +} + +#[test] +fn test_max_iterations_at_max() { + let strategy = max_iterations(5); + assert!(!strategy(&make_state(5, None))); + assert!(!strategy(&make_state(6, None))); +} + +#[test] +fn test_max_iterations_one() { + let strategy = max_iterations(1); + assert!(strategy(&make_state(0, None))); + assert!(!strategy(&make_state(1, None))); +} + +#[test] +fn test_max_iterations_zero() { + let strategy = max_iterations(0); + assert!(!strategy(&make_state(0, None))); +} + +#[test] +fn test_until_finish_reason_stops_on_match() { + let strategy = until_finish_reason("stop".to_string()); + assert!(!strategy(&make_state(1, Some("stop")))); +} + +#[test] +fn test_until_finish_reason_continues_on_no_match() { + let strategy = until_finish_reason("stop".to_string()); + assert!(strategy(&make_state(1, Some("tool_calls")))); + assert!(strategy(&make_state(1, None))); +} + +#[test] +fn test_combine_strategies_all_true() { + let strategy = combine_strategies(vec![ + max_iterations(5), + Box::new(|state: &AgentLoopState| state.iteration_count < 10), + ]); + assert!(strategy(&make_state(2, None))); +} + +#[test] +fn test_combine_strategies_any_false() { + let strategy = combine_strategies(vec![ + max_iterations(5), + Box::new(|state: &AgentLoopState| state.iteration_count < 10), + ]); + assert!(!strategy(&make_state(5, None))); +} + +#[test] +fn test_combine_strategies_empty() { + let strategy = combine_strategies(vec![]); + assert!(strategy(&make_state(0, None))); +} + +// ============================================================================ +// Tool Definition Tests +// ============================================================================ + +#[test] +fn test_tool_definition_basic() { + let def = tool_definition("getWeather", "Get the weather for a location"); + assert_eq!(def.name, "getWeather"); + assert_eq!(def.description, "Get the weather for a location"); +} + +#[test] +fn test_tool_definition_with_schemas() { + let def = tool_definition("addToCart", "Add item to cart") + .input_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { + "itemId": { "type": "string" }, + "quantity": { "type": "number" } + } + }))) + .output_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { + "success": { "type": "boolean" }, + "cartId": { "type": "string" } + } + }))); + + assert!(def.input_schema.is_some()); + assert!(def.output_schema.is_some()); +} + +#[test] +fn test_tool_definition_to_tool() { + let def = tool_definition("simpleTool", "A simple tool"); + let tool = def.to_tool(); + assert_eq!(tool.name, "simpleTool"); + assert_eq!(tool.description, "A simple tool"); + assert!(!tool.needs_approval); + assert!(!tool.lazy); +} + +#[test] +fn test_tool_definition_needs_approval() { + let def = tool_definition("deleteFile", "Delete a file").needs_approval(true); + assert!(def.needs_approval); + + let tool = def.to_tool(); + assert!(tool.needs_approval); +} + +#[test] +fn test_tool_definition_lazy() { + let def = tool_definition("discoverableTool", "A lazy tool").lazy(true); + assert!(def.lazy); + + let tool = def.to_tool(); + assert!(tool.lazy); +} + +#[test] +fn test_tool_definition_metadata() { + let def = tool_definition("customTool", "A custom tool") + .metadata(serde_json::json!({"category": "utility"})); + assert!(def.metadata.is_some()); +} + +#[test] +fn test_tool_definition_to_server_tool() { + let def = tool_definition("compute", "Compute something"); + let tool = def.to_server_tool(|_args: serde_json::Value, _ctx| async move { + Ok(serde_json::json!({"result": 42})) + }); + assert_eq!(tool.name, "compute"); + assert!(tool.execute.is_some()); +} + +#[test] +fn test_tool_builder() { + let tool = Tool::new("get_weather", "Get the weather") + .with_input_schema(json_schema(serde_json::json!({ + "type": "object", + "properties": { "location": { "type": "string" } }, + "required": ["location"] + }))) + .with_approval(); + + assert_eq!(tool.name, "get_weather"); + assert!(tool.needs_approval); + assert!(tool.input_schema.is_some()); +} + +// ============================================================================ +// Message Converter Tests +// ============================================================================ + +#[test] +fn test_ui_to_model_simple_text() { + let ui_msg = UiMessage { + id: "msg-1".to_string(), + role: UiMessageRole::User, + parts: vec![MessagePart::Text { + content: "Hello".to_string(), + metadata: None, + }], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!(model_msgs[0].role, MessageRole::User); + assert_eq!(model_msgs[0].content.as_str(), Some("Hello")); +} + +#[test] +fn test_ui_to_model_multiple_text_parts() { + let ui_msg = UiMessage { + id: "msg-1".to_string(), + role: UiMessageRole::User, + parts: vec![ + MessagePart::Text { + content: "Hello ".to_string(), + metadata: None, + }, + MessagePart::Text { + content: "world!".to_string(), + metadata: None, + }, + ], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + // Multiple text parts are combined into MessageContent::Parts + match &model_msgs[0].content { + MessageContent::Parts(parts) => { + assert_eq!(parts.len(), 2); + } + _ => panic!("Expected Parts content for multiple text parts"), + } +} + +#[test] +fn test_ui_to_model_multimodal_image() { + let ui_msg = UiMessage { + id: "msg-1".to_string(), + role: UiMessageRole::User, + parts: vec![ + MessagePart::Text { + content: "What is in this image?".to_string(), + metadata: None, + }, + MessagePart::Image { + source: ContentPartSource::Url { + value: "https://example.com/cat.jpg".to_string(), + mime_type: None, + }, + metadata: None, + }, + ], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!(model_msgs[0].role, MessageRole::User); + match &model_msgs[0].content { + MessageContent::Parts(parts) => { + assert_eq!(parts.len(), 2); + assert!(matches!(&parts[0], ContentPart::Text { .. })); + assert!(matches!(&parts[1], ContentPart::Image { .. })); + } + _ => panic!("Expected Parts content"), + } +} + +#[test] +fn test_ui_to_model_multimodal_audio() { + let ui_msg = UiMessage { + id: "msg-1".to_string(), + role: UiMessageRole::User, + parts: vec![ + MessagePart::Text { + content: "Transcribe this".to_string(), + metadata: None, + }, + MessagePart::Audio { + source: ContentPartSource::Data { + value: "base64audio".to_string(), + mime_type: "audio/mp3".to_string(), + }, + metadata: None, + }, + ], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + match &model_msgs[0].content { + MessageContent::Parts(parts) => { + assert_eq!(parts.len(), 2); + assert!(matches!(&parts[1], ContentPart::Audio { .. })); + } + _ => panic!("Expected Parts content"), + } +} + +#[test] +fn test_model_to_ui_round_trip_text() { + let model_msg = ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Text("Hi there!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }; + + let ui_msg = model_message_to_ui_message(&model_msg); + assert_eq!(ui_msg.role, UiMessageRole::Assistant); + assert_eq!(ui_msg.parts.len(), 1); + match &ui_msg.parts[0] { + MessagePart::Text { content, .. } => assert_eq!(content, "Hi there!"), + _ => panic!("Expected Text part"), + } + + // Round-trip back + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!(model_msgs[0].content.as_str(), Some("Hi there!")); +} + +#[test] +fn test_model_to_ui_with_tool_calls() { + let model_msg = ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Null, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + call_type: "function".to_string(), + function: ToolCallFunction { + name: "getWeather".to_string(), + arguments: "{\"city\":\"NYC\"}".to_string(), + }, + provider_metadata: None, + }]), + tool_call_id: None, + }; + + let ui_msg = model_message_to_ui_message(&model_msg); + assert_eq!(ui_msg.parts.len(), 1); + match &ui_msg.parts[0] { + MessagePart::ToolCall { id, name, .. } => { + assert_eq!(id, "call_1"); + assert_eq!(name, "getWeather"); + } + _ => panic!("Expected ToolCall part"), + } +} + +#[test] +fn test_model_messages_to_ui_messages_batch() { + let messages = vec![ + ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Text("Hi!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ]; + + let ui_messages = model_messages_to_ui_messages(&messages); + assert_eq!(ui_messages.len(), 2); + assert_eq!(ui_messages[0].role, UiMessageRole::User); + assert_eq!(ui_messages[1].role, UiMessageRole::Assistant); +} + +// ============================================================================ +// JSON Parser Tests +// ============================================================================ + +#[test] +fn test_parse_complete_json() { + let result = tanstack_ai::stream::json_parser::parse_partial_json(r#"{"name": "test"}"#); + assert!(result.is_some()); + assert_eq!(result.unwrap()["name"], "test"); +} + +#[test] +fn test_parse_partial_json_object() { + let result = tanstack_ai::stream::json_parser::parse_partial_json(r#"{"name": "te"#); + assert!(result.is_some()); + assert_eq!(result.unwrap()["name"], "te"); +} + +#[test] +fn test_parse_partial_json_array() { + let result = tanstack_ai::stream::json_parser::parse_partial_json(r#"[1, 2, 3"#); + assert!(result.is_some()); + assert_eq!(result.unwrap(), serde_json::json!([1, 2, 3])); +} + +#[test] +fn test_parse_empty_string() { + assert!(tanstack_ai::stream::json_parser::parse_partial_json("").is_none()); + assert!(tanstack_ai::stream::json_parser::parse_partial_json(" ").is_none()); +} + +#[test] +fn test_parse_nested_partial() { + let result = tanstack_ai::stream::json_parser::parse_partial_json(r#"{"user": {"name": "Jo"#); + assert!(result.is_some()); + assert_eq!(result.unwrap()["user"]["name"], "Jo"); +} + +// ============================================================================ +// Stream Chunk Serialization Tests +// ============================================================================ + +#[test] +fn test_stream_chunk_serialization() { + let chunk = StreamChunk::TextMessageContent { + timestamp: 1234567890.0, + message_id: "msg-123".to_string(), + delta: "Hello".to_string(), + content: None, + model: Some("gpt-4o".to_string()), + }; + + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("\"type\":\"TEXT_MESSAGE_CONTENT\"")); + assert!(json.contains("\"delta\":\"Hello\"")); + + let deserialized: StreamChunk = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.event_type(), AguiEventType::TextMessageContent); +} + +#[test] +fn test_stream_chunk_run_started() { + let chunk = StreamChunk::RunStarted { + timestamp: 1234567890.0, + run_id: "run-1".to_string(), + thread_id: None, + model: Some("gpt-4o".to_string()), + }; + + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("\"type\":\"RUN_STARTED\"")); + assert!(json.contains("\"runId\":\"run-1\"")); + + let deserialized: StreamChunk = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.event_type(), AguiEventType::RunStarted); +} + +#[test] +fn test_stream_chunk_tool_call_events() { + let start = StreamChunk::ToolCallStart { + timestamp: 0.0, + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + parent_message_id: None, + index: Some(0), + provider_metadata: None, + model: None, + }; + assert_eq!(start.event_type(), AguiEventType::ToolCallStart); + + let args = StreamChunk::ToolCallArgs { + timestamp: 0.0, + tool_call_id: "call_1".to_string(), + delta: "{\"city\":".to_string(), + args: None, + model: None, + }; + assert_eq!(args.event_type(), AguiEventType::ToolCallArgs); + + let end = StreamChunk::ToolCallEnd { + timestamp: 0.0, + tool_call_id: "call_1".to_string(), + tool_name: "getWeather".to_string(), + input: Some(serde_json::json!({"city": "NYC"})), + result: Some("{\"temp\":72}".to_string()), + model: None, + }; + assert_eq!(end.event_type(), AguiEventType::ToolCallEnd); +} + +#[test] +fn test_stream_chunk_run_finished_with_usage() { + let chunk = StreamChunk::RunFinished { + timestamp: 0.0, + run_id: "run-1".to_string(), + finish_reason: Some("stop".to_string()), + usage: Some(Usage { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }), + model: Some("gpt-4o".to_string()), + }; + + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("\"finishReason\":\"stop\"")); + assert!(json.contains("\"promptTokens\":10")); + + let deserialized: StreamChunk = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.event_type(), AguiEventType::RunFinished); +} + +#[test] +fn test_model_message_serialization() { + let msg = ModelMessage { + role: MessageRole::User, + content: MessageContent::Text("Hello!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }; + + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: ModelMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.role, MessageRole::User); + assert_eq!(deserialized.content.as_str(), Some("Hello!")); +} + +#[test] +fn test_model_message_with_tool_calls() { + let msg = ModelMessage { + role: MessageRole::Assistant, + content: MessageContent::Null, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + call_type: "function".to_string(), + function: ToolCallFunction { + name: "getWeather".to_string(), + arguments: "{\"city\":\"NYC\"}".to_string(), + }, + provider_metadata: None, + }]), + tool_call_id: None, + }; + + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: ModelMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.role, MessageRole::Assistant); + assert!(deserialized.tool_calls.is_some()); + assert_eq!(deserialized.tool_calls.as_ref().unwrap()[0].function.name, "getWeather"); +} + +#[test] +fn test_detect_image_mime_type() { + let png_header = &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]; + assert_eq!(tanstack_ai::detect_image_mime_type(png_header), "image/png"); + + let jpeg_header = &[0xFF, 0xD8, 0xFF, 0xE0]; + assert_eq!(tanstack_ai::detect_image_mime_type(jpeg_header), "image/jpeg"); +} From 9c7d4d75e3d437d7ee5db9558f6319b2679b2b88 Mon Sep 17 00:00:00 2001 From: shane Date: Thu, 2 Apr 2026 16:08:22 -0500 Subject: [PATCH 2/3] refactor(rust): move crate under crates and add release guide --- .../rust => crates}/tanstack-ai/Cargo.lock | 0 .../rust => crates}/tanstack-ai/Cargo.toml | 0 crates/tanstack-ai/release.md | 49 +++++++++++++++++++ .../tanstack-ai/src/adapter.rs | 0 .../tanstack-ai/src/adapters/anthropic.rs | 0 .../tanstack-ai/src/adapters/gemini.rs | 0 .../tanstack-ai/src/adapters/mod.rs | 0 .../tanstack-ai/src/adapters/openai.rs | 0 .../rust => crates}/tanstack-ai/src/chat.rs | 0 .../src/client/connection_adapters.rs | 0 .../tanstack-ai/src/client/mod.rs | 0 .../rust => crates}/tanstack-ai/src/error.rs | 0 .../rust => crates}/tanstack-ai/src/lib.rs | 0 .../tanstack-ai/src/messages.rs | 0 .../tanstack-ai/src/middleware.rs | 0 .../tanstack-ai/src/stream/json_parser.rs | 0 .../tanstack-ai/src/stream/mod.rs | 0 .../tanstack-ai/src/stream/processor.rs | 0 .../tanstack-ai/src/stream/strategies.rs | 0 .../tanstack-ai/src/stream/types.rs | 0 .../tanstack-ai/src/stream_response.rs | 0 .../tanstack-ai/src/tools/definition.rs | 0 .../tanstack-ai/src/tools/mod.rs | 0 .../tanstack-ai/src/tools/registry.rs | 0 .../tanstack-ai/src/tools/tool_calls.rs | 0 .../rust => crates}/tanstack-ai/src/types.rs | 0 .../tanstack-ai/tests/chat_tests.rs | 0 .../tests/stream_processor_tests.rs | 0 .../tanstack-ai/tests/test_utils.rs | 0 .../tests/tool_call_manager_tests.rs | 0 .../tanstack-ai/tests/unit_tests.rs | 0 examples/rs-chat/Cargo.toml | 2 +- 32 files changed, 50 insertions(+), 1 deletion(-) rename {packages/rust => crates}/tanstack-ai/Cargo.lock (100%) rename {packages/rust => crates}/tanstack-ai/Cargo.toml (100%) create mode 100644 crates/tanstack-ai/release.md rename {packages/rust => crates}/tanstack-ai/src/adapter.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/adapters/anthropic.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/adapters/gemini.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/adapters/mod.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/adapters/openai.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/chat.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/client/connection_adapters.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/client/mod.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/error.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/lib.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/messages.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/middleware.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream/json_parser.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream/mod.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream/processor.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream/strategies.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream/types.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/stream_response.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/tools/definition.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/tools/mod.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/tools/registry.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/tools/tool_calls.rs (100%) rename {packages/rust => crates}/tanstack-ai/src/types.rs (100%) rename {packages/rust => crates}/tanstack-ai/tests/chat_tests.rs (100%) rename {packages/rust => crates}/tanstack-ai/tests/stream_processor_tests.rs (100%) rename {packages/rust => crates}/tanstack-ai/tests/test_utils.rs (100%) rename {packages/rust => crates}/tanstack-ai/tests/tool_call_manager_tests.rs (100%) rename {packages/rust => crates}/tanstack-ai/tests/unit_tests.rs (100%) diff --git a/packages/rust/tanstack-ai/Cargo.lock b/crates/tanstack-ai/Cargo.lock similarity index 100% rename from packages/rust/tanstack-ai/Cargo.lock rename to crates/tanstack-ai/Cargo.lock diff --git a/packages/rust/tanstack-ai/Cargo.toml b/crates/tanstack-ai/Cargo.toml similarity index 100% rename from packages/rust/tanstack-ai/Cargo.toml rename to crates/tanstack-ai/Cargo.toml diff --git a/crates/tanstack-ai/release.md b/crates/tanstack-ai/release.md new file mode 100644 index 00000000..81371341 --- /dev/null +++ b/crates/tanstack-ai/release.md @@ -0,0 +1,49 @@ +# tanstack-ai release process + +This document describes how to publish the Rust crate to crates.io. + +## Prerequisites + +- crates.io owner access for `tanstack-ai` +- `cargo` authenticated (`cargo login `), or `CARGO_REGISTRY_TOKEN` set +- all intended changes merged to `main` + +## 1) Bump version + +Update `version` in `crates/tanstack-ai/Cargo.toml`. + +## 2) Validate locally + +From repo root: + +```bash +cargo test --manifest-path crates/tanstack-ai/Cargo.toml +cargo package --manifest-path crates/tanstack-ai/Cargo.toml +``` + +`cargo package` verifies the crate can be built from the published archive. + +## 3) Publish + +From repo root: + +```bash +cargo publish --manifest-path crates/tanstack-ai/Cargo.toml +``` + +If needed, publish with an explicit token: + +```bash +CARGO_REGISTRY_TOKEN= cargo publish --manifest-path crates/tanstack-ai/Cargo.toml +``` + +## 4) Tag and notes + +Create a git tag after publish, for example: + +```bash +git tag rust/tanstack-ai-v0.1.0 +git push origin rust/tanstack-ai-v0.1.0 +``` + +Then document highlights in the GitHub release notes. diff --git a/packages/rust/tanstack-ai/src/adapter.rs b/crates/tanstack-ai/src/adapter.rs similarity index 100% rename from packages/rust/tanstack-ai/src/adapter.rs rename to crates/tanstack-ai/src/adapter.rs diff --git a/packages/rust/tanstack-ai/src/adapters/anthropic.rs b/crates/tanstack-ai/src/adapters/anthropic.rs similarity index 100% rename from packages/rust/tanstack-ai/src/adapters/anthropic.rs rename to crates/tanstack-ai/src/adapters/anthropic.rs diff --git a/packages/rust/tanstack-ai/src/adapters/gemini.rs b/crates/tanstack-ai/src/adapters/gemini.rs similarity index 100% rename from packages/rust/tanstack-ai/src/adapters/gemini.rs rename to crates/tanstack-ai/src/adapters/gemini.rs diff --git a/packages/rust/tanstack-ai/src/adapters/mod.rs b/crates/tanstack-ai/src/adapters/mod.rs similarity index 100% rename from packages/rust/tanstack-ai/src/adapters/mod.rs rename to crates/tanstack-ai/src/adapters/mod.rs diff --git a/packages/rust/tanstack-ai/src/adapters/openai.rs b/crates/tanstack-ai/src/adapters/openai.rs similarity index 100% rename from packages/rust/tanstack-ai/src/adapters/openai.rs rename to crates/tanstack-ai/src/adapters/openai.rs diff --git a/packages/rust/tanstack-ai/src/chat.rs b/crates/tanstack-ai/src/chat.rs similarity index 100% rename from packages/rust/tanstack-ai/src/chat.rs rename to crates/tanstack-ai/src/chat.rs diff --git a/packages/rust/tanstack-ai/src/client/connection_adapters.rs b/crates/tanstack-ai/src/client/connection_adapters.rs similarity index 100% rename from packages/rust/tanstack-ai/src/client/connection_adapters.rs rename to crates/tanstack-ai/src/client/connection_adapters.rs diff --git a/packages/rust/tanstack-ai/src/client/mod.rs b/crates/tanstack-ai/src/client/mod.rs similarity index 100% rename from packages/rust/tanstack-ai/src/client/mod.rs rename to crates/tanstack-ai/src/client/mod.rs diff --git a/packages/rust/tanstack-ai/src/error.rs b/crates/tanstack-ai/src/error.rs similarity index 100% rename from packages/rust/tanstack-ai/src/error.rs rename to crates/tanstack-ai/src/error.rs diff --git a/packages/rust/tanstack-ai/src/lib.rs b/crates/tanstack-ai/src/lib.rs similarity index 100% rename from packages/rust/tanstack-ai/src/lib.rs rename to crates/tanstack-ai/src/lib.rs diff --git a/packages/rust/tanstack-ai/src/messages.rs b/crates/tanstack-ai/src/messages.rs similarity index 100% rename from packages/rust/tanstack-ai/src/messages.rs rename to crates/tanstack-ai/src/messages.rs diff --git a/packages/rust/tanstack-ai/src/middleware.rs b/crates/tanstack-ai/src/middleware.rs similarity index 100% rename from packages/rust/tanstack-ai/src/middleware.rs rename to crates/tanstack-ai/src/middleware.rs diff --git a/packages/rust/tanstack-ai/src/stream/json_parser.rs b/crates/tanstack-ai/src/stream/json_parser.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream/json_parser.rs rename to crates/tanstack-ai/src/stream/json_parser.rs diff --git a/packages/rust/tanstack-ai/src/stream/mod.rs b/crates/tanstack-ai/src/stream/mod.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream/mod.rs rename to crates/tanstack-ai/src/stream/mod.rs diff --git a/packages/rust/tanstack-ai/src/stream/processor.rs b/crates/tanstack-ai/src/stream/processor.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream/processor.rs rename to crates/tanstack-ai/src/stream/processor.rs diff --git a/packages/rust/tanstack-ai/src/stream/strategies.rs b/crates/tanstack-ai/src/stream/strategies.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream/strategies.rs rename to crates/tanstack-ai/src/stream/strategies.rs diff --git a/packages/rust/tanstack-ai/src/stream/types.rs b/crates/tanstack-ai/src/stream/types.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream/types.rs rename to crates/tanstack-ai/src/stream/types.rs diff --git a/packages/rust/tanstack-ai/src/stream_response.rs b/crates/tanstack-ai/src/stream_response.rs similarity index 100% rename from packages/rust/tanstack-ai/src/stream_response.rs rename to crates/tanstack-ai/src/stream_response.rs diff --git a/packages/rust/tanstack-ai/src/tools/definition.rs b/crates/tanstack-ai/src/tools/definition.rs similarity index 100% rename from packages/rust/tanstack-ai/src/tools/definition.rs rename to crates/tanstack-ai/src/tools/definition.rs diff --git a/packages/rust/tanstack-ai/src/tools/mod.rs b/crates/tanstack-ai/src/tools/mod.rs similarity index 100% rename from packages/rust/tanstack-ai/src/tools/mod.rs rename to crates/tanstack-ai/src/tools/mod.rs diff --git a/packages/rust/tanstack-ai/src/tools/registry.rs b/crates/tanstack-ai/src/tools/registry.rs similarity index 100% rename from packages/rust/tanstack-ai/src/tools/registry.rs rename to crates/tanstack-ai/src/tools/registry.rs diff --git a/packages/rust/tanstack-ai/src/tools/tool_calls.rs b/crates/tanstack-ai/src/tools/tool_calls.rs similarity index 100% rename from packages/rust/tanstack-ai/src/tools/tool_calls.rs rename to crates/tanstack-ai/src/tools/tool_calls.rs diff --git a/packages/rust/tanstack-ai/src/types.rs b/crates/tanstack-ai/src/types.rs similarity index 100% rename from packages/rust/tanstack-ai/src/types.rs rename to crates/tanstack-ai/src/types.rs diff --git a/packages/rust/tanstack-ai/tests/chat_tests.rs b/crates/tanstack-ai/tests/chat_tests.rs similarity index 100% rename from packages/rust/tanstack-ai/tests/chat_tests.rs rename to crates/tanstack-ai/tests/chat_tests.rs diff --git a/packages/rust/tanstack-ai/tests/stream_processor_tests.rs b/crates/tanstack-ai/tests/stream_processor_tests.rs similarity index 100% rename from packages/rust/tanstack-ai/tests/stream_processor_tests.rs rename to crates/tanstack-ai/tests/stream_processor_tests.rs diff --git a/packages/rust/tanstack-ai/tests/test_utils.rs b/crates/tanstack-ai/tests/test_utils.rs similarity index 100% rename from packages/rust/tanstack-ai/tests/test_utils.rs rename to crates/tanstack-ai/tests/test_utils.rs diff --git a/packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs b/crates/tanstack-ai/tests/tool_call_manager_tests.rs similarity index 100% rename from packages/rust/tanstack-ai/tests/tool_call_manager_tests.rs rename to crates/tanstack-ai/tests/tool_call_manager_tests.rs diff --git a/packages/rust/tanstack-ai/tests/unit_tests.rs b/crates/tanstack-ai/tests/unit_tests.rs similarity index 100% rename from packages/rust/tanstack-ai/tests/unit_tests.rs rename to crates/tanstack-ai/tests/unit_tests.rs diff --git a/examples/rs-chat/Cargo.toml b/examples/rs-chat/Cargo.toml index c86c3788..7677ebac 100644 --- a/examples/rs-chat/Cargo.toml +++ b/examples/rs-chat/Cargo.toml @@ -4,6 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -tanstack-ai = { path = "../../packages/rust/tanstack-ai" } +tanstack-ai = { path = "../../crates/tanstack-ai" } tokio = { version = "1", features = ["full"] } serde_json = "1" From a774d4db6a363905a932403df8f0792f1eda6cb5 Mon Sep 17 00:00:00 2001 From: shane Date: Fri, 3 Apr 2026 12:22:47 -0500 Subject: [PATCH 3/3] fix(rust): address PR feedback across streaming and tools --- crates/tanstack-ai/Cargo.lock | 13 + crates/tanstack-ai/Cargo.toml | 1 + crates/tanstack-ai/src/adapter.rs | 60 +++ crates/tanstack-ai/src/adapters/anthropic.rs | 226 +++++------ crates/tanstack-ai/src/adapters/gemini.rs | 140 ++++--- crates/tanstack-ai/src/adapters/openai.rs | 363 ++++++++++------- crates/tanstack-ai/src/chat.rs | 214 +++++++--- .../src/client/connection_adapters.rs | 373 ++++++++++++------ crates/tanstack-ai/src/lib.rs | 15 +- crates/tanstack-ai/src/messages.rs | 102 +++-- crates/tanstack-ai/src/middleware.rs | 18 + crates/tanstack-ai/src/stream/json_parser.rs | 38 +- crates/tanstack-ai/src/stream/mod.rs | 4 +- crates/tanstack-ai/src/stream/processor.rs | 47 ++- crates/tanstack-ai/src/stream/strategies.rs | 4 +- crates/tanstack-ai/src/stream_response.rs | 29 +- crates/tanstack-ai/src/tools/definition.rs | 7 +- crates/tanstack-ai/src/tools/registry.rs | 4 + crates/tanstack-ai/src/tools/tool_calls.rs | 55 +-- crates/tanstack-ai/src/types.rs | 39 +- crates/tanstack-ai/tests/chat_tests.rs | 5 +- .../tests/stream_processor_tests.rs | 24 +- crates/tanstack-ai/tests/test_utils.rs | 14 +- .../tests/tool_call_manager_tests.rs | 2 + crates/tanstack-ai/tests/unit_tests.rs | 65 ++- examples/rs-chat/Cargo.lock | 13 + examples/rs-chat/src/main.rs | 4 +- 27 files changed, 1216 insertions(+), 663 deletions(-) diff --git a/crates/tanstack-ai/Cargo.lock b/crates/tanstack-ai/Cargo.lock index 3d34ec8b..e42df945 100644 --- a/crates/tanstack-ai/Cargo.lock +++ b/crates/tanstack-ai/Cargo.lock @@ -1226,6 +1226,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "tracing", "uuid", ] @@ -1396,9 +1397,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.36" diff --git a/crates/tanstack-ai/Cargo.toml b/crates/tanstack-ai/Cargo.toml index 41ccaf00..4c503b63 100644 --- a/crates/tanstack-ai/Cargo.toml +++ b/crates/tanstack-ai/Cargo.toml @@ -21,6 +21,7 @@ tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["codec"] } bytes = "1" regex = "1" +tracing = "0.1" [dev-dependencies] tokio = { version = "1", features = ["full", "test-util"] } diff --git a/crates/tanstack-ai/src/adapter.rs b/crates/tanstack-ai/src/adapter.rs index b91210da..f24fef6d 100644 --- a/crates/tanstack-ai/src/adapter.rs +++ b/crates/tanstack-ai/src/adapter.rs @@ -1,6 +1,11 @@ use async_trait::async_trait; use futures_core::Stream; +use reqwest::{ + header::{HeaderMap, HeaderName, HeaderValue}, + Client, RequestBuilder, Response, +}; use std::pin::Pin; +use std::time::Duration; use crate::error::AiResult; use crate::types::{StreamChunk, StructuredOutputOptions, StructuredOutputResult, TextOptions}; @@ -40,6 +45,61 @@ pub trait TextAdapter: Send + Sync { ) -> AiResult; } +pub(crate) fn build_http_client(config: &TextAdapterConfig) -> Client { + let mut builder = Client::builder(); + + if let Some(timeout_secs) = config.timeout { + builder = builder.timeout(Duration::from_secs(timeout_secs)); + } + + if let Some(headers) = &config.headers { + let mut default_headers = HeaderMap::new(); + for (name, value) in headers { + let header_name = + HeaderName::from_bytes(name.as_bytes()).expect("invalid adapter header name"); + let header_value = HeaderValue::from_str(value).expect("invalid adapter header value"); + default_headers.insert(header_name, header_value); + } + builder = builder.default_headers(default_headers); + } + + builder + .build() + .expect("failed to build adapter http client") +} + +pub(crate) async fn send_with_retries( + request: RequestBuilder, + max_retries: u32, +) -> Result { + for attempt in 0..=max_retries { + let response = request + .try_clone() + .expect("failed to clone request builder for retry") + .send() + .await; + + match response { + Ok(response) + if attempt < max_retries + && (response.status().is_server_error() + || response.status().as_u16() == 429) => + { + continue; + } + Ok(response) => return Ok(response), + Err(error) if attempt < max_retries => continue, + Err(error) => return Err(error), + } + } + + unreachable!("retry loop always returns") +} + +/// Placeholder traits for non-text capabilities. +/// +/// The Rust SDK currently ships text adapters first, with these traits reserved +/// for upcoming image, summarization, speech, and transcription adapters. /// Image generation adapter trait. #[async_trait] pub trait ImageAdapter: Send + Sync { diff --git a/crates/tanstack-ai/src/adapters/anthropic.rs b/crates/tanstack-ai/src/adapters/anthropic.rs index c0852391..6c9f2457 100644 --- a/crates/tanstack-ai/src/adapters/anthropic.rs +++ b/crates/tanstack-ai/src/adapters/anthropic.rs @@ -4,7 +4,9 @@ use reqwest::Client; use serde::Deserialize; use std::collections::HashMap; -use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::adapter::{ + build_http_client, send_with_retries, ChunkStream, TextAdapter, TextAdapterConfig, +}; use crate::error::{AiError, AiResult}; use crate::types::*; @@ -16,6 +18,7 @@ pub struct AnthropicTextAdapter { model: String, base_url: String, client: Client, + max_retries: u32, } impl AnthropicTextAdapter { @@ -26,6 +29,7 @@ impl AnthropicTextAdapter { model: model.into(), base_url: "https://api.anthropic.com".to_string(), client: Client::new(), + max_retries: 0, } } @@ -35,18 +39,17 @@ impl AnthropicTextAdapter { api_key: impl Into, config: TextAdapterConfig, ) -> Self { + let api_key = config.api_key.clone().unwrap_or_else(|| api_key.into()); let mut adapter = Self::new(model, api_key); + adapter.client = build_http_client(&config); + adapter.max_retries = config.max_retries.unwrap_or(0); if let Some(base_url) = config.base_url { adapter.base_url = base_url; } adapter } - fn build_request_body( - &self, - options: &TextOptions, - stream: bool, - ) -> serde_json::Value { + fn build_request_body(&self, options: &TextOptions, stream: bool) -> serde_json::Value { let mut body = serde_json::Map::new(); body.insert("model".to_string(), serde_json::json!(self.model)); @@ -66,7 +69,11 @@ impl AnthropicTextAdapter { .map(|msg| { let mut m = serde_json::Map::new(); // Anthropic uses "user" for tool result messages - let role_str = if msg.role == MessageRole::Tool { "user" } else { msg.role.as_str() }; + let role_str = if msg.role == MessageRole::Tool { + "user" + } else { + msg.role.as_str() + }; m.insert("role".to_string(), serde_json::json!(role_str)); if msg.role == MessageRole::Tool { @@ -96,29 +103,27 @@ impl AnthropicTextAdapter { ContentPart::Text { content } => { serde_json::json!({"type": "text", "text": content}) } - ContentPart::Image { source } => { - match source { - ContentPartSource::Data { value, mime_type } => { - serde_json::json!({ - "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": value - } - }) - } - ContentPartSource::Url { value, .. } => { - serde_json::json!({ - "type": "image", - "source": { - "type": "url", - "url": value - } - }) - } + ContentPart::Image { source } => match source { + ContentPartSource::Data { value, mime_type } => { + serde_json::json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": value + } + }) } - } + ContentPartSource::Url { value, .. } => { + serde_json::json!({ + "type": "image", + "source": { + "type": "url", + "url": value + } + }) + } + }, _ => serde_json::json!({"type": "text", "text": ""}), }) .collect(); @@ -150,10 +155,7 @@ impl AnthropicTextAdapter { }]), ); } else { - m.insert( - "content".to_string(), - serde_json::json!(""), - ); + m.insert("content".to_string(), serde_json::json!("")); } } } @@ -212,10 +214,7 @@ enum AnthropicEvent { content_block: AnthropicContentBlock, }, #[serde(rename = "content_block_delta")] - ContentBlockDelta { - index: usize, - delta: AnthropicDelta, - }, + ContentBlockDelta { index: usize, delta: AnthropicDelta }, #[serde(rename = "content_block_stop")] ContentBlockStop { index: usize }, #[serde(rename = "message_delta")] @@ -307,15 +306,16 @@ impl TextAdapter for AnthropicTextAdapter { let body = self.build_request_body(options, true); let url = format!("{}/v1/messages", self.base_url); - let response = self - .client - .post(&url) - .header("x-api-key", &self.api_key) - .header("anthropic-version", "2023-06-01") - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -357,15 +357,16 @@ impl TextAdapter for AnthropicTextAdapter { let url = format!("{}/v1/messages", self.base_url); - let response = self - .client - .post(&url) - .header("x-api-key", &self.api_key) - .header("anthropic-version", "2023-06-01") - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -475,17 +476,17 @@ where })); } AnthropicDelta::InputJsonDelta { partial_json } => { - let tool_call_id = tool_block_ids - .get(&index) - .cloned() - .unwrap_or_else(|| format!("block-{}", index)); - chunks.push(Ok(StreamChunk::ToolCallArgs { - timestamp: now, - tool_call_id, - delta: partial_json, - args: None, - model: model_str, - })); + if let Some(tool_call_id) = tool_block_ids.get(&index).cloned() { + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta: partial_json, + args: None, + model: model_str, + })); + } else { + tracing::warn!(index, "missing anthropic tool_use id for input_json_delta"); + } } _ => {} } @@ -498,10 +499,10 @@ where } } Err(e) => { - eprintln!( - "Failed to parse Anthropic event: {} - {}", - e, - &data_str[..data_str.len().min(200)] + tracing::warn!( + error = %e, + data = %&data_str[..data_str.len().min(200)], + "failed to parse anthropic event" ); } } @@ -574,17 +575,17 @@ where })); } AnthropicDelta::InputJsonDelta { partial_json } => { - let tool_call_id = tool_block_ids - .get(&index) - .cloned() - .unwrap_or_else(|| format!("block-{}", index)); - chunks.push(Ok(StreamChunk::ToolCallArgs { - timestamp: now, - tool_call_id, - delta: partial_json, - args: None, - model: model_str, - })); + if let Some(tool_call_id) = tool_block_ids.get(&index).cloned() { + chunks.push(Ok(StreamChunk::ToolCallArgs { + timestamp: now, + tool_call_id, + delta: partial_json, + args: None, + model: model_str, + })); + } else { + tracing::warn!(index, "missing anthropic tool_use id for input_json_delta"); + } } _ => {} } @@ -602,10 +603,10 @@ where )); } Err(e) => { - eprintln!( - "Failed to parse Anthropic event: {} - {}", - e, - &data_str[..data_str.len().min(200)] + tracing::warn!( + error = %e, + data = %&data_str[..data_str.len().min(200)], + "failed to parse anthropic event" ); } } @@ -627,39 +628,14 @@ fn convert_anthropic_event(event: AnthropicEvent, model: &str) -> Vec { - vec![ - StreamChunk::RunStarted { - timestamp: now, - run_id: message.id, - thread_id: None, - model: model_str, - }, - ] + vec![StreamChunk::RunStarted { + timestamp: now, + run_id: message.id, + thread_id: None, + model: model_str, + }] } - AnthropicEvent::ContentBlockStart { index, content_block } => match content_block { - AnthropicContentBlock::Text { .. } => { - vec![StreamChunk::TextMessageStart { - timestamp: now, - message_id: format!("block-{}", index), - role: "assistant".to_string(), - model: model_str, - }] - } - AnthropicContentBlock::ToolUse { id, name, .. } => { - vec![StreamChunk::ToolCallStart { - timestamp: now, - tool_call_id: id, - tool_name: name, - parent_message_id: None, - index: Some(index), - provider_metadata: None, - model: model_str, - }] - } - _ => vec![], - }, - AnthropicEvent::ContentBlockDelta { index, delta } => match delta { AnthropicDelta::TextDelta { text } => { vec![StreamChunk::TextMessageContent { @@ -670,15 +646,7 @@ fn convert_anthropic_event(event: AnthropicEvent, model: &str) -> Vec { - vec![StreamChunk::ToolCallArgs { - timestamp: now, - tool_call_id: format!("block-{}", index), - delta: partial_json, - args: None, - model: model_str, - }] - } + AnthropicDelta::InputJsonDelta { .. } => vec![], _ => vec![], }, @@ -767,12 +735,16 @@ mod tests { assert_eq!(chunks.len(), 2); match &chunks[0] { - Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => assert_eq!(tool_call_id, "toolu_123"), + Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => { + assert_eq!(tool_call_id, "toolu_123") + } other => panic!("unexpected chunk 0: {other:?}"), } match &chunks[1] { - Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => assert_eq!(tool_call_id, "toolu_123"), + Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => { + assert_eq!(tool_call_id, "toolu_123") + } other => panic!("unexpected chunk 1: {other:?}"), } } diff --git a/crates/tanstack-ai/src/adapters/gemini.rs b/crates/tanstack-ai/src/adapters/gemini.rs index 02ebfb97..15435f25 100644 --- a/crates/tanstack-ai/src/adapters/gemini.rs +++ b/crates/tanstack-ai/src/adapters/gemini.rs @@ -3,7 +3,9 @@ use futures_core::Stream; use reqwest::Client; use serde::Deserialize; -use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::adapter::{ + build_http_client, send_with_retries, ChunkStream, TextAdapter, TextAdapterConfig, +}; use crate::error::{AiError, AiResult}; use crate::types::*; @@ -15,6 +17,7 @@ pub struct GeminiTextAdapter { model: String, base_url: String, client: Client, + max_retries: u32, } impl GeminiTextAdapter { @@ -25,6 +28,7 @@ impl GeminiTextAdapter { model: model.into(), base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(), client: Client::new(), + max_retries: 0, } } @@ -34,17 +38,17 @@ impl GeminiTextAdapter { api_key: impl Into, config: TextAdapterConfig, ) -> Self { + let api_key = config.api_key.clone().unwrap_or_else(|| api_key.into()); let mut adapter = Self::new(model, api_key); + adapter.client = build_http_client(&config); + adapter.max_retries = config.max_retries.unwrap_or(0); if let Some(base_url) = config.base_url { adapter.base_url = base_url; } adapter } - fn build_request_body( - &self, - options: &TextOptions, - ) -> serde_json::Value { + fn build_request_body(&self, options: &TextOptions) -> serde_json::Value { let mut body = serde_json::Map::new(); // Contents (messages) @@ -72,6 +76,8 @@ impl GeminiTextAdapter { MessageContent::Text(text) => text.clone(), MessageContent::Null | MessageContent::Parts(_) => String::new(), }; + // Gemini correlates tool results by function name rather than a separate call ID. + // `tool_call_id` is expected to carry that name for tool result messages here. contents.push(serde_json::json!({ "role": "user", "parts": [{ @@ -96,29 +102,28 @@ impl GeminiTextAdapter { MessageContent::Text(text) => { vec![serde_json::json!({"text": text})] } - MessageContent::Parts(ps) => { - ps.iter() - .map(|part| match part { - ContentPart::Text { content } => { - serde_json::json!({"text": content}) + MessageContent::Parts(ps) => ps + .iter() + .map(|part| match part { + ContentPart::Text { content } => { + serde_json::json!({"text": content}) + } + ContentPart::Image { source } => match source { + ContentPartSource::Data { value, mime_type } => { + serde_json::json!({ + "inlineData": { + "mimeType": mime_type, + "data": value + } + }) } - ContentPart::Image { source } => match source { - ContentPartSource::Data { value, mime_type } => { - serde_json::json!({ - "inlineData": { - "mimeType": mime_type, - "data": value - } - }) - } - ContentPartSource::Url { value, .. } => { - serde_json::json!({"text": format!("[Image: {}]", value)}) - } - }, - _ => serde_json::json!({"text": ""}), - }) - .collect() - } + ContentPartSource::Url { value, .. } => { + serde_json::json!({"text": format!("[Image: {}]", value)}) + } + }, + _ => serde_json::json!({"text": ""}), + }) + .collect(), MessageContent::Null => { vec![serde_json::json!({"text": ""})] } @@ -206,9 +211,15 @@ struct GeminiContent { #[derive(Debug, Deserialize)] #[serde(untagged)] enum GeminiPart { - Text { text: String }, - FunctionCall { function_call: GeminiFunctionCall }, - FunctionResponse { function_response: GeminiFunctionResponse }, + Text { + text: String, + }, + FunctionCall { + function_call: GeminiFunctionCall, + }, + FunctionResponse { + function_response: GeminiFunctionResponse, + }, } #[allow(dead_code)] @@ -250,13 +261,14 @@ impl TextAdapter for GeminiTextAdapter { self.base_url, self.model, self.api_key ); - let response = self - .client - .post(&url) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -282,13 +294,20 @@ impl TextAdapter for GeminiTextAdapter { // Add response schema if let Some(obj) = body.as_object_mut() { - obj.insert( - "generationConfig".to_string(), - serde_json::json!({ - "response_mime_type": "application/json", - "response_schema": options.output_schema - }), - ); + let generation_config = obj + .entry("generationConfig".to_string()) + .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new())); + + if let Some(config) = generation_config.as_object_mut() { + config.insert( + "response_mime_type".to_string(), + serde_json::json!("application/json"), + ); + config.insert( + "response_schema".to_string(), + serde_json::json!(options.output_schema), + ); + } } let url = format!( @@ -296,13 +315,14 @@ impl TextAdapter for GeminiTextAdapter { self.base_url, self.model, self.api_key ); - let response = self - .client - .post(&url) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -336,10 +356,7 @@ impl TextAdapter for GeminiTextAdapter { } } -fn parse_gemini_sse_stream( - stream: S, - model: String, -) -> impl Stream> +fn parse_gemini_sse_stream(stream: S, model: String) -> impl Stream> where S: Stream> + Send, { @@ -377,10 +394,10 @@ where } } Err(e) => { - eprintln!( - "Failed to parse Gemini event: {} - {}", - e, - &data_str[..data_str.len().min(200)] + tracing::warn!( + error = %e, + data = %&data_str[..data_str.len().min(200)], + "failed to parse gemini event" ); } } @@ -390,7 +407,10 @@ where line_buf = line_buf[processed_to..].to_string(); if !chunks.is_empty() { - return Some((futures_util::stream::iter(chunks), (stream, line_buf))); + return Some(( + futures_util::stream::iter(chunks), + (stream, line_buf), + )); } } Some(Err(e)) => { diff --git a/crates/tanstack-ai/src/adapters/openai.rs b/crates/tanstack-ai/src/adapters/openai.rs index 85e1095d..5e4154b9 100644 --- a/crates/tanstack-ai/src/adapters/openai.rs +++ b/crates/tanstack-ai/src/adapters/openai.rs @@ -4,7 +4,9 @@ use reqwest::Client; use serde::Deserialize; use std::collections::HashMap; -use crate::adapter::{ChunkStream, TextAdapter, TextAdapterConfig}; +use crate::adapter::{ + build_http_client, send_with_retries, ChunkStream, TextAdapter, TextAdapterConfig, +}; use crate::error::{AiError, AiResult}; use crate::types::*; @@ -16,6 +18,7 @@ pub struct OpenAiTextAdapter { model: String, base_url: String, client: Client, + max_retries: u32, } impl OpenAiTextAdapter { @@ -26,6 +29,7 @@ impl OpenAiTextAdapter { model: model.into(), base_url: "https://api.openai.com/v1".to_string(), client: Client::new(), + max_retries: 0, } } @@ -35,19 +39,34 @@ impl OpenAiTextAdapter { api_key: impl Into, config: TextAdapterConfig, ) -> Self { + let api_key = config.api_key.clone().unwrap_or_else(|| api_key.into()); let mut adapter = Self::new(model, api_key); + adapter.client = build_http_client(&config); + adapter.max_retries = config.max_retries.unwrap_or(0); if let Some(base_url) = config.base_url { adapter.base_url = base_url; } adapter } + fn tool_calls_json(tool_calls: &[ToolCall]) -> Vec { + tool_calls + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + }) + }) + .collect() + } + /// Build the request body for the OpenAI Responses API. - fn build_request_body( - &self, - options: &TextOptions, - stream: bool, - ) -> serde_json::Value { + fn build_request_body(&self, options: &TextOptions, stream: bool) -> serde_json::Value { let mut body = serde_json::Map::new(); body.insert("model".to_string(), serde_json::json!(self.model)); @@ -87,7 +106,10 @@ impl OpenAiTextAdapter { body.insert("top_p".to_string(), serde_json::json!(top_p)); } if let Some(max_tokens) = options.max_tokens { - body.insert("max_output_tokens".to_string(), serde_json::json!(max_tokens)); + body.insert( + "max_output_tokens".to_string(), + serde_json::json!(max_tokens), + ); } serde_json::Value::Object(body) @@ -126,10 +148,16 @@ impl OpenAiTextAdapter { match &msg.content { MessageContent::Text(text) => { - items.push(serde_json::json!({ - "role": msg.role.as_str(), - "content": text - })); + let mut item = serde_json::Map::new(); + item.insert("role".to_string(), serde_json::json!(msg.role.as_str())); + item.insert("content".to_string(), serde_json::json!(text)); + if let Some(tool_calls) = &msg.tool_calls { + item.insert( + "tool_calls".to_string(), + serde_json::json!(Self::tool_calls_json(tool_calls)), + ); + } + items.push(serde_json::Value::Object(item)); } MessageContent::Parts(parts) => { let mut content = Vec::new(); @@ -144,7 +172,9 @@ impl OpenAiTextAdapter { ContentPart::Image { source } => { let url = match source { ContentPartSource::Url { value, .. } => value.clone(), - ContentPartSource::Data { value, mime_type, .. } => { + ContentPartSource::Data { + value, mime_type, .. + } => { format!("data:{};base64,{}", mime_type, value) } }; @@ -156,26 +186,23 @@ impl OpenAiTextAdapter { _ => {} } } - items.push(serde_json::json!({ - "role": msg.role.as_str(), - "content": content - })); + let mut item = serde_json::Map::new(); + item.insert("role".to_string(), serde_json::json!(msg.role.as_str())); + item.insert("content".to_string(), serde_json::json!(content)); + if let Some(tool_calls) = &msg.tool_calls { + item.insert( + "tool_calls".to_string(), + serde_json::json!(Self::tool_calls_json(tool_calls)), + ); + } + items.push(serde_json::Value::Object(item)); } MessageContent::Null => { // Handle tool calls in assistant messages if let Some(tool_calls) = &msg.tool_calls { items.push(serde_json::json!({ "role": "assistant", - "tool_calls": tool_calls.iter().map(|tc| { - serde_json::json!({ - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments - } - }) - }).collect::>() + "tool_calls": Self::tool_calls_json(tool_calls) })); } } @@ -325,14 +352,15 @@ impl TextAdapter for OpenAiTextAdapter { let body = self.build_request_body(options, true); let url = format!("{}/responses", self.base_url); - let response = self - .client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -375,14 +403,15 @@ impl TextAdapter for OpenAiTextAdapter { let url = format!("{}/responses", self.base_url); - let response = self - .client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; + let response = send_with_retries( + self.client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body), + self.max_retries, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -423,8 +452,15 @@ where String::new(), HashMap::::new(), HashMap::::new(), + false, ), - move |(mut stream, mut line_buf, mut item_to_call, mut item_to_tool)| { + move |( + mut stream, + mut line_buf, + mut item_to_call, + mut item_to_tool, + mut has_tool_calls, + )| { let model = model.clone(); async move { loop { @@ -453,18 +489,25 @@ where match serde_json::from_str::(data_str) { Ok(event) => { - let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let now = chrono::Utc::now().timestamp_millis() as f64 + / 1000.0; let model_str = Some(model.clone()); match event { - OpenAiStreamEvent::OutputItemAdded { item, .. } => match item { - OpenAiOutputItem::Message { id, role, .. } => { - chunks.push(Ok(StreamChunk::TextMessageStart { - timestamp: now, - message_id: id, - role, - model: model_str, - })); + OpenAiStreamEvent::OutputItemAdded { + item, .. + } => match item { + OpenAiOutputItem::Message { + id, role, .. + } => { + chunks.push(Ok( + StreamChunk::TextMessageStart { + timestamp: now, + message_id: id, + role, + model: model_str, + }, + )); } OpenAiOutputItem::FunctionCall { id, @@ -472,17 +515,22 @@ where call_id, .. } => { - item_to_call.insert(id.clone(), call_id.clone()); - item_to_tool.insert(id.clone(), name.clone()); - chunks.push(Ok(StreamChunk::ToolCallStart { - timestamp: now, - tool_call_id: call_id, - tool_name: name, - parent_message_id: Some(id), - index: None, - provider_metadata: None, - model: model_str, - })); + has_tool_calls = true; + item_to_call + .insert(id.clone(), call_id.clone()); + item_to_tool + .insert(id.clone(), name.clone()); + chunks.push(Ok( + StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: call_id, + tool_name: name, + parent_message_id: Some(id), + index: None, + provider_metadata: None, + model: model_str, + }, + )); } _ => {} }, @@ -491,8 +539,10 @@ where delta, .. } => { - let tool_call_id = - item_to_call.get(&item_id).cloned().unwrap_or(item_id); + let tool_call_id = item_to_call + .get(&item_id) + .cloned() + .unwrap_or(item_id); chunks.push(Ok(StreamChunk::ToolCallArgs { timestamp: now, tool_call_id, @@ -506,10 +556,14 @@ where arguments, .. } => { - let tool_call_id = - item_to_call.get(&item_id).cloned().unwrap_or(item_id.clone()); - let tool_name = - item_to_tool.get(&item_id).cloned().unwrap_or_default(); + let tool_call_id = item_to_call + .get(&item_id) + .cloned() + .unwrap_or(item_id.clone()); + let tool_name = item_to_tool + .get(&item_id) + .cloned() + .unwrap_or_default(); let input: Option = serde_json::from_str(&arguments).ok(); chunks.push(Ok(StreamChunk::ToolCallEnd { @@ -522,16 +576,21 @@ where })); } other => { - for chunk in convert_openai_event(other, &model) { + for chunk in convert_openai_event( + other, + &model, + has_tool_calls, + ) { chunks.push(Ok(chunk)); } } } } Err(e) => { - eprintln!( - "Failed to parse OpenAI event: {} - {}", - e, &data_str[..data_str.len().min(200)] + tracing::warn!( + error = %e, + data = %&data_str[..data_str.len().min(200)], + "failed to parse openai event" ); } } @@ -544,14 +603,14 @@ where if !chunks.is_empty() { return Some(( futures_util::stream::iter(chunks), - (stream, line_buf, item_to_call, item_to_tool), + (stream, line_buf, item_to_call, item_to_tool, has_tool_calls), )); } } Some(Err(e)) => { return Some(( futures_util::stream::iter(vec![Err(AiError::Http(e))]), - (stream, line_buf, item_to_call, item_to_tool), + (stream, line_buf, item_to_call, item_to_tool, has_tool_calls), )); } None => { @@ -567,47 +626,61 @@ where match serde_json::from_str::(data_str) { Ok(event) => { - let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let now = + chrono::Utc::now().timestamp_millis() as f64 / 1000.0; let model_str = Some(model.clone()); let mut chunks: Vec> = Vec::new(); match event { - OpenAiStreamEvent::OutputItemAdded { item, .. } => match item { - OpenAiOutputItem::Message { id, role, .. } => { - chunks.push(Ok(StreamChunk::TextMessageStart { - timestamp: now, - message_id: id, - role, - model: model_str, - })); - } - OpenAiOutputItem::FunctionCall { - id, - name, - call_id, - .. - } => { - item_to_call.insert(id.clone(), call_id.clone()); - item_to_tool.insert(id.clone(), name.clone()); - chunks.push(Ok(StreamChunk::ToolCallStart { - timestamp: now, - tool_call_id: call_id, - tool_name: name, - parent_message_id: Some(id), - index: None, - provider_metadata: None, - model: model_str, - })); + OpenAiStreamEvent::OutputItemAdded { item, .. } => { + match item { + OpenAiOutputItem::Message { + id, role, .. + } => { + chunks.push(Ok( + StreamChunk::TextMessageStart { + timestamp: now, + message_id: id, + role, + model: model_str, + }, + )); + } + OpenAiOutputItem::FunctionCall { + id, + name, + call_id, + .. + } => { + has_tool_calls = true; + item_to_call + .insert(id.clone(), call_id.clone()); + item_to_tool + .insert(id.clone(), name.clone()); + chunks.push(Ok( + StreamChunk::ToolCallStart { + timestamp: now, + tool_call_id: call_id, + tool_name: name, + parent_message_id: Some(id), + index: None, + provider_metadata: None, + model: model_str, + }, + )); + } + _ => {} } - _ => {} - }, + } OpenAiStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { - let tool_call_id = - item_to_call.get(&item_id).cloned().unwrap_or(item_id); + let tool_call_id = item_to_call + .get(&item_id) + .cloned() + .unwrap_or(item_id); chunks.push(Ok(StreamChunk::ToolCallArgs { timestamp: now, tool_call_id, @@ -625,8 +698,10 @@ where .get(&item_id) .cloned() .unwrap_or(item_id.clone()); - let tool_name = - item_to_tool.get(&item_id).cloned().unwrap_or_default(); + let tool_name = item_to_tool + .get(&item_id) + .cloned() + .unwrap_or_default(); let input: Option = serde_json::from_str(&arguments).ok(); chunks.push(Ok(StreamChunk::ToolCallEnd { @@ -639,7 +714,11 @@ where })); } other => { - for chunk in convert_openai_event(other, &model) { + for chunk in convert_openai_event( + other, + &model, + has_tool_calls, + ) { chunks.push(Ok(chunk)); } } @@ -647,14 +726,20 @@ where return Some(( futures_util::stream::iter(chunks), - (stream, String::new(), item_to_call, item_to_tool), + ( + stream, + String::new(), + item_to_call, + item_to_tool, + has_tool_calls, + ), )); } Err(e) => { - eprintln!( - "Failed to parse OpenAI event: {} - {}", - e, - &data_str[..data_str.len().min(200)] + tracing::warn!( + error = %e, + data = %&data_str[..data_str.len().min(200)], + "failed to parse openai event" ); } } @@ -671,7 +756,11 @@ where } /// Convert an OpenAI SSE event to one or more StreamChunks. -fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec { +fn convert_openai_event( + event: OpenAiStreamEvent, + model: &str, + has_tool_calls: bool, +) -> Vec { let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; let model_str = Some(model.to_string()); @@ -695,10 +784,7 @@ fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec { vec![StreamChunk::ToolCallStart { timestamp: now, @@ -733,9 +819,7 @@ fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec { + OpenAiStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { vec![StreamChunk::ToolCallArgs { timestamp: now, tool_call_id: item_id, @@ -746,9 +830,7 @@ fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec { let input: Option = serde_json::from_str(&arguments).ok(); vec![StreamChunk::ToolCallEnd { @@ -767,14 +849,15 @@ fn convert_openai_event(event: OpenAiStreamEvent, model: &str) -> Vec "stop".to_string(), "failed" => "error".to_string(), _ => s.to_string(), - }); + }) + }; let usage = response.usage.map(|u| Usage { prompt_tokens: u.input_tokens, @@ -882,10 +965,8 @@ mod tests { let sse_1 = format!("data: {}\ndata: {}\n", start, delta); let sse_2 = format!("data: {}", done); - let stream = futures_util::stream::iter(vec![ - Ok(Bytes::from(sse_1)), - Ok(Bytes::from(sse_2)), - ]); + let stream = + futures_util::stream::iter(vec![Ok(Bytes::from(sse_1)), Ok(Bytes::from(sse_2))]); let chunks = parse_openai_sse_stream(stream, "gpt-4o".to_string()) .collect::>() @@ -894,17 +975,25 @@ mod tests { assert_eq!(chunks.len(), 3); match &chunks[0] { - Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => assert_eq!(tool_call_id, "call_abc"), + Ok(StreamChunk::ToolCallStart { tool_call_id, .. }) => { + assert_eq!(tool_call_id, "call_abc") + } other => panic!("unexpected chunk 0: {other:?}"), } match &chunks[1] { - Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => assert_eq!(tool_call_id, "call_abc"), + Ok(StreamChunk::ToolCallArgs { tool_call_id, .. }) => { + assert_eq!(tool_call_id, "call_abc") + } other => panic!("unexpected chunk 1: {other:?}"), } match &chunks[2] { - Ok(StreamChunk::ToolCallEnd { tool_call_id, tool_name, .. }) => { + Ok(StreamChunk::ToolCallEnd { + tool_call_id, + tool_name, + .. + }) => { assert_eq!(tool_call_id, "call_abc"); assert_eq!(tool_name, "get_weather"); } diff --git a/crates/tanstack-ai/src/chat.rs b/crates/tanstack-ai/src/chat.rs index 58999d07..b37a2797 100644 --- a/crates/tanstack-ai/src/chat.rs +++ b/crates/tanstack-ai/src/chat.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::adapter::TextAdapter; -use crate::error::AiResult; +use crate::error::{AiError, AiResult}; use crate::messages::generate_message_id; use crate::middleware::*; use crate::tools::tool_calls::*; @@ -70,7 +70,9 @@ async fn run_non_streaming_text(options: ChatOptions) -> AiResult { let mut content = String::new(); for chunk in &chunks { if let StreamChunk::TextMessageContent { - delta, content: full, .. + delta, + content: full, + .. } = chunk { if let Some(full_content) = full { @@ -89,7 +91,9 @@ async fn run_non_streaming_text(options: ChatOptions) -> AiResult { /// 2. Once complete, call adapter.structured_output with the conversation context /// 3. Return the structured result async fn run_agentic_structured_output(options: ChatOptions) -> AiResult { - let schema = options.output_schema.clone() + let schema = options + .output_schema + .clone() .expect("run_agentic_structured_output called without output_schema"); let adapter = options.adapter.clone(); let model_name = adapter.model().to_string(); @@ -115,6 +119,15 @@ async fn run_agentic_structured_output(options: ChatOptions) -> AiResult AiResult AiResult> { + async fn execute_tool_calls( + &self, + tool_calls: &[ToolCall], + ) -> AiResult { let approvals: HashMap = HashMap::new(); let client_results: HashMap = HashMap::new(); - let result = crate::tools::tool_calls::execute_tool_calls( + crate::tools::tool_calls::execute_tool_calls( tool_calls, &self.tools, &approvals, &client_results, None, ) - .await?; - - // Combine server results with error results for client tools - let mut all_results = result.results; - - // Handle tools that need client execution (mark as error for server-side) - for client_tool in result.needs_client_execution { - all_results.push(ToolResult { - tool_call_id: client_tool.tool_call_id, - tool_name: client_tool.tool_name, - result: serde_json::json!({"error": "Client-side tool execution not available"}), - state: ToolResultOutputState::Error, - duration_ms: None, - }); - } - - // Handle tools that need approval (mark as error) - for approval in result.needs_approval { - all_results.push(ToolResult { - tool_call_id: approval.tool_call_id, - tool_name: approval.tool_name, - result: serde_json::json!({"error": "Tool requires approval"}), - state: ToolResultOutputState::Error, - duration_ms: None, - }); - } - - Ok(all_results) + .await } /// Build TOOL_CALL_END chunks for tool results. @@ -515,6 +566,63 @@ impl TextEngine { .collect() } + fn build_pending_tool_chunks( + &self, + approvals: &[ApprovalRequest], + client_requests: &[ClientToolRequest], + ) -> Vec { + let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; + let model = Some(self.adapter.model().to_string()); + let mut chunks = Vec::new(); + + for approval in approvals { + chunks.push(StreamChunk::Custom { + timestamp: now, + name: "tool-approval-required".to_string(), + value: Some(serde_json::json!({ + "toolCallId": approval.tool_call_id, + "toolName": approval.tool_name, + "input": approval.input, + "approvalId": approval.approval_id, + })), + model: model.clone(), + }); + } + + for request in client_requests { + chunks.push(StreamChunk::Custom { + timestamp: now, + name: "tool-client-execution-required".to_string(), + value: Some(serde_json::json!({ + "toolCallId": request.tool_call_id, + "toolName": request.tool_name, + "input": request.input, + })), + model: model.clone(), + }); + } + + chunks + } + + fn push_assistant_message( + &mut self, + accumulated_content: &str, + tool_calls: Option>, + ) { + self.messages.push(ModelMessage { + role: MessageRole::Assistant, + content: if accumulated_content.is_empty() { + MessageContent::Null + } else { + MessageContent::Text(accumulated_content.to_string()) + }, + name: None, + tool_calls, + tool_call_id: None, + }); + } + /// Get accumulated content from the last text message. pub fn content(&self) -> String { self.messages @@ -532,7 +640,9 @@ fn accumulated_content_from_chunks(chunks: &[StreamChunk]) -> String { let mut content = String::new(); for chunk in chunks { if let StreamChunk::TextMessageContent { - delta, content: full, .. + delta, + content: full, + .. } = chunk { if let Some(full_content) = full { diff --git a/crates/tanstack-ai/src/client/connection_adapters.rs b/crates/tanstack-ai/src/client/connection_adapters.rs index 545388ad..47cd0bd7 100644 --- a/crates/tanstack-ai/src/client/connection_adapters.rs +++ b/crates/tanstack-ai/src/client/connection_adapters.rs @@ -6,6 +6,7 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use tokio::sync::{broadcast, Mutex, RwLock}; +use tokio_util::sync::CancellationToken; use crate::error::{AiError, AiResult}; use crate::stream::StreamProcessor; @@ -14,12 +15,22 @@ use crate::types::*; /// Connection type for the chat client. pub enum ConnectionAdapter { /// Server-Sent Events via fetch. - ServerSentEvents { url: String, headers: HashMap }, + ServerSentEvents { + url: String, + headers: HashMap, + }, /// HTTP streaming. - HttpStream { url: String, headers: HashMap }, + HttpStream { + url: String, + headers: HashMap, + }, /// Custom stream provider. Custom { - provider: Box) -> Pin> + Send>> + Send + Sync>, + provider: Box< + dyn Fn(Vec) -> Pin> + Send>> + + Send + + Sync, + >, }, } @@ -53,7 +64,7 @@ impl Default for ChatState { pub struct ChatClient { state: Arc>, connection: ConnectionAdapter, - processor: Arc>, + active_cancellation: Arc>>, next_subscription_id: Arc>, subscribers: Arc>>>, client: Client, @@ -65,7 +76,7 @@ impl ChatClient { Self { state: Arc::new(RwLock::new(ChatState::default())), connection, - processor: Arc::new(Mutex::new(StreamProcessor::new())), + active_cancellation: Arc::new(Mutex::new(None)), next_subscription_id: Arc::new(Mutex::new(0)), subscribers: Arc::new(RwLock::new(HashMap::new())), client: Client::new(), @@ -104,12 +115,21 @@ impl ChatClient { }], created_at: Some(chrono::Utc::now()), }); + state.accumulated_content.clear(); state.is_loading = true; state.error = None; } self.notify_subscribers().await; + let cancel_token = CancellationToken::new(); + { + let mut active_cancellation = self.active_cancellation.lock().await; + if let Some(existing) = active_cancellation.replace(cancel_token.clone()) { + existing.cancel(); + } + } + // Build messages for the provider let messages = { let state = self.state.read().await; @@ -119,31 +139,41 @@ impl ChatClient { // Stream the response match &self.connection { ConnectionAdapter::ServerSentEvents { url, headers } => { - self.stream_via_sse(url, headers, messages).await + self.stream_via_sse(url, headers, messages, cancel_token.clone()) + .await } ConnectionAdapter::HttpStream { url, headers } => { - self.stream_via_http(url, headers, messages).await + self.stream_via_http(url, headers, messages, cancel_token.clone()) + .await } ConnectionAdapter::Custom { provider } => { let mut stream = provider(messages); - self.process_stream(&mut stream).await + self.process_stream(&mut stream, cancel_token).await } } } /// Stop the current generation. pub async fn stop(&self) { - let mut state = self.state.write().await; - state.is_loading = false; + if let Some(cancel_token) = self.active_cancellation.lock().await.take() { + cancel_token.cancel(); + } + { + let mut state = self.state.write().await; + state.is_loading = false; + } self.notify_subscribers().await; } /// Clear all messages. pub async fn clear(&self) { - let mut state = self.state.write().await; - state.messages.clear(); - state.accumulated_content.clear(); - state.error = None; + { + let mut state = self.state.write().await; + state.messages.clear(); + state.accumulated_content.clear(); + state.is_loading = false; + state.error = None; + } self.notify_subscribers().await; } @@ -152,6 +182,7 @@ impl ChatClient { url: &str, headers: &HashMap, messages: Vec, + cancel_token: CancellationToken, ) -> AiResult<()> { let body = serde_json::json!({ "messages": messages }); @@ -160,13 +191,28 @@ impl ChatClient { request = request.header(key.as_str(), value.as_str()); } - let response = request.send().await?; + let response = match request.send().await { + Ok(response) => response, + Err(error) => { + { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(error.to_string()); + } + self.clear_active_cancellation().await; + self.notify_subscribers().await; + return Err(AiError::Http(error)); + } + }; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); - let mut state = self.state.write().await; - state.is_loading = false; - state.error = Some(format!("HTTP {}: {}", status, text)); + { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(format!("HTTP {}: {}", status, text)); + } + self.clear_active_cancellation().await; self.notify_subscribers().await; return Err(AiError::Provider(format!("HTTP {}: {}", status, text))); } @@ -177,7 +223,7 @@ impl ChatClient { let chunk_stream = parse_sse_to_chunks(byte_stream); futures_util::pin_mut!(chunk_stream); - self.process_stream(&mut chunk_stream).await + self.process_stream(&mut chunk_stream, cancel_token).await } async fn stream_via_http( @@ -185,6 +231,7 @@ impl ChatClient { url: &str, headers: &HashMap, messages: Vec, + cancel_token: CancellationToken, ) -> AiResult<()> { let body = serde_json::json!({ "messages": messages }); @@ -193,10 +240,29 @@ impl ChatClient { request = request.header(key.as_str(), value.as_str()); } - let response = request.send().await?; + let response = match request.send().await { + Ok(response) => response, + Err(error) => { + { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(error.to_string()); + } + self.clear_active_cancellation().await; + self.notify_subscribers().await; + return Err(AiError::Http(error)); + } + }; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); + { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(format!("HTTP {}: {}", status, text)); + } + self.clear_active_cancellation().await; + self.notify_subscribers().await; return Err(AiError::Provider(format!("HTTP {}: {}", status, text))); } @@ -206,135 +272,191 @@ impl ChatClient { let chunk_stream = parse_ndjson_to_chunks(byte_stream); futures_util::pin_mut!(chunk_stream); - self.process_stream(&mut chunk_stream).await + self.process_stream(&mut chunk_stream, cancel_token).await } - async fn process_stream(&self, stream: &mut Pin> + Send>>) -> AiResult<()> - { - let mut processor = self.processor.lock().await; - - while let Some(result) = stream.next().await { - match result { - Ok(chunk) => { - if let Some(processed) = processor.process_chunk(chunk) { - self.apply_chunk(&processed).await; + async fn process_stream( + &self, + stream: &mut Pin> + Send>>, + cancel_token: CancellationToken, + ) -> AiResult<()> { + let mut processor = StreamProcessor::new(); + + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + { + let mut state = self.state.write().await; + state.is_loading = false; } - } - Err(e) => { - let mut state = self.state.write().await; - state.is_loading = false; - state.error = Some(e.to_string()); + self.clear_active_cancellation().await; self.notify_subscribers().await; - return Err(e); + return Err(AiError::Aborted("generation stopped".to_string())); + } + result = stream.next() => { + match result { + Some(Ok(chunk)) => { + for processed in processor.process_chunk(chunk) { + self.apply_chunk(&processed).await; + } + } + Some(Err(e)) => { + { + let mut state = self.state.write().await; + state.is_loading = false; + state.error = Some(e.to_string()); + } + self.clear_active_cancellation().await; + self.notify_subscribers().await; + return Err(e); + } + None => break, + } } } } - let mut state = self.state.write().await; - state.is_loading = false; + { + let mut state = self.state.write().await; + state.is_loading = false; + } + self.clear_active_cancellation().await; self.notify_subscribers().await; Ok(()) } async fn apply_chunk(&self, chunk: &StreamChunk) { - let mut state = self.state.write().await; - - match chunk { - StreamChunk::TextMessageContent { delta, content, .. } => { - if let Some(full) = content { - state.accumulated_content = full.clone(); - } else { - state.accumulated_content.push_str(delta); - } + { + let mut state = self.state.write().await; + + match chunk { + StreamChunk::TextMessageContent { delta, content, .. } => { + if let Some(full) = content { + state.accumulated_content = full.clone(); + } else { + state.accumulated_content.push_str(delta); + } - // Update or create the assistant message - let new_content = state.accumulated_content.clone(); - if let Some(last) = state.messages.last_mut() { - if last.role == UiMessageRole::Assistant { - if let Some(MessagePart::Text { content, .. }) = last.parts.last_mut() { - *content = new_content; + let new_content = state.accumulated_content.clone(); + if let Some(last) = state.messages.last_mut() { + if last.role == UiMessageRole::Assistant { + if let Some(MessagePart::Text { content, .. }) = last.parts.last_mut() { + *content = new_content; + } } } } - } - StreamChunk::TextMessageStart { role, .. } => { - let ui_role = match role.as_str() { - "system" => UiMessageRole::System, - "assistant" => UiMessageRole::Assistant, - _ => UiMessageRole::Assistant, - }; - state.messages.push(UiMessage { - id: generate_message_id("msg"), - role: ui_role, - parts: vec![MessagePart::Text { - content: String::new(), - metadata: None, - }], - created_at: Some(chrono::Utc::now()), - }); - } + StreamChunk::TextMessageStart { role, .. } => { + let ui_role = match role.as_str() { + "system" => UiMessageRole::System, + "assistant" => UiMessageRole::Assistant, + _ => UiMessageRole::Assistant, + }; + state.accumulated_content.clear(); + state.messages.push(UiMessage { + id: generate_message_id("msg"), + role: ui_role, + parts: vec![MessagePart::Text { + content: String::new(), + metadata: None, + }], + created_at: Some(chrono::Utc::now()), + }); + } - StreamChunk::ToolCallStart { tool_call_id, tool_name, .. } => { - if let Some(last) = state.messages.last_mut() { - if last.role == UiMessageRole::Assistant { - last.parts.push(MessagePart::ToolCall { - id: tool_call_id.clone(), - name: tool_name.clone(), - arguments: String::new(), - state: ToolCallState::AwaitingInput, - approval: None, - output: None, - }); + StreamChunk::ToolCallStart { + tool_call_id, + tool_name, + .. + } => { + if let Some(last) = state.messages.last_mut() { + if last.role == UiMessageRole::Assistant { + last.parts.push(MessagePart::ToolCall { + id: tool_call_id.clone(), + name: tool_name.clone(), + arguments: String::new(), + state: ToolCallState::AwaitingInput, + approval: None, + output: None, + }); + } } } - } - StreamChunk::ToolCallArgs { tool_call_id, delta, .. } => { - if let Some(last) = state.messages.last_mut() { - for part in &mut last.parts { - if let MessagePart::ToolCall { id, arguments, state: tc_state, .. } = part { - if id == tool_call_id { - arguments.push_str(delta); - *tc_state = ToolCallState::InputStreaming; - break; + StreamChunk::ToolCallArgs { + tool_call_id, + delta, + .. + } => { + if let Some(last) = state.messages.last_mut() { + for part in &mut last.parts { + if let MessagePart::ToolCall { + id, + arguments, + state: tc_state, + .. + } = part + { + if id == tool_call_id { + arguments.push_str(delta); + *tc_state = ToolCallState::InputStreaming; + break; + } } } } } - } - StreamChunk::ToolCallEnd { tool_call_id, input, result, .. } => { - if let Some(last) = state.messages.last_mut() { - for part in &mut last.parts { - if let MessagePart::ToolCall { id, state: tc_state, arguments, output, .. } = part { - if id == tool_call_id { - *tc_state = ToolCallState::InputComplete; - if let Some(input_val) = input { - *arguments = serde_json::to_string(input_val).unwrap_or_default(); - } - if let Some(result_str) = result { - *output = serde_json::from_str(result_str).ok(); + StreamChunk::ToolCallEnd { + tool_call_id, + input, + result, + .. + } => { + if let Some(last) = state.messages.last_mut() { + for part in &mut last.parts { + if let MessagePart::ToolCall { + id, + state: tc_state, + arguments, + output, + .. + } = part + { + if id == tool_call_id { + *tc_state = ToolCallState::InputComplete; + if let Some(input_val) = input { + *arguments = + serde_json::to_string(input_val).unwrap_or_default(); + } + if let Some(result_str) = result { + *output = serde_json::from_str(result_str).ok(); + } + break; } - break; } } } } - } - StreamChunk::RunError { error, .. } => { - state.is_loading = false; - state.error = Some(error.message.clone()); - } + StreamChunk::RunError { error, .. } => { + state.is_loading = false; + state.error = Some(error.message.clone()); + } - _ => {} + _ => {} + } } self.notify_subscribers().await; } + async fn clear_active_cancellation(&self) { + self.active_cancellation.lock().await.take(); + } + async fn notify_subscribers(&self) { let state = self.state.read().await; let subscribers = self.subscribers.read().await; @@ -368,9 +490,7 @@ fn generate_message_id(prefix: &str) -> String { } /// Parse SSE byte stream into StreamChunk events. -fn parse_sse_to_chunks( - stream: S, -) -> Pin> + Send>> +fn parse_sse_to_chunks(stream: S) -> Pin> + Send>> where S: Stream> + Send + 'static, { @@ -435,11 +555,14 @@ where if data_str == "[DONE]" { return None; } - let parsed = serde_json::from_str::(data_str) - .map_err(|e| { + let parsed = + serde_json::from_str::(data_str).map_err(|e| { AiError::Stream(format!("Failed to parse SSE chunk: {}", e)) }); - return Some((futures_util::stream::iter(vec![parsed]), (stream, String::new()))); + return Some(( + futures_util::stream::iter(vec![parsed]), + (stream, String::new()), + )); } return Some(( @@ -458,9 +581,7 @@ where } /// Parse NDJSON byte stream into StreamChunk events. -fn parse_ndjson_to_chunks( - stream: S, -) -> Pin> + Send>> +fn parse_ndjson_to_chunks(stream: S) -> Pin> + Send>> where S: Stream> + Send + 'static, { @@ -516,9 +637,13 @@ where return None; } - let parsed = serde_json::from_str::(line) - .map_err(|e| AiError::Stream(format!("Failed to parse NDJSON chunk: {}", e))); - return Some((futures_util::stream::iter(vec![parsed]), (stream, String::new()))); + let parsed = serde_json::from_str::(line).map_err(|e| { + AiError::Stream(format!("Failed to parse NDJSON chunk: {}", e)) + }); + return Some(( + futures_util::stream::iter(vec![parsed]), + (stream, String::new()), + )); } } } @@ -576,7 +701,11 @@ mod tests { assert_eq!(parsed.len(), 1); match &parsed[0] { - Ok(StreamChunk::RunFinished { run_id, finish_reason, .. }) => { + Ok(StreamChunk::RunFinished { + run_id, + finish_reason, + .. + }) => { assert_eq!(run_id, "run_2"); assert_eq!(finish_reason.as_deref(), Some("stop")); } diff --git a/crates/tanstack-ai/src/lib.rs b/crates/tanstack-ai/src/lib.rs index dfeebe7f..6b9da67c 100644 --- a/crates/tanstack-ai/src/lib.rs +++ b/crates/tanstack-ai/src/lib.rs @@ -69,10 +69,7 @@ pub use tools::*; pub use types::*; /// Convenience function to create a chat with an OpenAI model. -pub fn openai_text( - model: impl Into, - api_key: impl Into, -) -> OpenAiTextAdapter { +pub fn openai_text(model: impl Into, api_key: impl Into) -> OpenAiTextAdapter { OpenAiTextAdapter::new(model, api_key) } @@ -85,10 +82,7 @@ pub fn anthropic_text( } /// Convenience function to create a chat with a Gemini model. -pub fn gemini_text( - model: impl Into, - api_key: impl Into, -) -> GeminiTextAdapter { +pub fn gemini_text(model: impl Into, api_key: impl Into) -> GeminiTextAdapter { GeminiTextAdapter::new(model, api_key) } @@ -127,7 +121,7 @@ pub fn detect_image_mime_type(data: &[u8]) -> &'static str { "image/jpeg" } else if data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a") { "image/gif" - } else if data.starts_with(b"RIFF") && data.len() > 12 && &data[8..12] == b"WEBP" { + } else if data.starts_with(b"RIFF") && data.len() >= 12 && &data[8..12] == b"WEBP" { "image/webp" } else { "application/octet-stream" @@ -145,6 +139,9 @@ mod tests { let jpeg_header = &[0xFF, 0xD8, 0xFF, 0xE0]; assert_eq!(detect_image_mime_type(jpeg_header), "image/jpeg"); + + let webp_header = b"RIFF1234WEBP"; + assert_eq!(detect_image_mime_type(webp_header), "image/webp"); } #[test] diff --git a/crates/tanstack-ai/src/messages.rs b/crates/tanstack-ai/src/messages.rs index 1cf26ed3..9df4b79d 100644 --- a/crates/tanstack-ai/src/messages.rs +++ b/crates/tanstack-ai/src/messages.rs @@ -190,50 +190,52 @@ pub fn model_message_to_ui_message(msg: &ModelMessage) -> UiMessage { let mut parts = Vec::new(); // Add content parts - match &msg.content { - MessageContent::Text(text) => { - parts.push(MessagePart::Text { - content: text.clone(), - metadata: None, - }); - } - MessageContent::Parts(content_parts) => { - for part in content_parts { - match part { - ContentPart::Text { content } => { - parts.push(MessagePart::Text { - content: content.clone(), - metadata: None, - }); - } - ContentPart::Image { source } => { - parts.push(MessagePart::Image { - source: source.clone(), - metadata: None, - }); - } - ContentPart::Audio { source } => { - parts.push(MessagePart::Audio { - source: source.clone(), - metadata: None, - }); - } - ContentPart::Video { source } => { - parts.push(MessagePart::Video { - source: source.clone(), - metadata: None, - }); - } - ContentPart::Document { source } => { - parts.push(MessagePart::Document { - source: source.clone(), - metadata: None, - }); + if msg.role != MessageRole::Tool { + match &msg.content { + MessageContent::Text(text) => { + parts.push(MessagePart::Text { + content: text.clone(), + metadata: None, + }); + } + MessageContent::Parts(content_parts) => { + for part in content_parts { + match part { + ContentPart::Text { content } => { + parts.push(MessagePart::Text { + content: content.clone(), + metadata: None, + }); + } + ContentPart::Image { source } => { + parts.push(MessagePart::Image { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Audio { source } => { + parts.push(MessagePart::Audio { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Video { source } => { + parts.push(MessagePart::Video { + source: source.clone(), + metadata: None, + }); + } + ContentPart::Document { source } => { + parts.push(MessagePart::Document { + source: source.clone(), + metadata: None, + }); + } } } } + MessageContent::Null => {} } - MessageContent::Null => {} } // Add tool call parts @@ -275,8 +277,9 @@ pub fn model_message_to_ui_message(msg: &ModelMessage) -> UiMessage { } /// Normalize messages to ModelMessage format. -/// If the messages are already ModelMessages, pass through. -/// If they are UI messages, convert them. +/// +/// This is intentionally a passthrough for `ModelMessage` input today. +/// UI-message normalization happens in `ui_messages_to_model_messages`. pub fn normalize_to_model_messages(messages: &[ModelMessage]) -> Vec { messages.to_vec() } @@ -321,4 +324,19 @@ mod tests { assert_eq!(model_msgs.len(), 1); assert_eq!(model_msgs[0].content.as_str(), Some("Hi there!")); } + + #[test] + fn test_model_to_ui_tool_result_does_not_duplicate_text() { + let model_msg = ModelMessage { + role: MessageRole::Tool, + content: MessageContent::Text("{\"temp\":72}".to_string()), + name: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + }; + + let ui_msg = model_message_to_ui_message(&model_msg); + assert_eq!(ui_msg.parts.len(), 1); + assert!(matches!(ui_msg.parts[0], MessagePart::ToolResult { .. })); + } } diff --git a/crates/tanstack-ai/src/middleware.rs b/crates/tanstack-ai/src/middleware.rs index 7fe5e307..17d71536 100644 --- a/crates/tanstack-ai/src/middleware.rs +++ b/crates/tanstack-ai/src/middleware.rs @@ -293,6 +293,24 @@ impl MiddlewareRunner { } } + /// Run on_tool_phase_complete through all middlewares. + pub fn run_on_tool_phase_complete( + &self, + ctx: &ChatMiddlewareContext, + info: &ToolPhaseCompleteInfo, + ) { + for mw in &self.middlewares { + mw.on_tool_phase_complete(ctx, info); + } + } + + /// Run on_usage through all middlewares. + pub fn run_on_usage(&self, ctx: &ChatMiddlewareContext, usage: &UsageInfo) { + for mw in &self.middlewares { + mw.on_usage(ctx, usage); + } + } + /// Run on_finish through all middlewares. pub fn run_on_finish(&self, ctx: &ChatMiddlewareContext, info: &FinishInfo) { for mw in &self.middlewares { diff --git a/crates/tanstack-ai/src/stream/json_parser.rs b/crates/tanstack-ai/src/stream/json_parser.rs index 74a78187..dc9b4632 100644 --- a/crates/tanstack-ai/src/stream/json_parser.rs +++ b/crates/tanstack-ai/src/stream/json_parser.rs @@ -19,9 +19,8 @@ pub fn parse_partial_json(json_string: &str) -> Option { let trimmed = json_string.trim(); let mut attempt = trimmed.to_string(); - // Count open/close brackets and braces to determine what to close - let mut brace_depth: i32 = 0; - let mut bracket_depth: i32 = 0; + // Track opener order so incomplete nested structures close correctly. + let mut open_stack = Vec::new(); let mut in_string = false; let mut escape_next = false; @@ -42,10 +41,17 @@ pub fn parse_partial_json(json_string: &str) -> Option { continue; } match ch { - '{' => brace_depth += 1, - '}' => brace_depth -= 1, - '[' => bracket_depth += 1, - ']' => bracket_depth -= 1, + '{' | '[' => open_stack.push(ch), + '}' => { + if matches!(open_stack.last(), Some('{')) { + open_stack.pop(); + } + } + ']' => { + if matches!(open_stack.last(), Some('[')) { + open_stack.pop(); + } + } _ => {} } } @@ -55,12 +61,12 @@ pub fn parse_partial_json(json_string: &str) -> Option { attempt.push('"'); } - // Close unclosed brackets and braces - for _ in 0..bracket_depth { - attempt.push(']'); - } - for _ in 0..brace_depth { - attempt.push('}'); + while let Some(open) = open_stack.pop() { + attempt.push(match open { + '{' => '}', + '[' => ']', + _ => continue, + }); } serde_json::from_str(&attempt).ok() @@ -99,4 +105,10 @@ mod tests { let result = parse_partial_json(r#"{"user": {"name": "Jo"#).unwrap(); assert_eq!(result["user"]["name"], "Jo"); } + + #[test] + fn test_interleaved_nesting() { + let result = parse_partial_json(r#"[{"name": "Jo"#).unwrap(); + assert_eq!(result, serde_json::json!([{"name": "Jo"}])); + } } diff --git a/crates/tanstack-ai/src/stream/mod.rs b/crates/tanstack-ai/src/stream/mod.rs index 1d389364..1b6e88df 100644 --- a/crates/tanstack-ai/src/stream/mod.rs +++ b/crates/tanstack-ai/src/stream/mod.rs @@ -1,9 +1,9 @@ +pub mod json_parser; pub mod processor; pub mod strategies; -pub mod json_parser; pub mod types; +pub use json_parser::*; pub use processor::*; pub use strategies::*; -pub use json_parser::*; pub use types::*; diff --git a/crates/tanstack-ai/src/stream/processor.rs b/crates/tanstack-ai/src/stream/processor.rs index edebff68..27f95080 100644 --- a/crates/tanstack-ai/src/stream/processor.rs +++ b/crates/tanstack-ai/src/stream/processor.rs @@ -51,8 +51,8 @@ impl StreamProcessor { /// Process a single stream chunk, updating internal state. /// - /// Returns the processed chunk if it should be emitted to the UI. - pub fn process_chunk(&mut self, chunk: StreamChunk) -> Option { + /// Returns the processed chunks that should be emitted to the UI. + pub fn process_chunk(&mut self, chunk: StreamChunk) -> Vec { if self.recording_enabled { self.recordings.push(RecordedChunk { chunk: chunk.clone(), @@ -67,7 +67,7 @@ impl StreamProcessor { } => { let state = MessageStreamState::new(message_id.clone(), role.clone()); self.current_message = Some(state); - Some(chunk) + vec![chunk] } StreamChunk::TextMessageContent { @@ -90,19 +90,31 @@ impl StreamProcessor { { msg.last_emitted_text = msg.total_text_content.clone(); msg.current_segment_text.clear(); - Some(chunk) + vec![chunk] } else { - None + Vec::new() } } else { - Some(chunk) + vec![chunk] } } - StreamChunk::TextMessageEnd { .. } => { + StreamChunk::TextMessageEnd { + timestamp, + message_id, + model, + } => { + let mut emitted_chunks = Vec::new(); + if let Some(msg) = &mut self.current_message { - // Flush any remaining text that wasn't emitted if !msg.current_segment_text.is_empty() { + emitted_chunks.push(StreamChunk::TextMessageContent { + timestamp: *timestamp, + message_id: message_id.clone(), + delta: msg.current_segment_text.clone(), + content: Some(msg.total_text_content.clone()), + model: model.clone(), + }); msg.last_emitted_text = msg.total_text_content.clone(); msg.current_segment_text.clear(); } @@ -112,7 +124,8 @@ impl StreamProcessor { } } self.chunk_strategy.reset(); - Some(chunk) + emitted_chunks.push(chunk); + emitted_chunks } StreamChunk::ToolCallStart { @@ -138,7 +151,7 @@ impl StreamProcessor { msg.tool_call_order.push(tool_call_id.clone()); msg.has_tool_calls_since_text_start = true; } - Some(chunk) + vec![chunk] } StreamChunk::ToolCallArgs { @@ -154,7 +167,7 @@ impl StreamProcessor { crate::stream::json_parser::parse_partial_json(&tc.arguments); } } - Some(chunk) + vec![chunk] } StreamChunk::ToolCallEnd { @@ -171,10 +184,10 @@ impl StreamProcessor { tc.state = ToolCallState::InputComplete; } } - Some(chunk) + vec![chunk] } - StreamChunk::StepStarted { .. } => Some(chunk), + StreamChunk::StepStarted { .. } => vec![chunk], StreamChunk::StepFinished { step_id: _step_id, delta, @@ -188,14 +201,14 @@ impl StreamProcessor { msg.thinking_content.push_str(delta); } } - Some(chunk) + vec![chunk] } - StreamChunk::RunFinished { .. } => Some(chunk), - StreamChunk::RunError { .. } => Some(chunk), + StreamChunk::RunFinished { .. } => vec![chunk], + StreamChunk::RunError { .. } => vec![chunk], // Pass through other events unchanged - _ => Some(chunk), + _ => vec![chunk], } } diff --git a/crates/tanstack-ai/src/stream/strategies.rs b/crates/tanstack-ai/src/stream/strategies.rs index 7eb2b778..4bf6323c 100644 --- a/crates/tanstack-ai/src/stream/strategies.rs +++ b/crates/tanstack-ai/src/stream/strategies.rs @@ -100,7 +100,9 @@ impl ChunkStrategy for CompositeStrategy { fn should_emit(&mut self, chunk: &str, accumulated: &str) -> bool { self.strategies .iter_mut() - .any(|s| s.should_emit(chunk, accumulated)) + .fold(false, |should_emit, strategy| { + strategy.should_emit(chunk, accumulated) || should_emit + }) } fn reset(&mut self) { diff --git a/crates/tanstack-ai/src/stream_response.rs b/crates/tanstack-ai/src/stream_response.rs index bb3ce913..45e0a7d7 100644 --- a/crates/tanstack-ai/src/stream_response.rs +++ b/crates/tanstack-ai/src/stream_response.rs @@ -28,9 +28,9 @@ pub fn sse_stream_to_json( use futures_util::StreamExt; let lines = tokio_util::codec::FramedRead::new( - tokio_util::io::StreamReader::new(stream.map(|r| r.map_err(|e| { - std::io::Error::new(std::io::ErrorKind::Other, e) - }))), + tokio_util::io::StreamReader::new( + stream.map(|r| r.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), + ), tokio_util::codec::LinesCodec::new(), ); @@ -97,11 +97,18 @@ pub fn chunks_to_sse_stream( stream .map(|result| match result { - Ok(chunk) => { - let json = serde_json::to_string(&chunk).unwrap_or_default(); - Ok(Bytes::from(format!("data: {}\n\n", json))) - } - Err(e) => Ok(Bytes::from(format!("data: {}\n\n", + Ok(chunk) => match serde_json::to_string(&chunk) { + Ok(json) => Ok(Bytes::from(format!("data: {}\n\n", json))), + Err(err) => Ok(Bytes::from(format!( + "data: {}\n\n", + serde_json::json!({ + "type": "RUN_ERROR", + "error": { "message": format!("failed to serialize stream chunk: {}", err) } + }) + ))), + }, + Err(e) => Ok(Bytes::from(format!( + "data: {}\n\n", serde_json::json!({"type": "RUN_ERROR", "error": {"message": e.to_string()}}) ))), }) @@ -117,7 +124,11 @@ pub async fn stream_to_text( while let Some(result) = stream.next().await { match result? { - StreamChunk::TextMessageContent { delta, content: full, .. } => { + StreamChunk::TextMessageContent { + delta, + content: full, + .. + } => { if let Some(full_content) = full { content = full_content; } else { diff --git a/crates/tanstack-ai/src/tools/definition.rs b/crates/tanstack-ai/src/tools/definition.rs index 9c04ae24..3f163a6c 100644 --- a/crates/tanstack-ai/src/tools/definition.rs +++ b/crates/tanstack-ai/src/tools/definition.rs @@ -99,7 +99,12 @@ pub fn tool_definition(name: impl Into, description: impl Into) ToolDefinition::new(name, description) } +/// Fallible helper to create a JSON Schema from a JSON value. +pub fn try_json_schema(value: serde_json::Value) -> Result { + serde_json::from_value(value) +} + /// Helper to create a JSON Schema from a JSON value. pub fn json_schema(value: serde_json::Value) -> JsonSchema { - serde_json::from_value(value).unwrap_or_default() + try_json_schema(value).expect("invalid json schema value") } diff --git a/crates/tanstack-ai/src/tools/registry.rs b/crates/tanstack-ai/src/tools/registry.rs index 8931dd82..1ddb9689 100644 --- a/crates/tanstack-ai/src/tools/registry.rs +++ b/crates/tanstack-ai/src/tools/registry.rs @@ -80,6 +80,10 @@ impl From> for ToolRegistry { } /// Frozen/immutable tool registry for thread-safe sharing. +/// +/// Use `FrozenToolRegistry` when multiple clients or tasks need to reuse the +/// same tool set. `ChatOptions` still accepts `Vec` because a single chat +/// run owns its tool definitions outright. #[derive(Debug, Clone)] pub struct FrozenToolRegistry { tools: Arc>, diff --git a/crates/tanstack-ai/src/tools/tool_calls.rs b/crates/tanstack-ai/src/tools/tool_calls.rs index cfc422c8..5d344303 100644 --- a/crates/tanstack-ai/src/tools/tool_calls.rs +++ b/crates/tanstack-ai/src/tools/tool_calls.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; use crate::error::{AiError, AiResult}; use crate::types::*; +use std::collections::HashMap; use tokio::sync::mpsc; /// Result of a tool execution. @@ -45,23 +45,18 @@ pub struct ExecuteToolCallsResult { pub needs_client_execution: Vec, } -/// Custom event emitted during tool execution. -#[derive(Debug, Clone)] -pub struct ToolCustomEvent { - pub name: String, - pub value: serde_json::Value, -} - /// Manages tool call accumulation and execution for the chat engine. #[derive(Debug, Default)] pub struct ToolCallManager { tool_calls: HashMap, + tool_call_order: Vec, } impl ToolCallManager { pub fn new() -> Self { Self { tool_calls: HashMap::new(), + tool_call_order: Vec::new(), } } @@ -75,6 +70,9 @@ impl ToolCallManager { .. } = event { + if !self.tool_calls.contains_key(tool_call_id) { + self.tool_call_order.push(tool_call_id.clone()); + } self.tool_calls.insert( tool_call_id.clone(), ToolCall { @@ -92,7 +90,12 @@ impl ToolCallManager { /// Add a TOOL_CALL_ARGS event to accumulate arguments. pub fn add_args_event(&mut self, event: &StreamChunk) { - if let StreamChunk::ToolCallArgs { tool_call_id, delta, .. } = event { + if let StreamChunk::ToolCallArgs { + tool_call_id, + delta, + .. + } = event + { if let Some(tc) = self.tool_calls.get_mut(tool_call_id) { tc.function.arguments.push_str(delta); } @@ -101,11 +104,15 @@ impl ToolCallManager { /// Complete a tool call with its final input. pub fn complete_tool_call(&mut self, event: &StreamChunk) { - if let StreamChunk::ToolCallEnd { tool_call_id, input, .. } = event { + if let StreamChunk::ToolCallEnd { + tool_call_id, + input, + .. + } = event + { if let Some(tc) = self.tool_calls.get_mut(tool_call_id) { if let Some(final_input) = input { - tc.function.arguments = - serde_json::to_string(final_input).unwrap_or_default(); + tc.function.arguments = serde_json::to_string(final_input).unwrap_or_default(); } } } @@ -120,8 +127,9 @@ impl ToolCallManager { /// Get all tool calls as a Vec. pub fn tool_calls(&self) -> Vec { - self.tool_calls - .values() + self.tool_call_order + .iter() + .filter_map(|tool_call_id| self.tool_calls.get(tool_call_id)) .filter(|tc| !tc.id.is_empty() && !tc.function.name.trim().is_empty()) .cloned() .collect() @@ -130,6 +138,7 @@ impl ToolCallManager { /// Clear all tool calls for the next iteration. pub fn clear(&mut self) { self.tool_calls.clear(); + self.tool_call_order.clear(); } } @@ -141,7 +150,7 @@ pub async fn execute_tool_calls( tools: &[Tool], approvals: &HashMap, client_results: &HashMap, - event_tx: Option>, + event_tx: Option>, ) -> AiResult { let mut results = Vec::new(); let mut needs_approval = Vec::new(); @@ -262,14 +271,7 @@ pub async fn execute_tool_calls( let approval_id = format!("approval_{}", tool_call.id); if let Some(&approved) = approvals.get(&approval_id) { if approved { - execute_server_tool( - tool_call, - tool, - input, - &event_tx, - &mut results, - ) - .await?; + execute_server_tool(tool_call, tool, input, &event_tx, &mut results).await?; } else { results.push(ToolResult { tool_call_id: tool_call.id.clone(), @@ -305,14 +307,14 @@ async fn execute_server_tool( tool_call: &ToolCall, tool: &Tool, input: serde_json::Value, - event_tx: &Option>, + event_tx: &Option>, results: &mut Vec, ) -> AiResult<()> { let start = std::time::Instant::now(); let ctx = ToolExecutionContext { tool_call_id: Some(tool_call.id.clone()), - custom_event_tx: None, // TODO: wire up custom event channel + custom_event_tx: event_tx.clone(), }; let execute_fn = tool.execute.as_ref().unwrap(); @@ -323,13 +325,14 @@ async fn execute_server_tool( // Emit custom event if channel is available if let Some(tx) = event_tx { - let _ = tx.send(ToolCustomEvent { + let _ = tx.send(CustomEventData { name: "tool-result".to_string(), value: serde_json::json!({ "toolCallId": tool_call.id, "toolName": tool.name, "result": result, }), + tool_call_id: Some(tool_call.id.clone()), }); } diff --git a/crates/tanstack-ai/src/types.rs b/crates/tanstack-ai/src/types.rs index 374623d2..05daf10c 100644 --- a/crates/tanstack-ai/src/types.rs +++ b/crates/tanstack-ai/src/types.rs @@ -1,5 +1,6 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::Arc; // ============================================================================ // Tool Call States @@ -56,8 +57,10 @@ pub struct JsonSchema { #[serde(skip_serializing_if = "Option::is_none")] pub default: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "$ref")] pub r#ref: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "$defs")] pub defs: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub all_of: Option>, @@ -115,7 +118,7 @@ pub enum ContentPartSource { } /// Image content part. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] pub struct ImagePart { #[serde(rename = "type")] pub part_type: &'static str, // always "image" @@ -125,7 +128,7 @@ pub struct ImagePart { } /// Audio content part. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] pub struct AudioPart { #[serde(rename = "type")] pub part_type: &'static str, @@ -135,7 +138,7 @@ pub struct AudioPart { } /// Video content part. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] pub struct VideoPart { #[serde(rename = "type")] pub part_type: &'static str, @@ -145,7 +148,7 @@ pub struct VideoPart { } /// Document content part (e.g., PDFs). -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] pub struct DocumentPart { #[serde(rename = "type")] pub part_type: &'static str, @@ -663,9 +666,10 @@ pub struct AgentLoopState { } /// Strategy function that determines whether the agent loop should continue. -pub type AgentLoopStrategy = Box bool + Send + Sync>; +pub type AgentLoopStrategy = Arc bool + Send + Sync>; /// Options for text generation / chat. +#[derive(Clone)] pub struct TextOptions { pub model: String, pub messages: Vec, @@ -681,25 +685,6 @@ pub struct TextOptions { pub conversation_id: Option, } -impl Clone for TextOptions { - fn clone(&self) -> Self { - Self { - model: self.model.clone(), - messages: self.messages.clone(), - tools: self.tools.clone(), - system_prompts: self.system_prompts.clone(), - agent_loop_strategy: None, // Strategy functions cannot be cloned - temperature: self.temperature, - top_p: self.top_p, - max_tokens: self.max_tokens, - metadata: self.metadata.clone(), - model_options: self.model_options.clone(), - output_schema: self.output_schema.clone(), - conversation_id: self.conversation_id.clone(), - } - } -} - impl Default for TextOptions { fn default() -> Self { Self { @@ -739,18 +724,18 @@ pub struct StructuredOutputResult { /// Continue for up to `max` iterations. pub fn max_iterations(max: u32) -> AgentLoopStrategy { - Box::new(move |state: &AgentLoopState| state.iteration_count < max) + Arc::new(move |state: &AgentLoopState| state.iteration_count < max) } /// Continue until a specific finish reason is received. pub fn until_finish_reason(reason: impl Into) -> AgentLoopStrategy { let reason = reason.into(); - Box::new(move |state: &AgentLoopState| state.finish_reason.as_ref() != Some(&reason)) + Arc::new(move |state: &AgentLoopState| state.finish_reason.as_ref() != Some(&reason)) } /// Combine multiple strategies with AND logic (all must return true). pub fn combine_strategies(strategies: Vec) -> AgentLoopStrategy { - Box::new(move |state: &AgentLoopState| strategies.iter().all(|s| s(state))) + Arc::new(move |state: &AgentLoopState| strategies.iter().all(|s| s(state))) } // ============================================================================ diff --git a/crates/tanstack-ai/tests/chat_tests.rs b/crates/tanstack-ai/tests/chat_tests.rs index d955db64..4e1d2045 100644 --- a/crates/tanstack-ai/tests/chat_tests.rs +++ b/crates/tanstack-ai/tests/chat_tests.rs @@ -249,7 +249,10 @@ async fn test_server_tool_execution() { let has_tool_result = second_call_messages .iter() .any(|m| m.role == MessageRole::Tool); - assert!(has_tool_result, "Expected tool result message in second call"); + assert!( + has_tool_result, + "Expected tool result message in second call" + ); // Should have TOOL_CALL_END chunks match result { diff --git a/crates/tanstack-ai/tests/stream_processor_tests.rs b/crates/tanstack-ai/tests/stream_processor_tests.rs index 5dd022be..42340548 100644 --- a/crates/tanstack-ai/tests/stream_processor_tests.rs +++ b/crates/tanstack-ai/tests/stream_processor_tests.rs @@ -126,7 +126,7 @@ fn test_processor_with_batch_strategy() { role: "assistant".to_string(), model: None, }); - assert!(r1.is_some()); // Start always passes through + assert_eq!(r1.len(), 1); // Start always passes through let r2 = processor.process_chunk(StreamChunk::TextMessageContent { timestamp: now(), @@ -135,7 +135,7 @@ fn test_processor_with_batch_strategy() { content: None, model: None, }); - assert!(r2.is_none()); // Batch: not enough chunks yet + assert!(r2.is_empty()); // Batch: not enough chunks yet let r3 = processor.process_chunk(StreamChunk::TextMessageContent { timestamp: now(), @@ -144,7 +144,25 @@ fn test_processor_with_batch_strategy() { content: None, model: None, }); - assert!(r3.is_some()); // Batch: reached 2 chunks + assert_eq!(r3.len(), 1); // Batch: reached 2 chunks + + let r4 = processor.process_chunk(StreamChunk::TextMessageContent { + timestamp: now(), + message_id: "msg-1".to_string(), + delta: "c".to_string(), + content: None, + model: None, + }); + assert!(r4.is_empty()); // Final chunk is still buffered + + let r5 = processor.process_chunk(StreamChunk::TextMessageEnd { + timestamp: now(), + message_id: "msg-1".to_string(), + model: None, + }); + assert_eq!(r5.len(), 2); // Flush buffered text, then emit end + assert!(matches!(r5[0], StreamChunk::TextMessageContent { .. })); + assert!(matches!(r5[1], StreamChunk::TextMessageEnd { .. })); } #[test] diff --git a/crates/tanstack-ai/tests/test_utils.rs b/crates/tanstack-ai/tests/test_utils.rs index 3b24cdc4..5d85d019 100644 --- a/crates/tanstack-ai/tests/test_utils.rs +++ b/crates/tanstack-ai/tests/test_utils.rs @@ -102,10 +102,7 @@ pub fn _tool_end( } } -pub fn run_finished( - finish_reason: &str, - run_id: &str, -) -> StreamChunk { +pub fn run_finished(finish_reason: &str, run_id: &str) -> StreamChunk { StreamChunk::RunFinished { timestamp: now(), run_id: run_id.to_string(), @@ -228,7 +225,9 @@ pub fn collect_text(chunks: &[StreamChunk]) -> String { let mut content = String::new(); for chunk in chunks { if let StreamChunk::TextMessageContent { - delta, content: full, .. + delta, + content: full, + .. } = chunk { if let Some(full_content) = full { @@ -260,10 +259,7 @@ pub fn count_chunk_type(chunks: &[StreamChunk], chunk_type: &str) -> usize { } /// Helper to create a simple server tool for testing. -pub fn server_tool( - name: &str, - result: serde_json::Value, -) -> Tool { +pub fn server_tool(name: &str, result: serde_json::Value) -> Tool { let result_clone = result.clone(); Tool::new(name, format!("Test tool: {}", name)).with_execute( move |_args: serde_json::Value, _ctx: tanstack_ai::types::ToolExecutionContext| { diff --git a/crates/tanstack-ai/tests/tool_call_manager_tests.rs b/crates/tanstack-ai/tests/tool_call_manager_tests.rs index 72f7f0a9..68e3ca60 100644 --- a/crates/tanstack-ai/tests/tool_call_manager_tests.rs +++ b/crates/tanstack-ai/tests/tool_call_manager_tests.rs @@ -167,4 +167,6 @@ fn test_manager_parallel_tool_calls() { let calls = manager.tool_calls(); assert_eq!(calls.len(), 2); + assert_eq!(calls[0].id, "call_1"); + assert_eq!(calls[1].id, "call_2"); } diff --git a/crates/tanstack-ai/tests/unit_tests.rs b/crates/tanstack-ai/tests/unit_tests.rs index 52bcfece..28f380fc 100644 --- a/crates/tanstack-ai/tests/unit_tests.rs +++ b/crates/tanstack-ai/tests/unit_tests.rs @@ -2,6 +2,7 @@ //! //! Ports the TypeScript test suite to Rust with full feature parity. +use std::sync::Arc; use tanstack_ai::*; // ============================================================================ @@ -240,7 +241,7 @@ fn test_until_finish_reason_continues_on_no_match() { fn test_combine_strategies_all_true() { let strategy = combine_strategies(vec![ max_iterations(5), - Box::new(|state: &AgentLoopState| state.iteration_count < 10), + Arc::new(|state: &AgentLoopState| state.iteration_count < 10), ]); assert!(strategy(&make_state(2, None))); } @@ -249,11 +250,23 @@ fn test_combine_strategies_all_true() { fn test_combine_strategies_any_false() { let strategy = combine_strategies(vec![ max_iterations(5), - Box::new(|state: &AgentLoopState| state.iteration_count < 10), + Arc::new(|state: &AgentLoopState| state.iteration_count < 10), ]); assert!(!strategy(&make_state(5, None))); } +#[test] +fn test_text_options_clone_preserves_agent_loop_strategy() { + let options = TextOptions { + agent_loop_strategy: Some(max_iterations(2)), + ..Default::default() + }; + + let cloned = options.clone(); + assert!(cloned.agent_loop_strategy.is_some()); + assert!((cloned.agent_loop_strategy.unwrap())(&make_state(1, None))); +} + #[test] fn test_combine_strategies_empty() { let strategy = combine_strategies(vec![]); @@ -404,6 +417,25 @@ fn test_ui_to_model_multiple_text_parts() { } } +#[test] +fn test_ui_to_model_thinking_part() { + let ui_msg = UiMessage { + id: "msg-1".to_string(), + role: UiMessageRole::Assistant, + parts: vec![MessagePart::Thinking { + content: "reasoning".to_string(), + }], + created_at: None, + }; + + let model_msgs = ui_message_to_model_messages(&ui_msg); + assert_eq!(model_msgs.len(), 1); + assert_eq!( + model_msgs[0].content.as_str(), + Some("[thinking]reasoning[/thinking]"), + ); +} + #[test] fn test_ui_to_model_multimodal_image() { let ui_msg = UiMessage { @@ -717,7 +749,10 @@ fn test_model_message_with_tool_calls() { let deserialized: ModelMessage = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.role, MessageRole::Assistant); assert!(deserialized.tool_calls.is_some()); - assert_eq!(deserialized.tool_calls.as_ref().unwrap()[0].function.name, "getWeather"); + assert_eq!( + deserialized.tool_calls.as_ref().unwrap()[0].function.name, + "getWeather" + ); } #[test] @@ -726,5 +761,27 @@ fn test_detect_image_mime_type() { assert_eq!(tanstack_ai::detect_image_mime_type(png_header), "image/png"); let jpeg_header = &[0xFF, 0xD8, 0xFF, 0xE0]; - assert_eq!(tanstack_ai::detect_image_mime_type(jpeg_header), "image/jpeg"); + assert_eq!( + tanstack_ai::detect_image_mime_type(jpeg_header), + "image/jpeg" + ); +} + +#[test] +fn test_json_schema_serializes_ref_and_defs_keys() { + let schema = JsonSchema { + r#ref: Some("#/$defs/thing".to_string()), + defs: Some(std::collections::HashMap::from([( + "thing".to_string(), + JsonSchema { + r#type: Some(serde_json::json!("string")), + ..Default::default() + }, + )])), + ..Default::default() + }; + + let json = serde_json::to_value(schema).unwrap(); + assert_eq!(json["$ref"], "#/$defs/thing"); + assert!(json.get("$defs").is_some()); } diff --git a/examples/rs-chat/Cargo.lock b/examples/rs-chat/Cargo.lock index 1c4bec63..a0b4098c 100644 --- a/examples/rs-chat/Cargo.lock +++ b/examples/rs-chat/Cargo.lock @@ -1235,6 +1235,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "tracing", "uuid", ] @@ -1405,9 +1406,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.36" diff --git a/examples/rs-chat/src/main.rs b/examples/rs-chat/src/main.rs index 9ebd57d7..7356d1d4 100644 --- a/examples/rs-chat/src/main.rs +++ b/examples/rs-chat/src/main.rs @@ -199,7 +199,9 @@ fn extract_text(chunks: &[StreamChunk]) -> String { let mut content = String::new(); for chunk in chunks { if let StreamChunk::TextMessageContent { - delta, content: full, .. + delta, + content: full, + .. } = chunk { if let Some(f) = full {