From bac0f459d2518dfacfb04c9391d398e7214e7c4f Mon Sep 17 00:00:00 2001 From: Code With Me Date: Sun, 15 Mar 2026 23:14:13 +0300 Subject: [PATCH] feat: RBAC implementation, ownership protection and unification of services --- .github/workflows/ci.yaml | 4 +- app/core/exception_handlers.py | 10 + app/core/exceptions.py | 7 + app/core/security.py | 37 ++++ app/core/setup.py | 6 + app/services/inventory/models.py | 14 ++ app/services/inventory/routes.py | 125 ++++++++++- app/services/inventory/schemas.py | 33 +++ app/services/inventory/service.py | 182 ++++++++++++---- app/services/orders/models.py | 14 +- app/services/orders/routes.py | 37 ++-- app/services/orders/service.py | 205 ++++++++---------- app/services/user/models.py | 14 +- app/services/user/schemas.py | 3 + .../8e607d73b2b3_add_product_owner_id.py | 39 ++++ ...02cfb988689_add_rbac_and_product_status.py | 70 ++++++ tests/conftest.py | 6 + tests/test_arq_expiry.py | 13 +- tests/test_concurrency_inventory.py | 21 +- tests/test_integration_orders.py | 15 +- tests/test_integration_reserve.py | 29 ++- 21 files changed, 664 insertions(+), 220 deletions(-) create mode 100644 migrations/versions/8e607d73b2b3_add_product_owner_id.py create mode 100644 migrations/versions/902cfb988689_add_rbac_and_product_status.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 32297e2..3aced35 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -41,10 +41,10 @@ jobs: # ===== steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up UV (without pip) - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v7 with: enable-cache: true diff --git a/app/core/exception_handlers.py b/app/core/exception_handlers.py index 62c1d5d..87e1209 100644 --- a/app/core/exception_handlers.py +++ b/app/core/exception_handlers.py @@ -6,6 +6,7 @@ CredentialsError, InsufficientInventoryError, NotFoundError, + PermissionDeniedError, UserAlreadyExists, ) @@ -56,3 +57,12 @@ async def conflict_error_handler(request: Request, exc: ConflictError) -> JSONRe status_code=status.HTTP_409_CONFLICT, content={'detail': str(exc) or CONFLICT_MESSAGE}, ) + + +async def permission_denied_handler( + request: Request, exc: PermissionDeniedError +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={'detail': str(exc) or 'Permission denied'}, + ) diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 5d40866..e6acc45 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -37,3 +37,10 @@ class InsufficientInventoryError(AppError): def __init__(self, message: str = 'Insufficient inventory'): super().__init__(message=message) + + +class PermissionDeniedError(AppError): + """Permission denied.""" + + def __init__(self, message: str = 'Permission denied'): + super().__init__(message=message) diff --git a/app/core/security.py b/app/core/security.py index 3dcca39..49d4570 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,10 +1,15 @@ import asyncio from datetime import UTC, datetime, timedelta +from typing import Any +from fastapi import Depends from jose import jwt from passlib.context import CryptContext from app.core.config import settings +from app.core.exceptions import PermissionDeniedError +from app.services.user.models import User, UserRole +from app.shared.deps import get_current_user pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') @@ -40,3 +45,35 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s to_encode, settings.secret_key, algorithm=settings.jwt_algorithm ) return str(encoded_jwt) + + +async def check_permission( + user: User, allowed_roles: list[UserRole], required_verified: bool = False +) -> None: + if user.role == UserRole.ADMIN: + return + if required_verified and not user.is_verified: + raise PermissionDeniedError('User is not verified') + if user.role not in allowed_roles: + raise PermissionDeniedError( + 'User does not have permission to perform this action' + ) + + +def check_ownership(user: User, obj: Any) -> None: + if user.role in (UserRole.ADMIN, UserRole.MODERATOR): + return + if not hasattr(obj, 'owner_id'): + raise ValueError(f'Object {type(obj)} does not have owner_id') + if obj.owner_id != user.id: + raise PermissionDeniedError + + +class RoleChecker: + def __init__(self, allowed_roles: list[UserRole], required_verified: bool = False): + self.allowed_roles = allowed_roles + self.required_verified = required_verified + + async def __call__(self, user: User = Depends(get_current_user)) -> User: + await check_permission(user, self.allowed_roles, self.required_verified) + return user diff --git a/app/core/setup.py b/app/core/setup.py index 221646b..104cc82 100644 --- a/app/core/setup.py +++ b/app/core/setup.py @@ -5,6 +5,7 @@ credentials_error_handler, insufficient_inventory_error_handler, not_found_error_handler, + permission_denied_handler, user_already_exists_handler, ) from .exceptions import ( @@ -12,6 +13,7 @@ CredentialsError, InsufficientInventoryError, NotFoundError, + PermissionDeniedError, UserAlreadyExists, ) @@ -25,3 +27,7 @@ def setup_exception_handlers(app: FastAPI) -> None: InsufficientInventoryError, insufficient_inventory_error_handler, # type: ignore[arg-type] ) + app.add_exception_handler( + PermissionDeniedError, + permission_denied_handler, # type: ignore[arg-type] + ) diff --git a/app/services/inventory/models.py b/app/services/inventory/models.py index cedf687..659f591 100644 --- a/app/services/inventory/models.py +++ b/app/services/inventory/models.py @@ -1,8 +1,10 @@ from datetime import datetime from decimal import Decimal +from enum import StrEnum from uuid import UUID, uuid4 from sqlalchemy import CheckConstraint, ForeignKey, Numeric +from sqlalchemy import Enum as SQLEnum from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.sql import func from sqlalchemy.types import DateTime, Integer, String, Text @@ -13,10 +15,19 @@ DECIMAL_SCALE = 2 +class ProductStatus(StrEnum): + DRAFT = 'DRAFT' + ACTIVE = 'ACTIVE' + ARCHIVED = 'ARCHIVED' + + class Product(Base): __tablename__ = 'products' id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) + owner_id: Mapped[UUID] = mapped_column( + ForeignKey('users.id'), nullable=False, index=True + ) name: Mapped[str] = mapped_column(String(), nullable=False) description: Mapped[str | None] = mapped_column(Text(), nullable=True) price: Mapped[Decimal] = mapped_column( @@ -31,6 +42,9 @@ class Product(Base): updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) + status: Mapped[ProductStatus] = mapped_column( + SQLEnum(ProductStatus), nullable=False, default=ProductStatus.DRAFT + ) __table_args__ = ( CheckConstraint('qty_available >= 0', name='check_qty_non_negative'), diff --git a/app/services/inventory/routes.py b/app/services/inventory/routes.py index 705db4a..f4e6a61 100644 --- a/app/services/inventory/routes.py +++ b/app/services/inventory/routes.py @@ -1,16 +1,133 @@ -from fastapi import APIRouter, Depends, Header, Request +from typing import Annotated, Any +from uuid import UUID + +from fastapi import APIRouter, Depends, Header, Request, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_session +from app.core.security import RoleChecker, UserRole +from app.services.inventory.models import ProductStatus from app.services.inventory.rate_limit import check_rate_limit -from app.services.inventory.schemas import ReservationCreate, ReservationResponse -from app.services.inventory.service import reserve_items +from app.services.inventory.schemas import ( + ProductCreate, + ProductRead, + ProductUpdate, + ReservationCreate, + ReservationResponse, +) +from app.services.inventory.service import InventoryService from app.services.user.models import User from app.shared.decorators import idempotent from app.shared.deps import get_current_user router_v1 = APIRouter(prefix='/inventory', tags=['Inventory']) +SELLER_DEPENDENCY = Depends( + RoleChecker( + allowed_roles=[UserRole.SELLER, UserRole.SELLER_B2B], + required_verified=True, + ) +) +ADMIN_DEPENDENCY = Depends( + RoleChecker( + allowed_roles=[UserRole.ADMIN, UserRole.MODERATOR], + ) +) +ADMIN_AND_SELLER_DEPENDENCY = Depends( + RoleChecker( + allowed_roles=[ + UserRole.ADMIN, + UserRole.MODERATOR, + UserRole.SELLER, + UserRole.SELLER_B2B, + ], + ) +) + + +@router_v1.get('/', response_model=list[ProductRead]) +async def get_active_products( + session: Annotated[AsyncSession, Depends(get_session)], + skip: int = 0, + limit: int = 50, +) -> list[ProductRead]: + products = await InventoryService.get_products( + status=ProductStatus.ACTIVE, + skip=skip, + limit=limit, + session=session, + ) + return [ProductRead.model_validate(p) for p in products] + + +@router_v1.post('/', response_model=ProductRead, status_code=status.HTTP_201_CREATED) +async def create_product( + product_data: ProductCreate, + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, SELLER_DEPENDENCY], +) -> ProductRead: + product = await InventoryService.create_product( + session=session, + product_data=product_data, + owner_id=current_user.id, + ) + return ProductRead.model_validate(product) + + +@router_v1.patch('/{product_id}/activate', response_model=ProductRead) +async def activate_product( + product_id: UUID, + session: Annotated[AsyncSession, Depends(get_session)], + _: Any = ADMIN_DEPENDENCY, +) -> ProductRead: + product = await InventoryService.change_status( + session=session, + product_id=product_id, + status=ProductStatus.ACTIVE, + ) + return ProductRead.model_validate(product) + + +@router_v1.patch('/{product_id}', response_model=ProductRead) +async def update_product( + product_id: UUID, + product_data: ProductUpdate, + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, ADMIN_AND_SELLER_DEPENDENCY], +) -> ProductRead: + product = await InventoryService.update_product( + session=session, + product_id=product_id, + product_data=product_data, + current_user=current_user, + ) + return ProductRead.model_validate(product) + + +@router_v1.delete('/{product_id}', status_code=status.HTTP_204_NO_CONTENT) +async def delete_product( + product_id: UUID, + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, ADMIN_DEPENDENCY], +) -> None: + await InventoryService.delete_product( + session=session, + product_id=product_id, + current_user=current_user, + ) + + +@router_v1.get('/{product_id}', response_model=ProductRead) +async def get_product( + product_id: UUID, + session: Annotated[AsyncSession, Depends(get_session)], +) -> ProductRead: + product = await InventoryService.get_product( + session=session, + product_id=product_id, + ) + return ProductRead.model_validate(product) + @router_v1.post('/reserve', response_model=ReservationResponse) @idempotent() @@ -26,7 +143,7 @@ async def reservation_data( user_id=str(current_user.id), item_id=str(reservation_data.product_id), ) - result = await reserve_items( + result = await InventoryService.reserve_items( session=session, user_id=current_user.id, idempotency_key=x_idempotency_key, diff --git a/app/services/inventory/schemas.py b/app/services/inventory/schemas.py index 717221e..8a7d667 100644 --- a/app/services/inventory/schemas.py +++ b/app/services/inventory/schemas.py @@ -1,8 +1,41 @@ from datetime import datetime +from decimal import Decimal from uuid import UUID from pydantic import BaseModel, ConfigDict, Field +from app.services.inventory.models import ProductStatus + + +class ProductCreate(BaseModel): + name: str + description: str | None = None + price: Decimal = Field(gt=0, description='Price must be greater than 0') + qty_available: int = Field( + ge=0, description='Quantity must be greater than or equal to 0' + ) + + +class ProductUpdate(BaseModel): + name: str | None = None + description: str | None = None + price: Decimal | None = Field(gt=0, description='Price must be greater than 0') + qty_available: int | None = Field( + ge=0, description='Quantity must be greater than or equal to 0' + ) + + +class ProductRead(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: UUID + name: str + description: str + price: Decimal + qty_available: int + status: ProductStatus + created_at: datetime + updated_at: datetime + class ReservationCreate(BaseModel): product_id: UUID diff --git a/app/services/inventory/service.py b/app/services/inventory/service.py index 3d9a5bb..d3352e9 100644 --- a/app/services/inventory/service.py +++ b/app/services/inventory/service.py @@ -1,4 +1,4 @@ -import datetime +from datetime import UTC, datetime, timedelta from uuid import UUID from sqlalchemy import select @@ -6,44 +6,144 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.core.exceptions import ConflictError, InsufficientInventoryError, NotFoundError -from app.services.inventory.models import Product, Reservation -from app.services.inventory.schemas import ReservationCreate - - -async def reserve_items( - session: AsyncSession, - user_id: UUID, - idempotency_key: str, - reservation_data: ReservationCreate, -) -> Reservation: - result = await session.execute( - select(Product) - .with_for_update() - .where(Product.id == reservation_data.product_id) - ) - product = result.scalar_one_or_none() - if not product: - raise NotFoundError - if product.qty_available < reservation_data.quantity: - raise InsufficientInventoryError - product.qty_available -= reservation_data.quantity - expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta( - minutes=settings.reserve_timeout_minutes - ) - new_reservation = Reservation( - qty_reserved=reservation_data.quantity, - user_id=user_id, - product_id=reservation_data.product_id, - status='pending', - idempotency_key=idempotency_key, - expires_at=expires_at, - ) - session.add(new_reservation) - try: +from app.core.exceptions import ( + ConflictError, + InsufficientInventoryError, + NotFoundError, +) +from app.core.security import check_ownership +from app.services.inventory.models import Product, ProductStatus, Reservation +from app.services.inventory.schemas import ( + ProductCreate, + ProductUpdate, + ReservationCreate, +) +from app.services.orders.models import OrderStatus +from app.services.user.models import User + + +class InventoryService: + @staticmethod + async def _get_product( + session: AsyncSession, + product_id: UUID, + for_update: bool = False, + current_user: User | None = None, + ) -> Product: + query = select(Product).where(Product.id == product_id) + if for_update: + query = query.with_for_update() + result = await session.execute(query) + product = result.scalar_one_or_none() + if not product: + raise NotFoundError + if current_user: + check_ownership(current_user, product) + return product + + @staticmethod + async def change_status( + session: AsyncSession, product_id: UUID, status: ProductStatus + ) -> Product: + product = await InventoryService._get_product( + session, product_id, for_update=True + ) + product.status = status + await session.commit() + await session.refresh(product) + return product + + @staticmethod + async def create_product( + session: AsyncSession, owner_id: UUID, product_data: ProductCreate + ) -> Product: + new_product = Product(**product_data.model_dump()) + new_product.owner_id = owner_id + session.add(new_product) + await session.commit() + await session.refresh(new_product) + return new_product + + @staticmethod + async def update_product( + session: AsyncSession, + product_id: UUID, + product_data: ProductUpdate, + current_user: User, + ) -> Product: + product = await InventoryService._get_product( + session, product_id, for_update=True, current_user=current_user + ) + for field, value in product_data.model_dump(exclude_unset=True).items(): + setattr(product, field, value) + await session.commit() + await session.refresh(product) + return product + + @staticmethod + async def delete_product( + session: AsyncSession, + product_id: UUID, + current_user: User, + ) -> None: + product = await InventoryService._get_product( + session, product_id, for_update=True, current_user=current_user + ) + await session.delete(product) await session.commit() - await session.refresh(new_reservation) - return new_reservation - except IntegrityError: - await session.rollback() - raise ConflictError + + @staticmethod + async def get_product(session: AsyncSession, product_id: UUID) -> Product: + product = await InventoryService._get_product(session, product_id) + return product + + @staticmethod + async def get_products( + session: AsyncSession, + status: ProductStatus | None = None, + skip: int = 0, + limit: int = 50, + ) -> list[Product]: + query = select(Product) + if status: + query = query.where(Product.status == status) + result = await session.execute(query.offset(skip).limit(limit)) + return list(result.scalars().all()) + + @staticmethod + async def reserve_items( + session: AsyncSession, + user_id: UUID, + idempotency_key: str, + reservation_data: ReservationCreate, + ) -> Reservation: + result = await session.execute( + select(Product) + .with_for_update() + .where(Product.id == reservation_data.product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise NotFoundError + if product.qty_available < reservation_data.quantity: + raise InsufficientInventoryError + product.qty_available -= reservation_data.quantity + expires_at = datetime.now(UTC) + timedelta( + minutes=settings.reserve_timeout_minutes + ) + new_reservation = Reservation( + qty_reserved=reservation_data.quantity, + user_id=user_id, + product_id=reservation_data.product_id, + status=OrderStatus.PENDING, + idempotency_key=idempotency_key, + expires_at=expires_at, + ) + session.add(new_reservation) + try: + await session.commit() + await session.refresh(new_reservation) + return new_reservation + except IntegrityError: + await session.rollback() + raise ConflictError diff --git a/app/services/orders/models.py b/app/services/orders/models.py index 6aba6f6..c052853 100644 --- a/app/services/orders/models.py +++ b/app/services/orders/models.py @@ -15,13 +15,13 @@ class OrderStatus(StrEnum): - PENDING = 'pending' - PAID = 'paid' - SHIPPED = 'shipped' - CANCELLED = 'cancelled' - FAILED = 'failed' - COMPLETED = 'completed' - EXPIRED = 'expired' + PENDING = 'PENDING' + PAID = 'PAID' + SHIPPED = 'SHIPPED' + CANCELLED = 'CANCELLED' + FAILED = 'FAILED' + COMPLETED = 'COMPLETED' + EXPIRED = 'EXPIRED' class Order(Base): diff --git a/app/services/orders/routes.py b/app/services/orders/routes.py index 618fb8e..cc5145a 100644 --- a/app/services/orders/routes.py +++ b/app/services/orders/routes.py @@ -1,3 +1,4 @@ +from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, Header, Request @@ -6,11 +7,7 @@ from app.core.database import get_session from app.services.orders.models import Order from app.services.orders.schemas import OrderCreate, OrderResponse -from app.services.orders.service import ( - cancel_order, - confirm_order_payment, - create_order_from_reservation, -) +from app.services.orders.service import OrderService from app.services.user.models import User from app.shared.decorators import idempotent from app.shared.deps import get_current_user @@ -23,13 +20,13 @@ async def create_order_endpoint( request: Request, order_data: OrderCreate, - x_idempotency_key: str = Header(...), - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + x_idempotency_key: Annotated[str, Header(...)], + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, Depends(get_current_user)], ) -> Order: - return await create_order_from_reservation( + return await OrderService.create_order_from_reservation( session=session, - user_id=current_user.id, + current_user=current_user, order_data=order_data, ) @@ -39,14 +36,14 @@ async def create_order_endpoint( async def confirm_order_payment_endpoint( request: Request, order_id: UUID, - x_idempotency_key: str = Header(...), - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + x_idempotency_key: Annotated[str, Header(...)], + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, Depends(get_current_user)], ) -> Order: - return await confirm_order_payment( + return await OrderService.confirm_order_payment( session=session, order_id=order_id, - user_id=current_user.id, + current_user=current_user, ) @@ -55,12 +52,12 @@ async def confirm_order_payment_endpoint( async def cancel_order_endpoint( request: Request, order_id: UUID, - x_idempotency_key: str = Header(...), - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + x_idempotency_key: Annotated[str, Header(...)], + session: Annotated[AsyncSession, Depends(get_session)], + current_user: Annotated[User, Depends(get_current_user)], ) -> Order: - return await cancel_order( + return await OrderService.cancel_order( session=session, order_id=order_id, - user_id=current_user.id, + current_user=current_user, ) diff --git a/app/services/orders/service.py b/app/services/orders/service.py index 4e7fcdd..e6db8d6 100644 --- a/app/services/orders/service.py +++ b/app/services/orders/service.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import ConflictError, NotFoundError +from app.core.security import check_ownership from app.services.inventory.internal import ( cancel_reservation_by_order_and_return_stock, mark_reservation_by_order_as_completed, @@ -12,128 +13,102 @@ from app.services.inventory.models import Product, Reservation from app.services.orders.models import Order, OrderItem, OrderStatus from app.services.orders.schemas import OrderCreate +from app.services.user.models import User -async def create_order_from_reservation( - session: AsyncSession, - user_id: UUID, - order_data: OrderCreate, -) -> Order: - reservation_result = await session.execute( - select(Reservation) - .with_for_update() - .where( - Reservation.id == order_data.reservation_id, - Reservation.user_id == user_id, - Reservation.status == OrderStatus.PENDING, +class OrderService: + @staticmethod + async def _get_order( + session: AsyncSession, + order_id: UUID, + current_user: User, + for_update: bool = False, + ) -> Order: + stmt = select(Order).where(Order.id == order_id) + if for_update: + stmt = stmt.with_for_update() + result = await session.execute(stmt) + order = result.scalar_one_or_none() + if not order: + raise NotFoundError + check_ownership(current_user, order) + return order + + @staticmethod + async def create_order_from_reservation( + session: AsyncSession, + current_user: User, + order_data: OrderCreate, + ) -> Order: + reservation_result = await session.execute( + select(Reservation) + .with_for_update() + .where( + Reservation.id == order_data.reservation_id, + Reservation.user_id == current_user.id, + Reservation.status == OrderStatus.PENDING, + ) ) - ) - reservation = reservation_result.scalar_one_or_none() - if not reservation: - raise NotFoundError - if reservation.expires_at < datetime.datetime.now(datetime.UTC): - raise ConflictError - if reservation.order_id is not None: - raise ConflictError - product = ( - await session.execute( - select(Product).where(Product.id == reservation.product_id) + reservation = reservation_result.scalar_one_or_none() + if not reservation: + raise NotFoundError + if reservation.expires_at < datetime.datetime.now(datetime.UTC): + raise ConflictError + if reservation.order_id is not None: + raise ConflictError + product = ( + await session.execute( + select(Product).where(Product.id == reservation.product_id) + ) + ).scalar_one_or_none() + if not product: + raise NotFoundError + create_order = Order( + user_id=current_user.id, + total_amount=product.price * reservation.qty_reserved, + status=OrderStatus.PENDING, + shipping_address=order_data.shipping_address, ) - ).scalar_one_or_none() - if not product: - raise NotFoundError - create_order = Order( - user_id=user_id, - total_amount=product.price * reservation.qty_reserved, - status=OrderStatus.PENDING, - shipping_address=order_data.shipping_address, - ) - session.add(create_order) - await session.flush() - create_order_item = OrderItem( - order_id=create_order.id, - product_id=reservation.product_id, - product_name=product.name, - quantity=reservation.qty_reserved, - price=product.price, - ) - session.add(create_order_item) - reservation.order_id = create_order.id - await session.commit() - await session.refresh(create_order, attribute_names=['items']) - return create_order - - -async def _get_locked_order_and_reservation( - session: AsyncSession, order_id: UUID, user_id: UUID -) -> tuple[Order, Reservation]: - order_result = await session.execute( - select(Order) - .with_for_update() - .where( - Order.id == order_id, - Order.user_id == user_id, + session.add(create_order) + await session.flush() + create_order_item = OrderItem( + order_id=create_order.id, + product_id=reservation.product_id, + product_name=product.name, + quantity=reservation.qty_reserved, + price=product.price, ) - ) - order = order_result.scalar_one_or_none() - if not order: - raise NotFoundError - if order.status != OrderStatus.PENDING: - raise ConflictError - res_result = await session.execute( - select(Reservation).with_for_update().where(Reservation.order_id == order_id) - ) - reservation = res_result.scalar_one_or_none() - if not reservation: - raise NotFoundError - return order, reservation + session.add(create_order_item) + reservation.order_id = create_order.id + await session.commit() + await session.refresh(create_order, attribute_names=['items']) + return create_order - -async def confirm_order_payment( - session: AsyncSession, - order_id: UUID, - user_id: UUID, -) -> Order: - order_result = await session.execute( - select(Order) - .with_for_update() - .where( - Order.id == order_id, - Order.user_id == user_id, + @staticmethod + async def confirm_order_payment( + session: AsyncSession, order_id: UUID, current_user: User + ) -> Order: + order = await OrderService._get_order( + session, order_id, current_user, for_update=True ) - ) - order = order_result.scalar_one_or_none() - if not order: - raise NotFoundError - if order.status != OrderStatus.PENDING: - raise ConflictError - - order.status = OrderStatus.PAID - await mark_reservation_by_order_as_completed(session, order_id) - await session.commit() - return order - + if order.status != OrderStatus.PENDING: + raise ConflictError + order.status = OrderStatus.PAID + await mark_reservation_by_order_as_completed(session, order_id) + await session.commit() + return order -async def cancel_order( - session: AsyncSession, - order_id: UUID, - user_id: UUID, -) -> Order: - order_result = await session.execute( - select(Order) - .with_for_update() - .where( - Order.id == order_id, - Order.user_id == user_id, + @staticmethod + async def cancel_order( + session: AsyncSession, order_id: UUID, current_user: User + ) -> Order: + order = await OrderService._get_order( + session, order_id, current_user, for_update=True ) - ) - order = order_result.scalar_one_or_none() - if not order: - raise NotFoundError - if order.status != OrderStatus.PENDING: - raise ConflictError + if order.status != OrderStatus.PENDING: + raise ConflictError - order.status = OrderStatus.CANCELLED - await cancel_reservation_by_order_and_return_stock(session, order_id) - await session.commit() - return order + order.status = OrderStatus.CANCELLED + await cancel_reservation_by_order_and_return_stock(session, order_id) + await session.commit() + return order diff --git a/app/services/user/models.py b/app/services/user/models.py index 4589ef4..a9b6e53 100644 --- a/app/services/user/models.py +++ b/app/services/user/models.py @@ -1,6 +1,8 @@ from datetime import datetime +from enum import StrEnum from uuid import UUID, uuid4 +from sqlalchemy import Enum as SQLEnum from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.sql import func @@ -11,6 +13,15 @@ EMAIL_MAX_LENGTH = 255 +class UserRole(StrEnum): + ADMIN = 'ADMIN' + MODERATOR = 'MODERATOR' + USER = 'USER' + USER_B2B = 'USER_B2B' + SELLER = 'SELLER' + SELLER_B2B = 'SELLER_B2B' + + class User(Base): __tablename__ = 'users' id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) @@ -19,7 +30,8 @@ class User(Base): ) password_hash: Mapped[str] = mapped_column(String(), nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True) - is_superuser: Mapped[bool] = mapped_column(Boolean, default=False) + is_verified: Mapped[bool] = mapped_column(Boolean, default=False) + role: Mapped[UserRole] = mapped_column(SQLEnum(UserRole), default=UserRole.USER) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) refresh_tokens: Mapped[list['RefreshToken']] = relationship(back_populates='user') diff --git a/app/services/user/schemas.py b/app/services/user/schemas.py index c9e02e6..87d97b3 100644 --- a/app/services/user/schemas.py +++ b/app/services/user/schemas.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, ConfigDict, EmailStr +from .models import UserRole + class UserCreate(BaseModel): email: EmailStr @@ -12,6 +14,7 @@ class UserRead(BaseModel): model_config = ConfigDict(from_attributes=True) id: UUID email: EmailStr + role: UserRole class Token(BaseModel): diff --git a/migrations/versions/8e607d73b2b3_add_product_owner_id.py b/migrations/versions/8e607d73b2b3_add_product_owner_id.py new file mode 100644 index 0000000..5dd1ab3 --- /dev/null +++ b/migrations/versions/8e607d73b2b3_add_product_owner_id.py @@ -0,0 +1,39 @@ +"""add_product_owner_id + +Revision ID: 8e607d73b2b3 +Revises: 902cfb988689 +Create Date: 2026-03-15 20:41:22.381708 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '8e607d73b2b3' +down_revision: str | Sequence[str] | None = '902cfb988689' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column('products', sa.Column('owner_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_products_owner', 'products', 'users', ['owner_id'], ['id'] + ) + op.execute( + 'UPDATE products SET owner_id = (SELECT id FROM users ' + "WHERE role = 'ADMIN' LIMIT 1)" + ) + op.execute( + 'UPDATE products SET owner_id = (SELECT id FROM users LIMIT 1) ' + 'WHERE owner_id IS NULL' + ) + op.alter_column('products', 'owner_id', nullable=False) + + +def downgrade() -> None: + op.drop_constraint('fk_products_owner', 'products', type_='foreignkey') + op.drop_column('products', 'owner_id') diff --git a/migrations/versions/902cfb988689_add_rbac_and_product_status.py b/migrations/versions/902cfb988689_add_rbac_and_product_status.py new file mode 100644 index 0000000..f2c8ecd --- /dev/null +++ b/migrations/versions/902cfb988689_add_rbac_and_product_status.py @@ -0,0 +1,70 @@ +"""Add RBAC and Product status + +Revision ID: 902cfb988689 +Revises: 2a03cc82fdc7 +Create Date: 2026-03-13 16:17:58.458310 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '902cfb988689' +down_revision: str | Sequence[str] | None = '2a03cc82fdc7' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + user_role_enum = sa.Enum( + 'ADMIN', + 'MODERATOR', + 'USER', + 'USER_B2B', + 'SELLER', + 'SELLER_B2B', + name='userrole', + ) + user_role_enum.create(op.get_bind(), checkfirst=True) + + product_status_enum = sa.Enum('DRAFT', 'ACTIVE', 'ARCHIVED', name='productstatus') + product_status_enum.create(op.get_bind(), checkfirst=True) + + # 1. Add as nullable + op.add_column('products', sa.Column('status', product_status_enum, nullable=True)) + op.add_column('users', sa.Column('is_verified', sa.Boolean(), nullable=True)) + op.add_column('users', sa.Column('role', user_role_enum, nullable=True)) + + # 2. Populate data + op.execute("UPDATE products SET status = 'DRAFT'") + op.execute('UPDATE users SET is_verified = FALSE') + op.execute("UPDATE users SET role = 'USER'") + + # 3. Set NOT NULL + op.alter_column('products', 'status', nullable=False) + op.alter_column('users', 'is_verified', nullable=False) + op.alter_column('users', 'role', nullable=False) + + op.drop_column('users', 'is_superuser') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + 'users', + sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False), + ) + op.drop_column('users', 'role') + op.drop_column('users', 'is_verified') + op.drop_column('products', 'status') + + sa.Enum(name='productstatus').drop(op.get_bind(), checkfirst=True) + sa.Enum(name='userrole').drop(op.get_bind(), checkfirst=True) + # ### end Alembic commands ### diff --git a/tests/conftest.py b/tests/conftest.py index d04e401..0118039 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,8 +77,14 @@ async def redis_client() -> AsyncGenerator[Redis, None]: async def async_client( db_session_factory: async_sessionmaker[AsyncSession], redis_client: Redis ) -> AsyncGenerator[AsyncClient, None]: + from app.core.database import get_session from app.core.lua_scripts import RATE_LIMIT_LUA_SCRIPT + async def override_get_session() -> AsyncGenerator[AsyncSession, None]: + async with db_session_factory() as session: + yield session + + main_app.dependency_overrides[get_session] = override_get_session main_app.state.redis = redis_client main_app.state.rate_limit_script = redis_client.register_script( RATE_LIMIT_LUA_SCRIPT diff --git a/tests/test_arq_expiry.py b/tests/test_arq_expiry.py index e712016..899b149 100644 --- a/tests/test_arq_expiry.py +++ b/tests/test_arq_expiry.py @@ -1,11 +1,14 @@ import asyncio import datetime +from decimal import Decimal from uuid import uuid4 from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from app.services.inventory.models import Product, Reservation +from app.services.inventory.schemas import ProductCreate +from app.services.inventory.service import InventoryService from app.services.inventory.tasks import release_expired_reservations from app.services.orders.models import OrderStatus from app.services.user.models import User @@ -21,12 +24,14 @@ async def test_arq_concurrent_expiry_no_double_return( await clean_up_session.commit() async with db_session_factory() as setup_session: user = User(id=uuid4(), email=f'test_{uuid4()}@mail.com', password_hash='foo') - product = Product( - id=uuid4(), name='Test Plate carrier', price=100.0, qty_available=0 - ) setup_session.add(user) - setup_session.add(product) await setup_session.commit() + product_data = ProductCreate( + name='Test Plate carrier', price=Decimal('100.00'), qty_available=0 + ) + product = await InventoryService.create_product( + setup_session, user.id, product_data + ) for _ in range(10): reservation = Reservation( qty_reserved=1, diff --git a/tests/test_concurrency_inventory.py b/tests/test_concurrency_inventory.py index fc73d5e..29ba712 100644 --- a/tests/test_concurrency_inventory.py +++ b/tests/test_concurrency_inventory.py @@ -1,4 +1,5 @@ import asyncio +from decimal import Decimal from uuid import uuid4 import pytest @@ -6,8 +7,8 @@ from app.core.exceptions import InsufficientInventoryError from app.services.inventory.models import Product -from app.services.inventory.schemas import ReservationCreate -from app.services.inventory.service import reserve_items +from app.services.inventory.schemas import ProductCreate, ReservationCreate +from app.services.inventory.service import InventoryService from app.services.orders.models import Order # noqa: F401 from app.services.user.models import User @@ -18,14 +19,16 @@ async def test_concurrent_reservations_service_level( ) -> None: async with db_session_factory() as setup_session: user = User(id=uuid4(), email=f'test_{uuid4()}@mail.com', password_hash='foo') - product = Product( - id=uuid4(), name='Test Sneakers', price=100.0, qty_available=10 - ) setup_session.add(user) - setup_session.add(product) await setup_session.commit() - user_id = user.id + product_data = ProductCreate( + name='Test Sneakers', price=Decimal('100.00'), qty_available=10 + ) + product = await InventoryService.create_product( + setup_session, user.id, product_data + ) product_id = product.id + user_id = user.id concurrency_level = 50 async def worker() -> bool | Exception: @@ -33,7 +36,9 @@ async def worker() -> bool | Exception: request = ReservationCreate(product_id=product_id, quantity=1) idempotency_key = str(uuid4()) try: - await reserve_items(session, user_id, idempotency_key, request) + await InventoryService.reserve_items( + session, user_id, idempotency_key, request + ) return True except InsufficientInventoryError: return False diff --git a/tests/test_integration_orders.py b/tests/test_integration_orders.py index fa8cda6..421fc49 100644 --- a/tests/test_integration_orders.py +++ b/tests/test_integration_orders.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from uuid import uuid4 +from uuid import UUID, uuid4 from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession @@ -14,8 +14,15 @@ async def test_reserve_order_flow( test_password = 'super_secret_password' test_product_id = uuid4() idempotency_key = uuid4().hex + test_user_res = await async_client.post( + '/api/v1/users', json={'email': test_email, 'password': test_password} + ) + assert test_user_res.status_code == HTTPStatus.CREATED + created_user_id = UUID(test_user_res.json()['id']) + test_product = Product( id=test_product_id, + owner_id=created_user_id, name='Hatchback 02 blue', description='dayz car such as golf II blue', price='150', @@ -23,10 +30,6 @@ async def test_reserve_order_flow( ) db_session.add(test_product) await db_session.commit() - test_user = await async_client.post( - '/api/v1/users', json={'email': test_email, 'password': test_password} - ) - assert test_user.status_code == HTTPStatus.CREATED test_user_login = await async_client.post( '/api/v1/auth/token', data={'username': test_email, 'password': test_password} ) @@ -52,6 +55,6 @@ async def test_reserve_order_flow( ) assert test_order.status_code == HTTPStatus.OK test_order_data = test_order.json() - assert test_order_data['status'] == 'pending' + assert test_order_data['status'] == 'PENDING' assert test_order_data['total_amount'] == '150.00' assert 'id' in test_order_data diff --git a/tests/test_integration_reserve.py b/tests/test_integration_reserve.py index 01d53d3..d27f277 100644 --- a/tests/test_integration_reserve.py +++ b/tests/test_integration_reserve.py @@ -1,7 +1,8 @@ +import asyncio from http import HTTPStatus -from uuid import uuid4 +from uuid import UUID, uuid4 -from httpx import AsyncClient +from httpx import AsyncClient, Response from sqlalchemy.ext.asyncio import AsyncSession from app.services.inventory.models import Product @@ -15,19 +16,22 @@ async def test_reserve_flow( test_password = 'super_secret_password' test_product_id = uuid4() idempotency_key = uuid4().hex + test_user_res = await async_client.post( + '/api/v1/users', json={'email': test_email, 'password': test_password} + ) + assert test_user_res.status_code == HTTPStatus.CREATED + created_user_id = UUID(test_user_res.json()['id']) + test_product = Product( id=test_product_id, + owner_id=created_user_id, name='Hatchback 02 red', description='dayz car such as golf II', price='100', - qty_available=5, + qty_available=100, ) db_session.add(test_product) await db_session.commit() - test_user = await async_client.post( - '/api/v1/users', json={'email': test_email, 'password': test_password} - ) - assert test_user.status_code == HTTPStatus.CREATED test_user_login = await async_client.post( '/api/v1/auth/token', data={'username': test_email, 'password': test_password} ) @@ -46,7 +50,7 @@ async def test_reserve_flow( test_reserve_data = test_reserve.json() assert 'id' in test_reserve_data assert test_reserve_data['product_id'] == str(test_product_id) - assert test_reserve_data['status'] == 'pending' + assert test_reserve_data['status'] == 'PENDING' test_reserve = await async_client.post( '/api/v1/inventory/reserve', @@ -58,9 +62,8 @@ async def test_reserve_flow( assert 'id' in test_reserve_data assert test_reserve_data['product_id'] == str(test_product_id) - status_codes = [] - for _ in range(20): - response = await async_client.post( + async def make_request() -> Response: + return await async_client.post( '/api/v1/inventory/reserve', json={'product_id': str(test_product_id), 'quantity': 1}, headers={ @@ -68,6 +71,8 @@ async def test_reserve_flow( 'X-Idempotency-Key': uuid4().hex, }, ) - status_codes.append(response.status_code) + + responses = await asyncio.gather(*(make_request() for _ in range(50))) + status_codes = [r.status_code for r in responses] assert HTTPStatus.TOO_MANY_REQUESTS in status_codes