From b1a078e8e392318753dd6a99ab71536c414ba3b8 Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 10:40:20 +0500 Subject: [PATCH 1/7] Implement JWT authentication with signup, login, and token refresh endpoints --- backend/app/api/deps.py | 61 +++- backend/app/api/main.py | 3 +- backend/app/api/routes/auth.py | 64 ++++ backend/app/core/config.py | 5 +- backend/app/core/security.py | 43 ++- backend/app/crud/{.gitkeep => __init__.py} | 0 backend/app/crud/vendor.py | 49 +++ backend/app/schemas/auth.py | 39 +++ .../app/services/{.gitkeep => __init__.py} | 0 backend/app/services/auth.py | 113 +++++++ backend/pyproject.toml | 1 + backend/tests/integration/test_auth.py | 268 +++++++++++++++ backend/tests/unit/test_auth.py | 304 ++++++++++++++++++ uv.lock | 11 + 14 files changed, 946 insertions(+), 15 deletions(-) create mode 100644 backend/app/api/routes/auth.py rename backend/app/crud/{.gitkeep => __init__.py} (100%) create mode 100644 backend/app/crud/vendor.py create mode 100644 backend/app/schemas/auth.py rename backend/app/services/{.gitkeep => __init__.py} (100%) create mode 100644 backend/app/services/auth.py create mode 100644 backend/tests/integration/test_auth.py create mode 100644 backend/tests/unit/test_auth.py diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index a07bab4..629a28a 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,11 +1,17 @@ from collections.abc import Generator from typing import Annotated +import jwt as pyjwt from fastapi import Depends, Request +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from psycopg import Cursor from app.core.config import Settings -from app.core.exceptions import ServiceUnavailableException +from app.core.exceptions import AuthenticationException, ServiceUnavailableException +from app.core.security import decode_token + +# ── Re-usable bearer scheme (auto-documents in OpenAPI) ────── +_bearer_scheme = HTTPBearer(auto_error=False) def get_db(request: Request) -> Generator[Cursor, None, None]: @@ -28,3 +34,56 @@ def get_settings(request: Request) -> Settings: CursorDep = Annotated[Cursor, Depends(get_db)] SettingsDep = Annotated[Settings, Depends(get_settings)] + + +def get_current_vendor_id( + credentials: Annotated[ + HTTPAuthorizationCredentials | None, Depends(_bearer_scheme) + ], + settings: SettingsDep, +) -> str: + """Extract and validate vendor_id from the Authorization: Bearer token. + + Raises AuthenticationException on missing / invalid / expired tokens + or if the token is not an access token. + """ + if credentials is None: + raise AuthenticationException("Missing authentication token") + + try: + payload = decode_token(credentials.credentials, settings) + except pyjwt.PyJWTError: + raise AuthenticationException("Invalid or expired token") + + if payload.get("token_type") != "access": + raise AuthenticationException("Invalid token type") + + vendor_id: str | None = payload.get("vendor_id") + if vendor_id is None: + raise AuthenticationException("Invalid token payload") + + return vendor_id + + +def get_rls_cursor( + request: Request, + vendor_id: Annotated[str, Depends(get_current_vendor_id)], +) -> Generator[Cursor, None, None]: + """Return a database cursor with app.vendor_id set for RLS. + + After authentication, this dependency: + 1. Obtains a connection from the pool + 2. Calls app.set_app_context(vendor_id) to set the RLS context + 3. Yields the cursor for use in route handlers + """ + pool = getattr(request.app.state, "db_pool", None) + if pool is None: + raise ServiceUnavailableException("Database pool not initialized") + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT app.set_app_context(%s)", (vendor_id,)) + yield cursor + + +CurrentVendorId = Annotated[str, Depends(get_current_vendor_id)] +RLSCursorDep = Annotated[Cursor, Depends(get_rls_cursor)] diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 12ab2ae..1bb1833 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,7 +1,8 @@ from fastapi import APIRouter -from app.api.routes import health, login +from app.api.routes import auth, health, login api_router = APIRouter() api_router.include_router(health.router, tags=["health"]) api_router.include_router(login.router, prefix="/login", tags=["login"]) +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py new file mode 100644 index 0000000..7e0b59c --- /dev/null +++ b/backend/app/api/routes/auth.py @@ -0,0 +1,64 @@ +"""Auth routes — signup, login, token refresh.""" + +from fastapi import APIRouter, status + +from app.api.deps import CursorDep, SettingsDep +from app.schemas.auth import ( + LoginRequest, + RefreshRequest, + SignupRequest, + SignupResponse, + TokenPair, +) +from app.schemas.response import SuccessResponse +from app.services import auth as auth_service + +router = APIRouter() + + +@router.post( + "/signup", + status_code=status.HTTP_201_CREATED, + response_model=SuccessResponse[SignupResponse], +) +def signup(body: SignupRequest, cursor: CursorDep, settings: SettingsDep): + """Create a new vendor account.""" + result = auth_service.signup( + cursor=cursor, + email=body.email, + password=body.password, + client_id=body.client_id, + settings=settings, + ) + return SuccessResponse(data=result) + + +@router.post( + "/login", + response_model=SuccessResponse[TokenPair], +) +def login(body: LoginRequest, cursor: CursorDep, settings: SettingsDep): + """Authenticate a vendor and return an access/refresh token pair.""" + result = auth_service.login( + cursor=cursor, + email=body.email, + password=body.password, + client_id=body.client_id, + settings=settings, + ) + return SuccessResponse(data=result) + + +@router.post( + "/refresh", + response_model=SuccessResponse[TokenPair], +) +def refresh(body: RefreshRequest, cursor: CursorDep, settings: SettingsDep): + """Issue a new token pair using a valid refresh token.""" + result = auth_service.refresh( + refresh_token_str=body.refresh_token, + client_id=body.client_id, + cursor=cursor, + settings=settings, + ) + return SuccessResponse(data=result) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f25f85a..22971f1 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -13,8 +13,9 @@ class Settings(BaseSettings): ) SECRET_KEY: str - # 60 minutes * 24 hours = 1 day (configurable via ACCESS_TOKEN_EXPIRE_MINUTES env var) - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 + # JWT token lifetimes + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 # 1 hour + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days PROJECT_NAME: str POSTGRES_SERVER: str diff --git a/backend/app/core/security.py b/backend/app/core/security.py index bb587e4..9858f12 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -3,29 +3,50 @@ import jwt from pwdlib import PasswordHash -from pwdlib.hashers.argon2 import Argon2Hasher from pwdlib.hashers.bcrypt import BcryptHasher from app.core.config import Settings -password_hash = PasswordHash( - ( - Argon2Hasher(), - BcryptHasher(), - ) -) +# Use bcrypt only as required by spec +password_hash = PasswordHash((BcryptHasher(),)) ALGORITHM = "HS256" def create_access_token( - subject: str | Any, expires_delta: timedelta, settings: Settings + vendor_id: str, settings: Settings, *, expires_delta: timedelta | None = None ) -> str: + """Create a short-lived access token with vendor_id claim.""" + if expires_delta is None: + expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + expires_delta - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + to_encode = { + "vendor_id": str(vendor_id), + "exp": expire, + "token_type": "access", + } + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def create_refresh_token( + vendor_id: str, settings: Settings, *, expires_delta: timedelta | None = None +) -> str: + """Create a long-lived refresh token with vendor_id claim.""" + if expires_delta is None: + expires_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + expire = datetime.now(timezone.utc) + expires_delta + to_encode = { + "vendor_id": str(vendor_id), + "exp": expire, + "token_type": "refresh", + } + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def decode_token(token: str, settings: Settings) -> dict[str, Any]: + """Decode and validate a JWT token. Raises jwt.PyJWTError on failure.""" + return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) def verify_password( diff --git a/backend/app/crud/.gitkeep b/backend/app/crud/__init__.py similarity index 100% rename from backend/app/crud/.gitkeep rename to backend/app/crud/__init__.py diff --git a/backend/app/crud/vendor.py b/backend/app/crud/vendor.py new file mode 100644 index 0000000..dc635af --- /dev/null +++ b/backend/app/crud/vendor.py @@ -0,0 +1,49 @@ +"""CRUD operations for the vendors table (raw psycopg).""" + +from __future__ import annotations + +from typing import Any + +from psycopg import Cursor + + +def get_vendor_by_email(cursor: Cursor, email: str) -> dict[str, Any] | None: + """Return a vendor row by email (case-insensitive) or None.""" + cursor.execute( + 'SELECT "id", "email", "password_hash" ' + 'FROM app."vendors" ' + 'WHERE LOWER("email") = LOWER(%s) ' + 'AND "deleted_at" IS NULL', + (email,), + ) + row = cursor.fetchone() + if row is None: + return None + return {"id": str(row[0]), "email": row[1], "password_hash": row[2]} + + +def get_vendor_by_id(cursor: Cursor, vendor_id: str) -> dict[str, Any] | None: + """Return a vendor row by id or None.""" + cursor.execute( + 'SELECT "id", "email" ' + 'FROM app."vendors" ' + 'WHERE "id" = %s AND "deleted_at" IS NULL', + (vendor_id,), + ) + row = cursor.fetchone() + if row is None: + return None + return {"id": str(row[0]), "email": row[1]} + + +def create_vendor(cursor: Cursor, email: str, password_hash: str) -> dict[str, Any]: + """Insert a new vendor and return the created row.""" + cursor.execute( + 'INSERT INTO app."vendors" ("email", "password_hash") ' + "VALUES (%s, %s) " + 'RETURNING "id", "email"', + (email, password_hash), + ) + row = cursor.fetchone() + assert row is not None + return {"id": str(row[0]), "email": row[1]} diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py new file mode 100644 index 0000000..d52ac9d --- /dev/null +++ b/backend/app/schemas/auth.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel, EmailStr, Field + + +# ── Request schemas ────────────────────────────────────────── + + +class SignupRequest(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8, max_length=128) + client_id: str = Field(..., min_length=1, max_length=256) + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + client_id: str = Field(..., min_length=1, max_length=256) + + +class RefreshRequest(BaseModel): + refresh_token: str + client_id: str = Field(..., min_length=1, max_length=256) + + +# ── Response schemas ───────────────────────────────────────── + + +class TokenPair(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class VendorOut(BaseModel): + id: str + email: str + + +class SignupResponse(BaseModel): + vendor: VendorOut diff --git a/backend/app/services/.gitkeep b/backend/app/services/__init__.py similarity index 100% rename from backend/app/services/.gitkeep rename to backend/app/services/__init__.py diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py new file mode 100644 index 0000000..4f7421b --- /dev/null +++ b/backend/app/services/auth.py @@ -0,0 +1,113 @@ +"""Auth service — orchestrates signup, login, and token refresh.""" + +from __future__ import annotations + +import logging + +import jwt as pyjwt +from psycopg import Cursor + +from app.core.config import Settings +from app.core.exceptions import AuthenticationException, ConflictException +from app.core.security import ( + create_access_token, + create_refresh_token, + decode_token, + get_password_hash, + verify_password, +) +from app.crud.vendor import create_vendor, get_vendor_by_email, get_vendor_by_id +from app.schemas.auth import SignupResponse, TokenPair, VendorOut + +logger = logging.getLogger(__name__) + + +def signup( + cursor: Cursor, + email: str, + password: str, + client_id: str, + settings: Settings, +) -> SignupResponse: + """Create a new vendor account. + + Raises ConflictException if the email already exists. + """ + existing = get_vendor_by_email(cursor, email) + if existing is not None: + raise ConflictException("A vendor with this email already exists") + + hashed = get_password_hash(password) + vendor = create_vendor(cursor, email, hashed) + + logger.info("Vendor created: %s (client_id=%s)", vendor["id"], client_id) + return SignupResponse(vendor=VendorOut(**vendor)) + + +def login( + cursor: Cursor, + email: str, + password: str, + client_id: str, + settings: Settings, +) -> TokenPair: + """Authenticate a vendor and return an access/refresh token pair. + + Raises AuthenticationException for invalid credentials. + """ + vendor = get_vendor_by_email(cursor, email) + if vendor is None: + raise AuthenticationException() + + valid, updated_hash = verify_password(password, vendor["password_hash"]) + if not valid: + raise AuthenticationException() + + # If the hashing library returned an upgraded hash, persist it + if updated_hash is not None: + cursor.execute( + 'UPDATE app."vendors" SET "password_hash" = %s, "updated_at" = NOW() ' + 'WHERE "id" = %s', + (updated_hash, vendor["id"]), + ) + + access_token = create_access_token(vendor["id"], settings) + refresh_token = create_refresh_token(vendor["id"], settings) + + logger.info("Vendor logged in: %s (client_id=%s)", vendor["id"], client_id) + return TokenPair(access_token=access_token, refresh_token=refresh_token) + + +def refresh( + refresh_token_str: str, + client_id: str, + cursor: Cursor, + settings: Settings, +) -> TokenPair: + """Issue a new token pair from a valid refresh token. + + Raises AuthenticationException if the token is invalid/expired + or if it is not a refresh token. + """ + try: + payload = decode_token(refresh_token_str, settings) + except pyjwt.PyJWTError: + raise AuthenticationException("Invalid or expired refresh token") + + if payload.get("token_type") != "refresh": + raise AuthenticationException("Invalid token type") + + vendor_id = payload.get("vendor_id") + if vendor_id is None: + raise AuthenticationException("Invalid token payload") + + # Ensure vendor still exists and is not deleted + vendor = get_vendor_by_id(cursor, vendor_id) + if vendor is None: + raise AuthenticationException("Vendor not found") + + access_token = create_access_token(vendor_id, settings) + new_refresh_token = create_refresh_token(vendor_id, settings) + + logger.info("Token refreshed for vendor: %s (client_id=%s)", vendor_id, client_id) + return TokenPair(access_token=access_token, refresh_token=new_refresh_token) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2043866..1cbd8ab 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pydantic-settings<3.0.0,>=2.2.1", "pwdlib[argon2,bcrypt]>=0.3.0", "psycopg-pool>=3.3.0,<4.0.0", + "pyjwt>=2.8.0,<3.0.0", ] [dependency-groups] diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py new file mode 100644 index 0000000..0793c2b --- /dev/null +++ b/backend/tests/integration/test_auth.py @@ -0,0 +1,268 @@ +"""Integration tests for auth endpoints (signup, login, refresh). + +Uses Testcontainers for a real PostgreSQL instance with migrations applied. +Verifies the full HTTP round-trip including RLS context setting. +""" + +from __future__ import annotations + +import typing +from contextlib import asynccontextmanager +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from psycopg import Cursor, connect +from testcontainers.postgres import PostgresContainer + +from app.api.deps import get_db, get_settings +from app.api.main import api_router +from app.core.config import Settings +from app.core.exception_handlers import ( + api_exception_handler, + general_exception_handler, + validation_exception_handler, +) +from app.core.exceptions import APIException +from fastapi.exceptions import RequestValidationError + +MIGRATIONS_DIR = str(Path(__file__).parents[3] / "migrations") +API_V1 = "/api/v1" + + +# ── Module-scoped fixtures ────────────────────────────────── + + +@pytest.fixture(scope="module") +def pg_container() -> typing.Generator[PostgresContainer, None, None]: + with PostgresContainer("postgres:18.2-alpine3.23", driver=None).with_volume_mapping( + MIGRATIONS_DIR, "/docker-entrypoint-initdb.d" + ) as container: + yield container + + +@pytest.fixture(scope="module") +def test_settings() -> Settings: + return Settings( + SECRET_KEY="integration-test-secret-key-32bytes!", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ACCESS_TOKEN_EXPIRE_MINUTES=60, + REFRESH_TOKEN_EXPIRE_DAYS=7, + ) + + +@pytest.fixture(scope="module") +def client( + pg_container: PostgresContainer, test_settings: Settings +) -> typing.Generator[TestClient, None, None]: + """TestClient with a lightweight test app (no real lifespan).""" + + @asynccontextmanager + async def _noop_lifespan(app: FastAPI): + yield + + test_app = FastAPI(lifespan=_noop_lifespan) + test_app.include_router(api_router, prefix=API_V1) + test_app.add_exception_handler(APIException, api_exception_handler) + test_app.add_exception_handler(RequestValidationError, validation_exception_handler) + test_app.add_exception_handler(Exception, general_exception_handler) + + def _override_get_db() -> typing.Generator[Cursor, None, None]: + with connect(pg_container.get_connection_url()) as conn: + with conn.cursor() as cur: + yield cur + + def _override_get_settings() -> Settings: + return test_settings + + test_app.dependency_overrides[get_db] = _override_get_db + test_app.dependency_overrides[get_settings] = _override_get_settings + + with TestClient(test_app) as tc: + yield tc + + +# ── Helpers ────────────────────────────────────────────────── + + +def _signup( + client: TestClient, email: str = "vendor@test.com", password: str = "SecurePass123!" +) -> dict: + return client.post( + f"{API_V1}/auth/signup", + json={"email": email, "password": password, "client_id": "integration-test"}, + ).json() + + +def _login( + client: TestClient, email: str = "vendor@test.com", password: str = "SecurePass123!" +) -> dict: + return client.post( + f"{API_V1}/auth/login", + json={"email": email, "password": password, "client_id": "integration-test"}, + ).json() + + +# ── Tests ──────────────────────────────────────────────────── + + +@pytest.mark.integration +class TestSignup: + def test_signup_creates_vendor_201(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={ + "email": "signup-test@example.com", + "password": "StrongPass1!", + "client_id": "test-client", + }, + ) + assert resp.status_code == 201 + body = resp.json() + assert "data" in body + assert body["data"]["vendor"]["email"] == "signup-test@example.com" + assert "id" in body["data"]["vendor"] + + def test_signup_duplicate_email_409(self, client: TestClient): + email = "dup@example.com" + # First signup succeeds + resp1 = client.post( + f"{API_V1}/auth/signup", + json={"email": email, "password": "Pass12345!", "client_id": "c1"}, + ) + assert resp1.status_code == 201 + + # Second signup with same email fails + resp2 = client.post( + f"{API_V1}/auth/signup", + json={"email": email, "password": "Pass12345!", "client_id": "c1"}, + ) + assert resp2.status_code == 409 + + def test_signup_weak_password_422(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={"email": "weak@example.com", "password": "short", "client_id": "c1"}, + ) + assert resp.status_code == 422 + + def test_signup_missing_client_id_422(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/signup", + json={"email": "no-client@example.com", "password": "StrongPass1!"}, + ) + assert resp.status_code == 422 + + +@pytest.mark.integration +class TestLogin: + def test_login_returns_token_pair(self, client: TestClient): + email = "login-test@example.com" + _signup(client, email=email) + + resp = client.post( + f"{API_V1}/auth/login", + json={"email": email, "password": "SecurePass123!", "client_id": "c1"}, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + + def test_login_wrong_password_401(self, client: TestClient): + email = "login-fail@example.com" + _signup(client, email=email) + + resp = client.post( + f"{API_V1}/auth/login", + json={"email": email, "password": "WrongPassword!", "client_id": "c1"}, + ) + assert resp.status_code == 401 + + def test_login_nonexistent_email_401(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": "nobody@example.com", + "password": "Pass12345!", + "client_id": "c1", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.integration +class TestRefresh: + def test_refresh_issues_new_tokens(self, client: TestClient): + email = "refresh-test@example.com" + _signup(client, email=email) + login_data = _login(client, email=email)["data"] + + resp = client.post( + f"{API_V1}/auth/refresh", + json={ + "refresh_token": login_data["refresh_token"], + "client_id": "c1", + }, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + assert "access_token" in data + assert "refresh_token" in data + + def test_refresh_with_access_token_fails_401(self, client: TestClient): + email = "refresh-bad@example.com" + _signup(client, email=email) + login_data = _login(client, email=email)["data"] + + resp = client.post( + f"{API_V1}/auth/refresh", + json={ + "refresh_token": login_data["access_token"], # wrong token type + "client_id": "c1", + }, + ) + assert resp.status_code == 401 + + def test_refresh_with_garbage_token_fails_401(self, client: TestClient): + resp = client.post( + f"{API_V1}/auth/refresh", + json={"refresh_token": "not.a.real.token", "client_id": "c1"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.integration +class TestProtectedEndpoints: + """Verify that protected endpoints reject invalid/missing tokens.""" + + def test_missing_token_401(self, client: TestClient): + """A request without Authorization header should be rejected. + + We use the health endpoint here as a baseline; in a real app + you would test an endpoint that uses CurrentVendorId dependency. + This test validates that the dependency itself rejects missing tokens. + """ + from app.api.deps import get_current_vendor_id + from app.core.exceptions import AuthenticationException + + # Directly test the dependency + from app.core.config import Settings + + settings = Settings( + SECRET_KEY="integration-test-secret", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ) + + with pytest.raises(AuthenticationException, match="Missing"): + get_current_vendor_id(None, settings) diff --git a/backend/tests/unit/test_auth.py b/backend/tests/unit/test_auth.py new file mode 100644 index 0000000..e1d369f --- /dev/null +++ b/backend/tests/unit/test_auth.py @@ -0,0 +1,304 @@ +"""Unit tests for JWT-based vendor authentication. + +Tests cover: +- Password hashing (bcrypt) +- Token creation and decoding +- Auth service logic (signup, login, refresh) +- Auth dependency (JWT validation) +""" + +from __future__ import annotations + +import uuid +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import jwt as pyjwt +import pytest + +from app.core.config import Settings +from app.core.exceptions import AuthenticationException, ConflictException +from app.core.security import ( + create_access_token, + create_refresh_token, + decode_token, + get_password_hash, + verify_password, +) + + +# ── Fixture: test settings ────────────────────────────────── + + +@pytest.fixture +def settings() -> Settings: + """Minimal settings for security tests.""" + return Settings( + SECRET_KEY="test-secret-key-for-unit-tests", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ACCESS_TOKEN_EXPIRE_MINUTES=60, + REFRESH_TOKEN_EXPIRE_DAYS=7, + ) + + +# ── Password hashing ──────────────────────────────────────── + + +@pytest.mark.unit +class TestPasswordHashing: + def test_hash_and_verify(self): + plain = "SuperSecret123!" + hashed = get_password_hash(plain) + + assert hashed != plain + assert hashed.startswith("$2") # bcrypt prefix + valid, _ = verify_password(plain, hashed) + assert valid is True + + def test_wrong_password_fails(self): + hashed = get_password_hash("correct-password") + valid, _ = verify_password("wrong-password", hashed) + assert valid is False + + +# ── Token creation / decoding ──────────────────────────────── + + +@pytest.mark.unit +class TestTokens: + def test_access_token_claims(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + + payload = decode_token(token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "access" + assert "exp" in payload + + def test_refresh_token_claims(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_refresh_token(vendor_id, settings) + + payload = decode_token(token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "refresh" + assert "exp" in payload + + def test_expired_token_raises(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + + with pytest.raises(pyjwt.ExpiredSignatureError): + decode_token(token, settings) + + def test_invalid_signature_raises(self, settings: Settings): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + + bad_settings = Settings( + SECRET_KEY="different-secret", + PROJECT_NAME="test", + POSTGRES_SERVER="localhost", + POSTGRES_USER="test", + POSTGRES_PASSWORD="test", + POSTGRES_DB="test", + ) + with pytest.raises(pyjwt.InvalidSignatureError): + decode_token(token, bad_settings) + + +# ── Auth service ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestAuthService: + """Test auth service with mocked CRUD layer.""" + + def test_signup_success(self, settings: Settings): + from app.services.auth import signup + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + + with ( + patch("app.services.auth.get_vendor_by_email", return_value=None), + patch( + "app.services.auth.create_vendor", + return_value={"id": vendor_id, "email": "v@test.com"}, + ), + ): + result = signup(cursor, "v@test.com", "password123", "client-1", settings) + + assert result.vendor.id == vendor_id + assert result.vendor.email == "v@test.com" + + def test_signup_duplicate_email_raises(self, settings: Settings): + from app.services.auth import signup + + cursor = MagicMock() + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={"id": "x", "email": "v@test.com", "password_hash": "h"}, + ): + with pytest.raises(ConflictException): + signup(cursor, "v@test.com", "password123", "client-1", settings) + + def test_login_success(self, settings: Settings): + from app.services.auth import login + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + hashed = get_password_hash("password123") + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={ + "id": vendor_id, + "email": "v@test.com", + "password_hash": hashed, + }, + ): + result = login(cursor, "v@test.com", "password123", "client-1", settings) + + assert result.access_token + assert result.refresh_token + assert result.token_type == "bearer" + + # Verify the access token is valid + payload = decode_token(result.access_token, settings) + assert payload["vendor_id"] == vendor_id + assert payload["token_type"] == "access" + + def test_login_wrong_email_raises(self, settings: Settings): + from app.services.auth import login + + cursor = MagicMock() + + with patch("app.services.auth.get_vendor_by_email", return_value=None): + with pytest.raises(AuthenticationException): + login(cursor, "bad@test.com", "password123", "client-1", settings) + + def test_login_wrong_password_raises(self, settings: Settings): + from app.services.auth import login + + cursor = MagicMock() + hashed = get_password_hash("correct-password") + + with patch( + "app.services.auth.get_vendor_by_email", + return_value={"id": "x", "email": "v@test.com", "password_hash": hashed}, + ): + with pytest.raises(AuthenticationException): + login(cursor, "v@test.com", "wrong-password", "client-1", settings) + + def test_refresh_success(self, settings: Settings): + from app.services.auth import refresh + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token(vendor_id, settings) + + with patch( + "app.services.auth.get_vendor_by_id", + return_value={"id": vendor_id, "email": "v@test.com"}, + ): + result = refresh(rt, "client-1", cursor, settings) + + assert result.access_token + assert result.refresh_token + + def test_refresh_with_access_token_raises(self, settings: Settings): + from app.services.auth import refresh + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + at = create_access_token(vendor_id, settings) + + with pytest.raises(AuthenticationException, match="Invalid token type"): + refresh(at, "client-1", cursor, settings) + + def test_refresh_expired_raises(self, settings: Settings): + from app.services.auth import refresh + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + + with pytest.raises(AuthenticationException): + refresh(rt, "client-1", cursor, settings) + + def test_refresh_deleted_vendor_raises(self, settings: Settings): + from app.services.auth import refresh + + cursor = MagicMock() + vendor_id = str(uuid.uuid4()) + rt = create_refresh_token(vendor_id, settings) + + with patch("app.services.auth.get_vendor_by_id", return_value=None): + with pytest.raises(AuthenticationException, match="Vendor not found"): + refresh(rt, "client-1", cursor, settings) + + +# ── Auth dependency ────────────────────────────────────────── + + +@pytest.mark.unit +class TestGetCurrentVendorId: + def test_valid_access_token(self, settings: Settings): + from app.api.deps import get_current_vendor_id + + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, settings) + creds = MagicMock() + creds.credentials = token + + result = get_current_vendor_id(creds, settings) + assert result == vendor_id + + def test_missing_credentials_raises(self, settings: Settings): + from app.api.deps import get_current_vendor_id + + with pytest.raises(AuthenticationException, match="Missing"): + get_current_vendor_id(None, settings) + + def test_refresh_token_rejected(self, settings: Settings): + from app.api.deps import get_current_vendor_id + + vendor_id = str(uuid.uuid4()) + token = create_refresh_token(vendor_id, settings) + creds = MagicMock() + creds.credentials = token + + with pytest.raises(AuthenticationException, match="Invalid token type"): + get_current_vendor_id(creds, settings) + + def test_expired_token_raises(self, settings: Settings): + from app.api.deps import get_current_vendor_id + + vendor_id = str(uuid.uuid4()) + token = create_access_token( + vendor_id, settings, expires_delta=timedelta(seconds=-1) + ) + creds = MagicMock() + creds.credentials = token + + with pytest.raises(AuthenticationException): + get_current_vendor_id(creds, settings) + + def test_garbage_token_raises(self, settings: Settings): + from app.api.deps import get_current_vendor_id + + creds = MagicMock() + creds.credentials = "not.a.jwt" + + with pytest.raises(AuthenticationException): + get_current_vendor_id(creds, settings) diff --git a/uv.lock b/uv.lock index 9da86ba..94be5b7 100644 --- a/uv.lock +++ b/uv.lock @@ -959,6 +959,7 @@ dependencies = [ { name = "pwdlib", extra = ["argon2", "bcrypt"] }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-multipart" }, { name = "tenacity" }, ] @@ -982,6 +983,7 @@ requires-dist = [ { name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" }, { name = "pydantic", specifier = ">2.0" }, { name = "pydantic-settings", specifier = ">=2.2.1,<3.0.0" }, + { name = "pyjwt", specifier = ">=2.8.0,<3.0.0" }, { name = "python-multipart", specifier = ">=0.0.7,<1.0.0" }, { name = "tenacity", specifier = ">=8.2.3,<9.0.0" }, ] @@ -1321,6 +1323,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, +] + [[package]] name = "pytest" version = "7.4.4" From ddba71018a2e459270891578e6a0e9b7739cbaf3 Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 11:01:51 +0500 Subject: [PATCH 2/7] Enhance JWT authentication and vendor management - Add UUID validation for vendor_id in get_current_vendor_id. - Update create_vendor to handle email conflicts gracefully. - Refactor auth service to raise ConflictException on duplicate vendor email. - Improve request schemas for better validation. - Add integration tests for protected endpoints and concurrent signup conflicts. --- backend/app/api/deps.py | 9 +++- backend/app/core/security.py | 6 ++- backend/app/crud/vendor.py | 15 ++++-- backend/app/schemas/auth.py | 10 +--- backend/app/services/auth.py | 5 ++ backend/tests/integration/test_auth.py | 71 +++++++++++++++----------- backend/tests/unit/test_auth.py | 55 +++++--------------- 7 files changed, 84 insertions(+), 87 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 629a28a..07c4682 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,3 +1,4 @@ +import uuid as _uuid from collections.abc import Generator from typing import Annotated @@ -10,7 +11,6 @@ from app.core.exceptions import AuthenticationException, ServiceUnavailableException from app.core.security import decode_token -# ── Re-usable bearer scheme (auto-documents in OpenAPI) ────── _bearer_scheme = HTTPBearer(auto_error=False) @@ -62,6 +62,13 @@ def get_current_vendor_id( if vendor_id is None: raise AuthenticationException("Invalid token payload") + # Validate that vendor_id is a well-formed UUID before it reaches + # downstream consumers such as app.set_app_context(). + try: + _uuid.UUID(vendor_id) + except (ValueError, AttributeError): + raise AuthenticationException("Invalid token payload") + return vendor_id diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 9858f12..2b001e7 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -3,12 +3,14 @@ import jwt from pwdlib import PasswordHash +from pwdlib.hashers.argon2 import Argon2Hasher from pwdlib.hashers.bcrypt import BcryptHasher from app.core.config import Settings -# Use bcrypt only as required by spec -password_hash = PasswordHash((BcryptHasher(),)) +# BcryptHasher is listed first so new passwords are hashed with bcrypt. +# Argon2Hasher is kept for verification of legacy hashes. +password_hash = PasswordHash((BcryptHasher(), Argon2Hasher())) ALGORITHM = "HS256" diff --git a/backend/app/crud/vendor.py b/backend/app/crud/vendor.py index dc635af..a4c5885 100644 --- a/backend/app/crud/vendor.py +++ b/backend/app/crud/vendor.py @@ -36,14 +36,23 @@ def get_vendor_by_id(cursor: Cursor, vendor_id: str) -> dict[str, Any] | None: return {"id": str(row[0]), "email": row[1]} -def create_vendor(cursor: Cursor, email: str, password_hash: str) -> dict[str, Any]: - """Insert a new vendor and return the created row.""" +def create_vendor( + cursor: Cursor, email: str, password_hash: str +) -> dict[str, Any] | None: + """Insert a new vendor and return the created row. + + Uses ON CONFLICT DO NOTHING so a concurrent insert with the same + (case-insensitive) email returns None instead of raising a + UniqueViolation. + """ cursor.execute( 'INSERT INTO app."vendors" ("email", "password_hash") ' "VALUES (%s, %s) " + 'ON CONFLICT (LOWER("email")) DO NOTHING ' 'RETURNING "id", "email"', (email, password_hash), ) row = cursor.fetchone() - assert row is not None + if row is None: + return None return {"id": str(row[0]), "email": row[1]} diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index d52ac9d..e9c8f28 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -1,9 +1,6 @@ from pydantic import BaseModel, EmailStr, Field -# ── Request schemas ────────────────────────────────────────── - - class SignupRequest(BaseModel): email: EmailStr password: str = Field(..., min_length=8, max_length=128) @@ -12,18 +9,15 @@ class SignupRequest(BaseModel): class LoginRequest(BaseModel): email: EmailStr - password: str + password: str = Field(..., min_length=8, max_length=128) client_id: str = Field(..., min_length=1, max_length=256) class RefreshRequest(BaseModel): - refresh_token: str + refresh_token: str = Field(..., min_length=8, max_length=4096) client_id: str = Field(..., min_length=1, max_length=256) -# ── Response schemas ───────────────────────────────────────── - - class TokenPair(BaseModel): access_token: str refresh_token: str diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py index 4f7421b..651bb6d 100644 --- a/backend/app/services/auth.py +++ b/backend/app/services/auth.py @@ -40,6 +40,11 @@ def signup( hashed = get_password_hash(password) vendor = create_vendor(cursor, email, hashed) + # create_vendor returns None when a concurrent insert won the race + # (ON CONFLICT DO NOTHING). + if vendor is None: + raise ConflictException("A vendor with this email already exists") + logger.info("Vendor created: %s (client_id=%s)", vendor["id"], client_id) return SignupResponse(vendor=VendorOut(**vendor)) diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py index 0793c2b..4027f07 100644 --- a/backend/tests/integration/test_auth.py +++ b/backend/tests/integration/test_auth.py @@ -8,15 +8,19 @@ import typing from contextlib import asynccontextmanager +from datetime import timedelta from pathlib import Path +import uuid import pytest +from fastapi import APIRouter as _APIRouter from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient from psycopg import Cursor, connect from testcontainers.postgres import PostgresContainer -from app.api.deps import get_db, get_settings +from app.api.deps import CurrentVendorId, get_db, get_settings from app.api.main import api_router from app.core.config import Settings from app.core.exception_handlers import ( @@ -25,13 +29,17 @@ validation_exception_handler, ) from app.core.exceptions import APIException -from fastapi.exceptions import RequestValidationError +from app.core.security import create_access_token MIGRATIONS_DIR = str(Path(__file__).parents[3] / "migrations") API_V1 = "/api/v1" +_test_router = _APIRouter() -# ── Module-scoped fixtures ────────────────────────────────── + +@_test_router.get("/protected-test") +def _protected_test(vendor_id: CurrentVendorId) -> dict: + return {"vendor_id": vendor_id} @pytest.fixture(scope="module") @@ -68,6 +76,7 @@ async def _noop_lifespan(app: FastAPI): test_app = FastAPI(lifespan=_noop_lifespan) test_app.include_router(api_router, prefix=API_V1) + test_app.include_router(_test_router, prefix=API_V1) test_app.add_exception_handler(APIException, api_exception_handler) test_app.add_exception_handler(RequestValidationError, validation_exception_handler) test_app.add_exception_handler(Exception, general_exception_handler) @@ -87,9 +96,6 @@ def _override_get_settings() -> Settings: yield tc -# ── Helpers ────────────────────────────────────────────────── - - def _signup( client: TestClient, email: str = "vendor@test.com", password: str = "SecurePass123!" ) -> dict: @@ -108,9 +114,6 @@ def _login( ).json() -# ── Tests ──────────────────────────────────────────────────── - - @pytest.mark.integration class TestSignup: def test_signup_creates_vendor_201(self, client: TestClient): @@ -240,29 +243,35 @@ def test_refresh_with_garbage_token_fails_401(self, client: TestClient): @pytest.mark.integration class TestProtectedEndpoints: - """Verify that protected endpoints reject invalid/missing tokens.""" + """Verify that protected endpoints reject invalid/missing tokens + via the full HTTP layer (Bearer parsing → dependency injection → + exception-to-HTTP translation). + """ def test_missing_token_401(self, client: TestClient): - """A request without Authorization header should be rejected. - - We use the health endpoint here as a baseline; in a real app - you would test an endpoint that uses CurrentVendorId dependency. - This test validates that the dependency itself rejects missing tokens. - """ - from app.api.deps import get_current_vendor_id - from app.core.exceptions import AuthenticationException - - # Directly test the dependency - from app.core.config import Settings - - settings = Settings( - SECRET_KEY="integration-test-secret", - PROJECT_NAME="test", - POSTGRES_SERVER="localhost", - POSTGRES_USER="test", - POSTGRES_PASSWORD="test", - POSTGRES_DB="test", + resp = client.get(f"{API_V1}/protected-test") + assert resp.status_code == 401 + + def test_expired_token_401(self, client: TestClient, test_settings: Settings): + token = create_access_token( + str(uuid.uuid4()), + test_settings, + expires_delta=timedelta(seconds=-1), + ) + resp = client.get( + f"{API_V1}/protected-test", + headers={"Authorization": f"Bearer {token}"}, ) + assert resp.status_code == 401 - with pytest.raises(AuthenticationException, match="Missing"): - get_current_vendor_id(None, settings) + def test_valid_token_returns_vendor_id( + self, client: TestClient, test_settings: Settings + ): + vendor_id = str(uuid.uuid4()) + token = create_access_token(vendor_id, test_settings) + resp = client.get( + f"{API_V1}/protected-test", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["vendor_id"] == vendor_id diff --git a/backend/tests/unit/test_auth.py b/backend/tests/unit/test_auth.py index e1d369f..3452bd9 100644 --- a/backend/tests/unit/test_auth.py +++ b/backend/tests/unit/test_auth.py @@ -16,6 +16,7 @@ import jwt as pyjwt import pytest +from app.api.deps import get_current_vendor_id from app.core.config import Settings from app.core.exceptions import AuthenticationException, ConflictException from app.core.security import ( @@ -25,9 +26,7 @@ get_password_hash, verify_password, ) - - -# ── Fixture: test settings ────────────────────────────────── +from app.services.auth import login, refresh, signup @pytest.fixture @@ -45,9 +44,6 @@ def settings() -> Settings: ) -# ── Password hashing ──────────────────────────────────────── - - @pytest.mark.unit class TestPasswordHashing: def test_hash_and_verify(self): @@ -65,9 +61,6 @@ def test_wrong_password_fails(self): assert valid is False -# ── Token creation / decoding ──────────────────────────────── - - @pytest.mark.unit class TestTokens: def test_access_token_claims(self, settings: Settings): @@ -113,16 +106,11 @@ def test_invalid_signature_raises(self, settings: Settings): decode_token(token, bad_settings) -# ── Auth service ───────────────────────────────────────────── - - @pytest.mark.unit class TestAuthService: """Test auth service with mocked CRUD layer.""" def test_signup_success(self, settings: Settings): - from app.services.auth import signup - cursor = MagicMock() vendor_id = str(uuid.uuid4()) @@ -139,8 +127,6 @@ def test_signup_success(self, settings: Settings): assert result.vendor.email == "v@test.com" def test_signup_duplicate_email_raises(self, settings: Settings): - from app.services.auth import signup - cursor = MagicMock() with patch( @@ -151,8 +137,6 @@ def test_signup_duplicate_email_raises(self, settings: Settings): signup(cursor, "v@test.com", "password123", "client-1", settings) def test_login_success(self, settings: Settings): - from app.services.auth import login - cursor = MagicMock() vendor_id = str(uuid.uuid4()) hashed = get_password_hash("password123") @@ -171,14 +155,11 @@ def test_login_success(self, settings: Settings): assert result.refresh_token assert result.token_type == "bearer" - # Verify the access token is valid payload = decode_token(result.access_token, settings) assert payload["vendor_id"] == vendor_id assert payload["token_type"] == "access" def test_login_wrong_email_raises(self, settings: Settings): - from app.services.auth import login - cursor = MagicMock() with patch("app.services.auth.get_vendor_by_email", return_value=None): @@ -186,8 +167,6 @@ def test_login_wrong_email_raises(self, settings: Settings): login(cursor, "bad@test.com", "password123", "client-1", settings) def test_login_wrong_password_raises(self, settings: Settings): - from app.services.auth import login - cursor = MagicMock() hashed = get_password_hash("correct-password") @@ -199,8 +178,6 @@ def test_login_wrong_password_raises(self, settings: Settings): login(cursor, "v@test.com", "wrong-password", "client-1", settings) def test_refresh_success(self, settings: Settings): - from app.services.auth import refresh - cursor = MagicMock() vendor_id = str(uuid.uuid4()) rt = create_refresh_token(vendor_id, settings) @@ -215,8 +192,6 @@ def test_refresh_success(self, settings: Settings): assert result.refresh_token def test_refresh_with_access_token_raises(self, settings: Settings): - from app.services.auth import refresh - cursor = MagicMock() vendor_id = str(uuid.uuid4()) at = create_access_token(vendor_id, settings) @@ -225,8 +200,6 @@ def test_refresh_with_access_token_raises(self, settings: Settings): refresh(at, "client-1", cursor, settings) def test_refresh_expired_raises(self, settings: Settings): - from app.services.auth import refresh - cursor = MagicMock() vendor_id = str(uuid.uuid4()) rt = create_refresh_token( @@ -237,8 +210,6 @@ def test_refresh_expired_raises(self, settings: Settings): refresh(rt, "client-1", cursor, settings) def test_refresh_deleted_vendor_raises(self, settings: Settings): - from app.services.auth import refresh - cursor = MagicMock() vendor_id = str(uuid.uuid4()) rt = create_refresh_token(vendor_id, settings) @@ -247,15 +218,23 @@ def test_refresh_deleted_vendor_raises(self, settings: Settings): with pytest.raises(AuthenticationException, match="Vendor not found"): refresh(rt, "client-1", cursor, settings) + def test_signup_concurrent_insert_conflict(self, settings: Settings): + """Pre-read shows no existing vendor, but the insert collides + (create_vendor returns None due to ON CONFLICT DO NOTHING). + """ + cursor = MagicMock() -# ── Auth dependency ────────────────────────────────────────── + with ( + patch("app.services.auth.get_vendor_by_email", return_value=None), + patch("app.services.auth.create_vendor", return_value=None), + ): + with pytest.raises(ConflictException): + signup(cursor, "race@test.com", "password123", "client-1", settings) @pytest.mark.unit class TestGetCurrentVendorId: def test_valid_access_token(self, settings: Settings): - from app.api.deps import get_current_vendor_id - vendor_id = str(uuid.uuid4()) token = create_access_token(vendor_id, settings) creds = MagicMock() @@ -265,14 +244,10 @@ def test_valid_access_token(self, settings: Settings): assert result == vendor_id def test_missing_credentials_raises(self, settings: Settings): - from app.api.deps import get_current_vendor_id - with pytest.raises(AuthenticationException, match="Missing"): get_current_vendor_id(None, settings) def test_refresh_token_rejected(self, settings: Settings): - from app.api.deps import get_current_vendor_id - vendor_id = str(uuid.uuid4()) token = create_refresh_token(vendor_id, settings) creds = MagicMock() @@ -282,8 +257,6 @@ def test_refresh_token_rejected(self, settings: Settings): get_current_vendor_id(creds, settings) def test_expired_token_raises(self, settings: Settings): - from app.api.deps import get_current_vendor_id - vendor_id = str(uuid.uuid4()) token = create_access_token( vendor_id, settings, expires_delta=timedelta(seconds=-1) @@ -295,8 +268,6 @@ def test_expired_token_raises(self, settings: Settings): get_current_vendor_id(creds, settings) def test_garbage_token_raises(self, settings: Settings): - from app.api.deps import get_current_vendor_id - creds = MagicMock() creds.credentials = "not.a.jwt" From 0cfc50a3e17df158415dad311bf1eb0ecbd5397d Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 11:35:09 +0500 Subject: [PATCH 3/7] Enhance code quality and documentation across the application - Updated .gitignore to exclude log files. - Improved docstrings in deps.py, main.py, middlewares.py, auth.py, health.py, and other modules for better clarity on function behavior and return types. - Refactored exception handling in exception_handlers.py and exceptions.py to include return type hints in constructors. - Added type hints and improved formatting in various service and CRUD functions for consistency and readability. - Integrated Ruff configuration in pyproject.toml for enhanced linting and code style enforcement. - Updated integration and unit tests for better readability and maintainability. --- .gitignore | 1 + backend/app/api/deps.py | 45 ++++++++++++++---- backend/app/api/main.py | 1 + backend/app/api/middlewares.py | 10 ++-- backend/app/api/routes/auth.py | 41 +++++++++++------ backend/app/api/routes/health.py | 6 ++- backend/app/api/routes/login.py | 1 + backend/app/core/config.py | 12 ++--- backend/app/core/exception_handlers.py | 42 +++++++++-------- backend/app/core/exceptions.py | 20 ++++---- backend/app/core/security.py | 29 ++++++++++-- backend/app/crud/vendor.py | 15 +++++- backend/app/main.py | 9 ++-- backend/app/pre_start.py | 3 +- backend/app/schemas/auth.py | 2 +- backend/app/schemas/response.py | 17 +++++-- backend/app/services/auth.py | 45 +++++++++++++----- backend/pyproject.toml | 64 ++++++++++++++++++++++++++ backend/tests/__init__.py | 1 + backend/tests/conftest.py | 4 ++ backend/tests/integration/test_auth.py | 55 +++++++++++++++++----- backend/tests/unit/test_auth.py | 50 ++++++++++++++++---- 22 files changed, 359 insertions(+), 114 deletions(-) diff --git a/.gitignore b/.gitignore index 29e02fe..f9ace42 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ htmlcov *.env build/ dist/ +*.log diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 07c4682..d33af99 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -8,14 +8,25 @@ from psycopg import Cursor from app.core.config import Settings -from app.core.exceptions import AuthenticationException, ServiceUnavailableException +from app.core.exceptions import ( + AuthenticationException, + ServiceUnavailableException, +) from app.core.security import decode_token + _bearer_scheme = HTTPBearer(auto_error=False) def get_db(request: Request) -> Generator[Cursor, None, None]: - """Return a database cursor for the request.""" + """Return a database cursor for the request. + + Yields: + Cursor: A database cursor. + + Raises: + ServiceUnavailableException: If the database pool is not initialized. + """ pool = getattr(request.app.state, "db_pool", None) if pool is None: raise ServiceUnavailableException("Database pool not initialized") @@ -25,7 +36,14 @@ def get_db(request: Request) -> Generator[Cursor, None, None]: def get_settings(request: Request) -> Settings: - """Return settings from app state.""" + """Return settings from app state. + + Returns: + Settings: Application settings. + + Raises: + ServiceUnavailableException: If settings are not initialized. + """ settings = getattr(request.app.state, "settings", None) if settings is None: raise ServiceUnavailableException("Settings not initialized") @@ -44,8 +62,12 @@ def get_current_vendor_id( ) -> str: """Extract and validate vendor_id from the Authorization: Bearer token. - Raises AuthenticationException on missing / invalid / expired tokens - or if the token is not an access token. + Returns: + str: The validated vendor_id from the JWT. + + Raises: + AuthenticationException: On missing / invalid / expired tokens + or if the token is not an access token. """ if credentials is None: raise AuthenticationException("Missing authentication token") @@ -53,7 +75,7 @@ def get_current_vendor_id( try: payload = decode_token(credentials.credentials, settings) except pyjwt.PyJWTError: - raise AuthenticationException("Invalid or expired token") + raise AuthenticationException("Invalid or expired token") from None if payload.get("token_type") != "access": raise AuthenticationException("Invalid token type") @@ -67,14 +89,13 @@ def get_current_vendor_id( try: _uuid.UUID(vendor_id) except (ValueError, AttributeError): - raise AuthenticationException("Invalid token payload") + raise AuthenticationException("Invalid token payload") from None return vendor_id def get_rls_cursor( - request: Request, - vendor_id: Annotated[str, Depends(get_current_vendor_id)], + request: Request, vendor_id: Annotated[str, Depends(get_current_vendor_id)] ) -> Generator[Cursor, None, None]: """Return a database cursor with app.vendor_id set for RLS. @@ -82,6 +103,12 @@ def get_rls_cursor( 1. Obtains a connection from the pool 2. Calls app.set_app_context(vendor_id) to set the RLS context 3. Yields the cursor for use in route handlers + + Yields: + Cursor: A database cursor with RLS context set. + + Raises: + ServiceUnavailableException: If the database pool is not initialized. """ pool = getattr(request.app.state, "db_pool", None) if pool is None: diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 1bb1833..46d462c 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -2,6 +2,7 @@ from app.api.routes import auth, health, login + api_router = APIRouter() api_router.include_router(health.router, tags=["health"]) api_router.include_router(login.router, prefix="/login", tags=["login"]) diff --git a/backend/app/api/middlewares.py b/backend/app/api/middlewares.py index 387d34c..de4738d 100644 --- a/backend/app/api/middlewares.py +++ b/backend/app/api/middlewares.py @@ -1,10 +1,14 @@ import uuid -from fastapi import Request +from fastapi import Request, Response -async def add_request_id(request: Request, call_next): - """Add a unique request_id to each request context""" +async def add_request_id(request: Request, call_next: object) -> Response: + """Add a unique request_id to each request context. + + Returns: + Response: The response with X-Request-ID header. + """ raw_header = request.headers.get("X-Request-ID") stripped_header = raw_header.strip() if raw_header else "" request_id = stripped_header if stripped_header else str(uuid.uuid4()) diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index 7e0b59c..3aa1000 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -13,6 +13,7 @@ from app.schemas.response import SuccessResponse from app.services import auth as auth_service + router = APIRouter() @@ -21,8 +22,14 @@ status_code=status.HTTP_201_CREATED, response_model=SuccessResponse[SignupResponse], ) -def signup(body: SignupRequest, cursor: CursorDep, settings: SettingsDep): - """Create a new vendor account.""" +def signup( + body: SignupRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[SignupResponse]: + """Create a new vendor account. + + Returns: + SuccessResponse[SignupResponse]: The created vendor. + """ result = auth_service.signup( cursor=cursor, email=body.email, @@ -33,12 +40,15 @@ def signup(body: SignupRequest, cursor: CursorDep, settings: SettingsDep): return SuccessResponse(data=result) -@router.post( - "/login", - response_model=SuccessResponse[TokenPair], -) -def login(body: LoginRequest, cursor: CursorDep, settings: SettingsDep): - """Authenticate a vendor and return an access/refresh token pair.""" +@router.post("/login", response_model=SuccessResponse[TokenPair]) +def login( + body: LoginRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[TokenPair]: + """Authenticate a vendor and return an access/refresh token pair. + + Returns: + SuccessResponse[TokenPair]: The token pair. + """ result = auth_service.login( cursor=cursor, email=body.email, @@ -49,12 +59,15 @@ def login(body: LoginRequest, cursor: CursorDep, settings: SettingsDep): return SuccessResponse(data=result) -@router.post( - "/refresh", - response_model=SuccessResponse[TokenPair], -) -def refresh(body: RefreshRequest, cursor: CursorDep, settings: SettingsDep): - """Issue a new token pair using a valid refresh token.""" +@router.post("/refresh", response_model=SuccessResponse[TokenPair]) +def refresh( + body: RefreshRequest, cursor: CursorDep, settings: SettingsDep +) -> SuccessResponse[TokenPair]: + """Issue a new token pair using a valid refresh token. + + Returns: + SuccessResponse[TokenPair]: The new token pair. + """ result = auth_service.refresh( refresh_token_str=body.refresh_token, client_id=body.client_id, diff --git a/backend/app/api/routes/health.py b/backend/app/api/routes/health.py index ffe224c..2a39a8a 100644 --- a/backend/app/api/routes/health.py +++ b/backend/app/api/routes/health.py @@ -3,11 +3,15 @@ from fastapi import APIRouter + router = APIRouter() @router.get("/health") def health_check() -> dict[str, Any]: return { - "data": {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()} + "data": { + "status": "ok", + "timestamp": datetime.now(timezone.utc).isoformat(), + } } diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index af9233c..e6f2f82 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -1,3 +1,4 @@ from fastapi import APIRouter + router = APIRouter() diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 22971f1..b2a3142 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,15 +1,10 @@ -from pydantic import ( - PostgresDsn, - computed_field, -) +from pydantic import PostgresDsn, computed_field from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): model_config = SettingsConfigDict( - env_file=".env", - env_ignore_empty=True, - extra="ignore", + env_file=".env", env_ignore_empty=True, extra="ignore" ) SECRET_KEY: str @@ -26,7 +21,8 @@ class Settings(BaseSettings): @computed_field @property - def DATABASE_DSN(self) -> PostgresDsn: + def DATABASE_DSN(self) -> PostgresDsn: # noqa: N802 + # See: https://docs.astral.sh/ruff/rules/invalid-function-name/ return PostgresDsn.build( scheme="postgresql", username=self.POSTGRES_USER, diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index 1462445..20f96bd 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -1,5 +1,5 @@ -import uuid import logging +import uuid from fastapi import Request, status from fastapi.exceptions import RequestValidationError @@ -17,8 +17,12 @@ logger = logging.getLogger(__name__) -async def api_exception_handler(request: Request, exc: APIException) -> JSONResponse: - """Handle custom API exceptions""" +def api_exception_handler(request: Request, exc: APIException) -> JSONResponse: + """Handle custom API exceptions. + + Returns: + JSONResponse: The error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) logger.warning( @@ -46,10 +50,14 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp ) -async def validation_exception_handler( +def validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: - """Handle Pydantic validation errors""" + """Handle Pydantic validation errors. + + Returns: + JSONResponse: The validation error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) # Parse validation errors @@ -57,20 +65,12 @@ async def validation_exception_handler( for error in exc.errors(): field_path = ".".join(str(loc) for loc in error["loc"][1:]) field = field_path if field_path else None - details.append( - { - "field": field, - "message": error["msg"], - } - ) + details.append({"field": field, "message": error["msg"]}) logger.warning( "Validation Error: %d validation errors", len(details), - extra={ - "request_id": request_id, - "validation_errors": details, - }, + extra={"request_id": request_id, "validation_errors": details}, ) # Starlette <0.48 compatibility: use getattr fallback to literal 422 @@ -96,15 +96,17 @@ async def validation_exception_handler( ) -async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: - """Handle all uncaught exceptions""" +def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle all uncaught exceptions. + + Returns: + JSONResponse: The generic error response. + """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())) # Log the full traceback server-side logger.exception( - "Unexpected error: %s", - str(exc), - extra={"request_id": request_id}, + "Unexpected error: %s", exc, extra={"request_id": request_id} ) # Use typed schema to validate error structure diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py index b6537a9..c027cff 100644 --- a/backend/app/core/exceptions.py +++ b/backend/app/core/exceptions.py @@ -5,7 +5,7 @@ from app.schemas.response import ErrorCode -class APIException(Exception): +class APIException(Exception): # noqa: N818 """Base exception for API errors""" def __init__( @@ -14,7 +14,7 @@ def __init__( message: str, http_status: int, details: list[dict[str, Any]] | None = None, - ): + ) -> None: self.error_code = error_code self.message = message self.http_status = http_status @@ -29,7 +29,7 @@ def __init__( self, message: str = "Invalid request parameters", details: list[dict[str, Any]] | None = None, - ): + ) -> None: super().__init__( error_code=ErrorCode.VALIDATION_FAILED, message=message, @@ -42,7 +42,7 @@ def __init__( class AuthenticationException(APIException): """Authentication error""" - def __init__(self, message: str = "Invalid credentials"): + def __init__(self, message: str = "Invalid credentials") -> None: super().__init__( error_code=ErrorCode.AUTH_INVALID, message=message, @@ -53,7 +53,7 @@ def __init__(self, message: str = "Invalid credentials"): class AuthorizationException(APIException): """Authorization error""" - def __init__(self, message: str = "Access denied"): + def __init__(self, message: str = "Access denied") -> None: super().__init__( error_code=ErrorCode.FORBIDDEN, message=message, @@ -64,7 +64,7 @@ def __init__(self, message: str = "Access denied"): class NotFoundException(APIException): """Resource not found error""" - def __init__(self, message: str = "Resource not found"): + def __init__(self, message: str = "Resource not found") -> None: super().__init__( error_code=ErrorCode.RESOURCE_NOT_FOUND, message=message, @@ -75,7 +75,7 @@ def __init__(self, message: str = "Resource not found"): class ConflictException(APIException): """Resource conflict error""" - def __init__(self, message: str = "Resource conflict"): + def __init__(self, message: str = "Resource conflict") -> None: super().__init__( error_code=ErrorCode.RESOURCE_CONFLICT, message=message, @@ -90,7 +90,7 @@ def __init__( self, message: str = "Business logic error", details: list[dict[str, Any]] | None = None, - ): + ) -> None: super().__init__( error_code=ErrorCode.BUSINESS_LOGIC_ERROR, message=message, @@ -103,7 +103,9 @@ def __init__( class ServiceUnavailableException(APIException): """Service unavailable error""" - def __init__(self, message: str = "Service temporarily unavailable"): + def __init__( + self, message: str = "Service temporarily unavailable" + ) -> None: super().__init__( error_code=ErrorCode.SERVICE_UNAVAILABLE, message=message, diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 2b001e7..48d52d7 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -8,6 +8,7 @@ from app.core.config import Settings + # BcryptHasher is listed first so new passwords are hashed with bcrypt. # Argon2Hasher is kept for verification of legacy hashes. password_hash = PasswordHash((BcryptHasher(), Argon2Hasher())) @@ -17,9 +18,16 @@ def create_access_token( - vendor_id: str, settings: Settings, *, expires_delta: timedelta | None = None + vendor_id: str, + settings: Settings, + *, + expires_delta: timedelta | None = None, ) -> str: - """Create a short-lived access token with vendor_id claim.""" + """Create a short-lived access token with vendor_id claim. + + Returns: + str: The encoded JWT access token. + """ if expires_delta is None: expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + expires_delta @@ -32,9 +40,16 @@ def create_access_token( def create_refresh_token( - vendor_id: str, settings: Settings, *, expires_delta: timedelta | None = None + vendor_id: str, + settings: Settings, + *, + expires_delta: timedelta | None = None, ) -> str: - """Create a long-lived refresh token with vendor_id claim.""" + """Create a long-lived refresh token with vendor_id claim. + + Returns: + str: The encoded JWT refresh token. + """ if expires_delta is None: expires_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now(timezone.utc) + expires_delta @@ -47,7 +62,11 @@ def create_refresh_token( def decode_token(token: str, settings: Settings) -> dict[str, Any]: - """Decode and validate a JWT token. Raises jwt.PyJWTError on failure.""" + """Decode and validate a JWT token. + + Returns: + dict[str, Any]: The decoded token payload. + """ return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) diff --git a/backend/app/crud/vendor.py b/backend/app/crud/vendor.py index a4c5885..57c2774 100644 --- a/backend/app/crud/vendor.py +++ b/backend/app/crud/vendor.py @@ -8,7 +8,11 @@ def get_vendor_by_email(cursor: Cursor, email: str) -> dict[str, Any] | None: - """Return a vendor row by email (case-insensitive) or None.""" + """Return a vendor row by email (case-insensitive) or None. + + Returns: + dict[str, Any] | None: The vendor row or None. + """ cursor.execute( 'SELECT "id", "email", "password_hash" ' 'FROM app."vendors" ' @@ -23,7 +27,11 @@ def get_vendor_by_email(cursor: Cursor, email: str) -> dict[str, Any] | None: def get_vendor_by_id(cursor: Cursor, vendor_id: str) -> dict[str, Any] | None: - """Return a vendor row by id or None.""" + """Return a vendor row by id or None. + + Returns: + dict[str, Any] | None: The vendor row or None. + """ cursor.execute( 'SELECT "id", "email" ' 'FROM app."vendors" ' @@ -44,6 +52,9 @@ def create_vendor( Uses ON CONFLICT DO NOTHING so a concurrent insert with the same (case-insensitive) email returns None instead of raising a UniqueViolation. + + Returns: + dict[str, Any] | None: The created vendor row, or None on conflict. """ cursor.execute( 'INSERT INTO app."vendors" ("email", "password_hash") ' diff --git a/backend/app/main.py b/backend/app/main.py index 7d0e4e9..07241ca 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,27 +1,28 @@ -from contextlib import asynccontextmanager import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute +from psycopg_pool import ConnectionPool from app.api.main import api_router from app.api.middlewares import add_request_id from app.core.config import Settings from app.core.exception_handlers import ( api_exception_handler, - validation_exception_handler, general_exception_handler, + validation_exception_handler, ) from app.core.exceptions import APIException -from psycopg_pool import ConnectionPool logger = logging.getLogger(__name__) @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI) -> AsyncIterator[None]: # noqa: RUF029 # Startup logger.info("Initializing settings") settings = Settings() diff --git a/backend/app/pre_start.py b/backend/app/pre_start.py index c30cbaa..3b8262a 100644 --- a/backend/app/pre_start.py +++ b/backend/app/pre_start.py @@ -1,5 +1,6 @@ import logging +from psycopg_pool import ConnectionPool from tenacity import ( after_log, before_log, @@ -10,7 +11,7 @@ ) from app.core.config import Settings -from psycopg_pool import ConnectionPool + logger = logging.getLogger(__name__) diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index e9c8f28..13d9e0b 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -21,7 +21,7 @@ class RefreshRequest(BaseModel): class TokenPair(BaseModel): access_token: str refresh_token: str - token_type: str = "bearer" + token_type: str = "bearer" # noqa: S105 class VendorOut(BaseModel): diff --git a/backend/app/schemas/response.py b/backend/app/schemas/response.py index af1a7f7..171920c 100644 --- a/backend/app/schemas/response.py +++ b/backend/app/schemas/response.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Generic, TypeVar -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field + T = TypeVar("T") @@ -47,7 +48,9 @@ class SuccessResponse(BaseModel, Generic[T]): class ErrorDetail(BaseModel): """Additional error details""" - field: str | None = Field(None, description="Field name if validation error") + field: str | None = Field( + None, description="Field name if validation error" + ) message: str = Field(..., description="Detailed error message") @@ -71,7 +74,11 @@ class ErrorResponse(BaseModel): error: ErrorBodyResponse = Field( ..., - description="Error information with code, message, http_status, details, and request_id", + description=( + "Error information containing code," + " message, http_status, details," + " and request_id" + ), ) model_config = ConfigDict( @@ -81,7 +88,9 @@ class ErrorResponse(BaseModel): "code": "VALIDATION_FAILED", "message": "Invalid request parameters", "http_status": 422, - "details": [{"field": "email", "message": "Invalid email format"}], + "details": [ + {"field": "email", "message": "Invalid email format"} + ], "request_id": "req-123456789", } } diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py index 651bb6d..ef4f4eb 100644 --- a/backend/app/services/auth.py +++ b/backend/app/services/auth.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import uuid as _uuid import jwt as pyjwt from psycopg import Cursor @@ -19,6 +20,7 @@ from app.crud.vendor import create_vendor, get_vendor_by_email, get_vendor_by_id from app.schemas.auth import SignupResponse, TokenPair, VendorOut + logger = logging.getLogger(__name__) @@ -31,7 +33,11 @@ def signup( ) -> SignupResponse: """Create a new vendor account. - Raises ConflictException if the email already exists. + Returns: + SignupResponse: The created vendor. + + Raises: + ConflictException: If the email already exists. """ existing = get_vendor_by_email(cursor, email) if existing is not None: @@ -58,7 +64,11 @@ def login( ) -> TokenPair: """Authenticate a vendor and return an access/refresh token pair. - Raises AuthenticationException for invalid credentials. + Returns: + TokenPair: The access/refresh token pair. + + Raises: + AuthenticationException: For invalid credentials. """ vendor = get_vendor_by_email(cursor, email) if vendor is None: @@ -71,8 +81,9 @@ def login( # If the hashing library returned an upgraded hash, persist it if updated_hash is not None: cursor.execute( - 'UPDATE app."vendors" SET "password_hash" = %s, "updated_at" = NOW() ' - 'WHERE "id" = %s', + 'UPDATE app."vendors"' + ' SET "password_hash" = %s, "updated_at" = NOW()' + ' WHERE "id" = %s', (updated_hash, vendor["id"]), ) @@ -84,20 +95,23 @@ def login( def refresh( - refresh_token_str: str, - client_id: str, - cursor: Cursor, - settings: Settings, + refresh_token_str: str, client_id: str, cursor: Cursor, settings: Settings ) -> TokenPair: """Issue a new token pair from a valid refresh token. - Raises AuthenticationException if the token is invalid/expired - or if it is not a refresh token. + Returns: + TokenPair: The new access/refresh token pair. + + Raises: + AuthenticationException: If the token is invalid/expired + or if it is not a refresh token. """ try: payload = decode_token(refresh_token_str, settings) except pyjwt.PyJWTError: - raise AuthenticationException("Invalid or expired refresh token") + raise AuthenticationException( + "Invalid or expired refresh token" + ) from None if payload.get("token_type") != "refresh": raise AuthenticationException("Invalid token type") @@ -106,6 +120,11 @@ def refresh( if vendor_id is None: raise AuthenticationException("Invalid token payload") + try: + _uuid.UUID(vendor_id) + except (ValueError, AttributeError): + raise AuthenticationException("Invalid token payload") from None + # Ensure vendor still exists and is not deleted vendor = get_vendor_by_id(cursor, vendor_id) if vendor is None: @@ -114,5 +133,7 @@ def refresh( access_token = create_access_token(vendor_id, settings) new_refresh_token = create_refresh_token(vendor_id, settings) - logger.info("Token refreshed for vendor: %s (client_id=%s)", vendor_id, client_id) + logger.info( + "Token refreshed for vendor: %s (client_id=%s)", vendor_id, client_id + ) return TokenPair(access_token=access_token, refresh_token=new_refresh_token) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1cbd8ab..b694553 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -33,6 +33,70 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["app"] +[tool.ruff] +line-length = 80 +indent-width = 4 +preview = true + +# Output serialization format for violations. The default serialization +# format is "full" [env: RUFF_OUTPUT_FORMAT=] [possible values: +# concise, full, json, json-lines, junit, grouped, github, gitlab, +# pylint, rdjson, azure, sarif] +output-format = "grouped" + +[tool.ruff.lint] +isort.lines-after-imports = 2 +isort.split-on-trailing-comma = false + +select = [ + "ANN", # flake8-annotations (required strict type annotations for public functions) + "S", # flake8-bandit (checks basic security issues in code) + "BLE", # flake8-blind-except (checks the except blocks that do not specify exception) + "FBT", # flake8-boolean-trap (ensure that boolean args can be used with kw only) + "E", # pycodestyle errors (PEP 8 style guide violations) + "W", # pycodestyle warnings (e.g., extra spaces, indentation issues) + "DOC", # pydoclint issues (e.g., extra or missing return, yield, warnings) + "A", # flake8-buitins (check variable and function names to not shadow builtins) + "N", # Naming convention checks (e.g., PEP 8 variable and function names) + "F", # Pyflakes errors (e.g., unused imports, undefined variables) + "I", # isort (Ensures imports are sorted properly) + "B", # flake8-bugbear (Detects likely bugs and bad practices) + "TID", # flake8-tidy-imports (Checks for banned or misplaced imports) + "UP", # pyupgrade (Automatically updates old Python syntax) + "YTT", # flake8-2020 (Detects outdated Python 2/3 compatibility issues) + "FLY", # flynt (Converts old-style string formatting to f-strings) + "PIE", # flake8-pie + "PL", # pylint + "RUF", # Ruff-specific rules (Additional optimizations and best practices) +] + +ignore = [] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "ANN001", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-function-argument/) + "ANN002", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-args/) + "ANN003", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-type-kwargs/) + "ANN201", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-return-type-undocumented-public-function/) + "ANN202", # [flake8-annotations](https://docs.astral.sh/ruff/rules/missing-return-type-private-function/) + "S101", # [flake8-bandit](https://docs.astral.sh/ruff/rules/assert/) + "S105", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-string/) + "S106", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-func-arg/) + "S107", # [flake8-bandit](https://docs.astral.sh/ruff/rules/hardcoded-password-default/) + "PLR2004", # [pylint](https://docs.astral.sh/ruff/rules/magic-value-comparison/) + "PLR6301", # [pylint](https://docs.astral.sh/ruff/rules/no-self-use/) + "DOC201", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-returns/) + "DOC402", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-yields/) + "DOC501", # [pydoclint](https://docs.astral.sh/ruff/rules/docstring-missing-exception/) +] +"tests/__init__.py" = [ + "RUF067", # [ruff](https://docs.astral.sh/ruff/rules/non-empty-init-module/) +] + +[tool.ruff.format] +docstring-code-format = true +skip-magic-trailing-comma = true + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py index c7e2080..37ff788 100644 --- a/backend/tests/__init__.py +++ b/backend/tests/__init__.py @@ -1,5 +1,6 @@ import os + # Environment variables must be set before importing the app os.environ.setdefault("SECRET_KEY", "test-secret-key") os.environ.setdefault("PROJECT_NAME", "permit-test") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index ca04ec6..9a3a0d5 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -19,6 +19,10 @@ from psycopg import Cursor, connect from testcontainers.postgres import PostgresContainer +from app.api.deps import get_db +from app.main import app + + MIGRATIONS_DIR = str(Path(__file__).parents[2] / "migrations") diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py index 4027f07..e696a2b 100644 --- a/backend/tests/integration/test_auth.py +++ b/backend/tests/integration/test_auth.py @@ -7,10 +7,10 @@ from __future__ import annotations import typing +import uuid from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path -import uuid import pytest from fastapi import APIRouter as _APIRouter @@ -31,6 +31,7 @@ from app.core.exceptions import APIException from app.core.security import create_access_token + MIGRATIONS_DIR = str(Path(__file__).parents[3] / "migrations") API_V1 = "/api/v1" @@ -44,7 +45,9 @@ def _protected_test(vendor_id: CurrentVendorId) -> dict: @pytest.fixture(scope="module") def pg_container() -> typing.Generator[PostgresContainer, None, None]: - with PostgresContainer("postgres:18.2-alpine3.23", driver=None).with_volume_mapping( + with PostgresContainer( + "postgres:18.2-alpine3.23", driver=None + ).with_volume_mapping( MIGRATIONS_DIR, "/docker-entrypoint-initdb.d" ) as container: yield container @@ -71,14 +74,16 @@ def client( """TestClient with a lightweight test app (no real lifespan).""" @asynccontextmanager - async def _noop_lifespan(app: FastAPI): + async def _noop_lifespan(app: FastAPI): # noqa: RUF029 yield test_app = FastAPI(lifespan=_noop_lifespan) test_app.include_router(api_router, prefix=API_V1) test_app.include_router(_test_router, prefix=API_V1) test_app.add_exception_handler(APIException, api_exception_handler) - test_app.add_exception_handler(RequestValidationError, validation_exception_handler) + test_app.add_exception_handler( + RequestValidationError, validation_exception_handler + ) test_app.add_exception_handler(Exception, general_exception_handler) def _override_get_db() -> typing.Generator[Cursor, None, None]: @@ -97,20 +102,32 @@ def _override_get_settings() -> Settings: def _signup( - client: TestClient, email: str = "vendor@test.com", password: str = "SecurePass123!" + client: TestClient, + email: str = "vendor@test.com", + password: str = "SecurePass123!", ) -> dict: return client.post( f"{API_V1}/auth/signup", - json={"email": email, "password": password, "client_id": "integration-test"}, + json={ + "email": email, + "password": password, + "client_id": "integration-test", + }, ).json() def _login( - client: TestClient, email: str = "vendor@test.com", password: str = "SecurePass123!" + client: TestClient, + email: str = "vendor@test.com", + password: str = "SecurePass123!", ) -> dict: return client.post( f"{API_V1}/auth/login", - json={"email": email, "password": password, "client_id": "integration-test"}, + json={ + "email": email, + "password": password, + "client_id": "integration-test", + }, ).json() @@ -150,7 +167,11 @@ def test_signup_duplicate_email_409(self, client: TestClient): def test_signup_weak_password_422(self, client: TestClient): resp = client.post( f"{API_V1}/auth/signup", - json={"email": "weak@example.com", "password": "short", "client_id": "c1"}, + json={ + "email": "weak@example.com", + "password": "short", + "client_id": "c1", + }, ) assert resp.status_code == 422 @@ -170,7 +191,11 @@ def test_login_returns_token_pair(self, client: TestClient): resp = client.post( f"{API_V1}/auth/login", - json={"email": email, "password": "SecurePass123!", "client_id": "c1"}, + json={ + "email": email, + "password": "SecurePass123!", + "client_id": "c1", + }, ) assert resp.status_code == 200 data = resp.json()["data"] @@ -184,7 +209,11 @@ def test_login_wrong_password_401(self, client: TestClient): resp = client.post( f"{API_V1}/auth/login", - json={"email": email, "password": "WrongPassword!", "client_id": "c1"}, + json={ + "email": email, + "password": "WrongPassword!", + "client_id": "c1", + }, ) assert resp.status_code == 401 @@ -252,7 +281,9 @@ def test_missing_token_401(self, client: TestClient): resp = client.get(f"{API_V1}/protected-test") assert resp.status_code == 401 - def test_expired_token_401(self, client: TestClient, test_settings: Settings): + def test_expired_token_401( + self, client: TestClient, test_settings: Settings + ): token = create_access_token( str(uuid.uuid4()), test_settings, diff --git a/backend/tests/unit/test_auth.py b/backend/tests/unit/test_auth.py index 3452bd9..8697d20 100644 --- a/backend/tests/unit/test_auth.py +++ b/backend/tests/unit/test_auth.py @@ -121,7 +121,9 @@ def test_signup_success(self, settings: Settings): return_value={"id": vendor_id, "email": "v@test.com"}, ), ): - result = signup(cursor, "v@test.com", "password123", "client-1", settings) + result = signup( + cursor, "v@test.com", "password123", "client-1", settings + ) assert result.vendor.id == vendor_id assert result.vendor.email == "v@test.com" @@ -131,10 +133,16 @@ def test_signup_duplicate_email_raises(self, settings: Settings): with patch( "app.services.auth.get_vendor_by_email", - return_value={"id": "x", "email": "v@test.com", "password_hash": "h"}, + return_value={ + "id": "x", + "email": "v@test.com", + "password_hash": "h", + }, ): with pytest.raises(ConflictException): - signup(cursor, "v@test.com", "password123", "client-1", settings) + signup( + cursor, "v@test.com", "password123", "client-1", settings + ) def test_login_success(self, settings: Settings): cursor = MagicMock() @@ -149,7 +157,9 @@ def test_login_success(self, settings: Settings): "password_hash": hashed, }, ): - result = login(cursor, "v@test.com", "password123", "client-1", settings) + result = login( + cursor, "v@test.com", "password123", "client-1", settings + ) assert result.access_token assert result.refresh_token @@ -164,7 +174,9 @@ def test_login_wrong_email_raises(self, settings: Settings): with patch("app.services.auth.get_vendor_by_email", return_value=None): with pytest.raises(AuthenticationException): - login(cursor, "bad@test.com", "password123", "client-1", settings) + login( + cursor, "bad@test.com", "password123", "client-1", settings + ) def test_login_wrong_password_raises(self, settings: Settings): cursor = MagicMock() @@ -172,10 +184,16 @@ def test_login_wrong_password_raises(self, settings: Settings): with patch( "app.services.auth.get_vendor_by_email", - return_value={"id": "x", "email": "v@test.com", "password_hash": hashed}, + return_value={ + "id": "x", + "email": "v@test.com", + "password_hash": hashed, + }, ): with pytest.raises(AuthenticationException): - login(cursor, "v@test.com", "wrong-password", "client-1", settings) + login( + cursor, "v@test.com", "wrong-password", "client-1", settings + ) def test_refresh_success(self, settings: Settings): cursor = MagicMock() @@ -215,7 +233,9 @@ def test_refresh_deleted_vendor_raises(self, settings: Settings): rt = create_refresh_token(vendor_id, settings) with patch("app.services.auth.get_vendor_by_id", return_value=None): - with pytest.raises(AuthenticationException, match="Vendor not found"): + with pytest.raises( + AuthenticationException, match="Vendor not found" + ): refresh(rt, "client-1", cursor, settings) def test_signup_concurrent_insert_conflict(self, settings: Settings): @@ -229,7 +249,9 @@ def test_signup_concurrent_insert_conflict(self, settings: Settings): patch("app.services.auth.create_vendor", return_value=None), ): with pytest.raises(ConflictException): - signup(cursor, "race@test.com", "password123", "client-1", settings) + signup( + cursor, "race@test.com", "password123", "client-1", settings + ) @pytest.mark.unit @@ -273,3 +295,13 @@ def test_garbage_token_raises(self, settings: Settings): with pytest.raises(AuthenticationException): get_current_vendor_id(creds, settings) + + def test_malformed_vendor_id_raises(self, settings: Settings): + token = create_access_token("invalid-uuid", settings) + creds = MagicMock() + creds.credentials = token + + with pytest.raises( + AuthenticationException, match="Invalid token payload" + ): + get_current_vendor_id(creds, settings) From 12a12748a068a3fc644594617ab695e862642fca Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 11:51:06 +0500 Subject: [PATCH 4/7] Refactor middleware request ID function and update JWT dependency version --- backend/app/api/middlewares.py | 5 ++++- backend/pyproject.toml | 2 +- backend/tests/integration/test_auth.py | 6 +++--- uv.lock | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/backend/app/api/middlewares.py b/backend/app/api/middlewares.py index de4738d..f0de878 100644 --- a/backend/app/api/middlewares.py +++ b/backend/app/api/middlewares.py @@ -1,9 +1,12 @@ import uuid from fastapi import Request, Response +from starlette.middleware.base import RequestResponseEndpoint -async def add_request_id(request: Request, call_next: object) -> Response: +async def add_request_id( + request: Request, call_next: RequestResponseEndpoint +) -> Response: """Add a unique request_id to each request context. Returns: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b694553..30b0951 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "pydantic-settings<3.0.0,>=2.2.1", "pwdlib[argon2,bcrypt]>=0.3.0", "psycopg-pool>=3.3.0,<4.0.0", - "pyjwt>=2.8.0,<3.0.0", + "pyjwt>=2.11.0,<3.0.0", ] [dependency-groups] diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py index e696a2b..b49192d 100644 --- a/backend/tests/integration/test_auth.py +++ b/backend/tests/integration/test_auth.py @@ -43,7 +43,7 @@ def _protected_test(vendor_id: CurrentVendorId) -> dict: return {"vendor_id": vendor_id} -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def pg_container() -> typing.Generator[PostgresContainer, None, None]: with PostgresContainer( "postgres:18.2-alpine3.23", driver=None @@ -53,7 +53,7 @@ def pg_container() -> typing.Generator[PostgresContainer, None, None]: yield container -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def test_settings() -> Settings: return Settings( SECRET_KEY="integration-test-secret-key-32bytes!", @@ -67,7 +67,7 @@ def test_settings() -> Settings: ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def client( pg_container: PostgresContainer, test_settings: Settings ) -> typing.Generator[TestClient, None, None]: diff --git a/uv.lock b/uv.lock index 94be5b7..6652a70 100644 --- a/uv.lock +++ b/uv.lock @@ -983,7 +983,7 @@ requires-dist = [ { name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" }, { name = "pydantic", specifier = ">2.0" }, { name = "pydantic-settings", specifier = ">=2.2.1,<3.0.0" }, - { name = "pyjwt", specifier = ">=2.8.0,<3.0.0" }, + { name = "pyjwt", specifier = ">=2.11.0,<3.0.0" }, { name = "python-multipart", specifier = ">=0.0.7,<1.0.0" }, { name = "tenacity", specifier = ">=8.2.3,<9.0.0" }, ] From 3a0c8b16b99d02fdea01f0a58f833895997c25f0 Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 13:00:36 +0500 Subject: [PATCH 5/7] Enhance request ID validation in middleware and update integration tests for vendor ID retrieval --- backend/app/api/middlewares.py | 23 ++++++++++++++++-- backend/pyproject.toml | 2 +- backend/tests/integration/test_auth.py | 33 +++++++++++++++++++------- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/backend/app/api/middlewares.py b/backend/app/api/middlewares.py index f0de878..9e1f785 100644 --- a/backend/app/api/middlewares.py +++ b/backend/app/api/middlewares.py @@ -1,9 +1,25 @@ +import re import uuid from fastapi import Request, Response from starlette.middleware.base import RequestResponseEndpoint +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}" + r"-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" +) + + +def _is_valid_request_id(value: str) -> bool: + """Check that value is a well-formed UUID string. + + Returns: + bool: True when the value matches the UUID4 hex pattern. + """ + return bool(_UUID_RE.match(value)) + + async def add_request_id( request: Request, call_next: RequestResponseEndpoint ) -> Response: @@ -13,8 +29,11 @@ async def add_request_id( Response: The response with X-Request-ID header. """ raw_header = request.headers.get("X-Request-ID") - stripped_header = raw_header.strip() if raw_header else "" - request_id = stripped_header if stripped_header else str(uuid.uuid4()) + stripped = raw_header.strip() if raw_header else "" + if stripped and _is_valid_request_id(stripped): + request_id = stripped + else: + request_id = str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 30b0951..35e7e4b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -56,7 +56,7 @@ select = [ "E", # pycodestyle errors (PEP 8 style guide violations) "W", # pycodestyle warnings (e.g., extra spaces, indentation issues) "DOC", # pydoclint issues (e.g., extra or missing return, yield, warnings) - "A", # flake8-buitins (check variable and function names to not shadow builtins) + "A", # flake8-builtins (check variable and function names to not shadow builtins) "N", # Naming convention checks (e.g., PEP 8 variable and function names) "F", # Pyflakes errors (e.g., unused imports, undefined variables) "I", # isort (Ensures imports are sorted properly) diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py index b49192d..4f3c7f6 100644 --- a/backend/tests/integration/test_auth.py +++ b/backend/tests/integration/test_auth.py @@ -20,7 +20,7 @@ from psycopg import Cursor, connect from testcontainers.postgres import PostgresContainer -from app.api.deps import CurrentVendorId, get_db, get_settings +from app.api.deps import CurrentVendorId, CursorDep, get_db, get_settings from app.api.main import api_router from app.core.config import Settings from app.core.exception_handlers import ( @@ -39,8 +39,12 @@ @_test_router.get("/protected-test") -def _protected_test(vendor_id: CurrentVendorId) -> dict: - return {"vendor_id": vendor_id} +def _protected_test(vendor_id: CurrentVendorId, cursor: CursorDep) -> dict: + cursor.execute("SELECT app.set_app_context(%s)", (vendor_id,)) + cursor.execute("SELECT current_setting('app.vendor_id', true)") + row = cursor.fetchone() + db_vendor_id = row[0] if row else None + return {"vendor_id": vendor_id, "db_vendor_id": db_vendor_id} @pytest.fixture(scope="function") @@ -295,14 +299,25 @@ def test_expired_token_401( ) assert resp.status_code == 401 - def test_valid_token_returns_vendor_id( - self, client: TestClient, test_settings: Settings - ): - vendor_id = str(uuid.uuid4()) - token = create_access_token(vendor_id, test_settings) + def test_valid_token_returns_vendor_id(self, client: TestClient): + email = "protected-test@example.com" + _signup(client, email=email) + login_resp = client.post( + f"{API_V1}/auth/login", + json={ + "email": email, + "password": "SecurePass123!", + "client_id": "c1", + }, + ) + assert login_resp.status_code == 200 + token = login_resp.json()["data"]["access_token"] + resp = client.get( f"{API_V1}/protected-test", headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 200 - assert resp.json()["vendor_id"] == vendor_id + body = resp.json() + assert body["vendor_id"] + assert body["db_vendor_id"] == body["vendor_id"] From 10a7615cb1ecc434e58e2a6875011d8a0587a755 Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 13:45:42 +0500 Subject: [PATCH 6/7] Refactor exception handling and response schemas for improved clarity and consistency; enhance test coverage for error handling --- backend/app/core/exception_handlers.py | 10 +- backend/app/core/exceptions.py | 14 +- backend/app/schemas/response.py | 13 +- backend/tests/api/routes/test_health.py | 3 +- backend/tests/conftest.py | 9 +- backend/tests/core/test_exception_handlers.py | 151 ++++++++++++------ backend/tests/core/test_exceptions.py | 57 ++++--- backend/tests/integration/__init__.py | 0 backend/tests/schemas/test_response.py | 30 ++-- backend/tests/unit/__init__.py | 0 10 files changed, 190 insertions(+), 97 deletions(-) create mode 100644 backend/tests/integration/__init__.py create mode 100644 backend/tests/unit/__init__.py diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index 20f96bd..4fe854a 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -131,14 +131,16 @@ def _build_error_details( details: list[dict | ErrorDetail] | dict | ErrorDetail | None, ) -> list[ErrorDetail]: """ - Converts error details (list, dict, ErrorDetail, or None) into a list of ErrorDetail instances. + Convert error details into a list of ErrorDetail instances. + + Accepts a list, dict, ErrorDetail, or None and normalises + into ``list[ErrorDetail]``. Args: - details: Can be a list of dict/ErrorDetail, a single dict, a single ErrorDetail, or None. - Non-dict/non-ErrorDetail entries are converted to string messages. + details: Raw error details in any supported shape. Returns: - List of ErrorDetail instances. Non-dict/non-ErrorDetail entries are preserved as string messages. + list[ErrorDetail]: Normalised detail objects. """ if details is None: return [] diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py index c027cff..ce52e4b 100644 --- a/backend/app/core/exceptions.py +++ b/backend/app/core/exceptions.py @@ -33,7 +33,7 @@ def __init__( super().__init__( error_code=ErrorCode.VALIDATION_FAILED, message=message, - # Use getattr for compatibility with Starlette <0.48 which lacks HTTP_422_UNPROCESSABLE_CONTENT + # Starlette <0.48 compat: fallback to 422 http_status=getattr(status, "HTTP_422_UNPROCESSABLE_CONTENT", 422), details=details, ) @@ -94,7 +94,7 @@ def __init__( super().__init__( error_code=ErrorCode.BUSINESS_LOGIC_ERROR, message=message, - # Use getattr for compatibility with Starlette <0.48 which lacks HTTP_422_UNPROCESSABLE_CONTENT + # Starlette <0.48 compat: fallback to 422 http_status=getattr(status, "HTTP_422_UNPROCESSABLE_CONTENT", 422), details=details, ) @@ -116,7 +116,9 @@ def __init__( class AuthExpiredException(APIException): """Authentication token expired error""" - def __init__(self, message: str = "Authentication token has expired"): + def __init__( + self, message: str = "Authentication token has expired" + ) -> None: super().__init__( error_code=ErrorCode.AUTH_EXPIRED, message=message, @@ -127,7 +129,7 @@ def __init__(self, message: str = "Authentication token has expired"): class LicenseNotFoundException(APIException): """License not found error""" - def __init__(self, message: str = "License not found"): + def __init__(self, message: str = "License not found") -> None: super().__init__( error_code=ErrorCode.LICENSE_NOT_FOUND, message=message, @@ -138,7 +140,7 @@ def __init__(self, message: str = "License not found"): class LicenseRevokedException(APIException): """License has been revoked error""" - def __init__(self, message: str = "License has been revoked"): + def __init__(self, message: str = "License has been revoked") -> None: super().__init__( error_code=ErrorCode.LICENSE_REVOKED, message=message, @@ -149,7 +151,7 @@ def __init__(self, message: str = "License has been revoked"): class LicenseExpiredException(APIException): """License has expired error""" - def __init__(self, message: str = "License has expired"): + def __init__(self, message: str = "License has expired") -> None: super().__init__( error_code=ErrorCode.LICENSE_EXPIRED, message=message, diff --git a/backend/app/schemas/response.py b/backend/app/schemas/response.py index 171920c..6dd9645 100644 --- a/backend/app/schemas/response.py +++ b/backend/app/schemas/response.py @@ -60,13 +60,20 @@ class ErrorBodyResponse(BaseModel): code: ErrorCode = Field(..., description="Error code") message: str = Field(..., description="Human-readable error message") http_status: int = Field( - ..., ge=400, le=599, description="HTTP status code matching the response" + ..., + ge=400, + le=599, + description="HTTP status code matching the response", ) details: list[ErrorDetail] = Field( default_factory=list, - description="Additional error details for validation errors or other context", + description=( + "Additional error details for validation errors or other context" + ), + ) + request_id: str = Field( + ..., description="Unique request identifier for tracing" ) - request_id: str = Field(..., description="Unique request identifier for tracing") class ErrorResponse(BaseModel): diff --git a/backend/tests/api/routes/test_health.py b/backend/tests/api/routes/test_health.py index d371700..121d4b8 100644 --- a/backend/tests/api/routes/test_health.py +++ b/backend/tests/api/routes/test_health.py @@ -26,7 +26,8 @@ def test_health_check_returns_ok(): f"Expected status to be 'ok', got '{data['data']['status']}'" ) assert "timestamp" in data["data"], ( - f"Expected 'timestamp' field in health response data, got keys: {data['data'].keys()}" + "Expected 'timestamp' field in health response data," + f" got keys: {data['data'].keys()}" ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 9a3a0d5..af1d56c 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -13,9 +13,6 @@ from pathlib import Path import pytest - -from app.api.deps import get_db -from app.main import app from psycopg import Cursor, connect from testcontainers.postgres import PostgresContainer @@ -28,8 +25,10 @@ @pytest.fixture(scope="module") def test_container() -> typing.Generator[PostgresContainer, None, None]: - """Fixture that yields a Testcontainers Postgres (PostgresContainer) with migrations applied.""" - with PostgresContainer("postgres:18.2-alpine3.23", driver=None).with_volume_mapping( + """Fixture: Testcontainers Postgres with migrations.""" + with PostgresContainer( + "postgres:18.2-alpine3.23", driver=None + ).with_volume_mapping( MIGRATIONS_DIR, "/docker-entrypoint-initdb.d" ) as container: yield container diff --git a/backend/tests/core/test_exception_handlers.py b/backend/tests/core/test_exception_handlers.py index 7005395..7f2f45d 100644 --- a/backend/tests/core/test_exception_handlers.py +++ b/backend/tests/core/test_exception_handlers.py @@ -1,15 +1,20 @@ """ -Integration tests for exception handlers in app.core.exception_handlers. +Integration tests for exception handlers. -Verifies that every handler (api_exception_handler, validation_exception_handler, -general_exception_handler) produces responses that conform to the error contract: -correct HTTP status codes, error codes, message round-tripping, details structure, -request-id presence/format, and uniqueness. +Verifies that every handler produces responses that conform +to the error contract: correct HTTP status codes, error codes, +message round-tripping, details structure, request-id +presence/format, and uniqueness. """ import re import pytest +from fastapi import FastAPI, status +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from pydantic import BaseModel + from app.core.exception_handlers import ( api_exception_handler, general_exception_handler, @@ -29,10 +34,7 @@ ServiceUnavailableException, ValidationException, ) -from fastapi import FastAPI, status -from fastapi.exceptions import RequestValidationError -from fastapi.testclient import TestClient -from pydantic import BaseModel + # Pre-compiled UUID v4 pattern reused across request-id assertions _UUID_V4_RE = re.compile( @@ -42,11 +44,13 @@ @pytest.fixture(scope="module") def error_contract_app(): - """FastAPI app for error contract handler tests (all exception handlers, one endpoint per error type).""" + """FastAPI app wired with all exception handlers.""" app = FastAPI() app.add_exception_handler(APIException, api_exception_handler) - app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_exception_handler( + RequestValidationError, validation_exception_handler + ) app.add_exception_handler(Exception, general_exception_handler) @app.get("/auth-invalid") @@ -212,7 +216,7 @@ async def validate_nested(data: NestedBodyRequest): def test_error_status_code_and_code_match( error_contract_app, endpoint, expected_code, expected_status ): - """HTTP status code and error code field must match the expected values for every exception type.""" + """HTTP status and error code must match for every type.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get(endpoint) @@ -222,12 +226,14 @@ def test_error_status_code_and_code_match( ) error = response.json()["error"] assert error["code"] == expected_code, ( - f"Expected error code '{expected_code}' for {endpoint}, got '{error['code']}'" + f"Expected error code '{expected_code}' for " + f"{endpoint}, got '{error['code']}'" ) assert error["http_status"] == expected_status, ( - f"Expected error http_status {expected_status}, got {error['http_status']}" + f"Expected error http_status {expected_status}," + f" got {error['http_status']}" ) - # Explicitly validate wire vs body contract: response status must match error http_status + # Wire vs body contract: status must match assert response.status_code == error["http_status"], ( f"Response status code {response.status_code} does not match " f"error http_status {error['http_status']}" @@ -241,7 +247,7 @@ def test_error_status_code_and_code_match( @pytest.mark.integration def test_error_details_structure(error_contract_app): - """Details list must contain exact field+message pairs from the raised exception.""" + """Details list must contain exact field+message pairs.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/validation") @@ -250,7 +256,8 @@ def test_error_details_structure(error_contract_app): f"Expected 'details' field in error response, got keys: {error.keys()}" ) assert len(error["details"]) == 2, ( - f"Expected 2 detail entries, got {len(error['details'])}: {error['details']}" + f"Expected 2 detail entries, got " + f"{len(error['details'])}: {error['details']}" ) detail1 = error["details"][0] @@ -267,13 +274,14 @@ def test_error_details_structure(error_contract_app): f"Expected second detail field 'password', got '{detail2['field']}'" ) assert detail2["message"] == "Too short", ( - f"Expected second detail message 'Too short', got '{detail2['message']}'" + "Expected second detail message 'Too short'," + f" got '{detail2['message']}'" ) @pytest.mark.integration def test_error_details_default_empty(error_contract_app): - """Exceptions raised without details must produce an empty details list in the response.""" + """Exceptions without details must yield empty details.""" client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/auth-invalid") @@ -282,7 +290,8 @@ def test_error_details_default_empty(error_contract_app): f"Expected 'details' field in error response, got keys: {error.keys()}" ) assert error["details"] == [], ( - f"Expected empty details list for exception without details, got {error['details']}" + "Expected empty details list for exception" + f" without details, got {error['details']}" ) @@ -303,7 +312,10 @@ def test_validation_handler_nested_field_path(error_contract_app): assert response.status_code == getattr( status, "HTTP_422_UNPROCESSABLE_CONTENT", 422 - ), f"Expected status 422 for nested validation error, got {response.status_code}" + ), ( + "Expected status 422 for nested validation" + f" error, got {response.status_code}" + ) details = response.json()["error"]["details"] fields = [d["field"] for d in details] assert "address.street" in fields, ( @@ -312,12 +324,15 @@ def test_validation_handler_nested_field_path(error_contract_app): @pytest.mark.integration -def test_validation_handler_body_level_error_sets_field_to_none(error_contract_app): - """When Pydantic emits a body-level error (loc has only one element after slicing), - validation_exception_handler must produce field=None in the detail. +def test_validation_handler_body_level_error_sets_field_to_none( + error_contract_app, +): + """Body-level error (loc with one element after slicing) + must produce field=None in the detail. - Sending an integer body against an object-typed endpoint triggers - loc=["body"], making loc[1:] empty and field_path="", so field=None. + Sending an integer body against an object-typed endpoint + triggers loc=["body"], making loc[1:] empty and + field_path="", so field=None. """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json=5) @@ -325,13 +340,17 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a assert response.status_code == getattr( status, "HTTP_422_UNPROCESSABLE_CONTENT", 422 ), ( - f"Expected status 422 for body-level validation error, got {response.status_code}" + "Expected status 422 for body-level" + f" validation error, got {response.status_code}" ) details = response.json()["error"]["details"] - assert len(details) > 0, "Expected at least one validation detail, got empty list" + assert len(details) > 0, ( + "Expected at least one validation detail, got empty list" + ) # At least one detail must have field=None from the body-level loc assert any(d["field"] is None for d in details), ( - f"Expected at least one detail with field=None for a body-level error, got {details}" + "Expected at least one detail with" + f" field=None for body-level error, got {details}" ) @@ -344,7 +363,9 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a @pytest.mark.parametrize( "endpoint,method,json_body,raise_exceptions", [ - pytest.param("/auth-invalid", "get", None, True, id="api_exception_handler"), + pytest.param( + "/auth-invalid", "get", None, True, id="api_exception_handler" + ), pytest.param( "/validate-body", "post", @@ -360,21 +381,25 @@ def test_validation_handler_body_level_error_sets_field_to_none(error_contract_a def test_request_id_header_matches_body_and_is_uuid_v4( error_contract_app, endpoint, method, json_body, raise_exceptions ): - """X-Request-ID header must be present, match the body request_id, and be a valid UUID v4.""" - client = TestClient(error_contract_app, raise_server_exceptions=raise_exceptions) + """X-Request-ID header must match body and be UUID v4.""" + client = TestClient( + error_contract_app, raise_server_exceptions=raise_exceptions + ) request_method = getattr(client, method) kwargs = {"json": json_body} if json_body is not None else {} response = request_method(endpoint, **kwargs) assert "X-Request-ID" in response.headers, ( - f"Expected 'X-Request-ID' header in response, got headers: {list(response.headers.keys())}" + "Expected 'X-Request-ID' header, got" + f" headers: {list(response.headers.keys())}" ) request_id = response.headers["X-Request-ID"] assert request_id, "X-Request-ID header must not be empty" error = response.json()["error"] assert request_id == error["request_id"], ( - f"Header X-Request-ID '{request_id}' does not match body request_id '{error['request_id']}'" + f"Header X-Request-ID '{request_id}' does not" + f" match body request_id '{error['request_id']}'" ) assert _UUID_V4_RE.match(request_id), ( f"request_id '{request_id}' is not a valid UUID v4 format" @@ -383,13 +408,17 @@ def test_request_id_header_matches_body_and_is_uuid_v4( @pytest.mark.integration def test_request_id_uniqueness_across_requests(error_contract_app): - """Repeated requests to the same endpoint must each receive a distinct request_id.""" + """Repeated requests must each get a distinct request_id.""" client = TestClient(error_contract_app, raise_server_exceptions=False) - ids = {client.get("/auth-invalid").json()["error"]["request_id"] for _ in range(5)} + ids = { + client.get("/auth-invalid").json()["error"]["request_id"] + for _ in range(5) + } assert len(ids) == 5, ( - f"Expected 5 unique request_ids from 5 requests to the same endpoint, " - f"got {len(ids)} unique IDs. request_id may not be regenerated per request." + "Expected 5 unique request_ids from 5 requests," + f" got {len(ids)} unique IDs." + " request_id may not be regenerated per request." ) @@ -415,8 +444,12 @@ def test_api_exception_handler_returns_correct_structure(error_contract_app): @pytest.mark.integration -def test_validation_exception_handler_returns_correct_structure(error_contract_app): - """validation_exception_handler must return VALIDATION_FAILED with populated details.""" +def test_validation_exception_handler_returns_correct_structure( + error_contract_app, +): + """validation_exception_handler must return + VALIDATION_FAILED with populated details. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json={"invalid": "data"}) @@ -442,13 +475,16 @@ def test_validation_exception_handler_returns_correct_structure(error_contract_a f"Expected 'message' key in detail, got keys: {detail.keys()}" ) assert isinstance(detail["message"], str), ( - f"Expected detail message to be string, got {type(detail['message'])}" + "Expected detail message to be string," + f" got {type(detail['message'])}" ) @pytest.mark.integration def test_general_exception_handler_returns_sanitized_500(error_contract_app): - """general_exception_handler must return 500 with a sanitized message, never internal details.""" + """general_exception_handler must return 500 + with a sanitized message, never internal details. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get("/general-error") @@ -474,7 +510,9 @@ def test_general_exception_handler_returns_sanitized_500(error_contract_app): @pytest.mark.integration def test_validation_error_message_and_details(error_contract_app): - """validation_exception_handler must return the standard message with non-empty detail messages.""" + """validation_exception_handler must return the + standard message with non-empty detail messages. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.post("/validate-body", json={"invalid": "data"}) @@ -490,7 +528,8 @@ def test_validation_error_message_and_details(error_contract_app): "Expected non-empty message in detail, got empty string" ) assert isinstance(detail["message"], str), ( - f"Expected detail message to be string, got {type(detail['message'])}" + "Expected detail message to be string," + f" got {type(detail['message'])}" ) @@ -501,24 +540,36 @@ def test_validation_error_message_and_details(error_contract_app): pytest.param("/auth-invalid", "Invalid credentials", id="auth_invalid"), pytest.param("/auth-expired", "Token has expired", id="auth_expired"), pytest.param("/forbidden", "Access denied", id="forbidden"), - pytest.param("/not-found", "Resource not found", id="resource_not_found"), + pytest.param( + "/not-found", "Resource not found", id="resource_not_found" + ), pytest.param("/conflict", "Resource conflict", id="resource_conflict"), pytest.param("/validation", "Validation failed", id="validation_error"), pytest.param( - "/business-logic", "Cannot process request", id="business_logic_error" + "/business-logic", + "Cannot process request", + id="business_logic_error", ), pytest.param( "/service-unavailable", "Database is down", id="service_unavailable" ), - pytest.param("/license-not-found", "No active license", id="license_not_found"), - pytest.param("/license-revoked", "License was revoked", id="license_revoked"), - pytest.param("/license-expired", "License has expired", id="license_expired"), + pytest.param( + "/license-not-found", "No active license", id="license_not_found" + ), + pytest.param( + "/license-revoked", "License was revoked", id="license_revoked" + ), + pytest.param( + "/license-expired", "License has expired", id="license_expired" + ), ], ) def test_raised_error_messages_roundtrip_correctly( error_contract_app, endpoint, expected_message ): - """The exact message passed to the exception constructor must appear in the response.""" + """The message passed to the constructor must + appear in the response. + """ client = TestClient(error_contract_app, raise_server_exceptions=False) response = client.get(endpoint) error = response.json()["error"] diff --git a/backend/tests/core/test_exceptions.py b/backend/tests/core/test_exceptions.py index 68609a9..732e4d9 100644 --- a/backend/tests/core/test_exceptions.py +++ b/backend/tests/core/test_exceptions.py @@ -1,8 +1,9 @@ """ Unit tests for exception classes in app.core.exceptions. -Covers: error codes, HTTP status codes, base-class parameter storage, -details defaulting, custom details, custom messages, and default message lengths. +Covers: error codes, HTTP status codes, base-class parameter +storage, details defaulting, custom details, custom messages, +and default message lengths. """ import pytest @@ -136,9 +137,7 @@ def test_exception_has_correct_error_code(exception_class, expected_code): id="not_found_exception", ), pytest.param( - ConflictException, - status.HTTP_409_CONFLICT, - id="conflict_exception", + ConflictException, status.HTTP_409_CONFLICT, id="conflict_exception" ), pytest.param( BusinessLogicException, @@ -184,7 +183,7 @@ def test_exception_has_correct_http_status(exception_class, expected_status): @pytest.mark.unit def test_api_exception_base_class_stores_all_parameters(): - """APIException must store all constructor arguments and pass message to str().""" + """APIException must store all constructor args.""" exc = APIException( error_code=ErrorCode.VALIDATION_FAILED, message="Test error", @@ -204,7 +203,7 @@ def test_api_exception_base_class_stores_all_parameters(): f"Expected details with test field, got {exc.details}" ) assert str(exc) == "Test error", ( - f"Expected str(exc) to return message, got '{str(exc)}'" + f"Expected str(exc) to return message, got '{exc!s}'" ) @@ -271,10 +270,16 @@ def test_exception_with_details_none_defaults_to_empty_list(exception_class): ], ) def test_exception_stores_custom_details(exception_class, message, details): - """Exceptions must store the exact message and details passed at construction.""" + """Exceptions must store the exact message and + details passed at construction. + """ exc = exception_class(message=message, details=details) - assert exc.message == message, f"Expected message '{message}', got '{exc.message}'" - assert exc.details == details, f"Expected details {details}, got {exc.details}" + assert exc.message == message, ( + f"Expected message '{message}', got '{exc.message}'" + ) + assert exc.details == details, ( + f"Expected details {details}, got {exc.details}" + ) # --------------------------------------------------------------------------- @@ -291,8 +296,12 @@ def test_exception_stores_custom_details(exception_class, message, details): pytest.param(AuthorizationException, id="authorization_exception"), pytest.param(NotFoundException, id="not_found_exception"), pytest.param(ConflictException, id="conflict_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], @@ -339,8 +348,12 @@ def test_parameterized_exception_accepts_custom_message(exception_class): pytest.param(NotFoundException, id="not_found_exception"), pytest.param(ConflictException, id="conflict_exception"), pytest.param(BusinessLogicException, id="business_logic_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], @@ -380,16 +393,22 @@ def test_api_exception_is_subclass_of_exception(): pytest.param(ConflictException, id="conflict_exception"), pytest.param(ValidationException, id="validation_exception"), pytest.param(BusinessLogicException, id="business_logic_exception"), - pytest.param(ServiceUnavailableException, id="service_unavailable_exception"), - pytest.param(LicenseNotFoundException, id="license_not_found_exception"), + pytest.param( + ServiceUnavailableException, id="service_unavailable_exception" + ), + pytest.param( + LicenseNotFoundException, id="license_not_found_exception" + ), pytest.param(LicenseRevokedException, id="license_revoked_exception"), pytest.param(LicenseExpiredException, id="license_expired_exception"), ], ) def test_concrete_exception_is_subclass_of_api_exception(exception_class): - """Every concrete exception class must be a subclass of APIException so that - api_exception_handler catches it. If this invariant breaks, those exceptions - silently fall through to general_exception_handler and return 500s.""" + """Every concrete exception must subclass + APIException so api_exception_handler catches it. + If this invariant breaks, those exceptions silently + fall through to general_exception_handler (500s). + """ assert issubclass(exception_class, APIException), ( f"{exception_class.__name__} is not a subclass of APIException; " f"api_exception_handler will not catch it" diff --git a/backend/tests/integration/__init__.py b/backend/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/schemas/test_response.py b/backend/tests/schemas/test_response.py index 655ad99..e8da933 100644 --- a/backend/tests/schemas/test_response.py +++ b/backend/tests/schemas/test_response.py @@ -41,7 +41,9 @@ def test_error_body_response_stores_all_fields(): assert error_body.http_status == status.HTTP_400_BAD_REQUEST, ( f"Expected http_status 400, got {error_body.http_status}" ) - assert error_body.details == [], f"Expected empty details, got {error_body.details}" + assert error_body.details == [], ( + f"Expected empty details, got {error_body.details}" + ) assert error_body.request_id == "req-123", ( f"Expected request_id 'req-123', got '{error_body.request_id}'" ) @@ -70,13 +72,15 @@ def test_error_body_response_stores_all_fields(): ], ) def test_error_body_response_required_field_raises_validation_error(kwargs): - """Omitting a required field on ErrorBodyResponse must raise ValidationError.""" + """Omitting a required field must raise ValidationError.""" with pytest.raises(ValidationError) as exc_info: ErrorBodyResponse(**kwargs) # Optionally verify the missing field is in the error error_fields = {e["loc"][0] for e in exc_info.value.errors()} expected_missing = {"http_status", "request_id"} - kwargs.keys() - assert expected_missing & error_fields, f"Expected error for {expected_missing}" + assert expected_missing & error_fields, ( + f"Expected error for {expected_missing}" + ) @pytest.mark.unit @@ -110,10 +114,12 @@ def test_error_response_envelope_structure(): ) ) assert isinstance(error_response.error, ErrorBodyResponse), ( - f"Expected error field to be ErrorBodyResponse, got {type(error_response.error)}" + "Expected error field to be ErrorBodyResponse," + f" got {type(error_response.error)}" ) assert error_response.error.http_status == status.HTTP_400_BAD_REQUEST, ( - f"Expected wrapped error http_status 400, got {error_response.error.http_status}" + "Expected wrapped error http_status 400," + f" got {error_response.error.http_status}" ) @@ -133,7 +139,9 @@ def test_error_detail_requires_message_field(): def test_error_detail_accepts_field_as_none(): """ErrorDetail.field is optional — None must be stored as-is.""" detail = ErrorDetail(field=None, message="Test error") - assert detail.field is None, f"Expected field to be None, got {detail.field}" + assert detail.field is None, ( + f"Expected field to be None, got {detail.field}" + ) assert detail.message == "Test error", ( f"Expected message 'Test error', got '{detail.message}'" ) @@ -143,7 +151,9 @@ def test_error_detail_accepts_field_as_none(): def test_error_detail_with_both_fields_populated(): """ErrorDetail must store both field and message when both are provided.""" detail = ErrorDetail(field="email", message="Invalid email format") - assert detail.field == "email", f"Expected field 'email', got '{detail.field}'" + assert detail.field == "email", ( + f"Expected field 'email', got '{detail.field}'" + ) assert detail.message == "Invalid email format", ( f"Expected message 'Invalid email format', got '{detail.message}'" ) @@ -174,7 +184,9 @@ def test_error_body_response_rejects_invalid_error_code(): @pytest.mark.unit def test_error_response_requires_error_field(): - """ErrorResponse.error is a required field (Field(...)). Omitting it must - raise a ValidationError rather than silently constructing an empty envelope.""" + """ErrorResponse.error is required. Omitting it must + raise a ValidationError rather than silently + constructing an empty envelope. + """ with pytest.raises(ValidationError): ErrorResponse() diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 From e20c32805e2442fde736562d8c089cf1ceb7aad9 Mon Sep 17 00:00:00 2001 From: saudzahirr Date: Thu, 5 Mar 2026 14:05:06 +0500 Subject: [PATCH 7/7] Refactor authentication routes by removing login endpoint; update vendor creation SQL to handle email conflicts correctly; enhance integration tests with RLS cursor dependency --- backend/app/api/main.py | 3 +-- backend/app/crud/vendor.py | 2 +- backend/tests/integration/test_auth.py | 28 ++++++++++++++++++++------ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 46d462c..8272777 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,9 +1,8 @@ from fastapi import APIRouter -from app.api.routes import auth, health, login +from app.api.routes import auth, health api_router = APIRouter() api_router.include_router(health.router, tags=["health"]) -api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/app/crud/vendor.py b/backend/app/crud/vendor.py index 57c2774..318504c 100644 --- a/backend/app/crud/vendor.py +++ b/backend/app/crud/vendor.py @@ -59,7 +59,7 @@ def create_vendor( cursor.execute( 'INSERT INTO app."vendors" ("email", "password_hash") ' "VALUES (%s, %s) " - 'ON CONFLICT (LOWER("email")) DO NOTHING ' + 'ON CONFLICT ((LOWER("email"))) DO NOTHING ' 'RETURNING "id", "email"', (email, password_hash), ) diff --git a/backend/tests/integration/test_auth.py b/backend/tests/integration/test_auth.py index 4f3c7f6..e2c340b 100644 --- a/backend/tests/integration/test_auth.py +++ b/backend/tests/integration/test_auth.py @@ -20,7 +20,13 @@ from psycopg import Cursor, connect from testcontainers.postgres import PostgresContainer -from app.api.deps import CurrentVendorId, CursorDep, get_db, get_settings +from app.api.deps import ( + CurrentVendorId, + RLSCursorDep, + get_db, + get_rls_cursor, + get_settings, +) from app.api.main import api_router from app.core.config import Settings from app.core.exception_handlers import ( @@ -39,8 +45,7 @@ @_test_router.get("/protected-test") -def _protected_test(vendor_id: CurrentVendorId, cursor: CursorDep) -> dict: - cursor.execute("SELECT app.set_app_context(%s)", (vendor_id,)) +def _protected_test(vendor_id: CurrentVendorId, cursor: RLSCursorDep) -> dict: cursor.execute("SELECT current_setting('app.vendor_id', true)") row = cursor.fetchone() db_vendor_id = row[0] if row else None @@ -101,6 +106,16 @@ def _override_get_settings() -> Settings: test_app.dependency_overrides[get_db] = _override_get_db test_app.dependency_overrides[get_settings] = _override_get_settings + def _override_get_rls_cursor( + vendor_id: CurrentVendorId, + ) -> typing.Generator[Cursor, None, None]: + with connect(pg_container.get_connection_url()) as conn: + with conn.cursor() as cur: + cur.execute("SELECT app.set_app_context(%s)", (vendor_id,)) + yield cur + + test_app.dependency_overrides[get_rls_cursor] = _override_get_rls_cursor + with TestClient(test_app) as tc: yield tc @@ -301,7 +316,8 @@ def test_expired_token_401( def test_valid_token_returns_vendor_id(self, client: TestClient): email = "protected-test@example.com" - _signup(client, email=email) + signup_data = _signup(client, email=email) + created_vendor_id = signup_data["data"]["vendor"]["id"] login_resp = client.post( f"{API_V1}/auth/login", json={ @@ -319,5 +335,5 @@ def test_valid_token_returns_vendor_id(self, client: TestClient): ) assert resp.status_code == 200 body = resp.json() - assert body["vendor_id"] - assert body["db_vendor_id"] == body["vendor_id"] + assert body["vendor_id"] == created_vendor_id + assert body["db_vendor_id"] == created_vendor_id