Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ htmlcov
*.env
build/
dist/
*.log
99 changes: 96 additions & 3 deletions backend/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
import uuid as _uuid
from collections.abc import Generator
from typing import Annotated

import jwt as pyjwt
from fastapi import Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from psycopg import Cursor

from app.core.config import Settings
from app.core.exceptions import ServiceUnavailableException
from app.core.exceptions import (
AuthenticationException,
ServiceUnavailableException,
)
from app.core.security import decode_token


_bearer_scheme = HTTPBearer(auto_error=False)


def get_db(request: Request) -> Generator[Cursor, None, None]:
"""Return a database cursor for the request."""
"""Return a database cursor for the request.

Yields:
Cursor: A database cursor.

Raises:
ServiceUnavailableException: If the database pool is not initialized.
"""
pool = getattr(request.app.state, "db_pool", None)
if pool is None:
raise ServiceUnavailableException("Database pool not initialized")
Expand All @@ -19,7 +36,14 @@ def get_db(request: Request) -> Generator[Cursor, None, None]:


def get_settings(request: Request) -> Settings:
"""Return settings from app state."""
"""Return settings from app state.

Returns:
Settings: Application settings.

Raises:
ServiceUnavailableException: If settings are not initialized.
"""
settings = getattr(request.app.state, "settings", None)
if settings is None:
raise ServiceUnavailableException("Settings not initialized")
Expand All @@ -28,3 +52,72 @@ def get_settings(request: Request) -> Settings:

CursorDep = Annotated[Cursor, Depends(get_db)]
SettingsDep = Annotated[Settings, Depends(get_settings)]


def get_current_vendor_id(
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(_bearer_scheme)
],
settings: SettingsDep,
) -> str:
"""Extract and validate vendor_id from the Authorization: Bearer token.

Returns:
str: The validated vendor_id from the JWT.

Raises:
AuthenticationException: On missing / invalid / expired tokens
or if the token is not an access token.
"""
if credentials is None:
raise AuthenticationException("Missing authentication token")

try:
payload = decode_token(credentials.credentials, settings)
except pyjwt.PyJWTError:
raise AuthenticationException("Invalid or expired token") from None

if payload.get("token_type") != "access":
raise AuthenticationException("Invalid token type")

vendor_id: str | None = payload.get("vendor_id")
if vendor_id is None:
raise AuthenticationException("Invalid token payload")

# Validate that vendor_id is a well-formed UUID before it reaches
# downstream consumers such as app.set_app_context().
try:
_uuid.UUID(vendor_id)
except (ValueError, AttributeError):
raise AuthenticationException("Invalid token payload") from None

return vendor_id


def get_rls_cursor(
request: Request, vendor_id: Annotated[str, Depends(get_current_vendor_id)]
) -> Generator[Cursor, None, None]:
"""Return a database cursor with app.vendor_id set for RLS.

After authentication, this dependency:
1. Obtains a connection from the pool
2. Calls app.set_app_context(vendor_id) to set the RLS context
3. Yields the cursor for use in route handlers

Yields:
Cursor: A database cursor with RLS context set.

Raises:
ServiceUnavailableException: If the database pool is not initialized.
"""
pool = getattr(request.app.state, "db_pool", None)
if pool is None:
raise ServiceUnavailableException("Database pool not initialized")
with pool.connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT app.set_app_context(%s)", (vendor_id,))
yield cursor


CurrentVendorId = Annotated[str, Depends(get_current_vendor_id)]
RLSCursorDep = Annotated[Cursor, Depends(get_rls_cursor)]
5 changes: 3 additions & 2 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter

from app.api.routes import health, login
from app.api.routes import auth, health


