From befd6d12405f75c01a2e9cad4f49419ebfd74385 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Fri, 1 May 2026 23:05:53 +1200 Subject: [PATCH 01/11] fix --- .github/workflows/ci.yml | 59 +++-- .gitignore | 4 + Cargo.toml | 138 +++++++++--- benches/foo.rs | 1 - examples/embed_text.rs | 73 ++++++ examples/foo.rs | 1 - src/embedding.rs | 333 +++++++++++++++++++++++++++ src/error.rs | 99 ++++++++ src/lib.rs | 51 ++++- src/options.rs | 473 +++++++++++++++++++++++++++++++++++++++ src/session.rs | 101 +++++++++ src/simd/mod.rs | 148 ++++++++++++ src/simd/neon.rs | 101 +++++++++ src/simd/scalar.rs | 57 +++++ src/simd/x86.rs | 105 +++++++++ src/text_enc.rs | 438 ++++++++++++++++++++++++++++++++++++ tests/foo.rs | 1 - tests/integration.rs | 166 ++++++++++++++ 18 files changed, 2272 insertions(+), 77 deletions(-) delete mode 100644 benches/foo.rs create mode 100644 examples/embed_text.rs delete mode 100644 examples/foo.rs create mode 100644 src/embedding.rs create mode 100644 src/error.rs create mode 100644 src/options.rs create mode 100644 src/session.rs create mode 100644 src/simd/mod.rs create mode 100644 src/simd/neon.rs create mode 100644 src/simd/scalar.rs create mode 100644 src/simd/x86.rs create mode 100644 src/text_enc.rs delete mode 100644 tests/foo.rs create mode 100644 tests/integration.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 545e1d8..4b7e55e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,12 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Apply clippy lints - run: cargo hack clippy --each-feature --exclude-no-default-features + # Skip the opt-in EP features (`cuda`, `tensorrt`, `directml`, + # `rocm`, `coreml`) — those activate `ort/`, which requires + # the corresponding vendor SDK at build time. Standard CI runners + # don't have CUDA / TensorRT / ROCm installed, and `directml` / + # `coreml` are platform-restricted. + run: cargo hack clippy --each-feature --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml # Run tests on some extra platforms cross: @@ -94,9 +99,19 @@ jobs: - name: Install Rust run: rustup update stable && rustup default stable - name: cargo build --target ${{ matrix.target }} + # The cross matrix verifies the no-inference subset (the + # `Embedding`, `Options`, `Error` types) compiles on every + # listed target, including wasm32-* and tier-2/3 native + # platforms where ort prebuilds aren't available. The default + # `inference` feature pulls `ort` + `tokenizers`, neither of + # which targets wasm or most non-tier-1 native targets, so + # passing default features here would only re-fail the build + # in their C/FFI prereqs. The native inference path is + # exercised by the `build` / `test` jobs below on + # ubuntu-latest / macos-latest / windows-latest. run: | rustup target add ${{ matrix.target }} - cargo build --target ${{ matrix.target }} + cargo build --target ${{ matrix.target }} --no-default-features build: name: build @@ -125,7 +140,8 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Run build - run: cargo hack build --feature-powerset --exclude-no-default-features + # See the clippy job for why the EP features are excluded. + run: cargo hack build --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml test: name: test @@ -154,7 +170,8 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Run test - run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features loom + # See the clippy job for why the EP features are excluded. + run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml sanitizer: name: sanitizer @@ -250,32 +267,6 @@ jobs: run: | bash ci/miri_sb.sh "${{ matrix.target }}" - loom: - name: loom - strategy: - matrix: - os: - - ubuntu-latest - - macos-latest - - windows-latest - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-loom-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-loom- - - name: Install Rust - run: rustup update nightly --no-self-update && rustup default nightly - - name: Loom tests - run: cargo test --tests --features loom - # valgrind: # name: valgrind # runs-on: ubuntu-latest @@ -315,7 +306,6 @@ jobs: - cross - test - sanitizer - - loom steps: - uses: actions/checkout@v6 - name: Install Rust @@ -335,7 +325,12 @@ jobs: - name: Run tarpaulin env: RUSTFLAGS: "--cfg tarpaulin" - run: cargo tarpaulin --all-features --run-types tests --run-types doctests --workspace --out xml + # `--all-features` would activate `cuda` / `tensorrt` / `rocm` + # / `directml` / `coreml`, which require the corresponding + # vendor SDKs to compile — see the clippy job. Cover only the + # features that actually compile on a stock ubuntu-latest + # runner. + run: cargo tarpaulin --features inference,serde --run-types tests --run-types doctests --workspace --out xml - name: Upload to codecov.io uses: codecov/codecov-action@v6 with: diff --git a/.gitignore b/.gitignore index 01e0c11..4f54896 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,7 @@ /target Cargo.lock + +**.claude/ +docs/ + diff --git a/Cargo.toml b/Cargo.toml index ff7fe91..d7f2aeb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,48 +1,118 @@ [package] -name = "template-rs" -version = "0.0.0" -edition = "2021" -repository = "https://github.com/al8n/template-rs" -homepage = "https://github.com/al8n/template-rs" -documentation = "https://docs.rs/template-rs" -description = "A template for creating Rust open-source repo on GitHub" -license = "MIT OR Apache-2.0" -rust-version = "1.73" - -[[bench]] -path = "benches/foo.rs" -name = "foo" -harness = false - -[features] -default = ["std"] -alloc = [] -std = [] +name = "egemma" +version = "0.1.0" +edition = "2024" +rust-version = "1.95" +description = "Rust ONNX inference library for Google's EmbeddingGemma (text embeddings)" +license = "MIT OR Apache-2.0" +repository = "https://github.com/Findit-AI/egemma" +keywords = ["embedding-gemma", "gemma", "embedding", "onnx", "ml"] +categories = ["science"] +include = [ + "src/**/*.rs", + "examples/**/*.rs", + "Cargo.toml", + "README.md", + "CHANGELOG.md", + "LICENSE-*", +] [dependencies] +# `ort` and `tokenizers` are gated on the `inference` feature so the +# crate compiles for wasm32 (where neither builds today — `ort` needs +# native ONNX Runtime FFI and `tokenizers` pulls `onig_sys` / +# `esaxx-rs` which lack wasm targets). Without `inference` the crate +# exposes the `Embedding` type, the `Options` surface, and the error +# enum — useful in environments where inference happens elsewhere. +ort = { version = "2.0.0-rc.12", optional = true } +tokenizers = { version = "0.23", optional = true } +thiserror = "2" +# `serde` is opt-in: most consumers use `egemma` purely as an inference +# library and don't need to (de)serialize `Options`. Pulling serde +# unconditionally adds compile time and binary size for no benefit on +# the common path. Mirrors the `silero` pattern — gate the dependency, +# the `derive(Serialize, Deserialize)` calls, and every field-level +# `#[serde(...)]` attribute on `feature = "serde"`. +serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] -criterion = "0.8" -tempfile = "3" +serde_json = "1" +tempfile = "3" + +# ===================================================================== +# Target / feature contract. +# +# Default features include `inference`, which pulls `ort` + `tokenizers`. +# Both require native build scripts (`ort` needs ONNX Runtime FFI; +# `tokenizers` pulls `onig_sys` / `esaxx-rs`, which fail on wasm32 +# because their C deps don't have wasm targets). +# +# **Wasm builds must use `--no-default-features`.** Building wasm with +# default features fails inside `getrandom` / `onig_sys` long before +# this crate's code is ever touched, with a confusing C-toolchain +# error. The `--no-default-features` build compiles the `Embedding`, +# `Options`, and `Error` subset — useful when inference happens +# elsewhere (a server, a different runtime). This mirrors the +# `siglip2` pattern: native = inference by default; wasm = explicit +# opt-out, no-inference subset. +# +# CI / build-matrix consumers: gate the wasm job on +# `--no-default-features`, e.g. +# `cargo check --target wasm32-unknown-unknown --no-default-features`. +# ===================================================================== +[features] +default = ["inference"] +# Activates the ONNX-backed inference path (`TextEncoder`). Native +# targets only — see the target / feature contract block above for the +# wasm escape hatch. +inference = ["dep:ort", "dep:tokenizers"] +# Pulls the `serde` dependency and activates `Serialize` / `Deserialize` +# on `Options`, `BatchOptions`, and `ThreadOptions`. `Embedding` +# deliberately does NOT carry these derives — round-trip via the inner +# slice (see `Embedding` docs) so the dim and L2-norm invariants that +# `TryFrom>` exists to enforce can't be bypassed. +serde = ["dep:serde"] + +# ----- Opt-in execution providers -------------------------------------- +# +# Feature names mirror ort's own EP feature flags so a downstream +# consumer can reason about both crates with the same vocabulary +# (`egemma/cuda` ↔ `ort/cuda`, etc.). Each feature activates the +# corresponding ort sub-feature so the prebuilt ONNX Runtime that +# gets linked includes that EP, and `session::build_session` +# cfg-registers the EP at session-build time. None are enabled by +# default — see `siglip2/Cargo.toml` for the measurement methodology. +cuda = ["inference", "ort/cuda"] +tensorrt = ["inference", "ort/tensorrt"] +directml = ["inference", "ort/directml"] +rocm = ["inference", "ort/rocm"] +coreml = ["inference", "ort/coreml"] + +[[test]] +name = "integration" +path = "tests/integration.rs" +required-features = ["inference"] + +[[example]] +name = "embed_text" +path = "examples/embed_text.rs" +required-features = ["inference"] [profile.bench] -opt-level = 3 -debug = false -codegen-units = 1 -lto = 'thin' -incremental = false -debug-assertions = false -overflow-checks = false -rpath = false +opt-level = 3 +debug = false +codegen-units = 1 +lto = 'thin' +incremental = false +debug-assertions = false +overflow-checks = false +rpath = false [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [lints.rust] -rust_2018_idioms = "warn" +rust_2018_idioms = "warn" single_use_lifetimes = "warn" -unexpected_cfgs = { level = "warn", check-cfg = [ - 'cfg(all_tests)', - 'cfg(tarpaulin)', -] } +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(all_tests)', 'cfg(tarpaulin)', 'cfg(docsrs)'] } diff --git a/benches/foo.rs b/benches/foo.rs deleted file mode 100644 index f328e4d..0000000 --- a/benches/foo.rs +++ /dev/null @@ -1 +0,0 @@ -fn main() {} diff --git a/examples/embed_text.rs b/examples/embed_text.rs new file mode 100644 index 0000000..cca3bf5 --- /dev/null +++ b/examples/embed_text.rs @@ -0,0 +1,73 @@ +//! Embed a few sentences with `embedding-gemma` and print pairwise cosine +//! similarity. Run from the crate root with the canonical fp32 export +//! from `onnx-community/embeddinggemma-300m-ONNX`: +//! +//! ```bash +//! cargo run --example embed_text -- \ +//! /path/to/model.onnx /path/to/tokenizer.json +//! ``` +//! +//! The upstream model card flags fp16 as an unsupported activation +//! dtype for this graph; pass `model_fp16.onnx` only if you've +//! validated quality for your specific workload. + +use std::{env, path::PathBuf, process::ExitCode}; + +use egemma::TextEncoder; + +fn main() -> ExitCode { + let mut args = env::args().skip(1); + let graph: PathBuf = match args.next() { + Some(p) => p.into(), + None => { + eprintln!("usage: embed_text "); + return ExitCode::from(2); + } + }; + let tokenizer: PathBuf = match args.next() { + Some(p) => p.into(), + None => { + eprintln!("usage: embed_text "); + return ExitCode::from(2); + } + }; + + let mut encoder = match TextEncoder::from_files(&graph, &tokenizer) { + Ok(e) => e, + Err(err) => { + eprintln!("failed to load encoder: {err}"); + return ExitCode::FAILURE; + } + }; + + let prompts = [ + "task: search result | query: how do I build a Rust ONNX inference library?", + "Rust crates that wrap ONNX Runtime for embedding generation.", + "Today's weather forecast for Singapore.", + ]; + + let embeddings = match encoder.embed_batch(&prompts) { + Ok(v) => v, + Err(err) => { + eprintln!("embed failed: {err}"); + return ExitCode::FAILURE; + } + }; + + for (i, e) in embeddings.iter().enumerate() { + println!("[{i}] {:?}", &e.as_slice()[..6]); + } + println!(); + for i in 0..embeddings.len() { + for j in (i + 1)..embeddings.len() { + match embeddings[i].try_cosine(&embeddings[j]) { + Ok(cos) => println!("cos({i}, {j}) = {cos:.4}"), + Err(err) => { + eprintln!("cos({i}, {j}) failed: {err}"); + return ExitCode::FAILURE; + } + } + } + } + ExitCode::SUCCESS +} diff --git a/examples/foo.rs b/examples/foo.rs deleted file mode 100644 index f328e4d..0000000 --- a/examples/foo.rs +++ /dev/null @@ -1 +0,0 @@ -fn main() {} diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..becc837 --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,333 @@ +//! `Embedding` — L2-normalized 768-dim sentence embedding. + +use std::sync::Arc; + +use crate::error::{Error, Result}; + +/// L2-normalized embedding. Length is `EMBED_DIM` (768) in 0.1.0. +/// +/// `Embedding` deliberately does **not** implement `Serialize` or `Deserialize`. +/// An auto-derived `Deserialize` would bypass the dim and L2-norm invariants +/// that `TryFrom>` exists to enforce. Round-trip via the inner +/// representation: +/// +/// ```ignore +/// // Serialize via the inner slice (`&[f32]: Serialize`): +/// let json = serde_json::to_string(embedding.as_slice())?; +/// +/// // Deserialize via the validated path: +/// let v: Vec = serde_json::from_str(&json)?; +/// let embedding = Embedding::try_from(v)?; // validates dim + L2-norm +/// ``` +#[derive(Clone, Debug)] +pub struct Embedding(Arc<[f32]>); + +impl Embedding { + /// 0.1.0 supports only the 768-dim base export. + pub const EMBED_DIM: usize = 768; + + /// L2-norm tolerance for the unit-norm invariant. + pub const NORM_EPSILON: f32 = 5e-4; + + pub fn dim(&self) -> usize { + self.0.len() + } + + pub fn as_slice(&self) -> &[f32] { + &self.0 + } + + /// Returns the inner `Arc<[f32]>`. O(1) — atomic refcount only, no + /// data copy. Callers who need a fresh `Vec` can write + /// `embedding.into_inner().to_vec()` so the allocation is explicit. + pub fn into_inner(self) -> Arc<[f32]> { + self.0 + } + + /// Cosine similarity. Both operands must be unit-norm; valid because every + /// `Embedding` in this crate is L2-normalized at construction. + /// + /// Returns [`crate::Error::EmbeddingDim`] when `self.dim() != other.dim()` + /// or when either operand's dim doesn't equal [`Self::EMBED_DIM`]. In + /// 0.1.0 every public constructor (`try_from`, `from_model_output` via + /// `TextEncoder`) produces a 768-d `Embedding`, so the error path is + /// only reachable in-crate; the check is forward-compatibility for + /// variable-dim embeddings and a guard against future internal misuse. + /// + /// Internally dispatches through [`crate::simd::dot_768`] — picks NEON + /// on aarch64, AVX2+FMA on x86_64 (when the runtime CPU advertises + /// both), or a four-accumulator scalar fallback on every other target. + pub fn try_cosine(&self, other: &Embedding) -> Result { + if self.dim() != other.dim() { + return Err(Error::EmbeddingDim { + expected: self.dim(), + got: other.dim(), + }); + } + let a: &[f32; Self::EMBED_DIM] = + self + .as_slice() + .try_into() + .map_err(|_| Error::EmbeddingDim { + expected: Self::EMBED_DIM, + got: self.dim(), + })?; + let b: &[f32; Self::EMBED_DIM] = + other + .as_slice() + .try_into() + .map_err(|_| Error::EmbeddingDim { + expected: Self::EMBED_DIM, + got: other.dim(), + })?; + Ok(crate::simd::dot_768(a, b)) + } + + /// Crate-internal: build an `Embedding` from raw model output. The + /// `embedding-gemma` ONNX export emits `sentence_embedding` that may + /// or may not be L2-normalized depending on the optimum-export pipeline + /// — we re-normalize unconditionally so downstream cosine code is + /// always operating on unit-norm vectors. Rejection only happens for + /// dim mismatch, all-zero output (degenerate model state), or + /// non-finite components. + /// + /// The `TryFrom>` path keeps the strict near-unit-norm check + /// — that's for *caller-supplied* embeddings (e.g., deserialized from + /// a vector store) which should already be unit-norm; silent renorm + /// there would mask data corruption. + #[cfg(feature = "inference")] + pub(crate) fn from_model_output(data: &[f32]) -> Result { + let arr: &[f32; Self::EMBED_DIM] = data.try_into().map_err(|_| Error::EmbeddingDim { + expected: Self::EMBED_DIM, + got: data.len(), + })?; + let norm_sq = crate::simd::dot_768(arr, arr); + let norm = norm_sq.sqrt(); + if !norm.is_finite() || norm == 0.0 { + return Err(Error::NotNormalized { + norm, + epsilon: Self::NORM_EPSILON, + }); + } + let factor = 1.0 / norm; + let arc: Arc<[f32]> = data.iter().map(|&x| x * factor).collect(); + Ok(Self(arc)) + } +} + +impl TryFrom> for Embedding { + type Error = Error; + + /// Validates dim (`Error::EmbeddingDim`) and L2-norm + /// (`Error::NotNormalized`, tolerance `NORM_EPSILON`). This path is for + /// **caller-supplied** embeddings — typically deserialized from a + /// vector store — that should already be unit-norm; we reject (rather + /// than silently renormalize) so corruption can't slip through. + /// + /// Vectors whose `||v||₂` is within `NORM_EPSILON` of 1.0 are + /// snapped to exactly 1.0 (in-place renorm preserves the cosine + /// invariant under tiny f32 drift). + fn try_from(mut v: Vec) -> Result { + let norm_sq = { + let arr: &[f32; Self::EMBED_DIM] = + v.as_slice().try_into().map_err(|_| Error::EmbeddingDim { + expected: Self::EMBED_DIM, + got: v.len(), + })?; + crate::simd::dot_768(arr, arr) + }; + let norm = norm_sq.sqrt(); + if !norm.is_finite() || (norm - 1.0).abs() > Self::NORM_EPSILON { + return Err(Error::NotNormalized { + norm, + epsilon: Self::NORM_EPSILON, + }); + } + let factor = 1.0 / norm; + for x in &mut v { + *x *= factor; + } + Ok(Self(v.into())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn unit_vec(dim: usize) -> Vec { + let mut v = vec![0.0f32; dim]; + v[0] = 1.0; + v + } + + #[test] + fn try_from_accepts_unit_norm_768() { + let v = unit_vec(768); + let e = Embedding::try_from(v).expect("unit-norm 768-dim should succeed"); + assert_eq!(e.dim(), 768); + let cos = e.try_cosine(&e).expect("happy path"); + assert!((cos - 1.0).abs() < 1e-5); + } + + #[test] + fn try_from_rejects_wrong_dim() { + let v = vec![0.0; 100]; + let err = Embedding::try_from(v).unwrap_err(); + match err { + Error::EmbeddingDim { expected, got } => { + assert_eq!(expected, 768); + assert_eq!(got, 100); + } + _ => panic!("expected EmbeddingDim, got {err}"), + } + } + + #[test] + fn try_from_rejects_non_unit_norm() { + let v = vec![0.5f32; 768]; + let err = Embedding::try_from(v).unwrap_err(); + match err { + Error::NotNormalized { .. } => {} + _ => panic!("expected NotNormalized, got {err}"), + } + } + + #[cfg(feature = "inference")] + #[test] + fn from_model_output_normalizes_arbitrary_norm() { + let v = vec![1.0f32; 768]; + let e = Embedding::from_model_output(&v).expect("arbitrary-norm output must be normalized"); + let cos = e.try_cosine(&e).expect("happy path"); + assert!( + (cos - 1.0).abs() < 1e-5, + "post-norm cosine should be 1.0; got {cos}" + ); + assert!((e.as_slice()[0] - (1.0 / (768.0_f32).sqrt())).abs() < 1e-6); + } + + /// The SIMD boundary takes `&[f32; 768]`, so a wrong-length slice + /// can never reach the unsafe kernels — `from_model_output` rejects + /// it at the conversion site with `Error::EmbeddingDim`. This test + /// pins the rejection path so a future refactor that re-loosens the + /// signature back to `&[f32]` would surface as a unit-test failure. + #[cfg(feature = "inference")] + #[test] + fn from_model_output_rejects_wrong_dim() { + let v = vec![0.5f32; 100]; + let err = Embedding::from_model_output(&v).unwrap_err(); + match err { + Error::EmbeddingDim { expected, got } => { + assert_eq!(expected, 768); + assert_eq!(got, 100); + } + _ => panic!("expected EmbeddingDim, got {err}"), + } + } + + #[cfg(feature = "inference")] + #[test] + fn from_model_output_rejects_zero_norm() { + let v = vec![0.0f32; 768]; + let err = Embedding::from_model_output(&v).unwrap_err(); + match err { + Error::NotNormalized { norm, .. } => assert_eq!(norm, 0.0), + _ => panic!("expected NotNormalized for zero output, got {err}"), + } + } + + #[cfg(feature = "inference")] + #[test] + fn from_model_output_rejects_nan_component() { + let mut v = vec![0.5f32; 768]; + v[100] = f32::NAN; + let err = Embedding::from_model_output(&v).unwrap_err(); + match err { + Error::NotNormalized { norm, .. } => assert!(norm.is_nan()), + _ => panic!("expected NotNormalized for NaN, got {err}"), + } + } + + #[test] + fn try_from_renormalizes_within_tolerance() { + let mut v = unit_vec(768); + v[1] = Embedding::NORM_EPSILON / 2.0; + let e = Embedding::try_from(v).expect("near-unit norm should be accepted"); + let dot = e.try_cosine(&e).expect("happy path"); + assert!( + (dot - 1.0).abs() < 1e-5, + "renormalized cosine should be 1.0; got {dot}" + ); + } + + /// `try_cosine` must surface dim mismatches as `Error::EmbeddingDim` + /// rather than panicking. Pins the contract that callers who want a + /// panic-free surface can rely on it never panicking on dim differences. + #[test] + fn try_cosine_returns_dim_error_on_mismatch() { + let a = Embedding(vec![1.0f32, 0.0].into()); + let b = Embedding(vec![1.0f32, 0.0, 0.0].into()); + let err = a + .try_cosine(&b) + .expect_err("dim mismatch must surface as Err"); + match err { + Error::EmbeddingDim { expected, got } => { + assert_eq!(expected, 2, "lhs dim"); + assert_eq!(got, 3, "rhs dim"); + } + other => panic!("expected Error::EmbeddingDim, got {other}"), + } + } + + /// `try_cosine` must also reject same-dim-but-non-768 pairs (the + /// `try_into::<&[f32; EMBED_DIM]>` failure path inside the kernel + /// boundary). This pair has matching dims (both 4), so the + /// dim-equality check passes, but the typed-array conversion still + /// fails — and `try_cosine` translates that into `EmbeddingDim` + /// rather than panicking. + #[test] + fn try_cosine_returns_dim_error_when_both_wrong_size() { + let a = Embedding(vec![1.0f32, 0.0, 0.0, 0.0].into()); + let b = Embedding(vec![0.0f32, 1.0, 0.0, 0.0].into()); + let err = a + .try_cosine(&b) + .expect_err("non-EMBED_DIM operands must error"); + match err { + Error::EmbeddingDim { expected, got } => { + assert_eq!(expected, Embedding::EMBED_DIM); + assert_eq!(got, 4); + } + other => panic!("expected Error::EmbeddingDim, got {other}"), + } + } + + /// Happy path: when both operands are valid 768-d unit vectors, + /// `try_cosine` returns `Ok(_)` close to 1.0 for the self-pair. + #[test] + fn try_cosine_self_unit_pair() { + let v = unit_vec(768); + let e = Embedding::try_from(v).expect("unit-norm 768-d should succeed"); + let cos = e.try_cosine(&e).expect("happy path must be Ok"); + assert!((cos - 1.0).abs() < 1e-5); + } + + /// `into_inner` exposes the storage `Arc<[f32]>` cheaply (no copy), + /// and the inner slice round-trips through the renormalization + /// performed by `try_from`. Replaces the old `into_vec_round_trips` + /// test which exercised an API that was removed in favor of the + /// allocation-free `into_inner`. + #[test] + fn into_inner_exposes_arc_unchanged() { + let v = unit_vec(768); + let e = Embedding::try_from(v).expect("unit-norm 768-d should succeed"); + let arc = e.into_inner(); + assert_eq!(arc.len(), 768); + assert!((arc[0] - 1.0).abs() < 1e-6); + } + + #[test] + fn embedding_is_send_sync() { + fn _req() {} + _req::(); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..7f26a5d --- /dev/null +++ b/src/error.rs @@ -0,0 +1,99 @@ +//! Error type for the `egemma` crate. + +#[cfg(feature = "inference")] +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum Error { + /// ORT-backed graph load failure. Gated on the `inference` feature + /// because `ort::Error` doesn't exist when the feature is off. + #[cfg(feature = "inference")] + #[error("failed to load ONNX graph at {path}: {source}")] + LoadGraph { path: PathBuf, source: ort::Error }, + + /// Required ONNX output tensor was not present in the session output map. + /// Indicates an unexpected re-export or a corrupted graph. + #[error("required ONNX output `{name}` was missing from session run")] + MissingOnnxOutput { name: &'static str }, + + #[error("tokenizer load failed: {0}")] + Tokenizer(String), + + #[error("unexpected output rank: expected 2, got {rank} with shape {shape:?}")] + OutputRank { rank: usize, shape: Vec }, + + #[error("session shape mismatch on `{input}`: expected {expected}, got {got:?}")] + SessionShapeMismatch { + input: &'static str, + expected: &'static str, + got: Vec, + }, + + #[error("embedding dimension mismatch: expected {expected}, got {got}")] + EmbeddingDim { expected: usize, got: usize }, + + #[error("embedding is not unit-norm (got ||v||₂ = {norm}, tolerance ε = {epsilon})")] + NotNormalized { norm: f32, epsilon: f32 }, + + #[error("text input is empty")] + EmptyText, + + #[error("batch size {got} exceeds maximum {max}")] + BatchTooLarge { got: usize, max: usize }, + + /// `BatchOptions::batch_size` was outside the legal range + /// `1..=max_batch_size` at encoder construction. + #[error("invalid batch_size {batch_size}: must be in 1..={max_batch_size}")] + InvalidBatchSize { + batch_size: usize, + max_batch_size: usize, + }, + + #[error("batch index {index}: {source}")] + Batch { index: usize, source: Box }, + + /// ORT runtime error pass-through. Gated on the `inference` feature + /// because `ort::Error` doesn't exist when the feature is off. + #[cfg(feature = "inference")] + #[error(transparent)] + Ort(#[from] ort::Error), + + #[error(transparent)] + Io(#[from] std::io::Error), +} + +pub type Result = core::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_text_displays_message() { + assert_eq!(Error::EmptyText.to_string(), "text input is empty"); + } + + #[test] + fn batch_wraps_inner_error() { + let inner = Error::EmptyText; + let wrapped = Error::Batch { + index: 3, + source: Box::new(inner), + }; + assert_eq!(wrapped.to_string(), "batch index 3: text input is empty"); + } + + #[test] + fn embedding_dim_mismatch_shows_expected_and_got() { + let err = Error::EmbeddingDim { + expected: 768, + got: 512, + }; + assert_eq!( + err.to_string(), + "embedding dimension mismatch: expected 768, got 512" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 0a58390..800b2dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,46 @@ -//! A template for creating Rust open-source repo on GitHub -#![cfg_attr(not(feature = "std"), no_std)] +//! EmbeddingGemma inference library — produces 768-dim L2-normalized +//! sentence embeddings from Google's `embedding-gemma` ONNX export. +//! +//! Mirrors the `siglip2` text-tower API surface: a [`TextEncoder`] with +//! `from_files` / `from_files_with_options` / `from_ort_session` +//! constructors, plus `embed`, `embed_batch`, and `warmup`. +//! +//! # Target / feature contract +//! +//! The `inference` feature is **on by default** and is **native-only**. +//! It pulls `ort` (ONNX Runtime FFI) and `tokenizers` (which transitively +//! depends on C-only libraries like `onig_sys`); neither builds on +//! `wasm32-*` today. Building wasm with default features therefore fails +//! deep in `getrandom` / `onig_sys` before this crate's code is reached. +//! +//! **Wasm consumers must opt out:** +//! +//! ```bash +//! cargo check --target wasm32-unknown-unknown --no-default-features +//! ``` +//! +//! Without `inference`, the public surface is the [`Embedding`] type, +//! [`Options`] / [`BatchOptions`] / [`ThreadOptions`], and the +//! [`Error`] enum — useful when inference itself happens elsewhere +//! (a server, a different runtime) and only the value types and +//! similarity primitive need to be present. + #![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, allow(unused_attributes))] -#![deny(missing_docs)] +#![deny(rust_2018_idioms, single_use_lifetimes)] -#[cfg(all(not(feature = "std"), feature = "alloc"))] -extern crate alloc as std; +pub mod embedding; +pub mod error; +pub mod options; +#[cfg(feature = "inference")] +pub(crate) mod session; +pub(crate) mod simd; +#[cfg(feature = "inference")] +pub mod text_enc; -#[cfg(feature = "std")] -extern crate std; +pub use embedding::Embedding; +pub use error::{Error, Result}; +#[cfg(feature = "inference")] +pub use options::GraphOptimizationLevel; +pub use options::{BatchOptions, Options, ThreadOptions}; +#[cfg(feature = "inference")] +pub use text_enc::TextEncoder; diff --git a/src/options.rs b/src/options.rs new file mode 100644 index 0000000..b8d9a01 --- /dev/null +++ b/src/options.rs @@ -0,0 +1,473 @@ +//! Session, batch, and threading options for [`crate::TextEncoder`]. +//! +//! `GraphOptimizationLevel` and `Options::optimization_level` are +//! re-exported / present only with `feature = "inference"` — they +//! reach into `ort` types that don't exist on wasm builds. +//! +//! `serde::{Serialize, Deserialize}` derives on `Options`, +//! `BatchOptions`, and `ThreadOptions` are gated on `feature = "serde"` +//! so consumers who don't need config (de)serialization don't pay the +//! serde compile-time cost. + +#[cfg(feature = "inference")] +pub use ort::session::builder::GraphOptimizationLevel; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +// `optimization_level`'s `serialize` / `deserialize` adapters depend on +// both `inference` (for the `GraphOptimizationLevel` type itself) and +// `serde` (for the trait machinery). `Options::optimization_level` +// references this module only under the same conjunction. +#[cfg(all(feature = "inference", feature = "serde"))] +mod optimization_level { + use super::GraphOptimizationLevel; + use serde::*; + + #[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize, + )] + #[serde(rename_all = "snake_case")] + enum OptimizationLevel { + Disable, + #[default] + Level1, + Level2, + Level3, + All, + } + + impl From for OptimizationLevel { + #[inline] + fn from(value: GraphOptimizationLevel) -> Self { + match value { + GraphOptimizationLevel::Disable => Self::Disable, + GraphOptimizationLevel::Level1 => Self::Level1, + GraphOptimizationLevel::Level2 => Self::Level2, + GraphOptimizationLevel::Level3 => Self::Level3, + GraphOptimizationLevel::All => Self::All, + } + } + } + + impl From for GraphOptimizationLevel { + #[inline] + fn from(value: OptimizationLevel) -> Self { + match value { + OptimizationLevel::Disable => Self::Disable, + OptimizationLevel::Level1 => Self::Level1, + OptimizationLevel::Level2 => Self::Level2, + OptimizationLevel::Level3 => Self::Level3, + OptimizationLevel::All => Self::All, + } + } + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub fn serialize(level: &GraphOptimizationLevel, serializer: S) -> Result + where + S: Serializer, + { + OptimizationLevel::from(*level).serialize(serializer) + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + OptimizationLevel::deserialize(deserializer).map(Into::into) + } + + // Must stay in lock-step with `Options::new()` so that deserializing a + // config that omits `optimization_level` yields the same baseline level + // a normal `Options::default()` would. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn default() -> GraphOptimizationLevel { + GraphOptimizationLevel::Level1 + } +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_max_seq_len() -> usize { + 2048 +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_batch_size() -> usize { + 8 +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_max_batch_size() -> usize { + 1024 +} + +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct BatchOptions { + #[cfg_attr(feature = "serde", serde(default = "default_max_seq_len"))] + max_seq_len: usize, + #[cfg_attr(feature = "serde", serde(default = "default_batch_size"))] + batch_size: usize, + #[cfg_attr(feature = "serde", serde(default = "default_max_batch_size"))] + max_batch_size: usize, +} + +impl BatchOptions { + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn new() -> Self { + Self { + max_seq_len: default_max_seq_len(), + batch_size: default_batch_size(), + max_batch_size: default_max_batch_size(), + } + } + + /// Maximum number of tokens per input. Long inputs are truncated to + /// this length by the tokenizer. Defaults to 2048 — `embedding-gemma`'s + /// trained context window. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn max_seq_len(&self) -> usize { + self.max_seq_len + } + + /// Inputs per ORT inference call (chunk size). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn batch_size(&self) -> usize { + self.batch_size + } + + /// Hard upper bound on `texts.len()` accepted by `embed_batch`. + /// Inputs above this are rejected with `Error::BatchTooLarge`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn max_batch_size(&self) -> usize { + self.max_batch_size + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_max_seq_len(mut self, n: usize) -> Self { + self.max_seq_len = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_batch_size(mut self, n: usize) -> Self { + self.batch_size = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_max_batch_size(mut self, n: usize) -> Self { + self.max_batch_size = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_max_seq_len(&mut self, n: usize) -> &mut Self { + self.max_seq_len = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_batch_size(&mut self, n: usize) -> &mut Self { + self.batch_size = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_max_batch_size(&mut self, n: usize) -> &mut Self { + self.max_batch_size = n; + self + } + + /// Reject `batch_size == 0` (the silent `.max(1)` coercion footgun) and + /// `batch_size > max_batch_size` (a config error that wastes scratch + /// memory and never produces a chunk that large in practice). + #[cfg_attr(not(any(feature = "inference", test)), allow(dead_code))] + pub(crate) fn validate(&self) -> Result<(), crate::Error> { + if self.batch_size == 0 || self.batch_size > self.max_batch_size { + return Err(crate::Error::InvalidBatchSize { + batch_size: self.batch_size, + max_batch_size: self.max_batch_size, + }); + } + Ok(()) + } +} + +impl Default for BatchOptions { + #[cfg_attr(not(tarpaulin), inline(always))] + fn default() -> Self { + Self::new() + } +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_intra_threads() -> usize { + 1 +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_inter_threads() -> usize { + 1 +} + +#[cfg_attr(not(tarpaulin), inline(always))] +const fn default_parallel_execution() -> bool { + false +} + +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ThreadOptions { + #[cfg_attr(feature = "serde", serde(default = "default_intra_threads"))] + intra_threads: usize, + #[cfg_attr(feature = "serde", serde(default = "default_inter_threads"))] + inter_threads: usize, + #[cfg_attr(feature = "serde", serde(default = "default_parallel_execution"))] + parallel_execution: bool, +} + +impl ThreadOptions { + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn new() -> Self { + Self { + intra_threads: default_intra_threads(), + inter_threads: default_inter_threads(), + parallel_execution: default_parallel_execution(), + } + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn intra_threads(&self) -> usize { + self.intra_threads + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn inter_threads(&self) -> usize { + self.inter_threads + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn parallel_execution(&self) -> bool { + self.parallel_execution + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_intra_threads(mut self, n: usize) -> Self { + self.intra_threads = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_inter_threads(mut self, n: usize) -> Self { + self.inter_threads = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_parallel_execution(mut self, p: bool) -> Self { + self.parallel_execution = p; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_intra_threads(&mut self, n: usize) -> &mut Self { + self.intra_threads = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_inter_threads(&mut self, n: usize) -> &mut Self { + self.inter_threads = n; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_parallel_execution(&mut self, p: bool) -> &mut Self { + self.parallel_execution = p; + self + } +} + +impl Default for ThreadOptions { + #[cfg_attr(not(tarpaulin), inline(always))] + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Options { + #[cfg(feature = "inference")] + #[cfg_attr( + feature = "serde", + serde(with = "optimization_level", default = "optimization_level::default") + )] + optimization_level: GraphOptimizationLevel, + #[cfg_attr(feature = "serde", serde(default))] + batch: BatchOptions, + #[cfg_attr(feature = "serde", serde(default))] + threads: ThreadOptions, +} + +impl Options { + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn new() -> Self { + Self { + #[cfg(feature = "inference")] + optimization_level: GraphOptimizationLevel::Level1, + batch: BatchOptions::new(), + threads: ThreadOptions::new(), + } + } + + #[cfg(feature = "inference")] + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn optimization_level(&self) -> GraphOptimizationLevel { + self.optimization_level + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn batch(&self) -> BatchOptions { + self.batch + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn threads(&self) -> ThreadOptions { + self.threads + } + + #[cfg(feature = "inference")] + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_optimization_level(mut self, l: GraphOptimizationLevel) -> Self { + self.optimization_level = l; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_batch(mut self, b: BatchOptions) -> Self { + self.batch = b; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_threads(mut self, t: ThreadOptions) -> Self { + self.threads = t; + self + } + + #[cfg(feature = "inference")] + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_optimization_level(&mut self, l: GraphOptimizationLevel) -> &mut Self { + self.optimization_level = l; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_batch(&mut self, b: BatchOptions) -> &mut Self { + self.batch = b; + self + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_threads(&mut self, t: ThreadOptions) -> &mut Self { + self.threads = t; + self + } +} + +impl Default for Options { + #[cfg_attr(not(tarpaulin), inline(always))] + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "inference")] + #[test] + fn defaults_match_spec() { + let o = Options::default(); + assert_eq!(o.optimization_level(), GraphOptimizationLevel::Level1); + assert_eq!(o.batch().max_seq_len(), 2048); + assert_eq!(o.batch().batch_size(), 8); + assert_eq!(o.batch().max_batch_size(), 1024); + assert_eq!(o.threads().intra_threads(), 1); + assert_eq!(o.threads().inter_threads(), 1); + assert!(!o.threads().parallel_execution()); + } + + #[cfg(feature = "inference")] + #[test] + fn builder_chains_compose() { + let o = Options::default() + .with_optimization_level(GraphOptimizationLevel::Level3) + .with_batch(BatchOptions::default().with_batch_size(32)) + .with_threads(ThreadOptions::default().with_intra_threads(4)); + + assert_eq!(o.optimization_level(), GraphOptimizationLevel::Level3); + assert_eq!(o.batch().batch_size(), 32); + assert_eq!(o.threads().intra_threads(), 4); + } + + #[test] + fn options_is_copy() { + fn _require_copy() {} + _require_copy::(); + _require_copy::(); + _require_copy::(); + } + + #[test] + fn validate_rejects_zero_batch_size() { + let bad = BatchOptions::default().with_batch_size(0); + match bad.validate() { + Err(crate::Error::InvalidBatchSize { + batch_size: 0, + max_batch_size: 1024, + }) => {} + other => panic!("expected InvalidBatchSize {{ 0, 1024 }}, got {other:?}"), + } + } + + #[test] + fn validate_rejects_batch_size_above_max() { + let bad = BatchOptions::default() + .with_batch_size(2048) + .with_max_batch_size(1024); + match bad.validate() { + Err(crate::Error::InvalidBatchSize { + batch_size: 2048, + max_batch_size: 1024, + }) => {} + other => panic!("expected InvalidBatchSize {{ 2048, 1024 }}, got {other:?}"), + } + } + + #[test] + fn validate_accepts_default() { + BatchOptions::default() + .validate() + .expect("default BatchOptions must validate (8 / 1024)"); + } + + #[cfg(all(feature = "inference", feature = "serde"))] + #[test] + fn deserializing_empty_object_equals_default() { + let from_empty: Options = serde_json::from_str("{}").expect("empty options"); + let dflt = Options::default(); + assert_eq!(from_empty.optimization_level(), dflt.optimization_level()); + assert_eq!(from_empty.batch().max_seq_len(), dflt.batch().max_seq_len()); + assert_eq!(from_empty.batch().batch_size(), dflt.batch().batch_size()); + assert_eq!( + from_empty.batch().max_batch_size(), + dflt.batch().max_batch_size() + ); + } +} diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..2a55da5 --- /dev/null +++ b/src/session.rs @@ -0,0 +1,101 @@ +//! Shared `ort::Session` constructor for [`crate::TextEncoder`]. +//! +//! Gated on `feature = "inference"` because every type touched here +//! (`ort::Session`, `ort::ep::*`) only exists when ort is in the +//! dependency graph. + +use std::path::Path; + +use crate::{ + error::{Error, Result}, + options::Options, +}; + +/// Build an `ort::Session` from the graph at `path` with the +/// caller-supplied `Options`. Registers any execution providers the +/// caller opted into (`cuda` / `tensorrt` / `directml` / `rocm` / +/// `coreml` Cargo features) before committing the graph file. The +/// implicit CPU EP is always available as the final fallback. +pub(crate) fn build_session(graph: &Path, opts: Options) -> Result { + use ort::session::Session; + + let level = opts.optimization_level(); + + let mut builder = Session::builder() + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source, + })? + .with_optimization_level(level) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source: ort::Error::from(source), + })? + .with_intra_threads(opts.threads().intra_threads()) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source: ort::Error::from(source), + })? + .with_inter_threads(opts.threads().inter_threads()) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source: ort::Error::from(source), + })? + .with_parallel_execution(opts.threads().parallel_execution()) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source: ort::Error::from(source), + })?; + + let providers = collect_execution_providers(); + if !providers.is_empty() { + builder = builder + .with_execution_providers(providers) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source: ort::Error::from(source), + })?; + } + + builder + .commit_from_file(graph) + .map_err(|source| Error::LoadGraph { + path: graph.to_path_buf(), + source, + }) +} + +/// Collect the execution-provider dispatchers active under the +/// current target + feature configuration. Order matters: ort tries +/// each in the supplied list before falling back to the implicit +/// CPU EP, so the first registered EP gets first refusal on each op. +fn collect_execution_providers() -> Vec { + #[allow(unused_mut)] + let mut providers: Vec = Vec::new(); + + // TensorRT before CUDA when both are enabled: TensorRT typically + // beats raw CUDA on supported ops, and the unsupported ones fall + // back to CUDA's general execution path. + #[cfg(feature = "tensorrt")] + { + providers.push(ort::ep::TensorRT::default().build()); + } + #[cfg(feature = "cuda")] + { + providers.push(ort::ep::CUDA::default().build()); + } + #[cfg(feature = "directml")] + { + providers.push(ort::ep::DirectML::default().build()); + } + #[cfg(feature = "rocm")] + { + providers.push(ort::ep::ROCm::default().build()); + } + #[cfg(feature = "coreml")] + { + providers.push(ort::ep::CoreML::default().build()); + } + + providers +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs new file mode 100644 index 0000000..32d89d8 --- /dev/null +++ b/src/simd/mod.rs @@ -0,0 +1,148 @@ +//! Crate-internal SIMD primitives. Only one operation is hot enough to +//! be worth hand-vectorizing: the 768-element f32 dot product +//! ([`Embedding::cosine`], `||v||²` during normalization). Pointwise +//! scales and integer widenings auto-vectorize under `-O3`, so they +//! stay in scalar form. +//! +//! Backends: +//! - `scalar` — always compiled, reference implementation. +//! - `neon` — aarch64 NEON + FMA. NEON is baseline on aarch64, so the +//! dispatcher invokes it unconditionally on that target. +//! - `x86` — x86_64 AVX2 + FMA. Selected at runtime when +//! `is_x86_feature_detected!("avx2")` and `…("fma")` both succeed. +//! +//! Numerical contract: SIMD backends are not byte-identical to scalar +//! (different summation order changes f32 rounding) but agree within +//! `1e-3` absolute for `dot_768` on unit-norm 768-d vectors. Tests in +//! each backend module enforce this. +//! +//! Safety boundary: `dot_768` takes `&[f32; 768]` rather than `&[f32]`. +//! The unsafe per-arch kernels read exactly 768 elements via raw +//! pointer offsets, and the type-level length invariant is what makes +//! that read sound. A `&[f32]`-typed parameter would only be checked +//! by `debug_assert!`, which is stripped in release — the type-level +//! version eliminates that release-mode footgun by construction. +//! +//! Miri escape hatch: every per-arch dispatcher short-circuits to +//! scalar under `cfg!(miri)`. Miri cannot evaluate target-specific +//! LLVM intrinsics (`vfmaq_f32`, `_mm256_fmadd_ps`, …) and would +//! abort with "unsupported operation: can't call foreign function" +//! the moment a normal test went through `Embedding::cosine`. +//! Routing through scalar lets the Miri matrix exercise the same +//! call sites as native CI — and validate the *unsafe-free* path — +//! without ever entering the SIMD backends. The per-arch backend +//! tests that call the unsafe kernels directly are gated with +//! `#[cfg(not(miri))]` for the same reason. + +pub(crate) mod scalar; + +#[cfg(target_arch = "aarch64")] +pub(crate) mod neon; + +#[cfg(target_arch = "x86_64")] +pub(crate) mod x86; + +/// Dispatch to the best available 768-d f32 dot product. The fixed-size +/// array parameter is the safety contract: the unsafe per-arch backends +/// rely on exactly 768 elements being readable. +#[cfg_attr(not(tarpaulin), inline(always))] +pub(crate) fn dot_768(a: &[f32; 768], b: &[f32; 768]) -> f32 { + dot_768_dispatch(a, b) +} + +#[cfg(target_arch = "aarch64")] +#[cfg_attr(not(tarpaulin), inline(always))] +fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 { + // Miri can't evaluate `vfmaq_f32` / `vld1q_f32` and would abort with + // "unsupported operation: can't call foreign function" — see the + // module-level docstring. Route through scalar so Miri-driven jobs + // still exercise `Embedding::cosine` and the surrounding logic. + if cfg!(miri) { + return scalar::dot_768(a, b); + } + // SAFETY: NEON is a baseline aarch64 feature — every aarch64 CPU has + // it. The 768-element precondition is encoded in the parameter type. + unsafe { neon::dot_768(a, b) } +} + +#[cfg(target_arch = "x86_64")] +#[cfg_attr(not(tarpaulin), inline(always))] +fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 { + // Same Miri reasoning as the aarch64 path — bypass the AVX2+FMA + // intrinsics under Miri. + if cfg!(miri) { + return scalar::dot_768(a, b); + } + // `is_x86_feature_detected!` caches its result behind an atomic, so + // the per-call cost is a relaxed load + branch. + if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") { + // SAFETY: feature detection above; 768-element precondition encoded + // in the parameter type. + unsafe { x86::dot_768_avx2_fma(a, b) } + } else { + scalar::dot_768(a, b) + } +} + +#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] +#[cfg_attr(not(tarpaulin), inline(always))] +fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 { + scalar::dot_768(a, b) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fixture() -> (Box<[f32; 768]>, Box<[f32; 768]>) { + let a: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.013).sin()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + let b: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.017).cos()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + (a, b) + } + + #[test] + fn dispatch_agrees_with_scalar_within_tolerance() { + let (a, b) = fixture(); + let s = scalar::dot_768(&a, &b); + let d = dot_768(&a, &b); + assert!( + (s - d).abs() < 1e-3, + "dispatch dot ({d}) disagrees with scalar ({s})" + ); + } + + #[test] + fn dispatch_zero_for_orthogonal_axes() { + // e_0 vs e_1 → exactly 0 in both scalar and SIMD (no rounding). + let mut a = Box::new([0.0f32; 768]); + let mut b = Box::new([0.0f32; 768]); + a[0] = 1.0; + b[1] = 1.0; + assert_eq!(dot_768(&a, &b), 0.0); + } + + /// Short slices can never reach the SIMD boundary: the parameter + /// type `&[f32; 768]` rejects them at compile time. This test + /// documents the conversion-site failure mode that callers see when + /// they pass a wrong-length slice — the SIMD backends never get a + /// chance to read OOB. + #[test] + fn short_slice_cannot_be_converted_to_768_array() { + let v = vec![0.0f32; 100]; + let arr: Result<&[f32; 768], _> = v.as_slice().try_into(); + assert!( + arr.is_err(), + "100-element slice must not convert to [f32; 768]" + ); + } +} diff --git a/src/simd/neon.rs b/src/simd/neon.rs new file mode 100644 index 0000000..09b9a84 --- /dev/null +++ b/src/simd/neon.rs @@ -0,0 +1,101 @@ +//! aarch64 NEON backend — selected unconditionally on aarch64 (NEON is +//! a baseline feature). The kernel carries +//! `#[target_feature(enable = "neon")]` so its intrinsics execute in +//! an explicitly NEON-enabled context rather than one merely inherited +//! from the aarch64 target's default features. + +use core::arch::aarch64::*; + +/// 768-element f32 dot product using NEON FMA. +/// +/// Four parallel 4-lane accumulators (16 lanes total). Each +/// `vfmaq_f32` multiplies 4 f32s and adds into the accumulator, fully +/// pipelined across the four chains. +/// +/// # Safety +/// +/// NEON must be available — guaranteed on aarch64 by the ISA, but we +/// keep the `target_feature` annotation so the call site is explicitly +/// typed as a NEON context. The 768-element length precondition is +/// encoded in the parameter type (`&[f32; 768]`), not asserted at +/// runtime — this is what makes the raw-pointer reads sound in +/// release builds where `debug_assert!`s would have been stripped. +#[inline] +#[target_feature(enable = "neon")] +pub(crate) unsafe fn dot_768(a: &[f32; 768], b: &[f32; 768]) -> f32 { + let mut acc0 = vdupq_n_f32(0.0); + let mut acc1 = vdupq_n_f32(0.0); + let mut acc2 = vdupq_n_f32(0.0); + let mut acc3 = vdupq_n_f32(0.0); + + let pa = a.as_ptr(); + let pb = b.as_ptr(); + + // 768 / 16 = 48 outer iterations: 4 × 4-lane FMAs across 4 + // independent dependency chains = 192 FMAs total. + let mut i = 0usize; + while i < 768 { + // SAFETY: `i + 16 ≤ 768` each iteration; pa/pb point to fixed-size + // arrays of length 768 by the parameter type. NEON loads/FMAs are + // sound under `#[target_feature(enable = "neon")]`. + unsafe { + let a0 = vld1q_f32(pa.add(i)); + let a1 = vld1q_f32(pa.add(i + 4)); + let a2 = vld1q_f32(pa.add(i + 8)); + let a3 = vld1q_f32(pa.add(i + 12)); + let b0 = vld1q_f32(pb.add(i)); + let b1 = vld1q_f32(pb.add(i + 4)); + let b2 = vld1q_f32(pb.add(i + 8)); + let b3 = vld1q_f32(pb.add(i + 12)); + acc0 = vfmaq_f32(acc0, a0, b0); + acc1 = vfmaq_f32(acc1, a1, b1); + acc2 = vfmaq_f32(acc2, a2, b2); + acc3 = vfmaq_f32(acc3, a3, b3); + } + i += 16; + } + + // Pairwise reduce 4 vectors → 1 vector → scalar. + let s01 = vaddq_f32(acc0, acc1); + let s23 = vaddq_f32(acc2, acc3); + let s = vaddq_f32(s01, s23); + vaddvq_f32(s) +} + +// Direct calls to the unsafe NEON kernel. Miri can't evaluate the +// `vfmaq_f32` / `vld1q_f32` intrinsics; the same coverage of the +// public API (via `Embedding::cosine`) under Miri is provided by the +// scalar fallback the dispatcher routes to under `cfg!(miri)`. +#[cfg(all(test, not(miri)))] +mod tests { + use super::*; + + #[test] + fn agrees_with_scalar_within_tolerance() { + let a: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.013).sin()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + let b: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.017).cos()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + let s = crate::simd::scalar::dot_768(&a, &b); + // SAFETY: NEON is baseline on aarch64; length is type-encoded. + let n = unsafe { dot_768(&a, &b) }; + assert!((s - n).abs() < 1e-3, "neon ({n}) vs scalar ({s})"); + } + + #[test] + fn unit_vector_self_dot_is_one() { + let mut a = Box::new([0.0f32; 768]); + a[42] = 1.0; + // SAFETY: NEON baseline; type-encoded length. + let got = unsafe { dot_768(&a, &a) }; + assert_eq!(got, 1.0); + } +} diff --git a/src/simd/scalar.rs b/src/simd/scalar.rs new file mode 100644 index 0000000..9f227f0 --- /dev/null +++ b/src/simd/scalar.rs @@ -0,0 +1,57 @@ +//! Always-compiled scalar reference implementation. Acts as the +//! fallback backend on targets without a SIMD path and as the +//! agreement baseline for the per-arch backends' tests. + +/// Four-accumulator scalar dot product. The four independent reduction +/// chains let the compiler overlap multiplies/adds across iterations +/// even without SIMD intrinsics. +/// +/// `#[allow(dead_code)]`: on aarch64 the dispatcher always picks NEON +/// (a baseline ISA feature), so this baseline is unreachable from +/// non-test builds on that target — but we keep it compiled as the +/// agreement reference for the per-arch backends' tests and as the +/// fallback path on non-aarch64 / non-x86_64 targets. +#[allow(dead_code)] +#[cfg_attr(not(tarpaulin), inline(always))] +pub(crate) fn dot_768(a: &[f32; 768], b: &[f32; 768]) -> f32 { + let mut acc = [0.0_f32; 4]; + let mut i = 0; + while i < 768 { + acc[0] += a[i] * b[i]; + acc[1] += a[i + 1] * b[i + 1]; + acc[2] += a[i + 2] * b[i + 2]; + acc[3] += a[i + 3] * b[i + 3]; + i += 4; + } + acc[0] + acc[1] + acc[2] + acc[3] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dot_orthogonal_unit_vectors_is_zero() { + let mut a = Box::new([0.0f32; 768]); + let mut b = Box::new([0.0f32; 768]); + a[0] = 1.0; + b[1] = 1.0; + assert_eq!(dot_768(&a, &b), 0.0); + } + + #[test] + fn dot_self_unit_vector_is_one() { + let mut a = Box::new([0.0f32; 768]); + a[100] = 1.0; + assert_eq!(dot_768(&a, &a), 1.0); + } + + #[test] + fn dot_constant_vectors_matches_known_sum() { + let a = Box::new([0.5f32; 768]); + let b = Box::new([0.25f32; 768]); + // 768 × 0.5 × 0.25 = 96.0 + let got = dot_768(&a, &b); + assert!((got - 96.0).abs() < 1e-4, "expected ≈96.0, got {got}"); + } +} diff --git a/src/simd/x86.rs b/src/simd/x86.rs new file mode 100644 index 0000000..f4a9ee7 --- /dev/null +++ b/src/simd/x86.rs @@ -0,0 +1,105 @@ +//! x86_64 AVX2 + FMA backend. Selected by the dispatcher when both +//! `is_x86_feature_detected!("avx2")` and `…("fma")` succeed at +//! runtime. Carries `#[target_feature(enable = "avx2,fma")]` so the +//! intrinsics execute in an explicitly AVX2+FMA context. + +use core::arch::x86_64::*; + +/// 768-element f32 dot product using AVX2 256-bit registers + FMA. +/// +/// 768 = 8 × 96 → 24 outer iterations × 4 independent FMA chains +/// (4 × 8 = 32 elements per iteration) = 96 FMAs total. +/// +/// # Safety +/// +/// AVX2 + FMA must be present (dispatcher-verified via +/// `is_x86_feature_detected!`). The 768-element length precondition is +/// encoded in the parameter type (`&[f32; 768]`), not asserted at +/// runtime — this is what makes the raw-pointer reads sound in release +/// builds where `debug_assert!`s would have been stripped. +#[inline] +#[target_feature(enable = "avx2,fma")] +pub(crate) unsafe fn dot_768_avx2_fma(a: &[f32; 768], b: &[f32; 768]) -> f32 { + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + + let pa = a.as_ptr(); + let pb = b.as_ptr(); + + // 768 / 32 = 24 outer iterations, each loading 4 × 8-lane vectors + // per operand into 4 parallel accumulators. + let mut i = 0usize; + while i < 768 { + // SAFETY: `i + 32 ≤ 768` each iteration; pa/pb point to fixed-size + // arrays of length 768 by the parameter type. AVX2+FMA loads/FMAs + // are sound under `#[target_feature(enable = "avx2,fma")]`. + unsafe { + let a0 = _mm256_loadu_ps(pa.add(i)); + let a1 = _mm256_loadu_ps(pa.add(i + 8)); + let a2 = _mm256_loadu_ps(pa.add(i + 16)); + let a3 = _mm256_loadu_ps(pa.add(i + 24)); + let b0 = _mm256_loadu_ps(pb.add(i)); + let b1 = _mm256_loadu_ps(pb.add(i + 8)); + let b2 = _mm256_loadu_ps(pb.add(i + 16)); + let b3 = _mm256_loadu_ps(pb.add(i + 24)); + acc0 = _mm256_fmadd_ps(a0, b0, acc0); + acc1 = _mm256_fmadd_ps(a1, b1, acc1); + acc2 = _mm256_fmadd_ps(a2, b2, acc2); + acc3 = _mm256_fmadd_ps(a3, b3, acc3); + } + i += 32; + } + + // Reduce: 4 vectors → 1 vector → scalar. These intrinsics are + // pure lane/arithmetic ops (no memory access), so they are safe to + // call inside an `unsafe fn` body without an inner `unsafe { ... }` + // wrapper — the `target_feature(enable = "avx2,fma")` scope already + // guarantees the SIMD context they need. + let s01 = _mm256_add_ps(acc0, acc1); + let s23 = _mm256_add_ps(acc2, acc3); + let s = _mm256_add_ps(s01, s23); + let lo = _mm256_castps256_ps128(s); + let hi = _mm256_extractf128_ps(s, 1); + let sum128 = _mm_add_ps(lo, hi); + // sum128 = [a, b, c, d]; want a + b + c + d. + let shuf = _mm_movehdup_ps(sum128); // [b, b, d, d] + let sums = _mm_add_ps(sum128, shuf); // [a+b, _, c+d, _] + let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, …] + let total = _mm_add_ss(sums, shuf2); + _mm_cvtss_f32(total) +} + +// Direct call to the unsafe AVX2+FMA kernel. Miri can't evaluate +// `_mm256_loadu_ps` / `_mm256_fmadd_ps`; under `cfg(miri)` the +// dispatcher routes through scalar, so the public API is still +// covered. +#[cfg(all(test, not(miri)))] +mod tests { + use super::*; + + #[test] + fn agrees_with_scalar_within_tolerance() { + if !std::arch::is_x86_feature_detected!("avx2") || !std::arch::is_x86_feature_detected!("fma") { + eprintln!("skipping: AVX2/FMA not available on this host"); + return; + } + let a: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.013).sin()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + let b: Box<[f32; 768]> = (0..768) + .map(|i| ((i as f32) * 0.017).cos()) + .collect::>() + .into_boxed_slice() + .try_into() + .unwrap(); + let s = crate::simd::scalar::dot_768(&a, &b); + // SAFETY: AVX2+FMA detected above; type-encoded length. + let v = unsafe { dot_768_avx2_fma(&a, &b) }; + assert!((s - v).abs() < 1e-3, "avx2+fma ({v}) vs scalar ({s})"); + } +} diff --git a/src/text_enc.rs b/src/text_enc.rs new file mode 100644 index 0000000..32a5e67 --- /dev/null +++ b/src/text_enc.rs @@ -0,0 +1,438 @@ +//! Text encoder for `embedding-gemma`. + +use std::path::Path; + +use tokenizers::{ + PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection, + TruncationParams, TruncationStrategy, +}; + +use crate::{ + embedding::Embedding, + error::{Error, Result}, + options::Options, +}; + +const EMBED_DIM: usize = Embedding::EMBED_DIM; +const PAD_TOKEN: &str = ""; + +/// `embedding-gemma` text-tower inference. Owns one `ort::Session` and one +/// `tokenizers::Tokenizer`. +/// +/// `TextEncoder: Send + !Sync` — `ort::Session` is `!Sync`. Workers wanting +/// parallelism instantiate one `TextEncoder` per thread, or share one behind +/// a `Mutex`. +pub struct TextEncoder { + session: ort::session::Session, + tokenizer: Tokenizer, + opts: Options, +} + +impl TextEncoder { + /// **Not available on wasm32.** `ort 2.0.0-rc.12` cfg-gates + /// `commit_from_file` out of wasm32 builds. On wasm callers must + /// construct the `ort::session::Session` via the wasm-specific async + /// APIs and pass it to [`Self::from_ort_session`]. + #[cfg(not(target_arch = "wasm32"))] + pub fn from_files(graph: &Path, tokenizer: &Path) -> Result { + Self::from_files_with_options(graph, tokenizer, Options::default()) + } + + /// Same wasm32 caveat as [`Self::from_files`]. + #[cfg(not(target_arch = "wasm32"))] + pub fn from_files_with_options(graph: &Path, tokenizer: &Path, opts: Options) -> Result { + let session = crate::session::build_session(graph, opts)?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(|e| Error::Tokenizer(e.to_string()))?; + let tokenizer = configure_tokenizer(tokenizer, opts.batch().max_seq_len())?; + Self::from_ort_session_with_options(session, tokenizer, opts) + } + + pub fn from_ort_session(session: ort::session::Session, tokenizer: Tokenizer) -> Result { + let opts = Options::default(); + let tokenizer = configure_tokenizer(tokenizer, opts.batch().max_seq_len())?; + Self::from_ort_session_with_options(session, tokenizer, opts) + } + + fn from_ort_session_with_options( + session: ort::session::Session, + tokenizer: Tokenizer, + opts: Options, + ) -> Result { + validate_text_session(&session)?; + opts.batch().validate()?; + Ok(Self { + session, + tokenizer, + opts, + }) + } + + pub fn embed(&mut self, text: &str) -> Result { + if text.is_empty() { + return Err(Error::EmptyText); + } + let mut out = self.embed_batch(&[text])?; + Ok(out.remove(0)) + } + + /// Returns `Ok(vec![])` for an empty input slice (no ORT call). + /// Returns `Error::BatchTooLarge` when `texts.len() > opts.batch.max_batch_size`. + /// Internally chunks `texts` into groups of size `BatchOptions::batch_size` + /// and runs one ORT inference per chunk; the returned `Vec` preserves + /// input order and has the same length as `texts` on success. + /// + /// **Failure semantics.** Aborts on the first failing input and returns + /// `Error::Batch { index, source }` carrying the offending zero-based + /// index. Already-computed embeddings from earlier chunks are dropped. + pub fn embed_batch(&mut self, texts: &[&str]) -> Result> { + if texts.is_empty() { + return Ok(Vec::new()); + } + let max = self.opts.batch().max_batch_size(); + if texts.len() > max { + return Err(Error::BatchTooLarge { + got: texts.len(), + max, + }); + } + if let Some((index, _)) = texts.iter().enumerate().find(|(_, t)| t.is_empty()) { + return Err(Error::Batch { + index, + source: Box::new(Error::EmptyText), + }); + } + let chunk = self.opts.batch().batch_size(); + let mut out = Vec::with_capacity(texts.len()); + for (chunk_idx, group) in texts.chunks(chunk).enumerate() { + let base_index = chunk_idx * chunk; + let chunk_emb = embed_chunk(&mut self.session, &self.tokenizer, group, base_index)?; + out.extend(chunk_emb); + } + Ok(out) + } + + pub fn warmup(&mut self) -> Result<()> { + let _ = self.embed("warmup")?; + Ok(()) + } +} + +fn embed_chunk( + session: &mut ort::session::Session, + tokenizer: &Tokenizer, + group: &[&str], + base_index: usize, +) -> Result> { + let encodings = tokenizer + .encode_batch(group.to_vec(), true) + .map_err(|e| Error::Batch { + index: base_index, + source: Box::new(Error::Tokenizer(e.to_string())), + })?; + + let batch = group.len(); + // BatchLongest pads every encoding in the chunk to the same length. + let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0); + if seq_len == 0 { + return Err(Error::Batch { + index: base_index, + source: Box::new(Error::EmptyText), + }); + } + + let mut input_ids = Vec::with_capacity(batch * seq_len); + let mut attention_mask = Vec::with_capacity(batch * seq_len); + for (i, enc) in encodings.iter().enumerate() { + let ids = enc.get_ids(); + let mask = enc.get_attention_mask(); + if ids.len() != seq_len || mask.len() != seq_len { + return Err(Error::Batch { + index: base_index + i, + source: Box::new(Error::Tokenizer(format!( + "tokenizer produced uneven row {} (ids={}, mask={}, expected {})", + i, + ids.len(), + mask.len(), + seq_len + ))), + }); + } + input_ids.extend(ids.iter().map(|&u| u as i64)); + attention_mask.extend(mask.iter().map(|&u| u as i64)); + } + + run_session( + session, + &input_ids, + &attention_mask, + batch, + seq_len, + base_index, + ) +} + +fn run_session( + session: &mut ort::session::Session, + input_ids: &[i64], + attention_mask: &[i64], + batch: usize, + seq_len: usize, + base_index: usize, +) -> Result> { + use ort::value::TensorRef; + + // Wrap chunk-level errors (tensor build, ORT run, output extraction, + // shape validation) with `Error::Batch { index: base_index }` so the + // caller can identify which chunk failed even when the failure + // doesn't pin to a specific row. Per-row normalization failures get + // a precise `base_index + i` further down. This mirrors siglip2's + // text_enc batch-failure semantics — a documented contract that + // `embed_batch` reports failures via `Error::Batch`. + let wrap_chunk = |source: Error| Error::Batch { + index: base_index, + source: Box::new(source), + }; + + let shape: [usize; 2] = [batch, seq_len]; + let ids_val = + TensorRef::from_array_view((shape, input_ids)).map_err(|e| wrap_chunk(Error::Ort(e)))?; + let mask_val = + TensorRef::from_array_view((shape, attention_mask)).map_err(|e| wrap_chunk(Error::Ort(e)))?; + + let outputs = session + .run(ort::inputs![ + "input_ids" => ids_val, + "attention_mask" => mask_val, + ]) + .map_err(|e| wrap_chunk(Error::Ort(e)))?; + + let out = outputs.get("sentence_embedding").ok_or_else(|| { + wrap_chunk(Error::MissingOnnxOutput { + name: "sentence_embedding", + }) + })?; + let (shape, data) = out + .try_extract_tensor::() + .map_err(|e| wrap_chunk(Error::Ort(e)))?; + + if shape.len() != 2 { + return Err(wrap_chunk(Error::OutputRank { + rank: shape.len(), + shape: shape.to_vec(), + })); + } + if shape[0] != batch as i64 || shape[1] != EMBED_DIM as i64 { + return Err(wrap_chunk(Error::SessionShapeMismatch { + input: "sentence_embedding", + expected: "[batch, 768]", + got: shape.to_vec(), + })); + } + + embeddings_from_chunk(data, batch, base_index) +} + +/// Convert a flat `[batch * EMBED_DIM]` model-output buffer into +/// `batch` `Embedding`s, wrapping per-row normalization failures as +/// `Error::Batch { index: base_index + i, source }` so callers can +/// quarantine the offending row. Pulled out of `run_session` so the +/// indexed wrapping is unit-testable without an ORT session. +fn embeddings_from_chunk(data: &[f32], batch: usize, base_index: usize) -> Result> { + debug_assert_eq!(data.len(), batch * EMBED_DIM); + let mut embeddings = Vec::with_capacity(batch); + for i in 0..batch { + let row = &data[i * EMBED_DIM..(i + 1) * EMBED_DIM]; + let emb = Embedding::from_model_output(row).map_err(|source| Error::Batch { + index: base_index + i, + source: Box::new(source), + })?; + embeddings.push(emb); + } + Ok(embeddings) +} + +fn validate_text_session(session: &ort::session::Session) -> Result<()> { + use ort::value::TensorElementType; + + let inputs = session.inputs(); + let outputs = session.outputs(); + + // Both inputs are `[batch, seq]` with dynamic batch and dynamic seq. + check_outlet(inputs, "input_ids", TensorElementType::Int64, &[-1, -1])?; + check_outlet( + inputs, + "attention_mask", + TensorElementType::Int64, + &[-1, -1], + )?; + // Output is `[batch, EMBED_DIM]` with dynamic batch. + check_outlet( + outputs, + "sentence_embedding", + TensorElementType::Float32, + &[-1, EMBED_DIM as i64], + )?; + Ok(()) +} + +/// Verify an `Outlet` exists with the expected dtype and shape. +/// +/// `expected_shape` semantics: a value of `-1` is a wildcard (matches any +/// dim including the graph's own `-1` dynamic marker). Any other value +/// must match exactly. The graph's declared shape may itself contain `-1` +/// for dynamic axes; in that case we still accept it (the runtime will +/// catch shape mismatches at inference time). +fn check_outlet( + outlets: &[ort::value::Outlet], + name: &'static str, + expected_dtype: ort::value::TensorElementType, + expected_shape: &[i64], +) -> Result<()> { + use ort::value::ValueType; + + let outlet = outlets + .iter() + .find(|o| o.name() == name) + .ok_or(Error::SessionShapeMismatch { + input: name, + expected: "outlet present in session", + got: vec![], + })?; + + match outlet.dtype() { + ValueType::Tensor { ty, shape, .. } => { + if *ty != expected_dtype { + return Err(Error::SessionShapeMismatch { + input: name, + expected: "matching tensor dtype", + got: shape.to_vec(), + }); + } + let actual: &[i64] = shape; + if actual.len() != expected_shape.len() { + return Err(Error::SessionShapeMismatch { + input: name, + expected: "matching tensor rank", + got: actual.to_vec(), + }); + } + for (i, &want) in expected_shape.iter().enumerate() { + let act = actual[i]; + if want != -1 && act != -1 && act != want { + return Err(Error::SessionShapeMismatch { + input: name, + expected: "matching static dim", + got: actual.to_vec(), + }); + } + } + Ok(()) + } + _ => Err(Error::SessionShapeMismatch { + input: name, + expected: "tensor", + got: vec![], + }), + } +} + +fn configure_tokenizer(mut tokenizer: Tokenizer, max_seq_len: usize) -> Result { + let pad_id = tokenizer + .token_to_id(PAD_TOKEN) + .ok_or_else(|| Error::Tokenizer(format!("loaded tokenizer has no `{PAD_TOKEN}` token")))?; + + // Pad to the longest input in each batch. The model's `attention_mask` + // input lets us mask out the padding tokens cleanly, so we don't need + // a fixed sequence length — every chunk pads to its own longest row. + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + pad_id, + pad_token: PAD_TOKEN.to_string(), + pad_type_id: 0, + pad_to_multiple_of: None, + })); + + // Truncate long inputs to `max_seq_len`. `with_truncation` returns + // `Result<&mut Self>` and only fails when `stride > effective_max_length`; + // with `stride = 0` and `max_length > 0` this is infallible. + if max_seq_len == 0 { + return Err(Error::Tokenizer( + "max_seq_len must be > 0 (BatchOptions::with_max_seq_len)".to_string(), + )); + } + tokenizer + .with_truncation(Some(TruncationParams { + direction: TruncationDirection::Right, + max_length: max_seq_len, + strategy: TruncationStrategy::LongestFirst, + stride: 0, + })) + .map_err(|e| Error::Tokenizer(e.to_string()))?; + Ok(tokenizer) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pad_token_constant_matches_gemma_vocab() { + // The `embedding-gemma` tokenizer.json has `` at id 0; this + // pin keeps the constant in sync with the assumption used by + // `from_files_with_options`. + assert_eq!(PAD_TOKEN, ""); + } + + #[test] + fn embed_dim_constant_matches_embedding_module() { + assert_eq!(EMBED_DIM, 768); + } + + /// Codex review finding: `embed_batch` documents that failures + /// surface as `Error::Batch { index, source }` carrying the + /// offending zero-based index, but the previous implementation + /// propagated `Embedding::from_model_output` errors unwrapped via + /// `?` — so a degenerate row in the middle of a batch would lose + /// its position. This test fakes a 3-row chunk where the middle row + /// is all zero (→ `NotNormalized`) and asserts the wrapped index is + /// `base_index + 1`, proving the row context is preserved across + /// the boundary. + #[test] + fn embeddings_from_chunk_wraps_row_error_with_index() { + // 3 rows × 768. Rows 0 and 2 are unit vectors (normalize fine); + // row 1 is all-zero, which `from_model_output` rejects as + // `Error::NotNormalized` — the row we want to surface. + let mut data = vec![0.0f32; 3 * EMBED_DIM]; + data[0] = 1.0; + data[2 * EMBED_DIM] = 1.0; + + let err = embeddings_from_chunk(&data, 3, 100).expect_err("row 1 must fail"); + match err { + Error::Batch { index, source } => { + assert_eq!(index, 101, "expected base_index + 1, got {index}"); + match *source { + Error::NotNormalized { norm, .. } => assert_eq!(norm, 0.0), + other => panic!("expected NotNormalized inside Batch, got {other}"), + } + } + other => panic!("expected Error::Batch, got {other}"), + } + } + + /// Sibling check: when every row is well-formed, + /// `embeddings_from_chunk` returns the full batch with no wrapping. + #[test] + fn embeddings_from_chunk_succeeds_for_clean_batch() { + let mut data = vec![0.0f32; 2 * EMBED_DIM]; + data[0] = 1.0; + data[EMBED_DIM] = 1.0; + let out = embeddings_from_chunk(&data, 2, 0).expect("clean batch must succeed"); + assert_eq!(out.len(), 2); + for e in &out { + assert_eq!(e.dim(), EMBED_DIM); + let cos = e.try_cosine(e).expect("happy path"); + assert!((cos - 1.0).abs() < 1e-5); + } + } +} diff --git a/tests/foo.rs b/tests/foo.rs deleted file mode 100644 index 8b13789..0000000 --- a/tests/foo.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..2c9b879 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,166 @@ +//! End-to-end integration test against the released `embedding-gemma` +//! ONNX export. +//! +//! Requires the model files at runtime; gated behind the `EGEMMA_MODEL_DIR` +//! env var so `cargo test` works without the assets present. Set it to a +//! directory containing `model.onnx` (with its `.onnx_data` sidecar) +//! and a `tokenizer.json`. +//! +//! `model.onnx` is the canonical fp32 export from +//! `onnx-community/embeddinggemma-300m-ONNX`. The model card flags +//! fp16 as an unsupported activation dtype for this graph; we don't +//! auto-discover `model_fp16.onnx`. If you have only the fp16 file +//! locally and have validated it on your tokenizer/quality bar, +//! point at it explicitly via `EGEMMA_MODEL_FILE=model_fp16.onnx`. +//! +//! ```bash +//! EGEMMA_MODEL_DIR=/path/to/embedding-gemma cargo test --test integration +//! ``` +//! +//! # CI contract — read this before assuming "tests passed = inference works" +//! +//! GitHub Actions does **not** set `EGEMMA_MODEL_DIR`. When unset, every +//! test in this file emits a `[INTEGRATION-SKIP]` banner and returns +//! `Ok(())` without loading a model. CI therefore reports them as +//! `ok` even though no `ort::Session::run` ever happened. This is a +//! deliberate trade-off: the alternative — pulling ~600 MB of model +//! assets per CI run, or maintaining a synthetic ONNX fixture — costs +//! more than it catches, because the structural risks (input/output +//! name drift, dtype drift, dim drift) are already enforced at +//! construction time by `validate_text_session` (see `src/text_enc.rs`), +//! and the unit tests pin the constant assumptions (`PAD_TOKEN`, +//! `EMBED_DIM`). +//! +//! **Developer responsibility.** Before merging changes that touch +//! `src/text_enc.rs`, `src/session.rs`, `src/simd/`, `src/embedding.rs`, +//! or this file, run: +//! +//! ```bash +//! EGEMMA_MODEL_DIR=/path/to/embedding-gemma cargo test --test integration +//! ``` +//! +//! and confirm 4/4 pass. Grep CI logs for `[INTEGRATION-SKIP]` to +//! verify whether a given run actually exercised this path. + +#![cfg(feature = "inference")] + +use std::path::PathBuf; + +use egemma::{Embedding, TextEncoder}; + +fn model_dir() -> Option { + std::env::var_os("EGEMMA_MODEL_DIR").map(PathBuf::from) +} + +fn discover_graph(dir: &std::path::Path) -> PathBuf { + if let Some(name) = std::env::var_os("EGEMMA_MODEL_FILE") { + return dir.join(name); + } + let canonical = dir.join("model.onnx"); + if canonical.exists() { + return canonical; + } + panic!( + "no `model.onnx` found in {}; set `EGEMMA_MODEL_FILE` to point at \ + a different filename (the upstream model card flags fp16 as an \ + unsupported activation dtype for this graph, so `model_fp16.onnx` \ + is not auto-discovered — pass it explicitly only if you've \ + validated it for your workload)", + dir.display() + ); +} + +/// Centralizes the skip-or-load decision so every test prints the same +/// `[INTEGRATION-SKIP]` banner — searchable in CI logs to distinguish +/// "real run with assertions" from "skipped, env var unset" runs. +/// Returns `None` when `EGEMMA_MODEL_DIR` is unset; the caller should +/// `return` immediately. +fn try_load_encoder(test_name: &str) -> Option { + if model_dir().is_none() { + eprintln!( + "[INTEGRATION-SKIP] {test_name}: EGEMMA_MODEL_DIR unset — skipping. \ + Run locally with `EGEMMA_MODEL_DIR=/path/to/embedding-gemma cargo test \ + --test integration` before merging inference-path changes." + ); + return None; + } + let dir = model_dir().expect("model_dir checked above"); + let graph = discover_graph(&dir); + let tokenizer = dir.join("tokenizer.json"); + Some(TextEncoder::from_files(&graph, &tokenizer).expect("loading encoder must succeed")) +} + +#[test] +fn embed_single_returns_unit_norm_vector() { + let Some(mut encoder) = try_load_encoder("embed_single_returns_unit_norm_vector") else { + return; + }; + let e = encoder + .embed("hello world") + .expect("single embed must succeed"); + assert_eq!(e.dim(), Embedding::EMBED_DIM); + let cos = e.try_cosine(&e).expect("self-cosine on valid embedding"); + assert!( + (cos - 1.0).abs() < 1e-4, + "self-cosine should be 1.0; got {cos}" + ); +} + +#[test] +fn embed_batch_preserves_order_and_self_cosine() { + let Some(mut encoder) = try_load_encoder("embed_batch_preserves_order_and_self_cosine") else { + return; + }; + let prompts = ["alpha", "the quick brown fox", "lorem ipsum dolor sit amet"]; + let embeddings = encoder.embed_batch(&prompts).expect("batch embed"); + assert_eq!(embeddings.len(), prompts.len()); + for e in &embeddings { + assert_eq!(e.dim(), Embedding::EMBED_DIM); + let cos = e.try_cosine(e).expect("self-cosine on valid embedding"); + assert!((cos - 1.0).abs() < 1e-4); + } + let single = encoder + .embed("alpha") + .expect("single embed for parity check"); + let parity = single + .try_cosine(&embeddings[0]) + .expect("parity cosine on valid pair"); + assert!( + (parity - 1.0).abs() < 1e-3, + "single vs batched embedding for the same prompt should match" + ); +} + +#[test] +fn related_prompts_more_similar_than_unrelated() { + let Some(mut encoder) = try_load_encoder("related_prompts_more_similar_than_unrelated") else { + return; + }; + let v = encoder + .embed_batch(&[ + "task: search result | query: how do birds fly?", + "Birds use lift generated by their wings to fly.", + "The price of bananas in Tokyo is rising.", + ]) + .expect("batch embed"); + + let related = v[0] + .try_cosine(&v[1]) + .expect("related cosine on valid pair"); + let unrelated = v[0] + .try_cosine(&v[2]) + .expect("unrelated cosine on valid pair"); + assert!( + related > unrelated, + "expected related > unrelated; got related={related}, unrelated={unrelated}" + ); +} + +#[test] +fn empty_text_rejected() { + let Some(mut encoder) = try_load_encoder("empty_text_rejected") else { + return; + }; + let err = encoder.embed("").expect_err("empty text must error"); + assert!(matches!(err, egemma::Error::EmptyText)); +} From e1d64c2fc920f8d63888cca95d07075614f65d31 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 15:17:16 +1200 Subject: [PATCH 02/11] cleanup --- Cargo.toml | 9 ++++++++- ci/miri_sb.sh | 7 ++++++- ci/miri_tb.sh | 10 +++++++++- ci/sanitizer.sh | 18 ++++++++++++++---- src/embedding.rs | 7 ++++--- 5 files changed, 41 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d7f2aeb..2f72287 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,7 +109,14 @@ overflow-checks = false rpath = false [package.metadata.docs.rs] -all-features = true +# `all-features = true` would activate the opt-in execution-provider +# features (`cuda`, `tensorrt`, `directml`, `rocm`, `coreml`), each of +# which requires the corresponding vendor SDK at build time. docs.rs's +# Linux builder doesn't have any of them, so `all-features` would fail +# the docs build even when the actual code on docs.rs is fine. List +# the features that compile on a stock Linux runner; EP-specific docs +# coverage would need a separately provisioned builder. +features = ["inference", "serde"] rustdoc-args = ["--cfg", "docsrs"] [lints.rust] diff --git a/ci/miri_sb.sh b/ci/miri_sb.sh index cc3c6e0..73ad07a 100755 --- a/ci/miri_sb.sh +++ b/ci/miri_sb.sh @@ -35,4 +35,9 @@ cargo miri setup export MIRIFLAGS="-Zmiri-strict-provenance -Zmiri-disable-isolation -Zmiri-symbolic-alignment-check" -cargo miri test --all-targets --target "$TARGET" +# See `ci/miri_tb.sh` for why `--no-default-features` is required: +# Miri can't evaluate ort / tokenizers FFI, and the SIMD dispatcher's +# `cfg!(miri)` short-circuit routes through scalar so the unsafe +# NEON / AVX2 kernel boundaries are covered indirectly through the +# embedding API without entering platform intrinsics. +cargo miri test --all-targets --no-default-features --target "$TARGET" diff --git a/ci/miri_tb.sh b/ci/miri_tb.sh index 5d374c7..54a9976 100755 --- a/ci/miri_tb.sh +++ b/ci/miri_tb.sh @@ -35,4 +35,12 @@ cargo miri setup export MIRIFLAGS="-Zmiri-strict-provenance -Zmiri-disable-isolation -Zmiri-symbolic-alignment-check -Zmiri-tree-borrows" -cargo miri test --all-targets --target "$TARGET" +# Miri can't evaluate the FFI in `ort` / `tokenizers` (the `inference` +# default-feature dependencies), and most of the matrix targets +# (powerpc64, s390x, riscv64, i686) have no ort prebuilds. With +# `--no-default-features` Miri exercises the embedding + options + simd +# subset it actually validates: the SIMD dispatcher routes through the +# scalar fallback under `cfg!(miri)`, so the unsafe NEON / AVX2 kernel +# boundaries are indirectly covered through `Embedding::try_cosine` +# without ever entering the platform intrinsics Miri can't model. +cargo miri test --all-targets --no-default-features --target "$TARGET" diff --git a/ci/sanitizer.sh b/ci/sanitizer.sh index 4ff6819..a400f34 100755 --- a/ci/sanitizer.sh +++ b/ci/sanitizer.sh @@ -5,18 +5,28 @@ export ASAN_OPTIONS="detect_odr_violation=0 detect_leaks=0" TARGET="x86_64-unknown-linux-gnu" +# Sanitizer feature set: matches the coverage job. `--all-features` is +# unsafe here — it would activate the opt-in execution-provider features +# (`cuda`, `tensorrt`, `directml`, `rocm`, `coreml`), each of which +# requires the corresponding vendor SDK to compile. Stock GitHub +# runners don't have any of them, so `--all-features` would fail in +# `ort-sys`'s build script before the unsafe SIMD code is ever +# instrumented. EP-specific sanitizer coverage belongs on separately +# provisioned runners. +FEATURES="inference,serde" + # Run address sanitizer RUSTFLAGS="-Z sanitizer=address" \ -cargo test --tests --target "$TARGET" --all-features +cargo test --tests --target "$TARGET" --no-default-features --features "$FEATURES" # Run leak sanitizer RUSTFLAGS="-Z sanitizer=leak" \ -cargo test --tests --target "$TARGET" --all-features +cargo test --tests --target "$TARGET" --no-default-features --features "$FEATURES" # Run memory sanitizer (requires -Zbuild-std for instrumented std) RUSTFLAGS="-Z sanitizer=memory" \ -cargo -Zbuild-std test --tests --target "$TARGET" --all-features +cargo -Zbuild-std test --tests --target "$TARGET" --no-default-features --features "$FEATURES" # Run thread sanitizer (requires -Zbuild-std for instrumented std) RUSTFLAGS="-Z sanitizer=thread" \ -cargo -Zbuild-std test --tests --target "$TARGET" --all-features +cargo -Zbuild-std test --tests --target "$TARGET" --no-default-features --features "$FEATURES" diff --git a/src/embedding.rs b/src/embedding.rs index becc837..3131696 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -54,9 +54,10 @@ impl Embedding { /// only reachable in-crate; the check is forward-compatibility for /// variable-dim embeddings and a guard against future internal misuse. /// - /// Internally dispatches through [`crate::simd::dot_768`] — picks NEON - /// on aarch64, AVX2+FMA on x86_64 (when the runtime CPU advertises - /// both), or a four-accumulator scalar fallback on every other target. + /// Internally dispatches through the crate-private SIMD layer — picks + /// NEON on aarch64, AVX2+FMA on x86_64 (when the runtime CPU + /// advertises both), or a four-accumulator scalar fallback on every + /// other target. pub fn try_cosine(&self, other: &Embedding) -> Result { if self.dim() != other.dim() { return Err(Error::EmbeddingDim { From 06ba1fe75969c80dfabcc4182369dff808898fa8 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 15:38:00 +1200 Subject: [PATCH 03/11] cleanup --- CHANGELOG.md | 10 ++-- README-zh_CN.md | 51 ----------------- README.md | 142 +++++++++++++++++++++++++++++++++++++----------- src/simd/x86.rs | 114 ++++++++++++++++++++++++++++++++------ 4 files changed, 213 insertions(+), 104 deletions(-) delete mode 100644 README-zh_CN.md diff --git a/CHANGELOG.md b/CHANGELOG.md index bd7a668..80fde44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,9 @@ -# UNRELEASED +# Changelog -# 0.1.2 (January 6th, 2022) - -FEATURES +All notable changes to `egemma` will be documented in this file. The +format is loosely based on [Keep a Changelog](https://keepachangelog.com/), +and the project adheres to [Semantic Versioning](https://semver.org/). +## [Unreleased] +Initial release. See `Cargo.toml` for the public surface. diff --git a/README-zh_CN.md b/README-zh_CN.md deleted file mode 100644 index 7a07f4d..0000000 --- a/README-zh_CN.md +++ /dev/null @@ -1,51 +0,0 @@ -
-

