diff --git a/README.md b/README.md index ff85f4b..2caf2d5 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ You can also use CURL to get embeddings: ```bash curl http://localhost:8080/v1/embeddings \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-secret-key" \ -d '{ "input": "Your text string goes here", "model": "all-MiniLM-L6-v2" diff --git a/entrypoint.sh b/entrypoint.sh index 750a3c6..4cc13eb 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash - +export API_KEY=${API_KEY:-""} uvicorn main:app --host 0.0.0.0 --port $PORT \ No newline at end of file diff --git a/main.py b/main.py index 12ec834..5212ff0 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,33 @@ -from typing import Union, List, Dict +from typing import Union, List, Dict, Optional from contextlib import asynccontextmanager import os - -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Security, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer models: Dict[str, SentenceTransformer] = {} model_name = os.getenv("MODEL", "all-MiniLM-L6-v2") +api_key = os.getenv("API_KEY") + +security = HTTPBearer(auto_error=False) + +def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)): + if api_key: + if not credentials: + raise HTTPException( + status_code=401, + detail="API key required", + headers={"WWW-Authenticate": "Bearer"}, + ) + if credentials.credentials != api_key: + raise HTTPException( + status_code=401, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + return credentials.credentials + return None class EmbeddingRequest(BaseModel): @@ -48,7 +68,10 @@ async def lifespan(app: FastAPI): @app.post("/v1/embeddings") -async def embedding(item: EmbeddingRequest) -> EmbeddingResponse: +async def embedding( + item: EmbeddingRequest, + api_key: Optional[str] = Depends(verify_api_key) +) -> EmbeddingResponse: model: SentenceTransformer = models[model_name] if isinstance(item.input, str): vectors = model.encode(item.input)