api_router = APIRouter()
api_router.include_router(health.router, tags=["health"])
api_router.include_router(login.router, prefix="/login", tags=["login"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
36 changes: 31 additions & 5 deletions backend/app/api/middlewares.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
import re
import uuid

from fastapi import Request
from fastapi import Request, Response
from starlette.middleware.base import RequestResponseEndpoint


async def add_request_id(request: Request, call_next):
"""Add a unique request_id to each request context"""
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}"
r"-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)


def _is_valid_request_id(value: str) -> bool:
"""Check that value is a well-formed UUID string.

Returns:
bool: True when the value matches the UUID4 hex pattern.
"""
return bool(_UUID_RE.match(value))


async def add_request_id(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Add a unique request_id to each request context.

Returns:
Response: The response with X-Request-ID header.
"""
raw_header = request.headers.get("X-Request-ID")
stripped_header = raw_header.strip() if raw_header else ""
request_id = stripped_header if stripped_header else str(uuid.uuid4())
stripped = raw_header.strip() if raw_header else ""
if stripped and _is_valid_request_id(stripped):
request_id = stripped
else:
request_id = str(uuid.uuid4())
request.state.request_id = request_id

response = await call_next(request)
Expand Down
77 changes: 77 additions & 0 deletions backend/app/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Auth routes — signup, login, token refresh."""

from fastapi import APIRouter, status

from app.api.deps import CursorDep, SettingsDep
from app.schemas.auth import (
LoginRequest,
RefreshRequest,
SignupRequest,
SignupResponse,
TokenPair,
)
from app.schemas.response import SuccessResponse
from app.services import auth as auth_service


router = APIRouter()


@router.post(
"/signup",
status_code=status.HTTP_201_CREATED,
response_model=SuccessResponse[SignupResponse],
)
def signup(
body: SignupRequest, cursor: CursorDep, settings: SettingsDep
) -> SuccessResponse[SignupResponse]:
"""Create a new vendor account.

Returns:
SuccessResponse[SignupResponse]: The created vendor.
"""
result = auth_service.signup(
cursor=cursor,
email=body.email,
password=body.password,
client_id=body.client_id,
settings=settings,
)
return SuccessResponse(data=result)


@router.post("/login", response_model=SuccessResponse[TokenPair])
def login(
body: LoginRequest, cursor: CursorDep, settings: SettingsDep
) -> SuccessResponse[TokenPair]:
"""Authenticate a vendor and return an access/refresh token pair.

Returns:
SuccessResponse[TokenPair]: The token pair.
"""
result = auth_service.login(
cursor=cursor,
email=body.email,
password=body.password,
client_id=body.client_id,
settings=settings,
)
return SuccessResponse(data=result)


@router.post("/refresh", response_model=SuccessResponse[TokenPair])
def refresh(
body: RefreshRequest, cursor: CursorDep, settings: SettingsDep
) -> SuccessResponse[TokenPair]:
"""Issue a new token pair using a valid refresh token.

Returns:
SuccessResponse[TokenPair]: The new token pair.
"""
result = auth_service.refresh(
refresh_token_str=body.refresh_token,
client_id=body.client_id,
cursor=cursor,
settings=settings,
)
return SuccessResponse(data=result)
6 changes: 5 additions & 1 deletion backend/app/api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

from fastapi import APIRouter


router = APIRouter()


@router.get("/health")
def health_check() -> dict[str, Any]:
return {
"data": {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()}
"data": {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
}
1 change: 1 addition & 0 deletions backend/app/api/routes/login.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from fastapi import APIRouter


router = APIRouter()
17 changes: 7 additions & 10 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from pydantic import (
PostgresDsn,
computed_field,
)
from pydantic import PostgresDsn, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_ignore_empty=True,
extra="ignore",
env_file=".env", env_ignore_empty=True, extra="ignore"
)

SECRET_KEY: str
# 60 minutes * 24 hours = 1 day (configurable via ACCESS_TOKEN_EXPIRE_MINUTES env var)
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24
# JWT token lifetimes
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 # 1 hour
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days

PROJECT_NAME: str
POSTGRES_SERVER: str
Expand All @@ -25,7 +21,8 @@ class Settings(BaseSettings):

@computed_field
@property
def DATABASE_DSN(self) -> PostgresDsn:
def DATABASE_DSN(self) -> PostgresDsn: # noqa: N802
# See: https://docs.astral.sh/ruff/rules/invalid-function-name/
return PostgresDsn.build(
scheme="postgresql",
username=self.POSTGRES_USER,
Expand Down
Loading