From 02f382aacfd718443c2f0d35e1ca6f69bef0e353 Mon Sep 17 00:00:00 2001 From: johnnywalee Date: Sun, 11 Jan 2026 16:05:56 +0800 Subject: [PATCH] api support for image , spars and rerank --- .github/workflows/test-and-tag.yml | 69 --- .github/workflows/test.yml | 167 ++++++ Cargo.lock | 1 + Cargo.toml | 24 +- README.md | 456 ++++++++++++++--- src/bin/cli.rs | 251 +++++++++ src/bin/preload.rs | 144 +++++- src/core/embeddings.rs | 790 ++++++++++++++++++++++++++++- src/core/image_utils.rs | 287 +++++++++++ src/core/mod.rs | 33 +- src/core/types.rs | 195 +++++++ src/lambda.rs | 395 +++++++++++++-- src/lib.rs | 26 +- tests/integration_image_tests.rs | 323 ++++++++++++ tests/integration_rerank_tests.rs | 300 +++++++++++ tests/integration_sparse_tests.rs | 226 +++++++++ tests/integration_tests.rs | 695 ------------------------- tests/integration_text_tests.rs | 288 +++++++++++ tests/unit_tests.rs | 711 ++++---------------------- 19 files changed, 3873 insertions(+), 1508 deletions(-) delete mode 100644 .github/workflows/test-and-tag.yml create mode 100644 .github/workflows/test.yml create mode 100644 src/core/image_utils.rs create mode 100644 tests/integration_image_tests.rs create mode 100644 tests/integration_rerank_tests.rs create mode 100644 tests/integration_sparse_tests.rs delete mode 100644 tests/integration_tests.rs create mode 100644 tests/integration_text_tests.rs diff --git a/.github/workflows/test-and-tag.yml b/.github/workflows/test-and-tag.yml deleted file mode 100644 index b0ec2a8..0000000 --- a/.github/workflows/test-and-tag.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: Test and Tag - -on: - pull_request: - branches: - - main -env: - CARGO_TERM_COLOR: always - RUST_BACKTRACE: 1 - -jobs: - test: - name: Test Suite - runs-on: ubuntu-latest - services: - localstack: - image: localstack/localstack:4.11.1 - ports: - - 4566:4566 - - 4571:4571 - env: - SERVICES: s3 - DEBUG: 1 - DATA_DIR: /tmp/localstack/data - options: >- - --health-cmd "awslocal s3 ls || exit 1" - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - lfs: 'true' - - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - name: Cache cargo registry - uses: actions/cache@v3 - with: - path: ~/.cargo/registry - key: ${{ runner.os }}-cargo-registry-${{ hashFiles('Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-registry- - - - name: Cache cargo index - uses: actions/cache@v3 - with: - path: ~/.cargo/git - key: ${{ runner.os }}-cargo-index-${{ hashFiles('Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-index- - - - name: Cache cargo build - uses: actions/cache@v3 - with: - path: target - key: ${{ runner.os }}-cargo-build-target-${{ hashFiles('Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-build-target- - - - name: Run tests - env: - LOCALSTACK_ENDPOINT: http://127.0.0.1:4566 - run: cargo test --verbose \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..f10bb94 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,167 @@ +name: Test + +on: + pull_request: + branches: + - main +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + RUST_LOG: info + +jobs: + test: + name: Test Suite + runs-on: ubuntu-latest + services: + localstack: + image: localstack/localstack:4.11.1 + ports: + - 4566:4566 + - 4571:4571 + env: + SERVICES: s3 + DEBUG: 1 + DATA_DIR: /tmp/localstack/data + options: >- + --health-cmd "awslocal s3 ls || exit 1" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + lfs: 'true' + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + components: rustfmt, clippy + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v3 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-cargo-index-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v3 + with: + path: target + key: ${{ runner.os }}-cargo-build-target-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-build-target- + + - name: Cache fastembed models + uses: actions/cache@v3 + with: + path: | + ~/.cache/huggingface + .fastembed_cache + key: ${{ runner.os }}-fastembed-models-v2 + restore-keys: | + ${{ runner.os }}-fastembed-models- + + - name: Run unit tests + env: + LOCALSTACK_ENDPOINT: http://127.0.0.1:4566 + FASTEMBED_CACHE_PATH: .fastembed_cache + run: cargo test --verbose + + integration-tests: + name: Integration Tests - ${{ matrix.test }} + runs-on: ubuntu-latest + needs: test + strategy: + matrix: + test: [ text, image, sparse, rerank ] + include: + - test: text + timeout: 10 + - test: image + timeout: 15 + - test: sparse + timeout: 10 + - test: rerank + timeout: 10 + services: + localstack: + image: localstack/localstack:4.11.1 + ports: + - 4566:4566 + - 4571:4571 + env: + SERVICES: s3 + DEBUG: 1 + DATA_DIR: /tmp/localstack/data + options: >- + --health-cmd "awslocal s3 ls || exit 1" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + lfs: 'true' + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v3 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-cargo-index-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v3 + with: + path: target + key: ${{ runner.os }}-cargo-build-target-${{ hashFiles('Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-build-target- + + - name: Cache fastembed models + uses: actions/cache@v3 + with: + path: | + ~/.cache/huggingface + .fastembed_cache + key: ${{ runner.os }}-fastembed-models-v2-${{ matrix.test }} + restore-keys: | + ${{ runner.os }}-fastembed-models-v2- + ${{ runner.os }}-fastembed-models- + + - name: Run ${{ matrix.test }} integration tests + env: + LOCALSTACK_ENDPOINT: http://127.0.0.1:4566 + FASTEMBED_CACHE_PATH: .fastembed_cache + run: cargo test --features aws --test integration_${{ matrix.test }}_tests -- --nocapture --test-threads=1 + timeout-minutes: ${{ matrix.timeout }} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 4fd4467..ad1f9c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3580,6 +3580,7 @@ dependencies = [ "aws-config", "aws-credential-types", "aws-sdk-s3", + "base64 0.22.1", "clap", "fastembed", "lambda_runtime", diff --git a/Cargo.toml b/Cargo.toml index fc1f585..8ef1820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ fastembed = "5" once_cell = "1" clap = { version = "4", features = ["derive"] } thiserror = "2" +base64 = "0.22" # PDF support (optional) pdf-extract = { version = "0.10", optional = true } @@ -59,13 +60,28 @@ strip = true tokio-test = "0.4" [[test]] -name = "integration_tests" -path = "tests/integration_tests.rs" +name = "unit_tests" +path = "tests/unit_tests.rs" required-features = ["aws"] [[test]] -name = "unit_tests" -path = "tests/unit_tests.rs" +name = "integration_text_tests" +path = "tests/integration_text_tests.rs" +required-features = ["aws"] + +[[test]] +name = "integration_image_tests" +path = "tests/integration_image_tests.rs" +required-features = ["aws"] + +[[test]] +name = "integration_sparse_tests" +path = "tests/integration_sparse_tests.rs" +required-features = ["aws"] + +[[test]] +name = "integration_rerank_tests" +path = "tests/integration_rerank_tests.rs" required-features = ["aws"] [build-dependencies] diff --git a/README.md b/README.md index e6d943b..aedf279 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Serverless Vectorizer -[![CI](https://github.com/johnnywale/serverless-vectorizer/actions/workflows/test-and-tag.yml/badge.svg)](https://github.com/johnnywale/serverless-vectorizer/actions/workflows/ci.yml) +[![CI](https://github.com/johnnywale/serverless-vectorizer/actions/workflows/test.yml/badge.svg)](https://github.com/johnnywale/serverless-vectorizer/actions/workflows/ci.yml) [![Release](https://img.shields.io/github/v/release/johnnywale/serverless-vectorizer)](https://github.com/johnnywale/serverless-vectorizer/releases) [![License](https://img.shields.io/github/license/johnnywale/serverless-vectorizer)](LICENSE-MIT) [![Docker Pulls](https://img.shields.io/docker/pulls/johnnywale/serverless-vectorizer)](https://hub.docker.com/r/johnnywale/serverless-vectorizer) @@ -196,11 +196,58 @@ cargo run --bin list-models -- -c all -## Usage +## Lambda API Reference -The Lambda supports two invocation methods: +The Lambda automatically detects the model type from the `MODEL_ID` environment variable and routes requests accordingly. Each model type has its own request/response format. -### Direct Lambda Invocation +### Model Type Auto-Detection + +| MODEL_ID Pattern | Model Type | Use Case | +|-----------------|------------|----------| +| Text embedding models | `text` | Semantic search, similarity | +| `Qdrant/clip-ViT-B-32-vision`, etc. | `image` | Image similarity, visual search | +| `Qdrant/Splade_PP_en_v1`, etc. | `sparse` | Hybrid search, keyword matching | +| `BAAI/bge-reranker-*`, etc. | `rerank` | Re-ranking search results | + +--- + +## Text Embeddings + +Generate dense vector embeddings for text. Default model type. + +### Request + +```json +{ + "messages": ["Hello world", "How are you?"] +} +``` + +Or read from S3: + +```json +{ + "s3_file": "my-bucket/path/to/texts.json" +} +``` + +### Response + +```json +{ + "embeddings": [ + [0.123, 0.456, -0.789, ...], + [0.321, 0.654, -0.987, ...] + ], + "dimension": 384, + "model_type": "text", + "count": 2 +} +``` + +### Examples + +**Direct Lambda Invocation:** ```bash aws lambda invoke \ @@ -209,55 +256,308 @@ aws lambda invoke \ response.json ``` -**Response:** +**Local Docker Testing:** + +```bash +# Start the container +docker run -p 9000:8080 johnnywalee/serverless-vectorizer:latest-Xenova/bge-small-en-v1.5 + +# Send request +curl -X POST "http://localhost:9000/2015-03-31/functions/function/invocations" \ + -H "Content-Type: application/json" \ + -d '{"messages": ["Hello world", "How are you?"]}' +``` + +**API Gateway:** + +```bash +curl -X POST https://your-api.execute-api.region.amazonaws.com/embed \ + -H "Content-Type: application/json" \ + -d '{"messages": ["Hello world", "How are you?"]}' +``` + +--- + +## Image Embeddings + +Generate dense vector embeddings for images. Requires an image embedding model (e.g., `Qdrant/clip-ViT-B-32-vision`). + +### Request + +Images can be provided as base64-encoded data or S3 paths: + +```json +{ + "images": [ + {"base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ..."}, + {"s3_path": "my-bucket/images/photo.jpg"} + ] +} +``` + +Or using `s3_images` for multiple S3 paths: + +```json +{ + "s3_images": [ + "my-bucket/images/photo1.jpg", + "my-bucket/images/photo2.png" + ] +} +``` + +### Response ```json { "embeddings": [ - [ - 0.123, - 0.456, - ... - ], - [ - 0.789, - 0.012, - ... + [0.123, 0.456, -0.789, ...], + [0.321, 0.654, -0.987, ...] + ], + "dimension": 512, + "model_type": "image", + "count": 2 +} +``` + +### Examples + +**Build Image Embedding Container:** + +```bash +docker build \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Qdrant/clip-ViT-B-32-vision \ + -f Dockerfile.variant \ + -t serverless-vectorizer:clip . +``` + +**Local Docker Testing:** + +```bash +# Start the container +docker run -p 9000:8080 serverless-vectorizer:clip + +# Send request with base64 image +curl -X POST "http://localhost:9000/2015-03-31/functions/function/invocations" \ + -H "Content-Type: application/json" \ + -d '{ + "images": [ + {"base64": "'"$(base64 -w 0 image.png)"'"} ] + }' +``` + +**Lambda with S3 Images:** + +```bash +aws lambda invoke \ + --function-name serverless-vectorizer-image \ + --payload '{ + "s3_images": ["my-bucket/images/photo1.jpg", "my-bucket/images/photo2.jpg"] + }' \ + response.json +``` + +--- + +## Sparse Embeddings + +Generate sparse vector embeddings for text using SPLADE models. Useful for hybrid search combining dense and sparse vectors. + +### Request + +```json +{ + "messages": ["The quick brown fox jumps over the lazy dog"] +} +``` + +### Response + +```json +{ + "sparse_embeddings": [ + { + "indices": [102, 456, 789, 1234, 5678], + "values": [0.5, 0.3, 0.8, 0.2, 0.9] + } ], - "dimension": 384 + "model_type": "sparse", + "count": 1 } ``` -### API Gateway (POST /embed) +The sparse embedding contains: +- `indices`: Token indices with non-zero weights +- `values`: Corresponding weights for each token + +### Examples + +**Build Sparse Embedding Container:** ```bash -curl -X POST https://your-api.execute-api.region.amazonaws.com/embed \ +docker build \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Qdrant/Splade_PP_en_v1 \ + -f Dockerfile.variant \ + -t serverless-vectorizer:splade . +``` + +**Local Docker Testing:** + +```bash +# Start the container +docker run -p 9000:8080 serverless-vectorizer:splade + +# Send request +curl -X POST "http://localhost:9000/2015-03-31/functions/function/invocations" \ -H "Content-Type: application/json" \ - -d '{"messages": ["Hello world", "How are you?"]}' + -d '{"messages": ["Machine learning is a subset of artificial intelligence"]}' +``` + +**Lambda Invocation:** + +```bash +aws lambda invoke \ + --function-name serverless-vectorizer-sparse \ + --payload '{"messages": ["Machine learning is a subset of artificial intelligence"]}' \ + response.json +``` + +--- + +## Reranking + +Re-rank documents based on relevance to a query. Useful for improving search results. + +### Request + +```json +{ + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of AI that enables computers to learn from data.", + "The weather today is sunny and warm.", + "Deep learning uses neural networks with many layers." + ], + "top_k": 2, + "return_documents": true +} +``` + +Parameters: +- `query`: The search query +- `documents`: Array of documents to rank +- `top_k` (optional): Return only top K results +- `return_documents` (optional, default: true): Include document text in response + +### Response + +```json +{ + "rankings": [ + { + "index": 0, + "score": 0.95, + "document": "Machine learning is a subset of AI that enables computers to learn from data." + }, + { + "index": 2, + "score": 0.82, + "document": "Deep learning uses neural networks with many layers." + } + ], + "model_type": "rerank", + "count": 2 +} +``` + +Results are sorted by score in descending order. + +### Examples + +**Build Reranking Container:** + +```bash +docker build \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=BAAI/bge-reranker-base \ + -f Dockerfile.variant \ + -t serverless-vectorizer:reranker . +``` + +**Local Docker Testing:** + +```bash +# Start the container +docker run -p 9000:8080 serverless-vectorizer:reranker + +# Send request +curl -X POST "http://localhost:9000/2015-03-31/functions/function/invocations" \ + -H "Content-Type: application/json" \ + -d '{ + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of AI that enables computers to learn from data.", + "The weather today is sunny and warm.", + "Deep learning uses neural networks with many layers." + ], + "top_k": 2 + }' +``` + +**Lambda Invocation:** + +```bash +aws lambda invoke \ + --function-name serverless-vectorizer-rerank \ + --payload '{ + "query": "Capital cities of Europe", + "documents": [ + "Paris is the capital of France.", + "Pizza is a popular Italian food.", + "London is the capital of England.", + "Berlin is the capital of Germany." + ], + "top_k": 3 + }' \ + response.json ``` +--- + ## S3 Integration +All model types support reading from and writing to S3. + ### Read Input from S3 -Instead of passing messages directly, read text from an S3 file: +**Text Embeddings:** ```bash aws lambda invoke \ --function-name serverless-vectorizer \ - --payload '{"s3_file": "my-bucket/path/to/input.txt"}' \ + --payload '{"s3_file": "my-bucket/path/to/texts.json"}' \ response.json ``` The S3 file can contain: - - Plain text (embedded as single document) - JSON array of strings (each string embedded separately) +**Image Embeddings:** + +```bash +aws lambda invoke \ + --function-name serverless-vectorizer-image \ + --payload '{"s3_images": ["my-bucket/images/photo1.jpg", "my-bucket/images/photo2.png"]}' \ + response.json +``` + ### Save Output to S3 -Save embeddings directly to S3: +Save embeddings directly to S3 (text and image models only): ```bash aws lambda invoke \ @@ -276,14 +576,10 @@ aws lambda invoke \ ```json { - "embeddings": [ - [ - 0.123, - 0.456, - ... - ] - ], + "embeddings": [[0.123, 0.456, ...]], "dimension": 384, + "model_type": "text", + "count": 1, "s3_location": "s3://my-output-bucket/embeddings/output.json" } ``` @@ -305,44 +601,72 @@ aws lambda invoke \ response.json ``` -## Request Schema +--- + +## Complete Request Schema ```json { - "messages": [ - "text1", - "text2" + // === Text Embedding Input === + "messages": ["text1", "text2"], // Direct text input + "s3_file": "bucket/key", // OR read text from S3 + + // === Image Embedding Input === + "images": [ // Image input array + {"base64": "..."}, // Base64 encoded image + {"s3_path": "bucket/key"} // OR S3 path to image ], - // Direct text input (array of strings) - "s3_file": "bucket/key", - // OR read input from S3 - "save_to_s3": { - // Optional: save embeddings to S3 + "s3_images": ["bucket/key1", "bucket/key2"], // OR S3 paths array + + // === Reranking Input === + "query": "search query", // Query for reranking + "documents": ["doc1", "doc2"], // Documents to rank + "top_k": 5, // Return top K results (optional) + "return_documents": true, // Include docs in response (optional) + + // === Output Options === + "save_to_s3": { // Save results to S3 (optional) "bucket": "bucket-name", "key": "path/to/output.json" } } ``` -Either `messages` or `s3_file` must be provided. `save_to_s3` is optional. +## Complete Response Schema -## Response Schema +**Text/Image Embeddings:** ```json { - "embeddings": [ - [ - ... - ], - [ - ... - ] + "embeddings": [[...], [...]], // Dense embedding vectors + "dimension": 384, // Vector dimension + "model_type": "text", // "text" or "image" + "count": 2, // Number of embeddings + "s3_location": "s3://..." // If save_to_s3 was used +} +``` + +**Sparse Embeddings:** + +```json +{ + "sparse_embeddings": [ + {"indices": [...], "values": [...]} ], - // Array of embedding vectors - "dimension": 384, - // Vector dimension - "s3_location": "s3://..." - // Only present if save_to_s3 was used + "model_type": "sparse", + "count": 1 +} +``` + +**Reranking:** + +```json +{ + "rankings": [ + {"index": 0, "score": 0.95, "document": "..."} + ], + "model_type": "rerank", + "count": 2 } ``` @@ -363,12 +687,18 @@ cargo test --test unit_tests --features aws # Start LocalStack for integration tests docker-compose up -d -# Run integration tests -AWS_ENDPOINT_URL=http://localhost:4566 \ -AWS_ACCESS_KEY_ID=test \ -AWS_SECRET_ACCESS_KEY=test \ -AWS_DEFAULT_REGION=us-east-1 \ -cargo test --test integration_tests --features aws -- --test-threads=1 +# Run all integration tests (can run in parallel - each has its own MODEL_ID) +cargo test --features aws --test integration_text_tests & +cargo test --features aws --test integration_image_tests & +cargo test --features aws --test integration_sparse_tests & +cargo test --features aws --test integration_rerank_tests & +wait + +# Or run specific model type tests +cargo test --features aws --test integration_text_tests # Text embeddings +cargo test --features aws --test integration_image_tests # Image embeddings +cargo test --features aws --test integration_sparse_tests # Sparse embeddings +cargo test --features aws --test integration_rerank_tests # Reranking # Stop LocalStack docker-compose down @@ -384,11 +714,15 @@ docker-compose down │ ├── lib.rs # Library exports │ └── core/ │ ├── model.rs # Model definitions and registry -│ ├── embeddings.rs # Embedding service +│ ├── embeddings.rs # Embedding services (text, image, sparse, rerank) +│ ├── image_utils.rs # Image loading utilities │ └── ... ├── tests/ -│ ├── unit_tests.rs -│ └── integration_tests.rs +│ ├── unit_tests.rs # Unit tests +│ ├── integration_text_tests.rs # Text embedding integration tests +│ ├── integration_image_tests.rs # Image embedding integration tests +│ ├── integration_sparse_tests.rs # Sparse embedding integration tests +│ └── integration_rerank_tests.rs # Reranking integration tests ├── Dockerfile # Base image ├── Dockerfile.variant # Model-specific variant builder └── docker-compose.yaml # LocalStack for testing diff --git a/src/bin/cli.rs b/src/bin/cli.rs index 35063a5..928be0a 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -12,6 +12,11 @@ use serverless_vectorizer::{ // Types EmbeddingOutput, SearchResult, SearchResponse, ClusterResponse, DistanceMatrixResponse, BenchmarkResult, + // New types for image/sparse/rerank + ImageEmbeddingOutput, SparseEmbeddingOutput, + RerankOutput, + // New services + ImageEmbeddingService, SparseEmbeddingService, RerankService, }; #[cfg(feature = "pdf")] use serverless_vectorizer::{extract_text_from_file, is_pdf_file}; @@ -231,6 +236,98 @@ enum Commands { #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] model: String, }, + + /// Generate embeddings for images + EmbedImage { + /// Path to image file + #[arg(short, long)] + image: PathBuf, + + /// Output file path + #[arg(short, long)] + output: Option, + + /// Image embedding model to use (use 'info --category image' to list) + #[arg(long, default_value = "Qdrant/clip-ViT-B-32-vision")] + model: String, + + /// Pretty print JSON output + #[arg(long)] + pretty: bool, + }, + + /// Batch embed multiple images from a directory + BatchImages { + /// Directory containing images + #[arg(short, long)] + directory: PathBuf, + + /// Output file path + #[arg(short, long)] + output: Option, + + /// Image embedding model to use + #[arg(long, default_value = "Qdrant/clip-ViT-B-32-vision")] + model: String, + + /// Pretty print JSON output + #[arg(long)] + pretty: bool, + }, + + /// Generate sparse embeddings for text (useful for hybrid search) + EmbedSparse { + /// Text to embed + #[arg(short, long)] + text: Option, + + /// File containing texts (one per line or JSON array) + #[arg(short, long)] + file: Option, + + /// Output file path + #[arg(short, long)] + output: Option, + + /// Sparse embedding model to use (use 'info --category sparse' to list) + #[arg(long, default_value = "Qdrant/bm42-all-minilm-l6-v2-attentions")] + model: String, + + /// Pretty print JSON output + #[arg(long)] + pretty: bool, + }, + + /// Rerank documents by relevance to a query + Rerank { + /// The search query + #[arg(short, long)] + query: String, + + /// Documents to rerank (can be specified multiple times) + #[arg(short, long)] + documents: Vec, + + /// File containing documents (one per line or JSON array) + #[arg(short, long)] + file: Option, + + /// Output file path + #[arg(short, long)] + output: Option, + + /// Number of top results to return + #[arg(short = 'k', long)] + top_k: Option, + + /// Reranking model to use (use 'info --category rerank' to list) + #[arg(long, default_value = "BAAI/bge-reranker-base")] + model: String, + + /// Pretty print JSON output + #[arg(long)] + pretty: bool, + }, } // Embedding service with model caching @@ -767,6 +864,160 @@ fn main() -> Result<(), Box> { write_output(&output_str, output)?; } } + + Commands::EmbedImage { image, output, model, pretty } => { + // Resolve image model + let image_model = ModelRegistry::find_image_model(&model) + .ok_or_else(|| format!("Unknown image model: '{}'\n\nUse 'info --category image' to list available models", model))?; + + // Load image bytes + let image_bytes = fs::read(&image)?; + eprintln!("Loaded image: {:?} ({} bytes)", image, image_bytes.len()); + + // Create service and generate embedding + let service = ImageEmbeddingService::new().with_progress(true); + let embeddings = service.embed_images_with_model(&[image_bytes], image_model) + .map_err(|e| format!("Image embedding failed: {}", e))?; + + let out = ImageEmbeddingOutput::new(embeddings).with_model(&model); + let output_str = if pretty { + serde_json::to_string_pretty(&out)? + } else { + serde_json::to_string(&out)? + }; + write_output(&output_str, output)?; + } + + Commands::BatchImages { directory, output, model, pretty } => { + // Resolve image model + let image_model = ModelRegistry::find_image_model(&model) + .ok_or_else(|| format!("Unknown image model: '{}'\n\nUse 'info --category image' to list available models", model))?; + + // Find all images in directory + let extensions = ["jpg", "jpeg", "png", "gif", "webp", "bmp"]; + let mut image_paths: Vec = Vec::new(); + + for entry in fs::read_dir(&directory)? { + let entry = entry?; + let path = entry.path(); + if let Some(ext) = path.extension() { + if extensions.iter().any(|e| ext.eq_ignore_ascii_case(e)) { + image_paths.push(path); + } + } + } + + if image_paths.is_empty() { + eprintln!("No images found in {:?}", directory); + std::process::exit(1); + } + + eprintln!("Found {} images in {:?}", image_paths.len(), directory); + + // Load all images + let mut image_bytes_list: Vec> = Vec::new(); + for path in &image_paths { + let bytes = fs::read(path)?; + image_bytes_list.push(bytes); + } + + // Generate embeddings + let service = ImageEmbeddingService::new().with_progress(true); + let embeddings = service.embed_images_with_model(&image_bytes_list.iter().map(|v| v.clone()).collect::>().as_slice(), image_model) + .map_err(|e| format!("Image embedding failed: {}", e))?; + + let result = serde_json::json!({ + "images": image_paths.iter().map(|p| p.to_string_lossy().to_string()).collect::>(), + "embeddings": embeddings, + "dimension": embeddings.first().map(|e| e.len()).unwrap_or(0), + "count": embeddings.len(), + "model": model + }); + + let output_str = if pretty { + serde_json::to_string_pretty(&result)? + } else { + serde_json::to_string(&result)? + }; + write_output(&output_str, output)?; + } + + Commands::EmbedSparse { text, file, output, model, pretty } => { + // Resolve sparse model + let sparse_model = ModelRegistry::find_sparse_model(&model) + .ok_or_else(|| format!("Unknown sparse model: '{}'\n\nUse 'info --category sparse' to list available models", model))?; + + // Get texts to embed + let texts = get_input(text, file)?; + + if texts.is_empty() { + eprintln!("No texts provided"); + std::process::exit(1); + } + + eprintln!("Generating sparse embeddings for {} text(s)...", texts.len()); + + // Generate sparse embeddings + let service = SparseEmbeddingService::new().with_progress(true); + let sparse_embeddings = service.embed_with_model(texts, sparse_model) + .map_err(|e| format!("Sparse embedding failed: {}", e))?; + + let out = SparseEmbeddingOutput::new(sparse_embeddings).with_model(&model); + let output_str = if pretty { + serde_json::to_string_pretty(&out)? + } else { + serde_json::to_string(&out)? + }; + write_output(&output_str, output)?; + } + + Commands::Rerank { query, documents, file, output, top_k, model, pretty } => { + // Resolve rerank model + let rerank_model = ModelRegistry::find_rerank_model(&model) + .ok_or_else(|| format!("Unknown rerank model: '{}'\n\nUse 'info --category rerank' to list available models", model))?; + + // Collect documents + let mut docs: Vec = documents; + + // Add documents from file if provided + if let Some(file_path) = file { + let content = fs::read_to_string(&file_path)?; + let file_docs: Vec = if content.trim().starts_with('[') { + serde_json::from_str(&content)? + } else { + content.lines() + .filter(|l| !l.trim().is_empty()) + .map(|l| l.to_string()) + .collect() + }; + docs.extend(file_docs); + } + + if docs.is_empty() { + eprintln!("No documents provided"); + std::process::exit(1); + } + + eprintln!("Reranking {} document(s) for query: \"{}\"", docs.len(), query); + + // Perform reranking + let service = RerankService::new().with_progress(true); + let mut results = service.rerank_with_model(&query, docs, true, rerank_model) + .map_err(|e| format!("Reranking failed: {}", e))?; + + // Apply top_k if specified + if let Some(k) = top_k { + results.truncate(k); + } + + let out = RerankOutput::new(query, results).with_model(&model); + let output_str = if pretty { + serde_json::to_string_pretty(&out)? + } else { + serde_json::to_string(&out)? + }; + write_output(&output_str, output)?; + } } Ok(()) diff --git a/src/bin/preload.rs b/src/bin/preload.rs index 5885c4e..c1c7d9e 100644 --- a/src/bin/preload.rs +++ b/src/bin/preload.rs @@ -1,4 +1,10 @@ -use fastembed::{InitOptions, TextEmbedding}; +// Preload binary for all model types (text, image, sparse, rerank) +// Detects model type automatically from MODEL_ID + +use fastembed::{ + ImageEmbedding, ImageInitOptions, InitOptions, RerankInitOptions, SparseInitOptions, + SparseTextEmbedding, TextEmbedding, TextRerank, +}; use serverless_vectorizer::ModelRegistry; use std::env; @@ -10,33 +16,121 @@ fn main() { .map(String::as_str) .unwrap_or("Xenova/bge-small-en-v1.5"); - // Use ModelRegistry to dynamically find the model - let embedding_model = match ModelRegistry::find_text_model(model_id) { - Some(model) => model, - None => { - eprintln!( - "Warning: Unknown model '{}', falling back to Xenova/bge-small-en-v1.5", - model_id - ); - eprintln!("\nAvailable text embedding models:"); - for model_info in ModelRegistry::text_embedding_models() { - eprintln!(" - {} ({}D)", model_info.model_id, model_info.dimension.unwrap_or(0)); - } - ModelRegistry::default_text_model() - } - }; - println!("======================================"); - println!("Preloading embedding model: {}", model_id); + println!("Preloading model: {}", model_id); - let mut model = - TextEmbedding::try_new(InitOptions::new(embedding_model).with_show_download_progress(true)) - .expect("Failed to initialize model"); + // Auto-detect model category and preload accordingly + if let Some(model) = ModelRegistry::find_text_model(model_id) { + println!("Model type: Text Embedding"); + preload_text_model(model, model_id); + } else if let Some(model) = ModelRegistry::find_image_model(model_id) { + println!("Model type: Image Embedding"); + preload_image_model(model, model_id); + } else if let Some(model) = ModelRegistry::find_sparse_model(model_id) { + println!("Model type: Sparse Text Embedding"); + preload_sparse_model(model, model_id); + } else if let Some(model) = ModelRegistry::find_rerank_model(model_id) { + println!("Model type: Reranking"); + preload_rerank_model(model, model_id); + } else { + eprintln!("Error: Unknown model '{}'", model_id); + eprintln!("\nAvailable models by category:"); + list_all_models(); + std::process::exit(1); + } - let embeddings = model + println!("Done!"); +} + +fn preload_text_model(model: fastembed::EmbeddingModel, _model_id: &str) { + let mut embedding = TextEmbedding::try_new( + InitOptions::new(model).with_show_download_progress(true), + ) + .expect("Failed to initialize text embedding model"); + + // Warm up with a test embedding + let embeddings = embedding .embed(vec!["test".to_string()], None) - .expect("Failed to generate embedding"); + .expect("Failed to generate test embedding"); - println!("Embedding dimension = {}", embeddings[0].len()); - println!("Done!"); + println!("Embedding dimension: {}", embeddings[0].len()); +} + +fn preload_image_model(model: fastembed::ImageEmbeddingModel, model_id: &str) { + let _embedding = ImageEmbedding::try_new( + ImageInitOptions::new(model).with_show_download_progress(true), + ) + .expect("Failed to initialize image embedding model"); + + // Get dimension from model info + let dim = ImageEmbedding::list_supported_models() + .into_iter() + .find(|info| info.model_code == model_id) + .map(|info| info.dim) + .unwrap_or(512); + + println!("Embedding dimension: {}", dim); + println!("Note: Image model loaded (no test embedding without image file)"); +} + +fn preload_sparse_model(model: fastembed::SparseModel, _model_id: &str) { + let mut embedding = SparseTextEmbedding::try_new( + SparseInitOptions::new(model).with_show_download_progress(true), + ) + .expect("Failed to initialize sparse embedding model"); + + // Warm up with a test embedding + let embeddings = embedding + .embed(vec!["test".to_string()], None) + .expect("Failed to generate test sparse embedding"); + + println!( + "Sparse embedding non-zero elements: {}", + embeddings[0].indices.len() + ); +} + +fn preload_rerank_model(model: fastembed::RerankerModel, _model_id: &str) { + let mut reranker = TextRerank::try_new( + RerankInitOptions::new(model).with_show_download_progress(true), + ) + .expect("Failed to initialize reranking model"); + + // Warm up with a test rerank + let results = reranker + .rerank("test query", vec!["test document"], true, None) + .expect("Failed to run test rerank"); + + println!("Rerank test score: {:.4}", results[0].score); +} + +fn list_all_models() { + println!("\n## Text Embedding Models:"); + for model in ModelRegistry::text_embedding_models().iter().take(10) { + println!( + " - {} ({}D)", + model.model_id, + model.dimension.unwrap_or(0) + ); + } + println!(" ... and more (use list-models for full list)"); + + println!("\n## Image Embedding Models:"); + for model in ModelRegistry::image_embedding_models() { + println!( + " - {} ({}D)", + model.model_id, + model.dimension.unwrap_or(0) + ); + } + + println!("\n## Sparse Text Embedding Models:"); + for model in ModelRegistry::sparse_text_embedding_models() { + println!(" - {}", model.model_id); + } + + println!("\n## Reranking Models:"); + for model in ModelRegistry::rerank_models() { + println!(" - {}", model.model_id); + } } diff --git a/src/core/embeddings.rs b/src/core/embeddings.rs index a6dfc0b..bdcb601 100644 --- a/src/core/embeddings.rs +++ b/src/core/embeddings.rs @@ -1,7 +1,11 @@ -// Embedding generation service +// Embedding generation service supporting multiple model types -use crate::core::model::ModelType; -use fastembed::{InitOptions, TextEmbedding}; +use crate::core::model::{ModelCategory, ModelRegistry, ModelType}; +use crate::core::types::{RerankResult, SparseEmbedding}; +use fastembed::{ + ImageEmbedding, ImageEmbeddingModel, ImageInitOptions, InitOptions, + RerankerModel, SparseInitOptions, SparseModel, SparseTextEmbedding, TextEmbedding, TextRerank, +}; use std::collections::HashMap; use std::sync::Mutex; use thiserror::Error; @@ -23,9 +27,22 @@ pub enum EmbeddingError { #[error("Invalid input: {0}")] InvalidInput(String), + + #[error("Model not found: {0}")] + ModelNotFound(String), + + #[error("Wrong model type: expected {expected}, got {actual}")] + WrongModelType { expected: String, actual: String }, + + #[error("Image loading error: {0}")] + ImageLoadError(String), } -/// Thread-safe embedding service with model caching +// ============================================================================ +// Text Embedding Service (existing, refactored) +// ============================================================================ + +/// Thread-safe text embedding service with model caching pub struct EmbeddingService { models: Mutex>, show_progress: bool, @@ -122,6 +139,63 @@ impl EmbeddingService { pub fn dimension(&self, model_type: ModelType) -> usize { model_type.dimension() } + + /// Unload a specific model from cache + pub fn unload(&self, model_type: ModelType) -> Result { + let mut models = self + .models + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + Ok(models.remove(&model_type).is_some()) + } + + /// Unload all models from cache + pub fn unload_all(&self) -> Result { + let mut models = self + .models + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let count = models.len(); + models.clear(); + Ok(count) + } + + /// Reload a model (unload then preload) + pub fn reload(&self, model_type: ModelType) -> Result<(), EmbeddingError> { + self.unload(model_type)?; + self.preload(model_type) + } + + /// Get list of currently loaded models + pub fn loaded_models(&self) -> Vec { + self.models + .lock() + .map(|models| models.keys().cloned().collect()) + .unwrap_or_default() + } + + /// Generate embeddings using a fastembed model directly + pub fn embed_with_model( + &self, + texts: Vec, + model: fastembed::EmbeddingModel, + ) -> Result>, EmbeddingError> { + if texts.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + // Create a new TextEmbedding for this model + let mut embedding = TextEmbedding::try_new( + InitOptions::new(model).with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + embedding + .embed(texts, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string())) + } } impl Default for EmbeddingService { @@ -149,6 +223,685 @@ pub fn embed_one(text: &str, model_type: ModelType) -> Result, Embeddin global_service().embed_one(text, model_type) } +// ============================================================================ +// Image Embedding Service +// ============================================================================ + +/// Thread-safe image embedding service +pub struct ImageEmbeddingService { + model: Mutex>, + show_progress: bool, +} + +impl ImageEmbeddingService { + pub fn new() -> Self { + ImageEmbeddingService { + model: Mutex::new(None), + show_progress: false, + } + } + + pub fn with_progress(mut self, show: bool) -> Self { + self.show_progress = show; + self + } + + /// Initialize with a specific model + pub fn init(&self, model_id: &str) -> Result<(), EmbeddingError> { + let embedding_model = ModelRegistry::find_image_model(model_id) + .ok_or_else(|| EmbeddingError::ModelNotFound(model_id.to_string()))?; + + let model = ImageEmbedding::try_new( + ImageInitOptions::new(embedding_model.clone()) + .with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + *guard = Some((embedding_model, model)); + + Ok(()) + } + + /// Generate embeddings for images (as byte arrays) + /// Note: fastembed requires file paths, so we write temp files + pub fn embed_images(&self, images: Vec>) -> Result>, EmbeddingError> { + if images.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + // Write images to temp files since fastembed expects paths + // Use unique ID to avoid collisions when running in parallel + let temp_dir = std::env::temp_dir(); + let unique_id = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + let mut temp_paths: Vec = Vec::with_capacity(images.len()); + + for (i, image_bytes) in images.iter().enumerate() { + // Detect image format and use appropriate extension + let ext = crate::core::image_utils::detect_image_format(image_bytes).unwrap_or("png"); + let temp_path = temp_dir.join(format!("fastembed_img_{}_{}.{}", unique_id, i, ext)); + std::fs::write(&temp_path, image_bytes) + .map_err(|e| EmbeddingError::ImageLoadError(e.to_string()))?; + temp_paths.push(temp_path); + } + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let (_, model) = guard + .as_mut() + .ok_or_else(|| EmbeddingError::ModelInitError("Model not initialized".to_string()))?; + + let result = model + .embed(temp_paths.clone(), None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string())); + + // Clean up temp files + for path in temp_paths { + let _ = std::fs::remove_file(path); + } + + result + } + + /// Generate embeddings for images from file paths + pub fn embed_from_paths + Send + Sync>( + &self, + paths: Vec

, + ) -> Result>, EmbeddingError> { + if paths.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let (_, model) = guard + .as_mut() + .ok_or_else(|| EmbeddingError::ModelInitError("Model not initialized".to_string()))?; + + model + .embed(paths, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string())) + } + + /// Generate embedding for a single image + pub fn embed_one(&self, image: &[u8]) -> Result, EmbeddingError> { + let embeddings = self.embed_images(vec![image.to_vec()])?; + embeddings + .into_iter() + .next() + .ok_or_else(|| EmbeddingError::EmbeddingFailed("No embedding generated".to_string())) + } + + /// Check if model is loaded + pub fn is_loaded(&self) -> bool { + self.model + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } + + /// Get dimension of the loaded model + pub fn dimension(&self) -> Option { + self.model + .lock() + .ok() + .and_then(|guard| { + guard.as_ref().map(|(model_enum, _)| { + ImageEmbedding::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model_enum)) + .map(|info| info.dim) + .unwrap_or(512) + }) + }) + } + + /// Unload the current model + pub fn unload(&self) -> Result { + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let was_loaded = guard.is_some(); + *guard = None; + Ok(was_loaded) + } + + /// Reload the model with a new model ID + pub fn reload(&self, model_id: &str) -> Result<(), EmbeddingError> { + self.unload()?; + self.init(model_id) + } + + /// Get the currently loaded model ID + pub fn loaded_model_id(&self) -> Option { + self.model + .lock() + .ok() + .and_then(|guard| { + guard.as_ref().map(|(model_enum, _)| { + ImageEmbedding::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model_enum)) + .map(|info| info.model_code.to_string()) + .unwrap_or_default() + }) + }) + } + + /// Generate embeddings for images using a specific model + pub fn embed_images_with_model( + &self, + images: &[Vec], + model: ImageEmbeddingModel, + ) -> Result>, EmbeddingError> { + if images.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + // Write images to temp files since fastembed expects paths + // Use unique ID to avoid collisions when running in parallel + let temp_dir = std::env::temp_dir(); + let unique_id = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + let mut temp_paths: Vec = Vec::with_capacity(images.len()); + + for (i, image_bytes) in images.iter().enumerate() { + // Detect image format and use appropriate extension + let ext = crate::core::image_utils::detect_image_format(image_bytes).unwrap_or("png"); + let temp_path = temp_dir.join(format!("fastembed_img_{}_{}.{}", unique_id, i, ext)); + std::fs::write(&temp_path, image_bytes) + .map_err(|e| EmbeddingError::ImageLoadError(e.to_string()))?; + temp_paths.push(temp_path); + } + + // Create embedding model + let mut embedding = ImageEmbedding::try_new( + ImageInitOptions::new(model).with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let result = embedding + .embed(temp_paths.clone(), None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string())); + + // Clean up temp files + for path in temp_paths { + let _ = std::fs::remove_file(path); + } + + result + } +} + +impl Default for ImageEmbeddingService { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================ +// Sparse Text Embedding Service +// ============================================================================ + +/// Thread-safe sparse text embedding service +pub struct SparseEmbeddingService { + model: Mutex>, + show_progress: bool, +} + +impl SparseEmbeddingService { + pub fn new() -> Self { + SparseEmbeddingService { + model: Mutex::new(None), + show_progress: false, + } + } + + pub fn with_progress(mut self, show: bool) -> Self { + self.show_progress = show; + self + } + + /// Initialize with a specific model + pub fn init(&self, model_id: &str) -> Result<(), EmbeddingError> { + let sparse_model = ModelRegistry::find_sparse_model(model_id) + .ok_or_else(|| EmbeddingError::ModelNotFound(model_id.to_string()))?; + + let model = SparseTextEmbedding::try_new( + SparseInitOptions::new(sparse_model.clone()) + .with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + *guard = Some((sparse_model, model)); + + Ok(()) + } + + /// Generate sparse embeddings for texts + pub fn embed(&self, texts: Vec) -> Result, EmbeddingError> { + if texts.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let (_, model) = guard + .as_mut() + .ok_or_else(|| EmbeddingError::ModelInitError("Model not initialized".to_string()))?; + + let results = model + .embed(texts, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string()))?; + + // Convert fastembed's sparse format to our SparseEmbedding + Ok(results + .into_iter() + .map(|sparse| SparseEmbedding::new(sparse.indices, sparse.values)) + .collect()) + } + + /// Generate sparse embedding for a single text + pub fn embed_one(&self, text: &str) -> Result { + let embeddings = self.embed(vec![text.to_string()])?; + embeddings + .into_iter() + .next() + .ok_or_else(|| EmbeddingError::EmbeddingFailed("No embedding generated".to_string())) + } + + /// Check if model is loaded + pub fn is_loaded(&self) -> bool { + self.model + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } + + /// Unload the current model + pub fn unload(&self) -> Result { + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let was_loaded = guard.is_some(); + *guard = None; + Ok(was_loaded) + } + + /// Reload the model with a new model ID + pub fn reload(&self, model_id: &str) -> Result<(), EmbeddingError> { + self.unload()?; + self.init(model_id) + } + + /// Get the currently loaded model ID + pub fn loaded_model_id(&self) -> Option { + self.model + .lock() + .ok() + .and_then(|guard| { + guard.as_ref().map(|(model_enum, _)| { + SparseTextEmbedding::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model_enum)) + .map(|info| info.model_code.to_string()) + .unwrap_or_default() + }) + }) + } + + /// Generate sparse embeddings using a specific model + pub fn embed_with_model( + &self, + texts: Vec, + model: SparseModel, + ) -> Result, EmbeddingError> { + if texts.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + // Create embedding model + let mut embedding = SparseTextEmbedding::try_new( + SparseInitOptions::new(model).with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let results = embedding + .embed(texts, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string()))?; + + // Convert fastembed's sparse format to our SparseEmbedding + Ok(results + .into_iter() + .map(|sparse| SparseEmbedding::new(sparse.indices, sparse.values)) + .collect()) + } +} + +impl Default for SparseEmbeddingService { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================ +// Reranking Service +// ============================================================================ + +/// Thread-safe text reranking service +pub struct RerankService { + model: Mutex>, + show_progress: bool, +} + +impl RerankService { + pub fn new() -> Self { + RerankService { + model: Mutex::new(None), + show_progress: false, + } + } + + pub fn with_progress(mut self, show: bool) -> Self { + self.show_progress = show; + self + } + + /// Initialize with a specific model + pub fn init(&self, model_id: &str) -> Result<(), EmbeddingError> { + let reranker_model = ModelRegistry::find_rerank_model(model_id) + .ok_or_else(|| EmbeddingError::ModelNotFound(model_id.to_string()))?; + + let model = TextRerank::try_new( + fastembed::RerankInitOptions::new(reranker_model.clone()) + .with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + *guard = Some((reranker_model, model)); + + Ok(()) + } + + /// Rerank documents given a query + pub fn rerank( + &self, + query: &str, + documents: Vec, + return_documents: bool, + ) -> Result, EmbeddingError> { + if documents.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let (_, model) = guard + .as_mut() + .ok_or_else(|| EmbeddingError::ModelInitError("Model not initialized".to_string()))?; + + let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect(); + + let results = model + .rerank(query, doc_refs, return_documents, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string()))?; + + Ok(results + .into_iter() + .map(|r| { + let mut result = RerankResult::new(r.index, r.score); + if let Some(doc) = r.document { + result = result.with_document(doc); + } + result + }) + .collect()) + } + + /// Rerank and return top K results + pub fn rerank_top_k( + &self, + query: &str, + documents: Vec, + k: usize, + return_documents: bool, + ) -> Result, EmbeddingError> { + let mut results = self.rerank(query, documents, return_documents)?; + results.truncate(k); + Ok(results) + } + + /// Check if model is loaded + pub fn is_loaded(&self) -> bool { + self.model + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } + + /// Unload the current model + pub fn unload(&self) -> Result { + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + let was_loaded = guard.is_some(); + *guard = None; + Ok(was_loaded) + } + + /// Reload the model with a new model ID + pub fn reload(&self, model_id: &str) -> Result<(), EmbeddingError> { + self.unload()?; + self.init(model_id) + } + + /// Get the currently loaded model ID + pub fn loaded_model_id(&self) -> Option { + self.model + .lock() + .ok() + .and_then(|guard| { + guard.as_ref().map(|(model_enum, _)| { + TextRerank::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model_enum)) + .map(|info| info.model_code.to_string()) + .unwrap_or_default() + }) + }) + } + + /// Rerank documents using a specific model + pub fn rerank_with_model( + &self, + query: &str, + documents: Vec, + return_documents: bool, + model: RerankerModel, + ) -> Result, EmbeddingError> { + if documents.is_empty() { + return Err(EmbeddingError::EmptyInput); + } + + // Create rerank model + let mut reranker = TextRerank::try_new( + fastembed::RerankInitOptions::new(model) + .with_show_download_progress(self.show_progress), + ) + .map_err(|e| EmbeddingError::ModelInitError(e.to_string()))?; + + let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect(); + + let results = reranker + .rerank(query, doc_refs, return_documents, None) + .map_err(|e| EmbeddingError::EmbeddingFailed(e.to_string()))?; + + Ok(results + .into_iter() + .map(|r| { + let mut result = RerankResult::new(r.index, r.score); + if let Some(doc) = r.document { + result = result.with_document(doc); + } + result + }) + .collect()) + } +} + +impl Default for RerankService { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================ +// Unified Multi-Model Service +// ============================================================================ + +/// Unified service that can handle any model type +pub enum UnifiedModel { + Text(EmbeddingService), + Image(ImageEmbeddingService), + Sparse(SparseEmbeddingService), + Rerank(RerankService), +} + +/// Unified embedding service that auto-detects model type +pub struct UnifiedEmbeddingService { + model: Mutex>, + model_id: String, + category: ModelCategory, + show_progress: bool, +} + +impl UnifiedEmbeddingService { + /// Create a new unified service for the given model ID + pub fn new(model_id: &str) -> Result { + let category = Self::detect_category(model_id)?; + + Ok(UnifiedEmbeddingService { + model: Mutex::new(None), + model_id: model_id.to_string(), + category, + show_progress: false, + }) + } + + pub fn with_progress(mut self, show: bool) -> Self { + self.show_progress = show; + self + } + + /// Detect the model category from the model ID + fn detect_category(model_id: &str) -> Result { + if ModelRegistry::find_text_model(model_id).is_some() { + Ok(ModelCategory::TextEmbedding) + } else if ModelRegistry::find_image_model(model_id).is_some() { + Ok(ModelCategory::ImageEmbedding) + } else if ModelRegistry::find_sparse_model(model_id).is_some() { + Ok(ModelCategory::SparseTextEmbedding) + } else if ModelRegistry::find_rerank_model(model_id).is_some() { + Ok(ModelCategory::TextRerank) + } else { + Err(EmbeddingError::ModelNotFound(model_id.to_string())) + } + } + + /// Get the model category + pub fn category(&self) -> ModelCategory { + self.category + } + + /// Get the model ID + pub fn model_id(&self) -> &str { + &self.model_id + } + + /// Initialize the appropriate model + pub fn init(&self) -> Result<(), EmbeddingError> { + let mut guard = self + .model + .lock() + .map_err(|e| EmbeddingError::LockError(e.to_string()))?; + + if guard.is_some() { + return Ok(()); // Already initialized + } + + let model = match self.category { + ModelCategory::TextEmbedding => { + let service = EmbeddingService::new().with_progress(self.show_progress); + // Preload using the legacy type if possible + if let Some(model_type) = ModelType::from_str(&self.model_id) { + service.preload(model_type)?; + } + UnifiedModel::Text(service) + } + ModelCategory::ImageEmbedding => { + let service = ImageEmbeddingService::new().with_progress(self.show_progress); + service.init(&self.model_id)?; + UnifiedModel::Image(service) + } + ModelCategory::SparseTextEmbedding => { + let service = SparseEmbeddingService::new().with_progress(self.show_progress); + service.init(&self.model_id)?; + UnifiedModel::Sparse(service) + } + ModelCategory::TextRerank => { + let service = RerankService::new().with_progress(self.show_progress); + service.init(&self.model_id)?; + UnifiedModel::Rerank(service) + } + }; + + *guard = Some(model); + Ok(()) + } + + /// Check if the model is loaded + pub fn is_loaded(&self) -> bool { + self.model + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } +} + #[cfg(test)] mod tests { use super::*; @@ -165,4 +918,33 @@ mod tests { let result = service.embed(vec![], ModelType::BgeSmallEnV15); assert!(matches!(result, Err(EmbeddingError::EmptyInput))); } + + #[test] + fn test_image_service_creation() { + let service = ImageEmbeddingService::new(); + assert!(!service.is_loaded()); + } + + #[test] + fn test_sparse_service_creation() { + let service = SparseEmbeddingService::new(); + assert!(!service.is_loaded()); + } + + #[test] + fn test_rerank_service_creation() { + let service = RerankService::new(); + assert!(!service.is_loaded()); + } + + #[test] + fn test_sparse_embedding_to_dense() { + let sparse = SparseEmbedding::new(vec![0, 2, 5], vec![1.0, 2.0, 3.0]); + let dense = sparse.to_dense(10); + assert_eq!(dense.len(), 10); + assert_eq!(dense[0], 1.0); + assert_eq!(dense[2], 2.0); + assert_eq!(dense[5], 3.0); + assert_eq!(dense[1], 0.0); + } } diff --git a/src/core/image_utils.rs b/src/core/image_utils.rs new file mode 100644 index 0000000..a1f2d16 --- /dev/null +++ b/src/core/image_utils.rs @@ -0,0 +1,287 @@ +// Image loading utilities for image embedding support + +use crate::core::types::ImageInput; +use base64::Engine; +use std::path::Path; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ImageError { + #[error("Failed to decode base64 image: {0}")] + Base64DecodeError(String), + + #[error("Failed to load image from file: {0}")] + FileLoadError(String), + + #[error("Failed to load image from S3: {0}")] + S3LoadError(String), + + #[error("Invalid image format: {0}")] + InvalidFormat(String), + + #[error("Image input is empty or invalid")] + EmptyInput, + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), +} + +/// Load image bytes from various input sources +pub fn load_image_bytes(input: &ImageInput) -> Result, ImageError> { + match input { + ImageInput::Base64 { base64 } => decode_base64_image(base64), + ImageInput::FilePath { path } => load_image_from_file(path), + ImageInput::S3Path { .. } => { + // S3 loading requires async, so we return an error here + // The async version should be used for S3 + Err(ImageError::S3LoadError( + "Use load_image_bytes_async for S3 paths".to_string(), + )) + } + } +} + +/// Decode base64 encoded image data +pub fn decode_base64_image(data: &str) -> Result, ImageError> { + // Handle data URL format (e.g., "data:image/png;base64,...") + let base64_data = if data.contains(",") { + data.split(",").last().unwrap_or(data) + } else { + data + }; + + // Remove whitespace + let cleaned: String = base64_data.chars().filter(|c| !c.is_whitespace()).collect(); + + if cleaned.is_empty() { + return Err(ImageError::EmptyInput); + } + + base64::engine::general_purpose::STANDARD + .decode(&cleaned) + .map_err(|e| ImageError::Base64DecodeError(e.to_string())) +} + +/// Load image from local file path +pub fn load_image_from_file(path: &str) -> Result, ImageError> { + let path = Path::new(path); + + if !path.exists() { + return Err(ImageError::FileLoadError(format!( + "File not found: {}", + path.display() + ))); + } + + std::fs::read(path).map_err(|e| ImageError::FileLoadError(e.to_string())) +} + +/// Load multiple images from various input sources (sync, no S3) +pub fn load_images_bytes(inputs: &[ImageInput]) -> Result>, ImageError> { + inputs.iter().map(load_image_bytes).collect() +} + +// ============================================================================ +// Async S3 support (requires aws feature) +// ============================================================================ + +#[cfg(feature = "aws")] +pub mod s3 { + use super::*; + use aws_sdk_s3::Client as S3Client; + + /// Load image bytes from S3 + pub async fn load_image_from_s3( + client: &S3Client, + bucket: &str, + key: &str, + ) -> Result, ImageError> { + let response = client + .get_object() + .bucket(bucket) + .key(key) + .send() + .await + .map_err(|e| ImageError::S3LoadError(e.to_string()))?; + + let bytes = response + .body + .collect() + .await + .map_err(|e| ImageError::S3LoadError(e.to_string()))? + .into_bytes() + .to_vec(); + + Ok(bytes) + } + + /// Parse S3 path in format "bucket/key" or "s3://bucket/key" + pub fn parse_s3_path(path: &str) -> Result<(String, String), ImageError> { + let path = path.strip_prefix("s3://").unwrap_or(path); + + let parts: Vec<&str> = path.splitn(2, '/').collect(); + if parts.len() != 2 { + return Err(ImageError::S3LoadError(format!( + "Invalid S3 path format: {}. Expected: bucket/key", + path + ))); + } + + Ok((parts[0].to_string(), parts[1].to_string())) + } + + /// Load image bytes from ImageInput (async, supports S3) + pub async fn load_image_bytes_async( + input: &ImageInput, + s3_client: Option<&S3Client>, + ) -> Result, ImageError> { + match input { + ImageInput::Base64 { base64 } => decode_base64_image(base64), + ImageInput::FilePath { path } => load_image_from_file(path), + ImageInput::S3Path { s3_path } => { + let client = s3_client.ok_or_else(|| { + ImageError::S3LoadError("S3 client not provided".to_string()) + })?; + let (bucket, key) = parse_s3_path(s3_path)?; + load_image_from_s3(client, &bucket, &key).await + } + } + } + + /// Load multiple images (async, supports S3) + pub async fn load_images_bytes_async( + inputs: &[ImageInput], + s3_client: Option<&S3Client>, + ) -> Result>, ImageError> { + let mut results = Vec::with_capacity(inputs.len()); + for input in inputs { + results.push(load_image_bytes_async(input, s3_client).await?); + } + Ok(results) + } +} + +// ============================================================================ +// Image validation utilities +// ============================================================================ + +/// Check if bytes look like a valid image based on magic bytes +pub fn is_valid_image_bytes(bytes: &[u8]) -> bool { + if bytes.len() < 8 { + return false; + } + + // PNG: 89 50 4E 47 0D 0A 1A 0A + if bytes.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) { + return true; + } + + // JPEG: FF D8 FF + if bytes.starts_with(&[0xFF, 0xD8, 0xFF]) { + return true; + } + + // GIF: GIF87a or GIF89a + if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") { + return true; + } + + // WebP: RIFF....WEBP + if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" { + return true; + } + + // BMP: BM + if bytes.starts_with(b"BM") { + return true; + } + + false +} + +/// Get image format from magic bytes +pub fn detect_image_format(bytes: &[u8]) -> Option<&'static str> { + if bytes.len() < 8 { + return None; + } + + if bytes.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) { + return Some("png"); + } + if bytes.starts_with(&[0xFF, 0xD8, 0xFF]) { + return Some("jpeg"); + } + if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") { + return Some("gif"); + } + if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" { + return Some("webp"); + } + if bytes.starts_with(b"BM") { + return Some("bmp"); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode_base64_simple() { + // A minimal valid PNG (1x1 transparent pixel) + let png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + let result = decode_base64_image(png_base64); + assert!(result.is_ok()); + + let bytes = result.unwrap(); + assert!(is_valid_image_bytes(&bytes)); + assert_eq!(detect_image_format(&bytes), Some("png")); + } + + #[test] + fn test_decode_base64_data_url() { + let data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + let result = decode_base64_image(data_url); + assert!(result.is_ok()); + } + + #[test] + fn test_empty_base64() { + let result = decode_base64_image(""); + assert!(matches!(result, Err(ImageError::EmptyInput))); + } + + #[test] + fn test_invalid_base64() { + let result = decode_base64_image("not-valid-base64!!!"); + assert!(matches!(result, Err(ImageError::Base64DecodeError(_)))); + } + + #[test] + fn test_image_input_constructors() { + let base64 = ImageInput::from_base64("test".to_string()); + assert!(matches!(base64, ImageInput::Base64 { .. })); + + let path = ImageInput::from_path("/path/to/image.png"); + assert!(matches!(path, ImageInput::FilePath { .. })); + + let s3 = ImageInput::from_s3("bucket/key"); + assert!(matches!(s3, ImageInput::S3Path { .. })); + } + + #[cfg(feature = "aws")] + #[test] + fn test_parse_s3_path() { + use s3::parse_s3_path; + + let (bucket, key) = parse_s3_path("my-bucket/path/to/image.png").unwrap(); + assert_eq!(bucket, "my-bucket"); + assert_eq!(key, "path/to/image.png"); + + let (bucket, key) = parse_s3_path("s3://my-bucket/image.jpg").unwrap(); + assert_eq!(bucket, "my-bucket"); + assert_eq!(key, "image.jpg"); + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 19199fb..bcc579e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -6,6 +6,7 @@ pub mod similarity; pub mod clustering; pub mod chunking; pub mod types; +pub mod image_utils; #[cfg(feature = "pdf")] pub mod pdf; @@ -13,8 +14,30 @@ pub mod pdf; // Re-export model types pub use model::{ModelType, ModelInfo, ModelRegistry, ModelCategory, TextModel, MODEL_REGISTRY}; -// Re-export embedding service -pub use embeddings::{EmbeddingService, EmbeddingError, global_service, embed, embed_one}; +// Re-export embedding services +pub use embeddings::{ + // Text embedding (existing) + EmbeddingService, EmbeddingError, global_service, embed, embed_one, + // Image embedding (new) + ImageEmbeddingService, + // Sparse embedding (new) + SparseEmbeddingService, + // Reranking (new) + RerankService, + // Unified service (new) + UnifiedEmbeddingService, UnifiedModel, +}; + +// Re-export image utilities +pub use image_utils::{ + ImageError, load_image_bytes, decode_base64_image, load_image_from_file, + load_images_bytes, is_valid_image_bytes, detect_image_format, +}; + +#[cfg(feature = "aws")] +pub use image_utils::s3::{ + load_image_from_s3, parse_s3_path, load_image_bytes_async, load_images_bytes_async, +}; // Re-export similarity functions pub use similarity::{ @@ -43,6 +66,12 @@ pub use types::{ EmbeddingOutput, SearchResult, SearchResponse, ClusterInfo, ClusterMember, ClusterResponse, DistanceMatrixResponse, BenchmarkResult, + // Image types (new) + ImageInput, ImageEmbeddingOutput, + // Sparse types (new) + SparseEmbedding, SparseEmbeddingOutput, + // Rerank types (new) + RerankResult, RerankOutput, }; // Re-export PDF utilities (when feature enabled) diff --git a/src/core/types.rs b/src/core/types.rs index b93ffba..eab1670 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -2,6 +2,201 @@ use serde::{Deserialize, Serialize}; +// ============================================================================ +// Image Embedding Types +// ============================================================================ + +/// Image input - supports both base64 encoded data and file paths +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ImageInput { + /// Base64 encoded image data + Base64 { base64: String }, + /// Path to image file (local or S3) + FilePath { path: String }, + /// S3 path to image + S3Path { s3_path: String }, +} + +impl ImageInput { + /// Create from base64 string + pub fn from_base64(data: String) -> Self { + ImageInput::Base64 { base64: data } + } + + /// Create from file path + pub fn from_path(path: impl Into) -> Self { + ImageInput::FilePath { path: path.into() } + } + + /// Create from S3 path + pub fn from_s3(s3_path: impl Into) -> Self { + ImageInput::S3Path { s3_path: s3_path.into() } + } +} + +/// Image embedding output +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageEmbeddingOutput { + /// The embedding vectors + pub embeddings: Vec>, + /// Dimension of each embedding + pub dimension: usize, + /// Number of embeddings + pub count: usize, + /// Model used + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +impl ImageEmbeddingOutput { + pub fn new(embeddings: Vec>) -> Self { + let dimension = embeddings.first().map(|e| e.len()).unwrap_or(0); + let count = embeddings.len(); + ImageEmbeddingOutput { + embeddings, + dimension, + count, + model: None, + } + } + + pub fn with_model(mut self, model: &str) -> Self { + self.model = Some(model.to_string()); + self + } +} + +// ============================================================================ +// Sparse Embedding Types +// ============================================================================ + +/// Sparse embedding with indices and values +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SparseEmbedding { + /// Token indices with non-zero values + pub indices: Vec, + /// Values at those indices + pub values: Vec, +} + +impl SparseEmbedding { + pub fn new(indices: Vec, values: Vec) -> Self { + SparseEmbedding { indices, values } + } + + /// Get the number of non-zero elements + pub fn nnz(&self) -> usize { + self.indices.len() + } + + /// Convert to dense vector of given size + pub fn to_dense(&self, size: usize) -> Vec { + let mut dense = vec![0.0; size]; + for (idx, val) in self.indices.iter().zip(self.values.iter()) { + if *idx < size { + dense[*idx] = *val; + } + } + dense + } +} + +/// Sparse embedding output with metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SparseEmbeddingOutput { + /// Sparse embeddings + pub embeddings: Vec, + /// Number of embeddings + pub count: usize, + /// Model used + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +impl SparseEmbeddingOutput { + pub fn new(embeddings: Vec) -> Self { + let count = embeddings.len(); + SparseEmbeddingOutput { + embeddings, + count, + model: None, + } + } + + pub fn with_model(mut self, model: &str) -> Self { + self.model = Some(model.to_string()); + self + } +} + +// ============================================================================ +// Reranking Types +// ============================================================================ + +/// Single rerank result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankResult { + /// Original index in the documents array + pub index: usize, + /// Relevance score (higher is more relevant) + pub score: f32, + /// The document text + #[serde(skip_serializing_if = "Option::is_none")] + pub document: Option, +} + +impl RerankResult { + pub fn new(index: usize, score: f32) -> Self { + RerankResult { + index, + score, + document: None, + } + } + + pub fn with_document(mut self, document: String) -> Self { + self.document = Some(document); + self + } +} + +/// Rerank response with all results +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankOutput { + /// The query used for reranking + pub query: String, + /// Ranked results (sorted by score, descending) + pub results: Vec, + /// Total number of documents reranked + pub count: usize, + /// Model used + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +impl RerankOutput { + pub fn new(query: String, results: Vec) -> Self { + let count = results.len(); + RerankOutput { + query, + results, + count, + model: None, + } + } + + pub fn with_model(mut self, model: &str) -> Self { + self.model = Some(model.to_string()); + self + } + + /// Get top K results + pub fn top_k(&self, k: usize) -> Vec<&RerankResult> { + self.results.iter().take(k).collect() + } +} + /// Embedding output with metadata #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddingOutput { diff --git a/src/lambda.rs b/src/lambda.rs index e7f9742..f7eb758 100644 --- a/src/lambda.rs +++ b/src/lambda.rs @@ -1,39 +1,110 @@ // Lambda-specific handler and AWS integration -use crate::core::model::ModelRegistry; -use crate::core::{EmbeddingService, ModelType}; +use crate::core::model::{ModelCategory, ModelRegistry}; +use crate::core::{ + EmbeddingService, ImageEmbeddingService, ImageInput, RerankService, SparseEmbeddingService, +}; use aws_config; use aws_sdk_s3 as s3; use lambda_runtime::{Error, LambdaEvent}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::{json, Value}; use std::env; use std::sync::LazyLock; -/// Lambda request structure -#[derive(Serialize, Deserialize, Clone)] +// ============================================================================ +// Request Types +// ============================================================================ + +/// Lambda request structure - supports all model types +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Request { + // Text embedding input (existing) pub messages: Option>, + + // Image embedding input (new) + pub images: Option>, + + // Reranking input (new) + pub query: Option, + pub documents: Option>, + + // S3 input (existing, extended for images) pub s3_file: Option, + pub s3_images: Option>, + + // S3 output (existing) pub save_to_s3: Option, + + // Optional top_k for reranking + pub top_k: Option, + + // Include documents in rerank response + pub return_documents: Option, +} + +/// Image input for request - base64 or S3 path +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ImageInputRequest { + pub base64: Option, + pub s3_path: Option, } /// S3 save configuration -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct SaveConfig { pub bucket: String, pub key: String, } -/// Lambda response structure +// ============================================================================ +// Response Types +// ============================================================================ + +/// Lambda response structure - supports all model types #[derive(Serialize, Deserialize, Debug)] pub struct Response { - pub embeddings: Vec>, - pub dimension: usize, + // Dense embeddings (text/image) + #[serde(skip_serializing_if = "Option::is_none")] + pub embeddings: Option>>, + + // Sparse embeddings + #[serde(skip_serializing_if = "Option::is_none")] + pub sparse_embeddings: Option>, + + // Rerank results + #[serde(skip_serializing_if = "Option::is_none")] + pub rankings: Option>, + + // Metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub dimension: Option, + + pub model_type: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub count: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub s3_location: Option, } +/// Sparse embedding response format +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct SparseEmbeddingResponse { + pub indices: Vec, + pub values: Vec, +} + +/// Rerank response format +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RerankResponse { + pub index: usize, + pub score: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub document: Option, +} + /// API Gateway response wrapper #[derive(Serialize, Deserialize)] pub struct ApiGatewayResponse { @@ -43,32 +114,51 @@ pub struct ApiGatewayResponse { pub body: String, } -// Global embedding service for Lambda (initialized once per cold start) -static EMBEDDING_SERVICE: LazyLock = LazyLock::new(|| EmbeddingService::new()); +// ============================================================================ +// Global Services (initialized once per cold start) +// ============================================================================ -/// Get the model type from environment or default -fn get_model_type() -> ModelType { - let model_str = env::var("MODEL_ID").unwrap_or_else(|_| "Xenova/bge-small-zh-v1.5".to_string()); +// Text embedding service +static TEXT_SERVICE: LazyLock = LazyLock::new(|| EmbeddingService::new()); - // Try to parse as legacy ModelType first - if let Some(model_type) = ModelType::from_str(&model_str) { - return model_type; - } +// Image embedding service +static IMAGE_SERVICE: LazyLock = + LazyLock::new(|| ImageEmbeddingService::new()); - // Check if it's a valid model in fastembed's registry - if ModelRegistry::find_text_model(&model_str).is_some() { - eprintln!( - "Note: Model '{}' is supported by fastembed but not in legacy ModelType. Using BGE-Small as fallback.", - model_str - ); +// Sparse embedding service +static SPARSE_SERVICE: LazyLock = + LazyLock::new(|| SparseEmbeddingService::new()); + +// Rerank service +static RERANK_SERVICE: LazyLock = LazyLock::new(|| RerankService::new()); + +// ============================================================================ +// Model Detection +// ============================================================================ + +/// Get the model ID from environment +fn get_model_id() -> String { + env::var("MODEL_ID").unwrap_or_else(|_| "Xenova/bge-small-en-v1.5".to_string()) +} + +/// Detect model category from MODEL_ID environment variable +fn detect_model_category(model_id: &str) -> ModelCategory { + if ModelRegistry::find_text_model(model_id).is_some() { + ModelCategory::TextEmbedding + } else if ModelRegistry::find_image_model(model_id).is_some() { + ModelCategory::ImageEmbedding + } else if ModelRegistry::find_sparse_model(model_id).is_some() { + ModelCategory::SparseTextEmbedding + } else if ModelRegistry::find_rerank_model(model_id).is_some() { + ModelCategory::TextRerank } else { + // Default to text embedding for unknown models eprintln!( - "Warning: Unknown model '{}', falling back to BGE-Small", - model_str + "Warning: Unknown model '{}', defaulting to TextEmbedding category", + model_id ); + ModelCategory::TextEmbedding } - - ModelType::default() } async fn read_from_s3(s3_client: &s3::Client, s3_path: &str) -> Result { @@ -135,9 +225,27 @@ fn create_api_gateway_response(status_code: u16, body: Value) -> ApiGatewayRespo } } +// ============================================================================ +// Request Processing - Routes to appropriate handler based on model type +// ============================================================================ + async fn process_request(request: Request) -> Result { - let model_type = get_model_type(); + let model_id = get_model_id(); + let category = detect_model_category(&model_id); + + match category { + ModelCategory::TextEmbedding => process_text_embedding(request, &model_id).await, + ModelCategory::ImageEmbedding => process_image_embedding(request, &model_id).await, + ModelCategory::SparseTextEmbedding => process_sparse_embedding(request, &model_id).await, + ModelCategory::TextRerank => process_rerank(request, &model_id).await, + } +} + +// ============================================================================ +// Text Embedding Handler +// ============================================================================ +async fn process_text_embedding(request: Request, model_id: &str) -> Result { // Get texts to embed let texts = if let Some(messages) = &request.messages { if messages.is_empty() { @@ -155,15 +263,20 @@ async fn process_request(request: Request) -> Result { Err(_) => vec![content], } } else { - return Err("Either 'messages' or 's3_file' must be provided".into()); + return Err("Either 'messages' or 's3_file' must be provided for text embedding".into()); }; + // Get the fastembed model enum + let model = ModelRegistry::find_text_model(model_id) + .ok_or_else(|| format!("Unknown text model: {}", model_id))?; + // Generate embeddings using shared service - let embeddings = EMBEDDING_SERVICE - .embed(texts, model_type) - .map_err(|e| format!("Embedding failed: {}", e))?; + let embeddings = TEXT_SERVICE + .embed_with_model(texts, model) + .map_err(|e| format!("Text embedding failed: {}", e))?; let dimension = embeddings.first().map(|e| e.len()).unwrap_or(0); + let count = embeddings.len(); // Optionally save to S3 let s3_location = if let Some(save_config) = &request.save_to_s3 { @@ -184,12 +297,224 @@ async fn process_request(request: Request) -> Result { }; Ok(Response { - embeddings, - dimension, + embeddings: Some(embeddings), + sparse_embeddings: None, + rankings: None, + dimension: Some(dimension), + model_type: "text".to_string(), + count: Some(count), s3_location, }) } +// ============================================================================ +// Image Embedding Handler +// ============================================================================ + +async fn process_image_embedding(request: Request, model_id: &str) -> Result { + use crate::core::image_utils::s3::load_image_bytes_async; + + // Check if we need S3 client (any S3 paths in request) + let has_s3_paths = request.s3_images.is_some() + || request + .images + .as_ref() + .map(|imgs| imgs.iter().any(|i| i.s3_path.is_some())) + .unwrap_or(false); + + let s3_client = if has_s3_paths { + let config = aws_config::load_from_env().await; + Some(s3::Client::new(&config)) + } else { + None + }; + + // Collect image inputs + let mut image_inputs: Vec = Vec::new(); + + // From direct image inputs + if let Some(images) = &request.images { + for img in images { + if let Some(base64_data) = &img.base64 { + image_inputs.push(ImageInput::from_base64(base64_data.clone())); + } else if let Some(s3_path) = &img.s3_path { + image_inputs.push(ImageInput::from_s3(s3_path.clone())); + } + } + } + + // From S3 image paths + if let Some(s3_images) = &request.s3_images { + for s3_path in s3_images { + image_inputs.push(ImageInput::from_s3(s3_path.clone())); + } + } + + if image_inputs.is_empty() { + return Err("Either 'images' or 's3_images' must be provided for image embedding".into()); + } + + // Load all image bytes + let mut image_bytes_list: Vec> = Vec::new(); + for input in &image_inputs { + let bytes = load_image_bytes_async(input, s3_client.as_ref()) + .await + .map_err(|e| format!("Failed to load image: {}", e))?; + image_bytes_list.push(bytes); + } + + // Get the fastembed model enum + let model = ModelRegistry::find_image_model(model_id) + .ok_or_else(|| format!("Unknown image model: {}", model_id))?; + + // Generate embeddings + let embeddings = IMAGE_SERVICE + .embed_images_with_model(&image_bytes_list, model) + .map_err(|e| format!("Image embedding failed: {}", e))?; + + let dimension = embeddings.first().map(|e| e.len()).unwrap_or(0); + let count = embeddings.len(); + + // Optionally save to S3 + let s3_location = if let Some(save_config) = &request.save_to_s3 { + let client = if let Some(ref c) = s3_client { + c.clone() + } else { + let config = aws_config::load_from_env().await; + s3::Client::new(&config) + }; + + save_to_s3(&client, &save_config.bucket, &save_config.key, &embeddings).await?; + + Some(format!("s3://{}/{}", save_config.bucket, save_config.key)) + } else { + None + }; + + Ok(Response { + embeddings: Some(embeddings), + sparse_embeddings: None, + rankings: None, + dimension: Some(dimension), + model_type: "image".to_string(), + count: Some(count), + s3_location, + }) +} + +// ============================================================================ +// Sparse Text Embedding Handler +// ============================================================================ + +async fn process_sparse_embedding(request: Request, model_id: &str) -> Result { + // Get texts to embed + let texts = if let Some(messages) = &request.messages { + if messages.is_empty() { + return Err("Messages array cannot be empty".into()); + } + messages.clone() + } else if let Some(s3_path) = &request.s3_file { + let config = aws_config::load_from_env().await; + let s3_client = s3::Client::new(&config); + let content = read_from_s3(&s3_client, s3_path).await?; + + match serde_json::from_str::>(&content) { + Ok(texts) => texts, + Err(_) => vec![content], + } + } else { + return Err("Either 'messages' or 's3_file' must be provided for sparse embedding".into()); + }; + + // Get the fastembed model enum + let model = ModelRegistry::find_sparse_model(model_id) + .ok_or_else(|| format!("Unknown sparse model: {}", model_id))?; + + // Generate sparse embeddings + let sparse_embeddings = SPARSE_SERVICE + .embed_with_model(texts, model) + .map_err(|e| format!("Sparse embedding failed: {}", e))?; + + let count = sparse_embeddings.len(); + + // Convert to response format + let sparse_responses: Vec = sparse_embeddings + .into_iter() + .map(|se| SparseEmbeddingResponse { + indices: se.indices, + values: se.values, + }) + .collect(); + + Ok(Response { + embeddings: None, + sparse_embeddings: Some(sparse_responses), + rankings: None, + dimension: None, + model_type: "sparse".to_string(), + count: Some(count), + s3_location: None, + }) +} + +// ============================================================================ +// Reranking Handler +// ============================================================================ + +async fn process_rerank(request: Request, model_id: &str) -> Result { + let query = request + .query + .ok_or("'query' is required for reranking")?; + + let documents = request + .documents + .ok_or("'documents' is required for reranking")?; + + if documents.is_empty() { + return Err("'documents' array cannot be empty".into()); + } + + let return_documents = request.return_documents.unwrap_or(true); + + // Get the fastembed model enum + let model = ModelRegistry::find_rerank_model(model_id) + .ok_or_else(|| format!("Unknown rerank model: {}", model_id))?; + + // Perform reranking + let results = RERANK_SERVICE + .rerank_with_model(&query, documents.clone(), return_documents, model) + .map_err(|e| format!("Reranking failed: {}", e))?; + + // Apply top_k if specified + let results = if let Some(top_k) = request.top_k { + results.into_iter().take(top_k).collect() + } else { + results + }; + + let count = results.len(); + + // Convert to response format + let rankings: Vec = results + .into_iter() + .map(|r| RerankResponse { + index: r.index, + score: r.score, + document: r.document, + }) + .collect(); + + Ok(Response { + embeddings: None, + sparse_embeddings: None, + rankings: Some(rankings), + dimension: None, + model_type: "rerank".to_string(), + count: Some(count), + s3_location: None, + }) +} + /// Main Lambda handler pub async fn handler(event: LambdaEvent) -> Result { if is_api_gateway_request(&event) { diff --git a/src/lib.rs b/src/lib.rs index 04eae15..4205233 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,19 @@ mod lambda; pub use core::{ // Model types ModelType, ModelInfo, ModelRegistry, ModelCategory, TextModel, MODEL_REGISTRY, - // Embedding service + // Text embedding service EmbeddingService, EmbeddingError, global_service, embed, embed_one, + // Image embedding service + ImageEmbeddingService, + // Sparse embedding service + SparseEmbeddingService, + // Reranking service + RerankService, + // Unified service + UnifiedEmbeddingService, UnifiedModel, + // Image utilities + ImageError, load_image_bytes, decode_base64_image, load_image_from_file, + load_images_bytes, is_valid_image_bytes, detect_image_format, // Similarity functions cosine_similarity, dot_product, euclidean_distance, l2_normalize, magnitude, pairwise_similarity_matrix, @@ -26,10 +37,16 @@ pub use core::{ chunk_text, chunk_text_with_config, chunk_text_detailed, chunk_by_sentences, estimate_tokens, ChunkConfig, ChunkResult, - // Types + // Text embedding types EmbeddingOutput, SearchResult, SearchResponse, ClusterInfo, ClusterMember, ClusterResponse, DistanceMatrixResponse, BenchmarkResult, + // Image embedding types + ImageInput, ImageEmbeddingOutput, + // Sparse embedding types + SparseEmbedding, SparseEmbeddingOutput, + // Rerank types + RerankResult, RerankOutput, }; // Re-export PDF utilities (when feature enabled) @@ -43,4 +60,7 @@ pub use core::{ // Re-export Lambda handler (when feature enabled) #[cfg(feature = "aws")] -pub use lambda::{handler, Request, Response, SaveConfig, ApiGatewayResponse}; +pub use lambda::{ + handler, Request, Response, SaveConfig, ApiGatewayResponse, + ImageInputRequest, SparseEmbeddingResponse, RerankResponse, +}; diff --git a/tests/integration_image_tests.rs b/tests/integration_image_tests.rs new file mode 100644 index 0000000..1fe9d83 --- /dev/null +++ b/tests/integration_image_tests.rs @@ -0,0 +1,323 @@ +// tests/integration_image_tests.rs +// +// Image embedding integration tests +// MODEL_ID is set to an image embedding model + +use aws_config::{BehaviorVersion, Region}; +use aws_sdk_s3 as s3; +use aws_sdk_s3::config::Credentials; +use base64::{engine::general_purpose::STANDARD, Engine}; +use lambda_runtime::{Context, LambdaEvent}; +use serde_json::Value; +use serverless_vectorizer::{handler, ImageInputRequest, Request, Response}; + +const MODEL_ID: &str = "Qdrant/clip-ViT-B-32-vision"; +const TEST_BUCKET: &str = "test-images-bucket"; + +// Setup environment for tests +fn setup_env() { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + unsafe { + std::env::set_var("AWS_ENDPOINT_URL", &endpoint_url); + std::env::set_var("AWS_ACCESS_KEY_ID", "test"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "test"); + std::env::set_var("AWS_REGION", "us-east-1"); + std::env::set_var("MODEL_ID", MODEL_ID); + } + + // Print diagnostic info for CI debugging + eprintln!("[TEST] MODEL_ID set to: {}", MODEL_ID); + eprintln!("[TEST] AWS_ENDPOINT_URL: {}", endpoint_url); +} + +async fn create_s3_client() -> s3::Client { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + let creds = Credentials::new("test", "test", None, None, "test"); + + let config = aws_config::defaults(BehaviorVersion::latest()) + .region(Region::new("us-east-1")) + .credentials_provider(creds) + .endpoint_url(&endpoint_url) + .load() + .await; + + let s3_config = s3::config::Builder::from(&config) + .force_path_style(true) + .build(); + + s3::Client::from_conf(s3_config) +} + +fn create_lambda_event(request: Request) -> LambdaEvent { + let payload = serde_json::to_value(request).expect("Failed to serialize request"); + LambdaEvent { + payload, + context: Context::default(), + } +} + +fn create_image_request(images: Option>, s3_images: Option>) -> Request { + Request { + messages: None, + images, + query: None, + documents: None, + s3_file: None, + s3_images, + save_to_s3: None, + top_k: None, + return_documents: None, + } +} + +fn parse_response(value: Value) -> Result { + serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) +} + +async fn ensure_bucket_exists(s3_client: &s3::Client, bucket: &str) -> Result<(), Box> { + match s3_client.head_bucket().bucket(bucket).send().await { + Ok(_) => Ok(()), + Err(_) => { + s3_client.create_bucket().bucket(bucket).send().await?; + Ok(()) + } + } +} + +async fn upload_binary_to_s3( + s3_client: &s3::Client, + bucket: &str, + key: &str, + content: Vec, + content_type: &str, +) -> Result<(), Box> { + s3_client + .put_object() + .bucket(bucket) + .key(key) + .body(content.into()) + .content_type(content_type) + .send() + .await?; + Ok(()) +} + +// Create a minimal valid PNG image (1x1 transparent pixel) +fn create_test_png() -> Vec { + base64::engine::general_purpose::STANDARD + .decode("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==") + .expect("Failed to decode test PNG") +} + +#[tokio::test] +async fn test_image_embedding_base64() { + setup_env(); + + let png_bytes = create_test_png(); + eprintln!("[TEST] Created test PNG with {} bytes", png_bytes.len()); + + // Verify PNG magic bytes + if png_bytes.len() >= 8 { + eprintln!("[TEST] PNG header: {:02X?}", &png_bytes[..8]); + } + + let base64_image = STANDARD.encode(&png_bytes); + eprintln!("[TEST] Base64 encoded length: {}", base64_image.len()); + + let request = create_image_request( + Some(vec![ImageInputRequest { + base64: Some(base64_image), + s3_path: None, + }]), + None, + ); + + eprintln!("[TEST] Calling handler..."); + let event = create_lambda_event(request); + let result = handler(event).await; + + match &result { + Ok(value) => { + eprintln!("[TEST] Handler succeeded"); + eprintln!("[TEST] Response: {}", serde_json::to_string_pretty(value).unwrap_or_default()); + }, + Err(e) => { + eprintln!("[TEST] Handler error: {:?}", e); + eprintln!("[TEST] Error details: {}", e); + } + } + assert!(result.is_ok(), "Handler should succeed for image embedding: {:?}", result.err()); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "image", "Should be image model type, got: {}", response.model_type); + assert!(response.embeddings.is_some(), "Should have embeddings"); + + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 1, "Should have one embedding"); + assert!(!embeddings[0].is_empty(), "Embedding should not be empty"); + assert_eq!(response.dimension.unwrap_or(0), 512, "CLIP ViT-B/32 has 512 dimensions"); +} + +#[tokio::test] +async fn test_image_embedding_from_s3() { + setup_env(); + let s3_client = create_s3_client().await; + ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); + + let png_bytes = create_test_png(); + let s3_key = "test-images/test.png"; + + upload_binary_to_s3(&s3_client, TEST_BUCKET, s3_key, png_bytes, "image/png") + .await + .unwrap(); + + let request = create_image_request( + None, + Some(vec![format!("{}/{}", TEST_BUCKET, s3_key)]), + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for S3 image"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "image"); + assert!(response.embeddings.is_some()); +} + +#[tokio::test] +async fn test_image_embedding_multiple_base64() { + setup_env(); + + let png_bytes = create_test_png(); + let base64_image = STANDARD.encode(&png_bytes); + + let request = create_image_request( + Some(vec![ + ImageInputRequest { + base64: Some(base64_image.clone()), + s3_path: None, + }, + ImageInputRequest { + base64: Some(base64_image), + s3_path: None, + }, + ]), + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for multiple images"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "image"); + + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 2, "Should have 2 embeddings"); +} + +#[tokio::test] +async fn test_image_embedding_mixed_input() { + setup_env(); + let s3_client = create_s3_client().await; + ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); + + let png_bytes = create_test_png(); + let s3_key = "test-images/mixed-test.png"; + + upload_binary_to_s3(&s3_client, TEST_BUCKET, s3_key, png_bytes.clone(), "image/png") + .await + .unwrap(); + + let base64_image = STANDARD.encode(&png_bytes); + + let request = Request { + messages: None, + images: Some(vec![ImageInputRequest { + base64: Some(base64_image), + s3_path: None, + }]), + query: None, + documents: None, + s3_file: None, + s3_images: Some(vec![format!("{}/{}", TEST_BUCKET, s3_key)]), + save_to_s3: None, + top_k: None, + return_documents: None, + }; + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for mixed input"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "image"); + + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 2, "Should have embeddings for both images"); +} + +#[tokio::test] +async fn test_image_embedding_empty_input() { + setup_env(); + + let request = create_image_request(None, None); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Should fail with no images"); +} + +#[tokio::test] +async fn test_image_embedding_consistency() { + setup_env(); + + let png_bytes = create_test_png(); + let base64_image = STANDARD.encode(&png_bytes); + + let request1 = create_image_request( + Some(vec![ImageInputRequest { + base64: Some(base64_image.clone()), + s3_path: None, + }]), + None, + ); + + let request2 = create_image_request( + Some(vec![ImageInputRequest { + base64: Some(base64_image), + s3_path: None, + }]), + None, + ); + + let event1 = create_lambda_event(request1); + let event2 = create_lambda_event(request2); + + let result1 = handler(event1).await; + let result2 = handler(event2).await; + + assert!(result1.is_ok() && result2.is_ok(), "Both requests should succeed"); + + let response1 = parse_response(result1.unwrap()).unwrap(); + let response2 = parse_response(result2.unwrap()).unwrap(); + + let emb1 = response1.embeddings.unwrap(); + let emb2 = response2.embeddings.unwrap(); + + let match_result = emb1[0] + .iter() + .zip(emb2[0].iter()) + .all(|(a, b)| (a - b).abs() < 1e-6); + + assert!(match_result, "Same image should produce identical embeddings"); +} diff --git a/tests/integration_rerank_tests.rs b/tests/integration_rerank_tests.rs new file mode 100644 index 0000000..cd5c760 --- /dev/null +++ b/tests/integration_rerank_tests.rs @@ -0,0 +1,300 @@ +// tests/integration_rerank_tests.rs +// +// Reranking integration tests +// MODEL_ID is set to a reranking model + +use lambda_runtime::{Context, LambdaEvent}; +use serde_json::Value; +use serverless_vectorizer::{handler, Request, Response}; + +const MODEL_ID: &str = "BAAI/bge-reranker-base"; + +// Setup environment for tests +fn setup_env() { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + unsafe { + std::env::set_var("AWS_ENDPOINT_URL", &endpoint_url); + std::env::set_var("AWS_ACCESS_KEY_ID", "test"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "test"); + std::env::set_var("AWS_REGION", "us-east-1"); + std::env::set_var("MODEL_ID", MODEL_ID); + } +} + +fn create_lambda_event(request: Request) -> LambdaEvent { + let payload = serde_json::to_value(request).expect("Failed to serialize request"); + LambdaEvent { + payload, + context: Context::default(), + } +} + +fn create_rerank_request(query: String, documents: Vec, top_k: Option) -> Request { + Request { + messages: None, + images: None, + query: Some(query), + documents: Some(documents), + s3_file: None, + s3_images: None, + save_to_s3: None, + top_k, + return_documents: Some(true), + } +} + +fn parse_response(value: Value) -> Result { + serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) +} + +#[tokio::test] +async fn test_rerank_basic() { + setup_env(); + + let request = create_rerank_request( + "What is machine learning?".to_string(), + vec![ + "Machine learning is a subset of AI that enables computers to learn from data.".to_string(), + "The weather today is sunny and warm.".to_string(), + "Deep learning uses neural networks with many layers.".to_string(), + ], + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for reranking"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "rerank", "Should be rerank model type"); + assert!(response.rankings.is_some(), "Should have rankings"); + + let rankings = response.rankings.unwrap(); + assert_eq!(rankings.len(), 3, "Should have rankings for all documents"); + + // Results should be sorted by score (descending) + for i in 1..rankings.len() { + assert!(rankings[i-1].score >= rankings[i].score, + "Rankings should be sorted by score descending"); + } + + // The ML-related document should rank higher than weather + let ml_doc_rank = rankings.iter().position(|r| r.index == 0).unwrap(); + let weather_doc_rank = rankings.iter().position(|r| r.index == 1).unwrap(); + assert!(ml_doc_rank < weather_doc_rank, + "ML document should rank higher than weather document"); +} + +#[tokio::test] +async fn test_rerank_with_top_k() { + setup_env(); + + let request = create_rerank_request( + "Capital cities".to_string(), + vec![ + "Paris is the capital of France.".to_string(), + "Pizza is a popular Italian food.".to_string(), + "London is the capital of England.".to_string(), + "Coffee comes from beans.".to_string(), + "Tokyo is the capital of Japan.".to_string(), + ], + Some(2), + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "rerank"); + + let rankings = response.rankings.unwrap(); + assert_eq!(rankings.len(), 2, "Should only return top 2 results"); +} + +#[tokio::test] +async fn test_rerank_returns_documents() { + setup_env(); + + let docs = vec![ + "Document one about programming.".to_string(), + "Document two about cooking.".to_string(), + ]; + + let request = create_rerank_request( + "Programming tutorials".to_string(), + docs.clone(), + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "rerank"); + + let rankings = response.rankings.unwrap(); + + for ranking in &rankings { + assert!(ranking.document.is_some(), "Document should be included"); + let doc = ranking.document.as_ref().unwrap(); + assert!(docs.contains(doc), "Document should match original"); + } +} + +#[tokio::test] +async fn test_rerank_consistency() { + setup_env(); + + let query = "Technology news".to_string(); + let documents = vec![ + "Apple released new iPhone.".to_string(), + "Farmers harvested apples.".to_string(), + ]; + + let request1 = create_rerank_request(query.clone(), documents.clone(), None); + let request2 = create_rerank_request(query, documents, None); + + let event1 = create_lambda_event(request1); + let event2 = create_lambda_event(request2); + + let result1 = handler(event1).await; + let result2 = handler(event2).await; + + assert!(result1.is_ok() && result2.is_ok(), "Both requests should succeed"); + + let response1 = parse_response(result1.unwrap()).unwrap(); + let response2 = parse_response(result2.unwrap()).unwrap(); + + assert_eq!(response1.model_type, "rerank"); + assert_eq!(response2.model_type, "rerank"); + + let rankings1 = response1.rankings.unwrap(); + let rankings2 = response2.rankings.unwrap(); + + for (r1, r2) in rankings1.iter().zip(rankings2.iter()) { + assert_eq!(r1.index, r2.index, "Ranking order should be consistent"); + assert!((r1.score - r2.score).abs() < 1e-6, "Scores should be identical"); + } +} + +#[tokio::test] +async fn test_rerank_missing_query() { + setup_env(); + + let request = Request { + messages: None, + images: None, + query: None, + documents: Some(vec!["doc1".to_string()]), + s3_file: None, + s3_images: None, + save_to_s3: None, + top_k: None, + return_documents: None, + }; + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Should fail with missing query"); +} + +#[tokio::test] +async fn test_rerank_empty_documents() { + setup_env(); + + let request = create_rerank_request( + "Test query".to_string(), + vec![], + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Should fail with empty documents"); +} + +#[tokio::test] +async fn test_rerank_single_document() { + setup_env(); + + let request = create_rerank_request( + "Test query".to_string(), + vec!["Only one document here.".to_string()], + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed with single document"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "rerank"); + + let rankings = response.rankings.unwrap(); + assert_eq!(rankings.len(), 1, "Should have one ranking"); + assert_eq!(rankings[0].index, 0); +} + +#[tokio::test] +async fn test_rerank_large_batch() { + setup_env(); + + let documents: Vec = (0..20) + .map(|i| format!("This is document number {} with some content.", i)) + .collect(); + + let request = create_rerank_request( + "Document with number".to_string(), + documents, + Some(5), + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for large batch"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "rerank"); + + let rankings = response.rankings.unwrap(); + assert_eq!(rankings.len(), 5, "Should return only top 5"); +} + +#[tokio::test] +async fn test_rerank_relevance_ordering() { + setup_env(); + + let request = create_rerank_request( + "Programming languages".to_string(), + vec![ + "Rust is a systems programming language focused on safety.".to_string(), + "Python is popular for data science and machine learning.".to_string(), + "The cat sat on the mat.".to_string(), + "JavaScript runs in web browsers.".to_string(), + ], + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed"); + + let response = parse_response(result.unwrap()).unwrap(); + let rankings = response.rankings.unwrap(); + + // The cat document (index 2) should be ranked last + let cat_rank = rankings.iter().position(|r| r.index == 2).unwrap(); + assert_eq!(cat_rank, rankings.len() - 1, "Irrelevant document should be ranked last"); +} diff --git a/tests/integration_sparse_tests.rs b/tests/integration_sparse_tests.rs new file mode 100644 index 0000000..4a59830 --- /dev/null +++ b/tests/integration_sparse_tests.rs @@ -0,0 +1,226 @@ +// tests/integration_sparse_tests.rs +// +// Sparse embedding integration tests +// MODEL_ID is set to a sparse embedding model (SPLADE) + +use lambda_runtime::{Context, LambdaEvent}; +use serde_json::Value; +use serverless_vectorizer::{handler, Request, Response}; + +const MODEL_ID: &str = "Qdrant/Splade_PP_en_v1"; + +// Setup environment for tests +fn setup_env() { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + unsafe { + std::env::set_var("AWS_ENDPOINT_URL", &endpoint_url); + std::env::set_var("AWS_ACCESS_KEY_ID", "test"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "test"); + std::env::set_var("AWS_REGION", "us-east-1"); + std::env::set_var("MODEL_ID", MODEL_ID); + } + + // Print diagnostic info for CI debugging + eprintln!("[SPARSE TEST] MODEL_ID set to: {}", MODEL_ID); +} + +fn create_lambda_event(request: Request) -> LambdaEvent { + let payload = serde_json::to_value(request).expect("Failed to serialize request"); + LambdaEvent { + payload, + context: Context::default(), + } +} + +fn create_sparse_request(messages: Vec) -> Request { + Request { + messages: Some(messages), + images: None, + query: None, + documents: None, + s3_file: None, + s3_images: None, + save_to_s3: None, + top_k: None, + return_documents: None, + } +} + +fn parse_response(value: Value) -> Result { + serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) +} + +#[tokio::test] +async fn test_sparse_embedding_single_text() { + setup_env(); + + let request = create_sparse_request(vec!["The quick brown fox jumps over the lazy dog".to_string()]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for sparse embedding"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "sparse", "Should be sparse model type"); + assert!(response.sparse_embeddings.is_some(), "Should have sparse embeddings"); + + let sparse = response.sparse_embeddings.unwrap(); + assert_eq!(sparse.len(), 1); + assert!(!sparse[0].indices.is_empty(), "Sparse embedding should have non-zero indices"); + assert_eq!(sparse[0].indices.len(), sparse[0].values.len(), "Indices and values should have same length"); +} + +#[tokio::test] +async fn test_sparse_embedding_batch() { + setup_env(); + + eprintln!("[SPARSE TEST] Running batch test with 3 texts"); + + let request = create_sparse_request(vec![ + "Machine learning is a subset of artificial intelligence".to_string(), + "Deep learning uses neural networks".to_string(), + "Natural language processing handles text".to_string(), + ]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + match &result { + Ok(value) => { + eprintln!("[SPARSE TEST] Batch handler succeeded"); + eprintln!("[SPARSE TEST] Response: {}", serde_json::to_string_pretty(value).unwrap_or_default()); + }, + Err(e) => { + eprintln!("[SPARSE TEST] Batch handler error: {:?}", e); + eprintln!("[SPARSE TEST] Error details: {}", e); + } + } + + assert!(result.is_ok(), "Handler should succeed for batch sparse embedding: {:?}", result.err()); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "sparse"); + + let sparse = response.sparse_embeddings.unwrap(); + assert_eq!(sparse.len(), 3, "Should have 3 sparse embeddings"); + + for emb in &sparse { + assert!(!emb.indices.is_empty(), "Each sparse embedding should have indices"); + assert_eq!(emb.indices.len(), emb.values.len()); + } +} + +#[tokio::test] +async fn test_sparse_embedding_consistency() { + setup_env(); + + let text = "Consistent test for sparse embeddings".to_string(); + + let request1 = create_sparse_request(vec![text.clone()]); + let request2 = create_sparse_request(vec![text]); + + let event1 = create_lambda_event(request1); + let event2 = create_lambda_event(request2); + + let result1 = handler(event1).await; + let result2 = handler(event2).await; + + assert!(result1.is_ok() && result2.is_ok(), "Both requests should succeed"); + + let response1 = parse_response(result1.unwrap()).unwrap(); + let response2 = parse_response(result2.unwrap()).unwrap(); + + assert_eq!(response1.model_type, "sparse"); + assert_eq!(response2.model_type, "sparse"); + + let sparse1 = response1.sparse_embeddings.unwrap(); + let sparse2 = response2.sparse_embeddings.unwrap(); + + assert_eq!(sparse1[0].indices, sparse2[0].indices, "Indices should match"); + + let values_match = sparse1[0].values.iter() + .zip(sparse2[0].values.iter()) + .all(|(a, b)| (a - b).abs() < 1e-6); + assert!(values_match, "Values should be identical"); +} + +#[tokio::test] +async fn test_sparse_embedding_different_texts() { + setup_env(); + + let request = create_sparse_request(vec![ + "Programming in Rust".to_string(), + "Cooking Italian pasta".to_string(), + ]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "sparse"); + + let sparse = response.sparse_embeddings.unwrap(); + + // Different texts should have different sparse representations + let indices_different = sparse[0].indices != sparse[1].indices; + assert!(indices_different, "Different texts should have different sparse indices"); +} + +#[tokio::test] +async fn test_sparse_embedding_empty_input() { + setup_env(); + + let request = create_sparse_request(vec![]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Handler should fail for empty input"); +} + +#[tokio::test] +async fn test_sparse_embedding_long_text() { + setup_env(); + + let long_text = "This is a much longer text that contains many words. ".repeat(50); + + let request = create_sparse_request(vec![long_text]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for long text"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "sparse"); + + let sparse = response.sparse_embeddings.unwrap(); + assert_eq!(sparse.len(), 1); + assert!(!sparse[0].indices.is_empty()); +} + +#[tokio::test] +async fn test_sparse_embedding_special_characters() { + setup_env(); + + let request = create_sparse_request(vec![ + "Text with special chars: @#$%^&*()".to_string(), + "Numbers: 12345 and symbols: <>?/".to_string(), + ]); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should handle special characters"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "sparse"); + + let sparse = response.sparse_embeddings.unwrap(); + assert_eq!(sparse.len(), 2); +} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs deleted file mode 100644 index 1d3b973..0000000 --- a/tests/integration_tests.rs +++ /dev/null @@ -1,695 +0,0 @@ -// tests/integration_tests.rs - -use aws_config::{BehaviorVersion, Region}; -use aws_sdk_s3 as s3; -use aws_sdk_s3::config::Credentials; -use serverless_vectorizer::{Request, Response, SaveConfig, handler}; -use lambda_runtime::{Context, LambdaEvent}; -use serde_json::Value; - -// Setup LocalStack environment variables so handler's aws_config::load_from_env() uses LocalStack -fn setup_localstack_env() { - let endpoint_url = std::env::var("LOCALSTACK_ENDPOINT") - .unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); - - // SAFETY: These tests run sequentially with --test-threads=1 or the env vars - // are set before any concurrent access. This is safe for test setup. - unsafe { - std::env::set_var("AWS_ENDPOINT_URL", &endpoint_url); - std::env::set_var("AWS_ACCESS_KEY_ID", "test"); - std::env::set_var("AWS_SECRET_ACCESS_KEY", "test"); - std::env::set_var("AWS_REGION", "us-east-1"); - } -} - -// Helper function to create S3 client for LocalStack -async fn create_localstack_s3_client() -> s3::Client { - let endpoint_url = std::env::var("LOCALSTACK_ENDPOINT") - .unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); - - let creds = Credentials::new("test", "test", None, None, "test"); - - let config = aws_config::defaults(BehaviorVersion::latest()) - .region(Region::new("us-east-1")) - .credentials_provider(creds) - .endpoint_url(&endpoint_url) - .load() - .await; - - // Create S3-specific config with force_path_style enabled - let s3_config = s3::config::Builder::from(&config) - .force_path_style(true) // Add this line! - .build(); - - s3::Client::from_conf(s3_config) -} - -// Helper function to create a test context -fn create_test_context() -> Context { - Context::default() -} - -// Helper function to create a LambdaEvent with proper Value payload -fn create_lambda_event(request: Request) -> LambdaEvent { - let payload = serde_json::to_value(request).expect("Failed to serialize request"); - LambdaEvent { - payload, - context: create_test_context(), - } -} - -// Helper function to parse response from handler -fn parse_response(value: Value) -> Result { - serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) -} - -// Helper function to create bucket if it doesn't exist -async fn ensure_bucket_exists( - s3_client: &s3::Client, - bucket: &str, -) -> Result<(), Box> { - match s3_client.head_bucket().bucket(bucket).send().await { - Ok(_) => Ok(()), - Err(_) => { - s3_client.create_bucket().bucket(bucket).send().await?; - Ok(()) - } - } -} - -// Helper function to upload text to S3 -async fn upload_text_to_s3( - s3_client: &s3::Client, - bucket: &str, - key: &str, - content: &str, -) -> Result<(), Box> { - s3_client - .put_object() - .bucket(bucket) - .key(key) - .body(content.as_bytes().to_vec().into()) - .send() - .await?; - Ok(()) -} - -// Helper function to upload JSON array to S3 -async fn upload_json_to_s3( - s3_client: &s3::Client, - bucket: &str, - key: &str, - messages: &[String], -) -> Result<(), Box> { - let json_content = serde_json::to_string(messages)?; - s3_client - .put_object() - .bucket(bucket) - .key(key) - .body(json_content.as_bytes().to_vec().into()) - .content_type("application/json") - .send() - .await?; - Ok(()) -} - -// Helper function to read from S3 -async fn read_from_s3( - s3_client: &s3::Client, - bucket: &str, - key: &str, -) -> Result> { - let resp = s3_client - .get_object() - .bucket(bucket) - .key(key) - .send() - .await?; - - let data = resp.body.collect().await?; - Ok(String::from_utf8(data.to_vec())?) -} - -#[cfg(test)] -mod integration_tests { - use super::*; - - const TEST_BUCKET: &str = "test-embeddings-bucket"; - - #[tokio::test] - async fn test_read_single_text_from_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let test_text = "Hello from S3!"; - let s3_key = "test-files/single-text.txt"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, test_text) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, s3_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should succeed reading from S3"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 1); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_read_json_array_from_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let messages = vec![ - "First message from S3".to_string(), - "Second message from S3".to_string(), - "Third message from S3".to_string(), - ]; - let s3_key = "test-files/messages-array.json"; - - upload_json_to_s3(&s3_client, TEST_BUCKET, s3_key, &messages) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, s3_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_ok(), - "Handler should succeed reading JSON array from S3" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 3, "Should generate 3 embeddings"); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_save_embeddings_to_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let output_key = "test-output/embeddings.json"; - - // Test - let request = Request { - messages: Some(vec!["Test message for saving".to_string()]), - s3_file: None, - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should succeed saving to S3"); - - let response = parse_response(result.unwrap()).unwrap(); - assert!(response.s3_location.is_some()); - assert_eq!( - response.s3_location.unwrap(), - format!("s3://{}/{}", TEST_BUCKET, output_key) - ); - - // Verify the file was actually saved - let saved_content = read_from_s3(&s3_client, TEST_BUCKET, output_key) - .await - .unwrap(); - - let saved_embeddings: Vec> = serde_json::from_str(&saved_content).unwrap(); - assert_eq!(saved_embeddings.len(), 1); - assert_eq!(saved_embeddings[0].len(), 384); - } - - #[tokio::test] - async fn test_save_multiple_embeddings_to_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let output_key = "test-output/batch-embeddings.json"; - - // Test - let request = Request { - messages: Some(vec![ - "First message".to_string(), - "Second message".to_string(), - "Third message".to_string(), - ]), - s3_file: None, - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should succeed saving batch to S3"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 3); - - // Verify the file was actually saved with all embeddings - let saved_content = read_from_s3(&s3_client, TEST_BUCKET, output_key) - .await - .unwrap(); - - let saved_embeddings: Vec> = serde_json::from_str(&saved_content).unwrap(); - assert_eq!(saved_embeddings.len(), 3); - - for embedding in &saved_embeddings { - assert_eq!(embedding.len(), 384); - } - } - - #[tokio::test] - async fn test_read_from_s3_and_save_to_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let input_key = "test-files/input-text.txt"; - let output_key = "test-output/result-embeddings.json"; - let test_text = "Process this text from S3 and save result"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, input_key, test_text) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, input_key)), - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_ok(), - "Handler should succeed with both S3 read and write" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 1); - assert!(response.s3_location.is_some()); - - // Verify output file - let saved_content = read_from_s3(&s3_client, TEST_BUCKET, output_key) - .await - .unwrap(); - - let saved_embeddings: Vec> = serde_json::from_str(&saved_content).unwrap(); - assert_eq!(saved_embeddings.len(), 1); - assert_eq!(saved_embeddings[0], response.embeddings[0]); - } - - #[tokio::test] - async fn test_read_json_array_and_save_to_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let input_key = "test-files/batch-input.json"; - let output_key = "test-output/batch-output.json"; - let messages = vec!["Batch message 1".to_string(), "Batch message 2".to_string()]; - - upload_json_to_s3(&s3_client, TEST_BUCKET, input_key, &messages) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, input_key)), - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_ok(), - "Handler should succeed with batch S3 read and write" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 2); - - // Verify output - let saved_content = read_from_s3(&s3_client, TEST_BUCKET, output_key) - .await - .unwrap(); - - let saved_embeddings: Vec> = serde_json::from_str(&saved_content).unwrap(); - assert_eq!(saved_embeddings.len(), 2); - } - - #[tokio::test] - async fn test_invalid_s3_path() { - setup_localstack_env(); - let request = Request { - messages: None, - s3_file: Some("invalid-path".to_string()), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_err(), "Handler should fail with invalid S3 path"); - assert!( - result - .unwrap_err() - .to_string() - .contains("Invalid S3 path format") - ); - } - - #[tokio::test] - async fn test_nonexistent_s3_file() { - setup_localstack_env(); - let request = Request { - messages: None, - s3_file: Some(format!("{}/nonexistent-file.txt", TEST_BUCKET)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_err(), - "Handler should fail with nonexistent S3 file" - ); - } - - #[tokio::test] - async fn test_large_file_from_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let large_text = "This is a test sentence. ".repeat(100); - let s3_key = "test-files/large-file.txt"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, &large_text) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, s3_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should handle large files"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_unicode_content_from_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let unicode_text = "Hello 世界! 🌍 Привет مرحبا"; - let s3_key = "test-files/unicode.txt"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, unicode_text) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, s3_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should handle Unicode content"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_empty_file_from_s3() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let s3_key = "test-files/empty.txt"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, "") - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, s3_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should handle empty files"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_multiple_files_sequential() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let files = vec![ - ("test-files/seq-1.txt", "First sequential file"), - ("test-files/seq-2.txt", "Second sequential file"), - ("test-files/seq-3.txt", "Third sequential file"), - ]; - - for (key, content) in &files { - upload_text_to_s3(&s3_client, TEST_BUCKET, key, content) - .await - .unwrap(); - } - - // Test each file - for (key, _) in files { - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should succeed for file: {}", key); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - } - - #[tokio::test] - async fn test_nested_s3_paths() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let nested_key = "level1/level2/level3/nested-file.txt"; - let test_text = "Content in nested path"; - - upload_text_to_s3(&s3_client, TEST_BUCKET, nested_key, test_text) - .await - .unwrap(); - - // Test - let request = Request { - messages: None, - s3_file: Some(format!("{}/{}", TEST_BUCKET, nested_key)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should handle nested S3 paths"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_concurrent_s3_operations() { - use tokio::task; - - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - // Upload test files - for i in 0..5 { - let key = format!("test-files/concurrent-{}.txt", i); - let content = format!("Concurrent test message {}", i); - upload_text_to_s3(&s3_client, TEST_BUCKET, &key, &content) - .await - .unwrap(); - } - - // Test concurrent processing - let handles: Vec<_> = (0..5) - .map(|i| { - let bucket = TEST_BUCKET.to_string(); - task::spawn(async move { - let request = Request { - messages: None, - s3_file: Some(format!("{}/test-files/concurrent-{}.txt", bucket, i)), - save_to_s3: None, - }; - - let event = create_lambda_event(request); - handler(event).await - }) - }) - .collect(); - - // Verify all succeed - for handle in handles { - let result = handle.await.unwrap(); - assert!( - result.is_ok(), - "All concurrent S3 operations should succeed" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - } - - #[tokio::test] - async fn test_save_to_nonexistent_bucket() { - setup_localstack_env(); - let request = Request { - messages: Some(vec!["Test message".to_string()]), - s3_file: None, - save_to_s3: Some(SaveConfig { - bucket: "nonexistent-bucket-12345".to_string(), - key: "test.json".to_string(), - }), - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - // Should fail because bucket doesn't exist - assert!( - result.is_err(), - "Handler should fail with nonexistent bucket" - ); - } - - #[tokio::test] - async fn test_overwrite_existing_s3_file() { - // Setup LocalStack env for handler - setup_localstack_env(); - let s3_client = create_localstack_s3_client().await; - ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); - - let output_key = "test-output/overwrite-test.json"; - - // First save - let request1 = Request { - messages: Some(vec!["First version".to_string()]), - s3_file: None, - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event1 = create_lambda_event(request1); - handler(event1).await.unwrap(); - - // Second save (overwrite) - let request2 = Request { - messages: Some(vec!["Second version".to_string()]), - s3_file: None, - save_to_s3: Some(SaveConfig { - bucket: TEST_BUCKET.to_string(), - key: output_key.to_string(), - }), - }; - - let event2 = create_lambda_event(request2); - let result = handler(event2).await; - - assert!(result.is_ok(), "Handler should succeed overwriting file"); - - // Verify the file was overwritten - let saved_content = read_from_s3(&s3_client, TEST_BUCKET, output_key) - .await - .unwrap(); - - let saved_embeddings: Vec> = serde_json::from_str(&saved_content).unwrap(); - assert_eq!(saved_embeddings.len(), 1); - } -} diff --git a/tests/integration_text_tests.rs b/tests/integration_text_tests.rs new file mode 100644 index 0000000..b3422a1 --- /dev/null +++ b/tests/integration_text_tests.rs @@ -0,0 +1,288 @@ +// tests/integration_text_tests.rs +// +// Text embedding integration tests +// MODEL_ID is set to a text embedding model + +use aws_config::{BehaviorVersion, Region}; +use aws_sdk_s3 as s3; +use aws_sdk_s3::config::Credentials; +use lambda_runtime::{Context, LambdaEvent}; +use serde_json::Value; +use serverless_vectorizer::{handler, Request, Response}; + +const MODEL_ID: &str = "Xenova/bge-small-en-v1.5"; +const TEST_BUCKET: &str = "test-text-embeddings-bucket"; + +// Setup environment for tests +fn setup_env() { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + unsafe { + std::env::set_var("AWS_ENDPOINT_URL", &endpoint_url); + std::env::set_var("AWS_ACCESS_KEY_ID", "test"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "test"); + std::env::set_var("AWS_REGION", "us-east-1"); + std::env::set_var("MODEL_ID", MODEL_ID); + } +} + +async fn create_s3_client() -> s3::Client { + let endpoint_url = + std::env::var("LOCALSTACK_ENDPOINT").unwrap_or_else(|_| "http://127.0.0.1:4566".to_string()); + + let creds = Credentials::new("test", "test", None, None, "test"); + + let config = aws_config::defaults(BehaviorVersion::latest()) + .region(Region::new("us-east-1")) + .credentials_provider(creds) + .endpoint_url(&endpoint_url) + .load() + .await; + + let s3_config = s3::config::Builder::from(&config) + .force_path_style(true) + .build(); + + s3::Client::from_conf(s3_config) +} + +fn create_lambda_event(request: Request) -> LambdaEvent { + let payload = serde_json::to_value(request).expect("Failed to serialize request"); + LambdaEvent { + payload, + context: Context::default(), + } +} + +fn create_text_request(messages: Option>, s3_file: Option) -> Request { + Request { + messages, + images: None, + query: None, + documents: None, + s3_file, + s3_images: None, + save_to_s3: None, + top_k: None, + return_documents: None, + } +} + +fn parse_response(value: Value) -> Result { + serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) +} + +async fn ensure_bucket_exists(s3_client: &s3::Client, bucket: &str) -> Result<(), Box> { + match s3_client.head_bucket().bucket(bucket).send().await { + Ok(_) => Ok(()), + Err(_) => { + s3_client.create_bucket().bucket(bucket).send().await?; + Ok(()) + } + } +} + +async fn upload_text_to_s3( + s3_client: &s3::Client, + bucket: &str, + key: &str, + content: &str, +) -> Result<(), Box> { + s3_client + .put_object() + .bucket(bucket) + .key(key) + .body(content.as_bytes().to_vec().into()) + .send() + .await?; + Ok(()) +} + +async fn upload_json_to_s3( + s3_client: &s3::Client, + bucket: &str, + key: &str, + messages: &[String], +) -> Result<(), Box> { + let json_content = serde_json::to_string(messages)?; + s3_client + .put_object() + .bucket(bucket) + .key(key) + .body(json_content.as_bytes().to_vec().into()) + .content_type("application/json") + .send() + .await?; + Ok(()) +} + +#[tokio::test] +async fn test_text_embedding_direct_input() { + setup_env(); + + let request = create_text_request( + Some(vec!["Hello world".to_string(), "How are you?".to_string()]), + None, + ); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed for text embedding"); + + let response = parse_response(result.unwrap()).unwrap(); + assert_eq!(response.model_type, "text", "Should be text model type"); + assert!(response.embeddings.is_some(), "Should have embeddings"); + + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 2, "Should have 2 embeddings"); + assert_eq!(response.dimension.unwrap_or(0), 384); +} + +#[tokio::test] +async fn test_text_embedding_single_text_from_s3() { + setup_env(); + let s3_client = create_s3_client().await; + ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); + + let test_text = "Hello from S3!"; + let s3_key = "test-files/single-text.txt"; + + upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, test_text) + .await + .unwrap(); + + let request = create_text_request(None, Some(format!("{}/{}", TEST_BUCKET, s3_key))); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed reading from S3"); + + let response = parse_response(result.unwrap()).unwrap(); + let embeddings = response.embeddings.expect("Should have embeddings"); + assert_eq!(embeddings.len(), 1); + assert_eq!(response.dimension.unwrap_or(0), 384); +} + +#[tokio::test] +async fn test_text_embedding_json_array_from_s3() { + setup_env(); + let s3_client = create_s3_client().await; + ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); + + let messages = vec![ + "First message from S3".to_string(), + "Second message from S3".to_string(), + "Third message from S3".to_string(), + ]; + let s3_key = "test-files/messages-array.json"; + + upload_json_to_s3(&s3_client, TEST_BUCKET, s3_key, &messages) + .await + .unwrap(); + + let request = create_text_request(None, Some(format!("{}/{}", TEST_BUCKET, s3_key))); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_ok(), "Handler should succeed reading JSON array from S3"); + + let response = parse_response(result.unwrap()).unwrap(); + let embeddings = response.embeddings.expect("Should have embeddings"); + assert_eq!(embeddings.len(), 3); + assert_eq!(response.dimension.unwrap_or(0), 384); +} + +#[tokio::test] +async fn test_text_embedding_s3_consistency() { + setup_env(); + let s3_client = create_s3_client().await; + ensure_bucket_exists(&s3_client, TEST_BUCKET).await.unwrap(); + + let test_text = "Consistency test message"; + let s3_key = "test-files/consistency-test.txt"; + + upload_text_to_s3(&s3_client, TEST_BUCKET, s3_key, test_text) + .await + .unwrap(); + + // Get embedding from S3 + let s3_request = create_text_request(None, Some(format!("{}/{}", TEST_BUCKET, s3_key))); + let s3_event = create_lambda_event(s3_request); + let s3_response = parse_response(handler(s3_event).await.unwrap()).unwrap(); + let s3_embeddings = s3_response.embeddings.expect("Should have embeddings"); + + // Get embedding from direct input + let direct_request = create_text_request(Some(vec![test_text.to_string()]), None); + let direct_event = create_lambda_event(direct_request); + let direct_response = parse_response(handler(direct_event).await.unwrap()).unwrap(); + let direct_embeddings = direct_response.embeddings.expect("Should have embeddings"); + + // Compare embeddings + assert_eq!(s3_embeddings[0].len(), direct_embeddings[0].len()); + + let embeddings_match = s3_embeddings[0] + .iter() + .zip(direct_embeddings[0].iter()) + .all(|(a, b)| (a - b).abs() < 1e-6); + + assert!(embeddings_match, "S3 and direct input should produce identical embeddings"); +} + +#[tokio::test] +async fn test_text_embedding_invalid_s3_path() { + setup_env(); + + let request = create_text_request(None, Some("nonexistent-bucket/nonexistent-file.txt".to_string())); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Handler should fail for invalid S3 path"); +} + +#[tokio::test] +async fn test_text_embedding_empty_messages() { + setup_env(); + + let request = create_text_request(Some(vec![]), None); + + let event = create_lambda_event(request); + let result = handler(event).await; + + assert!(result.is_err(), "Handler should fail for empty messages"); +} + +#[tokio::test] +async fn test_text_embedding_consistency() { + setup_env(); + + let text = "Consistent embedding test".to_string(); + + let request1 = create_text_request(Some(vec![text.clone()]), None); + let request2 = create_text_request(Some(vec![text]), None); + + let event1 = create_lambda_event(request1); + let event2 = create_lambda_event(request2); + + let result1 = handler(event1).await; + let result2 = handler(event2).await; + + assert!(result1.is_ok() && result2.is_ok(), "Both requests should succeed"); + + let response1 = parse_response(result1.unwrap()).unwrap(); + let response2 = parse_response(result2.unwrap()).unwrap(); + + let emb1 = response1.embeddings.unwrap(); + let emb2 = response2.embeddings.unwrap(); + + let match_result = emb1[0] + .iter() + .zip(emb2[0].iter()) + .all(|(a, b)| (a - b).abs() < 1e-6); + + assert!(match_result, "Same text should produce identical embeddings"); +} diff --git a/tests/unit_tests.rs b/tests/unit_tests.rs index 938ede8..d850846 100644 --- a/tests/unit_tests.rs +++ b/tests/unit_tests.rs @@ -1,8 +1,8 @@ // tests/unit_tests.rs use lambda_runtime::{Context, LambdaEvent}; -use serde_json::{Value, json}; -use serverless_vectorizer::{Request, Response, SaveConfig, handler}; +use serde_json::{json, Value}; +use serverless_vectorizer::{handler, Request, Response, SaveConfig}; // Helper function to create a test context fn create_test_context() -> Context { @@ -23,6 +23,21 @@ fn parse_response(value: Value) -> Result { serde_json::from_value(value).map_err(|e| format!("Failed to parse response: {}", e)) } +// Helper to create a text embedding request +fn create_text_request(messages: Vec) -> Request { + Request { + messages: Some(messages), + images: None, + query: None, + documents: None, + s3_file: None, + s3_images: None, + save_to_s3: None, + top_k: None, + return_documents: None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -30,11 +45,7 @@ mod tests { #[tokio::test] async fn test_simple_text_embedding() { // Test basic text embedding - let request = Request { - messages: Some(vec!["Hello, world!".to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec!["Hello, world!".to_string()]); let event = create_lambda_event(request); let result = handler(event).await; @@ -42,21 +53,21 @@ mod tests { assert!(result.is_ok(), "Handler should succeed"); let response = parse_response(result.unwrap()).expect("Should parse response"); - assert_eq!(response.embeddings.len(), 1, "Should have one embedding"); - assert!( - response.embeddings[0].len() > 0, - "Embedding should not be empty" - ); + let embeddings = response.embeddings.expect("Should have embeddings"); + assert_eq!(embeddings.len(), 1, "Should have one embedding"); + assert!(embeddings[0].len() > 0, "Embedding should not be empty"); assert_eq!( - response.dimension, - response.embeddings[0].len(), + response.dimension.unwrap_or(0), + embeddings[0].len(), "Dimension should match embedding length" ); assert!(response.s3_location.is_none(), "S3 location should be None"); + assert_eq!(response.model_type, "text", "Model type should be text"); // BGE-small-en-v1.5 produces 384-dimensional embeddings assert_eq!( - response.dimension, 384, + response.dimension.unwrap_or(0), + 384, "Expected 384-dimensional embedding from BGE-small-en-v1.5" ); } @@ -64,15 +75,11 @@ mod tests { #[tokio::test] async fn test_multiple_messages_embedding() { // Test batch embedding with multiple messages - let request = Request { - messages: Some(vec![ - "First message".to_string(), - "Second message".to_string(), - "Third message".to_string(), - ]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec![ + "First message".to_string(), + "Second message".to_string(), + "Third message".to_string(), + ]); let event = create_lambda_event(request); let result = handler(event).await; @@ -83,18 +90,19 @@ mod tests { ); let response = parse_response(result.unwrap()).expect("Should parse response"); - assert_eq!(response.embeddings.len(), 3, "Should have three embeddings"); - assert_eq!(response.dimension, 384); + let embeddings = response.embeddings.expect("Should have embeddings"); + assert_eq!(embeddings.len(), 3, "Should have three embeddings"); + assert_eq!(response.dimension.unwrap_or(0), 384); // All embeddings should be 384-dimensional - for embedding in &response.embeddings { + for embedding in &embeddings { assert_eq!(embedding.len(), 384); } // Different messages should produce different embeddings - let embeddings_different = response.embeddings[0] + let embeddings_different = embeddings[0] .iter() - .zip(response.embeddings[1].iter()) + .zip(embeddings[1].iter()) .any(|(a, b)| (a - b).abs() > 0.01); assert!( embeddings_different, @@ -105,11 +113,7 @@ mod tests { #[tokio::test] async fn test_empty_message() { // Test with empty string - let request = Request { - messages: Some(vec!["".to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec!["".to_string()]); let event = create_lambda_event(request); let result = handler(event).await; @@ -118,7 +122,8 @@ mod tests { let response = parse_response(result.unwrap()).expect("Should parse response"); assert_eq!( - response.dimension, 384, + response.dimension.unwrap_or(0), + 384, "Should still produce 384-dimensional embedding" ); } @@ -130,11 +135,7 @@ mod tests { It tests whether the embedding model can handle longer inputs correctly. \ The embeddings should still be generated properly regardless of text length."; - let request = Request { - messages: Some(vec![long_text.to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec![long_text.to_string()]); let event = create_lambda_event(request); let result = handler(event).await; @@ -142,8 +143,9 @@ mod tests { assert!(result.is_ok(), "Handler should handle long text"); let response = parse_response(result.unwrap()).expect("Should parse response"); - assert_eq!(response.dimension, 384); - assert!(response.embeddings[0].len() == 384); + let embeddings = response.embeddings.expect("Should have embeddings"); + assert_eq!(response.dimension.unwrap_or(0), 384); + assert!(embeddings[0].len() == 384); } #[tokio::test] @@ -151,8 +153,14 @@ mod tests { // Test when neither messages nor s3_file is provided let request = Request { messages: None, + images: None, + query: None, + documents: None, s3_file: None, + s3_images: None, save_to_s3: None, + top_k: None, + return_documents: None, }; let event = create_lambda_event(request); @@ -164,21 +172,15 @@ mod tests { ); let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Either 'messages' or 's3_file' must be provided") - ); + assert!(error + .to_string() + .contains("Either 'messages' or 's3_file' must be provided")); } #[tokio::test] async fn test_empty_messages_array() { // Test with empty messages array - let request = Request { - messages: Some(vec![]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec![]); let event = create_lambda_event(request); let result = handler(event).await; @@ -194,17 +196,8 @@ mod tests { // Test that the same input produces consistent embeddings let text = "Consistent test message"; - let request1 = Request { - messages: Some(vec![text.to_string()]), - s3_file: None, - save_to_s3: None, - }; - - let request2 = Request { - messages: Some(vec![text.to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request1 = create_text_request(vec![text.to_string()]); + let request2 = create_text_request(vec![text.to_string()]); let event1 = create_lambda_event(request1); let event2 = create_lambda_event(request2); @@ -212,12 +205,15 @@ mod tests { let result1 = parse_response(handler(event1).await.unwrap()).unwrap(); let result2 = parse_response(handler(event2).await.unwrap()).unwrap(); - assert_eq!(result1.embeddings[0].len(), result2.embeddings[0].len()); + let emb1 = result1.embeddings.unwrap(); + let emb2 = result2.embeddings.unwrap(); + + assert_eq!(emb1[0].len(), emb2[0].len()); // Check that embeddings are identical (or very close due to floating point) - let embeddings_match = result1.embeddings[0] + let embeddings_match = emb1[0] .iter() - .zip(result2.embeddings[0].iter()) + .zip(emb2[0].iter()) .all(|(a, b)| (a - b).abs() < 1e-6); assert!( @@ -226,62 +222,22 @@ mod tests { ); } - #[tokio::test] - async fn test_batch_embedding_consistency() { - // Test that same messages in batch produce consistent embeddings - let messages = vec!["First message".to_string(), "Second message".to_string()]; - - let request1 = Request { - messages: Some(messages.clone()), - s3_file: None, - save_to_s3: None, - }; - - let request2 = Request { - messages: Some(messages), - s3_file: None, - save_to_s3: None, - }; - - let event1 = create_lambda_event(request1); - let event2 = create_lambda_event(request2); - - let result1 = parse_response(handler(event1).await.unwrap()).unwrap(); - let result2 = parse_response(handler(event2).await.unwrap()).unwrap(); - - // Check that all embeddings match - for i in 0..result1.embeddings.len() { - let embeddings_match = result1.embeddings[i] - .iter() - .zip(result2.embeddings[i].iter()) - .all(|(a, b)| (a - b).abs() < 1e-6); - assert!( - embeddings_match, - "Batch embedding {} should be consistent", - i - ); - } - } - #[tokio::test] async fn test_different_texts_produce_different_embeddings() { // Test that different texts produce different embeddings - let request = Request { - messages: Some(vec![ - "Hello, world!".to_string(), - "Goodbye, world!".to_string(), - ]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec![ + "Hello, world!".to_string(), + "Goodbye, world!".to_string(), + ]); let event = create_lambda_event(request); let result = parse_response(handler(event).await.unwrap()).unwrap(); + let embeddings = result.embeddings.unwrap(); // Check that embeddings are different - let embeddings_different = result.embeddings[0] + let embeddings_different = embeddings[0] .iter() - .zip(result.embeddings[1].iter()) + .zip(embeddings[1].iter()) .any(|(a, b)| (a - b).abs() > 0.01); assert!( @@ -293,17 +249,14 @@ mod tests { #[tokio::test] async fn test_embedding_vector_properties() { // Test mathematical properties of embedding vectors - let request = Request { - messages: Some(vec!["Test vector properties".to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec!["Test vector properties".to_string()]); let event = create_lambda_event(request); let result = handler(event).await.unwrap(); let response = parse_response(result).unwrap(); - let embedding = &response.embeddings[0]; + let embeddings = response.embeddings.unwrap(); + let embedding = &embeddings[0]; // Check all values are finite assert!( @@ -330,11 +283,7 @@ mod tests { // Test with a larger batch of messages let messages: Vec = (0..10).map(|i| format!("Message number {}", i)).collect(); - let request = Request { - messages: Some(messages.clone()), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(messages); let event = create_lambda_event(request); let result = handler(event).await; @@ -342,9 +291,10 @@ mod tests { assert!(result.is_ok(), "Handler should handle large batch"); let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 10, "Should have 10 embeddings"); + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 10, "Should have 10 embeddings"); - for embedding in &response.embeddings { + for embedding in &embeddings { assert_eq!(embedding.len(), 384); } } @@ -352,11 +302,8 @@ mod tests { #[tokio::test] async fn test_special_characters() { // Test with special characters and Unicode - let request = Request { - messages: Some(vec!["Hello 世界! 🌍 Special chars: @#$%^&*()".to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = + create_text_request(vec!["Hello 世界! 🌍 Special chars: @#$%^&*()".to_string()]); let event = create_lambda_event(request); let result = handler(event).await; @@ -364,35 +311,13 @@ mod tests { assert!(result.is_ok(), "Handler should handle special characters"); let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_numeric_text() { - // Test with numeric content - let request = Request { - messages: Some(vec!["12345 67890".to_string()]), - s3_file: None, - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!(result.is_ok(), "Handler should handle numeric text"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); + assert_eq!(response.dimension.unwrap_or(0), 384); } #[tokio::test] async fn test_response_serialization() { // Test that the response can be serialized to JSON - let request = Request { - messages: Some(vec!["Test serialization".to_string()]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec!["Test serialization".to_string()]); let event = create_lambda_event(request); let result = handler(event).await.unwrap(); @@ -404,16 +329,14 @@ mod tests { let json_value = json.unwrap(); assert!(json_value.contains("embeddings")); - assert!(json_value.contains("dimension")); + assert!(json_value.contains("model_type")); } #[tokio::test] async fn test_request_deserialization() { // Test that requests can be deserialized from JSON let json_str = r#"{ - "messages": ["Test message"], - "s3_file": null, - "save_to_s3": null + "messages": ["Test message"] }"#; let request: Result = serde_json::from_str(json_str); @@ -421,55 +344,20 @@ mod tests { let req = request.unwrap(); assert_eq!(req.messages.as_ref().unwrap()[0], "Test message"); - assert!(req.s3_file.is_none()); - assert!(req.save_to_s3.is_none()); - } - - #[tokio::test] - async fn test_concurrent_requests() { - // Test that multiple concurrent requests work correctly - use tokio::task; - - let handles: Vec<_> = (0..5) - .map(|i| { - task::spawn(async move { - let request = Request { - messages: Some(vec![format!("Concurrent request {}", i)]), - s3_file: None, - save_to_s3: None, - }; - - let event = create_lambda_event(request); - handler(event).await - }) - }) - .collect(); - - // Wait for all tasks to complete - for handle in handles { - let result = handle.await.unwrap(); - assert!(result.is_ok(), "All concurrent requests should succeed"); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.dimension, 384); - } } #[tokio::test] async fn test_semantic_similarity() { // Test that semantically similar texts have similar embeddings - let request = Request { - messages: Some(vec![ - "The cat sat on the mat".to_string(), - "A feline rested on the rug".to_string(), - "The weather is sunny today".to_string(), - ]), - s3_file: None, - save_to_s3: None, - }; + let request = create_text_request(vec![ + "The cat sat on the mat".to_string(), + "A feline rested on the rug".to_string(), + "The weather is sunny today".to_string(), + ]); let event = create_lambda_event(request); let result = parse_response(handler(event).await.unwrap()).unwrap(); + let embeddings = result.embeddings.unwrap(); // Calculate cosine similarity let cosine_sim = |a: &[f32], b: &[f32]| -> f32 { @@ -479,8 +367,8 @@ mod tests { dot / (mag_a * mag_b) }; - let sim_similar = cosine_sim(&result.embeddings[0], &result.embeddings[1]); - let sim_different = cosine_sim(&result.embeddings[0], &result.embeddings[2]); + let sim_similar = cosine_sim(&embeddings[0], &embeddings[1]); + let sim_different = cosine_sim(&embeddings[0], &embeddings[2]); // Similar sentences should have higher similarity than different ones assert!( @@ -489,73 +377,6 @@ mod tests { sim_similar, sim_different ); - - // Similar texts should have reasonably high similarity (typically > 0.5 for related texts) - assert!( - sim_similar > 0.3, - "Similar texts should have cosine similarity > 0.3, got: {}", - sim_similar - ); - } - - #[tokio::test] - async fn test_whitespace_handling() { - // Test various whitespace scenarios - let texts = vec![ - "Hello world".to_string(), - "Hello world".to_string(), // double space - " Hello world ".to_string(), // leading/trailing spaces - "Hello\nworld".to_string(), // newline - "Hello\tworld".to_string(), // tab - ]; - - let request = Request { - messages: Some(texts), - s3_file: None, - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_ok(), - "Handler should handle whitespace variations" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 5); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_mixed_length_messages() { - // Test batch with varying message lengths - let request = Request { - messages: Some(vec![ - "Short".to_string(), - "Medium length message here".to_string(), - "This is a much longer message that contains multiple sentences and covers various topics to test the embedding model's ability to handle varying input lengths effectively.".to_string(), - ]), - s3_file: None, - save_to_s3: None, - }; - - let event = create_lambda_event(request); - let result = handler(event).await; - - assert!( - result.is_ok(), - "Handler should handle mixed length messages" - ); - - let response = parse_response(result.unwrap()).unwrap(); - assert_eq!(response.embeddings.len(), 3); - - // All should be same dimension - for embedding in &response.embeddings { - assert_eq!(embedding.len(), 384); - } } } @@ -577,30 +398,6 @@ mod save_config_tests { assert_eq!(cfg.bucket, "my-bucket"); assert_eq!(cfg.key, "embeddings/test.json"); } - - #[test] - fn test_request_with_save_config() { - let json_str = r#"{ - "messages": ["Test"], - "save_to_s3": { - "bucket": "test-bucket", - "key": "test.json" - } - }"#; - - let request: Result = serde_json::from_str(json_str); - assert!( - request.is_ok(), - "Request with save_to_s3 should deserialize" - ); - - let req = request.unwrap(); - assert!(req.save_to_s3.is_some()); - - let save_cfg = req.save_to_s3.unwrap(); - assert_eq!(save_cfg.bucket, "test-bucket"); - assert_eq!(save_cfg.key, "test.json"); - } } #[cfg(test)] @@ -655,9 +452,10 @@ mod api_gateway_tests { // Parse the body to check the actual response let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.dimension, 384); - assert_eq!(response.embeddings.len(), 1); - assert!(response.embeddings[0].len() > 0); + assert_eq!(response.dimension.unwrap_or(0), 384); + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 1); + assert!(embeddings[0].len() > 0); assert!(response.s3_location.is_none()); } @@ -676,8 +474,9 @@ mod api_gateway_tests { assert_eq!(api_response.status_code, 200); let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.embeddings.len(), 3); - assert_eq!(response.dimension, 384); + let embeddings = response.embeddings.unwrap(); + assert_eq!(embeddings.len(), 3); + assert_eq!(response.dimension.unwrap_or(0), 384); } #[tokio::test] @@ -709,60 +508,10 @@ mod api_gateway_tests { let error_body: Value = serde_json::from_str(&api_response.body).unwrap(); assert!(error_body.get("error").is_some()); - assert!( - error_body["error"] - .as_str() - .unwrap() - .contains("Missing body") - ); - } - - #[tokio::test] - async fn test_api_gateway_invalid_json_body() { - let body = r#"{"messages": invalid json}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should return error response"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!( - api_response.status_code, 400, - "Should return 400 for invalid JSON" - ); - - let error_body: Value = serde_json::from_str(&api_response.body).unwrap(); - assert!(error_body.get("error").is_some()); - assert!( - error_body["error"] - .as_str() - .unwrap() - .contains("Failed to parse body") - ); - } - - #[tokio::test] - async fn test_api_gateway_missing_required_fields() { - let body = r#"{"some_other_field": "value"}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should return error response"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!( - api_response.status_code, 500, - "Should return 500 for processing error" - ); - - let error_body: Value = serde_json::from_str(&api_response.body).unwrap(); - assert!(error_body.get("error").is_some()); - assert!( - error_body["error"] - .as_str() - .unwrap() - .contains("Either 'messages' or 's3_file' must be provided") - ); + assert!(error_body["error"] + .as_str() + .unwrap() + .contains("Missing body")); } #[tokio::test] @@ -776,262 +525,4 @@ mod api_gateway_tests { let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); assert_eq!(api_response.status_code, 500); } - - #[tokio::test] - async fn test_api_gateway_with_s3_file() { - let body = r#"{"s3_file": "my-bucket/test.txt"}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - - // Note: This will fail in unit tests because S3 isn't available - // But we can verify it's handled as an API Gateway request - assert!(result.is_ok(), "Handler should return a response"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - // Should be 500 because S3 operation will fail in test environment - assert!(api_response.status_code == 200 || api_response.status_code == 500); - } - - #[tokio::test] - async fn test_api_gateway_batch_processing() { - let messages: Vec = (0..5).map(|i| format!("Batch message {}", i)).collect(); - - let body = json!({"messages": messages}).to_string(); - let event = create_api_gateway_event(&body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should handle batch processing"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.embeddings.len(), 5); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_api_gateway_long_message() { - let long_text = - "This is a very long message that simulates a real-world API Gateway request. " - .repeat(50); - let body = json!({ - "messages": [long_text] - }) - .to_string(); - - let event = create_api_gateway_event(&body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should handle long messages"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_api_gateway_special_characters_in_body() { - let body = r#"{"messages": ["Special chars: 世界 🌍 @#$%^&*()"]}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should handle special characters"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_api_gateway_response_structure() { - let body = r#"{"messages": ["Test response structure"]}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await.unwrap(); - let api_response = parse_api_gateway_response(result).unwrap(); - - // Verify API Gateway response structure - assert!(api_response.status_code > 0); - assert!(api_response.headers.contains_key("Content-Type")); - assert!(!api_response.body.is_empty()); - - // Verify body is valid JSON - let parsed_body: Value = serde_json::from_str(&api_response.body).unwrap(); - assert!(parsed_body.is_object()); - } - - #[tokio::test] - async fn test_api_gateway_vs_direct_invocation() { - let message = "Compare invocation methods"; - - // API Gateway request - let api_body = json!({"messages": [message]}).to_string(); - let api_event = create_api_gateway_event(&api_body); - let api_result = handler(api_event).await.unwrap(); - let api_response = parse_api_gateway_response(api_result).unwrap(); - - // Direct Lambda invocation - let direct_request = Request { - messages: Some(vec![message.to_string()]), - s3_file: None, - save_to_s3: None, - }; - let direct_event = create_lambda_event(direct_request); - let direct_result = handler(direct_event).await.unwrap(); - let direct_response = parse_response(direct_result).unwrap(); - - // Parse API Gateway body - let api_body_response: Response = serde_json::from_str(&api_response.body).unwrap(); - - // Both should produce the same embedding - assert_eq!(api_body_response.dimension, direct_response.dimension); - assert_eq!( - api_body_response.embeddings[0].len(), - direct_response.embeddings[0].len() - ); - - // Embeddings should be identical - let embeddings_match = api_body_response.embeddings[0] - .iter() - .zip(direct_response.embeddings[0].iter()) - .all(|(a, b)| (a - b).abs() < 1e-6); - - assert!( - embeddings_match, - "Same message should produce identical embeddings regardless of invocation method" - ); - } - - #[tokio::test] - async fn test_api_gateway_multiple_sequential_requests() { - let messages = vec![ - vec!["First message"], - vec!["Second message"], - vec!["Third message"], - ]; - - for (i, msg) in messages.iter().enumerate() { - let body = json!({"messages": msg}).to_string(); - let event = create_api_gateway_event(&body); - - let result = handler(event).await; - assert!(result.is_ok(), "Request {} should succeed", i + 1); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.dimension, 384); - } - } - - #[tokio::test] - async fn test_api_gateway_empty_message() { - let body = r#"{"messages": [""]}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should handle empty message"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.dimension, 384); - } - - #[tokio::test] - async fn test_api_gateway_with_different_http_methods() { - let body = r#"{"messages": ["Test message"]}"#; - - // Test with different HTTP methods - for method in &["POST", "GET", "PUT", "DELETE"] { - let payload = json!({ - "httpMethod": method, - "path": "/embed", - "headers": { - "Content-Type": "application/json" - }, - "requestContext": { - "requestId": "test-request-id" - }, - "body": body - }); - - let event = LambdaEvent { - payload, - context: create_test_context(), - }; - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should process {} request", method); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - // Should succeed regardless of HTTP method (business logic doesn't check method) - assert!(api_response.status_code == 200 || api_response.status_code == 400); - } - } - - #[tokio::test] - async fn test_api_gateway_headers_in_response() { - let body = r#"{"messages": ["Test headers"]}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await.unwrap(); - let api_response = parse_api_gateway_response(result).unwrap(); - - // Verify headers are present and correct - assert!(api_response.headers.contains_key("Content-Type")); - assert_eq!( - api_response.headers.get("Content-Type").unwrap(), - "application/json" - ); - } - - #[tokio::test] - async fn test_api_gateway_request_with_save_config() { - let body = r#"{ - "messages": ["Test save to S3"], - "save_to_s3": { - "bucket": "test-bucket", - "key": "embeddings/test.json" - } - }"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!( - result.is_ok(), - "Handler should process request with save config" - ); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - // Will likely be 500 in test environment due to S3, but request is processed - assert!(api_response.status_code == 200 || api_response.status_code == 500); - } - - #[tokio::test] - async fn test_api_gateway_mixed_batch() { - let body = r#"{"messages": ["Short", "Medium length text here", "This is a much longer message with multiple sentences and various content to test processing."]}"#; - let event = create_api_gateway_event(body); - - let result = handler(event).await; - assert!(result.is_ok(), "Handler should handle mixed length batch"); - - let api_response = parse_api_gateway_response(result.unwrap()).unwrap(); - assert_eq!(api_response.status_code, 200); - - let response: Response = serde_json::from_str(&api_response.body).unwrap(); - assert_eq!(response.embeddings.len(), 3); - - for embedding in &response.embeddings { - assert_eq!(embedding.len(), 384); - } - } }