From 6d0305b857fe1bed57e13f5aa3c7eb21648d0fe1 Mon Sep 17 00:00:00 2001 From: marianbasti <31198560+marianbasti@users.noreply.github.com> Date: Fri, 21 Feb 2025 09:30:46 -0300 Subject: [PATCH 1/3] Add API key validation and require API_KEY environment variable --- entrypoint.sh | 2 +- main.py | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/entrypoint.sh b/entrypoint.sh index 750a3c6..4e42288 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash - +export API_KEY=${API_KEY:?"API_KEY environment variable is required"} uvicorn main:app --host 0.0.0.0 --port $PORT \ No newline at end of file diff --git a/main.py b/main.py index 12ec834..2f7f1aa 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,28 @@ from typing import Union, List, Dict from contextlib import asynccontextmanager import os - -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Security, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer models: Dict[str, SentenceTransformer] = {} model_name = os.getenv("MODEL", "all-MiniLM-L6-v2") +api_key = os.getenv("API_KEY") + +if not api_key: + raise ValueError("API_KEY environment variable must be set") + +security = HTTPBearer() + +def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): + if credentials.credentials != api_key: + raise HTTPException( + status_code=401, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + return credentials.credentials class EmbeddingRequest(BaseModel): @@ -48,7 +63,10 @@ async def lifespan(app: FastAPI): @app.post("/v1/embeddings") -async def embedding(item: EmbeddingRequest) -> EmbeddingResponse: +async def embedding( + item: EmbeddingRequest, + api_key: str = Depends(verify_api_key) +) -> EmbeddingResponse: model: SentenceTransformer = models[model_name] if isinstance(item.input, str): vectors = model.encode(item.input) From 6f356884d00b919b4a860892668f61d9900e95f0 Mon Sep 17 00:00:00 2001 From: marianbasti <31198560+marianbasti@users.noreply.github.com> Date: Fri, 21 Feb 2025 09:33:21 -0300 Subject: [PATCH 2/3] Add authorization header example to CURL command in README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ff85f4b..2caf2d5 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ You can also use CURL to get embeddings: ```bash curl http://localhost:8080/v1/embeddings \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-secret-key" \ -d '{ "input": "Your text string goes here", "model": "all-MiniLM-L6-v2" From 959bb09f058a4b23e250ee045518394569bb1656 Mon Sep 17 00:00:00 2001 From: marianbasti <31198560+marianbasti@users.noreply.github.com> Date: Tue, 25 Feb 2025 10:16:49 -0300 Subject: [PATCH 3/3] Refactor API key handling to allow optional credentials --- entrypoint.sh | 2 +- main.py | 33 +++++++++++++++++++-------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/entrypoint.sh b/entrypoint.sh index 4e42288..4cc13eb 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -export API_KEY=${API_KEY:?"API_KEY environment variable is required"} +export API_KEY=${API_KEY:-""} uvicorn main:app --host 0.0.0.0 --port $PORT \ No newline at end of file diff --git a/main.py b/main.py index 2f7f1aa..5212ff0 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from typing import Union, List, Dict +from typing import Union, List, Dict, Optional from contextlib import asynccontextmanager import os from fastapi import FastAPI, HTTPException, Security, Depends @@ -10,19 +10,24 @@ model_name = os.getenv("MODEL", "all-MiniLM-L6-v2") api_key = os.getenv("API_KEY") -if not api_key: - raise ValueError("API_KEY environment variable must be set") +security = HTTPBearer(auto_error=False) -security = HTTPBearer() - -def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): - if credentials.credentials != api_key: - raise HTTPException( - status_code=401, - detail="Invalid API key", - headers={"WWW-Authenticate": "Bearer"}, - ) - return credentials.credentials +def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)): + if api_key: + if not credentials: + raise HTTPException( + status_code=401, + detail="API key required", + headers={"WWW-Authenticate": "Bearer"}, + ) + if credentials.credentials != api_key: + raise HTTPException( + status_code=401, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + return credentials.credentials + return None class EmbeddingRequest(BaseModel): @@ -65,7 +70,7 @@ async def lifespan(app: FastAPI): @app.post("/v1/embeddings") async def embedding( item: EmbeddingRequest, - api_key: str = Depends(verify_api_key) + api_key: Optional[str] = Depends(verify_api_key) ) -> EmbeddingResponse: model: SentenceTransformer = models[model_name] if isinstance(item.input, str):