template-rs

-
-
- -开源Rust代码库GitHub模版 - -[github][Github-url] -LoC -[Build][CI-url] -[codecov][codecov-url] - -[docs.rs][doc-url] -[crates.io][crates-url] -[crates.io][crates-url] -license - -[English][en-url] | 简体中文 - -
- -## Installation - -```toml -[dependencies] -template_rs = "0.1" -``` - -## Features - -- [x] 更快的创建GitHub开源Rust代码库 - -#### License - -`Template-rs` is under the terms of both the MIT license and the -Apache License (Version 2.0). - -See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. - -Copyright (c) 2021 Al Liu. - -[Github-url]: https://github.com/al8n/template-rs/ -[CI-url]: https://github.com/al8n/template/actions/workflows/template.yml -[doc-url]: https://docs.rs/template-rs -[crates-url]: https://crates.io/crates/template-rs -[codecov-url]: https://app.codecov.io/gh/al8n/template-rs/ -[license-url]: https://opensource.org/licenses/Apache-2.0 -[rustc-url]: https://github.com/rust-lang/rust/blob/master/RELEASES.md -[license-apache-url]: https://opensource.org/licenses/Apache-2.0 -[license-mit-url]: https://opensource.org/licenses/MIT -[en-url]: https://github.com/al8n/template-rs/tree/main/README.md diff --git a/README.md b/README.md index 1af27e2..181ac23 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,124 @@ -
-

