diff --git a/.claude/rules/github-workflows.md b/.claude/rules/github-workflows.md new file mode 100644 index 00000000..68d68c71 --- /dev/null +++ b/.claude/rules/github-workflows.md @@ -0,0 +1,11 @@ +--- +paths: + - ".github/**/*" +--- + +# GitHub Workflows Standards + +- Always use Makefile targets in the workflow to avoid code duplication. +- Never add the tests that use LLMs to GitHub workflows, because the default GitHub worker does not have the capacity to run them. +- Only add unit tests to GitHub workflows. +- Keep GitHub workflows responsible for only a single concern. For example, run linter, and tests in parallel. diff --git a/.claude/rules/makefile.md b/.claude/rules/makefile.md new file mode 100644 index 00000000..35407319 --- /dev/null +++ b/.claude/rules/makefile.md @@ -0,0 +1,13 @@ +--- +paths: + - "**/Makefile" +--- + +# Makefile Standards + +- Keep variables at the top of the file. Always. +- Prefer real targets over phony targets. If something can be express as a real target, do that. +- If you see that a phony target can be expressed as a real target, you can suggest a fix. +- Keep real targets, phony targets grouped together. Keep targets alphabetically sorted within each group. +- Keep all the real targets above phony targets. +- Make sure each Makefile target has enough dependencies to be able to run from a clean state. diff --git a/.claude/rules/python-on-nixos.md b/.claude/rules/python-on-nixos.md new file mode 100644 index 00000000..04329068 --- /dev/null +++ b/.claude/rules/python-on-nixos.md @@ -0,0 +1,19 @@ +--- +paths: + - "paddler_client_python/**/*" +--- + +# Running Python tooling on NixOS + +To run any Python tool that may have ELF / dynamic-linker issues on NixOS — `ruff`, `mypy`, `pyright`, `pytest`, anything installed from a pip wheel with native +bits — first enter `paddler_client_python/shell.nix`, then drive everything through `poetry` from inside that shell. + +**Why:** +pip wheels like `ruff` ship a generic-linux binary, which NixOS does not provide. +Running them directly fails with `Could not start dynamically linked executable: ... NixOS cannot run dynamically linked executables intended for generic linux environments`. +`shell.nix` provides the Nix-built loader / replacement tools that make those binaries (or their Nix equivalents) actually launch. +`poetry` is just the dispatcher you use *inside* that prepared shell — never the entry point on its own. + +**How to apply:** +- Never invoke `ruff`, `poetry run ...`, `python`, `pytest`, etc. from outside `nix-shell`. If a command starts with one of those, it must be inside `nix-shell --run "..."`. +- If `paddler_client_python/shell.nix` is missing, stop and ask. Adding a `shell.nix` is the fix; running tooling unwrapped is not. diff --git a/.claude/rules/rust-integration-tests.md b/.claude/rules/rust-integration-tests.md new file mode 100644 index 00000000..53abd1c1 --- /dev/null +++ b/.claude/rules/rust-integration-tests.md @@ -0,0 +1,15 @@ +--- +paths: + - "**/tests/**/*.rs" +--- + +# Rust Integration Tests Standards + +- Each test needs to be named after what functionality, or issue it actually tests. +- Each test file needs to be named after what functionality, or issue it actually tests. +- Each test represents a specific scenario that the core project needs to support, or represent an uncovered issue. +- If you uncover a new issue while testing, create yet another targeted test that covers that. +- Every test muse use production code. Never recreate the original code to test something conceptually. Always use production code. +- They must be single-purpose. +- It must be clear what is being tested in the test file by just reading the filename. + diff --git a/.claude/rules/rust.md b/.claude/rules/rust.md index be678983..4584b330 100644 --- a/.claude/rules/rust.md +++ b/.claude/rules/rust.md @@ -17,3 +17,20 @@ paths: - In Rust, when implementing a `new` method in a struct, prefer to use a struct with a parameter list instead of multiple function arguments. It should be easier to maintain. - Always check the project with Clippy. - Always format the code with `cargo fmt`. +- Each file must contain at most a single struct, or single enum. For readability split those into multiple modules. You can still keep multiple private function helpers. +- Never use Result<> as a function argument. +- Never forward Result in enums if you can instead create a targeted error enum. It is always better to signal the specific issue, so it can be handled downstream. +- Always destructure structs in arguments if possible. + +# Code Style + +Imports/uses must not be mixed with other kinds of rust syntax. + +Each file needs to follow this order: +1. `pub mod`/`mod` exports +2. vendor crate `use` +2. project crate `use` +3. local crate `use` +4. private function helpers +5. private struct helpers +6. single public export diff --git a/.claude/rules/testing.md b/.claude/rules/testing.md index d24bd31f..c60bf4b2 100644 --- a/.claude/rules/testing.md +++ b/.claude/rules/testing.md @@ -6,3 +6,9 @@ - If some piece of code can be handled by proper types, use types instead. Write tests as a last resort. - In unit tests, make sure there is always just a single correct way to do a specific thing. Never accept fuzzy inputs from end users. - When working on tests, if you notice that the tested code can be better, you can suggest changes. +- Maintain 100% test coverage across the codebase. No file, branch, or line may be excluded from coverage reports. +- Reach 100% coverage with the minimum number of tests. Each test must cover a unique code path, behavior, or edge case that no other test already covers. +- If two tests cover overlapping paths, remove the weaker one. Redundant tests waste maintenance effort without improving correctness signal. +- Tests must exercise actual functionality and observable behavior. Never write a test purely to hit lines for the sake of coverage. +- Design tests deliberately before writing them. Identify the feature or branch under test, then write the smallest test that verifies it. +- Coverage gaps signal missing tests, never permission to exclude files. Write the test instead of suppressing the gap. diff --git a/.claude/skills/running-all-tests/SKILL.md b/.claude/skills/running-all-tests/SKILL.md new file mode 100644 index 00000000..7af693e7 --- /dev/null +++ b/.claude/skills/running-all-tests/SKILL.md @@ -0,0 +1,56 @@ +--- +name: running-all-tests +description: Runs every test suite in the paddler workspace on the fastest available device. Use when the user asks to run the tests, run all the tests, run the full test suite, or check that everything still passes. +--- + +# Running all tests + +Run every test suite in the workspace, picking the fastest compiled device backend for the host. + +## Step 1: detect the device + +Run this once at the start and echo the chosen device: + +```bash +if [[ "$OSTYPE" == "darwin"* ]]; then + DEVICE=metal +elif command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi >/dev/null 2>&1; then + DEVICE=cuda +else + DEVICE=cpu +fi +echo "Device: $DEVICE" +``` + +`$DEVICE` selects the Rust integration suite variant in Step 2. The other four suites don't take a device feature. + +## Step 2: run the five suites + +Copy this checklist and tick each item as the suite completes: + +``` +- [ ] JS client +- [ ] Python client +- [ ] Rust unit +- [ ] Rust integration +- [ ] paddler_gui +``` + +| # | Suite | Inner command | Working dir | +|---|------------------|-----------------------------------------------------------------------------------------------------------------------------------|--------------------------| +| 1 | JS client | `make test.client.js` | repo root | +| 2 | Python client | NixOS: `poetry run pytest`, `ruff`, `poetry run mypy"`. Every other OS: `poetry run pytest`, `poetry run ruff`, `poetry run mypy` | `paddler_client_python/` | +| 3 | Rust unit | `make test.unit` | repo root | +| 4 | Rust integration | `make test.integration` (cpu) / `make test.integration.cuda` / `make test.integration.metal` — pick by `$DEVICE` | repo root | +| 5 | paddler_gui | `cargo test -p paddler_gui --features web_admin_panel` | `paddler_gui/` | + +Run them in this order. Cheap suites (1, 3, 4) surface bugs quickly; the heavy GPU-bound suites (2, 5) come last. + +## Step 3: rules during the run + +- **Serialize GPU suites.** When `$DEVICE` is `cuda` or `metal`, run test suites sequentially to avoid device contention. +- **Per-test 30 s budget.** Flag any individual test that exceeds 30 s wall-clock. That is a real bug — production or test — not flakiness. + +## Step 4: report + +After all suites finish, sum up the results in an actionable report. diff --git a/.gitignore b/.gitignore index 69980977..cb516b8c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ .DS_Store /*.db -/.tsimp /esbuild-meta.json /godepgraph.png /models diff --git a/Cargo.lock b/Cargo.lock index ef76369c..62046878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,7 +121,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -239,7 +239,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -287,7 +287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ac4d6e04e97fe707286509b4f338e99c5fb7249c770e1da074af5e27faa96b3" dependencies = [ "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -312,7 +312,7 @@ checksum = "b6ac1e58cded18cb28ddc17143c4dea5345b3ad575e14f32f66e4054a56eb271" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -505,7 +505,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -571,7 +571,7 @@ dependencies = [ "rustc-hash 2.1.2", "serde", "serde_derive", - "syn", + "syn 2.0.117", ] [[package]] @@ -679,7 +679,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -719,7 +719,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -736,7 +736,16 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", +] + +[[package]] +name = "atomic" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89cbf775b137e9b968e67227ef7f775587cde3fd31b0d8599dbd0f598a48340" +dependencies = [ + "bytemuck", ] [[package]] @@ -826,7 +835,16 @@ dependencies = [ "regex", "rustc-hash 2.1.2", "shlex", - "syn", + "syn 2.0.117", +] + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec 0.6.3", ] [[package]] @@ -835,9 +853,15 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ - "bit-vec", + "bit-vec 0.8.0", ] +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bit-vec" version = "0.8.0" @@ -988,7 +1012,7 @@ checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1088,11 +1112,20 @@ dependencies = [ "wayland-client", ] +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" -version = "1.2.59" +version = "1.2.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7a4d3ec6524d28a329fc53654bbadc9bdd7b0431f5d65f1a56ffb28a1ee5283" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" dependencies = [ "find-msvc-tools", "jobserver", @@ -1180,7 +1213,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1270,6 +1303,20 @@ dependencies = [ "memchr", ] +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "static_assertions", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1482,6 +1529,34 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crossterm" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" +dependencies = [ + "bitflags 2.11.0", + "crossterm_winapi", + "derive_more", + "document-features", + "futures-core", + "mio", + "parking_lot", + "rustix 1.1.4", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.4" @@ -1520,6 +1595,16 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "csscolorparser" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2a7d3066da2de787b7f032c736763eb7ae5d355f81a68bab2675a96008b0bf" +dependencies = [ + "lab", + "phf", +] + [[package]] name = "csv" version = "1.4.0" @@ -1553,6 +1638,40 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.117", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.117", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -1579,6 +1698,12 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be1e0bca6c3637f992fc1cc7cbc52a78c1ef6db076dbf1059c4323d6a2048376" +[[package]] +name = "deltae" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5729f5117e208430e437df2f4843f5e5952997175992d1414f94c57d61e270b4" + [[package]] name = "deranged" version = "0.5.8" @@ -1608,10 +1733,25 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn", + "syn 2.0.117", "unicode-xid", ] +[[package]] +name = "derivre" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786c7c65c4ef0c7deb05de3005e01991612a8f09fe0844fc0969c68b90468ba8" +dependencies = [ + "ahash", + "anyhow", + "bytemuck", + "bytemuck_derive", + "hashbrown 0.15.5", + "regex-syntax", + "strum", +] + [[package]] name = "diff" version = "0.1.13" @@ -1684,7 +1824,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1772,7 +1912,7 @@ checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1815,7 +1955,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1912,13 +2052,23 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set 0.5.3", + "regex", +] + [[package]] name = "fancy-regex" version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "998b056554fbe42e03ae0e152895cd1a7e1002aec800fdc6635d20270260c46f" dependencies = [ - "bit-set", + "bit-set 0.8.0", "regex-automata", "regex-syntax", ] @@ -1946,7 +2096,7 @@ checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1958,6 +2108,17 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "filedescriptor" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d" +dependencies = [ + "libc", + "thiserror 1.0.69", + "winapi", +] + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -1973,6 +2134,18 @@ dependencies = [ "glob", ] +[[package]] +name = "finl_unicode" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9844ddc3a6e533d62bba727eb6c28b5d360921d5175e9ff0f1e621a5c590a4d5" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.1.9" @@ -1989,6 +2162,17 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +[[package]] +name = "fluent-uri" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" +dependencies = [ + "borrow-or-share", + "ref-cast", + "serde", +] + [[package]] name = "fluent-uri" version = "0.4.1" @@ -2086,7 +2270,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -2205,7 +2389,7 @@ checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -2980,6 +3164,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -3109,6 +3299,19 @@ dependencies = [ "rustversion", ] +[[package]] +name = "instability" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb2d60ef19920a3a9193c3e371f726ec1dafc045dac788d0fb3704272458971" +dependencies = [ + "darling", + "indoc", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "interpolate_name" version = "0.2.4" @@ -3117,7 +3320,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -3215,7 +3418,7 @@ checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -3261,7 +3464,7 @@ dependencies = [ "quote", "rustc_version", "simd_cesu8", - "syn", + "syn 2.0.117", ] [[package]] @@ -3289,7 +3492,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" dependencies = [ "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -3324,7 +3527,7 @@ dependencies = [ "bytecount", "data-encoding", "email_address", - "fancy-regex", + "fancy-regex 0.16.2", "fraction", "getrandom 0.3.4", "idna", @@ -3332,7 +3535,7 @@ dependencies = [ "num-cmp", "num-traits", "percent-encoding", - "referencing", + "referencing 0.37.4", "regex", "regex-syntax", "serde", @@ -3350,6 +3553,17 @@ dependencies = [ "mutate_once", ] +[[package]] +name = "kasuari" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde5057d6143cc94e861d90f591b9303d6716c6b9602309150bd068853c10899" +dependencies = [ + "hashbrown 0.16.1", + "portable-atomic", + "thiserror 2.0.18", +] + [[package]] name = "khronos-egl" version = "6.0.0" @@ -3399,6 +3613,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "lab" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf36173d4167ed999940f804952e6b08197cae5ad5d572eb4db150ce8ad5d58f" + [[package]] name = "language-tags" version = "0.3.2" @@ -3476,6 +3696,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "line-clipping" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f50e8f47623268b5407192d26876c4d7f89d686ca130fdc53bced4814cd29f8" +dependencies = [ + "bitflags 2.11.0", +] + [[package]] name = "linebender_resource_handle" version = "0.1.1" @@ -3499,7 +3728,7 @@ checksum = "e5cec0ec4228b4853bb129c84dbf093a27e6c7a20526da046defc334a1b017f7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -3528,41 +3757,75 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" [[package]] name = "llama-cpp-bindings" -version = "0.4.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9627055cad4854d59fcc449c9ddf3f93f96951cb1e120e6dc4c9a2af51902862" +checksum = "5c12d8de3511f1c3e3025e811ad2644d22d5b6657f18e9496ea3977c456eed8a" dependencies = [ "encoding_rs", "enumflags2", "llama-cpp-bindings-sys", + "llama-cpp-bindings-types", + "llguidance", + "nom 8.0.0", + "serde_json", "thiserror 2.0.18", + "toktrie", "tracing", "tracing-core", ] [[package]] name = "llama-cpp-bindings-build" -version = "0.4.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d19130eaca34f6578ee105766f08fd004f1d2c1e187cb75cf51a444852dd4aa" +checksum = "352fc011d0d723af3864d500d3a78b2d5ee0a5200993229f93421984d9abbfe3" dependencies = [ "bindgen", "cc", "cmake", "find_cuda_helper", "glob", + "thiserror 2.0.18", "walkdir", ] [[package]] name = "llama-cpp-bindings-sys" -version = "0.4.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15bf30008f4d6624d200b3fc5ea953e7c08410b6d14dabecc6e4b288f20f9f05" +checksum = "ac6ce8ade04ae8cacd4ce4e627786c0a449c54d87b5e7287e936ab13fd8ceca8" dependencies = [ "llama-cpp-bindings-build", ] +[[package]] +name = "llama-cpp-bindings-types" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76906d544513079d6dbd299d9f0469f618077153aba8eeaf950727a06b45aac" +dependencies = [ + "serde", + "serde_json", + "thiserror 2.0.18", +] + +[[package]] +name = "llguidance" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "586ccb282d83fe337ddad2e28bb11bb258a462caa80e117fc76a6c73dac919e4" +dependencies = [ + "anyhow", + "derivre", + "indexmap", + "rayon", + "referencing 0.29.1", + "regex-syntax", + "serde", + "serde_json", + "toktrie", +] + [[package]] name = "local-channel" version = "0.1.5" @@ -3609,6 +3872,19 @@ name = "lru" version = "0.16.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f66e8d5d03f609abc3a39e6f08e4164ebf1447a732906d39eb9b99b7919ef39" +dependencies = [ + "hashbrown 0.16.1", +] + +[[package]] +name = "mac_address" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" +dependencies = [ + "nix 0.29.0", + "winapi", +] [[package]] name = "macro_registry" @@ -3617,7 +3893,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28c03fc749d06e1000766283015673e91aea121f30c60b7445681f2248e4994c" dependencies = [ "module_path_extractor", - "syn", + "syn 2.0.117", ] [[package]] @@ -3654,6 +3930,12 @@ dependencies = [ "libc", ] +[[package]] +name = "memmem" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" + [[package]] name = "memo-map" version = "0.3.3" @@ -3812,7 +4094,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "066cf25f0e8b11ee0df221219010f213ad429855f57c494f995590c861a9a7d8" dependencies = [ "arrayvec", - "bit-set", + "bit-set 0.8.0", "bitflags 2.11.0", "cfg-if", "cfg_aliases", @@ -3893,6 +4175,19 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nix" version = "0.30.1" @@ -3983,7 +4278,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -4056,7 +4351,16 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.117", +] + +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", ] [[package]] @@ -4470,7 +4774,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -4507,6 +4811,15 @@ dependencies = [ "libredox", ] +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "5.3.0" @@ -4588,6 +4901,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", + "tokio-test", "tokio-tungstenite", "tokio-util", "url", @@ -4648,6 +4962,29 @@ dependencies = [ "url", ] +[[package]] +name = "paddler_client_cli" +version = "3.1.2" +dependencies = [ + "anyhow", + "async-trait", + "clap", + "crossterm", + "env_logger", + "futures-util", + "llama-cpp-bindings-types", + "log", + "paddler_bootstrap", + "paddler_client", + "paddler_types", + "ratatui", + "reqwest", + "serde_json", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "paddler_gui" version = "3.1.2" @@ -4678,8 +5015,10 @@ dependencies = [ "async-stream", "base64", "futures-util", + "hf-hub", + "llama-cpp-bindings", "log", - "nix", + "nix 0.30.1", "paddler", "paddler_bootstrap", "paddler_client", @@ -4701,8 +5040,10 @@ version = "3.1.2" dependencies = [ "anyhow", "jsonschema", + "llama-cpp-bindings-types", "serde", "serde_json", + "thiserror 2.0.18", ] [[package]] @@ -4758,6 +5099,101 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "pest" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" +dependencies = [ + "memchr", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "pest_meta" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" +dependencies = [ + "pest", + "sha2", +] + +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_macros", + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand 0.8.6", +] + +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pico-args" version = "0.5.0" @@ -4781,7 +5217,7 @@ checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -4921,7 +5357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.117", ] [[package]] @@ -4958,7 +5394,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" dependencies = [ "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -5100,6 +5536,91 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" +[[package]] +name = "ratatui" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1ce67fb8ba4446454d1c8dbaeda0557ff5e94d39d5e5ed7f10a65eb4c8266bc" +dependencies = [ + "instability", + "ratatui-core", + "ratatui-crossterm", + "ratatui-macros", + "ratatui-termwiz", + "ratatui-widgets", +] + +[[package]] +name = "ratatui-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef8dea09a92caaf73bff7adb70b76162e5937524058a7e5bff37869cbbec293" +dependencies = [ + "bitflags 2.11.0", + "compact_str", + "hashbrown 0.16.1", + "indoc", + "itertools 0.14.0", + "kasuari", + "lru", + "strum", + "thiserror 2.0.18", + "unicode-segmentation", + "unicode-truncate", + "unicode-width", +] + +[[package]] +name = "ratatui-crossterm" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "577c9b9f652b4c121fb25c6a391dd06406d3b092ba68827e6d2f09550edc54b3" +dependencies = [ + "cfg-if", + "crossterm", + "instability", + "ratatui-core", +] + +[[package]] +name = "ratatui-macros" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7f1342a13e83e4bb9d0b793d0ea762be633f9582048c892ae9041ef39c936f4" +dependencies = [ + "ratatui-core", + "ratatui-widgets", +] + +[[package]] +name = "ratatui-termwiz" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f76fe0bd0ed4295f0321b1676732e2454024c15a35d01904ddb315afd3d545c" +dependencies = [ + "ratatui-core", + "termwiz", +] + +[[package]] +name = "ratatui-widgets" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7dbfa023cd4e604c2553483820c5fe8aa9d71a42eea5aa77c6e7f35756612db" +dependencies = [ + "bitflags 2.11.0", + "hashbrown 0.16.1", + "indoc", + "instability", + "itertools 0.14.0", + "line-clipping", + "ratatui-core", + "strum", + "time", + "unicode-segmentation", + "unicode-width", +] + [[package]] name = "rav1e" version = "0.8.1" @@ -5252,7 +5773,21 @@ checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", +] + +[[package]] +name = "referencing" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40a64b3a635fad9000648b4d8a59c8710c523ab61a23d392a7d91d47683f5adc" +dependencies = [ + "ahash", + "fluent-uri 0.3.2", + "once_cell", + "parking_lot", + "percent-encoding", + "serde_json", ] [[package]] @@ -5262,7 +5797,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4283168a506f0dcbdce31c9f9cce3129c924da4c6bca46e46707fcb746d2d70c" dependencies = [ "ahash", - "fluent-uri", + "fluent-uri 0.4.1", "getrandom 0.3.4", "hashbrown 0.16.1", "parking_lot", @@ -5447,7 +5982,7 @@ dependencies = [ "quote", "rust-embed-utils", "shellexpand", - "syn", + "syn 2.0.117", "walkdir", ] @@ -5693,7 +6228,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -5715,6 +6250,7 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ + "indexmap", "itoa", "memchr", "serde", @@ -5741,7 +6277,7 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -5780,7 +6316,7 @@ checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -5831,6 +6367,27 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -6119,7 +6676,7 @@ dependencies = [ "moddef", "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6137,6 +6694,27 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "subtle" version = "2.6.1" @@ -6180,6 +6758,17 @@ dependencies = [ "zeno", ] +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.117" @@ -6208,7 +6797,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6263,6 +6852,69 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminfo" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ea810f0692f9f51b382fff5893887bb4580f5fa246fde546e0b13e7fcee662" +dependencies = [ + "fnv", + "nom 7.1.3", + "phf", + "phf_codegen", +] + +[[package]] +name = "termios" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "411c5bf740737c7918b8b1fe232dca4dc9f8e754b8ad5e20966814001ed0ac6b" +dependencies = [ + "libc", +] + +[[package]] +name = "termwiz" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4676b37242ccbd1aabf56edb093a4827dc49086c0ffd764a5705899e0f35f8f7" +dependencies = [ + "anyhow", + "base64", + "bitflags 2.11.0", + "fancy-regex 0.11.0", + "filedescriptor", + "finl_unicode", + "fixedbitset", + "hex", + "lazy_static", + "libc", + "log", + "memmem", + "nix 0.29.0", + "num-derive", + "num-traits", + "ordered-float 4.6.0", + "pest", + "pest_derive", + "phf", + "sha2", + "signal-hook", + "siphasher", + "terminfo", + "termios", + "thiserror 1.0.69", + "ucd-trie", + "unicode-segmentation", + "vtparse", + "wezterm-bidi", + "wezterm-blob-leases", + "wezterm-color-types", + "wezterm-dynamic", + "wezterm-input-types", + "winapi", +] + [[package]] name = "textwrap" version = "0.16.2" @@ -6298,7 +6950,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6309,7 +6961,7 @@ checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6334,7 +6986,9 @@ checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", + "libc", "num-conv", + "num_threads", "powerfmt", "serde_core", "time-core", @@ -6457,7 +7111,7 @@ checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6492,6 +7146,17 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-test" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6d24790a10a7af737693a3e8f1d03faef7e6ca0cc99aae5066f533766de545" +dependencies = [ + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.28.0" @@ -6517,6 +7182,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "toktrie" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0310776df61c26b98e6ed7c3629de5131e0c8731439350b79152b7f52468229a" +dependencies = [ + "anyhow", + "bytemuck", + "bytemuck_derive", + "serde", + "serde_json", +] + [[package]] name = "toml_datetime" version = "1.1.1+spec-1.1.0" @@ -6612,7 +7290,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -6663,6 +7341,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "uds_windows" version = "1.2.1" @@ -6734,6 +7418,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" +[[package]] +name = "unicode-truncate" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b380a1238663e5f8a691f9039c73e1cdae598a30e9855f541d29b08b53e9a5" +dependencies = [ + "itertools 0.14.0", + "unicode-segmentation", + "unicode-width", +] + [[package]] name = "unicode-vo" version = "0.1.0" @@ -6875,6 +7570,8 @@ version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ + "atomic", + "getrandom 0.4.2", "js-sys", "serde_core", "wasm-bindgen", @@ -6925,6 +7622,15 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" +[[package]] +name = "vtparse" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9b2acfb050df409c972a37d3b8e08cdea3bddb0c09db9d53137e504cfabed0" +dependencies = [ + "utf8parse", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -7010,7 +7716,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn", + "syn 2.0.117", "wasm-bindgen-shared", ] @@ -7263,6 +7969,78 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" +[[package]] +name = "wezterm-bidi" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0a6e355560527dd2d1cf7890652f4f09bb3433b6aadade4c9b5ed76de5f3ec" +dependencies = [ + "log", + "wezterm-dynamic", +] + +[[package]] +name = "wezterm-blob-leases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692daff6d93d94e29e4114544ef6d5c942a7ed998b37abdc19b17136ea428eb7" +dependencies = [ + "getrandom 0.3.4", + "mac_address", + "sha2", + "thiserror 1.0.69", + "uuid", +] + +[[package]] +name = "wezterm-color-types" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7de81ef35c9010270d63772bebef2f2d6d1f2d20a983d27505ac850b8c4b4296" +dependencies = [ + "csscolorparser", + "deltae", + "lazy_static", + "wezterm-dynamic", +] + +[[package]] +name = "wezterm-dynamic" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f2ab60e120fd6eaa68d9567f3226e876684639d22a4219b313ff69ec0ccd5ac" +dependencies = [ + "log", + "ordered-float 4.6.0", + "strsim", + "thiserror 1.0.69", + "wezterm-dynamic-derive", +] + +[[package]] +name = "wezterm-dynamic-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c0cf2d539c645b448eaffec9ec494b8b19bd5077d9e58cb1ae7efece8d575b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "wezterm-input-types" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7012add459f951456ec9d6c7e6fc340b1ce15d6fc9629f8c42853412c029e57e" +dependencies = [ + "bitflags 1.3.2", + "euclid", + "lazy_static", + "serde", + "wezterm-dynamic", +] + [[package]] name = "wgpu" version = "27.0.1" @@ -7299,8 +8077,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27a75de515543b1897b26119f93731b385a19aea165a1ec5f0e3acecc229cae7" dependencies = [ "arrayvec", - "bit-set", - "bit-vec", + "bit-set 0.8.0", + "bit-vec 0.8.0", "bitflags 2.11.0", "bytemuck", "cfg_aliases", @@ -7360,7 +8138,7 @@ dependencies = [ "android_system_properties", "arrayvec", "ash", - "bit-set", + "bit-set 0.8.0", "bitflags 2.11.0", "block", "bytemuck", @@ -7383,7 +8161,7 @@ dependencies = [ "ndk-sys", "objc", "once_cell", - "ordered-float", + "ordered-float 5.3.0", "parking_lot", "portable-atomic", "portable-atomic-util", @@ -7535,7 +8313,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -7546,7 +8324,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -7557,7 +8335,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -7568,7 +8346,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -7975,7 +8753,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn", + "syn 2.0.117", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -7991,7 +8769,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn", + "syn 2.0.117", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -8145,7 +8923,7 @@ checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", "synstructure", ] @@ -8193,7 +8971,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.117", "zbus_names", "zvariant", "zvariant_utils", @@ -8233,7 +9011,7 @@ checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -8253,7 +9031,7 @@ checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", "synstructure", ] @@ -8293,7 +9071,7 @@ checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -8392,7 +9170,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.117", "zvariant_utils", ] @@ -8405,6 +9183,6 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn", + "syn 2.0.117", "winnow 0.7.15", ] diff --git a/Cargo.toml b/Cargo.toml index 706fe91e..17e67296 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["paddler", "paddler_bootstrap", "paddler_cli", "paddler_client", "paddler_gui", "paddler_tests", "paddler_types"] +members = ["paddler", "paddler_bootstrap", "paddler_cli", "paddler_client", "paddler_client_cli", "paddler_gui", "paddler_tests", "paddler_types"] resolver = "2" [workspace.package] @@ -25,6 +25,7 @@ async-trait = "0.1" bytes = "1.11" cadence = "1.6" clap = { version = "4.5", features = ["derive"] } +crossterm = { version = "=0.29.0", features = ["event-stream"] } dashmap = "6.1" encoding_rs = { version = "0.8", features = ["serde"] } env_logger = "0.11" @@ -35,8 +36,9 @@ hf-hub = { version = "0.4", features = ["tokio"] } image = "0.25" indoc = "2" jsonschema = { version = "0.37", default-features = false } -llama-cpp-bindings = { version = "0.4.2", features = ["mtmd"] } -llama-cpp-bindings-sys = "0.4.2" +llama-cpp-bindings = "=0.6.0" +llama-cpp-bindings-sys = "=0.6.0" +llama-cpp-bindings-types = "=0.6.0" base64 = "0.22" log = "0.4" mime_guess = "2" @@ -47,6 +49,7 @@ nix = { version = "0.30", features = ["signal"] } open = "5.3.4" pastey = "0.2" rand = "0.9" +ratatui = "=0.30.0" reqwest = { version = "0.12", features = ["json", "stream"] } resvg = "0.46" rust-embed = { version = "8.9", features = ["interpolate-folder-path"] } @@ -60,6 +63,7 @@ statum = "0.6" tempfile = "3.20.0" tokio = { version = "1.48", features = ["full"] } tokio-stream = { version = "0.1.17", features = ["sync"] } +tokio-test = "0.4.4" tokio-tungstenite = "0.28" tokio-util = "0.7" thiserror = "2" diff --git a/Makefile b/Makefile index 34bea6c8..1f90ea08 100644 --- a/Makefile +++ b/Makefile @@ -72,16 +72,24 @@ test.integration: target/debug/paddler .PHONY: test.integration.cuda test.integration.cuda: target/cuda/debug/paddler - PADDLER_BINARY_PATH=../target/cuda/debug/paddler PADDLER_TEST_DEVICE=cuda cargo test -p paddler_tests --features cuda,tests_that_use_compiled_paddler,tests_that_use_llms + PADDLER_BINARY_PATH=../target/cuda/debug/paddler PADDLER_TEST_DEVICE=cuda cargo test --target-dir target/cuda -p paddler_tests --features cuda,tests_that_use_compiled_paddler,tests_that_use_llms .PHONY: test.integration.metal test.integration.metal: target/metal/debug/paddler - PADDLER_BINARY_PATH=../target/metal/debug/paddler PADDLER_TEST_DEVICE=metal cargo test -p paddler_tests --features metal,tests_that_use_compiled_paddler,tests_that_use_llms + PADDLER_BINARY_PATH=../target/metal/debug/paddler PADDLER_TEST_DEVICE=metal cargo test --target-dir target/metal -p paddler_tests --features metal,tests_that_use_compiled_paddler,tests_that_use_llms .PHONY: test.unit test.unit: esbuild-meta.json cargo test --features web_admin_panel +.PHONY: build.client.js +build.client.js: + npm --workspace @intentee/paddler-client run build + +.PHONY: test.client.js +test.client.js: + npm --workspace @intentee/paddler-client test + .PHONY: watch watch: node_modules ./jarmuz-watch.mjs diff --git a/fixtures/sarnow.jpeg b/fixtures/sarnow.jpeg new file mode 100644 index 00000000..a8b67b13 Binary files /dev/null and b/fixtures/sarnow.jpeg differ diff --git a/jarmuz/run-website.mjs b/jarmuz/run-website.mjs index 0b8028a2..ee8453e5 100644 --- a/jarmuz/run-website.mjs +++ b/jarmuz/run-website.mjs @@ -6,7 +6,13 @@ export function run({ development, once = false, rustJobs }) { jarmuz({ once, pipeline: ["stylelint", "tcm", "tsc", "eslint", esbuildJob, ...rustJobs], - watch: ["paddler", "paddler_client", "paddler_types", "resources"], + watch: [ + "paddler", + "paddler_client", + "paddler_client_javascript", + "paddler_types", + "resources", + ], }).decide(function ({ matches, schedule }) { if (matches("resources/**/*.css")) { schedule("stylelint"); @@ -14,6 +20,7 @@ export function run({ development, once = false, rustJobs }) { switch (true) { case matches("resources/**/*.{ts,tsx}"): + case matches("paddler_client_javascript/src/**/*.ts"): schedule("tsc"); schedule("eslint"); break; diff --git a/package-lock.json b/package-lock.json index 88c5d46e..69b7026c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -6,11 +6,14 @@ "": { "name": "paddler", "license": "Apache-2.0", + "workspaces": [ + "paddler_client_javascript" + ], "dependencies": { "@codemirror/lang-jinja": "^6.0.0", "@uiw/react-codemirror": "^4.24.2", "clsx": "^2.1.1", - "nanoid": "^5.1.5", + "nanoid": "^5.1.11", "path-to-regexp": "^8.4.0", "react": "^19.1.1", "react-dom": "^19.1.1", @@ -23,7 +26,6 @@ "@types/hotwired__turbo": "^8", "@types/react": "^19.1.10", "@types/react-dom": "^19.1.7", - "ava": "^6.4.1", "esbuild": "^0.25.8", "eslint": "^9.33.0", "eslint-plugin-react-hooks": "^5.2.0", @@ -37,7 +39,6 @@ "stylelint": "^16.23.1", "stylelint-config-recommended": "^17.0.0", "tempy": "^3.1.0", - "tsimp": "^2.0.12", "tslib": "^2.8.1", "typed-css-modules": "^0.9.1", "typescript": "^5.9.2", @@ -979,22 +980,9 @@ "url": "https://github.com/sponsors/nzakas" } }, - "node_modules/@isaacs/cached": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@isaacs/cached/-/cached-1.0.1.tgz", - "integrity": "sha512-7kGcJ9Hc1f4qpTApWz3swxbF9Qv1NF/GxuPtXeTptbsgvJIoufSd0h854Nq/2bw80F5C1onsFgEI05l+q0e4vw==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/catcher": "^1.0.0" - } - }, - "node_modules/@isaacs/catcher": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/@isaacs/catcher/-/catcher-1.0.4.tgz", - "integrity": "sha512-g2klMwbnguClWNnCeQ1zYaDJsvPbIbnjdJPDE0z09MqoejJDZSLK5vIKiClq2Bkg5ubuI8vaN6wfIUi5GYzMVA==", - "dev": true, - "license": "BlueOak-1.0.0" + "node_modules/@intentee/paddler-client": { + "resolved": "paddler_client_javascript", + "link": true }, "node_modules/@isaacs/cliui": { "version": "8.0.2", @@ -1039,19 +1027,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/@isaacs/fs-minipass": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", - "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", - "dev": true, - "license": "ISC", - "dependencies": { - "minipass": "^7.0.4" - }, - "engines": { - "node": ">=18.0.0" - } - }, "node_modules/@keyv/serialize": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@keyv/serialize/-/serialize-1.1.0.tgz", @@ -1116,28 +1091,6 @@ "@lezer/common": "^1.0.0" } }, - "node_modules/@mapbox/node-pre-gyp": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@mapbox/node-pre-gyp/-/node-pre-gyp-2.0.0.tgz", - "integrity": "sha512-llMXd39jtP0HpQLVI37Bf1m2ADlEb35GYSh1SDSLsBhR+5iCxiNGlT31yqbNtVHygHAtMy6dWFERpU2JgufhPg==", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "consola": "^3.2.3", - "detect-libc": "^2.0.0", - "https-proxy-agent": "^7.0.5", - "node-fetch": "^2.6.7", - "nopt": "^8.0.0", - "semver": "^7.5.3", - "tar": "^7.4.0" - }, - "bin": { - "node-pre-gyp": "bin/node-pre-gyp" - }, - "engines": { - "node": ">=18" - } - }, "node_modules/@marijn/find-cluster-break": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", @@ -1193,42 +1146,6 @@ "node": ">=14" } }, - "node_modules/@rollup/pluginutils": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/@rollup/pluginutils/-/pluginutils-5.2.0.tgz", - "integrity": "sha512-qWJ2ZTbmumwiLFomfzTyt5Kng4hwPi9rwCYN4SHb6eaRU1KNO4ccxINHr/VhH4GgPlt1XfSTLX2LBTme8ne4Zw==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/estree": "^1.0.0", - "estree-walker": "^2.0.2", - "picomatch": "^4.0.2" - }, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "rollup": "^1.20.0||^2.0.0||^3.0.0||^4.0.0" - }, - "peerDependenciesMeta": { - "rollup": { - "optional": true - } - } - }, - "node_modules/@sindresorhus/merge-streams": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/@sindresorhus/merge-streams/-/merge-streams-2.3.0.tgz", - "integrity": "sha512-LtoMMhxAlorcGhmFYI+LhPgbPZCkgP6ra1YL604EeF6U98pLlQ3iWIGMdWSC+vWmPBWBNgmDBAhnAobLROJmwg==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/@types/debug": { "version": "4.1.12", "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz", @@ -1291,6 +1208,16 @@ "integrity": "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==", "license": "MIT" }, + "node_modules/@types/node": { + "version": "22.19.17", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.17.tgz", + "integrity": "sha512-wGdMcf+vPYM6jikpS/qhg6WiqSV/OhG+jeeHT/KlVqxYfD40iYJf9/AE1uQxVWFvU7MipKRkRv8NSHiCGgPr8Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.21.0" + } + }, "node_modules/@types/react": { "version": "19.1.10", "resolved": "https://registry.npmjs.org/@types/react/-/react-19.1.10.tgz", @@ -1633,130 +1560,6 @@ "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", "license": "ISC" }, - "node_modules/@vercel/nft": { - "version": "0.29.4", - "resolved": "https://registry.npmjs.org/@vercel/nft/-/nft-0.29.4.tgz", - "integrity": "sha512-6lLqMNX3TuycBPABycx7A9F1bHQR7kiQln6abjFbPrf5C/05qHM9M5E4PeTE59c7z8g6vHnx1Ioihb2AQl7BTA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@mapbox/node-pre-gyp": "^2.0.0", - "@rollup/pluginutils": "^5.1.3", - "acorn": "^8.6.0", - "acorn-import-attributes": "^1.9.5", - "async-sema": "^3.1.1", - "bindings": "^1.4.0", - "estree-walker": "2.0.2", - "glob": "^10.4.5", - "graceful-fs": "^4.2.9", - "node-gyp-build": "^4.2.2", - "picomatch": "^4.0.2", - "resolve-from": "^5.0.0" - }, - "bin": { - "nft": "out/cli.js" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@vercel/nft/node_modules/brace-expansion": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz", - "integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==", - "dev": true, - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/@vercel/nft/node_modules/glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/@vercel/nft/node_modules/jackspeak": { - "version": "3.4.3", - "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", - "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/cliui": "^8.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - }, - "optionalDependencies": { - "@pkgjs/parseargs": "^0.11.0" - } - }, - "node_modules/@vercel/nft/node_modules/lru-cache": { - "version": "10.4.3", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", - "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", - "dev": true, - "license": "ISC" - }, - "node_modules/@vercel/nft/node_modules/minimatch": { - "version": "9.0.9", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", - "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.2" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/@vercel/nft/node_modules/path-scurry": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", - "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "lru-cache": "^10.2.0", - "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" - }, - "engines": { - "node": ">=16 || 14 >=14.18" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/abbrev": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-3.0.1.tgz", - "integrity": "sha512-AO2ac6pjRB3SJmGJo+v5/aK6Omggp6fsLrs6wN9bd35ulu4cCwaAU9+7ZhXjeqHVkaHThLuzH0nZr0YpCDhygg==", - "dev": true, - "license": "ISC", - "engines": { - "node": "^18.17.0 || >=20.5.0" - } - }, "node_modules/acorn": { "version": "8.15.0", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", @@ -1770,16 +1573,6 @@ "node": ">=0.4.0" } }, - "node_modules/acorn-import-attributes": { - "version": "1.9.5", - "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", - "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", - "dev": true, - "license": "MIT", - "peerDependencies": { - "acorn": "^8" - } - }, "node_modules/acorn-jsx": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", @@ -1790,29 +1583,6 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, - "node_modules/acorn-walk": { - "version": "8.3.4", - "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz", - "integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==", - "dev": true, - "license": "MIT", - "dependencies": { - "acorn": "^8.11.0" - }, - "engines": { - "node": ">=0.4.0" - } - }, - "node_modules/agent-base": { - "version": "7.1.4", - "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", - "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 14" - } - }, "node_modules/ajv": { "version": "6.14.0", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", @@ -1890,16 +1660,6 @@ "dev": true, "license": "Python-2.0" }, - "node_modules/array-find-index": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/array-find-index/-/array-find-index-1.0.2.tgz", - "integrity": "sha512-M1HQyIXcBGtVywBt8WVdim+lrNaK7VHp99Qt5pSNziXznKHViIBbXWtfRTpEFpF/c4FdfxNAsCCwPp5phBYJtw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/array-union": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", @@ -1910,29 +1670,6 @@ "node": ">=8" } }, - "node_modules/arrgv": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/arrgv/-/arrgv-1.0.2.tgz", - "integrity": "sha512-a4eg4yhp7mmruZDQFqVMlxNRFGi/i1r87pt8SDHy0/I8PqSXoUTlWZRdAZo0VXgvEARcujbtTk8kiZRi1uDGRw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8.0.0" - } - }, - "node_modules/arrify": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/arrify/-/arrify-3.0.0.tgz", - "integrity": "sha512-tLkvA81vQG/XqE2mjDkGQHoOINtMHtysSnemrmoGe6PydDPMRbVugqyk4A6V/WDWEfm3l+0d8anA9r8cv/5Jaw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/astral-regex": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/astral-regex/-/astral-regex-2.0.0.tgz", @@ -1943,76 +1680,6 @@ "node": ">=8" } }, - "node_modules/async-sema": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/async-sema/-/async-sema-3.1.1.tgz", - "integrity": "sha512-tLRNUXati5MFePdAk8dw7Qt7DpxPB60ofAgn8WRhW6a2rcimZnYBP9oxHiv0OHy+Wz7kPMG+t4LGdt31+4EmGg==", - "dev": true, - "license": "MIT" - }, - "node_modules/ava": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/ava/-/ava-6.4.1.tgz", - "integrity": "sha512-vxmPbi1gZx9zhAjHBgw81w/iEDKcrokeRk/fqDTyA2DQygZ0o+dUGRHFOtX8RA5N0heGJTTsIk7+xYxitDb61Q==", - "dev": true, - "license": "MIT", - "dependencies": { - "@vercel/nft": "^0.29.4", - "acorn": "^8.15.0", - "acorn-walk": "^8.3.4", - "ansi-styles": "^6.2.1", - "arrgv": "^1.0.2", - "arrify": "^3.0.0", - "callsites": "^4.2.0", - "cbor": "^10.0.9", - "chalk": "^5.4.1", - "chunkd": "^2.0.1", - "ci-info": "^4.3.0", - "ci-parallel-vars": "^1.0.1", - "cli-truncate": "^4.0.0", - "code-excerpt": "^4.0.0", - "common-path-prefix": "^3.0.0", - "concordance": "^5.0.4", - "currently-unhandled": "^0.4.1", - "debug": "^4.4.1", - "emittery": "^1.2.0", - "figures": "^6.1.0", - "globby": "^14.1.0", - "ignore-by-default": "^2.1.0", - "indent-string": "^5.0.0", - "is-plain-object": "^5.0.0", - "is-promise": "^4.0.0", - "matcher": "^5.0.0", - "memoize": "^10.1.0", - "ms": "^2.1.3", - "p-map": "^7.0.3", - "package-config": "^5.0.0", - "picomatch": "^4.0.2", - "plur": "^5.1.0", - "pretty-ms": "^9.2.0", - "resolve-cwd": "^3.0.0", - "stack-utils": "^2.0.6", - "strip-ansi": "^7.1.0", - "supertap": "^3.0.1", - "temp-dir": "^3.0.0", - "write-file-atomic": "^6.0.0", - "yargs": "^17.7.2" - }, - "bin": { - "ava": "entrypoints/cli.mjs" - }, - "engines": { - "node": "^18.18 || ^20.8 || ^22 || ^23 || >=24" - }, - "peerDependencies": { - "@ava/typescript": "*" - }, - "peerDependenciesMeta": { - "@ava/typescript": { - "optional": true - } - } - }, "node_modules/bail": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz", @@ -2043,23 +1710,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/bindings": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/bindings/-/bindings-1.5.0.tgz", - "integrity": "sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "file-uri-to-path": "1.0.0" - } - }, - "node_modules/blueimp-md5": { - "version": "2.19.0", - "resolved": "https://registry.npmjs.org/blueimp-md5/-/blueimp-md5-2.19.0.tgz", - "integrity": "sha512-DRQrD6gJyy8FbiE4s+bDoXS9hiW3Vbx5uCdwvcCf3zLHL+Iv7LtGHLpr+GZV8rHG8tK766FGYBwRbu8pELTt+w==", - "dev": true, - "license": "MIT" - }, "node_modules/brace-expansion": { "version": "1.1.14", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.14.tgz", @@ -2105,19 +1755,6 @@ "@keyv/serialize": "^1.1.0" } }, - "node_modules/callsites": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/callsites/-/callsites-4.2.0.tgz", - "integrity": "sha512-kfzR4zzQtAE9PC7CzZsjl3aBNbXWuXiSeOCdLcPpBfGW8YuCqQHcRPFDbr/BPVmd3EEPVpuFzLyuT/cUhPr4OQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12.20" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/camelcase": { "version": "6.3.0", "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-6.3.0.tgz", @@ -2131,19 +1768,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/cbor": { - "version": "10.0.9", - "resolved": "https://registry.npmjs.org/cbor/-/cbor-10.0.9.tgz", - "integrity": "sha512-KEWYehb/vJkRmigctVQLsz73Us2RNnITo/wOwQV5AtZpLGH1r2PPlsNHdsX460YuHZCyhLklbYzAOuJfOeg34Q==", - "dev": true, - "license": "MIT", - "dependencies": { - "nofilter": "^3.0.2" - }, - "engines": { - "node": ">=20" - } - }, "node_modules/ccount": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/ccount/-/ccount-2.0.1.tgz", @@ -2154,23 +1778,10 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/chalk": { - "version": "5.4.1", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.4.1.tgz", - "integrity": "sha512-zgVZuo2WcZgfUEmsn6eO3kINexW8RAE4maiQ8QNs8CtpPCSyMiYsULR3HQYkm3w8FIA3SberyMJMSldGsW+U3w==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^12.17.0 || ^14.13 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/character-entities": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz", - "integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==", + "node_modules/character-entities": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz", + "integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==", "license": "MIT", "funding": { "type": "github", @@ -2223,63 +1834,6 @@ "url": "https://paulmillr.com/funding/" } }, - "node_modules/chownr": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", - "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", - "dev": true, - "license": "BlueOak-1.0.0", - "engines": { - "node": ">=18" - } - }, - "node_modules/chunkd": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/chunkd/-/chunkd-2.0.1.tgz", - "integrity": "sha512-7d58XsFmOq0j6el67Ug9mHf9ELUXsQXYJBkyxhH/k+6Ke0qXRnv0kbemx+Twc6fRJ07C49lcbdgm9FL1Ei/6SQ==", - "dev": true, - "license": "MIT" - }, - "node_modules/ci-info": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ci-info/-/ci-info-4.3.0.tgz", - "integrity": "sha512-l+2bNRMiQgcfILUi33labAZYIWlH1kWDp+ecNo5iisRKrbm0xcRyCww71/YU0Fkw0mAFpz9bJayXPjey6vkmaQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/sibiraj-s" - } - ], - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/ci-parallel-vars": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/ci-parallel-vars/-/ci-parallel-vars-1.0.1.tgz", - "integrity": "sha512-uvzpYrpmidaoxvIQHM+rKSrigjOe9feHYbw4uOI2gdfe1C3xIlxO+kVXq83WQWNniTf8bAxVpy+cQeFQsMERKg==", - "dev": true, - "license": "MIT" - }, - "node_modules/cli-truncate": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/cli-truncate/-/cli-truncate-4.0.0.tgz", - "integrity": "sha512-nPdaFdQ0h/GEigbPClz11D0v/ZJEwxmeVZGeMo3Z5StPtUTkA9o1lD6QwoirYiSDzbcwn2XcjwmCp68W1IS4TA==", - "dev": true, - "license": "MIT", - "dependencies": { - "slice-ansi": "^5.0.0", - "string-width": "^7.0.0" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/cliui": { "version": "8.0.1", "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", @@ -2393,19 +1947,6 @@ "node": ">=6" } }, - "node_modules/code-excerpt": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/code-excerpt/-/code-excerpt-4.0.0.tgz", - "integrity": "sha512-xxodCmBen3iy2i0WtAK8FlFNrRzjUqjRsMfho58xT/wvZU1YTM3fCnRjcy1gJPMepaRlgm/0e6w8SpWHpn3/cA==", - "dev": true, - "license": "MIT", - "dependencies": { - "convert-to-spaces": "^2.0.1" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, "node_modules/codemirror": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz", @@ -2458,13 +1999,6 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/common-path-prefix": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/common-path-prefix/-/common-path-prefix-3.0.0.tgz", - "integrity": "sha512-QE33hToZseCH3jS0qN96O/bSh3kaw/h+Tq7ngyY9eWDUnTlTNUyqfqvCXioLe5Na5jFsL78ra/wuBU4iuEgd4w==", - "dev": true, - "license": "ISC" - }, "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", @@ -2472,46 +2006,6 @@ "dev": true, "license": "MIT" }, - "node_modules/concordance": { - "version": "5.0.4", - "resolved": "https://registry.npmjs.org/concordance/-/concordance-5.0.4.tgz", - "integrity": "sha512-OAcsnTEYu1ARJqWVGwf4zh4JDfHZEaSNlNccFmt8YjB2l/n19/PF2viLINHc57vO4FKIAFl2FWASIGZZWZ2Kxw==", - "dev": true, - "license": "ISC", - "dependencies": { - "date-time": "^3.1.0", - "esutils": "^2.0.3", - "fast-diff": "^1.2.0", - "js-string-escape": "^1.0.1", - "lodash": "^4.17.15", - "md5-hex": "^3.0.1", - "semver": "^7.3.2", - "well-known-symbols": "^2.0.0" - }, - "engines": { - "node": ">=10.18.0 <11 || >=12.14.0 <13 || >=14" - } - }, - "node_modules/consola": { - "version": "3.4.2", - "resolved": "https://registry.npmjs.org/consola/-/consola-3.4.2.tgz", - "integrity": "sha512-5IKcdX0nnYavi6G7TtOhwkYzyjfJlatbjMjuLSfE2kYT5pMDOilZ4OvMhi637CcDICTmz3wARPoyhqyX1Y+XvA==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^14.18.0 || >=16.10.0" - } - }, - "node_modules/convert-to-spaces": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/convert-to-spaces/-/convert-to-spaces-2.0.1.tgz", - "integrity": "sha512-rcQ1bsQO9799wq24uE5AM2tAILy4gXGIK/njFWcVQkGNZ96edlpY+A7bjwvzjYvLDyzmG1MmMLZhpcsb+klNMQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, "node_modules/cosmiconfig": { "version": "9.0.0", "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-9.0.0.tgz", @@ -2632,32 +2126,6 @@ "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", "license": "MIT" }, - "node_modules/currently-unhandled": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/currently-unhandled/-/currently-unhandled-0.4.1.tgz", - "integrity": "sha512-/fITjgjGU50vjQ4FH6eUoYu+iUoUKIXws2hL15JJpIR+BbTxaXQsMuuyjtNh2WqsSBS5nsaZHFsFecyw5CCAng==", - "dev": true, - "license": "MIT", - "dependencies": { - "array-find-index": "^1.0.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/date-time": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/date-time/-/date-time-3.1.0.tgz", - "integrity": "sha512-uqCUKXE5q1PNBXjPqvwhwJf9SwMoAHBgWJ6DcrnS5o+W2JOiIILl0JEdVD8SGujrNS02GGxgwAg2PN2zONgtjg==", - "dev": true, - "license": "MIT", - "dependencies": { - "time-zone": "^1.0.0" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/debug": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", @@ -2704,16 +2172,6 @@ "node": ">=6" } }, - "node_modules/detect-libc": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.4.tgz", - "integrity": "sha512-3UDv+G9CsCKO1WKMGw9fwq/SWJYbI0c5Y7LU1AXYoDdbhE2AHQ6N6Nb34sG8Fj7T5APy8qXDCKuuIHd1BR0tVA==", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=8" - } - }, "node_modules/devlop": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", @@ -2757,26 +2215,6 @@ "dev": true, "license": "MIT" }, - "node_modules/emittery": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/emittery/-/emittery-1.2.0.tgz", - "integrity": "sha512-KxdRyyFcS85pH3dnU8Y5yFUm2YJdaHwcBZWrfG8o89ZY9a13/f9itbN+YG3ELbBo9Pg5zvIozstmuV8bX13q6g==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sindresorhus/emittery?sponsor=1" - } - }, - "node_modules/emoji-regex": { - "version": "10.4.0", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-10.4.0.tgz", - "integrity": "sha512-EC+0oUMY1Rqm4O6LLrgjtYDvcVYTy7chDnM4Q7030tP4Kwj3u/pR6gP9ygnp2CJMK5Gq+9Q2oqmrFJAz01DXjw==", - "dev": true, - "license": "MIT" - }, "node_modules/env-paths": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/env-paths/-/env-paths-2.2.1.tgz", @@ -3017,20 +2455,6 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/esprima": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", - "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", - "dev": true, - "license": "BSD-2-Clause", - "bin": { - "esparse": "bin/esparse.js", - "esvalidate": "bin/esvalidate.js" - }, - "engines": { - "node": ">=4" - } - }, "node_modules/esquery": { "version": "1.6.0", "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.6.0.tgz", @@ -3077,13 +2501,6 @@ "url": "https://opencollective.com/unified" } }, - "node_modules/estree-walker": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", - "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", - "dev": true, - "license": "MIT" - }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -3107,13 +2524,6 @@ "dev": true, "license": "MIT" }, - "node_modules/fast-diff": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/fast-diff/-/fast-diff-1.3.0.tgz", - "integrity": "sha512-VxPP4NqbUjj6MaAOafWeUn2cXWLcCtljklUtZf0Ind4XQ+QPtmA0b18zZy0jIQx+ExRVCR/ZQpBmik5lXshNsw==", - "dev": true, - "license": "Apache-2.0" - }, "node_modules/fast-glob": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", @@ -3195,22 +2605,6 @@ "reusify": "^1.0.4" } }, - "node_modules/figures": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/figures/-/figures-6.1.0.tgz", - "integrity": "sha512-d+l3qxjSesT4V7v2fh+QnmFnUWv9lSpjarhShNTgBOfA0ttejbQUAlHLitbjkoRiDulW0OPoQPYIGhIC8ohejg==", - "dev": true, - "license": "MIT", - "dependencies": { - "is-unicode-supported": "^2.0.0" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/file-entry-cache": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", @@ -3224,13 +2618,6 @@ "node": ">=16.0.0" } }, - "node_modules/file-uri-to-path": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz", - "integrity": "sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw==", - "dev": true, - "license": "MIT" - }, "node_modules/fill-range": { "version": "7.1.1", "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", @@ -3261,19 +2648,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/find-up-simple": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/find-up-simple/-/find-up-simple-1.0.1.tgz", - "integrity": "sha512-afd4O7zpqHeRyg4PfDQsXmlDe2PfdHtJt6Akt8jOWaApLOZk5JXs6VMR29lz03pRe9mpykrRCYIYxaJYcfpncQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/flat-cache": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-4.0.1.tgz", @@ -3352,17 +2726,17 @@ "node": "6.* || 8.* || >= 10.*" } }, - "node_modules/get-east-asian-width": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.3.0.tgz", - "integrity": "sha512-vpeMIQKxczTD/0s2CdEWHcb0eeJe6TFjxb+J5xgX7hScxqrGuyjmv4c1D4A/gelKfyox0gJJwIHF+fLjeaM8kQ==", + "node_modules/get-tsconfig": { + "version": "4.14.0", + "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.14.0.tgz", + "integrity": "sha512-yTb+8DXzDREzgvYmh6s9vHsSVCHeC0G3PI5bEXNBHtmshPnO+S5O7qgLEOn0I5QvMy6kpZN8K1NKGyilLb93wA==", "dev": true, "license": "MIT", - "engines": { - "node": ">=18" + "dependencies": { + "resolve-pkg-maps": "^1.0.0" }, "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" } }, "node_modules/glob": { @@ -3495,37 +2869,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/globby": { - "version": "14.1.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-14.1.0.tgz", - "integrity": "sha512-0Ia46fDOaT7k4og1PDW4YbodWWr3scS2vAr2lTbsplOt2WkKp0vQbkI9wKis/T5LV/dqPjO3bpS/z6GTJB82LA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@sindresorhus/merge-streams": "^2.1.0", - "fast-glob": "^3.3.3", - "ignore": "^7.0.3", - "path-type": "^6.0.0", - "slash": "^5.1.0", - "unicorn-magic": "^0.3.0" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/globby/node_modules/ignore": { - "version": "7.0.5", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.5.tgz", - "integrity": "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 4" - } - }, "node_modules/globjoin": { "version": "0.1.4", "resolved": "https://registry.npmjs.org/globjoin/-/globjoin-0.1.4.tgz", @@ -3634,20 +2977,6 @@ "url": "https://opencollective.com/unified" } }, - "node_modules/https-proxy-agent": { - "version": "7.0.6", - "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", - "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", - "dev": true, - "license": "MIT", - "dependencies": { - "agent-base": "^7.1.2", - "debug": "4" - }, - "engines": { - "node": ">= 14" - } - }, "node_modules/icss-replace-symbols": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/icss-replace-symbols/-/icss-replace-symbols-1.1.0.tgz", @@ -3678,16 +3007,6 @@ "node": ">= 4" } }, - "node_modules/ignore-by-default": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/ignore-by-default/-/ignore-by-default-2.1.0.tgz", - "integrity": "sha512-yiWd4GVmJp0Q6ghmM2B/V3oZGRmjrKLXvHR3TE1nfoXsmoggllfZUQe74EN0fJdPFZu2NIvNdrMMLm3OsV7Ohw==", - "dev": true, - "license": "ISC", - "engines": { - "node": ">=10 <11 || >=12 <13 || >=14" - } - }, "node_modules/import-fresh": { "version": "3.3.1", "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", @@ -3725,19 +3044,6 @@ "node": ">=0.8.19" } }, - "node_modules/indent-string": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-5.0.0.tgz", - "integrity": "sha512-m6FAo/spmsW2Ab2fU35JTYwtOKa2yAwXSwgjSv1TJzh4Mh7mC3lzAOVLBprb72XsTrgkEIsl7YrFNAiDiRhIGg==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/ini": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", @@ -3751,16 +3057,6 @@ "integrity": "sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==", "license": "MIT" }, - "node_modules/irregular-plurals": { - "version": "3.5.0", - "resolved": "https://registry.npmjs.org/irregular-plurals/-/irregular-plurals-3.5.0.tgz", - "integrity": "sha512-1ANGLZ+Nkv1ptFb2pa8oG8Lem4krflKuX/gINiHJHjJUKaJHk/SXk5x6K3J+39/p0h1RQ2saROclJJ+QLvETCQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, "node_modules/is-alphabetical": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/is-alphabetical/-/is-alphabetical-2.0.1.tgz", @@ -3841,19 +3137,6 @@ "node": ">=0.10.0" } }, - "node_modules/is-fullwidth-code-point": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-4.0.0.tgz", - "integrity": "sha512-O4L094N2/dZ7xqVdrXhh9r1KODPJpFms8B5sGdJLPy664AgvXsreZUyCQQNItZRDlYug4xStLjNp/sz3HvBowQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -3909,13 +3192,6 @@ "node": ">=0.10.0" } }, - "node_modules/is-promise": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-4.0.0.tgz", - "integrity": "sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==", - "dev": true, - "license": "MIT" - }, "node_modules/is-stream": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-3.0.0.tgz", @@ -3936,19 +3212,6 @@ "dev": true, "license": "MIT" }, - "node_modules/is-unicode-supported": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/is-unicode-supported/-/is-unicode-supported-2.1.0.tgz", - "integrity": "sha512-mE00Gnza5EEB3Ds0HfMyllZzbBrmLOX3vfWoj9A9PEnTfratQ/BcaJOuMhnkhjXvb2+FkY3VuHqtAGpTPmglFQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/is-wsl": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", @@ -4038,16 +3301,6 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/js-string-escape": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/js-string-escape/-/js-string-escape-1.0.1.tgz", - "integrity": "sha512-Smw4xcfIQ5LVjAOuJCvN/zIodzA/BBSsluuoSykP+lUvScIi4U6RJLfwHet5cxFnCswUjISV8oAXaqaJDY3chg==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.8" - } - }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -4157,27 +3410,14 @@ "dev": true, "license": "MIT" }, - "node_modules/load-json-file": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/load-json-file/-/load-json-file-7.0.1.tgz", - "integrity": "sha512-Gnxj3ev3mB5TkVBGad0JM6dmLiQL+o0t23JPBZ9sd+yvSLk05mFoqKBw5N8gbbkU4TNXyqCgIrl/VM17OgUIgQ==", + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", "dev": true, "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/locate-path": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", - "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", - "dev": true, - "license": "MIT", - "dependencies": { - "p-locate": "^5.0.0" + "dependencies": { + "p-locate": "^5.0.0" }, "engines": { "node": ">=10" @@ -4227,35 +3467,6 @@ "node": "20 || >=22" } }, - "node_modules/matcher": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/matcher/-/matcher-5.0.0.tgz", - "integrity": "sha512-s2EMBOWtXFc8dgqvoAzKJXxNHibcdJMV0gwqKUaw9E2JBJuGUK7DrNKrA6g/i+v72TT16+6sVm5mS3thaMLQUw==", - "dev": true, - "license": "MIT", - "dependencies": { - "escape-string-regexp": "^5.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/matcher/node_modules/escape-string-regexp": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-5.0.0.tgz", - "integrity": "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/mathml-tag-names": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/mathml-tag-names/-/mathml-tag-names-2.1.3.tgz", @@ -4267,19 +3478,6 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/md5-hex": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/md5-hex/-/md5-hex-3.0.1.tgz", - "integrity": "sha512-BUiRtTtV39LIJwinWBjqVsU9xhdnz7/i889V859IBFpuqGAj6LuOvHv5XLbgZ2R7ptJoJaEcxkv88/h25T7Ciw==", - "dev": true, - "license": "MIT", - "dependencies": { - "blueimp-md5": "^2.10.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/mdast-util-from-markdown": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.2.tgz", @@ -4440,22 +3638,6 @@ "dev": true, "license": "CC0-1.0" }, - "node_modules/memoize": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/memoize/-/memoize-10.1.0.tgz", - "integrity": "sha512-MMbFhJzh4Jlg/poq1si90XRlTZRDHVqdlz2mPyGJ6kqMpyHUyVpDd5gpFAvVehW64+RA1eKE9Yt8aSLY7w2Kgg==", - "dev": true, - "license": "MIT", - "dependencies": { - "mimic-function": "^5.0.1" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sindresorhus/memoize?sponsor=1" - } - }, "node_modules/meow": { "version": "13.2.0", "resolved": "https://registry.npmjs.org/meow/-/meow-13.2.0.tgz", @@ -4948,19 +4130,6 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, - "node_modules/mimic-function": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/mimic-function/-/mimic-function-5.0.1.tgz", - "integrity": "sha512-VP79XUPxV2CigYP3jWwAUFSku2aKqBH7uTAapFWCBqutsbmDo96KY5o8uh6U+/YSIn5OxJnXp73beVkpqMIGhA==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/minimatch": { "version": "3.1.5", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", @@ -4984,19 +4153,6 @@ "node": ">=16 || 14 >=14.17" } }, - "node_modules/minizlib": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz", - "integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==", - "dev": true, - "license": "MIT", - "dependencies": { - "minipass": "^7.1.2" - }, - "engines": { - "node": ">= 18" - } - }, "node_modules/mitt": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/mitt/-/mitt-3.0.1.tgz", @@ -5026,9 +4182,9 @@ "license": "MIT" }, "node_modules/nanoid": { - "version": "5.1.5", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-5.1.5.tgz", - "integrity": "sha512-Ir/+ZpE9fDsNH0hQ3C68uyThDXzYcim2EqcZ8zn8Chtt1iylPT9xXJB0kPCnqzgcEGikO9RxSrh63MsmVCU7Fw==", + "version": "5.1.11", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-5.1.11.tgz", + "integrity": "sha512-v+KEsUv2ps74PaSKv0gHTxTCgMXOIfBEbaqa6w6ISIGC7ZsvHN4N9oJ8d4cmf0n5oTzQz2SLmThbQWhjd/8eKg==", "funding": [ { "type": "github", @@ -5050,39 +4206,6 @@ "dev": true, "license": "MIT" }, - "node_modules/node-fetch": { - "version": "2.7.0", - "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", - "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", - "dev": true, - "license": "MIT", - "dependencies": { - "whatwg-url": "^5.0.0" - }, - "engines": { - "node": "4.x || >=6.0.0" - }, - "peerDependencies": { - "encoding": "^0.1.0" - }, - "peerDependenciesMeta": { - "encoding": { - "optional": true - } - } - }, - "node_modules/node-gyp-build": { - "version": "4.8.4", - "resolved": "https://registry.npmjs.org/node-gyp-build/-/node-gyp-build-4.8.4.tgz", - "integrity": "sha512-LA4ZjwlnUblHVgq0oBF3Jl/6h/Nvs5fzBLwdEF4nuxnFdsfajde4WfxtJr3CaiH+F6ewcIB/q4jQ4UzPyid+CQ==", - "dev": true, - "license": "MIT", - "bin": { - "node-gyp-build": "bin.js", - "node-gyp-build-optional": "optional.js", - "node-gyp-build-test": "build-test.js" - } - }, "node_modules/node-notifier": { "version": "10.0.1", "resolved": "https://registry.npmjs.org/node-notifier/-/node-notifier-10.0.1.tgz", @@ -5098,32 +4221,6 @@ "which": "^2.0.2" } }, - "node_modules/nofilter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/nofilter/-/nofilter-3.1.0.tgz", - "integrity": "sha512-l2NNj07e9afPnhAhvgVrCD/oy2Ai1yfLpuo3EpiO1jFTsB4sFz6oIfAfSZyQzVpkZQ9xS8ZS5g1jCBgq4Hwo0g==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12.19" - } - }, - "node_modules/nopt": { - "version": "8.1.0", - "resolved": "https://registry.npmjs.org/nopt/-/nopt-8.1.0.tgz", - "integrity": "sha512-ieGu42u/Qsa4TFktmaKEwM6MQH0pOWnaB3htzh0JRtx84+Mebc0cbZYN5bC+6WTZ4+77xrL9Pn5m7CV6VIkV7A==", - "dev": true, - "license": "ISC", - "dependencies": { - "abbrev": "^3.0.0" - }, - "bin": { - "nopt": "bin/nopt.js" - }, - "engines": { - "node": "^18.17.0 || >=20.5.0" - } - }, "node_modules/normalize-path": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", @@ -5184,36 +4281,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/p-map": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/p-map/-/p-map-7.0.3.tgz", - "integrity": "sha512-VkndIv2fIB99swvQoA65bm+fsmt6UNdGeIB0oxBs+WhAhdh08QA04JXpI7rbB9r08/nkbysKoya9rtDERYOYMA==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/package-config": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/package-config/-/package-config-5.0.0.tgz", - "integrity": "sha512-GYTTew2slBcYdvRHqjhwaaydVMvn/qrGC323+nKclYioNSLTDUM/lGgtGTgyHVtYcozb+XkE8CNhwcraOmZ9Mg==", - "dev": true, - "license": "MIT", - "dependencies": { - "find-up-simple": "^1.0.0", - "load-json-file": "^7.0.1" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/package-json-from-dist": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", @@ -5288,19 +4355,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/parse-ms": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/parse-ms/-/parse-ms-4.0.0.tgz", - "integrity": "sha512-TXfryirbmq34y8QBwgqCVLi+8oA3oWx2eAnSn62ITyEhEYaWRlVZ2DvMM9eZbMs/RfxPu/PK/aBLyGj4IrqMHw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -5348,19 +4402,6 @@ "url": "https://opencollective.com/express" } }, - "node_modules/path-type": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/path-type/-/path-type-6.0.0.tgz", - "integrity": "sha512-Vj7sf++t5pBD637NSfkxpHSMfWaeig5+DKWLhcqIYx6mWQz5hdJTGDVMQiJcw1ZYkhs7AazKDGpRVji1LJCZUQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -5368,45 +4409,6 @@ "dev": true, "license": "ISC" }, - "node_modules/picomatch": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", - "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/jonschlinkert" - } - }, - "node_modules/pirates": { - "version": "4.0.7", - "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.7.tgz", - "integrity": "sha512-TfySrs/5nm8fQJDcBDuUng3VOUKsd7S+zqvbOTiGXHfxX4wK31ard+hoNuvkicM/2YFzlpDgABOevKSsB4G/FA==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/plur": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/plur/-/plur-5.1.0.tgz", - "integrity": "sha512-VP/72JeXqak2KiOzjgKtQen5y3IZHn+9GOuLDafPv0eXa47xq0At93XahYBs26MsifCQ4enGKwbjBTKgb9QJXg==", - "dev": true, - "license": "MIT", - "dependencies": { - "irregular-plurals": "^3.3.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/postcss": { "version": "8.5.6", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz", @@ -5616,22 +4618,6 @@ } } }, - "node_modules/pretty-ms": { - "version": "9.2.0", - "resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.2.0.tgz", - "integrity": "sha512-4yf0QO/sllf/1zbZWYnvWw3NxCQwLXKzIj0G849LSufP15BXKM0rbD2Z3wVnkMfjdn/CB0Dpp444gYAACdsplg==", - "dev": true, - "license": "MIT", - "dependencies": { - "parse-ms": "^4.0.0" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/property-information": { "version": "7.1.0", "resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz", @@ -5797,19 +4783,6 @@ "node": ">=0.10.0" } }, - "node_modules/resolve-cwd": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", - "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", - "dev": true, - "license": "MIT", - "dependencies": { - "resolve-from": "^5.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/resolve-from": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", @@ -5820,6 +4793,16 @@ "node": ">=8" } }, + "node_modules/resolve-pkg-maps": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", + "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" + } + }, "node_modules/reusify": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", @@ -5831,26 +4814,6 @@ "node": ">=0.10.0" } }, - "node_modules/rimraf": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-6.0.1.tgz", - "integrity": "sha512-9dkvaxAsk/xNXSJzMgFqqMCuFgt2+KsOFek3TMLfo8NCPfWpBmqwyNn5Y+NX56QUYfCtsyhF3ayiboEoUmJk/A==", - "dev": true, - "license": "ISC", - "dependencies": { - "glob": "^11.0.0", - "package-json-from-dist": "^1.0.0" - }, - "bin": { - "rimraf": "dist/esm/bin.mjs" - }, - "engines": { - "node": "20 || >=22" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -5903,22 +4866,6 @@ "node": ">=10" } }, - "node_modules/serialize-error": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz", - "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==", - "dev": true, - "license": "MIT", - "dependencies": { - "type-fest": "^0.13.1" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -5962,208 +4909,24 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/slash": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/slash/-/slash-5.1.0.tgz", - "integrity": "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==", + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", "dev": true, - "license": "MIT", + "license": "BSD-3-Clause", "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=0.10.0" } }, - "node_modules/slice-ansi": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-5.0.0.tgz", - "integrity": "sha512-FC+lgizVPfie0kkhqUScwRu1O/lF6NOgJmlCgK+/LYxDCTk8sGelYaHDhFcDN+Sn3Cv+3VSa4Byeo+IMCzpMgQ==", - "dev": true, + "node_modules/space-separated-tokens": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz", + "integrity": "sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==", "license": "MIT", - "dependencies": { - "ansi-styles": "^6.0.0", - "is-fullwidth-code-point": "^4.0.0" - }, - "engines": { - "node": ">=12" - }, "funding": { - "url": "https://github.com/chalk/slice-ansi?sponsor=1" - } - }, - "node_modules/sock-daemon": { - "version": "1.4.2", - "resolved": "https://registry.npmjs.org/sock-daemon/-/sock-daemon-1.4.2.tgz", - "integrity": "sha512-IzbegWshWWR+UzQ7487mbdYNmfJ1jXUXQBUHooqtpylO+aW0vMVbFN2d2ug3CSPZ0wbG7ZTTGwpUuthIDFIOGg==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "rimraf": "^5.0.5", - "signal-exit": "^4.1.0", - "socket-post-message": "^1.0.3" - }, - "engines": { - "node": "16 >=16.17.0 || 18 >= 18.6.0 || >=20" - } - }, - "node_modules/sock-daemon/node_modules/brace-expansion": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz", - "integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==", - "dev": true, - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/sock-daemon/node_modules/glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/sock-daemon/node_modules/jackspeak": { - "version": "3.4.3", - "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", - "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/cliui": "^8.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - }, - "optionalDependencies": { - "@pkgjs/parseargs": "^0.11.0" - } - }, - "node_modules/sock-daemon/node_modules/lru-cache": { - "version": "10.4.3", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", - "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", - "dev": true, - "license": "ISC" - }, - "node_modules/sock-daemon/node_modules/minimatch": { - "version": "9.0.9", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", - "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.2" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/sock-daemon/node_modules/path-scurry": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", - "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", - "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "lru-cache": "^10.2.0", - "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" - }, - "engines": { - "node": ">=16 || 14 >=14.18" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/sock-daemon/node_modules/rimraf": { - "version": "5.0.10", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-5.0.10.tgz", - "integrity": "sha512-l0OE8wL34P4nJH/H2ffoaniAokM2qSmrtXHmlpvYr5AVVX8msAyW0l8NVJFDxlSK4u3Uh/f41cQheDVdnYijwQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "glob": "^10.3.7" - }, - "bin": { - "rimraf": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/socket-post-message": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/socket-post-message/-/socket-post-message-1.0.3.tgz", - "integrity": "sha512-UhJaB3xR2oF+HvddFOq2cBZi4zVKOHvdiBo+BaScNxsEUg3TLWSP8BkweKfe07kfH1thjn1hJR0af/w1EtBFjg==", - "dev": true - }, - "node_modules/source-map-js": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", - "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", - "dev": true, - "license": "BSD-3-Clause", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/space-separated-tokens": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz", - "integrity": "sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==", - "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/wooorm" - } - }, - "node_modules/sprintf-js": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", - "integrity": "sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==", - "dev": true, - "license": "BSD-3-Clause" - }, - "node_modules/stack-utils": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/stack-utils/-/stack-utils-2.0.6.tgz", - "integrity": "sha512-XlkWvfIm6RmsWtNJx+uqtKLS8eqFbxUg0ZzLXqY0caEy9l7hruX8IpiDnjsLavoBgqCCR71TqWO8MaXYheJ3RQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "escape-string-regexp": "^2.0.0" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/stack-utils/node_modules/escape-string-regexp": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz", - "integrity": "sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" + "type": "github", + "url": "https://github.com/sponsors/wooorm" } }, "node_modules/string-argv": { @@ -6176,24 +4939,6 @@ "node": ">=0.6.19" } }, - "node_modules/string-width": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz", - "integrity": "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "emoji-regex": "^10.3.0", - "get-east-asian-width": "^1.0.0", - "strip-ansi": "^7.1.0" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/string-width-cjs": { "name": "string-width", "version": "4.2.3", @@ -6576,46 +5321,6 @@ "node": "^14.17.0 || ^16.13.0 || >=18.0.0" } }, - "node_modules/supertap": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/supertap/-/supertap-3.0.1.tgz", - "integrity": "sha512-u1ZpIBCawJnO+0QePsEiOknOfCRq0yERxiAchT0i4li0WHNUJbf0evXXSXOcCAR4M8iMDoajXYmstm/qO81Isw==", - "dev": true, - "license": "MIT", - "dependencies": { - "indent-string": "^5.0.0", - "js-yaml": "^3.14.1", - "serialize-error": "^7.0.1", - "strip-ansi": "^7.0.1" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, - "node_modules/supertap/node_modules/argparse": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", - "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", - "dev": true, - "license": "MIT", - "dependencies": { - "sprintf-js": "~1.0.2" - } - }, - "node_modules/supertap/node_modules/js-yaml": { - "version": "3.14.2", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", - "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", - "dev": true, - "license": "MIT", - "dependencies": { - "argparse": "^1.0.7", - "esprima": "^4.0.0" - }, - "bin": { - "js-yaml": "bin/js-yaml.js" - } - }, "node_modules/supports-color": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", @@ -6760,183 +5465,626 @@ "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", "dev": true, "license": "MIT", - "dependencies": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.1" - }, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/table/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/temp-dir": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/temp-dir/-/temp-dir-3.0.0.tgz", + "integrity": "sha512-nHc6S/bwIilKHNRgK/3jlhDoIHcp45YgyiwcAk46Tr0LfEqGBVpmiAyuiuxeVE44m3mXnEeVhaipLOEWmH+Njw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.16" + } + }, + "node_modules/tempy": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/tempy/-/tempy-3.1.0.tgz", + "integrity": "sha512-7jDLIdD2Zp0bDe5r3D2qtkd1QOCacylBuL7oa4udvN6v2pqr4+LcCr67C8DR1zkpaZ8XosF5m1yQSabKAW6f2g==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-stream": "^3.0.0", + "temp-dir": "^3.0.0", + "type-fest": "^2.12.2", + "unique-string": "^3.0.0" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/tempy/node_modules/type-fest": { + "version": "2.19.0", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", + "integrity": "sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==", + "dev": true, + "license": "(MIT OR CC0-1.0)", + "engines": { + "node": ">=12.20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/trim-lines": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", + "integrity": "sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/trough": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/trough/-/trough-2.2.0.tgz", + "integrity": "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/ts-api-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", + "integrity": "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/tsx": { + "version": "4.21.0", + "resolved": "https://registry.npmjs.org/tsx/-/tsx-4.21.0.tgz", + "integrity": "sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==", + "dev": true, + "license": "MIT", + "dependencies": { + "esbuild": "~0.27.0", + "get-tsconfig": "^4.7.5" + }, + "bin": { + "tsx": "dist/cli.mjs" + }, + "engines": { + "node": ">=18.0.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + } + }, + "node_modules/tsx/node_modules/@esbuild/aix-ppc64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.27.7.tgz", + "integrity": "sha512-EKX3Qwmhz1eMdEJokhALr0YiD0lhQNwDqkPYyPhiSwKrh7/4KRjQc04sZ8db+5DVVnZ1LmbNDI1uAMPEUBnQPg==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/android-arm": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.27.7.tgz", + "integrity": "sha512-jbPXvB4Yj2yBV7HUfE2KHe4GJX51QplCN1pGbYjvsyCZbQmies29EoJbkEc+vYuU5o45AfQn37vZlyXy4YJ8RQ==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/android-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.27.7.tgz", + "integrity": "sha512-62dPZHpIXzvChfvfLJow3q5dDtiNMkwiRzPylSCfriLvZeq0a1bWChrGx/BbUbPwOrsWKMn8idSllklzBy+dgQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/android-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.27.7.tgz", + "integrity": "sha512-x5VpMODneVDb70PYV2VQOmIUUiBtY3D3mPBG8NxVk5CogneYhkR7MmM3yR/uMdITLrC1ml/NV1rj4bMJuy9MCg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/darwin-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.27.7.tgz", + "integrity": "sha512-5lckdqeuBPlKUwvoCXIgI2D9/ABmPq3Rdp7IfL70393YgaASt7tbju3Ac+ePVi3KDH6N2RqePfHnXkaDtY9fkw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/darwin-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.27.7.tgz", + "integrity": "sha512-rYnXrKcXuT7Z+WL5K980jVFdvVKhCHhUwid+dDYQpH+qu+TefcomiMAJpIiC2EM3Rjtq0sO3StMV/+3w3MyyqQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/freebsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.27.7.tgz", + "integrity": "sha512-B48PqeCsEgOtzME2GbNM2roU29AMTuOIN91dsMO30t+Ydis3z/3Ngoj5hhnsOSSwNzS+6JppqWsuhTp6E82l2w==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/freebsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.27.7.tgz", + "integrity": "sha512-jOBDK5XEjA4m5IJK3bpAQF9/Lelu/Z9ZcdhTRLf4cajlB+8VEhFFRjWgfy3M1O4rO2GQ/b2dLwCUGpiF/eATNQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-arm": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.27.7.tgz", + "integrity": "sha512-RkT/YXYBTSULo3+af8Ib0ykH8u2MBh57o7q/DAs3lTJlyVQkgQvlrPTnjIzzRPQyavxtPtfg0EopvDyIt0j1rA==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.27.7.tgz", + "integrity": "sha512-RZPHBoxXuNnPQO9rvjh5jdkRmVizktkT7TCDkDmQ0W2SwHInKCAV95GRuvdSvA7w4VMwfCjUiPwDi0ZO6Nfe9A==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-ia32": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.27.7.tgz", + "integrity": "sha512-GA48aKNkyQDbd3KtkplYWT102C5sn/EZTY4XROkxONgruHPU72l+gW+FfF8tf2cFjeHaRbWpOYa/uRBz/Xq1Pg==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-loong64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.27.7.tgz", + "integrity": "sha512-a4POruNM2oWsD4WKvBSEKGIiWQF8fZOAsycHOt6JBpZ+JN2n2JH9WAv56SOyu9X5IqAjqSIPTaJkqN8F7XOQ5Q==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-mips64el": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.27.7.tgz", + "integrity": "sha512-KabT5I6StirGfIz0FMgl1I+R1H73Gp0ofL9A3nG3i/cYFJzKHhouBV5VWK1CSgKvVaG4q1RNpCTR2LuTVB3fIw==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-ppc64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.27.7.tgz", + "integrity": "sha512-gRsL4x6wsGHGRqhtI+ifpN/vpOFTQtnbsupUF5R5YTAg+y/lKelYR1hXbnBdzDjGbMYjVJLJTd2OFmMewAgwlQ==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-riscv64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.27.7.tgz", + "integrity": "sha512-hL25LbxO1QOngGzu2U5xeXtxXcW+/GvMN3ejANqXkxZ/opySAZMrc+9LY/WyjAan41unrR3YrmtTsUpwT66InQ==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/tsx/node_modules/@esbuild/linux-s390x": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.27.7.tgz", + "integrity": "sha512-2k8go8Ycu1Kb46vEelhu1vqEP+UeRVj2zY1pSuPdgvbd5ykAw82Lrro28vXUrRmzEsUV0NzCf54yARIK8r0fdw==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], "engines": { - "node": ">=8" + "node": ">=18" } }, - "node_modules/table/node_modules/strip-ansi": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", - "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "node_modules/tsx/node_modules/@esbuild/linux-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.27.7.tgz", + "integrity": "sha512-hzznmADPt+OmsYzw1EE33ccA+HPdIqiCRq7cQeL1Jlq2gb1+OyWBkMCrYGBJ+sxVzve2ZJEVeePbLM2iEIZSxA==", + "cpu": [ + "x64" + ], "dev": true, "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1" - }, + "optional": true, + "os": [ + "linux" + ], "engines": { - "node": ">=8" + "node": ">=18" } }, - "node_modules/tar": { - "version": "7.5.11", - "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.11.tgz", - "integrity": "sha512-ChjMH33/KetonMTAtpYdgUFr0tbz69Fp2v7zWxQfYZX4g5ZN2nOBXm1R2xyA+lMIKrLKIoKAwFj93jE/avX9cQ==", + "node_modules/tsx/node_modules/@esbuild/netbsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.27.7.tgz", + "integrity": "sha512-b6pqtrQdigZBwZxAn1UpazEisvwaIDvdbMbmrly7cDTMFnw/+3lVxxCTGOrkPVnsYIosJJXAsILG9XcQS+Yu6w==", + "cpu": [ + "arm64" + ], "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/fs-minipass": "^4.0.0", - "chownr": "^3.0.0", - "minipass": "^7.1.2", - "minizlib": "^3.1.0", - "yallist": "^5.0.0" - }, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], "engines": { "node": ">=18" } }, - "node_modules/temp-dir": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/temp-dir/-/temp-dir-3.0.0.tgz", - "integrity": "sha512-nHc6S/bwIilKHNRgK/3jlhDoIHcp45YgyiwcAk46Tr0LfEqGBVpmiAyuiuxeVE44m3mXnEeVhaipLOEWmH+Njw==", + "node_modules/tsx/node_modules/@esbuild/netbsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.27.7.tgz", + "integrity": "sha512-OfatkLojr6U+WN5EDYuoQhtM+1xco+/6FSzJJnuWiUw5eVcicbyK3dq5EeV/QHT1uy6GoDhGbFpprUiHUYggrw==", + "cpu": [ + "x64" + ], "dev": true, "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], "engines": { - "node": ">=14.16" + "node": ">=18" } }, - "node_modules/tempy": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/tempy/-/tempy-3.1.0.tgz", - "integrity": "sha512-7jDLIdD2Zp0bDe5r3D2qtkd1QOCacylBuL7oa4udvN6v2pqr4+LcCr67C8DR1zkpaZ8XosF5m1yQSabKAW6f2g==", + "node_modules/tsx/node_modules/@esbuild/openbsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.27.7.tgz", + "integrity": "sha512-AFuojMQTxAz75Fo8idVcqoQWEHIXFRbOc1TrVcFSgCZtQfSdc1RXgB3tjOn/krRHENUB4j00bfGjyl2mJrU37A==", + "cpu": [ + "arm64" + ], "dev": true, "license": "MIT", - "dependencies": { - "is-stream": "^3.0.0", - "temp-dir": "^3.0.0", - "type-fest": "^2.12.2", - "unique-string": "^3.0.0" - }, + "optional": true, + "os": [ + "openbsd" + ], "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=18" } }, - "node_modules/tempy/node_modules/type-fest": { - "version": "2.19.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", - "integrity": "sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==", + "node_modules/tsx/node_modules/@esbuild/openbsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.27.7.tgz", + "integrity": "sha512-+A1NJmfM8WNDv5CLVQYJ5PshuRm/4cI6WMZRg1by1GwPIQPCTs1GLEUHwiiQGT5zDdyLiRM/l1G0Pv54gvtKIg==", + "cpu": [ + "x64" + ], "dev": true, - "license": "(MIT OR CC0-1.0)", + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], "engines": { - "node": ">=12.20" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=18" } }, - "node_modules/time-zone": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/time-zone/-/time-zone-1.0.0.tgz", - "integrity": "sha512-TIsDdtKo6+XrPtiTm1ssmMngN1sAhyKnTO2kunQWqNPWIVvCm15Wmw4SWInwTVgJ5u/Tr04+8Ei9TNcw4x4ONA==", + "node_modules/tsx/node_modules/@esbuild/openharmony-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.27.7.tgz", + "integrity": "sha512-+KrvYb/C8zA9CU/g0sR6w2RBw7IGc5J2BPnc3dYc5VJxHCSF1yNMxTV5LQ7GuKteQXZtspjFbiuW5/dOj7H4Yw==", + "cpu": [ + "arm64" + ], "dev": true, "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], "engines": { - "node": ">=4" + "node": ">=18" } }, - "node_modules/to-regex-range": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", - "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "node_modules/tsx/node_modules/@esbuild/sunos-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.27.7.tgz", + "integrity": "sha512-ikktIhFBzQNt/QDyOL580ti9+5mL/YZeUPKU2ivGtGjdTYoqz6jObj6nOMfhASpS4GU4Q/Clh1QtxWAvcYKamA==", + "cpu": [ + "x64" + ], "dev": true, "license": "MIT", - "dependencies": { - "is-number": "^7.0.0" - }, + "optional": true, + "os": [ + "sunos" + ], "engines": { - "node": ">=8.0" + "node": ">=18" } }, - "node_modules/tr46": { - "version": "0.0.3", - "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", - "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "node_modules/tsx/node_modules/@esbuild/win32-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.27.7.tgz", + "integrity": "sha512-7yRhbHvPqSpRUV7Q20VuDwbjW5kIMwTHpptuUzV+AA46kiPze5Z7qgt6CLCK3pWFrHeNfDd1VKgyP4O+ng17CA==", + "cpu": [ + "arm64" + ], "dev": true, - "license": "MIT" - }, - "node_modules/trim-lines": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", - "integrity": "sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==", "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/wooorm" + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" } }, - "node_modules/trough": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/trough/-/trough-2.2.0.tgz", - "integrity": "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==", + "node_modules/tsx/node_modules/@esbuild/win32-ia32": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.27.7.tgz", + "integrity": "sha512-SmwKXe6VHIyZYbBLJrhOoCJRB/Z1tckzmgTLfFYOfpMAx63BJEaL9ExI8x7v0oAO3Zh6D/Oi1gVxEYr5oUCFhw==", + "cpu": [ + "ia32" + ], + "dev": true, "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/wooorm" + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" } }, - "node_modules/ts-api-utils": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", - "integrity": "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==", + "node_modules/tsx/node_modules/@esbuild/win32-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.27.7.tgz", + "integrity": "sha512-56hiAJPhwQ1R4i+21FVF7V8kSD5zZTdHcVuRFMW0hn753vVfQN8xlx4uOPT4xoGH0Z/oVATuR82AiqSTDIpaHg==", + "cpu": [ + "x64" + ], "dev": true, "license": "MIT", + "optional": true, + "os": [ + "win32" + ], "engines": { - "node": ">=18.12" - }, - "peerDependencies": { - "typescript": ">=4.8.4" + "node": ">=18" } }, - "node_modules/tsimp": { - "version": "2.0.12", - "resolved": "https://registry.npmjs.org/tsimp/-/tsimp-2.0.12.tgz", - "integrity": "sha512-0XbhMfDB1BlN4iuheUaCUVB2iAjWb9z6Ik/6WcxREc4MhjYmkScK+CRNf34wkDO8wMvmFBb0lYdrd8H44g9yjg==", + "node_modules/tsx/node_modules/esbuild": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.7.tgz", + "integrity": "sha512-IxpibTjyVnmrIQo5aqNpCgoACA/dTKLTlhMHihVHhdkxKyPO1uBBthumT0rdHmcsk9uMonIWS0m4FljWzILh3w==", "dev": true, - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/cached": "^1.0.1", - "@isaacs/catcher": "^1.0.4", - "foreground-child": "^3.1.1", - "mkdirp": "^3.0.1", - "pirates": "^4.0.6", - "rimraf": "^6.0.1", - "signal-exit": "^4.1.0", - "sock-daemon": "^1.4.2", - "walk-up-path": "^4.0.0" - }, + "hasInstallScript": true, + "license": "MIT", "bin": { - "tsimp": "dist/esm/bin.mjs" + "esbuild": "bin/esbuild" }, "engines": { - "node": "16 >=16.17.0 || 18 >= 18.6.0 || >=20" + "node": ">=18" }, - "peerDependencies": { - "typescript": "^5.1.0" + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.27.7", + "@esbuild/android-arm": "0.27.7", + "@esbuild/android-arm64": "0.27.7", + "@esbuild/android-x64": "0.27.7", + "@esbuild/darwin-arm64": "0.27.7", + "@esbuild/darwin-x64": "0.27.7", + "@esbuild/freebsd-arm64": "0.27.7", + "@esbuild/freebsd-x64": "0.27.7", + "@esbuild/linux-arm": "0.27.7", + "@esbuild/linux-arm64": "0.27.7", + "@esbuild/linux-ia32": "0.27.7", + "@esbuild/linux-loong64": "0.27.7", + "@esbuild/linux-mips64el": "0.27.7", + "@esbuild/linux-ppc64": "0.27.7", + "@esbuild/linux-riscv64": "0.27.7", + "@esbuild/linux-s390x": "0.27.7", + "@esbuild/linux-x64": "0.27.7", + "@esbuild/netbsd-arm64": "0.27.7", + "@esbuild/netbsd-x64": "0.27.7", + "@esbuild/openbsd-arm64": "0.27.7", + "@esbuild/openbsd-x64": "0.27.7", + "@esbuild/openharmony-arm64": "0.27.7", + "@esbuild/sunos-x64": "0.27.7", + "@esbuild/win32-arm64": "0.27.7", + "@esbuild/win32-ia32": "0.27.7", + "@esbuild/win32-x64": "0.27.7" } }, - "node_modules/tslib": { - "version": "2.8.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", - "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", - "license": "0BSD" - }, "node_modules/type-check": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", @@ -6950,19 +6098,6 @@ "node": ">= 0.8.0" } }, - "node_modules/type-fest": { - "version": "0.13.1", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", - "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==", - "dev": true, - "license": "(MIT OR CC0-1.0)", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/typed-css-modules": { "version": "0.9.1", "resolved": "https://registry.npmjs.org/typed-css-modules/-/typed-css-modules-0.9.1.tgz", @@ -7213,18 +6348,12 @@ "typescript": ">=4.8.4 <6.0.0" } }, - "node_modules/unicorn-magic": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/unicorn-magic/-/unicorn-magic-0.3.0.tgz", - "integrity": "sha512-+QBBXBCvifc56fsbuxZQ6Sic3wqqc3WWaqxs58gvJrcOuN83HGTCwz3oS5phzU9LthRNE9VrJCFCLUgHeeFnfA==", + "node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", "dev": true, - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } + "license": "MIT" }, "node_modules/unified": { "version": "11.0.5", @@ -7413,44 +6542,6 @@ "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", "license": "MIT" }, - "node_modules/walk-up-path": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/walk-up-path/-/walk-up-path-4.0.0.tgz", - "integrity": "sha512-3hu+tD8YzSLGuFYtPRb48vdhKMi0KQV5sn+uWr8+7dMEq/2G/dtLrdDinkLjqq5TIbIBjYJ4Ax/n3YiaW7QM8A==", - "dev": true, - "license": "ISC", - "engines": { - "node": "20 || >=22" - } - }, - "node_modules/webidl-conversions": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", - "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", - "dev": true, - "license": "BSD-2-Clause" - }, - "node_modules/well-known-symbols": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/well-known-symbols/-/well-known-symbols-2.0.0.tgz", - "integrity": "sha512-ZMjC3ho+KXo0BfJb7JgtQ5IBuvnShdlACNkKkdsqBmYw3bPAaJfPeYUo6tLUaT5tG/Gkh7xkpBhKRQ9e7pyg9Q==", - "dev": true, - "license": "ISC", - "engines": { - "node": ">=6" - } - }, - "node_modules/whatwg-url": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", - "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", - "dev": true, - "license": "MIT", - "dependencies": { - "tr46": "~0.0.3", - "webidl-conversions": "^3.0.0" - } - }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -7624,20 +6715,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/write-file-atomic": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/write-file-atomic/-/write-file-atomic-6.0.0.tgz", - "integrity": "sha512-GmqrO8WJ1NuzJ2DrziEI2o57jKAVIQNf8a18W3nCYU3H7PNWqCCVTeH6/NQE93CIllIgQS98rrmVkYgTX9fFJQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "imurmurhash": "^0.1.4", - "signal-exit": "^4.0.1" - }, - "engines": { - "node": "^18.17.0 || >=20.5.0" - } - }, "node_modules/y18n": { "version": "5.0.8", "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", @@ -7648,16 +6725,6 @@ "node": ">=10" } }, - "node_modules/yallist": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", - "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", - "dev": true, - "license": "BlueOak-1.0.0", - "engines": { - "node": ">=18" - } - }, "node_modules/yargs": { "version": "17.7.2", "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", @@ -7773,6 +6840,26 @@ "type": "github", "url": "https://github.com/sponsors/wooorm" } + }, + "paddler_client_javascript": { + "name": "@intentee/paddler-client", + "version": "3.1.2", + "license": "Apache-2.0", + "dependencies": { + "nanoid": "^5.1.11" + }, + "devDependencies": { + "@types/node": "^22", + "tsx": "^4.21.0", + "typescript": "^5.9.2" + }, + "engines": { + "node": ">=22" + }, + "peerDependencies": { + "rxjs": "^7.8", + "zod": "^4" + } } } } diff --git a/package.json b/package.json index d2c84c10..cdf98de8 100644 --- a/package.json +++ b/package.json @@ -1,17 +1,8 @@ { - "ava": { - "extensions": { - "ts": "module" - }, - "nodeArguments": [ - "--import=tsimp" - ] - }, "devDependencies": { "@types/hotwired__turbo": "^8", "@types/react": "^19.1.10", "@types/react-dom": "^19.1.7", - "ava": "^6.4.1", "esbuild": "^0.25.8", "eslint": "^9.33.0", "eslint-plugin-react-hooks": "^5.2.0", @@ -25,7 +16,6 @@ "stylelint": "^16.23.1", "stylelint-config-recommended": "^17.0.0", "tempy": "^3.1.0", - "tsimp": "^2.0.12", "tslib": "^2.8.1", "typed-css-modules": "^0.9.1", "typescript": "^5.9.2", @@ -35,7 +25,7 @@ "@codemirror/lang-jinja": "^6.0.0", "@uiw/react-codemirror": "^4.24.2", "clsx": "^2.1.1", - "nanoid": "^5.1.5", + "nanoid": "^5.1.11", "path-to-regexp": "^8.4.0", "react": "^19.1.1", "react-dom": "^19.1.1", @@ -46,10 +36,12 @@ }, "license": "Apache-2.0", "name": "paddler", + "private": true, "overrides": { "node-notifier": { "uuid": "^14.0.0" } }, - "type": "module" + "type": "module", + "workspaces": ["paddler_client_javascript"] } diff --git a/paddler/Cargo.toml b/paddler/Cargo.toml index 21f7ec99..603b3f6c 100644 --- a/paddler/Cargo.toml +++ b/paddler/Cargo.toml @@ -71,6 +71,7 @@ vulkan = ["llama-cpp-bindings/vulkan"] [dev-dependencies] tempfile = { workspace = true } +tokio-test = { workspace = true } [lints] workspace = true diff --git a/paddler/src/agent/continue_from_conversation_history_request.rs b/paddler/src/agent/continue_from_conversation_history_request.rs index af68c877..1b16aa6b 100644 --- a/paddler/src/agent/continue_from_conversation_history_request.rs +++ b/paddler/src/agent/continue_from_conversation_history_request.rs @@ -1,14 +1,19 @@ +use std::sync::Arc; + use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use tokio::sync::mpsc; use crate::agent::from_request_params::FromRequestParams; +use crate::agent::slot_guard::SlotGuard; +use crate::slot_aggregated_status::SlotAggregatedStatus; pub struct ContinueFromConversationHistoryRequest { pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, pub generated_tokens_tx: mpsc::UnboundedSender, pub params: ContinueFromConversationHistoryParams, + pub slot_guard: SlotGuard, } impl FromRequestParams for ContinueFromConversationHistoryRequest { @@ -19,11 +24,13 @@ impl FromRequestParams for ContinueFromConversationHistoryRequest { params: Self::RequestParams, generated_tokens_tx: mpsc::UnboundedSender, generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, + slot_aggregated_status: Arc, ) -> Self { Self { generate_tokens_stop_rx, generated_tokens_tx, params, + slot_guard: SlotGuard::new(slot_aggregated_status), } } } diff --git a/paddler/src/agent/continue_from_raw_prompt_request.rs b/paddler/src/agent/continue_from_raw_prompt_request.rs index f6365ae1..9f573e74 100644 --- a/paddler/src/agent/continue_from_raw_prompt_request.rs +++ b/paddler/src/agent/continue_from_raw_prompt_request.rs @@ -1,13 +1,18 @@ +use std::sync::Arc; + use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use tokio::sync::mpsc; use crate::agent::from_request_params::FromRequestParams; +use crate::agent::slot_guard::SlotGuard; +use crate::slot_aggregated_status::SlotAggregatedStatus; pub struct ContinueFromRawPromptRequest { pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, pub generated_tokens_tx: mpsc::UnboundedSender, pub params: ContinueFromRawPromptParams, + pub slot_guard: SlotGuard, } impl FromRequestParams for ContinueFromRawPromptRequest { @@ -18,11 +23,13 @@ impl FromRequestParams for ContinueFromRawPromptRequest { params: Self::RequestParams, generated_tokens_tx: mpsc::UnboundedSender, generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, + slot_aggregated_status: Arc, ) -> Self { Self { generate_tokens_stop_rx, generated_tokens_tx, params, + slot_guard: SlotGuard::new(slot_aggregated_status), } } } diff --git a/paddler/src/agent/continuous_batch_active_request.rs b/paddler/src/agent/continuous_batch_active_request.rs index 1eb0413f..13d239ef 100644 --- a/paddler/src/agent/continuous_batch_active_request.rs +++ b/paddler/src/agent/continuous_batch_active_request.rs @@ -1,3 +1,5 @@ +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::SampledTokenClassifier; use llama_cpp_bindings::sampling::LlamaSampler; use llama_cpp_bindings::token::LlamaToken; use log::warn; @@ -6,22 +8,25 @@ use tokio::sync::mpsc; use tokio::sync::mpsc::error::TryRecvError; use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::agent::slot_guard::SlotGuard; +use crate::tool_call_pipeline::ToolCallPipeline; pub struct ContinuousBatchActiveRequest { pub chain: LlamaSampler, + pub token_classifier: SampledTokenClassifier<'static>, pub current_token_position: i32, pub grammar_sampler: Option, - pub generated_tokens_count: i32, pub generated_tokens_tx: mpsc::UnboundedSender, pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, pub i_batch: Option, pub max_tokens: i32, - pub pending_sampled_token: Option, + pub pending_sampled_token: Option, pub phase: ContinuousBatchRequestPhase, pub prompt_tokens: Vec, pub prompt_tokens_ingested: usize, pub sequence_id: i32, - pub utf8_decoder: encoding_rs::Decoder, + pub slot_guard: SlotGuard, + pub tool_call_pipeline: Option, } impl ContinuousBatchActiveRequest { diff --git a/paddler/src/agent/continuous_batch_arbiter.rs b/paddler/src/agent/continuous_batch_arbiter.rs index 384c542b..d662d0ec 100644 --- a/paddler/src/agent/continuous_batch_arbiter.rs +++ b/paddler/src/agent/continuous_batch_arbiter.rs @@ -8,8 +8,11 @@ use std::thread::available_parallelism; use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::context::LlamaContext; use llama_cpp_bindings::context::params::LlamaContextParams; use llama_cpp_bindings::llama_backend::LlamaBackend; +use llama_cpp_bindings::llama_batch::LlamaBatch; use llama_cpp_bindings::model::LlamaModel; use llama_cpp_bindings::model::params::LlamaModelParams; use llama_cpp_bindings::mtmd::MtmdContext; @@ -17,6 +20,7 @@ use llama_cpp_bindings::mtmd::MtmdContextParams; use llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_AUTO; use log::error; use log::info; +use log::warn; use paddler_types::agent_issue::AgentIssue; use paddler_types::agent_issue_params::ChatTemplateDoesNotCompileParams; use paddler_types::agent_issue_params::ModelPath; @@ -26,10 +30,12 @@ use paddler_types::inference_parameters::InferenceParameters; use paddler_types::model_metadata::ModelMetadata; use tokio::sync::oneshot; +use crate::agent::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; use crate::agent::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; use crate::agent::continuous_batch_scheduler::ContinuousBatchScheduler; use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; use crate::agent::model_metadata_holder::ModelMetadataHolder; +use crate::agent_applicable_state::AgentApplicableState; use crate::agent_issue_fix::AgentIssueFix; use crate::chat_template_renderer::ChatTemplateRenderer; use crate::converts_to_llama_kv_cache_dtype::ConvertsToLlamaKvCacheDtype; @@ -49,9 +55,38 @@ pub struct ContinuousBatchArbiter { } impl ContinuousBatchArbiter { + #[must_use] + pub fn build_from_applicable_state( + applicable_state: AgentApplicableState, + agent_name: Option, + desired_slots_total: i32, + model_metadata_holder: Arc, + slot_aggregated_status_manager: Arc, + ) -> ContinuousBatchArbiterBuildOutcome { + let Some(model_path) = applicable_state.model_path else { + return ContinuousBatchArbiterBuildOutcome::NoModelConfigured; + }; + + let model_path_string = model_path.display().to_string(); + + ContinuousBatchArbiterBuildOutcome::ReadyToSpawn(Box::new(Self { + agent_name, + chat_template_override: applicable_state.chat_template_override, + desired_slots_total, + inference_parameters: applicable_state.inference_parameters, + multimodal_projection_path: applicable_state.multimodal_projection_path, + model_metadata_holder, + model_path, + model_path_string, + slot_aggregated_status_manager, + })) + } + pub async fn spawn(&self) -> Result { let (chat_template_loaded_tx, chat_template_loaded_rx) = oneshot::channel::<()>(); let (model_loaded_tx, model_loaded_rx) = oneshot::channel::<()>(); + let (agent_warm_and_scheduler_running_tx, agent_warm_and_scheduler_running_rx) = + oneshot::channel::<()>(); let available_parallelism_value: i32 = available_parallelism()?.get().try_into()?; let n_threads = max(2, available_parallelism_value / 2); @@ -82,9 +117,16 @@ impl ContinuousBatchArbiter { )] let n_seq_max = desired_slots_total as u32; + #[expect( + clippy::cast_possible_truncation, + reason = "n_batch fits in u32 for llama.cpp FFI; usize is the internal type" + )] + let inference_parameters_n_batch_u32 = inference_parameters.n_batch as u32; + let context_params = LlamaContextParams::default() .with_embeddings(inference_parameters.enable_embeddings) .with_n_ctx(NonZeroU32::new(inference_parameters.context_size)) + .with_n_batch(inference_parameters_n_batch_u32) .with_flash_attention_policy(LLAMA_FLASH_ATTN_TYPE_AUTO) .with_n_seq_max(n_seq_max) .with_n_threads(n_threads) @@ -255,19 +297,19 @@ impl ContinuousBatchArbiter { model_path: model_path.clone(), multimodal_context, token_bos_str: model.token_to_piece( - model.token_bos(), + &SampledToken::Content(model.token_bos()), &mut special_token_decoder, true, None, )?, token_nl_str: model.token_to_piece( - model.token_nl(), + &SampledToken::Content(model.token_nl()), &mut special_token_decoder, true, None, )?, token_eos_str: model.token_to_piece( - model.token_eos(), + &SampledToken::Content(model.token_eos()), &mut special_token_decoder, true, None, @@ -275,50 +317,53 @@ impl ContinuousBatchArbiter { model: model.clone(), }); - let llama_context = match model - .new_context(&llama_backend, context_params) - .context("Unable to create llama.cpp context") - { - Ok(context) => context, - Err(err) => { - for slot_index in 0..desired_slots_total { - #[expect( - clippy::cast_sign_loss, - reason = "slot_index is always non-negative" - )] - slot_aggregated_status_manager - .slot_aggregated_status - .register_issue(AgentIssue::SlotCannotStart(SlotCannotStartParams { - error: format!("{err:#}"), - slot_index: slot_index as u32, - })); - } - - return Err(err); - } - }; + let mut llama_context = + match LlamaContext::from_model(&model, &llama_backend, context_params) + .context("Unable to create llama.cpp context") + { + Ok(context) => context, + Err(err) => { + for slot_index in 0..desired_slots_total { + #[expect( + clippy::cast_sign_loss, + reason = "slot_index is always non-negative" + )] + slot_aggregated_status_manager + .slot_aggregated_status + .register_issue(AgentIssue::SlotCannotStart( + SlotCannotStartParams { + error: format!("{err:#}"), + slot_index: slot_index as u32, + }, + )); + } - for slot_index in 0..desired_slots_total { - slot_aggregated_status_manager - .slot_aggregated_status - .increment_total_slots(); + return Err(err); + } + }; - #[expect(clippy::cast_sign_loss, reason = "slot_index is always non-negative")] - slot_aggregated_status_manager - .slot_aggregated_status - .register_fix(&AgentIssueFix::SlotStarted(slot_index as u32)); - } + Self::run_warmup_decode( + &model, + &mut llama_context, + scheduler_context.inference_parameters.n_batch, + desired_slots_total, + ); let mut scheduler = ContinuousBatchScheduler::new( command_rx, scheduler_context, llama_context, desired_slots_total, - slot_aggregated_status_manager - .slot_aggregated_status - .clone(), ); + if agent_warm_and_scheduler_running_tx.send(()).is_err() { + let message = "Arbiter dropped the agent-warm-and-scheduler-running receiver before the scheduler could start"; + + error!("{message}"); + + return Err(anyhow!(message)); + } + scheduler.run(); Ok(()) @@ -376,9 +421,54 @@ impl ContinuousBatchArbiter { } } + agent_warm_and_scheduler_running_rx.await.context( + "Scheduler thread did not signal agent-warm-and-scheduler-running before exiting", + )?; + + for slot_index in 0..self.desired_slots_total { + self.slot_aggregated_status_manager + .slot_aggregated_status + .increment_total_slots(); + + #[expect(clippy::cast_sign_loss, reason = "slot_index is always non-negative")] + self.slot_aggregated_status_manager + .slot_aggregated_status + .register_fix(&AgentIssueFix::SlotStarted(slot_index as u32)); + } + Ok(ContinuousBatchArbiterHandle { command_tx, scheduler_thread_handle, }) } + + fn run_warmup_decode( + model: &LlamaModel, + llama_context: &mut LlamaContext<'_>, + n_batch: usize, + desired_slots_total: i32, + ) { + let warmup_tokens = vec![model.token_bos(); 4]; + let mut warmup_batch = match LlamaBatch::new(n_batch, desired_slots_total) { + Ok(warmup_batch) => warmup_batch, + Err(err) => { + warn!("Warmup batch allocation failed: {err:#}"); + return; + } + }; + + for sequence_index in 0..desired_slots_total { + if let Err(err) = warmup_batch.add_sequence(&warmup_tokens, sequence_index, true) { + warn!("Warmup batch add_sequence failed: {err:#}"); + return; + } + } + + llama_context.clear_kv_cache(); + if let Err(err) = llama_context.decode(&mut warmup_batch) { + warn!("Warmup decode failed: {err:#}"); + } + llama_context.synchronize(); + llama_context.clear_kv_cache(); + } } diff --git a/paddler/src/agent/continuous_batch_arbiter_build_outcome.rs b/paddler/src/agent/continuous_batch_arbiter_build_outcome.rs new file mode 100644 index 00000000..1abd91ba --- /dev/null +++ b/paddler/src/agent/continuous_batch_arbiter_build_outcome.rs @@ -0,0 +1,6 @@ +use crate::agent::continuous_batch_arbiter::ContinuousBatchArbiter; + +pub enum ContinuousBatchArbiterBuildOutcome { + NoModelConfigured, + ReadyToSpawn(Box), +} diff --git a/paddler/src/agent/continuous_batch_embedding_processor.rs b/paddler/src/agent/continuous_batch_embedding_processor.rs index 162023c5..9952047d 100644 --- a/paddler/src/agent/continuous_batch_embedding_processor.rs +++ b/paddler/src/agent/continuous_batch_embedding_processor.rs @@ -6,9 +6,11 @@ use anyhow::anyhow; use llama_cpp_bindings::context::LlamaContext; use llama_cpp_bindings::llama_batch::LlamaBatch; use llama_cpp_bindings::model::AddBos; +use log::warn; use paddler_types::embedding::Embedding; use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; use paddler_types::embedding_result::EmbeddingResult; +use paddler_types::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; use paddler_types::request_params::GenerateEmbeddingBatchParams; use tokio::sync::mpsc; @@ -43,16 +45,21 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { input_batch, normalization_method, }, + slot_guard, }: GenerateEmbeddingBatchRequest, ) -> Result<()> { + #[expect( + unused_variables, + reason = "slot_guard is held until function returns to release the slot via Drop" + )] + let slot_guard = slot_guard; + if !self .scheduler_context .inference_parameters .enable_embeddings { - generated_embedding_tx.send(EmbeddingResult::Error( - "Embeddings are not enabled for this agent".to_owned(), - ))?; + generated_embedding_tx.send(EmbeddingResult::EmbeddingsDisabled)?; return Err(anyhow!("Embeddings are not enabled")); } @@ -75,15 +82,45 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { .collect::, _>>() .context("failed to tokenize embedding input batch")?; - let batch_n_tokens = self.scheduler_context.inference_parameters.batch_n_tokens; + let n_batch = self.scheduler_context.inference_parameters.n_batch; let max_sequences_per_batch = self.scheduler_context.desired_slots_total; - let token_counts: Vec = tokens_lines_list + + let mut tokens_lines_list_within_batch: Vec = Vec::new(); + for input in tokens_lines_list { + if input.tokens.len() > n_batch { + #[expect( + clippy::cast_possible_truncation, + reason = "document token counts and n_batch are model-bounded and fit in u32" + )] + let details = OversizedEmbeddingDocumentDetails { + document_tokens: input.tokens.len() as u32, + n_batch: n_batch as u32, + source_document_id: input.id.clone(), + }; + + warn!( + "{:?}: skipped embedding document {:?}: {} tokens exceeds n_batch {}", + self.scheduler_context.agent_name, + input.id, + details.document_tokens, + details.n_batch, + ); + + generated_embedding_tx.send(EmbeddingResult::DocumentExceedsBatchSize(details))?; + } else { + tokens_lines_list_within_batch.push(input); + } + } + + let token_counts: Vec = tokens_lines_list_within_batch .iter() .map(|input| input.tokens.len()) .collect(); let planned_batches = - plan_embedding_batches(&token_counts, batch_n_tokens, max_sequences_per_batch); - let mut batch = LlamaBatch::new(batch_n_tokens, max_sequences_per_batch)?; + plan_embedding_batches(&token_counts, n_batch, max_sequences_per_batch); + let mut batch = LlamaBatch::new(n_batch, max_sequences_per_batch)?; + + let mut embeddings_emitted: usize = 0; #[expect( clippy::cast_possible_truncation, @@ -95,8 +132,10 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { break; } - let batch_inputs: Vec<&EmbeddingInputTokenized> = - tokens_lines_list[planned_batch].iter().collect(); + let batch_inputs: Vec<&EmbeddingInputTokenized> = tokens_lines_list_within_batch + [planned_batch] + .iter() + .collect(); for (sequence_index, input) in batch_inputs.iter().enumerate() { batch.add_sequence(&input.tokens, sequence_index as i32, true)?; @@ -108,9 +147,15 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { &generated_embedding_tx, &normalization_method, )?; + + embeddings_emitted += batch_inputs.len(); } - generated_embedding_tx.send(EmbeddingResult::Done)?; + if embeddings_emitted == 0 { + generated_embedding_tx.send(EmbeddingResult::NoEmbeddingsProduced)?; + } else { + generated_embedding_tx.send(EmbeddingResult::Done)?; + } Ok(()) } diff --git a/paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs b/paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs new file mode 100644 index 00000000..32c46840 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs @@ -0,0 +1,179 @@ +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::context::LlamaContext; +use log::error; +use log::warn; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::generation_summary::GenerationSummary; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::agent::continuous_batch_scheduler::advance_outcome::AdvanceOutcome; +use crate::agent::continuous_batch_scheduler::classify_token_phase; +use crate::agent::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; +use crate::agent::continuous_batch_scheduler::completion_check_phase::CompletionCheckPhase; +use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; +use crate::agent::continuous_batch_scheduler::emit_token_phase; +use crate::agent::continuous_batch_scheduler::sample_outcome::SampleOutcome; +use crate::agent::continuous_batch_scheduler::sample_token_phase::SampleTokenPhase; +use crate::agent::continuous_batch_scheduler::tool_call_pass; +use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; + +pub struct AdvanceGeneratingPhase<'context> { + pub scheduler_context: &'context ContinuousBatchSchedulerContext, + pub llama_context: &'context LlamaContext<'context>, +} + +impl AdvanceGeneratingPhase<'_> { + pub fn run(self, requests: &mut [ContinuousBatchActiveRequest]) { + for request in requests { + let outcome = self.advance_one(request); + self.apply_outcome(request, outcome); + } + } + + fn advance_one(&self, request: &mut ContinuousBatchActiveRequest) -> Option { + if !matches!(request.phase, ContinuousBatchRequestPhase::Generating) { + return None; + } + + if request.pending_sampled_token.is_some() { + return None; + } + + let batch_index = request.i_batch?; + + let raw_token = match (SampleTokenPhase { + context: self.llama_context, + }) + .run(request, batch_index) + { + SampleOutcome::Sampled(token) => token, + SampleOutcome::AllCandidatesEliminated => { + error!( + "{:?}: sequence {} sampling exhausted candidates", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::Completed( + GeneratedTokenResult::SamplerError( + "all token candidates were eliminated during sampling".to_owned(), + ), + )); + } + SampleOutcome::GrammarRejected(message) => { + error!( + "{:?}: sequence {} grammar rejected sampled token: {message}", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::Completed( + GeneratedTokenResult::GrammarRejectedModelOutput(message), + )); + } + SampleOutcome::Failed(message) => { + error!( + "{:?}: sequence {} sampling error: {message}", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::Completed( + GeneratedTokenResult::SamplerError(message), + )); + } + }; + + let classified_outcomes = classify_token_phase::run(request, raw_token); + + let completion_phase = CompletionCheckPhase { + model: &self.scheduler_context.model, + }; + + let raw_as_sampled = SampledToken::Content(raw_token); + if matches!( + completion_phase.run(request, &raw_as_sampled), + CompletionCheckOutcome::ReachedEog + ) { + if let Some(pipeline) = request.tool_call_pipeline.as_mut() + && !pipeline.buffer_is_empty() + && let Some(event) = pipeline.finalize_to_generated_event() + && request.generated_tokens_tx.send(event).is_err() + { + warn!( + "{:?}: sequence {} client disconnected (receiver dropped) during EOG tool-call flush", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::ChannelDropped); + } + return Some(AdvanceOutcome::Completed(GeneratedTokenResult::Done( + GenerationSummary { + usage: *request.token_classifier.usage(), + }, + ))); + } + + for classified in &classified_outcomes { + match emit_token_phase::run(request, classified) { + EmitTokenOutcome::Emitted(_) => {} + EmitTokenOutcome::ChannelDropped => { + warn!( + "{:?}: sequence {} client disconnected (receiver dropped)", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::ChannelDropped); + } + } + + if let Some(event) = + tool_call_pass::run(request.tool_call_pipeline.as_mut(), classified) + && request.generated_tokens_tx.send(event).is_err() + { + warn!( + "{:?}: sequence {} client disconnected (receiver dropped)", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::ChannelDropped); + } + } + + match completion_phase.run(request, &raw_as_sampled) { + CompletionCheckOutcome::ReachedEog | CompletionCheckOutcome::ReachedMaxTokens => { + if let Some(pipeline) = request.tool_call_pipeline.as_mut() + && !pipeline.buffer_is_empty() + && let Some(event) = pipeline.finalize_to_generated_event() + && request.generated_tokens_tx.send(event).is_err() + { + warn!( + "{:?}: sequence {} client disconnected (receiver dropped) during tool-call EOG flush", + self.scheduler_context.agent_name, request.sequence_id + ); + return Some(AdvanceOutcome::ChannelDropped); + } + Some(AdvanceOutcome::Completed(GeneratedTokenResult::Done( + GenerationSummary { + usage: *request.token_classifier.usage(), + }, + ))) + } + CompletionCheckOutcome::Continue => { + Some(AdvanceOutcome::SampledAndStored(raw_as_sampled)) + } + } + } + + fn apply_outcome( + &self, + request: &mut ContinuousBatchActiveRequest, + outcome: Option, + ) { + match outcome { + None => {} + Some(AdvanceOutcome::SampledAndStored(token)) => { + request.pending_sampled_token = Some(token); + } + Some(AdvanceOutcome::Completed(event)) => { + request.complete_with_outcome(&self.scheduler_context.agent_name, event); + } + Some(AdvanceOutcome::ChannelDropped) => { + request.i_batch = None; + request.phase = ContinuousBatchRequestPhase::Completed; + } + } + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs new file mode 100644 index 00000000..174d4b7e --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs @@ -0,0 +1,34 @@ +use llama_cpp_bindings::SampledToken; +use paddler_types::generated_token_result::GeneratedTokenResult; + +pub enum AdvanceOutcome { + SampledAndStored(SampledToken), + Completed(GeneratedTokenResult), + ChannelDropped, +} + +#[cfg(test)] +mod tests { + use paddler_types::generated_token_result::GeneratedTokenResult; + use paddler_types::generation_summary::GenerationSummary; + + use super::AdvanceOutcome; + + #[test] + fn completed_carries_event_through_into_inner() { + let outcome = + AdvanceOutcome::Completed(GeneratedTokenResult::Done(GenerationSummary::default())); + + assert!(matches!( + outcome, + AdvanceOutcome::Completed(GeneratedTokenResult::Done(_)) + )); + } + + #[test] + fn channel_dropped_is_distinct_variant() { + let outcome = AdvanceOutcome::ChannelDropped; + + assert!(matches!(outcome, AdvanceOutcome::ChannelDropped)); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs b/paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs new file mode 100644 index 00000000..3255d18e --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs @@ -0,0 +1,164 @@ +use anyhow::Result; +use llama_cpp_bindings::SampledToken; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; +use crate::agent::continuous_batch_scheduler::generating_contribution::GeneratingContribution; +use crate::agent::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; + +pub struct AssembleBatchPhase { + pub n_batch: usize, +} + +impl AssembleBatchPhase { + /// # Errors + /// Forwards `LlamaBatch::add` failures verbatim. + pub fn run( + &self, + pass: &mut BatchPass, + requests: &mut [ContinuousBatchActiveRequest], + ) -> Result<()> { + let added = self.fill_generating(pass, requests)?; + pass.contributions.current_batch_token_count += added; + self.fill_ingesting(pass, requests)?; + Ok(()) + } + + fn fill_generating( + &self, + pass: &mut BatchPass, + requests: &[ContinuousBatchActiveRequest], + ) -> Result { + let mut tokens_added: usize = 0; + + for (request_index, request) in requests.iter().enumerate() { + if !matches!(request.phase, ContinuousBatchRequestPhase::Generating) { + continue; + } + + let Some(pending_token) = request.pending_sampled_token else { + continue; + }; + + if tokens_added >= self.n_batch { + break; + } + + let batch_position = pass.batch.n_tokens(); + + pass.batch.add( + &pending_token, + request.current_token_position, + &[request.sequence_id], + true, + )?; + + pass.contributions.generating.push(GeneratingContribution { + request_index, + batch_position, + }); + + tokens_added += 1; + } + + Ok(tokens_added) + } + + #[expect( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + reason = "token counts and positions fit in i32 for llama.cpp FFI" + )] + fn fill_ingesting( + &self, + pass: &mut BatchPass, + requests: &[ContinuousBatchActiveRequest], + ) -> Result<()> { + for (request_index, request) in requests.iter().enumerate() { + if !matches!(request.phase, ContinuousBatchRequestPhase::Ingesting) { + continue; + } + + let remaining = request.remaining_prompt_tokens(); + let chunk_size = compute_ingesting_chunk_size( + remaining.len(), + self.n_batch, + pass.contributions.current_batch_token_count, + ); + + if chunk_size == 0 { + continue; + } + + let chunk = &request.prompt_tokens + [request.prompt_tokens_ingested..request.prompt_tokens_ingested + chunk_size]; + let is_last_chunk = + request.prompt_tokens_ingested + chunk_size >= request.prompt_tokens.len(); + + for (offset, token) in chunk.iter().enumerate() { + let position = request.current_token_position + offset as i32; + let is_last_token_of_prompt = is_last_chunk && offset == chunk_size - 1; + + pass.batch.add( + &SampledToken::Content(*token), + position, + &[request.sequence_id], + is_last_token_of_prompt, + )?; + } + + pass.contributions.ingesting.push(IngestingContribution { + request_index, + chunk_size, + is_last_chunk, + last_batch_position: pass.batch.n_tokens() - 1, + }); + + pass.contributions.current_batch_token_count += chunk_size; + } + + Ok(()) + } +} + +fn compute_ingesting_chunk_size( + remaining_prompt_len: usize, + n_batch: usize, + current_batch_token_count: usize, +) -> usize { + let available_space = n_batch.saturating_sub(current_batch_token_count); + remaining_prompt_len.min(available_space) +} + +#[cfg(test)] +mod tests { + use super::compute_ingesting_chunk_size; + + #[test] + fn chunk_size_is_min_of_remaining_and_available_space() { + assert_eq!(compute_ingesting_chunk_size(10, 32, 0), 10); + assert_eq!(compute_ingesting_chunk_size(100, 32, 0), 32); + } + + #[test] + fn chunk_size_subtracts_already_used_space_from_batch_capacity() { + assert_eq!(compute_ingesting_chunk_size(20, 32, 12), 20); + assert_eq!(compute_ingesting_chunk_size(50, 32, 12), 20); + } + + #[test] + fn chunk_size_is_zero_when_batch_already_full() { + assert_eq!(compute_ingesting_chunk_size(50, 32, 32), 0); + } + + #[test] + fn chunk_size_is_zero_when_already_overfilled_via_saturating_sub() { + assert_eq!(compute_ingesting_chunk_size(50, 32, 40), 0); + } + + #[test] + fn chunk_size_is_zero_when_remaining_prompt_is_empty() { + assert_eq!(compute_ingesting_chunk_size(0, 32, 0), 0); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/batch_pass.rs b/paddler/src/agent/continuous_batch_scheduler/batch_pass.rs new file mode 100644 index 00000000..e5bfd0cb --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/batch_pass.rs @@ -0,0 +1,25 @@ +use anyhow::Result; +use llama_cpp_bindings::llama_batch::LlamaBatch; + +use crate::agent::continuous_batch_scheduler::contributions::Contributions; + +pub struct BatchPass<'tokens> { + pub batch: LlamaBatch<'tokens>, + pub contributions: Contributions, +} + +impl BatchPass<'_> { + /// # Errors + /// Forwards [`LlamaBatch::new`] failures verbatim. + pub fn new(n_batch: usize, max_sequences: i32) -> Result { + Ok(Self { + batch: LlamaBatch::new(n_batch, max_sequences)?, + contributions: Contributions::default(), + }) + } + + #[must_use] + pub const fn is_empty(&self) -> bool { + self.contributions.is_empty() + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/classified_token.rs b/paddler/src/agent/continuous_batch_scheduler/classified_token.rs new file mode 100644 index 00000000..4ac63597 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/classified_token.rs @@ -0,0 +1,15 @@ +use llama_cpp_bindings::SampledToken; + +pub struct ClassifiedToken { + pub sampled_token: SampledToken, + pub was_in_tool_call: bool, + pub is_in_tool_call: bool, + /// User-visible decoded piece. Empty when this token is part of a marker + /// (e.g. `` or `[/THINK]`) — emit phases must skip emission for + /// empty pieces so marker text never reaches client streams. + pub visible_piece: String, + /// Always the decoded UTF-8 piece, including marker bytes. Used by the + /// tool-call buffer so downstream parsers see the wrapped form + /// (`...` etc.) that llama.cpp's autoparser expects. + pub raw_piece: String, +} diff --git a/paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs b/paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs new file mode 100644 index 00000000..39bc589f --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs @@ -0,0 +1,159 @@ +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::sampled_token_classifier::IngestOutcome; +use llama_cpp_bindings::sampled_token_classifier::SampledTokenSection; +use llama_cpp_bindings::token::LlamaToken; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; + +pub fn run( + request: &mut ContinuousBatchActiveRequest, + raw_token: LlamaToken, +) -> Vec { + let section_before_ingest = request.token_classifier.current_section(); + let outcomes = request.token_classifier.ingest(raw_token); + classify_ingest_outcomes(outcomes, section_before_ingest) +} + +fn classify_ingest_outcomes( + outcomes: Vec, + section_before: SampledTokenSection, +) -> Vec { + let mut previous_section = section_before; + outcomes + .into_iter() + .map(|outcome| { + let section = section_of(outcome.sampled_token); + let classified = ClassifiedToken { + sampled_token: outcome.sampled_token, + was_in_tool_call: previous_section == SampledTokenSection::ToolCall, + is_in_tool_call: section == SampledTokenSection::ToolCall, + visible_piece: outcome.visible_piece, + raw_piece: outcome.raw_piece, + }; + previous_section = section; + classified + }) + .collect() +} + +const fn section_of(token: SampledToken) -> SampledTokenSection { + match token { + SampledToken::Reasoning(_) => SampledTokenSection::Reasoning, + SampledToken::Content(_) => SampledTokenSection::Content, + SampledToken::ToolCall(_) => SampledTokenSection::ToolCall, + SampledToken::Undeterminable(_) => SampledTokenSection::Pending, + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::SampledToken; + use llama_cpp_bindings::sampled_token_classifier::IngestOutcome; + use llama_cpp_bindings::sampled_token_classifier::SampledTokenSection; + use llama_cpp_bindings::token::LlamaToken; + + use super::classify_ingest_outcomes; + + fn outcome(sampled: SampledToken) -> IngestOutcome { + IngestOutcome { + sampled_token: sampled, + visible_piece: String::new(), + raw_piece: String::new(), + } + } + + #[test] + fn content_after_content_stays_outside_tool_call() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::Content(LlamaToken::new(1)))], + SampledTokenSection::Content, + ); + + assert_eq!(classified.len(), 1); + assert!(!classified[0].was_in_tool_call); + assert!(!classified[0].is_in_tool_call); + } + + #[test] + fn content_to_tool_call_marks_entry_transition() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::ToolCall(LlamaToken::new(2)))], + SampledTokenSection::Content, + ); + + assert_eq!(classified.len(), 1); + assert!(!classified[0].was_in_tool_call); + assert!(classified[0].is_in_tool_call); + } + + #[test] + fn tool_call_to_tool_call_stays_inside() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::ToolCall(LlamaToken::new(3)))], + SampledTokenSection::ToolCall, + ); + + assert_eq!(classified.len(), 1); + assert!(classified[0].was_in_tool_call); + assert!(classified[0].is_in_tool_call); + } + + #[test] + fn tool_call_to_content_marks_exit_transition() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::Content(LlamaToken::new(4)))], + SampledTokenSection::ToolCall, + ); + + assert_eq!(classified.len(), 1); + assert!(classified[0].was_in_tool_call); + assert!(!classified[0].is_in_tool_call); + } + + #[test] + fn reasoning_after_content_stays_outside_tool_call() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::Reasoning(LlamaToken::new(5)))], + SampledTokenSection::Content, + ); + + assert_eq!(classified.len(), 1); + assert!(!classified[0].was_in_tool_call); + assert!(!classified[0].is_in_tool_call); + } + + #[test] + fn undeterminable_after_content_maps_to_pending() { + let classified = classify_ingest_outcomes( + vec![outcome(SampledToken::Undeterminable(LlamaToken::new(6)))], + SampledTokenSection::Content, + ); + + assert_eq!(classified.len(), 1); + assert!(!classified[0].was_in_tool_call); + assert!(!classified[0].is_in_tool_call); + } + + #[test] + fn previous_section_carries_forward_across_multi_outcome_vec() { + let outcomes = vec![ + outcome(SampledToken::ToolCall(LlamaToken::new(7))), + outcome(SampledToken::ToolCall(LlamaToken::new(8))), + outcome(SampledToken::Content(LlamaToken::new(9))), + ]; + + let classified = classify_ingest_outcomes(outcomes, SampledTokenSection::Content); + + assert_eq!(classified.len(), 3); + + assert!(!classified[0].was_in_tool_call); + assert!(classified[0].is_in_tool_call); + + assert!(classified[1].was_in_tool_call); + assert!(classified[1].is_in_tool_call); + + assert!(classified[2].was_in_tool_call); + assert!(!classified[2].is_in_tool_call); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs b/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs new file mode 100644 index 00000000..9e58ca7b --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs @@ -0,0 +1,30 @@ +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; + +#[expect( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + reason = "chunk sizes fit in i32 for llama.cpp position arithmetic" +)] +pub fn run(pass: BatchPass, requests: &mut [ContinuousBatchActiveRequest]) { + for contribution in pass.contributions.generating { + let request = &mut requests[contribution.request_index]; + + request.pending_sampled_token = None; + request.i_batch = Some(contribution.batch_position); + request.current_token_position += 1; + } + + for contribution in pass.contributions.ingesting { + let request = &mut requests[contribution.request_index]; + + request.prompt_tokens_ingested += contribution.chunk_size; + request.current_token_position += contribution.chunk_size as i32; + + if contribution.is_last_chunk { + request.i_batch = Some(contribution.last_batch_position); + request.phase = ContinuousBatchRequestPhase::Generating; + } + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/completion_check_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/completion_check_outcome.rs new file mode 100644 index 00000000..d3321426 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/completion_check_outcome.rs @@ -0,0 +1,5 @@ +pub enum CompletionCheckOutcome { + Continue, + ReachedEog, + ReachedMaxTokens, +} diff --git a/paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs b/paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs new file mode 100644 index 00000000..e2855e50 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs @@ -0,0 +1,87 @@ +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::TokenUsage; +use llama_cpp_bindings::model::LlamaModel; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; + +pub struct CompletionCheckPhase<'model> { + pub model: &'model LlamaModel, +} + +impl CompletionCheckPhase<'_> { + #[must_use] + pub fn run( + &self, + request: &ContinuousBatchActiveRequest, + sampled_token: &SampledToken, + ) -> CompletionCheckOutcome { + if self.model.is_eog_token(sampled_token) { + return CompletionCheckOutcome::ReachedEog; + } + + #[expect( + clippy::cast_sign_loss, + reason = "max_tokens is non-negative by API contract" + )] + let max_tokens_u64 = request.max_tokens as u64; + + if completion_token_count(request.token_classifier.usage()) >= max_tokens_u64 { + CompletionCheckOutcome::ReachedMaxTokens + } else { + CompletionCheckOutcome::Continue + } + } +} + +const fn completion_token_count(usage: &TokenUsage) -> u64 { + usage.content_tokens + usage.reasoning_tokens + usage.undeterminable_tokens +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::TokenUsage; + + use super::completion_token_count; + + #[test] + fn completion_token_count_sums_content_reasoning_and_undeterminable() { + let usage = TokenUsage { + content_tokens: 5, + reasoning_tokens: 3, + undeterminable_tokens: 2, + ..TokenUsage::new() + }; + + assert_eq!(completion_token_count(&usage), 10); + } + + #[test] + fn completion_token_count_excludes_prompt_and_cached_prompt_tokens() { + let usage = TokenUsage { + prompt_tokens: 100, + cached_prompt_tokens: 50, + input_image_tokens: 20, + input_audio_tokens: 10, + content_tokens: 4, + reasoning_tokens: 0, + tool_call_tokens: 0, + undeterminable_tokens: 0, + }; + + assert_eq!(completion_token_count(&usage), 4); + } + + #[test] + fn completion_token_count_excludes_tool_call_tokens() { + let usage = TokenUsage { + content_tokens: 1, + reasoning_tokens: 0, + tool_call_tokens: 99, + undeterminable_tokens: 0, + ..TokenUsage::new() + }; + + assert_eq!(completion_token_count(&usage), 1); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/contributions.rs b/paddler/src/agent/continuous_batch_scheduler/contributions.rs new file mode 100644 index 00000000..4ce8805b --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/contributions.rs @@ -0,0 +1,55 @@ +use crate::agent::continuous_batch_scheduler::generating_contribution::GeneratingContribution; +use crate::agent::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; + +#[derive(Default)] +pub struct Contributions { + pub generating: Vec, + pub ingesting: Vec, + pub current_batch_token_count: usize, +} + +impl Contributions { + #[must_use] + pub const fn is_empty(&self) -> bool { + self.generating.is_empty() && self.ingesting.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::Contributions; + use super::GeneratingContribution; + use super::IngestingContribution; + + #[test] + fn default_contributions_are_empty() { + let contributions = Contributions::default(); + + assert!(contributions.is_empty()); + assert_eq!(contributions.current_batch_token_count, 0); + } + + #[test] + fn contributions_with_generating_entry_is_not_empty() { + let mut contributions = Contributions::default(); + contributions.generating.push(GeneratingContribution { + request_index: 0, + batch_position: 1, + }); + + assert!(!contributions.is_empty()); + } + + #[test] + fn contributions_with_ingesting_entry_is_not_empty() { + let mut contributions = Contributions::default(); + contributions.ingesting.push(IngestingContribution { + request_index: 0, + chunk_size: 4, + is_last_chunk: false, + last_batch_position: 3, + }); + + assert!(!contributions.is_empty()); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs b/paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs new file mode 100644 index 00000000..d080bbba --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs @@ -0,0 +1,8 @@ +use llama_cpp_bindings::context::LlamaContext; + +use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; +use crate::agent::continuous_batch_scheduler::decode_outcome::DecodeOutcome; + +pub fn run(pass: &mut BatchPass, context: &mut LlamaContext) -> DecodeOutcome { + DecodeOutcome::from_decode_result(&context.decode(&mut pass.batch)) +} diff --git a/paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs new file mode 100644 index 00000000..667d2821 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs @@ -0,0 +1,67 @@ +use llama_cpp_bindings::DecodeError; + +#[derive(Debug)] +pub enum DecodeOutcome { + Decoded, + NeedsEviction, + Aborted, + Errored(i32), +} + +impl DecodeOutcome { + #[must_use] + pub const fn from_decode_result(result: &Result<(), DecodeError>) -> Self { + match result { + Ok(()) => Self::Decoded, + Err(DecodeError::NoKvCacheSlot) => Self::NeedsEviction, + Err(DecodeError::Aborted | DecodeError::NTokensZero) => Self::Aborted, + Err(DecodeError::Unknown(error_code)) => Self::Errored(*error_code), + } + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::DecodeError; + + use super::DecodeOutcome; + + #[test] + fn ok_maps_to_decoded() { + assert!(matches!( + DecodeOutcome::from_decode_result(&Ok(())), + DecodeOutcome::Decoded + )); + } + + #[test] + fn no_kv_cache_slot_maps_to_needs_eviction() { + assert!(matches!( + DecodeOutcome::from_decode_result(&Err(DecodeError::NoKvCacheSlot)), + DecodeOutcome::NeedsEviction + )); + } + + #[test] + fn aborted_maps_to_aborted() { + assert!(matches!( + DecodeOutcome::from_decode_result(&Err(DecodeError::Aborted)), + DecodeOutcome::Aborted + )); + } + + #[test] + fn n_tokens_zero_maps_to_aborted() { + assert!(matches!( + DecodeOutcome::from_decode_result(&Err(DecodeError::NTokensZero)), + DecodeOutcome::Aborted + )); + } + + #[test] + fn unknown_carries_error_code() { + let outcome = DecodeOutcome::from_decode_result(&Err(DecodeError::Unknown(42))); + + assert!(matches!(outcome, DecodeOutcome::Errored(42))); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/emit_token_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/emit_token_outcome.rs new file mode 100644 index 00000000..2860bc03 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/emit_token_outcome.rs @@ -0,0 +1,5 @@ +#[derive(Debug)] +pub enum EmitTokenOutcome { + Emitted(String), + ChannelDropped, +} diff --git a/paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs b/paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs new file mode 100644 index 00000000..a1af128d --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs @@ -0,0 +1,149 @@ +use llama_cpp_bindings::SampledToken; +use paddler_types::generated_token_result::GeneratedTokenResult; +use tokio::sync::mpsc; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; +use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; + +pub fn run( + request: &mut ContinuousBatchActiveRequest, + classified: &ClassifiedToken, +) -> EmitTokenOutcome { + emit_classified(classified, &request.generated_tokens_tx) +} + +fn emit_classified( + classified: &ClassifiedToken, + tx: &mpsc::UnboundedSender, +) -> EmitTokenOutcome { + if classified.visible_piece.is_empty() { + return EmitTokenOutcome::Emitted(String::new()); + } + + let piece = classified.visible_piece.clone(); + let event = token_to_event(classified.sampled_token, piece.clone()); + + if tx.send(event).is_err() { + return EmitTokenOutcome::ChannelDropped; + } + + EmitTokenOutcome::Emitted(piece) +} + +const fn token_to_event(sampled_token: SampledToken, piece: String) -> GeneratedTokenResult { + match sampled_token { + SampledToken::Content(_) => GeneratedTokenResult::ContentToken(piece), + SampledToken::Reasoning(_) => GeneratedTokenResult::ReasoningToken(piece), + SampledToken::ToolCall(_) => GeneratedTokenResult::ToolCallToken(piece), + SampledToken::Undeterminable(_) => GeneratedTokenResult::UndeterminableToken(piece), + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::bail; + use llama_cpp_bindings::SampledToken; + use llama_cpp_bindings::token::LlamaToken; + use paddler_types::generated_token_result::GeneratedTokenResult; + use tokio::sync::mpsc; + + use super::emit_classified; + use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; + use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; + + fn classified_with_piece(sampled: SampledToken, piece: &str) -> ClassifiedToken { + ClassifiedToken { + sampled_token: sampled, + was_in_tool_call: false, + is_in_tool_call: false, + visible_piece: piece.to_owned(), + raw_piece: piece.to_owned(), + } + } + + #[test] + fn empty_visible_piece_emits_empty_string_without_sending() -> Result<()> { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(1)), ""); + + match emit_classified(&classified, &tx) { + EmitTokenOutcome::Emitted(piece) if piece.is_empty() => {} + other => bail!("expected Emitted(\"\"), got {other:?}"), + } + + match rx.try_recv() { + Err(mpsc::error::TryRecvError::Empty) => Ok(()), + other => bail!("expected empty channel, got {other:?}"), + } + } + + #[test] + fn content_token_emits_content_event() -> Result<()> { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(2)), "hi"); + + emit_classified(&classified, &tx); + + match rx.try_recv() { + Ok(GeneratedTokenResult::ContentToken(text)) if text == "hi" => Ok(()), + other => bail!("expected ContentToken(\"hi\"), got {other:?}"), + } + } + + #[test] + fn reasoning_token_emits_reasoning_event() -> Result<()> { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let classified = + classified_with_piece(SampledToken::Reasoning(LlamaToken::new(3)), "think"); + + emit_classified(&classified, &tx); + + match rx.try_recv() { + Ok(GeneratedTokenResult::ReasoningToken(text)) if text == "think" => Ok(()), + other => bail!("expected ReasoningToken(\"think\"), got {other:?}"), + } + } + + #[test] + fn tool_call_token_emits_tool_call_event() -> Result<()> { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let classified = classified_with_piece(SampledToken::ToolCall(LlamaToken::new(4)), "{"); + + emit_classified(&classified, &tx); + + match rx.try_recv() { + Ok(GeneratedTokenResult::ToolCallToken(text)) if text == "{" => Ok(()), + other => bail!("expected ToolCallToken(\"{{\"), got {other:?}"), + } + } + + #[test] + fn undeterminable_token_emits_undeterminable_event() -> Result<()> { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let classified = + classified_with_piece(SampledToken::Undeterminable(LlamaToken::new(5)), "?"); + + emit_classified(&classified, &tx); + + match rx.try_recv() { + Ok(GeneratedTokenResult::UndeterminableToken(text)) if text == "?" => Ok(()), + other => bail!("expected UndeterminableToken(\"?\"), got {other:?}"), + } + } + + #[test] + fn dropped_receiver_returns_channel_dropped() -> Result<()> { + let (tx, rx) = mpsc::unbounded_channel::(); + drop(rx); + let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(6)), "hi"); + + match emit_classified(&classified, &tx) { + EmitTokenOutcome::ChannelDropped => Ok(()), + EmitTokenOutcome::Emitted(piece) => { + bail!("expected ChannelDropped on dropped receiver, got Emitted({piece:?})") + } + } + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/generating_contribution.rs b/paddler/src/agent/continuous_batch_scheduler/generating_contribution.rs new file mode 100644 index 00000000..4eeb07fa --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/generating_contribution.rs @@ -0,0 +1,4 @@ +pub struct GeneratingContribution { + pub request_index: usize, + pub batch_position: i32, +} diff --git a/paddler/src/agent/continuous_batch_scheduler/ingesting_contribution.rs b/paddler/src/agent/continuous_batch_scheduler/ingesting_contribution.rs new file mode 100644 index 00000000..48a2b032 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/ingesting_contribution.rs @@ -0,0 +1,6 @@ +pub struct IngestingContribution { + pub request_index: usize, + pub chunk_size: usize, + pub is_last_chunk: bool, + pub last_batch_position: i32, +} diff --git a/paddler/src/agent/continuous_batch_scheduler.rs b/paddler/src/agent/continuous_batch_scheduler/mod.rs similarity index 70% rename from paddler/src/agent/continuous_batch_scheduler.rs rename to paddler/src/agent/continuous_batch_scheduler/mod.rs index 50e29b7f..a21b5f2e 100644 --- a/paddler/src/agent/continuous_batch_scheduler.rs +++ b/paddler/src/agent/continuous_batch_scheduler/mod.rs @@ -1,17 +1,39 @@ +pub mod advance_generating_phase; +pub mod advance_outcome; +pub mod assemble_batch_phase; +pub mod batch_pass; +pub mod classified_token; +pub mod classify_token_phase; +pub mod commit_phase; +pub mod completion_check_outcome; +pub mod completion_check_phase; +pub mod contributions; +pub mod decode_batch_phase; +pub mod decode_outcome; +pub mod emit_token_outcome; +pub mod emit_token_phase; +pub mod generating_contribution; +pub mod ingesting_contribution; +pub mod sample_outcome; +pub mod sample_token_phase; +pub mod tool_call_pass; +pub mod tool_call_pipeline_build_outcome; + use std::collections::VecDeque; use std::sync::Arc; use std::sync::mpsc::Receiver; use std::sync::mpsc::TryRecvError; use std::time::Duration; +use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; -use llama_cpp_bindings::DecodeError; use llama_cpp_bindings::context::LlamaContext; -use llama_cpp_bindings::llama_batch::LlamaBatch; +use llama_cpp_bindings::error::EvalMultimodalChunksError; use llama_cpp_bindings::model::AddBos; use llama_cpp_bindings::mtmd::MtmdBitmap; use llama_cpp_bindings::mtmd::MtmdContext; +use llama_cpp_bindings::mtmd::MtmdEvalError; use llama_cpp_bindings::mtmd::MtmdInputText; use llama_cpp_bindings::sampling::LlamaSampler; use log::debug; @@ -20,11 +42,20 @@ use log::info; use log::warn; use paddler_types::embedding_result::EmbeddingResult; use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::generation_summary::GenerationSummary; +use paddler_types::oversized_image_details::OversizedImageDetails; use paddler_types::request_params::ContinueFromRawPromptParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use rand::Rng as _; use rand::rngs::ThreadRng; use tokio::sync::mpsc; +use self::advance_generating_phase::AdvanceGeneratingPhase; +use self::assemble_batch_phase::AssembleBatchPhase; +use self::batch_pass::BatchPass; +use self::decode_outcome::DecodeOutcome; +use self::tool_call_pipeline_build_outcome::ToolCallPipelineBuildOutcome; use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; @@ -40,21 +71,11 @@ use crate::agent::resolve_grammar::resolve_grammar; use crate::agent::sample_token_at_batch_index::sample_token_at_batch_index; use crate::agent::sampling_outcome::SamplingOutcome; use crate::agent::sequence_id_pool::SequenceIdPool; +use crate::agent::slot_guard::SlotGuard; use crate::decoded_image::DecodedImage; -use crate::dispenses_slots::DispensesSlots; -use crate::slot_aggregated_status::SlotAggregatedStatus; - -struct GeneratingContribution { - request_index: usize, - batch_position: i32, -} - -struct IngestingContribution { - request_index: usize, - chunk_size: usize, - is_last_chunk: bool, - last_batch_position: i32, -} +use crate::tool_call_pipeline::ToolCallPipeline; +use crate::tool_call_validator::ToolCallValidator; +use crate::tool_call_validator::ValidatorBuildError; pub struct ContinuousBatchScheduler { active_requests: Vec, @@ -65,10 +86,10 @@ pub struct ContinuousBatchScheduler { running: bool, scheduler_context: Arc, sequence_id_pool: SequenceIdPool, - slot_aggregated_status: Arc, } impl ContinuousBatchScheduler { + #[must_use] #[expect( unsafe_code, reason = "required for FFI lifetime extension with llama.cpp" @@ -78,7 +99,6 @@ impl ContinuousBatchScheduler { scheduler_context: Arc, llama_context: LlamaContext, max_concurrent_sequences: i32, - slot_aggregated_status: Arc, ) -> Self { let llama_context = unsafe { std::mem::transmute::, LlamaContext<'static>>(llama_context) @@ -93,7 +113,6 @@ impl ContinuousBatchScheduler { running: true, scheduler_context, sequence_id_pool: SequenceIdPool::new(max_concurrent_sequences), - slot_aggregated_status, } } @@ -177,13 +196,15 @@ impl ContinuousBatchScheduler { fn accept_conversation_history_request( &mut self, - request: ContinueFromConversationHistoryRequest, + ContinueFromConversationHistoryRequest { + generate_tokens_stop_rx, + generated_tokens_tx, + params, + slot_guard, + }: ContinueFromConversationHistoryRequest, ) { - let generated_tokens_tx = request.generated_tokens_tx; - let generate_tokens_stop_rx = request.generate_tokens_stop_rx; - let prepared = match prepare_conversation_history_request( - request.params, + params, &generated_tokens_tx, &self.scheduler_context, ) { @@ -203,32 +224,52 @@ impl ContinuousBatchScheduler { raw_prompt, max_tokens, grammar_sampler, + parse_tool_calls, + tools, } => { - self.accept_text_prompt( + if let Err(err) = self.accept_text_prompt( &raw_prompt, max_tokens, grammar_sampler, + parse_tool_calls, + tools, generated_tokens_tx, generate_tokens_stop_rx, - ); + slot_guard, + ) { + error!( + "{:?}: failed to accept text prompt: {err:#}", + self.scheduler_context.agent_name + ); + } } PreparedConversationHistoryRequest::MultimodalPrompt { raw_prompt, images, max_tokens, grammar_sampler, + parse_tool_calls, + tools, } => { let multimodal_context = self.scheduler_context.multimodal_context.clone(); - if let Some(multimodal_context) = multimodal_context.as_ref() { - self.accept_multimodal_request( + if let Some(multimodal_context) = multimodal_context.as_ref() + && let Err(err) = self.accept_multimodal_request( multimodal_context, raw_prompt, &images, max_tokens, grammar_sampler, + parse_tool_calls, + tools, generated_tokens_tx, generate_tokens_stop_rx, + slot_guard, + ) + { + error!( + "{:?}: failed to accept multimodal request: {err:#}", + self.scheduler_context.agent_name ); } } @@ -246,6 +287,7 @@ impl ContinuousBatchScheduler { max_tokens, raw_prompt, }, + slot_guard, }: ContinueFromRawPromptRequest, ) { let grammar_sampler = match resolve_grammar(grammar.as_ref(), false, &generated_tokens_tx) { @@ -260,13 +302,21 @@ impl ContinuousBatchScheduler { } }; - self.accept_text_prompt( + if let Err(err) = self.accept_text_prompt( &raw_prompt, max_tokens, grammar_sampler, + false, + Vec::new(), generated_tokens_tx, generate_tokens_stop_rx, - ); + slot_guard, + ) { + error!( + "{:?}: failed to accept raw prompt: {err:#}", + self.scheduler_context.agent_name + ); + } } fn create_sampler_chain(&mut self) -> LlamaSampler { @@ -324,14 +374,99 @@ impl ContinuousBatchScheduler { ) } + #[expect( + unsafe_code, + reason = "the SchedulerContext owns the LlamaModel for the lifetime of the active_requests vec — same pattern as LlamaContext<'static> above" + )] + fn build_token_classifier_for_active_request( + &self, + ) -> llama_cpp_bindings::SampledTokenClassifier<'static> { + let classifier = self.scheduler_context.model.sampled_token_classifier(); + + unsafe { + std::mem::transmute::< + llama_cpp_bindings::SampledTokenClassifier<'_>, + llama_cpp_bindings::SampledTokenClassifier<'static>, + >(classifier) + } + } + + fn build_tool_call_pipeline( + &self, + tools: Vec>, + parse_tool_calls: bool, + ) -> Result { + if !parse_tool_calls || tools.is_empty() { + return Ok(ToolCallPipelineBuildOutcome::Disabled); + } + + let validator = match ToolCallValidator::from_tools(&tools) { + Ok(validator) => validator, + Err(ValidatorBuildError::InvalidSchema { tool_name, message }) => { + return Ok(ToolCallPipelineBuildOutcome::SchemaInvalid(format!( + "tool {tool_name:?} parameters are not a valid JSON Schema: {message}" + ))); + } + Err(err @ ValidatorBuildError::SerializationFailed { .. }) => { + return Err(anyhow::Error::from(err)) + .context("failed to serialize tool parameters during validator build"); + } + }; + + let tools_json: Vec = tools + .into_iter() + .map(|tool| serde_json::to_value(&tool)) + .collect::, _>>() + .context("failed to serialize tools to JSON")?; + + let pipeline = + ToolCallPipeline::new(self.scheduler_context.model.clone(), &tools_json, validator) + .context("failed to serialize tools for tool-call pipeline")?; + + Ok(ToolCallPipelineBuildOutcome::Ready(pipeline)) + } + + #[expect( + clippy::too_many_arguments, + reason = "text prompt acceptance genuinely needs all these parameters from the caller" + )] fn accept_text_prompt( &mut self, prompt: &str, max_tokens: i32, grammar_sampler: Option, + parse_tool_calls: bool, + tools: Vec>, generated_tokens_tx: mpsc::UnboundedSender, generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, - ) { + slot_guard: SlotGuard, + ) -> Result<()> { + let tool_call_pipeline = match self + .build_tool_call_pipeline(tools, parse_tool_calls) + .context("failed to build tool-call pipeline for text prompt")? + { + ToolCallPipelineBuildOutcome::Disabled => None, + ToolCallPipelineBuildOutcome::Ready(pipeline) => Some(pipeline), + ToolCallPipelineBuildOutcome::SchemaInvalid(message) => { + error!( + "{:?}: rejecting text prompt: {message}", + self.scheduler_context.agent_name + ); + + if generated_tokens_tx + .send(GeneratedTokenResult::ToolSchemaInvalid(message)) + .is_err() + { + warn!( + "{:?}: failed to send result to client (receiver dropped)", + self.scheduler_context.agent_name + ); + } + + return Ok(()); + } + }; + let mut sequence_id_option = self.sequence_id_pool.acquire(); if sequence_id_option.is_none() { @@ -357,7 +492,7 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); }; let prompt_tokens = match self @@ -385,7 +520,7 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); } }; @@ -394,11 +529,16 @@ impl ContinuousBatchScheduler { else { self.sequence_id_pool.release(sequence_id); - return; + return Ok(()); }; let chain = self.create_sampler_chain(); + let mut token_classifier = self.build_token_classifier_for_active_request(); + + token_classifier.record_prompt_tokens(prompt_tokens.len() as u64); + token_classifier.ingest_prompt_tokens(&prompt_tokens); + #[expect( clippy::cast_sign_loss, reason = "sequence IDs are always non-negative" @@ -413,8 +553,6 @@ impl ContinuousBatchScheduler { ); } - self.slot_aggregated_status.take_slot(); - debug!( "{:?}: accepted text prompt request on sequence {sequence_id} ({} tokens)", self.scheduler_context.agent_name, @@ -423,9 +561,9 @@ impl ContinuousBatchScheduler { self.active_requests.push(ContinuousBatchActiveRequest { chain, + token_classifier, current_token_position: 0, grammar_sampler: llama_grammar_sampler, - generated_tokens_count: 0, generated_tokens_tx, generate_tokens_stop_rx, i_batch: None, @@ -435,8 +573,11 @@ impl ContinuousBatchScheduler { prompt_tokens, prompt_tokens_ingested: 0, sequence_id, - utf8_decoder: encoding_rs::UTF_8.new_decoder(), + slot_guard, + tool_call_pipeline, }); + + Ok(()) } #[expect( @@ -450,9 +591,38 @@ impl ContinuousBatchScheduler { images: &[DecodedImage], max_tokens: i32, grammar_sampler: Option, + parse_tool_calls: bool, + tools: Vec>, generated_tokens_tx: mpsc::UnboundedSender, generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, - ) { + slot_guard: SlotGuard, + ) -> Result<()> { + let tool_call_pipeline = match self + .build_tool_call_pipeline(tools, parse_tool_calls) + .context("failed to build tool-call pipeline for multimodal request")? + { + ToolCallPipelineBuildOutcome::Disabled => None, + ToolCallPipelineBuildOutcome::Ready(pipeline) => Some(pipeline), + ToolCallPipelineBuildOutcome::SchemaInvalid(message) => { + error!( + "{:?}: rejecting multimodal request: {message}", + self.scheduler_context.agent_name + ); + + if generated_tokens_tx + .send(GeneratedTokenResult::ToolSchemaInvalid(message)) + .is_err() + { + warn!( + "{:?}: failed to send result to client (receiver dropped)", + self.scheduler_context.agent_name + ); + } + + return Ok(()); + } + }; + let Some(sequence_id) = self.sequence_id_pool.acquire() else { let message = format!( "{:?}: no available sequence slots for multimodal request", @@ -471,7 +641,7 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); }; let bitmaps: Vec = match images @@ -502,7 +672,7 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); } }; @@ -538,11 +708,11 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); } }; - let batch_size = self.scheduler_context.inference_parameters.batch_n_tokens; + let batch_size = self.scheduler_context.inference_parameters.n_batch; #[expect( clippy::cast_sign_loss, @@ -560,23 +730,52 @@ impl ContinuousBatchScheduler { self.harvest_pending_samples_before_external_decode(); + let mut token_classifier = self.build_token_classifier_for_active_request(); + #[expect( clippy::cast_possible_truncation, clippy::cast_possible_wrap, reason = "batch_size fits in i32 for llama.cpp FFI" )] - let tokens_ingested = match input_chunks - .eval_chunks( - multimodal_context, - &self.llama_context, - 0, - sequence_id, - batch_size as i32, - true, - ) - .map_err(|err| anyhow!("Failed to evaluate multimodal chunks: {err}")) - { + let eval_outcome = token_classifier.eval_multimodal_chunks( + &input_chunks, + multimodal_context, + &self.llama_context, + 0, + sequence_id, + batch_size as i32, + true, + ); + + let tokens_ingested = match eval_outcome { Ok(tokens_ingested) => tokens_ingested, + Err(EvalMultimodalChunksError::EvalFailed( + MtmdEvalError::ImageChunkExceedsBatchSize(mismatch), + )) => { + warn!( + "{:?}: refused multimodal request: image chunk has {} tokens but n_batch is {}", + self.scheduler_context.agent_name, mismatch.image_tokens, mismatch.n_batch, + ); + + self.sequence_id_pool.release(sequence_id); + + if generated_tokens_tx + .send(GeneratedTokenResult::ImageExceedsBatchSize( + OversizedImageDetails { + image_tokens: mismatch.image_tokens, + n_batch: mismatch.n_batch, + }, + )) + .is_err() + { + warn!( + "{:?}: failed to send result to client (receiver dropped)", + self.scheduler_context.agent_name + ); + } + + return Ok(()); + } Err(err) => { let message = format!( "{:?}: failed to ingest multimodal prompt: {err}", @@ -596,7 +795,7 @@ impl ContinuousBatchScheduler { ); } - return; + return Ok(()); } }; @@ -607,13 +806,11 @@ impl ContinuousBatchScheduler { else { self.sequence_id_pool.release(sequence_id); - return; + return Ok(()); }; let chain = self.create_sampler_chain(); - self.slot_aggregated_status.take_slot(); - debug!( "{:?}: accepted multimodal request on sequence {sequence_id} ({tokens_ingested} tokens ingested)", self.scheduler_context.agent_name @@ -621,9 +818,9 @@ impl ContinuousBatchScheduler { self.active_requests.push(ContinuousBatchActiveRequest { chain, + token_classifier, current_token_position: tokens_ingested, grammar_sampler: llama_grammar_sampler, - generated_tokens_count: 0, generated_tokens_tx, generate_tokens_stop_rx, i_batch: Some(-1), @@ -633,8 +830,11 @@ impl ContinuousBatchScheduler { prompt_tokens: Vec::new(), prompt_tokens_ingested: 0, sequence_id, - utf8_decoder: encoding_rs::UTF_8.new_decoder(), + slot_guard, + tool_call_pipeline, }); + + Ok(()) } fn harvest_pending_samples_before_external_decode(&mut self) { @@ -660,8 +860,15 @@ impl ContinuousBatchScheduler { &mut active_request.chain, &mut active_request.grammar_sampler, ) { - Ok(SamplingOutcome::Token(sampled_token)) => { - active_request.pending_sampled_token = Some(sampled_token); + Ok(SamplingOutcome::Token(raw_token)) => { + // Update classifier state (section / usage counters) but drop the + // outcomes — harvest-sampled tokens are funnelled into the next + // batch via `pending_sampled_token`; their user-visible emission + // happens in `advance_generating_phase` after the next decode, + // not here. + let _ = active_request.token_classifier.ingest(raw_token); + active_request.pending_sampled_token = + Some(llama_cpp_bindings::SampledToken::Content(raw_token)); active_request.i_batch = None; } Ok(SamplingOutcome::AllCandidatesEliminated) => { @@ -703,9 +910,13 @@ impl ContinuousBatchScheduler { fn check_stop_signals(&mut self) { for active_request in &mut self.active_requests { if active_request.is_stop_requested() { + let summary = GenerationSummary { + usage: *active_request.token_classifier.usage(), + }; + active_request.complete_with_outcome( &self.scheduler_context.agent_name, - GeneratedTokenResult::Done, + GeneratedTokenResult::Done(summary), ); } } @@ -719,10 +930,7 @@ impl ContinuousBatchScheduler { if self.has_active_requests() { if request .generated_embedding_tx - .send(EmbeddingResult::Error( - "Embedding requests cannot be processed while generation requests are active" - .to_owned(), - )) + .send(EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration) .is_err() { warn!( @@ -756,7 +964,8 @@ impl ContinuousBatchScheduler { fn execute_one_iteration(&mut self) -> Result<()> { self.advance_generating_requests(); - let batch_n_tokens = self.scheduler_context.inference_parameters.batch_n_tokens; + let n_batch = self.scheduler_context.inference_parameters.n_batch; + let assemble_phase = AssembleBatchPhase { n_batch }; loop { let max_sequences = self.active_requests.len(); @@ -766,54 +975,38 @@ impl ContinuousBatchScheduler { clippy::cast_possible_wrap, reason = "token counts and positions fit in i32 for llama.cpp FFI" )] - let mut batch = LlamaBatch::new(batch_n_tokens, max_sequences.max(1) as i32)?; + let mut pass = BatchPass::new(n_batch, max_sequences.max(1) as i32)?; - let mut generating_contributions: Vec = Vec::new(); - let mut ingesting_contributions: Vec = Vec::new(); + assemble_phase.run(&mut pass, &mut self.active_requests)?; - let mut current_batch_token_count: usize = 0; - - current_batch_token_count += self.add_generating_pending_tokens_to_batch( - &mut batch, - batch_n_tokens, - &mut generating_contributions, - )?; - - self.add_ingesting_prompt_chunks_to_batch( - &mut batch, - batch_n_tokens, - current_batch_token_count, - &mut ingesting_contributions, - )?; - - if batch.n_tokens() == 0 { + if pass.is_empty() { return Ok(()); } debug!( "{:?}: decoding batch with {} tokens for {} active requests", self.scheduler_context.agent_name, - batch.n_tokens(), + pass.batch.n_tokens(), self.active_requests.len() ); - match self.llama_context.decode(&mut batch) { - Ok(()) => { - self.commit_contributions(&generating_contributions, &ingesting_contributions); + match decode_batch_phase::run(&mut pass, &mut self.llama_context) { + DecodeOutcome::Decoded => { + commit_phase::run(pass, &mut self.active_requests); return Ok(()); } - Err(DecodeError::NoKvCacheSlot) => { + DecodeOutcome::NeedsEviction => { self.evict_largest_sequence(); if self.active_requests.is_empty() { return Ok(()); } } - Err(DecodeError::Aborted | DecodeError::NTokensZero) => { + DecodeOutcome::Aborted => { return Ok(()); } - Err(DecodeError::Unknown(error_code)) => { + DecodeOutcome::Errored(error_code) => { return Err(anyhow!( "Decode failed with unknown error code: {error_code}" )); @@ -823,254 +1016,11 @@ impl ContinuousBatchScheduler { } fn advance_generating_requests(&mut self) { - for active_request in &mut self.active_requests { - if !matches!( - active_request.phase, - ContinuousBatchRequestPhase::Generating - ) { - continue; - } - - if active_request.pending_sampled_token.is_some() { - continue; - } - - let Some(batch_index) = active_request.i_batch else { - continue; - }; - - let sampled_token = match sample_token_at_batch_index( - &self.llama_context, - batch_index, - &mut active_request.chain, - &mut active_request.grammar_sampler, - ) { - Ok(SamplingOutcome::Token(sampled_token)) => sampled_token, - Ok(SamplingOutcome::AllCandidatesEliminated) => { - error!( - "{:?}: sequence {} sampling exhausted candidates", - self.scheduler_context.agent_name, active_request.sequence_id - ); - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::SamplerError( - "all token candidates were eliminated during sampling".to_owned(), - ), - ); - continue; - } - Ok(SamplingOutcome::GrammarRejectedModelOutput(message)) => { - error!( - "{:?}: sequence {} grammar rejected sampled token: {message}", - self.scheduler_context.agent_name, active_request.sequence_id - ); - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::GrammarRejectedModelOutput(message), - ); - continue; - } - Err(err) => { - error!( - "{:?}: sequence {} sampling error: {err:#}", - self.scheduler_context.agent_name, active_request.sequence_id - ); - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::SamplerError(err.to_string()), - ); - continue; - } - }; - - if self.scheduler_context.model.is_eog_token(sampled_token) { - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::Done, - ); - continue; - } - - let output_string = match self.scheduler_context.model.token_to_piece( - sampled_token, - &mut active_request.utf8_decoder, - true, - None, - ) { - Ok(output_string) => output_string, - Err(err) => { - error!( - "{:?}: sequence {} token_to_piece failed: {err}", - self.scheduler_context.agent_name, active_request.sequence_id - ); - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::SamplerError(format!( - "Failed to convert token to string: {err}" - )), - ); - continue; - } - }; - - if active_request - .generated_tokens_tx - .send(GeneratedTokenResult::Token(output_string)) - .is_err() - { - warn!( - "{:?}: sequence {} client disconnected (receiver dropped)", - self.scheduler_context.agent_name, active_request.sequence_id - ); - - active_request.i_batch = None; - active_request.phase = ContinuousBatchRequestPhase::Completed; - - continue; - } - - active_request.generated_tokens_count += 1; - - if active_request.generated_tokens_count >= active_request.max_tokens { - active_request.complete_with_outcome( - &self.scheduler_context.agent_name, - GeneratedTokenResult::Done, - ); - continue; - } - - active_request.pending_sampled_token = Some(sampled_token); - } - } - - fn add_generating_pending_tokens_to_batch( - &self, - batch: &mut LlamaBatch, - batch_n_tokens: usize, - contributions: &mut Vec, - ) -> Result { - let mut tokens_added: usize = 0; - - for (request_index, active_request) in self.active_requests.iter().enumerate() { - if !matches!( - active_request.phase, - ContinuousBatchRequestPhase::Generating - ) { - continue; - } - - let Some(pending_token) = active_request.pending_sampled_token else { - continue; - }; - - if tokens_added >= batch_n_tokens { - break; - } - - let batch_position = batch.n_tokens(); - - batch.add( - pending_token, - active_request.current_token_position, - &[active_request.sequence_id], - true, - )?; - - contributions.push(GeneratingContribution { - request_index, - batch_position, - }); - - tokens_added += 1; - } - - Ok(tokens_added) - } - - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "token counts and positions fit in i32 for llama.cpp FFI" - )] - fn add_ingesting_prompt_chunks_to_batch( - &self, - batch: &mut LlamaBatch, - batch_n_tokens: usize, - mut current_batch_token_count: usize, - contributions: &mut Vec, - ) -> Result<()> { - for (request_index, active_request) in self.active_requests.iter().enumerate() { - if !matches!(active_request.phase, ContinuousBatchRequestPhase::Ingesting) { - continue; - } - - let remaining = active_request.remaining_prompt_tokens(); - let available_space = batch_n_tokens.saturating_sub(current_batch_token_count); - let chunk_size = remaining.len().min(available_space); - - if chunk_size == 0 { - continue; - } - - let chunk = &active_request.prompt_tokens[active_request.prompt_tokens_ingested - ..active_request.prompt_tokens_ingested + chunk_size]; - let is_last_chunk = active_request.prompt_tokens_ingested + chunk_size - >= active_request.prompt_tokens.len(); - - for (offset, token) in chunk.iter().enumerate() { - let position = active_request.current_token_position + offset as i32; - let is_last_token_of_prompt = is_last_chunk && offset == chunk_size - 1; - - batch.add( - *token, - position, - &[active_request.sequence_id], - is_last_token_of_prompt, - )?; - } - - contributions.push(IngestingContribution { - request_index, - chunk_size, - is_last_chunk, - last_batch_position: batch.n_tokens() - 1, - }); - - current_batch_token_count += chunk_size; - } - - Ok(()) - } - - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "chunk sizes fit in i32 for llama.cpp position arithmetic" - )] - fn commit_contributions( - &mut self, - generating_contributions: &[GeneratingContribution], - ingesting_contributions: &[IngestingContribution], - ) { - for contribution in generating_contributions { - let request = &mut self.active_requests[contribution.request_index]; - - request.pending_sampled_token = None; - request.i_batch = Some(contribution.batch_position); - request.current_token_position += 1; - } - - for contribution in ingesting_contributions { - let request = &mut self.active_requests[contribution.request_index]; - - request.prompt_tokens_ingested += contribution.chunk_size; - request.current_token_position += contribution.chunk_size as i32; - - if contribution.is_last_chunk { - request.i_batch = Some(contribution.last_batch_position); - request.phase = ContinuousBatchRequestPhase::Generating; - } + AdvanceGeneratingPhase { + scheduler_context: &self.scheduler_context, + llama_context: &self.llama_context, } + .run(&mut self.active_requests); } fn evict_largest_sequence(&mut self) { @@ -1151,13 +1101,14 @@ impl ContinuousBatchScheduler { } self.sequence_id_pool.release(removed_request.sequence_id); - self.slot_aggregated_status.release_slot(); + + let usage = removed_request.token_classifier.usage(); debug!( - "{:?}: cleaned up sequence {} ({} tokens generated)", + "{:?}: cleaned up sequence {} ({} completion tokens generated)", self.scheduler_context.agent_name, removed_request.sequence_id, - removed_request.generated_tokens_count, + usage.content_tokens + usage.reasoning_tokens + usage.undeterminable_tokens, ); } } diff --git a/paddler/src/agent/continuous_batch_scheduler/sample_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/sample_outcome.rs new file mode 100644 index 00000000..67a7d951 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/sample_outcome.rs @@ -0,0 +1,8 @@ +use llama_cpp_bindings::token::LlamaToken; + +pub enum SampleOutcome { + Sampled(LlamaToken), + AllCandidatesEliminated, + GrammarRejected(String), + Failed(String), +} diff --git a/paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs b/paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs new file mode 100644 index 00000000..14f9c921 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs @@ -0,0 +1,32 @@ +use llama_cpp_bindings::context::LlamaContext; + +use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::agent::continuous_batch_scheduler::sample_outcome::SampleOutcome; +use crate::agent::sample_token_at_batch_index::sample_token_at_batch_index; +use crate::agent::sampling_outcome::SamplingOutcome; + +pub struct SampleTokenPhase<'context> { + pub context: &'context LlamaContext<'context>, +} + +impl SampleTokenPhase<'_> { + pub fn run( + &self, + request: &mut ContinuousBatchActiveRequest, + batch_index: i32, + ) -> SampleOutcome { + match sample_token_at_batch_index( + self.context, + batch_index, + &mut request.chain, + &mut request.grammar_sampler, + ) { + Ok(SamplingOutcome::Token(token)) => SampleOutcome::Sampled(token), + Ok(SamplingOutcome::AllCandidatesEliminated) => SampleOutcome::AllCandidatesEliminated, + Ok(SamplingOutcome::GrammarRejectedModelOutput(message)) => { + SampleOutcome::GrammarRejected(message) + } + Err(err) => SampleOutcome::Failed(err.to_string()), + } + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs b/paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs new file mode 100644 index 00000000..4373c1c7 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs @@ -0,0 +1,72 @@ +use llama_cpp_bindings::SampledToken; +use paddler_types::generated_token_result::GeneratedTokenResult; + +use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; +use crate::tool_call_pipeline::ToolCallPipeline; + +#[must_use] +pub fn run( + pipeline: Option<&mut ToolCallPipeline>, + classified: &ClassifiedToken, +) -> Option { + let pipeline = pipeline?; + + if matches!(classified.sampled_token, SampledToken::ToolCall(_)) { + pipeline.feed(&classified.raw_piece); + } + + if !classified.was_in_tool_call || classified.is_in_tool_call { + return None; + } + + pipeline.finalize_to_generated_event() +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::SampledToken; + use llama_cpp_bindings::token::LlamaToken; + + use super::run; + use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; + + fn classified(was: bool, is: bool, sampled: SampledToken) -> ClassifiedToken { + ClassifiedToken { + sampled_token: sampled, + was_in_tool_call: was, + is_in_tool_call: is, + visible_piece: String::new(), + raw_piece: String::new(), + } + } + + #[test] + fn pipeline_none_returns_none_for_content_token() { + let result = run( + None, + &classified(false, false, SampledToken::Content(LlamaToken::new(1))), + ); + + assert!(result.is_none()); + } + + #[test] + fn pipeline_none_returns_none_for_tool_call_token() { + let result = run( + None, + &classified(true, true, SampledToken::ToolCall(LlamaToken::new(2))), + ); + + assert!(result.is_none()); + } + + #[test] + fn pipeline_none_returns_none_on_transition_out() { + let result = run( + None, + &classified(true, false, SampledToken::ToolCall(LlamaToken::new(3))), + ); + + assert!(result.is_none()); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs new file mode 100644 index 00000000..5304d131 --- /dev/null +++ b/paddler/src/agent/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs @@ -0,0 +1,7 @@ +use crate::tool_call_pipeline::ToolCallPipeline; + +pub enum ToolCallPipelineBuildOutcome { + Disabled, + Ready(ToolCallPipeline), + SchemaInvalid(String), +} diff --git a/paddler/src/agent/from_request_params.rs b/paddler/src/agent/from_request_params.rs index 6e7d79b4..db9371d2 100644 --- a/paddler/src/agent/from_request_params.rs +++ b/paddler/src/agent/from_request_params.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + use tokio::sync::mpsc; use crate::agent::jsonrpc::response::Response; +use crate::slot_aggregated_status::SlotAggregatedStatus; pub trait FromRequestParams: Send + Sync { type RequestParams; @@ -10,5 +13,6 @@ pub trait FromRequestParams: Send + Sync { params: Self::RequestParams, response_tx: mpsc::UnboundedSender, stop_rx: mpsc::UnboundedReceiver<()>, + slot_aggregated_status: Arc, ) -> Self; } diff --git a/paddler/src/agent/generate_embedding_batch_request.rs b/paddler/src/agent/generate_embedding_batch_request.rs index 112b6228..f7139022 100644 --- a/paddler/src/agent/generate_embedding_batch_request.rs +++ b/paddler/src/agent/generate_embedding_batch_request.rs @@ -1,13 +1,18 @@ +use std::sync::Arc; + use paddler_types::embedding_result::EmbeddingResult; use paddler_types::request_params::GenerateEmbeddingBatchParams; use tokio::sync::mpsc; use crate::agent::from_request_params::FromRequestParams; +use crate::agent::slot_guard::SlotGuard; +use crate::slot_aggregated_status::SlotAggregatedStatus; pub struct GenerateEmbeddingBatchRequest { pub generate_embedding_stop_rx: mpsc::UnboundedReceiver<()>, pub generated_embedding_tx: mpsc::UnboundedSender, pub params: GenerateEmbeddingBatchParams, + pub slot_guard: SlotGuard, } impl FromRequestParams for GenerateEmbeddingBatchRequest { @@ -18,11 +23,13 @@ impl FromRequestParams for GenerateEmbeddingBatchRequest { params: Self::RequestParams, generated_embedding_tx: mpsc::UnboundedSender, generate_embedding_stop_rx: mpsc::UnboundedReceiver<()>, + slot_aggregated_status: Arc, ) -> Self { Self { generate_embedding_stop_rx, generated_embedding_tx, params, + slot_guard: SlotGuard::new(slot_aggregated_status), } } } diff --git a/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs b/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs index 00060ea4..99876259 100644 --- a/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs +++ b/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs @@ -1,8 +1,7 @@ +use paddler_types::agent_desired_state::AgentDesiredState; use serde::Deserialize; use serde::Serialize; -use crate::agent_desired_state::AgentDesiredState; - #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct SetStateParams { diff --git a/paddler/src/agent/jsonrpc/request.rs b/paddler/src/agent/jsonrpc/request.rs index 72243fb8..88d95652 100644 --- a/paddler/src/agent/jsonrpc/request.rs +++ b/paddler/src/agent/jsonrpc/request.rs @@ -1,9 +1,9 @@ use serde::Deserialize; use serde::Serialize; -use paddler_types::request_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::ContinueFromRawPromptParams; use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; #[derive(Deserialize, Serialize)] diff --git a/paddler/src/agent/llamacpp_arbiter_service.rs b/paddler/src/agent/llamacpp_arbiter_service.rs index b3cf27df..8d60fb56 100644 --- a/paddler/src/agent/llamacpp_arbiter_service.rs +++ b/paddler/src/agent/llamacpp_arbiter_service.rs @@ -2,15 +2,11 @@ use std::sync::Arc; use anyhow::Context as _; use anyhow::Result; -use anyhow::anyhow; use async_trait::async_trait; use log::error; use log::info; use log::warn; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ModelPath; use paddler_types::agent_state_application_status::AgentStateApplicationStatus; -use tokio::fs; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::MissedTickBehavior; @@ -20,6 +16,7 @@ use tokio_util::sync::CancellationToken; use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; use crate::agent::continuous_batch_arbiter::ContinuousBatchArbiter; +use crate::agent::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; use crate::agent::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; use crate::agent::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; use crate::agent::drain_in_flight_requests::drain_in_flight_requests; @@ -27,7 +24,6 @@ use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchReques use crate::agent::model_metadata_holder::ModelMetadataHolder; use crate::agent_applicable_state::AgentApplicableState; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; -use crate::agent_issue_fix::AgentIssueFix; use crate::service::Service; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; @@ -47,127 +43,29 @@ pub struct LlamaCppArbiterService { impl LlamaCppArbiterService { async fn apply_state(&mut self, shutdown: &CancellationToken) -> Result<()> { - if self.continuous_batch_arbiter_handle.is_some() { - drain_in_flight_requests(&self.slot_aggregated_status_manager, shutdown).await?; - } - - if let Some(arbiter_handle) = self.continuous_batch_arbiter_handle.take() { - arbiter_handle - .shutdown() - .context("Unable to stop arbiter controller")?; - } + self.wait_for_in_flight_requests_to_finish(shutdown).await?; + self.tear_down_arbiter()?; - if let Some(AgentApplicableState { - chat_template_override, - inference_parameters, - multimodal_projection_path, - model_path, - }) = self.agent_applicable_state.clone() - { + if let Some(applicable_state) = self.agent_applicable_state.clone() { self.slot_aggregated_status_manager.reset(); - if let Some(model_path) = model_path { - if !fs::try_exists(&model_path).await? { - self.slot_aggregated_status_manager - .slot_aggregated_status - .register_issue(AgentIssue::ModelFileDoesNotExist(ModelPath { - model_path: model_path.display().to_string(), - })); - - return Err(anyhow!( - "Model path does not exist: {}", - model_path.display() - )); + match ContinuousBatchArbiter::build_from_applicable_state( + applicable_state, + self.agent_name.clone(), + self.desired_slots_total, + self.model_metadata_holder.clone(), + self.slot_aggregated_status_manager.clone(), + ) { + ContinuousBatchArbiterBuildOutcome::ReadyToSpawn(arbiter) => { + self.continuous_batch_arbiter_handle = Some(arbiter.spawn().await?); + info!("Reconciled state change applied successfully"); } - - let model_path_string = model_path.display().to_string(); - - if self - .slot_aggregated_status_manager - .slot_aggregated_status - .has_issue(&AgentIssue::UnableToFindChatTemplate(ModelPath { - model_path: model_path_string.clone(), - })) - { - self.slot_aggregated_status_manager - .slot_aggregated_status - .set_state_application_status( - AgentStateApplicationStatus::AttemptedAndNotAppliable, - ); - - return Err(anyhow!( - "Unable to establish chat template for model at path: {model_path_string}" - )); - } - - if self - .slot_aggregated_status_manager - .slot_aggregated_status - .has_issue_like(|issue| { - matches!(issue, AgentIssue::ChatTemplateDoesNotCompile(_)) - }) - { - self.slot_aggregated_status_manager - .slot_aggregated_status - .set_state_application_status( - AgentStateApplicationStatus::AttemptedAndNotAppliable, - ); - - return Err(anyhow!( - "Chat template does not compile for model at path: {model_path_string}" - )); - } - - if self - .slot_aggregated_status_manager - .slot_aggregated_status - .has_issue_like(|issue| { - matches!(issue, AgentIssue::MultimodalProjectionCannotBeLoaded(_)) - }) - { - self.slot_aggregated_status_manager - .slot_aggregated_status - .set_state_application_status( - AgentStateApplicationStatus::AttemptedAndNotAppliable, - ); - - return Err(anyhow!( - "Multimodal projection cannot be loaded: {}", - multimodal_projection_path.map_or_else( - || "*cannot establish path*".to_owned(), - |multimodal_projection_path| multimodal_projection_path - .display() - .to_string() - ) - )); + ContinuousBatchArbiterBuildOutcome::NoModelConfigured => { + warn!( + "No model configured in applicable state; skipping llama.cpp initialization" + ); } - - self.slot_aggregated_status_manager - .slot_aggregated_status - .register_fix(&AgentIssueFix::ModelFileExists(ModelPath { - model_path: model_path_string.clone(), - })); - - self.continuous_batch_arbiter_handle = Some( - ContinuousBatchArbiter { - agent_name: self.agent_name.clone(), - chat_template_override, - desired_slots_total: self.desired_slots_total, - inference_parameters, - multimodal_projection_path, - model_metadata_holder: self.model_metadata_holder.clone(), - model_path, - model_path_string, - slot_aggregated_status_manager: self.slot_aggregated_status_manager.clone(), - } - .spawn() - .await?, - ); - } else { - warn!("Model path is not set, skipping llama.cpp initialization"); } - - info!("Reconciled state change applied successfully"); } self.slot_aggregated_status_manager @@ -177,6 +75,27 @@ impl LlamaCppArbiterService { Ok(()) } + async fn wait_for_in_flight_requests_to_finish( + &self, + shutdown: &CancellationToken, + ) -> Result<()> { + if self.continuous_batch_arbiter_handle.is_some() { + drain_in_flight_requests(&self.slot_aggregated_status_manager, shutdown).await?; + } + + Ok(()) + } + + fn tear_down_arbiter(&mut self) -> Result<()> { + if let Some(arbiter_handle) = self.continuous_batch_arbiter_handle.take() { + arbiter_handle + .shutdown() + .context("Unable to stop arbiter controller")?; + } + + Ok(()) + } + fn forward_command(&self, command: ContinuousBatchSchedulerCommand) { if let Some(arbiter_handle) = &self.continuous_batch_arbiter_handle { if let Err(err) = arbiter_handle.command_tx.send(command) { diff --git a/paddler/src/agent/management_socket_client_service.rs b/paddler/src/agent/management_socket_client_service.rs index 49762e90..155f2f04 100644 --- a/paddler/src/agent/management_socket_client_service.rs +++ b/paddler/src/agent/management_socket_client_service.rs @@ -19,6 +19,7 @@ use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_util::sync::CancellationToken; +use paddler_types::agent_desired_state::AgentDesiredState; use paddler_types::jsonrpc::Error as JsonRpcError; use paddler_types::jsonrpc::ErrorEnvelope; use paddler_types::jsonrpc::RequestEnvelope; @@ -36,7 +37,6 @@ use crate::agent::jsonrpc::notification_params::VersionParams; use crate::agent::model_metadata_holder::ModelMetadataHolder; use crate::agent::receive_stream_stopper_collection::ReceiveStreamStopperCollection; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; -use crate::agent_desired_state::AgentDesiredState; use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Message as ManagementJsonRpcMessage; use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Notification as ManagementJsonRpcNotification; use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::RegisterAgentParams; @@ -57,6 +57,7 @@ struct IncomingMessageContext { model_metadata_holder: Arc, receive_stream_stopper_collection: Arc, message_tx: mpsc::UnboundedSender, + slot_aggregated_status: Arc, } pub struct ManagementSocketClientService { @@ -81,6 +82,7 @@ impl ManagementSocketClientService { request_params: TRequest::RequestParams, receive_stream_stopper_collection: Arc, request_tx: mpsc::UnboundedSender, + slot_aggregated_status: Arc, ) -> Result<()> { let (response_tx, mut response_rx) = mpsc::unbounded_channel::(); let (stop_tx, stop_rx) = mpsc::unbounded_channel::<()>(); @@ -93,6 +95,7 @@ impl ManagementSocketClientService { request_params, response_tx, stop_rx, + slot_aggregated_status, ))?; loop { @@ -104,6 +107,7 @@ impl ManagementSocketClientService { message_tx.send( ManagementJsonRpcMessage::Response( ResponseEnvelope { + generated_by: None, request_id: id.clone(), response: response.into(), } @@ -130,6 +134,7 @@ impl ManagementSocketClientService { message_tx, model_metadata_holder, receive_stream_stopper_collection, + slot_aggregated_status, }: IncomingMessageContext, deserialized_message: JsonRpcMessage, ) -> Result<()> { @@ -185,6 +190,7 @@ impl ManagementSocketClientService { continue_from_conversation_history_params, receive_stream_stopper_collection, continue_from_conversation_history_request_tx, + slot_aggregated_status, ) .await } @@ -199,6 +205,7 @@ impl ManagementSocketClientService { generate_tokens_params, receive_stream_stopper_collection, continue_from_raw_prompt_request_tx, + slot_aggregated_status, ) .await } @@ -213,6 +220,7 @@ impl ManagementSocketClientService { generate_embedding_batch_params, receive_stream_stopper_collection, generate_embedding_batch_request_tx, + slot_aggregated_status, ) .await } @@ -221,6 +229,7 @@ impl ManagementSocketClientService { request: JsonRpcRequest::GetChatTemplateOverride, }) => Ok( message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { + generated_by: None, request_id: id, response: JsonRpcResponse::ChatTemplateOverride( if let Some(agent_applicable_state) = @@ -238,6 +247,7 @@ impl ManagementSocketClientService { request: JsonRpcRequest::GetModelMetadata, }) => Ok( message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { + generated_by: None, request_id: id, response: JsonRpcResponse::ModelMetadata( model_metadata_holder.get_model_metadata(), @@ -445,6 +455,7 @@ impl ManagementSocketClientService { model_metadata_holder: self.model_metadata_holder.clone(), receive_stream_stopper_collection: self.receive_stream_stopper_collection.clone(), message_tx: message_tx.clone(), + slot_aggregated_status: self.slot_aggregated_status.clone(), }, msg, &pong_tx, diff --git a/paddler/src/agent/mod.rs b/paddler/src/agent/mod.rs index 0f5a7789..44bcaa90 100644 --- a/paddler/src/agent/mod.rs +++ b/paddler/src/agent/mod.rs @@ -2,6 +2,7 @@ pub mod continue_from_conversation_history_request; pub mod continue_from_raw_prompt_request; pub mod continuous_batch_active_request; pub mod continuous_batch_arbiter; +pub mod continuous_batch_arbiter_build_outcome; pub mod continuous_batch_arbiter_handle; pub mod continuous_batch_embedding_processor; pub mod continuous_batch_request_phase; @@ -28,3 +29,4 @@ pub mod resolved_grammar; pub mod sample_token_at_batch_index; pub mod sampling_outcome; pub mod sequence_id_pool; +pub mod slot_guard; diff --git a/paddler/src/agent/plan_embedding_batches.rs b/paddler/src/agent/plan_embedding_batches.rs index 0172c3be..d481cf07 100644 --- a/paddler/src/agent/plan_embedding_batches.rs +++ b/paddler/src/agent/plan_embedding_batches.rs @@ -3,7 +3,7 @@ use std::ops::Range; #[must_use] pub fn plan_embedding_batches( token_counts: &[usize], - batch_n_tokens: usize, + n_batch: usize, max_sequences_per_batch: i32, ) -> Vec> { let mut batches = Vec::new(); @@ -12,7 +12,7 @@ pub fn plan_embedding_batches( let mut current_sequences: i32 = 0; for (index, &token_count) in token_counts.iter().enumerate() { - let would_exceed_tokens = current_tokens + token_count > batch_n_tokens; + let would_exceed_tokens = current_tokens + token_count > n_batch; let would_exceed_sequences = current_sequences >= max_sequences_per_batch; if (would_exceed_tokens || would_exceed_sequences) && current_sequences > 0 { diff --git a/paddler/src/agent/prepare_conversation_history_request.rs b/paddler/src/agent/prepare_conversation_history_request.rs index ac49edb1..87be81d9 100644 --- a/paddler/src/agent/prepare_conversation_history_request.rs +++ b/paddler/src/agent/prepare_conversation_history_request.rs @@ -6,7 +6,7 @@ use log::warn; use minijinja::context; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::media_marker::MediaMarker; -use paddler_types::request_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use tokio::sync::mpsc; @@ -23,6 +23,7 @@ pub fn prepare_conversation_history_request( grammar, conversation_history, max_tokens, + parse_tool_calls, tools, }: ContinueFromConversationHistoryParams, generated_tokens_tx: &mpsc::UnboundedSender, @@ -37,8 +38,7 @@ pub fn prepare_conversation_history_request( .iter() .map(|image_url| { DecodedImage::from_data_uri(image_url) - .and_then(|image| image.converted_to_png_if_necessary(image_resize_to_fit)) - .and_then(|image| image.resized_to_fit(image_resize_to_fit)) + .and_then(|image| image.prepared_for_inference(image_resize_to_fit)) }) .collect::, DecodedImageError>>() .map_err(|err| { @@ -129,6 +129,8 @@ pub fn prepare_conversation_history_request( images, max_tokens, grammar_sampler, + parse_tool_calls, + tools, }); } @@ -136,5 +138,7 @@ pub fn prepare_conversation_history_request( raw_prompt, max_tokens, grammar_sampler, + parse_tool_calls, + tools, }) } diff --git a/paddler/src/agent/prepared_conversation_history_request.rs b/paddler/src/agent/prepared_conversation_history_request.rs index 5485feb3..61092da5 100644 --- a/paddler/src/agent/prepared_conversation_history_request.rs +++ b/paddler/src/agent/prepared_conversation_history_request.rs @@ -1,3 +1,6 @@ +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + use crate::agent::grammar_sampler::GrammarSampler; use crate::decoded_image::DecodedImage; @@ -6,11 +9,15 @@ pub enum PreparedConversationHistoryRequest { raw_prompt: String, max_tokens: i32, grammar_sampler: Option, + parse_tool_calls: bool, + tools: Vec>, }, MultimodalPrompt { raw_prompt: String, images: Vec, max_tokens: i32, grammar_sampler: Option, + parse_tool_calls: bool, + tools: Vec>, }, } diff --git a/paddler/src/agent/reconciliation_service.rs b/paddler/src/agent/reconciliation_service.rs index 08980c6f..e4a14574 100644 --- a/paddler/src/agent/reconciliation_service.rs +++ b/paddler/src/agent/reconciliation_service.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use log::error; +use paddler_types::agent_desired_state::AgentDesiredState; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::MissedTickBehavior; @@ -10,7 +11,6 @@ use tokio::time::interval; use tokio_util::sync::CancellationToken; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; -use crate::agent_desired_state::AgentDesiredState; use crate::agent_issue_fix::AgentIssueFix; use crate::converts_to_applicable_state::ConvertsToApplicableState as _; use crate::service::Service; @@ -28,11 +28,11 @@ impl ReconciliationService { pub async fn convert_to_applicable_state(&mut self) -> Result<()> { let applicable_state = match &self.agent_desired_state { None => None, - Some(agent_desired_state) => { + Some(agent_desired_state) => Some( agent_desired_state .to_applicable_state(self.slot_aggregated_status.clone()) - .await? - } + .await?, + ), }; self.is_converted_to_applicable_state = true; diff --git a/paddler/src/agent/slot_guard.rs b/paddler/src/agent/slot_guard.rs new file mode 100644 index 00000000..2503739a --- /dev/null +++ b/paddler/src/agent/slot_guard.rs @@ -0,0 +1,106 @@ +use std::sync::Arc; + +use crate::dispenses_slots::DispensesSlots as _; +use crate::slot_aggregated_status::SlotAggregatedStatus; + +pub struct SlotGuard { + slot_aggregated_status: Arc, +} + +impl SlotGuard { + #[must_use] + pub fn new(slot_aggregated_status: Arc) -> Self { + slot_aggregated_status.take_slot(); + + Self { + slot_aggregated_status, + } + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + self.slot_aggregated_status.release_slot(); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use anyhow::Result; + use tokio_util::sync::CancellationToken; + + use crate::agent::drain_in_flight_requests::drain_in_flight_requests; + use crate::agent::slot_guard::SlotGuard; + use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; + + #[tokio::test] + async fn increments_slot_on_construct_and_releases_on_drop() -> Result<()> { + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(4)); + + assert_eq!( + slot_aggregated_status_manager + .slot_aggregated_status + .slots_processing_count(), + 0 + ); + + { + let _guard = SlotGuard::new( + slot_aggregated_status_manager + .slot_aggregated_status + .clone(), + ); + + assert_eq!( + slot_aggregated_status_manager + .slot_aggregated_status + .slots_processing_count(), + 1 + ); + } + + assert_eq!( + slot_aggregated_status_manager + .slot_aggregated_status + .slots_processing_count(), + 0 + ); + + Ok(()) + } + + #[tokio::test] + async fn drain_in_flight_requests_blocks_until_guard_dropped() -> Result<()> { + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(4)); + let shutdown = CancellationToken::new(); + + let guard = SlotGuard::new( + slot_aggregated_status_manager + .slot_aggregated_status + .clone(), + ); + + let manager_for_drain = slot_aggregated_status_manager.clone(); + let shutdown_for_drain = shutdown.clone(); + let mut drain_task = tokio::spawn(async move { + drain_in_flight_requests(&manager_for_drain, &shutdown_for_drain).await + }); + + let blocking_window = Duration::from_millis(50); + let timeout_result = tokio::time::timeout(blocking_window, &mut drain_task).await; + assert!( + timeout_result.is_err(), + "drain_in_flight_requests returned while a SlotGuard was still held" + ); + + drop(guard); + + let unblock_window = Duration::from_millis(500); + tokio::time::timeout(unblock_window, drain_task).await???; + + Ok(()) + } +} diff --git a/paddler/src/agent_desired_model.rs b/paddler/src/agent_desired_model.rs deleted file mode 100644 index 5455eeed..00000000 --- a/paddler/src/agent_desired_model.rs +++ /dev/null @@ -1,147 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use anyhow::Result; -use anyhow::anyhow; -use async_trait::async_trait; -use hf_hub::Cache; -use hf_hub::Repo; -use hf_hub::RepoType; -use hf_hub::api::tokio::ApiBuilder; -use hf_hub::api::tokio::ApiError; -use log::warn; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::HuggingFaceDownloadLock; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; -use tokio::time::Duration; -use tokio::time::sleep; - -use crate::agent_issue_fix::AgentIssueFix; -use crate::converts_to_applicable_state::ConvertsToApplicableState; -use crate::slot_aggregated_status::SlotAggregatedStatus; -use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloadProgress; - -const LOCK_RETRY_TIMEOUT: Duration = Duration::from_secs(10); - -#[async_trait] -impl ConvertsToApplicableState for AgentDesiredModel { - type ApplicableState = PathBuf; - type Context = Arc; - - async fn to_applicable_state( - &self, - slot_aggregated_status: Self::Context, - ) -> Result> { - Ok(match self { - Self::HuggingFace(HuggingFaceModelReference { - filename, - repo_id, - revision, - }) => { - let model_path = format!("{repo_id}/{revision}/{filename}"); - - if slot_aggregated_status.has_issue(&AgentIssue::HuggingFaceModelDoesNotExist( - ModelPath { - model_path: model_path.clone(), - }, - )) { - return Err(anyhow!( - "Model '{model_path}' does not exist on Hugging Face. Not attempting to download it again." - )); - } - - let hf_cache = Cache::from_env(); - let hf_api = ApiBuilder::from_cache(hf_cache.clone()).build()?; - let hf_repo = hf_api.repo(Repo::with_revision( - repo_id.to_owned(), - RepoType::Model, - revision.to_owned(), - )); - - if let Some(cached_path) = hf_cache - .repo(Repo::new(repo_id.to_owned(), RepoType::Model)) - .get(filename) - { - slot_aggregated_status.reset_download(); - - return Ok(Some(cached_path)); - } - - let weights_filename = match hf_repo - .download_with_progress( - filename, - SlotAggregatedStatusDownloadProgress::new(slot_aggregated_status.clone()), - ) - .await - { - Ok(resolved_filename) => { - slot_aggregated_status.register_fix( - &AgentIssueFix::HuggingFaceDownloadedModel(ModelPath { model_path }), - ); - - resolved_filename - } - Err(ApiError::LockAcquisition(lock_path)) => { - slot_aggregated_status.register_issue( - AgentIssue::HuggingFaceCannotAcquireLock(HuggingFaceDownloadLock { - lock_path: lock_path.display().to_string(), - model_path: ModelPath { model_path }, - }), - ); - - warn!( - "Waiting to acquire download lock for '{}'. Sleeping for {} secs", - lock_path.display(), - LOCK_RETRY_TIMEOUT.as_secs() - ); - - sleep(LOCK_RETRY_TIMEOUT).await; - - return Err(anyhow!( - "Failed to acquire download lock '{}'. Is more than one agent running on this machine?", - lock_path.display() - )); - } - Err(ApiError::RequestError(reqwest_error)) => match reqwest_error.status() { - Some(reqwest::StatusCode::NOT_FOUND) => { - slot_aggregated_status.register_issue( - AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { - model_path: model_path.clone(), - }), - ); - - return Err(anyhow!( - "Model '{model_path}' does not exist on Hugging Face." - )); - } - Some( - reqwest::StatusCode::FORBIDDEN | reqwest::StatusCode::UNAUTHORIZED, - ) => { - slot_aggregated_status.register_issue( - AgentIssue::HuggingFacePermissions(ModelPath { - model_path: model_path.clone(), - }), - ); - - return Err(anyhow!( - "You do not have enough permissions to download '{model_path}' from Hugging Face." - )); - } - _ => { - return Err(anyhow!( - "Failed to download model from Hugging Face: {reqwest_error}" - )); - } - }, - Err(err_other) => return Err(err_other.into()), - }; - - Some(weights_filename) - } - Self::LocalToAgent(path) => Some(PathBuf::from(path)), - Self::None => None, - }) - } -} diff --git a/paddler/src/agent_desired_state.rs b/paddler/src/agent_desired_state.rs index a9ef0bc2..5313dd59 100644 --- a/paddler/src/agent_desired_state.rs +++ b/paddler/src/agent_desired_state.rs @@ -1,13 +1,43 @@ +use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; +use anyhow::anyhow; use async_trait::async_trait; -pub use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::agent_issue_params::ModelPath; use crate::agent_applicable_state::AgentApplicableState; use crate::converts_to_applicable_state::ConvertsToApplicableState; +use crate::desired_model_resolution::DesiredModelResolution; +use crate::resolve_desired_model::resolve_desired_model; use crate::slot_aggregated_status::SlotAggregatedStatus; +async fn resolve_into_optional_path( + desired: &AgentDesiredModel, + slot_aggregated_status: &Arc, + on_local_missing: TLocalMissingIssue, +) -> Result> +where + TLocalMissingIssue: FnOnce(ModelPath) -> AgentIssue, +{ + match resolve_desired_model(desired, slot_aggregated_status.clone()).await? { + DesiredModelResolution::NotConfigured => Ok(None), + DesiredModelResolution::Resolved(path) => Ok(Some(path)), + DesiredModelResolution::LocalFileMissing(path) => { + let model_path_string = path.display().to_string(); + + slot_aggregated_status.register_issue(on_local_missing(ModelPath { + model_path: model_path_string.clone(), + })); + + Err(anyhow!("Local file does not exist: {model_path_string}")) + } + } +} + #[async_trait] impl ConvertsToApplicableState for AgentDesiredState { type ApplicableState = AgentApplicableState; @@ -16,21 +46,129 @@ impl ConvertsToApplicableState for AgentDesiredState { async fn to_applicable_state( &self, slot_aggregated_status: Self::Context, - ) -> Result> { - let model_path = self - .model - .to_applicable_state(slot_aggregated_status.clone()) - .await?; - let multimodal_projection_path = self - .multimodal_projection - .to_applicable_state(slot_aggregated_status) - .await?; - - Ok(Some(AgentApplicableState { + ) -> Result { + let model_path = resolve_into_optional_path( + &self.model, + &slot_aggregated_status, + AgentIssue::ModelFileDoesNotExist, + ) + .await?; + + let multimodal_projection_path = resolve_into_optional_path( + &self.multimodal_projection, + &slot_aggregated_status, + AgentIssue::MultimodalProjectionCannotBeLoaded, + ) + .await?; + + Ok(AgentApplicableState { chat_template_override: self.chat_template_override.clone(), inference_parameters: self.inference_parameters.clone(), model_path, multimodal_projection_path, - })) + }) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + use std::sync::Arc; + + use anyhow::Result; + use paddler_types::agent_desired_model::AgentDesiredModel; + use paddler_types::agent_desired_state::AgentDesiredState; + use paddler_types::agent_issue::AgentIssue; + use paddler_types::agent_issue_params::ModelPath; + use paddler_types::inference_parameters::InferenceParameters; + use tempfile::TempDir; + + use crate::converts_to_applicable_state::ConvertsToApplicableState; + use crate::slot_aggregated_status::SlotAggregatedStatus; + + fn fresh_status() -> Arc { + Arc::new(SlotAggregatedStatus::new(1)) + } + + fn nonexistent_path_in_temp_dir(label: &str) -> Result<(TempDir, PathBuf)> { + let dir = tempfile::tempdir()?; + let path = dir.path().join(format!("missing-{label}.gguf")); + + Ok((dir, path)) + } + + fn desired_state( + model: AgentDesiredModel, + multimodal_projection: AgentDesiredModel, + ) -> AgentDesiredState { + AgentDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model, + multimodal_projection, + } + } + + #[tokio::test] + async fn local_missing_model_registers_model_file_does_not_exist_and_errs() -> Result<()> { + let status = fresh_status(); + let (_dir_guard, missing_path) = nonexistent_path_in_temp_dir("model")?; + let desired = desired_state( + AgentDesiredModel::LocalToAgent(missing_path.display().to_string()), + AgentDesiredModel::None, + ); + + let outcome = desired.to_applicable_state(status.clone()).await; + + assert!( + outcome.is_err(), + "AgentDesiredState::to_applicable_state must Err when the model's local path is missing" + ); + assert!( + status.has_issue(&AgentIssue::ModelFileDoesNotExist(ModelPath { + model_path: missing_path.display().to_string(), + })), + "ModelFileDoesNotExist must be registered for a missing local model file" + ); + assert!( + !status.has_issue(&AgentIssue::MultimodalProjectionCannotBeLoaded(ModelPath { + model_path: missing_path.display().to_string(), + })), + "MultimodalProjectionCannotBeLoaded must NOT be registered for a missing model" + ); + + Ok(()) + } + + #[tokio::test] + async fn local_missing_multimodal_projection_registers_multimodal_projection_cannot_be_loaded_and_errs() + -> Result<()> { + let status = fresh_status(); + let (_dir_guard, missing_path) = nonexistent_path_in_temp_dir("projection")?; + let desired = desired_state( + AgentDesiredModel::None, + AgentDesiredModel::LocalToAgent(missing_path.display().to_string()), + ); + + let outcome = desired.to_applicable_state(status.clone()).await; + + assert!( + outcome.is_err(), + "AgentDesiredState::to_applicable_state must Err when the projection's local path is missing" + ); + assert!( + status.has_issue(&AgentIssue::MultimodalProjectionCannotBeLoaded(ModelPath { + model_path: missing_path.display().to_string(), + })), + "MultimodalProjectionCannotBeLoaded must be registered for a missing local projection file" + ); + assert!( + !status.has_issue(&AgentIssue::ModelFileDoesNotExist(ModelPath { + model_path: missing_path.display().to_string(), + })), + "ModelFileDoesNotExist must NOT be registered for a missing projection" + ); + + Ok(()) } } diff --git a/paddler/src/atomic_value.rs b/paddler/src/atomic_value.rs index 4b2a9ffa..9f0202d7 100644 --- a/paddler/src/atomic_value.rs +++ b/paddler/src/atomic_value.rs @@ -77,20 +77,6 @@ impl AtomicValue { true } } - - pub fn try_increment_below(&self, limit: i32) -> bool { - loop { - let current = self.get(); - - if current >= limit { - return false; - } - - if self.compare_and_swap(current, current + 1) { - return true; - } - } - } } impl AtomicValue { @@ -123,33 +109,3 @@ impl AtomicValue { } } } - -#[cfg(test)] -mod tests { - use super::AtomicValue; - use std::sync::atomic::AtomicI32; - - #[test] - fn try_increment_below_increments_when_below_limit() { - let value = AtomicValue::::new(0); - - assert!(value.try_increment_below(1)); - assert_eq!(value.get(), 1); - } - - #[test] - fn try_increment_below_refuses_at_limit() { - let value = AtomicValue::::new(1); - - assert!(!value.try_increment_below(1)); - assert_eq!(value.get(), 1); - } - - #[test] - fn try_increment_below_refuses_above_limit() { - let value = AtomicValue::::new(5); - - assert!(!value.try_increment_below(3)); - assert_eq!(value.get(), 5); - } -} diff --git a/paddler/src/balancer/agent_controller.rs b/paddler/src/balancer/agent_controller.rs index 11c7cb3b..6b852af8 100644 --- a/paddler/src/balancer/agent_controller.rs +++ b/paddler/src/balancer/agent_controller.rs @@ -13,11 +13,12 @@ use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_types::agent_desired_state::AgentDesiredState; use paddler_types::agent_issue::AgentIssue; use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::request_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::ContinueFromRawPromptParams; use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; @@ -25,7 +26,6 @@ use crate::agent::jsonrpc::Message as AgentJsonRpcMessage; use crate::agent::jsonrpc::Notification as AgentJsonRpcNotification; use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; use crate::agent::jsonrpc::notification_params::SetStateParams; -use crate::agent_desired_state::AgentDesiredState; use crate::atomic_value::AtomicValue; use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult; use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; @@ -165,18 +165,16 @@ impl AgentController { let mut changed = false; - changed = changed || self.desired_slots_total.set_check(desired_slots_total); - changed = changed || self.download_current.set_check(download_current); - changed = changed || self.download_total.set_check(download_total); - changed = changed || self.slots_total.set_check(slots_total); - changed = changed - || self - .state_application_status_code - .set_check(state_application_status as i32); - changed = changed - || self - .uses_chat_template_override - .set_check(uses_chat_template_override); + changed |= self.desired_slots_total.set_check(desired_slots_total); + changed |= self.download_current.set_check(download_current); + changed |= self.download_total.set_check(download_total); + changed |= self.slots_total.set_check(slots_total); + changed |= self + .state_application_status_code + .set_check(state_application_status as i32); + changed |= self + .uses_chat_template_override + .set_check(uses_chat_template_override); self.newest_update_version .compare_and_swap(newest_update_version, version); @@ -348,3 +346,95 @@ impl SetsDesiredState for AgentController { .await } } + +#[cfg(test)] +mod tests { + use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + + use super::*; + + fn fresh_agent_controller() -> AgentController { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + + AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-test".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(0), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + } + } + + #[test] + fn multi_field_update_stores_all_changed_atomic_fields() -> Result<()> { + let agent_controller = fresh_agent_controller(); + + let snapshot = SlotAggregatedStatusSnapshot { + desired_slots_total: 4, + download_current: 10, + download_filename: None, + download_total: 100, + issues: BTreeSet::new(), + model_path: None, + slots_processing: 0, + slots_total: 4, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: true, + version: 1, + }; + + let result = agent_controller.update_from_slot_aggregated_status_snapshot(snapshot); + + if !matches!(result, AgentControllerUpdateResult::Updated) { + anyhow::bail!("update with multiple changed fields must return Updated"); + } + + if agent_controller.desired_slots_total.get() != 4 { + anyhow::bail!( + "desired_slots_total must be stored: expected 4, got {}", + agent_controller.desired_slots_total.get() + ); + } + if agent_controller.download_current.get() != 10 { + anyhow::bail!( + "download_current must be stored: expected 10, got {}", + agent_controller.download_current.get() + ); + } + if agent_controller.download_total.get() != 100 { + anyhow::bail!( + "download_total must be stored: expected 100, got {}", + agent_controller.download_total.get() + ); + } + if agent_controller.slots_total.get() != 4 { + anyhow::bail!( + "slots_total must be stored: expected 4, got {}", + agent_controller.slots_total.get() + ); + } + if !agent_controller.uses_chat_template_override.get() { + anyhow::bail!("uses_chat_template_override must be stored: expected true, got false"); + } + + Ok(()) + } +} diff --git a/paddler/src/balancer/agent_controller_pool.rs b/paddler/src/balancer/agent_controller_pool.rs index 29ba23a0..8c4188ba 100644 --- a/paddler/src/balancer/agent_controller_pool.rs +++ b/paddler/src/balancer/agent_controller_pool.rs @@ -5,12 +5,13 @@ use async_trait::async_trait; use dashmap::DashMap; use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_types::agent_desired_state::AgentDesiredState; use tokio::sync::watch; use super::agent_controller::AgentController; use super::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; -use crate::agent_desired_state::AgentDesiredState; use crate::balancer::agent_controller_slot_guard::AgentControllerSlotGuard; +use crate::balancer::dispatch_candidate::DispatchCandidate; use crate::balancer::dispatched_agent::DispatchedAgent; use crate::produces_snapshot::ProducesSnapshot; use crate::sets_desired_state::SetsDesiredState; @@ -23,31 +24,60 @@ pub struct AgentControllerPool { impl AgentControllerPool { #[must_use] - pub fn take_least_busy_agent_controller(&self) -> Option { - let mut candidates: Vec> = self - .agents - .iter() - .map(|entry| entry.value().clone()) - .collect(); + pub fn select_least_busy_with_capacity(&self) -> Option { + let mut best: Option = None; - candidates.sort_by_key(|agent| agent.slots_processing.get()); + for entry in &self.agents { + let agent_controller = entry.value().clone(); + let snapshot = agent_controller.slots_processing.get(); - for agent_controller in candidates { - let limit = agent_controller.slots_total.get(); + if snapshot >= agent_controller.slots_total.get() { + continue; + } - if agent_controller.slots_processing.try_increment_below(limit) { - self.update_tx.send_replace(()); + best = Some(match best { + Some(current) if current.snapshot <= snapshot => current, + _ => DispatchCandidate { + agent_controller, + snapshot, + }, + }); + } - let slot_guard = AgentControllerSlotGuard::new( - agent_controller.clone(), - self.update_tx.clone(), - ); + best + } - return Some(DispatchedAgent::new(agent_controller, slot_guard)); - } + pub fn try_claim( + &self, + candidate: DispatchCandidate, + ) -> Result { + if candidate + .agent_controller + .slots_processing + .compare_and_swap(candidate.snapshot, candidate.snapshot + 1) + { + self.update_tx.send_replace(()); + + let slot_guard = AgentControllerSlotGuard::new( + candidate.agent_controller.clone(), + self.update_tx.clone(), + ); + + Ok(DispatchedAgent::new(candidate.agent_controller, slot_guard)) + } else { + Err(candidate) } + } + + #[must_use] + pub fn take_least_busy_agent_controller(&self) -> Option { + loop { + let candidate = self.select_least_busy_with_capacity()?; - None + if let Ok(dispatched) = self.try_claim(candidate) { + return Some(dispatched); + } + } } #[must_use] diff --git a/paddler/src/balancer/buffered_request_manager.rs b/paddler/src/balancer/buffered_request_manager.rs index c02d935f..b6429916 100644 --- a/paddler/src/balancer/buffered_request_manager.rs +++ b/paddler/src/balancer/buffered_request_manager.rs @@ -96,7 +96,25 @@ impl SubscribesToUpdates for BufferedRequestManager { #[cfg(test)] mod tests { + use std::collections::BTreeSet; + use std::sync::RwLock; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicUsize; + use std::task::Poll; + + use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + use super::*; + use crate::atomic_value::AtomicValue; + use crate::balancer::agent_controller::AgentController; + use crate::balancer::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; + use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; + use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; #[tokio::test] async fn counter_increment_wakes_subscribed_waiter() -> Result<()> { @@ -118,4 +136,121 @@ mod tests { Ok(()) } + + #[tokio::test(flavor = "current_thread")] + async fn waiter_returns_found_after_agent_registration_with_no_initial_agents() -> Result<()> { + let pool = Arc::new(AgentControllerPool::default()); + let manager = Arc::new(BufferedRequestManager::new( + pool.clone(), + Duration::from_secs(60), + 10, + )); + + let mut waiter = + tokio_test::task::spawn(async move { manager.wait_for_available_agent().await }); + + assert!( + waiter.poll().is_pending(), + "waiter must be Pending while pool has no agents" + ); + + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent = Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-1".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }); + + pool.register_agent_controller("agent-1".to_owned(), agent)?; + + assert!( + waiter.is_woken(), + "register_agent_controller must wake the subscribed waiter" + ); + + let Poll::Ready(result) = waiter.poll() else { + anyhow::bail!("waiter must be Ready after register_agent_controller, got Pending"); + }; + + if !matches!(result?, BufferedRequestAgentWaitResult::Found(_)) { + anyhow::bail!("waiter must return Found after register_agent_controller"); + } + + Ok(()) + } + + #[tokio::test(flavor = "current_thread")] + async fn waiter_returns_found_when_agent_was_registered_before_call() -> Result<()> { + let pool = Arc::new(AgentControllerPool::default()); + + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent = Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-pre".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }); + + pool.register_agent_controller("agent-pre".to_owned(), agent)?; + + let manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(60), + 10, + )); + + let mut waiter = + tokio_test::task::spawn(async move { manager.wait_for_available_agent().await }); + + let Poll::Ready(result) = waiter.poll() else { + anyhow::bail!( + "waiter must be Ready on first poll when agent was registered before call" + ); + }; + + if !matches!(result?, BufferedRequestAgentWaitResult::Found(_)) { + anyhow::bail!("waiter must return Found when an agent is already in the pool"); + } + + Ok(()) + } } diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs b/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs index f69872c3..f9d964af 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs +++ b/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs @@ -5,10 +5,11 @@ use paddler_types::inference_client::Message as OutgoingMessage; use super::transform_result::TransformResult; use super::transforms_outgoing_message::TransformsOutgoingMessage; -#[derive(Clone)] +#[derive(Clone, Default)] pub struct IdentityTransformer; impl IdentityTransformer { + #[must_use] pub const fn new() -> Self { Self {} } @@ -16,9 +17,9 @@ impl IdentityTransformer { #[async_trait] impl TransformsOutgoingMessage for IdentityTransformer { - async fn transform(&self, message: OutgoingMessage) -> Result { + async fn transform(&self, message: OutgoingMessage) -> Result> { let serialized = serde_json::to_string(&message)?; - Ok(TransformResult::Chunk(serialized)) + Ok(vec![TransformResult::Chunk(serialized)]) } } diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs b/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs index 02dfd12f..3b147333 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs +++ b/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs @@ -41,12 +41,15 @@ where TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync, { async fn send_response(&mut self, message: OutgoingMessage) -> anyhow::Result<()> { - match self.transformer.transform(message).await? { - TransformResult::Discard => Ok(()), - forwarded @ (TransformResult::Chunk(_) | TransformResult::Error(_)) => { - self.chunk_tx.send(forwarded)?; - Ok(()) + for transform_result in self.transformer.transform(message).await? { + match transform_result { + TransformResult::Discard => {} + forwarded @ (TransformResult::Chunk(_) | TransformResult::Error(_)) => { + self.chunk_tx.send(forwarded)?; + } } } + + Ok(()) } } diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs b/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs index 2d3e27b4..ccfc438e 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs +++ b/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs @@ -6,5 +6,5 @@ use super::transform_result::TransformResult; #[async_trait] pub trait TransformsOutgoingMessage { - async fn transform(&self, message: OutgoingMessage) -> Result; + async fn transform(&self, message: OutgoingMessage) -> Result>; } diff --git a/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs b/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs index 384085a9..1abdca71 100644 --- a/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs +++ b/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; +use std::sync::Mutex; use std::time::SystemTime; use std::time::UNIX_EPOCH; @@ -5,17 +7,29 @@ use actix_web::Error; use actix_web::HttpResponse; use actix_web::post; use actix_web::web; +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; use async_trait::async_trait; use nanoid::nanoid; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::generation_summary::GenerationSummary; use paddler_types::inference_client::Message as OutgoingMessage; use paddler_types::inference_client::Response as OutgoingResponse; use paddler_types::jsonrpc::ErrorEnvelope; use paddler_types::jsonrpc::ResponseEnvelope; -use paddler_types::request_params::ContinueFromConversationHistoryParams; +use llama_cpp_bindings::ParsedToolCall; +use llama_cpp_bindings::TokenUsage; +use llama_cpp_bindings::ToolCallArguments; +use paddler_types::oversized_image_details::OversizedImageDetails; +use paddler_types::raw_tool_call_tokens::RawToolCallTokens; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use paddler_types::validates::Validates; use serde::Deserialize; use serde_json::json; use tokio_stream::StreamExt as _; @@ -41,6 +55,22 @@ fn openai_error_json(error_type: &str, message: &str) -> serde_json::Value { }) } +fn openai_usage_json(usage: &TokenUsage) -> serde_json::Value { + json!({ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens(), + "total_tokens": usage.total_tokens(), + "prompt_tokens_details": { + "cached_tokens": usage.cached_prompt_tokens, + "audio_tokens": usage.input_audio_tokens, + "image_tokens": usage.input_image_tokens, + }, + "completion_tokens_details": { + "reasoning_tokens": usage.reasoning_tokens, + } + }) +} + #[expect( clippy::expect_used, reason = "system time before UNIX_EPOCH means we are moving back in time" @@ -52,10 +82,100 @@ fn current_timestamp() -> u64 { .as_secs() } +fn validation_failure_message(errors: &[String]) -> String { + errors + .first() + .cloned() + .unwrap_or_else(|| "tool call failed validation".to_owned()) +} + +fn unrecognized_tool_call_format_message(raw: &RawToolCallTokens) -> String { + format!( + "model produced output the parser did not recognise as any registered tool-call format; \ + FFI error: {}; raw text: {}", + raw.ffi_error_message, raw.text, + ) +} + +fn image_exceeds_batch_size_message(details: &OversizedImageDetails) -> String { + format!( + "image required {} tokens but agent n_batch is {}; rerun with a larger n_batch", + details.image_tokens, details.n_batch, + ) +} + +fn arguments_to_openai_string(arguments: &ToolCallArguments) -> Result { + match arguments { + ToolCallArguments::ValidJson(value) => { + serde_json::to_string(value).context("serializing tool-call arguments to OpenAI string") + } + ToolCallArguments::InvalidJson(raw) => Ok(raw.clone()), + } +} + +fn server_error_chunk(description: &str) -> TransformResult { + TransformResult::Error(openai_error_json("server_error", description).to_string()) +} + +fn timeout_response_chunk() -> TransformResult { + TransformResult::Error(openai_error_json("timeout", "request timed out").to_string()) +} + +fn rate_limit_response_chunk() -> TransformResult { + TransformResult::Error( + openai_error_json("rate_limit_error", "too many buffered requests").to_string(), + ) +} + +fn unexpected_embedding_response_chunk() -> TransformResult { + TransformResult::Error( + openai_error_json( + "invalid_request_error", + "unexpected embedding response in chat completions", + ) + .to_string(), + ) +} + +fn description_from_error_token(token: &GeneratedTokenResult) -> Option<&str> { + match token { + GeneratedTokenResult::ChatTemplateError(description) + | GeneratedTokenResult::GrammarIncompatibleWithThinking(description) + | GeneratedTokenResult::GrammarRejectedModelOutput(description) + | GeneratedTokenResult::GrammarInitializationFailed(description) + | GeneratedTokenResult::GrammarSyntaxError(description) + | GeneratedTokenResult::ImageDecodingFailed(description) + | GeneratedTokenResult::MultimodalNotSupported(description) + | GeneratedTokenResult::SamplerError(description) + | GeneratedTokenResult::ToolCallParseFailed(description) + | GeneratedTokenResult::ToolSchemaInvalid(description) => Some(description), + _ => None, + } +} + +fn try_universal_error_chunk(message: &OutgoingMessage) -> Option { + match message { + OutgoingMessage::Error(ErrorEnvelope { + error: paddler_types::jsonrpc::Error { description, .. }, + .. + }) => Some(server_error_chunk(description)), + OutgoingMessage::Response(ResponseEnvelope { response, .. }) => match response { + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ImageExceedsBatchSize( + details, + )) => Some(server_error_chunk(&image_exceeds_batch_size_message( + details, + ))), + OutgoingResponse::GeneratedToken(token) => { + description_from_error_token(token).map(server_error_chunk) + } + OutgoingResponse::Timeout => Some(timeout_response_chunk()), + OutgoingResponse::TooManyBufferedRequests => Some(rate_limit_response_chunk()), + OutgoingResponse::Embedding(_) => Some(unexpected_embedding_response_chunk()), + }, + } +} + #[derive(Deserialize)] -/// Although fields are same as in Paddler's conversation message for the moment, -/// it would be better if this struct stayed independent from ours just in case -/// to avoid any potential side effects in the future. struct OpenAIMessage { content: ConversationMessageContent, role: String, @@ -70,6 +190,13 @@ impl From<&OpenAIMessage> for ConversationMessage { } } +#[derive(Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct StreamOptions { + #[serde(default)] + include_usage: bool, +} + #[derive(Deserialize)] struct OpenAICompletionRequestParams { max_completion_tokens: Option, @@ -77,159 +204,432 @@ struct OpenAICompletionRequestParams { /// This parameter is ignored here, but is required by the `OpenAI` API. model: String, stream: Option, + stream_options: Option, + #[serde(default)] + tools: Vec>, +} + +#[derive(Default)] +struct OpenAIStreamingState { + saw_tool_call: bool, } #[derive(Clone)] struct OpenAIStreamingResponseTransformer { + include_usage: bool, model: String, + state: Arc>, system_fingerprint: String, } +impl OpenAIStreamingResponseTransformer { + fn content_chunk(&self, request_id: &str, text: &str) -> Result { + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": current_timestamp(), + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": text, + }, + "logprobs": null, + "finish_reason": null + } + ] + }))?) + } + + fn reasoning_chunk(&self, request_id: &str, text: &str) -> Result { + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": current_timestamp(), + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": text, + }, + "logprobs": null, + "finish_reason": null + } + ] + }))?) + } + + fn tool_calls_chunk( + &self, + request_id: &str, + parsed_calls: &[ParsedToolCall], + ) -> Result { + let tool_calls = parsed_calls + .iter() + .enumerate() + .map(|(index, call)| -> Result { + let arguments = arguments_to_openai_string(&call.arguments)?; + Ok(json!({ + "index": index, + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": arguments, + } + })) + }) + .collect::>>()?; + + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": current_timestamp(), + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": tool_calls, + }, + "logprobs": null, + "finish_reason": null + } + ] + }))?) + } + + fn finish_chunk(&self, request_id: &str, finish_reason: &str) -> Result { + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": current_timestamp(), + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": null, + "finish_reason": finish_reason + } + ] + }))?) + } + + fn usage_chunk(&self, request_id: &str, usage: &TokenUsage) -> Result { + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": current_timestamp(), + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [], + "usage": openai_usage_json(usage), + }))?) + } + + fn handle_content(&self, request_id: &str, text: &str) -> Result> { + Ok(vec![TransformResult::Chunk( + self.content_chunk(request_id, text)?, + )]) + } + + fn handle_reasoning(&self, request_id: &str, text: &str) -> Result> { + Ok(vec![TransformResult::Chunk( + self.reasoning_chunk(request_id, text)?, + )]) + } + + fn handle_tool_call_parsed( + &self, + request_id: &str, + parsed_calls: &[ParsedToolCall], + ) -> Result> { + if parsed_calls.is_empty() { + return Ok(vec![]); + } + + self.state + .lock() + .map_err(|err| anyhow!("streaming state mutex poisoned: {err}"))? + .saw_tool_call = true; + + Ok(vec![TransformResult::Chunk( + self.tool_calls_chunk(request_id, parsed_calls)?, + )]) + } + + fn handle_done( + &self, + request_id: &str, + summary: &GenerationSummary, + ) -> Result> { + let saw_tool_call = self + .state + .lock() + .map_err(|err| anyhow!("streaming state mutex poisoned: {err}"))? + .saw_tool_call; + + let finish_reason = if saw_tool_call { "tool_calls" } else { "stop" }; + let finish = TransformResult::Chunk(self.finish_chunk(request_id, finish_reason)?); + + if self.include_usage { + let usage = TransformResult::Chunk(self.usage_chunk(request_id, &summary.usage)?); + Ok(vec![finish, usage]) + } else { + Ok(vec![finish]) + } + } +} + #[async_trait] impl TransformsOutgoingMessage for OpenAIStreamingResponseTransformer { - async fn transform(&self, message: OutgoingMessage) -> anyhow::Result { + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let Some(error_chunk) = try_universal_error_chunk(&message) { + return Ok(vec![error_chunk]); + } + match message { OutgoingMessage::Response(ResponseEnvelope { request_id, - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done), - }) => Ok(TransformResult::Chunk(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": null, - "finish_reason": "stop" - } - ] - }))?)), - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Token(token)), - }) => Ok(TransformResult::Chunk(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "content": token, - }, - "logprobs": null, - "finish_reason": null - } - ] - }))?)), - OutgoingMessage::Response(ResponseEnvelope { response: OutgoingResponse::GeneratedToken( - GeneratedTokenResult::ChatTemplateError(description) - | GeneratedTokenResult::GrammarIncompatibleWithThinking(description) - | GeneratedTokenResult::GrammarRejectedModelOutput(description) - | GeneratedTokenResult::GrammarInitializationFailed(description) - | GeneratedTokenResult::GrammarSyntaxError(description) - | GeneratedTokenResult::ImageDecodingFailed(description) - | GeneratedTokenResult::MultimodalNotSupported(description) - | GeneratedTokenResult::SamplerError(description), + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text), ), .. - }) - | OutgoingMessage::Error(ErrorEnvelope { - error: paddler_types::jsonrpc::Error { description, .. }, + }) => self.handle_content(&request_id, &text), + OutgoingMessage::Response(ResponseEnvelope { + request_id, + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ReasoningToken(text)), .. - }) => Ok(TransformResult::Error( - openai_error_json("server_error", &description).to_string(), - )), + }) => self.handle_reasoning(&request_id, &text), OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::Timeout, + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallToken(_)), .. - }) => Ok(TransformResult::Error( - openai_error_json("timeout", "request timed out").to_string(), - )), + }) => Ok(vec![]), OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::TooManyBufferedRequests, + request_id, + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), .. - }) => Ok(TransformResult::Error( - openai_error_json("rate_limit_error", "too many buffered requests").to_string(), - )), + }) => self.handle_tool_call_parsed(&request_id, &parsed_calls), + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallValidationFailed( + errors, + )), + .. + }) => Ok(vec![server_error_chunk(&validation_failure_message( + &errors, + ))]), + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::UnrecognizedToolCallFormat( + raw, + )), + .. + }) => Ok(vec![server_error_chunk( + &unrecognized_tool_call_format_message(&raw), + )]), OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::Embedding(_), + request_id, + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), .. - }) => Ok(TransformResult::Error( - openai_error_json( - "invalid_request_error", - "unexpected embedding response in chat completions", - ) - .to_string(), + }) => self.handle_done(&request_id, &summary), + other => Err(anyhow!( + "OpenAIStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" )), } } } +#[derive(Clone, Default)] +struct OpenAINonStreamingState { + content: String, + reasoning: String, + tool_calls: Vec, +} + #[derive(Clone)] -struct OpenAICombinedResponseTransformer {} +struct OpenAINonStreamingResponseTransformer { + model: String, + state: Arc>, +} + +impl OpenAINonStreamingResponseTransformer { + fn append_content(&self, text: &str) -> Result<()> { + self.state + .lock() + .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? + .content + .push_str(text); + Ok(()) + } + + fn append_reasoning(&self, text: &str) -> Result<()> { + self.state + .lock() + .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? + .reasoning + .push_str(text); + Ok(()) + } + + fn append_tool_calls(&self, parsed_calls: Vec) -> Result<()> { + self.state + .lock() + .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? + .tool_calls + .extend(parsed_calls); + Ok(()) + } + + fn build_done_chunk(&self, request_id: &str, summary: &GenerationSummary) -> Result { + let snapshot = self.snapshot_state()?; + + let has_tool_calls = !snapshot.tool_calls.is_empty(); + let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" }; + + let mut message_obj = json!({ + "role": "assistant", + "content": if snapshot.content.is_empty() && has_tool_calls { + serde_json::Value::Null + } else { + json!(snapshot.content) + }, + "refusal": null, + "annotations": [] + }); + + if !snapshot.reasoning.is_empty() + && let Some(map) = message_obj.as_object_mut() + { + map.insert("reasoning_content".to_owned(), json!(snapshot.reasoning)); + } + + if has_tool_calls && let Some(map) = message_obj.as_object_mut() { + let tool_calls_json = snapshot + .tool_calls + .iter() + .map(|call| -> Result { + let arguments = arguments_to_openai_string(&call.arguments)?; + Ok(json!({ + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": arguments, + } + })) + }) + .collect::>>()?; + map.insert("tool_calls".to_owned(), json!(tool_calls_json)); + } + + Ok(serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion", + "created": current_timestamp(), + "model": self.model, + "choices": [ + { + "index": 0, + "message": message_obj, + "logprobs": null, + "finish_reason": finish_reason + } + ], + "usage": openai_usage_json(&summary.usage), + "service_tier": "default" + }))?) + } + + fn snapshot_state(&self) -> Result { + let state = self + .state + .lock() + .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))?; + Ok(state.clone()) + } +} #[async_trait] -impl TransformsOutgoingMessage for OpenAICombinedResponseTransformer { - async fn transform(&self, message: OutgoingMessage) -> anyhow::Result { +impl TransformsOutgoingMessage for OpenAINonStreamingResponseTransformer { + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let Some(error_chunk) = try_universal_error_chunk(&message) { + return Ok(vec![error_chunk]); + } + match message { OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done), + response: + OutgoingResponse::GeneratedToken( + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text), + ), .. - }) => Ok(TransformResult::Chunk(String::new())), + }) => { + self.append_content(&text)?; + Ok(vec![]) + } OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Token(token)), + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ReasoningToken(text)), .. - }) => Ok(TransformResult::Chunk(token)), + }) => { + self.append_reasoning(&text)?; + Ok(vec![]) + } OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken( - GeneratedTokenResult::ChatTemplateError(description) - | GeneratedTokenResult::GrammarIncompatibleWithThinking(description) - | GeneratedTokenResult::GrammarRejectedModelOutput(description) - | GeneratedTokenResult::GrammarInitializationFailed(description) - | GeneratedTokenResult::GrammarSyntaxError(description) - | GeneratedTokenResult::ImageDecodingFailed(description) - | GeneratedTokenResult::MultimodalNotSupported(description) - | GeneratedTokenResult::SamplerError(description), - ), + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallToken(_)), .. - }) - | OutgoingMessage::Error(ErrorEnvelope { - error: paddler_types::jsonrpc::Error { description, .. }, + }) => Ok(vec![]), + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), .. - }) => Ok(TransformResult::Error( - openai_error_json("server_error", &description).to_string(), - )), + }) => { + self.append_tool_calls(parsed_calls)?; + Ok(vec![]) + } OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::Timeout, + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallValidationFailed( + errors, + )), .. - }) => Ok(TransformResult::Error( - openai_error_json("timeout", "request timed out").to_string(), - )), + }) => Ok(vec![server_error_chunk(&validation_failure_message( + &errors, + ))]), OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::TooManyBufferedRequests, + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::UnrecognizedToolCallFormat( + raw, + )), .. - }) => Ok(TransformResult::Error( - openai_error_json("rate_limit_error", "too many buffered requests").to_string(), - )), + }) => Ok(vec![server_error_chunk( + &unrecognized_tool_call_format_message(&raw), + )]), OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::Embedding(_), + request_id, + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), .. - }) => Ok(TransformResult::Error( - openai_error_json( - "invalid_request_error", - "unexpected embedding response in chat completions", - ) - .to_string(), + }) => Ok(vec![TransformResult::Chunk( + self.build_done_chunk(&request_id, &summary)?, + )]), + other => Err(anyhow!( + "OpenAINonStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" )), } } @@ -240,6 +640,23 @@ async fn respond( app_data: web::Data, openai_params: web::Json, ) -> Result { + let openai_params = openai_params.into_inner(); + + let validated_tools = match openai_params + .tools + .into_iter() + .map(Validates::validate) + .collect::, _>>() + { + Ok(tools) => tools, + Err(err) => { + return Ok(HttpResponse::BadRequest() + .content_type("application/json") + .body(openai_error_json("invalid_request_error", &err.to_string()).to_string())); + } + }; + + let parse_tool_calls = !validated_tools.is_empty(); let paddler_params = ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new( @@ -252,16 +669,24 @@ async fn respond( enable_thinking: true, grammar: None, max_tokens: openai_params.max_completion_tokens.unwrap_or(2000), - tools: vec![], + parse_tool_calls, + tools: validated_tools, }; if openai_params.stream.unwrap_or(false) { + let include_usage = openai_params + .stream_options + .as_ref() + .is_some_and(|options| options.include_usage); + Ok(http_stream_from_agent( app_data.buffered_request_manager.clone(), app_data.inference_service_configuration.clone(), paddler_params, OpenAIStreamingResponseTransformer { + include_usage, model: openai_params.model.clone(), + state: Arc::new(Mutex::new(OpenAIStreamingState::default())), system_fingerprint: nanoid!(), }, )) @@ -270,7 +695,10 @@ async fn respond( app_data.buffered_request_manager.clone(), app_data.inference_service_configuration.clone(), paddler_params, - OpenAICombinedResponseTransformer {}, + OpenAINonStreamingResponseTransformer { + model: openai_params.model.clone(), + state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), + }, ) .collect() .await; @@ -284,68 +712,51 @@ async fn respond( .body(error_json.clone())); } - let combined_response: String = results - .into_iter() - .filter_map(|result| match result { - TransformResult::Chunk(content) => Some(content), - TransformResult::Discard | TransformResult::Error(_) => None, - }) - .collect(); - - Ok(HttpResponse::Ok().json(json!({ - "id": nanoid!(), - "object": "chat.completion", - "created": current_timestamp(), - "model": openai_params.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": combined_response, - "refusal": null, - "annotations": [] - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - "prompt_tokens_details": { - "cached_tokens": 0, - "audio_tokens": 0 + let body = results.into_iter().find_map(|result| match result { + TransformResult::Chunk(content) => Some(content), + TransformResult::Discard | TransformResult::Error(_) => None, + }); + + Ok(body.map_or_else( + || { + HttpResponse::InternalServerError() + .content_type("application/json") + .body(openai_error_json("server_error", "no completion produced").to_string()) }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "audio_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "service_tier": "default" - }))) + |json_body| { + HttpResponse::Ok() + .content_type("application/json") + .body(json_body) + }, + )) } } #[cfg(test)] mod tests { + use std::sync::Arc; + use std::sync::Mutex; + use anyhow::Result; + use llama_cpp_bindings::ParsedToolCall; + use llama_cpp_bindings::TokenUsage; + use llama_cpp_bindings::ToolCallArguments; use paddler_types::generated_token_result::GeneratedTokenResult; + use paddler_types::generation_summary::GenerationSummary; use paddler_types::inference_client::Message as OutgoingMessage; use paddler_types::inference_client::Response as OutgoingResponse; use paddler_types::jsonrpc::ErrorEnvelope; use paddler_types::jsonrpc::ResponseEnvelope; - use super::OpenAICombinedResponseTransformer; + use super::OpenAINonStreamingResponseTransformer; + use super::OpenAINonStreamingState; use super::OpenAIStreamingResponseTransformer; use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; fn make_token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, request_id: "test-request".to_owned(), response: OutgoingResponse::GeneratedToken(token_result), }) @@ -363,11 +774,28 @@ mod tests { fn make_response_message(response: OutgoingResponse) -> OutgoingMessage { OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, request_id: "test-request".to_owned(), response, }) } + fn streaming_transformer(include_usage: bool) -> OpenAIStreamingResponseTransformer { + OpenAIStreamingResponseTransformer { + include_usage, + model: "test-model".to_owned(), + state: Arc::new(Mutex::new(super::OpenAIStreamingState::default())), + system_fingerprint: "test-fingerprint".to_owned(), + } + } + + fn non_streaming_transformer() -> OpenAINonStreamingResponseTransformer { + OpenAINonStreamingResponseTransformer { + model: "test-model".to_owned(), + state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), + } + } + fn assert_chunk_contains(result: &TransformResult, expected: &str) -> Result<()> { let TransformResult::Chunk(content) = result else { anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); @@ -381,6 +809,19 @@ mod tests { Ok(()) } + fn assert_chunk_does_not_contain(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Chunk(content) = result else { + anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); + }; + + assert!( + !content.contains(expected), + "chunk unexpectedly contains '{expected}': {content}" + ); + + Ok(()) + } + fn assert_error_contains(result: &TransformResult, expected: &str) -> Result<()> { let TransformResult::Error(content) = result else { anyhow::bail!("expected TransformResult::Error, got TransformResult::Chunk"); @@ -394,255 +835,625 @@ mod tests { Ok(()) } + fn summary_with_counts( + prompt_tokens: u64, + content_tokens: u64, + reasoning_tokens: u64, + ) -> GenerationSummary { + GenerationSummary { + usage: TokenUsage { + prompt_tokens, + content_tokens, + reasoning_tokens, + ..TokenUsage::default() + }, + } + } + + fn weather_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(serde_json::json!({"location": "Paris"})), + ) + } + #[actix_web::test] - async fn streaming_token_emits_content_delta() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn streaming_content_token_emits_content_delta() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_token_message(GeneratedTokenResult::Token("hello".to_owned())); - let result = transformer.transform(message).await?; + let message = make_token_message(GeneratedTokenResult::ContentToken("hello".to_owned())); + let chunks = transformer.transform(message).await?; - assert_chunk_contains(&result, "\"content\":\"hello\"")?; - assert_chunk_contains(&result, "\"role\":\"assistant\"")?; + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"content\":\"hello\"")?; + assert_chunk_contains(&chunks[0], "\"role\":\"assistant\"")?; + assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; Ok(()) } #[actix_web::test] - async fn streaming_done_emits_stop_finish_reason() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn streaming_reasoning_token_emits_reasoning_content_delta() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_token_message(GeneratedTokenResult::Done); - let result = transformer.transform(message).await?; + let message = + make_token_message(GeneratedTokenResult::ReasoningToken("thought".to_owned())); + let chunks = transformer.transform(message).await?; - assert_chunk_contains(&result, "\"finish_reason\":\"stop\"")?; + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"reasoning_content\":\"thought\"")?; + assert_chunk_contains(&chunks[0], "\"role\":\"assistant\"")?; + assert_chunk_does_not_contain(&chunks[0], "\"content\":")?; Ok(()) } #[actix_web::test] - async fn combined_token_returns_content() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn streaming_undeterminable_token_emits_content_delta() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_token_message(GeneratedTokenResult::Token("hello".to_owned())); - let result = transformer.transform(message).await?; + let message = make_token_message(GeneratedTokenResult::UndeterminableToken( + "ambig".to_owned(), + )); + let chunks = transformer.transform(message).await?; - assert!(matches!(result, TransformResult::Chunk(ref content) if content == "hello")); + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"content\":\"ambig\"")?; + assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; Ok(()) } #[actix_web::test] - async fn combined_done_returns_empty_chunk() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn streaming_tool_call_token_is_silently_dropped() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_token_message(GeneratedTokenResult::Done); - let result = transformer.transform(message).await?; + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::ToolCallToken( + "{".to_owned(), + ))) + .await?; - assert!(matches!(result, TransformResult::Chunk(ref content) if content.is_empty())); + assert_eq!(chunks.len(), 0); Ok(()) } #[actix_web::test] - async fn streaming_error_message_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn streaming_tool_call_parsed_emits_structured_tool_calls_chunk() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( + vec![weather_call()], + ))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"tool_calls\"")?; + assert_chunk_contains(&chunks[0], "\"id\":\"call_x\"")?; + assert_chunk_contains(&chunks[0], "\"name\":\"get_weather\"")?; + assert_chunk_contains( + &chunks[0], + "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", + )?; - let message = make_error_message(500, "internal server error"); - let result = transformer.transform(message).await?; + Ok(()) + } + + #[actix_web::test] + async fn streaming_done_after_tool_call_uses_tool_calls_finish_reason() -> Result<()> { + let transformer = streaming_transformer(false); + + transformer + .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( + vec![weather_call()], + ))) + .await?; + + let summary = summary_with_counts(2, 0, 0); + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"tool_calls\"")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_done_without_tool_call_uses_stop_finish_reason() -> Result<()> { + let transformer = streaming_transformer(false); + + transformer + .transform(make_token_message(GeneratedTokenResult::ContentToken( + "hi".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(2, 1, 0); + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_done_with_include_usage_emits_finish_then_usage_chunk() -> Result<()> { + let transformer = streaming_transformer(true); + let summary = summary_with_counts(7, 4, 1); + + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 2); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + assert_chunk_does_not_contain(&chunks[0], "usage")?; + assert_chunk_contains(&chunks[1], "\"prompt_tokens\":7")?; + assert_chunk_contains(&chunks[1], "\"completion_tokens\":5")?; + assert_chunk_contains(&chunks[1], "\"total_tokens\":12")?; + assert_chunk_contains(&chunks[1], "\"choices\":[]")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_done_without_include_usage_emits_only_finish_chunk() -> Result<()> { + let transformer = streaming_transformer(false); + let summary = summary_with_counts(5, 3, 2); + + let chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; - assert_error_contains(&result, "internal server error")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + assert_chunk_does_not_contain(&chunks[0], "usage")?; Ok(()) } #[actix_web::test] - async fn combined_error_message_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn streaming_tool_call_parse_failed_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::ToolCallParseFailed("bad payload".to_owned()), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad payload")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_tool_call_validation_failed_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::ToolCallValidationFailed(vec!["missing field x".to_owned()]), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "missing field x")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::UnrecognizedToolCallFormat( + paddler_types::raw_tool_call_tokens::RawToolCallTokens { + text: "blah".to_owned(), + ffi_error_message: "common_chat_parse failed: no parser".to_owned(), + }, + ), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; + assert_error_contains(&chunks[0], "blah")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_error_message_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); let message = make_error_message(500, "internal server error"); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "internal server error")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "internal server error")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] async fn streaming_chat_template_error_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + let transformer = streaming_transformer(false); let message = make_token_message(GeneratedTokenResult::ChatTemplateError( "bad template".to_owned(), )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "bad template")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad template")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn combined_chat_template_error_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn streaming_timeout_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_token_message(GeneratedTokenResult::ChatTemplateError( - "bad template".to_owned(), + let message = make_response_message(OutgoingResponse::Timeout); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "request timed out")?; + assert_error_contains(&chunks[0], "timeout")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "too many buffered requests")?; + assert_error_contains(&chunks[0], "rate_limit_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn streaming_image_decoding_failed_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = make_token_message(GeneratedTokenResult::ImageDecodingFailed( + "unsupported format".to_owned(), )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "bad template")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "unsupported format")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn streaming_timeout_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_response_message(OutgoingResponse::Timeout); - let result = transformer.transform(message).await?; + let message = make_token_message(GeneratedTokenResult::MultimodalNotSupported( + "model does not support images".to_owned(), + )); + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "request timed out")?; - assert_error_contains(&result, "timeout")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "model does not support images")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn combined_timeout_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); - let message = make_response_message(OutgoingResponse::Timeout); - let result = transformer.transform(message).await?; + let message = make_token_message(GeneratedTokenResult::ImageExceedsBatchSize( + paddler_types::oversized_image_details::OversizedImageDetails { + image_tokens: 368, + n_batch: 100, + }, + )); + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "request timed out")?; - assert_error_contains(&result, "timeout")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "368")?; + assert_error_contains(&chunks[0], "100")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn non_streaming_aggregates_content_only_when_no_reasoning() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(make_token_message(GeneratedTokenResult::ContentToken( + "hel".to_owned(), + ))) + .await?; + transformer + .transform(make_token_message(GeneratedTokenResult::ContentToken( + "lo".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(4, 2, 0); + let final_chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"hello\"")?; + assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; + assert_chunk_contains(&final_chunks[0], "\"prompt_tokens\":4")?; + assert_chunk_contains(&final_chunks[0], "\"completion_tokens\":2")?; - let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); - let result = transformer.transform(message).await?; + Ok(()) + } - assert_error_contains(&result, "too many buffered requests")?; - assert_error_contains(&result, "rate_limit_error")?; + #[actix_web::test] + async fn non_streaming_separates_reasoning_from_content() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(make_token_message(GeneratedTokenResult::ReasoningToken( + "think".to_owned(), + ))) + .await?; + transformer + .transform(make_token_message(GeneratedTokenResult::ContentToken( + "answer".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(3, 1, 1); + let final_chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"answer\"")?; + assert_chunk_contains(&final_chunks[0], "\"reasoning_content\":\"think\"")?; + assert_chunk_contains(&final_chunks[0], "\"reasoning_tokens\":1")?; Ok(()) } #[actix_web::test] - async fn combined_too_many_buffered_requests_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn non_streaming_undeterminable_routes_to_content() -> Result<()> { + let transformer = non_streaming_transformer(); - let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); - let result = transformer.transform(message).await?; + transformer + .transform(make_token_message( + GeneratedTokenResult::UndeterminableToken("amb".to_owned()), + )) + .await?; - assert_error_contains(&result, "too many buffered requests")?; - assert_error_contains(&result, "rate_limit_error")?; + let summary = summary_with_counts(2, 0, 0); + let final_chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"amb\"")?; + assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; Ok(()) } #[actix_web::test] - async fn openai_error_json_has_correct_structure() -> Result<()> { - let error = super::openai_error_json("server_error", "something went wrong"); + async fn non_streaming_tool_call_parsed_populates_message_tool_calls() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( + vec![weather_call()], + ))) + .await?; + + let summary = summary_with_counts(4, 0, 0); + let final_chunks = transformer + .transform(make_token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"tool_calls\":")?; + assert_chunk_contains(&final_chunks[0], "\"name\":\"get_weather\"")?; + assert_chunk_contains( + &final_chunks[0], + "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", + )?; + assert_chunk_contains(&final_chunks[0], "\"finish_reason\":\"tool_calls\"")?; - assert_eq!(error["error"]["type"], "server_error"); - assert_eq!(error["error"]["message"], "something went wrong"); - assert!(error["error"]["param"].is_null()); - assert!(error["error"]["code"].is_null()); + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_tool_call_parse_failed_emits_error() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::ToolCallParseFailed("bad payload".to_owned()), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad payload")?; Ok(()) } #[actix_web::test] - async fn streaming_image_decoding_failed_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn non_streaming_tool_call_validation_failed_emits_error() -> Result<()> { + let transformer = non_streaming_transformer(); - let message = make_token_message(GeneratedTokenResult::ImageDecodingFailed( - "unsupported format".to_owned(), + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::ToolCallValidationFailed(vec!["bad shape".to_owned()]), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad shape")?; + + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(make_token_message( + GeneratedTokenResult::UnrecognizedToolCallFormat( + paddler_types::raw_tool_call_tokens::RawToolCallTokens { + text: "blah".to_owned(), + ffi_error_message: "common_chat_parse failed: no parser".to_owned(), + }, + ), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; + assert_error_contains(&chunks[0], "blah")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_error_message_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(make_error_message(500, "internal server error")) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "internal server error")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_chat_template_error_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = make_token_message(GeneratedTokenResult::ChatTemplateError( + "bad template".to_owned(), )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "unsupported format")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad template")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn combined_image_decoding_failed_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn non_streaming_image_decoding_failed_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); let message = make_token_message(GeneratedTokenResult::ImageDecodingFailed( "unsupported format".to_owned(), )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "unsupported format")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "unsupported format")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { - let transformer = OpenAIStreamingResponseTransformer { - model: "test-model".to_owned(), - system_fingerprint: "test-fingerprint".to_owned(), - }; + async fn non_streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); let message = make_token_message(GeneratedTokenResult::MultimodalNotSupported( "model does not support images".to_owned(), )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "model does not support images")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "model does not support images")?; + assert_error_contains(&chunks[0], "server_error")?; Ok(()) } #[actix_web::test] - async fn combined_multimodal_not_supported_returns_error_variant() -> Result<()> { - let transformer = OpenAICombinedResponseTransformer {}; + async fn non_streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); - let message = make_token_message(GeneratedTokenResult::MultimodalNotSupported( - "model does not support images".to_owned(), + let message = make_token_message(GeneratedTokenResult::ImageExceedsBatchSize( + paddler_types::oversized_image_details::OversizedImageDetails { + image_tokens: 368, + n_batch: 100, + }, )); - let result = transformer.transform(message).await?; + let chunks = transformer.transform(message).await?; - assert_error_contains(&result, "model does not support images")?; - assert_error_contains(&result, "server_error")?; + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "368")?; + assert_error_contains(&chunks[0], "100")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_timeout_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = make_response_message(OutgoingResponse::Timeout); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "request timed out")?; + assert_error_contains(&chunks[0], "timeout")?; + + Ok(()) + } + + #[actix_web::test] + async fn non_streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "too many buffered requests")?; + assert_error_contains(&chunks[0], "rate_limit_error")?; Ok(()) } @@ -666,6 +1477,41 @@ mod tests { Ok(()) } + #[test] + fn deserialize_request_with_stream_options_include_usage_true() -> Result<()> { + let input = serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": true, + "stream_options": {"include_usage": true} + }); + + let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; + + let stream_options = params + .stream_options + .ok_or_else(|| anyhow::anyhow!("expected stream_options"))?; + + assert!(stream_options.include_usage); + + Ok(()) + } + + #[test] + fn deserialize_request_without_stream_options_defaults_to_none() -> Result<()> { + let input = serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": true + }); + + let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; + + assert!(params.stream_options.is_none()); + + Ok(()) + } + #[test] fn deserialize_multimodal_request_with_image() -> Result<()> { let input = serde_json::json!({ @@ -712,10 +1558,6 @@ mod tests { let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; assert_eq!(params.messages.len(), 4); - assert_eq!(params.messages[0].role, "system"); - assert_eq!(params.messages[1].role, "user"); - assert_eq!(params.messages[2].role, "assistant"); - assert_eq!(params.messages[3].role, "user"); Ok(()) } @@ -741,4 +1583,31 @@ mod tests { Ok(()) } + + #[test] + fn openai_error_json_has_correct_structure() { + let error = super::openai_error_json("server_error", "something went wrong"); + + assert_eq!(error["error"]["type"], "server_error"); + assert_eq!(error["error"]["message"], "something went wrong"); + assert!(error["error"]["param"].is_null()); + assert!(error["error"]["code"].is_null()); + } + + #[test] + fn validation_failure_message_returns_first_error() { + let message = super::validation_failure_message(&[ + "first issue".to_owned(), + "second issue".to_owned(), + ]); + + assert_eq!(message, "first issue"); + } + + #[test] + fn validation_failure_message_falls_back_when_no_errors() { + let message = super::validation_failure_message(&[]); + + assert!(message.contains("validation")); + } } diff --git a/paddler/src/balancer/dispatch_candidate.rs b/paddler/src/balancer/dispatch_candidate.rs new file mode 100644 index 00000000..7d8785cc --- /dev/null +++ b/paddler/src/balancer/dispatch_candidate.rs @@ -0,0 +1,8 @@ +use std::sync::Arc; + +use crate::balancer::agent_controller::AgentController; + +pub struct DispatchCandidate { + pub agent_controller: Arc, + pub snapshot: i32, +} diff --git a/paddler/src/balancer/inference_service/app_data.rs b/paddler/src/balancer/inference_service/app_data.rs index 79c44127..26e2dee3 100644 --- a/paddler/src/balancer/inference_service/app_data.rs +++ b/paddler/src/balancer/inference_service/app_data.rs @@ -2,11 +2,13 @@ use std::sync::Arc; use tokio_util::sync::CancellationToken; +use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer::buffered_request_manager::BufferedRequestManager; use crate::balancer::inference_service::configuration::Configuration; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; pub struct AppData { + pub agent_controller_pool: Arc, pub balancer_applicable_state_holder: Arc, pub buffered_request_manager: Arc, pub inference_service_configuration: Configuration, diff --git a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs b/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs index cea36b27..ffe33394 100644 --- a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs +++ b/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs @@ -4,7 +4,7 @@ use actix_web::error::ErrorBadRequest; use actix_web::post; use actix_web::web; -use paddler_types::request_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; use paddler_types::validates::Validates as _; diff --git a/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs b/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs index 918943a6..efd59ecc 100644 --- a/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs +++ b/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs @@ -1,6 +1,7 @@ use actix_web::Error; use actix_web::HttpResponse; use actix_web::Responder; +use actix_web::error::ErrorInternalServerError; use actix_web::error::ErrorNotImplemented; use actix_web::error::ErrorServiceUnavailable; use actix_web::http::header; @@ -19,6 +20,7 @@ use paddler_types::inference_client::Response as OutgoingResponse; use paddler_types::jsonrpc::Error as JsonRpcError; use paddler_types::jsonrpc::ErrorEnvelope; use paddler_types::jsonrpc::ResponseEnvelope; +use paddler_types::request_params::ChunkEvenlyWithCapError; use paddler_types::request_params::GenerateEmbeddingBatchParams; use tokio::sync::mpsc; use tokio::task::JoinSet; @@ -34,25 +36,23 @@ use crate::balancer::request_from_agent::request_from_agent; use crate::cancellation_token_stream_guard::CancellationTokenStreamGuard; use crate::controls_session::ControlsSession as _; -const CHARACTERS_PER_TOKEN_APPROXIMATELY: usize = 3; - #[derive(Clone)] struct EmbeddingChunkBodyTransformer; #[async_trait] impl TransformsOutgoingMessage for EmbeddingChunkBodyTransformer { - async fn transform(&self, message: OutgoingMessage) -> Result { + async fn transform(&self, message: OutgoingMessage) -> Result> { if let OutgoingMessage::Response(ResponseEnvelope { response: OutgoingResponse::Embedding(EmbeddingResult::Done), .. }) = &message { - return Ok(TransformResult::Discard); + return Ok(vec![TransformResult::Discard]); } let serialized = serde_json::to_string(&message)?; - Ok(TransformResult::Chunk(serialized)) + Ok(vec![TransformResult::Chunk(serialized)]) } } @@ -79,15 +79,32 @@ async fn respond( )); } + let agent_count = app_data.agent_controller_pool.agents.len(); + let embedding_batch_size = agent_desired_state + .inference_parameters + .embedding_batch_size; + let connection_close = CancellationToken::new(); let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); let mut chunk_tasks: JoinSet<()> = JoinSet::new(); - for batch in params.chunk_by_input_size( - agent_desired_state.inference_parameters.batch_n_tokens - * CHARACTERS_PER_TOKEN_APPROXIMATELY, - ) { + let batches = match params + .into_inner() + .chunk_evenly_with_cap(agent_count, embedding_batch_size) + { + Ok(batches) => batches, + Err(ChunkEvenlyWithCapError::ZeroAgentCount) => { + return Err(ErrorServiceUnavailable("No agents are currently connected")); + } + Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) => { + return Err(ErrorInternalServerError( + "embedding_batch_size is zero despite validation", + )); + } + }; + + for batch in batches { let buffered_request_manager_clone = app_data.buffered_request_manager.clone(); let chunk_tx_clone = chunk_tx.clone(); let connection_close_clone = connection_close.clone(); @@ -136,6 +153,7 @@ async fn respond( final_session .send_response_safe(OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, request_id: final_request_id, response: OutgoingResponse::Embedding(EmbeddingResult::Done), })) @@ -144,18 +162,16 @@ async fn respond( drop(chunk_tx); - let stream = CancellationTokenStreamGuard::new( - UnboundedReceiverStream::new(chunk_rx), - connection_close, - ) - .filter_map(|transform_result| async move { - match transform_result { - TransformResult::Chunk(content) | TransformResult::Error(content) => { - Some(Ok::<_, Error>(Bytes::from(format!("{content}\n")))) - } - TransformResult::Discard => None, - } - }); + let stream = + CancellationTokenStreamGuard::new(UnboundedReceiverStream::new(chunk_rx), connection_close) + .filter_map(|transform_result| async move { + match transform_result { + TransformResult::Chunk(content) | TransformResult::Error(content) => { + Some(Ok::<_, Error>(Bytes::from(format!("{content}\n")))) + } + TransformResult::Discard => None, + } + }); Ok(HttpResponse::Ok() .insert_header(header::ContentType::json()) diff --git a/paddler/src/balancer/inference_service/mod.rs b/paddler/src/balancer/inference_service/mod.rs index 2a9ec2f7..844a004a 100644 --- a/paddler/src/balancer/inference_service/mod.rs +++ b/paddler/src/balancer/inference_service/mod.rs @@ -11,6 +11,7 @@ use anyhow::Result; use async_trait::async_trait; use tokio_util::sync::CancellationToken; +use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer::buffered_request_manager::BufferedRequestManager; use crate::balancer::http_route as common_http_route; use crate::balancer::inference_service::app_data::AppData; @@ -22,6 +23,7 @@ use crate::create_cors_middleware::create_cors_middleware; use crate::service::Service; pub struct InferenceService { + pub agent_controller_pool: Arc, pub balancer_applicable_state_holder: Arc, pub buffered_request_manager: Arc, pub configuration: InferenceServiceConfiguration, @@ -47,6 +49,7 @@ impl Service for InferenceService { let cors_allowed_hosts_arc = Arc::new(cors_allowed_hosts); let app_data = Data::new(AppData { + agent_controller_pool: self.agent_controller_pool.clone(), balancer_applicable_state_holder: self.balancer_applicable_state_holder.clone(), buffered_request_manager: self.buffered_request_manager.clone(), inference_service_configuration: self.configuration.clone(), diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs index 1790f42f..db057885 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs +++ b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs @@ -233,6 +233,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController { ManagementJsonRpcMessage::Response(ResponseEnvelope { request_id, response: AgentJsonRpcResponse::ChatTemplateOverride(chat_template_override), + .. }) => { context .chat_template_override_sender_collection @@ -244,6 +245,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController { ManagementJsonRpcMessage::Response(ResponseEnvelope { request_id, response: AgentJsonRpcResponse::Embedding(embedding_result), + .. }) => { context .embedding_sender_collection @@ -255,6 +257,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController { ManagementJsonRpcMessage::Response(ResponseEnvelope { request_id, response: AgentJsonRpcResponse::GeneratedToken(generated_token_envelope), + .. }) => { context .generate_tokens_sender_collection @@ -266,6 +269,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController { ManagementJsonRpcMessage::Response(ResponseEnvelope { request_id, response: AgentJsonRpcResponse::ModelMetadata(model_metadata), + .. }) => { context .model_metadata_sender_collection diff --git a/paddler/src/balancer/mod.rs b/paddler/src/balancer/mod.rs index 79047014..f0011cbf 100644 --- a/paddler/src/balancer/mod.rs +++ b/paddler/src/balancer/mod.rs @@ -8,9 +8,10 @@ mod buffered_request_count_guard; mod buffered_request_counter; pub mod buffered_request_manager; pub mod chat_template_override_sender_collection; -mod chunk_forwarding_session_controller; +pub mod chunk_forwarding_session_controller; pub mod compatibility; mod controls_manages_senders_endpoint; +pub mod dispatch_candidate; pub mod dispatched_agent; pub mod embedding_sender_collection; pub mod generate_tokens_sender_collection; @@ -19,11 +20,11 @@ mod http_route; mod http_stream_from_agent; pub mod inference_service; pub mod management_service; -mod manages_senders; -mod manages_senders_controller; +pub mod manages_senders; +pub mod manages_senders_controller; pub mod model_metadata_sender_collection; pub mod reconciliation_service; -mod request_from_agent; +pub mod request_from_agent; #[cfg(feature = "web_admin_panel")] mod response; pub mod state_database; diff --git a/paddler/src/balancer/reconciliation_service.rs b/paddler/src/balancer/reconciliation_service.rs index e6fab0a5..11f95b03 100644 --- a/paddler/src/balancer/reconciliation_service.rs +++ b/paddler/src/balancer/reconciliation_service.rs @@ -26,15 +26,13 @@ pub struct ReconciliationService { impl ReconciliationService { pub async fn convert_to_applicable_state(&mut self) -> Result<()> { - if let Some(balancer_applicable_state) = - self.balancer_desired_state.to_applicable_state(()).await? - { - self.agent_controller_pool - .set_desired_state(balancer_applicable_state.agent_desired_state.clone()) - .await?; - self.balancer_applicable_state_holder - .set_balancer_applicable_state(Some(balancer_applicable_state)); - } + let balancer_applicable_state = self.balancer_desired_state.to_applicable_state(()).await?; + + self.agent_controller_pool + .set_desired_state(balancer_applicable_state.agent_desired_state.clone()) + .await?; + self.balancer_applicable_state_holder + .set_balancer_applicable_state(Some(balancer_applicable_state)); self.is_converted_to_applicable_state = true; diff --git a/paddler/src/balancer/request_from_agent.rs b/paddler/src/balancer/request_from_agent.rs index af1e8971..d66bad35 100644 --- a/paddler/src/balancer/request_from_agent.rs +++ b/paddler/src/balancer/request_from_agent.rs @@ -87,7 +87,7 @@ where } } -async fn forward_responses_stream( +pub async fn forward_responses_stream( agent_controller: Arc, connection_close: CancellationToken, inference_service_configuration: InferenceServiceConfiguration, @@ -154,22 +154,35 @@ where break; } response = receive_response_controller.response_rx.recv() => { - match response { - Some(response) => { - let is_done = response.is_done(); - - let send_succeeded = send_response_to_client( - agent_controller.clone(), - response, - request_id.clone(), - &mut session_controller, - ).await; - - if is_done || !send_succeeded { - break; - } + if let Some(response) = response { + let is_done = response.is_done(); + + let send_succeeded = send_response_to_client( + agent_controller.clone(), + response, + request_id.clone(), + &mut session_controller, + ).await; + + if is_done || !send_succeeded { + break; } - None => break, + } else { + error!( + "Response channel closed before terminator for request {request_id:?}" + ); + + respond_with_error( + JsonRpcError { + code: 502, + description: + "Response channel closed before terminator".to_owned(), + }, + request_id, + &mut session_controller, + ).await; + + break; } } } @@ -178,7 +191,7 @@ where Ok(()) } -async fn respond_with_error( +pub async fn respond_with_error( error: JsonRpcError, request_id: String, session_controller: &mut TControlsSession, @@ -208,6 +221,7 @@ where { if let Err(err) = session_controller .send_response(OutgoingMessage::Response(ResponseEnvelope { + generated_by: agent_controller.name.clone(), request_id: request_id.clone(), response: response.into(), })) diff --git a/paddler/src/balancer_applicable_state.rs b/paddler/src/balancer_applicable_state.rs index 91961b44..eafc71ac 100644 --- a/paddler/src/balancer_applicable_state.rs +++ b/paddler/src/balancer_applicable_state.rs @@ -1,4 +1,4 @@ -use crate::agent_desired_state::AgentDesiredState; +use paddler_types::agent_desired_state::AgentDesiredState; #[derive(Clone, Debug)] pub struct BalancerApplicableState { diff --git a/paddler/src/balancer_applicable_state_holder.rs b/paddler/src/balancer_applicable_state_holder.rs index b5a6407a..94902204 100644 --- a/paddler/src/balancer_applicable_state_holder.rs +++ b/paddler/src/balancer_applicable_state_holder.rs @@ -1,8 +1,8 @@ use std::sync::RwLock; +use paddler_types::agent_desired_state::AgentDesiredState; use tokio::sync::watch; -use crate::agent_desired_state::AgentDesiredState; use crate::balancer_applicable_state::BalancerApplicableState; use crate::subscribes_to_updates::SubscribesToUpdates; diff --git a/paddler/src/balancer_desired_state.rs b/paddler/src/balancer_desired_state.rs index ad810933..d714027e 100644 --- a/paddler/src/balancer_desired_state.rs +++ b/paddler/src/balancer_desired_state.rs @@ -10,12 +10,9 @@ impl ConvertsToApplicableState for BalancerDesiredState { type ApplicableState = BalancerApplicableState; type Context = (); - async fn to_applicable_state( - &self, - _context: Self::Context, - ) -> Result> { - Ok(Some(BalancerApplicableState { + async fn to_applicable_state(&self, _context: Self::Context) -> Result { + Ok(BalancerApplicableState { agent_desired_state: self.to_agent_desired_state(), - })) + }) } } diff --git a/paddler/src/cancellation_token_stream_guard.rs b/paddler/src/cancellation_token_stream_guard.rs index 71eb035f..dc836e1b 100644 --- a/paddler/src/cancellation_token_stream_guard.rs +++ b/paddler/src/cancellation_token_stream_guard.rs @@ -25,10 +25,7 @@ where { type Item = TStream::Item; - fn poll_next( - mut self: Pin<&mut Self>, - context: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(context) } } diff --git a/paddler/src/chat_template_renderer.rs b/paddler/src/chat_template_renderer/mod.rs similarity index 72% rename from paddler/src/chat_template_renderer.rs rename to paddler/src/chat_template_renderer/mod.rs index 4f111c78..94f5445e 100644 --- a/paddler/src/chat_template_renderer.rs +++ b/paddler/src/chat_template_renderer/mod.rs @@ -1,21 +1,16 @@ +pub mod pyjinja_tojson; +pub mod raise_exception; + use anyhow::Result; use minijinja::Environment; -use minijinja::Error; -use minijinja::ErrorKind; use minijinja_contrib::pycompat::unknown_method_callback; use paddler_types::chat_template::ChatTemplate; use serde::ser::Serialize; -const CHAT_TEMPLATE_NAME: &str = "chat_template"; +use self::pyjinja_tojson::pyjinja_tojson; +use self::raise_exception::raise_exception; -// Known uses: -// https://huggingface.co/bartowski/Mistral-7B-Instruct-v0.3-GGUF -fn minijinja_raise_exception(message: &str) -> std::result::Result { - Err(Error::new::( - ErrorKind::InvalidOperation, - format!("Model's chat template raised an exception: '{message}'"), - )) -} +const CHAT_TEMPLATE_NAME: &str = "chat_template"; pub struct ChatTemplateRenderer { minijinja_env: Environment<'static>, @@ -25,11 +20,12 @@ impl ChatTemplateRenderer { pub fn new(ChatTemplate { content }: ChatTemplate) -> Result { let mut minijinja_env = Environment::new(); - minijinja_env.add_function("raise_exception", minijinja_raise_exception); + minijinja_env.add_function("raise_exception", raise_exception); minijinja_env.add_template_owned(CHAT_TEMPLATE_NAME, content)?; minijinja_env.set_unknown_method_callback(unknown_method_callback); minijinja_contrib::add_to_environment(&mut minijinja_env); + minijinja_env.add_filter("tojson", pyjinja_tojson); Ok(Self { minijinja_env }) } @@ -127,4 +123,40 @@ mod tests { Ok(()) } + + #[test] + fn registers_pyjinja_tojson_filter() -> Result<()> { + let template = ChatTemplate { + content: "{{ value | tojson(ensure_ascii=False) }}".to_owned(), + }; + let renderer = ChatTemplateRenderer::new(template)?; + + let result = renderer.render(context! { value => "café" })?; + + assert_eq!(result, "\"café\""); + + Ok(()) + } + + #[test] + fn registers_raise_exception_function() -> Result<()> { + let template = ChatTemplate { + content: "{{ raise_exception('boom') }}".to_owned(), + }; + let template_renderer = ChatTemplateRenderer::new(template)?; + + let err = template_renderer + .render(context! {}) + .err() + .ok_or_else(|| anyhow::anyhow!("expected Err, got Ok"))?; + let error_message = err.to_string(); + + if !error_message.contains("boom") { + return Err(anyhow::anyhow!( + "raise_exception must surface its message; got: {error_message}" + )); + } + + Ok(()) + } } diff --git a/paddler/src/chat_template_renderer/pyjinja_tojson.rs b/paddler/src/chat_template_renderer/pyjinja_tojson.rs new file mode 100644 index 00000000..cfcc8db6 --- /dev/null +++ b/paddler/src/chat_template_renderer/pyjinja_tojson.rs @@ -0,0 +1,210 @@ +use minijinja::Error; +use minijinja::ErrorKind; +use minijinja::Value; +use minijinja::filters::tojson; +use minijinja::value::Kwargs; + +#[expect( + clippy::needless_pass_by_value, + reason = "minijinja's Filter trait requires Kwargs by value; taking &Kwargs makes the \ + function unregisterable as a filter" +)] +pub fn pyjinja_tojson(value: &Value, kwargs: Kwargs) -> Result { + let indent: Option = kwargs.get("indent")?; + + let ensure_ascii: Option = kwargs.get("ensure_ascii")?; + if matches!(ensure_ascii, Some(true)) { + return Err(Error::new( + ErrorKind::InvalidOperation, + "tojson(ensure_ascii=True) is not supported by minijinja: object output already \ + emits non-ASCII characters unescaped (matching ensure_ascii=False). Drop the \ + kwarg or set it to False.", + )); + } + + let sort_keys: Option = kwargs.get("sort_keys")?; + if matches!(sort_keys, Some(true)) { + return Err(Error::new( + ErrorKind::InvalidOperation, + "tojson(sort_keys=True) is not supported by minijinja: object key ordering follows \ + insertion order. Drop the kwarg or set it to False.", + )); + } + + let separators: Option = kwargs.get("separators")?; + if separators.is_some() { + return Err(Error::new( + ErrorKind::InvalidOperation, + "tojson(separators=...) is not supported by minijinja: separator strings are fixed.", + )); + } + + kwargs.assert_all_used()?; + + let forwarded_kwargs: Kwargs = Kwargs::from_iter(Vec::<(String, Value)>::new()); + + tojson(value, indent, forwarded_kwargs) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::anyhow; + use minijinja::Environment; + use minijinja::context; + + use super::pyjinja_tojson; + + fn render(template_source: &str, scope: minijinja::Value) -> Result { + let mut env = Environment::new(); + env.add_filter("tojson", pyjinja_tojson); + env.add_template_owned("t", template_source.to_owned())?; + Ok(env.get_template("t")?.render(scope)?) + } + + fn render_expecting_error( + template_source: &str, + scope: minijinja::Value, + ) -> Result { + let mut env = Environment::new(); + env.add_filter("tojson", pyjinja_tojson); + env.add_template_owned("t", template_source.to_owned())?; + let outcome = env.get_template("t")?.render(scope); + + outcome.err().ok_or_else(|| anyhow!("expected Err, got Ok")) + } + + #[test] + fn no_kwargs_emits_quoted_json_string() -> Result<()> { + let result = render("{{ value | tojson }}", context! { value => "hello" })?; + + assert_eq!(result, "\"hello\""); + + Ok(()) + } + + #[test] + fn ensure_ascii_false_matches_default_output() -> Result<()> { + let with_kwarg = render( + "{{ value | tojson(ensure_ascii=False) }}", + context! { value => "café" }, + )?; + let without_kwarg = render("{{ value | tojson }}", context! { value => "café" })?; + + assert_eq!(with_kwarg, without_kwarg); + assert_eq!(with_kwarg, "\"café\""); + + Ok(()) + } + + #[test] + fn ensure_ascii_true_returns_error_naming_the_kwarg() -> Result<()> { + let err = render_expecting_error( + "{{ value | tojson(ensure_ascii=True) }}", + context! { value => "x" }, + )?; + let rendered = err.to_string(); + + if !rendered.contains("ensure_ascii=True") { + return Err(anyhow!( + "error must name the rejected kwarg; got: {rendered}" + )); + } + + Ok(()) + } + + #[test] + fn sort_keys_false_matches_default_output() -> Result<()> { + let with_kwarg = render( + "{{ value | tojson(sort_keys=False) }}", + context! { value => "x" }, + )?; + + assert_eq!(with_kwarg, "\"x\""); + + Ok(()) + } + + #[test] + fn sort_keys_true_returns_error_naming_the_kwarg() -> Result<()> { + let err = render_expecting_error( + "{{ value | tojson(sort_keys=True) }}", + context! { value => "x" }, + )?; + let rendered = err.to_string(); + + if !rendered.contains("sort_keys=True") { + return Err(anyhow!( + "error must name the rejected kwarg; got: {rendered}" + )); + } + + Ok(()) + } + + #[test] + fn separators_returns_error_naming_the_kwarg() -> Result<()> { + let err = render_expecting_error( + "{{ value | tojson(separators=[',', ':']) }}", + context! { value => "x" }, + )?; + let rendered = err.to_string(); + + if !rendered.contains("separators") { + return Err(anyhow!( + "error must name the rejected kwarg; got: {rendered}" + )); + } + + Ok(()) + } + + #[test] + fn indent_kwarg_emits_pretty_printed_json() -> Result<()> { + let result = render( + "{{ value | tojson(indent=2) }}", + context! { value => context! { k => "v" } }, + )?; + + assert_eq!(result, "{\n \"k\": \"v\"\n}"); + + Ok(()) + } + + #[test] + fn indent_kwarg_combines_with_ensure_ascii_false() -> Result<()> { + let result = render( + "{{ value | tojson(ensure_ascii=False, indent=2) }}", + context! { value => context! { k => "café" } }, + )?; + + assert_eq!(result, "{\n \"k\": \"café\"\n}"); + + Ok(()) + } + + #[test] + fn unknown_kwarg_returns_error() -> Result<()> { + let err = + render_expecting_error("{{ value | tojson(bogus=42) }}", context! { value => "x" })?; + let rendered = err.to_string(); + + if !rendered.contains("bogus") { + return Err(anyhow!( + "error must name the unknown kwarg; got: {rendered}" + )); + } + + Ok(()) + } + + #[test] + fn non_ascii_codepoints_emitted_unescaped() -> Result<()> { + let result = render("{{ value | tojson }}", context! { value => "日本語" })?; + + assert_eq!(result, "\"日本語\""); + + Ok(()) + } +} diff --git a/paddler/src/chat_template_renderer/raise_exception.rs b/paddler/src/chat_template_renderer/raise_exception.rs new file mode 100644 index 00000000..30a4efe4 --- /dev/null +++ b/paddler/src/chat_template_renderer/raise_exception.rs @@ -0,0 +1,46 @@ +use minijinja::Error; +use minijinja::ErrorKind; + +// Surfaces errors raised explicitly inside a chat template. Known uses: +// https://huggingface.co/bartowski/Mistral-7B-Instruct-v0.3-GGUF +pub fn raise_exception(message: &str) -> Result { + Err(Error::new::( + ErrorKind::InvalidOperation, + format!("Model's chat template raised an exception: '{message}'"), + )) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::anyhow; + + use super::raise_exception; + + #[test] + fn returns_err_with_supplied_message_quoted() -> Result<()> { + let err = raise_exception("template is invalid") + .err() + .ok_or_else(|| anyhow!("expected Err, got Ok"))?; + let rendered = err.to_string(); + + if !rendered.contains("template is invalid") { + return Err(anyhow!( + "error must include the supplied message; got: {rendered}" + )); + } + + Ok(()) + } + + #[test] + fn returns_err_with_invalid_operation_kind() -> Result<()> { + let err = raise_exception("anything") + .err() + .ok_or_else(|| anyhow!("expected Err, got Ok"))?; + + assert_eq!(err.kind(), minijinja::ErrorKind::InvalidOperation); + + Ok(()) + } +} diff --git a/paddler/src/controls_websocket_endpoint.rs b/paddler/src/controls_websocket_endpoint.rs index 8a8a884f..8185053d 100644 --- a/paddler/src/controls_websocket_endpoint.rs +++ b/paddler/src/controls_websocket_endpoint.rs @@ -16,6 +16,7 @@ use async_trait::async_trait; use futures_util::StreamExt as _; use log::debug; use log::error; +use log::warn; use paddler_types::rpc_message::RpcMessage; use serde::de::DeserializeOwned; use tokio::time::Duration; @@ -231,14 +232,22 @@ pub trait ControlsWebSocketEndpoint: Send + Sync + 'static { Ok(ContinuationDecision::Stop(stop_parameters)) => { close_reason = stop_parameters.close_reason; - let _ = session.close(close_reason).await; + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed after Stop decision (peer likely already disconnected): {close_err:?}" + ); + } return; } Err(err) => { error!("Error in connection start handler: {err:?}"); - let _ = session.close(close_reason).await; + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed after start-handler error (peer likely already disconnected): {close_err:?}" + ); + } return; } @@ -291,7 +300,11 @@ pub trait ControlsWebSocketEndpoint: Send + Sync + 'static { connection_close.cancel(); - let _ = session.close(close_reason).await; + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed at end of message loop (peer likely already disconnected): {close_err:?}" + ); + } }); Ok(res) diff --git a/paddler/src/converts_to_applicable_state.rs b/paddler/src/converts_to_applicable_state.rs index 210176bf..4e8a065d 100644 --- a/paddler/src/converts_to_applicable_state.rs +++ b/paddler/src/converts_to_applicable_state.rs @@ -6,8 +6,5 @@ pub trait ConvertsToApplicableState { type ApplicableState; type Context; - async fn to_applicable_state( - &self, - context: Self::Context, - ) -> Result>; + async fn to_applicable_state(&self, context: Self::Context) -> Result; } diff --git a/paddler/src/decoded_image.rs b/paddler/src/decoded_image.rs index 8635a782..1d1d5b74 100644 --- a/paddler/src/decoded_image.rs +++ b/paddler/src/decoded_image.rs @@ -2,6 +2,7 @@ use std::io::Cursor; use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use image::DynamicImage; use image::ImageFormat; use image::imageops::FilterType; use log::info; @@ -38,7 +39,10 @@ fn compute_target_dimension(svg_dim: f64, scale: f64) -> Result Result, DecodedImageError> { +fn rasterize_svg_to_dynamic_image( + data: &[u8], + max_dimension: u32, +) -> Result { let svg_tree = SvgTree::from_data(data, &Options::default()).map_err(|err| { DecodedImageError::ConversionFailed { message: format!("Failed to parse SVG: {err}"), @@ -73,22 +77,60 @@ fn rasterize_svg_to_png(data: &[u8], max_dimension: u32) -> Result, Deco resvg::render(&svg_tree, transform, &mut pixmap.as_mut()); - pixmap - .encode_png() - .map_err(|err| DecodedImageError::ConversionFailed { - message: format!("Failed to encode SVG rasterization to PNG: {err}"), - }) + let rgba = image::RgbaImage::from_raw(target_width, target_height, pixmap.data().to_vec()) + .ok_or_else(|| DecodedImageError::ConversionFailed { + message: "rasterized SVG buffer did not match target dimensions".to_owned(), + })?; + + Ok(DynamicImage::ImageRgba8(rgba)) } -fn reencode_to_png(data: &[u8]) -> Result, DecodedImageError> { - let dynamic_image = +enum LoadedImageOrigin { + PassThroughEligible, + NeedsReencode, +} + +fn load_supported_image( + data: &[u8], + max_dimension: u32, +) -> Result<(DynamicImage, LoadedImageOrigin), DecodedImageError> { + if is_svg(data) { + info!("Rasterizing SVG (max_dimension: {max_dimension})"); + let image = rasterize_svg_to_dynamic_image(data, max_dimension)?; + return Ok((image, LoadedImageOrigin::NeedsReencode)); + } + + let format = image::guess_format(data).map_err(|err| DecodedImageError::ConversionFailed { + message: err.to_string(), + })?; + + let origin = match format { + ImageFormat::Png | ImageFormat::Jpeg | ImageFormat::Gif | ImageFormat::Bmp => { + LoadedImageOrigin::PassThroughEligible + } + unsupported if !unsupported.reading_enabled() => { + return Err(DecodedImageError::UnsupportedFormat { + format: format!("{unsupported:?}"), + }); + } + convertible_format => { + info!("Converting {convertible_format:?} image to PNG for llama.cpp compatibility"); + LoadedImageOrigin::NeedsReencode + } + }; + + let image = image::load_from_memory(data).map_err(|err| DecodedImageError::ConversionFailed { message: err.to_string(), })?; + Ok((image, origin)) +} + +fn encode_png(image: &DynamicImage) -> Result, DecodedImageError> { let mut output_buffer = Cursor::new(Vec::new()); - dynamic_image + image .write_to(&mut output_buffer, ImageFormat::Png) .map_err(|err| DecodedImageError::ConversionFailed { message: err.to_string(), @@ -97,50 +139,24 @@ fn reencode_to_png(data: &[u8]) -> Result, DecodedImageError> { Ok(output_buffer.into_inner()) } +fn encode_jpeg(image: &DynamicImage) -> Result, DecodedImageError> { + let mut output_buffer = Cursor::new(Vec::new()); + + image + .write_to(&mut output_buffer, ImageFormat::Jpeg) + .map_err(|err| DecodedImageError::ResizeFailed { + message: err.to_string(), + })?; + + Ok(output_buffer.into_inner()) +} + #[derive(Debug)] pub struct DecodedImage { pub data: Vec, } impl DecodedImage { - pub fn converted_to_png_if_necessary( - self, - max_dimension: u32, - ) -> Result { - if max_dimension == 0 { - return Err(DecodedImageError::InvalidMaxDimension); - } - - if is_svg(&self.data) { - info!("Converting SVG to PNG (max_dimension: {max_dimension})"); - - let png_data = rasterize_svg_to_png(&self.data, max_dimension)?; - - return Ok(Self { data: png_data }); - } - - let format = - image::guess_format(&self.data).map_err(|err| DecodedImageError::ConversionFailed { - message: err.to_string(), - })?; - - match format { - ImageFormat::Png | ImageFormat::Jpeg | ImageFormat::Gif | ImageFormat::Bmp => Ok(self), - unsupported_format if !unsupported_format.reading_enabled() => { - Err(DecodedImageError::UnsupportedFormat { - format: format!("{unsupported_format:?}"), - }) - } - convertible_format => { - info!("Converting {convertible_format:?} image to PNG for llama.cpp compatibility"); - - let png_data = reencode_to_png(&self.data)?; - - Ok(Self { data: png_data }) - } - } - } - pub fn from_data_uri(image_url: &ImageUrl) -> Result { let url = &image_url.url; @@ -161,36 +177,30 @@ impl DecodedImage { Ok(Self { data }) } - pub fn resized_to_fit(self, max_dimension: u32) -> Result { + pub fn prepared_for_inference(self, max_dimension: u32) -> Result { if max_dimension == 0 { return Err(DecodedImageError::InvalidMaxDimension); } - let dynamic_image = - image::load_from_memory(&self.data).map_err(|err| DecodedImageError::ResizeFailed { - message: err.to_string(), - })?; + let (image, origin) = load_supported_image(&self.data, max_dimension)?; - let width = dynamic_image.width(); - let height = dynamic_image.height(); + let width = image.width(); + let height = image.height(); + let needs_resize = width > max_dimension || height > max_dimension; - if width <= max_dimension && height <= max_dimension { - return Ok(self); + if needs_resize { + let resized = image.resize(max_dimension, max_dimension, FilterType::Lanczos3); + return Ok(Self { + data: encode_jpeg(&resized)?, + }); } - let resized = dynamic_image.resize(max_dimension, max_dimension, FilterType::Lanczos3); - - let mut output_buffer = Cursor::new(Vec::new()); - - resized - .write_to(&mut output_buffer, ImageFormat::Jpeg) - .map_err(|err| DecodedImageError::ResizeFailed { - message: err.to_string(), - })?; - - Ok(Self { - data: output_buffer.into_inner(), - }) + match origin { + LoadedImageOrigin::PassThroughEligible => Ok(self), + LoadedImageOrigin::NeedsReencode => Ok(Self { + data: encode_png(&image)?, + }), + } } } @@ -334,252 +344,193 @@ mod tests { } #[test] - fn test_resized_to_fit_shrinks_oversized_image() -> Result<()> { - let original_data = create_test_jpeg(2000, 1500)?; - let decoded_image = DecodedImage { - data: original_data, - }; - - let resized = decoded_image.resized_to_fit(1024)?; - - let result_image = image::load_from_memory(&resized.data)?; - - assert!(result_image.width() <= 1024); - assert!(result_image.height() <= 1024); - - Ok(()) - } - - #[test] - fn test_resized_to_fit_preserves_aspect_ratio() -> Result<()> { - let original_data = create_test_jpeg(2000, 1000)?; - let decoded_image = DecodedImage { - data: original_data, - }; - - let resized = decoded_image.resized_to_fit(1000)?; - - let result_image = image::load_from_memory(&resized.data)?; - - assert_eq!(result_image.width(), 1000); - assert_eq!(result_image.height(), 500); - - Ok(()) - } - - #[test] - fn test_resized_to_fit_skips_small_image() -> Result<()> { - let original_data = create_test_jpeg(512, 256)?; - let original_len = original_data.len(); - let decoded_image = DecodedImage { - data: original_data, - }; - - let resized = decoded_image.resized_to_fit(1024)?; - - assert_eq!(resized.data.len(), original_len); - - Ok(()) - } - - #[test] - fn test_resized_to_fit_with_llamas_fixture() -> Result<()> { - let fixture_data = std::fs::read(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../fixtures/llamas.jpg" - ))?; - - let original_image = image::load_from_memory(&fixture_data)?; - - assert_eq!(original_image.width(), 640); - assert_eq!(original_image.height(), 427); - - let decoded_image = DecodedImage { data: fixture_data }; - let resized = decoded_image.resized_to_fit(320)?; - - let result_image = image::load_from_memory(&resized.data)?; - - assert_eq!(result_image.width(), 320); - assert_eq!(result_image.height(), 214); - - Ok(()) - } - - #[test] - fn test_converted_to_png_passes_through_jpeg() -> Result<()> { + fn test_prepared_passes_through_small_jpeg() -> Result<()> { let jpeg_data = create_test_jpeg(100, 100)?; let original_len = jpeg_data.len(); - let decoded_image = DecodedImage { data: jpeg_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; + let result = decoded_image.prepared_for_inference(1024)?; assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_converted_to_png_passes_through_png() -> Result<()> { + fn test_prepared_passes_through_small_png() -> Result<()> { let png_data = create_test_png(100, 100)?; let original_len = png_data.len(); - let decoded_image = DecodedImage { data: png_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; + let result = decoded_image.prepared_for_inference(1024)?; assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_converted_to_png_passes_through_gif() -> Result<()> { + fn test_prepared_passes_through_small_gif() -> Result<()> { let gif_data = create_test_gif(100, 100)?; let original_len = gif_data.len(); - let decoded_image = DecodedImage { data: gif_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; + let result = decoded_image.prepared_for_inference(1024)?; assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_converted_to_png_passes_through_bmp() -> Result<()> { + fn test_prepared_passes_through_small_bmp() -> Result<()> { let bmp_data = create_test_bmp(100, 100)?; let original_len = bmp_data.len(); - let decoded_image = DecodedImage { data: bmp_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; + let result = decoded_image.prepared_for_inference(1024)?; assert_eq!(result.data.len(), original_len); + Ok(()) + } + + #[test] + fn test_prepared_converts_small_tiff_to_png() -> Result<()> { + let tiff_data = create_test_tiff(100, 100)?; + let decoded_image = DecodedImage { data: tiff_data }; + let result = decoded_image.prepared_for_inference(1024)?; + + let result_format = image::guess_format(&result.data)?; + assert_eq!(result_format, ImageFormat::Png); Ok(()) } #[test] - fn test_converted_to_png_converts_webp_fixture() -> Result<()> { + fn test_prepared_converts_small_webp_fixture_to_png() -> Result<()> { let webp_data = load_fixture("llamas.webp")?; - let decoded_image = DecodedImage { data: webp_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; - let result_format = image::guess_format(&result.data)?; + let result = decoded_image.prepared_for_inference(1024)?; + let result_format = image::guess_format(&result.data)?; assert_eq!(result_format, ImageFormat::Png); let result_image = image::load_from_memory(&result.data)?; - assert_eq!(result_image.width(), 640); assert_eq!(result_image.height(), 427); - Ok(()) } #[test] - fn test_converted_to_png_rasterizes_svg_fixture() -> Result<()> { - let svg_data = load_fixture("llamas.svg")?; + fn test_prepared_rasterizes_small_svg() -> Result<()> { + let svg_data = br#" + + "#; + let decoded_image = DecodedImage { + data: svg_data.to_vec(), + }; - let decoded_image = DecodedImage { data: svg_data }; - let result = decoded_image.converted_to_png_if_necessary(320)?; + let result = decoded_image.prepared_for_inference(1024)?; let result_format = image::guess_format(&result.data)?; - assert_eq!(result_format, ImageFormat::Png); let result_image = image::load_from_memory(&result.data)?; - - assert!(result_image.width() <= 320); - assert!(result_image.height() <= 320); - + assert_eq!(result_image.width(), 50); + assert_eq!(result_image.height(), 50); Ok(()) } #[test] - fn test_converted_to_png_rejects_zero_max_dimension() -> Result<()> { + fn test_prepared_rasterizes_svg_fixture_within_bound() -> Result<()> { let svg_data = load_fixture("llamas.svg")?; - let decoded_image = DecodedImage { data: svg_data }; - let result = decoded_image.converted_to_png_if_necessary(0); + let result = decoded_image.prepared_for_inference(320)?; + + let result_format = image::guess_format(&result.data)?; + let result_image = image::load_from_memory(&result.data)?; + + assert!(result_image.width() <= 320); + assert!(result_image.height() <= 320); assert!(matches!( - result, - Err(DecodedImageError::InvalidMaxDimension) + result_format, + ImageFormat::Png | ImageFormat::Jpeg )); - Ok(()) } #[test] - fn test_resized_to_fit_rejects_zero_max_dimension() -> Result<()> { - let original_data = create_test_jpeg(200, 100)?; - let decoded_image = DecodedImage { - data: original_data, - }; + fn test_prepared_resizes_oversized_jpeg_to_jpeg() -> Result<()> { + let jpeg_data = create_test_jpeg(2000, 1500)?; + let decoded_image = DecodedImage { data: jpeg_data }; - let result = decoded_image.resized_to_fit(0); + let result = decoded_image.prepared_for_inference(1024)?; - assert!(matches!( - result, - Err(DecodedImageError::InvalidMaxDimension) - )); + let result_format = image::guess_format(&result.data)?; + assert_eq!(result_format, ImageFormat::Jpeg); + let result_image = image::load_from_memory(&result.data)?; + assert!(result_image.width() <= 1024); + assert!(result_image.height() <= 1024); Ok(()) } #[test] - fn test_converted_to_png_converts_tiff() -> Result<()> { - let tiff_data = create_test_tiff(100, 100)?; - - let decoded_image = DecodedImage { data: tiff_data }; - let result = decoded_image.converted_to_png_if_necessary(1024)?; + fn test_prepared_preserves_aspect_ratio_on_resize() -> Result<()> { + let jpeg_data = create_test_jpeg(2000, 1000)?; + let decoded_image = DecodedImage { data: jpeg_data }; - let result_format = image::guess_format(&result.data)?; - - assert_eq!(result_format, ImageFormat::Png); + let result = decoded_image.prepared_for_inference(1000)?; + let result_image = image::load_from_memory(&result.data)?; + assert_eq!(result_image.width(), 1000); + assert_eq!(result_image.height(), 500); Ok(()) } #[test] - fn test_converted_to_png_rasterizes_small_svg() -> Result<()> { - let svg_data = br#" - - "#; - - let decoded_image = DecodedImage { - data: svg_data.to_vec(), - }; - - let result = decoded_image.converted_to_png_if_necessary(1024)?; + fn test_prepared_with_jpg_fixture_within_bound() -> Result<()> { + let fixture_data = std::fs::read(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../fixtures/llamas.jpg" + ))?; - let result_format = image::guess_format(&result.data)?; + let original_image = image::load_from_memory(&fixture_data)?; + assert_eq!(original_image.width(), 640); + assert_eq!(original_image.height(), 427); - assert_eq!(result_format, ImageFormat::Png); + let decoded_image = DecodedImage { data: fixture_data }; + let result = decoded_image.prepared_for_inference(320)?; let result_image = image::load_from_memory(&result.data)?; + assert_eq!(result_image.width(), 320); + assert_eq!(result_image.height(), 214); + Ok(()) + } - assert_eq!(result_image.width(), 50); - assert_eq!(result_image.height(), 50); + #[test] + fn test_prepared_rejects_zero_max_dimension() -> Result<()> { + let png_data = create_test_png(50, 50)?; + let decoded_image = DecodedImage { data: png_data }; + let result = decoded_image.prepared_for_inference(0); + + assert!(matches!( + result, + Err(DecodedImageError::InvalidMaxDimension) + )); Ok(()) } #[test] - fn test_converted_to_png_rejects_zero_dimension_svg() { + fn test_prepared_rejects_zero_dimension_svg() { let svg_data = br#" "#; - let decoded_image = DecodedImage { data: svg_data.to_vec(), }; - let result = decoded_image.converted_to_png_if_necessary(1024); + let result = decoded_image.prepared_for_inference(1024); assert!(matches!( result, diff --git a/paddler/src/desired_model_resolution.rs b/paddler/src/desired_model_resolution.rs new file mode 100644 index 00000000..a825b19d --- /dev/null +++ b/paddler/src/desired_model_resolution.rs @@ -0,0 +1,7 @@ +use std::path::PathBuf; + +pub enum DesiredModelResolution { + NotConfigured, + Resolved(PathBuf), + LocalFileMissing(PathBuf), +} diff --git a/paddler/src/download_huggingface_model.rs b/paddler/src/download_huggingface_model.rs new file mode 100644 index 00000000..787d8d40 --- /dev/null +++ b/paddler/src/download_huggingface_model.rs @@ -0,0 +1,125 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use anyhow::anyhow; +use hf_hub::Cache; +use hf_hub::Repo; +use hf_hub::RepoType; +use hf_hub::api::tokio::ApiBuilder; +use hf_hub::api::tokio::ApiError; +use log::warn; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::agent_issue_params::HuggingFaceDownloadLock; +use paddler_types::agent_issue_params::ModelPath; +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use tokio::time::Duration; +use tokio::time::sleep; + +use crate::agent_issue_fix::AgentIssueFix; +use crate::slot_aggregated_status::SlotAggregatedStatus; +use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloadProgress; + +const LOCK_RETRY_TIMEOUT: Duration = Duration::from_secs(10); + +pub async fn download_huggingface_model( + reference: &HuggingFaceModelReference, + slot_aggregated_status: Arc, +) -> Result { + let HuggingFaceModelReference { + filename, + repo_id, + revision, + } = reference; + let model_path = format!("{repo_id}/{revision}/{filename}"); + + if slot_aggregated_status.has_issue(&AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { + model_path: model_path.clone(), + })) { + return Err(anyhow!( + "Model '{model_path}' does not exist on Hugging Face. Not attempting to download it again." + )); + } + + let hf_cache = Cache::from_env(); + let hf_api = ApiBuilder::from_cache(hf_cache.clone()).build()?; + let hf_repo = hf_api.repo(Repo::with_revision( + repo_id.to_owned(), + RepoType::Model, + revision.to_owned(), + )); + + if let Some(cached_path) = hf_cache + .repo(Repo::new(repo_id.to_owned(), RepoType::Model)) + .get(filename) + { + slot_aggregated_status.reset_download(); + + return Ok(cached_path); + } + + match hf_repo + .download_with_progress( + filename, + SlotAggregatedStatusDownloadProgress::new(slot_aggregated_status.clone()), + ) + .await + { + Ok(resolved_filename) => { + slot_aggregated_status.register_fix(&AgentIssueFix::HuggingFaceDownloadedModel( + ModelPath { model_path }, + )); + + Ok(resolved_filename) + } + Err(ApiError::LockAcquisition(lock_path)) => { + slot_aggregated_status.register_issue(AgentIssue::HuggingFaceCannotAcquireLock( + HuggingFaceDownloadLock { + lock_path: lock_path.display().to_string(), + model_path: ModelPath { model_path }, + }, + )); + + warn!( + "Waiting to acquire download lock for '{}'. Sleeping for {} secs", + lock_path.display(), + LOCK_RETRY_TIMEOUT.as_secs() + ); + + sleep(LOCK_RETRY_TIMEOUT).await; + + Err(anyhow!( + "Failed to acquire download lock '{}'. Is more than one agent running on this machine?", + lock_path.display() + )) + } + Err(ApiError::RequestError(reqwest_error)) => match reqwest_error.status() { + Some(reqwest::StatusCode::NOT_FOUND) => { + slot_aggregated_status.register_issue(AgentIssue::HuggingFaceModelDoesNotExist( + ModelPath { + model_path: model_path.clone(), + }, + )); + + Err(anyhow!( + "Model '{model_path}' does not exist on Hugging Face." + )) + } + Some(reqwest::StatusCode::FORBIDDEN | reqwest::StatusCode::UNAUTHORIZED) => { + slot_aggregated_status.register_issue(AgentIssue::HuggingFacePermissions( + ModelPath { + model_path: model_path.clone(), + }, + )); + + Err(anyhow!( + "You do not have enough permissions to download '{model_path}' from Hugging Face." + )) + } + _ => Err(anyhow!( + "Failed to download model from Hugging Face: {reqwest_error}" + )), + }, + Err(err_other) => Err(err_other.into()), + } +} diff --git a/paddler/src/lib.rs b/paddler/src/lib.rs index 5dd9c630..88a18d87 100644 --- a/paddler/src/lib.rs +++ b/paddler/src/lib.rs @@ -1,9 +1,6 @@ -pub use llama_cpp_bindings; - pub mod agent; pub mod agent_applicable_state; pub mod agent_applicable_state_holder; -pub mod agent_desired_model; pub mod agent_desired_state; pub mod agent_issue_fix; pub mod atomic_value; @@ -23,9 +20,12 @@ pub mod converts_to_llama_pooling_type; pub mod create_cors_middleware; pub mod decoded_image; pub mod decoded_image_error; +pub mod desired_model_resolution; pub mod dispenses_slots; +pub mod download_huggingface_model; pub mod embedding_input_tokenized; pub mod produces_snapshot; +pub mod resolve_desired_model; pub mod resolved_socket_addr; pub mod sends_rpc_message; pub mod service; @@ -38,4 +38,10 @@ pub mod snapshots_stream; #[cfg(feature = "web_admin_panel")] pub mod static_files; pub mod subscribes_to_updates; +pub mod tool_call_buffer; +pub mod tool_call_event; +pub mod tool_call_pipeline; +pub mod tool_call_pipeline_error; +pub mod tool_call_validation_error; +pub mod tool_call_validator; pub mod websocket_session_controller; diff --git a/paddler/src/resolve_desired_model.rs b/paddler/src/resolve_desired_model.rs new file mode 100644 index 00000000..25683ea8 --- /dev/null +++ b/paddler/src/resolve_desired_model.rs @@ -0,0 +1,103 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; + +use crate::desired_model_resolution::DesiredModelResolution; +use crate::download_huggingface_model::download_huggingface_model; +use crate::slot_aggregated_status::SlotAggregatedStatus; + +pub async fn resolve_desired_model( + desired: &AgentDesiredModel, + slot_aggregated_status: Arc, +) -> Result { + match desired { + AgentDesiredModel::HuggingFace(reference) => { + let path = download_huggingface_model(reference, slot_aggregated_status).await?; + + Ok(DesiredModelResolution::Resolved(path)) + } + AgentDesiredModel::LocalToAgent(path) => { + let local_path = PathBuf::from(path); + + if tokio::fs::try_exists(&local_path).await? { + Ok(DesiredModelResolution::Resolved(local_path)) + } else { + Ok(DesiredModelResolution::LocalFileMissing(local_path)) + } + } + AgentDesiredModel::None => Ok(DesiredModelResolution::NotConfigured), + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + use std::sync::Arc; + + use anyhow::Result; + use paddler_types::agent_desired_model::AgentDesiredModel; + use tempfile::NamedTempFile; + use tempfile::TempDir; + + use crate::desired_model_resolution::DesiredModelResolution; + use crate::resolve_desired_model::resolve_desired_model; + use crate::slot_aggregated_status::SlotAggregatedStatus; + + fn fresh_status() -> Arc { + Arc::new(SlotAggregatedStatus::new(1)) + } + + fn nonexistent_path_in_temp_dir(label: &str) -> Result<(TempDir, PathBuf)> { + let dir = tempfile::tempdir()?; + let path = dir.path().join(format!("missing-{label}.gguf")); + + Ok((dir, path)) + } + + #[tokio::test] + async fn local_existing_file_resolves_to_resolved_with_that_path() -> Result<()> { + let status = fresh_status(); + let temp_file = NamedTempFile::new()?; + let path = temp_file.path().to_path_buf(); + let desired = AgentDesiredModel::LocalToAgent(path.display().to_string()); + + let resolution = resolve_desired_model(&desired, status).await?; + + assert!(matches!( + resolution, + DesiredModelResolution::Resolved(ref resolved) if *resolved == path + )); + + Ok(()) + } + + #[tokio::test] + async fn local_missing_file_resolves_to_local_file_missing_with_that_path() -> Result<()> { + let status = fresh_status(); + let (_dir_guard, path) = nonexistent_path_in_temp_dir("desired")?; + let desired = AgentDesiredModel::LocalToAgent(path.display().to_string()); + + let resolution = resolve_desired_model(&desired, status).await?; + + assert!(matches!( + resolution, + DesiredModelResolution::LocalFileMissing(ref missing) if *missing == path + )); + + Ok(()) + } + + #[tokio::test] + async fn none_variant_resolves_to_not_configured() -> Result<()> { + let status = fresh_status(); + let desired = AgentDesiredModel::None; + + let resolution = resolve_desired_model(&desired, status).await?; + + assert!(matches!(resolution, DesiredModelResolution::NotConfigured)); + + Ok(()) + } +} diff --git a/paddler/src/sets_desired_state.rs b/paddler/src/sets_desired_state.rs index 9e78aa5d..8d182a39 100644 --- a/paddler/src/sets_desired_state.rs +++ b/paddler/src/sets_desired_state.rs @@ -1,7 +1,6 @@ use anyhow::Result; use async_trait::async_trait; - -use crate::agent_desired_state::AgentDesiredState; +use paddler_types::agent_desired_state::AgentDesiredState; #[async_trait] pub trait SetsDesiredState { diff --git a/paddler/src/tool_call_buffer.rs b/paddler/src/tool_call_buffer.rs new file mode 100644 index 00000000..cb214933 --- /dev/null +++ b/paddler/src/tool_call_buffer.rs @@ -0,0 +1,116 @@ +#[derive(Debug, Default)] +pub struct ToolCallBuffer { + accumulated: String, +} + +impl ToolCallBuffer { + #[must_use] + pub const fn new() -> Self { + Self { + accumulated: String::new(), + } + } + + pub fn append(&mut self, fragment: &str) { + self.accumulated.push_str(fragment); + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.accumulated + } + + pub fn clear(&mut self) { + self.accumulated.clear(); + } + + #[must_use] + pub fn take(&mut self) -> String { + std::mem::take(&mut self.accumulated) + } + + #[must_use] + pub const fn is_empty(&self) -> bool { + self.accumulated.is_empty() + } + + #[must_use] + pub const fn len(&self) -> usize { + self.accumulated.len() + } +} + +#[cfg(test)] +mod tests { + use super::ToolCallBuffer; + + #[test] + fn new_is_empty() { + let buffer = ToolCallBuffer::new(); + + assert!(buffer.is_empty()); + assert_eq!(buffer.len(), 0); + assert_eq!(buffer.as_str(), ""); + } + + #[test] + fn append_extends_buffer() { + let mut buffer = ToolCallBuffer::new(); + buffer.append("hello"); + + assert_eq!(buffer.as_str(), "hello"); + assert_eq!(buffer.len(), 5); + assert!(!buffer.is_empty()); + } + + #[test] + fn multiple_appends_concatenate() { + let mut buffer = ToolCallBuffer::new(); + buffer.append("\n"); + buffer.append("{\"name\":\"x\"}"); + buffer.append("\n"); + + assert_eq!( + buffer.as_str(), + "\n{\"name\":\"x\"}\n" + ); + } + + #[test] + fn clear_resets_to_empty() { + let mut buffer = ToolCallBuffer::new(); + buffer.append("data"); + buffer.clear(); + + assert!(buffer.is_empty()); + assert_eq!(buffer.len(), 0); + } + + #[test] + fn take_returns_contents_and_clears() { + let mut buffer = ToolCallBuffer::new(); + buffer.append("hello"); + let taken = buffer.take(); + + assert_eq!(taken, "hello"); + assert!(buffer.is_empty()); + } + + #[test] + fn take_on_empty_returns_empty_string() { + let mut buffer = ToolCallBuffer::new(); + let taken = buffer.take(); + + assert!(taken.is_empty()); + assert!(buffer.is_empty()); + } + + #[test] + fn append_handles_unicode() { + let mut buffer = ToolCallBuffer::new(); + buffer.append("héllo"); + + assert_eq!(buffer.as_str(), "héllo"); + assert_eq!(buffer.len(), "héllo".len()); + } +} diff --git a/paddler/src/tool_call_event.rs b/paddler/src/tool_call_event.rs new file mode 100644 index 00000000..54a692c1 --- /dev/null +++ b/paddler/src/tool_call_event.rs @@ -0,0 +1,185 @@ +use llama_cpp_bindings::ParsedToolCall; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::raw_tool_call_tokens::RawToolCallTokens; + +use crate::tool_call_pipeline_error::ToolCallPipelineError; +use crate::tool_call_validation_error::ToolCallValidationError; + +#[derive(Debug)] +pub enum ToolCallEvent { + Pending, + Resolved(Vec), + ParseFailed(ToolCallPipelineError), + ValidationFailed(Vec), + UnrecognizedFormat(RawToolCallTokens), +} + +impl ToolCallEvent { + #[must_use] + pub const fn is_resolved(&self) -> bool { + matches!(self, Self::Resolved(_)) + } + + #[must_use] + pub const fn is_failure(&self) -> bool { + matches!(self, Self::ParseFailed(_) | Self::ValidationFailed(_)) + } + + #[must_use] + pub const fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + #[must_use] + pub fn into_generated_token_result(self) -> Option { + match self { + Self::Resolved(parsed) => Some(GeneratedTokenResult::ToolCallParsed(parsed)), + Self::ParseFailed(err) => { + Some(GeneratedTokenResult::ToolCallParseFailed(err.to_string())) + } + Self::ValidationFailed(errors) => Some(GeneratedTokenResult::ToolCallValidationFailed( + errors.into_iter().map(|err| err.to_string()).collect(), + )), + Self::UnrecognizedFormat(raw) => { + Some(GeneratedTokenResult::UnrecognizedToolCallFormat(raw)) + } + Self::Pending => None, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::bail; + use llama_cpp_bindings::ParsedToolCall; + use llama_cpp_bindings::ToolCallArguments; + use paddler_types::generated_token_result::GeneratedTokenResult; + use paddler_types::raw_tool_call_tokens::RawToolCallTokens; + use serde_json::json; + + use super::ToolCallEvent; + use crate::tool_call_pipeline_error::ToolCallPipelineError; + use crate::tool_call_validation_error::ToolCallValidationError; + + #[test] + fn pending_classifies_as_pending() { + let event = ToolCallEvent::Pending; + + assert!(event.is_pending()); + assert!(!event.is_resolved()); + assert!(!event.is_failure()); + } + + #[test] + fn resolved_classifies_as_resolved() { + let event = ToolCallEvent::Resolved(vec![ParsedToolCall::default()]); + + assert!(event.is_resolved()); + assert!(!event.is_pending()); + assert!(!event.is_failure()); + } + + #[test] + fn parse_failed_classifies_as_failure() { + let event = ToolCallEvent::ParseFailed(ToolCallPipelineError::EmptyBuffer); + + assert!(event.is_failure()); + assert!(!event.is_resolved()); + assert!(!event.is_pending()); + } + + #[test] + fn validation_failed_classifies_as_failure() { + let event = + ToolCallEvent::ValidationFailed(vec![ToolCallValidationError::UnknownToolName( + "x".to_owned(), + )]); + + assert!(event.is_failure()); + assert!(!event.is_resolved()); + } + + #[test] + fn pending_converts_to_none() { + assert!( + ToolCallEvent::Pending + .into_generated_token_result() + .is_none() + ); + } + + #[test] + fn resolved_converts_to_tool_call_parsed() -> Result<()> { + let parsed = ParsedToolCall::new( + "id".to_owned(), + "tool".to_owned(), + ToolCallArguments::ValidJson(json!({})), + ); + let event = ToolCallEvent::Resolved(vec![parsed.clone()]); + + match event.into_generated_token_result() { + Some(GeneratedTokenResult::ToolCallParsed(calls)) if calls == vec![parsed] => Ok(()), + other => bail!("expected ToolCallParsed with one call, got {other:?}"), + } + } + + #[test] + fn parse_failed_converts_to_tool_call_parse_failed_with_message() -> Result<()> { + let event = ToolCallEvent::ParseFailed(ToolCallPipelineError::EmptyBuffer); + + match event.into_generated_token_result() { + Some(GeneratedTokenResult::ToolCallParseFailed(message)) if !message.is_empty() => { + Ok(()) + } + other => bail!("expected ToolCallParseFailed with non-empty message, got {other:?}"), + } + } + + #[test] + fn validation_failed_converts_to_tool_call_validation_failed_with_messages() -> Result<()> { + let event = + ToolCallEvent::ValidationFailed(vec![ToolCallValidationError::UnknownToolName( + "missing".to_owned(), + )]); + + match event.into_generated_token_result() { + Some(GeneratedTokenResult::ToolCallValidationFailed(messages)) + if messages.len() == 1 && messages[0].contains("missing") => + { + Ok(()) + } + other => bail!("expected ToolCallValidationFailed mentioning 'missing', got {other:?}"), + } + } + + #[test] + fn unrecognized_format_classifies_as_neither_resolved_nor_failure_nor_pending() { + let event = ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { + text: "raw".to_owned(), + ffi_error_message: "bailed".to_owned(), + }); + + assert!(!event.is_pending()); + assert!(!event.is_resolved()); + assert!(!event.is_failure()); + } + + #[test] + fn unrecognized_format_converts_to_unrecognized_tool_call_format_preserving_payload() + -> Result<()> { + let event = ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { + text: "raw output".to_owned(), + ffi_error_message: "parser bailed".to_owned(), + }); + + match event.into_generated_token_result() { + Some(GeneratedTokenResult::UnrecognizedToolCallFormat(raw)) => { + assert_eq!(raw.text, "raw output"); + assert_eq!(raw.ffi_error_message, "parser bailed"); + Ok(()) + } + other => bail!("expected UnrecognizedToolCallFormat preserving payload, got {other:?}"), + } + } +} diff --git a/paddler/src/tool_call_pipeline.rs b/paddler/src/tool_call_pipeline.rs new file mode 100644 index 00000000..016de45f --- /dev/null +++ b/paddler/src/tool_call_pipeline.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; + +use llama_cpp_bindings::ChatMessageParseOutcome; +use llama_cpp_bindings::ParsedToolCall; +use llama_cpp_bindings::RawChatMessage; +use llama_cpp_bindings::model::LlamaModel; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::raw_tool_call_tokens::RawToolCallTokens; + +use crate::tool_call_buffer::ToolCallBuffer; +use crate::tool_call_event::ToolCallEvent; +use crate::tool_call_pipeline_error::ToolCallPipelineError; +use crate::tool_call_validator::ToolCallValidator; + +pub struct ToolCallPipeline { + buffer: ToolCallBuffer, + model: Arc, + tools_json: Arc, + validator: ToolCallValidator, +} + +impl ToolCallPipeline { + pub fn new( + model: Arc, + tools: &[serde_json::Value], + validator: ToolCallValidator, + ) -> Result { + let tools_json = Arc::from(serde_json::to_string(tools)?); + + Ok(Self { + buffer: ToolCallBuffer::new(), + model, + tools_json, + validator, + }) + } + + pub fn feed(&mut self, fragment: &str) { + self.buffer.append(fragment); + } + + pub fn finalize(&mut self) -> ToolCallEvent { + let input = self.buffer.take(); + if input.is_empty() { + return ToolCallEvent::Resolved(Vec::new()); + } + + match self + .model + .parse_chat_message(&self.tools_json, &input, false) + { + Ok(ChatMessageParseOutcome::Recognized(parsed)) => { + self.validate_resolved(parsed.tool_calls) + } + Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage { + text, + ffi_error_message, + .. + })) => ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { + text, + ffi_error_message, + }), + Err(err) => ToolCallEvent::ParseFailed(ToolCallPipelineError::Bindings(err)), + } + } + + pub fn finalize_to_generated_event(&mut self) -> Option { + self.finalize().into_generated_token_result() + } + + #[must_use] + pub fn try_partial(&self) -> ToolCallEvent { + let input = self.buffer.as_str(); + if input.is_empty() { + return ToolCallEvent::Pending; + } + + match self.model.parse_chat_message(&self.tools_json, input, true) { + Ok(ChatMessageParseOutcome::Recognized(parsed)) if parsed.tool_calls.is_empty() => { + ToolCallEvent::Pending + } + Ok(ChatMessageParseOutcome::Recognized(parsed)) => { + self.validate_resolved(parsed.tool_calls) + } + Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage { + text, + ffi_error_message, + .. + })) => ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { + text, + ffi_error_message, + }), + Err(err) => ToolCallEvent::ParseFailed(ToolCallPipelineError::Bindings(err)), + } + } + + pub fn reset(&mut self) { + self.buffer.clear(); + } + + #[must_use] + pub const fn buffer_is_empty(&self) -> bool { + self.buffer.is_empty() + } + + fn validate_resolved(&self, tool_calls: Vec) -> ToolCallEvent { + let mut errors = Vec::new(); + for call in &tool_calls { + if let Err(err) = self.validator.validate(call) { + errors.push(err); + } + } + + if errors.is_empty() { + ToolCallEvent::Resolved(tool_calls) + } else { + ToolCallEvent::ValidationFailed(errors) + } + } +} diff --git a/paddler/src/tool_call_pipeline_error.rs b/paddler/src/tool_call_pipeline_error.rs new file mode 100644 index 00000000..da6ac110 --- /dev/null +++ b/paddler/src/tool_call_pipeline_error.rs @@ -0,0 +1,9 @@ +use llama_cpp_bindings::ParseChatMessageError; + +#[derive(Debug, thiserror::Error)] +pub enum ToolCallPipelineError { + #[error("tool-call pipeline invoked on empty buffer")] + EmptyBuffer, + #[error("bindings parse failed: {0}")] + Bindings(#[from] ParseChatMessageError), +} diff --git a/paddler/src/tool_call_validation_error.rs b/paddler/src/tool_call_validation_error.rs new file mode 100644 index 00000000..798cf38c --- /dev/null +++ b/paddler/src/tool_call_validation_error.rs @@ -0,0 +1,7 @@ +#[derive(Debug, thiserror::Error)] +pub enum ToolCallValidationError { + #[error("unknown tool name {0:?}")] + UnknownToolName(String), + #[error("arguments for tool {tool_name:?} failed schema check: {message}")] + SchemaMismatch { tool_name: String, message: String }, +} diff --git a/paddler/src/tool_call_validator.rs b/paddler/src/tool_call_validator.rs new file mode 100644 index 00000000..63f48a80 --- /dev/null +++ b/paddler/src/tool_call_validator.rs @@ -0,0 +1,344 @@ +use std::collections::HashMap; + +use jsonschema::Validator; +use llama_cpp_bindings::ParsedToolCall; +use llama_cpp_bindings::ToolCallArguments; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + +use crate::tool_call_validation_error::ToolCallValidationError; + +#[derive(Debug, thiserror::Error)] +pub enum ValidatorBuildError { + #[error("could not serialize tool {tool_name:?} parameters to JSON: {message}")] + SerializationFailed { tool_name: String, message: String }, + #[error("tool {tool_name:?} parameters are not a valid JSON Schema: {message}")] + InvalidSchema { tool_name: String, message: String }, +} + +enum ValidationStrategy { + JsonObjectOnly, + Schema(Box), +} + +pub struct ToolCallValidator { + strategies: HashMap, +} + +impl ToolCallValidator { + pub fn from_tools( + tools: &[Tool], + ) -> Result { + let mut strategies = HashMap::with_capacity(tools.len()); + + for tool in tools { + let Tool::Function(function_call) = tool; + let function = &function_call.function; + + let strategy = match &function.parameters { + Parameters::Empty => ValidationStrategy::JsonObjectOnly, + Parameters::Schema(schema) => { + let schema_value = serde_json::to_value(schema).map_err(|err| { + ValidatorBuildError::SerializationFailed { + tool_name: function.name.clone(), + message: err.to_string(), + } + })?; + let compiled = jsonschema::validator_for(&schema_value).map_err(|err| { + ValidatorBuildError::InvalidSchema { + tool_name: function.name.clone(), + message: err.to_string(), + } + })?; + ValidationStrategy::Schema(Box::new(compiled)) + } + }; + + strategies.insert(function.name.clone(), strategy); + } + + Ok(Self { strategies }) + } + + pub fn validate(&self, parsed: &ParsedToolCall) -> Result<(), ToolCallValidationError> { + let strategy = self + .strategies + .get(&parsed.name) + .ok_or_else(|| ToolCallValidationError::UnknownToolName(parsed.name.clone()))?; + + let arguments_value = match &parsed.arguments { + ToolCallArguments::ValidJson(value) => value, + ToolCallArguments::InvalidJson(_) => return Ok(()), + }; + + match strategy { + ValidationStrategy::JsonObjectOnly => Ok(()), + ValidationStrategy::Schema(validator) => { + let mut messages: Vec = validator + .iter_errors(arguments_value) + .map(|err| err.to_string()) + .collect(); + + if messages.is_empty() { + Ok(()) + } else { + Err(ToolCallValidationError::SchemaMismatch { + tool_name: parsed.name.clone(), + message: messages.remove(0), + }) + } + } + } + } + + #[must_use] + pub fn known_tool_names(&self) -> Vec<&str> { + self.strategies.keys().map(String::as_str).collect() + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::bail; + use llama_cpp_bindings::ParsedToolCall; + use llama_cpp_bindings::ToolCallArguments; + use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; + use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; + use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; + use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; + use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + use serde_json::Map; + use serde_json::Value; + use serde_json::json; + + use super::ToolCallValidator; + use crate::tool_call_validation_error::ToolCallValidationError; + + fn valid_json_arguments(value: Value) -> ToolCallArguments { + ToolCallArguments::ValidJson(value) + } + + fn weather_tool_with_schema() -> Tool { + let mut properties = Map::new(); + properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "city"}), + ); + + Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "fetch weather".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + }) + } + + fn schemaless_tool() -> Tool { + Tool::Function(FunctionCall { + function: Function { + name: "freeform".to_owned(), + description: "tool with no schema".to_owned(), + parameters: Parameters::Empty, + }, + }) + } + + #[test] + fn schema_validator_accepts_matching_arguments() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "get_weather".to_owned(), + valid_json_arguments(json!({"location": "Paris"})), + ); + + validator.validate(&parsed)?; + + Ok(()) + } + + #[test] + fn schema_validator_rejects_missing_required_field() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "get_weather".to_owned(), + valid_json_arguments(json!({})), + ); + + match validator.validate(&parsed) { + Err(ToolCallValidationError::SchemaMismatch { tool_name, .. }) => { + assert_eq!(tool_name, "get_weather"); + Ok(()) + } + other => bail!("expected SchemaMismatch, got {other:?}"), + } + } + + #[test] + fn schema_validator_rejects_wrong_type() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "get_weather".to_owned(), + valid_json_arguments(json!({"location": 42})), + ); + + match validator.validate(&parsed) { + Err(ToolCallValidationError::SchemaMismatch { .. }) => Ok(()), + other => bail!("expected SchemaMismatch, got {other:?}"), + } + } + + #[test] + fn unknown_tool_name_returns_error() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "set_thermostat".to_owned(), + valid_json_arguments(json!({"value": 21})), + ); + + match validator.validate(&parsed) { + Err(ToolCallValidationError::UnknownToolName(name)) => { + assert_eq!(name, "set_thermostat"); + Ok(()) + } + other => bail!("expected UnknownToolName, got {other:?}"), + } + } + + #[test] + fn invalid_json_arguments_pass_validation_silently() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::InvalidJson("not json".to_owned()), + ); + + validator.validate(&parsed)?; + + Ok(()) + } + + #[test] + fn schemaless_tool_accepts_any_object() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[schemaless_tool()])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "freeform".to_owned(), + valid_json_arguments(json!({"x": 1, "y": 2})), + ); + + validator.validate(&parsed)?; + + Ok(()) + } + + #[test] + fn known_tool_names_returns_all_registered_names() -> Result<()> { + let validator = + ToolCallValidator::from_tools(&[weather_tool_with_schema(), schemaless_tool()])?; + + let mut names = validator.known_tool_names(); + names.sort_unstable(); + + assert_eq!(names, vec!["freeform", "get_weather"]); + + Ok(()) + } + + #[test] + fn empty_tools_yields_validator_that_rejects_any_call() -> Result<()> { + let validator = ToolCallValidator::from_tools(&[])?; + let parsed = ParsedToolCall::new( + "id".to_owned(), + "anything".to_owned(), + valid_json_arguments(json!({})), + ); + + assert!(matches!( + validator.validate(&parsed), + Err(ToolCallValidationError::UnknownToolName(_)) + )); + + Ok(()) + } + + fn tool_with_invalid_property_schema() -> Tool { + let mut properties = Map::new(); + properties.insert("location".to_owned(), serde_json::json!({"type": 42})); + + Tool::Function(FunctionCall { + function: Function { + name: "broken_tool".to_owned(), + description: "tool whose property schema is not valid JSON Schema".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties), + required: None, + additional_properties: None, + }), + }, + }) + } + + #[test] + fn invalid_property_schema_rejects_validator_build() -> Result<()> { + let error = ToolCallValidator::from_tools(&[tool_with_invalid_property_schema()]) + .err() + .ok_or_else(|| anyhow::anyhow!("expected ValidatorBuildError, got Ok"))?; + + match error { + super::ValidatorBuildError::InvalidSchema { tool_name, .. } => { + assert_eq!(tool_name, "broken_tool"); + Ok(()) + } + super::ValidatorBuildError::SerializationFailed { .. } => { + bail!("expected InvalidSchema, got SerializationFailed: {error:?}"); + } + } + } + + fn tool_with_invalid_additional_properties_schema() -> Tool { + Tool::Function(FunctionCall { + function: Function { + name: "broken_additional".to_owned(), + description: "tool whose additionalProperties schema is invalid".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: None, + required: None, + additional_properties: Some(json!({"type": "not_a_type"})), + }), + }, + }) + } + + #[test] + fn invalid_additional_properties_schema_rejects_validator_build() -> Result<()> { + let error = + ToolCallValidator::from_tools(&[tool_with_invalid_additional_properties_schema()]) + .err() + .ok_or_else(|| anyhow::anyhow!("expected ValidatorBuildError, got Ok"))?; + + match error { + super::ValidatorBuildError::InvalidSchema { tool_name, .. } => { + assert_eq!(tool_name, "broken_additional"); + Ok(()) + } + super::ValidatorBuildError::SerializationFailed { .. } => { + bail!("expected InvalidSchema, got SerializationFailed: {error:?}"); + } + } + } +} diff --git a/paddler_bootstrap/src/bootstrapped_agent_handle.rs b/paddler_bootstrap/src/bootstrapped_agent_handle.rs index 3b2408f0..9d87fdb5 100644 --- a/paddler_bootstrap/src/bootstrapped_agent_handle.rs +++ b/paddler_bootstrap/src/bootstrapped_agent_handle.rs @@ -9,10 +9,10 @@ use paddler::agent::management_socket_client_service::ManagementSocketClientServ use paddler::agent::model_metadata_holder::ModelMetadataHolder; use paddler::agent::reconciliation_service::ReconciliationService; use paddler::agent_applicable_state_holder::AgentApplicableStateHolder; -use paddler::agent_desired_state::AgentDesiredState; use paddler::service_manager::ServiceManager; use paddler::slot_aggregated_status::SlotAggregatedStatus; use paddler::slot_aggregated_status_manager::SlotAggregatedStatusManager; +use paddler_types::agent_desired_state::AgentDesiredState; use tokio::sync::mpsc; pub struct BootstrappedAgentHandle { diff --git a/paddler_bootstrap/src/bootstrapped_balancer_handle.rs b/paddler_bootstrap/src/bootstrapped_balancer_handle.rs index ee706483..1e208648 100644 --- a/paddler_bootstrap/src/bootstrapped_balancer_handle.rs +++ b/paddler_bootstrap/src/bootstrapped_balancer_handle.rs @@ -90,6 +90,7 @@ pub async fn bootstrap_balancer( }; service_manager.add_service(InferenceService { + agent_controller_pool: agent_controller_pool.clone(), balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), buffered_request_manager: buffered_request_manager.clone(), configuration: inference_service_configuration.clone(), diff --git a/paddler_bootstrap/src/service_thread.rs b/paddler_bootstrap/src/service_thread.rs index b0b889da..f3256a65 100644 --- a/paddler_bootstrap/src/service_thread.rs +++ b/paddler_bootstrap/src/service_thread.rs @@ -4,6 +4,7 @@ use std::thread; use anyhow::Result; use anyhow::anyhow; use log::error; +use log::warn; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; @@ -24,7 +25,16 @@ impl ServiceThread { let thread = thread::spawn(move || { let result = actix_web::rt::System::new().block_on(run(task_token)); - let _ = completion_tx.send(result); + if let Err(unsent) = completion_tx.send(result) { + match unsent { + Ok(()) => warn!( + "service thread completion receiver dropped before delivery; run() succeeded but result was not observed by the caller" + ), + Err(run_err) => error!( + "service thread completion receiver dropped before delivery; lost run() error: {run_err:?}" + ), + } + } }); Self { diff --git a/paddler_bootstrap/src/shutdown_signal/mod.rs b/paddler_bootstrap/src/shutdown_signal/mod.rs index 2aa84865..9453fb58 100644 --- a/paddler_bootstrap/src/shutdown_signal/mod.rs +++ b/paddler_bootstrap/src/shutdown_signal/mod.rs @@ -4,6 +4,10 @@ mod unix; mod windows; #[cfg(unix)] -pub use unix::wait_for_shutdown_signal; +pub use unix::ShutdownSignals; +#[cfg(unix)] +pub use unix::register_shutdown_signals; +#[cfg(windows)] +pub use windows::ShutdownSignals; #[cfg(windows)] -pub use windows::wait_for_shutdown_signal; +pub use windows::register_shutdown_signals; diff --git a/paddler_bootstrap/src/shutdown_signal/unix.rs b/paddler_bootstrap/src/shutdown_signal/unix.rs index 969f9c37..411d5531 100644 --- a/paddler_bootstrap/src/shutdown_signal/unix.rs +++ b/paddler_bootstrap/src/shutdown_signal/unix.rs @@ -1,19 +1,36 @@ use anyhow::Context as _; use anyhow::Result; use log::info; +use tokio::signal::unix::Signal; use tokio::signal::unix::SignalKind; use tokio::signal::unix::signal; -pub async fn wait_for_shutdown_signal() -> Result<()> { - let mut sigterm = signal(SignalKind::terminate()).context("failed to listen for SIGTERM")?; - let mut sigint = signal(SignalKind::interrupt()).context("failed to listen for SIGINT")?; - let mut sighup = signal(SignalKind::hangup()).context("failed to listen for SIGHUP")?; +pub struct ShutdownSignals { + sigterm: Signal, + sigint: Signal, + sighup: Signal, +} + +impl ShutdownSignals { + pub async fn wait(mut self) -> Result<()> { + tokio::select! { + _ = self.sigterm.recv() => info!("Received SIGTERM"), + _ = self.sigint.recv() => info!("Received SIGINT"), + _ = self.sighup.recv() => info!("Received SIGHUP"), + } - tokio::select! { - _ = sigterm.recv() => info!("Received SIGTERM"), - _ = sigint.recv() => info!("Received SIGINT"), - _ = sighup.recv() => info!("Received SIGHUP"), + Ok(()) } +} + +pub fn register_shutdown_signals() -> Result { + let sigterm = signal(SignalKind::terminate()).context("failed to listen for SIGTERM")?; + let sigint = signal(SignalKind::interrupt()).context("failed to listen for SIGINT")?; + let sighup = signal(SignalKind::hangup()).context("failed to listen for SIGHUP")?; - Ok(()) + Ok(ShutdownSignals { + sigterm, + sigint, + sighup, + }) } diff --git a/paddler_bootstrap/src/shutdown_signal/windows.rs b/paddler_bootstrap/src/shutdown_signal/windows.rs index 862186fa..b019c78f 100644 --- a/paddler_bootstrap/src/shutdown_signal/windows.rs +++ b/paddler_bootstrap/src/shutdown_signal/windows.rs @@ -1,23 +1,45 @@ use anyhow::Context as _; use anyhow::Result; use log::info; +use tokio::signal::windows::CtrlBreak; +use tokio::signal::windows::CtrlC; +use tokio::signal::windows::CtrlClose; +use tokio::signal::windows::CtrlShutdown; use tokio::signal::windows::ctrl_break; use tokio::signal::windows::ctrl_c; use tokio::signal::windows::ctrl_close; use tokio::signal::windows::ctrl_shutdown; -pub async fn wait_for_shutdown_signal() -> Result<()> { - let mut ctrl_c = ctrl_c().context("failed to listen for Ctrl+C")?; - let mut ctrl_break = ctrl_break().context("failed to listen for Ctrl+Break")?; - let mut ctrl_close = ctrl_close().context("failed to listen for console close")?; - let mut ctrl_shutdown = ctrl_shutdown().context("failed to listen for system shutdown")?; +pub struct ShutdownSignals { + ctrl_c: CtrlC, + ctrl_break: CtrlBreak, + ctrl_close: CtrlClose, + ctrl_shutdown: CtrlShutdown, +} + +impl ShutdownSignals { + pub async fn wait(mut self) -> Result<()> { + tokio::select! { + _ = self.ctrl_c.recv() => info!("Received Ctrl+C"), + _ = self.ctrl_break.recv() => info!("Received Ctrl+Break"), + _ = self.ctrl_close.recv() => info!("Received console close"), + _ = self.ctrl_shutdown.recv() => info!("Received system shutdown"), + } - tokio::select! { - _ = ctrl_c.recv() => info!("Received Ctrl+C"), - _ = ctrl_break.recv() => info!("Received Ctrl+Break"), - _ = ctrl_close.recv() => info!("Received console close"), - _ = ctrl_shutdown.recv() => info!("Received system shutdown"), + Ok(()) } +} + +pub fn register_shutdown_signals() -> Result { + let ctrl_c = ctrl_c().context("failed to listen for Ctrl+C")?; + let ctrl_break = ctrl_break().context("failed to listen for Ctrl+Break")?; + let ctrl_close = ctrl_close().context("failed to listen for console close")?; + let ctrl_shutdown = ctrl_shutdown().context("failed to listen for system shutdown")?; - Ok(()) + Ok(ShutdownSignals { + ctrl_c, + ctrl_break, + ctrl_close, + ctrl_shutdown, + }) } diff --git a/paddler_bootstrap/tests/runners.rs b/paddler_bootstrap/tests/runners.rs index 588f9214..3567cb91 100644 --- a/paddler_bootstrap/tests/runners.rs +++ b/paddler_bootstrap/tests/runners.rs @@ -227,10 +227,7 @@ async fn agent_runner_cancels_from_parent_token() -> Result<()> { let parent = CancellationToken::new(); - let runner = AgentRunner::start(make_agent_runner_params( - management_addr, - parent.clone(), - )); + let runner = AgentRunner::start(make_agent_runner_params(management_addr, parent.clone())); parent.cancel(); drop(runner); diff --git a/paddler_cli/src/main.rs b/paddler_cli/src/main.rs index b406ea23..0e06b294 100644 --- a/paddler_cli/src/main.rs +++ b/paddler_cli/src/main.rs @@ -1,14 +1,14 @@ +mod cmd; + use anyhow::Result; use clap::Parser; use clap::Subcommand; -#[cfg(feature = "web_admin_panel")] -use esbuild_metafile::instance::initialize_instance; -mod cmd; - use cmd::agent::Agent; use cmd::balancer::Balancer; use cmd::handler::Handler as _; -use paddler_bootstrap::shutdown_signal::wait_for_shutdown_signal; +#[cfg(feature = "web_admin_panel")] +use esbuild_metafile::instance::initialize_instance; +use paddler_bootstrap::shutdown_signal::register_shutdown_signals; use tokio_util::sync::CancellationToken; #[cfg(feature = "web_admin_panel")] @@ -42,11 +42,12 @@ enum Commands { async fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + let shutdown_signals = register_shutdown_signals()?; let shutdown = CancellationToken::new(); let signal_shutdown = shutdown.clone(); tokio::spawn(async move { - if let Err(error) = wait_for_shutdown_signal().await { + if let Err(error) = shutdown_signals.wait().await { log::error!("shutdown signal listener failed: {error}"); return; } diff --git a/paddler_client/src/client_inference.rs b/paddler_client/src/client_inference.rs index 9b84aea1..86fe6ddf 100644 --- a/paddler_client/src/client_inference.rs +++ b/paddler_client/src/client_inference.rs @@ -5,9 +5,9 @@ use paddler_types::inference_client::Message as InferenceMessage; use paddler_types::inference_server::Message as InferenceServerMessage; use paddler_types::inference_server::Request as InferenceServerRequest; use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::request_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::ContinueFromRawPromptParams; use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use reqwest::Client; use tokio_stream::wrappers::UnboundedReceiverStream; diff --git a/paddler_client/src/lib.rs b/paddler_client/src/lib.rs index 2c805849..0939d664 100644 --- a/paddler_client/src/lib.rs +++ b/paddler_client/src/lib.rs @@ -1,10 +1,10 @@ -pub mod agents_stream; -pub mod buffered_requests_stream; -pub mod client_inference; -pub mod client_management; -pub mod error; +mod agents_stream; +mod buffered_requests_stream; +mod client_inference; +mod client_management; +mod error; mod format_api_url; -pub mod inference_message_stream; +mod inference_message_stream; mod inference_socket; mod stream; diff --git a/paddler_client_cli/Cargo.toml b/paddler_client_cli/Cargo.toml new file mode 100644 index 00000000..9b56ade1 --- /dev/null +++ b/paddler_client_cli/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "paddler_client_cli" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "Client CLI/TUI binary for Paddler" +license.workspace = true + +[[bin]] +name = "paddler_client_cli" +path = "src/main.rs" + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +clap = { workspace = true } +crossterm = { workspace = true } +env_logger = { workspace = true } +futures-util = { workspace = true } +llama-cpp-bindings-types = { workspace = true } +log = { workspace = true } +paddler_bootstrap = { workspace = true } +paddler_client = { workspace = true } +paddler_types = { workspace = true } +ratatui = { workspace = true } +reqwest = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +url = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_client_cli/examples/calculator.json b/paddler_client_cli/examples/calculator.json new file mode 100644 index 00000000..a857e795 --- /dev/null +++ b/paddler_client_cli/examples/calculator.json @@ -0,0 +1,27 @@ +{ + "type": "function", + "function": { + "name": "calculator", + "description": "Perform a basic arithmetic operation on two numbers", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The operation to perform" + }, + "a": { + "type": "number", + "description": "First operand" + }, + "b": { + "type": "number", + "description": "Second operand" + } + }, + "required": ["operation", "a", "b"], + "additionalProperties": false + } + } +} diff --git a/paddler_client_cli/examples/get_weather.json b/paddler_client_cli/examples/get_weather.json new file mode 100644 index 00000000..d3948491 --- /dev/null +++ b/paddler_client_cli/examples/get_weather.json @@ -0,0 +1,23 @@ +{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name, e.g. 'Paris' or 'Tokyo'" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"], + "additionalProperties": false + } + } +} diff --git a/paddler_client_cli/examples/negotiate_with_cat.json b/paddler_client_cli/examples/negotiate_with_cat.json new file mode 100644 index 00000000..6b74c15f --- /dev/null +++ b/paddler_client_cli/examples/negotiate_with_cat.json @@ -0,0 +1,29 @@ +{ + "type": "function", + "function": { + "name": "negotiate_with_cat", + "description": "Attempt to negotiate with a cat. Outcomes are not guaranteed and may include the silent treatment.", + "parameters": { + "type": "object", + "properties": { + "topic": { + "type": "string", + "description": "What you are trying to negotiate, e.g. 'get off the keyboard' or 'stop knocking things off the table'" + }, + "bribe": { + "type": "string", + "enum": ["tuna", "salmon", "treats", "ear_scritches", "cardboard_box", "none"], + "description": "What you are offering in exchange" + }, + "desperation_level": { + "type": "integer", + "description": "How desperate you are, on a scale from 1 (mildly annoyed human) to 10 (it is 3am)", + "minimum": 1, + "maximum": 10 + } + }, + "required": ["topic"], + "additionalProperties": false + } + } +} diff --git a/paddler_client_cli/src/chat_session.rs b/paddler_client_cli/src/chat_session.rs new file mode 100644 index 00000000..5742499a --- /dev/null +++ b/paddler_client_cli/src/chat_session.rs @@ -0,0 +1,172 @@ +use std::io; + +use anyhow::Result; +use anyhow::anyhow; +use crossterm::event::Event as CrosstermEvent; +use crossterm::event::EventStream; +use crossterm::event::KeyCode; +use crossterm::event::KeyEvent; +use crossterm::event::KeyModifiers; +use crossterm::event::MouseButton; +use crossterm::event::MouseEvent; +use crossterm::event::MouseEventKind; +use futures_util::StreamExt; +use paddler_client::InferenceMessageStream; +use ratatui::Terminal; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::Rect; +use tokio_util::sync::CancellationToken; + +use crate::chat_session_event::ChatSessionEvent; +use crate::streaming_response::StreamingResponse; +use crate::view_chat_panels::view_chat_panels; +use crate::view_panel_layout::ViewPanelLayout; +use crate::view_panel_navigation::ViewPanelNavigation; +use crate::view_terminal_guard::ViewTerminalGuard; + +const MOUSE_WHEEL_LINES: u16 = 3; +const ARROW_KEY_LINES: u16 = 1; + +pub struct ChatSession { + inference_stream: InferenceMessageStream, + state: StreamingResponse, + navigation: ViewPanelNavigation, + shutdown: CancellationToken, +} + +impl ChatSession { + pub fn new(inference_stream: InferenceMessageStream, shutdown: CancellationToken) -> Self { + Self { + inference_stream, + state: StreamingResponse::default(), + navigation: ViewPanelNavigation::default(), + shutdown, + } + } + + pub async fn run(mut self) -> Result<()> { + let _terminal_guard = ViewTerminalGuard::enter()?; + let mut terminal = Terminal::new(CrosstermBackend::new(io::stdout()))?; + let mut events = EventStream::new(); + + let mut layout = compute_layout(&terminal)?; + terminal.draw(|frame| { + view_chat_panels(&self.state, &mut self.navigation, &layout, frame); + })?; + + loop { + match self.next_event(&mut events).await { + ChatSessionEvent::InferenceMessage(message) => { + self.state.apply_message(message); + } + ChatSessionEvent::InferenceStreamEnded => { + if !self.state.is_finished() { + self.state.record_wire_error(&anyhow!( + "inference stream ended before sending Done" + )); + } + } + ChatSessionEvent::InferenceStreamError(error) => { + self.state.record_wire_error(&error); + } + ChatSessionEvent::Key(key_event) => { + if is_quit(key_event) { + return Ok(()); + } + self.handle_navigation_key(key_event, &layout); + } + ChatSessionEvent::Mouse(mouse_event) => { + self.handle_mouse(mouse_event, &layout); + } + ChatSessionEvent::Repaint => {} + ChatSessionEvent::Shutdown => return Ok(()), + } + layout = compute_layout(&terminal)?; + terminal.draw(|frame| { + view_chat_panels(&self.state, &mut self.navigation, &layout, frame); + })?; + } + } + + async fn next_event(&mut self, events: &mut EventStream) -> ChatSessionEvent { + let inference_active = !self.state.is_finished(); + loop { + tokio::select! { + biased; + () = self.shutdown.cancelled() => return ChatSessionEvent::Shutdown, + maybe_event = events.next() => match maybe_event { + Some(Ok(CrosstermEvent::Key(key))) => return ChatSessionEvent::Key(key), + Some(Ok(CrosstermEvent::Mouse(mouse))) => return ChatSessionEvent::Mouse(mouse), + Some(Ok(CrosstermEvent::Resize(_, _))) => return ChatSessionEvent::Repaint, + Some(Ok(_)) => {} + Some(Err(read_error)) => { + log::error!("terminal event read error: {read_error}"); + return ChatSessionEvent::Shutdown; + } + None => return ChatSessionEvent::Shutdown, + }, + maybe_message = self.inference_stream.next(), if inference_active => match maybe_message { + Some(Ok(message)) => return ChatSessionEvent::InferenceMessage(message), + Some(Err(stream_error)) => return ChatSessionEvent::InferenceStreamError(stream_error.into()), + None => return ChatSessionEvent::InferenceStreamEnded, + }, + } + } + } + + fn handle_navigation_key(&mut self, key_event: KeyEvent, layout: &ViewPanelLayout) { + let focused = self.navigation.focused(); + let viewport_rows = layout.viewport_rows(focused); + let page_lines = viewport_rows.saturating_sub(1).max(1); + match key_event.code { + KeyCode::Up => self.navigation.scroll_up(focused, ARROW_KEY_LINES), + KeyCode::Down => self.navigation.scroll_down(focused, ARROW_KEY_LINES), + KeyCode::PageUp => self.navigation.scroll_up(focused, page_lines), + KeyCode::PageDown => self.navigation.scroll_down(focused, page_lines), + KeyCode::Home => self.navigation.jump_to_top(focused), + KeyCode::End => self.navigation.jump_to_bottom(focused), + KeyCode::Tab => self.navigation.cycle_focus_forward(), + KeyCode::BackTab => self.navigation.cycle_focus_backward(), + _ => {} + } + } + + fn handle_mouse(&mut self, mouse_event: MouseEvent, layout: &ViewPanelLayout) { + let Some(panel) = layout.panel_at(mouse_event.column, mouse_event.row) else { + return; + }; + match mouse_event.kind { + MouseEventKind::ScrollUp => { + self.navigation.focus(panel); + self.navigation.scroll_up(panel, MOUSE_WHEEL_LINES); + } + MouseEventKind::ScrollDown => { + self.navigation.focus(panel); + self.navigation.scroll_down(panel, MOUSE_WHEEL_LINES); + } + MouseEventKind::Down(MouseButton::Left) => { + self.navigation.focus(panel); + } + _ => {} + } + } +} + +fn compute_layout(terminal: &Terminal>) -> Result { + let size = terminal.size()?; + Ok(ViewPanelLayout::compute(Rect::new( + 0, + 0, + size.width, + size.height, + ))) +} + +const fn is_quit(key_event: KeyEvent) -> bool { + if key_event.modifiers.contains(KeyModifiers::CONTROL) + && matches!(key_event.code, KeyCode::Char('c')) + { + return true; + } + matches!(key_event.code, KeyCode::Char('q' | 'Q') | KeyCode::Esc) +} diff --git a/paddler_client_cli/src/chat_session_event.rs b/paddler_client_cli/src/chat_session_event.rs new file mode 100644 index 00000000..8c1e0edd --- /dev/null +++ b/paddler_client_cli/src/chat_session_event.rs @@ -0,0 +1,14 @@ +use crossterm::event::KeyEvent; +use crossterm::event::MouseEvent; +use paddler_types::inference_client::Message; + +#[derive(Debug)] +pub enum ChatSessionEvent { + InferenceMessage(Message), + InferenceStreamEnded, + InferenceStreamError(anyhow::Error), + Key(KeyEvent), + Mouse(MouseEvent), + Repaint, + Shutdown, +} diff --git a/paddler_client_cli/src/cmd/handler.rs b/paddler_client_cli/src/cmd/handler.rs new file mode 100644 index 00000000..2840065c --- /dev/null +++ b/paddler_client_cli/src/cmd/handler.rs @@ -0,0 +1,8 @@ +use anyhow::Result; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; + +#[async_trait] +pub trait Handler { + async fn handle(&self, shutdown: CancellationToken) -> Result<()>; +} diff --git a/paddler_client_cli/src/cmd/mod.rs b/paddler_client_cli/src/cmd/mod.rs new file mode 100644 index 00000000..bc27b40c --- /dev/null +++ b/paddler_client_cli/src/cmd/mod.rs @@ -0,0 +1,2 @@ +pub mod handler; +pub mod prompt; diff --git a/paddler_client_cli/src/cmd/prompt.rs b/paddler_client_cli/src/cmd/prompt.rs new file mode 100644 index 00000000..13efb6ed --- /dev/null +++ b/paddler_client_cli/src/cmd/prompt.rs @@ -0,0 +1,73 @@ +use std::path::PathBuf; + +use anyhow::Result; +use async_trait::async_trait; +use clap::Parser; +use paddler_client::ClientInference; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; +use tokio_util::sync::CancellationToken; +use url::Url; + +use super::handler::Handler; +use crate::chat_session::ChatSession; +use crate::prompt_load_tool::prompt_load_tool; +use crate::prompt_parse_inference_url::prompt_parse_inference_url; +use crate::prompt_thinking_mode::PromptThinkingMode; + +#[derive(Parser)] +pub struct Prompt { + #[arg(long, value_parser = prompt_parse_inference_url)] + /// Address of the inference server (e.g. 127.0.0.1:8061) + inference_addr: Url, + + #[arg(long)] + /// Maximum number of tokens to generate + max_tokens: i32, + + #[arg(long, value_enum)] + /// Whether chain-of-thought thinking is on or off + thinking: PromptThinkingMode, + + #[arg(long, action = clap::ArgAction::Append)] + /// Path to a JSON file describing one tool (repeatable) + tool: Vec, + + /// Prompt to send to the model + message: String, +} + +#[async_trait] +impl Handler for Prompt { + async fn handle(&self, shutdown: CancellationToken) -> Result<()> { + let tools = self + .tool + .iter() + .map(|path| prompt_load_tool(path)) + .collect::>>()?; + + let request = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text(self.message.clone()), + role: "user".to_owned(), + }]), + enable_thinking: self.thinking.is_enabled(), + grammar: None, + max_tokens: self.max_tokens, + parse_tool_calls: !tools.is_empty(), + tools, + }; + + let http_client = Client::new(); + let inference = ClientInference::new(&self.inference_addr, &http_client, 1); + let stream = inference + .post_continue_from_conversation_history(&request) + .await?; + + ChatSession::new(stream, shutdown).run().await + } +} diff --git a/paddler_client_cli/src/main.rs b/paddler_client_cli/src/main.rs new file mode 100644 index 00000000..fb47974c --- /dev/null +++ b/paddler_client_cli/src/main.rs @@ -0,0 +1,54 @@ +mod chat_session; +mod chat_session_event; +mod cmd; +mod prompt_load_tool; +mod prompt_parse_inference_url; +mod prompt_thinking_mode; +mod stop_reason; +mod streaming_response; +mod view_chat_panels; +mod view_panel_kind; +mod view_panel_layout; +mod view_panel_navigation; +mod view_terminal_guard; + +use anyhow::Result; +use clap::Parser; +use clap::Subcommand; +use cmd::handler::Handler as _; +use cmd::prompt::Prompt; +use paddler_bootstrap::shutdown_signal::register_shutdown_signals; +use tokio_util::sync::CancellationToken; + +#[derive(Parser)] +#[command(arg_required_else_help(true), version, about, long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + Prompt(Prompt), +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let shutdown_signals = register_shutdown_signals()?; + let shutdown = CancellationToken::new(); + let signal_shutdown = shutdown.clone(); + + tokio::spawn(async move { + if let Err(error) = shutdown_signals.wait().await { + log::error!("shutdown signal listener failed: {error}"); + return; + } + signal_shutdown.cancel(); + }); + + match Cli::parse().command { + Commands::Prompt(handler) => handler.handle(shutdown).await, + } +} diff --git a/paddler_client_cli/src/prompt_load_tool.rs b/paddler_client_cli/src/prompt_load_tool.rs new file mode 100644 index 00000000..54a20786 --- /dev/null +++ b/paddler_client_cli/src/prompt_load_tool.rs @@ -0,0 +1,17 @@ +use std::fs::File; +use std::path::Path; + +use anyhow::Context; +use anyhow::Result; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_types::validates::Validates; + +pub fn prompt_load_tool(path: &Path) -> Result> { + let file = File::open(path).with_context(|| format!("opening tool file {}", path.display()))?; + let raw: Tool = serde_json::from_reader(file) + .with_context(|| format!("parsing tool file {}", path.display()))?; + raw.validate() + .with_context(|| format!("validating tool from {}", path.display())) +} diff --git a/paddler_client_cli/src/prompt_parse_inference_url.rs b/paddler_client_cli/src/prompt_parse_inference_url.rs new file mode 100644 index 00000000..2678ed7c --- /dev/null +++ b/paddler_client_cli/src/prompt_parse_inference_url.rs @@ -0,0 +1,6 @@ +use url::Url; + +pub fn prompt_parse_inference_url(input_addr: &str) -> Result { + Url::parse(&format!("http://{input_addr}")) + .map_err(|err| format!("invalid address '{input_addr}': {err}")) +} diff --git a/paddler_client_cli/src/prompt_thinking_mode.rs b/paddler_client_cli/src/prompt_thinking_mode.rs new file mode 100644 index 00000000..573a2a64 --- /dev/null +++ b/paddler_client_cli/src/prompt_thinking_mode.rs @@ -0,0 +1,14 @@ +use clap::ValueEnum; + +#[derive(Clone, Copy, ValueEnum)] +pub enum PromptThinkingMode { + On, + Off, +} + +impl PromptThinkingMode { + #[must_use] + pub const fn is_enabled(self) -> bool { + matches!(self, Self::On) + } +} diff --git a/paddler_client_cli/src/stop_reason.rs b/paddler_client_cli/src/stop_reason.rs new file mode 100644 index 00000000..391ca9f4 --- /dev/null +++ b/paddler_client_cli/src/stop_reason.rs @@ -0,0 +1,84 @@ +use std::fmt; + +use paddler_types::oversized_image_details::OversizedImageDetails; + +#[derive(Debug)] +pub enum StopReason { + Completed, + ChatTemplateError(String), + GrammarIncompatibleWithThinking(String), + GrammarInitializationFailed(String), + GrammarRejectedModelOutput(String), + GrammarSyntaxError(String), + ImageDecodingFailed(String), + ImageExceedsBatchSize(OversizedImageDetails), + InferenceError { code: i32, description: String }, + MultimodalNotSupported(String), + SamplerError(String), + Timeout, + TooManyBufferedRequests, + ToolCallParseFailed(String), + ToolCallValidationFailed(Vec), + ToolSchemaInvalid(String), + WireStreamError(String), +} + +impl fmt::Display for StopReason { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Completed => formatter.write_str("completed"), + Self::ChatTemplateError(detail) => { + write!(formatter, "chat template error: {detail}") + } + Self::GrammarIncompatibleWithThinking(detail) => { + write!(formatter, "grammar incompatible with thinking: {detail}") + } + Self::GrammarInitializationFailed(detail) => { + write!(formatter, "grammar initialization failed: {detail}") + } + Self::GrammarRejectedModelOutput(detail) => { + write!(formatter, "grammar rejected model output: {detail}") + } + Self::GrammarSyntaxError(detail) => { + write!(formatter, "grammar syntax error: {detail}") + } + Self::ImageDecodingFailed(detail) => { + write!(formatter, "image decoding failed: {detail}") + } + Self::ImageExceedsBatchSize(details) => { + write!( + formatter, + "image required {} tokens but agent n_batch is {}", + details.image_tokens, details.n_batch, + ) + } + Self::InferenceError { code, description } => { + write!(formatter, "inference error {code}: {description}") + } + Self::MultimodalNotSupported(detail) => { + write!(formatter, "multimodal input not supported: {detail}") + } + Self::SamplerError(detail) => write!(formatter, "sampler error: {detail}"), + Self::Timeout => formatter.write_str("balancer timed out the request"), + Self::TooManyBufferedRequests => { + formatter.write_str("balancer rejected the request: queue is full") + } + Self::ToolCallParseFailed(detail) => { + write!(formatter, "tool-call parse failed: {detail}") + } + Self::ToolCallValidationFailed(field_errors) => { + write!( + formatter, + "tool-call validation failed: {}", + field_errors.join("; ") + ) + } + Self::ToolSchemaInvalid(detail) => { + write!(formatter, "tool schema invalid: {detail}") + } + Self::WireStreamError(detail) => { + write!(formatter, "wire stream error: {detail}") + } + } + } +} diff --git a/paddler_client_cli/src/streaming_response.rs b/paddler_client_cli/src/streaming_response.rs new file mode 100644 index 00000000..eb321e53 --- /dev/null +++ b/paddler_client_cli/src/streaming_response.rs @@ -0,0 +1,229 @@ +use llama_cpp_bindings_types::ParsedToolCall; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::generation_summary::GenerationSummary; +use paddler_types::inference_client::Message; +use paddler_types::inference_client::Response; +use paddler_types::raw_tool_call_tokens::RawToolCallTokens; + +use crate::stop_reason::StopReason; + +#[derive(Debug, Default)] +pub struct StreamingResponse { + pub thinking: Vec, + pub response: Vec, + pub tool_call_tokens: Vec, + pub tool_calls: Vec, + pub undetermined: Vec, + pub unrecognized_tool_call_format: Vec, + pub summary: Option, + pub stop_reason: Option, +} + +impl StreamingResponse { + pub fn apply_message(&mut self, message: Message) { + match message { + Message::Error(envelope) => { + self.stop_reason = Some(StopReason::InferenceError { + code: envelope.error.code, + description: envelope.error.description, + }); + } + Message::Response(envelope) => self.apply_response(envelope.response), + } + } + + pub fn record_wire_error(&mut self, error: &anyhow::Error) { + self.stop_reason = Some(StopReason::WireStreamError(error.to_string())); + } + + pub const fn is_finished(&self) -> bool { + self.stop_reason.is_some() + } + + fn apply_response(&mut self, response: Response) { + match response { + Response::GeneratedToken(token_result) => self.apply_token_result(token_result), + Response::Timeout => { + self.stop_reason = Some(StopReason::Timeout); + } + Response::TooManyBufferedRequests => { + self.stop_reason = Some(StopReason::TooManyBufferedRequests); + } + Response::Embedding(_) => { + unreachable!("server sent an embedding response on a token-generation stream") + } + } + } + + fn apply_token_result(&mut self, token_result: GeneratedTokenResult) { + match token_result { + GeneratedTokenResult::ContentToken(piece) => self.response.push(piece), + GeneratedTokenResult::ReasoningToken(piece) => self.thinking.push(piece), + GeneratedTokenResult::UndeterminableToken(piece) => self.undetermined.push(piece), + GeneratedTokenResult::ToolCallToken(piece) => self.tool_call_tokens.push(piece), + GeneratedTokenResult::ToolCallParsed(calls) => { + self.tool_calls.extend(calls); + } + GeneratedTokenResult::Done(summary) => { + self.summary = Some(summary); + self.stop_reason = Some(StopReason::Completed); + } + GeneratedTokenResult::ChatTemplateError(detail) => { + self.stop_reason = Some(StopReason::ChatTemplateError(detail)); + } + GeneratedTokenResult::GrammarIncompatibleWithThinking(detail) => { + self.stop_reason = Some(StopReason::GrammarIncompatibleWithThinking(detail)); + } + GeneratedTokenResult::GrammarInitializationFailed(detail) => { + self.stop_reason = Some(StopReason::GrammarInitializationFailed(detail)); + } + GeneratedTokenResult::GrammarRejectedModelOutput(detail) => { + self.stop_reason = Some(StopReason::GrammarRejectedModelOutput(detail)); + } + GeneratedTokenResult::GrammarSyntaxError(detail) => { + self.stop_reason = Some(StopReason::GrammarSyntaxError(detail)); + } + GeneratedTokenResult::ImageDecodingFailed(detail) => { + self.stop_reason = Some(StopReason::ImageDecodingFailed(detail)); + } + GeneratedTokenResult::ImageExceedsBatchSize(details) => { + self.stop_reason = Some(StopReason::ImageExceedsBatchSize(details)); + } + GeneratedTokenResult::MultimodalNotSupported(detail) => { + self.stop_reason = Some(StopReason::MultimodalNotSupported(detail)); + } + GeneratedTokenResult::SamplerError(detail) => { + self.stop_reason = Some(StopReason::SamplerError(detail)); + } + GeneratedTokenResult::ToolCallParseFailed(detail) => { + self.stop_reason = Some(StopReason::ToolCallParseFailed(detail)); + } + GeneratedTokenResult::ToolCallValidationFailed(field_errors) => { + self.stop_reason = Some(StopReason::ToolCallValidationFailed(field_errors)); + } + GeneratedTokenResult::ToolSchemaInvalid(detail) => { + self.stop_reason = Some(StopReason::ToolSchemaInvalid(detail)); + } + GeneratedTokenResult::UnrecognizedToolCallFormat(raw) => { + self.unrecognized_tool_call_format.push(raw); + } + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use paddler_types::jsonrpc::Error; + use paddler_types::jsonrpc::ErrorEnvelope; + use paddler_types::jsonrpc::ResponseEnvelope; + + use super::*; + + fn token_message(token_result: GeneratedTokenResult) -> Message { + Message::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: Response::GeneratedToken(token_result), + }) + } + + #[test] + fn content_token_appended_to_response_stream() { + let mut state = StreamingResponse::default(); + + state.apply_message(token_message(GeneratedTokenResult::ContentToken( + "hello ".to_owned(), + ))); + state.apply_message(token_message(GeneratedTokenResult::ContentToken( + "world".to_owned(), + ))); + + assert_eq!( + state.response, + vec!["hello ".to_owned(), "world".to_owned()] + ); + assert!(state.thinking.is_empty()); + assert!(state.undetermined.is_empty()); + assert!(!state.is_finished()); + } + + #[test] + fn raw_tool_call_token_appended_to_token_stream() { + let mut state = StreamingResponse::default(); + + state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( + "{\"name\":".to_owned(), + ))); + state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( + "\"calc\"}".to_owned(), + ))); + + assert_eq!( + state.tool_call_tokens, + vec!["{\"name\":".to_owned(), "\"calc\"}".to_owned()] + ); + assert!(state.tool_calls.is_empty()); + } + + #[test] + fn tool_call_parsed_extends_calls_without_dropping_token_stream() { + let mut state = StreamingResponse::default(); + state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( + "{\"name\":\"calc\"}".to_owned(), + ))); + let parsed = vec![ParsedToolCall::default()]; + + state.apply_message(token_message(GeneratedTokenResult::ToolCallParsed( + parsed.clone(), + ))); + + assert_eq!(state.tool_calls, parsed); + assert_eq!( + state.tool_call_tokens, + vec!["{\"name\":\"calc\"}".to_owned()] + ); + } + + #[test] + fn done_records_summary_and_completed_stop_reason() { + let mut state = StreamingResponse::default(); + let summary = GenerationSummary::default(); + + state.apply_message(token_message(GeneratedTokenResult::Done(summary))); + + assert!(state.summary.is_some()); + assert!(matches!(state.stop_reason, Some(StopReason::Completed))); + assert!(state.is_finished()); + } + + #[test] + fn message_error_sets_inference_error_stop_reason() { + let mut state = StreamingResponse::default(); + + state.apply_message(Message::Error(ErrorEnvelope { + request_id: "test-request".to_owned(), + error: Error { + code: 503, + description: "agent unavailable".to_owned(), + }, + })); + + assert!(matches!( + state.stop_reason, + Some(StopReason::InferenceError { code: 503, .. }) + )); + } + + #[test] + fn wire_error_sets_wire_stream_error_stop_reason() { + let mut state = StreamingResponse::default(); + + state.record_wire_error(&anyhow!("connection reset")); + + assert!(matches!( + state.stop_reason, + Some(StopReason::WireStreamError(ref message)) if message.contains("connection reset") + )); + } +} diff --git a/paddler_client_cli/src/view_chat_panels.rs b/paddler_client_cli/src/view_chat_panels.rs new file mode 100644 index 00000000..80e99892 --- /dev/null +++ b/paddler_client_cli/src/view_chat_panels.rs @@ -0,0 +1,382 @@ +use llama_cpp_bindings_types::ParsedToolCall; +use llama_cpp_bindings_types::ToolCallArguments; +use paddler_types::generation_summary::GenerationSummary; +use ratatui::Frame; +use ratatui::layout::Margin; +use ratatui::layout::Rect; +use ratatui::style::Color; +use ratatui::style::Style; +use ratatui::text::Line; +use ratatui::text::Span; +use ratatui::text::Text; +use ratatui::widgets::Block; +use ratatui::widgets::Paragraph; +use ratatui::widgets::Scrollbar; +use ratatui::widgets::ScrollbarOrientation; +use ratatui::widgets::ScrollbarState; +use ratatui::widgets::Wrap; + +use crate::streaming_response::StreamingResponse; +use crate::view_panel_kind::ViewPanelKind; +use crate::view_panel_layout::ViewPanelLayout; +use crate::view_panel_navigation::ViewPanelNavigation; + +const TOKEN_PALETTE: [Color; 6] = [ + Color::LightCyan, + Color::LightYellow, + Color::LightMagenta, + Color::LightGreen, + Color::LightBlue, + Color::LightRed, +]; + +pub fn view_chat_panels( + state: &StreamingResponse, + navigation: &mut ViewPanelNavigation, + layout: &ViewPanelLayout, + frame: &mut Frame<'_>, +) { + render_panel_text( + frame, + layout.thinking, + ViewPanelKind::Thinking, + Text::from(build_colored_token_lines(&state.thinking)), + navigation, + ); + render_panel_text( + frame, + layout.response, + ViewPanelKind::Response, + Text::from(build_colored_token_lines(&state.response)), + navigation, + ); + render_panel_text( + frame, + layout.tool_calls, + ViewPanelKind::ToolCalls, + Text::from(build_tool_calls_lines( + &state.tool_call_tokens, + &state.tool_calls, + Block::bordered().inner(layout.tool_calls).width, + )), + navigation, + ); + render_panel_text( + frame, + layout.undetermined, + ViewPanelKind::Undetermined, + Text::from(build_colored_token_lines(&state.undetermined)), + navigation, + ); + render_status_bar(frame, layout.status_bar, state); +} + +fn render_panel_text( + frame: &mut Frame<'_>, + area: Rect, + panel: ViewPanelKind, + text: Text<'_>, + navigation: &mut ViewPanelNavigation, +) { + let title = if navigation.focused() == panel { + format!("[ {} ]", panel.label()) + } else { + format!(" {} ", panel.label()) + }; + let block = Block::bordered().title(title); + let inner = block.inner(area); + let visible_rows = count_text_rows(&text, inner.width); + navigation.settle(panel, visible_rows.into(), inner.height.into()); + let position = navigation.position(panel); + let scroll_offset = u16::try_from(position).unwrap_or(u16::MAX); + + let paragraph = Paragraph::new(text) + .wrap(Wrap { trim: false }) + .scroll((scroll_offset, 0)) + .block(block); + frame.render_widget(paragraph, area); + + if visible_rows > inner.height { + let mut scrollbar_state = ScrollbarState::new(visible_rows.into()) + .position(position) + .viewport_content_length(inner.height.into()); + let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight) + .thumb_symbol("┃") + .track_symbol(Some("│")) + .begin_symbol(None) + .end_symbol(None); + frame.render_stateful_widget( + scrollbar, + area.inner(Margin { + vertical: 1, + horizontal: 0, + }), + &mut scrollbar_state, + ); + } +} + +fn build_colored_token_lines(tokens: &[String]) -> Vec> { + let mut lines: Vec> = Vec::new(); + let mut current_line: Vec> = Vec::new(); + let mut had_token = false; + + for (token_index, token) in tokens.iter().enumerate() { + if token.is_empty() { + continue; + } + had_token = true; + let is_whitespace_only = token.chars().all(char::is_whitespace); + let style = if is_whitespace_only { + Style::default() + .bg(palette_color(token_index)) + .fg(Color::Black) + } else { + Style::default().fg(palette_color(token_index)) + }; + for (piece_index, piece) in token.split('\n').enumerate() { + if piece_index > 0 { + if is_whitespace_only { + current_line.push(Span::styled("↵", style)); + } + lines.push(Line::from(std::mem::take(&mut current_line))); + } + if !piece.is_empty() { + let rendered = if is_whitespace_only { + piece.replace('\t', "→") + } else { + piece.to_owned() + }; + if !rendered.is_empty() { + current_line.push(Span::styled(rendered, style)); + } + } + } + } + + if had_token { + lines.push(Line::from(current_line)); + } + + lines +} + +fn build_tool_calls_lines( + token_stream: &[String], + parsed_calls: &[ParsedToolCall], + inner_width: u16, +) -> Vec> { + let mut lines = build_colored_token_lines(token_stream); + let has_tokens = !token_stream.is_empty(); + let has_parsed = !parsed_calls.is_empty(); + if has_tokens && has_parsed { + lines.push(divider_line(inner_width)); + } + lines.extend(parsed_call_lines(parsed_calls)); + lines +} + +fn divider_line(width: u16) -> Line<'static> { + let label = " parsed "; + let label_width = u16::try_from(label.chars().count()).unwrap_or(u16::MAX); + let total = width.max(label_width.saturating_add(2)); + let dash_count = total.saturating_sub(label_width); + let left = dash_count / 2; + let right = dash_count - left; + let mut text = String::new(); + for _ in 0..left { + text.push('─'); + } + text.push_str(label); + for _ in 0..right { + text.push('─'); + } + Line::from(Span::styled(text, Style::default().fg(Color::Gray))) +} + +fn parsed_call_lines(calls: &[ParsedToolCall]) -> Vec> { + let mut lines = Vec::new(); + for call in calls { + lines.push(Line::raw(call.name.clone())); + match &call.arguments { + ToolCallArguments::ValidJson(value) => match serde_json::to_string_pretty(value) { + Ok(formatted) => { + for inner_line in formatted.lines() { + lines.push(Line::raw(format!(" {inner_line}"))); + } + } + Err(format_error) => { + log::error!( + "failed to pretty-print tool-call arguments for {name}: {format_error}", + name = call.name + ); + lines.push(Line::raw(format!(" {value}"))); + } + }, + ToolCallArguments::InvalidJson(raw) => { + lines.push(Line::raw(format!(" invalid JSON: {raw}"))); + } + } + } + lines +} + +const fn palette_color(token_index: usize) -> Color { + TOKEN_PALETTE[token_index % TOKEN_PALETTE.len()] +} + +fn count_text_rows(text: &Text<'_>, width: u16) -> u16 { + if width == 0 { + return 0; + } + let mut total: u16 = 0; + for line in &text.lines { + let chars_count: usize = line + .spans + .iter() + .map(|span| span.content.chars().count()) + .sum(); + let chars = u16::try_from(chars_count).unwrap_or(u16::MAX); + let rows = chars.div_ceil(width).max(1); + total = total.saturating_add(rows); + } + total +} + +fn render_status_bar(frame: &mut Frame<'_>, area: Rect, state: &StreamingResponse) { + let text = match (&state.stop_reason, &state.summary) { + (None, _) => { + "generating… · tab/shift-tab focus · ↑↓ pgup/pgdn home/end scroll · q quit".to_owned() + } + (Some(_), Some(summary)) => format_completion_status(summary), + (Some(reason), None) => format!("stopped — {reason} · press q to quit"), + }; + frame.render_widget(Paragraph::new(text), area); +} + +fn format_completion_status(summary: &GenerationSummary) -> String { + let usage = summary.usage; + format!( + "done · response {response} · thinking {thinking} · tools {tools} · undet {undet} · prompt {prompt} · total {total} · press q to quit", + response = usage.content_tokens, + thinking = usage.reasoning_tokens, + tools = usage.tool_call_tokens, + undet = usage.undeterminable_tokens, + prompt = usage.prompt_tokens, + total = usage.total_tokens(), + ) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use paddler_types::generation_summary::GenerationSummary; + use ratatui::Terminal; + use ratatui::backend::TestBackend; + use ratatui::buffer::Buffer; + + use super::*; + use crate::stop_reason::StopReason; + + fn render_to_string(state: &StreamingResponse, width: u16, height: u16) -> Result { + let mut navigation = ViewPanelNavigation::default(); + let mut terminal = Terminal::new(TestBackend::new(width, height))?; + terminal.draw(|frame| { + let layout = ViewPanelLayout::compute(frame.area()); + view_chat_panels(state, &mut navigation, &layout, frame); + })?; + Ok(buffer_text(terminal.backend().buffer())) + } + + fn buffer_text(buffer: &Buffer) -> String { + let area = buffer.area; + let mut output = String::with_capacity((area.width as usize + 1) * area.height as usize); + for y in 0..area.height { + for x in 0..area.width { + output.push_str(buffer[(x, y)].symbol()); + } + output.push('\n'); + } + output + } + + #[test] + fn empty_state_shows_all_four_panels_and_generating_status() -> Result<()> { + let state = StreamingResponse::default(); + + let rendered = render_to_string(&state, 100, 30)?; + + assert!(rendered.contains("Thinking")); + assert!(rendered.contains("Response")); + assert!(rendered.contains("Tool Calls")); + assert!(rendered.contains("Undetermined")); + assert!(rendered.contains("generating")); + assert!(!rendered.contains("done")); + Ok(()) + } + + #[test] + fn focused_panel_title_uses_brackets() -> Result<()> { + let state = StreamingResponse::default(); + + let rendered = render_to_string(&state, 100, 30)?; + + assert!(rendered.contains("[ Response ]")); + assert!(rendered.contains(" Thinking ")); + Ok(()) + } + + #[test] + fn response_buffer_text_is_visible() -> Result<()> { + let mut state = StreamingResponse::default(); + state.response.push("hello world".to_owned()); + + let rendered = render_to_string(&state, 80, 30)?; + + assert!(rendered.contains("hello world")); + Ok(()) + } + + #[test] + fn completed_state_shows_summary_and_quit_hint() -> Result<()> { + let state = StreamingResponse { + summary: Some(GenerationSummary::default()), + stop_reason: Some(StopReason::Completed), + ..StreamingResponse::default() + }; + + let rendered = render_to_string(&state, 140, 30)?; + + assert!(rendered.contains("done")); + assert!(rendered.contains("press q to quit")); + assert!(!rendered.contains("generating")); + Ok(()) + } + + #[test] + fn whitespace_only_newline_token_renders_return_marker() -> Result<()> { + let mut state = StreamingResponse::default(); + state.response.push("hello".to_owned()); + state.response.push("\n".to_owned()); + state.response.push("world".to_owned()); + + let rendered = render_to_string(&state, 80, 30)?; + + assert!(rendered.contains("↵")); + Ok(()) + } + + #[test] + fn tool_calls_panel_shows_divider_when_tokens_and_parsed_both_present() -> Result<()> { + let mut state = StreamingResponse::default(); + state + .tool_call_tokens + .push("{\"name\":\"calc\"}".to_owned()); + state.tool_calls.push(ParsedToolCall::default()); + + let rendered = render_to_string(&state, 120, 30)?; + + assert!(rendered.contains("parsed")); + Ok(()) + } +} diff --git a/paddler_client_cli/src/view_panel_kind.rs b/paddler_client_cli/src/view_panel_kind.rs new file mode 100644 index 00000000..0a8eac0e --- /dev/null +++ b/paddler_client_cli/src/view_panel_kind.rs @@ -0,0 +1,19 @@ +#[repr(u8)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ViewPanelKind { + Thinking = 0, + Response = 1, + ToolCalls = 2, + Undetermined = 3, +} + +impl ViewPanelKind { + pub const fn label(self) -> &'static str { + match self { + Self::Thinking => "Thinking", + Self::Response => "Response", + Self::ToolCalls => "Tool Calls", + Self::Undetermined => "Undetermined", + } + } +} diff --git a/paddler_client_cli/src/view_panel_layout.rs b/paddler_client_cli/src/view_panel_layout.rs new file mode 100644 index 00000000..0873dda1 --- /dev/null +++ b/paddler_client_cli/src/view_panel_layout.rs @@ -0,0 +1,61 @@ +use ratatui::layout::Constraint; +use ratatui::layout::Layout; +use ratatui::layout::Position; +use ratatui::layout::Rect; + +use crate::view_panel_kind::ViewPanelKind; + +const STATUS_BAR_HEIGHT: u16 = 1; + +pub struct ViewPanelLayout { + pub thinking: Rect, + pub response: Rect, + pub tool_calls: Rect, + pub undetermined: Rect, + pub status_bar: Rect, +} + +impl ViewPanelLayout { + pub fn compute(area: Rect) -> Self { + let outer = Layout::vertical([Constraint::Min(0), Constraint::Length(STATUS_BAR_HEIGHT)]) + .split(area); + let rows = Layout::vertical([Constraint::Percentage(50), Constraint::Percentage(50)]) + .split(outer[0]); + let top = Layout::horizontal([Constraint::Percentage(50), Constraint::Percentage(50)]) + .split(rows[0]); + let bottom = Layout::horizontal([Constraint::Percentage(50), Constraint::Percentage(50)]) + .split(rows[1]); + Self { + thinking: top[0], + response: top[1], + tool_calls: bottom[0], + undetermined: bottom[1], + status_bar: outer[1], + } + } + + pub const fn rect_for(&self, panel: ViewPanelKind) -> Rect { + match panel { + ViewPanelKind::Thinking => self.thinking, + ViewPanelKind::Response => self.response, + ViewPanelKind::ToolCalls => self.tool_calls, + ViewPanelKind::Undetermined => self.undetermined, + } + } + + pub const fn viewport_rows(&self, panel: ViewPanelKind) -> u16 { + self.rect_for(panel).height.saturating_sub(2) + } + + pub fn panel_at(&self, column: u16, row: u16) -> Option { + let position = Position { x: column, y: row }; + [ + ViewPanelKind::Thinking, + ViewPanelKind::Response, + ViewPanelKind::ToolCalls, + ViewPanelKind::Undetermined, + ] + .into_iter() + .find(|panel| self.rect_for(*panel).contains(position)) + } +} diff --git a/paddler_client_cli/src/view_panel_navigation.rs b/paddler_client_cli/src/view_panel_navigation.rs new file mode 100644 index 00000000..62416c7b --- /dev/null +++ b/paddler_client_cli/src/view_panel_navigation.rs @@ -0,0 +1,185 @@ +use crate::view_panel_kind::ViewPanelKind; + +const PANEL_COUNT: usize = 4; + +pub struct ViewPanelNavigation { + focused: ViewPanelKind, + views: [PanelView; PANEL_COUNT], +} + +#[derive(Clone, Copy)] +struct PanelView { + position: usize, + follow_bottom: bool, +} + +impl Default for PanelView { + fn default() -> Self { + Self { + position: 0, + follow_bottom: true, + } + } +} + +impl Default for ViewPanelNavigation { + fn default() -> Self { + Self { + focused: ViewPanelKind::Response, + views: [PanelView::default(); PANEL_COUNT], + } + } +} + +impl ViewPanelNavigation { + pub const fn focused(&self) -> ViewPanelKind { + self.focused + } + + pub const fn focus(&mut self, panel: ViewPanelKind) { + self.focused = panel; + } + + pub const fn cycle_focus_forward(&mut self) { + self.focused = match self.focused { + ViewPanelKind::Thinking => ViewPanelKind::Response, + ViewPanelKind::Response => ViewPanelKind::ToolCalls, + ViewPanelKind::ToolCalls => ViewPanelKind::Undetermined, + ViewPanelKind::Undetermined => ViewPanelKind::Thinking, + }; + } + + pub const fn cycle_focus_backward(&mut self) { + self.focused = match self.focused { + ViewPanelKind::Thinking => ViewPanelKind::Undetermined, + ViewPanelKind::Response => ViewPanelKind::Thinking, + ViewPanelKind::ToolCalls => ViewPanelKind::Response, + ViewPanelKind::Undetermined => ViewPanelKind::ToolCalls, + }; + } + + pub fn scroll_up(&mut self, panel: ViewPanelKind, lines: u16) { + let view = &mut self.views[panel as usize]; + view.follow_bottom = false; + view.position = view.position.saturating_sub(lines.into()); + } + + pub fn scroll_down(&mut self, panel: ViewPanelKind, lines: u16) { + let view = &mut self.views[panel as usize]; + view.position = view.position.saturating_add(lines.into()); + } + + pub const fn jump_to_top(&mut self, panel: ViewPanelKind) { + let view = &mut self.views[panel as usize]; + view.follow_bottom = false; + view.position = 0; + } + + pub const fn jump_to_bottom(&mut self, panel: ViewPanelKind) { + self.views[panel as usize].follow_bottom = true; + } + + pub const fn settle( + &mut self, + panel: ViewPanelKind, + content_rows: usize, + viewport_rows: usize, + ) { + let view = &mut self.views[panel as usize]; + let max_position = content_rows.saturating_sub(viewport_rows); + if view.follow_bottom || view.position >= max_position { + view.position = max_position; + view.follow_bottom = true; + } + } + + pub const fn position(&self, panel: ViewPanelKind) -> usize { + self.views[panel as usize].position + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn defaults_focus_response_and_follow_bottom() { + let mut nav = ViewPanelNavigation::default(); + + assert_eq!(nav.focused(), ViewPanelKind::Response); + nav.settle(ViewPanelKind::Response, 100, 10); + assert_eq!(nav.position(ViewPanelKind::Response), 90); + } + + #[test] + fn scroll_up_disengages_follow_and_decrements_position() { + let mut nav = ViewPanelNavigation::default(); + + nav.scroll_up(ViewPanelKind::Response, 5); + + nav.settle(ViewPanelKind::Response, 100, 10); + assert_eq!(nav.position(ViewPanelKind::Response), 0); + } + + #[test] + fn scroll_down_after_scroll_up_advances_within_content() { + let mut nav = ViewPanelNavigation::default(); + nav.scroll_up(ViewPanelKind::Response, 50); + + nav.scroll_down(ViewPanelKind::Response, 10); + + nav.settle(ViewPanelKind::Response, 100, 10); + assert_eq!(nav.position(ViewPanelKind::Response), 10); + } + + #[test] + fn jump_to_bottom_re_engages_auto_follow() { + let mut nav = ViewPanelNavigation::default(); + nav.scroll_up(ViewPanelKind::Response, 50); + + nav.jump_to_bottom(ViewPanelKind::Response); + + nav.settle(ViewPanelKind::Response, 200, 10); + assert_eq!(nav.position(ViewPanelKind::Response), 190); + } + + #[test] + fn cycle_focus_forward_walks_panels_in_reading_order() { + let mut nav = ViewPanelNavigation::default(); + + nav.cycle_focus_forward(); + assert_eq!(nav.focused(), ViewPanelKind::ToolCalls); + nav.cycle_focus_forward(); + assert_eq!(nav.focused(), ViewPanelKind::Undetermined); + nav.cycle_focus_forward(); + assert_eq!(nav.focused(), ViewPanelKind::Thinking); + nav.cycle_focus_forward(); + assert_eq!(nav.focused(), ViewPanelKind::Response); + } + + #[test] + fn scrolling_back_to_bottom_re_engages_auto_follow_for_subsequent_growth() { + let mut nav = ViewPanelNavigation::default(); + nav.settle(ViewPanelKind::Response, 100, 10); + + nav.scroll_up(ViewPanelKind::Response, 5); + nav.settle(ViewPanelKind::Response, 100, 10); + + nav.scroll_down(ViewPanelKind::Response, 10); + nav.settle(ViewPanelKind::Response, 100, 10); + + nav.settle(ViewPanelKind::Response, 110, 10); + + assert_eq!(nav.position(ViewPanelKind::Response), 100); + } + + #[test] + fn position_is_clamped_when_content_shorter_than_stored_offset() { + let mut nav = ViewPanelNavigation::default(); + nav.scroll_up(ViewPanelKind::Response, 0); + nav.scroll_down(ViewPanelKind::Response, 80); + + nav.settle(ViewPanelKind::Response, 30, 10); + assert_eq!(nav.position(ViewPanelKind::Response), 20); + } +} diff --git a/paddler_client_cli/src/view_terminal_guard.rs b/paddler_client_cli/src/view_terminal_guard.rs new file mode 100644 index 00000000..d23f6bba --- /dev/null +++ b/paddler_client_cli/src/view_terminal_guard.rs @@ -0,0 +1,57 @@ +use std::io; + +use anyhow::Context; +use anyhow::Result; +use crossterm::ExecutableCommand; +use crossterm::event::DisableMouseCapture; +use crossterm::event::EnableMouseCapture; +use crossterm::terminal::EnterAlternateScreen; +use crossterm::terminal::LeaveAlternateScreen; +use crossterm::terminal::disable_raw_mode; +use crossterm::terminal::enable_raw_mode; + +pub struct ViewTerminalGuard; + +impl ViewTerminalGuard { + pub fn enter() -> Result { + enable_raw_mode().context("enabling raw mode")?; + if let Err(enter_alt_screen_error) = io::stdout().execute(EnterAlternateScreen) { + if let Err(rollback_error) = disable_raw_mode() { + log::error!( + "failed to disable raw mode while rolling back alt-screen entry: {rollback_error}" + ); + } + return Err( + anyhow::Error::from(enter_alt_screen_error).context("entering alternate screen") + ); + } + if let Err(enable_mouse_error) = io::stdout().execute(EnableMouseCapture) { + if let Err(leave_alt_screen_error) = io::stdout().execute(LeaveAlternateScreen) { + log::error!( + "failed to leave alt screen while rolling back mouse-capture: {leave_alt_screen_error}" + ); + } + if let Err(rollback_error) = disable_raw_mode() { + log::error!( + "failed to disable raw mode while rolling back mouse-capture: {rollback_error}" + ); + } + return Err(anyhow::Error::from(enable_mouse_error).context("enabling mouse capture")); + } + Ok(Self) + } +} + +impl Drop for ViewTerminalGuard { + fn drop(&mut self) { + if let Err(disable_mouse_error) = io::stdout().execute(DisableMouseCapture) { + log::error!("failed to disable mouse capture: {disable_mouse_error}"); + } + if let Err(leave_alt_screen_error) = io::stdout().execute(LeaveAlternateScreen) { + log::error!("failed to leave alternate screen: {leave_alt_screen_error}"); + } + if let Err(disable_raw_mode_error) = disable_raw_mode() { + log::error!("failed to disable raw mode: {disable_raw_mode_error}"); + } + } +} diff --git a/paddler_client_javascript/.gitignore b/paddler_client_javascript/.gitignore new file mode 100644 index 00000000..10ce1293 --- /dev/null +++ b/paddler_client_javascript/.gitignore @@ -0,0 +1,6 @@ +node_modules/ +dist/ +*.tsbuildinfo +*.log +coverage/ +.nyc_output/ diff --git a/paddler_client_javascript/LICENSE b/paddler_client_javascript/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/paddler_client_javascript/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/paddler_client_javascript/Makefile b/paddler_client_javascript/Makefile new file mode 100644 index 00000000..97147eae --- /dev/null +++ b/paddler_client_javascript/Makefile @@ -0,0 +1,10 @@ +.PHONY: build +build: + npm run build + +.PHONY: test +test: + npm test + +.PHONY: check +check: test diff --git a/paddler_client_javascript/README.md b/paddler_client_javascript/README.md new file mode 100644 index 00000000..1f1e23d0 --- /dev/null +++ b/paddler_client_javascript/README.md @@ -0,0 +1,88 @@ +# @intentee/paddler-client + +JavaScript/TypeScript client for the [Paddler](https://github.com/intentee/paddler) LLM load balancer. + +## Install + +```sh +npm install @intentee/paddler-client rxjs zod +``` + +`rxjs` and `zod` are peer dependencies. + +## Quick start + +### WebSocket inference (multiplexed, request-id-keyed) + +```ts +import { inferenceSocketClient } from "@intentee/paddler-client"; + +const webSocket = new WebSocket("ws://localhost:8061/api/v1/inference_socket"); + +const { continueConversation } = inferenceSocketClient({ webSocket }); + +continueConversation({ + enableThinking: true, + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Hello" }, + ], +}).subscribe((chunk) => { + if (chunk.error) { + console.error(chunk.error); + return; + } + if (chunk.done) { + console.log("done", chunk.summary); + return; + } + if (chunk.token !== null) { + process.stdout.write(chunk.token); + } +}); +``` + +### HTTP NDJSON streaming + +```ts +import { streamHttpNdjson, InferenceServiceGenerateTokensResponseSchema } from "@intentee/paddler-client"; + +const controller = new AbortController(); + +streamHttpNdjson({ + url: new URL("http://localhost:8061/api/v1/continue_from_conversation_history"), + body: { add_generation_prompt: true, conversation_history: [...], max_tokens: 200 }, + signal: controller.signal, + schema: InferenceServiceGenerateTokensResponseSchema, +}).subscribe(/* ... */); +``` + +### SSE management stream + +```ts +import { streamEventSource, AgentsResponseSchema, matchEventSourceUpdateState } from "@intentee/paddler-client"; + +streamEventSource({ + url: new URL("http://localhost:8062/api/v1/agents/stream"), + schema: AgentsResponseSchema, +}).subscribe((state) => { + matchEventSourceUpdateState(state, { + initial: () => console.log("connecting"), + connected: () => console.log("connected"), + dataSnapshot: ({ data }) => console.log("agents", data.agents.length), + connectionError: () => console.error("connection lost"), + deserializationError: () => console.error("invalid payload"), + }); +}); +``` + +## Coverage + +- Transport: WebSocket (multiplexed), HTTP NDJSON, HTTP JSON, Server-Sent Events +- Schemas: every Paddler wire-format type (validated via zod) +- State machines + exhaustive matchers for connection/stream/fetch states +- Specialized error types per failure mode + +## License + +Apache-2.0 diff --git a/paddler_client_javascript/package.json b/paddler_client_javascript/package.json new file mode 100644 index 00000000..64d92f4e --- /dev/null +++ b/paddler_client_javascript/package.json @@ -0,0 +1,41 @@ +{ + "name": "@intentee/paddler-client", + "version": "3.1.2", + "description": "JavaScript/TypeScript client for the Paddler LLM load balancer", + "license": "Apache-2.0", + "homepage": "https://github.com/intentee/paddler", + "repository": "https://github.com/intentee/paddler", + "type": "module", + "exports": { + "./*": "./src/*.ts" + }, + "publishConfig": { + "exports": { + "./*": { + "types": "./dist/*.d.ts", + "default": "./dist/*.js" + } + } + }, + "files": ["dist", "README.md", "LICENSE"], + "scripts": { + "build": "tsc -p .", + "test": "tsc -p tsconfig.test.json && tsx --test", + "prepack": "tsc -p ." + }, + "peerDependencies": { + "rxjs": "^7.8", + "zod": "^4" + }, + "dependencies": { + "nanoid": "^5.1.11" + }, + "devDependencies": { + "@types/node": "^22", + "tsx": "^4.21.0", + "typescript": "^5.9.2" + }, + "engines": { + "node": ">=22" + } +} diff --git a/paddler_client_javascript/shell.nix b/paddler_client_javascript/shell.nix new file mode 100644 index 00000000..1ebe6e4a --- /dev/null +++ b/paddler_client_javascript/shell.nix @@ -0,0 +1,7 @@ +{ pkgs ? import {} }: + +pkgs.mkShell { + nativeBuildInputs = with pkgs; [ + nodejs_22 + ]; +} diff --git a/paddler_client_javascript/src/ConnectionDroppedError.ts b/paddler_client_javascript/src/ConnectionDroppedError.ts new file mode 100644 index 00000000..b8e315f8 --- /dev/null +++ b/paddler_client_javascript/src/ConnectionDroppedError.ts @@ -0,0 +1,9 @@ +import { PaddlerError } from "./PaddlerError"; + +export class ConnectionDroppedError extends PaddlerError { + override name = "ConnectionDroppedError"; + + constructor(public readonly requestId: string) { + super(`Connection dropped while streaming request ${requestId}`); + } +} diff --git a/paddler_client_javascript/src/EventSourceConnectedState.ts b/paddler_client_javascript/src/EventSourceConnectedState.ts new file mode 100644 index 00000000..25b531d2 --- /dev/null +++ b/paddler_client_javascript/src/EventSourceConnectedState.ts @@ -0,0 +1,18 @@ +export type EventSourceConnectedState = { + data: undefined; + isConnected: true; + isConnectionError: false; + isDeserializationError: false; + isInitial: false; + isOk: false; +}; + +export const eventSourceConnectedState: EventSourceConnectedState = + Object.freeze({ + data: undefined, + isConnected: true, + isConnectionError: false, + isDeserializationError: false, + isInitial: false, + isOk: false, + }); diff --git a/paddler_client_javascript/src/EventSourceConnectionErrorState.ts b/paddler_client_javascript/src/EventSourceConnectionErrorState.ts new file mode 100644 index 00000000..6891fb63 --- /dev/null +++ b/paddler_client_javascript/src/EventSourceConnectionErrorState.ts @@ -0,0 +1,18 @@ +export type EventSourceConnectionErrorState = { + data: undefined; + isConnected: false; + isConnectionError: true; + isDeserializationError: false; + isInitial: false; + isOk: false; +}; + +export const eventSourceConnectionErrorState: EventSourceConnectionErrorState = + Object.freeze({ + data: undefined, + isConnected: false, + isConnectionError: true, + isDeserializationError: false, + isInitial: false, + isOk: false, + }); diff --git a/paddler_client_javascript/src/EventSourceDataSnapshotState.ts b/paddler_client_javascript/src/EventSourceDataSnapshotState.ts new file mode 100644 index 00000000..a97331b5 --- /dev/null +++ b/paddler_client_javascript/src/EventSourceDataSnapshotState.ts @@ -0,0 +1,10 @@ +import type { z } from "zod"; + +export type EventSourceDataSnapshotState = { + data: z.infer; + isConnected: true; + isConnectionError: false; + isDeserializationError: false; + isInitial: false; + isOk: true; +}; diff --git a/paddler_client_javascript/src/EventSourceDeserializationErrorState.ts b/paddler_client_javascript/src/EventSourceDeserializationErrorState.ts new file mode 100644 index 00000000..c24fe02d --- /dev/null +++ b/paddler_client_javascript/src/EventSourceDeserializationErrorState.ts @@ -0,0 +1,18 @@ +export type EventSourceDeserializationErrorState = { + data: undefined; + isConnected: true; + isConnectionError: false; + isDeserializationError: true; + isInitial: false; + isOk: false; +}; + +export const eventSourceDeserializationErrorState: EventSourceDeserializationErrorState = + Object.freeze({ + data: undefined, + isConnected: true, + isConnectionError: false, + isDeserializationError: true, + isInitial: false, + isOk: false, + }); diff --git a/paddler_client_javascript/src/EventSourceInitialState.ts b/paddler_client_javascript/src/EventSourceInitialState.ts new file mode 100644 index 00000000..f00fad6b --- /dev/null +++ b/paddler_client_javascript/src/EventSourceInitialState.ts @@ -0,0 +1,17 @@ +export type EventSourceInitialState = { + data: undefined; + isConnected: false; + isConnectionError: false; + isDeserializationError: false; + isInitial: true; + isOk: false; +}; + +export const eventSourceInitialState: EventSourceInitialState = Object.freeze({ + data: undefined, + isConnected: false, + isConnectionError: false, + isDeserializationError: false, + isInitial: true, + isOk: false, +}); diff --git a/paddler_client_javascript/src/EventSourceState.ts b/paddler_client_javascript/src/EventSourceState.ts new file mode 100644 index 00000000..106d11cc --- /dev/null +++ b/paddler_client_javascript/src/EventSourceState.ts @@ -0,0 +1,14 @@ +import type { z } from "zod"; + +import type { EventSourceConnectedState } from "./EventSourceConnectedState"; +import type { EventSourceConnectionErrorState } from "./EventSourceConnectionErrorState"; +import type { EventSourceDataSnapshotState } from "./EventSourceDataSnapshotState"; +import type { EventSourceDeserializationErrorState } from "./EventSourceDeserializationErrorState"; +import type { EventSourceInitialState } from "./EventSourceInitialState"; + +export type EventSourceState = + | EventSourceConnectedState + | EventSourceConnectionErrorState + | EventSourceDataSnapshotState + | EventSourceDeserializationErrorState + | EventSourceInitialState; diff --git a/paddler_client_javascript/src/FetchJsonEmptyState.ts b/paddler_client_javascript/src/FetchJsonEmptyState.ts new file mode 100644 index 00000000..ffa04ef3 --- /dev/null +++ b/paddler_client_javascript/src/FetchJsonEmptyState.ts @@ -0,0 +1,15 @@ +export type FetchJsonEmptyState = { + empty: true; + error: null; + loading: false; + ok: false; + response: null; +}; + +export const fetchJsonEmptyState: FetchJsonEmptyState = Object.freeze({ + empty: true, + error: null, + loading: false, + ok: false, + response: null, +}); diff --git a/paddler_client_javascript/src/FetchJsonErrorState.ts b/paddler_client_javascript/src/FetchJsonErrorState.ts new file mode 100644 index 00000000..a5a9aff5 --- /dev/null +++ b/paddler_client_javascript/src/FetchJsonErrorState.ts @@ -0,0 +1,7 @@ +export type FetchJsonErrorState = { + empty: false; + error: string; + loading: false; + ok: false; + response: null; +}; diff --git a/paddler_client_javascript/src/FetchJsonLoadingState.ts b/paddler_client_javascript/src/FetchJsonLoadingState.ts new file mode 100644 index 00000000..71b5bc4b --- /dev/null +++ b/paddler_client_javascript/src/FetchJsonLoadingState.ts @@ -0,0 +1,15 @@ +export type FetchJsonLoadingState = { + empty: false; + error: null; + loading: true; + ok: false; + response: null; +}; + +export const fetchJsonLoadingState: FetchJsonLoadingState = Object.freeze({ + empty: false, + error: null, + loading: true, + ok: false, + response: null, +}); diff --git a/paddler_client_javascript/src/FetchJsonState.ts b/paddler_client_javascript/src/FetchJsonState.ts new file mode 100644 index 00000000..4112640b --- /dev/null +++ b/paddler_client_javascript/src/FetchJsonState.ts @@ -0,0 +1,10 @@ +import type { FetchJsonEmptyState } from "./FetchJsonEmptyState"; +import type { FetchJsonErrorState } from "./FetchJsonErrorState"; +import type { FetchJsonLoadingState } from "./FetchJsonLoadingState"; +import type { FetchJsonSuccessState } from "./FetchJsonSuccessState"; + +export type FetchJsonState = + | FetchJsonEmptyState + | FetchJsonErrorState + | FetchJsonLoadingState + | FetchJsonSuccessState; diff --git a/paddler_client_javascript/src/FetchJsonSuccessState.ts b/paddler_client_javascript/src/FetchJsonSuccessState.ts new file mode 100644 index 00000000..f5ec89e9 --- /dev/null +++ b/paddler_client_javascript/src/FetchJsonSuccessState.ts @@ -0,0 +1,7 @@ +export type FetchJsonSuccessState = { + empty: false; + error: null; + loading: false; + ok: true; + response: TResult; +}; diff --git a/paddler_client_javascript/src/HttpError.ts b/paddler_client_javascript/src/HttpError.ts new file mode 100644 index 00000000..285d5b06 --- /dev/null +++ b/paddler_client_javascript/src/HttpError.ts @@ -0,0 +1,12 @@ +import { PaddlerError } from "./PaddlerError"; + +export class HttpError extends PaddlerError { + override name = "HttpError"; + + constructor( + public readonly statusCode: number, + message: string, + ) { + super(message); + } +} diff --git a/paddler_client_javascript/src/JsonError.ts b/paddler_client_javascript/src/JsonError.ts new file mode 100644 index 00000000..b6c03ccc --- /dev/null +++ b/paddler_client_javascript/src/JsonError.ts @@ -0,0 +1,12 @@ +import { PaddlerError } from "./PaddlerError"; + +export class JsonError extends PaddlerError { + override name = "JsonError"; + + constructor( + message: string, + public readonly raw: string, + ) { + super(message); + } +} diff --git a/paddler_client_javascript/src/PaddlerError.ts b/paddler_client_javascript/src/PaddlerError.ts new file mode 100644 index 00000000..b470ce4c --- /dev/null +++ b/paddler_client_javascript/src/PaddlerError.ts @@ -0,0 +1,3 @@ +export class PaddlerError extends Error { + override name = "PaddlerError"; +} diff --git a/paddler_client_javascript/src/ServerError.ts b/paddler_client_javascript/src/ServerError.ts new file mode 100644 index 00000000..38d04232 --- /dev/null +++ b/paddler_client_javascript/src/ServerError.ts @@ -0,0 +1,12 @@ +import { PaddlerError } from "./PaddlerError"; + +export class ServerError extends PaddlerError { + override name = "ServerError"; + + constructor( + public readonly code: number, + message: string, + ) { + super(message); + } +} diff --git a/paddler_client_javascript/src/WebSocketConnectingState.ts b/paddler_client_javascript/src/WebSocketConnectingState.ts new file mode 100644 index 00000000..f7271516 --- /dev/null +++ b/paddler_client_javascript/src/WebSocketConnectingState.ts @@ -0,0 +1,15 @@ +export type WebSocketConnectingState = { + isConnected: false; + isConnectionClosed: false; + isConnectionError: false; + webSocket: null; +}; + +export const webSocketConnectingState: WebSocketConnectingState = Object.freeze( + { + isConnected: false, + isConnectionClosed: false, + isConnectionError: false, + webSocket: null, + }, +); diff --git a/paddler_client_javascript/src/WebSocketConnectionClosedState.ts b/paddler_client_javascript/src/WebSocketConnectionClosedState.ts new file mode 100644 index 00000000..99ddf763 --- /dev/null +++ b/paddler_client_javascript/src/WebSocketConnectionClosedState.ts @@ -0,0 +1,14 @@ +export type WebSocketConnectionClosedState = { + isConnected: false; + isConnectionClosed: true; + isConnectionError: false; + webSocket: null; +}; + +export const webSocketConnectionClosedState: WebSocketConnectionClosedState = + Object.freeze({ + isConnected: false, + isConnectionClosed: true, + isConnectionError: false, + webSocket: null, + }); diff --git a/paddler_client_javascript/src/WebSocketConnectionErrorState.ts b/paddler_client_javascript/src/WebSocketConnectionErrorState.ts new file mode 100644 index 00000000..1e01bda4 --- /dev/null +++ b/paddler_client_javascript/src/WebSocketConnectionErrorState.ts @@ -0,0 +1,14 @@ +export type WebSocketConnectionErrorState = { + isConnected: false; + isConnectionClosed: false; + isConnectionError: true; + webSocket: null; +}; + +export const webSocketConnectionErrorState: WebSocketConnectionErrorState = + Object.freeze({ + isConnected: false, + isConnectionClosed: false, + isConnectionError: true, + webSocket: null, + }); diff --git a/paddler_client_javascript/src/WebSocketConnectionOpenedState.ts b/paddler_client_javascript/src/WebSocketConnectionOpenedState.ts new file mode 100644 index 00000000..435e532f --- /dev/null +++ b/paddler_client_javascript/src/WebSocketConnectionOpenedState.ts @@ -0,0 +1,6 @@ +export type WebSocketConnectionOpenedState = { + isConnected: true; + isConnectionClosed: false; + isConnectionError: false; + webSocket: WebSocket; +}; diff --git a/paddler_client_javascript/src/WebSocketError.ts b/paddler_client_javascript/src/WebSocketError.ts new file mode 100644 index 00000000..7274ed4d --- /dev/null +++ b/paddler_client_javascript/src/WebSocketError.ts @@ -0,0 +1,5 @@ +import { PaddlerError } from "./PaddlerError"; + +export class WebSocketError extends PaddlerError { + override name = "WebSocketError"; +} diff --git a/paddler_client_javascript/src/WebSocketState.ts b/paddler_client_javascript/src/WebSocketState.ts new file mode 100644 index 00000000..ca2968ae --- /dev/null +++ b/paddler_client_javascript/src/WebSocketState.ts @@ -0,0 +1,10 @@ +import type { WebSocketConnectingState } from "./WebSocketConnectingState"; +import type { WebSocketConnectionClosedState } from "./WebSocketConnectionClosedState"; +import type { WebSocketConnectionErrorState } from "./WebSocketConnectionErrorState"; +import type { WebSocketConnectionOpenedState } from "./WebSocketConnectionOpenedState"; + +export type WebSocketState = + | WebSocketConnectingState + | WebSocketConnectionClosedState + | WebSocketConnectionErrorState + | WebSocketConnectionOpenedState; diff --git a/paddler_client_javascript/src/extractHuggingFaceUrlParts.ts b/paddler_client_javascript/src/extractHuggingFaceUrlParts.ts new file mode 100644 index 00000000..ec7e6f5d --- /dev/null +++ b/paddler_client_javascript/src/extractHuggingFaceUrlParts.ts @@ -0,0 +1,38 @@ +import type { HuggingFaceModelReference } from "./schemas/HuggingFaceModelReference"; + +export function extractHuggingFaceUrlParts({ + pathname, +}: URL): HuggingFaceModelReference { + const segments = pathname.split("/").filter(function (segment) { + return segment.length > 0; + }); + + if (segments.length < 5) { + throw new Error(`Invalid Hugging Face URL format: ${pathname}`); + } + + const [owner, repo, resourceKind, revision, ...filenameSegments] = segments; + + if ( + owner === undefined + || repo === undefined + || resourceKind === undefined + || revision === undefined + ) { + throw new Error(`Invalid Hugging Face URL format: ${pathname}`); + } + + if (resourceKind !== "blob" && resourceKind !== "resolve") { + throw new Error(`Invalid Hugging Face URL format: ${pathname}`); + } + + if (filenameSegments.length < 1) { + throw new Error(`Invalid Hugging Face URL format: ${pathname}`); + } + + return { + filename: filenameSegments.join("/"), + repo_id: `${owner}/${repo}`, + revision, + }; +} diff --git a/paddler_client_javascript/src/fetchJson.ts b/paddler_client_javascript/src/fetchJson.ts new file mode 100644 index 00000000..3f263b3e --- /dev/null +++ b/paddler_client_javascript/src/fetchJson.ts @@ -0,0 +1,26 @@ +import type { z } from "zod"; + +import { HttpError } from "./HttpError"; + +export async function fetchJson({ + url, + signal, + schema, +}: { + url: URL | string; + signal: AbortSignal; + schema: TSchema; +}): Promise> { + const response = await fetch(url, { signal }); + + if (!response.ok) { + throw new HttpError( + response.status, + `HTTP ${response.status} ${response.statusText}`, + ); + } + + const payload: unknown = await response.json(); + + return schema.parse(payload); +} diff --git a/resources/ts/InferenceSocketClient.ts b/paddler_client_javascript/src/inferenceSocketClient.ts similarity index 86% rename from resources/ts/InferenceSocketClient.ts rename to paddler_client_javascript/src/inferenceSocketClient.ts index 8b3bf261..82b95259 100644 --- a/resources/ts/InferenceSocketClient.ts +++ b/paddler_client_javascript/src/inferenceSocketClient.ts @@ -1,14 +1,20 @@ import { nanoid } from "nanoid"; import { filter, fromEvent, map, takeWhile, type Observable } from "rxjs"; -import { type ConversationMessage } from "./ConversationMessage.type"; -import { type InferenceSocketClient } from "./InferenceSocketClient.interface"; import { InferenceServiceGenerateTokensResponseSchema, type InferenceServiceGenerateTokensResponse, } from "./schemas/InferenceServiceGenerateTokensResponse"; +import type { ConversationMessage } from "./schemas/ConversationMessage"; -export function InferenceSocketClient({ +export interface InferenceSocketClient { + continueConversation(params: { + enableThinking: boolean; + messages: ConversationMessage[]; + }): Observable; +} + +export function inferenceSocketClient({ webSocket, }: { webSocket: WebSocket; diff --git a/resources/ts/schemas/Agent.ts b/paddler_client_javascript/src/schemas/Agent.ts similarity index 100% rename from resources/ts/schemas/Agent.ts rename to paddler_client_javascript/src/schemas/Agent.ts diff --git a/resources/ts/schemas/AgentDesiredModel.ts b/paddler_client_javascript/src/schemas/AgentDesiredModel.ts similarity index 100% rename from resources/ts/schemas/AgentDesiredModel.ts rename to paddler_client_javascript/src/schemas/AgentDesiredModel.ts diff --git a/resources/ts/schemas/AgentIssue.ts b/paddler_client_javascript/src/schemas/AgentIssue.ts similarity index 100% rename from resources/ts/schemas/AgentIssue.ts rename to paddler_client_javascript/src/schemas/AgentIssue.ts diff --git a/resources/ts/schemas/AgentIssueModelPath.ts b/paddler_client_javascript/src/schemas/AgentIssueModelPath.ts similarity index 100% rename from resources/ts/schemas/AgentIssueModelPath.ts rename to paddler_client_javascript/src/schemas/AgentIssueModelPath.ts diff --git a/resources/ts/schemas/AgentsResponse.ts b/paddler_client_javascript/src/schemas/AgentsResponse.ts similarity index 100% rename from resources/ts/schemas/AgentsResponse.ts rename to paddler_client_javascript/src/schemas/AgentsResponse.ts diff --git a/resources/ts/schemas/BalancerDesiredState.ts b/paddler_client_javascript/src/schemas/BalancerDesiredState.ts similarity index 100% rename from resources/ts/schemas/BalancerDesiredState.ts rename to paddler_client_javascript/src/schemas/BalancerDesiredState.ts diff --git a/resources/ts/schemas/BufferedRequestsResponse.ts b/paddler_client_javascript/src/schemas/BufferedRequestsResponse.ts similarity index 100% rename from resources/ts/schemas/BufferedRequestsResponse.ts rename to paddler_client_javascript/src/schemas/BufferedRequestsResponse.ts diff --git a/resources/ts/schemas/ChatTemplate.ts b/paddler_client_javascript/src/schemas/ChatTemplate.ts similarity index 100% rename from resources/ts/schemas/ChatTemplate.ts rename to paddler_client_javascript/src/schemas/ChatTemplate.ts diff --git a/paddler_client_javascript/src/schemas/ContinueFromConversationHistoryParams.ts b/paddler_client_javascript/src/schemas/ContinueFromConversationHistoryParams.ts new file mode 100644 index 00000000..6b019e9c --- /dev/null +++ b/paddler_client_javascript/src/schemas/ContinueFromConversationHistoryParams.ts @@ -0,0 +1,21 @@ +import { z } from "zod"; + +import { ConversationMessageSchema } from "./ConversationMessage"; +import { GrammarConstraintSchema } from "./GrammarConstraint"; +import { ToolSchema } from "./Tool"; + +export const ContinueFromConversationHistoryParamsSchema = z + .object({ + add_generation_prompt: z.boolean(), + conversation_history: z.array(ConversationMessageSchema), + enable_thinking: z.boolean(), + grammar: GrammarConstraintSchema.nullable().optional(), + max_tokens: z.number().int(), + parse_tool_calls: z.boolean().optional(), + tools: z.array(ToolSchema).optional(), + }) + .strict(); + +export type ContinueFromConversationHistoryParams = z.infer< + typeof ContinueFromConversationHistoryParamsSchema +>; diff --git a/paddler_client_javascript/src/schemas/ContinueFromRawPromptParams.ts b/paddler_client_javascript/src/schemas/ContinueFromRawPromptParams.ts new file mode 100644 index 00000000..b6716ec1 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ContinueFromRawPromptParams.ts @@ -0,0 +1,15 @@ +import { z } from "zod"; + +import { GrammarConstraintSchema } from "./GrammarConstraint"; + +export const ContinueFromRawPromptParamsSchema = z + .object({ + grammar: GrammarConstraintSchema.nullable().optional(), + max_tokens: z.number().int(), + raw_prompt: z.string(), + }) + .strict(); + +export type ContinueFromRawPromptParams = z.infer< + typeof ContinueFromRawPromptParamsSchema +>; diff --git a/paddler_client_javascript/src/schemas/ConversationMessage.ts b/paddler_client_javascript/src/schemas/ConversationMessage.ts new file mode 100644 index 00000000..64d18333 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ConversationMessage.ts @@ -0,0 +1,13 @@ +import { z } from "zod"; + +import { ConversationMessageContentPartSchema } from "./ConversationMessageContentPart"; + +export const ConversationMessageSchema = z.object({ + role: z.string(), + content: z.union([ + z.string(), + z.array(ConversationMessageContentPartSchema), + ]), +}); + +export type ConversationMessage = z.infer; diff --git a/paddler_client_javascript/src/schemas/ConversationMessageContentPart.ts b/paddler_client_javascript/src/schemas/ConversationMessageContentPart.ts new file mode 100644 index 00000000..1bd08398 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ConversationMessageContentPart.ts @@ -0,0 +1,16 @@ +import { z } from "zod"; + +export const ConversationMessageContentPartSchema = z.union([ + z.object({ + type: z.literal("text"), + text: z.string(), + }), + z.object({ + type: z.literal("image_url"), + image_url: z.object({ url: z.string() }), + }), +]); + +export type ConversationMessageContentPart = z.infer< + typeof ConversationMessageContentPartSchema +>; diff --git a/paddler_client_javascript/src/schemas/Embedding.ts b/paddler_client_javascript/src/schemas/Embedding.ts new file mode 100644 index 00000000..1129536b --- /dev/null +++ b/paddler_client_javascript/src/schemas/Embedding.ts @@ -0,0 +1,13 @@ +import { z } from "zod"; + +import { EmbeddingNormalizationMethodSchema } from "./EmbeddingNormalizationMethod"; +import { PoolingTypeSchema } from "./PoolingType"; + +export const EmbeddingSchema = z.object({ + embedding: z.array(z.number()), + normalization_method: EmbeddingNormalizationMethodSchema, + pooling_type: PoolingTypeSchema, + source_document_id: z.string(), +}); + +export type Embedding = z.infer; diff --git a/paddler_client_javascript/src/schemas/EmbeddingInputDocument.ts b/paddler_client_javascript/src/schemas/EmbeddingInputDocument.ts new file mode 100644 index 00000000..60924aa0 --- /dev/null +++ b/paddler_client_javascript/src/schemas/EmbeddingInputDocument.ts @@ -0,0 +1,10 @@ +import { z } from "zod"; + +export const EmbeddingInputDocumentSchema = z.object({ + content: z.string(), + id: z.string(), +}); + +export type EmbeddingInputDocument = z.infer< + typeof EmbeddingInputDocumentSchema +>; diff --git a/paddler_client_javascript/src/schemas/EmbeddingNormalizationMethod.ts b/paddler_client_javascript/src/schemas/EmbeddingNormalizationMethod.ts new file mode 100644 index 00000000..68a4729a --- /dev/null +++ b/paddler_client_javascript/src/schemas/EmbeddingNormalizationMethod.ts @@ -0,0 +1,15 @@ +import { z } from "zod"; + +export const EmbeddingNormalizationMethodSchema = z.union([ + z.literal("L2"), + z.literal("None"), + z.object({ + RmsNorm: z.object({ + epsilon: z.number(), + }), + }), +]); + +export type EmbeddingNormalizationMethod = z.infer< + typeof EmbeddingNormalizationMethodSchema +>; diff --git a/paddler_client_javascript/src/schemas/GenerateEmbeddingBatchParams.ts b/paddler_client_javascript/src/schemas/GenerateEmbeddingBatchParams.ts new file mode 100644 index 00000000..0d3194d4 --- /dev/null +++ b/paddler_client_javascript/src/schemas/GenerateEmbeddingBatchParams.ts @@ -0,0 +1,13 @@ +import { z } from "zod"; + +import { EmbeddingInputDocumentSchema } from "./EmbeddingInputDocument"; +import { EmbeddingNormalizationMethodSchema } from "./EmbeddingNormalizationMethod"; + +export const GenerateEmbeddingBatchParamsSchema = z.object({ + input_documents: z.array(EmbeddingInputDocumentSchema), + normalization_method: EmbeddingNormalizationMethodSchema, +}); + +export type GenerateEmbeddingBatchParams = z.infer< + typeof GenerateEmbeddingBatchParamsSchema +>; diff --git a/paddler_client_javascript/src/schemas/GrammarConstraint.ts b/paddler_client_javascript/src/schemas/GrammarConstraint.ts new file mode 100644 index 00000000..e2c7a42a --- /dev/null +++ b/paddler_client_javascript/src/schemas/GrammarConstraint.ts @@ -0,0 +1,15 @@ +import { z } from "zod"; + +export const GrammarConstraintSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("gbnf"), + grammar: z.string(), + root: z.string(), + }), + z.object({ + type: z.literal("json_schema"), + schema: z.string(), + }), +]); + +export type GrammarConstraint = z.infer; diff --git a/resources/ts/schemas/HuggingFaceDownloadLock.ts b/paddler_client_javascript/src/schemas/HuggingFaceDownloadLock.ts similarity index 99% rename from resources/ts/schemas/HuggingFaceDownloadLock.ts rename to paddler_client_javascript/src/schemas/HuggingFaceDownloadLock.ts index 24b043e9..1dd570a9 100644 --- a/resources/ts/schemas/HuggingFaceDownloadLock.ts +++ b/paddler_client_javascript/src/schemas/HuggingFaceDownloadLock.ts @@ -1,4 +1,5 @@ import { z } from "zod"; + import { AgentIssueModelPathSchema } from "./AgentIssueModelPath"; export const HuggingFaceDownloadLockSchema = z.object({ diff --git a/resources/ts/schemas/HuggingFaceModelReference.ts b/paddler_client_javascript/src/schemas/HuggingFaceModelReference.ts similarity index 100% rename from resources/ts/schemas/HuggingFaceModelReference.ts rename to paddler_client_javascript/src/schemas/HuggingFaceModelReference.ts diff --git a/resources/ts/schemas/InferenceParameters.ts b/paddler_client_javascript/src/schemas/InferenceParameters.ts similarity index 73% rename from resources/ts/schemas/InferenceParameters.ts rename to paddler_client_javascript/src/schemas/InferenceParameters.ts index ce5cdb6f..b7c0d206 100644 --- a/resources/ts/schemas/InferenceParameters.ts +++ b/paddler_client_javascript/src/schemas/InferenceParameters.ts @@ -23,8 +23,9 @@ export const poolingTypes = [ export const InferenceParametersSchema = z .object({ - batch_n_tokens: z.number(), + n_batch: z.number(), context_size: z.number(), + embedding_batch_size: z.number().int().min(1), enable_embeddings: z.boolean(), image_resize_to_fit: z.number().int().min(1), k_cache_dtype: z.enum(cacheDtypes), @@ -43,14 +44,3 @@ export const InferenceParametersSchema = z .strict(); export type InferenceParameters = z.infer; - -export type BooleanKeys = { - [K in keyof InferenceParameters]: InferenceParameters[K] extends boolean - ? K - : never; -}[keyof InferenceParameters]; -export type NumberKeys = { - [K in keyof InferenceParameters]: InferenceParameters[K] extends number - ? K - : never; -}[keyof InferenceParameters]; diff --git a/paddler_client_javascript/src/schemas/InferenceServiceGenerateTokensResponse.ts b/paddler_client_javascript/src/schemas/InferenceServiceGenerateTokensResponse.ts new file mode 100644 index 00000000..ec0ae68b --- /dev/null +++ b/paddler_client_javascript/src/schemas/InferenceServiceGenerateTokensResponse.ts @@ -0,0 +1,375 @@ +import { z } from "zod"; + +import { ParsedToolCallSchema } from "./ParsedToolCall"; + +export type GeneratedTokenKind = + | "content" + | "reasoning" + | "tool_call" + | "undeterminable"; + +const TokenUsageSchema = z.object({ + prompt_tokens: z.number(), + cached_prompt_tokens: z.number(), + input_image_tokens: z.number(), + input_audio_tokens: z.number(), + content_tokens: z.number(), + reasoning_tokens: z.number(), + tool_call_tokens: z.number(), + undeterminable_tokens: z.number(), +}); + +const GenerationSummarySchema = z.object({ + usage: TokenUsageSchema, +}); + +const RawToolCallTokensSchema = z.object({ + text: z.string(), + ffi_error_message: z.string(), +}); + +const OversizedImageDetailsSchema = z.object({ + image_tokens: z.number(), + n_batch: z.number(), +}); + +const GeneratedTokenResultSchema = z.union([ + z.object({ ContentToken: z.string() }), + z.object({ ReasoningToken: z.string() }), + z.object({ ToolCallToken: z.string() }), + z.object({ UndeterminableToken: z.string() }), + z.object({ Done: GenerationSummarySchema }), + z.object({ ChatTemplateError: z.string() }), + z.object({ GrammarIncompatibleWithThinking: z.string() }), + z.object({ GrammarInitializationFailed: z.string() }), + z.object({ GrammarRejectedModelOutput: z.string() }), + z.object({ GrammarSyntaxError: z.string() }), + z.object({ ImageDecodingFailed: z.string() }), + z.object({ ImageExceedsBatchSize: OversizedImageDetailsSchema }), + z.object({ MultimodalNotSupported: z.string() }), + z.object({ SamplerError: z.string() }), + z.object({ ToolCallParsed: z.array(ParsedToolCallSchema) }), + z.object({ ToolCallParseFailed: z.string() }), + z.object({ ToolCallValidationFailed: z.array(z.string()) }), + z.object({ ToolCallValidatorBuildFailed: z.string() }), + z.object({ UnrecognizedToolCallFormat: RawToolCallTokensSchema }), +]); + +type Normalised = + | { + done: true; + error: null; + generated_by: string | null; + ok: true; + rawToolCallTokens: null; + request_id: string; + summary: z.infer; + token: null; + tokenKind: null; + toolCalls: null; + } + | { + done: false; + error: null; + generated_by: string | null; + ok: true; + rawToolCallTokens: null; + request_id: string; + summary: null; + token: string; + tokenKind: GeneratedTokenKind; + toolCalls: null; + } + | { + done: false; + error: null; + generated_by: string | null; + ok: true; + rawToolCallTokens: null; + request_id: string; + summary: null; + token: null; + tokenKind: null; + toolCalls: ReadonlyArray>; + } + | { + done: false; + error: null; + generated_by: string | null; + ok: true; + rawToolCallTokens: z.infer; + request_id: string; + summary: null; + token: null; + tokenKind: null; + toolCalls: null; + } + | { + done: true; + error: { code: number; description: string }; + generated_by: string | null; + ok: false; + rawToolCallTokens: null; + request_id: string; + summary: null; + token: null; + tokenKind: null; + toolCalls: null; + } + | { + done: false; + error: { code: number; description: string }; + generated_by: string | null; + ok: false; + rawToolCallTokens: null; + request_id: string; + summary: null; + token: null; + tokenKind: null; + toolCalls: null; + }; + +function terminalError( + request_id: string, + generated_by: string | null, + code: number, + description: string, +): Normalised { + return Object.freeze({ + done: true, + error: Object.freeze({ code, description }), + generated_by, + ok: false, + rawToolCallTokens: null, + request_id, + summary: null, + token: null, + tokenKind: null, + toolCalls: null, + }); +} + +function nonTerminalError( + request_id: string, + generated_by: string | null, + code: number, + description: string, +): Normalised { + return Object.freeze({ + done: false, + error: Object.freeze({ code, description }), + generated_by, + ok: false, + rawToolCallTokens: null, + request_id, + summary: null, + token: null, + tokenKind: null, + toolCalls: null, + }); +} + +function streamingToken( + request_id: string, + generated_by: string | null, + token: string, + tokenKind: GeneratedTokenKind, +): Normalised { + return Object.freeze({ + done: false, + error: null, + generated_by, + ok: true, + rawToolCallTokens: null, + request_id, + summary: null, + token, + tokenKind, + toolCalls: null, + }); +} + +function unrecognizedToolCallFormat( + request_id: string, + generated_by: string | null, + raw: z.infer, +): Normalised { + return Object.freeze({ + done: false, + error: null, + generated_by, + ok: true, + rawToolCallTokens: Object.freeze(raw), + request_id, + summary: null, + token: null, + tokenKind: null, + toolCalls: null, + }); +} + +export const InferenceServiceGenerateTokensResponseSchema = z + .union([ + z.object({ + Error: z.object({ + error: z.object({ + code: z.number(), + description: z.string(), + }), + request_id: z.string(), + }), + }), + z.object({ + Response: z.object({ + generated_by: z.string().nullable(), + request_id: z.string(), + response: z.object({ + GeneratedToken: GeneratedTokenResultSchema, + }), + }), + }), + ]) + .transform(function (data): Normalised { + if ("Error" in data) { + return terminalError( + data.Error.request_id, + null, + data.Error.error.code, + data.Error.error.description, + ); + } + + const request_id = data.Response.request_id; + const generated_by = data.Response.generated_by; + const variant = data.Response.response.GeneratedToken; + + if ("ContentToken" in variant) { + return streamingToken(request_id, generated_by, variant.ContentToken, "content"); + } + + if ("ReasoningToken" in variant) { + return streamingToken(request_id, generated_by, variant.ReasoningToken, "reasoning"); + } + + if ("ToolCallToken" in variant) { + return streamingToken(request_id, generated_by, variant.ToolCallToken, "tool_call"); + } + + if ("UndeterminableToken" in variant) { + return streamingToken( + request_id, + generated_by, + variant.UndeterminableToken, + "undeterminable", + ); + } + + if ("Done" in variant) { + return Object.freeze({ + done: true, + error: null, + generated_by, + ok: true, + rawToolCallTokens: null, + request_id, + summary: variant.Done, + token: null, + tokenKind: null, + toolCalls: null, + }); + } + + if ("ToolCallParsed" in variant) { + return Object.freeze({ + done: false, + error: null, + generated_by, + ok: true, + rawToolCallTokens: null, + request_id, + summary: null, + token: null, + tokenKind: null, + toolCalls: Object.freeze(variant.ToolCallParsed), + }); + } + + if ("UnrecognizedToolCallFormat" in variant) { + return unrecognizedToolCallFormat( + request_id, + generated_by, + variant.UnrecognizedToolCallFormat, + ); + } + + if ("ToolCallParseFailed" in variant) { + return nonTerminalError(request_id, generated_by, 422, variant.ToolCallParseFailed); + } + + if ("ToolCallValidationFailed" in variant) { + return nonTerminalError( + request_id, + generated_by, + 422, + variant.ToolCallValidationFailed.join("; "), + ); + } + + if ("ToolCallValidatorBuildFailed" in variant) { + return terminalError( + request_id, + generated_by, + 400, + variant.ToolCallValidatorBuildFailed, + ); + } + + if ("ChatTemplateError" in variant) { + return terminalError(request_id, generated_by, 500, variant.ChatTemplateError); + } + + if ("GrammarIncompatibleWithThinking" in variant) { + return terminalError( + request_id, + generated_by, + 400, + variant.GrammarIncompatibleWithThinking, + ); + } + + if ("GrammarInitializationFailed" in variant) { + return terminalError(request_id, generated_by, 500, variant.GrammarInitializationFailed); + } + + if ("GrammarRejectedModelOutput" in variant) { + return terminalError(request_id, generated_by, 500, variant.GrammarRejectedModelOutput); + } + + if ("GrammarSyntaxError" in variant) { + return terminalError(request_id, generated_by, 400, variant.GrammarSyntaxError); + } + + if ("ImageDecodingFailed" in variant) { + return terminalError(request_id, generated_by, 400, variant.ImageDecodingFailed); + } + + if ("ImageExceedsBatchSize" in variant) { + const details = variant.ImageExceedsBatchSize; + return terminalError( + request_id, + generated_by, + 400, + `image required ${details.image_tokens} tokens but n_batch is ${details.n_batch}`, + ); + } + + if ("MultimodalNotSupported" in variant) { + return terminalError(request_id, generated_by, 400, variant.MultimodalNotSupported); + } + + return terminalError(request_id, generated_by, 500, variant.SamplerError); + }); + +export type InferenceServiceGenerateTokensResponse = z.infer< + typeof InferenceServiceGenerateTokensResponseSchema +>; diff --git a/paddler_client_javascript/src/schemas/ModelMetadata.ts b/paddler_client_javascript/src/schemas/ModelMetadata.ts new file mode 100644 index 00000000..3dc4c953 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ModelMetadata.ts @@ -0,0 +1,5 @@ +import { z } from "zod"; + +export const ModelMetadataSchema = z.record(z.string(), z.string()); + +export type ModelMetadata = z.infer; diff --git a/paddler_client_javascript/src/schemas/ParsedToolCall.ts b/paddler_client_javascript/src/schemas/ParsedToolCall.ts new file mode 100644 index 00000000..073ece92 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ParsedToolCall.ts @@ -0,0 +1,15 @@ +import { z } from "zod"; + +export const ToolCallArgumentsSchema = z.union([ + z.strictObject({ InvalidJson: z.string() }), + z.strictObject({ ValidJson: z.unknown() }), +]); + +export const ParsedToolCallSchema = z.object({ + id: z.string(), + name: z.string(), + arguments: ToolCallArgumentsSchema, +}); + +export type ParsedToolCall = z.infer; +export type ToolCallArguments = z.infer; diff --git a/paddler_client_javascript/src/schemas/PoolingType.ts b/paddler_client_javascript/src/schemas/PoolingType.ts new file mode 100644 index 00000000..ed94f496 --- /dev/null +++ b/paddler_client_javascript/src/schemas/PoolingType.ts @@ -0,0 +1,12 @@ +import { z } from "zod"; + +export const PoolingTypeSchema = z.enum([ + "Cls", + "Last", + "Mean", + "None", + "Rank", + "Unspecified", +]); + +export type PoolingType = z.infer; diff --git a/paddler_client_javascript/src/schemas/Tool.ts b/paddler_client_javascript/src/schemas/Tool.ts new file mode 100644 index 00000000..411ff36d --- /dev/null +++ b/paddler_client_javascript/src/schemas/Tool.ts @@ -0,0 +1,20 @@ +import { z } from "zod"; + +import { ValidatedParametersSchemaSchema } from "./ValidatedParametersSchema"; + +export const FunctionDefinitionSchema = z.object({ + name: z.string(), + description: z.string(), + parameters: ValidatedParametersSchemaSchema.optional(), +}); + +export const FunctionCallToolSchema = z.object({ + type: z.literal("function"), + function: FunctionDefinitionSchema, +}); + +export const ToolSchema = FunctionCallToolSchema; + +export type FunctionDefinition = z.infer; +export type FunctionCallTool = z.infer; +export type Tool = z.infer; diff --git a/paddler_client_javascript/src/schemas/ValidatedParametersSchema.ts b/paddler_client_javascript/src/schemas/ValidatedParametersSchema.ts new file mode 100644 index 00000000..fec87b41 --- /dev/null +++ b/paddler_client_javascript/src/schemas/ValidatedParametersSchema.ts @@ -0,0 +1,14 @@ +import { z } from "zod"; + +export const ValidatedParametersSchemaSchema = z + .object({ + type: z.string(), + properties: z.record(z.string(), z.unknown()).optional(), + required: z.array(z.string()).optional(), + additionalProperties: z.unknown().optional(), + }) + .strict(); + +export type ValidatedParametersSchema = z.infer< + typeof ValidatedParametersSchemaSchema +>; diff --git a/paddler_client_javascript/src/streamEventSource.ts b/paddler_client_javascript/src/streamEventSource.ts new file mode 100644 index 00000000..6f9b1c1d --- /dev/null +++ b/paddler_client_javascript/src/streamEventSource.ts @@ -0,0 +1,69 @@ +import { Observable } from "rxjs"; +import type { z } from "zod"; + +import { eventSourceConnectedState } from "./EventSourceConnectedState"; +import { eventSourceConnectionErrorState } from "./EventSourceConnectionErrorState"; +import { eventSourceDeserializationErrorState } from "./EventSourceDeserializationErrorState"; +import { eventSourceInitialState } from "./EventSourceInitialState"; +import type { EventSourceState } from "./EventSourceState"; + +export function streamEventSource({ + url, + schema, +}: { + url: URL | string; + schema: TSchema; +}): Observable> { + return new Observable>(function (subscriber) { + subscriber.next(eventSourceInitialState); + + const eventSource = new EventSource(url); + + eventSource.addEventListener("open", function () { + subscriber.next(eventSourceConnectedState); + }); + + eventSource.addEventListener("error", function () { + subscriber.next(eventSourceConnectionErrorState); + }); + + eventSource.addEventListener("message", function (event) { + if ("string" !== typeof event.data) { + subscriber.next(eventSourceDeserializationErrorState); + + return; + } + + let parsedJson: unknown; + + try { + parsedJson = JSON.parse(event.data); + } catch { + subscriber.next(eventSourceDeserializationErrorState); + + return; + } + + const result = schema.safeParse(parsedJson); + + if (!result.success) { + subscriber.next(eventSourceDeserializationErrorState); + + return; + } + + subscriber.next({ + data: result.data, + isConnected: true, + isConnectionError: false, + isDeserializationError: false, + isInitial: false, + isOk: true, + }); + }); + + return function () { + eventSource.close(); + }; + }); +} diff --git a/paddler_client_javascript/src/streamHttpNdjson.ts b/paddler_client_javascript/src/streamHttpNdjson.ts new file mode 100644 index 00000000..32e886da --- /dev/null +++ b/paddler_client_javascript/src/streamHttpNdjson.ts @@ -0,0 +1,104 @@ +import { Observable } from "rxjs"; +import type { z } from "zod"; + +import { HttpError } from "./HttpError"; +import { JsonError } from "./JsonError"; + +export function streamHttpNdjson({ + url, + body, + signal, + schema, +}: { + url: URL | string; + body: unknown; + signal: AbortSignal; + schema: TSchema; +}): Observable> { + return new Observable(function (subscriber) { + fetch(url, { + body: JSON.stringify(body), + headers: { "Content-Type": "application/json" }, + method: "POST", + signal, + }) + .then(async function (response) { + if (!response.ok) { + throw new HttpError( + response.status, + `HTTP ${response.status} ${response.statusText}`, + ); + } + + if (!response.body) { + throw new HttpError(response.status, "Response has no body"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (!signal.aborted) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + + let newlineIndex = buffer.indexOf("\n"); + + while (newlineIndex !== -1) { + const line = buffer.slice(0, newlineIndex).trim(); + buffer = buffer.slice(newlineIndex + 1); + + if (line.length > 0) { + let parsedJson: unknown; + + try { + parsedJson = JSON.parse(line); + } catch (error: unknown) { + throw new JsonError( + `Failed to parse NDJSON line: ${String(error)}`, + line, + ); + } + + subscriber.next(schema.parse(parsedJson)); + } + + newlineIndex = buffer.indexOf("\n"); + } + } + + const trailing = buffer.trim(); + + if (trailing.length > 0) { + let parsedJson: unknown; + + try { + parsedJson = JSON.parse(trailing); + } catch (error: unknown) { + throw new JsonError( + `Failed to parse trailing NDJSON line: ${String(error)}`, + trailing, + ); + } + + subscriber.next(schema.parse(parsedJson)); + } + + subscriber.complete(); + }) + .catch(function (error: unknown) { + if (signal.aborted) { + subscriber.complete(); + + return; + } + + subscriber.error(error); + }); + }); +} diff --git a/resources/ts/urlToAgentDesiredModel.ts b/paddler_client_javascript/src/urlToAgentDesiredModel.ts similarity index 66% rename from resources/ts/urlToAgentDesiredModel.ts rename to paddler_client_javascript/src/urlToAgentDesiredModel.ts index b094a00c..9372ad44 100644 --- a/resources/ts/urlToAgentDesiredModel.ts +++ b/paddler_client_javascript/src/urlToAgentDesiredModel.ts @@ -1,16 +1,18 @@ import { extractHuggingFaceUrlParts } from "./extractHuggingFaceUrlParts"; -import { type AgentDesiredModel } from "./schemas/AgentDesiredModel"; +import type { AgentDesiredModel } from "./schemas/AgentDesiredModel"; export function urlToAgentDesiredModel(url: URL): AgentDesiredModel { if (url.hostname === "huggingface.co") { return { HuggingFace: extractHuggingFaceUrlParts(url), }; - } else if (url.protocol === "agent:") { + } + + if (url.protocol === "agent:") { return { LocalToAgent: url.pathname, }; - } else { - throw new Error("Unsupported URL format"); } + + throw new Error("Unsupported URL format"); } diff --git a/paddler_client_javascript/src/webSocketProtocol.ts b/paddler_client_javascript/src/webSocketProtocol.ts new file mode 100644 index 00000000..0b3435a7 --- /dev/null +++ b/paddler_client_javascript/src/webSocketProtocol.ts @@ -0,0 +1,3 @@ +export function webSocketProtocol(httpProtocol: string): string { + return httpProtocol === "https:" ? "wss:" : "ws:"; +} diff --git a/paddler_client_javascript/tests/PaddlerError.test.ts b/paddler_client_javascript/tests/PaddlerError.test.ts new file mode 100644 index 00000000..727f8363 --- /dev/null +++ b/paddler_client_javascript/tests/PaddlerError.test.ts @@ -0,0 +1,50 @@ +import { ok, strictEqual } from "node:assert/strict"; +import { test } from "node:test"; + +import { ConnectionDroppedError } from "../src/ConnectionDroppedError"; +import { HttpError } from "../src/HttpError"; +import { JsonError } from "../src/JsonError"; +import { PaddlerError } from "../src/PaddlerError"; +import { ServerError } from "../src/ServerError"; +import { WebSocketError } from "../src/WebSocketError"; + +test("HttpError extends PaddlerError and carries the status code", function () { + const err = new HttpError(503, "Service Unavailable"); + + ok(err instanceof PaddlerError); + ok(err instanceof Error); + strictEqual(err.statusCode, 503); + strictEqual(err.message, "Service Unavailable"); + strictEqual(err.name, "HttpError"); +}); + +test("JsonError carries raw payload alongside its message", function () { + const err = new JsonError("unexpected token", "{not-json"); + + ok(err instanceof PaddlerError); + strictEqual(err.raw, "{not-json"); + strictEqual(err.name, "JsonError"); +}); + +test("WebSocketError is a distinct subclass", function () { + const err = new WebSocketError("socket closed"); + + ok(err instanceof PaddlerError); + strictEqual(err.name, "WebSocketError"); +}); + +test("ConnectionDroppedError carries the request id", function () { + const err = new ConnectionDroppedError("req-1"); + + ok(err instanceof PaddlerError); + strictEqual(err.requestId, "req-1"); + ok(err.message.includes("req-1")); +}); + +test("ServerError carries an integer code", function () { + const err = new ServerError(429, "rate limit"); + + ok(err instanceof PaddlerError); + strictEqual(err.code, 429); + strictEqual(err.message, "rate limit"); +}); diff --git a/paddler_client_javascript/tests/extractHuggingFaceUrlParts.test.ts b/paddler_client_javascript/tests/extractHuggingFaceUrlParts.test.ts new file mode 100644 index 00000000..49e3d557 --- /dev/null +++ b/paddler_client_javascript/tests/extractHuggingFaceUrlParts.test.ts @@ -0,0 +1,48 @@ +import { deepStrictEqual, throws } from "node:assert/strict"; +import { test } from "node:test"; + +import { extractHuggingFaceUrlParts } from "../src/extractHuggingFaceUrlParts"; + +test("blob URL extracts owner, repo, revision and filename", function () { + const url = new URL( + "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/blob/main/Qwen3-0.6B-Q8_0.gguf", + ); + + deepStrictEqual(extractHuggingFaceUrlParts(url), { + filename: "Qwen3-0.6B-Q8_0.gguf", + repo_id: "Qwen/Qwen3-0.6B-GGUF", + revision: "main", + }); +}); + +test("resolve URL extracts the same fields", function () { + const url = new URL( + "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf", + ); + + deepStrictEqual(extractHuggingFaceUrlParts(url), { + filename: "Qwen3-0.6B-Q8_0.gguf", + repo_id: "Qwen/Qwen3-0.6B-GGUF", + revision: "main", + }); +}); + +test("nested filename paths preserve every segment", function () { + const url = new URL( + "https://huggingface.co/owner/repo/blob/main/dir/sub/file.gguf", + ); + + deepStrictEqual(extractHuggingFaceUrlParts(url), { + filename: "dir/sub/file.gguf", + repo_id: "owner/repo", + revision: "main", + }); +}); + +test("malformed URLs throw", function () { + const url = new URL("https://huggingface.co/owner/repo"); + + throws(function () { + extractHuggingFaceUrlParts(url); + }); +}); diff --git a/paddler_client_javascript/tests/fetchJson.test.ts b/paddler_client_javascript/tests/fetchJson.test.ts new file mode 100644 index 00000000..c1b0a1fb --- /dev/null +++ b/paddler_client_javascript/tests/fetchJson.test.ts @@ -0,0 +1,72 @@ +import { deepStrictEqual, rejects } from "node:assert/strict"; +import { createServer, type RequestListener, type Server } from "node:http"; +import { test } from "node:test"; +import { z } from "zod"; + +import { fetchJson } from "../src/fetchJson"; +import { HttpError } from "../src/HttpError"; + +const Schema = z.object({ ok: z.boolean(), count: z.number() }); + +function listenOnce(handler: RequestListener): Promise<{ + server: Server; + url: URL; +}> { + return new Promise(function (resolve) { + const server = createServer(handler); + + server.listen(0, "127.0.0.1", function () { + const address = server.address(); + + if (typeof address !== "object" || address === null) { + throw new Error("server did not bind"); + } + + resolve({ + server, + url: new URL(`http://127.0.0.1:${address.port}/json`), + }); + }); + }); +} + +test("parses JSON body against the schema", async function () { + const { server, url } = await listenOnce(function (_req, res) { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ ok: true, count: 7 })); + }); + + try { + const result = await fetchJson({ + url, + signal: new AbortController().signal, + schema: Schema, + }); + + deepStrictEqual(result, { ok: true, count: 7 }); + } finally { + server.close(); + } +}); + +test("non-2xx status throws HttpError", async function () { + const { server, url } = await listenOnce(function (_req, res) { + res.writeHead(404); + res.end(); + }); + + try { + await rejects( + async function () { + return fetchJson({ + url, + signal: new AbortController().signal, + schema: Schema, + }); + }, + HttpError, + ); + } finally { + server.close(); + } +}); diff --git a/paddler_client_javascript/tests/schemas/Agent.test.ts b/paddler_client_javascript/tests/schemas/Agent.test.ts new file mode 100644 index 00000000..a01ff9c3 --- /dev/null +++ b/paddler_client_javascript/tests/schemas/Agent.test.ts @@ -0,0 +1,43 @@ +import { strictEqual, throws } from "node:assert/strict"; +import { test } from "node:test"; + +import { AgentSchema } from "../../src/schemas/Agent"; + +test("parses a fully populated agent payload", function () { + const parsed = AgentSchema.parse({ + desired_slots_total: 4, + download_current: 0, + download_filename: null, + download_total: 0, + id: "agent-0", + issues: [], + model_path: "/models/qwen.gguf", + name: "agent-0", + slots_processing: 1, + slots_total: 4, + state_application_status: "Applied", + uses_chat_template_override: false, + }); + + strictEqual(parsed.id, "agent-0"); + strictEqual(parsed.state_application_status, "Applied"); +}); + +test("rejects an unknown state_application_status", function () { + throws(function () { + AgentSchema.parse({ + desired_slots_total: 1, + download_current: 0, + download_filename: null, + download_total: 0, + id: "agent-x", + issues: [], + model_path: null, + name: null, + slots_processing: 0, + slots_total: 1, + state_application_status: "Unknown", + uses_chat_template_override: false, + }); + }); +}); diff --git a/paddler_client_javascript/tests/schemas/InferenceServiceGenerateTokensResponse.test.ts b/paddler_client_javascript/tests/schemas/InferenceServiceGenerateTokensResponse.test.ts new file mode 100644 index 00000000..b9536cb3 --- /dev/null +++ b/paddler_client_javascript/tests/schemas/InferenceServiceGenerateTokensResponse.test.ts @@ -0,0 +1,145 @@ +import { + deepStrictEqual, + notStrictEqual, + ok, + strictEqual, +} from "node:assert/strict"; +import { test } from "node:test"; + +import { InferenceServiceGenerateTokensResponseSchema } from "../../src/schemas/InferenceServiceGenerateTokensResponse"; + +test("ContentToken normalises into a streaming token with content kind", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-1", + response: { GeneratedToken: { ContentToken: "Hello" } }, + }, + }); + + strictEqual(parsed.done, false); + strictEqual(parsed.error, null); + strictEqual(parsed.token, "Hello"); + strictEqual(parsed.tokenKind, "content"); + strictEqual(parsed.toolCalls, null); +}); + +test("ReasoningToken maps to reasoning kind", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-2", + response: { GeneratedToken: { ReasoningToken: "thinking..." } }, + }, + }); + + strictEqual(parsed.token, "thinking..."); + strictEqual(parsed.tokenKind, "reasoning"); +}); + +test("Done normalises with the full usage summary", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-3", + response: { + GeneratedToken: { + Done: { + usage: { + prompt_tokens: 10, + cached_prompt_tokens: 0, + input_image_tokens: 0, + input_audio_tokens: 0, + content_tokens: 5, + reasoning_tokens: 0, + tool_call_tokens: 0, + undeterminable_tokens: 0, + }, + }, + }, + }, + }, + }); + + strictEqual(parsed.done, true); + strictEqual(parsed.error, null); + deepStrictEqual(parsed.summary?.usage.prompt_tokens, 10); +}); + +test("ToolCallValidatorBuildFailed normalises to a terminal error", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-4", + response: { + GeneratedToken: { + ToolCallValidatorBuildFailed: "schema invalid", + }, + }, + }, + }); + + strictEqual(parsed.done, true); + deepStrictEqual(parsed.error, { code: 400, description: "schema invalid" }); +}); + +test("Top-level Error envelope normalises to terminal error", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Error: { + request_id: "req-5", + error: { code: 500, description: "boom" }, + }, + }); + + strictEqual(parsed.done, true); + deepStrictEqual(parsed.error, { code: 500, description: "boom" }); +}); + +test("UnrecognizedToolCallFormat preserves text and FFI error message", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-6", + response: { + GeneratedToken: { + UnrecognizedToolCallFormat: { + text: "raw", + ffi_error_message: "common_chat_parse failed: no parser", + }, + }, + }, + }, + }); + + strictEqual(parsed.done, false); + strictEqual(parsed.error, null); + strictEqual(parsed.ok, true); + strictEqual(parsed.token, null); + strictEqual(parsed.tokenKind, null); + strictEqual(parsed.toolCalls, null); + deepStrictEqual(parsed.rawToolCallTokens, { + text: "raw", + ffi_error_message: "common_chat_parse failed: no parser", + }); +}); + +test("ImageExceedsBatchSize is terminal and describes token counts", function () { + const parsed = InferenceServiceGenerateTokensResponseSchema.parse({ + Response: { + generated_by: null, + request_id: "req-7", + response: { + GeneratedToken: { + ImageExceedsBatchSize: { image_tokens: 368, n_batch: 100 }, + }, + }, + }, + }); + + strictEqual(parsed.done, true); + strictEqual(parsed.ok, false); + notStrictEqual(parsed.error, null); + strictEqual(parsed.error?.code, 400); + ok(parsed.error?.description.includes("368")); + ok(parsed.error?.description.includes("100")); +}); diff --git a/paddler_client_javascript/tests/schemas/ParsedToolCall.test.ts b/paddler_client_javascript/tests/schemas/ParsedToolCall.test.ts new file mode 100644 index 00000000..b43ecebb --- /dev/null +++ b/paddler_client_javascript/tests/schemas/ParsedToolCall.test.ts @@ -0,0 +1,36 @@ +import { deepStrictEqual, strictEqual, throws } from "node:assert/strict"; +import { test } from "node:test"; + +import { ParsedToolCallSchema } from "../../src/schemas/ParsedToolCall"; + +test("ValidJson arguments parse with the inner JSON kept intact", function () { + const parsed = ParsedToolCallSchema.parse({ + id: "call_0", + name: "get_weather", + arguments: { ValidJson: { location: "Paris" } }, + }); + + strictEqual(parsed.id, "call_0"); + strictEqual(parsed.name, "get_weather"); + deepStrictEqual(parsed.arguments, { ValidJson: { location: "Paris" } }); +}); + +test("InvalidJson arguments preserve the raw string", function () { + const parsed = ParsedToolCallSchema.parse({ + id: "call_1", + name: "get_weather", + arguments: { InvalidJson: "not json" }, + }); + + deepStrictEqual(parsed.arguments, { InvalidJson: "not json" }); +}); + +test("rejects payloads missing the discriminated arguments wrapper", function () { + throws(function () { + ParsedToolCallSchema.parse({ + id: "call_2", + name: "get_weather", + arguments: { location: "Paris" }, + }); + }); +}); diff --git a/paddler_client_javascript/tests/streamHttpNdjson.test.ts b/paddler_client_javascript/tests/streamHttpNdjson.test.ts new file mode 100644 index 00000000..677a79d6 --- /dev/null +++ b/paddler_client_javascript/tests/streamHttpNdjson.test.ts @@ -0,0 +1,86 @@ +import { deepStrictEqual, rejects } from "node:assert/strict"; +import { createServer, type RequestListener, type Server } from "node:http"; +import { test } from "node:test"; +import { firstValueFrom, lastValueFrom, toArray } from "rxjs"; +import { z } from "zod"; + +import { HttpError } from "../src/HttpError"; +import { streamHttpNdjson } from "../src/streamHttpNdjson"; + +const Schema = z.object({ index: z.number() }); + +function listenOnce(handler: RequestListener): Promise<{ + server: Server; + url: URL; +}> { + return new Promise(function (resolve) { + const server = createServer(handler); + + server.listen(0, "127.0.0.1", function () { + const address = server.address(); + + if (typeof address !== "object" || address === null) { + throw new Error("server did not bind"); + } + + resolve({ + server, + url: new URL(`http://127.0.0.1:${address.port}/stream`), + }); + }); + }); +} + +test("yields parsed messages from an NDJSON stream", async function () { + const { server, url } = await listenOnce(function (_req, res) { + res.writeHead(200, { "Content-Type": "application/x-ndjson" }); + res.write(`${JSON.stringify({ index: 0 })}\n`); + res.write(`${JSON.stringify({ index: 1 })}\n`); + res.write(`${JSON.stringify({ index: 2 })}\n`); + res.end(); + }); + + try { + const messages = await lastValueFrom( + streamHttpNdjson({ + url, + body: {}, + signal: new AbortController().signal, + schema: Schema, + }).pipe(toArray()), + ); + + deepStrictEqual(messages, [ + { index: 0 }, + { index: 1 }, + { index: 2 }, + ]); + } finally { + server.close(); + } +}); + +test("non-2xx response throws HttpError", async function () { + const { server, url } = await listenOnce(function (_req, res) { + res.writeHead(503); + res.end(); + }); + + try { + await rejects( + async function () { + await firstValueFrom( + streamHttpNdjson({ + url, + body: {}, + signal: new AbortController().signal, + schema: Schema, + }), + ); + }, + HttpError, + ); + } finally { + server.close(); + } +}); diff --git a/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts b/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts new file mode 100644 index 00000000..06e6b517 --- /dev/null +++ b/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts @@ -0,0 +1,37 @@ +import { deepStrictEqual, throws } from "node:assert/strict"; +import { test } from "node:test"; + +import { urlToAgentDesiredModel } from "../src/urlToAgentDesiredModel"; + +test("recognizes Hugging Face URLs as HuggingFace variant", function () { + const url = new URL( + "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/blob/main/Qwen3-0.6B-Q8_0.gguf", + ); + + deepStrictEqual(urlToAgentDesiredModel(url), { + HuggingFace: { + filename: "Qwen3-0.6B-Q8_0.gguf", + repo_id: "Qwen/Qwen3-0.6B-GGUF", + revision: "main", + }, + }); +}); + +test("agent: URLs become LocalToAgent variant", function () { + const url = new URL("agent:///home/user/models/Qwen3-0.6B-Q8_0.gguf"); + + deepStrictEqual(urlToAgentDesiredModel(url), { + LocalToAgent: "/home/user/models/Qwen3-0.6B-Q8_0.gguf", + }); +}); + +test("unsupported URLs throw", function () { + const url = new URL("https://example.com/some/path"); + + throws( + function () { + urlToAgentDesiredModel(url); + }, + { message: "Unsupported URL format" }, + ); +}); diff --git a/paddler_client_javascript/tests/webSocketProtocol.test.ts b/paddler_client_javascript/tests/webSocketProtocol.test.ts new file mode 100644 index 00000000..73d4e8e1 --- /dev/null +++ b/paddler_client_javascript/tests/webSocketProtocol.test.ts @@ -0,0 +1,16 @@ +import { strictEqual } from "node:assert/strict"; +import { test } from "node:test"; + +import { webSocketProtocol } from "../src/webSocketProtocol"; + +test("https: maps to wss:", function () { + strictEqual(webSocketProtocol("https:"), "wss:"); +}); + +test("http: maps to ws:", function () { + strictEqual(webSocketProtocol("http:"), "ws:"); +}); + +test("anything other than https: maps to ws:", function () { + strictEqual(webSocketProtocol("file:"), "ws:"); +}); diff --git a/paddler_client_javascript/tsconfig.json b/paddler_client_javascript/tsconfig.json new file mode 100644 index 00000000..0f1ab405 --- /dev/null +++ b/paddler_client_javascript/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ESNext", + "moduleResolution": "Bundler", + "lib": ["ES2022", "DOM"], + "strict": true, + "noUncheckedIndexedAccess": true, + "exactOptionalPropertyTypes": true, + "declaration": true, + "declarationMap": true, + "outDir": "dist", + "rootDir": "src", + "isolatedModules": true, + "skipLibCheck": true, + "resolveJsonModule": true, + "forceConsistentCasingInFileNames": true + }, + "include": ["src/**/*"] +} diff --git a/paddler_client_javascript/tsconfig.test.json b/paddler_client_javascript/tsconfig.test.json new file mode 100644 index 00000000..908772ec --- /dev/null +++ b/paddler_client_javascript/tsconfig.test.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "noEmit": true, + "rootDir": "." + }, + "include": ["src/**/*", "tests/**/*"] +} diff --git a/paddler_client_python/.gitignore b/paddler_client_python/.gitignore index c8c7d31d..6b152e7e 100644 --- a/paddler_client_python/.gitignore +++ b/paddler_client_python/.gitignore @@ -8,6 +8,7 @@ build/ *.egg .eggs/ *.whl +.venv .venv/ venv/ .mypy_cache/ diff --git a/paddler_client_python/Makefile b/paddler_client_python/Makefile deleted file mode 100644 index 062f5ee7..00000000 --- a/paddler_client_python/Makefile +++ /dev/null @@ -1,33 +0,0 @@ -.venv: poetry.lock - poetry sync - touch .venv - -poetry.lock: pyproject.toml - poetry lock - touch poetry.lock - -.PHONY: check -check: lint test - -.PHONY: lint -lint: mypy pyright ruff - -.PHONY: mypy -mypy: .venv - poetry run mypy paddler_client tests - -.PHONY: pyright -pyright: .venv - poetry run pyright - -.PHONY: ruff -ruff: .venv - poetry run ruff check paddler_client tests - -.PHONY: fmt -fmt: - poetry run ruff format paddler_client tests - -.PHONY: test -test: .venv - poetry run pytest diff --git a/paddler_client_python/paddler_client/inference_message.py b/paddler_client_python/paddler_client/inference_message.py index ec872484..c8e63f1e 100644 --- a/paddler_client_python/paddler_client/inference_message.py +++ b/paddler_client_python/paddler_client/inference_message.py @@ -1,30 +1,104 @@ from __future__ import annotations import json +from collections.abc import Callable from dataclasses import dataclass from enum import StrEnum -from typing import Any +from typing import Any, cast from paddler_client.embedding import Embedding +from paddler_client.oversized_image_details import OversizedImageDetails +from paddler_client.parsed_tool_call import ParsedToolCall +from paddler_client.raw_tool_call_tokens import RawToolCallTokens class InferenceMessageKind(StrEnum): CHAT_TEMPLATE_ERROR = "chat_template_error" + CONTENT_TOKEN = "content_token" DONE = "done" EMBEDDING = "embedding" EMBEDDING_DONE = "embedding_done" EMBEDDING_ERROR = "embedding_error" + EMBEDDING_REJECTED_DUE_TO_ACTIVE_TOKEN_GENERATION = ( + "embedding_rejected_due_to_active_token_generation" + ) + EMBEDDING_NO_EMBEDDINGS_PRODUCED = "embedding_no_embeddings_produced" GRAMMAR_INCOMPATIBLE_WITH_THINKING = "grammar_incompatible_with_thinking" GRAMMAR_INITIALIZATION_FAILED = "grammar_initialization_failed" GRAMMAR_REJECTED_MODEL_OUTPUT = "grammar_rejected_model_output" GRAMMAR_SYNTAX_ERROR = "grammar_syntax_error" IMAGE_DECODING_FAILED = "image_decoding_failed" + IMAGE_EXCEEDS_BATCH_SIZE = "image_exceeds_batch_size" MULTIMODAL_NOT_SUPPORTED = "multimodal_not_supported" + REASONING_TOKEN = "reasoning_token" SAMPLER_ERROR = "sampler_error" SERVER_ERROR = "server_error" TIMEOUT = "timeout" - TOKEN = "token" + TOOL_CALL_PARSED = "tool_call_parsed" + TOOL_CALL_PARSE_FAILED = "tool_call_parse_failed" + TOOL_CALL_TOKEN = "tool_call_token" + TOOL_CALL_VALIDATION_FAILED = "tool_call_validation_failed" + TOOL_CALL_VALIDATOR_BUILD_FAILED = "tool_call_validator_build_failed" TOO_MANY_BUFFERED_REQUESTS = "too_many_buffered_requests" + UNDETERMINABLE_TOKEN = "undeterminable_token" + UNRECOGNIZED_TOOL_CALL_FORMAT = "unrecognized_tool_call_format" + + +_TOKEN_KINDS: frozenset[InferenceMessageKind] = frozenset( + { + InferenceMessageKind.CONTENT_TOKEN, + InferenceMessageKind.REASONING_TOKEN, + InferenceMessageKind.TOOL_CALL_TOKEN, + InferenceMessageKind.UNDETERMINABLE_TOKEN, + }, +) + + +@dataclass(frozen=True) +class TokenUsage: + prompt_tokens: int = 0 + cached_prompt_tokens: int = 0 + input_image_tokens: int = 0 + input_audio_tokens: int = 0 + content_tokens: int = 0 + reasoning_tokens: int = 0 + tool_call_tokens: int = 0 + undeterminable_tokens: int = 0 + + @property + def completion_tokens(self) -> int: + return ( + self.content_tokens + + self.reasoning_tokens + + self.tool_call_tokens + + self.undeterminable_tokens + ) + + @property + def total_tokens(self) -> int: + return self.prompt_tokens + self.completion_tokens + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TokenUsage: + return cls( + prompt_tokens=int(data.get("prompt_tokens", 0)), + cached_prompt_tokens=int(data.get("cached_prompt_tokens", 0)), + input_image_tokens=int(data.get("input_image_tokens", 0)), + input_audio_tokens=int(data.get("input_audio_tokens", 0)), + content_tokens=int(data.get("content_tokens", 0)), + reasoning_tokens=int(data.get("reasoning_tokens", 0)), + tool_call_tokens=int(data.get("tool_call_tokens", 0)), + undeterminable_tokens=int(data.get("undeterminable_tokens", 0)), + ) + + +@dataclass(frozen=True) +class GenerationSummary: + usage: TokenUsage + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> GenerationSummary: + return cls(usage=TokenUsage.from_dict(data.get("usage", {}))) @dataclass(frozen=True) @@ -35,10 +109,15 @@ class InferenceMessage: embedding_data: Embedding | None = None error_message: str | None = None error_code: int | None = None + summary: GenerationSummary | None = None + parsed_tool_calls: list[ParsedToolCall] | None = None + raw_tool_call_tokens: RawToolCallTokens | None = None + oversized_image_details: OversizedImageDetails | None = None + generated_by: str | None = None @property def is_token(self) -> bool: - return self.kind == InferenceMessageKind.TOKEN + return self.kind in _TOKEN_KINDS @property def is_done(self) -> bool: @@ -46,10 +125,7 @@ def is_done(self) -> bool: @property def is_terminal(self) -> bool: - return self.kind not in ( - InferenceMessageKind.TOKEN, - InferenceMessageKind.EMBEDDING, - ) + return not self.is_token and self.kind != InferenceMessageKind.EMBEDDING def parse_inference_client_message( @@ -71,6 +147,7 @@ def parse_inference_client_message( return _parse_response( response_envelope["request_id"], response_envelope["response"], + response_envelope.get("generated_by"), ) msg = f"Unknown inference client message format: {data}" @@ -93,18 +170,21 @@ def _parse_error_envelope( def _parse_response( request_id: str, response: str | dict[str, Any], + generated_by: str | None, ) -> InferenceMessage: if isinstance(response, str): if response == "Timeout": return InferenceMessage( request_id=request_id, kind=InferenceMessageKind.TIMEOUT, + generated_by=generated_by, ) if response == "TooManyBufferedRequests": return InferenceMessage( request_id=request_id, kind=InferenceMessageKind.TOO_MANY_BUFFERED_REQUESTS, + generated_by=generated_by, ) msg = f"Unknown response variant: {response}" @@ -114,12 +194,14 @@ def _parse_response( return _parse_generated_token_result( request_id, response["GeneratedToken"], + generated_by, ) if "Embedding" in response: return _parse_embedding_result( request_id, response["Embedding"], + generated_by, ) msg = f"Unknown response: {response}" @@ -137,35 +219,185 @@ def _parse_response( "ImageDecodingFailed": InferenceMessageKind.IMAGE_DECODING_FAILED, "MultimodalNotSupported": InferenceMessageKind.MULTIMODAL_NOT_SUPPORTED, "SamplerError": InferenceMessageKind.SAMPLER_ERROR, + "ToolCallValidatorBuildFailed": ( + InferenceMessageKind.TOOL_CALL_VALIDATOR_BUILD_FAILED + ), } -def _parse_generated_token_result( +_GENERATED_TOKEN_KINDS: dict[str, InferenceMessageKind] = { + "ContentToken": InferenceMessageKind.CONTENT_TOKEN, + "ReasoningToken": InferenceMessageKind.REASONING_TOKEN, + "ToolCallToken": InferenceMessageKind.TOOL_CALL_TOKEN, + "UndeterminableToken": InferenceMessageKind.UNDETERMINABLE_TOKEN, +} + + +def _build_done_message( request_id: str, - data: str | dict[str, Any], + payload: Any, + generated_by: str | None, ) -> InferenceMessage: - if data == "Done": - return InferenceMessage( - request_id=request_id, - kind=InferenceMessageKind.DONE, - ) + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.DONE, + summary=GenerationSummary.from_dict(payload), + generated_by=generated_by, + ) - if isinstance(data, dict): - if "Token" in data: - return InferenceMessage( - request_id=request_id, - kind=InferenceMessageKind.TOKEN, - token=data["Token"], - ) - for key, kind in _GENERATED_TOKEN_ERROR_KINDS.items(): - if key in data: - return InferenceMessage( - request_id=request_id, - kind=kind, - error_message=data[key], - ) +def _build_tool_call_parsed_message( + request_id: str, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + if not isinstance(payload, list): + msg = f"ToolCallParsed payload is not a list: {payload}" + raise TypeError(msg) + typed_calls = cast("list[dict[str, Any]]", payload) + parsed_calls: list[ParsedToolCall] = [ + ParsedToolCall.from_dict(call) for call in typed_calls + ] + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.TOOL_CALL_PARSED, + parsed_tool_calls=parsed_calls, + generated_by=generated_by, + ) + + +def _build_tool_call_parse_failed_message( + request_id: str, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.TOOL_CALL_PARSE_FAILED, + error_message=str(payload), + generated_by=generated_by, + ) + + +def _build_tool_call_validation_failed_message( + request_id: str, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + if not isinstance(payload, list): + msg = f"ToolCallValidationFailed payload is not a list: {payload}" + raise TypeError(msg) + typed_errors = cast("list[object]", payload) + joined_errors: str = "; ".join(str(error) for error in typed_errors) + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.TOOL_CALL_VALIDATION_FAILED, + error_message=joined_errors, + generated_by=generated_by, + ) + +def _build_unrecognized_tool_call_format_message( + request_id: str, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + if not isinstance(payload, dict): + msg = f"UnrecognizedToolCallFormat payload is not a dict: {payload!r}" + raise TypeError(msg) + typed_raw = cast("dict[str, Any]", payload) + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.UNRECOGNIZED_TOOL_CALL_FORMAT, + raw_tool_call_tokens=RawToolCallTokens.from_dict(typed_raw), + generated_by=generated_by, + ) + + +def _build_image_exceeds_batch_size_message( + request_id: str, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + if not isinstance(payload, dict): + msg = f"ImageExceedsBatchSize payload is not a dict: {payload!r}" + raise TypeError(msg) + typed_details = cast("dict[str, Any]", payload) + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.IMAGE_EXCEEDS_BATCH_SIZE, + oversized_image_details=OversizedImageDetails.from_dict(typed_details), + generated_by=generated_by, + ) + + +def _build_token_kind_message( + request_id: str, + kind: InferenceMessageKind, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + return InferenceMessage( + request_id=request_id, + kind=kind, + token=payload, + generated_by=generated_by, + ) + + +def _build_error_kind_message( + request_id: str, + kind: InferenceMessageKind, + payload: Any, + generated_by: str | None, +) -> InferenceMessage: + return InferenceMessage( + request_id=request_id, + kind=kind, + error_message=payload, + generated_by=generated_by, + ) + + +_StructuredHandler = Callable[[str, Any, str | None], InferenceMessage] + +_STRUCTURED_HANDLERS: dict[str, _StructuredHandler] = { + "Done": _build_done_message, + "ToolCallParsed": _build_tool_call_parsed_message, + "ToolCallParseFailed": _build_tool_call_parse_failed_message, + "ToolCallValidationFailed": _build_tool_call_validation_failed_message, + "UnrecognizedToolCallFormat": _build_unrecognized_tool_call_format_message, + "ImageExceedsBatchSize": _build_image_exceeds_batch_size_message, +} + + +def _parse_generated_token_result( + request_id: str, + data: str | dict[str, Any], + generated_by: str | None, +) -> InferenceMessage: + if not isinstance(data, dict): + msg = f"Unknown GeneratedTokenResult: {data}" + raise TypeError(msg) + for structured_key, handler in _STRUCTURED_HANDLERS.items(): + if structured_key in data: + return handler(request_id, data[structured_key], generated_by) + for token_key, token_kind in _GENERATED_TOKEN_KINDS.items(): + if token_key in data: + return _build_token_kind_message( + request_id, + token_kind, + data[token_key], + generated_by, + ) + for error_key, error_kind in _GENERATED_TOKEN_ERROR_KINDS.items(): + if error_key in data: + return _build_error_kind_message( + request_id, + error_kind, + data[error_key], + generated_by, + ) msg = f"Unknown GeneratedTokenResult: {data}" raise ValueError(msg) @@ -173,11 +405,27 @@ def _parse_generated_token_result( def _parse_embedding_result( request_id: str, data: str | dict[str, Any], + generated_by: str | None, ) -> InferenceMessage: if data == "Done": return InferenceMessage( request_id=request_id, kind=InferenceMessageKind.EMBEDDING_DONE, + generated_by=generated_by, + ) + + if data == "EmbeddingRejectedDueToActiveTokenGeneration": + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.EMBEDDING_REJECTED_DUE_TO_ACTIVE_TOKEN_GENERATION, + generated_by=generated_by, + ) + + if data == "NoEmbeddingsProduced": + return InferenceMessage( + request_id=request_id, + kind=InferenceMessageKind.EMBEDDING_NO_EMBEDDINGS_PRODUCED, + generated_by=generated_by, ) if isinstance(data, dict): @@ -188,6 +436,7 @@ def _parse_embedding_result( request_id=request_id, kind=InferenceMessageKind.EMBEDDING, embedding_data=embedding, + generated_by=generated_by, ) if "Error" in data: @@ -195,6 +444,7 @@ def _parse_embedding_result( request_id=request_id, kind=InferenceMessageKind.EMBEDDING_ERROR, error_message=data["Error"], + generated_by=generated_by, ) msg = f"Unknown EmbeddingResult: {data}" diff --git a/paddler_client_python/paddler_client/inference_parameters.py b/paddler_client_python/paddler_client/inference_parameters.py index 73d20f20..44e674f2 100644 --- a/paddler_client_python/paddler_client/inference_parameters.py +++ b/paddler_client_python/paddler_client/inference_parameters.py @@ -4,7 +4,7 @@ class InferenceParameters(BaseModel): - batch_n_tokens: int = 512 + n_batch: int = 2048 context_size: int = 8192 enable_embeddings: bool = False image_resize_to_fit: int = 1024 diff --git a/paddler_client_python/paddler_client/oversized_image_details.py b/paddler_client_python/paddler_client/oversized_image_details.py new file mode 100644 index 00000000..938a1d2c --- /dev/null +++ b/paddler_client_python/paddler_client/oversized_image_details.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class OversizedImageDetails: + image_tokens: int + n_batch: int + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> OversizedImageDetails: + return cls( + image_tokens=int(data["image_tokens"]), + n_batch=int(data["n_batch"]), + ) diff --git a/paddler_client_python/paddler_client/parsed_tool_call.py b/paddler_client_python/paddler_client/parsed_tool_call.py new file mode 100644 index 00000000..969f2004 --- /dev/null +++ b/paddler_client_python/paddler_client/parsed_tool_call.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +from paddler_client.tool_call_arguments import ( + ToolCallArguments, + parse_tool_call_arguments, +) + + +@dataclass(frozen=True) +class ParsedToolCall: + id: str + name: str + arguments: ToolCallArguments + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ParsedToolCall: + arguments_payload = data["arguments"] + if not isinstance(arguments_payload, dict): + msg = ( + f"arguments field must be a dict (tagged enum), " + f"got: {arguments_payload!r}" + ) + raise TypeError(msg) + typed_payload = cast("dict[str, Any]", arguments_payload) + return cls( + id=str(data["id"]), + name=str(data["name"]), + arguments=parse_tool_call_arguments(typed_payload), + ) diff --git a/paddler_client_python/paddler_client/raw_tool_call_tokens.py b/paddler_client_python/paddler_client/raw_tool_call_tokens.py new file mode 100644 index 00000000..37f4404c --- /dev/null +++ b/paddler_client_python/paddler_client/raw_tool_call_tokens.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class RawToolCallTokens: + text: str + ffi_error_message: str + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> RawToolCallTokens: + return cls( + text=str(data["text"]), + ffi_error_message=str(data["ffi_error_message"]), + ) diff --git a/paddler_client_python/paddler_client/tool_call_arguments.py b/paddler_client_python/paddler_client/tool_call_arguments.py new file mode 100644 index 00000000..2199fb31 --- /dev/null +++ b/paddler_client_python/paddler_client/tool_call_arguments.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class ValidJson: + value: Any + + +@dataclass(frozen=True) +class InvalidJson: + raw: str + + +ToolCallArguments = ValidJson | InvalidJson + + +def parse_tool_call_arguments(payload: dict[str, Any]) -> ToolCallArguments: + if "ValidJson" in payload: + return ValidJson(payload["ValidJson"]) + if "InvalidJson" in payload: + return InvalidJson(str(payload["InvalidJson"])) + msg = f"Unknown ToolCallArguments shape: {payload}" + raise ValueError(msg) diff --git a/paddler_client_python/poetry.lock b/paddler_client_python/poetry.lock index 330fda2c..dae3d93c 100644 --- a/paddler_client_python/poetry.lock +++ b/paddler_client_python/poetry.lock @@ -749,30 +749,30 @@ testing = ["process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.11.13" +version = "0.15.12" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.11.13-py3-none-linux_armv6l.whl", hash = "sha256:4bdfbf1240533f40042ec00c9e09a3aade6f8c10b6414cf11b519488d2635d46"}, - {file = "ruff-0.11.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aef9c9ed1b5ca28bb15c7eac83b8670cf3b20b478195bd49c8d756ba0a36cf48"}, - {file = "ruff-0.11.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53b15a9dfdce029c842e9a5aebc3855e9ab7771395979ff85b7c1dedb53ddc2b"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab153241400789138d13f362c43f7edecc0edfffce2afa6a68434000ecd8f69a"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c51f93029d54a910d3d24f7dd0bb909e31b6cd989a5e4ac513f4eb41629f0dc"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1808b3ed53e1a777c2ef733aca9051dc9bf7c99b26ece15cb59a0320fbdbd629"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d28ce58b5ecf0f43c1b71edffabe6ed7f245d5336b17805803312ec9bc665933"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55e4bc3a77842da33c16d55b32c6cac1ec5fb0fbec9c8c513bdce76c4f922165"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:633bf2c6f35678c56ec73189ba6fa19ff1c5e4807a78bf60ef487b9dd272cc71"}, - {file = "ruff-0.11.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ffbc82d70424b275b089166310448051afdc6e914fdab90e08df66c43bb5ca9"}, - {file = "ruff-0.11.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a9ddd3ec62a9a89578c85842b836e4ac832d4a2e0bfaad3b02243f930ceafcc"}, - {file = "ruff-0.11.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d237a496e0778d719efb05058c64d28b757c77824e04ffe8796c7436e26712b7"}, - {file = "ruff-0.11.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26816a218ca6ef02142343fd24c70f7cd8c5aa6c203bca284407adf675984432"}, - {file = "ruff-0.11.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:51c3f95abd9331dc5b87c47ac7f376db5616041173826dfd556cfe3d4977f492"}, - {file = "ruff-0.11.13-py3-none-win32.whl", hash = "sha256:96c27935418e4e8e77a26bb05962817f28b8ef3843a6c6cc49d8783b5507f250"}, - {file = "ruff-0.11.13-py3-none-win_amd64.whl", hash = "sha256:29c3189895a8a6a657b7af4e97d330c8a3afd2c9c8f46c81e2fc5a31866517e3"}, - {file = "ruff-0.11.13-py3-none-win_arm64.whl", hash = "sha256:b4385285e9179d608ff1d2fb9922062663c658605819a6876d8beef0c30b7f3b"}, - {file = "ruff-0.11.13.tar.gz", hash = "sha256:26fa247dc68d1d4e72c179e08889a25ac0c7ba4d78aecfc835d49cbfd60bf514"}, + {file = "ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c"}, + {file = "ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c"}, + {file = "ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e"}, + {file = "ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20"}, + {file = "ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d"}, + {file = "ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f"}, + {file = "ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6"}, ] [[package]] @@ -884,4 +884,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "583a72dfa1b3b77080e861788d046f3b896d7e0d4045cfb371fac424ad51d54b" +content-hash = "8ede1b613e2e51eabdae4fd71df2ce16d21bf2c3063b3769d016d9ecf452d629" diff --git a/paddler_client_python/pyproject.toml b/paddler_client_python/pyproject.toml index 910f011a..f6699452 100644 --- a/paddler_client_python/pyproject.toml +++ b/paddler_client_python/pyproject.toml @@ -25,7 +25,7 @@ pytest-asyncio = "^1.3" pytest-cov = "^7" pygments = "^2.20" pyright = "^1" -ruff = "^0.11" +ruff = "0.15.12" mypy = "^1" [build-system] diff --git a/paddler_client_python/tests/test_client_inference.py b/paddler_client_python/tests/test_client_inference.py index b17ae7d0..f22f5b88 100644 --- a/paddler_client_python/tests/test_client_inference.py +++ b/paddler_client_python/tests/test_client_inference.py @@ -27,7 +27,7 @@ def _make_ndjson_token_response( { "Response": { "request_id": request_id, - "response": {"GeneratedToken": {"Token": token}}, + "response": {"GeneratedToken": {"ContentToken": token}}, } } ) @@ -38,7 +38,22 @@ def _make_ndjson_done_response(request_id: str) -> str: { "Response": { "request_id": request_id, - "response": {"GeneratedToken": "Done"}, + "response": { + "GeneratedToken": { + "Done": { + "usage": { + "prompt_tokens": 0, + "cached_prompt_tokens": 0, + "input_image_tokens": 0, + "input_audio_tokens": 0, + "content_tokens": 0, + "reasoning_tokens": 0, + "tool_call_tokens": 0, + "undeterminable_tokens": 0, + } + } + } + }, } } ) @@ -119,7 +134,7 @@ def handler(request: httpx.Request) -> httpx.Response: messages.append(message) assert len(messages) == 2 - assert messages[0].kind == InferenceMessageKind.TOKEN + assert messages[0].kind == InferenceMessageKind.CONTENT_TOKEN assert messages[0].token == "hi" assert messages[1].kind == InferenceMessageKind.DONE diff --git a/paddler_client_python/tests/test_client_management.py b/paddler_client_python/tests/test_client_management.py index fdacccfa..603f2903 100644 --- a/paddler_client_python/tests/test_client_management.py +++ b/paddler_client_python/tests/test_client_management.py @@ -90,7 +90,7 @@ async def test_get_balancer_desired_state_deserializes() -> None: response_data = { "chat_template_override": None, "inference_parameters": { - "batch_n_tokens": 512, + "n_batch": 512, "context_size": 8192, "enable_embeddings": False, "image_resize_to_fit": 1024, diff --git a/paddler_client_python/tests/test_inference_message.py b/paddler_client_python/tests/test_inference_message.py index 10ae28b3..4f053da1 100644 --- a/paddler_client_python/tests/test_inference_message.py +++ b/paddler_client_python/tests/test_inference_message.py @@ -6,27 +6,260 @@ ) -def test_parse_token_response() -> None: +def test_parse_content_token_response() -> None: data = { "Response": { "request_id": "req-1", - "response": {"GeneratedToken": {"Token": "hello"}}, + "response": {"GeneratedToken": {"ContentToken": "hello"}}, } } message = parse_inference_client_message(data) assert message.request_id == "req-1" - assert message.kind == InferenceMessageKind.TOKEN + assert message.kind == InferenceMessageKind.CONTENT_TOKEN assert message.token == "hello" assert message.is_token assert not message.is_terminal -def test_parse_done_response() -> None: +def test_parse_reasoning_token_response() -> None: data = { "Response": { "request_id": "req-1", - "response": {"GeneratedToken": "Done"}, + "response": {"GeneratedToken": {"ReasoningToken": "thinking"}}, + } + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.REASONING_TOKEN + assert message.token == "thinking" + assert message.is_token + assert not message.is_terminal + + +def test_parse_tool_call_token_response() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"GeneratedToken": {"ToolCallToken": '{"name":'}}, + } + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.TOOL_CALL_TOKEN + assert message.token == '{"name":' + assert message.is_token + assert not message.is_terminal + + +def test_parse_tool_call_parsed_response_carries_structured_calls() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "ToolCallParsed": [ + { + "id": "call_42", + "name": "get_weather", + "arguments": {"ValidJson": {"location": "Paris"}}, + }, + ], + }, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.TOOL_CALL_PARSED + assert message.parsed_tool_calls is not None + assert len(message.parsed_tool_calls) == 1 + assert message.parsed_tool_calls[0].id == "call_42" + assert message.parsed_tool_calls[0].name == "get_weather" + assert not message.is_token + + +def test_parse_tool_call_parse_failed_response_carries_error() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": {"ToolCallParseFailed": "syntax error at 12"}, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.TOOL_CALL_PARSE_FAILED + assert message.error_message == "syntax error at 12" + + +def test_parse_tool_call_validation_failed_response_joins_errors() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "ToolCallValidationFailed": [ + "missing field 'location'", + "extra field 'foo'", + ], + }, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.TOOL_CALL_VALIDATION_FAILED + assert message.error_message == "missing field 'location'; extra field 'foo'" + + +def test_parse_tool_call_parsed_with_non_list_payload_raises() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"GeneratedToken": {"ToolCallParsed": "not a list"}}, + }, + } + + with pytest.raises(TypeError, match="ToolCallParsed payload is not a list"): + parse_inference_client_message(data) + + +def test_parse_tool_call_validation_failed_with_non_list_payload_raises() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"GeneratedToken": {"ToolCallValidationFailed": "oops"}}, + }, + } + + with pytest.raises( + TypeError, + match="ToolCallValidationFailed payload is not a list", + ): + parse_inference_client_message(data) + + +def test_parse_unrecognized_tool_call_format_response_carries_text_and_ffi_error() -> ( + None +): + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "UnrecognizedToolCallFormat": { + "text": "blah", + "ffi_error_message": "common_chat_parse failed: no parser", + }, + }, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.UNRECOGNIZED_TOOL_CALL_FORMAT + assert message.raw_tool_call_tokens is not None + assert message.raw_tool_call_tokens.text == "blah" + assert ( + message.raw_tool_call_tokens.ffi_error_message + == "common_chat_parse failed: no parser" + ) + assert not message.is_token + + +def test_parse_unrecognized_tool_call_format_with_non_dict_payload_raises() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": {"UnrecognizedToolCallFormat": "raw text only"}, + }, + }, + } + + with pytest.raises( + TypeError, + match="UnrecognizedToolCallFormat payload is not a dict", + ): + parse_inference_client_message(data) + + +def test_parse_image_exceeds_batch_size_response_carries_token_counts() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "ImageExceedsBatchSize": { + "image_tokens": 368, + "n_batch": 100, + }, + }, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.IMAGE_EXCEEDS_BATCH_SIZE + assert message.oversized_image_details is not None + assert message.oversized_image_details.image_tokens == 368 + assert message.oversized_image_details.n_batch == 100 + assert not message.is_token + + +def test_parse_image_exceeds_batch_size_with_non_dict_payload_raises() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": {"ImageExceedsBatchSize": "scalar payload"}, + }, + }, + } + + with pytest.raises( + TypeError, + match="ImageExceedsBatchSize payload is not a dict", + ): + parse_inference_client_message(data) + + +def test_parse_undeterminable_token_response() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"GeneratedToken": {"UndeterminableToken": "raw"}}, + } + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.UNDETERMINABLE_TOKEN + assert message.token == "raw" + assert message.is_token + + +def test_parse_done_response_carries_summary() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "Done": { + "usage": { + "prompt_tokens": 4, + "cached_prompt_tokens": 0, + "input_image_tokens": 0, + "input_audio_tokens": 0, + "content_tokens": 6, + "reasoning_tokens": 1, + "tool_call_tokens": 0, + "undeterminable_tokens": 0, + } + } + } + }, } } message = parse_inference_client_message(data) @@ -34,6 +267,12 @@ def test_parse_done_response() -> None: assert message.kind == InferenceMessageKind.DONE assert message.is_done assert message.is_terminal + assert message.summary is not None + assert message.summary.usage.prompt_tokens == 4 + assert message.summary.usage.content_tokens == 6 + assert message.summary.usage.reasoning_tokens == 1 + assert message.summary.usage.completion_tokens == 7 + assert message.summary.usage.total_tokens == 11 def test_parse_timeout() -> None: @@ -95,11 +334,7 @@ def test_parse_grammar_incompatible_with_thinking() -> None: data = { "Response": { "request_id": "req-1", - "response": { - "GeneratedToken": { - "GrammarIncompatibleWithThinking": "err" - } - }, + "response": {"GeneratedToken": {"GrammarIncompatibleWithThinking": "err"}}, } } message = parse_inference_client_message(data) @@ -114,9 +349,7 @@ def test_parse_grammar_initialization_failed() -> None: "Response": { "request_id": "req-1", "response": { - "GeneratedToken": { - "GrammarInitializationFailed": "null grammar" - } + "GeneratedToken": {"GrammarInitializationFailed": "null grammar"} }, } } @@ -132,9 +365,7 @@ def test_parse_grammar_rejected_model_output() -> None: "Response": { "request_id": "req-1", "response": { - "GeneratedToken": { - "GrammarRejectedModelOutput": "token rejected" - } + "GeneratedToken": {"GrammarRejectedModelOutput": "token rejected"} }, } } @@ -149,9 +380,7 @@ def test_parse_grammar_syntax_error() -> None: data = { "Response": { "request_id": "req-1", - "response": { - "GeneratedToken": {"GrammarSyntaxError": "invalid schema"} - }, + "response": {"GeneratedToken": {"GrammarSyntaxError": "invalid schema"}}, } } message = parse_inference_client_message(data) @@ -161,6 +390,28 @@ def test_parse_grammar_syntax_error() -> None: assert message.is_terminal +def test_parse_tool_call_validator_build_failed_response_carries_error() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": { + "GeneratedToken": { + "ToolCallValidatorBuildFailed": ( + 'tool "get_weather" parameters are not a valid JSON Schema' + ), + }, + }, + }, + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.TOOL_CALL_VALIDATOR_BUILD_FAILED + assert message.error_message == ( + 'tool "get_weather" parameters are not a valid JSON Schema' + ) + assert message.is_terminal + + def test_parse_image_decoding_failed() -> None: data = { "Response": { @@ -255,9 +506,41 @@ def test_parse_embedding_error() -> None: assert message.is_terminal +def test_parse_embedding_no_embeddings_produced() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"Embedding": "NoEmbeddingsProduced"}, + } + } + message = parse_inference_client_message(data) + + assert message.kind == InferenceMessageKind.EMBEDDING_NO_EMBEDDINGS_PRODUCED + assert message.is_terminal + + +def test_parse_embedding_rejected_due_to_active_token_generation() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"Embedding": "EmbeddingRejectedDueToActiveTokenGeneration"}, + } + } + message = parse_inference_client_message(data) + + assert ( + message.kind + == InferenceMessageKind.EMBEDDING_REJECTED_DUE_TO_ACTIVE_TOKEN_GENERATION + ) + assert message.is_terminal + + def test_parse_json_string() -> None: json_str = ( - '{"Response": {"request_id": "req-1", "response": {"GeneratedToken": "Done"}}}' + '{"Response": {"request_id": "req-1", "response": {"GeneratedToken": ' + '{"Done": {"usage": {"prompt_tokens": 0, "cached_prompt_tokens": 0, ' + '"input_image_tokens": 0, "input_audio_tokens": 0, "content_tokens": 0, ' + '"reasoning_tokens": 0, "tool_call_tokens": 0, "undeterminable_tokens": 0}}}}}}' ) message = parse_inference_client_message(json_str) @@ -310,6 +593,18 @@ def test_parse_unknown_generated_token_result_raises() -> None: parse_inference_client_message(data) +def test_parse_string_generated_token_result_raises() -> None: + data = { + "Response": { + "request_id": "req-1", + "response": {"GeneratedToken": "Done"}, + } + } + + with pytest.raises(TypeError, match="Unknown GeneratedTokenResult"): + parse_inference_client_message(data) + + def test_parse_unknown_embedding_result_raises() -> None: data = { "Response": { diff --git a/paddler_client_python/tests/test_integration_inference.py b/paddler_client_python/tests/test_integration_inference.py index 3293f083..ccb03a69 100644 --- a/paddler_client_python/tests/test_integration_inference.py +++ b/paddler_client_python/tests/test_integration_inference.py @@ -60,7 +60,7 @@ async def test_http_continue_from_conversation_history( ): _assert_not_error(message) - if message.kind == InferenceMessageKind.TOKEN: + if message.kind == InferenceMessageKind.CONTENT_TOKEN: assert message.token is not None tokens.append(message.token) elif message.is_terminal: @@ -89,7 +89,7 @@ async def test_websocket_continue_from_conversation_history( async for message in stream: _assert_not_error(message) - if message.kind == InferenceMessageKind.TOKEN: + if message.kind == InferenceMessageKind.CONTENT_TOKEN: assert message.token is not None tokens.append(message.token) elif message.is_terminal: @@ -114,7 +114,7 @@ async def test_websocket_continue_from_raw_prompt( async for message in stream: _assert_not_error(message) - if message.kind == InferenceMessageKind.TOKEN: + if message.kind == InferenceMessageKind.CONTENT_TOKEN: assert message.token is not None tokens.append(message.token) elif message.is_terminal: diff --git a/paddler_client_python/tests/test_parsed_tool_call.py b/paddler_client_python/tests/test_parsed_tool_call.py new file mode 100644 index 00000000..e2dd0aba --- /dev/null +++ b/paddler_client_python/tests/test_parsed_tool_call.py @@ -0,0 +1,41 @@ +import pytest + +from paddler_client.parsed_tool_call import ParsedToolCall +from paddler_client.tool_call_arguments import InvalidJson, ValidJson + + +def test_from_dict_with_valid_json_arguments() -> None: + parsed = ParsedToolCall.from_dict( + { + "id": "call_42", + "name": "get_weather", + "arguments": {"ValidJson": {"location": "Paris"}}, + }, + ) + + assert parsed.id == "call_42" + assert parsed.name == "get_weather" + assert parsed.arguments == ValidJson({"location": "Paris"}) + + +def test_from_dict_with_invalid_json_arguments() -> None: + parsed = ParsedToolCall.from_dict( + { + "id": "call_99", + "name": "freeform", + "arguments": {"InvalidJson": "{half a json"}, + }, + ) + + assert parsed.arguments == InvalidJson("{half a json") + + +def test_from_dict_with_non_dict_arguments_raises() -> None: + with pytest.raises(TypeError, match="arguments field must be a dict"): + ParsedToolCall.from_dict( + { + "id": "x", + "name": "y", + "arguments": "not a dict", + }, + ) diff --git a/paddler_client_python/tests/test_response_stream.py b/paddler_client_python/tests/test_response_stream.py index ff0d7943..eb3fbbc2 100644 --- a/paddler_client_python/tests/test_response_stream.py +++ b/paddler_client_python/tests/test_response_stream.py @@ -11,7 +11,7 @@ def token_message() -> InferenceMessage: return InferenceMessage( request_id="req-1", - kind=InferenceMessageKind.TOKEN, + kind=InferenceMessageKind.CONTENT_TOKEN, token="hello", ) diff --git a/paddler_client_python/tests/test_stream_ndjson.py b/paddler_client_python/tests/test_stream_ndjson.py index f985c590..0d9fa036 100644 --- a/paddler_client_python/tests/test_stream_ndjson.py +++ b/paddler_client_python/tests/test_stream_ndjson.py @@ -13,7 +13,7 @@ def _make_token_line(request_id: str, token: str) -> str: { "Response": { "request_id": request_id, - "response": {"GeneratedToken": {"Token": token}}, + "response": {"GeneratedToken": {"ContentToken": token}}, } } ) @@ -24,7 +24,22 @@ def _make_done_line(request_id: str) -> str: { "Response": { "request_id": request_id, - "response": {"GeneratedToken": "Done"}, + "response": { + "GeneratedToken": { + "Done": { + "usage": { + "prompt_tokens": 0, + "cached_prompt_tokens": 0, + "input_image_tokens": 0, + "input_audio_tokens": 0, + "content_tokens": 0, + "reasoning_tokens": 0, + "tool_call_tokens": 0, + "undeterminable_tokens": 0, + } + } + } + }, } } ) @@ -48,7 +63,7 @@ def handler(request: httpx.Request) -> httpx.Response: messages.append(message) assert len(messages) == 2 - assert messages[0].kind == InferenceMessageKind.TOKEN + assert messages[0].kind == InferenceMessageKind.CONTENT_TOKEN assert messages[0].token == "hello" assert messages[1].kind == InferenceMessageKind.DONE diff --git a/paddler_client_python/tests/test_tool.py b/paddler_client_python/tests/test_tool.py index 8c53127d..bf96cfc2 100644 --- a/paddler_client_python/tests/test_tool.py +++ b/paddler_client_python/tests/test_tool.py @@ -27,11 +27,13 @@ def test_tool_with_parameters_serialization() -> None: function=Function( name="get_weather", description="Get weather", - parameters=ValidatedParametersSchema.model_validate({ - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }), + parameters=ValidatedParametersSchema.model_validate( + { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + ), ) ) dumped = tool.model_dump( diff --git a/paddler_client_python/tests/test_tool_call_arguments.py b/paddler_client_python/tests/test_tool_call_arguments.py new file mode 100644 index 00000000..b9df15ed --- /dev/null +++ b/paddler_client_python/tests/test_tool_call_arguments.py @@ -0,0 +1,30 @@ +import pytest + +from paddler_client.tool_call_arguments import ( + InvalidJson, + ValidJson, + parse_tool_call_arguments, +) + + +def test_parse_valid_json_with_object() -> None: + result = parse_tool_call_arguments({"ValidJson": {"location": "Paris"}}) + + assert result == ValidJson({"location": "Paris"}) + + +def test_parse_valid_json_with_array() -> None: + result = parse_tool_call_arguments({"ValidJson": [1, 2, 3]}) + + assert result == ValidJson([1, 2, 3]) + + +def test_parse_invalid_json_carries_raw_text() -> None: + result = parse_tool_call_arguments({"InvalidJson": "{half a json"}) + + assert result == InvalidJson("{half a json") + + +def test_parse_unknown_shape_raises() -> None: + with pytest.raises(ValueError, match="Unknown ToolCallArguments shape"): + parse_tool_call_arguments({"SomethingElse": "x"}) diff --git a/paddler_gui/src/app.rs b/paddler_gui/src/app.rs index a26c793d..7bfa8024 100644 --- a/paddler_gui/src/app.rs +++ b/paddler_gui/src/app.rs @@ -34,7 +34,7 @@ use paddler_bootstrap::agent_runner::AgentRunner; use paddler_bootstrap::agent_runner::AgentRunnerParams; use paddler_bootstrap::balancer_runner::BalancerRunner; use paddler_bootstrap::balancer_runner::BalancerRunnerParams; -use paddler_bootstrap::shutdown_signal::wait_for_shutdown_signal; +use paddler_bootstrap::shutdown_signal::register_shutdown_signals; use paddler_types::balancer_desired_state::BalancerDesiredState; use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; @@ -64,13 +64,24 @@ static BETA_IMAGE: LazyLock = LazyLock::new(|| { fn shutdown_signal_stream() -> impl iced::futures::Stream { iced::stream::channel(1, async move |mut output| { - if let Err(error) = wait_for_shutdown_signal().await { + let shutdown_signals = match register_shutdown_signals() { + Ok(shutdown_signals) => shutdown_signals, + Err(error) => { + log::error!("failed to register shutdown signal handlers: {error}"); + + return; + } + }; + + if let Err(error) = shutdown_signals.wait().await { log::error!("shutdown signal listener failed: {error}"); return; } - let _ = output.send(Message::Quit).await; + if let Err(err) = output.send(Message::Quit).await { + log::warn!("Failed to deliver Quit message to iced runtime (receiver dropped): {err}"); + } }) } @@ -402,11 +413,26 @@ impl App { } } result = &mut completion_future => { - let message = match result { - Ok(()) => Message::AgentStopped, - Err(error) => Message::AgentFailed(error.to_string()), - }; - let _ = output.send(message).await; + match result { + Ok(()) => { + if let Err(err) = output.send(Message::AgentStopped).await { + log::warn!( + "Failed to deliver AgentStopped to UI (receiver dropped): {err}" + ); + } + } + Err(error) => { + let detail = error.to_string(); + if let Err(err) = output + .send(Message::AgentFailed(detail.clone())) + .await + { + log::error!( + "Failed to deliver AgentFailed to UI (receiver dropped); lost detail: {detail}; send err: {err}" + ); + } + } + } return; } @@ -488,9 +514,12 @@ impl App { let mut runner = match BalancerRunner::start(params).await { Ok(runner) => runner, Err(error) => { - let _ = output - .send(Message::BalancerFailed(error.to_string())) - .await; + let detail = error.to_string(); + if let Err(err) = output.send(Message::BalancerFailed(detail.clone())).await { + log::error!( + "Failed to deliver BalancerFailed to UI (receiver dropped); lost detail: {detail}; send err: {err}" + ); + } return; } @@ -568,11 +597,26 @@ impl App { } } result = &mut completion_future => { - let message = match result { - Ok(()) => Message::BalancerStopped, - Err(error) => Message::BalancerFailed(error.to_string()), - }; - let _ = output.send(message).await; + match result { + Ok(()) => { + if let Err(err) = output.send(Message::BalancerStopped).await { + log::warn!( + "Failed to deliver BalancerStopped to UI (receiver dropped): {err}" + ); + } + } + Err(error) => { + let detail = error.to_string(); + if let Err(err) = output + .send(Message::BalancerFailed(detail.clone())) + .await + { + log::error!( + "Failed to deliver BalancerFailed to UI (receiver dropped); lost detail: {detail}; send err: {err}" + ); + } + } + } return; } diff --git a/paddler_gui/src/running_balancer_snapshot.rs b/paddler_gui/src/running_balancer_snapshot.rs index 5782ff99..dcf57484 100644 --- a/paddler_gui/src/running_balancer_snapshot.rs +++ b/paddler_gui/src/running_balancer_snapshot.rs @@ -53,7 +53,6 @@ mod tests { use std::sync::atomic::AtomicUsize; use anyhow::Result; - use paddler::agent_desired_state::AgentDesiredState; use paddler::atomic_value::AtomicValue; use paddler::balancer::agent_controller::AgentController; use paddler::balancer::agent_controller_pool::AgentControllerPool; @@ -64,6 +63,7 @@ mod tests { use paddler::balancer_applicable_state::BalancerApplicableState; use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; use paddler_types::agent_desired_model::AgentDesiredModel; + use paddler_types::agent_desired_state::AgentDesiredState; use paddler_types::agent_state_application_status::AgentStateApplicationStatus; use paddler_types::inference_parameters::InferenceParameters; use tokio::sync::mpsc; diff --git a/paddler_gui/src/start_balancer_form_handler.rs b/paddler_gui/src/start_balancer_form_handler.rs index 051faaf6..f3254518 100644 --- a/paddler_gui/src/start_balancer_form_handler.rs +++ b/paddler_gui/src/start_balancer_form_handler.rs @@ -41,6 +41,10 @@ fn validate_required_address(raw: &str) -> Result { validate_optional_address(raw)?.ok_or_else(|| "Address is required.".to_owned()) } +#[expect( + clippy::large_enum_variant, + reason = "ephemeral value, immediately consumed" +)] #[derive(Debug, Clone)] pub enum Message { SetBalancerAddress(String), diff --git a/paddler_tests/Cargo.toml b/paddler_tests/Cargo.toml index df61f89e..cd927f8f 100644 --- a/paddler_tests/Cargo.toml +++ b/paddler_tests/Cargo.toml @@ -21,6 +21,8 @@ anyhow = { workspace = true } async-stream = { workspace = true } base64 = { workspace = true } futures-util = { workspace = true } +hf-hub = { workspace = true } +llama-cpp-bindings = { workspace = true } log = { workspace = true } nix = { workspace = true } paddler = { workspace = true } diff --git a/paddler_tests/src/agent_config.rs b/paddler_tests/src/agent_config.rs new file mode 100644 index 00000000..bc3cb9f9 --- /dev/null +++ b/paddler_tests/src/agent_config.rs @@ -0,0 +1,25 @@ +#[derive(Clone, Debug)] +pub struct AgentConfig { + pub name: String, + pub slot_count: i32, +} + +impl AgentConfig { + #[must_use] + pub fn single(slot_count: i32) -> Self { + Self { + name: "test-agent".to_owned(), + slot_count, + } + } + + #[must_use] + pub fn uniform(count: usize, slot_count: i32) -> Vec { + (0..count) + .map(|agent_index| Self { + name: format!("test-agent-{agent_index}"), + slot_count, + }) + .collect() + } +} diff --git a/paddler_tests/src/agents_stream_watcher.rs b/paddler_tests/src/agents_stream_watcher.rs index 5b79dcdc..0f93aba9 100644 --- a/paddler_tests/src/agents_stream_watcher.rs +++ b/paddler_tests/src/agents_stream_watcher.rs @@ -55,20 +55,98 @@ impl AgentsStreamWatcher { )) } - pub async fn wait_for_slots_ready( + pub async fn until_agent( &mut self, - expected_agent_count: usize, - slots_per_agent: i32, - ) -> Result<()> { + agent_id: &str, + mut predicate: TPredicate, + ) -> Result + where + TPredicate: FnMut(&AgentControllerPoolSnapshot) -> bool, + { + while let Some(item) = self.stream.next().await { + let snapshot = item.context("agents stream yielded an error")?; + + let agent_present = snapshot + .agents + .iter() + .any(|registered_agent| registered_agent.id == agent_id); + + if !agent_present { + bail!( + "agent {agent_id} disappeared from the balancer's agent pool before the predicate was satisfied; this means the agent subprocess died or its WebSocket dropped" + ); + } + + if predicate(&snapshot) { + return Ok(snapshot); + } + } + + Err(anyhow!( + "agents stream closed before predicate was satisfied" + )) + } + + pub async fn wait_for_agent_ready( + &mut self, + agent_name: &str, + expected_slot_count: i32, + ) -> Result { + let predicate_name = agent_name.to_owned(); + let snapshot = self + .until(move |snapshot| { + snapshot.agents.iter().any(|registered_agent| { + registered_agent.name.as_deref() == Some(predicate_name.as_str()) + && (registered_agent.slots_total == expected_slot_count + || !registered_agent.issues.is_empty()) + }) + }) + .await + .with_context(|| format!("agent {agent_name:?} did not reach slot readiness"))?; + + let agent_with_issues = snapshot.agents.iter().find(|registered_agent| { + registered_agent.name.as_deref() == Some(agent_name) + && !registered_agent.issues.is_empty() + }); + + if let Some(failing_agent) = agent_with_issues { + bail!( + "agent {agent_name:?} reported issues during startup: {issues:?}", + issues = failing_agent.issues, + ); + } + + Ok(snapshot) + } + + pub async fn wait_for_slots_ready(&mut self, expected_slot_counts: &[i32]) -> Result<()> { + let mut expected_sorted: Vec = expected_slot_counts.to_vec(); + expected_sorted.sort_unstable(); + let expected_agent_count = expected_sorted.len(); + let snapshot = self .until(move |snapshot| { - snapshot.agents.len() >= expected_agent_count - && snapshot.agents.iter().all(|agent| { - agent.slots_total >= slots_per_agent || !agent.issues.is_empty() - }) + if snapshot.agents.len() < expected_agent_count { + return false; + } + + let any_with_issues = snapshot.agents.iter().any(|agent| !agent.issues.is_empty()); + + if any_with_issues { + return true; + } + + let mut observed_slot_counts: Vec = snapshot + .agents + .iter() + .map(|agent| agent.slots_total) + .collect(); + observed_slot_counts.sort_unstable(); + + observed_slot_counts == expected_sorted }) .await - .context("agents did not reach the requested slot count")?; + .context("agents did not reach the requested slot counts")?; let agents_with_issues: Vec = snapshot .agents @@ -87,3 +165,233 @@ impl AgentsStreamWatcher { Ok(()) } } + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use futures_util::stream; + use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; + use paddler_types::agent_issue::AgentIssue; + use paddler_types::agent_issue_params::ModelPath; + use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + + use super::*; + + fn snapshot_with_agent( + agent_id: &str, + issues: BTreeSet, + ) -> AgentControllerSnapshot { + snapshot_with_agent_and_slots(agent_id, issues, 0) + } + + fn snapshot_with_agent_and_slots( + agent_id: &str, + issues: BTreeSet, + slots_total: i32, + ) -> AgentControllerSnapshot { + AgentControllerSnapshot { + desired_slots_total: 1, + download_current: 0, + download_filename: None, + download_total: 0, + id: agent_id.to_owned(), + issues, + model_path: None, + name: Some(agent_id.to_owned()), + slots_processing: 0, + slots_total, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: false, + } + } + + fn unable_to_find_chat_template_issue() -> BTreeSet { + let mut issues = BTreeSet::new(); + issues.insert(AgentIssue::UnableToFindChatTemplate(ModelPath { + model_path: "/models/embed.gguf".to_owned(), + })); + issues + } + + fn make_watcher(snapshots: Vec) -> AgentsStreamWatcher { + AgentsStreamWatcher::from_stream(Box::pin(stream::iter(snapshots.into_iter().map(Ok)))) + } + + #[tokio::test] + async fn until_agent_returns_ok_when_predicate_matches_with_agent_present() -> Result<()> { + let agent_id = "agent-x"; + let snapshots = vec![ + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent(agent_id, BTreeSet::new())], + }, + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent( + agent_id, + unable_to_find_chat_template_issue(), + )], + }, + ]; + + let mut watcher = make_watcher(snapshots); + + let predicate_agent_id = agent_id.to_owned(); + let observed = watcher + .until_agent(agent_id, move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == predicate_agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::UnableToFindChatTemplate(_))) + }) + }) + .await?; + + assert!( + observed.agents.iter().any(|agent| agent.id == agent_id), + "matched snapshot must contain the watched agent" + ); + + Ok(()) + } + + #[tokio::test] + async fn until_agent_errors_when_agent_disappears_mid_stream() -> Result<()> { + let agent_id = "agent-y"; + let snapshots = vec![ + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent(agent_id, BTreeSet::new())], + }, + AgentControllerPoolSnapshot { agents: vec![] }, + ]; + + let mut watcher = make_watcher(snapshots); + + let error = watcher + .until_agent(agent_id, |_snapshot| false) + .await + .err() + .context("until_agent must surface the disappearance as an error")?; + let rendered = format!("{error:#}"); + + assert!( + rendered.contains("disappeared"), + "error must explicitly call out the disappearance, got: {rendered}" + ); + assert!( + rendered.contains(agent_id), + "error must name the missing agent, got: {rendered}" + ); + + Ok(()) + } + + #[tokio::test] + async fn until_agent_errors_when_stream_closes_before_predicate_matches() -> Result<()> { + let agent_id = "agent-z"; + let snapshots = vec![AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent(agent_id, BTreeSet::new())], + }]; + + let mut watcher = make_watcher(snapshots); + + let error = watcher + .until_agent(agent_id, |_snapshot| false) + .await + .err() + .context("until_agent must error when the stream ends without a match")?; + let rendered = format!("{error:#}"); + + assert!( + rendered.contains("stream closed"), + "error must surface the stream-closed condition, got: {rendered}" + ); + + Ok(()) + } + + #[tokio::test] + async fn wait_for_agent_ready_returns_snapshot_when_named_agent_reaches_slot_count() + -> Result<()> { + let agent_id = "agent-warm-0"; + let snapshots = vec![ + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent_and_slots(agent_id, BTreeSet::new(), 0)], + }, + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent_and_slots(agent_id, BTreeSet::new(), 2)], + }, + ]; + + let mut watcher = make_watcher(snapshots); + + let snapshot = watcher.wait_for_agent_ready(agent_id, 2).await?; + + assert!( + snapshot + .agents + .iter() + .any(|agent| { agent.name.as_deref() == Some(agent_id) && agent.slots_total == 2 }), + "returned snapshot must contain the named agent at its target slot count" + ); + + Ok(()) + } + + #[tokio::test] + async fn wait_for_agent_ready_errors_when_named_agent_reports_issues() -> Result<()> { + let agent_id = "agent-warm-1"; + let snapshots = vec![AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent_and_slots( + agent_id, + unable_to_find_chat_template_issue(), + 0, + )], + }]; + + let mut watcher = make_watcher(snapshots); + + let error = watcher + .wait_for_agent_ready(agent_id, 2) + .await + .err() + .context("wait_for_agent_ready must surface agent-side issues as an error")?; + let rendered = format!("{error:#}"); + + assert!( + rendered.contains(agent_id), + "error must name the failing agent, got: {rendered}" + ); + assert!( + rendered.contains("issues"), + "error must mention that issues were registered, got: {rendered}" + ); + + Ok(()) + } + + #[tokio::test] + async fn wait_for_agent_ready_errors_when_stream_closes_before_match() -> Result<()> { + let agent_id = "agent-warm-2"; + let snapshots = vec![AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent_and_slots(agent_id, BTreeSet::new(), 0)], + }]; + + let mut watcher = make_watcher(snapshots); + + let error = watcher + .wait_for_agent_ready(agent_id, 2) + .await + .err() + .context("wait_for_agent_ready must error when the stream ends without a match")?; + let rendered = format!("{error:#}"); + + assert!( + rendered.contains("slot readiness"), + "error must mention that slot readiness was not reached, got: {rendered}" + ); + + Ok(()) + } +} diff --git a/paddler_tests/src/cluster_handle.rs b/paddler_tests/src/cluster_handle.rs index 52bf8ccf..7374fadb 100644 --- a/paddler_tests/src/cluster_handle.rs +++ b/paddler_tests/src/cluster_handle.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use log::warn; use paddler_client::PaddlerClient; use tokio_util::sync::CancellationToken; @@ -43,37 +44,25 @@ impl ClusterHandle { } } - pub async fn shutdown(self) -> Result<()> { - let Self { - cancel_token, - completion, - .. - } = self; - - cancel_token.cancel(); + pub async fn shutdown(mut self) -> Result<()> { + self.cancel_token.cancel(); - match completion { - ClusterCompletion::InProcess { - mut agents, - mut balancer, - } => { - for agent_runner in &mut agents { + match &mut self.completion { + ClusterCompletion::InProcess { agents, balancer } => { + for agent_runner in agents.iter_mut() { agent_runner.wait_for_completion().await?; } balancer.wait_for_completion().await?; } - ClusterCompletion::Subprocess { - mut agents, - mut balancer, - } => { - for child in &mut agents { + ClusterCompletion::Subprocess { agents, balancer } => { + for child in agents.iter_mut() { terminate_child(child)?; } - terminate_child(&mut balancer)?; + terminate_child(balancer)?; - for agent in &mut agents { + for agent in agents.iter_mut() { agent.wait().await?; } @@ -84,3 +73,20 @@ impl ClusterHandle { Ok(()) } } + +impl Drop for ClusterHandle { + fn drop(&mut self) { + self.cancel_token.cancel(); + + if let ClusterCompletion::Subprocess { agents, balancer } = &mut self.completion { + for child in agents.iter_mut() { + if let Err(error) = terminate_child(child) { + warn!("ClusterHandle drop: failed to terminate agent subprocess: {error:#}"); + } + } + if let Err(error) = terminate_child(balancer) { + warn!("ClusterHandle drop: failed to terminate balancer subprocess: {error:#}"); + } + } + } +} diff --git a/paddler_tests/src/collect_embedding_results.rs b/paddler_tests/src/collect_embedding_results.rs index 08bbc0fe..0b846d05 100644 --- a/paddler_tests/src/collect_embedding_results.rs +++ b/paddler_tests/src/collect_embedding_results.rs @@ -2,64 +2,93 @@ use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; use futures_util::StreamExt as _; -use paddler_types::embedding::Embedding; use paddler_types::embedding_result::EmbeddingResult; use paddler_types::inference_client::Message as InferenceMessage; use paddler_types::inference_client::Response as InferenceResponse; use crate::collected_embedding_results::CollectedEmbeddingResults; +use crate::embedding_with_producer::EmbeddingWithProducer; use crate::inference_message_stream::InferenceMessageStream; pub async fn collect_embedding_results( mut stream: InferenceMessageStream, ) -> Result { - let mut embeddings: Vec = Vec::new(); + let mut embeddings: Vec = Vec::new(); + let mut embeddings_disabled = false; let mut errors: Vec = Vec::new(); + let mut embedding_rejected_due_to_active_token_generation_count: usize = 0; + let mut no_embeddings_produced_count: usize = 0; + let mut oversized_documents = Vec::new(); let mut saw_done = false; + let mut wire_errors = Vec::new(); while let Some(item) = stream.next().await { let message = item.context("embedding stream yielded an error")?; match message { - InferenceMessage::Response(envelope) => match envelope.response { - InferenceResponse::Embedding(EmbeddingResult::Done) => { - saw_done = true; + InferenceMessage::Response(envelope) => { + let generated_by = envelope.generated_by.clone(); - break; - } - InferenceResponse::Embedding(EmbeddingResult::Embedding(embedding)) => { - embeddings.push(embedding); - } - InferenceResponse::Embedding(EmbeddingResult::Error(message)) => { - errors.push(message); - } - InferenceResponse::GeneratedToken(_) => { - return Err(anyhow!( - "unexpected generated-token response on an embedding stream" - )); - } - InferenceResponse::Timeout => { - return Err(anyhow!("embedding request timed out on balancer")); - } - InferenceResponse::TooManyBufferedRequests => { - return Err(anyhow!( - "balancer rejected embedding request: too many buffered" - )); + match envelope.response { + InferenceResponse::Embedding(EmbeddingResult::Done) => { + saw_done = true; + + break; + } + InferenceResponse::Embedding(EmbeddingResult::Embedding(embedding)) => { + embeddings.push(EmbeddingWithProducer { + embedding, + generated_by, + }); + } + InferenceResponse::Embedding(EmbeddingResult::DocumentExceedsBatchSize( + details, + )) => { + oversized_documents.push(details); + } + InferenceResponse::Embedding(EmbeddingResult::EmbeddingsDisabled) => { + embeddings_disabled = true; + } + InferenceResponse::Embedding(EmbeddingResult::Error(message)) => { + errors.push(message); + } + InferenceResponse::Embedding( + EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration, + ) => { + embedding_rejected_due_to_active_token_generation_count += 1; + } + InferenceResponse::Embedding(EmbeddingResult::NoEmbeddingsProduced) => { + no_embeddings_produced_count += 1; + } + InferenceResponse::GeneratedToken(_) => { + return Err(anyhow!( + "unexpected generated-token response on an embedding stream" + )); + } + InferenceResponse::Timeout => { + return Err(anyhow!("embedding request timed out on balancer")); + } + InferenceResponse::TooManyBufferedRequests => { + return Err(anyhow!( + "balancer rejected embedding request: too many buffered" + )); + } } - }, + } InferenceMessage::Error(error_envelope) => { - return Err(anyhow!( - "embedding stream returned JSON-RPC error code {} ({})", - error_envelope.error.code, - error_envelope.error.description - )); + wire_errors.push(error_envelope.error); } } } Ok(CollectedEmbeddingResults { embeddings, + embeddings_disabled, errors, + embedding_rejected_due_to_active_token_generation_count, + no_embeddings_produced_count, + oversized_documents, saw_done, + wire_errors, }) } diff --git a/paddler_tests/src/collect_generated_tokens.rs b/paddler_tests/src/collect_generated_tokens.rs index 57eba62a..baddceea 100644 --- a/paddler_tests/src/collect_generated_tokens.rs +++ b/paddler_tests/src/collect_generated_tokens.rs @@ -2,50 +2,57 @@ use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; use futures_util::StreamExt as _; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::inference_client::Message as InferenceMessage; use paddler_types::inference_client::Response as InferenceResponse; use paddler_types::streamable_result::StreamableResult as _; use crate::collected_generated_tokens::CollectedGeneratedTokens; use crate::inference_message_stream::InferenceMessageStream; +use crate::token_result_with_producer::TokenResultWithProducer; pub async fn collect_generated_tokens( mut stream: InferenceMessageStream, ) -> Result { let mut text = String::new(); - let mut token_results: Vec = Vec::new(); + let mut token_results: Vec = Vec::new(); while let Some(item) = stream.next().await { let message = item.context("inference stream yielded an error")?; match message { - InferenceMessage::Response(envelope) => match envelope.response { - InferenceResponse::GeneratedToken(token_result) => { - if let GeneratedTokenResult::Token(token_text) = &token_result { - text.push_str(token_text); - } + InferenceMessage::Response(envelope) => { + let generated_by = envelope.generated_by.clone(); + + match envelope.response { + InferenceResponse::GeneratedToken(token_result) => { + if let Some(token_text) = token_result.token_text() { + text.push_str(token_text); + } - let is_done = token_result.is_done(); + let is_done = token_result.is_done(); - token_results.push(token_result); + token_results.push(TokenResultWithProducer { + token_result, + generated_by, + }); - if is_done { - break; + if is_done { + break; + } + } + InferenceResponse::Embedding(_) => { + return Err(anyhow!( + "unexpected embedding response on a token-generation stream" + )); + } + InferenceResponse::Timeout => { + return Err(anyhow!("inference request timed out on balancer")); + } + InferenceResponse::TooManyBufferedRequests => { + return Err(anyhow!("balancer rejected request: too many buffered")); } } - InferenceResponse::Embedding(_) => { - return Err(anyhow!( - "unexpected embedding response on a token-generation stream" - )); - } - InferenceResponse::Timeout => { - return Err(anyhow!("inference request timed out on balancer")); - } - InferenceResponse::TooManyBufferedRequests => { - return Err(anyhow!("balancer rejected request: too many buffered")); - } - }, + } InferenceMessage::Error(error_envelope) => { return Err(anyhow!( "inference stream returned JSON-RPC error code {} ({})", diff --git a/paddler_tests/src/collected_embedding_results.rs b/paddler_tests/src/collected_embedding_results.rs index 15aecdac..4d785d94 100644 --- a/paddler_tests/src/collected_embedding_results.rs +++ b/paddler_tests/src/collected_embedding_results.rs @@ -1,7 +1,15 @@ -use paddler_types::embedding::Embedding; +use paddler_types::jsonrpc::Error as JsonRpcError; +use paddler_types::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; + +use crate::embedding_with_producer::EmbeddingWithProducer; pub struct CollectedEmbeddingResults { - pub embeddings: Vec, + pub embeddings: Vec, + pub embeddings_disabled: bool, pub errors: Vec, + pub embedding_rejected_due_to_active_token_generation_count: usize, + pub no_embeddings_produced_count: usize, + pub oversized_documents: Vec, pub saw_done: bool, + pub wire_errors: Vec, } diff --git a/paddler_tests/src/collected_generated_tokens.rs b/paddler_tests/src/collected_generated_tokens.rs index d725fe64..59779702 100644 --- a/paddler_tests/src/collected_generated_tokens.rs +++ b/paddler_tests/src/collected_generated_tokens.rs @@ -1,6 +1,6 @@ -use paddler_types::generated_token_result::GeneratedTokenResult; +use crate::token_result_with_producer::TokenResultWithProducer; pub struct CollectedGeneratedTokens { pub text: String, - pub token_results: Vec, + pub token_results: Vec, } diff --git a/paddler_tests/src/embedding_with_producer.rs b/paddler_tests/src/embedding_with_producer.rs new file mode 100644 index 00000000..cc3bf3d6 --- /dev/null +++ b/paddler_tests/src/embedding_with_producer.rs @@ -0,0 +1,7 @@ +use paddler_types::embedding::Embedding; + +#[derive(Debug)] +pub struct EmbeddingWithProducer { + pub embedding: Embedding, + pub generated_by: Option, +} diff --git a/paddler_tests/src/in_process_cluster_params.rs b/paddler_tests/src/in_process_cluster_params.rs index 891982dd..2585fe62 100644 --- a/paddler_tests/src/in_process_cluster_params.rs +++ b/paddler_tests/src/in_process_cluster_params.rs @@ -2,31 +2,32 @@ use std::time::Duration; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; + pub struct InProcessClusterParams { - pub agent_name: String, + pub agent: Option, pub buffered_request_timeout: Duration, pub desired_state: BalancerDesiredState, pub inference_cors_allowed_hosts: Vec, pub inference_item_timeout: Duration, pub management_cors_allowed_hosts: Vec, pub max_buffered_requests: i32, - pub slots_per_agent: i32, - pub spawn_agent: bool, pub wait_for_slots_ready: bool, } impl Default for InProcessClusterParams { fn default() -> Self { Self { - agent_name: "test-agent".to_owned(), + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 4, + }), buffered_request_timeout: Duration::from_secs(10), desired_state: BalancerDesiredState::default(), inference_cors_allowed_hosts: Vec::new(), inference_item_timeout: Duration::from_secs(30), management_cors_allowed_hosts: Vec::new(), max_buffered_requests: 10, - slots_per_agent: 4, - spawn_agent: true, wait_for_slots_ready: true, } } diff --git a/paddler_tests/src/inference_http_client.rs b/paddler_tests/src/inference_http_client.rs index f39376e5..a17436c6 100644 --- a/paddler_tests/src/inference_http_client.rs +++ b/paddler_tests/src/inference_http_client.rs @@ -4,15 +4,16 @@ use async_stream::try_stream; use futures_util::Stream; use futures_util::StreamExt as _; use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::request_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::ContinueFromRawPromptParams; use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use reqwest::Client; use url::Url; use crate::inference_message_stream::InferenceMessageStream; +#[derive(Clone)] pub struct InferenceHttpClient { http_client: Client, inference_base_url: Url, diff --git a/paddler_tests/src/lib.rs b/paddler_tests/src/lib.rs index e19b2533..6ce9c02a 100644 --- a/paddler_tests/src/lib.rs +++ b/paddler_tests/src/lib.rs @@ -1,3 +1,4 @@ +pub mod agent_config; pub mod agents_status; pub mod agents_stream_watcher; pub mod balancer_addresses; @@ -11,21 +12,33 @@ pub mod collect_generated_tokens; pub mod collected_embedding_results; pub mod collected_generated_tokens; pub mod current_test_device; +pub mod embedding_with_producer; pub mod in_process_cluster_params; pub mod inference_http_client; pub mod inference_message_stream; pub mod load_test_image_data_uri; pub mod make_agent_controller_without_remote_agent; pub mod model_card; +pub mod openai_chat_completions_client; pub mod paddler_command; pub mod parse_test_device_value; +pub mod qwen3_embedding_cluster_params; pub mod spawn_agent_subprocess; pub mod spawn_agent_subprocess_params; pub mod start_in_process_cluster; +pub mod start_in_process_cluster_with_deepseek_r1_distill_llama_8b; +pub mod start_in_process_cluster_with_gemma_4; +pub mod start_in_process_cluster_with_gemma_4_and_mmproj; +pub mod start_in_process_cluster_with_glm_4_7_flash; +pub mod start_in_process_cluster_with_ministral_3; +pub mod start_in_process_cluster_with_ministral_3_and_mmproj; pub mod start_in_process_cluster_with_qwen2_5_vl; pub mod start_in_process_cluster_with_qwen3; pub mod start_in_process_cluster_with_qwen3_5; +pub mod start_in_process_cluster_with_qwen3_6; +pub mod start_in_process_cluster_with_qwen3_6_and_mmproj; pub mod start_in_process_cluster_with_smolvlm2; +pub mod start_in_process_cluster_with_smolvlm2_and_n_batch; pub mod start_in_process_embedding_cluster; pub mod start_subprocess_cluster; pub mod start_subprocess_cluster_with_qwen2_5_vl; @@ -36,4 +49,5 @@ pub mod state_database_file; pub mod subprocess_cluster_params; pub mod terminate_child; pub mod test_device; +pub mod token_result_with_producer; pub mod wait_until_healthy; diff --git a/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs b/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs new file mode 100644 index 00000000..993f116b --- /dev/null +++ b/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn deepseek_r1_distill_llama_8b() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf".to_owned(), + repo_id: "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/gemma_4_e4b_it.rs b/paddler_tests/src/model_card/gemma_4_e4b_it.rs new file mode 100644 index 00000000..2959b2cc --- /dev/null +++ b/paddler_tests/src/model_card/gemma_4_e4b_it.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn gemma_4_e4b_it() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "gemma-4-E4B-it-Q4_K_M.gguf".to_owned(), + repo_id: "unsloth/gemma-4-E4B-it-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs b/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs new file mode 100644 index 00000000..083db911 --- /dev/null +++ b/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn gemma_4_e4b_it_mmproj() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "mmproj-F16.gguf".to_owned(), + repo_id: "unsloth/gemma-4-E4B-it-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/glm_4_7_flash.rs b/paddler_tests/src/model_card/glm_4_7_flash.rs new file mode 100644 index 00000000..5d5bba3e --- /dev/null +++ b/paddler_tests/src/model_card/glm_4_7_flash.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn glm_4_7_flash() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "GLM-4.7-Flash-Q4_K_M.gguf".to_owned(), + repo_id: "unsloth/GLM-4.7-Flash-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs b/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs new file mode 100644 index 00000000..0718b7fa --- /dev/null +++ b/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn ministral_3_14b_reasoning() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "Ministral-3-14B-Reasoning-2512-Q4_K_M.gguf".to_owned(), + repo_id: "unsloth/Ministral-3-14B-Reasoning-2512-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs b/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs new file mode 100644 index 00000000..be0c5b76 --- /dev/null +++ b/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn ministral_3_14b_reasoning_mmproj() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "mmproj-F16.gguf".to_owned(), + repo_id: "unsloth/Ministral-3-14B-Reasoning-2512-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/mod.rs b/paddler_tests/src/model_card/mod.rs index d064e778..dedd03a3 100644 --- a/paddler_tests/src/model_card/mod.rs +++ b/paddler_tests/src/model_card/mod.rs @@ -1,15 +1,23 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; - +pub mod deepseek_r1_distill_llama_8b; +pub mod gemma_4_e4b_it; +pub mod gemma_4_e4b_it_mmproj; +pub mod glm_4_7_flash; +pub mod ministral_3_14b_reasoning; +pub mod ministral_3_14b_reasoning_mmproj; pub mod nomic_embed_text_v1_5; pub mod qwen2_5_vl_3b; pub mod qwen2_5_vl_3b_mmproj; pub mod qwen3_0_6b; pub mod qwen3_5_0_8b; pub mod qwen3_5_0_8b_mmproj; +pub mod qwen3_6_35b_a3b; +pub mod qwen3_6_35b_a3b_mmproj; pub mod qwen3_embedding_0_6b; pub mod smolvlm2_256m; pub mod smolvlm2_256m_mmproj; +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + pub struct ModelCard { pub gpu_layer_count: u32, pub reference: HuggingFaceModelReference, diff --git a/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs b/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs new file mode 100644 index 00000000..c75a5f8a --- /dev/null +++ b/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn qwen3_6_35b_a3b() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "Qwen3.6-35B-A3B-UD-Q4_K_M.gguf".to_owned(), + repo_id: "unsloth/Qwen3.6-35B-A3B-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs b/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs new file mode 100644 index 00000000..5d6a5b55 --- /dev/null +++ b/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs @@ -0,0 +1,15 @@ +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn qwen3_6_35b_a3b_mmproj() -> ModelCard { + ModelCard { + gpu_layer_count: 999, + reference: HuggingFaceModelReference { + filename: "mmproj-F16.gguf".to_owned(), + repo_id: "unsloth/Qwen3.6-35B-A3B-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_tests/src/openai_chat_completions_client.rs b/paddler_tests/src/openai_chat_completions_client.rs new file mode 100644 index 00000000..c0571661 --- /dev/null +++ b/paddler_tests/src/openai_chat_completions_client.rs @@ -0,0 +1,86 @@ +use anyhow::Context as _; +use anyhow::Result; +use futures_util::StreamExt as _; +use reqwest::Client; +use serde_json::Value; +use url::Url; + +pub struct OpenAIChatCompletionsClient { + http_client: Client, + completions_url: Url, +} + +impl OpenAIChatCompletionsClient { + pub fn new(http_client: Client, openai_base_url: &Url) -> Result { + Ok(Self { + http_client, + completions_url: openai_base_url + .join("v1/chat/completions") + .context("failed to build /v1/chat/completions URL")?, + }) + } + + pub async fn post_streaming(&self, body: &Value) -> Result> { + let response = self + .http_client + .post(self.completions_url.clone()) + .json(body) + .send() + .await + .context("failed to POST OpenAI streaming chat completion")? + .error_for_status() + .context("non-success status from OpenAI streaming endpoint")?; + + let mut bytes_stream = response.bytes_stream(); + let mut buffer: Vec = Vec::new(); + let mut chunks: Vec = Vec::new(); + + while let Some(chunk_result) = bytes_stream.next().await { + let chunk = chunk_result.context("failed to read OpenAI streaming chunk")?; + + buffer.extend_from_slice(&chunk); + + while let Some(newline_position) = buffer.iter().position(|byte| *byte == b'\n') { + let line_bytes: Vec = buffer.drain(..=newline_position).collect(); + let line_text = std::str::from_utf8(&line_bytes[..newline_position]) + .context("OpenAI stream produced non-UTF8 bytes")? + .trim(); + + if line_text.is_empty() { + continue; + } + + chunks.push(serde_json::from_str(line_text).with_context(|| { + format!("failed to parse OpenAI streaming chunk: {line_text}") + })?); + } + } + + let trailing_text = std::str::from_utf8(&buffer) + .context("OpenAI stream produced trailing non-UTF8 bytes")? + .trim(); + + if !trailing_text.is_empty() { + chunks.push( + serde_json::from_str(trailing_text) + .with_context(|| format!("failed to parse trailing chunk: {trailing_text}"))?, + ); + } + + Ok(chunks) + } + + pub async fn post_non_streaming(&self, body: &Value) -> Result { + self.http_client + .post(self.completions_url.clone()) + .json(body) + .send() + .await + .context("failed to POST OpenAI non-streaming chat completion")? + .error_for_status() + .context("non-success status from OpenAI non-streaming endpoint")? + .json::() + .await + .context("failed to parse OpenAI non-streaming JSON response") + } +} diff --git a/paddler_tests/src/qwen3_embedding_cluster_params.rs b/paddler_tests/src/qwen3_embedding_cluster_params.rs new file mode 100644 index 00000000..64307c56 --- /dev/null +++ b/paddler_tests/src/qwen3_embedding_cluster_params.rs @@ -0,0 +1,23 @@ +use std::time::Duration; + +use paddler_types::inference_parameters::InferenceParameters; + +use crate::agent_config::AgentConfig; + +pub struct Qwen3EmbeddingClusterParams { + pub agents: Vec, + pub buffered_request_timeout: Duration, + pub inference_parameters: InferenceParameters, + pub max_buffered_requests: i32, +} + +impl Default for Qwen3EmbeddingClusterParams { + fn default() -> Self { + Self { + agents: AgentConfig::uniform(1, 4), + buffered_request_timeout: Duration::from_secs(10), + inference_parameters: InferenceParameters::default(), + max_buffered_requests: 10, + } + } +} diff --git a/paddler_tests/src/start_in_process_cluster.rs b/paddler_tests/src/start_in_process_cluster.rs index 7399bad2..6e679c10 100644 --- a/paddler_tests/src/start_in_process_cluster.rs +++ b/paddler_tests/src/start_in_process_cluster.rs @@ -22,15 +22,13 @@ use crate::wait_until_healthy::wait_until_healthy; pub async fn start_in_process_cluster( InProcessClusterParams { - agent_name, + agent, buffered_request_timeout, desired_state, inference_cors_allowed_hosts, inference_item_timeout, management_cors_allowed_hosts, max_buffered_requests, - slots_per_agent, - spawn_agent, wait_for_slots_ready, }: InProcessClusterParams, ) -> Result { @@ -82,15 +80,15 @@ pub async fn start_in_process_cluster( let buffered_requests_watcher = BufferedRequestsStreamWatcher::connect(&paddler_client.management()).await?; - let expected_agent_count: usize = usize::from(spawn_agent); + let expected_agent_count: usize = usize::from(agent.is_some()); let mut agent_runners: Vec = Vec::with_capacity(expected_agent_count); - if spawn_agent { + if let Some(agent_config) = agent.as_ref() { let agent_runner = AgentRunner::start(AgentRunnerParams { - agent_name: Some(agent_name), + agent_name: Some(agent_config.name.clone()), management_address: addresses.management.to_string(), cancellation_token: cancel_token.clone(), - slots: slots_per_agent, + slots: agent_config.slot_count, }); agent_runners.push(agent_runner); @@ -104,12 +102,12 @@ pub async fn start_in_process_cluster( let agent_ids: Vec = registered_snapshot .agents .iter() - .map(|agent| agent.id.clone()) + .map(|registered_agent| registered_agent.id.clone()) .collect(); - if wait_for_slots_ready && spawn_agent { + if wait_for_slots_ready && let Some(agent_config) = agent.as_ref() { agents_watcher - .wait_for_slots_ready(expected_agent_count, slots_per_agent) + .wait_for_slots_ready(&[agent_config.slot_count]) .await?; } diff --git a/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs b/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs new file mode 100644 index 00000000..f783b62b --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::deepseek_r1_distill_llama_8b::deepseek_r1_distill_llama_8b; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_deepseek_r1_distill_llama_8b( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference, + } = deepseek_r1_distill_llama_8b(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs b/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs new file mode 100644 index 00000000..bc2762d5 --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_gemma_4(agent: AgentConfig) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference, + } = gemma_4_e4b_it(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs new file mode 100644 index 00000000..f4297d49 --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; +use crate::model_card::gemma_4_e4b_it_mmproj::gemma_4_e4b_it_mmproj; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_gemma_4_and_mmproj( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = gemma_4_e4b_it(); + let ModelCard { + reference: mmproj_reference, + .. + } = gemma_4_e4b_it_mmproj(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs b/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs new file mode 100644 index 00000000..7d055561 --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::glm_4_7_flash::glm_4_7_flash; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_glm_4_7_flash( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference, + } = glm_4_7_flash(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs b/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs new file mode 100644 index 00000000..a9cd6b7c --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_ministral_3( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference, + } = ministral_3_14b_reasoning(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs new file mode 100644 index 00000000..515ac248 --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; +use crate::model_card::ministral_3_14b_reasoning_mmproj::ministral_3_14b_reasoning_mmproj; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_ministral_3_and_mmproj( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = ministral_3_14b_reasoning(); + let ModelCard { + reference: mmproj_reference, + .. + } = ministral_3_14b_reasoning_mmproj(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs b/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs index ddb6b8a7..cb5dc0a9 100644 --- a/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs +++ b/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::in_process_cluster_params::InProcessClusterParams; @@ -10,9 +11,7 @@ use crate::model_card::qwen2_5_vl_3b::qwen2_5_vl_3b; use crate::model_card::qwen2_5_vl_3b_mmproj::qwen2_5_vl_3b_mmproj; use crate::start_in_process_cluster::start_in_process_cluster; -pub async fn start_in_process_cluster_with_qwen2_5_vl( - slots_per_agent: i32, -) -> Result { +pub async fn start_in_process_cluster_with_qwen2_5_vl(agent: AgentConfig) -> Result { let device = current_test_device()?; device.require_available()?; @@ -27,7 +26,7 @@ pub async fn start_in_process_cluster_with_qwen2_5_vl( } = qwen2_5_vl_3b_mmproj(); start_in_process_cluster(InProcessClusterParams { - slots_per_agent, + agent: Some(agent), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3.rs index b42c2d2d..befceb8e 100644 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3.rs +++ b/paddler_tests/src/start_in_process_cluster_with_qwen3.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::in_process_cluster_params::InProcessClusterParams; @@ -9,7 +10,7 @@ use crate::model_card::ModelCard; use crate::model_card::qwen3_0_6b::qwen3_0_6b; use crate::start_in_process_cluster::start_in_process_cluster; -pub async fn start_in_process_cluster_with_qwen3(slots_per_agent: i32) -> Result { +pub async fn start_in_process_cluster_with_qwen3(agent: AgentConfig) -> Result { let device = current_test_device()?; device.require_available()?; @@ -20,7 +21,7 @@ pub async fn start_in_process_cluster_with_qwen3(slots_per_agent: i32) -> Result } = qwen3_0_6b(); start_in_process_cluster(InProcessClusterParams { - slots_per_agent, + agent: Some(agent), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs index 1f75faf5..3d9189bd 100644 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs +++ b/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::in_process_cluster_params::InProcessClusterParams; @@ -11,7 +12,7 @@ use crate::model_card::qwen3_5_0_8b_mmproj::qwen3_5_0_8b_mmproj; use crate::start_in_process_cluster::start_in_process_cluster; pub async fn start_in_process_cluster_with_qwen3_5( - slots_per_agent: i32, + agent: AgentConfig, with_mmproj: bool, ) -> Result { let device = current_test_device()?; @@ -35,7 +36,7 @@ pub async fn start_in_process_cluster_with_qwen3_5( }; start_in_process_cluster(InProcessClusterParams { - slots_per_agent, + agent: Some(agent), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs new file mode 100644 index 00000000..7f6e765b --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_qwen3_6(agent: AgentConfig) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_6_35b_a3b(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs new file mode 100644 index 00000000..2d5c5dad --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; +use crate::model_card::qwen3_6_35b_a3b_mmproj::qwen3_6_35b_a3b_mmproj; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_qwen3_6_and_mmproj( + agent: AgentConfig, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = qwen3_6_35b_a3b(); + let ModelCard { + reference: mmproj_reference, + .. + } = qwen3_6_35b_a3b_mmproj(); + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs b/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs index 96d86d4c..b92fa31c 100644 --- a/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs +++ b/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::in_process_cluster_params::InProcessClusterParams; @@ -10,7 +11,7 @@ use crate::model_card::smolvlm2_256m::smolvlm2_256m; use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; use crate::start_in_process_cluster::start_in_process_cluster; -pub async fn start_in_process_cluster_with_smolvlm2(slots_per_agent: i32) -> Result { +pub async fn start_in_process_cluster_with_smolvlm2(agent: AgentConfig) -> Result { let device = current_test_device()?; device.require_available()?; @@ -25,7 +26,7 @@ pub async fn start_in_process_cluster_with_smolvlm2(slots_per_agent: i32) -> Res } = smolvlm2_256m_mmproj(); start_in_process_cluster(InProcessClusterParams { - slots_per_agent, + agent: Some(agent), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs b/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs new file mode 100644 index 00000000..b6bbd67a --- /dev/null +++ b/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs @@ -0,0 +1,47 @@ +use anyhow::Result; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; + +use crate::agent_config::AgentConfig; +use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; +use crate::in_process_cluster_params::InProcessClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::smolvlm2_256m::smolvlm2_256m; +use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; +use crate::start_in_process_cluster::start_in_process_cluster; + +pub async fn start_in_process_cluster_with_smolvlm2_and_n_batch( + agent: AgentConfig, + n_batch: usize, +) -> Result { + let device = current_test_device()?; + + device.require_available()?; + + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = smolvlm2_256m(); + let ModelCard { + reference: mmproj_reference, + .. + } = smolvlm2_256m_mmproj(); + + let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); + inference_parameters.n_batch = n_batch; + + start_in_process_cluster(InProcessClusterParams { + agent: Some(agent), + desired_state: BalancerDesiredState { + chat_template_override: None, + inference_parameters, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }, + wait_for_slots_ready: true, + ..InProcessClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_in_process_embedding_cluster.rs b/paddler_tests/src/start_in_process_embedding_cluster.rs index 800e8237..827dbf8c 100644 --- a/paddler_tests/src/start_in_process_embedding_cluster.rs +++ b/paddler_tests/src/start_in_process_embedding_cluster.rs @@ -3,6 +3,7 @@ use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::inference_parameters::InferenceParameters; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::in_process_cluster_params::InProcessClusterParams; use crate::model_card::ModelCard; @@ -11,12 +12,12 @@ use crate::start_in_process_cluster::start_in_process_cluster; pub async fn start_in_process_embedding_cluster( inference_parameters: InferenceParameters, - slots_per_agent: i32, + agent: AgentConfig, ) -> Result { let ModelCard { reference, .. } = qwen3_embedding_0_6b(); start_in_process_cluster(InProcessClusterParams { - slots_per_agent, + agent: Some(agent), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, diff --git a/paddler_tests/src/start_subprocess_cluster.rs b/paddler_tests/src/start_subprocess_cluster.rs index 2677d9dc..b676e8c7 100644 --- a/paddler_tests/src/start_subprocess_cluster.rs +++ b/paddler_tests/src/start_subprocess_cluster.rs @@ -3,6 +3,7 @@ use std::process::Stdio; use anyhow::Context as _; use anyhow::Result; use paddler_client::PaddlerClient; +use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; use tokio::process::Child; use tokio_util::sync::CancellationToken; @@ -18,15 +19,13 @@ use crate::wait_until_healthy::wait_until_healthy; pub async fn start_subprocess_cluster( SubprocessClusterParams { - agent_count, - agent_name_prefix, + agents, buffered_request_timeout, desired_state, inference_cors_allowed_hosts, inference_item_timeout, management_cors_allowed_hosts, max_buffered_requests, - slots_per_agent, state_database_url, wait_for_slots_ready, }: SubprocessClusterParams, @@ -92,44 +91,49 @@ pub async fn start_subprocess_cluster( let buffered_requests_watcher = BufferedRequestsStreamWatcher::connect(&paddler_client.management()).await?; - let mut agent_children: Vec = Vec::with_capacity(agent_count); - - for agent_index in 0..agent_count { - let agent_name = format!("{agent_name_prefix}-{agent_index}"); + let expected_agent_count = agents.len(); + let mut agent_children: Vec = Vec::with_capacity(expected_agent_count); + let mut last_ready_snapshot: Option = None; + for agent in &agents { let agent_child = paddler_command() .arg("agent") .arg("--management-addr") .arg(addresses.management.to_string()) .arg("--name") - .arg(agent_name) + .arg(&agent.name) .arg("--slots") - .arg(slots_per_agent.to_string()) + .arg(agent.slot_count.to_string()) .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn() .context("failed to spawn paddler agent subprocess")?; agent_children.push(agent_child); + + if wait_for_slots_ready { + last_ready_snapshot = Some( + agents_watcher + .wait_for_agent_ready(&agent.name, agent.slot_count) + .await?, + ); + } } - let registered_snapshot = agents_watcher - .until(move |snapshot| snapshot.agents.len() >= agent_count) - .await - .context("not all subprocess agents registered")?; + let registered_snapshot = match last_ready_snapshot { + Some(snapshot) => snapshot, + None => agents_watcher + .until(move |snapshot| snapshot.agents.len() >= expected_agent_count) + .await + .context("not all subprocess agents registered")?, + }; let agent_ids: Vec = registered_snapshot .agents .iter() - .map(|agent| agent.id.clone()) + .map(|registered_agent| registered_agent.id.clone()) .collect(); - if wait_for_slots_ready { - agents_watcher - .wait_for_slots_ready(agent_count, slots_per_agent) - .await?; - } - Ok(ClusterHandle::new(ClusterHandleParams { addresses, agent_ids, diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs b/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs index e5657500..c898cc8e 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs +++ b/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; @@ -11,8 +12,7 @@ use crate::start_subprocess_cluster::start_subprocess_cluster; use crate::subprocess_cluster_params::SubprocessClusterParams; pub async fn start_subprocess_cluster_with_qwen2_5_vl( - slots_per_agent: i32, - agent_count: usize, + agents: Vec, ) -> Result { let device = current_test_device()?; @@ -28,8 +28,7 @@ pub async fn start_subprocess_cluster_with_qwen2_5_vl( } = qwen2_5_vl_3b_mmproj(); start_subprocess_cluster(SubprocessClusterParams { - agent_count, - slots_per_agent, + agents, desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs b/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs index 3f623f35..36db2eff 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs +++ b/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; @@ -10,8 +11,7 @@ use crate::start_subprocess_cluster::start_subprocess_cluster; use crate::subprocess_cluster_params::SubprocessClusterParams; pub async fn start_subprocess_cluster_with_qwen3( - slots_per_agent: i32, - agent_count: usize, + agents: Vec, ) -> Result { let device = current_test_device()?; @@ -23,8 +23,7 @@ pub async fn start_subprocess_cluster_with_qwen3( } = qwen3_0_6b(); start_subprocess_cluster(SubprocessClusterParams { - agent_count, - slots_per_agent, + agents, desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs b/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs index e7852b04..ab5e7069 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs +++ b/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs @@ -4,28 +4,47 @@ use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::inference_parameters::InferenceParameters; use crate::cluster_handle::ClusterHandle; +use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; use crate::model_card::qwen3_embedding_0_6b::qwen3_embedding_0_6b; +use crate::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; use crate::start_subprocess_cluster::start_subprocess_cluster; use crate::subprocess_cluster_params::SubprocessClusterParams; pub async fn start_subprocess_cluster_with_qwen3_embedding( - inference_parameters: InferenceParameters, - slots_per_agent: i32, - agent_count: usize, + Qwen3EmbeddingClusterParams { + agents, + buffered_request_timeout, + inference_parameters, + max_buffered_requests, + }: Qwen3EmbeddingClusterParams, ) -> Result { - let ModelCard { reference, .. } = qwen3_embedding_0_6b(); + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_embedding_0_6b(); + + let test_device = current_test_device()?; + test_device.require_available()?; + let device_offload_parameters = + test_device.inference_parameters_for_full_offload(gpu_layer_count); + + let inference_parameters_with_offload = InferenceParameters { + n_gpu_layers: device_offload_parameters.n_gpu_layers, + ..inference_parameters + }; start_subprocess_cluster(SubprocessClusterParams { - agent_count, - slots_per_agent, + agents, + buffered_request_timeout, desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters, + inference_parameters: inference_parameters_with_offload, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), + max_buffered_requests, wait_for_slots_ready: true, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs b/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs index 8c02df63..88d324ef 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs +++ b/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs @@ -2,6 +2,7 @@ use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; use crate::cluster_handle::ClusterHandle; use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; @@ -11,8 +12,7 @@ use crate::start_subprocess_cluster::start_subprocess_cluster; use crate::subprocess_cluster_params::SubprocessClusterParams; pub async fn start_subprocess_cluster_with_smolvlm2( - slots_per_agent: i32, - agent_count: usize, + agents: Vec, ) -> Result { let device = current_test_device()?; @@ -28,8 +28,7 @@ pub async fn start_subprocess_cluster_with_smolvlm2( } = smolvlm2_256m_mmproj(); start_subprocess_cluster(SubprocessClusterParams { - agent_count, - slots_per_agent, + agents, desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/src/subprocess_cluster_params.rs b/paddler_tests/src/subprocess_cluster_params.rs index b33bdbbb..03fa0739 100644 --- a/paddler_tests/src/subprocess_cluster_params.rs +++ b/paddler_tests/src/subprocess_cluster_params.rs @@ -2,16 +2,16 @@ use std::time::Duration; use paddler_types::balancer_desired_state::BalancerDesiredState; +use crate::agent_config::AgentConfig; + pub struct SubprocessClusterParams { - pub agent_count: usize, - pub agent_name_prefix: String, + pub agents: Vec, pub buffered_request_timeout: Duration, pub desired_state: Option, pub inference_cors_allowed_hosts: Vec, pub inference_item_timeout: Duration, pub management_cors_allowed_hosts: Vec, pub max_buffered_requests: i32, - pub slots_per_agent: i32, pub state_database_url: String, pub wait_for_slots_ready: bool, } @@ -19,15 +19,13 @@ pub struct SubprocessClusterParams { impl Default for SubprocessClusterParams { fn default() -> Self { Self { - agent_count: 1, - agent_name_prefix: "test-agent".to_owned(), + agents: AgentConfig::uniform(1, 4), buffered_request_timeout: Duration::from_secs(10), desired_state: Some(BalancerDesiredState::default()), inference_cors_allowed_hosts: Vec::new(), - inference_item_timeout: Duration::from_secs(30), + inference_item_timeout: Duration::from_secs(60), management_cors_allowed_hosts: Vec::new(), max_buffered_requests: 10, - slots_per_agent: 4, state_database_url: "memory://".to_owned(), wait_for_slots_ready: true, } diff --git a/paddler_tests/src/test_device.rs b/paddler_tests/src/test_device.rs index 569b42e2..4468ffae 100644 --- a/paddler_tests/src/test_device.rs +++ b/paddler_tests/src/test_device.rs @@ -2,11 +2,11 @@ use anyhow::Result; #[cfg(any(feature = "cuda", feature = "metal"))] use anyhow::bail; #[cfg(any(feature = "cuda", feature = "metal"))] -use paddler::llama_cpp_bindings::llama_backend::LlamaBackend; +use llama_cpp_bindings::llama_backend::LlamaBackend; #[cfg(any(feature = "cuda", feature = "metal"))] -use paddler::llama_cpp_bindings::llama_backend_device::LlamaBackendDeviceType; +use llama_cpp_bindings::llama_backend_device::LlamaBackendDeviceType; #[cfg(any(feature = "cuda", feature = "metal"))] -use paddler::llama_cpp_bindings::llama_backend_device::list_llama_ggml_backend_devices; +use llama_cpp_bindings::llama_backend_device::list_llama_ggml_backend_devices; use paddler_types::inference_parameters::InferenceParameters; #[derive(Clone, Copy, Debug, Eq, PartialEq)] diff --git a/paddler_tests/src/token_result_with_producer.rs b/paddler_tests/src/token_result_with_producer.rs new file mode 100644 index 00000000..6687eb4a --- /dev/null +++ b/paddler_tests/src/token_result_with_producer.rs @@ -0,0 +1,7 @@ +use paddler_types::generated_token_result::GeneratedTokenResult; + +#[derive(Debug)] +pub struct TokenResultWithProducer { + pub token_result: GeneratedTokenResult, + pub generated_by: Option, +} diff --git a/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs b/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs index 75aec4ef..aed646c8 100644 --- a/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs +++ b/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs @@ -3,6 +3,7 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -20,7 +21,7 @@ async fn agent_chunks_embedding_batch_larger_than_slot_count() -> Result<()> { enable_embeddings: true, ..InferenceParameters::default() }, - 4, + AgentConfig::single(4), ) .await?; @@ -50,16 +51,16 @@ async fn agent_chunks_embedding_batch_larger_than_slot_count() -> Result<()> { let returned_ids: BTreeSet = collected .embeddings .iter() - .map(|embedding| embedding.source_document_id.clone()) + .map(|produced| produced.embedding.source_document_id.clone()) .collect(); let expected_ids: BTreeSet = (0..12).map(|index| format!("doc-{index}")).collect(); assert_eq!(returned_ids, expected_ids); - let first_dimension = collected.embeddings[0].embedding.len(); + let first_dimension = collected.embeddings[0].embedding.embedding.len(); - for embedding in &collected.embeddings { - assert_eq!(embedding.embedding.len(), first_dimension); + for produced in &collected.embeddings { + assert_eq!(produced.embedding.embedding.len(), first_dimension); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs b/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs new file mode 100644 index 00000000..abfddfe9 --- /dev/null +++ b/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs @@ -0,0 +1,94 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::fs; + +use anyhow::Context as _; +use anyhow::Result; +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_smolvlm2::start_in_process_cluster_with_smolvlm2; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result { + let fixture_path = format!("{}/../fixtures/{fixture_name}", env!("CARGO_MANIFEST_DIR")); + let bytes = fs::read(&fixture_path) + .with_context(|| format!("failed to read test fixture {fixture_path}"))?; + let encoded = BASE64_STANDARD.encode(&bytes); + + Ok(format!("data:{mime_type};base64,{encoded}")) +} + +async fn drive_normal_image_fixture( + inference_client: &InferenceHttpClient, + fixture_name: &str, + mime_type: &str, +) -> Result<()> { + let image_data_uri = load_fixture_as_data_uri(fixture_name, mime_type)?; + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "What do you see in this image?".to_owned(), + }, + ]), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 20, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let saw_token = collected + .token_results + .iter() + .any(|result| result.token_result.is_token()); + + assert!( + saw_token, + "fixture {fixture_name} should produce at least one content/reasoning/tool-call token with adequate n_batch; got {:?}", + collected + .token_results + .iter() + .map(|result| &result.token_result) + .collect::>(), + ); + + Ok(()) +} + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_completes_generation_with_adequate_n_batch() -> Result<()> { + let cluster = start_in_process_cluster_with_smolvlm2(AgentConfig::single(1)).await?; + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + drive_normal_image_fixture(&inference_client, "sarnow.jpeg", "image/jpeg").await?; + drive_normal_image_fixture(&inference_client, "llamas.webp", "image/webp").await?; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs b/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs new file mode 100644 index 00000000..e9609871 --- /dev/null +++ b/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use anyhow::Result; +use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; +use tokio::sync::Barrier; + +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +async fn agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents() +-> Result<()> { + const AGENT_COUNT: usize = 4; + const SLOTS_PER_AGENT: i32 = 4; + const PARALLEL_CALLERS: usize = AGENT_COUNT * (SLOTS_PER_AGENT as usize); + + let pool = Arc::new(AgentControllerPool::default()); + let mut controllers = Vec::with_capacity(AGENT_COUNT); + + for index in 0..AGENT_COUNT { + let agent_id = format!("agent-{index}"); + let controller = Arc::new(make_agent_controller_without_remote_agent(&agent_id)); + + controller.slots_total.set(SLOTS_PER_AGENT); + pool.register_agent_controller(agent_id, controller.clone())?; + controllers.push(controller); + } + + let barrier = Arc::new(Barrier::new(PARALLEL_CALLERS)); + let mut handles = Vec::with_capacity(PARALLEL_CALLERS); + + for _ in 0..PARALLEL_CALLERS { + let pool_for_task = pool.clone(); + let barrier_for_task = barrier.clone(); + + handles.push(tokio::spawn(async move { + barrier_for_task.wait().await; + + pool_for_task.take_least_busy_agent_controller() + })); + } + + let mut acquired = Vec::with_capacity(PARALLEL_CALLERS); + + for handle in handles { + if let Some(dispatched_agent) = handle.await? { + acquired.push(dispatched_agent); + } + } + + assert_eq!( + acquired.len(), + PARALLEL_CALLERS, + "every caller must acquire a slot when total capacity equals concurrency" + ); + + for controller in &controllers { + assert_eq!( + controller.slots_processing.get(), + SLOTS_PER_AGENT, + "agent {} should be filled to capacity under fair burst dispatch", + controller.id, + ); + } + + Ok(()) +} diff --git a/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs b/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs new file mode 100644 index 00000000..380a2805 --- /dev/null +++ b/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs @@ -0,0 +1,68 @@ +use std::sync::Arc; + +use anyhow::Result; +use anyhow::anyhow; +use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn agent_controller_pool_re_selects_after_contended_claim() -> Result<()> { + let pool = Arc::new(AgentControllerPool::default()); + + let controller_a = Arc::new(make_agent_controller_without_remote_agent("agent-a")); + controller_a.slots_total.set(4); + pool.register_agent_controller("agent-a".to_owned(), controller_a)?; + + let controller_b = Arc::new(make_agent_controller_without_remote_agent("agent-b")); + controller_b.slots_total.set(4); + pool.register_agent_controller("agent-b".to_owned(), controller_b)?; + + let candidate_first = pool + .select_least_busy_with_capacity() + .ok_or_else(|| anyhow!("expected a candidate when both agents have free capacity"))?; + let first_pick_id = candidate_first.agent_controller.id.clone(); + + assert_eq!( + candidate_first.snapshot, 0, + "snapshot must capture the value observed at selection time" + ); + + assert!( + candidate_first + .agent_controller + .slots_processing + .compare_and_swap(0, 1), + "simulated contender must succeed at incrementing the targeted agent before our claim" + ); + + let claim_outcome = pool.try_claim(candidate_first); + + assert!( + claim_outcome.is_err(), + "stale snapshot must produce a Contended (Err) outcome, not a successful claim" + ); + + let candidate_second = pool + .select_least_busy_with_capacity() + .ok_or_else(|| anyhow!("expected a candidate after re-selection"))?; + + assert_ne!( + candidate_second.agent_controller.id, first_pick_id, + "after a contended claim, re-selection must pick the truly-least-busy agent (the other one)" + ); + assert_eq!( + candidate_second.snapshot, 0, + "the un-contended agent's snapshot must still be 0" + ); + + let dispatched = pool + .try_claim(candidate_second) + .map_err(|_| anyhow!("fresh selection must claim the un-contended agent"))?; + + assert_ne!( + dispatched.agent_controller.id, first_pick_id, + "the dispatched agent must be the one selected after re-selection" + ); + + Ok(()) +} diff --git a/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs b/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs index 1b39a422..fe4ebffc 100644 --- a/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs +++ b/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs @@ -4,20 +4,20 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_accepts_empty_tools_list() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,6 +32,7 @@ async fn agent_conversation_accepts_empty_tools_list() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 10, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -41,7 +42,7 @@ async fn agent_conversation_accepts_empty_tools_list() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs b/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs index e62d19d1..a77054b9 100644 --- a/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs +++ b/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs @@ -4,20 +4,20 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_history_respects_max_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,6 +32,7 @@ async fn agent_conversation_history_respects_max_tokens() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -41,7 +42,7 @@ async fn agent_conversation_history_respects_max_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs b/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs index 5507fd2d..f2cbf393 100644 --- a/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs +++ b/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -12,7 +13,7 @@ use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; @@ -23,7 +24,7 @@ use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_function_tool_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +46,7 @@ async fn agent_conversation_with_function_tool_succeeds() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 50, + parse_tool_calls: true, tools: vec![Tool::Function(FunctionCall { function: Function { name: "get_weather".to_owned(), diff --git a/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs b/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs index c6139515..d38d1f83 100644 --- a/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs +++ b/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -17,7 +18,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_gbnf_grammar_constrains_output() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -37,6 +38,7 @@ async fn agent_conversation_with_gbnf_grammar_constrains_output() -> Result<()> root: "root".to_owned(), }), max_tokens: 10, + parse_tool_calls: false, tools: vec![], }) .await?; diff --git a/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs b/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs index 25939315..727eabcc 100644 --- a/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs +++ b/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -17,7 +18,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_json_schema_grammar_returns_valid_json() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -34,6 +35,7 @@ async fn agent_conversation_with_json_schema_grammar_returns_valid_json() -> Res schema: r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#.to_owned(), }), max_tokens: 50, + parse_tool_calls: false, tools: vec![], }) .await?; diff --git a/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs b/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs index 700e7640..4bc10a38 100644 --- a/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs +++ b/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs @@ -5,13 +5,14 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_without_grammar_field_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_url = cluster .addresses diff --git a/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs b/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs new file mode 100644 index 00000000..eb1daf1a --- /dev/null +++ b/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs @@ -0,0 +1,98 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::fs; + +use anyhow::Context as _; +use anyhow::Result; +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_smolvlm2_and_n_batch::start_in_process_cluster_with_smolvlm2_and_n_batch; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result { + let fixture_path = format!("{}/../fixtures/{fixture_name}", env!("CARGO_MANIFEST_DIR")); + let bytes = fs::read(&fixture_path) + .with_context(|| format!("failed to read test fixture {fixture_path}"))?; + let encoded = BASE64_STANDARD.encode(&bytes); + + Ok(format!("data:{mime_type};base64,{encoded}")) +} + +async fn drive_oversized_image_fixture( + inference_client: &InferenceHttpClient, + fixture_name: &str, + mime_type: &str, +) -> Result<()> { + let image_data_uri = load_fixture_as_data_uri(fixture_name, mime_type)?; + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "Describe this image.".to_owned(), + }, + ]), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 20, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let saw_oversized = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageExceedsBatchSize(_), + ) + }); + + assert!( + saw_oversized, + "fixture {fixture_name} must produce GeneratedTokenResult::ImageExceedsBatchSize when n_batch < image tokens; got {:?}", + collected + .token_results + .iter() + .map(|result| &result.token_result) + .collect::>(), + ); + + Ok(()) +} + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_does_not_crash_on_oversized_image() -> Result<()> { + let cluster = + start_in_process_cluster_with_smolvlm2_and_n_batch(AgentConfig::single(1), 32).await?; + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + drive_oversized_image_fixture(&inference_client, "sarnow.jpeg", "image/jpeg").await?; + drive_oversized_image_fixture(&inference_client, "llamas.webp", "image/webp").await?; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs b/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs index d559046e..dbefb692 100644 --- a/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs +++ b/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs @@ -3,6 +3,7 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -17,12 +18,12 @@ use reqwest::Client; async fn agent_embedding_batch_distribution_independent_of_context_size() -> Result<()> { let cluster = start_in_process_embedding_cluster( InferenceParameters { - batch_n_tokens: 64, + n_batch: 64, context_size: 512, enable_embeddings: true, ..InferenceParameters::default() }, - 4, + AgentConfig::single(4), ) .await?; @@ -62,7 +63,7 @@ async fn agent_embedding_batch_distribution_independent_of_context_size() -> Res let returned_ids: BTreeSet = collected .embeddings .iter() - .map(|embedding| embedding.source_document_id.clone()) + .map(|produced| produced.embedding.source_document_id.clone()) .collect(); let expected_ids: BTreeSet = BTreeSet::from([ diff --git a/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs b/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs index 59a8e670..d09b1842 100644 --- a/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs +++ b/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs @@ -3,6 +3,7 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -20,7 +21,7 @@ async fn agent_embedding_batch_returns_one_embedding_per_input_document() -> Res enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -52,7 +53,7 @@ async fn agent_embedding_batch_returns_one_embedding_per_input_document() -> Res let returned_ids: BTreeSet = collected .embeddings .iter() - .map(|embedding| embedding.source_document_id.clone()) + .map(|produced| produced.embedding.source_document_id.clone()) .collect(); let expected_ids: BTreeSet = diff --git a/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs b/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs new file mode 100644 index 00000000..96c04184 --- /dev/null +++ b/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; + +const N_BATCH: u32 = 64; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_embedding_batch_with_all_oversized_documents_reports_error() -> Result<()> { + let cluster = start_in_process_embedding_cluster( + InferenceParameters { + n_batch: N_BATCH as usize, + context_size: 4096, + enable_embeddings: true, + ..InferenceParameters::default() + }, + AgentConfig::single(1), + ) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let huge_content = "The quick brown fox jumps over the lazy dog. ".repeat(40); + + let stream = inference_client + .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch: vec![ + EmbeddingInputDocument { + content: huge_content.clone(), + id: "huge-1".to_owned(), + }, + EmbeddingInputDocument { + content: huge_content, + id: "huge-2".to_owned(), + }, + ], + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + let collected = collect_embedding_results(stream).await?; + + assert_eq!( + collected.embeddings.len(), + 0, + "no embeddings should be produced when all documents are oversized", + ); + assert_eq!( + collected.oversized_documents.len(), + 2, + "both oversized documents should be reported", + ); + assert!( + collected.saw_done, + "stream must terminate with the balancer's final Done so the client unblocks", + ); + assert_eq!( + collected.no_embeddings_produced_count, + 1, + "the agent must terminate its sub-stream with a single NoEmbeddingsProduced variant when zero embeddings are produced; got oversized_documents: {:?}, errors: {:?}", + collected + .oversized_documents + .iter() + .map(|details| &details.source_document_id) + .collect::>(), + collected.errors, + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs b/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs new file mode 100644 index 00000000..4c238f4a --- /dev/null +++ b/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs @@ -0,0 +1,100 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; + +const N_BATCH: u32 = 64; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_embedding_document_exceeds_n_batch() -> Result<()> { + let cluster = start_in_process_embedding_cluster( + InferenceParameters { + n_batch: N_BATCH as usize, + context_size: 4096, + enable_embeddings: true, + ..InferenceParameters::default() + }, + AgentConfig::single(1), + ) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let huge_content = "The quick brown fox jumps over the lazy dog. ".repeat(40); + + let stream = inference_client + .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch: vec![ + EmbeddingInputDocument { + content: "ok".to_owned(), + id: "tiny".to_owned(), + }, + EmbeddingInputDocument { + content: huge_content, + id: "huge".to_owned(), + }, + ], + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + let collected = collect_embedding_results(stream).await?; + + assert!( + collected.saw_done, + "stream must terminate with Done even when one document is oversized", + ); + assert!( + collected.errors.is_empty(), + "no generic EmbeddingResult::Error events should be emitted; got {:?}", + collected.errors, + ); + + assert_eq!( + collected.oversized_documents.len(), + 1, + "exactly one DocumentExceedsBatchSize event expected; got {:?}", + collected + .oversized_documents + .iter() + .map(|details| &details.source_document_id) + .collect::>(), + ); + + let oversized = &collected.oversized_documents[0]; + + assert_eq!(oversized.source_document_id, "huge"); + assert_eq!(oversized.n_batch, N_BATCH); + assert!( + oversized.document_tokens > oversized.n_batch, + "document_tokens ({}) must exceed n_batch ({}) for the assertion to be meaningful", + oversized.document_tokens, + oversized.n_batch, + ); + + assert_eq!( + collected.embeddings.len(), + 1, + "the small document must still be embedded; got {:?}", + collected + .embeddings + .iter() + .map(|produced| &produced.embedding.source_document_id) + .collect::>(), + ); + assert_eq!(collected.embeddings[0].embedding.source_document_id, "tiny",); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs b/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs index 713d77a1..d9eae9d6 100644 --- a/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs +++ b/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -18,7 +19,7 @@ async fn agent_embeddings_share_dimension_across_inputs_of_varying_length() -> R enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -50,17 +51,17 @@ async fn agent_embeddings_share_dimension_across_inputs_of_varying_length() -> R assert_eq!(collected.embeddings.len(), 3); assert!(collected.saw_done); - let first_dimension = collected.embeddings[0].embedding.len(); + let first_dimension = collected.embeddings[0].embedding.embedding.len(); assert!(first_dimension > 0, "embedding dimension must be positive"); - for embedding in &collected.embeddings { + for produced in &collected.embeddings { assert_eq!( - embedding.embedding.len(), + produced.embedding.embedding.len(), first_dimension, "all embeddings must share dimension; {} has {} instead of {}", - embedding.source_document_id, - embedding.embedding.len(), + produced.embedding.source_document_id, + produced.embedding.embedding.len(), first_dimension ); } diff --git a/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs b/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs index 41da9dac..5ad7e18c 100644 --- a/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs +++ b/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs @@ -46,7 +46,7 @@ async fn agent_exits_cleanly_on_sigterm_during_multimodal_inference() -> Result< } = qwen2_5_vl_3b_mmproj(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, @@ -104,6 +104,7 @@ async fn agent_exits_cleanly_on_sigterm_during_multimodal_inference() -> Result< enable_thinking: false, grammar: None, max_tokens: 200, + parse_tool_calls: false, tools: vec![], }) .await?; diff --git a/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs b/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs index 82b48cb2..15a341da 100644 --- a/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs +++ b/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs @@ -35,7 +35,7 @@ async fn agent_exits_cleanly_when_killed_during_generation() -> Result<()> { } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs b/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs index 9f29a53a..649c38f6 100644 --- a/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs +++ b/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -18,7 +19,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_grammar_with_thinking_returns_incompatible_error() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -35,6 +36,7 @@ async fn agent_grammar_with_thinking_returns_incompatible_error() -> Result<()> schema: r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#.to_owned(), }), max_tokens: 50, + parse_tool_calls: false, tools: vec![], }) .await; @@ -45,7 +47,7 @@ async fn agent_grammar_with_thinking_returns_incompatible_error() -> Result<()> if let Ok(collected) = collected { assert!( collected.token_results.iter().any(|result| matches!( - result, + result.token_result, GeneratedTokenResult::GrammarIncompatibleWithThinking(_) )), "expected GrammarIncompatibleWithThinking error" diff --git a/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs b/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs index 7156eb9a..fe6094d2 100644 --- a/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs +++ b/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs @@ -3,6 +3,7 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -23,7 +24,7 @@ async fn agent_isolates_concurrent_embedding_requests_per_client() -> Result<()> enable_embeddings: true, ..InferenceParameters::default() }, - 4, + AgentConfig::single(4), ) .await?; @@ -71,7 +72,7 @@ async fn agent_isolates_concurrent_embedding_requests_per_client() -> Result<()> let returned_ids: BTreeSet = collected .embeddings .iter() - .map(|embedding| embedding.source_document_id.clone()) + .map(|produced| produced.embedding.source_document_id.clone()) .collect(); let expected_ids: BTreeSet = (0..docs_per_client) .map(|document_index| format!("client-{client_index}-doc-{document_index}")) diff --git a/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs b/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs index 1304674b..153ee20d 100644 --- a/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs +++ b/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -18,7 +19,7 @@ async fn agent_l2_normalized_embeddings_have_unit_norm() -> Result<()> { enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -40,14 +41,15 @@ async fn agent_l2_normalized_embeddings_have_unit_norm() -> Result<()> { assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); - let embedding = &collected.embeddings[0]; + let produced = &collected.embeddings[0]; assert!(matches!( - embedding.normalization_method, + produced.embedding.normalization_method, EmbeddingNormalizationMethod::L2 )); - let l2_norm: f32 = embedding + let l2_norm: f32 = produced + .embedding .embedding .iter() .map(|value| value * value) diff --git a/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs b/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs index 32ccde2b..ece35f61 100644 --- a/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs +++ b/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs @@ -5,13 +5,14 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_openai_chat_completions_non_streaming_returns_text() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let openai_url = cluster .addresses @@ -23,8 +24,9 @@ async fn agent_openai_chat_completions_non_streaming_returns_text() -> Result<() .json(&json!({ "model": "test", "messages": [{"role": "user", "content": "Say hello"}], - "max_completion_tokens": 10, + "max_completion_tokens": 200, "stream": false, + "chat_template_kwargs": {"enable_thinking": false}, })) .send() .await diff --git a/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs b/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs index 7331ed83..a912ea54 100644 --- a/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs +++ b/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs @@ -5,13 +5,14 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_openai_chat_completions_streaming_returns_chunks() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let openai_url = cluster .addresses diff --git a/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs b/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs new file mode 100644 index 00000000..714d81c3 --- /dev/null +++ b/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs @@ -0,0 +1,112 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::sync::Arc; + +use anyhow::Result; +use anyhow::bail; +use llama_cpp_bindings::ToolCallArguments; +use llama_cpp_bindings::llama_backend::LlamaBackend; +use llama_cpp_bindings::model::LlamaModel; +use llama_cpp_bindings::model::params::LlamaModelParams; +use paddler::tool_call_event::ToolCallEvent; +use paddler::tool_call_pipeline::ToolCallPipeline; +use paddler::tool_call_validator::ToolCallValidator; +use paddler_tests::model_card::ModelCard; +use paddler_tests::model_card::deepseek_r1_distill_llama_8b::deepseek_r1_distill_llama_8b; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use serde_json::Map; + +const QWEN_XML_PAYLOAD: &str = "\n\ +\n\ +\n\ +Paris\n\ +\n\ +\n\ +"; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[test] +fn agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered() +-> Result<()> { + let backend = LlamaBackend::init()?; + + let ModelCard { + gpu_layer_count, + reference, + } = deepseek_r1_distill_llama_8b(); + + let path = hf_hub::api::sync::ApiBuilder::from_env() + .build()? + .model(reference.repo_id.clone()) + .get(&reference.filename)?; + + let model_params = LlamaModelParams::default().with_n_gpu_layers(gpu_layer_count); + let model = Arc::new(LlamaModel::load_from_file(&backend, &path, &model_params)?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + let tools = vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(serde_json::Value::Bool(false)), + }), + }, + })]; + + let validator = ToolCallValidator::from_tools(&tools)?; + let tools_json: Vec = tools + .iter() + .map(serde_json::to_value) + .collect::>()?; + let mut pipeline = ToolCallPipeline::new(model, &tools_json, validator)?; + + pipeline.feed(QWEN_XML_PAYLOAD); + let event = pipeline.finalize(); + + let ToolCallEvent::Resolved(parsed_calls) = event else { + bail!( + "duck-type pass must recover Qwen XML on a model with no registered template; \ + expected ToolCallEvent::Resolved, got {event:?}" + ); + }; + assert_eq!( + parsed_calls.len(), + 1, + "expected exactly one parsed tool call; got {parsed_calls:?}" + ); + assert_eq!(parsed_calls[0].name, "get_weather"); + + let mapped = ToolCallEvent::Resolved(parsed_calls) + .into_generated_token_result() + .ok_or_else(|| anyhow::anyhow!("Resolved must produce a GeneratedTokenResult variant"))?; + let GeneratedTokenResult::ToolCallParsed(wire_calls) = mapped else { + bail!("expected GeneratedTokenResult::ToolCallParsed after mapping"); + }; + assert_eq!(wire_calls.len(), 1); + assert_eq!(wire_calls[0].name, "get_weather"); + let location = match &wire_calls[0].arguments { + ToolCallArguments::ValidJson(value) => value + .get("location") + .and_then(|v| v.as_str()) + .map(str::to_owned), + ToolCallArguments::InvalidJson(raw) => { + bail!("expected ValidJson, got InvalidJson: {raw}"); + } + }; + assert_eq!(location.as_deref(), Some("Paris")); + + Ok(()) +} diff --git a/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs b/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs index ec1922a4..288ccaa8 100644 --- a/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs +++ b/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs @@ -4,17 +4,17 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_respects_max_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,7 +32,7 @@ async fn agent_raw_prompt_respects_max_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs b/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs index 6bef4a8a..696a0db0 100644 --- a/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs +++ b/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -14,7 +15,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_with_gbnf_grammar_constrains_output() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); diff --git a/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs b/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs index 11fa358c..44611f21 100644 --- a/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs +++ b/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs @@ -5,13 +5,14 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_without_grammar_field_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_url = cluster .addresses diff --git a/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs b/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs index 1b6ca8fe..7b4b6835 100644 --- a/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs +++ b/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use paddler_types::conversation_history::ConversationHistory; @@ -11,7 +12,7 @@ use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; @@ -21,7 +22,7 @@ use serde_json::Map; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_rejects_tool_with_invalid_required_field_in_schema() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -40,6 +41,7 @@ async fn agent_rejects_tool_with_invalid_required_field_in_schema() -> Result<() enable_thinking: true, grammar: None, max_tokens: 10, + parse_tool_calls: true, tools: vec![Tool::Function(FunctionCall { function: Function { name: "test_fn".to_owned(), diff --git a/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs b/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs index 914ad4a9..5cb088f0 100644 --- a/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs +++ b/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs @@ -6,6 +6,7 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -15,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_releases_slot_when_websocket_client_disconnects() -> Result<()> { - let mut cluster = start_subprocess_cluster_with_qwen3(1, 1).await?; + let mut cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; let agent_id = cluster .agent_ids diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs index c8295928..6ea643af 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs @@ -7,6 +7,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; use paddler_tests::model_card::ModelCard; @@ -31,8 +32,10 @@ async fn agent_reports_slot_cannot_start_for_excessive_slots_in_process() -> Res let inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); let mut cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 257, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 257, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs index ec17177c..5e5cd051 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs @@ -7,6 +7,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; @@ -31,8 +32,7 @@ async fn agent_reports_slot_cannot_start_for_excessive_slots_subprocess() -> Res let inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 257, + agents: AgentConfig::uniform(1, 257), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs index eba1e05c..9289c53d 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; use paddler_tests::model_card::ModelCard; @@ -42,8 +43,10 @@ async fn agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_proc inference_parameters.v_cache_dtype = KvCacheDtype::Q4_0; let mut cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs index 9f9d1481..cbe534ab 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; @@ -42,8 +43,7 @@ async fn agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subproc inference_parameters.v_cache_dtype = KvCacheDtype::Q4_0; let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs b/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs index 7f3dc1b8..cc61d8b2 100644 --- a/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs +++ b/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs @@ -2,6 +2,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -19,7 +20,7 @@ async fn agent_returns_identical_embeddings_for_identical_documents() -> Result< enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -52,17 +53,17 @@ async fn agent_returns_identical_embeddings_for_identical_documents() -> Result< let first = collected .embeddings .iter() - .find(|embedding| embedding.source_document_id == "doc-first") + .find(|produced| produced.embedding.source_document_id == "doc-first") .context("first embedding missing")?; let second = collected .embeddings .iter() - .find(|embedding| embedding.source_document_id == "doc-second") + .find(|produced| produced.embedding.source_document_id == "doc-second") .context("second embedding missing")?; assert_eq!( - first.embedding, second.embedding, + first.embedding.embedding, second.embedding.embedding, "identical documents must produce identical embedding vectors" ); diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs index 2d7121b3..e1e14b7e 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -19,7 +20,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_invalid_base64() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -43,6 +44,7 @@ async fn agent_returns_image_decoding_error_for_invalid_base64() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await; @@ -51,10 +53,12 @@ async fn agent_returns_image_decoding_error_for_invalid_base64() -> Result<()> { let collected = collect_generated_tokens(stream).await; if let Ok(collected) = collected { - let saw_decoding_error = collected - .token_results - .iter() - .any(|result| matches!(result, GeneratedTokenResult::ImageDecodingFailed(_))); + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); assert!( saw_decoding_error, diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs index ffffb651..55a4a84b 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -19,7 +20,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_malformed_data_uri() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -43,6 +44,7 @@ async fn agent_returns_image_decoding_error_for_malformed_data_uri() -> Result<( enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await; @@ -51,10 +53,12 @@ async fn agent_returns_image_decoding_error_for_malformed_data_uri() -> Result<( let collected = collect_generated_tokens(stream).await; if let Ok(collected) = collected { - let saw_decoding_error = collected - .token_results - .iter() - .any(|result| matches!(result, GeneratedTokenResult::ImageDecodingFailed(_))); + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); assert!( saw_decoding_error, diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs index 2dc20b59..8ba81670 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; @@ -19,7 +20,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_remote_url() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -43,6 +44,7 @@ async fn agent_returns_image_decoding_error_for_remote_url() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await; @@ -51,10 +53,12 @@ async fn agent_returns_image_decoding_error_for_remote_url() -> Result<()> { let collected = collect_generated_tokens(stream).await; if let Ok(collected) = collected { - let saw_decoding_error = collected - .token_results - .iter() - .any(|result| matches!(result, GeneratedTokenResult::ImageDecodingFailed(_))); + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); assert!( saw_decoding_error, diff --git a/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs b/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs index 7d090b4f..31d0b471 100644 --- a/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs +++ b/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -18,7 +19,7 @@ async fn agent_returns_rms_normalized_embeddings_when_requested() -> Result<()> enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -40,7 +41,7 @@ async fn agent_returns_rms_normalized_embeddings_when_requested() -> Result<()> assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); assert!(matches!( - collected.embeddings[0].normalization_method, + collected.embeddings[0].embedding.normalization_method, EmbeddingNormalizationMethod::RmsNorm { .. } )); diff --git a/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs b/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs index 5dad152f..45f68685 100644 --- a/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs +++ b/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; @@ -18,7 +19,7 @@ async fn agent_returns_unnormalized_embeddings_when_requested() -> Result<()> { enable_embeddings: true, ..InferenceParameters::default() }, - 1, + AgentConfig::single(1), ) .await?; @@ -40,7 +41,7 @@ async fn agent_returns_unnormalized_embeddings_when_requested() -> Result<()> { assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); assert!(matches!( - collected.embeddings[0].normalization_method, + collected.embeddings[0].embedding.normalization_method, EmbeddingNormalizationMethod::None )); diff --git a/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs b/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs index 99f0069e..7a3a2bd9 100644 --- a/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs +++ b/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs @@ -4,17 +4,17 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_serves_four_concurrent_clients_streaming_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(4, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 4)).await?; let inference_base_url = cluster.addresses.inference_base_url()?; @@ -46,7 +46,7 @@ async fn agent_serves_four_concurrent_clients_streaming_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs b/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs index 34ddf0d3..fdb10902 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs @@ -4,20 +4,20 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_conversation_history_over_http() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,6 +32,7 @@ async fn agent_streams_tokens_from_conversation_history_over_http() -> Result<() enable_thinking: true, grammar: None, max_tokens: 50, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -41,7 +42,7 @@ async fn agent_streams_tokens_from_conversation_history_over_http() -> Result<() let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs b/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs index 3a129f8e..d9074177 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; @@ -12,7 +13,6 @@ use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::image_url::ImageUrl; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; @@ -20,7 +20,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_image_data_uri() -> Result<()> { - let cluster = start_subprocess_cluster_with_smolvlm2(4, 1).await?; + let cluster = start_subprocess_cluster_with_smolvlm2(AgentConfig::uniform(1, 4)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -46,6 +46,7 @@ async fn agent_streams_tokens_from_image_data_uri() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 100, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -55,7 +56,7 @@ async fn agent_streams_tokens_from_image_data_uri() -> Result<()> { let received_tokens = collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::Token(_))); + .any(|result| result.token_result.is_token()); assert!(received_tokens); diff --git a/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs b/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs index 531b8179..e3baa1f2 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs @@ -4,17 +4,17 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_raw_prompt() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,7 +32,7 @@ async fn agent_streams_tokens_from_raw_prompt() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs b/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs index bf0f3f4b..7bf0facc 100644 --- a/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs +++ b/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs @@ -4,6 +4,7 @@ ))] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; @@ -20,7 +21,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_text_only_model_rejects_image_input() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -46,6 +47,7 @@ async fn agent_text_only_model_rejects_image_input() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await; @@ -56,7 +58,7 @@ async fn agent_text_only_model_rejects_image_input() -> Result<()> { if let Ok(collected) = collected { let saw_rejection = collected.token_results.iter().any(|result| { matches!( - result, + result.token_result, GeneratedTokenResult::ChatTemplateError(_) | GeneratedTokenResult::MultimodalNotSupported(_) ) diff --git a/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs b/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs index 588fff33..7ed99c42 100644 --- a/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs +++ b/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs @@ -12,7 +12,7 @@ use tokio_tungstenite::tungstenite::protocol::Message; #[tokio::test(flavor = "multi_thread")] async fn balancer_closes_management_websocket_on_sigterm() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs b/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs index 30cc6423..8dca6351 100644 --- a/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs +++ b/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs @@ -36,7 +36,7 @@ async fn balancer_completes_buffered_request_after_agent_joins() -> Result<()> { } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_secs(120), max_buffered_requests: 1, diff --git a/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs b/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs index dee67c7b..3cee6e2b 100644 --- a/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs +++ b/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs @@ -4,12 +4,17 @@ ))] use anyhow::Result; +use anyhow::anyhow; +use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::grammar_constraint::GrammarConstraint; +use paddler_types::inference_client::Message as InferenceMessage; +use paddler_types::inference_client::Response as InferenceResponse; use paddler_types::inference_parameters::InferenceParameters; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -17,14 +22,14 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_completes_in_flight_inference_during_model_switch() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(1, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); let expected_output = "the quick brown fox jumps over the lazy dog"; - let stream = inference_client + let mut stream = inference_client .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: Some(GrammarConstraint::Gbnf { grammar: format!("root ::= \"{expected_output}\""), @@ -35,7 +40,29 @@ async fn balancer_completes_in_flight_inference_during_model_switch() -> Result< }) .await?; - // Trigger model switch to a nonexistent path while the request is in flight + // Wait for the first generated-token message before triggering the model + // switch. This guarantees the agent has acquired its inference slot and + // entered the generating phase, so the agent's `drain_in_flight_requests` + // correctly waits for the in-flight request to finish before tearing + // down the arbiter. Without this wait, the model-switch can race the + // request through the scheduler queue: drain sees zero slots in use, + // returns immediately, the arbiter is shut down, and the queued request + // times out with no scheduler to process it. + let mut buffered_text = String::new(); + loop { + let next = stream + .next() + .await + .ok_or_else(|| anyhow!("inference stream ended before producing any token"))??; + if let InferenceMessage::Response(envelope) = next + && let InferenceResponse::GeneratedToken(token_result) = envelope.response + && let Some(token_text) = token_result.token_text() + { + buffered_text.push_str(token_text); + break; + } + } + let switch_state = BalancerDesiredState { chat_template_override: None, inference_parameters: InferenceParameters::default(), @@ -53,8 +80,11 @@ async fn balancer_completes_in_flight_inference_during_model_switch() -> Result< let collected = collect_generated_tokens(stream).await?; + let mut full_text = buffered_text; + full_text.push_str(&collected.text); + assert_eq!( - collected.text, expected_output, + full_text, expected_output, "grammar-constrained output must complete despite concurrent model switch" ); diff --git a/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs b/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs index 8b887155..baccd731 100644 --- a/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs +++ b/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs @@ -7,6 +7,7 @@ use std::time::Duration; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; @@ -32,12 +33,19 @@ async fn balancer_distributes_buffered_requests_across_two_agents() -> Result<() } = qwen3_0_6b(); let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 2, - slots_per_agent: 2, + agents: vec![ + AgentConfig { + name: "distributed-agent-0".to_owned(), + slot_count: 2, + }, + AgentConfig { + name: "distributed-agent-1".to_owned(), + slot_count: 2, + }, + ], wait_for_slots_ready: true, buffered_request_timeout: Duration::from_secs(120), max_buffered_requests: 10, - agent_name_prefix: "distributed-agent".to_owned(), desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs index 802f9b2d..9ccf57c4 100644 --- a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs +++ b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs @@ -6,8 +6,10 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; use paddler_types::embedding_input_document::EmbeddingInputDocument; use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; @@ -18,14 +20,14 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_distributes_embedding_batch_across_agents() -> Result<()> { - let mut cluster = start_subprocess_cluster_with_qwen3_embedding( - InferenceParameters { + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(2, 4), + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - 4, - 2, - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; let inference_client = @@ -43,44 +45,24 @@ async fn balancer_distributes_embedding_batch_across_agents() -> Result<()> { normalization_method: EmbeddingNormalizationMethod::None, }; - let mut seen_busy_agents: BTreeSet = BTreeSet::new(); - let mut have_seen_any_activity = false; - - let request_future = async { - let stream = inference_client - .post_generate_embedding_batch(¶ms) - .await?; - collect_embedding_results(stream).await - }; - - let observation_future = cluster.agents.until(|snapshot| { - let any_busy_now = snapshot - .agents - .iter() - .any(|agent| agent.slots_processing > 0); - - if any_busy_now { - have_seen_any_activity = true; - for agent in &snapshot.agents { - if agent.slots_processing > 0 { - seen_busy_agents.insert(agent.id.clone()); - } - } - } - - seen_busy_agents.len() >= 2 || (have_seen_any_activity && !any_busy_now) - }); - - let (request_result, observation_result) = tokio::join!(request_future, observation_future); - let collected = request_result?; - observation_result?; + let stream = inference_client + .post_generate_embedding_batch(¶ms) + .await?; + let collected = collect_embedding_results(stream).await?; assert_eq!(collected.embeddings.len(), 12); assert!(collected.saw_done); assert!(collected.errors.is_empty()); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + assert!( - seen_busy_agents.len() >= 2, - "expected the embedding batch to be distributed across at least two agents, but only saw activity on: {seen_busy_agents:?}" + producers.len() >= 2, + "expected the embedding batch to be distributed across at least two agents, but only saw producers: {producers:?}" ); cluster.shutdown().await?; diff --git a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs new file mode 100644 index 00000000..a00c5ae6 --- /dev/null +++ b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs @@ -0,0 +1,97 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::collections::BTreeSet; + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_embedding_batch_across_agents_with_uneven_slots() -> Result<()> { + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: vec![ + AgentConfig { + name: "agent-fat".to_owned(), + slot_count: 4, + }, + AgentConfig { + name: "agent-thin-a".to_owned(), + slot_count: 1, + }, + AgentConfig { + name: "agent-medium".to_owned(), + slot_count: 2, + }, + AgentConfig { + name: "agent-thin-b".to_owned(), + slot_count: 1, + }, + ], + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let input_batch: Vec = (0..8) + .map(|index| EmbeddingInputDocument { + content: format!("Uneven-slot document number {index}."), + id: format!("doc-{index}"), + }) + .collect(); + + let stream = inference_client + .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + let collected = collect_embedding_results(stream).await?; + + assert_eq!(collected.embeddings.len(), 8); + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + + let returned_document_ids: BTreeSet = collected + .embeddings + .iter() + .map(|produced| produced.embedding.source_document_id.clone()) + .collect(); + let expected_document_ids: BTreeSet = + (0..8).map(|index| format!("doc-{index}")).collect(); + assert_eq!(returned_document_ids, expected_document_ids); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers.len(), + 4, + "embedding batch must fan out across all agents even when slot counts are uneven, but only saw producers: {producers:?}", + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs b/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs new file mode 100644 index 00000000..9e114ab5 --- /dev/null +++ b/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs @@ -0,0 +1,91 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::collections::BTreeSet; + +use std::time::Duration; + +use anyhow::Result; +use futures_util::future; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_embedding_burst_evenly_across_agents() -> Result<()> { + const AGENT_COUNT: usize = 4; + const SLOTS_PER_AGENT: i32 = 2; + const CONCURRENT_REQUESTS: usize = 8; + + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT), + buffered_request_timeout: Duration::from_secs(60), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + max_buffered_requests: 32, + }) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let collection_futures = (0..CONCURRENT_REQUESTS).map(|request_index| { + let inference_client = inference_client.clone(); + async move { + let input_batch: Vec = (0..4) + .map(|document_index| EmbeddingInputDocument { + content: format!( + "Burst request {request_index}, document {document_index}: \ + provide an embedding for evaluation." + ), + id: format!("req-{request_index}-doc-{document_index}"), + }) + .collect(); + + let stream = inference_client + .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + collect_embedding_results(stream).await + } + }); + + let collected_streams = future::try_join_all(collection_futures).await?; + + let producers_across_streams: BTreeSet<&str> = collected_streams + .iter() + .flat_map(|collected| collected.embeddings.iter()) + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers_across_streams.len(), + AGENT_COUNT, + "burst of {CONCURRENT_REQUESTS} embedding batches across {AGENT_COUNT} agents must reach every agent, but saw producers: {producers_across_streams:?}", + ); + + for collected in &collected_streams { + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + assert_eq!(collected.embeddings.len(), 4); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs b/paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs new file mode 100644 index 00000000..cf5cded4 --- /dev/null +++ b/paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs @@ -0,0 +1,88 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::collections::BTreeSet; + +use anyhow::Result; +use anyhow::anyhow; +use futures_util::future; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_types::request_params::ContinueFromRawPromptParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_token_burst_evenly_across_agents() -> Result<()> { + const AGENT_COUNT: usize = 4; + const SLOTS_PER_AGENT: i32 = 1; + + let cluster = + start_subprocess_cluster_with_qwen3(AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT)) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let prompts: Vec = (0..AGENT_COUNT) + .map(|index| format!("Burst request number {index}: Count from one to five.")) + .collect(); + + let collection_futures = prompts.iter().map(|prompt| { + let inference_client = inference_client.clone(); + let raw_prompt = prompt.clone(); + async move { + let stream = inference_client + .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt, + }) + .await?; + + collect_generated_tokens(stream).await + } + }); + + let collected_streams = future::try_join_all(collection_futures).await?; + + let mut producer_per_stream: Vec = Vec::with_capacity(AGENT_COUNT); + + for (stream_index, collected) in collected_streams.iter().enumerate() { + let producers_for_stream: BTreeSet<&str> = collected + .token_results + .iter() + .filter_map(|chunk| chunk.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers_for_stream.len(), + 1, + "stream {stream_index} must be served by exactly one agent, but saw producers: {producers_for_stream:?}", + ); + + let producer = producers_for_stream + .into_iter() + .next() + .ok_or_else(|| anyhow!("stream {stream_index} produced no attributable tokens"))? + .to_owned(); + + producer_per_stream.push(producer); + } + + let unique_producers: BTreeSet<&str> = producer_per_stream.iter().map(String::as_str).collect(); + + assert_eq!( + unique_producers.len(), + AGENT_COUNT, + "burst of {AGENT_COUNT} requests with {SLOTS_PER_AGENT} slot per agent must fan out across all agents, but stream-to-producer map was: {producer_per_stream:?}", + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs b/paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs new file mode 100644 index 00000000..a8b42453 --- /dev/null +++ b/paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs @@ -0,0 +1,98 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::time::Duration; + +use anyhow::Result; +use anyhow::anyhow; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; +use tokio::time::timeout; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests() +-> Result<()> { + const TOTAL_DOCUMENTS: usize = 16; + + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(4, 1), + buffered_request_timeout: Duration::from_secs(2), + inference_parameters: InferenceParameters { + embedding_batch_size: 1, + enable_embeddings: true, + ..InferenceParameters::default() + }, + max_buffered_requests: 4, + }) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let input_batch: Vec = (0..TOTAL_DOCUMENTS) + .map(|index| EmbeddingInputDocument { + content: format!("Overflow probe document {index}."), + id: format!("doc-{index}"), + }) + .collect(); + + let stream = inference_client + .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + let collected = timeout(Duration::from_secs(15), collect_embedding_results(stream)) + .await + .map_err(|_| anyhow!("burst-overflow embedding stream did not finish within 15s"))??; + + let overflow_errors: Vec<_> = collected + .wire_errors + .iter() + .filter(|wire_error| wire_error.code == 503) + .collect(); + + assert!( + !overflow_errors.is_empty(), + "expected at least one HTTP 503 \"Buffered requests overflow\" envelope, but saw none; wire_errors = {:?}", + collected.wire_errors, + ); + + for overflow in &overflow_errors { + assert!( + overflow.description.contains("Buffered requests overflow"), + "expected 503 envelope description to mention overflow, got {:?}", + overflow.description, + ); + } + + assert!( + collected.saw_done, + "stream must terminate cleanly with Done even when some sub-batches overflow", + ); + + assert_eq!( + collected.embeddings.len() + collected.wire_errors.len(), + TOTAL_DOCUMENTS, + "every sub-batch must be accounted for as either a successful embedding or a wire error (503 overflow or 504 timeout): {} embeddings + {} wire errors ({} of which are 503 overflow) ≠ {TOTAL_DOCUMENTS}", + collected.embeddings.len(), + collected.wire_errors.len(), + overflow_errors.len(), + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs b/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs index 055fbb56..afe3c64f 100644 --- a/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs +++ b/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs @@ -6,8 +6,10 @@ use std::collections::BTreeSet; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; use paddler_types::embedding_input_document::EmbeddingInputDocument; use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; @@ -20,14 +22,14 @@ use reqwest::Client; async fn balancer_fans_out_embedding_batch_to_all_agents() -> Result<()> { let agent_count: usize = 4; - let mut cluster = start_subprocess_cluster_with_qwen3_embedding( - InferenceParameters { + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(agent_count, 2), + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - 2, - agent_count, - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; let inference_client = @@ -45,45 +47,25 @@ async fn balancer_fans_out_embedding_batch_to_all_agents() -> Result<()> { normalization_method: EmbeddingNormalizationMethod::None, }; - let mut seen_busy_agents: BTreeSet = BTreeSet::new(); - let mut have_seen_any_activity = false; - - let request_future = async { - let stream = inference_client - .post_generate_embedding_batch(¶ms) - .await?; - collect_embedding_results(stream).await - }; - - let observation_future = cluster.agents.until(|snapshot| { - let any_busy_now = snapshot - .agents - .iter() - .any(|agent| agent.slots_processing > 0); - - if any_busy_now { - have_seen_any_activity = true; - for agent in &snapshot.agents { - if agent.slots_processing > 0 { - seen_busy_agents.insert(agent.id.clone()); - } - } - } - - seen_busy_agents.len() >= agent_count || (have_seen_any_activity && !any_busy_now) - }); - - let (request_result, observation_result) = tokio::join!(request_future, observation_future); - let collected = request_result?; - observation_result?; + let stream = inference_client + .post_generate_embedding_batch(¶ms) + .await?; + let collected = collect_embedding_results(stream).await?; assert_eq!(collected.embeddings.len(), 16); assert!(collected.saw_done); assert!(collected.errors.is_empty()); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + assert_eq!( - seen_busy_agents.len(), + producers.len(), agent_count, - "expected the embedding batch to fan out across every agent, but only saw activity on: {seen_busy_agents:?}" + "expected the embedding batch to fan out across every agent, but only saw producers: {producers:?}" ); cluster.shutdown().await?; diff --git a/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs b/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs new file mode 100644 index 00000000..7a734880 --- /dev/null +++ b/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs @@ -0,0 +1,92 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use paddler::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use paddler::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use paddler::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; +use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler::balancer::manages_senders_controller::ManagesSendersController; +use paddler::balancer::request_from_agent::forward_responses_stream; +use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; +use paddler_types::inference_client::Message as OutgoingMessage; +use paddler_types::jsonrpc::ErrorEnvelope; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +#[tokio::test(flavor = "multi_thread")] +async fn balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks() -> Result<()> { + let agent_controller = Arc::new(make_agent_controller_without_remote_agent("test-agent")); + let request_id = "timeout-test-request".to_owned(); + let receive_response_controller: ManagesSendersController = + ManagesSendersController::from_request_id( + request_id.clone(), + agent_controller.embedding_sender_collection.clone(), + )?; + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel::(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let connection_close = CancellationToken::new(); + let inference_item_timeout = Duration::from_millis(150); + let configuration = InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout, + }; + + let agent_controller_clone = agent_controller.clone(); + let request_id_clone = request_id.clone(); + let forward_handle: tokio::task::JoinHandle> = tokio::spawn(async move { + forward_responses_stream::<_, EmbeddingSenderCollection>( + agent_controller_clone, + connection_close, + configuration, + receive_response_controller, + request_id_clone, + session_controller, + ) + .await + }); + + let forward_completed_within = inference_item_timeout * 10; + tokio::time::timeout(forward_completed_within, forward_handle) + .await + .context("forward_responses_stream did not return within the 504-timeout budget")? + .context("forward_responses_stream task panicked")? + .context("forward_responses_stream returned an error")?; + + let chunk = chunk_rx + .recv() + .await + .ok_or_else(|| anyhow!("expected a 504 timeout envelope to be forwarded to the client"))?; + + let serialized = match chunk { + TransformResult::Chunk(serialized) | TransformResult::Error(serialized) => serialized, + TransformResult::Discard => { + return Err(anyhow!( + "expected a Chunk or Error transform result, got Discard" + )); + } + }; + + let envelope: OutgoingMessage = + serde_json::from_str(&serialized).context("failed to parse forwarded envelope as JSON")?; + + let OutgoingMessage::Error(ErrorEnvelope { error, .. }) = envelope else { + return Err(anyhow!("expected an Error envelope")); + }; + + assert_eq!( + error.code, 504, + "expected 504 error code for inference-item timeout, got {}", + error.code, + ); + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs b/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs new file mode 100644 index 00000000..38c16a9f --- /dev/null +++ b/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs @@ -0,0 +1,96 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use paddler::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use paddler::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use paddler::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; +use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler::balancer::manages_senders::ManagesSenders as _; +use paddler::balancer::manages_senders_controller::ManagesSendersController; +use paddler::balancer::request_from_agent::forward_responses_stream; +use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; +use paddler_types::inference_client::Message as OutgoingMessage; +use paddler_types::jsonrpc::ErrorEnvelope; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +#[tokio::test(flavor = "multi_thread")] +async fn forward_responses_stream_emits_error_envelope_when_response_channel_closes_before_terminator() +-> Result<()> { + let agent_controller = Arc::new(make_agent_controller_without_remote_agent("test-agent")); + let request_id = "test-request".to_owned(); + let receive_response_controller: ManagesSendersController = + ManagesSendersController::from_request_id( + request_id.clone(), + agent_controller.embedding_sender_collection.clone(), + )?; + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel::(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let connection_close = CancellationToken::new(); + let configuration = InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: Duration::from_secs(30), + }; + + let agent_controller_clone = agent_controller.clone(); + let request_id_clone = request_id.clone(); + let forward_handle: tokio::task::JoinHandle> = tokio::spawn(async move { + forward_responses_stream::<_, EmbeddingSenderCollection>( + agent_controller_clone, + connection_close, + configuration, + receive_response_controller, + request_id_clone, + session_controller, + ) + .await + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + agent_controller + .embedding_sender_collection + .deregister_sender(request_id.clone())?; + + tokio::time::timeout(Duration::from_secs(5), forward_handle) + .await + .context("forward_responses_stream did not complete in time")? + .context("forward_responses_stream task panicked")? + .context("forward_responses_stream returned an error")?; + + let chunk = chunk_rx + .recv() + .await + .ok_or_else(|| anyhow!("expected an error envelope to be forwarded to the client"))?; + + let serialized = match chunk { + TransformResult::Chunk(serialized) | TransformResult::Error(serialized) => serialized, + TransformResult::Discard => { + return Err(anyhow!("expected a Chunk transform result, got Discard")); + } + }; + + let envelope: OutgoingMessage = + serde_json::from_str(&serialized).context("failed to parse forwarded envelope as JSON")?; + + let OutgoingMessage::Error(ErrorEnvelope { error, .. }) = envelope else { + return Err(anyhow!("expected an Error envelope")); + }; + + assert_eq!( + error.code, 502, + "expected 502 error code for premature channel close, got {}", + error.code + ); + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs b/paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs new file mode 100644 index 00000000..0fa01396 --- /dev/null +++ b/paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs @@ -0,0 +1,43 @@ +use std::time::Duration; + +use anyhow::Result; +use anyhow::anyhow; +use futures_util::StreamExt as _; +use paddler_tests::in_process_cluster_params::InProcessClusterParams; +use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use tokio::time::timeout; + +#[tokio::test(flavor = "multi_thread")] +async fn balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second() +-> Result<()> { + let cluster = start_in_process_cluster(InProcessClusterParams { + agent: None, + wait_for_slots_ready: false, + ..InProcessClusterParams::default() + }) + .await?; + + let mut sse_stream = cluster + .paddler_client + .management() + .get_buffered_requests_stream() + .await + .map_err(anyhow::Error::new)?; + + let _first_snapshot = timeout(Duration::from_secs(1), sse_stream.next()) + .await + .map_err(|elapsed| anyhow!("first SSE snapshot must arrive within 1s: {elapsed}"))? + .ok_or_else(|| anyhow!("SSE stream closed before first snapshot"))? + .map_err(anyhow::Error::new)?; + + timeout(Duration::from_secs(1), cluster.shutdown()) + .await + .map_err(|elapsed| { + anyhow!( + "balancer in-process shutdown with an open SSE subscriber must complete within \ + 1s after cancel; got: {elapsed}" + ) + })??; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_inference_health_returns_ok.rs b/paddler_tests/tests/balancer_inference_health_returns_ok.rs index 49f0f1ef..f25ff347 100644 --- a/paddler_tests/tests/balancer_inference_health_returns_ok.rs +++ b/paddler_tests/tests/balancer_inference_health_returns_ok.rs @@ -8,7 +8,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn balancer_inference_health_returns_ok() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs b/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs index 164457d7..68cc3b26 100644 --- a/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs +++ b/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs @@ -10,7 +10,7 @@ const ALLOWED_ORIGIN: &str = "http://example.com"; #[tokio::test(flavor = "multi_thread")] async fn balancer_inference_service_replies_with_configured_cors_origin() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), inference_cors_allowed_hosts: vec![ALLOWED_ORIGIN.to_owned()], wait_for_slots_ready: false, ..SubprocessClusterParams::default() diff --git a/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs b/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs index cf166dbc..003b4993 100644 --- a/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs +++ b/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs @@ -10,7 +10,7 @@ const ALLOWED_ORIGIN: &str = "http://example.com"; #[tokio::test(flavor = "multi_thread")] async fn balancer_management_service_replies_with_configured_cors_origin() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), management_cors_allowed_hosts: vec![ALLOWED_ORIGIN.to_owned()], wait_for_slots_ready: false, ..SubprocessClusterParams::default() diff --git a/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs b/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs index cc6065ec..2a99f0ec 100644 --- a/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs +++ b/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs @@ -27,7 +27,7 @@ async fn balancer_memory_storage_persists_desired_state() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: "memory://".to_owned(), desired_state: Some(desired_state.clone()), diff --git a/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs b/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs index ed9d49b1..0c5876c5 100644 --- a/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs +++ b/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs @@ -8,7 +8,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn balancer_openai_compat_health_returns_ok() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs b/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs index a8d67c5f..7452f91e 100644 --- a/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs +++ b/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs @@ -35,7 +35,7 @@ async fn balancer_persists_chat_template_override_across_restart() -> Result<()> }; let first_cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(desired_state.clone()), @@ -46,7 +46,7 @@ async fn balancer_persists_chat_template_override_across_restart() -> Result<()> first_cluster.shutdown().await?; let second_cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: None, diff --git a/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs b/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs index bb5675e6..278f3f25 100644 --- a/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs +++ b/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs @@ -30,7 +30,7 @@ async fn balancer_persists_desired_state_across_restart() -> Result<()> { }; let first_cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(desired_state.clone()), @@ -41,7 +41,7 @@ async fn balancer_persists_desired_state_across_restart() -> Result<()> { first_cluster.shutdown().await?; let second_cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: None, diff --git a/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs b/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs index bd96cd33..38b2fb26 100644 --- a/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs +++ b/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs @@ -27,7 +27,7 @@ async fn balancer_persists_huggingface_mmproj_in_desired_state() -> Result<()> { } = smolvlm2_256m_mmproj(); let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs b/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs index 08ca0fca..db0ba63a 100644 --- a/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs +++ b/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs @@ -21,7 +21,7 @@ async fn balancer_persists_local_mmproj_path_in_desired_state() -> Result<()> { let local_mmproj_path = "/tmp/test-mmproj.gguf".to_owned(); let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs b/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs index 7f3cba23..f69a1133 100644 --- a/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs +++ b/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs @@ -30,7 +30,7 @@ async fn balancer_persists_model_switch_in_storage() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(initial_state.clone()), diff --git a/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs b/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs index 7e2171df..d86bc932 100644 --- a/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs +++ b/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs @@ -11,7 +11,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn balancer_registers_multiple_agents_over_time() -> Result<()> { let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs index 907c3543..709af576 100644 --- a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs +++ b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; @@ -21,8 +22,7 @@ async fn balancer_reports_chat_template_does_not_compile_for_invalid_jinja() -> let ModelCard { reference, .. } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: Some(ChatTemplate { diff --git a/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs b/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs index 1283fcad..411452b1 100644 --- a/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs +++ b/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; use paddler_types::agent_desired_model::AgentDesiredModel; @@ -17,8 +18,7 @@ use paddler_types::inference_parameters::InferenceParameters; #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_huggingface_model_does_not_exist() -> Result<()> { let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs b/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs index c2059079..8e75911c 100644 --- a/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs +++ b/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; @@ -25,8 +26,7 @@ async fn balancer_reports_mmproj_cannot_be_loaded_for_invalid_file() -> Result<( let ModelCard { reference, .. } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs index 5c6e94e8..0c74809d 100644 --- a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs +++ b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs @@ -7,6 +7,7 @@ use std::io::Write as _; use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; use paddler_types::agent_desired_model::AgentDesiredModel; @@ -29,8 +30,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_corrupt_file() -> Result<() .to_owned(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs index 2167da44..000db231 100644 --- a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs +++ b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; use paddler_types::agent_desired_model::AgentDesiredModel; @@ -18,8 +19,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_invalid_gguf() -> Result<() let invalid_gguf_path = concat!(env!("CARGO_MANIFEST_DIR"), "/../fixtures/invalid.gguf"); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs b/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs index fd8303c7..f84aad13 100644 --- a/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs +++ b/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; use paddler_types::agent_desired_model::AgentDesiredModel; @@ -16,8 +17,7 @@ use paddler_types::inference_parameters::InferenceParameters; #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_model_file_does_not_exist() -> Result<()> { let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs b/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs index d50caeaa..261a88fe 100644 --- a/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs +++ b/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; @@ -20,8 +21,7 @@ async fn balancer_reports_multimodal_projection_cannot_be_loaded() -> Result<()> let ModelCard { reference, .. } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs b/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs index 33d90d74..35a5c442 100644 --- a/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs +++ b/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::nomic_embed_text_v1_5::nomic_embed_text_v1_5; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; @@ -20,8 +21,7 @@ async fn balancer_reports_unable_to_find_chat_template_for_embedding_model() -> let ModelCard { reference, .. } = nomic_embed_text_v1_5(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: None, @@ -40,11 +40,12 @@ async fn balancer_reports_unable_to_find_chat_template_for_embedding_model() -> .context("cluster must have one registered agent")? .clone(); + let predicate_agent_id = agent_id.clone(); cluster .agents - .until(move |snapshot| { + .until_agent(&agent_id, move |snapshot| { snapshot.agents.iter().any(|agent| { - agent.id == agent_id + agent.id == predicate_agent_id && agent .issues .iter() diff --git a/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs b/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs index c0bbc611..8978d19b 100644 --- a/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs +++ b/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; @@ -35,12 +36,13 @@ async fn balancer_resolves_buffered_requests_after_agent_killed() -> Result<()> } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 2, + agents: vec![AgentConfig { + name: "removal-agent-primary".to_owned(), + slot_count: 2, + }], wait_for_slots_ready: true, buffered_request_timeout: Duration::from_secs(120), max_buffered_requests: 10, - agent_name_prefix: "removal-agent-primary".to_owned(), desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), diff --git a/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs b/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs index b14301b0..3d661ad0 100644 --- a/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs +++ b/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs @@ -21,7 +21,7 @@ use reqwest::Client; #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_503_when_request_buffering_disabled() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), max_buffered_requests: 0, diff --git a/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs b/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs index a3e7068d..9e7555ea 100644 --- a/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs +++ b/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; @@ -33,8 +34,7 @@ async fn balancer_returns_504_when_inference_item_timeout_is_zero() -> Result<() } = qwen3_0_6b(); let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 2, + agents: AgentConfig::uniform(1, 2), inference_item_timeout: Duration::ZERO, wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { diff --git a/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs b/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs index 85f13099..9a3071b5 100644 --- a/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs +++ b/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs @@ -26,7 +26,7 @@ async fn balancer_returns_504_when_no_agents_registered() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), max_buffered_requests: 1, diff --git a/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs b/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs index 24dc75af..041910ce 100644 --- a/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs +++ b/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs @@ -13,7 +13,7 @@ use reqwest::Client; #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_504_when_no_model_configured() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs b/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs index bdc9e57a..e584be7d 100644 --- a/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs +++ b/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs @@ -35,7 +35,7 @@ async fn balancer_serves_request_after_agent_with_capacity_registers() -> Result } = qwen3_0_6b(); let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), max_buffered_requests: 10, diff --git a/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs b/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs index c9abfc65..ab8c85fa 100644 --- a/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs +++ b/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; @@ -12,6 +13,7 @@ use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::chat_template::ChatTemplate; @@ -42,8 +44,7 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(template_a.clone()), @@ -75,6 +76,7 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 10, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -100,21 +102,24 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::Token(_))), + .any(|result| result.token_result.is_token()), "in-flight request must continue producing tokens during template swap" ); assert!( - !collected - .token_results - .iter() - .any(|result| matches!(result, GeneratedTokenResult::ChatTemplateError(_))), + !collected.token_results.iter().any(|result| matches!( + result.token_result, + GeneratedTokenResult::ChatTemplateError(_) + )), "in-flight request must not see ChatTemplateError during swap" ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); let retrieved = cluster diff --git a/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs b/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs index 73b97a7b..3f8cf97b 100644 --- a/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs +++ b/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::nomic_embed_text_v1_5::nomic_embed_text_v1_5; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; @@ -24,8 +25,7 @@ async fn chat_template_override_applied_to_embedding_model() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { chat_template_override: Some(chat_template.clone()), diff --git a/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs b/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs index 2fc982a1..f759a776 100644 --- a/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs +++ b/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; @@ -18,7 +19,6 @@ use paddler_types::chat_template::ChatTemplate; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; @@ -39,8 +39,7 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(chat_template.clone()), @@ -82,6 +81,7 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 10, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -91,7 +91,7 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { let received_tokens = collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::Token(_))); + .any(|result| result.token_result.is_token()); assert!( received_tokens, diff --git a/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs b/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs index 37e7986f..fcc7fa72 100644 --- a/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs +++ b/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs @@ -5,6 +5,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::inference_http_client::InferenceHttpClient; @@ -18,7 +19,6 @@ use paddler_types::chat_template::ChatTemplate; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; use reqwest::Client; @@ -33,6 +33,7 @@ async fn run_inference_after_template_swap(inference_client: &InferenceHttpClien enable_thinking: false, grammar: None, max_tokens: 10, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -42,7 +43,7 @@ async fn run_inference_after_template_swap(inference_client: &InferenceHttpClien Ok(collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::Token(_)))) + .any(|result| result.token_result.is_token())) } #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] @@ -65,8 +66,7 @@ async fn chat_template_swaps_between_inference_calls() -> Result<()> { }; let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, - slots_per_agent: 1, + agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(template_a.clone()), diff --git a/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs b/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs index 7266bb52..971f38e6 100644 --- a/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs +++ b/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -21,7 +23,7 @@ fn user_message(text: &str) -> ConversationMessage { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_concurrent_conversation_history_requests_complete() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(2).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -33,6 +35,7 @@ async fn continuous_batch_concurrent_conversation_history_requests_complete() -> enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -44,6 +47,7 @@ async fn continuous_batch_concurrent_conversation_history_requests_complete() -> enable_thinking: false, grammar: None, max_tokens: 20, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -59,23 +63,29 @@ async fn continuous_batch_concurrent_conversation_history_requests_complete() -> let tokens_a = collected_a .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); let tokens_b = collected_b .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(tokens_a > 0); assert!(tokens_b > 0); assert!(matches!( collected_a.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); assert!(matches!( collected_b.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_distinct_output.rs b/paddler_tests/tests/continuous_batch_distinct_output.rs index e5c1c2a6..9c4d43ef 100644 --- a/paddler_tests/tests/continuous_batch_distinct_output.rs +++ b/paddler_tests/tests/continuous_batch_distinct_output.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; @@ -34,8 +35,10 @@ async fn two_concurrent_prompts_produce_distinct_outputs() -> Result<()> { }; let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 2, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 2, + }), desired_state, wait_for_slots_ready: true, ..InProcessClusterParams::default() diff --git a/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs b/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs index 5cb4986f..1d5e3115 100644 --- a/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs +++ b/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; @@ -8,6 +9,7 @@ use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::generated_token_result::GeneratedTokenResult; @@ -28,13 +30,15 @@ async fn continuous_batch_evicts_long_sequence_under_kv_pressure() -> Result<()> let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); - inference_parameters.batch_n_tokens = 256; + inference_parameters.n_batch = 256; inference_parameters.context_size = 256; inference_parameters.temperature = 0.0; let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 2, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 2, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, @@ -77,7 +81,7 @@ async fn continuous_batch_evicts_long_sequence_under_kv_pressure() -> Result<()> let short_collected = short_collected?; let long_was_evicted = long_collected.token_results.iter().any(|result| { - matches!(result, GeneratedTokenResult::SamplerError(message) if message.contains("evicted")) + matches!(&result.token_result, GeneratedTokenResult::SamplerError(message) if message.contains("evicted")) }); assert!( @@ -86,7 +90,10 @@ async fn continuous_batch_evicts_long_sequence_under_kv_pressure() -> Result<()> ); assert!(matches!( short_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs b/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs index 4b08a0ce..e5fc6a90 100644 --- a/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs +++ b/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; @@ -8,6 +9,7 @@ use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::generated_token_result::GeneratedTokenResult; @@ -33,8 +35,10 @@ async fn continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes() inference_parameters.v_cache_dtype = KvCacheDtype::F16; let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, @@ -63,13 +67,16 @@ async fn continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes() let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs b/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs index 7b1080fb..20227f34 100644 --- a/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs +++ b/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; @@ -8,6 +9,7 @@ use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::generated_token_result::GeneratedTokenResult; @@ -29,8 +31,10 @@ async fn continuous_batch_generates_tokens_with_partial_layer_offload() -> Resul device.inference_parameters_for_full_offload(PARTIAL_GPU_LAYER_COUNT); let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters, @@ -59,13 +63,16 @@ async fn continuous_batch_generates_tokens_with_partial_layer_offload() -> Resul let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs b/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs index 8b02d364..a6f37ab7 100644 --- a/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs +++ b/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -11,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_long_and_short_prompts_complete_concurrently() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(2).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,23 +47,29 @@ async fn continuous_batch_long_and_short_prompts_complete_concurrently() -> Resu let long_tokens = long_collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); let short_tokens = short_collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(long_tokens > 0); assert!(short_tokens > 0); assert!(matches!( long_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); assert!(matches!( short_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs b/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs index 4d58a210..234d7f35 100644 --- a/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs +++ b/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs @@ -1,10 +1,12 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -18,7 +20,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(4, true).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(4), true).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -68,6 +70,7 @@ async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> enable_thinking: false, grammar: None, max_tokens: 32, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -87,7 +90,7 @@ async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( @@ -98,12 +101,15 @@ async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> !collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::SamplerError(_))), + .any(|result| matches!(result.token_result, GeneratedTokenResult::SamplerError(_))), "concurrent {label} request must not surface a SamplerError" ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); } diff --git a/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs b/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs index 47a2a6bc..099e2baa 100644 --- a/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs +++ b/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs @@ -2,6 +2,7 @@ use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_embedding_results::collect_embedding_results; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; @@ -15,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_rejects_embedding_during_active_generation() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(2).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); diff --git a/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs b/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs index bb22c61b..2a64475a 100644 --- a/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs +++ b/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs @@ -3,6 +3,7 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; @@ -29,8 +30,10 @@ async fn continuous_batch_rejects_second_request_when_only_slot_busy() -> Result } = qwen3_0_6b(); let mut cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), max_buffered_requests: 0, desired_state: BalancerDesiredState { chat_template_override: None, diff --git a/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs b/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs index 3f1b4a7f..f078eaab 100644 --- a/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs +++ b/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs @@ -3,6 +3,7 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; @@ -12,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_releases_slot_when_client_disconnects() -> Result<()> { - let mut cluster = start_in_process_cluster_with_qwen3(1).await?; + let mut cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let agent_id = cluster .agent_ids diff --git a/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs b/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs index ba3873c7..cc3f9e2c 100644 --- a/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs +++ b/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs @@ -2,6 +2,7 @@ use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; use paddler_types::request_params::ContinueFromRawPromptParams; @@ -10,7 +11,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_releases_slots_on_shutdown_with_active_request() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); diff --git a/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs b/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs index 8a5e8a40..0c07db46 100644 --- a/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs +++ b/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -11,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_reuses_slot_after_request_completes() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -28,7 +30,10 @@ async fn continuous_batch_reuses_slot_after_request_completes() -> Result<()> { assert!(matches!( first_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); let second_stream = inference_client @@ -43,13 +48,16 @@ async fn continuous_batch_reuses_slot_after_request_completes() -> Result<()> { assert!(matches!( second_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); let second_token_count = second_collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( diff --git a/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs b/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs index 67a84098..7dd5cfff 100644 --- a/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs +++ b/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -11,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_serves_four_concurrent_requests() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(4).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(4)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -43,7 +45,7 @@ async fn continuous_batch_serves_four_concurrent_requests() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( @@ -52,7 +54,10 @@ async fn continuous_batch_serves_four_concurrent_requests() -> Result<()> { ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); } diff --git a/paddler_tests/tests/continuous_batch_smoke.rs b/paddler_tests/tests/continuous_batch_smoke.rs index 8692b8ad..f91c3f9c 100644 --- a/paddler_tests/tests/continuous_batch_smoke.rs +++ b/paddler_tests/tests/continuous_batch_smoke.rs @@ -2,6 +2,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; use paddler_tests::in_process_cluster_params::InProcessClusterParams; @@ -9,6 +10,7 @@ use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::agent_desired_model::AgentDesiredModel; use paddler_types::balancer_desired_state::BalancerDesiredState; use paddler_types::generated_token_result::GeneratedTokenResult; @@ -38,8 +40,10 @@ async fn continuous_batch_smoke_generates_tokens() -> Result<()> { }; let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state, wait_for_slots_ready: true, ..InProcessClusterParams::default() @@ -64,7 +68,7 @@ async fn continuous_batch_smoke_generates_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( @@ -76,7 +80,10 @@ async fn continuous_batch_smoke_generates_tokens() -> Result<()> { assert!( matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) ), "smoke test stream did not terminate with Done" ); diff --git a/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs b/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs index 09a3a698..307cd479 100644 --- a/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs +++ b/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs @@ -3,6 +3,7 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; @@ -12,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stop_signal_terminates_generation_before_max_tokens() -> Result<()> { - let mut cluster = start_in_process_cluster_with_qwen3(1).await?; + let mut cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let agent_id = cluster .agent_ids diff --git a/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs b/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs index fbc12afb..61423f02 100644 --- a/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs +++ b/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -11,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stops_at_max_tokens_boundary() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -29,7 +31,7 @@ async fn continuous_batch_stops_at_max_tokens_boundary() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert_eq!( @@ -38,7 +40,10 @@ async fn continuous_batch_stops_at_max_tokens_boundary() -> Result<()> { ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs b/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs index eeaea187..445db487 100644 --- a/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs +++ b/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs @@ -2,9 +2,11 @@ use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -12,7 +14,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stops_generation_when_stop_sender_dropped() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(2).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -44,13 +46,16 @@ async fn continuous_batch_stops_generation_when_stop_sender_dropped() -> Result< assert!(matches!( second_collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); let second_token_count = second_collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( diff --git a/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs b/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs index d31ddca0..7de37fd7 100644 --- a/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs +++ b/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs @@ -1,10 +1,12 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -47,7 +49,7 @@ fn build_multimodal_conversation(image_data_uri: &str) -> ConversationHistory { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(4, true).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(4), true).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -61,6 +63,7 @@ async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> enable_thinking: false, grammar: None, max_tokens: 32, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -72,6 +75,7 @@ async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> enable_thinking: false, grammar: None, max_tokens: 32, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -88,7 +92,7 @@ async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!( @@ -99,12 +103,15 @@ async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> !collected .token_results .iter() - .any(|result| matches!(result, GeneratedTokenResult::SamplerError(_))), + .any(|result| matches!(result.token_result, GeneratedTokenResult::SamplerError(_))), "concurrent multimodal request must not surface a SamplerError" ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); } diff --git a/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs new file mode 100644 index 00000000..96fe5bbe --- /dev/null +++ b/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs @@ -0,0 +1,104 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_deepseek_r1_distill_llama_8b::start_in_process_cluster_with_deepseek_r1_distill_llama_8b; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens() -> Result<()> { + let cluster = + start_in_process_cluster_with_deepseek_r1_distill_llama_8b(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 400, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "DeepSeek-R1-8B: expected at least one reasoning token from a `` block (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "DeepSeek-R1-8B: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "DeepSeek-R1-8B: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_returns_error_when_embeddings_disabled_in_parameters.rs b/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs similarity index 64% rename from paddler_tests/tests/agent_returns_error_when_embeddings_disabled_in_parameters.rs rename to paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs index 25d26b69..f3245126 100644 --- a/paddler_tests/tests/agent_returns_error_when_embeddings_disabled_in_parameters.rs +++ b/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs @@ -1,9 +1,8 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; use paddler_tests::start_in_process_cluster::start_in_process_cluster; @@ -14,15 +13,18 @@ use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; use paddler_types::inference_parameters::InferenceParameters; use paddler_types::request_params::GenerateEmbeddingBatchParams; use reqwest::Client; +use reqwest::StatusCode; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] -async fn agent_returns_error_when_embeddings_disabled_in_parameters() -> Result<()> { +async fn endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state: BalancerDesiredState { chat_template_override: None, inference_parameters: InferenceParameters::default(), @@ -35,31 +37,26 @@ async fn agent_returns_error_when_embeddings_disabled_in_parameters() -> Result< }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let inference_base_url = cluster.addresses.inference_base_url()?; + let request_url = inference_base_url.join("api/v1/generate_embedding_batch")?; - let outcome = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let response = Client::new() + .post(request_url) + .json(&GenerateEmbeddingBatchParams { input_batch: vec![EmbeddingInputDocument { content: "Hello world".to_owned(), id: "doc-1".to_owned(), }], normalization_method: EmbeddingNormalizationMethod::None, }) - .await; + .send() + .await?; - if let Ok(stream) = outcome { - let collected = collect_embedding_results(stream).await?; - - assert!( - collected.embeddings.is_empty(), - "no embeddings should be returned when embeddings are disabled" - ); - assert!( - !collected.errors.is_empty(), - "stream must report at least one embedding error when embeddings are disabled" - ); - } + assert_eq!( + response.status(), + StatusCode::NOT_IMPLEMENTED, + "endpoint must reject embedding requests with HTTP 501 when embeddings are disabled", + ); cluster.shutdown().await?; diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs new file mode 100644 index 00000000..1f7df4a6 --- /dev/null +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_gemma_4::start_in_process_cluster_with_gemma_4; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn gemma4_internal_endpoint_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_gemma_4(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 800, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Gemma 4: expected at least one reasoning token from a `<|channel>thought` block (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["<|channel>thought", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "Gemma 4: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Gemma 4: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs new file mode 100644 index 00000000..235c4679 --- /dev/null +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -0,0 +1,81 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_in_process_cluster_with_gemma_4_and_mmproj::start_in_process_cluster_with_gemma_4_and_mmproj; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { + let cluster = start_in_process_cluster_with_gemma_4_and_mmproj(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let image_data_uri = load_test_image_data_uri()?; + + let conversation_history = ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "What animals do you see in this image? Think step by step.".to_owned(), + }, + ]), + role: "user".to_owned(), + }]); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history, + enable_thinking: true, + grammar: None, + max_tokens: 200, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Gemma 4: expected at least one reasoning token from a `<|channel>thought` block when an image is attached (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.reasoning_tokens > 0); + assert!(summary.usage.input_image_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..4b2dccdc --- /dev/null +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_gemma_4::start_in_process_cluster_with_gemma_4; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn gemma4_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_gemma_4(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "Gemma 4: expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs new file mode 100644 index 00000000..896ade28 --- /dev/null +++ b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_glm_4_7_flash::start_in_process_cluster_with_glm_4_7_flash; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn glm_4_7_flash_internal_endpoint_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_glm_4_7_flash(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 400, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "GLM-4.7: expected at least one reasoning token from a `` block (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "GLM-4.7: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "GLM-4.7: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..eb5d9159 --- /dev/null +++ b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_glm_4_7_flash::start_in_process_cluster_with_glm_4_7_flash; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_glm_4_7_flash(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "GLM-4.7: expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/harness_agents_watcher.rs b/paddler_tests/tests/harness_agents_watcher.rs index e1c805a0..fac8d7d9 100644 --- a/paddler_tests/tests/harness_agents_watcher.rs +++ b/paddler_tests/tests/harness_agents_watcher.rs @@ -101,7 +101,7 @@ async fn wait_for_slots_ready_includes_agent_id_in_error() { let fixture = stream::iter(vec![Ok(snapshot)]); let mut watcher = AgentsStreamWatcher::from_stream(Box::pin(fixture)); - let outcome = watcher.wait_for_slots_ready(1, 1).await; + let outcome = watcher.wait_for_slots_ready(&[1]).await; assert!( outcome.is_err(), diff --git a/paddler_tests/tests/harness_in_process_cluster_shutdown.rs b/paddler_tests/tests/harness_in_process_cluster_shutdown.rs index 8a023ce1..7b619834 100644 --- a/paddler_tests/tests/harness_in_process_cluster_shutdown.rs +++ b/paddler_tests/tests/harness_in_process_cluster_shutdown.rs @@ -5,7 +5,7 @@ use paddler_tests::start_in_process_cluster::start_in_process_cluster; #[tokio::test(flavor = "multi_thread")] async fn empty_cluster_starts_and_shuts_down_without_timeout() -> Result<()> { let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: false, + agent: None, wait_for_slots_ready: false, ..InProcessClusterParams::default() }) @@ -19,7 +19,6 @@ async fn empty_cluster_starts_and_shuts_down_without_timeout() -> Result<()> { #[tokio::test(flavor = "multi_thread")] async fn single_agent_registers_and_shuts_down_without_timeout() -> Result<()> { let cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, wait_for_slots_ready: false, ..InProcessClusterParams::default() }) diff --git a/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs b/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs index 6c363ccb..c8a4e6a2 100644 --- a/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs +++ b/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs @@ -1,13 +1,14 @@ #![cfg(feature = "tests_that_use_compiled_paddler")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn empty_subprocess_cluster_starts_and_exits_after_sigterm() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) @@ -21,7 +22,7 @@ async fn empty_subprocess_cluster_starts_and_exits_after_sigterm() -> Result<()> #[tokio::test(flavor = "multi_thread")] async fn single_subprocess_agent_registers_and_exits_after_sigterm() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 1, + agents: AgentConfig::uniform(1, 4), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs b/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs index 3e64ff08..3593057f 100644 --- a/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs +++ b/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs @@ -6,12 +6,13 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_agents_stream_yields_initial_snapshot() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let mut stream = cluster .paddler_client diff --git a/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs b/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs index 35a02d31..43028a6c 100644 --- a/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs +++ b/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs @@ -9,7 +9,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn management_buffered_requests_stream_yields_initial_snapshot() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/management_health_endpoint_returns_ok.rs b/paddler_tests/tests/management_health_endpoint_returns_ok.rs index 2fc1d631..367258f8 100644 --- a/paddler_tests/tests/management_health_endpoint_returns_ok.rs +++ b/paddler_tests/tests/management_health_endpoint_returns_ok.rs @@ -8,7 +8,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn management_health_endpoint_returns_ok() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs b/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs index ca31e92a..98ecc865 100644 --- a/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs +++ b/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs @@ -8,7 +8,7 @@ use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; #[tokio::test(flavor = "multi_thread")] async fn management_metrics_endpoint_exposes_prometheus_gauges() -> Result<()> { let cluster = start_subprocess_cluster(SubprocessClusterParams { - agent_count: 0, + agents: Vec::new(), wait_for_slots_ready: false, ..SubprocessClusterParams::default() }) diff --git a/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs b/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs index 6248afc1..22de1ecd 100644 --- a/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs +++ b/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs @@ -5,12 +5,13 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_reports_zero_download_progress_after_load_complete() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let snapshot = cluster .paddler_client diff --git a/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs b/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs index 5d17a401..bdd49fc7 100644 --- a/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs +++ b/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs @@ -5,12 +5,13 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_returns_model_metadata_for_loaded_agent() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(2, 1).await?; + let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let agent_id = cluster .agent_ids diff --git a/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_updates.rs b/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_updates.rs index 6b1ab8ab..b5a7ef4a 100644 --- a/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_updates.rs +++ b/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_updates.rs @@ -2,6 +2,7 @@ use anyhow::Context as _; use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::current_test_device::current_test_device; @@ -36,8 +37,10 @@ async fn management_two_agents_stream_subscribers_receive_slot_usage_updates() - }; let mut cluster = start_in_process_cluster(InProcessClusterParams { - spawn_agent: true, - slots_per_agent: 1, + agent: Some(AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }), desired_state, wait_for_slots_ready: true, ..InProcessClusterParams::default() diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs new file mode 100644 index 00000000..b4e3ad10 --- /dev/null +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_ministral_3::start_in_process_cluster_with_ministral_3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn mistral3_internal_endpoint_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_ministral_3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 800, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Mistral 3: expected at least one reasoning token from a [THINK]-emitting model (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["[THINK]", "[/THINK]"] { + assert!( + !reasoning_stream.contains(forbidden), + "Mistral 3: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Mistral 3: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs new file mode 100644 index 00000000..7a5764fe --- /dev/null +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_in_process_cluster_with_ministral_3_and_mmproj::start_in_process_cluster_with_ministral_3_and_mmproj; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { + let cluster = + start_in_process_cluster_with_ministral_3_and_mmproj(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let image_data_uri = load_test_image_data_uri()?; + + let conversation_history = ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "What animals do you see in this image? Think step by step.".to_owned(), + }, + ]), + role: "user".to_owned(), + }]); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history, + enable_thinking: true, + grammar: None, + max_tokens: 400, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Mistral 3: expected at least one reasoning token from a `[THINK]` block when an image is attached (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.reasoning_tokens > 0); + assert!(summary.usage.input_image_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..bb8912b6 --- /dev/null +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_ministral_3::start_in_process_cluster_with_ministral_3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn mistral3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_ministral_3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "Mistral 3: expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs b/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs new file mode 100644 index 00000000..95c0effe --- /dev/null +++ b/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs @@ -0,0 +1,102 @@ +#![cfg(all( + unix, + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use anyhow::Context as _; +use anyhow::Result; +use nix::sys::signal::Signal; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_embedding_results::collect_embedding_results; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; +use paddler_types::embedding_input_document::EmbeddingInputDocument; +use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::request_params::GenerateEmbeddingBatchParams; +use reqwest::Client; +use tokio::signal::unix::SignalKind; +use tokio::signal::unix::signal; +use tokio_util::sync::CancellationToken; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process() -> Result<()> { + let observed_sigurg_count = Arc::new(AtomicUsize::new(0)); + let observer_shutdown = CancellationToken::new(); + + let observer_count = observed_sigurg_count.clone(); + let observer_token = observer_shutdown.clone(); + let mut sigurg_stream = signal(SignalKind::from_raw(Signal::SIGURG as i32)) + .context("failed to install SIGURG observer on the test process")?; + + let observer_handle = tokio::spawn(async move { + loop { + tokio::select! { + () = observer_token.cancelled() => break, + signal_event = sigurg_stream.recv() => match signal_event { + Some(()) => { + observer_count.fetch_add(1, Ordering::SeqCst); + } + None => break, + }, + } + } + }); + + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(2, 2), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }) + .await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let input_batch: Vec = (0..4) + .map(|document_index| EmbeddingInputDocument { + content: format!("SIGURG regression document number {document_index:02}"), + id: format!("doc-{document_index}"), + }) + .collect(); + let params = GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }; + + let stream = inference_client + .post_generate_embedding_batch(¶ms) + .await?; + let collected = collect_embedding_results(stream).await?; + + assert_eq!(collected.embeddings.len(), 4); + assert!(collected.errors.is_empty()); + + cluster.shutdown().await?; + + observer_shutdown.cancel(); + observer_handle + .await + .context("SIGURG observer task panicked")?; + + let final_sigurg_count = observed_sigurg_count.load(Ordering::SeqCst); + + assert_eq!( + final_sigurg_count, 0, + "paddler subprocesses leaked {final_sigurg_count} SIGURG signals to the parent process; \ + this would kill bash test harness loops that rely on SIGURG's default ignore action being honored. \ + The observer ran throughout cluster startup, an embedding inference, and cluster shutdown." + ); + + Ok(()) +} diff --git a/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs b/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs index 5aab599d..0d2919f7 100644 --- a/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs +++ b/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs @@ -1,10 +1,12 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; use paddler_tests::start_in_process_cluster_with_qwen2_5_vl::start_in_process_cluster_with_qwen2_5_vl; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -17,7 +19,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen2_5_vl(1).await?; + let cluster = start_in_process_cluster_with_qwen2_5_vl(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +47,7 @@ async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 200, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -54,13 +57,16 @@ async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs b/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs index dd8ad4e8..aee344cf 100644 --- a/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs +++ b/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -51,7 +53,7 @@ fn build_long_link_list() -> String { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -81,6 +83,7 @@ async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> enable_thinking: false, grammar: None, max_tokens: 512, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -90,13 +93,16 @@ async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs b/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs index 17a227df..d9b53524 100644 --- a/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs +++ b/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -29,6 +31,7 @@ async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 500, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -38,7 +41,7 @@ async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); @@ -48,7 +51,10 @@ async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { ); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs new file mode 100644 index 00000000..2dbfe3e3 --- /dev/null +++ b/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -0,0 +1,81 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), true).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let image_data_uri = load_test_image_data_uri()?; + + let conversation_history = ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "What animals do you see in this image? Think step by step.".to_owned(), + }, + ]), + role: "user".to_owned(), + }]); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history, + enable_thinking: true, + grammar: None, + max_tokens: 200, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Qwen 3.5: expected at least one reasoning token from a `` block when an image is attached (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.reasoning_tokens > 0); + assert!(summary.usage.input_image_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..aa06133c --- /dev/null +++ b/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen35_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "Qwen3.5: expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs new file mode 100644 index 00000000..e28c8746 --- /dev/null +++ b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs @@ -0,0 +1,138 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text("What is two plus two?".to_owned()), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 200, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + let content_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ContentToken(_))) + .count(); + let undeterminable_count = collected + .token_results + .iter() + .filter(|result| { + matches!( + result.token_result, + GeneratedTokenResult::UndeterminableToken(_) + ) + }) + .count(); + + assert_eq!( + reasoning_count, 0, + "Qwen3.5 thinking-disabled: classifier must not stream any reasoning tokens \ + (got reasoning_count={reasoning_count}, content_count={content_count}, \ + undeterminable_count={undeterminable_count})" + ); + assert_eq!( + undeterminable_count, 0, + "Qwen3.5 thinking-disabled: prompt-token replay must move section to Content \ + before generation, so no UndeterminableToken may stream; \ + (got reasoning_count={reasoning_count}, content_count={content_count}, \ + undeterminable_count={undeterminable_count})" + ); + assert!( + content_count > 0, + "Qwen3.5 thinking-disabled: classifier must stream at least one content token" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert_eq!( + summary.usage.reasoning_tokens, 0, + "Qwen3.5 thinking-disabled: usage.reasoning_tokens must be zero; usage={:?}", + summary.usage + ); + assert_eq!( + summary.usage.undeterminable_tokens, 0, + "Qwen3.5 thinking-disabled: usage.undeterminable_tokens must be zero; usage={:?}", + summary.usage + ); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens, + "Qwen3.5 thinking-disabled: completion tokens equal content tokens since \ + reasoning and undeterminable are zero" + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "Qwen3.5 thinking-disabled: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Qwen3.5 thinking-disabled: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs new file mode 100644 index 00000000..4089d30f --- /dev/null +++ b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 600, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Qwen3.5: expected at least one reasoning token when thinking is enabled (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "Qwen3.5: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Qwen3.5: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs b/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs index a7faa478..fd9e2a62 100644 --- a/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs +++ b/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_thinking_mode_stops_cleanly_before_max_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -29,6 +31,7 @@ async fn qwen35_thinking_mode_stops_cleanly_before_max_tokens() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 2000, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -38,14 +41,17 @@ async fn qwen35_thinking_mode_stops_cleanly_before_max_tokens() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(token_count <= 2000); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs b/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs index a8830763..6fb55f20 100644 --- a/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs +++ b/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +47,7 @@ async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 1000, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -54,14 +57,17 @@ async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(token_count <= 1000); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs b/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs index 0f561f11..54292ff7 100644 --- a/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs +++ b/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs @@ -1,10 +1,12 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -17,7 +19,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, true).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), true).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +47,7 @@ async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 200, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -54,13 +57,16 @@ async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs b/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs index 0303b1a7..a92205eb 100644 --- a/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs +++ b/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -41,6 +43,7 @@ async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { enable_thinking: true, grammar: None, max_tokens: 2000, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -50,13 +53,16 @@ async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs b/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs index 94d8208e..6c2d0560 100644 --- a/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs +++ b/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -41,6 +43,7 @@ async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 512, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -50,13 +53,16 @@ async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs b/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs index fc286aec..7f2b6c28 100644 --- a/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs +++ b/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; @@ -17,7 +18,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(1, false).await?; + let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +46,7 @@ async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> enable_thinking: false, grammar: None, max_tokens: 100, + parse_tool_calls: false, tools: vec![], }) .await; @@ -55,7 +57,7 @@ async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> if let Ok(collected) = collected { assert!( collected.token_results.iter().any(|result| matches!( - result, + result.token_result, GeneratedTokenResult::MultimodalNotSupported(_) )), "expected MultimodalNotSupported, got: {:?}", diff --git a/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs new file mode 100644 index 00000000..c691ac7b --- /dev/null +++ b/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -0,0 +1,81 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_in_process_cluster_with_qwen3_6_and_mmproj::start_in_process_cluster_with_qwen3_6_and_mmproj; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::conversation_message_content_part::ConversationMessageContentPart; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::image_url::ImageUrl; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_6_and_mmproj(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let image_data_uri = load_test_image_data_uri()?; + + let conversation_history = ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Parts(vec![ + ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { + url: image_data_uri, + }, + }, + ConversationMessageContentPart::Text { + text: "What animals do you see in this image? Think step by step.".to_owned(), + }, + ]), + role: "user".to_owned(), + }]); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history, + enable_thinking: true, + grammar: None, + max_tokens: 200, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Qwen 3.6: expected at least one reasoning token from a `` block when an image is attached (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.reasoning_tokens > 0); + assert!(summary.usage.input_image_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..2e530298 --- /dev/null +++ b/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen36_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "Qwen3.6: expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs new file mode 100644 index 00000000..65006fc7 --- /dev/null +++ b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs @@ -0,0 +1,138 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text("What is two plus two?".to_owned()), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 200, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + let content_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ContentToken(_))) + .count(); + let undeterminable_count = collected + .token_results + .iter() + .filter(|result| { + matches!( + result.token_result, + GeneratedTokenResult::UndeterminableToken(_) + ) + }) + .count(); + + assert_eq!( + reasoning_count, 0, + "Qwen3.6 thinking-disabled: classifier must not stream any reasoning tokens \ + (got reasoning_count={reasoning_count}, content_count={content_count}, \ + undeterminable_count={undeterminable_count})" + ); + assert_eq!( + undeterminable_count, 0, + "Qwen3.6 thinking-disabled: prompt-token replay must move section to Content \ + before generation, so no UndeterminableToken may stream; \ + (got reasoning_count={reasoning_count}, content_count={content_count}, \ + undeterminable_count={undeterminable_count})" + ); + assert!( + content_count > 0, + "Qwen3.6 thinking-disabled: classifier must stream at least one content token" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert_eq!( + summary.usage.reasoning_tokens, 0, + "Qwen3.6 thinking-disabled: usage.reasoning_tokens must be zero; usage={:?}", + summary.usage + ); + assert_eq!( + summary.usage.undeterminable_tokens, 0, + "Qwen3.6 thinking-disabled: usage.undeterminable_tokens must be zero; usage={:?}", + summary.usage + ); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens, + "Qwen3.6 thinking-disabled: completion tokens equal content tokens since \ + reasoning and undeterminable are zero" + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "Qwen3.6 thinking-disabled: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Qwen3.6 thinking-disabled: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs new file mode 100644 index 00000000..b0161740 --- /dev/null +++ b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 600, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "Qwen3.6: expected at least one reasoning token when thinking is enabled (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + let reasoning_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ReasoningToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + let content_stream: String = collected + .token_results + .iter() + .filter_map(|result| match &result.token_result { + GeneratedTokenResult::ContentToken(piece) => Some(piece.as_str()), + _ => None, + }) + .collect(); + + for forbidden in ["", ""] { + assert!( + !reasoning_stream.contains(forbidden), + "Qwen3.6: reasoning stream leaked marker {forbidden:?}; \ + reasoning_stream={reasoning_stream:?}" + ); + assert!( + !content_stream.contains(forbidden), + "Qwen3.6: content stream leaked marker {forbidden:?}; \ + content_stream={content_stream:?}" + ); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs b/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs index 6baef0fc..2f77f78d 100644 --- a/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs +++ b/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; @@ -11,7 +12,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_gbnf_grammar_constrains_output_to_yes_or_no() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); diff --git a/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs b/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs index c3f9b029..afb3c816 100644 --- a/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs +++ b/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -14,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_generates_tokens_from_conversation_history() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -29,6 +31,7 @@ async fn qwen3_generates_tokens_from_conversation_history() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 500, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -38,14 +41,17 @@ async fn qwen3_generates_tokens_from_conversation_history() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(token_count < 500, "EOG should stop generation early"); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs b/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs index 46645848..994d4d04 100644 --- a/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs +++ b/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs @@ -1,9 +1,11 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; @@ -11,7 +13,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_generates_tokens_from_raw_prompt() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -31,13 +33,16 @@ async fn qwen3_generates_tokens_from_raw_prompt() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs b/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs index c829367e..1d1ce2c0 100644 --- a/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs +++ b/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; @@ -15,7 +16,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_grammar_with_thinking_returns_incompatible_error() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -32,6 +33,7 @@ async fn qwen3_grammar_with_thinking_returns_incompatible_error() -> Result<()> schema: r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#.to_owned(), }), max_tokens: 50, + parse_tool_calls: false, tools: vec![], }) .await; @@ -41,7 +43,7 @@ async fn qwen3_grammar_with_thinking_returns_incompatible_error() -> Result<()> if let Ok(collected) = collected { assert!( collected.token_results.iter().any(|result| matches!( - result, + result.token_result, GeneratedTokenResult::GrammarIncompatibleWithThinking(_) )), "expected GrammarIncompatibleWithThinking, got: {:?}", diff --git a/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs b/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs new file mode 100644 index 00000000..eac8d072 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs @@ -0,0 +1,80 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use futures_util::future; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::generation_summary::GenerationSummary; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_concurrent_requests_keep_independent_usage() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let prompts = ["Say hi.", "Count to three."]; + + let futures = prompts.iter().map(|prompt| { + let client = inference_client.clone(); + let prompt = (*prompt).to_owned(); + async move { + let stream = client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text(prompt), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 30, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + match &last.token_result { + GeneratedTokenResult::Done(summary) => { + Ok::(*summary) + } + other => Err(anyhow::anyhow!("last result was not Done: {other:?}")), + } + } + }); + + let summaries: Vec = future::try_join_all(futures).await?; + + assert_eq!(summaries.len(), 2); + + for summary in &summaries { + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.completion_tokens() > 0); + } + + // The two requests have different prompts and different generations; + // their usage breakdowns must not be byte-identical. + assert_ne!( + summaries[0].usage, summaries[1].usage, + "concurrent requests reported identical usage; counters likely shared" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs new file mode 100644 index 00000000..d36384e0 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let parsed_events: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallParsed(parsed) => Some(parsed), + _ => None, + }) + .collect(); + + assert!( + !parsed_events.is_empty(), + "expected at least one ToolCallParsed event; got tokens:\n{}", + collected.text + ); + + let first_call = parsed_events + .iter() + .flat_map(|calls| calls.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no parsed tool calls in any event"))?; + + assert_eq!(first_call.name, "get_weather"); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs new file mode 100644 index 00000000..683e970a --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs @@ -0,0 +1,98 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_emits_tool_call_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let tool_call_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ToolCallToken(_))) + .count(); + let content_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ContentToken(_))) + .count(); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!( + tool_call_count > 0, + "expected ToolCallToken (got {tool_call_count}); content_count={content_count}; usage={:?}; generated text:\n{}", + summary.usage, + collected.text, + ); + assert!(summary.usage.tool_call_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs b/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs new file mode 100644 index 00000000..66119a7c --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs @@ -0,0 +1,67 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +const MAX_TOKENS: i32 = 20; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_max_tokens_usage_matches_streamed_count() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text("Tell me a long story.".to_owned()), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: MAX_TOKENS, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let streamed_token_count = collected + .token_results + .iter() + .filter(|result| result.token_result.is_token()) + .count() as u64; + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(streamed_token_count > 0); + assert!(streamed_token_count <= MAX_TOKENS as u64); + assert_eq!( + summary.usage.completion_tokens(), + streamed_token_count, + "Done.usage.completion_tokens must match the count of streamed token deltas" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs b/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs new file mode 100644 index 00000000..52254469 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs @@ -0,0 +1,60 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_pure_content_usage_breakdown() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text("Say hello.".to_owned()), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 60, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.content_tokens > 0); + assert_eq!(summary.usage.reasoning_tokens, 0); + assert_eq!(summary.usage.tool_call_tokens, 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + summary.usage.undeterminable_tokens + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs new file mode 100644 index 00000000..2280be84 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs @@ -0,0 +1,92 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "string", "description": "The city name"}), + ); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: false, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + for event in &collected.token_results { + match &event.token_result { + GeneratedTokenResult::ToolCallParsed(_) + | GeneratedTokenResult::ToolCallParseFailed(_) + | GeneratedTokenResult::ToolSchemaInvalid(_) + | GeneratedTokenResult::ToolCallValidationFailed(_) => { + anyhow::bail!( + "expected no parsed/parse-failed/schema-invalid/validation-failed events when parse_tool_calls=false, got: {event:?}" + ); + } + _ => {} + } + } + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(_) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs new file mode 100644 index 00000000..270611b4 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs @@ -0,0 +1,72 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text("Say hello.".to_owned()), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 100, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + let content_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ContentToken(_))) + .count(); + + assert_eq!( + reasoning_count, 0, + "expected no reasoning tokens when thinking is disabled" + ); + assert!(content_count > 0, "expected content tokens to be produced"); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert_eq!(summary.usage.reasoning_tokens, 0); + assert!(summary.usage.content_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs new file mode 100644 index 00000000..23bb7f11 --- /dev/null +++ b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -0,0 +1,73 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use paddler_types::conversation_history::ConversationHistory; +use paddler_types::conversation_message::ConversationMessage; +use paddler_types::conversation_message_content::ConversationMessageContent; +use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use reqwest::Client; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + + let inference_client = + InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + + let stream = inference_client + .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is two plus two? Think step by step.".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: true, + grammar: None, + max_tokens: 600, + parse_tool_calls: false, + tools: vec![], + }) + .await?; + + let collected = collect_generated_tokens(stream).await?; + + let reasoning_count = collected + .token_results + .iter() + .filter(|result| matches!(result.token_result, GeneratedTokenResult::ReasoningToken(_))) + .count(); + + assert!( + reasoning_count > 0, + "expected at least one reasoning token when thinking is enabled (got {reasoning_count})" + ); + + let last = collected + .token_results + .last() + .ok_or_else(|| anyhow::anyhow!("no token results received"))?; + let GeneratedTokenResult::Done(summary) = &last.token_result else { + anyhow::bail!("last result was not Done: {last:?}"); + }; + + assert!(summary.usage.prompt_tokens > 0); + assert!(summary.usage.reasoning_tokens > 0); + assert_eq!( + summary.usage.completion_tokens(), + summary.usage.content_tokens + + summary.usage.reasoning_tokens + + summary.usage.undeterminable_tokens + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs b/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs index 9006648b..ef71c2b7 100644 --- a/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs +++ b/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs @@ -1,6 +1,7 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; @@ -11,7 +12,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_json_schema_grammar_returns_valid_json() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs b/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs new file mode 100644 index 00000000..9117d69f --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs @@ -0,0 +1,105 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_non_streaming_emits_tool_calls_for_function_tool() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let response = openai_client + .post_non_streaming(&json!({ + "model": "qwen3-test", + "messages": [{ + "role": "user", + "content": "What is the weather in Paris? Use the get_weather tool." + }], + "max_completion_tokens": 400, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city name"} + }, + "required": ["location"], + "additionalProperties": false + } + } + }] + })) + .await?; + + let tool_calls = response + .pointer("/choices/0/message/tool_calls") + .and_then(Value::as_array) + .ok_or_else(|| anyhow::anyhow!("response missing message.tool_calls: {response}"))?; + + assert_eq!( + tool_calls.len(), + 1, + "expected exactly one structured tool call in non-streaming response (got {})", + tool_calls.len() + ); + + let first_call = &tool_calls[0]; + + assert_eq!( + first_call.pointer("/type").and_then(Value::as_str), + Some("function"), + ); + + let id = first_call + .pointer("/id") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("tool call missing id"))?; + assert!(!id.is_empty(), "tool call id must not be empty"); + + let function_name = first_call + .pointer("/function/name") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("tool call missing function.name"))?; + + assert_eq!(function_name, "get_weather"); + + let function_arguments = first_call + .pointer("/function/arguments") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("tool call missing function.arguments"))?; + + let parsed_arguments: Value = serde_json::from_str(function_arguments)?; + assert!( + parsed_arguments.get("location").is_some(), + "tool-call arguments JSON missing 'location' field: {function_arguments}" + ); + + let finish_reason = response + .pointer("/choices/0/finish_reason") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("response missing finish_reason"))?; + + assert_eq!(finish_reason, "tool_calls"); + + let completion_tokens = response + .pointer("/usage/completion_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing usage.completion_tokens"))?; + assert!(completion_tokens > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs b/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs new file mode 100644 index 00000000..09793da3 --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs @@ -0,0 +1,62 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_non_streaming_returns_usage() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let response = openai_client + .post_non_streaming(&json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "Say hi briefly."}], + "max_completion_tokens": 600 + })) + .await?; + + let usage = response + .get("usage") + .ok_or_else(|| anyhow::anyhow!("non-streaming response missing usage: {response}"))?; + + let prompt_tokens = usage + .get("prompt_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.prompt_tokens missing"))?; + let completion_tokens = usage + .get("completion_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.completion_tokens missing"))?; + let total_tokens = usage + .get("total_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.total_tokens missing"))?; + + assert!(prompt_tokens > 0); + assert!(completion_tokens > 0); + assert_eq!(total_tokens, prompt_tokens + completion_tokens); + + let content = response + .pointer("/choices/0/message/content") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("non-streaming response missing message content"))?; + + assert!( + !content.is_empty(), + "non-streaming content must not be empty" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs b/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs new file mode 100644 index 00000000..366064bb --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_non_streaming_usage_with_tool_calls() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let response = openai_client + .post_non_streaming(&json!({ + "model": "qwen3-test", + "messages": [{ + "role": "user", + "content": "What is the weather in Paris? Use the get_weather tool." + }], + "max_completion_tokens": 400, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"], + "additionalProperties": false + } + } + }] + })) + .await?; + + let tool_calls = response + .pointer("/choices/0/message/tool_calls") + .and_then(Value::as_array) + .ok_or_else(|| anyhow::anyhow!("response missing message.tool_calls: {response}"))?; + assert!(!tool_calls.is_empty()); + + let usage = response + .get("usage") + .ok_or_else(|| anyhow::anyhow!("response missing usage: {response}"))?; + + let prompt_tokens = usage + .get("prompt_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.prompt_tokens missing"))?; + let completion_tokens = usage + .get("completion_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.completion_tokens missing"))?; + let total_tokens = usage + .get("total_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.total_tokens missing"))?; + + assert!(prompt_tokens > 0); + // A request that produced a tool call must have spent tokens generating + // the tool-call payload and any wrapping markers; completion_tokens + // therefore cannot be zero. + assert!( + completion_tokens > 0, + "expected non-zero completion_tokens for a tool-call response (got {completion_tokens})" + ); + assert_eq!(total_tokens, prompt_tokens + completion_tokens); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs b/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs new file mode 100644 index 00000000..dd8dc261 --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs @@ -0,0 +1,94 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_streaming_emits_tool_calls_for_function_tool() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let chunks = openai_client + .post_streaming(&json!({ + "model": "qwen3-test", + "messages": [{ + "role": "user", + "content": "What is the weather in Paris? Use the get_weather tool." + }], + "stream": true, + "max_completion_tokens": 400, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city name"} + }, + "required": ["location"], + "additionalProperties": false + } + } + }] + })) + .await?; + + let chunks_with_tool_calls: Vec<&Value> = chunks + .iter() + .filter(|chunk| { + chunk + .pointer("/choices/0/delta/tool_calls") + .and_then(Value::as_array) + .is_some_and(|calls| !calls.is_empty()) + }) + .collect(); + + assert_eq!( + chunks_with_tool_calls.len(), + 1, + "expected exactly one structured tool-call chunk per call (got {})", + chunks_with_tool_calls.len() + ); + + let structured_chunk = chunks_with_tool_calls[0]; + + let function_name = structured_chunk + .pointer("/choices/0/delta/tool_calls/0/function/name") + .and_then(Value::as_str) + .ok_or_else(|| { + anyhow::anyhow!("structured tool-call chunk missing function.name: {structured_chunk}") + })?; + + assert_eq!(function_name, "get_weather"); + + let function_arguments = structured_chunk + .pointer("/choices/0/delta/tool_calls/0/function/arguments") + .and_then(Value::as_str) + .ok_or_else(|| { + anyhow::anyhow!( + "structured tool-call chunk missing function.arguments: {structured_chunk}" + ) + })?; + + let parsed_arguments: Value = serde_json::from_str(function_arguments)?; + + assert!( + parsed_arguments.get("location").is_some(), + "tool-call arguments JSON missing 'location' field: {function_arguments}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs b/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs new file mode 100644 index 00000000..fd2afac4 --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs @@ -0,0 +1,68 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_streaming_emits_usage_when_requested() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let chunks = openai_client + .post_streaming(&json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "Say hi briefly."}], + "stream": true, + "stream_options": {"include_usage": true}, + "max_completion_tokens": 80 + })) + .await?; + + let last_chunk = chunks + .last() + .ok_or_else(|| anyhow::anyhow!("no chunks received from streaming endpoint"))?; + + let usage = last_chunk + .get("usage") + .ok_or_else(|| anyhow::anyhow!("trailing chunk lacks usage field: {last_chunk}"))?; + + let prompt_tokens = usage + .get("prompt_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.prompt_tokens missing or not u64"))?; + let completion_tokens = usage + .get("completion_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.completion_tokens missing or not u64"))?; + let total_tokens = usage + .get("total_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.total_tokens missing or not u64"))?; + + assert!(prompt_tokens > 0); + assert!(completion_tokens > 0); + assert_eq!(total_tokens, prompt_tokens + completion_tokens); + + let trailing_choices = last_chunk + .get("choices") + .and_then(Value::as_array) + .ok_or_else(|| anyhow::anyhow!("trailing chunk lacks choices array"))?; + + assert!( + trailing_choices.is_empty(), + "OpenAI usage chunk must have empty choices array, got: {trailing_choices:?}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs b/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs new file mode 100644 index 00000000..1e7e7d4c --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs @@ -0,0 +1,43 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_streaming_omits_usage_when_not_requested() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let chunks = openai_client + .post_streaming(&json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "Say hi briefly."}], + "stream": true, + "max_completion_tokens": 50 + })) + .await?; + + assert!(!chunks.is_empty(), "expected at least one chunk"); + + let chunks_with_usage = chunks + .iter() + .filter(|chunk| chunk.get("usage").is_some()) + .count(); + + assert_eq!( + chunks_with_usage, 0, + "expected no usage chunks when stream_options.include_usage is absent, got {chunks_with_usage}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs b/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs new file mode 100644 index 00000000..2fd50824 --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs @@ -0,0 +1,47 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_streaming_routes_reasoning_to_reasoning_content() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let chunks = openai_client + .post_streaming(&json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "What is two plus two? Think step by step."}], + "stream": true, + "max_completion_tokens": 600 + })) + .await?; + + let reasoning_chunks = chunks + .iter() + .filter(|chunk| { + chunk + .pointer("/choices/0/delta/reasoning_content") + .and_then(Value::as_str) + .is_some() + }) + .count(); + + assert!( + reasoning_chunks > 0, + "expected at least one delta.reasoning_content chunk; got {reasoning_chunks}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs b/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs new file mode 100644 index 00000000..227586f8 --- /dev/null +++ b/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs @@ -0,0 +1,60 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; +use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_openai_streaming_usage_breakdown_with_thinking() -> Result<()> { + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let openai_client = OpenAIChatCompletionsClient::new( + Client::new(), + &cluster.addresses.compat_openai_base_url()?, + )?; + + let chunks = openai_client + .post_streaming(&json!({ + "model": "qwen3-test", + "messages": [{ + "role": "user", + "content": "What is two plus two? Think briefly before answering." + }], + "stream": true, + "stream_options": {"include_usage": true}, + "max_completion_tokens": 200, + "chat_template_kwargs": {"enable_thinking": true} + })) + .await?; + + let usage_chunk = chunks + .iter() + .rev() + .find(|chunk| chunk.get("usage").is_some_and(|usage| !usage.is_null())) + .ok_or_else(|| anyhow::anyhow!("no chunk contained usage information"))?; + + let prompt_tokens = usage_chunk + .pointer("/usage/prompt_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage chunk missing prompt_tokens"))?; + let completion_tokens = usage_chunk + .pointer("/usage/completion_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage chunk missing completion_tokens"))?; + let total_tokens = usage_chunk + .pointer("/usage/total_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage chunk missing total_tokens"))?; + + assert!(prompt_tokens > 0); + assert!(completion_tokens > 0); + assert_eq!(total_tokens, prompt_tokens + completion_tokens); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs b/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs index 304f4f57..783ffcdc 100644 --- a/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs +++ b/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs @@ -1,17 +1,17 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::generated_token_result::GeneratedTokenResult; use paddler_types::request_params::ContinueFromRawPromptParams; use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_without_grammar_generates_unconstrained_output() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(1).await?; + let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -29,7 +29,7 @@ async fn qwen3_without_grammar_generates_unconstrained_output() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); diff --git a/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs b/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs index 380bd029..bbfe0466 100644 --- a/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs +++ b/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs @@ -1,10 +1,12 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; use paddler_tests::collect_generated_tokens::collect_generated_tokens; use paddler_tests::inference_http_client::InferenceHttpClient; use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; use paddler_tests::start_in_process_cluster_with_smolvlm2::start_in_process_cluster_with_smolvlm2; +use paddler_tests::token_result_with_producer::TokenResultWithProducer; use paddler_types::conversation_history::ConversationHistory; use paddler_types::conversation_message::ConversationMessage; use paddler_types::conversation_message_content::ConversationMessageContent; @@ -17,7 +19,7 @@ use reqwest::Client; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { - let cluster = start_in_process_cluster_with_smolvlm2(1).await?; + let cluster = start_in_process_cluster_with_smolvlm2(AgentConfig::single(1)).await?; let inference_client = InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); @@ -45,6 +47,7 @@ async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { enable_thinking: false, grammar: None, max_tokens: 200, + parse_tool_calls: false, tools: vec![], }) .await?; @@ -54,13 +57,16 @@ async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { let token_count = collected .token_results .iter() - .filter(|result| matches!(result, GeneratedTokenResult::Token(_))) + .filter(|result| result.token_result.is_token()) .count(); assert!(token_count > 0); assert!(matches!( collected.token_results.last(), - Some(GeneratedTokenResult::Done) + Some(TokenResultWithProducer { + token_result: GeneratedTokenResult::Done(_), + .. + }) )); cluster.shutdown().await?; diff --git a/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs b/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs new file mode 100644 index 00000000..6057a314 --- /dev/null +++ b/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs @@ -0,0 +1,60 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::time::Duration; +use std::time::Instant; + +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; +use paddler_types::inference_parameters::InferenceParameters; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn subprocess_cluster_starts_four_agents_within_sequential_spawn_budget() -> Result<()> { + let agent_count: usize = 4; + let single_agent_init_budget = Duration::from_secs(8); + let cluster_overhead_budget = Duration::from_secs(8); + #[expect( + clippy::cast_possible_truncation, + reason = "agent_count is a fixed test constant that fits in u32" + )] + let cluster_startup_budget = + single_agent_init_budget * (agent_count as u32) + cluster_overhead_budget; + + let cluster_startup_started_at = Instant::now(); + + let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(agent_count, 2), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }) + .await?; + + let cluster_startup_elapsed = cluster_startup_started_at.elapsed(); + + assert_eq!( + cluster.agent_ids.len(), + agent_count, + "expected {agent_count} agents to register; got {actual}", + actual = cluster.agent_ids.len(), + ); + + cluster.shutdown().await?; + + assert!( + cluster_startup_elapsed <= cluster_startup_budget, + "cluster startup took {cluster_startup_elapsed:?}, expected within {cluster_startup_budget:?}. \ + Under concurrent agent spawn on Metal, kernel-compile contention can starve a single agent \ + for 60-120s. Sequential spawn isolates each agent's Metal init and keeps total startup \ + within {single_agent_init_budget:?} per agent plus {cluster_overhead_budget:?} of overhead." + ); + + Ok(()) +} diff --git a/paddler_types/Cargo.toml b/paddler_types/Cargo.toml index 765c9748..e6432612 100644 --- a/paddler_types/Cargo.toml +++ b/paddler_types/Cargo.toml @@ -11,8 +11,10 @@ version.workspace = true [dependencies] anyhow = { workspace = true } jsonschema = { workspace = true } +llama-cpp-bindings-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +thiserror = { workspace = true } [lints] workspace = true diff --git a/paddler_types/src/embedding_result.rs b/paddler_types/src/embedding_result.rs index b69cc698..caf1a37a 100644 --- a/paddler_types/src/embedding_result.rs +++ b/paddler_types/src/embedding_result.rs @@ -2,19 +2,31 @@ use serde::Deserialize; use serde::Serialize; use crate::embedding::Embedding; +use crate::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; use crate::streamable_result::StreamableResult; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub enum EmbeddingResult { + DocumentExceedsBatchSize(OversizedEmbeddingDocumentDetails), Done, Embedding(Embedding), + EmbeddingsDisabled, Error(String), + EmbeddingRejectedDueToActiveTokenGeneration, + NoEmbeddingsProduced, } impl StreamableResult for EmbeddingResult { fn is_done(&self) -> bool { - matches!(self, Self::Done | Self::Error(_)) + matches!( + self, + Self::Done + | Self::EmbeddingsDisabled + | Self::Error(_) + | Self::EmbeddingRejectedDueToActiveTokenGeneration + | Self::NoEmbeddingsProduced, + ) } } @@ -34,6 +46,32 @@ mod tests { assert!(EmbeddingResult::Error("fail".to_owned()).is_done()); } + #[test] + fn embeddings_disabled_is_done() { + assert!(EmbeddingResult::EmbeddingsDisabled.is_done()); + } + + #[test] + fn embedding_rejected_due_to_active_token_generation_is_done() { + assert!(EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration.is_done()); + } + + #[test] + fn no_embeddings_produced_is_done() { + assert!(EmbeddingResult::NoEmbeddingsProduced.is_done()); + } + + #[test] + fn document_exceeds_batch_size_is_not_done() { + let result = EmbeddingResult::DocumentExceedsBatchSize(OversizedEmbeddingDocumentDetails { + document_tokens: 4096, + n_batch: 2048, + source_document_id: "huge".to_owned(), + }); + + assert!(!result.is_done()); + } + #[test] fn embedding_is_not_done() { let result = EmbeddingResult::Embedding(Embedding { diff --git a/paddler_types/src/generated_token_result.rs b/paddler_types/src/generated_token_result.rs index 52a4109e..63abed30 100644 --- a/paddler_types/src/generated_token_result.rs +++ b/paddler_types/src/generated_token_result.rs @@ -1,21 +1,72 @@ use serde::Deserialize; use serde::Serialize; +use llama_cpp_bindings_types::ParsedToolCall; + +use crate::generation_summary::GenerationSummary; +use crate::oversized_image_details::OversizedImageDetails; +use crate::raw_tool_call_tokens::RawToolCallTokens; use crate::streamable_result::StreamableResult; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub enum GeneratedTokenResult { ChatTemplateError(String), - Done, + ContentToken(String), + Done(GenerationSummary), GrammarIncompatibleWithThinking(String), GrammarInitializationFailed(String), GrammarRejectedModelOutput(String), GrammarSyntaxError(String), ImageDecodingFailed(String), + ImageExceedsBatchSize(OversizedImageDetails), MultimodalNotSupported(String), + ReasoningToken(String), SamplerError(String), - Token(String), + ToolCallParseFailed(String), + ToolCallParsed(Vec), + ToolCallToken(String), + ToolCallValidationFailed(Vec), + ToolSchemaInvalid(String), + UndeterminableToken(String), + UnrecognizedToolCallFormat(RawToolCallTokens), +} + +impl GeneratedTokenResult { + #[must_use] + pub const fn is_token(&self) -> bool { + matches!( + self, + Self::ContentToken(_) + | Self::ReasoningToken(_) + | Self::ToolCallToken(_) + | Self::UndeterminableToken(_) + ) + } + + #[must_use] + pub fn token_text(&self) -> Option<&str> { + match self { + Self::ContentToken(text) + | Self::ReasoningToken(text) + | Self::ToolCallToken(text) + | Self::UndeterminableToken(text) => Some(text), + _ => None, + } + } + + #[must_use] + pub const fn is_tool_call_parsed(&self) -> bool { + matches!(self, Self::ToolCallParsed(_)) + } + + #[must_use] + pub const fn is_tool_call_failure(&self) -> bool { + matches!( + self, + Self::ToolCallParseFailed(_) | Self::ToolCallValidationFailed(_) + ) + } } impl StreamableResult for GeneratedTokenResult { @@ -23,14 +74,16 @@ impl StreamableResult for GeneratedTokenResult { matches!( self, Self::ChatTemplateError(_) - | Self::Done + | Self::Done(_) | Self::GrammarIncompatibleWithThinking(_) | Self::GrammarInitializationFailed(_) | Self::GrammarRejectedModelOutput(_) | Self::GrammarSyntaxError(_) | Self::ImageDecodingFailed(_) + | Self::ImageExceedsBatchSize(_) | Self::MultimodalNotSupported(_) | Self::SamplerError(_) + | Self::ToolSchemaInvalid(_) ) } } @@ -41,7 +94,7 @@ mod tests { #[test] fn done_is_done() { - assert!(GeneratedTokenResult::Done.is_done()); + assert!(GeneratedTokenResult::Done(GenerationSummary::default()).is_done()); } #[test] @@ -74,6 +127,18 @@ mod tests { assert!(GeneratedTokenResult::ImageDecodingFailed("err".to_owned()).is_done()); } + #[test] + fn image_exceeds_batch_size_is_done_and_not_classified_as_token() { + let event = GeneratedTokenResult::ImageExceedsBatchSize(OversizedImageDetails { + image_tokens: 368, + n_batch: 100, + }); + + assert!(event.is_done()); + assert!(!event.is_token()); + assert!(event.token_text().is_none()); + } + #[test] fn multimodal_not_supported_is_done() { assert!(GeneratedTokenResult::MultimodalNotSupported("err".to_owned()).is_done()); @@ -85,7 +150,63 @@ mod tests { } #[test] - fn token_is_not_done() { - assert!(!GeneratedTokenResult::Token("hello".to_owned()).is_done()); + fn tool_schema_invalid_is_done() { + assert!(GeneratedTokenResult::ToolSchemaInvalid("invalid schema".to_owned()).is_done()); + } + + #[test] + fn content_token_is_not_done() { + assert!(!GeneratedTokenResult::ContentToken("hello".to_owned()).is_done()); + } + + #[test] + fn reasoning_token_is_not_done() { + assert!(!GeneratedTokenResult::ReasoningToken("thinking".to_owned()).is_done()); + } + + #[test] + fn undeterminable_token_is_not_done() { + assert!(!GeneratedTokenResult::UndeterminableToken("ambiguous".to_owned()).is_done()); + } + + #[test] + fn tool_call_parsed_is_not_done() { + let event = GeneratedTokenResult::ToolCallParsed(vec![]); + + assert!(!event.is_done()); + assert!(event.is_tool_call_parsed()); + assert!(!event.is_tool_call_failure()); + } + + #[test] + fn tool_call_parse_failed_is_failure_but_not_done() { + let event = GeneratedTokenResult::ToolCallParseFailed("oops".to_owned()); + + assert!(!event.is_done()); + assert!(!event.is_tool_call_parsed()); + assert!(event.is_tool_call_failure()); + } + + #[test] + fn tool_call_validation_failed_is_failure_but_not_done() { + let event = GeneratedTokenResult::ToolCallValidationFailed(vec!["missing".to_owned()]); + + assert!(!event.is_done()); + assert!(!event.is_tool_call_parsed()); + assert!(event.is_tool_call_failure()); + } + + #[test] + fn unrecognized_tool_call_format_is_not_done_and_not_classified_as_token() { + let event = GeneratedTokenResult::UnrecognizedToolCallFormat(RawToolCallTokens { + text: "raw output".to_owned(), + ffi_error_message: "parser bailed".to_owned(), + }); + + assert!(!event.is_done()); + assert!(!event.is_token()); + assert!(event.token_text().is_none()); + assert!(!event.is_tool_call_parsed()); + assert!(!event.is_tool_call_failure()); } } diff --git a/paddler_types/src/generation_summary.rs b/paddler_types/src/generation_summary.rs new file mode 100644 index 00000000..8592077c --- /dev/null +++ b/paddler_types/src/generation_summary.rs @@ -0,0 +1,10 @@ +use serde::Deserialize; +use serde::Serialize; + +use llama_cpp_bindings_types::TokenUsage; + +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct GenerationSummary { + pub usage: TokenUsage, +} diff --git a/paddler_types/src/inference_parameters.rs b/paddler_types/src/inference_parameters.rs index dd86047b..63e7a962 100644 --- a/paddler_types/src/inference_parameters.rs +++ b/paddler_types/src/inference_parameters.rs @@ -10,8 +10,9 @@ use crate::validates::Validates; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(deny_unknown_fields)] pub struct InferenceParameters { - pub batch_n_tokens: usize, + pub n_batch: usize, pub context_size: u32, + pub embedding_batch_size: usize, pub enable_embeddings: bool, pub image_resize_to_fit: u32, pub k_cache_dtype: KvCacheDtype, @@ -42,6 +43,10 @@ impl Validates for InferenceParameters { bail!("image_resize_to_fit must be greater than zero"); } + if self.embedding_batch_size == 0 { + bail!("embedding_batch_size must be greater than zero"); + } + Ok(self) } } @@ -49,8 +54,9 @@ impl Validates for InferenceParameters { impl Default for InferenceParameters { fn default() -> Self { Self { - batch_n_tokens: 512, + n_batch: 2048, context_size: 8192, + embedding_batch_size: 256, enable_embeddings: false, image_resize_to_fit: 1024, k_cache_dtype: KvCacheDtype::Q8_0, @@ -89,4 +95,21 @@ mod tests { assert!(params.validate().is_err()); } + + #[test] + fn validate_fails_when_embedding_batch_size_is_zero() { + let params = InferenceParameters { + embedding_batch_size: 0, + ..InferenceParameters::default() + }; + + assert!(params.validate().is_err()); + } + + #[test] + fn default_embedding_batch_size_is_256() { + let params = InferenceParameters::default(); + + assert_eq!(params.embedding_batch_size, 256); + } } diff --git a/paddler_types/src/inference_server/request.rs b/paddler_types/src/inference_server/request.rs index c6faab84..ba47597c 100644 --- a/paddler_types/src/inference_server/request.rs +++ b/paddler_types/src/inference_server/request.rs @@ -1,8 +1,8 @@ use serde::Deserialize; use serde::Serialize; -use crate::request_params::ContinueFromConversationHistoryParams; use crate::request_params::ContinueFromRawPromptParams; +use crate::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; #[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_types/src/jsonrpc/error.rs b/paddler_types/src/jsonrpc/error.rs index 9396d0c9..0432e9eb 100644 --- a/paddler_types/src/jsonrpc/error.rs +++ b/paddler_types/src/jsonrpc/error.rs @@ -5,7 +5,7 @@ use std::fmt::Formatter; use serde::Deserialize; use serde::Serialize; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Error { pub code: i32, diff --git a/paddler_types/src/jsonrpc/response_envelope.rs b/paddler_types/src/jsonrpc/response_envelope.rs index 27b3458f..00818344 100644 --- a/paddler_types/src/jsonrpc/response_envelope.rs +++ b/paddler_types/src/jsonrpc/response_envelope.rs @@ -4,6 +4,7 @@ use serde::Serialize; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct ResponseEnvelope { + pub generated_by: Option, pub request_id: String, pub response: TResponse, } diff --git a/paddler_types/src/lib.rs b/paddler_types/src/lib.rs index 6fd11797..e2a333e0 100644 --- a/paddler_types/src/lib.rs +++ b/paddler_types/src/lib.rs @@ -21,6 +21,7 @@ pub mod embedding_input_document; pub mod embedding_normalization_method; pub mod embedding_result; pub mod generated_token_result; +pub mod generation_summary; pub mod grammar_constraint; pub mod huggingface_model_reference; pub mod image_url; @@ -32,7 +33,10 @@ pub mod kv_cache_dtype; pub mod media_marker; pub mod model_metadata; pub mod normalization; +pub mod oversized_embedding_document_details; +pub mod oversized_image_details; pub mod pooling_type; +pub mod raw_tool_call_tokens; pub mod request_params; pub mod rpc_message; pub mod slot_aggregated_status_snapshot; diff --git a/paddler_types/src/oversized_embedding_document_details.rs b/paddler_types/src/oversized_embedding_document_details.rs new file mode 100644 index 00000000..b2fc4096 --- /dev/null +++ b/paddler_types/src/oversized_embedding_document_details.rs @@ -0,0 +1,10 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct OversizedEmbeddingDocumentDetails { + pub document_tokens: u32, + pub n_batch: u32, + pub source_document_id: String, +} diff --git a/paddler_types/src/oversized_image_details.rs b/paddler_types/src/oversized_image_details.rs new file mode 100644 index 00000000..08240d75 --- /dev/null +++ b/paddler_types/src/oversized_image_details.rs @@ -0,0 +1,9 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct OversizedImageDetails { + pub image_tokens: u32, + pub n_batch: u32, +} diff --git a/paddler_types/src/raw_tool_call_tokens.rs b/paddler_types/src/raw_tool_call_tokens.rs new file mode 100644 index 00000000..88f39ebf --- /dev/null +++ b/paddler_types/src/raw_tool_call_tokens.rs @@ -0,0 +1,25 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct RawToolCallTokens { + pub text: String, + pub ffi_error_message: String, +} + +#[cfg(test)] +mod tests { + use super::RawToolCallTokens; + + #[test] + fn carries_text_and_ffi_error_message() { + let tokens = RawToolCallTokens { + text: "raw payload".to_owned(), + ffi_error_message: "parser bailed".to_owned(), + }; + + assert_eq!(tokens.text, "raw payload"); + assert_eq!(tokens.ffi_error_message, "parser bailed"); + } +} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs index 64fbefcb..b396bff5 100644 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs +++ b/paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs @@ -22,6 +22,8 @@ pub struct ContinueFromConversationHistoryParams { pub grammar: Option, pub max_tokens: i32, #[serde(default)] + pub parse_tool_calls: bool, + #[serde(default)] pub tools: Vec>, } @@ -35,6 +37,7 @@ impl Validates> enable_thinking: self.enable_thinking, grammar: self.grammar, max_tokens: self.max_tokens, + parse_tool_calls: self.parse_tool_calls, tools: self .tools .into_iter() diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs index f0442641..af591de7 100644 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs +++ b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs @@ -4,10 +4,10 @@ use anyhow::Result; use serde::Deserialize; use serde::Serialize; -use self::tool_params::FunctionCall; -use crate::validates::Validates; +use self::tool_params::function_call::FunctionCall; use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use crate::validates::Validates; #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs index 4fb4b907..023a6680 100644 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs +++ b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs @@ -1,6 +1,5 @@ use anyhow::Result; use anyhow::anyhow; -use jsonschema::validator_for; use serde::Deserialize; use serde::Serialize; use serde_json::Map; @@ -9,12 +8,6 @@ use serde_json::Value; use super::validated_parameters_schema::ValidatedParametersSchema; use crate::validates::Validates; -fn validate_schema(schema: &Value) -> Result<()> { - validator_for(schema).map_err(|err| anyhow!("{err}"))?; - - Ok(()) -} - #[derive(Default, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct RawParametersSchema { @@ -36,20 +29,6 @@ impl Validates for RawParametersSchema { } } - if let Some(ref properties) = self.properties { - for (key, schema) in properties { - validate_schema(schema) - .map_err(|err| anyhow!("Invalid schema for property '{key}': {err}"))?; - } - } - - if let Some(ref additional) = self.additional_properties - && !additional.is_boolean() - { - validate_schema(additional) - .map_err(|err| anyhow!("Invalid additionalProperties schema: {err}"))?; - } - Ok(ValidatedParametersSchema { schema_type: self.schema_type, properties: self.properties, @@ -96,31 +75,6 @@ mod tests { Ok(()) } - #[test] - fn test_deserialize_with_invalid_property_schema() -> Result<()> { - let input = json!({ - "type": "object", - "properties": { - "name": {"type": "invalid_type"} - } - }); - - let raw_schema: RawParametersSchema = from_value(input)?; - let result: Result = raw_schema.validate(); - - assert!(result.is_err()); - - if let Err(error) = &result { - assert!( - error - .to_string() - .contains("Invalid schema for property 'name'") - ); - } - - Ok(()) - } - #[test] fn test_deserialize_required_field_not_in_properties() -> Result<()> { let input = json!({ @@ -146,27 +100,4 @@ mod tests { Ok(()) } - - #[test] - fn test_deserialize_invalid_additional_properties_schema() -> Result<()> { - let input = json!({ - "type": "object", - "additionalProperties": {"type": "not_a_type"} - }); - - let raw_schema: RawParametersSchema = from_value(input)?; - let result: Result = raw_schema.validate(); - - assert!(result.is_err()); - - if let Err(error) = &result { - assert!( - error - .to_string() - .contains("Invalid additionalProperties schema") - ); - } - - Ok(()) - } } diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs index d60596c3..15415c17 100644 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs +++ b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs @@ -1,3 +1 @@ pub mod function_call; - -pub use function_call::FunctionCall; diff --git a/paddler_types/src/request_params/generate_embedding_batch_params/chunk_by_input_size_iter.rs b/paddler_types/src/request_params/generate_embedding_batch_params/chunk_by_input_size_iter.rs deleted file mode 100644 index c793da06..00000000 --- a/paddler_types/src/request_params/generate_embedding_batch_params/chunk_by_input_size_iter.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::GenerateEmbeddingBatchParams; -use crate::embedding_input_document::EmbeddingInputDocument; -use crate::embedding_normalization_method::EmbeddingNormalizationMethod; - -pub struct ChunkByInputSizeIter<'embedding_batch> { - pub chunk_size: usize, - pub current_index: usize, - pub input_batch: &'embedding_batch [EmbeddingInputDocument], - pub normalization_method: &'embedding_batch EmbeddingNormalizationMethod, -} - -impl Iterator for ChunkByInputSizeIter<'_> { - type Item = GenerateEmbeddingBatchParams; - - fn next(&mut self) -> Option { - if self.current_index >= self.input_batch.len() { - return None; - } - - let mut current_batch = Vec::new(); - let mut current_size = 0; - - while self.current_index < self.input_batch.len() { - let input = &self.input_batch[self.current_index]; - let input_size = input.content.chars().count(); - - if current_size + input_size > self.chunk_size && !current_batch.is_empty() { - break; - } - - current_batch.push(input.clone()); - current_size += input_size; - self.current_index += 1; - } - - if current_batch.is_empty() { - None - } else { - Some(GenerateEmbeddingBatchParams { - input_batch: current_batch, - normalization_method: self.normalization_method.clone(), - }) - } - } -} diff --git a/paddler_types/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs b/paddler_types/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs new file mode 100644 index 00000000..0e42e2c6 --- /dev/null +++ b/paddler_types/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs @@ -0,0 +1,7 @@ +#[derive(Debug, thiserror::Error)] +pub enum ChunkEvenlyWithCapError { + #[error("agent_count must be non-zero")] + ZeroAgentCount, + #[error("max_documents_per_chunk must be non-zero")] + ZeroMaxDocumentsPerChunk, +} diff --git a/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs b/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs index 58b4a1e0..b9a0fc42 100644 --- a/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs +++ b/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs @@ -1,9 +1,9 @@ -mod chunk_by_input_size_iter; +mod chunk_evenly_with_cap_error; use serde::Deserialize; use serde::Serialize; -use self::chunk_by_input_size_iter::ChunkByInputSizeIter; +pub use self::chunk_evenly_with_cap_error::ChunkEvenlyWithCapError; use crate::embedding_input_document::EmbeddingInputDocument; use crate::embedding_normalization_method::EmbeddingNormalizationMethod; @@ -15,20 +15,58 @@ pub struct GenerateEmbeddingBatchParams { } impl GenerateEmbeddingBatchParams { - /// Input size is the total number of characters in the resulting batches. - #[must_use] - pub fn chunk_by_input_size(&self, chunk_size: usize) -> ChunkByInputSizeIter<'_> { - ChunkByInputSizeIter { - input_batch: &self.input_batch, - normalization_method: &self.normalization_method, - chunk_size, - current_index: 0, + pub fn chunk_evenly_with_cap( + &self, + agent_count: usize, + max_documents_per_chunk: usize, + ) -> Result, ChunkEvenlyWithCapError> { + if agent_count == 0 { + return Err(ChunkEvenlyWithCapError::ZeroAgentCount); } + if max_documents_per_chunk == 0 { + return Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk); + } + + let document_count = self.input_batch.len(); + + if document_count == 0 { + return Ok(Vec::new()); + } + + let chunks_to_honor_cap = document_count.div_ceil(max_documents_per_chunk); + let chunk_count = document_count.min(agent_count.max(chunks_to_honor_cap)); + + let quotient = document_count / chunk_count; + let remainder = document_count % chunk_count; + + let mut sub_batches = Vec::with_capacity(chunk_count); + let mut start_index = 0; + + for chunk_index in 0..chunk_count { + let chunk_size = if chunk_index < remainder { + quotient + 1 + } else { + quotient + }; + + let end_index = start_index + chunk_size; + + sub_batches.push(Self { + input_batch: self.input_batch[start_index..end_index].to_vec(), + normalization_method: self.normalization_method.clone(), + }); + + start_index = end_index; + } + + Ok(sub_batches) } } #[cfg(test)] mod tests { + use anyhow::Result; + use super::*; fn make_doc(id: &str, content: &str) -> EmbeddingInputDocument { @@ -45,103 +83,310 @@ mod tests { } } + fn make_docs(count: usize) -> Vec { + (0..count) + .map(|index| make_doc(&format!("doc{index:05}"), "x")) + .collect() + } + #[test] - fn test_chunk_by_input_size() { - let params = make_params(vec![ - make_doc("1", "Hello"), - make_doc("2", "World"), - make_doc("3", "This is a test"), - ]); + fn chunk_evenly_with_cap_empty_input() -> Result<()> { + let params = make_params(vec![]); + + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; - let batches = params.chunk_by_input_size(10).collect::>(); + assert!(sub_batches.is_empty()); - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].input_batch.len(), 2); - assert_eq!(batches[0].input_batch[0].id, "1"); - assert_eq!(batches[0].input_batch[1].id, "2"); - assert_eq!(batches[1].input_batch.len(), 1); - assert_eq!(batches[1].input_batch[0].id, "3"); + Ok(()) } #[test] - fn test_chunk_empty_batch() { - let params = make_params(vec![]); + fn chunk_evenly_with_cap_single_doc_single_agent() -> Result<()> { + let params = make_params(vec![make_doc("only", "content")]); + + let sub_batches = params.chunk_evenly_with_cap(1, 256)?; + + assert_eq!(sub_batches.len(), 1); + assert_eq!(sub_batches[0].input_batch.len(), 1); + assert_eq!(sub_batches[0].input_batch[0].id, "only"); + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_single_doc_many_agents() -> Result<()> { + let params = make_params(vec![make_doc("only", "content")]); + + let sub_batches = params.chunk_evenly_with_cap(5, 256)?; + + assert_eq!(sub_batches.len(), 1); + assert_eq!(sub_batches[0].input_batch.len(), 1); + assert_eq!(sub_batches[0].input_batch[0].id, "only"); - assert!(params.chunk_by_input_size(100).next().is_none()); + Ok(()) } #[test] - fn test_chunk_single_item_larger_than_chunk_size() { - let params = make_params(vec![make_doc("1", "This content exceeds the chunk limit")]); + fn chunk_evenly_with_cap_more_agents_than_docs_uses_n_chunks() -> Result<()> { + let params = make_params(make_docs(3)); - let batches = params.chunk_by_input_size(5).collect::>(); + let sub_batches = params.chunk_evenly_with_cap(5, 256)?; - assert_eq!(batches.len(), 1); - assert_eq!(batches[0].input_batch.len(), 1); - assert_eq!(batches[0].input_batch[0].id, "1"); + assert_eq!(sub_batches.len(), 3); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 1); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_rejects_zero_agent_count() { + let params = make_params(make_docs(5)); + + let result = params.chunk_evenly_with_cap(0, 256); + + assert!(matches!( + result, + Err(ChunkEvenlyWithCapError::ZeroAgentCount) + )); } #[test] - fn test_chunk_oversized_item_does_not_merge_with_next() { - let params = make_params(vec![ - make_doc("1", "This is way too long for chunk"), - make_doc("2", "Short"), - ]); + fn chunk_evenly_with_cap_rejects_zero_max_documents_per_chunk() { + let params = make_params(make_docs(4)); - let batches = params.chunk_by_input_size(5).collect::>(); + let result = params.chunk_evenly_with_cap(2, 0); - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].input_batch[0].id, "1"); - assert_eq!(batches[1].input_batch[0].id, "2"); + assert!(matches!( + result, + Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) + )); } #[test] - fn test_chunk_exact_fit() { - let params = make_params(vec![make_doc("1", "12345"), make_doc("2", "67890")]); + fn chunk_evenly_with_cap_below_cap_splits_per_agent() -> Result<()> { + let params = make_params(make_docs(4)); - // 5 + 5 = 10, exactly the chunk size - let batches = params.chunk_by_input_size(10).collect::>(); + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; - assert_eq!(batches.len(), 1); - assert_eq!(batches[0].input_batch.len(), 2); + assert_eq!(sub_batches.len(), 4); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 1); + } + + Ok(()) } #[test] - fn test_chunk_one_over_limit_splits() { - let params = make_params(vec![make_doc("1", "12345"), make_doc("2", "678901")]); + fn chunk_evenly_with_cap_below_cap_uneven_split() -> Result<()> { + let params = make_params(make_docs(11)); + + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; - // 5 + 6 = 11, exceeds chunk_size of 10 - let batches = params.chunk_by_input_size(10).collect::>(); + assert_eq!(sub_batches.len(), 4); + assert_eq!(sub_batches[0].input_batch.len(), 3); + assert_eq!(sub_batches[1].input_batch.len(), 3); + assert_eq!(sub_batches[2].input_batch.len(), 3); + assert_eq!(sub_batches[3].input_batch.len(), 2); - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].input_batch.len(), 1); - assert_eq!(batches[1].input_batch.len(), 1); + Ok(()) } #[test] - fn test_chunk_preserves_normalization_method() { + fn chunk_evenly_with_cap_user_example_80_docs_4_agents_cap_100() -> Result<()> { + let params = make_params(make_docs(80)); + + let sub_batches = params.chunk_evenly_with_cap(4, 100)?; + + assert_eq!(sub_batches.len(), 4); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 20); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_user_example_1000_docs_4_agents_cap_100() -> Result<()> { + let params = make_params(make_docs(1000)); + + let sub_batches = params.chunk_evenly_with_cap(4, 100)?; + + assert_eq!(sub_batches.len(), 10); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 100); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_at_cap_boundary_uses_agent_count() -> Result<()> { + let params = make_params(make_docs(1024)); + + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + + assert_eq!(sub_batches.len(), 4); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 256); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_above_cap_boundary_creates_extra_chunks() -> Result<()> { + let params = make_params(make_docs(2000)); + + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + + assert_eq!(sub_batches.len(), 8); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 250); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_far_above_cap_distributes_evenly() -> Result<()> { + let params = make_params(make_docs(1100)); + + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + + assert_eq!(sub_batches.len(), 5); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 220); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_extreme_large_n_small_cap() -> Result<()> { + let params = make_params(make_docs(10_000)); + + let sub_batches = params.chunk_evenly_with_cap(4, 1)?; + + assert_eq!(sub_batches.len(), 10_000); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 1); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_extreme_one_doc_per_chunk() -> Result<()> { + let params = make_params(make_docs(100)); + + let sub_batches = params.chunk_evenly_with_cap(100, 256)?; + + assert_eq!(sub_batches.len(), 100); + for sub_batch in &sub_batches { + assert_eq!(sub_batch.input_batch.len(), 1); + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_no_sub_batch_exceeds_cap_sweep() -> Result<()> { + let document_counts: Vec = (0..=50).chain([256, 257, 1000, 2001]).collect(); + let agent_counts: Vec = (1..=8).collect(); + let caps: Vec = vec![1, 2, 4, 100, 256]; + + for &document_count in &document_counts { + for &agent_count in &agent_counts { + for &cap in &caps { + let params = make_params(make_docs(document_count)); + + let sub_batches = params.chunk_evenly_with_cap(agent_count, cap)?; + + let total_documents: usize = + sub_batches.iter().map(|sub| sub.input_batch.len()).sum(); + assert_eq!( + total_documents, document_count, + "total documents must equal N (N={document_count}, agents={agent_count}, cap={cap})", + ); + + for sub_batch in &sub_batches { + assert!( + sub_batch.input_batch.len() <= cap, + "sub-batch size {} exceeds cap {} (N={document_count}, agents={agent_count}, cap={cap})", + sub_batch.input_batch.len(), + cap, + ); + } + + let collected_ids: Vec = sub_batches + .iter() + .flat_map(|sub| sub.input_batch.iter().map(|doc| doc.id.clone())) + .collect(); + let expected_ids: Vec = (0..document_count) + .map(|index| format!("doc{index:05}")) + .collect(); + assert_eq!( + collected_ids, expected_ids, + "concatenated IDs must equal original order (N={document_count}, agents={agent_count}, cap={cap})", + ); + + if document_count > 0 { + assert!( + !sub_batches.is_empty(), + "non-empty input must produce at least one sub-batch (N={document_count}, agents={agent_count}, cap={cap})", + ); + for sub_batch in &sub_batches { + assert!( + !sub_batch.input_batch.is_empty(), + "no sub-batch may be empty (N={document_count}, agents={agent_count}, cap={cap})", + ); + } + } else { + assert!(sub_batches.is_empty(), "empty input must produce empty Vec",); + } + } + } + } + + Ok(()) + } + + #[test] + fn chunk_evenly_with_cap_preserves_normalization_method() -> Result<()> { let params = GenerateEmbeddingBatchParams { - input_batch: vec![make_doc("1", "test")], + input_batch: make_docs(8), normalization_method: EmbeddingNormalizationMethod::L2, }; - let batches = params.chunk_by_input_size(100).collect::>(); + let sub_batches = params.chunk_evenly_with_cap(4, 256)?; - assert!(matches!( - batches[0].normalization_method, - EmbeddingNormalizationMethod::L2 - )); + assert_eq!(sub_batches.len(), 4); + for sub_batch in &sub_batches { + assert!(matches!( + sub_batch.normalization_method, + EmbeddingNormalizationMethod::L2 + )); + } + + Ok(()) } #[test] - fn test_chunk_counts_unicode_chars_not_bytes() { - // "café" is 4 chars but 5 bytes (é is 2 bytes) - let params = make_params(vec![make_doc("1", "café"), make_doc("2", "naïve")]); + fn chunk_evenly_with_cap_preserves_document_ids_and_order() -> Result<()> { + let params = make_params(make_docs(12)); + + let sub_batches = params.chunk_evenly_with_cap(5, 256)?; + + let collected_ids: Vec = sub_batches + .iter() + .flat_map(|sub| sub.input_batch.iter().map(|doc| doc.id.clone())) + .collect(); + let expected_ids: Vec = (0..12).map(|index| format!("doc{index:05}")).collect(); - // 4 chars + 5 chars = 9, fits in chunk_size of 9 - let batches = params.chunk_by_input_size(9).collect::>(); + assert_eq!(collected_ids, expected_ids); - assert_eq!(batches.len(), 1); - assert_eq!(batches[0].input_batch.len(), 2); + Ok(()) } } diff --git a/paddler_types/src/request_params/mod.rs b/paddler_types/src/request_params/mod.rs index 8c3691ae..dda6195e 100644 --- a/paddler_types/src/request_params/mod.rs +++ b/paddler_types/src/request_params/mod.rs @@ -2,6 +2,6 @@ pub mod continue_from_conversation_history_params; mod continue_from_raw_prompt_params; mod generate_embedding_batch_params; -pub use continue_from_conversation_history_params::ContinueFromConversationHistoryParams; pub use continue_from_raw_prompt_params::ContinueFromRawPromptParams; +pub use generate_embedding_batch_params::ChunkEvenlyWithCapError; pub use generate_embedding_batch_params::GenerateEmbeddingBatchParams; diff --git a/resources/ts/ConversationMessage.type.ts b/resources/ts/ConversationMessage.type.ts deleted file mode 100644 index accbb3fe..00000000 --- a/resources/ts/ConversationMessage.type.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { type ConversationMessageContentPart } from "./ConversationMessageContentPart.type"; - -export type ConversationMessage = { - role: string; - content: string | ConversationMessageContentPart[]; -}; diff --git a/resources/ts/ConversationMessageContentPart.type.ts b/resources/ts/ConversationMessageContentPart.type.ts deleted file mode 100644 index b99ec455..00000000 --- a/resources/ts/ConversationMessageContentPart.type.ts +++ /dev/null @@ -1,3 +0,0 @@ -export type ConversationMessageContentPart = - | { type: "text"; text: string } - | { type: "image_url"; image_url: { url: string } }; diff --git a/resources/ts/InferenceSocketClient.interface.ts b/resources/ts/InferenceSocketClient.interface.ts deleted file mode 100644 index 3700aac0..00000000 --- a/resources/ts/InferenceSocketClient.interface.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { type Observable } from "rxjs"; - -import { type ConversationMessage } from "./ConversationMessage.type"; -import { type InferenceServiceGenerateTokensResponse } from "./schemas/InferenceServiceGenerateTokensResponse"; - -export interface InferenceSocketClient { - continueConversation(params: { - enableThinking: boolean; - messages: ConversationMessage[]; - }): Observable; -} diff --git a/resources/ts/components/AgentIssues.tsx b/resources/ts/components/AgentIssues.tsx index ceb723e6..013cdf95 100644 --- a/resources/ts/components/AgentIssues.tsx +++ b/resources/ts/components/AgentIssues.tsx @@ -1,7 +1,7 @@ import React from "react"; import { Link } from "wouter"; -import { type AgentIssue } from "../schemas/AgentIssue"; +import { type AgentIssue } from "@intentee/paddler-client/schemas/AgentIssue"; import { agentIssues, agentIssues__issue } from "./AgentIssues.module.css"; diff --git a/resources/ts/components/AgentIssuesPreviewButton.tsx b/resources/ts/components/AgentIssuesPreviewButton.tsx index bb8a3006..f8e9018c 100644 --- a/resources/ts/components/AgentIssuesPreviewButton.tsx +++ b/resources/ts/components/AgentIssuesPreviewButton.tsx @@ -1,6 +1,6 @@ import React, { useCallback, useState, type MouseEvent } from "react"; -import { type AgentIssue } from "../schemas/AgentIssue"; +import { type AgentIssue } from "@intentee/paddler-client/schemas/AgentIssue"; import { AgentIssues } from "./AgentIssues"; import { agentIssuesPreviewButton } from "./AgentIssuesPreviewButton.module.css"; import { ModalWindow } from "./ModalWindow"; diff --git a/resources/ts/components/AgentList.tsx b/resources/ts/components/AgentList.tsx index da18bca9..7b082544 100644 --- a/resources/ts/components/AgentList.tsx +++ b/resources/ts/components/AgentList.tsx @@ -1,7 +1,7 @@ import clsx from "clsx"; import React from "react"; -import { type Agent } from "../schemas/Agent"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { AgentIssuesPreviewButton } from "./AgentIssuesPreviewButton"; import { AgentListAgentStatus } from "./AgentListAgentStatus"; import { ModelChatTemplateOverridePreviewButton } from "./ModelChatTemplateOverridePreviewButton"; diff --git a/resources/ts/components/AgentListAgentStatus.tsx b/resources/ts/components/AgentListAgentStatus.tsx index 19e30fe1..ceb51f14 100644 --- a/resources/ts/components/AgentListAgentStatus.tsx +++ b/resources/ts/components/AgentListAgentStatus.tsx @@ -1,6 +1,6 @@ import React, { CSSProperties } from "react"; -import { type Agent } from "../schemas/Agent"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { agentListAgentStatus__progress } from "./AgentListAgentStatus.module.css"; diff --git a/resources/ts/components/AgentListStream.tsx b/resources/ts/components/AgentListStream.tsx index 1937a5f7..b2eceecc 100644 --- a/resources/ts/components/AgentListStream.tsx +++ b/resources/ts/components/AgentListStream.tsx @@ -1,9 +1,9 @@ import React, { useContext } from "react"; +import { AgentsResponseSchema } from "@intentee/paddler-client/schemas/AgentsResponse"; import { PaddlerConfigurationContext } from "../contexts/PaddlerConfigurationContext"; import { useEventSourceUpdates } from "../hooks/useEventSourceUpdates"; import { matchEventSourceUpdateState } from "../matchEventSourceUpdateState"; -import { AgentsResponseSchema } from "../schemas/AgentsResponse"; import { AgentList } from "./AgentList"; import { agentListStream__placeholder } from "./AgentListStream.module.css"; diff --git a/resources/ts/components/BufferedRequestsStream.tsx b/resources/ts/components/BufferedRequestsStream.tsx index 9dfa9dbd..85b43260 100644 --- a/resources/ts/components/BufferedRequestsStream.tsx +++ b/resources/ts/components/BufferedRequestsStream.tsx @@ -1,9 +1,9 @@ import React, { useContext } from "react"; +import { BufferedRequestsResponseSchema } from "@intentee/paddler-client/schemas/BufferedRequestsResponse"; import { PaddlerConfigurationContext } from "../contexts/PaddlerConfigurationContext"; import { useEventSourceUpdates } from "../hooks/useEventSourceUpdates"; import { matchEventSourceUpdateState } from "../matchEventSourceUpdateState"; -import { BufferedRequestsResponseSchema } from "../schemas/BufferedRequestsResponse"; import { BufferedRequests } from "./BufferedRequests"; import { dashboardSectionStreamLoader } from "./dashboardSectionStreamLoader.module.css"; diff --git a/resources/ts/components/ChangeModelForm.tsx b/resources/ts/components/ChangeModelForm.tsx index 1832031d..cfbe1bc7 100644 --- a/resources/ts/components/ChangeModelForm.tsx +++ b/resources/ts/components/ChangeModelForm.tsx @@ -7,11 +7,11 @@ import React, { } from "react"; import { useLocation } from "wouter"; +import { type BalancerDesiredState } from "@intentee/paddler-client/schemas/BalancerDesiredState"; import { ChatTemplateContext } from "../contexts/ChatTemplateContext"; import { InferenceParametersContext } from "../contexts/InferenceParametersContext"; import { PaddlerConfigurationContext } from "../contexts/PaddlerConfigurationContext"; import { useAgentDesiredModelUrl } from "../hooks/useAgentDesiredModelUrl"; -import { type BalancerDesiredState } from "../schemas/BalancerDesiredState"; import { ChatTemplateBehavior } from "./ChatTemplateBehavior"; import { InferenceParameterCacheDtype } from "./InferenceParameterCacheDtype"; import { InferenceParameterCheckbox } from "./InferenceParameterCheckbox"; @@ -234,7 +234,7 @@ export function ChangeModelForm({ + (defaultMessage); - const inferenceSocketClient = useMemo( + const socketClient = useMemo( function () { - return InferenceSocketClient({ webSocket }); + return inferenceSocketClient({ webSocket }); }, [webSocket], ); @@ -78,7 +78,7 @@ export const ConversationMessagePromptGeneratedTokens = memo( return; } - const subscription = inferenceSocketClient + const subscription = socketClient .continueConversation({ enableThinking: submittedIsThinkingEnabled, messages: [ @@ -97,17 +97,17 @@ export const ConversationMessagePromptGeneratedTokens = memo( .pipe( scan(function ( message: Message, - { done, error, token }: InferenceServiceGenerateTokensResponse, + chunk: InferenceServiceGenerateTokensResponse, ) { - if (error) { + if (chunk.error) { return Object.freeze({ ...message, - errors: [...message.errors, error], + errors: [...message.errors, chunk.error], isEmpty: false, }); } - if (done) { + if (chunk.done) { return Object.freeze({ errors: message.errors, isEmpty: false, @@ -117,33 +117,25 @@ export const ConversationMessagePromptGeneratedTokens = memo( }); } - if ("" === token) { - return Object.freeze({ - errors: message.errors, - isEmpty: false, - isThinking: true, - response: message.response, - thoughts: message.thoughts, - }); + if (null === chunk.token) { + return message; } - if ("" === token) { + if ("reasoning" === chunk.tokenKind) { return Object.freeze({ errors: message.errors, isEmpty: false, - isThinking: false, + isThinking: true, response: message.response, - thoughts: message.thoughts, + thoughts: `${message.thoughts}${chunk.token}`, }); } - if (message.isThinking) { + if ("tool_call" === chunk.tokenKind) { return Object.freeze({ - errors: message.errors, + ...message, isEmpty: false, - isThinking: true, - response: message.response, - thoughts: `${message.thoughts}${token}`, + isThinking: false, }); } @@ -151,7 +143,7 @@ export const ConversationMessagePromptGeneratedTokens = memo( errors: message.errors, isEmpty: false, isThinking: false, - response: `${message.response}${token}`, + response: `${message.response}${chunk.token}`, thoughts: message.thoughts, }); }, defaultMessage), @@ -163,7 +155,7 @@ export const ConversationMessagePromptGeneratedTokens = memo( }; }, [ - inferenceSocketClient, + socketClient, setMessage, submittedImageDataUri, submittedIsThinkingEnabled, diff --git a/resources/ts/components/InferenceParameterCacheDtype.tsx b/resources/ts/components/InferenceParameterCacheDtype.tsx index 803f0899..6c68645f 100644 --- a/resources/ts/components/InferenceParameterCacheDtype.tsx +++ b/resources/ts/components/InferenceParameterCacheDtype.tsx @@ -1,7 +1,7 @@ import React, { useCallback, useContext, type ChangeEvent } from "react"; +import { cacheDtypes } from "@intentee/paddler-client/schemas/InferenceParameters"; import { InferenceParametersContext } from "../contexts/InferenceParametersContext"; -import { cacheDtypes } from "../schemas/InferenceParameters"; import { inferenceParameterInput, inferenceParameterInput__label, diff --git a/resources/ts/components/InferenceParameterCheckbox.tsx b/resources/ts/components/InferenceParameterCheckbox.tsx index 133f3a16..3e251ef7 100644 --- a/resources/ts/components/InferenceParameterCheckbox.tsx +++ b/resources/ts/components/InferenceParameterCheckbox.tsx @@ -1,20 +1,19 @@ import React, { useCallback, useContext } from "react"; import { InferenceParametersContext } from "../contexts/InferenceParametersContext"; -import { type BooleanKeys } from "../schemas/InferenceParameters"; +import { type InferenceParametersBooleanKeys } from "../inferenceParametersFormKeys"; import { inferenceParameterInput, inferenceParameterInput__checkbox, inferenceParameterInput__label, } from "./inferenceParameterInput.module.css"; -// eslint-disable-next-line @typescript-eslint/no-unnecessary-type-parameters -export function InferenceParameterCheckbox({ +export function InferenceParameterCheckbox({ description, name, }: { description: string; - name: TKey; + name: InferenceParametersBooleanKeys; }) { const { parameters, setParameter } = useContext(InferenceParametersContext); diff --git a/resources/ts/components/InferenceParameterInput.tsx b/resources/ts/components/InferenceParameterInput.tsx index 41a05402..0f6fd524 100644 --- a/resources/ts/components/InferenceParameterInput.tsx +++ b/resources/ts/components/InferenceParameterInput.tsx @@ -1,23 +1,19 @@ import React, { useCallback, useContext, type FormEvent } from "react"; import { InferenceParametersContext } from "../contexts/InferenceParametersContext"; -import { - type InferenceParameters, - type NumberKeys, -} from "../schemas/InferenceParameters"; +import { type InferenceParametersNumberKeys } from "../inferenceParametersFormKeys"; import { inferenceParameterInput, inferenceParameterInput__input, inferenceParameterInput__label, } from "./inferenceParameterInput.module.css"; -// eslint-disable-next-line @typescript-eslint/no-unnecessary-type-parameters -export function InferenceParameterInput({ +export function InferenceParameterInput({ description, name, }: { description: string; - name: TKey; + name: InferenceParametersNumberKeys; }) { const { parameters, setParameter } = useContext(InferenceParametersContext); @@ -25,10 +21,7 @@ export function InferenceParameterInput({ function (event: FormEvent) { event.preventDefault(); - setParameter( - name, - parseFloat(event.currentTarget.value) as InferenceParameters[TKey], - ); + setParameter(name, parseFloat(event.currentTarget.value)); }, [name, setParameter], ); diff --git a/resources/ts/components/InferenceParameterPoolingType.tsx b/resources/ts/components/InferenceParameterPoolingType.tsx index 93799d7b..0b83e0c2 100644 --- a/resources/ts/components/InferenceParameterPoolingType.tsx +++ b/resources/ts/components/InferenceParameterPoolingType.tsx @@ -1,7 +1,7 @@ import React, { useCallback, useContext, type ChangeEvent } from "react"; +import { poolingTypes } from "@intentee/paddler-client/schemas/InferenceParameters"; import { InferenceParametersContext } from "../contexts/InferenceParametersContext"; -import { poolingTypes } from "../schemas/InferenceParameters"; import { inferenceParameterInput, inferenceParameterInput__disabledHint, diff --git a/resources/ts/components/InferenceParametersContextProvider.tsx b/resources/ts/components/InferenceParametersContextProvider.tsx index 20b2b1eb..b67e2ed4 100644 --- a/resources/ts/components/InferenceParametersContextProvider.tsx +++ b/resources/ts/components/InferenceParametersContextProvider.tsx @@ -1,10 +1,10 @@ import React, { useMemo, useState, type ReactNode } from "react"; +import { type InferenceParameters } from "@intentee/paddler-client/schemas/InferenceParameters"; import { InferenceParametersContext, type InferenceParametersContextValue, } from "../contexts/InferenceParametersContext"; -import { type InferenceParameters } from "../schemas/InferenceParameters"; export function InferenceParametersContextProvider({ children, diff --git a/resources/ts/components/ModelChatTemplateOverridePreviewButton.tsx b/resources/ts/components/ModelChatTemplateOverridePreviewButton.tsx index b7e88366..3f87f3e4 100644 --- a/resources/ts/components/ModelChatTemplateOverridePreviewButton.tsx +++ b/resources/ts/components/ModelChatTemplateOverridePreviewButton.tsx @@ -1,6 +1,6 @@ import React, { useCallback, useState, type MouseEvent } from "react"; -import { type Agent } from "../schemas/Agent"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { ChatTemplateOverrideLoader } from "./ChatTemplateOverrideLoader"; import { modelChatTemplateOverridePreviewButton } from "./ModelChatTemplateOverridePreviewButton.module.css"; diff --git a/resources/ts/components/ModelMetadata.tsx b/resources/ts/components/ModelMetadata.tsx index 2e8a23f5..f6fa84a7 100644 --- a/resources/ts/components/ModelMetadata.tsx +++ b/resources/ts/components/ModelMetadata.tsx @@ -1,7 +1,7 @@ import React, { useContext } from "react"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { ModelMetadataContext } from "../contexts/ModelMetadataContext"; -import { type Agent } from "../schemas/Agent"; import { ModalWindow } from "./ModalWindow"; import { ModelChatTemplatePreviewButton } from "./ModelChatTemplatePreviewButton"; import { ModelMetadataFocusedParameter } from "./ModelMetadataFocusedParameter"; diff --git a/resources/ts/components/ModelMetadataLoader.tsx b/resources/ts/components/ModelMetadataLoader.tsx index 81df81b1..0fa1b3fc 100644 --- a/resources/ts/components/ModelMetadataLoader.tsx +++ b/resources/ts/components/ModelMetadataLoader.tsx @@ -1,8 +1,8 @@ import React from "react"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { useModelMetadata } from "../hooks/useModelMetadata"; import { matchFetchJsonState } from "../matchFetchJsonState"; -import { type Agent } from "../schemas/Agent"; import { ModalWindow } from "./ModalWindow"; import { ModelMetadata } from "./ModelMetadata"; import { ModelMetadataContextProvider } from "./ModelMetadataContextProvider"; diff --git a/resources/ts/components/ModelMetadataPreviewButton.tsx b/resources/ts/components/ModelMetadataPreviewButton.tsx index 24c54862..eeb0bfd7 100644 --- a/resources/ts/components/ModelMetadataPreviewButton.tsx +++ b/resources/ts/components/ModelMetadataPreviewButton.tsx @@ -1,6 +1,6 @@ import React, { useCallback, useState, type MouseEvent } from "react"; -import { type Agent } from "../schemas/Agent"; +import { type Agent } from "@intentee/paddler-client/schemas/Agent"; import { ModelMetadataLoader } from "./ModelMetadataLoader"; import { modelMetadataPreviewButton } from "./ModelMetadataPreviewButton.module.css"; diff --git a/resources/ts/components/PromptPage.tsx b/resources/ts/components/PromptPage.tsx index 18145056..022ce577 100644 --- a/resources/ts/components/PromptPage.tsx +++ b/resources/ts/components/PromptPage.tsx @@ -1,10 +1,10 @@ import React, { useContext } from "react"; +import { webSocketProtocol } from "@intentee/paddler-client/webSocketProtocol"; import { PaddlerConfigurationContext } from "../contexts/PaddlerConfigurationContext"; import { PromptContext } from "../contexts/PromptContext"; import { useWebSocket } from "../hooks/useWebSocket"; import { matchWebSocketState } from "../matchWebSocketState"; -import { webSocketProtocol } from "../webSocketProtocol"; import { ConversationMessage } from "./ConversationMessage"; import { ConversationMessagePromptGeneratedTokens } from "./ConversationMessagePromptGeneratedTokens"; import { ConversationPromptInput } from "./ConversationPromptInput"; diff --git a/resources/ts/contexts/ChatTemplateContext.ts b/resources/ts/contexts/ChatTemplateContext.ts index 7cb79dd8..f330d7fb 100644 --- a/resources/ts/contexts/ChatTemplateContext.ts +++ b/resources/ts/contexts/ChatTemplateContext.ts @@ -1,6 +1,6 @@ import { createContext } from "react"; -import { type ChatTemplate } from "../schemas/ChatTemplate"; +import { type ChatTemplate } from "@intentee/paddler-client/schemas/ChatTemplate"; export type ChatTemplateContextValue = { chatTemplateOverride: null | ChatTemplate; diff --git a/resources/ts/contexts/InferenceParametersContext.ts b/resources/ts/contexts/InferenceParametersContext.ts index 1363f90a..23e29459 100644 --- a/resources/ts/contexts/InferenceParametersContext.ts +++ b/resources/ts/contexts/InferenceParametersContext.ts @@ -1,6 +1,6 @@ import { createContext } from "react"; -import { type InferenceParameters } from "../schemas/InferenceParameters"; +import { type InferenceParameters } from "@intentee/paddler-client/schemas/InferenceParameters"; export type InferenceParametersContextValue = { parameters: InferenceParameters; diff --git a/resources/ts/extractHuggingFaceUrlParts.ts b/resources/ts/extractHuggingFaceUrlParts.ts deleted file mode 100644 index ba35d7f4..00000000 --- a/resources/ts/extractHuggingFaceUrlParts.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { match } from "path-to-regexp"; - -import { type HuggingFaceModelReference } from "./schemas/HuggingFaceModelReference"; - -type UrlParams = { - owner: string; - repo: string; - revision: string; - filename: string; -}; - -const blobMatcher = match("/:owner/:repo/blob/:revision/:filename"); -const resolveMatcher = match( - "/:owner/:repo/resolve/:revision/:filename", -); - -function urlParamsToHuggingFaceUrlParts({ - owner, - repo, - revision, - filename, -}: UrlParams): HuggingFaceModelReference { - return { - filename: filename.startsWith("/") ? filename.slice(1) : filename, - repo_id: `${owner}/${repo}`, - revision, - }; -} - -export function extractHuggingFaceUrlParts({ - pathname, -}: URL): HuggingFaceModelReference { - const blobMatch = blobMatcher(pathname); - - if (blobMatch) { - return urlParamsToHuggingFaceUrlParts(blobMatch.params); - } - - const resolveMatch = resolveMatcher(pathname); - - if (resolveMatch) { - return urlParamsToHuggingFaceUrlParts(resolveMatch.params); - } - - throw new Error(`Invalid Hugging Face URL format: ${pathname}`); -} diff --git a/resources/ts/hooks/useAgentDesiredModelUrl.ts b/resources/ts/hooks/useAgentDesiredModelUrl.ts index d20d691f..dc823592 100644 --- a/resources/ts/hooks/useAgentDesiredModelUrl.ts +++ b/resources/ts/hooks/useAgentDesiredModelUrl.ts @@ -1,7 +1,7 @@ import { useMemo, useState } from "react"; -import { type AgentDesiredModel } from "../schemas/AgentDesiredModel"; -import { urlToAgentDesiredModel } from "../urlToAgentDesiredModel"; +import { type AgentDesiredModel } from "@intentee/paddler-client/schemas/AgentDesiredModel"; +import { urlToAgentDesiredModel } from "@intentee/paddler-client/urlToAgentDesiredModel"; type EmptyState = { agentDesiredModel: "None"; diff --git a/resources/ts/hooks/useBalancerDesiredState.ts b/resources/ts/hooks/useBalancerDesiredState.ts index 03f09e29..b1d8f544 100644 --- a/resources/ts/hooks/useBalancerDesiredState.ts +++ b/resources/ts/hooks/useBalancerDesiredState.ts @@ -1,6 +1,6 @@ import { useCallback } from "react"; -import { BalancerDesiredStateSchema } from "../schemas/BalancerDesiredState"; +import { BalancerDesiredStateSchema } from "@intentee/paddler-client/schemas/BalancerDesiredState"; import { useFetchJson } from "./useFetchJson"; export function useBalancerDesiredState({ diff --git a/resources/ts/hooks/useChatTemplateOverride.ts b/resources/ts/hooks/useChatTemplateOverride.ts index cb66bd63..f875ffd4 100644 --- a/resources/ts/hooks/useChatTemplateOverride.ts +++ b/resources/ts/hooks/useChatTemplateOverride.ts @@ -1,6 +1,6 @@ import { useCallback } from "react"; -import { ChatTemplateSchema } from "../schemas/ChatTemplate"; +import { ChatTemplateSchema } from "@intentee/paddler-client/schemas/ChatTemplate"; import { useFetchJson } from "./useFetchJson"; const responseSchema = ChatTemplateSchema.nullable(); diff --git a/resources/ts/hooks/useEventSourceUpdates.ts b/resources/ts/hooks/useEventSourceUpdates.ts index 30632f2f..45a79160 100644 --- a/resources/ts/hooks/useEventSourceUpdates.ts +++ b/resources/ts/hooks/useEventSourceUpdates.ts @@ -1,93 +1,9 @@ import { useEffect, useState } from "react"; -import { z } from "zod"; +import type { z } from "zod"; -export type ConnectedState = { - data: undefined; - isConnected: true; - isConnectionError: false; - isDeserializationError: false; - isInitial: false; - isOk: false; -}; - -export type ConnectionErrorState = { - data: undefined; - isConnected: false; - isConnectionError: true; - isDeserializationError: false; - isInitial: false; - isOk: false; -}; - -export type DataSnapshotState = { - data: z.infer; - isConnected: true; - isConnectionError: false; - isDeserializationError: false; - isInitial: false; - isOk: true; -}; - -export type DeserializationErrorState = { - data: undefined; - isConnected: true; - isConnectionError: false; - isDeserializationError: true; - isInitial: false; - isOk: false; -}; - -export type InitialStreamState = { - data: undefined; - isConnected: false; - isConnectionError: false; - isDeserializationError: false; - isInitial: true; - isOk: false; -}; - -export type StreamState = - | ConnectedState - | ConnectionErrorState - | DataSnapshotState - | DeserializationErrorState - | InitialStreamState; - -const connectedState: ConnectedState = Object.freeze({ - data: undefined, - isConnected: true, - isConnectionError: false, - isDeserializationError: false, - isInitial: false, - isOk: false, -}); - -const connectionErrorState: ConnectionErrorState = Object.freeze({ - data: undefined, - isConnected: false, - isConnectionError: true, - isDeserializationError: false, - isInitial: false, - isOk: false, -}); - -const deserializationErrorState: DeserializationErrorState = Object.freeze({ - data: undefined, - isConnected: true, - isConnectionError: false, - isDeserializationError: true, - isInitial: false, - isOk: false, -}); - -const defaultStreamState: InitialStreamState = Object.freeze({ - data: undefined, - isConnected: false, - isConnectionError: false, - isDeserializationError: false, - isInitial: true, - isOk: false, -}); +import { eventSourceInitialState } from "@intentee/paddler-client/EventSourceInitialState"; +import { type EventSourceState } from "@intentee/paddler-client/EventSourceState"; +import { streamEventSource } from "@intentee/paddler-client/streamEventSource"; export function useEventSourceUpdates({ endpoint, @@ -95,57 +11,23 @@ export function useEventSourceUpdates({ }: { endpoint: string; schema: TSchema; -}): StreamState { - const [streamState, setStreamState] = - useState>(defaultStreamState); +}): EventSourceState { + const [streamState, setEventSourceState] = useState< + EventSourceState + >(eventSourceInitialState); useEffect( function () { - const eventSource = new EventSource(endpoint); - - eventSource.addEventListener("error", function () { - setStreamState(connectionErrorState); - }); - - eventSource.addEventListener("message", function (event) { - if ("string" !== typeof event.data) { - console.error("Received non-string data from SSE:", event.data); - setStreamState(deserializationErrorState); - - return; - } - - const parsed = JSON.parse(event.data); - const result = schema.safeParse(parsed); - - if (!result.success) { - console.error( - "Deserialization error:", - JSON.stringify(parsed, null, " "), - result.error.issues, - ); - setStreamState(deserializationErrorState); - } else { - setStreamState({ - data: result.data, - isConnected: true, - isConnectionError: false, - isDeserializationError: false, - isInitial: false, - isOk: true, - }); - } - }); - - eventSource.addEventListener("open", function () { - setStreamState(connectedState); - }); + const subscription = streamEventSource({ + url: endpoint, + schema, + }).subscribe(setEventSourceState); return function () { - eventSource.close(); + subscription.unsubscribe(); }; }, - [endpoint, schema, setStreamState], + [endpoint, schema, setEventSourceState], ); return streamState; diff --git a/resources/ts/hooks/useFetchJson.ts b/resources/ts/hooks/useFetchJson.ts index f1158759..19589d93 100644 --- a/resources/ts/hooks/useFetchJson.ts +++ b/resources/ts/hooks/useFetchJson.ts @@ -1,59 +1,9 @@ import { useEffect, useState } from "react"; -import { z } from "zod"; +import type { z } from "zod"; -export type EmptyState = { - empty: true; - error: null; - loading: false; - ok: false; - response: null; -}; - -export type ErrorState = { - empty: false; - error: string; - loading: false; - response: null; - ok: false; -}; - -export type LoadingState = { - empty: false; - error: null; - loading: true; - response: null; - ok: false; -}; - -export type SuccessState = { - empty: false; - error: null; - loading: false; - response: TResult; - ok: true; -}; - -export type FetchJsonState = - | EmptyState - | ErrorState - | LoadingState - | SuccessState; - -const emptyState: EmptyState = Object.freeze({ - empty: true, - error: null, - loading: false, - response: null, - ok: false, -}); - -const loadingState: LoadingState = Object.freeze({ - empty: false, - error: null, - loading: true, - response: null, - ok: false, -}); +import { fetchJsonEmptyState } from "@intentee/paddler-client/FetchJsonEmptyState"; +import { fetchJsonLoadingState } from "@intentee/paddler-client/FetchJsonLoadingState"; +import { type FetchJsonState } from "@intentee/paddler-client/FetchJsonState"; export function useFetchJson({ produceFetchPromise, @@ -65,8 +15,9 @@ export function useFetchJson({ ): null | Promise; responseSchema: TResponseSchema; }): FetchJsonState> { - const [fetchState, setFetchState] = - useState>>(loadingState); + const [fetchState, setFetchState] = useState< + FetchJsonState> + >(fetchJsonLoadingState); useEffect( function () { @@ -74,14 +25,14 @@ export function useFetchJson({ const fetchPromise = produceFetchPromise(abortController.signal); if (!fetchPromise) { - setFetchState(emptyState); + setFetchState(fetchJsonEmptyState); return function () { abortController.abort("Fetch promise was not provided."); }; } - setFetchState(loadingState); + setFetchState(fetchJsonLoadingState); fetchPromise .then(function (response) { @@ -99,8 +50,8 @@ export function useFetchJson({ empty: false, error: null, loading: false, - response: result, ok: true, + response: result, }); }) .catch(function (error: unknown) { @@ -108,8 +59,8 @@ export function useFetchJson({ empty: false, error: String(error), loading: false, - response: null, ok: false, + response: null, }); }); diff --git a/resources/ts/hooks/usePrompt.ts b/resources/ts/hooks/usePrompt.ts index 81e69527..e5e68edd 100644 --- a/resources/ts/hooks/usePrompt.ts +++ b/resources/ts/hooks/usePrompt.ts @@ -1,6 +1,7 @@ import { useEffect, useState } from "react"; -import { InferenceServiceGenerateTokensResponseSchema } from "../schemas/InferenceServiceGenerateTokensResponse"; +import { InferenceServiceGenerateTokensResponseSchema } from "@intentee/paddler-client/schemas/InferenceServiceGenerateTokensResponse"; +import { streamHttpNdjson } from "@intentee/paddler-client/streamHttpNdjson"; export function usePrompt({ inferenceAddr, @@ -19,8 +20,9 @@ export function usePrompt({ setMessage(""); - fetch(`//${inferenceAddr}/api/v1/continue_from_conversation_history`, { - body: JSON.stringify({ + const subscription = streamHttpNdjson({ + url: `//${inferenceAddr}/api/v1/continue_from_conversation_history`, + body: { add_generation_prompt: true, conversation_history: [ { role: "assistant", content: systemPrompt }, @@ -28,63 +30,34 @@ export function usePrompt({ ], enable_thinking: false, max_tokens: 300, - }), - headers: { - "Content-Type": "application/json", }, - method: "POST", signal: abortController.signal, - }) - .then(function ({ body }) { - if (!body) { - throw new Error("No response body"); + schema: InferenceServiceGenerateTokensResponseSchema, + }).subscribe({ + next(validatedMessage) { + if (validatedMessage.done) { + return; } - return body.getReader(); - }) - .then(async function (reader) { - const decoder = new TextDecoder(); - - // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition - while (true) { - const { done, value } = await reader.read(); - - if (done || abortController.signal.aborted) { - return; - } - - const chunk = decoder.decode(value, { - stream: true, - }); - - const lines = chunk.split("\n").filter(function (line) { - return line.trim(); - }); - - for (const line of lines) { - try { - const message = JSON.parse(line); - const validatedMessage = - InferenceServiceGenerateTokensResponseSchema.parse(message); - - if (validatedMessage.done) { - return; - } + if (null === validatedMessage.token) { + return; + } - setMessage(function (prevMessage) { - return `${prevMessage}${validatedMessage.token}`; - }); - } catch (err) { - console.error("Error:", err); - } - } + if ("content" !== validatedMessage.tokenKind) { + return; } - }) - .catch(function (error: unknown) { + + setMessage(function (prevMessage) { + return `${prevMessage}${validatedMessage.token}`; + }); + }, + error(error: unknown) { console.error("Error during fetch:", error); - }); + }, + }); return function () { + subscription.unsubscribe(); abortController.abort(); }; }, diff --git a/resources/ts/hooks/useWebSocket.ts b/resources/ts/hooks/useWebSocket.ts index 74afcba9..fe778989 100644 --- a/resources/ts/hooks/useWebSocket.ts +++ b/resources/ts/hooks/useWebSocket.ts @@ -1,59 +1,9 @@ import { useEffect, useRef, useState } from "react"; -export type ConnectingState = { - isConnected: false; - isConnectionClosed: false; - isConnectionError: false; - webSocket: null; -}; - -export type ConnectionClosedState = { - isConnected: false; - isConnectionClosed: true; - isConnectionError: false; - webSocket: null; -}; - -export type ConnectionErrorState = { - isConnected: false; - isConnectionClosed: false; - isConnectionError: true; - webSocket: null; -}; - -export type ConnectionOpenedState = { - isConnected: true; - isConnectionClosed: false; - isConnectionError: false; - webSocket: WebSocket; -}; - -export type SocketState = - | ConnectingState - | ConnectionClosedState - | ConnectionErrorState - | ConnectionOpenedState; - -const connectionClosedState: ConnectionClosedState = Object.freeze({ - isConnected: false, - isConnectionClosed: true, - isConnectionError: false, - webSocket: null, -}); - -const connectionErrorState: ConnectionErrorState = Object.freeze({ - isConnected: false, - isConnectionClosed: false, - isConnectionError: true, - webSocket: null, -}); - -const defaultSocketState: ConnectingState = Object.freeze({ - isConnected: false, - isConnectionClosed: false, - isConnectionError: false, - webSocket: null, -}); +import { webSocketConnectingState } from "@intentee/paddler-client/WebSocketConnectingState"; +import { webSocketConnectionClosedState } from "@intentee/paddler-client/WebSocketConnectionClosedState"; +import { webSocketConnectionErrorState } from "@intentee/paddler-client/WebSocketConnectionErrorState"; +import { type WebSocketState } from "@intentee/paddler-client/WebSocketState"; const MAX_RECONNECT_DEBOUNCE_TIME_INCREASE = 3; const RECONNECT_DELAY = 600; @@ -62,9 +12,14 @@ function incrementVersion(version: number): number { return version + 1; } -export function useWebSocket({ endpoint }: { endpoint: string }): SocketState { - const [socketState, setSocketState] = - useState(defaultSocketState); +export function useWebSocket({ + endpoint, +}: { + endpoint: string; +}): WebSocketState { + const [socketState, setSocketState] = useState( + webSocketConnectingState, + ); const [version, setVersion] = useState(0); const [webSocket, setWebSocket] = useState(null); const reconnectAttempts = useRef(0); @@ -120,13 +75,13 @@ export function useWebSocket({ endpoint }: { endpoint: string }): SocketState { } webSocket.addEventListener("close", function () { - setSocketState(connectionClosedState); + setSocketState(webSocketConnectionClosedState); setVersion(incrementVersion); }); webSocket.addEventListener("error", function (event) { console.error("WebSocket error:", event); - setSocketState(connectionErrorState); + setSocketState(webSocketConnectionErrorState); setVersion(incrementVersion); }); diff --git a/resources/ts/inferenceParametersFormKeys.ts b/resources/ts/inferenceParametersFormKeys.ts new file mode 100644 index 00000000..9ebc086d --- /dev/null +++ b/resources/ts/inferenceParametersFormKeys.ts @@ -0,0 +1,17 @@ +import type { InferenceParameters } from "@intentee/paddler-client/schemas/InferenceParameters"; + +export type InferenceParametersBooleanKeys = { + [TKey in keyof InferenceParameters]: TKey extends string + ? InferenceParameters[TKey] extends boolean + ? TKey + : never + : never; +}[keyof InferenceParameters]; + +export type InferenceParametersNumberKeys = { + [TKey in keyof InferenceParameters]: TKey extends string + ? InferenceParameters[TKey] extends number + ? TKey + : never + : never; +}[keyof InferenceParameters]; diff --git a/resources/ts/matchEventSourceUpdateState.ts b/resources/ts/matchEventSourceUpdateState.ts index 4883e50a..065549be 100644 --- a/resources/ts/matchEventSourceUpdateState.ts +++ b/resources/ts/matchEventSourceUpdateState.ts @@ -1,24 +1,23 @@ import { type ReactNode } from "react"; import { z } from "zod"; -import { - type ConnectedState, - type ConnectionErrorState, - type DataSnapshotState, - type DeserializationErrorState, - type InitialStreamState, - type StreamState, -} from "./hooks/useEventSourceUpdates"; + +import { type EventSourceConnectedState } from "@intentee/paddler-client/EventSourceConnectedState"; +import { type EventSourceConnectionErrorState } from "@intentee/paddler-client/EventSourceConnectionErrorState"; +import { type EventSourceDataSnapshotState } from "@intentee/paddler-client/EventSourceDataSnapshotState"; +import { type EventSourceDeserializationErrorState } from "@intentee/paddler-client/EventSourceDeserializationErrorState"; +import { type EventSourceInitialState } from "@intentee/paddler-client/EventSourceInitialState"; +import { type EventSourceState } from "@intentee/paddler-client/EventSourceState"; interface Handlers { - connected(state: ConnectedState): ReactNode; - connectionError(state: ConnectionErrorState): ReactNode; - dataSnapshot(state: DataSnapshotState): ReactNode; - deserializationError(state: DeserializationErrorState): ReactNode; - initial(state: InitialStreamState): ReactNode; + connected(state: EventSourceConnectedState): ReactNode; + connectionError(state: EventSourceConnectionErrorState): ReactNode; + dataSnapshot(state: EventSourceDataSnapshotState): ReactNode; + deserializationError(state: EventSourceDeserializationErrorState): ReactNode; + initial(state: EventSourceInitialState): ReactNode; } export function matchEventSourceUpdateState( - streamState: StreamState, + streamState: EventSourceState, handlers: Handlers>, ): ReactNode { if (streamState.isInitial) { diff --git a/resources/ts/matchFetchJsonState.ts b/resources/ts/matchFetchJsonState.ts index c0daa65b..f586f0a1 100644 --- a/resources/ts/matchFetchJsonState.ts +++ b/resources/ts/matchFetchJsonState.ts @@ -1,18 +1,16 @@ import { type ReactNode } from "react"; -import { - type EmptyState, - type ErrorState, - type FetchJsonState, - type LoadingState, - type SuccessState, -} from "./hooks/useFetchJson"; +import { type FetchJsonEmptyState } from "@intentee/paddler-client/FetchJsonEmptyState"; +import { type FetchJsonErrorState } from "@intentee/paddler-client/FetchJsonErrorState"; +import { type FetchJsonLoadingState } from "@intentee/paddler-client/FetchJsonLoadingState"; +import { type FetchJsonState } from "@intentee/paddler-client/FetchJsonState"; +import { type FetchJsonSuccessState } from "@intentee/paddler-client/FetchJsonSuccessState"; interface Handlers { - empty(state: EmptyState): ReactNode; - error(state: ErrorState): ReactNode; - loading(state: LoadingState): ReactNode; - ok(state: SuccessState): ReactNode; + empty(state: FetchJsonEmptyState): ReactNode; + error(state: FetchJsonErrorState): ReactNode; + loading(state: FetchJsonLoadingState): ReactNode; + ok(state: FetchJsonSuccessState): ReactNode; } export function matchFetchJsonState( @@ -27,13 +25,9 @@ export function matchFetchJsonState( return handlers.loading(state); } - if (state.error) { + if (state.error !== null) { return handlers.error(state); } - if (state.ok) { - return handlers.ok(state); - } - - throw new Error(`Invalid state: ${JSON.stringify(state)}`); + return handlers.ok(state); } diff --git a/resources/ts/matchWebSocketState.ts b/resources/ts/matchWebSocketState.ts index 3726e647..2fa81b23 100644 --- a/resources/ts/matchWebSocketState.ts +++ b/resources/ts/matchWebSocketState.ts @@ -1,21 +1,20 @@ import { type ReactNode } from "react"; -import { - type ConnectingState, - type ConnectionClosedState, - type ConnectionErrorState, - type ConnectionOpenedState, - type SocketState, -} from "./hooks/useWebSocket"; + +import { type WebSocketConnectingState } from "@intentee/paddler-client/WebSocketConnectingState"; +import { type WebSocketConnectionClosedState } from "@intentee/paddler-client/WebSocketConnectionClosedState"; +import { type WebSocketConnectionErrorState } from "@intentee/paddler-client/WebSocketConnectionErrorState"; +import { type WebSocketConnectionOpenedState } from "@intentee/paddler-client/WebSocketConnectionOpenedState"; +import { type WebSocketState } from "@intentee/paddler-client/WebSocketState"; interface Handlers { - connected(socketState: ConnectionOpenedState): ReactNode; - connecting(socketState: ConnectingState): ReactNode; - connectionClosed(socketState: ConnectionClosedState): ReactNode; - connectionError(socketState: ConnectionErrorState): ReactNode; + connected(state: WebSocketConnectionOpenedState): ReactNode; + connecting(state: WebSocketConnectingState): ReactNode; + connectionClosed(state: WebSocketConnectionClosedState): ReactNode; + connectionError(state: WebSocketConnectionErrorState): ReactNode; } export function matchWebSocketState( - socketState: SocketState, + socketState: WebSocketState, handlers: Handlers, ): ReactNode { if (socketState.isConnected) { diff --git a/resources/ts/schemas/InferenceServiceGenerateTokensResponse.ts b/resources/ts/schemas/InferenceServiceGenerateTokensResponse.ts deleted file mode 100644 index 2f8bf930..00000000 --- a/resources/ts/schemas/InferenceServiceGenerateTokensResponse.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { z } from "zod"; - -export const InferenceServiceGenerateTokensResponseSchema = z - .union([ - z.object({ - Error: z.object({ - error: z.object({ - code: z.number(), - description: z.string(), - }), - request_id: z.string(), - }), - }), - z.object({ - Response: z.object({ - request_id: z.string(), - response: z.object({ - GeneratedToken: z.union([ - z.object({ - ChatTemplateError: z.string(), - }), - z.literal("Done"), - z.object({ - ImageDecodingFailed: z.string(), - }), - z.object({ - MultimodalNotSupported: z.string(), - }), - z.object({ - Token: z.string(), - }), - ]), - }), - }), - }), - ]) - .transform(function (data): - | { - done: true; - error: null; - ok: true; - request_id: string; - token: null; - } - | { - done: false; - error: null; - ok: true; - request_id: string; - token: string; - } - | { - done: true; - error: { - code: number; - description: string; - }; - ok: false; - request_id: string; - token: null; - } { - if ("Error" in data) { - return Object.freeze({ - done: true, - error: data.Error.error, - ok: false, - request_id: data.Error.request_id, - token: null, - }); - } - - if (data.Response.response.GeneratedToken === "Done") { - return Object.freeze({ - done: true, - error: null, - ok: true, - request_id: data.Response.request_id, - token: null, - }); - } - - if ("ChatTemplateError" in data.Response.response.GeneratedToken) { - return Object.freeze({ - done: true, - error: Object.freeze({ - code: 500, - description: data.Response.response.GeneratedToken.ChatTemplateError, - }), - ok: false, - request_id: data.Response.request_id, - token: null, - }); - } - - if ("ImageDecodingFailed" in data.Response.response.GeneratedToken) { - return Object.freeze({ - done: true, - error: Object.freeze({ - code: 400, - description: - data.Response.response.GeneratedToken.ImageDecodingFailed, - }), - ok: false, - request_id: data.Response.request_id, - token: null, - }); - } - - if ("MultimodalNotSupported" in data.Response.response.GeneratedToken) { - return Object.freeze({ - done: true, - error: Object.freeze({ - code: 400, - description: - data.Response.response.GeneratedToken.MultimodalNotSupported, - }), - ok: false, - request_id: data.Response.request_id, - token: null, - }); - } - - if ("Token" in data.Response.response.GeneratedToken) { - return Object.freeze({ - done: false, - error: null, - ok: true, - request_id: data.Response.request_id, - token: data.Response.response.GeneratedToken.Token, - }); - } - - return Object.freeze({ - done: true, - error: null, - ok: true, - request_id: data.Response.request_id, - token: null, - }); - }); - -export type InferenceServiceGenerateTokensResponse = z.infer< - typeof InferenceServiceGenerateTokensResponseSchema ->; diff --git a/resources/ts/urlToAgentDesiredModel_test.ts b/resources/ts/urlToAgentDesiredModel_test.ts deleted file mode 100644 index 473c7e72..00000000 --- a/resources/ts/urlToAgentDesiredModel_test.ts +++ /dev/null @@ -1,24 +0,0 @@ -import test from "ava"; -import { urlToAgentDesiredModel } from "./urlToAgentDesiredModel"; - -test("recognizes Hugging Face urls", function (test) { - const url = new URL( - "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/blob/main/Qwen3-0.6B-Q8_0.gguf", - ); - - test.deepEqual(urlToAgentDesiredModel(url), { - HuggingFace: { - filename: "Qwen3-0.6B-Q8_0.gguf", - repo_id: "Qwen/Qwen3-0.6B-GGUF", - revision: "main", - }, - }); -}); - -test("uses local urls", function (test) { - const url = new URL("agent:///home/user/models/Qwen3-0.6B-Q8_0.gguf"); - - test.deepEqual(urlToAgentDesiredModel(url), { - LocalToAgent: "/home/user/models/Qwen3-0.6B-Q8_0.gguf", - }); -}); diff --git a/resources/ts/webSocketProtocol.ts b/resources/ts/webSocketProtocol.ts deleted file mode 100644 index 8f00b582..00000000 --- a/resources/ts/webSocketProtocol.ts +++ /dev/null @@ -1,3 +0,0 @@ -export function webSocketProtocol(windowProtocol: string): string { - return windowProtocol === "https:" ? "wss:" : "ws:"; -} diff --git a/tsconfig.json b/tsconfig.json index e5d0b138..3c5e36d8 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -18,7 +18,7 @@ "esnext.disposable" ], "module": "esnext", - "moduleResolution": "node", + "moduleResolution": "bundler", "noEmit": true, "noFallthroughCasesInSwitch": true, "noImplicitAny": true,