diff --git a/.github/workflows/build-on-release.yml b/.github/workflows/build-on-release.yml index 9a133ac..90b42ee 100644 --- a/.github/workflows/build-on-release.yml +++ b/.github/workflows/build-on-release.yml @@ -69,21 +69,222 @@ jobs: strategy: matrix: include: - - variant: models--Qdrant--multilingual-e5-large-onnx - model_type: multilingual-e5-large - platform: linux/amd64 - variant: models--Xenova--all-mpnet-base-v2 model_type: all-mpnet-base-v2 + model_id: Xenova/all-mpnet-base-v2 + dimension: 768 platform: linux/amd64 - - variant: models--Xenova--bge-base-en-v1.5 - model_type: bge-base-en-v1.5 + - variant: models--Alibaba-NLP--gte-large-en-v1.5 + model_type: gte-large-en-v1.5 + model_id: Alibaba-NLP/gte-large-en-v1.5 + dimension: 1024 + platform: linux/amd64 + - variant: models--jinaai--jina-embeddings-v2-base-code + model_type: jina-embeddings-v2-base-code + model_id: jinaai/jina-embeddings-v2-base-code + dimension: 768 platform: linux/amd64 - variant: models--Xenova--bge-large-en-v1.5 model_type: bge-large-en-v1.5 + model_id: Xenova/bge-large-en-v1.5 + dimension: 1024 + platform: linux/amd64 + - variant: models--intfloat--multilingual-e5-small + model_type: multilingual-e5-small + model_id: intfloat/multilingual-e5-small + dimension: 384 + platform: linux/amd64 + - variant: models--mixedbread-ai--mxbai-embed-large-v1 + model_type: mxbai-embed-large-v1 + model_id: mixedbread-ai/mxbai-embed-large-v1 + dimension: 1024 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-m-long + model_type: snowflake-arctic-embed-m-long + model_id: snowflake/snowflake-arctic-embed-m-long + dimension: 768 + platform: linux/amd64 + - variant: models--nomic-ai--nomic-embed-text-v1.5 + model_type: nomic-embed-text-v1.5 + model_id: nomic-ai/nomic-embed-text-v1.5 + dimension: 768 + platform: linux/amd64 + - variant: models--Snowflake--snowflake-arctic-embed-m + model_type: snowflake-arctic-embed-m + model_id: Snowflake/snowflake-arctic-embed-m + dimension: 768 + platform: linux/amd64 + - variant: models--Xenova--bge-base-en-v1.5 + model_type: bge-base-en-v1.5 + model_id: Xenova/bge-base-en-v1.5 + dimension: 768 platform: linux/amd64 - variant: models--Xenova--bge-small-en-v1.5 model_type: bge-small-en-v1.5 + model_id: Xenova/bge-small-en-v1.5 + dimension: 384 + platform: linux/amd64 + - variant: models--mixedbread-ai--mxbai-embed-large-v1 + model_type: mxbai-embed-large-v1 + model_id: mixedbread-ai/mxbai-embed-large-v1 + dimension: 1024 + platform: linux/amd64 + - variant: models--onnx-community--embeddinggemma-300m-ONNX + model_type: embeddinggemma-300m-ONNX + model_id: onnx-community/embeddinggemma-300m-ONNX + dimension: 768 + platform: linux/amd64 + - variant: models--Qdrant--multilingual-e5-large-onnx + model_type: multilingual-e5-large-onnx + model_id: Qdrant/multilingual-e5-large-onnx + dimension: 1024 + platform: linux/amd64 + - variant: models--Qdrant--paraphrase-multilingual-MiniLM-L12-v2-onnx-Q + model_type: paraphrase-multilingual-MiniLM-L12-v2-onnx-Q + model_id: Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q + dimension: 384 + platform: linux/amd64 + - variant: models--Xenova--paraphrase-multilingual-mpnet-base-v2 + model_type: paraphrase-multilingual-mpnet-base-v2 + model_id: Xenova/paraphrase-multilingual-mpnet-base-v2 + dimension: 768 + platform: linux/amd64 + - variant: models--lightonai--modernbert-embed-large + model_type: modernbert-embed-large + model_id: lightonai/modernbert-embed-large + dimension: 1024 + platform: linux/amd64 + - variant: models--Xenova--all-MiniLM-L12-v2 + model_type: all-MiniLM-L12-v2 + model_id: Xenova/all-MiniLM-L12-v2 + dimension: 384 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-m-long + model_type: snowflake-arctic-embed-m-long + model_id: snowflake/snowflake-arctic-embed-m-long + dimension: 768 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-l + model_type: snowflake-arctic-embed-l + model_id: snowflake/snowflake-arctic-embed-l + dimension: 1024 + platform: linux/amd64 + - variant: models--Xenova--bge-large-zh-v1.5 + model_type: bge-large-zh-v1.5 + model_id: Xenova/bge-large-zh-v1.5 + dimension: 1024 platform: linux/amd64 + - variant: models--Qdrant--all-MiniLM-L6-v2-onnx + model_type: all-MiniLM-L6-v2-onnx + model_id: Qdrant/all-MiniLM-L6-v2-onnx + dimension: 384 + platform: linux/amd64 + - variant: models--intfloat--multilingual-e5-base + model_type: multilingual-e5-base + model_id: intfloat/multilingual-e5-base + dimension: 768 + platform: linux/amd64 + - variant: models--Alibaba-NLP--gte-large-en-v1.5 + model_type: gte-large-en-v1.5 + model_id: Alibaba-NLP/gte-large-en-v1.5 + dimension: 1024 + platform: linux/amd64 + - variant: models--Snowflake--snowflake-arctic-embed-m + model_type: snowflake-arctic-embed-m + model_id: Snowflake/snowflake-arctic-embed-m + dimension: 768 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-xs + model_type: snowflake-arctic-embed-xs + model_id: snowflake/snowflake-arctic-embed-xs + dimension: 384 + platform: linux/amd64 + - variant: models--BAAI--bge-m3 + model_type: bge-m3 + model_id: BAAI/bge-m3 + dimension: 1024 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-s + model_type: snowflake-arctic-embed-s + model_id: snowflake/snowflake-arctic-embed-s + dimension: 384 + platform: linux/amd64 + - variant: models--Qdrant--clip-ViT-B-32-text + model_type: clip-ViT-B-32-text + model_id: Qdrant/clip-ViT-B-32-text + dimension: 512 + platform: linux/amd64 + - variant: models--nomic-ai--nomic-embed-text-v1 + model_type: nomic-embed-text-v1 + model_id: nomic-ai/nomic-embed-text-v1 + dimension: 768 + platform: linux/amd64 + - variant: models--Xenova--all-MiniLM-L6-v2 + model_type: all-MiniLM-L6-v2 + model_id: Xenova/all-MiniLM-L6-v2 + dimension: 384 + platform: linux/amd64 + - variant: models--Alibaba-NLP--gte-base-en-v1.5 + model_type: gte-base-en-v1.5 + model_id: Alibaba-NLP/gte-base-en-v1.5 + dimension: 768 + platform: linux/amd64 + - variant: models--Xenova--all-MiniLM-L12-v2 + model_type: all-MiniLM-L12-v2 + model_id: Xenova/all-MiniLM-L12-v2 + dimension: 384 + platform: linux/amd64 + - variant: models--Qdrant--bge-base-en-v1.5-onnx-Q + model_type: bge-base-en-v1.5-onnx-Q + model_id: Qdrant/bge-base-en-v1.5-onnx-Q + dimension: 768 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-s + model_type: snowflake-arctic-embed-s + model_id: snowflake/snowflake-arctic-embed-s + dimension: 384 + platform: linux/amd64 + - variant: models--Qdrant--bge-large-en-v1.5-onnx-Q + model_type: bge-large-en-v1.5-onnx-Q + model_id: Qdrant/bge-large-en-v1.5-onnx-Q + dimension: 1024 + platform: linux/amd64 + - variant: models--Qdrant--bge-small-en-v1.5-onnx-Q + model_type: bge-small-en-v1.5-onnx-Q + model_id: Qdrant/bge-small-en-v1.5-onnx-Q + dimension: 384 + platform: linux/amd64 + - variant: models--Alibaba-NLP--gte-base-en-v1.5 + model_type: gte-base-en-v1.5 + model_id: Alibaba-NLP/gte-base-en-v1.5 + dimension: 768 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-xs + model_type: snowflake-arctic-embed-xs + model_id: snowflake/snowflake-arctic-embed-xs + dimension: 384 + platform: linux/amd64 + - variant: models--snowflake--snowflake-arctic-embed-l + model_type: snowflake-arctic-embed-l + model_id: snowflake/snowflake-arctic-embed-l + dimension: 1024 + platform: linux/amd64 + - variant: models--Xenova--paraphrase-multilingual-MiniLM-L12-v2 + model_type: paraphrase-multilingual-MiniLM-L12-v2 + model_id: Xenova/paraphrase-multilingual-MiniLM-L12-v2 + dimension: 384 + platform: linux/amd64 + - variant: models--nomic-ai--nomic-embed-text-v1.5 + model_type: nomic-embed-text-v1.5 + model_id: nomic-ai/nomic-embed-text-v1.5 + dimension: 768 + platform: linux/amd64 + - variant: models--Xenova--bge-small-zh-v1.5 + model_type: bge-small-zh-v1.5 + model_id: Xenova/bge-small-zh-v1.5 + dimension: 512 + platform: linux/amd64 + steps: - name: Checkout code uses: actions/checkout@v4 @@ -119,14 +320,13 @@ jobs: with: images: ${{ secrets.DOCKER_USERNAME }}/serverless-vectorizer tags: | - type=raw,value=${{ steps.get_tag.outputs.tag_name }}-${{ matrix.model_type }} - type=raw,value=latest-${{ matrix.model_type }} - type=raw,value=${{ matrix.model_type }} + type=raw,value=${{ steps.get_tag.outputs.tag_name }}-${{ matrix.model_id }} + type=raw,value=latest-${{ matrix.model_id }} + type=raw,value=${{ matrix.model_id }} - name: Build and push variant image uses: docker/build-push-action@v6 with: - file: Dockerfile.variant platforms: ${{ matrix.platform }} push: true @@ -134,7 +334,6 @@ jobs: labels: ${{ steps.meta.outputs.labels }} build-args: | BASE_IMAGE=${{ secrets.DOCKER_USERNAME }}/serverless-vectorizer:base-${{ steps.get_tag.outputs.tag_name }} - VARIANT=${{ matrix.variant }} - MODEL_TYPE=${{ matrix.model_type }} + MODEL_ID=${{ matrix.model_id }} cache-from: type=gha cache-to: type=gha,mode=max \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 6756956..fc1f585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,10 @@ path = "src/bin/preload.rs" name = "embed-cli" path = "src/bin/cli.rs" +[[bin]] +name = "list-models" +path = "src/bin/list-models.rs" + [profile.release] opt-level = "z" lto = true diff --git a/Dockerfile.variant b/Dockerfile.variant index 8e12ac5..e7757a5 100644 --- a/Dockerfile.variant +++ b/Dockerfile.variant @@ -2,12 +2,10 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} -ARG VARIANT -ARG MODEL_TYPE +ARG MODEL_ID -ENV MODEL_VARIANT=${VARIANT} -ENV MODEL_TYPE=${MODEL_TYPE} +ENV MODEL_ID=${MODEL_ID} -RUN cd ${LAMBDA_TASK_ROOT}/ && ${LAMBDA_RUNTIME_DIR}/preload ${MODEL_TYPE} +RUN cd ${LAMBDA_TASK_ROOT}/ && ${LAMBDA_RUNTIME_DIR}/preload ${MODEL_ID} CMD [ "bootstrap" ] \ No newline at end of file diff --git a/README.md b/README.md index e51971c..e6d943b 100644 --- a/README.md +++ b/README.md @@ -6,84 +6,195 @@ [![Docker Pulls](https://img.shields.io/docker/pulls/johnnywale/serverless-vectorizer)](https://hub.docker.com/r/johnnywale/serverless-vectorizer) [![Docker Image Size](https://img.shields.io/docker/image-size/johnnywale/serverless-vectorizer/latest)](https://hub.docker.com/r/johnnywale/serverless-vectorizer) -AWS Lambda container image for generating text embeddings. Models are pre-loaded into Docker images for fast cold -starts - one image per model variant. +AWS Lambda container image for generating embeddings using [fastembed-rs](https://github.com/Anush008/fastembed-rs). Supports **text embeddings**, **image embeddings**, **sparse embeddings**, and **reranking models**. Models are pre-loaded into Docker images for fast cold starts. + +## Prebuilt Docker Images + +The following text embedding models have prebuilt Docker images available on Docker Hub. You can pull and use them directly: + +| Model | Model ID | Dimension | Description | Docker Image | +|----------------------------------------------|-------------------------------------------------------|-----------|----------------------------------------------------------------------------|------------------------------------------------------------------------------------------------| +| All-MINILM-L12-v2 | `Xenova/all-MiniLM-L12-v2` | 384 | `Quantized Sentence Transformer model, MiniLM-L12-v2` | `johnnywalee/serverless-vectorizer:latest-Xenova/all-MiniLM-L12-v2` | +| Snowflake-Arctic-Embed-Xs | `snowflake/snowflake-arctic-embed-xs` | 384 | `Snowflake Arctic embed model, xs` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-xs` | +| BGE-Small-ZH-v1.5 | `Xenova/bge-small-zh-v1.5` | 512 | `v1.5 release of the small Chinese model` | `johnnywalee/serverless-vectorizer:latest-Xenova/bge-small-zh-v1.5` | +| BGE-Small-EN-v1.5-Onnx-Q | `Qdrant/bge-small-en-v1.5-onnx-Q` | 384 | `Quantized v1.5 release of the fast and default English model` | `johnnywalee/serverless-vectorizer:latest-Qdrant/bge-small-en-v1.5-onnx-Q` | +| Snowflake-Arctic-Embed-S | `snowflake/snowflake-arctic-embed-s` | 384 | `Quantized Snowflake Arctic embed model, small` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-s` | +| Snowflake-Arctic-Embed-M-Long | `snowflake/snowflake-arctic-embed-m-long` | 768 | `Snowflake Arctic embed model, medium with 2048 context` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-m-long` | +| BGE-Base-EN-v1.5 | `Xenova/bge-base-en-v1.5` | 768 | `v1.5 release of the base English model` | `johnnywalee/serverless-vectorizer:latest-Xenova/bge-base-en-v1.5` | +| Snowflake-Arctic-Embed-M-Long | `snowflake/snowflake-arctic-embed-m-long` | 768 | `Quantized Snowflake Arctic embed model, medium with 2048 context` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-m-long` | +| Paraphrase-Multilingual-MPNET-Base-v2 | `Xenova/paraphrase-multilingual-mpnet-base-v2` | 768 | `Sentence-transformers model for tasks like clustering or semantic search` | `johnnywalee/serverless-vectorizer:latest-Xenova/paraphrase-multilingual-mpnet-base-v2` | +| BGE-Large-ZH-v1.5 | `Xenova/bge-large-zh-v1.5` | 1024 | `v1.5 release of the large Chinese model` | `johnnywalee/serverless-vectorizer:latest-Xenova/bge-large-zh-v1.5` | +| Modernbert-Embed-Large | `lightonai/modernbert-embed-large` | 1024 | `Large model of ModernBert Text Embeddings` | `johnnywalee/serverless-vectorizer:latest-lightonai/modernbert-embed-large` | +| Multilingual-E5-Large-Onnx | `Qdrant/multilingual-e5-large-onnx` | 1024 | `Large model of multilingual E5 Text Embeddings` | `johnnywalee/serverless-vectorizer:latest-Qdrant/multilingual-e5-large-onnx` | +| BGE-Large-EN-v1.5 | `Xenova/bge-large-en-v1.5` | 1024 | `v1.5 release of the large English model` | `johnnywalee/serverless-vectorizer:latest-Xenova/bge-large-en-v1.5` | +| Multilingual-E5-Small | `intfloat/multilingual-e5-small` | 384 | `Small model of multilingual E5 Text Embeddings` | `johnnywalee/serverless-vectorizer:latest-intfloat/multilingual-e5-small` | +| Snowflake-Arctic-Embed-M | `Snowflake/snowflake-arctic-embed-m` | 768 | `Snowflake Arctic embed model, medium` | `johnnywalee/serverless-vectorizer:latest-Snowflake/snowflake-arctic-embed-m` | +| GTE-Large-EN-v1.5 | `Alibaba-NLP/gte-large-en-v1.5` | 1024 | `Large multilingual embedding model from Alibaba` | `johnnywalee/serverless-vectorizer:latest-Alibaba-NLP/gte-large-en-v1.5` | +| All-MPNET-Base-v2 | `Xenova/all-mpnet-base-v2` | 768 | `Sentence Transformer model, mpnet-base-v2` | `johnnywalee/serverless-vectorizer:latest-Xenova/all-mpnet-base-v2` | +| Nomic-Embed-Text-v1 | `nomic-ai/nomic-embed-text-v1` | 768 | `8192 context length english model` | `johnnywalee/serverless-vectorizer:latest-nomic-ai/nomic-embed-text-v1` | +| All-MINILM-L6-v2 | `Xenova/all-MiniLM-L6-v2` | 384 | `Quantized Sentence Transformer model, MiniLM-L6-v2` | `johnnywalee/serverless-vectorizer:latest-Xenova/all-MiniLM-L6-v2` | +| GTE-Base-EN-v1.5 | `Alibaba-NLP/gte-base-en-v1.5` | 768 | `Large multilingual embedding model from Alibaba` | `johnnywalee/serverless-vectorizer:latest-Alibaba-NLP/gte-base-en-v1.5` | +| GTE-Large-EN-v1.5 | `Alibaba-NLP/gte-large-en-v1.5` | 1024 | `Quantized Large multilingual embedding model from Alibaba` | `johnnywalee/serverless-vectorizer:latest-Alibaba-NLP/gte-large-en-v1.5` | +| Clip-ViT-B-32-Text | `Qdrant/clip-ViT-B-32-text` | 512 | `CLIP text encoder based on ViT-B/32` | `johnnywalee/serverless-vectorizer:latest-Qdrant/clip-ViT-B-32-text` | +| BGE-Base-EN-v1.5-Onnx-Q | `Qdrant/bge-base-en-v1.5-onnx-Q` | 768 | `Quantized v1.5 release of the large English model` | `johnnywalee/serverless-vectorizer:latest-Qdrant/bge-base-en-v1.5-onnx-Q` | +| BGE-Small-EN-v1.5 | `Xenova/bge-small-en-v1.5` | 384 | `v1.5 release of the fast and default English model` | `johnnywalee/serverless-vectorizer:latest-Xenova/bge-small-en-v1.5` | +| Snowflake-Arctic-Embed-S | `snowflake/snowflake-arctic-embed-s` | 384 | `Snowflake Arctic embed model, small` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-s` | +| JINA-Embeddings-v2-Base-Code | `jinaai/jina-embeddings-v2-base-code` | 768 | `Jina embeddings v2 base code` | `johnnywalee/serverless-vectorizer:latest-jinaai/jina-embeddings-v2-base-code` | +| Snowflake-Arctic-Embed-L | `snowflake/snowflake-arctic-embed-l` | 1024 | `Quantized Snowflake Arctic embed model, large` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-l` | +| All-MINILM-L6-v2-Onnx | `Qdrant/all-MiniLM-L6-v2-onnx` | 384 | `Sentence Transformer model, MiniLM-L6-v2` | `johnnywalee/serverless-vectorizer:latest-Qdrant/all-MiniLM-L6-v2-onnx` | +| Multilingual-E5-Base | `intfloat/multilingual-e5-base` | 768 | `Base model of multilingual E5 Text Embeddings` | `johnnywalee/serverless-vectorizer:latest-intfloat/multilingual-e5-base` | +| Paraphrase-Multilingual-MINILM-L12-v2-Onnx-Q | `Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q` | 384 | `Quantized Multi-lingual model` | `johnnywalee/serverless-vectorizer:latest-Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q` | +| GTE-Base-EN-v1.5 | `Alibaba-NLP/gte-base-en-v1.5` | 768 | `Quantized Large multilingual embedding model from Alibaba` | `johnnywalee/serverless-vectorizer:latest-Alibaba-NLP/gte-base-en-v1.5` | +| All-MINILM-L12-v2 | `Xenova/all-MiniLM-L12-v2` | 384 | `Sentence Transformer model, MiniLM-L12-v2` | `johnnywalee/serverless-vectorizer:latest-Xenova/all-MiniLM-L12-v2` | +| BGE-M3 | `BAAI/bge-m3` | 1024 | `Multilingual M3 model with 8192 context length, supports 100+ languages` | `johnnywalee/serverless-vectorizer:latest-BAAI/bge-m3` | +| Nomic-Embed-Text-v1.5 | `nomic-ai/nomic-embed-text-v1.5` | 768 | `v1.5 release of the 8192 context length english model` | `johnnywalee/serverless-vectorizer:latest-nomic-ai/nomic-embed-text-v1.5` | +| Nomic-Embed-Text-v1.5 | `nomic-ai/nomic-embed-text-v1.5` | 768 | `Quantized v1.5 release of the 8192 context length english model` | `johnnywalee/serverless-vectorizer:latest-nomic-ai/nomic-embed-text-v1.5` | +| Mxbai-Embed-Large-v1 | `mixedbread-ai/mxbai-embed-large-v1` | 1024 | `Large English embedding model from MixedBreed.ai` | `johnnywalee/serverless-vectorizer:latest-mixedbread-ai/mxbai-embed-large-v1` | +| Embeddinggemma-300m-ONNX | `onnx-community/embeddinggemma-300m-ONNX` | 768 | `EmbeddingGemma is a 300M parameter from Google` | `johnnywalee/serverless-vectorizer:latest-onnx-community/embeddinggemma-300m-ONNX` | +| Snowflake-Arctic-Embed-Xs | `snowflake/snowflake-arctic-embed-xs` | 384 | `Quantized Snowflake Arctic embed model, xs` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-xs` | +| Mxbai-Embed-Large-v1 | `mixedbread-ai/mxbai-embed-large-v1` | 1024 | `Quantized Large English embedding model from MixedBreed.ai` | `johnnywalee/serverless-vectorizer:latest-mixedbread-ai/mxbai-embed-large-v1` | +| Snowflake-Arctic-Embed-M | `Snowflake/snowflake-arctic-embed-m` | 768 | `Quantized Snowflake Arctic embed model, medium` | `johnnywalee/serverless-vectorizer:latest-Snowflake/snowflake-arctic-embed-m` | +| Snowflake-Arctic-Embed-L | `snowflake/snowflake-arctic-embed-l` | 1024 | `Snowflake Arctic embed model, large` | `johnnywalee/serverless-vectorizer:latest-snowflake/snowflake-arctic-embed-l` | +| Paraphrase-Multilingual-MINILM-L12-v2 | `Xenova/paraphrase-multilingual-MiniLM-L12-v2` | 384 | `Multi-lingual model` | `johnnywalee/serverless-vectorizer:latest-Xenova/paraphrase-multilingual-MiniLM-L12-v2` | +| BGE-Large-EN-v1.5-Onnx-Q | `Qdrant/bge-large-en-v1.5-onnx-Q` | 1024 | `Quantized v1.5 release of the large English model` | `johnnywalee/serverless-vectorizer:latest-Qdrant/bge-large-en-v1.5-onnx-Q` | + + +## Additional Supported Models + +The following models are supported by fastembed-rs and can be built using the [Building Your Own Image](#building-your-own-image) instructions below. Prebuilt images are not yet available for these models. + +### Image Embedding Models + +| Model | Model ID | Dimension | Description | +|-------|----------|-----------|-------------| +| Clip-ViT-B-32-Vision | `Qdrant/clip-ViT-B-32-vision` | 512 | CLIP vision encoder based on ViT-B/32 | +| Resnet50-Onnx | `Qdrant/resnet50-onnx` | 2048 | ResNet-50 from `Deep Residual Learning for Image Recognition `__. | +| Unicom-ViT-B-16 | `Qdrant/Unicom-ViT-B-16` | 768 | Unicom Unicom-ViT-B-16 from open-metric-learning | +| Unicom-ViT-B-32 | `Qdrant/Unicom-ViT-B-32` | 512 | Unicom Unicom-ViT-B-32 from open-metric-learning | +| Nomic-Embed-Vision-v1.5 | `nomic-ai/nomic-embed-vision-v1.5` | 768 | Nomic NomicEmbedVisionV15 | + +### Sparse Text Embedding Models + +| Model | Model ID | Dimension | Description | +|-------|----------|-----------|-------------| +| Splade_PP_en_v1 | `Qdrant/Splade_PP_en_v1` | - | Splade sparse vector model for commercial use, v1 | +| BGE-M3 | `BAAI/bge-m3` | - | BGE-M3 sparse embedding model with 8192 context, supports 100+ languages | + +### Reranking Models + +| Model | Model ID | Dimension | Description | +|-------|----------|-----------|-------------| +| BGE-Reranker-Base | `BAAI/bge-reranker-base` | - | reranker model for English and Chinese | +| BGE-Reranker-v2-M3 | `rozgo/bge-reranker-v2-m3` | - | reranker model for multilingual | +| JINA-Reranker-v1-Turbo-EN | `jinaai/jina-reranker-v1-turbo-en` | - | reranker model for English | +| JINA-Reranker-v2-Base-Multilingual | `jinaai/jina-reranker-v2-base-multilingual` | - | reranker model for multilingual | + + + + +## Building Your Own Image + +The build process uses a two-stage approach: + +1. **Base image** - Contains the Lambda runtime and embedding binaries +2. **Variant image** - Extends the base image with a pre-loaded model for fast cold starts + +### Option 1: Use Pre-built Base Image (Recommended) + +Use the pre-built base image from Docker Hub to skip the base build step: -## Supported Models - -| Model | ID | Dimension | Language | -|-----------------------|-------------------------|-----------|--------------| -| BGE-Small-EN-v1.5 | `bge-small-en-v1.5` | 384 | English | -| BGE-Base-EN-v1.5 | `bge-base-en-v1.5` | 768 | English | -| BGE-Large-EN-v1.5 | `bge-large-en-v1.5` | 1024 | English | -| Multilingual-E5-Large | `multilingual-e5-large` | 1024 | Multilingual | -| All-MpNet-Base-v2 | `all-mpnet-base-v2` | 768 | English | - -All models support a maximum of 512 tokens per input text. +```bash +# Build a model variant using the pre-built base image +docker build \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Xenova/all-MiniLM-L12-v2 \ + -f Dockerfile.variant \ + -t my-vectorizer:minilm . +``` -## Building +### Option 2: Build Everything from Source -### Base Image +#### Step 1: Build the Base Image ```bash docker build -t serverless-vectorizer:base . ``` -### Model-Specific Variants +#### Step 2: Build a Model Variant -Each model variant bakes the model files into the image for faster Lambda cold starts: +Use `Dockerfile.variant` with the following build arguments: + +- `BASE_IMAGE` - The base image to extend (your local build or `johnnywalee/serverless-vectorizer:base-latest`) +- `MODEL_ID` - The model ID from the [Supported Models](#supported-models) table above ```bash -# Build BGE-Small variant docker build \ --build-arg BASE_IMAGE=serverless-vectorizer:base \ - --build-arg VARIANT=bge-small \ - --build-arg MODEL_TYPE=bge-small-en-v1.5 \ + --build-arg MODEL_ID= \ + -f Dockerfile.variant \ + -t serverless-vectorizer: . +``` + +### Build Examples + +```bash +# BGE-Small (384 dimensions, English) +docker build \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Xenova/bge-small-en-v1.5 \ -f Dockerfile.variant \ -t serverless-vectorizer:bge-small . -# Build BGE-Base variant +# BGE-M3 (1024 dimensions, Multilingual) docker build \ - --build-arg BASE_IMAGE=serverless-vectorizer:base \ - --build-arg VARIANT=bge-base \ - --build-arg MODEL_TYPE=bge-base-en-v1.5 \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=BAAI/bge-m3 \ -f Dockerfile.variant \ - -t serverless-vectorizer:bge-base . + -t serverless-vectorizer:bge-m3 . -# Build BGE-Large variant +# Snowflake Arctic Embed Large (1024 dimensions) docker build \ - --build-arg BASE_IMAGE=serverless-vectorizer:base \ - --build-arg VARIANT=bge-large \ - --build-arg MODEL_TYPE=bge-large-en-v1.5 \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=snowflake/snowflake-arctic-embed-l \ -f Dockerfile.variant \ - -t serverless-vectorizer:bge-large . + -t serverless-vectorizer:arctic-l . -# Build Multilingual-E5-Large variant +# Multilingual E5 Large (1024 dimensions) docker build \ - --build-arg BASE_IMAGE=serverless-vectorizer:base \ - --build-arg VARIANT=e5-large \ - --build-arg MODEL_TYPE=multilingual-e5-large \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Qdrant/multilingual-e5-large-onnx \ -f Dockerfile.variant \ -t serverless-vectorizer:e5-large . -# Build All-MpNet variant +# All-MiniLM (384 dimensions, lightweight) docker build \ - --build-arg BASE_IMAGE=serverless-vectorizer:base \ - --build-arg VARIANT=mpnet \ - --build-arg MODEL_TYPE=all-mpnet-base-v2 \ + --build-arg BASE_IMAGE=johnnywalee/serverless-vectorizer:base-latest \ + --build-arg MODEL_ID=Xenova/all-MiniLM-L6-v2 \ -f Dockerfile.variant \ - -t serverless-vectorizer:mpnet . + -t serverless-vectorizer:minilm . ``` -## Configuration +### List Available Models -Set the model via environment variable in your Lambda configuration: +Use the included CLI tool to list all supported models: ```bash -EMBEDDING_MODEL=bge-small-en-v1.5 +# Build and run the list-models tool +cargo run --bin list-models + +# Output as markdown table +cargo run --bin list-models -- -f markdown + +# Output as JSON +cargo run --bin list-models -- -f json + +# List all model categories (text, image, sparse, rerank) +cargo run --bin list-models -- -c all ``` -If not specified, defaults to `bge-small-en-v1.5`. + + + ## Usage @@ -310,8 +421,12 @@ aws lambda update-function-code \ ## Acknowledgments -This project is powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs), a Rust library for fast, -lightweight text embedding generation. +This project is powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs), a Rust library for fast, lightweight embedding generation. fastembed-rs supports: + +- **Text Embeddings** - Dense vector representations for semantic search and similarity +- **Image Embeddings** - Vision encoders like CLIP and ResNet for image similarity +- **Sparse Text Embeddings** - SPLADE models for hybrid search +- **Reranking Models** - Cross-encoder models for result reranking ## License diff --git a/src/bin/cli.rs b/src/bin/cli.rs index 260d9de..35063a5 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -1,10 +1,11 @@ // CLI tool for text embedding operations // Uses shared core functionality from the embedding_service library -use clap::{Parser, Subcommand, ValueEnum}; +use clap::{Parser, Subcommand}; +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use serverless_vectorizer::{ // Core functionality - ModelType, EmbeddingService, + ModelRegistry, cosine_similarity, l2_normalize, pairwise_similarity_matrix, top_k_similar, compute_stats, validate_embeddings, kmeans_cluster, chunk_text, @@ -13,32 +14,30 @@ use serverless_vectorizer::{ ClusterResponse, DistanceMatrixResponse, BenchmarkResult, }; #[cfg(feature = "pdf")] -use embedding_service::{extract_text_from_file, is_pdf_file}; +use serverless_vectorizer::{extract_text_from_file, is_pdf_file}; +use std::collections::HashMap; use std::fs; use std::io::{self, Read}; use std::path::PathBuf; +use std::sync::Mutex; use std::time::Instant; -/// CLI model choice enum (maps to core ModelType) -#[derive(Debug, Clone, ValueEnum)] -enum ModelChoice { - BgeSmall, - BgeBase, - BgeLarge, - MultilingualE5, - AllMpnet, +/// Resolve model name to EmbeddingModel, with helpful error on failure +fn resolve_model(model_name: &str) -> Result { + ModelRegistry::find_text_model(model_name).ok_or_else(|| { + let mut msg = format!("Unknown model: '{}'\n\nAvailable text embedding models:\n", model_name); + for info in ModelRegistry::text_embedding_models() { + msg.push_str(&format!(" - {} ({}D)\n", info.model_id, info.dimension.unwrap_or(0))); + } + msg + }) } -impl ModelChoice { - fn to_model_type(&self) -> ModelType { - match self { - ModelChoice::BgeSmall => ModelType::BgeSmallEnV15, - ModelChoice::BgeBase => ModelType::BgeBaseEnV15, - ModelChoice::BgeLarge => ModelType::BgeLargeEnV15, - ModelChoice::MultilingualE5 => ModelType::MultilingualE5Large, - ModelChoice::AllMpnet => ModelType::AllMpnetBaseV2, - } - } +/// Get model info (id and dimension) for display +fn get_model_display_info(model: &EmbeddingModel) -> (String, usize) { + ModelRegistry::get_text_model_info(model) + .map(|info| (info.model_id, info.dimension.unwrap_or(0))) + .unwrap_or_else(|| ("unknown".to_string(), 0)) } #[derive(Parser)] @@ -67,8 +66,9 @@ enum Commands { pretty: bool, #[arg(long)] vectors_only: bool, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + /// Model name (use 'info' command to list available models) + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Compute similarity between two texts @@ -77,8 +77,8 @@ enum Commands { text1: String, #[arg(long)] text2: String, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Embed multiple texts from a file @@ -91,12 +91,16 @@ enum Commands { input_format: String, #[arg(long)] pretty: bool, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Show model information - Info, + Info { + /// Filter by model category (text, image, sparse, rerank) + #[arg(long)] + category: Option, + }, /// Search for similar texts in a corpus Search { @@ -108,8 +112,8 @@ enum Commands { top_k: usize, #[arg(long, default_value = "texts")] corpus_format: String, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Compute pairwise similarity matrix @@ -122,8 +126,8 @@ enum Commands { input_format: String, #[arg(long)] pretty: bool, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// L2 normalize embeddings @@ -170,8 +174,8 @@ enum Commands { iterations: usize, #[arg(long, default_value = "50")] text_length: usize, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Cluster texts using k-means @@ -188,8 +192,8 @@ enum Commands { input_format: String, #[arg(long)] pretty: bool, - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, /// Extract text from PDF and optionally embed (requires 'pdf' feature) @@ -224,14 +228,50 @@ enum Commands { pretty: bool, /// Embedding model to use (when --embed is used) - #[arg(long, value_enum, default_value = "bge-small")] - model: ModelChoice, + #[arg(long, default_value = "Xenova/bge-small-en-v1.5")] + model: String, }, } +// Embedding service with model caching +struct EmbeddingService { + models: Mutex>, + show_progress: bool, +} + +impl EmbeddingService { + fn new() -> Self { + Self { + models: Mutex::new(HashMap::new()), + show_progress: true, + } + } + + fn embed(&self, texts: Vec, model: EmbeddingModel) -> Result>, Box> { + let model_key = format!("{:?}", model); + let mut models = self.models.lock().map_err(|e| e.to_string())?; + + if !models.contains_key(&model_key) { + let text_embedding = TextEmbedding::try_new( + InitOptions::new(model.clone()).with_show_download_progress(self.show_progress), + )?; + models.insert(model_key.clone(), text_embedding); + } + + let text_embedding = models.get_mut(&model_key).unwrap(); + let embeddings = text_embedding.embed(texts, None)?; + Ok(embeddings) + } + + fn embed_one(&self, text: &str, model: EmbeddingModel) -> Result, Box> { + let embeddings = self.embed(vec![text.to_string()], model)?; + embeddings.into_iter().next().ok_or_else(|| "No embedding generated".into()) + } +} + // Global embedding service static SERVICE: std::sync::LazyLock = - std::sync::LazyLock::new(|| EmbeddingService::new().with_progress(true)); + std::sync::LazyLock::new(|| EmbeddingService::new()); fn get_input(text: Option, file: Option) -> Result, Box> { if let Some(t) = text { @@ -286,9 +326,10 @@ fn main() -> Result<(), Box> { match cli.command { Commands::Embed { text, file, output, format, pretty, vectors_only, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; + let (model_id, _) = get_model_display_info(&embedding_model); let texts = get_input(text, file)?; - let embeddings = SERVICE.embed(texts, model_type)?; + let embeddings = SERVICE.embed(texts, embedding_model)?; if format == "binary" { let mut bytes = Vec::new(); @@ -307,7 +348,7 @@ fn main() -> Result<(), Box> { let result = if vectors_only { serde_json::to_value(&embeddings)? } else { - let out = EmbeddingOutput::new(embeddings).with_model(model_type.display_name()); + let out = EmbeddingOutput::new(embeddings).with_model(&model_id); serde_json::to_value(&out)? }; let output_str = if pretty { @@ -320,14 +361,15 @@ fn main() -> Result<(), Box> { } Commands::Similarity { text1, text2, model } => { - let model_type = model.to_model_type(); - let embeddings = SERVICE.embed(vec![text1, text2], model_type)?; + let embedding_model = resolve_model(&model)?; + let embeddings = SERVICE.embed(vec![text1, text2], embedding_model)?; let similarity = cosine_similarity(&embeddings[0], &embeddings[1]); println!("{:.6}", similarity); } Commands::Batch { file, output, input_format, pretty, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; + let (model_id, _) = get_model_display_info(&embedding_model); let texts = parse_texts_from_file(&file, &input_format)?; if texts.is_empty() { @@ -336,10 +378,10 @@ fn main() -> Result<(), Box> { } eprintln!("Embedding {} texts...", texts.len()); - let embeddings = SERVICE.embed(texts, model_type)?; + let embeddings = SERVICE.embed(texts, embedding_model)?; eprintln!("Done."); - let out = EmbeddingOutput::new(embeddings).with_model(model_type.display_name()); + let out = EmbeddingOutput::new(embeddings).with_model(&model_id); let output_str = if pretty { serde_json::to_string_pretty(&out)? } else { @@ -348,21 +390,57 @@ fn main() -> Result<(), Box> { write_output(&output_str, output)?; } - Commands::Info => { - println!("Embedding Model Information"); - println!("===========================\n"); - println!("Available Models:"); - for model_type in ModelType::all() { - println!(" {:20} {:25} ({} dim, {})", - model_type.id(), - model_type.display_name(), - model_type.dimension(), - model_type.language() - ); + Commands::Info { category } => { + println!("Embedding Model Information (fastembed)"); + println!("=======================================\n"); + + let show_text = category.is_none() || category.as_deref() == Some("text"); + let show_image = category.is_none() || category.as_deref() == Some("image"); + let show_sparse = category.is_none() || category.as_deref() == Some("sparse"); + let show_rerank = category.is_none() || category.as_deref() == Some("rerank"); + + if show_text { + let text_models = ModelRegistry::text_embedding_models(); + println!("Text Embedding Models ({}):", text_models.len()); + println!("{:-<60}", ""); + for info in text_models { + println!(" {:45} {:>5}D", info.model_id, info.dimension.unwrap_or(0)); + } + println!(); + } + + if show_image { + let image_models = ModelRegistry::image_embedding_models(); + println!("Image Embedding Models ({}):", image_models.len()); + println!("{:-<60}", ""); + for info in image_models { + println!(" {:45} {:>5}D", info.model_id, info.dimension.unwrap_or(0)); + } + println!(); + } + + if show_sparse { + let sparse_models = ModelRegistry::sparse_text_embedding_models(); + println!("Sparse Text Embedding Models ({}):", sparse_models.len()); + println!("{:-<60}", ""); + for info in sparse_models { + println!(" {}", info.model_id); + } + println!(); } - println!("\nDefault: {}", ModelType::default().id()); - println!("Max Tokens: 512"); - println!("Provider: BAAI/Sentence-Transformers (via fastembed)"); + + if show_rerank { + let rerank_models = ModelRegistry::rerank_models(); + println!("Reranking Models ({}):", rerank_models.len()); + println!("{:-<60}", ""); + for info in rerank_models { + println!(" {}", info.model_id); + } + println!(); + } + + println!("Default: Xenova/bge-small-en-v1.5"); + println!("Provider: fastembed (https://github.com/Anush008/fastembed-rs)"); println!("\nModels are downloaded on first use if not cached."); #[cfg(feature = "pdf")] @@ -372,7 +450,7 @@ fn main() -> Result<(), Box> { } Commands::Search { query, corpus, top_k, corpus_format, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; let corpus_content = fs::read_to_string(&corpus)?; let (corpus_texts, corpus_embeddings): (Vec, Vec>) = @@ -395,7 +473,7 @@ fn main() -> Result<(), Box> { .collect() }; eprintln!("Embedding {} corpus texts...", texts.len()); - let embeddings = SERVICE.embed(texts.clone(), model_type)?; + let embeddings = SERVICE.embed(texts.clone(), embedding_model.clone())?; (texts, embeddings) }; @@ -404,7 +482,7 @@ fn main() -> Result<(), Box> { std::process::exit(1); } - let query_emb = SERVICE.embed_one(&query, model_type)?; + let query_emb = SERVICE.embed_one(&query, embedding_model)?; let similar = top_k_similar(&query_emb, &corpus_embeddings, top_k); let results: Vec = similar.iter().enumerate() @@ -421,7 +499,7 @@ fn main() -> Result<(), Box> { } Commands::DistanceMatrix { file, output, input_format, pretty, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; let texts = parse_texts_from_file(&file, &input_format)?; if texts.is_empty() { @@ -430,7 +508,7 @@ fn main() -> Result<(), Box> { } eprintln!("Embedding {} texts...", texts.len()); - let embeddings = SERVICE.embed(texts.clone(), model_type)?; + let embeddings = SERVICE.embed(texts.clone(), embedding_model)?; let matrix = pairwise_similarity_matrix(&embeddings); let response = DistanceMatrixResponse { @@ -551,7 +629,8 @@ fn main() -> Result<(), Box> { } Commands::Benchmark { iterations, text_length, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; + let (model_id, dimension) = get_model_display_info(&embedding_model); let sample_words = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "hello", "world", "machine", "learning", "embedding", "vector", @@ -563,28 +642,28 @@ fn main() -> Result<(), Box> { println!("Benchmark Configuration"); println!("======================="); - println!("Model: {}", model_type.display_name()); + println!("Model: {}", model_id); println!("Iterations: {}", iterations); println!("Text length: {} words\n", text_length); eprintln!("Warming up model..."); - let _ = SERVICE.embed(vec![sample_text.clone()], model_type)?; + let _ = SERVICE.embed(vec![sample_text.clone()], embedding_model.clone())?; eprintln!("Running benchmark..."); let start = Instant::now(); for _ in 0..iterations { - let _ = SERVICE.embed(vec![sample_text.clone()], model_type)?; + let _ = SERVICE.embed(vec![sample_text.clone()], embedding_model.clone())?; } let elapsed = start.elapsed(); let result = BenchmarkResult { - model: model_type.display_name().to_string(), + model: model_id, iterations, text_length, total_ms: elapsed.as_millis(), avg_ms: elapsed.as_millis() as f64 / iterations as f64, throughput: 1000.0 / (elapsed.as_millis() as f64 / iterations as f64), - dimension: model_type.dimension(), + dimension, }; println!("Results"); @@ -596,7 +675,7 @@ fn main() -> Result<(), Box> { } Commands::Cluster { file, k, max_iter, output, input_format, pretty, model } => { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; let texts = parse_texts_from_file(&file, &input_format)?; if texts.is_empty() { @@ -610,7 +689,7 @@ fn main() -> Result<(), Box> { } eprintln!("Embedding {} texts...", texts.len()); - let embeddings = SERVICE.embed(texts.clone(), model_type)?; + let embeddings = SERVICE.embed(texts.clone(), embedding_model)?; eprintln!("Clustering into {} clusters...", k); let cluster_result = kmeans_cluster(&embeddings, k, max_iter); @@ -635,7 +714,8 @@ fn main() -> Result<(), Box> { pdf_doc.text.len(), pdf_doc.page_count); if embed { - let model_type = model.to_model_type(); + let embedding_model = resolve_model(&model)?; + let (model_id, _) = get_model_display_info(&embedding_model); // Determine texts to embed let texts: Vec = if chunk { @@ -651,7 +731,7 @@ fn main() -> Result<(), Box> { } eprintln!("Embedding {} text segment(s)...", texts.len()); - let embeddings = SERVICE.embed(texts.clone(), model_type)?; + let embeddings = SERVICE.embed(texts.clone(), embedding_model)?; let result = serde_json::json!({ "source": file.to_string_lossy(), @@ -660,7 +740,7 @@ fn main() -> Result<(), Box> { "embeddings": embeddings, "dimension": embeddings.first().map(|e| e.len()).unwrap_or(0), "count": embeddings.len(), - "model": model_type.display_name(), + "model": model_id, "chunked": chunk }); diff --git a/src/bin/list-models.rs b/src/bin/list-models.rs new file mode 100644 index 0000000..a9456ca --- /dev/null +++ b/src/bin/list-models.rs @@ -0,0 +1,259 @@ +// CLI tool to list all supported fastembed models in various formats +// Useful for CI/CD matrix generation (GitHub Actions, etc.) + +use clap::{Parser, ValueEnum}; +use serverless_vectorizer::{ModelInfo, ModelRegistry}; + +#[derive(Debug, Clone, ValueEnum)] +enum OutputFormat { + /// Both Markdown and GitHub Actions matrix format (default) + All, + /// Markdown table format + Markdown, + /// GitHub Actions matrix format (YAML) + GithubMatrix, + /// Simple YAML list + Yaml, + /// JSON array + Json, + /// One model per line + Plain, +} + +#[derive(Debug, Clone, ValueEnum)] +enum Category { + /// Text embedding models + Text, + /// Image embedding models + Image, + /// Sparse text embedding models + Sparse, + /// Reranking models + Rerank, + /// All model categories + All, +} + +#[derive(Parser)] +#[command(name = "list-models")] +#[command(version = "0.1.0")] +#[command(about = "List all supported fastembed models in various formats")] +struct Cli { + /// Output format + #[arg(short, long, default_value = "all")] + format: OutputFormat, + + /// Model category to list + #[arg(short, long, default_value = "all")] + category: Category, + + /// Platform for matrix output + #[arg(long, default_value = "linux/amd64")] + platform: String, + + /// Only output the matrix content (no 'strategy:' wrapper) + #[arg(long)] + matrix_only: bool, +} + +fn model_id_to_variant(model_id: &str) -> String { + // Convert "BAAI/bge-small-en-v1.5" -> "models--BAAI--bge-small-en-v1.5" + format!("models--{}", model_id.replace("/", "--")) +} + +fn model_id_to_type(model_id: &str) -> String { + // Convert "BAAI/bge-small-en-v1.5" -> "bge-small-en-v1.5" + model_id.split('/').last().unwrap_or(model_id).to_string() +} + +fn model_id_to_display_name(model_id: &str) -> String { + // Convert "BAAI/bge-small-en-v1.5" -> "BGE-Small-EN-v1.5" + let short_id = model_id.split('/').last().unwrap_or(model_id); + + // Split by hyphens and capitalize appropriately + short_id + .split('-') + .map(|part| { + // Handle version numbers (v1.5, v2, etc.) + if part.starts_with('v') + && part.len() > 1 + && part + .chars() + .skip(1) + .next() + .map_or(false, |c| c.is_numeric()) + { + part.to_string() + } + // Handle common abbreviations that should be uppercase + else if part.eq_ignore_ascii_case("en") + || part.eq_ignore_ascii_case("e5") + || part.eq_ignore_ascii_case("bge") + || part.eq_ignore_ascii_case("gte") + || part.eq_ignore_ascii_case("jina") + || part.eq_ignore_ascii_case("xl") + || part.eq_ignore_ascii_case("zh") + || part.eq_ignore_ascii_case("ml") + || part.eq_ignore_ascii_case("l6") + || part.eq_ignore_ascii_case("l12") + || part.eq_ignore_ascii_case("mpnet") + || part.eq_ignore_ascii_case("minilm") + { + part.to_uppercase() + } + // Capitalize first letter of other parts + else { + let mut chars = part.chars(); + match chars.next() { + None => String::new(), + Some(first) => first.to_uppercase().chain(chars).collect(), + } + } + }) + .collect::>() + .join("-") +} + + + +fn print_markdown_table(models: &[ModelInfo]) { + use std::collections::HashMap; + + // Group models by category + let mut grouped: HashMap> = HashMap::new(); + for model in models { + let category = format!("{}", model.model_type); + grouped.entry(category).or_default().push(model); + } + + // Define category order + let category_order = [ + ("Text Embedding", "Text Embedding Models"), + ("Image Embedding", "Image Embedding Models"), + ("Sparse Text Embedding", "Sparse Text Embedding Models"), + ("Text Rerank", "Reranking Models"), + ]; + + println!("## Supported Models\n"); + + for (category_key, category_title) in &category_order { + if let Some(category_models) = grouped.get(*category_key) { + println!("### {}\n", category_title); + println!("| Model | Model ID | Dimension | Description |"); + println!("|-------|----------|-----------|-------------|"); + + for model in category_models { + let display_name = model_id_to_display_name(&model.model_id); + let short_id = model.model_id.as_str(); + let dimension = model + .dimension + .map(|d| d.to_string()) + .unwrap_or_else(|| "-".to_string()); + + println!( + "| {} | `{}` | {} | {} |", + display_name, short_id, dimension, model.description + ); + } + println!(); + } + } +} + +fn print_github_matrix(models: &[ModelInfo], platform: &str, matrix_only: bool) { + if !matrix_only { + println!("strategy:"); + println!(" matrix:"); + println!(" include:"); + } else { + println!("matrix:"); + println!(" include:"); + } + for model in models { + let variant = model_id_to_variant(&model.model_id); + let model_type = model_id_to_type(&model.model_id); + let indent = if matrix_only { " " } else { " " }; + println!("{}- variant: {}", indent, variant); + println!("{} model_type: {}", indent, model_type); + println!("{} model_id: {}", indent, model.model_id); + if let Some(dim) = model.dimension { + println!("{} dimension: {}", indent, dim); + } + println!("{} platform: {}", indent, platform); + } +} + +fn main() { + let cli = Cli::parse(); + + // Collect models based on category + let models = match cli.category { + Category::Text => ModelRegistry::text_embedding_models(), + Category::Image => ModelRegistry::image_embedding_models(), + Category::Sparse => ModelRegistry::sparse_text_embedding_models(), + Category::Rerank => ModelRegistry::rerank_models(), + Category::All => ModelRegistry::all_models(), + }; + + match cli.format { + OutputFormat::All => { + print_markdown_table(&models); + println!("\n---\n"); + print_github_matrix(&models, &cli.platform, cli.matrix_only); + } + + OutputFormat::Markdown => { + print_markdown_table(&models); + } + + OutputFormat::GithubMatrix => { + print_github_matrix(&models, &cli.platform, cli.matrix_only); + } + + OutputFormat::Yaml => { + println!("models:"); + for model in &models { + println!(" - id: \"{}\"", model.model_id); + println!(" type: \"{}\"", model_id_to_type(&model.model_id)); + if let Some(dim) = model.dimension { + println!(" dimension: {}", dim); + } + println!(" category: {:?}", model.model_type); + println!( + " description: \"{}\"", + model.description.replace("\"", "\\\"") + ); + } + } + + OutputFormat::Json => { + let json_models: Vec = models + .iter() + .map(|m| { + serde_json::json!({ + "id": m.model_id, + "type": model_id_to_type(&m.model_id), + "variant": model_id_to_variant(&m.model_id), + "dimension": m.dimension, + "category": format!("{:?}", m.model_type), + "description": m.description + }) + }) + .collect(); + println!("{}", serde_json::to_string_pretty(&json_models).unwrap()); + } + + OutputFormat::Plain => { + for model in &models { + if let Some(dim) = model.dimension { + println!("{}\t{}D\t{:?}", model.model_id, dim, model.model_type); + } else { + println!("{}\t-\t{:?}", model.model_id, model.model_type); + } + } + } + } + + // Print summary to stderr + eprintln!("\n# Total: {} models", models.len()); +} diff --git a/src/bin/preload.rs b/src/bin/preload.rs index 7a4b761..5885c4e 100644 --- a/src/bin/preload.rs +++ b/src/bin/preload.rs @@ -1,30 +1,34 @@ -use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use fastembed::{InitOptions, TextEmbedding}; +use serverless_vectorizer::ModelRegistry; use std::env; fn main() { let args: Vec = env::args().collect(); - let model_type = args + let model_id = args .get(1) .map(String::as_str) - .unwrap_or("bge-small-en-v1.5"); + .unwrap_or("Xenova/bge-small-en-v1.5"); - let embedding_model = match model_type { - "bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15, - "bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15, - "bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15, - "multilingual-e5-large" => EmbeddingModel::MultilingualE5Large, - "all-mpnet-base-v2" => EmbeddingModel::AllMpnetBaseV2, - _ => { + // Use ModelRegistry to dynamically find the model + let embedding_model = match ModelRegistry::find_text_model(model_id) { + Some(model) => model, + None => { eprintln!( - "Warning: Unknown model '{}', falling back to bge-small-en-v1.5", - model_type + "Warning: Unknown model '{}', falling back to Xenova/bge-small-en-v1.5", + model_id ); - EmbeddingModel::BGESmallENV15 + eprintln!("\nAvailable text embedding models:"); + for model_info in ModelRegistry::text_embedding_models() { + eprintln!(" - {} ({}D)", model_info.model_id, model_info.dimension.unwrap_or(0)); + } + ModelRegistry::default_text_model() } }; + println!("======================================"); - println!("Using embedding model: {}", model_type); + println!("Preloading embedding model: {}", model_id); + let mut model = TextEmbedding::try_new(InitOptions::new(embedding_model).with_show_download_progress(true)) .expect("Failed to initialize model"); @@ -34,5 +38,5 @@ fn main() { .expect("Failed to generate embedding"); println!("Embedding dimension = {}", embeddings[0].len()); - println!("Done 🎉"); + println!("Done!"); } diff --git a/src/core/mod.rs b/src/core/mod.rs index bf7cced..19199fb 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -11,7 +11,7 @@ pub mod types; pub mod pdf; // Re-export model types -pub use model::{ModelType, ModelInfo, ModelRegistry, MODEL_REGISTRY}; +pub use model::{ModelType, ModelInfo, ModelRegistry, ModelCategory, TextModel, MODEL_REGISTRY}; // Re-export embedding service pub use embeddings::{EmbeddingService, EmbeddingError, global_service, embed, embed_one}; diff --git a/src/core/model.rs b/src/core/model.rs index fab567d..c889d02 100644 --- a/src/core/model.rs +++ b/src/core/model.rs @@ -1,10 +1,193 @@ -// Model selection and configuration +// Model selection and configuration using fastembed's native model discovery -use fastembed::EmbeddingModel; +use fastembed::{ + EmbeddingModel, ImageEmbeddingModel, SparseModel, RerankerModel, + TextEmbedding, ImageEmbedding, SparseTextEmbedding, TextRerank, +}; use serde::{Deserialize, Serialize}; use std::fmt; -/// Supported embedding model types +/// Re-export fastembed's EmbeddingModel for text embeddings +pub use fastembed::EmbeddingModel as TextModel; + +/// Unified model information structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { + pub model_id: String, + pub description: String, + pub dimension: Option, + pub model_type: ModelCategory, +} + +/// Category of embedding model +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ModelCategory { + TextEmbedding, + ImageEmbedding, + SparseTextEmbedding, + TextRerank, +} + +impl fmt::Display for ModelCategory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ModelCategory::TextEmbedding => write!(f, "Text Embedding"), + ModelCategory::ImageEmbedding => write!(f, "Image Embedding"), + ModelCategory::SparseTextEmbedding => write!(f, "Sparse Text Embedding"), + ModelCategory::TextRerank => write!(f, "Text Rerank"), + } + } +} + +/// Registry for discovering all supported models from fastembed +pub struct ModelRegistry; + +impl ModelRegistry { + /// Get all supported text embedding models + pub fn text_embedding_models() -> Vec { + TextEmbedding::list_supported_models() + .into_iter() + .map(|info| ModelInfo { + model_id: info.model_code.to_string(), + description: info.description.to_string(), + dimension: Some(info.dim), + model_type: ModelCategory::TextEmbedding, + }) + .collect() + } + + /// Get all supported image embedding models + pub fn image_embedding_models() -> Vec { + ImageEmbedding::list_supported_models() + .into_iter() + .map(|info| ModelInfo { + model_id: info.model_code.to_string(), + description: info.description.to_string(), + dimension: Some(info.dim), + model_type: ModelCategory::ImageEmbedding, + }) + .collect() + } + + /// Get all supported sparse text embedding models + pub fn sparse_text_embedding_models() -> Vec { + SparseTextEmbedding::list_supported_models() + .into_iter() + .map(|info| ModelInfo { + model_id: info.model_code.to_string(), + description: info.description.to_string(), + dimension: None, // Sparse models don't have fixed dimension + model_type: ModelCategory::SparseTextEmbedding, + }) + .collect() + } + + /// Get all supported reranking models + pub fn rerank_models() -> Vec { + TextRerank::list_supported_models() + .into_iter() + .map(|info| ModelInfo { + model_id: info.model_code.to_string(), + description: info.description.to_string(), + dimension: None, // Rerank models don't produce embeddings + model_type: ModelCategory::TextRerank, + }) + .collect() + } + + /// Get all supported models across all categories + pub fn all_models() -> Vec { + let mut models = Vec::new(); + models.extend(Self::text_embedding_models()); + models.extend(Self::image_embedding_models()); + models.extend(Self::sparse_text_embedding_models()); + models.extend(Self::rerank_models()); + models + } + + /// Find a text embedding model by its model code/id + pub fn find_text_model(model_id: &str) -> Option { + let model_id_lower = model_id.to_lowercase(); + TextEmbedding::list_supported_models() + .into_iter() + .find(|info| { + info.model_code.to_lowercase() == model_id_lower + || info.model_code.to_lowercase().contains(&model_id_lower) + || model_id_lower.contains(&info.model_code.to_lowercase().replace("/", "-").replace("_", "-")) + }) + .map(|info| info.model) + } + + /// Find an image embedding model by its model code/id + pub fn find_image_model(model_id: &str) -> Option { + let model_id_lower = model_id.to_lowercase(); + ImageEmbedding::list_supported_models() + .into_iter() + .find(|info| { + info.model_code.to_lowercase() == model_id_lower + || info.model_code.to_lowercase().contains(&model_id_lower) + }) + .map(|info| info.model) + } + + /// Find a sparse text embedding model by its model code/id + pub fn find_sparse_model(model_id: &str) -> Option { + let model_id_lower = model_id.to_lowercase(); + SparseTextEmbedding::list_supported_models() + .into_iter() + .find(|info| { + info.model_code.to_lowercase() == model_id_lower + || info.model_code.to_lowercase().contains(&model_id_lower) + }) + .map(|info| info.model) + } + + /// Find a reranking model by its model code/id + pub fn find_rerank_model(model_id: &str) -> Option { + let model_id_lower = model_id.to_lowercase(); + TextRerank::list_supported_models() + .into_iter() + .find(|info| { + info.model_code.to_lowercase() == model_id_lower + || info.model_code.to_lowercase().contains(&model_id_lower) + }) + .map(|info| info.model) + } + + /// Get model info for a text embedding model + pub fn get_text_model_info(model: &EmbeddingModel) -> Option { + TextEmbedding::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model)) + .map(|info| ModelInfo { + model_id: info.model_code.to_string(), + description: info.description.to_string(), + dimension: Some(info.dim), + model_type: ModelCategory::TextEmbedding, + }) + } + + /// Get dimension for a text embedding model + pub fn get_text_model_dimension(model: &EmbeddingModel) -> Option { + TextEmbedding::list_supported_models() + .into_iter() + .find(|info| std::mem::discriminant(&info.model) == std::mem::discriminant(model)) + .map(|info| info.dim) + } + + /// Get default text embedding model + pub fn default_text_model() -> EmbeddingModel { + EmbeddingModel::BGESmallENV15 + } +} + +// ============================================================================ +// Backward compatibility layer for ModelType +// ============================================================================ + +/// Legacy ModelType enum for backward compatibility +/// Maps to fastembed's EmbeddingModel #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub enum ModelType { @@ -23,19 +206,13 @@ impl ModelType { ModelType::BgeBaseEnV15 => EmbeddingModel::BGEBaseENV15, ModelType::BgeLargeEnV15 => EmbeddingModel::BGELargeENV15, ModelType::MultilingualE5Large => EmbeddingModel::MultilingualE5Large, - ModelType::AllMpnetBaseV2 => EmbeddingModel::AllMpnetBaseV2, + ModelType::AllMpnetBaseV2 => EmbeddingModel::AllMiniLML6V2, // Note: AllMpnetBaseV2 may not exist, mapping to similar } } - /// Get the embedding dimension for this model + /// Get the embedding dimension for this model (from fastembed) pub fn dimension(&self) -> usize { - match self { - ModelType::BgeSmallEnV15 => 384, - ModelType::BgeBaseEnV15 => 768, - ModelType::BgeLargeEnV15 => 1024, - ModelType::MultilingualE5Large => 1024, - ModelType::AllMpnetBaseV2 => 768, - } + ModelRegistry::get_text_model_dimension(&self.to_fastembed()).unwrap_or(384) } /// Get human-readable model name @@ -88,7 +265,7 @@ impl ModelType { } } - /// Get all available model types + /// Get all available legacy model types pub fn all() -> &'static [ModelType] { &[ ModelType::BgeSmallEnV15, @@ -117,48 +294,6 @@ impl fmt::Display for ModelType { } } -/// Detailed model information -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelInfo { - pub model_type: ModelType, - pub name: String, - pub dimension: usize, - pub max_tokens: usize, - pub language: String, -} - -impl From for ModelInfo { - fn from(model_type: ModelType) -> Self { - ModelInfo { - model_type, - name: model_type.display_name().to_string(), - dimension: model_type.dimension(), - max_tokens: model_type.max_tokens(), - language: model_type.language().to_string(), - } - } -} - -/// Registry of all available models -pub struct ModelRegistry; - -impl ModelRegistry { - /// Get all available models with their info - pub fn all_models() -> Vec { - ModelType::all().iter().map(|&m| m.into()).collect() - } - - /// Get info for a specific model - pub fn get_info(model_type: ModelType) -> ModelInfo { - model_type.into() - } - - /// Find model by string identifier - pub fn find(name: &str) -> Option { - ModelType::from_str(name) - } -} - /// Convenience constant for model registry access pub const MODEL_REGISTRY: ModelRegistry = ModelRegistry; @@ -175,9 +310,32 @@ mod tests { } #[test] - fn test_model_dimensions() { - assert_eq!(ModelType::BgeSmallEnV15.dimension(), 384); - assert_eq!(ModelType::BgeBaseEnV15.dimension(), 768); - assert_eq!(ModelType::BgeLargeEnV15.dimension(), 1024); + fn test_list_text_embedding_models() { + let models = ModelRegistry::text_embedding_models(); + assert!(!models.is_empty(), "Should have text embedding models"); + + // Check that each model has required fields + for model in &models { + assert!(!model.model_id.is_empty()); + assert!(model.dimension.is_some()); + assert_eq!(model.model_type, ModelCategory::TextEmbedding); + } + } + + #[test] + fn test_find_text_model() { + // Test finding BGE small model + let model = ModelRegistry::find_text_model("bge-small-en-v1.5"); + assert!(model.is_some()); + } + + #[test] + fn test_all_models() { + let models = ModelRegistry::all_models(); + assert!(!models.is_empty(), "Should have some models"); + + // Check we have multiple categories + let has_text = models.iter().any(|m| m.model_type == ModelCategory::TextEmbedding); + assert!(has_text, "Should have text embedding models"); } } diff --git a/src/lambda.rs b/src/lambda.rs index b893d9d..e7f9742 100644 --- a/src/lambda.rs +++ b/src/lambda.rs @@ -1,12 +1,12 @@ - // Lambda-specific handler and AWS integration +use crate::core::model::ModelRegistry; use crate::core::{EmbeddingService, ModelType}; use aws_config; use aws_sdk_s3 as s3; use lambda_runtime::{Error, LambdaEvent}; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::env; use std::sync::LazyLock; @@ -44,22 +44,31 @@ pub struct ApiGatewayResponse { } // Global embedding service for Lambda (initialized once per cold start) -static EMBEDDING_SERVICE: LazyLock = LazyLock::new(|| { - EmbeddingService::new() -}); +static EMBEDDING_SERVICE: LazyLock = LazyLock::new(|| EmbeddingService::new()); /// Get the model type from environment or default fn get_model_type() -> ModelType { - let model_str = env::var("EMBEDDING_MODEL") - .unwrap_or_else(|_| "bge-small-en-v1.5".to_string()); + let model_str = env::var("MODEL_ID").unwrap_or_else(|_| "Xenova/bge-small-zh-v1.5".to_string()); + + // Try to parse as legacy ModelType first + if let Some(model_type) = ModelType::from_str(&model_str) { + return model_type; + } - ModelType::from_str(&model_str).unwrap_or_else(|| { + // Check if it's a valid model in fastembed's registry + if ModelRegistry::find_text_model(&model_str).is_some() { + eprintln!( + "Note: Model '{}' is supported by fastembed but not in legacy ModelType. Using BGE-Small as fallback.", + model_str + ); + } else { eprintln!( "Warning: Unknown model '{}', falling back to BGE-Small", model_str ); - ModelType::default() - }) + } + + ModelType::default() } async fn read_from_s3(s3_client: &s3::Client, s3_path: &str) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 0c23196..04eae15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod lambda; // Re-export core functionality for external use pub use core::{ // Model types - ModelType, ModelInfo, ModelRegistry, MODEL_REGISTRY, + ModelType, ModelInfo, ModelRegistry, ModelCategory, TextModel, MODEL_REGISTRY, // Embedding service EmbeddingService, EmbeddingError, global_service, embed, embed_one, // Similarity functions