template-rs

-
-
+# egemma -A template for creating Rust open-source GitHub repo. +Rust ONNX inference library for Google's [EmbeddingGemma] +(`google/embeddinggemma-300m`) text embeddings. Produces 768-dim +L2-normalized sentence embeddings via [`ort`] and [`tokenizers`]. -[github][Github-url] -LoC -[Build][CI-url] -[codecov][codecov-url] +[EmbeddingGemma]: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX +[`ort`]: https://crates.io/crates/ort +[`tokenizers`]: https://crates.io/crates/tokenizers -[docs.rs][doc-url] -[crates.io][crates-url] -[crates.io][crates-url] -license +## Install -English | [简体中文][zh-cn-url] +```toml +[dependencies] +egemma = "0.1" +``` -
+## Quick start -## Installation +```rust,ignore +use egemma::TextEncoder; -```toml -[dependencies] -template_rs = "0.1" +let mut encoder = TextEncoder::from_files( + "model.onnx".as_ref(), + "tokenizer.json".as_ref(), +)?; + +let embeddings = encoder.embed_batch(&[ + "task: search result | query: how do birds fly?", + "Birds use lift generated by their wings to fly.", + "The price of bananas in Tokyo is rising.", +])?; + +let related = embeddings[0].try_cosine(&embeddings[1])?; +let unrelated = embeddings[0].try_cosine(&embeddings[2])?; +assert!(related > unrelated); +# Ok::<(), egemma::Error>(()) ``` -## Features -- [x] Create a Rust open-source repo fast +Download the canonical fp32 export from +[`onnx-community/embeddinggemma-300m-ONNX`][EmbeddingGemma] +(`model.onnx` plus its `model.onnx_data` sidecar, and `tokenizer.json`). +The model card flags fp16 as an unsupported activation dtype for this +graph; pass `model_fp16.onnx` only if you've validated it for your +workload. + +## Cargo features + +| Feature | Default | Effect | +|--------------|:-------:|-----------------------------------------------------------------------------------------------------------| +| `inference` | ✅ | Pulls `ort` + `tokenizers`; activates [`TextEncoder`]. Native targets only. | +| `serde` | | `Serialize` / `Deserialize` on `Options`, `BatchOptions`, `ThreadOptions`. | +| `cuda` | | NVIDIA GPUs (Linux/Windows). Requires CUDA toolkit + cuDNN at build and run time. | +| `tensorrt` | | NVIDIA, optimized inference. Falls back to CUDA, then CPU. Requires CUDA + TensorRT. | +| `directml` | | Windows GPUs (any vendor) via DirectX 12. | +| `rocm` | | AMD GPUs (Linux). Requires ROCm SDK. | +| `coreml` | | macOS / iOS via Core ML (Neural Engine + GPU + Metal Performance Shaders). | + +The execution-provider features are off by default — none are needed +for CPU inference, and each requires the corresponding vendor SDK at +build time. + +[`TextEncoder`]: https://docs.rs/egemma/latest/egemma/struct.TextEncoder.html + +## Target / feature contract + +The `inference` feature is **native-only**. It pulls `ort` (ONNX +Runtime FFI) and `tokenizers` (which transitively depends on C-only +libraries like `onig_sys`); neither builds on `wasm32-*` today. +Building wasm with default features fails deep in `getrandom` / +`onig_sys` before this crate's code is reached. + +**Wasm consumers must opt out:** + +```bash +cargo check --target wasm32-unknown-unknown --no-default-features +``` + +Without `inference`, the public surface is the `Embedding` type, +`Options` / `BatchOptions` / `ThreadOptions`, and the `Error` enum +— useful when inference itself happens elsewhere (a server, a +different runtime) and only the value types and similarity primitive +need to be present. + +## API surface + +The crate exposes: + +- `TextEncoder` — owns one `ort::Session` and one + `tokenizers::Tokenizer`. `embed`, `embed_batch`, `warmup`. + `Send + !Sync` (mirrors `ort::Session`); for parallelism, instantiate + one encoder per thread, or share one behind a `Mutex`. +- `Embedding(Arc<[f32]>)` — 768-dim L2-normalized sentence embedding. + `try_cosine` returns `Result` (no panic on dim mismatch). +- `Options` / `BatchOptions` / `ThreadOptions` — session, batch, and + threading configuration. `with_*` / `set_*` builders are `const fn` + where the underlying types permit. +- `Error` (`#[non_exhaustive]`, `thiserror`-derived). + +`Embedding` deliberately does **not** implement `Serialize` / +`Deserialize` — see its docstring for the validated round-trip pattern +through the inner slice. + +## SIMD + +`Embedding::try_cosine` dispatches the 768-element f32 dot product +through a runtime-detected backend: -#### License +- **NEON** on aarch64 (baseline ISA feature, always available). +- **AVX2 + FMA** on x86_64 when both are detected. +- **Scalar** four-accumulator fallback elsewhere. -`template-rs` is under the terms of both the MIT license and the -Apache License (Version 2.0). +The unsafe per-arch kernels take `&[f32; 768]` rather than `&[f32]` — +the type-level length invariant is what makes the raw-pointer reads +sound, and a wrong-length slice can never reach the unsafe boundary. +The dispatcher short-circuits to scalar under `cfg!(miri)` so Miri +matrices exercise the same call sites without entering platform +intrinsics it can't model. -See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. +## License -Copyright (c) 2021 Al Liu. +Dual-licensed under MIT or Apache-2.0, at your option. -[Github-url]: https://github.com/al8n/template-rs/ -[CI-url]: https://github.com/al8n/template-rs/actions/workflows/ci.yml -[doc-url]: https://docs.rs/template-rs -[crates-url]: https://crates.io/crates/template-rs -[codecov-url]: https://app.codecov.io/gh/al8n/template-rs/ -[zh-cn-url]: https://github.com/al8n/template-rs/tree/main/README-zh_CN.md +See [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE). diff --git a/src/simd/x86.rs b/src/simd/x86.rs index f4a9ee7..75fe2d9 100644 --- a/src/simd/x86.rs +++ b/src/simd/x86.rs @@ -79,27 +79,107 @@ pub(crate) unsafe fn dot_768_avx2_fma(a: &[f32; 768], b: &[f32; 768]) -> f32 { mod tests { use super::*; - #[test] - fn agrees_with_scalar_within_tolerance() { - if !std::arch::is_x86_feature_detected!("avx2") || !std::arch::is_x86_feature_detected!("fma") { - eprintln!("skipping: AVX2/FMA not available on this host"); - return; - } - let a: Box<[f32; 768]> = (0..768) - .map(|i| ((i as f32) * 0.013).sin()) - .collect::>() - .into_boxed_slice() - .try_into() - .unwrap(); - let b: Box<[f32; 768]> = (0..768) - .map(|i| ((i as f32) * 0.017).cos()) + /// AVX2+FMA must be available on any x86_64 host that runs these + /// tests. If a contributor's runner lacks it, fail loudly rather + /// than silently passing — the unsafe SIMD code is what these tests + /// are *for*, and a skip-on-missing-feature was masking the lack of + /// coverage on non-AVX hosts (Codex finding [medium]: AVX2 backend + /// test could silently skip). x86_64 baselines in any CI runner + /// since ~2013 (Haswell+) have both; if this panics on a real + /// machine, the host is too old to test on. + fn require_avx2_fma() { + assert!( + std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma"), + "x86_64 SIMD tests require AVX2+FMA on the runner; this host has neither — skipping is \ + not safe because the unsafe `dot_768_avx2_fma` kernel would otherwise have zero \ + direct test coverage. Run on a Haswell-or-later x86_64 host.", + ); + } + + fn boxed_array(f: impl Fn(usize) -> f32) -> Box<[f32; 768]> { + (0..768) + .map(f) .collect::>() .into_boxed_slice() .try_into() - .unwrap(); + .expect("768 elements") + } + + /// Compares the AVX2+FMA kernel against the scalar reference on + /// the trigonometric fixture used by the cross-backend agreement + /// test in `simd::tests`. + #[test] + fn agrees_with_scalar_within_tolerance() { + require_avx2_fma(); + let a = boxed_array(|i| ((i as f32) * 0.013).sin()); + let b = boxed_array(|i| ((i as f32) * 0.017).cos()); + let s = crate::simd::scalar::dot_768(&a, &b); + // SAFETY: AVX2+FMA asserted above; type-encoded length. + let v = unsafe { dot_768_avx2_fma(&a, &b) }; + assert!((s - v).abs() < 1e-3, "avx2+fma ({v}) vs scalar ({s})",); + } + + /// Orthogonal axis vectors → exact 0 dot product. Pin: SIMD + /// summation order can't introduce drift on inputs that are + /// identically zero outside one slot, so this checks the kernel's + /// "no spurious accumulation" property bit-exactly (no tolerance). + #[test] + fn orthogonal_axes_dot_to_exact_zero() { + require_avx2_fma(); + let mut a = Box::new([0.0f32; 768]); + let mut b = Box::new([0.0f32; 768]); + a[0] = 1.0; + b[1] = 1.0; + // SAFETY: AVX2+FMA asserted; type-encoded length. + let v = unsafe { dot_768_avx2_fma(&a, &b) }; + assert_eq!(v, 0.0, "orthogonal e0·e1 must be exactly 0; got {v}"); + } + + /// Self-dot of a unit-norm vector → exactly 1.0. Same bit-exact + /// reasoning as `orthogonal_axes_dot_to_exact_zero`: only one slot + /// contributes, no FP error from summation ordering. + #[test] + fn unit_vector_self_dot_is_one() { + require_avx2_fma(); + let mut a = Box::new([0.0f32; 768]); + a[123] = 1.0; + // SAFETY: AVX2+FMA asserted; type-encoded length. + let v = unsafe { dot_768_avx2_fma(&a, &a) }; + assert_eq!(v, 1.0, "unit-vector self-dot must be exactly 1.0; got {v}"); + } + + /// Constant-vector dot product: 768 × c × d. Catches FMA + /// accumulation bugs across the four chains (any chain + /// missing or double-counted lanes shows up here). + #[test] + fn constant_vectors_match_known_sum() { + require_avx2_fma(); + let a = Box::new([0.5f32; 768]); + let b = Box::new([0.25f32; 768]); + // 768 * 0.5 * 0.25 = 96.0 + // SAFETY: AVX2+FMA asserted; type-encoded length. + let v = unsafe { dot_768_avx2_fma(&a, &b) }; + assert!( + (v - 96.0).abs() < 1e-4, + "expected 96.0 from 768·0.5·0.25; got {v}", + ); + } + + /// Alternating-sign fixture: catches a class of reduction bugs + /// where a chain accidentally swaps subtract/add or where signs + /// drop during horizontal reduce. Also widens the tolerance check + /// past the all-positive trigonometric case. + #[test] + fn alternating_sign_agrees_with_scalar() { + require_avx2_fma(); + let a = boxed_array(|i| if i % 2 == 0 { 1.0 } else { -1.0 }); + let b = boxed_array(|i| if i % 3 == 0 { 1.0 } else { -1.0 }); let s = crate::simd::scalar::dot_768(&a, &b); - // SAFETY: AVX2+FMA detected above; type-encoded length. + // SAFETY: AVX2+FMA asserted; type-encoded length. let v = unsafe { dot_768_avx2_fma(&a, &b) }; - assert!((s - v).abs() < 1e-3, "avx2+fma ({v}) vs scalar ({s})"); + assert!( + (s - v).abs() < 1e-3, + "alternating-sign avx2 ({v}) vs scalar ({s})", + ); } } From 268fdbdfc014eecc4097eb6740cdf06754b7de36 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 15:57:58 +1200 Subject: [PATCH 04/11] cleanup --- src/simd/x86.rs | 64 +++++++++++++++++++++++++++++++++---------------- src/text_enc.rs | 51 ++++++++++++++++++++++++++++++--------- 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/simd/x86.rs b/src/simd/x86.rs index 75fe2d9..cc7b5a6 100644 --- a/src/simd/x86.rs +++ b/src/simd/x86.rs @@ -79,21 +79,35 @@ pub(crate) unsafe fn dot_768_avx2_fma(a: &[f32; 768], b: &[f32; 768]) -> f32 { mod tests { use super::*; - /// AVX2+FMA must be available on any x86_64 host that runs these - /// tests. If a contributor's runner lacks it, fail loudly rather - /// than silently passing — the unsafe SIMD code is what these tests - /// are *for*, and a skip-on-missing-feature was masking the lack of - /// coverage on non-AVX hosts (Codex finding [medium]: AVX2 backend - /// test could silently skip). x86_64 baselines in any CI runner - /// since ~2013 (Haswell+) have both; if this panics on a real - /// machine, the host is too old to test on. - fn require_avx2_fma() { - assert!( - std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma"), - "x86_64 SIMD tests require AVX2+FMA on the runner; this host has neither — skipping is \ - not safe because the unsafe `dot_768_avx2_fma` kernel would otherwise have zero \ - direct test coverage. Run on a Haswell-or-later x86_64 host.", - ); + /// Detect AVX2+FMA on the current host. Returns `true` if both + /// features are present, otherwise prints a `[SIMD-SKIP]` banner + /// and returns `false` so the caller can early-out with `Ok(())`. + /// + /// **Why skip rather than panic.** The production dispatcher + /// supports non-AVX2 x86_64 hosts via the scalar fallback (see + /// `simd::dot_768_dispatch` for x86_64). Panicking here would + /// break `cargo test` on a configuration the library officially + /// handles — a real CI/contributor problem on virtualized envs + /// or older hardware. + /// + /// **What still covers the kernel.** GitHub Actions Linux x86_64 + /// runners (Skylake+) have AVX2+FMA, so the `test` job exercises + /// these tests for real. The scalar fallback is independently + /// covered by `simd::scalar::tests` regardless of host. Codex's + /// previous-round critique of "silent skip masks lack of + /// coverage" is mitigated by the `[SIMD-SKIP]` banner — CI logs + /// are searchable to verify which runs hit the kernel. + fn avx2_fma_available() -> bool { + let ok = + std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma"); + if !ok { + eprintln!( + "[SIMD-SKIP] AVX2/FMA unavailable on this x86_64 host — direct kernel tests skipped. \ + The dispatcher's scalar fallback handles this configuration; CI Linux x86_64 runners \ + exercise the AVX2 kernel separately." + ); + } + ok } fn boxed_array(f: impl Fn(usize) -> f32) -> Box<[f32; 768]> { @@ -110,7 +124,9 @@ mod tests { /// test in `simd::tests`. #[test] fn agrees_with_scalar_within_tolerance() { - require_avx2_fma(); + if !avx2_fma_available() { + return; + } let a = boxed_array(|i| ((i as f32) * 0.013).sin()); let b = boxed_array(|i| ((i as f32) * 0.017).cos()); let s = crate::simd::scalar::dot_768(&a, &b); @@ -125,7 +141,9 @@ mod tests { /// "no spurious accumulation" property bit-exactly (no tolerance). #[test] fn orthogonal_axes_dot_to_exact_zero() { - require_avx2_fma(); + if !avx2_fma_available() { + return; + } let mut a = Box::new([0.0f32; 768]); let mut b = Box::new([0.0f32; 768]); a[0] = 1.0; @@ -140,7 +158,9 @@ mod tests { /// contributes, no FP error from summation ordering. #[test] fn unit_vector_self_dot_is_one() { - require_avx2_fma(); + if !avx2_fma_available() { + return; + } let mut a = Box::new([0.0f32; 768]); a[123] = 1.0; // SAFETY: AVX2+FMA asserted; type-encoded length. @@ -153,7 +173,9 @@ mod tests { /// missing or double-counted lanes shows up here). #[test] fn constant_vectors_match_known_sum() { - require_avx2_fma(); + if !avx2_fma_available() { + return; + } let a = Box::new([0.5f32; 768]); let b = Box::new([0.25f32; 768]); // 768 * 0.5 * 0.25 = 96.0 @@ -171,7 +193,9 @@ mod tests { /// past the all-positive trigonometric case. #[test] fn alternating_sign_agrees_with_scalar() { - require_avx2_fma(); + if !avx2_fma_available() { + return; + } let a = boxed_array(|i| if i % 2 == 0 { 1.0 } else { -1.0 }); let b = boxed_array(|i| if i % 3 == 0 { 1.0 } else { -1.0 }); let s = crate::simd::scalar::dot_768(&a, &b); diff --git a/src/text_enc.rs b/src/text_enc.rs index 32a5e67..c9b2627 100644 --- a/src/text_enc.rs +++ b/src/text_enc.rs @@ -277,11 +277,23 @@ fn validate_text_session(session: &ort::session::Session) -> Result<()> { /// Verify an `Outlet` exists with the expected dtype and shape. /// -/// `expected_shape` semantics: a value of `-1` is a wildcard (matches any -/// dim including the graph's own `-1` dynamic marker). Any other value -/// must match exactly. The graph's declared shape may itself contain `-1` -/// for dynamic axes; in that case we still accept it (the runtime will -/// catch shape mismatches at inference time). +/// `expected_shape` semantics — match `siglip2::check_outlet`: +/// +/// - `-1` in `expected_shape` means **the graph MUST declare this axis +/// dynamic**. A static dim there is rejected. This is what we want +/// for `input_ids` / `attention_mask`: `embed_chunk` sends batches +/// of `[group.len(), BatchLongest seq_len]` where neither dim is +/// known at session-build time, so a graph baking in `[1, 2048]` +/// or `[8, 512]` would fail at first `Session::run` — surface that +/// at construction time instead. +/// - any other value in `expected_shape` is an **exact match** +/// requirement. The graph may either match exactly or declare the +/// axis dynamic (`-1`); both work at runtime. +/// +/// The previous wildcard semantics (where `-1` meant "any dim +/// acceptable") let static-shape exports load successfully and only +/// failed at first inference call — Codex finding [medium]: +/// `check_outlet` accepted incompatible static shapes. fn check_outlet( outlets: &[ort::value::Outlet], name: &'static str, @@ -318,12 +330,29 @@ fn check_outlet( } for (i, &want) in expected_shape.iter().enumerate() { let act = actual[i]; - if want != -1 && act != -1 && act != want { - return Err(Error::SessionShapeMismatch { - input: name, - expected: "matching static dim", - got: actual.to_vec(), - }); + if want == -1 { + // We require this axis to be dynamic. A graph baking in + // a concrete dim here would load successfully under the + // old wildcard semantics and only fail at `Session::run` + // when `embed_chunk` sends a different size. + if act != -1 { + return Err(Error::SessionShapeMismatch { + input: name, + expected: "dynamic axis (graph declares -1; static-shape \ + exports incompatible with chunked APIs)", + got: actual.to_vec(), + }); + } + } else { + // Concrete dim required. Graph may match exactly or declare + // the axis dynamic — both work at runtime. + if act != -1 && act != want { + return Err(Error::SessionShapeMismatch { + input: name, + expected: "matching static dim", + got: actual.to_vec(), + }); + } } } Ok(()) From 342b564813af2351d8493b44044387a92302e905 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 16:13:45 +1200 Subject: [PATCH 05/11] cleanup --- src/simd/x86.rs | 26 ++++++++------------------ src/text_enc.rs | 40 +++++++++++++++------------------------- 2 files changed, 23 insertions(+), 43 deletions(-) diff --git a/src/simd/x86.rs b/src/simd/x86.rs index cc7b5a6..822268d 100644 --- a/src/simd/x86.rs +++ b/src/simd/x86.rs @@ -79,24 +79,14 @@ pub(crate) unsafe fn dot_768_avx2_fma(a: &[f32; 768], b: &[f32; 768]) -> f32 { mod tests { use super::*; - /// Detect AVX2+FMA on the current host. Returns `true` if both - /// features are present, otherwise prints a `[SIMD-SKIP]` banner - /// and returns `false` so the caller can early-out with `Ok(())`. - /// - /// **Why skip rather than panic.** The production dispatcher - /// supports non-AVX2 x86_64 hosts via the scalar fallback (see - /// `simd::dot_768_dispatch` for x86_64). Panicking here would - /// break `cargo test` on a configuration the library officially - /// handles — a real CI/contributor problem on virtualized envs - /// or older hardware. - /// - /// **What still covers the kernel.** GitHub Actions Linux x86_64 - /// runners (Skylake+) have AVX2+FMA, so the `test` job exercises - /// these tests for real. The scalar fallback is independently - /// covered by `simd::scalar::tests` regardless of host. Codex's - /// previous-round critique of "silent skip masks lack of - /// coverage" is mitigated by the `[SIMD-SKIP]` banner — CI logs - /// are searchable to verify which runs hit the kernel. + /// Returns `true` if AVX2+FMA are both detected on the current + /// host; otherwise prints a `[SIMD-SKIP]` banner and returns + /// `false`. Skip-not-panic: the dispatcher supports non-AVX2 + /// x86_64 via the scalar fallback (see `simd::dot_768_dispatch`), + /// so panicking here would fail `cargo test` on a configuration + /// the library handles. CI Linux x86_64 runners have AVX2+FMA, + /// which is where the kernel coverage actually fires; the banner + /// is grep-able to verify in CI logs. fn avx2_fma_available() -> bool { let ok = std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma"); diff --git a/src/text_enc.rs b/src/text_enc.rs index c9b2627..295cb0c 100644 --- a/src/text_enc.rs +++ b/src/text_enc.rs @@ -277,23 +277,16 @@ fn validate_text_session(session: &ort::session::Session) -> Result<()> { /// Verify an `Outlet` exists with the expected dtype and shape. /// -/// `expected_shape` semantics — match `siglip2::check_outlet`: +/// `expected_shape` semantics: /// -/// - `-1` in `expected_shape` means **the graph MUST declare this axis -/// dynamic**. A static dim there is rejected. This is what we want -/// for `input_ids` / `attention_mask`: `embed_chunk` sends batches -/// of `[group.len(), BatchLongest seq_len]` where neither dim is -/// known at session-build time, so a graph baking in `[1, 2048]` -/// or `[8, 512]` would fail at first `Session::run` — surface that -/// at construction time instead. -/// - any other value in `expected_shape` is an **exact match** -/// requirement. The graph may either match exactly or declare the -/// axis dynamic (`-1`); both work at runtime. -/// -/// The previous wildcard semantics (where `-1` meant "any dim -/// acceptable") let static-shape exports load successfully and only -/// failed at first inference call — Codex finding [medium]: -/// `check_outlet` accepted incompatible static shapes. +/// - `-1` means **the graph MUST declare this axis dynamic**. A static +/// dim there is rejected. `embed_chunk` sends batches of +/// `[group.len(), BatchLongest seq_len]` where neither dim is known +/// at session-build time, so a graph baking in `[1, 2048]` or +/// `[8, 512]` would fail at first `Session::run`. +/// - any other value is an **exact match** requirement. The graph may +/// either match exactly or declare the axis dynamic (`-1`); both +/// work at runtime. fn check_outlet( outlets: &[ort::value::Outlet], name: &'static str, @@ -418,15 +411,12 @@ mod tests { assert_eq!(EMBED_DIM, 768); } - /// Codex review finding: `embed_batch` documents that failures - /// surface as `Error::Batch { index, source }` carrying the - /// offending zero-based index, but the previous implementation - /// propagated `Embedding::from_model_output` errors unwrapped via - /// `?` — so a degenerate row in the middle of a batch would lose - /// its position. This test fakes a 3-row chunk where the middle row - /// is all zero (→ `NotNormalized`) and asserts the wrapped index is - /// `base_index + 1`, proving the row context is preserved across - /// the boundary. + /// `embed_batch` documents that failures surface as + /// `Error::Batch { index, source }` carrying the offending zero-based + /// index. Pin: a degenerate row (here, all-zero → `NotNormalized`) + /// in the middle of a batch must have its position preserved across + /// the `embeddings_from_chunk` boundary instead of bubbling up as a + /// bare `NotNormalized`. #[test] fn embeddings_from_chunk_wraps_row_error_with_index() { // 3 rows × 768. Rows 0 and 2 are unit vectors (normalize fine); From f23655f6b3efeb01f5dd65eaae6460fdbd9567a2 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 16:20:45 +1200 Subject: [PATCH 06/11] cleanup --- .github/workflows/ci.yml | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b7e55e..d1602a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -169,9 +169,29 @@ jobs: run: rustup update stable --no-self-update && rustup default stable - name: Install cargo-hack run: cargo install cargo-hack - - name: Run test - # See the clippy job for why the EP features are excluded. + # Linux / macOS: full feature-powerset minus the EP features (see + # the clippy job for that exclusion). + - name: Run test (Linux/macOS) + if: runner.os != 'Windows' run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml + # Windows: also exclude `inference` and `default`. ORT's prebuilt + # binary on `x86_64-pc-windows-msvc` is linked against the dynamic + # CRT (`/MD`), but the C/C++ build scripts in `tokenizers` deps + # (`onig_sys`, `esaxx-rs`) link the static CRT (`/MT`). When both + # end up in the same test executable — and `cargo test` compiles + # examples like `embed_text.rs` that link ort + tokenizers — MSVC + # emits `LNK2038` ("RuntimeLibrary mismatch") and the link aborts. + # `default` is excluded because it implies `inference`. + # + # This is an upstream dependency mismatch, not something this crate + # can resolve in its own code. The `build` and `clippy` Windows + # jobs above still compile the crate end-to-end (no final link), + # so lib-level regressions a Windows test binary would catch are + # still caught. The `inference` path is exercised on the Linux and + # macOS runners. + - name: Run test (Windows — skip inference due to MSVC CRT conflict) + if: runner.os == 'Windows' + run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml,inference,default sanitizer: name: sanitizer From 860d2b855f623ef2b263cab637bdb3ba5b55f929 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 16:34:21 +1200 Subject: [PATCH 07/11] cleanup --- .github/workflows/ci.yml | 10 +++++++--- Cargo.toml | 1 - src/error.rs | 27 +++++++++++++++++++++++++- src/options.rs | 13 ++++++++++--- src/simd/mod.rs | 6 +++--- src/text_enc.rs | 41 ++++++++++++++++++++++++++++++---------- 6 files changed, 77 insertions(+), 21 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d1602a9..0c7081d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -170,10 +170,14 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack # Linux / macOS: full feature-powerset minus the EP features (see - # the clippy job for that exclusion). + # the clippy job for that exclusion). `--exclude-no-default-features` + # is intentionally NOT set so the powerset also runs the + # `--no-default-features` and `--no-default-features --features serde` + # combos — those exercise the cfg-gated `serde` derive paths in + # `src/options.rs` that the default-features build can't reach. - name: Run test (Linux/macOS) if: runner.os != 'Windows' - run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml + run: cargo hack test --feature-powerset --exclude-features cuda,tensorrt,directml,rocm,coreml # Windows: also exclude `inference` and `default`. ORT's prebuilt # binary on `x86_64-pc-windows-msvc` is linked against the dynamic # CRT (`/MD`), but the C/C++ build scripts in `tokenizers` deps @@ -191,7 +195,7 @@ jobs: # macOS runners. - name: Run test (Windows — skip inference due to MSVC CRT conflict) if: runner.os == 'Windows' - run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features cuda,tensorrt,directml,rocm,coreml,inference,default + run: cargo hack test --feature-powerset --exclude-features cuda,tensorrt,directml,rocm,coreml,inference,default sanitizer: name: sanitizer diff --git a/Cargo.toml b/Cargo.toml index 2f72287..017a91d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,6 @@ serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] serde_json = "1" -tempfile = "3" # ===================================================================== # Target / feature contract. diff --git a/src/error.rs b/src/error.rs index 7f26a5d..12aa22a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -18,7 +18,11 @@ pub enum Error { #[error("required ONNX output `{name}` was missing from session run")] MissingOnnxOutput { name: &'static str }, - #[error("tokenizer load failed: {0}")] + /// Tokenizer load OR runtime use failure. Covers `Tokenizer::from_file` + /// errors at construction, ``-token contract violations during + /// configuration, `encode_batch` failures during inference, and any + /// uneven-row anomalies surfaced from the tokenizers crate. + #[error("tokenizer error: {0}")] Tokenizer(String), #[error("unexpected output rank: expected 2, got {rank} with shape {shape:?}")] @@ -31,6 +35,19 @@ pub enum Error { got: Vec, }, + /// Session contract violation that isn't a shape mismatch — wrong + /// element type, missing outlet, or non-tensor outlet. Carries the + /// actual `TensorElementType` so users debugging a bad re-export + /// see the dtype, not a shape vector that doesn't apply. Gated on + /// `feature = "inference"` because the `got` field is an `ort` type. + #[cfg(feature = "inference")] + #[error("session contract mismatch on `{input}`: expected {expected}, got {got:?}")] + SessionContractMismatch { + input: &'static str, + expected: &'static str, + got: ort::value::TensorElementType, + }, + #[error("embedding dimension mismatch: expected {expected}, got {got}")] EmbeddingDim { expected: usize, got: usize }, @@ -51,6 +68,14 @@ pub enum Error { max_batch_size: usize, }, + /// `BatchOptions::max_seq_len` was zero at encoder construction. + /// Tokenizer truncation requires `max_length > 0`; a zero-length + /// budget is meaningless. Caught alongside `InvalidBatchSize` so + /// shape-of-options errors stay together rather than leaking out + /// as opaque tokenizer-config errors. + #[error("invalid max_seq_len 0: must be > 0")] + InvalidMaxSeqLen, + #[error("batch index {index}: {source}")] Batch { index: usize, source: Box }, diff --git a/src/options.rs b/src/options.rs index b8d9a01..5154389 100644 --- a/src/options.rs +++ b/src/options.rs @@ -181,9 +181,13 @@ impl BatchOptions { self } - /// Reject `batch_size == 0` (the silent `.max(1)` coercion footgun) and - /// `batch_size > max_batch_size` (a config error that wastes scratch - /// memory and never produces a chunk that large in practice). + /// Reject: + /// - `batch_size == 0` (the silent `.max(1)` coercion footgun) + /// - `batch_size > max_batch_size` (config error: wastes scratch + /// memory and never produces a chunk that large in practice) + /// - `max_seq_len == 0` (tokenizer truncation requires `max_length > 0`; + /// a zero-length budget is meaningless and would otherwise leak out + /// as an opaque tokenizer-config error from `configure_tokenizer`) #[cfg_attr(not(any(feature = "inference", test)), allow(dead_code))] pub(crate) fn validate(&self) -> Result<(), crate::Error> { if self.batch_size == 0 || self.batch_size > self.max_batch_size { @@ -192,6 +196,9 @@ impl BatchOptions { max_batch_size: self.max_batch_size, }); } + if self.max_seq_len == 0 { + return Err(crate::Error::InvalidMaxSeqLen); + } Ok(()) } } diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 32d89d8..f6eea18 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -1,6 +1,6 @@ //! Crate-internal SIMD primitives. Only one operation is hot enough to //! be worth hand-vectorizing: the 768-element f32 dot product -//! ([`Embedding::cosine`], `||v||²` during normalization). Pointwise +//! ([`Embedding::try_cosine`], `||v||²` during normalization). Pointwise //! scales and integer widenings auto-vectorize under `-O3`, so they //! stay in scalar form. //! @@ -27,7 +27,7 @@ //! scalar under `cfg!(miri)`. Miri cannot evaluate target-specific //! LLVM intrinsics (`vfmaq_f32`, `_mm256_fmadd_ps`, …) and would //! abort with "unsupported operation: can't call foreign function" -//! the moment a normal test went through `Embedding::cosine`. +//! the moment a normal test went through `Embedding::try_cosine`. //! Routing through scalar lets the Miri matrix exercise the same //! call sites as native CI — and validate the *unsafe-free* path — //! without ever entering the SIMD backends. The per-arch backend @@ -56,7 +56,7 @@ fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 { // Miri can't evaluate `vfmaq_f32` / `vld1q_f32` and would abort with // "unsupported operation: can't call foreign function" — see the // module-level docstring. Route through scalar so Miri-driven jobs - // still exercise `Embedding::cosine` and the surrounding logic. + // still exercise `Embedding::try_cosine` and the surrounding logic. if cfg!(miri) { return scalar::dot_768(a, b); } diff --git a/src/text_enc.rs b/src/text_enc.rs index 295cb0c..db334c0 100644 --- a/src/text_enc.rs +++ b/src/text_enc.rs @@ -43,23 +43,28 @@ impl TextEncoder { pub fn from_files_with_options(graph: &Path, tokenizer: &Path, opts: Options) -> Result { let session = crate::session::build_session(graph, opts)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(|e| Error::Tokenizer(e.to_string()))?; - let tokenizer = configure_tokenizer(tokenizer, opts.batch().max_seq_len())?; + // `configure_tokenizer` runs inside `from_ort_session_with_options`, + // so we don't apply it here. Self::from_ort_session_with_options(session, tokenizer, opts) } pub fn from_ort_session(session: ort::session::Session, tokenizer: Tokenizer) -> Result { - let opts = Options::default(); - let tokenizer = configure_tokenizer(tokenizer, opts.batch().max_seq_len())?; - Self::from_ort_session_with_options(session, tokenizer, opts) + Self::from_ort_session_with_options(session, tokenizer, Options::default()) } - fn from_ort_session_with_options( + /// Construct from a caller-built `ort::Session` and `Tokenizer` with + /// custom [`Options`]. Public so wasm32 callers (who can't use + /// [`Self::from_files_with_options`] because `ort 2.0.0-rc.12` + /// cfg-gates `commit_from_file` out of wasm builds) can still tune + /// `max_seq_len`, `batch_size`, and `max_batch_size`. + pub fn from_ort_session_with_options( session: ort::session::Session, tokenizer: Tokenizer, opts: Options, ) -> Result { validate_text_session(&session)?; opts.batch().validate()?; + let tokenizer = configure_tokenizer(tokenizer, opts.batch().max_seq_len())?; Ok(Self { session, tokenizer, @@ -81,9 +86,22 @@ impl TextEncoder { /// and runs one ORT inference per chunk; the returned `Vec` preserves /// input order and has the same length as `texts` on success. /// - /// **Failure semantics.** Aborts on the first failing input and returns - /// `Error::Batch { index, source }` carrying the offending zero-based - /// index. Already-computed embeddings from earlier chunks are dropped. + /// **Failure semantics.** Aborts on the first failing chunk and returns + /// `Error::Batch { index, source }`. The wrapped `index` granularity + /// depends on where the failure originated: + /// + /// - **Row-precise** (`index = base + offending_row`) for failures + /// that pin to a specific input: empty-text guard, per-row + /// tokenizer-output length mismatch, and per-row embedding + /// normalization failures (`Error::NotNormalized` from + /// `from_model_output`). + /// - **Chunk-level** (`index = base`, the chunk's first input + /// position) for failures that don't pin to a single row: + /// `tokenizer.encode_batch` failures, ORT tensor-build / `run` / + /// output-extract errors, output-rank or output-shape mismatches. + /// Inspect `source` to disambiguate. + /// + /// Already-computed embeddings from earlier chunks are dropped. pub fn embed_batch(&mut self, texts: &[&str]) -> Result> { if texts.is_empty() { return Ok(Vec::new()); @@ -307,10 +325,13 @@ fn check_outlet( match outlet.dtype() { ValueType::Tensor { ty, shape, .. } => { if *ty != expected_dtype { - return Err(Error::SessionShapeMismatch { + // Use `SessionContractMismatch` so the actual dtype shows up + // in the message — `SessionShapeMismatch.got: Vec` would + // either be the shape (irrelevant for a dtype error) or empty. + return Err(Error::SessionContractMismatch { input: name, expected: "matching tensor dtype", - got: shape.to_vec(), + got: *ty, }); } let actual: &[i64] = shape; From a025d23b9fa23b8d78ba9aec35c13fd049c81c38 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 16:45:08 +1200 Subject: [PATCH 08/11] cleanup --- .codecov.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index bfe19d3..81d9826 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -2,9 +2,9 @@ codecov: require_ci_to_pass: false ignore: - - **benches/* - - **examples/* - - **tests/* + - benches/* + - examples/* + - tests/* coverage: status: From 10859640d36eaf139692d13545ffcdbf768cd323 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 18:08:21 +1200 Subject: [PATCH 09/11] cleanup --- Cargo.toml | 5 ++++- README.md | 28 +++++++++++++++++++++++----- src/lib.rs | 34 ++++++---------------------------- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 017a91d..591b40a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,7 +118,10 @@ rpath = false features = ["inference", "serde"] rustdoc-args = ["--cfg", "docsrs"] -[lints.rust] +[lints] +workspace = true + +[workspace.lints.rust] rust_2018_idioms = "warn" single_use_lifetimes = "warn" unexpected_cfgs = { level = "warn", check-cfg = ['cfg(all_tests)', 'cfg(tarpaulin)', 'cfg(docsrs)'] } diff --git a/README.md b/README.md index 181ac23..382284c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,27 @@ -# egemma +
+

