diff --git a/.gitignore b/.gitignore index 29e02fe..f9ace42 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ htmlcov *.env build/ dist/ +*.log diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index a07bab4..d33af99 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,15 +1,32 @@ +import uuid as _uuid from collections.abc import Generator from typing import Annotated +import jwt as pyjwt from fastapi import Depends, Request +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from psycopg import Cursor from app.core.config import Settings -from app.core.exceptions import ServiceUnavailableException +from app.core.exceptions import ( + AuthenticationException, + ServiceUnavailableException, +) +from app.core.security import decode_token + + +_bearer_scheme = HTTPBearer(auto_error=False) def get_db(request: Request) -> Generator[Cursor, None, None]: - """Return a database cursor for the request.""" + """Return a database cursor for the request. + + Yields: + Cursor: A database cursor. + + Raises: + ServiceUnavailableException: If the database pool is not initialized. + """ pool = getattr(request.app.state, "db_pool", None) if pool is None: raise ServiceUnavailableException("Database pool not initialized") @@ -19,7 +36,14 @@ def get_db(request: Request) -> Generator[Cursor, None, None]: def get_settings(request: Request) -> Settings: - """Return settings from app state.""" + """Return settings from app state. + + Returns: + Settings: Application settings. + + Raises: + ServiceUnavailableException: If settings are not initialized. + """ settings = getattr(request.app.state, "settings", None) if settings is None: raise ServiceUnavailableException("Settings not initialized") @@ -28,3 +52,72 @@ def get_settings(request: Request) -> Settings: CursorDep = Annotated[Cursor, Depends(get_db)] SettingsDep = Annotated[Settings, Depends(get_settings)] + + +def get_current_vendor_id( + credentials: Annotated[ + HTTPAuthorizationCredentials | None, Depends(_bearer_scheme) + ], + settings: SettingsDep, +) -> str: + """Extract and validate vendor_id from the Authorization: Bearer token. + + Returns: + str: The validated vendor_id from the JWT. + + Raises: + AuthenticationException: On missing / invalid / expired tokens + or if the token is not an access token. + """ + if credentials is None: + raise AuthenticationException("Missing authentication token") + + try: + payload = decode_token(credentials.credentials, settings) + except pyjwt.PyJWTError: + raise AuthenticationException("Invalid or expired token") from None + + if payload.get("token_type") != "access": + raise AuthenticationException("Invalid token type") + + vendor_id: str | None = payload.get("vendor_id") + if vendor_id is None: + raise AuthenticationException("Invalid token payload") + + # Validate that vendor_id is a well-formed UUID before it reaches + # downstream consumers such as app.set_app_context(). + try: + _uuid.UUID(vendor_id) + except (ValueError, AttributeError): + raise AuthenticationException("Invalid token payload") from None + + return vendor_id + + +def get_rls_cursor( + request: Request, vendor_id: Annotated[str, Depends(get_current_vendor_id)] +) -> Generator[Cursor, None, None]: + """Return a database cursor with app.vendor_id set for RLS. + + After authentication, this dependency: + 1. Obtains a connection from the pool + 2. Calls app.set_app_context(vendor_id) to set the RLS context + 3. Yields the cursor for use in route handlers + + Yields: + Cursor: A database cursor with RLS context set. + + Raises: + ServiceUnavailableException: If the database pool is not initialized. + """ + pool = getattr(request.app.state, "db_pool", None) + if pool is None: + raise ServiceUnavailableException("Database pool not initialized") + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT app.set_app_context(%s)", (vendor_id,)) + yield cursor + + +CurrentVendorId = Annotated[str, Depends(get_current_vendor_id)] +RLSCursorDep = Annotated[Cursor, Depends(get_rls_cursor)] diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 12ab2ae..8272777 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,7 +1,8 @@ from fastapi import APIRouter -from app.api.routes import health, login +from app.api.routes import auth, health + api_router = APIRouter() api_router.include_router(health.router, tags=["health"]) -api_router.include_router(login.router, prefix="/login", tags=["login"]) +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/app/api/middlewares.py b/backend/app/api/middlewares.py index 387d34c..9e1f785 100644 --- a/backend/app/api/middlewares.py +++ b/backend/app/api/middlewares.py @@ -1,13 +1,39 @@ +import re import uuid -from fastapi import Request +from fastapi import Request, Response +from starlette.middleware.base import RequestResponseEndpoint -async def add_request_id(request: Request, call_next): - """Add a unique request_id to each request context""" +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}" + r"-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" +) + + +def _is_valid_request_id(value: str) -> bool: + """Check that value is a well-formed UUID string. + + Returns: + bool: True when the value matches the UUID4 hex pattern. + """ + return bool(_UUID_RE.match(value)) + + +async def add_request_id( + request: Request, call_next: RequestResponseEndpoint +) -> Response: + """Add a unique request_id to each request context. + + Returns: + Response: The response with X-Request-ID header. + """ raw_header = request.headers.get("X-Request-ID") - stripped_header = raw_header.strip() if raw_header else "" - request_id = stripped_header if stripped_header else str(uuid.uuid4()) + stripped = raw_header.strip() if raw_header else "" + if stripped and _is_valid_request_id(stripped): + request_id = stripped + else: + request_id = str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py new file mode 100644 index 0000000..3aa1000 --- /dev/null +++ b/backend/app/api/routes/auth.py @@ -0,0 +1,77 @@ +"""Auth routes — signup, login, token refresh.""" + +from fastapi import APIRouter, status + +from app.api.deps import CursorDep, SettingsDep +from app.schemas.auth import ( + LoginRequest, + RefreshRequest, + SignupRequest, + SignupResponse, + TokenPair, +) +from app.schemas.response import SuccessResponse +from app.services import auth as auth_service + + +router = APIRouter() + + +@router.post( + "/signup", + status_code=status.HTTP_201_CREATED, + response_model=SuccessResponse[SignupResponse], +) +def signup( + body: SignupRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[SignupResponse]: + """Create a new vendor account. + + Returns: + SuccessResponse[SignupResponse]: The created vendor. + """ + result = auth_service.signup( + cursor=cursor, + email=body.email, + password=body.password, + client_id=body.client_id, + settings=settings, + ) + return SuccessResponse(data=result) + + +@router.post("/login", response_model=SuccessResponse[TokenPair]) +def login( + body: LoginRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[TokenPair]: + """Authenticate a vendor and return an access/refresh token pair. + + Returns: + SuccessResponse[TokenPair]: The token pair. + """ + result = auth_service.login( + cursor=cursor, + email=body.email, + password=body.password, + client_id=body.client_id, + settings=settings, + ) + return SuccessResponse(data=result) + + +@router.post("/refresh", response_model=SuccessResponse[TokenPair]) +def refresh( + body: RefreshRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[TokenPair]: + """Issue a new token pair using a valid refresh token. + + Returns: + SuccessResponse[TokenPair]: The new token pair. + """ + result = auth_service.refresh( + refresh_token_str=body.refresh_token, + client_id=body.client_id, + cursor=cursor, + settings=settings, + ) + return SuccessResponse(data=result) diff --git a/backend/app/api/routes/health.py b/backend/app/api/routes/health.py index ffe224c..2a39a8a 100644 --- a/backend/app/api/routes/health.py +++ b/backend/app/api/routes/health.py @@ -3,11 +3,15 @@ from fastapi import APIRouter + router = APIRouter() @router.get("/health") def health_check() -> dict[str, Any]: return { - "data": {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()} + "data": { + "status": "ok", + "timestamp": datetime.now(timezone.utc).isoformat(), + } } diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index af9233c..e6f2f82 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -1,3 +1,4 @@ from fastapi import APIRouter + router = APIRouter() diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f25f85a..b2a3142 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,20 +1,16 @@ -from pydantic import ( - PostgresDsn, - computed_field, -) +from pydantic import PostgresDsn, computed_field from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): model_config = SettingsConfigDict( - env_file=".env", - env_ignore_empty=True, - extra="ignore", + env_file=".env", env_ignore_empty=True, extra="ignore" ) SECRET_KEY: str - # 60 minutes * 24 hours = 1 day (configurable via ACCESS_TOKEN_EXPIRE_MINUTES env var) - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 + # JWT token lifetimes + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 # 1 hour + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days PROJECT_NAME: str POSTGRES_SERVER: str @@ -25,7 +21,8 @@ class Settings(BaseSettings): @computed_field @property - def DATABASE_DSN(self) -> PostgresDsn: + def DATABASE_DSN(self) -> PostgresDsn: # noqa: N802 + # See: https://docs.astral.sh/ruff/rules/invalid-function-name/ return PostgresDsn.build( scheme="postgresql", username=self.POSTGRES_USER, diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index 1462445..4fe854a 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -1,5 +1,5 @@ -import uuid import logging +import uuid from fastapi import Request, status from fastapi.exceptions import RequestValidationError @@ -17,8 +17,12 @@ logger = logging.getLogger(__name__) -async def api_exception_handler(request: Request, exc: APIException) -> JSONResponse: - """Handle custom API exceptions""" +def api_exception_handler(request: Request, exc: APIException) -> JSONResponse: + """Handle custom API exceptions. + + Returns: + JSONResponse: The error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) logger.warning( @@ -46,10 +50,14 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp ) -async def validation_exception_handler( +def validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: - """Handle Pydantic validation errors""" + """Handle Pydantic validation errors. + + Returns: + JSONResponse: The validation error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) # Parse validation errors @@ -57,20 +65,12 @@ async def validation_exception_handler( for error in exc.errors(): field_path = ".".join(str(loc) for loc in error["loc"][1:]) field = field_path if field_path else None - details.append( - { - "field": field, - "message": error["msg"], - } - ) + details.append({"field": field, "message": error["msg"]}) logger.warning( "Validation Error: %d validation errors", len(details), - extra={ - "request_id": request_id, - "validation_errors": details, - }, + extra={"request_id": request_id, "validation_errors": details}, ) # Starlette <0.48 compatibility: use getattr fallback to literal 422 @@ -96,15 +96,17 @@ async def validation_exception_handler( ) -async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: - """Handle all uncaught exceptions""" +def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle all uncaught exceptions. + + Returns: + JSONResponse: The generic error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) # Log the full traceback server-side logger.exception( - "Unexpected error: %s", - str(exc), - extra={"request_id": request_id}, + "Unexpected error: %s", exc, extra={"request_id": request_id} ) # Use typed schema to validate error structure @@ -129,14 +131,16 @@ def _build_error_details( details: list[dict | ErrorDetail] | dict | ErrorDetail | None, ) -> list[ErrorDetail]: """ - Converts error details (list, dict, ErrorDetail, or None) into a list of ErrorDetail instances. + Convert error details into a list of ErrorDetail instances. + + Accepts a list, dict, ErrorDetail, or None and normalises + into ``list[ErrorDetail]``. Args: - details: Can be a list of dict/ErrorDetail, a single dict, a single ErrorDetail, or None. - Non-dict/non-ErrorDetail entries are converted to string messages. + details: Raw error details in any supported shape. Returns: - List of ErrorDetail instances. Non-dict/non-ErrorDetail entries are preserved as string messages. + list[ErrorDetail]: Normalised detail objects. """ if details is None: return [] diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py index b6537a9..ce52e4b 100644 --- a/backend/app/core/exceptions.py +++ b/backend/app/core/exceptions.py @@ -5,7 +5,7 @@ from app.schemas.response import ErrorCode -class APIException(Exception): +class APIException(Exception): # noqa: N818 """Base exception for API errors""" def __init__( @@ -14,7 +14,7 @@ def __init__( message: str, http_status: int, details: list[dict[str, Any]] | None = None, - ): + ) -> None: self.error_code = error_code self.message = message self.http_status = http_status @@ -29,11 +29,11 @@ def __init__( self, message: str = "Invalid request parameters", details: list[dict[str, Any]] | None = None, - ): + ) -> None: super().__init__( error_code=ErrorCode.VALIDATION_FAILED, message=message, - # Use getattr for compatibility with Starlette <0.48 which lacks HTTP_422_UNPROCESSABLE_CONTENT + # Starlette <0.48 compat: fallback to 422 http_status=getattr(status, "HTTP_422_UNPROCESSABLE_CONTENT", 422), details=details, ) @@ -42,7 +42,7 @@ def __init__( class AuthenticationException(APIException): """Authentication error""" - def __init__(self, message: str = "Invalid credentials"): + def __init__(self, message: str = "Invalid credentials") -> None: super().__init__( error_code=ErrorCode.AUTH_INVALID, message=message, @@ -53,7 +53,7 @@ def __init__(self, message: str = "Invalid credentials"): class AuthorizationException(APIException): """Authorization error""" - def __init__(self, message: str = "Access denied"): + def __init__(self, message: str = "Access denied") -> None: super().__init__( error_code=ErrorCode.FORBIDDEN, message=message, @@ -64,7 +64,7 @@ def __init__(self, message: str = "Access denied"): class NotFoundException(APIException): """Resource not found error""" - def __init__(self, message: str = "Resource not found"): + def __init__(self, message: str = "Resource not found") -> None: super().__init__( error_code=ErrorCode.RESOURCE_NOT_FOUND, message=message, @@ -75,7 +75,7 @@ def __init__(self, message: str = "Resource not found"): class ConflictException(APIException): """Resource conflict error""" - def __init__(self, message: str = "Resource conflict"): + def __init__(self, message: str = "Resource conflict") -> None: super().__init__( error_code=ErrorCode.RESOURCE_CONFLICT, message=message, @@ -90,11 +90,11 @@ def __init__( self, message: str = "Business logic error", details: list[dict[str, Any]] | None = None, - ): + ) -> None: super().__init__( error_code=ErrorCode.BUSINESS_LOGIC_ERROR, message=message, - # Use getattr for compatibility with Starlette <0.48 which lacks HTTP_422_UNPROCESSABLE_CONTENT + # Starlette <0.48 compat: fallback to 422 http_status=getattr(status, "HTTP_422_UNPROCESSABLE_CONTENT", 422), details=details, ) @@ -103,7 +103,9 @@ def __init__( class ServiceUnavailableException(APIException): """Service unavailable error""" - def __init__(self, message: str = "Service temporarily unavailable"): + def __init__( + self, message: str = "Service temporarily unavailable" + ) -> None: super().__init__( error_code=ErrorCode.SERVICE_UNAVAILABLE, message=message, @@ -114,7 +116,9 @@ def __init__(self, message: str = "Service temporarily unavailable"): class AuthExpiredException(APIException): """Authentication token expired error""" - def __init__(self, message: str = "Authentication token has expired"): + def __init__( + self, message: str = "Authentication token has expired" + ) -> None: super().__init__( error_code=ErrorCode.AUTH_EXPIRED, message=message, @@ -125,7 +129,7 @@ def __init__(self, message: str = "Authentication token has expired"): class LicenseNotFoundException(APIException): """License not found error""" - def __init__(self, message: str = "License not found"): + def __init__(self, message: str = "License not found") -> None: super().__init__( error_code=ErrorCode.LICENSE_NOT_FOUND, message=message, @@ -136,7 +140,7 @@ def __init__(self, message: str = "License not found"): class LicenseRevokedException(APIException): """License has been revoked error""" - def __init__(self, message: str = "License has been revoked"): + def __init__(self, message: str = "License has been revoked") -> None: super().__init__( error_code=ErrorCode.LICENSE_REVOKED, message=message, @@ -147,7 +151,7 @@ def __init__(self, message: str = "License has been revoked"): class LicenseExpiredException(APIException): """License has expired error""" - def __init__(self, message: str = "License has expired"): + def __init__(self, message: str = "License has expired") -> None: super().__init__( error_code=ErrorCode.LICENSE_EXPIRED, message=message, diff --git a/backend/app/core/security.py b/backend/app/core/security.py index bb587e4..48d52d7 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -8,24 +8,66 @@ from app.core.config import Settings -password_hash = PasswordHash( - ( - Argon2Hasher(), - BcryptHasher(), - ) -) + +# BcryptHasher is listed first so new passwords are hashed with bcrypt. +# Argon2Hasher is kept for verification of legacy hashes. +password_hash = PasswordHash((BcryptHasher(), Argon2Hasher())) ALGORITHM = "HS256" def create_access_token( - subject: str | Any, expires_delta: timedelta, settings: Settings + vendor_id: str, + settings: Settings, + *, + expires_delta: timedelta | None = None, +) -> str: + """Create a short-lived access token with vendor_id claim. + + Returns: + str: The encoded JWT access token. + """ + if expires_delta is None: + expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + expires_delta + to_encode = { + "vendor_id": str(vendor_id), + "exp": expire, + "token_type": "access", + } + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def create_refresh_token( + vendor_id: str, + settings: Settings, + *, + expires_delta: timedelta | None = None, ) -> str: + """Create a long-lived refresh token with vendor_id claim. + + Returns: + str: The encoded JWT refresh token. + """ + if expires_delta is None: + expires_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now(timezone.utc) + expires_delta - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + to_encode = { + "vendor_id": str(vendor_id), + "exp": expire, + "token_type": "refresh", + } + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def decode_token(token: str, settings: Settings) -> dict[str, Any]: + """Decode and validate a JWT token. + + Returns: + dict[str, Any]: The decoded token payload. + """ + return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) def verify_password( diff --git a/backend/app/crud/.gitkeep b/backend/app/crud/__init__.py similarity index 100% rename from backend/app/crud/.gitkeep rename to backend/app/crud/__init__.py diff --git a/backend/app/crud/vendor.py b/backend/app/crud/vendor.py new file mode 100644 index 0000000..318504c --- /dev/null +++ b/backend/app/crud/vendor.py @@ -0,0 +1,69 @@ +"""CRUD operations for the vendors table (raw psycopg).""" + +from __future__ import annotations + +from typing import Any + +from psycopg import Cursor + + +def get_vendor_by_email(cursor: Cursor, email: str) -> dict[str, Any] | None: + """Return a vendor row by email (case-insensitive) or None. + + Returns: + dict[str, Any] | None: The vendor row or None. + """ + cursor.execute( + 'SELECT "id", "email", "password_hash" ' + 'FROM app."vendors" ' + 'WHERE LOWER("email") = LOWER(%s) ' + 'AND "deleted_at" IS NULL', + (email,), + ) + row = cursor.fetchone() + if row is None: + return None + return {"id": str(row[0]), "email": row[1], "password_hash": row[2]} + + +def get_vendor_by_id(cursor: Cursor, vendor_id: str) -> dict[str, Any] | None: + """Return a vendor row by id or None. + + Returns: + dict[str, Any] | None: The vendor row or None. + """ + cursor.execute( + 'SELECT "id", "email" ' + 'FROM app."vendors" ' + 'WHERE "id" = %s AND "deleted_at" IS NULL', + (vendor_id,), + ) + row = cursor.fetchone() + if row is None: + return None + return {"id": str(row[0]), "email": row[1]} + + +def create_vendor( + cursor: Cursor, email: str, password_hash: str +) -> dict[str, Any] | None: + """Insert a new vendor and return the created row. + + Uses ON CONFLICT DO NOTHING so a concurrent insert with the same + (case-insensitive) email returns None instead of raising a + UniqueViolation. + + Returns: + dict[str, Any] | None: The created vendor row, or None on conflict. + """ + cursor.execute( + 'INSERT INTO app."vendors" ("email", "password_hash") ' + "VALUES (%s, %s) " + 'ON CONFLICT ((LOWER("email"))) DO NOTHING ' + 'RETURNING "id", "email"', + (email, password_hash), + ) + row = cursor.fetchone() + if row is None: + return None + return {"id": str(row[0]), "email": row[1]} diff --git a/backend/app/main.py b/backend/app/main.py index 7d0e4e9..07241ca 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,27 +1,28 @@ -from contextlib import asynccontextmanager import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute +from psycopg_pool import ConnectionPool from app.api.main import api_router from app.api.middlewares import add_request_id from app.core.config import Settings from app.core.exception_handlers import ( api_exception_handler, - validation_exception_handler, general_exception_handler, + validation_exception_handler, ) from app.core.exceptions import APIException -from psycopg_pool import ConnectionPool logger = logging.getLogger(__name__) @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI) -> AsyncIterator[None]: # noqa: RUF029 # Startup logger.info("Initializing settings") settings = Settings() diff --git a/backend/app/pre_start.py b/backend/app/pre_start.py index c30cbaa..3b8262a 100644 --- a/backend/app/pre_start.py +++ b/backend/app/pre_start.py @@ -1,5 +1,6 @@ import logging +from psycopg_pool import ConnectionPool from tenacity import ( after_log, before_log, @@ -10,7 +11,7 @@ ) from app.core.config import Settings -from psycopg_pool import ConnectionPool + logger = logging.getLogger(__name__) diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py new file mode 100644 index 0000000..13d9e0b --- /dev/null +++ b/backend/app/schemas/auth.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel, EmailStr, Field + + +class SignupRequest(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8, max_length=128) + client_id: str = Field(..., min_length=1, max_length=256) + + +class LoginRequest(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8, max_length=128) + client_id: str = Field(..., min_length=1, max_length=256) + + +class RefreshRequest(BaseModel): + refresh_token: str = Field(..., min_length=8, max_length=4096) + client_id: str = Field(..., min_length=1, max_length=256) + + +class TokenPair(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" # noqa: S105 + + +class VendorOut(BaseModel): + id: str + email: str + + +class SignupResponse(BaseModel): + vendor: VendorOut diff --git a/backend/app/schemas/response.py b/backend/app/schemas/response.py index af1a7f7..6dd9645 100644 --- a/backend/app/schemas/response.py +++ b/backend/app/schemas/response.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Generic, TypeVar -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field + T = TypeVar("T") @@ -47,7 +48,9 @@ class SuccessResponse(BaseModel, Generic[T]): class ErrorDetail(BaseModel): """Additional error details""" - field: str | None = Field(None, description="Field name if validation error") + field: str | None = Field( + None, description="Field name if validation error" + ) message: str = Field(..., description="Detailed error message") @@ -57,13 +60,20 @@ class ErrorBodyResponse(BaseModel): code: ErrorCode = Field(..., description="Error code") message: str = Field(..., description="Human-readable error message") http_status: int = Field( - ..., ge=400, le=599, description="HTTP status code matching the response" + ..., + ge=400, + le=599, + description="HTTP status code matching the response", ) details: list[ErrorDetail] = Field( default_factory=list, - description="Additional error details for validation errors or other context", + description=( + "Additional error details for validation errors or other context" + ), + ) + request_id: str = Field( + ..., description="Unique request identifier for tracing" ) - request_id: str = Field(..., description="Unique request identifier for tracing") class ErrorResponse(BaseModel): @@ -71,7 +81,11 @@ class ErrorResponse(BaseModel): error: ErrorBodyResponse = Field( ..., - description="Error information with code, message, http_status, details, and request_id", + description=( + "Error information containing code," + " message, http_status, details," + " and request_id" + ), ) model_config = ConfigDict( @@ -81,7 +95,9 @@ class ErrorResponse(BaseModel): "code": "VALIDATION_FAILED", "message": "Invalid request parameters", "http_status": 422, - "details": [{"field": "email", "message": "Invalid email format"}], + "details": [ + {"field": "email", "message": "Invalid email format"} + ], "request_id": "req-123456789", } } diff --git a/backend/app/services/.gitkeep b/backend/app/services/__init__.py similarity index 100% rename from backend/app/services/.gitkeep rename to backend/app/services/__init__.py diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py new file mode 100644 index 0000000..ef4f4eb --- /dev/null +++ b/backend/app/services/auth.py @@ -0,0 +1,139 @@ +"""Auth service — orchestrates signup, login, and token refresh.""" + +from __future__ import annotations + +import logging +import uuid as _uuid + +import jwt as pyjwt +from psycopg import Cursor + +from app.core.config import Settings +from app.core.exceptions import AuthenticationException, ConflictException +from app.core.security import ( + create_access_token, + create_refresh_token, + decode_token, + get_password_hash, + verify_password, +) +from app.crud.vendor import create_vendor, get_vendor_by_email, get_vendor_by_id +from app.schemas.auth import SignupResponse, TokenPair, VendorOut + + +logger = logging.getLogger(__name__) + + +def signup( + cursor: Cursor, + email: str, + password: str, + client_id: str, + settings: Settings, +) -> SignupResponse: + """Create a new vendor account. + + Returns: + SignupResponse: The created vendor. + + Raises: + ConflictException: If the email already exists. + """ + existing = get_vendor_by_email(cursor, email) + if existing is not None: + raise ConflictException("A vendor with this email already exists") + + hashed = get_password_hash(password) + vendor = create_vendor(cursor, email, hashed) + + # create_vendor returns None when a concurrent insert won the race + # (ON CONFLICT DO NOTHING). + if vendor is None: + raise ConflictException("A vendor with this email already exists") + + logger.info("Vendor created: %s (client_id=%s)", vendor["id"], client_id) + return SignupResponse(vendor=VendorOut(**vendor)) + + +def login( + cursor: Cursor, + email: str, + password: str, + client_id: str, + settings: Settings, +) -> TokenPair: + """Authenticate a vendor and return an access/refresh token pair. + + Returns: + TokenPair: The access/refresh token pair. + + Raises: + AuthenticationException: For invalid credentials. + """ + vendor = get_vendor_by_email(cursor, email) + if vendor is None: + raise AuthenticationException() + + valid, updated_hash = verify_password(password, vendor["password_hash"]) + if not valid: + raise AuthenticationException() + + # If the hashing library returned an upgraded hash, persist it + if updated_hash is not None: + cursor.execute( + 'UPDATE app."vendors"' + ' SET "password_hash" = %s, "updated_at" = NOW()' + ' WHERE "id" = %s', + (updated_hash, vendor["id"]), + ) + + access_token = create_access_token(vendor["id"], settings) + refresh_token = create_refresh_token(vendor["id"], settings) + + logger.info("Vendor logged in: %s (client_id=%s)", vendor["id"], client_id) + return TokenPair(access_token=access_token, refresh_token=refresh_token) + + +def refresh( + refresh_token_str: str, client_id: str, cursor: Cursor, settings: Settings +) -> TokenPair: + """Issue a new token pair from a valid refresh token. + + Returns: + TokenPair: The new access/refresh token pair. + + Raises: + AuthenticationException: If the token is invalid/expired + or if it is not a refresh token. + """ + try: + payload = decode_token(refresh_token_str, settings) + except pyjwt.PyJWTError: + raise AuthenticationException( + "Invalid or expired refresh token" + ) from None + + if payload.get("token_type") != "refresh": + raise AuthenticationException("Invalid token type") + + vendor_id = payload.get("vendor_id") + if vendor_id is None: + raise AuthenticationException("Invalid token payload") + + try: + _uuid.UUID(vendor_id) + except (ValueError, AttributeError): + raise AuthenticationException("Invalid token payload") from None + + # Ensure vendor still exists and is not deleted + vendor = get_vendor_by_id(cursor, vendor_id) + if vendor is None: + raise AuthenticationException("Vendor not found") + + access_token = create_access_token(vendor_id, settings) + new_refresh_token = create_refresh_token(vendor_id, settings) + + logger.info( + "Token refreshed for vendor: %s (client_id=%s)", vendor_id, client_id + ) + return TokenPair(access_token=access_token, refresh_token=new_refresh_token) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2043866..35e7e4b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pydantic-settings<3.0.0,>=2.2.1", "pwdlib[argon2,bcrypt]>=0.3.0", "psycopg-pool>=3.3.0,<4.0.0", + "pyjwt>=2.11.0,<3.0.0", ] [dependency-groups] @@ -32,6 +33,70 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["app"] +[tool.ruff] +line-length = 80 +indent-width = 4 +preview = true + +# Output serialization format for violations. The default serialization +# format is "full" [env: RUFF_OUTPUT_FORMAT=] [possible values: +# concise, full, json, json-lines, junit, grouped, github, gitlab, +# pylint, rdjson, azure, sarif] +output-format = "grouped" + +[tool.ruff.lint] +isort.lines-after-imports = 2 +isort.split-on-trailing-comma = false + +select = [ + "ANN", # flake8-annotations (required strict type annotations for public functions) + "S", # flake8-bandit (checks basic security issues in code) + "BLE", # flake8-blind-except (checks the except blocks that do not specify exception) + "FBT", # flake8-boolean-trap (ensure that boolean args can be used with kw only) + "E", # pycodestyle errors (PEP 8 style guide violations) + "W", # pycodestyle warnings (e.g., extra spaces, indentation issues) + "DOC", # pydoclint issues (e.g., extra or missing return, yield, warnings) + "A", # flake8-builtins (check variable and function names to not shadow builtins) + "N", # Naming convention checks (e.g., PEP 8 variable and function names) + "F", # Pyflakes errors (e.g., unused imports, undefined variables) + "I", # isort (Ensures imports are sorted properly) + "B", # flake8-bugbear (Detects likely bugs and bad practices) + "TID", # flake8-tidy-imports (Checks for banned or misplaced imports) + "UP", # pyupgrade (Automatically updates old Python syntax) + "YTT", # flake8-2020 (Detects outdated Python 2/3 compatibility issues) + "FLY", # flynt (Converts old-style string formatting to f-strings) + "PIE", # flake8-pie + "PL", # pylint + "RUF", # Ruff-specific rules (Additional optimizations and best practices) +] + +ignore = [] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "ANN001", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-function-argument/) + "ANN002", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-args/) + "ANN003", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-kwargs/) + "ANN201", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-return-type-undocumented-public-function/) + "ANN202", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-return-type-private-function/) + "S101", # [flake8-bandit](https://docs.astral.sh/ruff/rules/assert/) + "S105", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-string/) + "S106", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-func-arg/) + "S107", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-default/) + "PLR2004", # [pylint](https://docs.astral.sh/ruff/rules/magic-value-comparison/) + "PLR6301", # [pylint](https://docs.astral.sh/ruff/rules/no-self-use/) + "DOC201", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-returns/) + "DOC402", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-yields/) + "DOC501", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-exception/) +] +"tests/__init__.py" = [ + "RUF067", # [ruff](https://docs.astral.sh/ruff/rules/non-empty-init-module/) +] + +[tool.ruff.format] +docstring-code-format = true +skip-magic-trailing-comma = true + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py index c7e2080..37ff788 100644 --- a/backend/tests/__init__.py +++ b/backend/tests/__init__.py @@ -1,5 +1,6 @@ import os + # Environment variables must be set before importing the app os.environ.setdefault("SECRET_KEY", "test-secret-key") os.environ.setdefault("PROJECT_NAME", "permit-test") diff --git a/backend/tests/api/routes/test_health.py b/backend/tests/api/routes/test_health.py index d371700..121d4b8 100644 --- a/backend/tests/api/routes/test_health.py +++ b/backend/tests/api/routes/test_health.py @@ -26,7 +26,8 @@ def test_health_check_returns_ok(): f"Expected status to be 'ok', got '{data['data']['status']}'" ) assert "timestamp" in data["data"], ( - f"Expected 'timestamp' field in health response data, got keys: {data['data'].keys()}" + "Expected 'timestamp' field in health response data," + f" got keys: {data['data'].keys()}" ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index ca04ec6..af1d56c 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -13,19 +13,22 @@ from pathlib import Path import pytest +from psycopg import Cursor, connect +from testcontainers.postgres import PostgresContainer from app.api.deps import get_db from app.main import app -from psycopg import Cursor, connect -from testcontainers.postgres import PostgresContainer + MIGRATIONS_DIR = str(Path(__file__).parents[2] / "migrations") @pytest.fixture(scope="module") def test_container() -> typing.Generator[PostgresContainer, None, None]: - """Fixture that yields a Testcontainers Postgres (PostgresContainer) with migrations applied.""" - with PostgresContainer("postgres:18.2-alpine3.23", driver=None).with_volume_mapping( + """Fixture: Testcontainers Postgres with migrations.""" + with PostgresContainer( + "postgres:18.2-alpine3.23", driver=None + ).with_volume_mapping( MIGRATIONS_DIR, "/docker-entrypoint-initdb.d" ) as container: yield container diff --git a/backend/tests/core/test_exception_handlers.py b/backend/tests/core/test_exception_handlers.py index 7005395..7f2f45d 100644 --- a/backend/tests/core/test_exception_handlers.py +++ b/backend/tests/core/test_exception_handlers.py @@ -1,15 +1,20 @@ """ -Integration tests for exception handlers in app.core.exception_handlers. +Integration tests for exception handlers. -Verifies that every handler (api_exception_handler, validation_exception_handler, -general_exception_handler) produces responses that conform to the error contract: -correct HTTP status codes, error codes, message round-tripping, details structure, -request-id presence/format, and uniqueness. +Verifies that every handler produces responses that conform +to the error contract: correct HTTP status codes, error codes, +message round-tripping, details structure, request-id +presence/format, and uniqueness. """ import re import pytest +from fastapi import FastAPI, status +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from pydantic import BaseModel + from app.core.exception_handlers import ( api_exception_handler, general_exception_handler, @@ -29,10 +34,7 @@ ServiceUnavailableException, ValidationException, ) -from fastapi import FastAPI, status -from fastapi.exceptions import RequestValidationError -from fastapi.testclient import TestClient -from pydantic import BaseModel + # Pre-compiled UUID v4 pattern reused across request-id assertions _UUID_V4_RE = re.compile( @@ -42,11 +44,13 @@ @pytest.fixture(scope="module") def error_contract_app(): - """FastAPI app for error contract handler tests (all exception handlers, one endpoint per error type).""" + """FastAPI app wired with all exception handlers.""" app = FastAPI() app.add_exception_handler(APIException, api_exception_handler) - app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_exception_handler( + RequestValidationError, validation_exception_handler + ) app.add_exception_handler(Exception, general_exception_handler) @app.get("/auth-invalid") @@ -212,7 +216,7 @@ async def validate_nested(data: NestedBodyRequest): def test_error_status_code_and_code_match( error_contract_app, endpoint, expected_code, expected_status ): - """HTTP status code and error code field must match the expected values for every exception type.""" + """HTTP status and error code must match for every type.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get(endpoint) @@ -222,12 +226,14 @@ def test_error_status_code_and_code_match( ) error = response.json()["error"] assert error["code"] == expected_code, ( - f"Expected error code '{expected_code}' for {endpoint}, got '{error['code']}'" + f"Expected error code '{expected_code}' for " + f"{endpoint}, got '{error['code']}'" ) assert error["http_status"] == expected_status, ( - f"Expected error http_status {expected_status}, got {error['http_status']}" + f"Expected error http_status {expected_status}," + f" got {error['http_status']}" ) - # Explicitly validate wire vs body contract: response status must match error http_status + # Wire vs body contract: status must match assert response.status_code == error["http_status"], ( f"Response status code {response.status_code} does not match " f"error http_status {error['http_status']}" @@ -241,7 +247,7 @@ def test_error_status_code_and_code_match( @pytest.mark.integration def test_error_details_structure(error_contract_app): - """Details list must contain exact field+message pairs from the raised exception.""" + """Details list must contain exact field+message pairs.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/validation") @@ -250,7 +256,8 @@ def test_error_details_structure(error_contract_app): f"Expected 'details' field in error response, got keys: {error.keys()}" ) assert len(error["details"]) == 2, ( - f"Expected 2 detail entries, got {len(error['details'])}: {error['details']}" + f"Expected 2 detail entries, got " + f"{len(error['details'])}: {error['details']}" ) detail1 = error["details"][0] @@ -267,13 +274,14 @@ def test_error_details_structure(error_contract_app): f"Expected second detail field 'password', got '{detail2['field']}'" ) assert detail2["message"] == "Too short", ( - f"Expected second detail message 'Too short', got '{detail2['message']}'" + "Expected second detail message 'Too short'," + f" got '{detail2['message']}'" ) @pytest.mark.integration def test_error_details_default_empty(error_contract_app): - """Exceptions raised without details must produce an empty details list in the response.""" + """Exceptions without details must yield empty details.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/auth-invalid") @@ -282,7 +290,8 @@ def test_error_details_default_empty(error_contract_app): f"Expected 'details' field in error response, got keys: {error.keys()}" ) assert error["details"] == [], ( - f"Expected empty details list for exception without details, got {error['details']}" + "Expected empty details list for exception" + f" without details, got {error['details']}" ) @@ -303,7 +312,10 @@ def test_validation_handler_nested_field_path(error_contract_app): assert response.status_code == getattr( status, "HTTP_422_UNPROCESSABLE_CONTENT", 422 - ), f"Expected status 422 for nested validation error, got {response.status_code}" + ), ( + "Expected status 422 for nested validation" + f" error, got {response.status_code}" + ) details = response.json()["error"]["details"] fields = [d["field"] for d in details] assert "address.street" in fields, ( @@ -312,12 +324,15 @@ def test_validation_handler_nested_field_path(error_contract_app): @pytest.mark.integration -def test_validation_handler_body_level_error_sets_field_to_none(error_contract_app): - """When Pydantic emits a body-level error (loc has only one element after slicing), - validation_exception_handler must produce field=None in the detail. +def test_validation_handler_body_level_error_sets_field_to_none( + error_contract_app, +): + """Body-level error (loc with one element after slicing) + must produce field=None in the detail. - Sending an integer body against an object-typed endpoint triggers - loc=["body"], making loc[1:] empty and field_path="", so field=None. + Sending an integer body against an object-typed endpoint + triggers loc=["body"], making loc[1:] empty and + field_path="", so field=None. """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json=5) @@ -325,13 +340,17 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a assert response.status_code == getattr( status, "HTTP_422_UNPROCESSABLE_CONTENT", 422 ), ( - f"Expected status 422 for body-level validation error, got {response.status_code}" + "Expected status 422 for body-level" + f" validation error, got {response.status_code}" ) details = response.json()["error"]["details"] - assert len(details) > 0, "Expected at least one validation detail, got empty list" + assert len(details) > 0, ( + "Expected at least one validation detail, got empty list" + ) # At least one detail must have field=None from the body-level loc assert any(d["field"] is None for d in details), ( - f"Expected at least one detail with field=None for a body-level error, got {details}" + "Expected at least one detail with" + f" field=None for body-level error, got {details}" ) @@ -344,7 +363,9 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a @pytest.mark.parametrize( "endpoint,method,json_body,raise_exceptions", [ - pytest.param("/auth-invalid", "get", None, True, id="api_exception_handler"), + pytest.param( + "/auth-invalid", "get", None, True, id="api_exception_handler" + ), pytest.param( "/validate-body", "post", @@ -360,21 +381,25 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a def test_request_id_header_matches_body_and_is_uuid_v4( error_contract_app, endpoint, method, json_body, raise_exceptions ): - """X-Request-ID header must be present, match the body request_id, and be a valid UUID v4.""" - client = TestClient(error_contract_app, raise_server_exceptions=raise_exceptions) + """X-Request-ID header must match body and be UUID v4.""" + client = TestClient( + error_contract_app, raise_server_exceptions=raise_exceptions + ) request_method = getattr(client, method) kwargs = {"json": json_body} if json_body is not None else {} response = request_method(endpoint, **kwargs) assert "X-Request-ID" in response.headers, ( - f"Expected 'X-Request-ID' header in response, got headers: {list(response.headers.keys())}" + "Expected 'X-Request-ID' header, got" + f" headers: {list(response.headers.keys())}" ) request_id = response.headers["X-Request-ID"] assert request_id, "X-Request-ID header must not be empty" error = response.json()["error"] assert request_id == error["request_id"], ( - f"Header X-Request-ID '{request_id}' does not match body request_id '{error['request_id']}'" + f"Header X-Request-ID '{request_id}' does not" + f" match body request_id '{error['request_id']}'" ) assert _UUID_V4_RE.match(request_id), ( f"request_id '{request_id}' is not a valid UUID v4 format" @@ -383,13 +408,17 @@ def test_request_id_header_matches_body_and_is_uuid_v4( @pytest.mark.integration def test_request_id_uniqueness_across_requests(error_contract_app): - """Repeated requests to the same endpoint must each receive a distinct request_id.""" + """Repeated requests must each get a distinct request_id.""" client = TestClient(error_contract_app, raise_server_exceptions=False) - ids = {client.get("/auth-invalid").json()["error"]["request_id"] for _ in range(5)} + ids = { + client.get("/auth-invalid").json()["error"]["request_id"] + for _ in range(5) + } assert len(ids) == 5, ( - f"Expected 5 unique request_ids from 5 requests to the same endpoint, " - f"got {len(ids)} unique IDs. request_id may not be regenerated per request." + "Expected 5 unique request_ids from 5 requests," + f" got {len(ids)} unique IDs." + " request_id may not be regenerated per request." ) @@ -415,8 +444,12 @@ def test_api_exception_handler_returns_correct_structure(error_contract_app): @pytest.mark.integration -def test_validation_exception_handler_returns_correct_structure(error_contract_app): - """validation_exception_handler must return VALIDATION_FAILED with populated details.""" +def test_validation_exception_handler_returns_correct_structure( + error_contract_app, +): + """validation_exception_handler must return + VALIDATION_FAILED with populated details. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json={"invalid": "data"}) @@ -442,13 +475,16 @@ def test_validation_exception_handler_returns_correct_structure(error_contract_a f"Expected 'message' key in detail, got keys: {detail.keys()}" ) assert isinstance(detail["message"], str), ( - f"Expected detail message to be string, got {type(detail['message'])}" + "Expected detail message to be string," + f" got {type(detail['message'])}" ) @pytest.mark.integration def test_general_exception_handler_returns_sanitized_500(error_contract_app): - """general_exception_handler must return 500 with a sanitized message, never internal details.""" + """general_exception_handler must return 500 + with a sanitized message, never internal details. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/general-error") @@ -474,7 +510,9 @@ def test_general_exception_handler_returns_sanitized_500(error_contract_app): @pytest.mark.integration def test_validation_error_message_and_details(error_contract_app): - """validation_exception_handler must return the standard message with non-empty detail messages.""" + """validation_exception_handler must return the + standard message with non-empty detail messages. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json={"invalid": "data"}) @@ -490,7 +528,8 @@ def test_validation_error_message_and_details(error_contract_app): "Expected non-empty message in detail, got empty string" ) assert isinstance(detail["message"], str), ( - f"Expected detail message to be string, got {type(detail['message'])}" + "Expected detail message to be string," + f" got {type(detail['message'])}" ) @@ -501,24 +540,36 @@ def test_validation_error_message_and_details(error_contract_app): pytest.param("/auth-invalid", "Invalid credentials", id="auth_invalid"), pytest.param("/auth-expired", "Token has expired", id="auth_expired"), pytest.param("/forbidden", "Access denied", id="forbidden"), - pytest.param("/not-found", "Resource not found", id="resource_not_found"), + pytest.param( + "/not-found", "Resource not found", id="resource_not_found" + ), pytest.param("/conflict", "Resource conflict", id="resource_conflict"), pytest.param("/validation", "Validation failed", id="validation_error"), pytest.param( - "/business-logic", "Cannot process request", id="business_logic_error" + "/business-logic", + "Cannot process request", + id="business_logic_error", ), pytest.param( "/service-unavailable", "Database is down", id="service_unavailable" ), - pytest.param("/license-not-found", "No active license", id="license_not_found"), - pytest.param("/license-revoked", "License was revoked", id="license_revoked"), - pytest.param("/license-expired", "License has expired", id="license_expired"), + pytest.param( + "/license-not-found", "No active license", id="license_not_found" + ), + pytest.param( + "/license-revoked", "License was revoked", id="license_revoked" + ), + pytest.param( + "/license-expired", "License has expired", id="license_expired" + ), ], ) def test_raised_error_messages_roundtrip_correctly( error_contract_app, endpoint, expected_message ): - """The exact message passed to the exception constructor must appear in the response.""" + """The message passed to the constructor must + appear in the response. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get(endpoint) error = response.json()["error"] diff --git a/backend/tests/core/test_exceptions.py b/backend/tests/core/test_exceptions.py index 68609a9..732e4d9 100644 --- a/backend/tests/core/test_exceptions.py +++ b/backend/tests/core/test_exceptions.py @@ -1,8 +1,9 @@ """ Unit tests for exception classes in app.core.exceptions. -Covers: error codes, HTTP status codes, base-class parameter storage, -details defaulting, custom details, custom messages, and default message lengths. +Covers: error codes, HTTP status codes, base-class parameter +storage, details defaulting, custom details, custom messages, +and default message lengths. """ import pytest @@ -136,9 +137,7 @@ def test_exception_has_correct_error_code(exception_class, expected_code): id="not_found_exception", ), pytest.param( - ConflictException, - status.HTTP_409_CONFLICT, - id="conflict_exception", + ConflictException, status.HTTP_409_CONFLICT, id="conflict_exception" ), pytest.param( BusinessLogicException, @@ -184,7 +183,7 @@ def test_exception_has_correct_http_status(exception_class, expected_status): @pytest.mark.unit def test_api_exception_base_class_stores_all_parameters(): - """APIException must store all constructor arguments and pass message to str().""" + """APIException must store all constructor args.""" exc = APIException( error_code=ErrorCode.VALIDATION_FAILED, message="Test error", @@ -204,7 +203,7 @@ def test_api_exception_base_class_stores_all_parameters(): f"Expected details with test field, got {exc.details}" ) assert str(exc) == "Test error", ( - f"Expected str(exc) to return message, got '{str(exc)}'" + f"Expected str(exc) to return message, got '{exc!s}'" ) @@ -271,10 +270,16 @@ def test_exception_with_details_none_defaults_to_empty_list(exception_class): ], ) def test_exception_stores_custom_details(exception_class, message, details): - """Exceptions must store the exact message and details passed at construction.""" + """Exceptions must store the exact message and + details passed at construction. + """ exc = exception_class(message=message, details=details) - assert exc.message == message, f"Expected message '{message}', got '{exc.message}'" - assert exc.details == details, f"Expected details {details}, got {exc.details}" + assert exc.message == message, ( + f"Expected message '{message}', got '{exc.message}'" + ) + assert exc.details == details, ( + f"Expected details {details}, got {exc.details}" + ) # --------------------------------------------------------------------------- @@ -291,8 +296,12 @@ def test_exception_stores_custom_details(exception_class, message, details): pytest.param(AuthorizationException, id="authorization_exception"), pytest.param(NotFoundException, id="not_found_exception"), pytest.param(ConflictException, id="conflict_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], @@ -339,8 +348,12 @@ def test_parameterized_exception_accepts_custom_message(exception_class): pytest.param(NotFoundException, id="not_found_exception"), pytest.param(ConflictException, id="conflict_exception"), pytest.param(BusinessLogicException, id="business_logic_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], @@ -380,16 +393,22 @@ def test_api_exception_is_subclass_of_exception(): pytest.param(ConflictException, id="conflict_exception"), pytest.param(ValidationException, id="validation_exception"), pytest.param(BusinessLogicException, id="business_logic_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], ) def test_concrete_exception_is_subclass_of_api_exception(exception_class): - """Every concrete exception class must be a subclass of APIException so that - api_exception_handler catches it. If this invariant breaks, those exceptions - silently fall through to general_exception_handler and return 500s.""" + """Every concrete exception must subclass + APIException so api_exception_handler catches it. + If this invariant breaks, those exceptions silently + fall through to general_exception_handler (500s). + """ assert issubclass(exception_class, APIException), ( f"{exception_class.__name__} is not a subclass of APIException; " f"api_exception_handler will not catch it" diff --git a/backend/tests/integration/__init__.py b/backend/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py new file mode 100644 index 0000000..e2c340b --- /dev/null +++ b/backend/tests/integration/test_auth.py @@ -0,0 +1,339 @@ +"""Integration tests for auth endpoints (signup, login, refresh). + +Uses Testcontainers for a real PostgreSQL instance with migrations applied. +Verifies the full HTTP round-trip including RLS context setting. +""" + +from __future__ import annotations + +import typing +import uuid +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path + +import pytest +from fastapi import APIRouter as _APIRouter +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from psycopg import Cursor, connect +from testcontainers.postgres import PostgresContainer + +from app.api.deps import ( + CurrentVendorId, + RLSCursorDep, + get_db, + get_rls_cursor, + get_settings, +) +from app.api.main import api_router +from app.core.config import Settings +from app.core.exception_handlers import ( + api_exception_handler, + general_exception_handler, + validation_exception_handler, +) +from app.core.exceptions import APIException +from app.core.security import create_access_token + + +MIGRATIONS_DIR = str(Path(__file__).parents[3] / "migrations") +API_V1 = "/api/v1" + +_test_router = _APIRouter() + + +@_test_router.get("/protected-test") +def _protected_test(vendor_id: CurrentVendorId, cursor: RLSCursorDep) -> dict: + cursor.execute("SELECT current_setting('app.vendor_id', true)") + row = cursor.fetchone() + db_vendor_id = row[0] if row else None + return {"vendor_id": vendor_id, "db_vendor_id": db_vendor_id} + + +@pytest.fixture(scope="function") +def pg_container() -> typing.Generator[PostgresContainer, None, None]: + with PostgresContainer( + "postgres:18.2-alpine3.23", driver=None + ).with_volume_mapping( + MIGRATIONS_DIR, "/docker-entrypoint-initdb.d" + ) as container: + yield container + + +@pytest.fixture(scope="function") +def test_settings() -> Settings: + return Settings( + SECRET_KEY="integration-test-secret-key-32bytes!", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ACCESS_TOKEN_EXPIRE_MINUTES=60, + REFRESH_TOKEN_EXPIRE_DAYS=7, + ) + + +@pytest.fixture(scope="function") +def client( + pg_container: PostgresContainer, test_settings: Settings +) -> typing.Generator[TestClient, None, None]: + """TestClient with a lightweight test app (no real lifespan).""" + + @asynccontextmanager + async def _noop_lifespan(app: FastAPI): # noqa: RUF029 + yield + + test_app = FastAPI(lifespan=_noop_lifespan) + test_app.include_router(api_router, prefix=API_V1) + test_app.include_router(_test_router, prefix=API_V1) + test_app.add_exception_handler(APIException, api_exception_handler) + test_app.add_exception_handler( + RequestValidationError, validation_exception_handler + ) + test_app.add_exception_handler(Exception, general_exception_handler) + + def _override_get_db() -> typing.Generator[Cursor, None, None]: + with connect(pg_container.get_connection_url()) as conn: + with conn.cursor() as cur: + yield cur + + def _override_get_settings() -> Settings: + return test_settings + + test_app.dependency_overrides[get_db] = _override_get_db + test_app.dependency_overrides[get_settings] = _override_get_settings + + def _override_get_rls_cursor( + vendor_id: CurrentVendorId, + ) -> typing.Generator[Cursor, None, None]: + with connect(pg_container.get_connection_url()) as conn: + with conn.cursor() as cur: + cur.execute("SELECT app.set_app_context(%s)", (vendor_id,)) + yield cur + + test_app.dependency_overrides[get_rls_cursor] = _override_get_rls_cursor + + with TestClient(test_app) as tc: + yield tc + + +def _signup( + client: TestClient, + email: str = "vendor@test.com", + password: str = "SecurePass123!", +) -> dict: + return client.post( + f"{API_V1}/auth/signup", + json={ + "email": email, + "password": password, + "client_id": "integration-test", + }, + ).json() + + +def _login( + client: TestClient, + email: str = "vendor@test.com", + password: str = "SecurePass123!", +) -> dict: + return client.post( + f"{API_V1}/auth/login", + json={ + "email": email, + "password": password, + "client_id": "integration-test", + }, + ).json() + + +@pytest.mark.integration +class TestSignup: + def test_signup_creates_vendor_201(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={ + "email": "signup-test@example.com", + "password": "StrongPass1!", + "client_id": "test-client", + }, + ) + assert resp.status_code == 201 + body = resp.json() + assert "data" in body + assert body["data"]["vendor"]["email"] == "signup-test@example.com" + assert "id" in body["data"]["vendor"] + + def test_signup_duplicate_email_409(self, client: TestClient): + email = "dup@example.com" + # First signup succeeds + resp1 = client.post( + f"{API_V1}/auth/signup", + json={"email": email, "password": "Pass12345!", "client_id": "c1"}, + ) + assert resp1.status_code == 201 + + # Second signup with same email fails + resp2 = client.post( + f"{API_V1}/auth/signup", + json={"email": email, "password": "Pass12345!", "client_id": "c1"}, + ) + assert resp2.status_code == 409 + + def test_signup_weak_password_422(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={ + "email": "weak@example.com", + "password": "short", + "client_id": "c1", + }, + ) + assert resp.status_code == 422 + + def test_signup_missing_client_id_422(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={"email": "no-client@example.com", "password": "StrongPass1!"}, + ) + assert resp.status_code == 422 + + +@pytest.mark.integration +class TestLogin: + def test_login_returns_token_pair(self, client: TestClient): + email = "login-test@example.com" + _signup(client, email=email) + + resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": email, + "password": "SecurePass123!", + "client_id": "c1", + }, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + + def test_login_wrong_password_401(self, client: TestClient): + email = "login-fail@example.com" + _signup(client, email=email) + + resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": email, + "password": "WrongPassword!", + "client_id": "c1", + }, + ) + assert resp.status_code == 401 + + def test_login_nonexistent_email_401(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": "nobody@example.com", + "password": "Pass12345!", + "client_id": "c1", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.integration +class TestRefresh: + def test_refresh_issues_new_tokens(self, client: TestClient): + email = "refresh-test@example.com" + _signup(client, email=email) + login_data = _login(client, email=email)["data"] + + resp = client.post( + f"{API_V1}/auth/refresh", + json={ + "refresh_token": login_data["refresh_token"], + "client_id": "c1", + }, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + assert "access_token" in data + assert "refresh_token" in data + + def test_refresh_with_access_token_fails_401(self, client: TestClient): + email = "refresh-bad@example.com" + _signup(client, email=email) + login_data = _login(client, email=email)["data"] + + resp = client.post( + f"{API_V1}/auth/refresh", + json={ + "refresh_token": login_data["access_token"], # wrong token type + "client_id": "c1", + }, + ) + assert resp.status_code == 401 + + def test_refresh_with_garbage_token_fails_401(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/refresh", + json={"refresh_token": "not.a.real.token", "client_id": "c1"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.integration +class TestProtectedEndpoints: + """Verify that protected endpoints reject invalid/missing tokens + via the full HTTP layer (Bearer parsing → dependency injection → + exception-to-HTTP translation). + """ + + def test_missing_token_401(self, client: TestClient): + resp = client.get(f"{API_V1}/protected-test") + assert resp.status_code == 401 + + def test_expired_token_401( + self, client: TestClient, test_settings: Settings + ): + token = create_access_token( + str(uuid.uuid4()), + test_settings, + expires_delta=timedelta(seconds=-1), + ) + resp = client.get( + f"{API_V1}/protected-test", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 401 + + def test_valid_token_returns_vendor_id(self, client: TestClient): + email = "protected-test@example.com" + signup_data = _signup(client, email=email) + created_vendor_id = signup_data["data"]["vendor"]["id"] + login_resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": email, + "password": "SecurePass123!", + "client_id": "c1", + }, + ) + assert login_resp.status_code == 200 + token = login_resp.json()["data"]["access_token"] + + resp = client.get( + f"{API_V1}/protected-test", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["vendor_id"] == created_vendor_id + assert body["db_vendor_id"] == created_vendor_id diff --git a/backend/tests/schemas/test_response.py b/backend/tests/schemas/test_response.py index 655ad99..e8da933 100644 --- a/backend/tests/schemas/test_response.py +++ b/backend/tests/schemas/test_response.py @@ -41,7 +41,9 @@ def test_error_body_response_stores_all_fields(): assert error_body.http_status == status.HTTP_400_BAD_REQUEST, ( f"Expected http_status 400, got {error_body.http_status}" ) - assert error_body.details == [], f"Expected empty details, got {error_body.details}" + assert error_body.details == [], ( + f"Expected empty details, got {error_body.details}" + ) assert error_body.request_id == "req-123", ( f"Expected request_id 'req-123', got '{error_body.request_id}'" ) @@ -70,13 +72,15 @@ def test_error_body_response_stores_all_fields(): ], ) def test_error_body_response_required_field_raises_validation_error(kwargs): - """Omitting a required field on ErrorBodyResponse must raise ValidationError.""" + """Omitting a required field must raise ValidationError.""" with pytest.raises(ValidationError) as exc_info: ErrorBodyResponse(**kwargs) # Optionally verify the missing field is in the error error_fields = {e["loc"][0] for e in exc_info.value.errors()} expected_missing = {"http_status", "request_id"} - kwargs.keys() - assert expected_missing & error_fields, f"Expected error for {expected_missing}" + assert expected_missing & error_fields, ( + f"Expected error for {expected_missing}" + ) @pytest.mark.unit @@ -110,10 +114,12 @@ def test_error_response_envelope_structure(): ) ) assert isinstance(error_response.error, ErrorBodyResponse), ( - f"Expected error field to be ErrorBodyResponse, got {type(error_response.error)}" + "Expected error field to be ErrorBodyResponse," + f" got {type(error_response.error)}" ) assert error_response.error.http_status == status.HTTP_400_BAD_REQUEST, ( - f"Expected wrapped error http_status 400, got {error_response.error.http_status}" + "Expected wrapped error http_status 400," + f" got {error_response.error.http_status}" ) @@ -133,7 +139,9 @@ def test_error_detail_requires_message_field(): def test_error_detail_accepts_field_as_none(): """ErrorDetail.field is optional — None must be stored as-is.""" detail = ErrorDetail(field=None, message="Test error") - assert detail.field is None, f"Expected field to be None, got {detail.field}" + assert detail.field is None, ( + f"Expected field to be None, got {detail.field}" + ) assert detail.message == "Test error", ( f"Expected message 'Test error', got '{detail.message}'" ) @@ -143,7 +151,9 @@ def test_error_detail_accepts_field_as_none(): def test_error_detail_with_both_fields_populated(): """ErrorDetail must store both field and message when both are provided.""" detail = ErrorDetail(field="email", message="Invalid email format") - assert detail.field == "email", f"Expected field 'email', got '{detail.field}'" + assert detail.field == "email", ( + f"Expected field 'email', got '{detail.field}'" + ) assert detail.message == "Invalid email format", ( f"Expected message 'Invalid email format', got '{detail.message}'" ) @@ -174,7 +184,9 @@ def test_error_body_response_rejects_invalid_error_code(): @pytest.mark.unit def test_error_response_requires_error_field(): - """ErrorResponse.error is a required field (Field(...)). Omitting it must - raise a ValidationError rather than silently constructing an empty envelope.""" + """ErrorResponse.error is required. Omitting it must + raise a ValidationError rather than silently + constructing an empty envelope. + """ with pytest.raises(ValidationError): ErrorResponse() diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/unit/test_auth.py b/backend/tests/unit/test_auth.py new file mode 100644 index 0000000..8697d20 --- /dev/null +++ b/backend/tests/unit/test_auth.py @@ -0,0 +1,307 @@ +"""Unit tests for JWT-based vendor authentication. + +Tests cover: +- Password hashing (bcrypt) +- Token creation and decoding +- Auth service logic (signup, login, refresh) +- Auth dependency (JWT validation) +""" + +from __future__ import annotations + +import uuid +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import jwt as pyjwt +import pytest + +from app.api.deps import get_current_vendor_id +from app.core.config import Settings +from app.core.exceptions import AuthenticationException, ConflictException +from app.core.security import ( + create_access_token, + create_refresh_token, + decode_token, + get_password_hash, + verify_password, +) +from app.services.auth import login, refresh, signup + + +@pytest.fixture +def settings() -> Settings: + """Minimal settings for security tests.""" + return Settings( + SECRET_KEY="test-secret-key-for-unit-tests", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ACCESS_TOKEN_EXPIRE_MINUTES=60, + REFRESH_TOKEN_EXPIRE_DAYS=7, + ) + + +@pytest.mark.unit +class TestPasswordHashing: + def test_hash_and_verify(self): + plain = "SuperSecret123!" + hashed = get_password_hash(plain) + + assert hashed != plain + assert hashed.startswith("$2") # bcrypt prefix + valid, _ = verify_password(plain, hashed) + assert valid is True + + def test_wrong_password_fails(self): + hashed = get_password_hash("correct-password") + valid, _ = verify_password("wrong-password", hashed) + assert valid is False + + +@pytest.mark.unit +class TestTokens: + def test_access_token_claims(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + + payload = decode_token(token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "access" + assert "exp" in payload + + def test_refresh_token_claims(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_refresh_token(vendor_id, settings) + + payload = decode_token(token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "refresh" + assert "exp" in payload + + def test_expired_token_raises(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + + with pytest.raises(pyjwt.ExpiredSignatureError): + decode_token(token, settings) + + def test_invalid_signature_raises(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + + bad_settings = Settings( + SECRET_KEY="different-secret", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ) + with pytest.raises(pyjwt.InvalidSignatureError): + decode_token(token, bad_settings) + + +@pytest.mark.unit +class TestAuthService: + """Test auth service with mocked CRUD layer.""" + + def test_signup_success(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + + with ( + patch("app.services.auth.get_vendor_by_email", return_value=None), + patch( + "app.services.auth.create_vendor", + return_value={"id": vendor_id, "email": "v@test.com"}, + ), + ): + result = signup( + cursor, "v@test.com", "password123", "client-1", settings + ) + + assert result.vendor.id == vendor_id + assert result.vendor.email == "v@test.com" + + def test_signup_duplicate_email_raises(self, settings: Settings): + cursor = MagicMock() + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={ + "id": "x", + "email": "v@test.com", + "password_hash": "h", + }, + ): + with pytest.raises(ConflictException): + signup( + cursor, "v@test.com", "password123", "client-1", settings + ) + + def test_login_success(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + hashed = get_password_hash("password123") + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={ + "id": vendor_id, + "email": "v@test.com", + "password_hash": hashed, + }, + ): + result = login( + cursor, "v@test.com", "password123", "client-1", settings + ) + + assert result.access_token + assert result.refresh_token + assert result.token_type == "bearer" + + payload = decode_token(result.access_token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "access" + + def test_login_wrong_email_raises(self, settings: Settings): + cursor = MagicMock() + + with patch("app.services.auth.get_vendor_by_email", return_value=None): + with pytest.raises(AuthenticationException): + login( + cursor, "bad@test.com", "password123", "client-1", settings + ) + + def test_login_wrong_password_raises(self, settings: Settings): + cursor = MagicMock() + hashed = get_password_hash("correct-password") + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={ + "id": "x", + "email": "v@test.com", + "password_hash": hashed, + }, + ): + with pytest.raises(AuthenticationException): + login( + cursor, "v@test.com", "wrong-password", "client-1", settings + ) + + def test_refresh_success(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token(vendor_id, settings) + + with patch( + "app.services.auth.get_vendor_by_id", + return_value={"id": vendor_id, "email": "v@test.com"}, + ): + result = refresh(rt, "client-1", cursor, settings) + + assert result.access_token + assert result.refresh_token + + def test_refresh_with_access_token_raises(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + at = create_access_token(vendor_id, settings) + + with pytest.raises(AuthenticationException, match="Invalid token type"): + refresh(at, "client-1", cursor, settings) + + def test_refresh_expired_raises(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + + with pytest.raises(AuthenticationException): + refresh(rt, "client-1", cursor, settings) + + def test_refresh_deleted_vendor_raises(self, settings: Settings): + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token(vendor_id, settings) + + with patch("app.services.auth.get_vendor_by_id", return_value=None): + with pytest.raises( + AuthenticationException, match="Vendor not found" + ): + refresh(rt, "client-1", cursor, settings) + + def test_signup_concurrent_insert_conflict(self, settings: Settings): + """Pre-read shows no existing vendor, but the insert collides + (create_vendor returns None due to ON CONFLICT DO NOTHING). + """ + cursor = MagicMock() + + with ( + patch("app.services.auth.get_vendor_by_email", return_value=None), + patch("app.services.auth.create_vendor", return_value=None), + ): + with pytest.raises(ConflictException): + signup( + cursor, "race@test.com", "password123", "client-1", settings + ) + + +@pytest.mark.unit +class TestGetCurrentVendorId: + def test_valid_access_token(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + creds = MagicMock() + creds.credentials = token + + result = get_current_vendor_id(creds, settings) + assert result == vendor_id + + def test_missing_credentials_raises(self, settings: Settings): + with pytest.raises(AuthenticationException, match="Missing"): + get_current_vendor_id(None, settings) + + def test_refresh_token_rejected(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_refresh_token(vendor_id, settings) + creds = MagicMock() + creds.credentials = token + + with pytest.raises(AuthenticationException, match="Invalid token type"): + get_current_vendor_id(creds, settings) + + def test_expired_token_raises(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + creds = MagicMock() + creds.credentials = token + + with pytest.raises(AuthenticationException): + get_current_vendor_id(creds, settings) + + def test_garbage_token_raises(self, settings: Settings): + creds = MagicMock() + creds.credentials = "not.a.jwt" + + with pytest.raises(AuthenticationException): + get_current_vendor_id(creds, settings) + + def test_malformed_vendor_id_raises(self, settings: Settings): + token = create_access_token("invalid-uuid", settings) + creds = MagicMock() + creds.credentials = token + + with pytest.raises( + AuthenticationException, match="Invalid token payload" + ): + get_current_vendor_id(creds, settings) diff --git a/uv.lock b/uv.lock index 9da86ba..6652a70 100644 --- a/uv.lock +++ b/uv.lock @@ -959,6 +959,7 @@ dependencies = [ { name = "pwdlib", extra = ["argon2", "bcrypt"] }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-multipart" }, { name = "tenacity" }, ] @@ -982,6 +983,7 @@ requires-dist = [ { name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" }, { name = "pydantic", specifier = ">2.0" }, { name = "pydantic-settings", specifier = ">=2.2.1,<3.0.0" }, + { name = "pyjwt", specifier = ">=2.11.0,<3.0.0" }, { name = "python-multipart", specifier = ">=0.0.7,<1.0.0" }, { name = "tenacity", specifier = ">=8.2.3,<9.0.0" }, ] @@ -1321,6 +1323,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, +] + [[package]] name = "pytest" version = "7.4.4"