Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ You can also use CURL to get embeddings:
```bash
curl http://localhost:8080/v1/embeddings \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-secret-key" \
-d '{
"input": "Your text string goes here",
"model": "all-MiniLM-L6-v2"
Expand Down
2 changes: 1 addition & 1 deletion entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env bash

export API_KEY=${API_KEY:-""}
uvicorn main:app --host 0.0.0.0 --port $PORT
31 changes: 27 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
from typing import Union, List, Dict
from typing import Union, List, Dict, Optional
from contextlib import asynccontextmanager
import os

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

models: Dict[str, SentenceTransformer] = {}
model_name = os.getenv("MODEL", "all-MiniLM-L6-v2")
api_key = os.getenv("API_KEY")

security = HTTPBearer(auto_error=False)

def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
if api_key:
if not credentials:
raise HTTPException(
status_code=401,
detail="API key required",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials.credentials != api_key:
raise HTTPException(
status_code=401,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
return None


class EmbeddingRequest(BaseModel):
Expand Down Expand Up @@ -48,7 +68,10 @@ async def lifespan(app: FastAPI):


@app.post("/v1/embeddings")
async def embedding(item: EmbeddingRequest) -> EmbeddingResponse:
async def embedding(
item: EmbeddingRequest,
api_key: Optional[str] = Depends(verify_api_key)
) -> EmbeddingResponse:
model: SentenceTransformer = models[model_name]
if isinstance(item.input, str):
vectors = model.encode(item.input)
Expand Down