egemma

+
+
+ +Rust ONNX inference library for Google's [EmbeddingGemma](`google/embeddinggemma-300m`) text embeddings. Produces 768-dim +L2-normalized sentence embeddings via [`ort`] and [`tokenizers`]. + +[github][Github-url] +LoC +[Build][CI-url] +[codecov][codecov-url] + +[docs.rs][doc-url] +[crates.io][crates-url] +[crates.io][crates-url] +license + +
Rust ONNX inference library for Google's [EmbeddingGemma] (`google/embeddinggemma-300m`) text embeddings. Produces 768-dim L2-normalized sentence embeddings via [`ort`] and [`tokenizers`]. -[EmbeddingGemma]: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX -[`ort`]: https://crates.io/crates/ort -[`tokenizers`]: https://crates.io/crates/tokenizers - ## Install ```toml @@ -122,3 +136,7 @@ intrinsics it can't model. Dual-licensed under MIT or Apache-2.0, at your option. See [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE). + +[EmbeddingGemma]: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX +[`ort`]: https://crates.io/crates/ort +[`tokenizers`]: https://crates.io/crates/tokenizers diff --git a/src/lib.rs b/src/lib.rs index 800b2dd..e4f06e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,46 +1,24 @@ -//! EmbeddingGemma inference library — produces 768-dim L2-normalized -//! sentence embeddings from Google's `embedding-gemma` ONNX export. -//! -//! Mirrors the `siglip2` text-tower API surface: a [`TextEncoder`] with -//! `from_files` / `from_files_with_options` / `from_ort_session` -//! constructors, plus `embed`, `embed_batch`, and `warmup`. -//! -//! # Target / feature contract -//! -//! The `inference` feature is **on by default** and is **native-only**. -//! It pulls `ort` (ONNX Runtime FFI) and `tokenizers` (which transitively -//! depends on C-only libraries like `onig_sys`); neither builds on -//! `wasm32-*` today. Building wasm with default features therefore fails -//! deep in `getrandom` / `onig_sys` before this crate's code is reached. -//! -//! **Wasm consumers must opt out:** -//! -//! ```bash -//! cargo check --target wasm32-unknown-unknown --no-default-features -//! ``` -//! -//! Without `inference`, the public surface is the [`Embedding`] type, -//! [`Options`] / [`BatchOptions`] / [`ThreadOptions`], and the -//! [`Error`] enum — useful when inference itself happens elsewhere -//! (a server, a different runtime) and only the value types and -//! similarity primitive need to be present. - +#![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_cfg))] -#![deny(rust_2018_idioms, single_use_lifetimes)] +#![deny(rust_2018_idioms, single_use_lifetimes, missing_docs)] pub mod embedding; pub mod error; pub mod options; + #[cfg(feature = "inference")] pub(crate) mod session; pub(crate) mod simd; #[cfg(feature = "inference")] +#[cfg_attr(docsrs, doc(cfg(feature = "inference")))] pub mod text_enc; pub use embedding::Embedding; pub use error::{Error, Result}; #[cfg(feature = "inference")] +#[cfg_attr(docsrs, doc(cfg(feature = "inference")))] pub use options::GraphOptimizationLevel; pub use options::{BatchOptions, Options, ThreadOptions}; #[cfg(feature = "inference")] +#[cfg_attr(docsrs, doc(cfg(feature = "inference")))] pub use text_enc::TextEncoder; From 0297513adf759dbe096b001572ea6d1b224d38e3 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 18:30:39 +1200 Subject: [PATCH 10/11] cleanup --- README.md | 5 +++ src/embedding.rs | 6 ++++ src/error.rs | 86 ++++++++++++++++++++++++++++++++++++++++++++---- src/options.rs | 56 +++++++++++++++++++++++++++++++ src/text_enc.rs | 15 +++++++++ 5 files changed, 161 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 382284c..f92ecfd 100644 --- a/README.md +++ b/README.md @@ -140,3 +140,8 @@ See [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE). [EmbeddingGemma]: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX [`ort`]: https://crates.io/crates/ort [`tokenizers`]: https://crates.io/crates/tokenizers +[Github-url]: https://github.com/Findit-AI/egemma +[CI-url]: https://github.com/Findit-AI/egemma/actions/workflows/ci.yml +[doc-url]: https://docs.rs/egemma +[crates-url]: https://crates.io/crates/egemma +[codecov-url]: https://app.codecov.io/gh/Findit-AI/egemma diff --git a/src/embedding.rs b/src/embedding.rs index 3131696..e803e52 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -29,10 +29,16 @@ impl Embedding { /// L2-norm tolerance for the unit-norm invariant. pub const NORM_EPSILON: f32 = 5e-4; + /// Number of `f32` lanes in the embedding. Always [`Self::EMBED_DIM`] + /// (768) for any `Embedding` produced by this crate's public + /// constructors. pub fn dim(&self) -> usize { self.0.len() } + /// Borrowed view of the underlying `f32` data. Cheap (no copy) and + /// the standard input for downstream similarity / vector-store code + /// that wants a `&[f32]`. pub fn as_slice(&self) -> &[f32] { &self.0 } diff --git a/src/error.rs b/src/error.rs index 12aa22a..d1a6d0f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,11 @@ use std::path::PathBuf; use thiserror::Error; +/// All errors surfaced from the public API. +/// +/// `#[non_exhaustive]` so that adding variants in a future minor +/// release isn't a breaking change for `match` arms — downstream +/// callers must include a wildcard (`_ => ...`) branch. #[derive(Debug, Error)] #[non_exhaustive] pub enum Error { @@ -11,12 +16,20 @@ pub enum Error { /// because `ort::Error` doesn't exist when the feature is off. #[cfg(feature = "inference")] #[error("failed to load ONNX graph at {path}: {source}")] - LoadGraph { path: PathBuf, source: ort::Error }, + LoadGraph { + /// Path that was passed to `commit_from_file`. + path: PathBuf, + /// Underlying `ort` error from the session-builder pipeline. + source: ort::Error, + }, /// Required ONNX output tensor was not present in the session output map. /// Indicates an unexpected re-export or a corrupted graph. #[error("required ONNX output `{name}` was missing from session run")] - MissingOnnxOutput { name: &'static str }, + MissingOnnxOutput { + /// Name of the missing output (e.g. `"sentence_embedding"`). + name: &'static str, + }, /// Tokenizer load OR runtime use failure. Covers `Tokenizer::from_file` /// errors at construction, ``-token contract violations during @@ -25,13 +38,26 @@ pub enum Error { #[error("tokenizer error: {0}")] Tokenizer(String), + /// ORT returned a tensor whose rank wasn't 2 (we expect + /// `[batch, EMBED_DIM]`). #[error("unexpected output rank: expected 2, got {rank} with shape {shape:?}")] - OutputRank { rank: usize, shape: Vec }, + OutputRank { + /// Number of dimensions in the returned tensor. + rank: usize, + /// Full shape vector for diagnostics. + shape: Vec, + }, + /// Session-level shape contract violation: a required outlet was + /// missing, had the wrong rank, had a static dim where we needed + /// a dynamic one, or had a static dim that didn't match expectations. #[error("session shape mismatch on `{input}`: expected {expected}, got {got:?}")] SessionShapeMismatch { + /// Outlet name that didn't satisfy the contract. input: &'static str, + /// Human-readable expectation message. expected: &'static str, + /// Actual shape from the session metadata. got: Vec, }, @@ -43,28 +69,60 @@ pub enum Error { #[cfg(feature = "inference")] #[error("session contract mismatch on `{input}`: expected {expected}, got {got:?}")] SessionContractMismatch { + /// Outlet name that didn't satisfy the contract. input: &'static str, + /// Human-readable expectation message. expected: &'static str, + /// Actual tensor element type from the session metadata. got: ort::value::TensorElementType, }, + /// `Embedding` constructed from a `Vec` whose length didn't + /// equal [`crate::Embedding::EMBED_DIM`] (768). #[error("embedding dimension mismatch: expected {expected}, got {got}")] - EmbeddingDim { expected: usize, got: usize }, + EmbeddingDim { + /// Required dim (always 768 in 0.1.0). + expected: usize, + /// Caller-supplied dim. + got: usize, + }, + /// `Embedding::try_from(Vec)` rejected an input whose + /// `||v||₂` was outside `[1 - ε, 1 + ε]`. The encoder path + /// normalizes raw model output unconditionally — this variant + /// only fires for caller-supplied vectors that should already be + /// unit-norm (e.g. deserialized from a vector store). #[error("embedding is not unit-norm (got ||v||₂ = {norm}, tolerance ε = {epsilon})")] - NotNormalized { norm: f32, epsilon: f32 }, + NotNormalized { + /// Computed L2 norm of the input vector. + norm: f32, + /// Tolerance window the norm had to fall inside. + epsilon: f32, + }, + /// An empty string was passed to [`crate::TextEncoder::embed`] or + /// appeared inside the slice given to + /// [`crate::TextEncoder::embed_batch`]. #[error("text input is empty")] EmptyText, + /// The slice passed to [`crate::TextEncoder::embed_batch`] exceeded + /// `BatchOptions::max_batch_size`. #[error("batch size {got} exceeds maximum {max}")] - BatchTooLarge { got: usize, max: usize }, + BatchTooLarge { + /// Number of inputs in the call. + got: usize, + /// Configured upper bound. + max: usize, + }, /// `BatchOptions::batch_size` was outside the legal range /// `1..=max_batch_size` at encoder construction. #[error("invalid batch_size {batch_size}: must be in 1..={max_batch_size}")] InvalidBatchSize { + /// The supplied (rejected) batch size. batch_size: usize, + /// The configured upper bound. max_batch_size: usize, }, @@ -76,8 +134,17 @@ pub enum Error { #[error("invalid max_seq_len 0: must be > 0")] InvalidMaxSeqLen, + /// Batched-failure envelope: wraps the underlying error with the + /// position of the offending input. See + /// [`crate::TextEncoder::embed_batch`] for the indexing + /// granularity (row-precise vs chunk-level). #[error("batch index {index}: {source}")] - Batch { index: usize, source: Box }, + Batch { + /// Zero-based index into the input slice. + index: usize, + /// Underlying error. + source: Box, + }, /// ORT runtime error pass-through. Gated on the `inference` feature /// because `ort::Error` doesn't exist when the feature is off. @@ -85,10 +152,15 @@ pub enum Error { #[error(transparent)] Ort(#[from] ort::Error), + /// Filesystem / I/O error pass-through (e.g. when reading a model + /// file). #[error(transparent)] Io(#[from] std::io::Error), } +/// Crate-local `Result` alias parameterized on the [`Error`](enum@Error) +/// enum. Disambiguated because `thiserror::Error` (the derive macro) is +/// also in scope here. pub type Result = core::result::Result; #[cfg(test)] diff --git a/src/options.rs b/src/options.rs index 5154389..d12d302 100644 --- a/src/options.rs +++ b/src/options.rs @@ -103,6 +103,10 @@ const fn default_max_batch_size() -> usize { 1024 } +/// Sequence-length and batching policy for [`crate::TextEncoder`]. +/// Validated at encoder construction; see +/// [`Self::with_max_seq_len`] / [`Self::with_batch_size`] / +/// [`Self::with_max_batch_size`] for the tunables. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct BatchOptions { @@ -115,6 +119,8 @@ pub struct BatchOptions { } impl BatchOptions { + /// Construct a `BatchOptions` with the crate defaults + /// (`max_seq_len = 2048`, `batch_size = 8`, `max_batch_size = 1024`). #[cfg_attr(not(tarpaulin), inline(always))] pub const fn new() -> Self { Self { @@ -145,36 +151,45 @@ impl BatchOptions { self.max_batch_size } + /// Returns a copy with [`Self::max_seq_len`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_max_seq_len(mut self, n: usize) -> Self { self.max_seq_len = n; self } + /// Returns a copy with [`Self::batch_size`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_batch_size(mut self, n: usize) -> Self { self.batch_size = n; self } + /// Returns a copy with [`Self::max_batch_size`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_max_batch_size(mut self, n: usize) -> Self { self.max_batch_size = n; self } + /// In-place setter for [`Self::max_seq_len`]; returns `&mut self` + /// so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_max_seq_len(&mut self, n: usize) -> &mut Self { self.max_seq_len = n; self } + /// In-place setter for [`Self::batch_size`]; returns `&mut self` + /// so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_batch_size(&mut self, n: usize) -> &mut Self { self.batch_size = n; self } + /// In-place setter for [`Self::max_batch_size`]; returns `&mut self` + /// so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_max_batch_size(&mut self, n: usize) -> &mut Self { self.max_batch_size = n; @@ -225,6 +240,11 @@ const fn default_parallel_execution() -> bool { false } +/// ORT thread-pool configuration. Maps onto ORT session-builder +/// settings (`with_intra_threads` / `with_inter_threads` / +/// `with_parallel_execution`). All defaults are `1` / `false`, +/// matching ORT's CPU-friendly low-contention setup; tune up for +/// high-throughput offline batches. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ThreadOptions { @@ -237,6 +257,8 @@ pub struct ThreadOptions { } impl ThreadOptions { + /// Construct with crate defaults (1 intra-op thread, 1 inter-op + /// thread, parallel execution off). #[cfg_attr(not(tarpaulin), inline(always))] pub const fn new() -> Self { Self { @@ -246,51 +268,65 @@ impl ThreadOptions { } } + /// Intra-op thread count — ORT's per-operator parallelism. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn intra_threads(&self) -> usize { self.intra_threads } + /// Inter-op thread count — ORT's between-operator parallelism. + /// Only meaningful when [`Self::parallel_execution`] is `true`. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn inter_threads(&self) -> usize { self.inter_threads } + /// Whether ORT runs independent operators concurrently. Most + /// embedding workloads don't benefit; off by default. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn parallel_execution(&self) -> bool { self.parallel_execution } + /// Returns a copy with [`Self::intra_threads`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_intra_threads(mut self, n: usize) -> Self { self.intra_threads = n; self } + /// Returns a copy with [`Self::inter_threads`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_inter_threads(mut self, n: usize) -> Self { self.inter_threads = n; self } + /// Returns a copy with [`Self::parallel_execution`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_parallel_execution(mut self, p: bool) -> Self { self.parallel_execution = p; self } + /// In-place setter for [`Self::intra_threads`]; returns `&mut self` + /// so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_intra_threads(&mut self, n: usize) -> &mut Self { self.intra_threads = n; self } + /// In-place setter for [`Self::inter_threads`]; returns `&mut self` + /// so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_inter_threads(&mut self, n: usize) -> &mut Self { self.inter_threads = n; self } + /// In-place setter for [`Self::parallel_execution`]; returns + /// `&mut self` so calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_parallel_execution(&mut self, p: bool) -> &mut Self { self.parallel_execution = p; @@ -305,6 +341,10 @@ impl Default for ThreadOptions { } } +/// Top-level configuration passed to [`crate::TextEncoder`] +/// constructors. Bundles ORT graph-optimization level +/// (gated on `feature = "inference"`), [`BatchOptions`], and +/// [`ThreadOptions`]. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Options { @@ -321,6 +361,9 @@ pub struct Options { } impl Options { + /// Construct with crate defaults + /// (`Level1` optimization, default `BatchOptions`, default + /// `ThreadOptions`). #[cfg_attr(not(tarpaulin), inline(always))] pub const fn new() -> Self { Self { @@ -331,22 +374,27 @@ impl Options { } } + /// ORT graph-optimization level applied at session-build time. #[cfg(feature = "inference")] #[cfg_attr(not(tarpaulin), inline(always))] pub const fn optimization_level(&self) -> GraphOptimizationLevel { self.optimization_level } + /// The nested [`BatchOptions`] for sequence-length and chunking + /// policy. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn batch(&self) -> BatchOptions { self.batch } + /// The nested [`ThreadOptions`] for ORT thread-pool tuning. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn threads(&self) -> ThreadOptions { self.threads } + /// Returns a copy with [`Self::optimization_level`] replaced. #[cfg(feature = "inference")] #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_optimization_level(mut self, l: GraphOptimizationLevel) -> Self { @@ -354,18 +402,22 @@ impl Options { self } + /// Returns a copy with [`Self::batch`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_batch(mut self, b: BatchOptions) -> Self { self.batch = b; self } + /// Returns a copy with [`Self::threads`] replaced. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_threads(mut self, t: ThreadOptions) -> Self { self.threads = t; self } + /// In-place setter for [`Self::optimization_level`]; returns + /// `&mut self` so calls can chain. #[cfg(feature = "inference")] #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_optimization_level(&mut self, l: GraphOptimizationLevel) -> &mut Self { @@ -373,12 +425,16 @@ impl Options { self } + /// In-place setter for [`Self::batch`]; returns `&mut self` so + /// calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_batch(&mut self, b: BatchOptions) -> &mut Self { self.batch = b; self } + /// In-place setter for [`Self::threads`]; returns `&mut self` so + /// calls can chain. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn set_threads(&mut self, t: ThreadOptions) -> &mut Self { self.threads = t; diff --git a/src/text_enc.rs b/src/text_enc.rs index db334c0..32afca2 100644 --- a/src/text_enc.rs +++ b/src/text_enc.rs @@ -48,6 +48,13 @@ impl TextEncoder { Self::from_ort_session_with_options(session, tokenizer, opts) } + /// Construct from a caller-built `ort::Session` and `Tokenizer`, + /// using the crate-default [`Options`]. Equivalent to calling + /// [`Self::from_ort_session_with_options`] with `Options::default()`. + /// On wasm32 this is the supported entry point because + /// `ort 2.0.0-rc.12` cfg-gates `commit_from_file` out of wasm + /// builds — wasm callers must build the `ort::Session` themselves + /// (e.g. via the wasm-specific async APIs) and pass it in. pub fn from_ort_session(session: ort::session::Session, tokenizer: Tokenizer) -> Result { Self::from_ort_session_with_options(session, tokenizer, Options::default()) } @@ -72,6 +79,10 @@ impl TextEncoder { }) } + /// Encode a single string and return its 768-dim L2-normalized + /// [`Embedding`]. Empty input is rejected with [`Error::EmptyText`]. + /// For multiple inputs, prefer [`Self::embed_batch`] — it amortizes + /// the per-call ORT overhead across the batch. pub fn embed(&mut self, text: &str) -> Result { if text.is_empty() { return Err(Error::EmptyText); @@ -129,6 +140,10 @@ impl TextEncoder { Ok(out) } + /// Run a single throwaway inference to amortize first-call ORT + /// graph compilation. Useful when latency-sensitive code wants to + /// pay the warm-up cost up-front rather than on the first user + /// request. pub fn warmup(&mut self) -> Result<()> { let _ = self.embed("warmup")?; Ok(()) From 4c2ec5fd5214d39bc638c2a5c6ee4adc0bdcb007 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Sat, 2 May 2026 18:53:11 +1200 Subject: [PATCH 11/11] cleanup --- README.md | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/README.md b/README.md index f92ecfd..4c77aa5 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
-

egemma

+

E-gemma

@@ -18,10 +18,6 @@ L2-normalized sentence embeddings via [`ort`] and [`tokenizers`].
-Rust ONNX inference library for Google's [EmbeddingGemma] -(`google/embeddinggemma-300m`) text embeddings. Produces 768-dim -L2-normalized sentence embeddings via [`ort`] and [`tokenizers`]. - ## Install ```toml @@ -29,28 +25,6 @@ L2-normalized sentence embeddings via [`ort`] and [`tokenizers`]. egemma = "0.1" ``` -## Quick start - -```rust,ignore -use egemma::TextEncoder; - -let mut encoder = TextEncoder::from_files( - "model.onnx".as_ref(), - "tokenizer.json".as_ref(), -)?; - -let embeddings = encoder.embed_batch(&[ - "task: search result | query: how do birds fly?", - "Birds use lift generated by their wings to fly.", - "The price of bananas in Tokyo is rising.", -])?; - -let related = embeddings[0].try_cosine(&embeddings[1])?; -let unrelated = embeddings[0].try_cosine(&embeddings[2])?; -assert!(related > unrelated); -# Ok::<(), egemma::Error>(()) -``` - Download the canonical fp32 export from [`onnx-community/embeddinggemma-300m-ONNX`][EmbeddingGemma] (`model.onnx` plus its `model.onnx_data` sidecar, and `tokenizer.json`).