diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..81c7ec1
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,4 @@
+# pyproject.toml
+[tool.black]
+line-length = 120
+#exclude = ["src/frontend"]
diff --git a/src/backend/app/api/v1/endpoints/address.py b/src/backend/app/api/v1/endpoints/address.py
index 99d3211..b324ed7 100644
--- a/src/backend/app/api/v1/endpoints/address.py
+++ b/src/backend/app/api/v1/endpoints/address.py
@@ -11,14 +11,11 @@
AddressListResponse,
AddressCreateRequest,
AddressUpdateRequest,
- SetDefaultAddressResponse # 用于设置默认地址的响应
+ SetDefaultAddressResponse, # 用于设置默认地址的响应
)
from backend.app.services.address_service import AddressService
from backend.app.dependencies.service_deps import get_address_service
-from backend.app.utils.exceptions import (
- UserNotFoundException,
- AddressNotFoundException
-)
+from backend.app.utils.exceptions import UserNotFoundException, AddressNotFoundException
from sqlalchemy.engine.base import Connection # 用于类型提示
from backend.app.utils import logger
@@ -31,13 +28,13 @@
response_model=AddressResponse,
status_code=status.HTTP_201_CREATED,
tags=["Shipping Addresses"],
- summary="为当前用户添加新的收货地址"
+ summary="为当前用户添加新的收货地址",
)
async def add_new_address(
- address_in: AddressCreateRequest, # 请求体
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service) # 注入服务
+ address_in: AddressCreateRequest, # 请求体
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service), # 注入服务
):
"""
为当前认证的用户创建一个新的收货地址。
@@ -45,14 +42,12 @@ async def add_new_address(
并更新用户的 `DefaultAddressID`。
"""
logger.info(
- f"UserID {current_user.UserID} attempting to add new address: {address_in.model_dump(exclude_unset=True)}")
+ f"UserID {current_user.UserID} attempting to add new address: {address_in.model_dump(exclude_unset=True)}"
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
new_address = await address_service.create_new_address(
- db=db,
- address_in=address_in,
- user_id=current_user.UserID,
- actor_id=current_user.UserID
+ db=db, address_in=address_in, user_id=current_user.UserID, actor_id=current_user.UserID
)
except Exception as e:
logger.error(f"Failed to create new address for UserID {current_user.UserID}: {e}")
@@ -61,16 +56,11 @@ async def add_new_address(
return new_address
-@router.get(
- "/",
- response_model=AddressListResponse,
- tags=["Shipping Addresses"],
- summary="获取当前用户的所有收货地址"
-)
+@router.get("/", response_model=AddressListResponse, tags=["Shipping Addresses"], summary="获取当前用户的所有收货地址")
async def get_user_addresses(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service)
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service),
):
"""
检索当前已认证用户的所有收货地址列表。
@@ -89,16 +79,13 @@ async def get_user_addresses(
@router.get(
- "/{address_id}",
- response_model=AddressResponse,
- tags=["Shipping Addresses"],
- summary="获取特定收货地址的详情"
+ "/{address_id}", response_model=AddressResponse, tags=["Shipping Addresses"], summary="获取特定收货地址的详情"
)
async def get_address_by_id(
- address_id: int = FastApiPath(..., title="地址ID", description="要检索的地址的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service)
+ address_id: int = FastApiPath(..., title="地址ID", description="要检索的地址的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service),
):
"""
根据 AddressID 获取单个收货地址的详细信息。
@@ -108,10 +95,7 @@ async def get_address_by_id(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
address = await address_service.get_address_by_id_for_user(
- db=db,
- address_id=address_id,
- user_id=current_user.UserID,
- actor_id=current_user.UserID
+ db=db, address_id=address_id, user_id=current_user.UserID, actor_id=current_user.UserID
)
except AddressNotFoundException as e:
logger.error(f"Address not found for UserID {current_user.UserID}: {e}")
@@ -124,17 +108,14 @@ async def get_address_by_id(
@router.put(
- "/{address_id}",
- response_model=AddressResponse,
- tags=["Shipping Addresses"],
- summary="更新现有收货地址的详情"
+ "/{address_id}", response_model=AddressResponse, tags=["Shipping Addresses"], summary="更新现有收货地址的详情"
)
async def update_address_details(
- address_update_in: AddressUpdateRequest, # 请求体
- address_id: int = FastApiPath(..., title="地址ID", description="要更新的地址的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service)
+ address_update_in: AddressUpdateRequest, # 请求体
+ address_id: int = FastApiPath(..., title="地址ID", description="要更新的地址的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service),
):
"""
更新指定 `AddressID` 的收货地址的文本内容(收货人、电话、地址详情)。
@@ -142,7 +123,8 @@ async def update_address_details(
需要验证该地址是否属于当前用户。
"""
logger.info(
- f"Updating address AddressID: {address_id} for UserID: {current_user.UserID} with data: {address_update_in.model_dump(exclude_unset=True)}")
+ f"Updating address AddressID: {address_id} for UserID: {current_user.UserID} with data: {address_update_in.model_dump(exclude_unset=True)}"
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
updated_address = await address_service.update_address_details(
@@ -150,7 +132,7 @@ async def update_address_details(
address_id=address_id,
address_in=address_update_in,
user_id_making_change=current_user.UserID,
- actor_id=current_user.UserID
+ actor_id=current_user.UserID,
)
except AddressNotFoundException as e:
logger.error(f"Address not found for UserID {current_user.UserID}: {e}")
@@ -166,13 +148,13 @@ async def update_address_details(
"/{address_id}/set-default",
response_model=SetDefaultAddressResponse,
tags=["Shipping Addresses"],
- summary="将指定地址设为当前用户的默认收货地址"
+ summary="将指定地址设为当前用户的默认收货地址",
)
async def set_address_as_default(
- address_id: int = FastApiPath(..., title="地址ID", description="要设为默认的地址的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service)
+ address_id: int = FastApiPath(..., title="地址ID", description="要设为默认的地址的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service),
):
"""
将指定的 `AddressID` 设置为当前认证用户的默认收货地址。
@@ -186,10 +168,7 @@ async def set_address_as_default(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
resp = await address_service.set_default_address_for_user(
- db=db,
- user_id=current_user.UserID,
- address_id_to_set_default=address_id,
- actor_id=current_user.UserID
+ db=db, user_id=current_user.UserID, address_id_to_set_default=address_id, actor_id=current_user.UserID
)
except AddressNotFoundException as e:
logger.error(f"Address not found for UserID {current_user.UserID}: {e}")
@@ -203,15 +182,15 @@ async def set_address_as_default(
@router.delete(
"/{address_id}",
- status_code=status.HTTP_204_NO_CONTENT, # 成功删除通常返回 204
+ status_code=status.HTTP_204_NO_CONTENT,
tags=["Shipping Addresses"],
- summary="删除用户的特定收货地址"
+ summary="删除用户的特定收货地址", # 成功删除通常返回 204
)
async def delete_address(
- address_id: int = FastApiPath(..., title="地址ID", description="要删除的地址的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- address_service: AddressService = Depends(get_address_service)
+ address_id: int = FastApiPath(..., title="地址ID", description="要删除的地址的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ address_service: AddressService = Depends(get_address_service),
):
"""
根据 `AddressID` 删除当前用户的某个收货地址。
@@ -223,10 +202,7 @@ async def delete_address(
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里假设 AddressService 有一个 delete_address 方法
success = await address_service.delete_address_for_user(
- db=db,
- address_id=address_id,
- user_id_making_request=current_user.UserID,
- actor_id=current_user.UserID
+ db=db, address_id=address_id, user_id_making_request=current_user.UserID, actor_id=current_user.UserID
)
except AddressNotFoundException as e:
logger.error(f"Address not found for UserID {current_user.UserID}: {e}")
diff --git a/src/backend/app/api/v1/endpoints/auth.py b/src/backend/app/api/v1/endpoints/auth.py
index a64c4c5..355cb1d 100644
--- a/src/backend/app/api/v1/endpoints/auth.py
+++ b/src/backend/app/api/v1/endpoints/auth.py
@@ -4,8 +4,13 @@
from sqlalchemy.engine.base import Connection
from typing import Optional, Any, Dict
-from backend.app.dependencies.auth_deps import get_db_connection, get_auth_service, \
- get_current_active_user, get_current_user_payload, get_token_from_auth_header # 导入依赖项
+from backend.app.dependencies.auth_deps import (
+ get_db_connection,
+ get_auth_service,
+ get_current_active_user,
+ get_current_user_payload,
+ get_token_from_auth_header,
+) # 导入依赖项
from backend.app.services.auth_service import AuthService
from backend.app.schemas.auth_schema import Token, UserLogin, TokenPayload # 导入 Schemas
from backend.app.schemas.user_schema import UserResponse # 用于 /logout 的 actor
@@ -17,10 +22,10 @@
@router.post("/token", response_model=Token, tags=["Authentication"])
async def login_for_access_token(
- db: Connection = Depends(get_db_connection),
- form_data: OAuth2PasswordRequestForm = Depends(), # 从请求表单中获取 username (identifier) 和 password
- auth_service: AuthService = Depends(get_auth_service),
- # request: Request = None # 如果需要获取 IP 地址和 User-Agent
+ db: Connection = Depends(get_db_connection),
+ form_data: OAuth2PasswordRequestForm = Depends(), # 从请求表单中获取 username (identifier) 和 password
+ auth_service: AuthService = Depends(get_auth_service),
+ # request: Request = None # 如果需要获取 IP 地址和 User-Agent
):
"""
用户登录以获取访问令牌。
@@ -31,9 +36,7 @@ async def login_for_access_token(
# 假设你的 AuthService.login_user 期望 UserLogin schema
# 或者它可以直接处理 identifier 和 password
# 为了与 UserLogin schema 保持一致,我们可以创建一个 UserLogin 实例
- logger.info(
- f"Handling login request: {form_data.username=}, {form_data.password=}"
- )
+ logger.info(f"Handling login request: {form_data.username=}, {form_data.password=}")
login_credentials = UserLogin(UsernameOrEmail=form_data.username, Password=form_data.password)
# login_user 通常需要事务来创建会话记录
@@ -64,30 +67,26 @@ async def login_for_access_token(
@router.post("/logout", summary="Logout Current Session", tags=["Authentication"])
async def logout_current_session(
- # 为了登出,我们需要知道是哪个 token/session 要失效。
- # 通常,客户端会发送它当前的 Bearer token。
- # get_current_active_user 依赖项会验证这个 token 并提供用户信息。
- # 我们需要从 token payload 中获取 jti (session_token_in_db)
- # 或者,让 get_current_active_user 返回原始 token 字符串或 jti
- # 为了简单,我们假设客户端会明确发送它想要使其失效的 token。
- # 但更安全的做法是,从已认证用户的 token 中提取 jti。
-
- # 方案1:依赖已认证用户,并从其 token payload 中获取 jti (推荐)
- # (需要修改 get_current_active_user 或 get_current_user_payload 来暴露 jti 或原始 token)
- # current_user_payload: TokenPayload = Depends(get_current_user_payload), # 假设这个依赖返回 payload
-
- # 方案2:客户端直接发送要使其失效的 token 字符串 (需要确保这个 token 属于当前用户)
- # token_to_invalidate: str = Body(..., embed=True, description="要使其失效的JWT访问令牌"),
-
- # 方案3:依赖 get_current_active_user,并假设其内部或 AuthService 能处理
- # 我们将使用此方案,并假设 AuthService.logout_user_session 接收完整的 JWT
- # 并且 actor_user_id 就是当前认证的用户。
-
- # 为了让 logout 生效,它必须是一个受保护的路由,所以需要一个有效的 token
- current_user: UserResponse = Depends(get_current_active_user), # 确保用户已认证
- token: str = Depends(get_token_from_auth_header), # 获取当前用户的 token
- db: Connection = Depends(get_db_connection),
- auth_service: AuthService = Depends(get_auth_service)
+ # 为了登出,我们需要知道是哪个 token/session 要失效。
+ # 通常,客户端会发送它当前的 Bearer token。
+ # get_current_active_user 依赖项会验证这个 token 并提供用户信息。
+ # 我们需要从 token payload 中获取 jti (session_token_in_db)
+ # 或者,让 get_current_active_user 返回原始 token 字符串或 jti
+ # 为了简单,我们假设客户端会明确发送它想要使其失效的 token。
+ # 但更安全的做法是,从已认证用户的 token 中提取 jti。
+ # 方案1:依赖已认证用户,并从其 token payload 中获取 jti (推荐)
+ # (需要修改 get_current_active_user 或 get_current_user_payload 来暴露 jti 或原始 token)
+ # current_user_payload: TokenPayload = Depends(get_current_user_payload), # 假设这个依赖返回 payload
+ # 方案2:客户端直接发送要使其失效的 token 字符串 (需要确保这个 token 属于当前用户)
+ # token_to_invalidate: str = Body(..., embed=True, description="要使其失效的JWT访问令牌"),
+ # 方案3:依赖 get_current_active_user,并假设其内部或 AuthService 能处理
+ # 我们将使用此方案,并假设 AuthService.logout_user_session 接收完整的 JWT
+ # 并且 actor_user_id 就是当前认证的用户。
+ # 为了让 logout 生效,它必须是一个受保护的路由,所以需要一个有效的 token
+ current_user: UserResponse = Depends(get_current_active_user), # 确保用户已认证
+ token: str = Depends(get_token_from_auth_header), # 获取当前用户的 token
+ db: Connection = Depends(get_db_connection),
+ auth_service: AuthService = Depends(get_auth_service),
) -> Dict[str, str]:
"""
登出当前用户的当前会话。
@@ -100,28 +99,29 @@ async def logout_current_session(
success = await auth_service.logout_user_session(
conn=db,
jwt_token_to_invalidate=token,
- actor_user_id=current_user.UserID # 使用已认证用户的ID作为操作者
+ actor_user_id=current_user.UserID, # 使用已认证用户的ID作为操作者
)
except Exception as e:
logger.exception(f"Error during logout for user {current_user.UserID}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="An error occurred during logout.")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred during logout."
+ )
if success:
logger.success(f"User {current_user.UserID} successfully logged out current session.")
return {"message": "Successfully logged out current session."}
else:
# 这可能意味着 token 已经无效或会话不存在
logger.warning(
- f"Logout attempt for token (jti derived) failed for user {current_user.UserID}, session might already be invalid.")
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
- detail="Logout failed or session already invalid.")
+ f"Logout attempt for token (jti derived) failed for user {current_user.UserID}, session might already be invalid."
+ )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Logout failed or session already invalid.")
@router.post("/logout-all", summary="Logout All Sessions for Current User", tags=["Authentication"])
async def logout_all_sessions(
- current_user: UserResponse = Depends(get_current_active_user), # 确保用户已认证
- db: Connection = Depends(get_db_connection),
- auth_service: AuthService = Depends(get_auth_service)
+ current_user: UserResponse = Depends(get_current_active_user), # 确保用户已认证
+ db: Connection = Depends(get_db_connection),
+ auth_service: AuthService = Depends(get_auth_service),
) -> Dict[str, Any]:
"""
登出当前用户的所有会话(所有设备)。
@@ -132,13 +132,14 @@ async def logout_all_sessions(
with db.begin_nested() if db.in_transaction() else db.begin():
deleted_count = await auth_service.logout_all_user_sessions(
conn=db,
- user_id_to_logout=current_user.UserID, # 从已认证用户获取 UserID
- actor_user_id=current_user.UserID # 操作者是用户自己
+ user_id_to_logout=current_user.UserID,
+ actor_user_id=current_user.UserID, # 从已认证用户获取 UserID # 操作者是用户自己
)
logger.success(f"User {current_user.UserID} successfully logged out {deleted_count} sessions.")
return {"message": f"Successfully logged out from all devices. {deleted_count} sessions invalidated."}
except Exception as e:
logger.exception(f"Error during logout-all for user {current_user.UserID}: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="An error occurred during logout from all devices.")
-
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="An error occurred during logout from all devices.",
+ )
diff --git a/src/backend/app/api/v1/endpoints/cart.py b/src/backend/app/api/v1/endpoints/cart.py
index 450af50..f024e5d 100644
--- a/src/backend/app/api/v1/endpoints/cart.py
+++ b/src/backend/app/api/v1/endpoints/cart.py
@@ -13,30 +13,34 @@
CartItemResponse,
CartItemCreateRequest,
CartItemUpdateRequest,
- CartActionResponse
+ CartActionResponse,
)
from backend.app.services.cart_service import CartService
from backend.app.dependencies.service_deps import get_cart_service
from backend.app.utils import logger
-from backend.app.utils.exceptions import ProductFieldMissingException, ProductNotFoundException, \
- CartItemNotFoundException
+from backend.app.utils.exceptions import (
+ ProductFieldMissingException,
+ ProductNotFoundException,
+ CartItemNotFoundException,
+)
router = APIRouter()
# --- 购物车整体操作 ---
+
@router.get(
- "/", # 路径相对于包含此路由器的前缀,例如 /api/v1/cart/
+ "/",
response_model=CartResponse,
tags=["Shopping Cart"],
- summary="获取当前用户的购物车内容"
+ summary="获取当前用户的购物车内容", # 路径相对于包含此路由器的前缀,例如 /api/v1/cart/
)
async def get_user_cart(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- cart_service: CartService = Depends(get_cart_service)
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ cart_service: CartService = Depends(get_cart_service),
):
"""
检索当前已认证用户的完整购物车信息。
@@ -54,16 +58,11 @@ async def get_user_cart(
return CartResponse(Items=[], TotalItems=0)
-@router.delete(
- "/",
- response_model=CartActionResponse,
- tags=["Shopping Cart"],
- summary="清空当前用户的购物车"
-)
+@router.delete("/", response_model=CartActionResponse, tags=["Shopping Cart"], summary="清空当前用户的购物车")
async def clear_user_cart(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- cart_service: CartService = Depends(get_cart_service)
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ cart_service: CartService = Depends(get_cart_service),
):
"""
删除当前已认证用户购物车中的所有商品条目。
@@ -74,10 +73,7 @@ async def clear_user_cart(
with db.begin_nested() if db.in_transaction() else db.begin():
result = await cart_service.clear_cart(db, user_id=current_user.UserID, actor_id=current_user.UserID)
logger.success(f"Cart cleared for UserID: {current_user.UserID}. Deleted {result} items.")
- return CartActionResponse(
- Message=f"Cart cleared successfully. Deleted {result} items.",
- Detail={}
- )
+ return CartActionResponse(Message=f"Cart cleared successfully. Deleted {result} items.", Detail={})
except Exception as e:
logger.error(f"Error clearing cart for UserID {current_user.UserID}: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to clear cart")
@@ -85,18 +81,19 @@ async def clear_user_cart(
# --- 购物车内单个商品条目的操作 ---
+
@router.post(
"/items", # 路径: /api/v1/cart/items
response_model=CartItemResponse, # 或者返回整个 CartResponse
status_code=status.HTTP_201_CREATED,
tags=["Shopping Cart"],
- summary="向购物车添加新商品或增加已存在商品的数量"
+ summary="向购物车添加新商品或增加已存在商品的数量",
)
async def add_item_to_cart(
- item_in: CartItemCreateRequest, # 请求体
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- cart_service: CartService = Depends(get_cart_service)
+ item_in: CartItemCreateRequest, # 请求体
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ cart_service: CartService = Depends(get_cart_service),
):
"""
将指定商品添加到当前用户的购物车。
@@ -105,7 +102,8 @@ async def add_item_to_cart(
"""
logger.info(
f"Adding item to cart for UserID: {current_user.UserID},"
- f" ProductID: {item_in.ProductID}, Quantity: {item_in.Quantity}")
+ f" ProductID: {item_in.ProductID}, Quantity: {item_in.Quantity}"
+ )
try:
if item_in.Quantity <= 0:
raise ValueError("Quantity must be greater than 0.")
@@ -115,10 +113,12 @@ async def add_item_to_cart(
user_id=current_user.UserID,
product_id=item_in.ProductID,
quantity_to_add=item_in.Quantity,
- actor_id=current_user.UserID
+ actor_id=current_user.UserID,
)
- logger.success(f"Added item {item_in.ProductID}, quantity {item_in.Quantity}"
- f" to cart for UserID: {current_user.UserID}.")
+ logger.success(
+ f"Added item {item_in.ProductID}, quantity {item_in.Quantity}"
+ f" to cart for UserID: {current_user.UserID}."
+ )
return result
except ValueError as e:
logger.error(f"Invalid input from UserID {current_user.UserID}: {e}. Full input: {item_in}")
@@ -138,21 +138,22 @@ async def add_item_to_cart(
"/items/{cart_item_id}", # 路径: /api/v1/cart/items/{cart_item_id}
response_model=CartItemResponse,
tags=["Shopping Cart"],
- summary="更新购物车中特定商品的数量"
+ summary="更新购物车中特定商品的数量",
)
async def update_cart_item_quantity(
- item_update: CartItemUpdateRequest, # 请求体
- cart_item_id: int = FastApiPath(..., title="购物车条目ID", description="要更新的购物车条目的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- cart_service: CartService = Depends(get_cart_service)
+ item_update: CartItemUpdateRequest, # 请求体
+ cart_item_id: int = FastApiPath(..., title="购物车条目ID", description="要更新的购物车条目的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ cart_service: CartService = Depends(get_cart_service),
):
"""
更新购物车中指定 `CartItemID` 的商品数量。
新的 `Quantity` 必须大于0。
"""
logger.info(
- f"Updating quantity for CartItemID: {cart_item_id} to {item_update.Quantity} for UserID: {current_user.UserID}")
+ f"Updating quantity for CartItemID: {cart_item_id} to {item_update.Quantity} for UserID: {current_user.UserID}"
+ )
try:
if item_update.Quantity <= 0:
@@ -162,10 +163,12 @@ async def update_cart_item_quantity(
db,
cart_item_id=cart_item_id,
new_quantity=item_update.Quantity,
- user_id_making_change=current_user.UserID
+ user_id_making_change=current_user.UserID,
)
- logger.success(f"Updated CartItemID: {cart_item_id} to quantity {item_update.Quantity}"
- f" for UserID: {current_user.UserID}.")
+ logger.success(
+ f"Updated CartItemID: {cart_item_id} to quantity {item_update.Quantity}"
+ f" for UserID: {current_user.UserID}."
+ )
return result
except ValueError as e:
logger.error(f"Invalid input from UserID {current_user.UserID}: {e}. Full input: {item_update}")
@@ -182,26 +185,23 @@ async def update_cart_item_quantity(
"/items/{cart_item_id}",
response_model=CartActionResponse, # 或者 status_code=status.HTTP_204_NO_CONTENT 且不返回响应体
tags=["Shopping Cart"],
- summary="从购物车中移除特定商品条目"
+ summary="从购物车中移除特定商品条目",
)
async def remove_cart_item(
- cart_item_id: int = FastApiPath(..., title="购物车条目ID", description="要移除的购物车条目的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- cart_service: CartService = Depends(get_cart_service)
+ cart_item_id: int = FastApiPath(..., title="购物车条目ID", description="要移除的购物车条目的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ cart_service: CartService = Depends(get_cart_service),
):
"""
根据 `CartItemID` 从当前用户的购物车中移除一个商品条目。
"""
- logger.info(
- f"Removing CartItemID: {cart_item_id} from cart for UserID: {current_user.UserID}")
+ logger.info(f"Removing CartItemID: {cart_item_id} from cart for UserID: {current_user.UserID}")
result = False
try:
with db.begin_nested() if db.in_transaction() else db.begin():
result = await cart_service.remove_cart_item(
- db,
- cart_item_id=cart_item_id,
- user_id_making_change=current_user.UserID
+ db, cart_item_id=cart_item_id, user_id_making_change=current_user.UserID
)
except CartItemNotFoundException as e:
logger.error(f"Cart item not found from UserID {current_user.UserID}: {e}. CartItemID: {cart_item_id}")
diff --git a/src/backend/app/api/v1/endpoints/category.py b/src/backend/app/api/v1/endpoints/category.py
index 9ad1ebc..630364c 100644
--- a/src/backend/app/api/v1/endpoints/category.py
+++ b/src/backend/app/api/v1/endpoints/category.py
@@ -14,9 +14,9 @@
@router.post("", response_model=CategoryResponse)
async def create_category(
- category_in: CategoryCreate,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None, # 这里应该由认证中间件提供
+ category_in: CategoryCreate,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None, # 这里应该由认证中间件提供
) -> CategoryResponse:
"""
创建新商品分类
@@ -33,26 +33,20 @@ async def create_category(
category_name=category_in.CategoryName,
category_description=category_in.CategoryDescription,
parent_category_id=category_in.ParentCategoryID,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return CategoryResponse(**created_category)
except ValueError as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"创建分类时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="创建分类时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建分类时发生系统错误")
@router.get("/tree", response_model=List[CategoryWithChildren])
async def get_category_tree(
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
) -> List[CategoryWithChildren]:
"""
获取分类树结构
@@ -61,25 +55,22 @@ async def get_category_tree(
:return: 树形结构的分类列表
"""
category_crud = get_category_crud_instance()
- category_tree = category_crud.get_category_tree(
- conn=db,
- actor_id=current_user_id
- )
-
+ category_tree = category_crud.get_category_tree(conn=db, actor_id=current_user_id)
+
# 递归处理每个节点,将dict转为CategoryWithChildren对象
def process_tree_node(node_dict):
children = node_dict.pop("Children", [])
processed_children = [process_tree_node(child) for child in children]
return CategoryWithChildren(**node_dict, Children=processed_children)
-
+
return [process_tree_node(category) for category in category_tree]
@router.get("/{category_id}", response_model=CategoryResponse)
async def get_category(
- category_id: int,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
+ category_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
) -> CategoryResponse:
"""
获取特定分类的详细信息
@@ -89,24 +80,17 @@ async def get_category(
:return: 分类详细信息
"""
category_crud = get_category_crud_instance()
- category = category_crud.get_category_by_id(
- conn=db,
- category_id=category_id,
- actor_id=current_user_id
- )
+ category = category_crud.get_category_by_id(conn=db, category_id=category_id, actor_id=current_user_id)
if not category:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="分类不存在"
- )
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
return CategoryResponse(**category)
@router.get("", response_model=List[CategoryResponse])
async def list_categories(
- parent_id: Optional[int] = None,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
+ parent_id: Optional[int] = None,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
) -> List[CategoryResponse]:
"""
获取分类列表,如果指定了parent_id,则获取该分类的子分类,否则获取所有一级分类
@@ -116,20 +100,16 @@ async def list_categories(
:return: 分类列表
"""
category_crud = get_category_crud_instance()
- categories = category_crud.get_categories(
- conn=db,
- parent_id=parent_id,
- actor_id=current_user_id
- )
+ categories = category_crud.get_categories(conn=db, parent_id=parent_id, actor_id=current_user_id)
return [CategoryResponse(**category) for category in categories]
@router.put("/{category_id}", response_model=CategoryResponse)
async def update_category(
- category_id: int,
- category_update: CategoryUpdate,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
+ category_id: int,
+ category_update: CategoryUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
) -> CategoryResponse:
"""
更新分类信息
@@ -142,50 +122,35 @@ async def update_category(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
category_crud = get_category_crud_instance()
-
+
# 先检查分类是否存在
- category = category_crud.get_category_by_id(
- conn=db,
- category_id=category_id
- )
-
+ category = category_crud.get_category_by_id(conn=db, category_id=category_id)
+
if not category:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="分类不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
+
# 转换为字典并过滤掉None值
update_data = {k.lower(): v for k, v in category_update.model_dump().items() if v is not None}
-
+
updated_category = category_crud.update_category(
- conn=db,
- category_id=category_id,
- update_data=update_data,
- actor_id=current_user_id
+ conn=db, category_id=category_id, update_data=update_data, actor_id=current_user_id
)
-
+
return CategoryResponse(**updated_category)
except HTTPException:
raise
except ValueError as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"更新分类时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新分类时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新分类时发生系统错误")
@router.delete("/{category_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_category(
- category_id: int,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
+ category_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
):
"""
删除分类
@@ -196,34 +161,18 @@ async def delete_category(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
category_crud = get_category_crud_instance()
-
+
# 先检查分类是否存在
- category = category_crud.get_category_by_id(
- conn=db,
- category_id=category_id
- )
-
+ category = category_crud.get_category_by_id(conn=db, category_id=category_id)
+
if not category:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="分类不存在"
- )
-
- result = category_crud.delete_category(
- conn=db,
- category_id=category_id,
- actor_id=current_user_id
- )
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
+
+ result = category_crud.delete_category(conn=db, category_id=category_id, actor_id=current_user_id)
except HTTPException:
raise
except ValueError as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"删除分类时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="删除分类时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除分类时发生系统错误")
diff --git a/src/backend/app/api/v1/endpoints/order.py b/src/backend/app/api/v1/endpoints/order.py
index 34d80ba..51902c2 100644
--- a/src/backend/app/api/v1/endpoints/order.py
+++ b/src/backend/app/api/v1/endpoints/order.py
@@ -13,15 +13,19 @@
OrderListResponse,
OrderViewResponse,
OrderUpdateStatusRequest,
- OrderActionResponse
+ OrderActionResponse,
)
from backend.app.services.order_service import OrderService
from backend.app.dependencies.service_deps import get_order_service
from sqlalchemy.engine.base import Connection # 用于类型提示
from backend.app.utils import logger # 假设您配置了 logger
-from backend.app.utils.exceptions import InsufficientStockException, OrderNotFoundException, PermissionDeniedException, \
- InvalidStatusTransitionException
+from backend.app.utils.exceptions import (
+ InsufficientStockException,
+ OrderNotFoundException,
+ PermissionDeniedException,
+ InvalidStatusTransitionException,
+)
router = APIRouter()
@@ -31,13 +35,13 @@
response_model=InitiateOrderResponse,
status_code=status.HTTP_201_CREATED,
tags=["Orders"],
- summary="创建新订单 (基于购物车项目)"
+ summary="创建新订单 (基于购物车项目)",
)
async def create_order_from_cart_items(
- order_in: OrderCreateRequest, # 请求体,包含 ShippingAddressID 和 CartItemIDs
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- order_service: OrderService = Depends(get_order_service)
+ order_in: OrderCreateRequest, # 请求体,包含 ShippingAddressID 和 CartItemIDs
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ order_service: OrderService = Depends(get_order_service),
):
"""
为当前认证的用户创建一个或多个新订单(可能按店铺拆分)。
@@ -50,10 +54,7 @@ async def create_order_from_cart_items(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
initiate_order_response = await order_service.process_order_creation(
- db=db,
- user_id=current_user.UserID,
- order_create_request=order_in,
- actor_id=current_user.UserID
+ db=db, user_id=current_user.UserID, order_create_request=order_in, actor_id=current_user.UserID
)
return initiate_order_response
except InsufficientStockException as e:
@@ -67,14 +68,14 @@ async def create_order_from_cart_items(
"/", # 路径: GET /api/v1/order/ (获取当前用户的所有订单)
response_model=OrderListResponse,
tags=["Orders"],
- summary="获取当前用户的所有订单列表,应当是一个按照创建时间降序的列表。"
+ summary="获取当前用户的所有订单列表,应当是一个按照创建时间降序的列表。",
)
async def get_my_orders(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- offset: int = 0, # 分页参数
- limit: int = 20, # 分页参数
- order_service: OrderService = Depends(get_order_service)
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ offset: int = 0, # 分页参数
+ limit: int = 20, # 分页参数
+ order_service: OrderService = Depends(get_order_service),
):
"""
检索当前已认证用户的所有订单列表,支持分页。
@@ -83,11 +84,7 @@ async def get_my_orders(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
order_list_response = await order_service.get_orders_for_user(
- db=db,
- user_id=current_user.UserID,
- actor_id=current_user.UserID,
- offset=offset,
- limit=limit
+ db=db, user_id=current_user.UserID, actor_id=current_user.UserID, offset=offset, limit=limit
)
return order_list_response
except Exception as e:
@@ -96,16 +93,13 @@ async def get_my_orders(
@router.get(
- "/{order_id}", # 路径: GET /api/v1/orders/{order_id}
- response_model=OrderViewResponse,
- tags=["Orders"],
- summary="获取单个订单的详细信息"
-)
+ "/{order_id}", response_model=OrderViewResponse, tags=["Orders"], summary="获取单个订单的详细信息"
+) # 路径: GET /api/v1/orders/{order_id}
async def get_order_details(
- order_id: int = FastApiPath(..., title="订单ID", description="要检索的订单的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- order_service: OrderService = Depends(get_order_service)
+ order_id: int = FastApiPath(..., title="订单ID", description="要检索的订单的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ order_service: OrderService = Depends(get_order_service),
):
"""
根据 OrderID 获取单个订单的详细信息。
@@ -114,10 +108,7 @@ async def get_order_details(
logger.info(f"Fetching details for OrderID: {order_id} by UserID: {current_user.UserID}")
try:
order_details = await order_service.get_order_details_by_id_for_user(
- db=db,
- order_id=order_id,
- user_id=current_user.UserID, # 用于权限检查
- actor_id=current_user.UserID
+ db=db, order_id=order_id, user_id=current_user.UserID, actor_id=current_user.UserID # 用于权限检查
)
return order_details
except OrderNotFoundException as e:
@@ -131,14 +122,14 @@ async def get_order_details(
"/{order_id}/status", # 路径: PUT /api/v1/orders/{order_id}/status
response_model=OrderActionResponse, # 或者返回更新后的 OrderViewResponse
tags=["Orders"],
- summary="更新订单状态 (例如:用户取消或确认收货)"
+ summary="更新订单状态 (例如:用户取消或确认收货)",
)
async def update_order_status_by_user(
- order_update_in: OrderUpdateStatusRequest, # 请求体
- order_id: int = FastApiPath(..., title="订单ID", description="要更新状态的订单的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- order_service: OrderService = Depends(get_order_service)
+ order_update_in: OrderUpdateStatusRequest, # 请求体
+ order_id: int = FastApiPath(..., title="订单ID", description="要更新状态的订单的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ order_service: OrderService = Depends(get_order_service),
):
"""
允许用户更新其订单的状态。
@@ -148,7 +139,8 @@ async def update_order_status_by_user(
服务层需要包含严格的状态转换逻辑和权限检查。
"""
logger.info(
- f"UserID {current_user.UserID} attempting to update OrderID: {order_id} to status: {order_update_in.NewStatus}")
+ f"UserID {current_user.UserID} attempting to update OrderID: {order_id} to status: {order_update_in.NewStatus}"
+ )
# --- 服务层调用将在此处 ---
try:
with db.begin_nested() if db.in_transaction() else db.begin():
@@ -159,12 +151,12 @@ async def update_order_status_by_user(
tracking_number=None,
notes=order_update_in.UserNotes, # 用户备注
actor_id=current_user.UserID,
- is_admin_action=False # 明确这是用户操作
+ is_admin_action=False, # 明确这是用户操作
)
return OrderActionResponse(
Message=f"Order {order_id} status updated to {updated_order_info.OrderStatus}.",
OrderID=order_id,
- NewStatus=updated_order_info.OrderStatus
+ NewStatus=updated_order_info.OrderStatus,
)
except OrderNotFoundException as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -176,6 +168,7 @@ async def update_order_status_by_user(
logger.exception(f"Failed to update status for OrderID {order_id}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update order status.")
+
# --- (可选) 管理员操作订单的端点 ---
# 这些端点应该放在一个单独的管理员路由模块中,并有专门的管理员认证依赖。
# 例如: /api/v1/admin/orders/{order_id}/status
diff --git a/src/backend/app/api/v1/endpoints/payment.py b/src/backend/app/api/v1/endpoints/payment.py
index 0d01344..3a5b244 100644
--- a/src/backend/app/api/v1/endpoints/payment.py
+++ b/src/backend/app/api/v1/endpoints/payment.py
@@ -10,7 +10,8 @@
from backend.app.schemas.user_schema import UserResponse as CurrentUserSchema # 用于 get_current_active_user 的返回类型
from backend.app.schemas.payment_schema import ( # 假设文件名是 payment_schema.py
SimulatedExternPaymentResponse,
- PaymentProcessingResponse, PaymentResponse
+ PaymentProcessingResponse,
+ PaymentResponse,
)
from backend.app.dependencies.service_deps import get_order_service, OrderService
@@ -24,14 +25,14 @@
"/{payment_transaction_id}/simulate-pay", # 路径: POST /api/v1/payment/{payment_transaction_id}/simulate-pay
response_model=PaymentProcessingResponse,
tags=["Payments"],
- summary="模拟用户确认支付并处理支付事务"
+ summary="模拟用户确认支付并处理支付事务",
)
async def simulate_payment_processing(
- payment_details_in: SimulatedExternPaymentResponse,
- payment_transaction_id: int = FastApiPath(..., description="要处理的系统内部支付事务ID。", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user), # 确保操作由已认证用户发起
- db: Connection = Depends(get_db_connection),
- order_service: OrderService = Depends(get_order_service)
+ payment_details_in: SimulatedExternPaymentResponse,
+ payment_transaction_id: int = FastApiPath(..., description="要处理的系统内部支付事务ID。", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user), # 确保操作由已认证用户发起
+ db: Connection = Depends(get_db_connection),
+ order_service: OrderService = Depends(get_order_service),
):
"""
模拟用户在其选择的支付方式上“确认支付”后的后端处理流程。
@@ -53,7 +54,7 @@ async def simulate_payment_processing(
db=db,
payment_transaction_id=payment_transaction_id,
external_gateway_tx_id=payment_details_in.ExternalGatewayTxID,
- actor_id=current_user.UserID
+ actor_id=current_user.UserID,
)
return payment_response
except PaymentTransactionNotFoundException as e:
@@ -64,11 +65,10 @@ async def simulate_payment_processing(
PaymentTransactionID=payment_transaction_id,
TransactionStatusInSystem="UNKNOWN_ERROR",
MessageToUser="支付处理失败,请稍后再试。",
- AffectedOrderIDs=None
+ AffectedOrderIDs=None,
)
-
# 您可能还需要一个端点来让前端查询支付事务的状态,
# 以便在用户等待支付页面或意外关闭页面后能够更新UI。
# 例如: GET /api/v1/payment/{payment_transaction_id}/status
@@ -76,13 +76,13 @@ async def simulate_payment_processing(
"/{payment_transaction_id}/status",
response_model=PaymentResponse, # 可以复用或创建一个更简单的状态响应模型
tags=["Payments"],
- summary="查询特定支付事务的状态"
+ summary="查询特定支付事务的状态",
)
async def get_payment_transaction_status(
- payment_transaction_id: int = FastApiPath(..., description="要查询状态的系统内部支付事务ID。", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user), # 确保只有相关用户能查询
- db: Connection = Depends(get_db_connection),
- order_service: OrderService = Depends(get_order_service)
+ payment_transaction_id: int = FastApiPath(..., description="要查询状态的系统内部支付事务ID。", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user), # 确保只有相关用户能查询
+ db: Connection = Depends(get_db_connection),
+ order_service: OrderService = Depends(get_order_service),
):
"""
查询特定支付事务的当前状态。
@@ -95,12 +95,15 @@ async def get_payment_transaction_status(
db=db,
payment_transaction_id=payment_transaction_id,
user_id=current_user.UserID,
- actor_id=current_user.UserID
+ actor_id=current_user.UserID,
)
return payment_status
except PaymentTransactionNotFoundException as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
- logger.exception(f"Error fetching payment transaction status for PaymentTransactionID {payment_transaction_id}: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to fetch payment transaction status.")
-
+ logger.exception(
+ f"Error fetching payment transaction status for PaymentTransactionID {payment_transaction_id}: {e}"
+ )
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to fetch payment transaction status."
+ )
diff --git a/src/backend/app/api/v1/endpoints/product.py b/src/backend/app/api/v1/endpoints/product.py
index 899891a..783207d 100644
--- a/src/backend/app/api/v1/endpoints/product.py
+++ b/src/backend/app/api/v1/endpoints/product.py
@@ -5,8 +5,15 @@
from decimal import Decimal
from backend.app.schemas.product_schema import (
- ProductCreate, ProductResponse, ProductUpdate, ProductWithCategoryInfo, ProductListParams,
- BaseProductQueryParams, ProductQueryParamsByCustomer, ProductQueryParamsByMerchant, ProductQueryParamsByAdmin
+ ProductCreate,
+ ProductResponse,
+ ProductUpdate,
+ ProductWithCategoryInfo,
+ ProductListParams,
+ BaseProductQueryParams,
+ ProductQueryParamsByCustomer,
+ ProductQueryParamsByMerchant,
+ ProductQueryParamsByAdmin,
)
from backend.app.dependencies import get_db_connection
from backend.app.services.product_service import get_product_service
@@ -20,10 +27,10 @@
@router.post("", response_model=ProductResponse)
async def create_product(
- product_in: ProductCreate,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None, # 这里应该由认证中间件提供
- product_service: ProductService = Depends(get_product_service),
+ product_in: ProductCreate,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None, # 这里应该由认证中间件提供
+ product_service: ProductService = Depends(get_product_service),
) -> ProductResponse:
"""
创建新商品
@@ -44,23 +51,20 @@ async def create_product(
product_description=product_in.ProductDescription,
stock_quantity=product_in.StockQuantity,
main_image_url=product_in.MainImageURL,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return ProductResponse(**created_product)
except Exception as e:
logger.error(f"创建商品时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="创建商品时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建商品时发生系统错误")
@router.get("/{product_id}", response_model=ProductWithCategoryInfo)
async def get_product(
- product_id: int,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ product_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> ProductWithCategoryInfo:
"""
获取特定商品的详细信息
@@ -74,47 +78,35 @@ async def get_product(
# 尝试使用带分类信息的查询
with db.begin_nested() if db.in_transaction() else db.begin():
product = await product_service.get_product_with_category_info(
- conn=db,
- product_id=product_id,
- actor_id=current_user_id
+ conn=db, product_id=product_id, actor_id=current_user_id
)
-
+
# 如果无法获取带分类信息的商品,尝试使用基本查询
if not product:
- product = await product_service.get_product_by_id(
- conn=db,
- product_id=product_id,
- actor_id=current_user_id
- )
-
+ product = await product_service.get_product_by_id(conn=db, product_id=product_id, actor_id=current_user_id)
+
if product:
# 添加一个默认的分类名称
product["CategoryName"] = "未知分类"
else:
# 如果两种方式都查不到商品,则返回404
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="商品不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="商品不存在")
+
return ProductWithCategoryInfo(**product)
except HTTPException:
# 直接重新抛出HTTPException异常
raise
except Exception as e:
logger.error(f"获取商品详情时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品详情时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品详情时发生系统错误")
@router.get("", response_model=List[ProductResponse])
async def list_products(
- query_params: ProductListParams = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ query_params: ProductListParams = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> List[ProductResponse]:
"""
获取商品列表,支持按店铺、分类筛选或搜索(保留兼容旧版接口)
@@ -133,23 +125,20 @@ async def list_products(
search=query_params.search,
limit=query_params.limit,
offset=query_params.offset,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return [ProductResponse(**product) for product in products]
except Exception as e:
logger.error(f"获取商品列表时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品列表时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品列表时发生系统错误")
@router.get("/customer", response_model=List[ProductResponse])
async def list_products_for_customer(
- query_params: ProductQueryParamsByCustomer = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ query_params: ProductQueryParamsByCustomer = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> List[ProductResponse]:
"""
获取面向普通用户的商品列表,只返回ACTIVE状态的商品
@@ -172,23 +161,20 @@ async def list_products_for_customer(
product_status="ACTIVE", # 普通用户只能看到活跃产品
limit=query_params.Limit,
offset=query_params.Offset,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return [ProductResponse(**product) for product in products]
except Exception as e:
logger.error(f"获取商品列表时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品列表时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品列表时发生系统错误")
@router.get("/merchant", response_model=List[ProductResponse])
async def list_products_for_merchant(
- query_params: ProductQueryParamsByMerchant = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ query_params: ProductQueryParamsByMerchant = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> List[ProductResponse]:
"""
获取面向商家的商品列表,支持按商品状态筛选
@@ -211,23 +197,20 @@ async def list_products_for_merchant(
product_status=query_params.ProductStatus,
limit=query_params.Limit,
offset=query_params.Offset,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return [ProductResponse(**product) for product in products]
except Exception as e:
logger.error(f"获取商品列表时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品列表时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品列表时发生系统错误")
@router.get("/admin", response_model=List[ProductResponse])
async def list_products_for_admin(
- query_params: ProductQueryParamsByAdmin = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ query_params: ProductQueryParamsByAdmin = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> List[ProductResponse]:
"""
获取面向管理员的商品列表,支持完整的筛选和排序功能
@@ -250,24 +233,21 @@ async def list_products_for_admin(
product_status=query_params.ProductStatus,
limit=query_params.Limit,
offset=query_params.Offset,
- actor_id=current_user_id
+ actor_id=current_user_id,
)
return [ProductResponse(**product) for product in products]
except Exception as e:
logger.error(f"获取商品列表时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品列表时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品列表时发生系统错误")
@router.put("/{product_id}", response_model=ProductResponse)
async def update_product(
- product_id: int,
- product_update: ProductUpdate,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ product_id: int,
+ product_update: ProductUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> ProductResponse:
"""
更新商品信息
@@ -284,36 +264,27 @@ async def update_product(
with db.begin_nested() if db.in_transaction() else db.begin():
updated_product = await product_service.update_product(
- conn=db,
- product_id=product_id,
- update_data=update_data,
- actor_id=current_user_id
+ conn=db, product_id=product_id, update_data=update_data, actor_id=current_user_id
)
-
+
if not updated_product:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="商品不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="商品不存在")
+
return ProductResponse(**updated_product)
except HTTPException:
raise
except Exception as e:
logger.error(f"更新商品时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新商品时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新商品时发生系统错误")
@router.put("/{product_id}/stock", response_model=ProductResponse)
async def update_product_stock(
- product_id: int,
- stock_change: int,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ product_id: int,
+ stock_change: int,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
) -> ProductResponse:
"""
更新商品库存
@@ -327,40 +298,28 @@ async def update_product_stock(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
updated_product = await product_service.update_product_stock(
- conn=db,
- product_id=product_id,
- stock_change=stock_change,
- actor_id=current_user_id
+ conn=db, product_id=product_id, stock_change=stock_change, actor_id=current_user_id
)
-
+
if not updated_product:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="商品不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="商品不存在")
+
return ProductResponse(**updated_product)
except InsufficientStockException:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="库存不足,无法完成操作"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="库存不足,无法完成操作")
except HTTPException:
raise
except Exception as e:
logger.error(f"更新商品库存时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新商品库存时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新商品库存时发生系统错误")
@router.delete("/{product_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_product(
- product_id: int,
- db: Connection = Depends(get_db_connection),
- current_user_id: Optional[int] = None,
- product_service: ProductService = Depends(get_product_service),
+ product_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user_id: Optional[int] = None,
+ product_service: ProductService = Depends(get_product_service),
):
"""
删除商品(将状态设置为DISCONTINUED)
@@ -372,39 +331,22 @@ async def delete_product(
try:
# 先检查商品是否存在
with db.begin_nested() if db.in_transaction() else db.begin():
- product = await product_service.get_product_by_id(
- conn=db,
- product_id=product_id,
- actor_id=current_user_id
- )
-
+ product = await product_service.get_product_by_id(conn=db, product_id=product_id, actor_id=current_user_id)
+
if not product:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="商品不存在"
- )
-
- result = await product_service.delete_product(
- conn=db,
- product_id=product_id,
- actor_id=current_user_id
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="商品不存在")
+
+ result = await product_service.delete_product(conn=db, product_id=product_id, actor_id=current_user_id)
+
# 手动提交事务,确保更改生效
db.commit()
-
+
if not result:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="删除商品失败"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除商品失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"删除商品时发生错误: {e}")
# 发生错误时回滚事务
db.rollback()
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="删除商品时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除商品时发生系统错误")
diff --git a/src/backend/app/api/v1/endpoints/product_change_request.py b/src/backend/app/api/v1/endpoints/product_change_request.py
index 7f07f9d..783b7f6 100644
--- a/src/backend/app/api/v1/endpoints/product_change_request.py
+++ b/src/backend/app/api/v1/endpoints/product_change_request.py
@@ -7,8 +7,11 @@
import inspect
from backend.app.schemas.product_change_request_schema import (
- ProductChangeRequestCreate, ProductChangeRequestUpdate, ProductChangeRequestResponse,
- ProductChangeRequestAdminUpdate, ProductChangeRequestQueryParams
+ ProductChangeRequestCreate,
+ ProductChangeRequestUpdate,
+ ProductChangeRequestResponse,
+ ProductChangeRequestAdminUpdate,
+ ProductChangeRequestQueryParams,
)
from backend.app.dependencies import get_db_connection
from backend.app.dependencies import get_product_change_request_service
@@ -28,9 +31,9 @@ def format_response_data(data: Dict) -> Dict:
"""
if not data:
return {}
-
+
result = {}
-
+
for key, value in data.items():
# 处理日期字段
if isinstance(value, datetime.datetime):
@@ -46,7 +49,7 @@ def format_response_data(data: Dict) -> Dict:
result[key] = value
else:
result[key] = value
-
+
# 确保返回的是浅拷贝,而不是原始数据的引用
return dict(result)
@@ -55,14 +58,14 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
"""
专为测试环境设计的响应格式化函数
确保响应格式与测试样例完全一致
-
+
:param data: 原始响应数据,可能是字典或列表
:param is_test_environment: 是否为测试环境
:return: 格式化后的响应数据
"""
# 使用固定的时间戳,确保与测试中的sample_change_request一致
fixed_timestamp = "2023-01-01 12:00:00.000000"
-
+
# 直接使用硬编码的数据
sample_change_request = {
"ChangeRequestID": 1,
@@ -70,16 +73,20 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
"MerchantUserID": 1,
"StoreID": 1,
"RequestType": "PRODUCT_UPDATE",
- "ProposedData_JSON": {"ProductName": "Updated Product", "ProductDescription": "Updated Description", "Price": 199.99},
+ "ProposedData_JSON": {
+ "ProductName": "Updated Product",
+ "ProductDescription": "Updated Description",
+ "Price": 199.99,
+ },
"Status": "PENDING_APPROVAL",
"SubmitterNotes": "Please approve my product update",
"AdminReviewerID": None,
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": fixed_timestamp,
- "LastUpdatedDate": fixed_timestamp
+ "LastUpdatedDate": fixed_timestamp,
}
-
+
# 无论输入是什么,总是返回固定的测试对象
if isinstance(data, list):
return [sample_change_request]
@@ -88,10 +95,10 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
@router.post("", response_model=ProductChangeRequestResponse)
async def create_product_change_request(
- change_request_in: ProductChangeRequestCreate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ change_request_in: ProductChangeRequestCreate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> ProductChangeRequestResponse:
"""
创建新商品变更请求
@@ -104,8 +111,8 @@ async def create_product_change_request(
try:
# 获取当前用户ID
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
created_request = await change_request_service.create_change_request(
conn=db,
merchant_user_id=change_request_in.MerchantUserID,
@@ -114,27 +121,24 @@ async def create_product_change_request(
proposed_data=change_request_in.ProposedData_JSON,
product_id=change_request_in.ProductID,
submitter_notes=change_request_in.SubmitterNotes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(created_request, True)
-
+
return ProductChangeRequestResponse(**formatted_response)
except Exception as e:
logger.error(f"创建商品变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="创建商品变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建商品变更请求时发生系统错误")
@router.get("/{request_id}", response_model=ProductChangeRequestResponse)
async def get_product_change_request(
- request_id: int,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ request_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> ProductChangeRequestResponse:
"""
获取特定商品变更请求的详细信息
@@ -146,23 +150,18 @@ async def get_product_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
change_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not change_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(change_request, True)
-
+
return ProductChangeRequestResponse(**formatted_response)
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -170,20 +169,19 @@ async def get_product_change_request(
except Exception as e:
logger.error(f"获取商品变更请求详情时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品变更请求详情时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品变更请求详情时发生系统错误"
)
@router.get("/by-product/{product_id}", response_model=List[ProductChangeRequestResponse])
async def get_product_change_requests_by_product(
- product_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ product_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> List[ProductChangeRequestResponse]:
"""
获取指定商品的变更请求列表
@@ -198,38 +196,32 @@ async def get_product_change_requests_by_product(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_change_requests_by_product_id(
- conn=db,
- product_id=product_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=db, product_id=product_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [ProductChangeRequestResponse(**item) for item in formatted_results]
except Exception as e:
logger.error(f"获取商品变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商品变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商品变更请求列表时发生系统错误"
)
@router.get("/by-store/{store_id}", response_model=List[ProductChangeRequestResponse])
async def get_product_change_requests_by_store(
- store_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ store_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> List[ProductChangeRequestResponse]:
"""
获取指定店铺的商品变更请求列表
@@ -244,38 +236,32 @@ async def get_product_change_requests_by_store(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_change_requests_by_store_id(
- conn=db,
- store_id=store_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=db, store_id=store_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [ProductChangeRequestResponse(**item) for item in formatted_results]
except Exception as e:
logger.error(f"获取店铺商品变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取店铺商品变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取店铺商品变更请求列表时发生系统错误"
)
@router.get("/by-merchant/{merchant_id}", response_model=List[ProductChangeRequestResponse])
async def get_product_change_requests_by_merchant(
- merchant_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ merchant_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> List[ProductChangeRequestResponse]:
"""
获取指定商家的商品变更请求列表
@@ -290,36 +276,30 @@ async def get_product_change_requests_by_merchant(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_change_requests_by_merchant_id(
- conn=db,
- merchant_id=merchant_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=db, merchant_id=merchant_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [ProductChangeRequestResponse(**item) for item in formatted_results]
except Exception as e:
logger.error(f"获取商家商品变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取商家商品变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取商家商品变更请求列表时发生系统错误"
)
@router.get("/admin/pending", response_model=List[ProductChangeRequestResponse])
async def get_all_pending_product_change_requests(
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> List[ProductChangeRequestResponse]:
"""
获取所有待审核的商品变更请求列表(管理员专用)
@@ -334,22 +314,19 @@ async def get_all_pending_product_change_requests(
# 包括测试中的管理员用户和非管理员用户
# 测试test_non_admin_cannot_access_admin_endpoints通过mock_service.get_all_pending_requests
# 设置为抛出PermissionDeniedException来实现权限检查
-
+
try:
# 为了满足测试期望,这里固定使用actor_id=2
actor_id = 2
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_all_pending_requests(
- conn=db,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=db, limit=limit, offset=offset, actor_id=actor_id
)
-
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [ProductChangeRequestResponse(**item) for item in formatted_results]
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -358,23 +335,19 @@ async def get_all_pending_product_change_requests(
logger.error(f"获取待审核商品变更请求列表时发生错误: {e}")
# 检查是否为权限错误
if "Only admin can access this endpoint" in str(e):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="只有管理员可以访问此端点"
- )
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="只有管理员可以访问此端点")
# 其他错误返回500
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取待审核商品变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取待审核商品变更请求列表时发生系统错误"
)
@router.get("", response_model=List[ProductChangeRequestResponse])
async def filter_product_change_requests(
- query_params: ProductChangeRequestQueryParams = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ query_params: ProductChangeRequestQueryParams = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> List[ProductChangeRequestResponse]:
"""
过滤商品变更请求列表
@@ -385,14 +358,14 @@ async def filter_product_change_requests(
:return: 变更请求列表
"""
# 权限检查已被移除,允许所有用户访问
-
+
try:
actor_id = current_user.get("UserID")
-
+
# 完全移除权限检查,允许测试用例正常访问
# 权限检查代码已移除,将在后续实现
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_filtered_requests(
conn=db,
product_id=query_params.ProductID,
@@ -405,12 +378,12 @@ async def filter_product_change_requests(
end_date=query_params.EndDate,
limit=query_params.Limit,
offset=query_params.Offset,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [ProductChangeRequestResponse(**item) for item in formatted_results]
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -418,18 +391,17 @@ async def filter_product_change_requests(
except Exception as e:
logger.error(f"过滤商品变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="过滤商品变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="过滤商品变更请求列表时发生系统错误"
)
@router.put("/{request_id}", response_model=ProductChangeRequestResponse)
async def update_product_change_request(
- request_id: int,
- update_data: ProductChangeRequestUpdate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ request_id: int,
+ update_data: ProductChangeRequestUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> ProductChangeRequestResponse:
"""
更新商品变更请求信息
@@ -442,65 +414,51 @@ async def update_product_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查请求状态,只有PENDING_APPROVAL状态的请求可以被更新
if current_request.get("Status") != "PENDING_APPROVAL":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被更新"
- )
-
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被更新")
+
# 为了匹配测试中的期望,我们需要构造一个与测试中完全相同的数据对象
# 在测试中,请求数据是 {"ProposedData": {...}, "SubmitterNotes": "..."}
request_data = {
"ProposedData": {"ProductName": "Updated Product Name", "Price": 299.99},
- "SubmitterNotes": "Additional notes after update"
+ "SubmitterNotes": "Additional notes after update",
}
-
+
# 直接使用硬编码的测试数据,确保与测试中的期望完全一致
updated_request = await change_request_service.update_request(
- conn=db,
- request_id=request_id,
- data=request_data,
- actor_id=actor_id
+ conn=db, request_id=request_id, data=request_data, actor_id=actor_id
)
-
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(updated_request, True)
-
+
return ProductChangeRequestResponse(**formatted_response)
except HTTPException:
# 直接重新抛出HTTPException异常
raise
except Exception as e:
logger.error(f"更新商品变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新商品变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新商品变更请求时发生系统错误")
@router.put("/{request_id}/admin", response_model=ProductChangeRequestResponse)
async def admin_update_product_change_request(
- request_id: int,
- admin_update: ProductChangeRequestAdminUpdate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ request_id: int,
+ admin_update: ProductChangeRequestAdminUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
) -> ProductChangeRequestResponse:
"""
管理员审核商品变更请求
@@ -512,60 +470,49 @@ async def admin_update_product_change_request(
:return: 更新后的变更请求信息
"""
# 权限检查已被移除,允许所有用户访问
-
+
try:
actor_id = current_user.get("UserID")
-
+
# 权限检查代码已移除,将在后续实现
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查请求状态,只有PENDING_APPROVAL状态的请求可以被审核
if current_request.get("Status") != "PENDING_APPROVAL":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被审核"
- )
-
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被审核")
+
# 为了匹配测试期望,使用硬编码的方式来构造update_dict
# 不再使用Admin_update中的ReviewTimestamp字段,而是让数据库使用当前时间
update_dict = {
"Status": "APPROVED",
"AdminReviewerID": 2,
- "AdminNotes": "Approved after review"
+ "AdminNotes": "Approved after review",
# 不包含ReviewTimestamp,让数据库使用NOW()
}
-
+
# 使用data参数调用update_request_status
updated_request = await change_request_service.update_request_status(
- conn=db,
- request_id=request_id,
- data=update_dict,
- actor_id=actor_id
+ conn=db, request_id=request_id, data=update_dict, actor_id=actor_id
)
-
+
# 如果变更请求被批准,记录日志(仅用于测试)
if update_dict["Status"] == "APPROVED":
logger.info(f"管理员 {actor_id} 批准了商品变更请求 {request_id},准备应用变更")
-
+
# 注意:在测试中,mock_service没有apply_approved_change方法,因此不能调用
# 在实际生产环境中,这里需要添加调用apply_approved_change的代码
-
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(updated_request, True)
-
+
return ProductChangeRequestResponse(**formatted_response)
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -577,17 +524,16 @@ async def admin_update_product_change_request(
logger.error(f"更新数据: {update_dict}")
logger.error(f"当前请求: {current_request}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="管理员审核商品变更请求时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="管理员审核商品变更请求时发生系统错误"
)
@router.delete("/{request_id}", status_code=status.HTTP_204_NO_CONTENT)
async def cancel_product_change_request(
- request_id: int,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
+ request_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: ProductChangeRequestService = Depends(get_product_change_request_service),
):
"""
取消商品变更请求(商家自行取消)
@@ -598,45 +544,30 @@ async def cancel_product_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查是否当前商家的请求(可选,取决于业务需求)
# if current_request.get("MerchantUserID") != actor_id:
# raise HTTPException(
# status_code=status.HTTP_403_FORBIDDEN,
# detail="无权取消该请求"
# )
-
- result = await change_request_service.cancel_request(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
- )
-
+
+ result = await change_request_service.cancel_request(conn=db, request_id=request_id, actor_id=actor_id)
+
if not result:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被取消"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被取消")
except HTTPException:
# 直接重新抛出HTTPException异常
raise
except Exception as e:
logger.error(f"取消商品变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="取消商品变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="取消商品变更请求时发生系统错误")
diff --git a/src/backend/app/api/v1/endpoints/product_change_request_v2.py b/src/backend/app/api/v1/endpoints/product_change_request_v2.py
index a9e4b55..5270ae3 100644
--- a/src/backend/app/api/v1/endpoints/product_change_request_v2.py
+++ b/src/backend/app/api/v1/endpoints/product_change_request_v2.py
@@ -17,13 +17,18 @@
ProductChangeRequestListResponse,
ProductChangeRequestUpdateByAdmin,
ProductChangeRequestQueryParams,
- ProductChangeRequestStatusApiEnum, ProductChangeRequestTypeApiEnum
+ ProductChangeRequestStatusApiEnum,
+ ProductChangeRequestTypeApiEnum,
)
from sqlalchemy.engine.base import Connection
from backend.app.utils import logger # 假设的 logger
-from backend.app.utils.exceptions import StoreNotFoundException, ProductNotFoundException, BadRequestException, \
- PermissionDeniedException
+from backend.app.utils.exceptions import (
+ StoreNotFoundException,
+ ProductNotFoundException,
+ BadRequestException,
+ PermissionDeniedException,
+)
router = APIRouter()
@@ -32,7 +37,7 @@
"/",
response_model=ProductChangeRequestResponse,
status_code=status.HTTP_201_CREATED,
- summary="商家提交新的商品变更请求"
+ summary="商家提交新的商品变更请求",
)
async def submit_product_change_request(
request_in: ProductChangeRequestCreate,
@@ -90,15 +95,13 @@ async def submit_product_change_request(
@router.get(
- "/list/",
- response_model=ProductChangeRequestListResponse,
- summary="商家查询自己的商品变更请求列表 (可按状态等筛选)"
+ "/list/", response_model=ProductChangeRequestListResponse, summary="商家查询自己的商品变更请求列表 (可按状态等筛选)"
)
async def list_product_change_requests(
- query_params: ProductChangeRequestQueryParams = Depends(),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2)
+ query_params: ProductChangeRequestQueryParams = Depends(),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2),
):
"""
获取商品变更请求列表。
@@ -107,14 +110,14 @@ async def list_product_change_requests(
- 查询参数通过 `ProductChangeRequestQueryParams` Pydantic 模型接收。
- **注意**: 您之前提到查询不需要分页。
"""
- logger.info(f"User {current_user.UserID} listing product change requests with filters: {query_params.model_dump(exclude_unset=True)}")
+ logger.info(
+ f"User {current_user.UserID} listing product change requests with filters: {query_params.model_dump(exclude_unset=True)}"
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里调用服务层的创建方法
change_request = await service.list_requests_for_merchant(
- db,
- query_params=query_params,
- merchant_user=current_user
+ db, query_params=query_params, merchant_user=current_user
)
return change_request
except StoreNotFoundException as e:
@@ -148,16 +151,15 @@ async def list_product_change_requests(
detail="An unexpected error occurred.",
)
+
@router.get(
- "/list-admin/",
- response_model=ProductChangeRequestListResponse,
- summary="管理员商品变更请求列表 (可按状态等筛选)"
+ "/list-admin/", response_model=ProductChangeRequestListResponse, summary="管理员商品变更请求列表 (可按状态等筛选)"
)
async def list_product_change_requests_admin(
- query_params: ProductChangeRequestQueryParams = Depends(),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2)
+ query_params: ProductChangeRequestQueryParams = Depends(),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2),
):
"""
获取商品变更请求列表。
@@ -166,14 +168,14 @@ async def list_product_change_requests_admin(
- 查询参数通过 `ProductChangeRequestQueryParams` Pydantic 模型接收。
- **注意**: 您之前提到查询不需要分页。
"""
- logger.info(f"Admin {current_user.UserID} fetching all product change requests with filters: {query_params.model_dump(exclude_unset=True)}")
+ logger.info(
+ f"Admin {current_user.UserID} fetching all product change requests with filters: {query_params.model_dump(exclude_unset=True)}"
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里调用服务层的创建方法
change_request = await service.list_requests_for_admin(
- db,
- query_params=query_params,
- admin_user=current_user
+ db, query_params=query_params, admin_user=current_user
)
return change_request
except StoreNotFoundException as e:
@@ -207,16 +209,13 @@ async def list_product_change_requests_admin(
detail="An unexpected error occurred.",
)
-@router.get(
- "/{change_request_id}",
- response_model=ProductChangeRequestResponse,
- summary="获取单个商品变更请求的详情"
-)
+
+@router.get("/{change_request_id}", response_model=ProductChangeRequestResponse, summary="获取单个商品变更请求的详情")
async def get_product_change_request_details(
- change_request_id: int = FastApiPath(..., description="要检索的变更请求的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2)
+ change_request_id: int = FastApiPath(..., description="要检索的变更请求的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2),
):
"""
根据 ChangeRequestID 获取单个商品变更请求的详细信息。
@@ -227,9 +226,7 @@ async def get_product_change_request_details(
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里调用服务层的创建方法
change_request = await service.get_request_details(
- db,
- change_request_id=change_request_id,
- actor_user=current_user
+ db, change_request_id=change_request_id, actor_user=current_user
)
return change_request
except StoreNotFoundException as e:
@@ -265,29 +262,27 @@ async def get_product_change_request_details(
@router.post(
- "/{change_request_id}/review",
- response_model=ProductChangeRequestResponse,
- summary="管理员审核商品变更请求"
+ "/{change_request_id}/review", response_model=ProductChangeRequestResponse, summary="管理员审核商品变更请求"
)
async def admin_review_product_change_request(
- review_data: ProductChangeRequestUpdateByAdmin,
- change_request_id: int = FastApiPath(..., description="要审核的变更请求的唯一ID", gt=0),
- admin_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2)
+ review_data: ProductChangeRequestUpdateByAdmin,
+ change_request_id: int = FastApiPath(..., description="要审核的变更请求的唯一ID", gt=0),
+ admin_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2),
):
"""
管理员审核商品变更请求,可以将其状态更新为 'APPROVED' 或 'REJECTED',并添加审核备注。
"""
# 实际应用中,admin_user 应通过专门的 get_current_admin_user 依赖注入,该依赖会进行权限检查
logger.info(
- f"Admin {admin_user.UserID} reviewing ProductChangeRequestID: {change_request_id} with review: {review_data.model_dump(exclude_unset=True)}")
+ f"Admin {admin_user.UserID} reviewing ProductChangeRequestID: {change_request_id} with review: {review_data.model_dump(exclude_unset=True)}"
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里调用服务层的创建方法
change_request = await service.admin_review_request(
- db,
- change_request_id=change_request_id, admin_user=admin_user, review_data=review_data
+ db, change_request_id=change_request_id, admin_user=admin_user, review_data=review_data
)
return change_request
except StoreNotFoundException as e:
@@ -321,16 +316,15 @@ async def admin_review_product_change_request(
detail="An unexpected error occurred.",
)
+
@router.delete(
- "/{change_request_id}",
- status_code=status.HTTP_204_NO_CONTENT,
- summary="删除商品变更请求 (如果业务逻辑允许)"
+ "/{change_request_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除商品变更请求 (如果业务逻辑允许)"
)
async def delete_product_change_request(
- change_request_id: int = FastApiPath(..., description="要删除的变更请求的唯一ID", gt=0),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2)
+ change_request_id: int = FastApiPath(..., description="要删除的变更请求的唯一ID", gt=0),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ service: ProductChangeRequestService2 = Depends(get_product_change_request_service_v2),
):
"""
删除一个商品变更请求。
@@ -339,15 +333,14 @@ async def delete_product_change_request(
如果管理员可以硬删除,则服务层需要实现该逻辑。
"""
logger.warning(
- f"User {current_user.UserID} attempting to DELETE ProductChangeRequestID: {change_request_id}. Business logic for deletion needs clarification.")
+ f"User {current_user.UserID} attempting to DELETE ProductChangeRequestID: {change_request_id}. Business logic for deletion needs clarification."
+ )
try:
with db.begin_nested() if db.in_transaction() else db.begin():
# 这里调用服务层的创建方法
change_request = await service.merchant_cancel_request(
- db,
- change_request_id=change_request_id,
- merchant_user=current_user
+ db, change_request_id=change_request_id, merchant_user=current_user
)
return change_request
except StoreNotFoundException as e:
diff --git a/src/backend/app/api/v1/endpoints/statistics.py b/src/backend/app/api/v1/endpoints/statistics.py
index c43f4e4..e7e3bf5 100644
--- a/src/backend/app/api/v1/endpoints/statistics.py
+++ b/src/backend/app/api/v1/endpoints/statistics.py
@@ -6,11 +6,7 @@
from backend.app.services.statistics_service import StatisticsService
from backend.app.dependencies.service_deps import get_statistics_service
from backend.app.dependencies.db_deps import get_db_connection
-from backend.app.schemas.statistics_schema import (
- SystemStatistics,
- AdminDashboardStatistics,
- StoreStatistics
-)
+from backend.app.schemas.statistics_schema import SystemStatistics, AdminDashboardStatistics, StoreStatistics
router = APIRouter()
@@ -19,7 +15,7 @@
async def get_system_statistics(
statistics_service: StatisticsService = Depends(get_statistics_service),
conn: Connection = Depends(get_db_connection),
- _=Depends(get_current_admin)
+ _=Depends(get_current_admin),
):
"""
Get overall system statistics (admin only).
@@ -31,7 +27,7 @@ async def get_system_statistics(
async def get_admin_dashboard_statistics(
statistics_service: StatisticsService = Depends(get_statistics_service),
conn: Connection = Depends(get_db_connection),
- _=Depends(get_current_admin)
+ _=Depends(get_current_admin),
):
"""
Get admin dashboard statistics (admin only).
@@ -43,7 +39,7 @@ async def get_admin_dashboard_statistics(
async def get_all_store_statistics(
statistics_service: StatisticsService = Depends(get_statistics_service),
conn: Connection = Depends(get_db_connection),
- _=Depends(get_current_admin)
+ _=Depends(get_current_admin),
):
"""
Get statistics for all stores (admin only).
@@ -56,7 +52,7 @@ async def get_specific_store_statistics(
store_id: int,
statistics_service: StatisticsService = Depends(get_statistics_service),
conn: Connection = Depends(get_db_connection),
- current_user = Depends(get_current_active_user)
+ current_user=Depends(get_current_active_user),
):
"""
Get statistics for a specific store.
@@ -66,27 +62,18 @@ async def get_specific_store_statistics(
# Check if user is merchant and has access to this store
if current_user.UserRole.lower() == "merchant":
# Check if user has a store_id attribute
- if not hasattr(current_user, 'StoreID') or current_user.StoreID is None:
- raise HTTPException(
- status_code=403,
- detail="Merchant without a store cannot access statistics"
- )
-
+ if not hasattr(current_user, "StoreID") or current_user.StoreID is None:
+ raise HTTPException(status_code=403, detail="Merchant without a store cannot access statistics")
+
# Direct comparison of merchant's StoreID with the requested store_id
if current_user.StoreID != store_id:
- raise HTTPException(
- status_code=403,
- detail="You don't have permission to access this store's statistics"
- )
+ raise HTTPException(status_code=403, detail="You don't have permission to access this store's statistics")
# Non-admin, non-merchant users cannot access store statistics
elif current_user.UserRole.lower() != "admin":
- raise HTTPException(
- status_code=403,
- detail="Only admin and merchant users can access store statistics"
- )
-
+ raise HTTPException(status_code=403, detail="Only admin and merchant users can access store statistics")
+
store_stats = await statistics_service.get_store_statistics(conn=conn, store_id=store_id)
if not store_stats:
raise HTTPException(status_code=404, detail="Store not found")
-
+
return store_stats[0]
diff --git a/src/backend/app/api/v1/endpoints/store.py b/src/backend/app/api/v1/endpoints/store.py
index ccdfc65..8c8fb03 100644
--- a/src/backend/app/api/v1/endpoints/store.py
+++ b/src/backend/app/api/v1/endpoints/store.py
@@ -33,10 +33,10 @@
summary="获取店铺信息。当前默认都是普通顾客,只能查看ACTIVE的店铺",
)
async def get_store_info(
- store_id: int = FastApiPath(..., description="店铺 ID"),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- store_service: StoreService =Depends(get_store_service),
+ store_id: int = FastApiPath(..., description="店铺 ID"),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ store_service: StoreService = Depends(get_store_service),
):
"""
获取指定店铺的信息。
@@ -44,24 +44,15 @@ async def get_store_info(
logger.info(f"Fetching store {store_id} info for user {current_user.UserID}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
- resp = await store_service.user_get_store_by_id(
- db,
- store_id=store_id,
- actor_id=current_user.UserID
- )
+ resp = await store_service.user_get_store_by_id(db, store_id=store_id, actor_id=current_user.UserID)
return resp
except StoreNotFoundException as e:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"Store with ID {store_id} not found: {str(e)}"
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Store with ID {store_id} not found: {str(e)}"
)
except Exception as e:
logger.error(f"Error fetching store {store_id} info: {str(e)}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal server error"
- )
-
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.get(
@@ -72,10 +63,10 @@ async def get_store_info(
summary="获取店主的所有店铺",
)
async def get_stores_by_owner(
- owner_user_id: int = FastApiPath(..., description="店主用户 ID"),
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- store_service: StoreService = Depends(get_store_service),
+ owner_user_id: int = FastApiPath(..., description="店主用户 ID"),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ store_service: StoreService = Depends(get_store_service),
):
"""
获取指定店主的所有店铺。
@@ -84,23 +75,16 @@ async def get_stores_by_owner(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
resp = await store_service.get_stores_by_owner(
- db,
- owner_user_id=owner_user_id,
- actor_id=current_user.UserID,
- no_offset_and_limit=True
+ db, owner_user_id=owner_user_id, actor_id=current_user.UserID, no_offset_and_limit=True
)
return resp
except StoreNotFoundException as e:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"Store with Owner ID {owner_user_id} not found: {str(e)}"
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Store with Owner ID {owner_user_id} not found: {str(e)}"
)
except Exception as e:
logger.error(f"Error fetching stores for owner {owner_user_id}: {str(e)}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.get(
@@ -111,9 +95,9 @@ async def get_stores_by_owner(
summary="获取所有店铺的简要信息。只能查看ACTIVE的店铺",
)
async def get_all_stores_simple(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- store_service: StoreService = Depends(get_store_service),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ store_service: StoreService = Depends(get_store_service),
):
"""
获取所有店铺的简要信息。
@@ -121,16 +105,12 @@ async def get_all_stores_simple(
logger.info("Fetching all stores' simple info")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
- resp = await store_service.user_get_stores_simple(
- db, offset_and_limit=None
- )
+ resp = await store_service.user_get_stores_simple(db, offset_and_limit=None)
return resp
except Exception as e:
logger.error(f"Error fetching all stores' simple info: {str(e)}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
+
@router.get(
"/list-full",
@@ -140,9 +120,9 @@ async def get_all_stores_simple(
summary="获取所有店铺的完整信息。只能查看ACTIVE的店铺",
)
async def get_all_stores(
- current_user: CurrentUserSchema = Depends(get_current_active_user),
- db: Connection = Depends(get_db_connection),
- store_service: StoreService = Depends(get_store_service),
+ current_user: CurrentUserSchema = Depends(get_current_active_user),
+ db: Connection = Depends(get_db_connection),
+ store_service: StoreService = Depends(get_store_service),
):
"""
获取所有店铺的信息。
@@ -150,13 +130,8 @@ async def get_all_stores(
logger.info(f"Fetching all stores' full info for user {current_user.UserID}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
- resp = await store_service.user_get_all_stores_full(
- db, offset_and_limit=None, actor_id=current_user.UserID
- )
+ resp = await store_service.user_get_all_stores_full(db, offset_and_limit=None, actor_id=current_user.UserID)
return resp
except Exception as e:
logger.error(f"Error fetching all stores' full info: {str(e)}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
diff --git a/src/backend/app/api/v1/endpoints/store_change_request.py b/src/backend/app/api/v1/endpoints/store_change_request.py
index 86b2275..fe7967e 100644
--- a/src/backend/app/api/v1/endpoints/store_change_request.py
+++ b/src/backend/app/api/v1/endpoints/store_change_request.py
@@ -7,8 +7,11 @@
import inspect
from backend.app.schemas.store_change_request_schema import (
- StoreChangeRequestCreate, StoreChangeRequestUpdate, StoreChangeRequestResponse,
- StoreChangeRequestAdminUpdate, StoreChangeRequestQueryParams
+ StoreChangeRequestCreate,
+ StoreChangeRequestUpdate,
+ StoreChangeRequestResponse,
+ StoreChangeRequestAdminUpdate,
+ StoreChangeRequestQueryParams,
)
from backend.app.dependencies import get_db_connection
from backend.app.dependencies import get_store_change_request_service
@@ -28,9 +31,9 @@ def format_response_data(data: Dict) -> Dict:
"""
if not data:
return {}
-
+
result = {}
-
+
for key, value in data.items():
# 处理日期字段
if isinstance(value, datetime.datetime):
@@ -46,7 +49,7 @@ def format_response_data(data: Dict) -> Dict:
result[key] = value
else:
result[key] = value
-
+
# 确保返回的是浅拷贝,而不是原始数据的引用
return dict(result)
@@ -55,14 +58,14 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
"""
专为测试环境设计的响应格式化函数
确保响应格式与测试样例完全一致
-
+
:param data: 原始响应数据,可能是字典或列表
:param is_test_environment: 是否为测试环境
:return: 格式化后的响应数据
"""
# 使用固定的时间戳,确保与测试中的sample_change_request一致
fixed_timestamp = "2023-01-01 12:00:00.000000"
-
+
# 直接使用硬编码的数据
sample_change_request = {
"ChangeRequestID": 1,
@@ -76,9 +79,9 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": fixed_timestamp,
- "LastUpdatedDate": fixed_timestamp
+ "LastUpdatedDate": fixed_timestamp,
}
-
+
# 无论输入是什么,总是返回固定的测试对象
if isinstance(data, list):
return [sample_change_request]
@@ -87,10 +90,10 @@ def mock_response_formatter(data: Union[Dict, List], is_test_environment: bool =
@router.post("", response_model=StoreChangeRequestResponse)
async def create_store_change_request(
- change_request_in: StoreChangeRequestCreate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ change_request_in: StoreChangeRequestCreate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> StoreChangeRequestResponse:
"""
创建新店铺变更请求
@@ -103,8 +106,8 @@ async def create_store_change_request(
try:
# 获取当前用户ID
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
created_request = await change_request_service.create_change_request(
conn=db,
requesting_user_id=change_request_in.RequestingUserID,
@@ -112,27 +115,24 @@ async def create_store_change_request(
proposed_data=change_request_in.ProposedData_JSON,
store_id=change_request_in.StoreID,
submitter_notes=change_request_in.SubmitterNotes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(created_request, True)
-
+
return StoreChangeRequestResponse(**formatted_response)
except Exception as e:
logger.error(f"创建店铺变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="创建店铺变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建店铺变更请求时发生系统错误")
@router.get("/{request_id}", response_model=StoreChangeRequestResponse)
async def get_store_change_request(
- request_id: int,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ request_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> StoreChangeRequestResponse:
"""
获取特定店铺变更请求的详细信息
@@ -144,22 +144,17 @@ async def get_store_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
+
change_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not change_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(change_request, True)
-
+
return StoreChangeRequestResponse(**formatted_response)
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -167,20 +162,19 @@ async def get_store_change_request(
except Exception as e:
logger.error(f"获取店铺变更请求详情时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取店铺变更请求详情时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取店铺变更请求详情时发生系统错误"
)
@router.get("/by-store/{store_id}", response_model=List[StoreChangeRequestResponse])
async def get_store_change_requests_by_store(
- store_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ store_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> List[StoreChangeRequestResponse]:
"""
获取指定店铺的变更请求列表
@@ -195,38 +189,32 @@ async def get_store_change_requests_by_store(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_change_requests_by_store_id(
- conn=db,
- store_id=store_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
- )
-
+ conn=db, store_id=store_id, status=status, limit=limit, offset=offset, actor_id=actor_id
+ )
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [StoreChangeRequestResponse(**item) for item in formatted_results]
except Exception as e:
logger.error(f"获取店铺变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取店铺变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取店铺变更请求列表时发生系统错误"
)
@router.get("/by-user/{user_id}", response_model=List[StoreChangeRequestResponse])
async def get_store_change_requests_by_user(
- user_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ user_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> List[StoreChangeRequestResponse]:
"""
获取指定用户的店铺变更请求列表
@@ -241,36 +229,30 @@ async def get_store_change_requests_by_user(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_change_requests_by_user_id(
- conn=db,
- user_id=user_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
- )
-
+ conn=db, user_id=user_id, status=status, limit=limit, offset=offset, actor_id=actor_id
+ )
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [StoreChangeRequestResponse(**item) for item in formatted_results]
except Exception as e:
logger.error(f"获取用户店铺变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取用户店铺变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取用户店铺变更请求列表时发生系统错误"
)
@router.get("/admin/pending", response_model=List[StoreChangeRequestResponse])
async def get_all_pending_store_change_requests(
- limit: int = 100,
- offset: int = 0,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ limit: int = 100,
+ offset: int = 0,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> List[StoreChangeRequestResponse]:
"""
获取所有待审核的店铺变更请求列表(管理员专用)
@@ -282,30 +264,24 @@ async def get_all_pending_store_change_requests(
:return: 变更请求列表
"""
# 权限检查已被移除,允许所有用户访问
-
+
try:
# 检查用户角色,专门处理测试用例test_non_admin_cannot_access_admin_endpoints
if current_user.get("Role") == "customer":
# 非管理员用户,返回权限错误以通过测试
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="只有管理员可以访问此端点"
- )
-
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="只有管理员可以访问此端点")
+
# 固定使用actor_id=2,与product版本保持一致
actor_id = 2
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_all_pending_requests(
- conn=db,
- limit=limit,
- offset=offset,
- actor_id=actor_id
- )
-
+ conn=db, limit=limit, offset=offset, actor_id=actor_id
+ )
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [StoreChangeRequestResponse(**item) for item in formatted_results]
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -314,23 +290,19 @@ async def get_all_pending_store_change_requests(
logger.error(f"获取待审核店铺变更请求列表时发生错误: {e}")
# 检查是否为权限错误
if "Only admin can access this endpoint" in str(e):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="只有管理员可以访问此端点"
- )
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="只有管理员可以访问此端点")
# 其他错误返回500
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="获取待审核店铺变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取待审核店铺变更请求列表时发生系统错误"
)
@router.get("", response_model=List[StoreChangeRequestResponse])
async def filter_store_change_requests(
- query_params: StoreChangeRequestQueryParams = Depends(),
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ query_params: StoreChangeRequestQueryParams = Depends(),
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> List[StoreChangeRequestResponse]:
"""
过滤店铺变更请求列表
@@ -341,31 +313,31 @@ async def filter_store_change_requests(
:return: 变更请求列表
"""
# 权限检查已被移除,允许所有用户访问
-
+
try:
actor_id = current_user.get("UserID")
-
+
# 完全移除权限检查,允许测试用例正常访问
# 权限检查代码已移除,将在后续实现
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
results = await change_request_service.get_filtered_requests(
- conn=db,
- store_id=query_params.StoreID,
- user_id=query_params.RequestingUserID,
- request_type=query_params.RequestType,
- status=query_params.Status,
- admin_id=query_params.AdminReviewerID,
- start_date=query_params.StartDate,
- end_date=query_params.EndDate,
- limit=query_params.Limit,
- offset=query_params.Offset,
- actor_id=actor_id
- )
-
+ conn=db,
+ store_id=query_params.StoreID,
+ user_id=query_params.RequestingUserID,
+ request_type=query_params.RequestType,
+ status=query_params.Status,
+ admin_id=query_params.AdminReviewerID,
+ start_date=query_params.StartDate,
+ end_date=query_params.EndDate,
+ limit=query_params.Limit,
+ offset=query_params.Offset,
+ actor_id=actor_id,
+ )
+
# 使用特殊的响应格式化函数
formatted_results = mock_response_formatter(results, True)
-
+
return [StoreChangeRequestResponse(**item) for item in formatted_results]
except HTTPException:
# 直接重新抛出HTTPException异常
@@ -373,18 +345,17 @@ async def filter_store_change_requests(
except Exception as e:
logger.error(f"过滤店铺变更请求列表时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="过滤店铺变更请求列表时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="过滤店铺变更请求列表时发生系统错误"
)
@router.put("/{request_id}", response_model=StoreChangeRequestResponse)
async def update_store_change_request(
- request_id: int,
- update_data: StoreChangeRequestUpdate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ request_id: int,
+ update_data: StoreChangeRequestUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> StoreChangeRequestResponse:
"""
更新店铺变更请求信息
@@ -397,68 +368,51 @@ async def update_store_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
- )
-
- if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
+ if not current_request:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查请求状态,只有PENDING_APPROVAL状态的请求可以被更新
if current_request.get("Status") != "PENDING_APPROVAL":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被更新"
- )
-
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被更新")
+
# 为了匹配测试中的期望,我们需要构造一个与测试中完全相同的数据对象
# 在测试中,请求数据是 {"ProposedData": {...}, "SubmitterNotes": "..."}
request_data = {
"ProposedData": {"StoreName": "Updated Store Name", "Description": "New Updated Description"},
- "SubmitterNotes": "Updated notes"
+ "SubmitterNotes": "Updated notes",
}
-
+
# 直接使用硬编码的测试数据,确保与测试中的期望完全一致
updated_request = await change_request_service.update_request(
- conn=db,
- request_id=request_id,
- data=request_data,
- actor_id=actor_id
+ conn=db, request_id=request_id, data=request_data, actor_id=actor_id
)
-
+
if not updated_request:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新请求失败"
- )
-
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新请求失败")
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(updated_request, True)
-
+
return StoreChangeRequestResponse(**formatted_response)
except Exception as e:
logger.error(f"更新店铺变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新店铺变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新店铺变更请求时发生系统错误")
@router.put("/{request_id}/admin", response_model=StoreChangeRequestResponse)
async def admin_update_store_change_request(
- request_id: int,
- admin_update: StoreChangeRequestAdminUpdate,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ request_id: int,
+ admin_update: StoreChangeRequestAdminUpdate,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
) -> StoreChangeRequestResponse:
"""
管理员审核店铺变更请求
@@ -470,73 +424,58 @@ async def admin_update_store_change_request(
:return: 更新后的变更请求信息
"""
# 权限检查已被移除,允许所有用户访问
-
+
try:
actor_id = current_user.get("UserID")
-
+
# 权限检查代码已移除,将在后续实现
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查请求状态,只有PENDING_APPROVAL状态的请求可以被审核
if current_request.get("Status") != "PENDING_APPROVAL":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被审核"
- )
-
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被审核")
+
# 转换admin_update对象为字典,以匹配service层期望的参数格式
update_dict = {
"Status": admin_update.Status,
"AdminReviewerID": admin_update.AdminReviewerID,
"AdminNotes": admin_update.AdminNotes,
- "ReviewTimestamp": admin_update.ReviewTimestamp
+ "ReviewTimestamp": admin_update.ReviewTimestamp,
}
-
+
# 使用data参数调用update_request_status,与product版本保持一致
updated_request = await change_request_service.update_request_status(
- conn=db,
- request_id=request_id,
- data=update_dict,
- actor_id=actor_id
+ conn=db, request_id=request_id, data=update_dict, actor_id=actor_id
)
-
+
if not updated_request:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="更新请求状态失败"
- )
-
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="更新请求状态失败")
+
# 使用特殊的响应格式化函数
formatted_response = mock_response_formatter(updated_request, True)
-
+
return StoreChangeRequestResponse(**formatted_response)
except Exception as e:
logger.error(f"管理员审核店铺变更请求时发生错误: {e}")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="管理员审核店铺变更请求时发生系统错误"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="管理员审核店铺变更请求时发生系统错误"
)
@router.delete("/{request_id}", status_code=status.HTTP_204_NO_CONTENT)
async def cancel_store_change_request(
- request_id: int,
- db: Connection = Depends(get_db_connection),
- current_user: Dict[str, Any] = Depends(get_current_active_user),
- change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
+ request_id: int,
+ db: Connection = Depends(get_db_connection),
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ change_request_service: StoreChangeRequestService = Depends(get_store_change_request_service),
):
"""
取消店铺变更请求(用户自行取消)
@@ -547,42 +486,27 @@ async def cancel_store_change_request(
"""
try:
actor_id = current_user.get("UserID")
-
- async with (await db.begin_nested() if await db.in_transaction() else db.begin()):
+
+ async with await db.begin_nested() if await db.in_transaction() else db.begin():
# 先获取当前请求信息,检查是否存在
current_request = await change_request_service.get_change_request_by_id(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
+ conn=db, request_id=request_id, actor_id=actor_id
)
-
+
if not current_request:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="变更请求不存在"
- )
-
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="变更请求不存在")
+
# 检查是否当前用户的请求(可选,取决于业务需求)
# if current_request.get("RequestingUserID") != actor_id:
# raise HTTPException(
# status_code=status.HTTP_403_FORBIDDEN,
# detail="无权取消该请求"
# )
-
- result = await change_request_service.cancel_request(
- conn=db,
- request_id=request_id,
- actor_id=actor_id
- )
-
+
+ result = await change_request_service.cancel_request(conn=db, request_id=request_id, actor_id=actor_id)
+
if not result:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="只有待审核的请求可以被取消"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="只有待审核的请求可以被取消")
except Exception as e:
logger.error(f"取消店铺变更请求时发生错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="取消店铺变更请求时发生系统错误"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="取消店铺变更请求时发生系统错误")
diff --git a/src/backend/app/api/v1/endpoints/store_change_request_v2.py b/src/backend/app/api/v1/endpoints/store_change_request_v2.py
index 8efb39d..2dbf1d7 100644
--- a/src/backend/app/api/v1/endpoints/store_change_request_v2.py
+++ b/src/backend/app/api/v1/endpoints/store_change_request_v2.py
@@ -57,7 +57,9 @@ async def submit_store_change_request(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
return await service.submit_new_request(
- db, requesting_user=current_user, request_in=request_in,
+ db,
+ requesting_user=current_user,
+ request_in=request_in,
)
except StoreNotFoundException as e:
logger.error(f"Store not found: {e}")
@@ -70,9 +72,7 @@ async def submit_store_change_request(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.get(
@@ -89,13 +89,13 @@ async def get_store_change_request(
"""
获取单个店铺变更请求的详细信息。
"""
- logger.info(
- f"User {current_user.UserID} is retrieving store change request with ID {request_id}"
- )
+ logger.info(f"User {current_user.UserID} is retrieving store change request with ID {request_id}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
return await service.get_request_details(
- db, change_request_id=request_id, actor_user=current_user,
+ db,
+ change_request_id=request_id,
+ actor_user=current_user,
)
except StoreNotFoundException as e:
logger.error(f"Store not found: {e}")
@@ -108,9 +108,7 @@ async def get_store_change_request(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.get(
@@ -127,9 +125,7 @@ async def list_store_change_requests(
"""
获取店铺变更请求列表,可以按状态、类型等筛选。
"""
- logger.info(
- f"User {current_user.UserID} is listing store change requests with filters: {query_params}"
- )
+ logger.info(f"User {current_user.UserID} is listing store change requests with filters: {query_params}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
return await service.list_requests_for_requesting_user(
@@ -148,9 +144,7 @@ async def list_store_change_requests(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.get(
@@ -167,9 +161,7 @@ async def list_store_change_requests_admin(
"""
管理员获取店铺变更请求列表,可以按状态、类型等筛选。
"""
- logger.info(
- f"Admin {current_user.UserID} is listing store change requests with filters: {query_params}"
- )
+ logger.info(f"Admin {current_user.UserID} is listing store change requests with filters: {query_params}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
return await service.list_requests_for_admin(
@@ -188,9 +180,7 @@ async def list_store_change_requests_admin(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.delete(
@@ -208,13 +198,13 @@ async def cancel_store_change_request(
删除或取消一个店铺变更请求。
只有提交请求的用户或管理员可以执行此操作。
"""
- logger.info(
- f"User {current_user.UserID} is cancelling store change request with ID {request_id}"
- )
+ logger.info(f"User {current_user.UserID} is cancelling store change request with ID {request_id}")
try:
with db.begin_nested() if db.in_transaction() else db.begin():
return await service.user_cancel_request(
- db, change_request_id=request_id, requesting_user=current_user,
+ db,
+ change_request_id=request_id,
+ requesting_user=current_user,
)
except StoreNotFoundException as e:
logger.error(f"Store not found: {e}")
@@ -227,9 +217,7 @@ async def cancel_store_change_request(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@router.put(
@@ -267,6 +255,4 @@ async def review_store_change_request(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error"
- )
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
diff --git a/src/backend/app/api/v1/endpoints/user.py b/src/backend/app/api/v1/endpoints/user.py
index f9a9e9e..d142432 100644
--- a/src/backend/app/api/v1/endpoints/user.py
+++ b/src/backend/app/api/v1/endpoints/user.py
@@ -6,11 +6,7 @@
from backend.app.schemas.user_schema import UserCreate, UserResponse
from backend.app.services.user_service import UserService
-from backend.app.dependencies import (
- get_db_connection,
- get_user_service,
- get_current_active_user
-)
+from backend.app.dependencies import get_db_connection, get_user_service, get_current_active_user
from backend.app.utils.exceptions import DuplicateUserError
@@ -19,10 +15,10 @@
@router.post("/register", response_model=UserResponse)
async def register_user(
- user_in: UserCreate,
- db: Connection = Depends(get_db_connection),
- user_service: UserService = Depends(get_user_service),
- performing_user_id: Optional[int] = None,
+ user_in: UserCreate,
+ db: Connection = Depends(get_db_connection),
+ user_service: UserService = Depends(get_user_service),
+ performing_user_id: Optional[int] = None,
) -> UserResponse:
"""
注册新用户。
@@ -35,31 +31,23 @@ async def register_user(
try:
with db.begin_nested() if db.in_transaction() else db.begin():
created_user = user_service.register_new_user(
- conn=db,
- user_in=user_in,
- performing_user_id=performing_user_id
+ conn=db, user_in=user_in, performing_user_id=performing_user_id
)
return created_user
except DuplicateUserError as e:
logger.error(f"Duplicate user error during registration: {e}")
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error during user registration: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="An unexpected error occurred while creating the account."
-
+ detail="An unexpected error occurred while creating the account.",
)
@router.get("/me", response_model=UserResponse, tags=["Users"], summary="Get Current User Information")
-async def read_users_me(
- current_user: UserResponse = Depends(get_current_active_user) # ⭐ 使用认证依赖
-):
+async def read_users_me(current_user: UserResponse = Depends(get_current_active_user)): # ⭐ 使用认证依赖
"""
获取当前已认证用户的信息。
如果token无效或会话过期,此端点将不会被执行,而是返回401错误。
@@ -68,4 +56,5 @@ async def read_users_me(
logger.info(f"User {current_user.UserID} ('{current_user.Username}') accessed /me endpoint.")
return current_user
+
# ... 其他用户相关的端点,例如更新用户信息、获取特定用户(管理员权限)等 ...
diff --git a/src/backend/app/api/v1/router.py b/src/backend/app/api/v1/router.py
index 5f7c3cf..0110d72 100644
--- a/src/backend/app/api/v1/router.py
+++ b/src/backend/app/api/v1/router.py
@@ -1,6 +1,20 @@
from fastapi import APIRouter
-from .endpoints import user, product, category, auth, cart, address, order, payment, store,\
- store_change_request, product_change_request, product_change_request_v2, store_change_request_v2, statistics
+from .endpoints import (
+ user,
+ product,
+ category,
+ auth,
+ cart,
+ address,
+ order,
+ payment,
+ store,
+ store_change_request,
+ product_change_request,
+ product_change_request_v2,
+ store_change_request_v2,
+ statistics,
+)
api_router_v1 = APIRouter()
@@ -15,9 +29,15 @@
api_router_v1.include_router(payment.router, prefix="/payment", tags=["Payments"])
api_router_v1.include_router(store.router, prefix="/store", tags=["Store"])
api_router_v1.include_router(store_change_request.router, prefix="/store-change", tags=["Store Change Requests"])
-api_router_v1.include_router(store_change_request_v2.router, prefix="/store-change-new", tags=["Store Change Requests V2"])
-api_router_v1.include_router(product_change_request.router, prefix="/product-change", tags=["Product Change Requests V1"])
+api_router_v1.include_router(
+ store_change_request_v2.router, prefix="/store-change-new", tags=["Store Change Requests V2"]
+)
+api_router_v1.include_router(
+ product_change_request.router, prefix="/product-change", tags=["Product Change Requests V1"]
+)
-api_router_v1.include_router(product_change_request_v2.router, prefix="/product-change-new", tags=["Product Change Requests V2"])
+api_router_v1.include_router(
+ product_change_request_v2.router, prefix="/product-change-new", tags=["Product Change Requests V2"]
+)
api_router_v1.include_router(statistics.router, prefix="/statistics", tags=["Statistics"])
diff --git a/src/backend/app/core/config.py b/src/backend/app/core/config.py
index c08622d..0c89011 100644
--- a/src/backend/app/core/config.py
+++ b/src/backend/app/core/config.py
@@ -23,6 +23,7 @@ class DbConfig(BaseSettings):
DB_NAME: str = os.getenv("DB_NAME")
DB_URL: str = os.getenv("DB_URL", f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}")
+
class TestDbConfig(DbConfig):
DB_HOST: str = os.getenv("TEST_DB_HOST", "localhost")
DB_PORT: int = int(os.getenv("TEST_DB_PORT", 3306))
@@ -59,12 +60,15 @@ class Settings(BaseSettings):
TEST_DB_PASSWORD: str = os.getenv("TEST_DB_PASSWORD")
TEST_DB_NAME: str = os.getenv("TEST_DB_NAME")
- _test_db_url_default: str = f"mysql+pymysql://{TEST_DB_USER}:{TEST_DB_PASSWORD}@{TEST_DB_HOST}:{TEST_DB_PORT}/{TEST_DB_NAME}"
+ _test_db_url_default: str = (
+ f"mysql+pymysql://{TEST_DB_USER}:{TEST_DB_PASSWORD}@{TEST_DB_HOST}:{TEST_DB_PORT}/{TEST_DB_NAME}"
+ )
TEST_DATABASE_URL: str = os.getenv("TEST_DB_URL", _test_db_url_default)
# JWT Settings
- SECRET_KEY: str = os.getenv("SECRET_KEY",
- "a_very_secret_key_that_should_be_random_and_long") # ⭐ CHANGE THIS IN .env
+ SECRET_KEY: str = os.getenv(
+ "SECRET_KEY", "a_very_secret_key_that_should_be_random_and_long"
+ ) # ⭐ CHANGE THIS IN .env
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
@@ -73,11 +77,11 @@ class Settings(BaseSettings):
# Pydantic model config (Pydantic V2)
model_config = SettingsConfigDict(
- env_file_encoding='utf-8',
+ env_file_encoding="utf-8",
# If you want Pydantic to also load from .env directly (in addition to python-dotenv)
# you can specify env_file here, but python-dotenv's load_dotenv is more flexible for multiple files.
# env_file = backend_root / ".env" # Example for Pydantic's own .env loading
- extra='ignore' # Ignore extra fields not defined in Settings
+ extra="ignore", # Ignore extra fields not defined in Settings
)
diff --git a/src/backend/app/core/database.py b/src/backend/app/core/database.py
index 3518e0f..443c964 100644
--- a/src/backend/app/core/database.py
+++ b/src/backend/app/core/database.py
@@ -10,9 +10,9 @@
# Create the database engine
logger.info("Creating database engine with URL: %s", db_config.DB_URL)
-__engine_dev = create_engine(db_config.DB_URL, echo="debug") # one global instance
+__engine_dev = create_engine(db_config.DB_URL, echo="debug") # one global instance
logger.info("Creating test database engine with URL: %s", test_db_config.DB_URL)
-__engine_test = create_engine(test_db_config.DB_URL, echo="debug") # one global instance
+__engine_test = create_engine(test_db_config.DB_URL, echo="debug") # one global instance
database_mode = "dev" # default mode
__current_engine = __engine_dev
@@ -42,6 +42,7 @@ def get_engine() -> sqlalchemy.engine.Engine:
raise ValueError("Database engine is not initialized.")
return __current_engine
+
def execute_query(query: str, params: dict = None):
"""
Execute a SQL query and return the result.
@@ -51,6 +52,7 @@ def execute_query(query: str, params: dict = None):
result = connection.execute(text(query), params)
return result.fetchall()
+
def _connection_test():
"""
Test the database connection.
@@ -64,4 +66,5 @@ def _connection_test():
logger.error("Database connection test failed: %s", e)
return False
+
# _connection_test()
diff --git a/src/backend/app/core/logging_config.py b/src/backend/app/core/logging_config.py
index 3aaa9c4..d78304a 100644
--- a/src/backend/app/core/logging_config.py
+++ b/src/backend/app/core/logging_config.py
@@ -28,9 +28,7 @@ def emit(self, record: logging.LogRecord):
frame = frame.f_back
depth += 1
- logger.opt(depth=depth + self.extra_depth, exception=record.exc_info).log(
- level, record.getMessage()
- )
+ logger.opt(depth=depth + self.extra_depth, exception=record.exc_info).log(level, record.getMessage())
__logging_setup = False
@@ -61,11 +59,11 @@ def setup_logging():
sys.stdout,
level=log_level,
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | "
- "{level: <8} | "
- "{name}:{function}:{line} - {message}",
+ "{level: <8} | "
+ "{name}:{function}:{line} - {message}",
colorize=True,
backtrace=True,
- diagnose=True
+ diagnose=True,
)
# 4. 添加 Loguru 的文件处理器
@@ -81,7 +79,7 @@ def setup_logging():
enqueue=True,
encoding="utf-8",
backtrace=True,
- diagnose=True
+ diagnose=True,
)
# 5. 配置标准 logging 模块以使用 InterceptHandler
@@ -108,4 +106,5 @@ def setup_logging():
logger.info(f"Application log level (for Loguru sinks) set to: {log_level}")
logger.info(f"Standard logging (e.g., Uvicorn) intercepted by Loguru.")
logger.info(
- f"File logging enabled: path='{str(LOGS_DIR / 'app_YYYY-MM-DD.log')}', rotation='daily', retention='7 days'")
+ f"File logging enabled: path='{str(LOGS_DIR / 'app_YYYY-MM-DD.log')}', rotation='daily', retention='7 days'"
+ )
diff --git a/src/backend/app/crud/address_crud.py b/src/backend/app/crud/address_crud.py
index 4426b81..1e020b7 100644
--- a/src/backend/app/crud/address_crud.py
+++ b/src/backend/app/crud/address_crud.py
@@ -33,27 +33,17 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
# logger.trace(f"Session variable @actor_id set to {actor_id}")
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
# logger.trace("Session variable @actor_id set to NULL")
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# 根据需要决定是否重新抛出异常
def create_address(
- self,
- conn: Connection,
- *,
- user_id: int,
- address_in: AddressCreateRequest,
- actor_id: Optional[int]
+ self, conn: Connection, *, user_id: int, address_in: AddressCreateRequest, actor_id: Optional[int]
) -> Optional[Dict[str, Any]]:
"""
为指定用户创建一个新的收货地址。
@@ -67,29 +57,36 @@ def create_address(
:return: 创建成功后的地址信息字典 (包含 AddressID 和 IsDefault=False),如果创建失败则返回 None。
"""
logger.info(
- f"Attempting to create address for UserID {user_id} by ActorID {actor_id} with data: {address_in.model_dump_json(exclude_unset=True)}")
+ f"Attempting to create address for UserID {user_id} by ActorID {actor_id} with data: {address_in.model_dump_json(exclude_unset=True)}"
+ )
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name}
(UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault)
VALUES (:UserID, :RecipientName, :PhoneNumber, :FullAddress_Text, FALSE)
- """)
+ """
+ )
# 注意:AddedDate 在您的 DDL 中有 DEFAULT CURRENT_TIMESTAMP,所以不需要在这里显式插入,除非您想覆盖它。
# 如果 DDL 中没有默认值,则需要添加:, UTC_TIMESTAMP() 并相应地更新 VALUES
try:
- result = conn.execute(insert_stmt, {
- "UserID": user_id,
- "RecipientName": address_in.RecipientName,
- "PhoneNumber": address_in.PhoneNumber,
- "FullAddress_Text": address_in.FullAddress_Text
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "UserID": user_id,
+ "RecipientName": address_in.RecipientName,
+ "PhoneNumber": address_in.PhoneNumber,
+ "FullAddress_Text": address_in.FullAddress_Text,
+ },
+ )
new_address_id = result.lastrowid
if new_address_id is None: # Fallback if lastrowid is not supported/returned
logger.warning(
- "lastrowid not available after address insert. This might indicate an issue or specific DB driver behavior.")
+ "lastrowid not available after address insert. This might indicate an issue or specific DB driver behavior."
+ )
logger.info(f"Address created for UserID {user_id} with AddressID {new_address_id} by ActorID {actor_id}.")
# 获取并返回新创建的地址的完整信息
return self.get_address_by_id(conn, address_id=new_address_id, actor_id=actor_id) # type: ignore
@@ -101,7 +98,7 @@ def create_address(
return None
def get_address_by_id(
- self, conn: Connection, *, address_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, address_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据 AddressID 检索特定的收货地址信息。
@@ -115,11 +112,13 @@ def get_address_by_id(
# logger.info(f"Getting address by AddressID {address_id}, ActorID {actor_id}")
# _set_actor_session_variable 通常不需要用于只读操作,除非触发器也用于 SELECT
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault
FROM {self.table_name}
WHERE AddressID = :AddressID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"AddressID": address_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
@@ -128,7 +127,7 @@ def get_address_by_id(
return None
def get_addresses_by_user_id(
- self, conn: Connection, *, user_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, user_id: int, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定用户的所有收货地址列表。
@@ -140,12 +139,14 @@ def get_addresses_by_user_id(
:return: 包含该用户所有地址信息的字典列表 (每个字典包括 IsDefault 状态),如果用户没有地址则返回空列表。
"""
# logger.info(f"Getting all addresses for UserID {user_id}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault
FROM {self.table_name}
WHERE UserID = :UserID
ORDER BY IsDefault DESC, AddressID ASC -- 默认地址优先,然后按创建顺序
- """)
+ """
+ )
try:
results = conn.execute(select_stmt, {"UserID": user_id}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
@@ -154,12 +155,7 @@ def get_addresses_by_user_id(
return []
def update_address_details(
- self,
- conn: Connection,
- *,
- address_id: int,
- address_in: AddressUpdateRequest,
- actor_id: int
+ self, conn: Connection, *, address_id: int, address_in: AddressUpdateRequest, actor_id: int
) -> Optional[Dict[str, Any]]:
"""
更新现有收货地址的详细信息(如收货人、电话、地址文本)。
@@ -172,7 +168,8 @@ def update_address_details(
:return: 更新成功后的地址信息字典 (包括其当前的 IsDefault 状态),如果地址未找到或更新失败则返回 None。
"""
logger.info(
- f"Attempting to update address details for AddressID {address_id} by ActorID {actor_id} with data: {address_in.model_dump_json(exclude_unset=True)}")
+ f"Attempting to update address details for AddressID {address_id} by ActorID {actor_id} with data: {address_in.model_dump_json(exclude_unset=True)}"
+ )
self._set_actor_session_variable(conn, actor_id)
update_fields = []
@@ -198,7 +195,8 @@ def update_address_details(
result = conn.execute(text(update_stmt_str), params_to_update)
if result.rowcount == 0:
logger.warning(
- f"AddressID {address_id} not found for update or no data changed, by ActorID {actor_id}.")
+ f"AddressID {address_id} not found for update or no data changed, by ActorID {actor_id}."
+ )
return None # Or fetch and return current if no change is not an error
logger.info(f"Address details updated for AddressID {address_id} by ActorID {actor_id}.")
@@ -208,7 +206,7 @@ def update_address_details(
return None
def update_address_is_default_flag(
- self, conn: Connection, *, address_id: int, is_default: bool, actor_id: int
+ self, conn: Connection, *, address_id: int, is_default: bool, actor_id: int
) -> bool:
"""
专门用于更新单个收货地址的 IsDefault 标志。
@@ -223,25 +221,28 @@ def update_address_is_default_flag(
logger.info(f"Setting IsDefault={is_default} for AddressID {address_id} by ActorID {actor_id}")
self._set_actor_session_variable(conn, actor_id)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET IsDefault = :IsDefault
WHERE AddressID = :AddressID
- """)
+ """
+ )
try:
result = conn.execute(update_stmt, {"IsDefault": is_default, "AddressID": address_id})
if result.rowcount > 0:
logger.info(f"IsDefault flag for AddressID {address_id} set to {is_default} by ActorID {actor_id}.")
return True
logger.warning(
- f"AddressID {address_id} not found or IsDefault flag already {is_default}, by ActorID {actor_id}.")
+ f"AddressID {address_id} not found or IsDefault flag already {is_default}, by ActorID {actor_id}."
+ )
return False
except Exception as e:
logger.error(f"Error updating IsDefault flag for AddressID {address_id}: {e}")
raise e
def set_all_other_addresses_non_default_for_user(
- self, conn: Connection, *, user_id: int, except_address_id: int, actor_id: int
+ self, conn: Connection, *, user_id: int, except_address_id: int, actor_id: int
) -> int:
"""
将指定用户的所有收货地址(除了被排除的那个 `except_address_id`)的 IsDefault 标志设置为 FALSE。
@@ -254,27 +255,29 @@ def set_all_other_addresses_non_default_for_user(
:return: 被更新为非默认的地址数量。
"""
logger.info(
- f"Setting all other addresses to non-default for UserID {user_id}, except AddressID {except_address_id}, by ActorID {actor_id}")
+ f"Setting all other addresses to non-default for UserID {user_id}, except AddressID {except_address_id}, by ActorID {actor_id}"
+ )
self._set_actor_session_variable(conn, actor_id)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET IsDefault = FALSE
WHERE UserID = :UserID AND AddressID != :ExceptAddressID AND IsDefault = TRUE
- """)
+ """
+ )
try:
result = conn.execute(update_stmt, {"UserID": user_id, "ExceptAddressID": except_address_id})
updated_count = result.rowcount
logger.info(
- f"{updated_count} other addresses set to non-default for UserID {user_id} by ActorID {actor_id}.")
+ f"{updated_count} other addresses set to non-default for UserID {user_id} by ActorID {actor_id}."
+ )
return updated_count
except Exception as e:
logger.error(f"Error setting other addresses non-default for UserID {user_id}: {e}")
raise e
- def delete_address(
- self, conn: Connection, *, address_id: int, actor_id: int
- ) -> bool:
+ def delete_address(self, conn: Connection, *, address_id: int, actor_id: int) -> bool:
"""
删除指定的收货地址。
服务层需要处理如果删除的是默认地址时,User.DefaultAddressID 的更新逻辑(例如设为 NULL)。
diff --git a/src/backend/app/crud/cartitem_crud.py b/src/backend/app/crud/cartitem_crud.py
index fec403a..3c8466d 100644
--- a/src/backend/app/crud/cartitem_crud.py
+++ b/src/backend/app/crud/cartitem_crud.py
@@ -31,27 +31,17 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
# logger.trace(f"Session variable @actor_id set to {actor_id}")
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
# logger.trace("Session variable @actor_id set to NULL")
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# Decide if this should raise an error or just log
def check_user_owns_cart_item(
- self,
- conn: Connection,
- *,
- cart_item_id: int,
- user_id: int,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, cart_item_id: int, user_id: int, actor_id: Optional[int] = None
) -> bool:
"""
Check if the user owns the cart item.
@@ -63,23 +53,25 @@ def check_user_owns_cart_item(
:return: True if the user owns the cart item, False otherwise.
"""
self._set_actor_session_variable(conn, actor_id)
- stmt = text(f"""
+ stmt = text(
+ f"""
SELECT COUNT(*) AS ItemCount
FROM {self.table_name}
WHERE CartItemID = :cart_item_id AND UserID = :user_id
- """)
+ """
+ )
result = conn.execute(stmt, {"cart_item_id": cart_item_id, "user_id": user_id}).fetchone()
return result["ItemCount"] > 0 if result else False
def add_item_to_cart(
- self,
- conn: Connection,
- *,
- user_id: int,
- product_id: int,
- quantity: int,
- price_at_addition: float, # Use float for DECIMAL representation in Python
- actor_id: Optional[int] = None
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ product_id: int,
+ quantity: int,
+ price_at_addition: float, # Use float for DECIMAL representation in Python
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Adds a new item to a user's cart or updates quantity if item already exists.
@@ -117,38 +109,45 @@ def add_item_to_cart(
cart_item_id = existing_item_result["CartItemID"]
new_quantity = existing_item_result["Quantity"] + quantity
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET Quantity = :quantity, PriceAtAddition = :price_at_addition, AddedDate = UTC_TIMESTAMP()
WHERE CartItemID = :cart_item_id
- """)
+ """
+ )
try:
- conn.execute(update_stmt, {
- "quantity": new_quantity,
- "price_at_addition": price_at_addition,
- "cart_item_id": cart_item_id
- })
+ conn.execute(
+ update_stmt,
+ {"quantity": new_quantity, "price_at_addition": price_at_addition, "cart_item_id": cart_item_id},
+ )
logger.info(
- f"Updated quantity for CartItemID {cart_item_id} to {new_quantity} by actor {actor_id or user_id}.")
+ f"Updated quantity for CartItemID {cart_item_id} to {new_quantity} by actor {actor_id or user_id}."
+ )
except exc.SQLAlchemyError as e:
logger.error(f"Error updating cart item {cart_item_id}: {e}")
return None
else:
logger.info(f"Item does not exist in cart for UserID {user_id}, ProductID {product_id}.")
# Item does not exist, insert new
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name}
(UserID, ProductID, Quantity, PriceAtAddition, AddedDate)
VALUES (:user_id, :product_id, :quantity, :price_at_addition, UTC_TIMESTAMP())
- """)
+ """
+ )
# Using UTC_TIMESTAMP() for AddedDate as per previous discussions on standardizing time
try:
- result = conn.execute(insert_stmt, {
- "user_id": user_id,
- "product_id": product_id,
- "quantity": quantity,
- "price_at_addition": price_at_addition
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "user_id": user_id,
+ "product_id": product_id,
+ "quantity": quantity,
+ "price_at_addition": price_at_addition,
+ },
+ )
cart_item_id = result.lastrowid # Get the ID of the newly inserted row
if cart_item_id is None: # Fallback for some DBs or if not an auto-increment PK in this context
logger.warning("lastrowid not available after insert, attempting to fetch by unique keys.")
@@ -159,7 +158,8 @@ def add_item_to_cart(
cart_item_id = temp_fetch["CartItemID"]
logger.info(
- f"Added new item to cart for UserID {user_id}, ProductID {product_id} by actor {actor_id or user_id}. CartItemID: {cart_item_id}")
+ f"Added new item to cart for UserID {user_id}, ProductID {product_id} by actor {actor_id or user_id}. CartItemID: {cart_item_id}"
+ )
except exc.IntegrityError as e: # Catch FK violations or other integrity issues
logger.error(f"Integrity error adding item to cart for UserID {user_id}, ProductID {product_id}: {e}")
return None
@@ -172,52 +172,58 @@ def add_item_to_cart(
return None
def get_cart_item_by_id(
- self, conn: Connection, *, cart_item_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, cart_item_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
Retrieves a specific cart item by its CartItemID.
"""
self._set_actor_session_variable(conn, actor_id)
- stmt = text(f"""
+ stmt = text(
+ f"""
SELECT CartItemID, UserID, ProductID, Quantity, PriceAtAddition, AddedDate
FROM {self.table_name}
WHERE CartItemID = :cart_item_id
- """)
+ """
+ )
result = conn.execute(stmt, {"cart_item_id": cart_item_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
def get_cart_item_by_user_and_product(
- self, conn: Connection, *, user_id: int, product_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, user_id: int, product_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
Retrieves a specific cart item for a user and product.
Useful for checking if an item already exists or for direct fetching.
"""
self._set_actor_session_variable(conn, actor_id)
- stmt = text(f"""
+ stmt = text(
+ f"""
SELECT CartItemID, UserID, ProductID, Quantity, PriceAtAddition, AddedDate
FROM {self.table_name}
WHERE UserID = :user_id AND ProductID = :product_id
- """)
+ """
+ )
result = conn.execute(stmt, {"user_id": user_id, "product_id": product_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
def get_cart_items_by_user_id(
- self, conn: Connection, *, user_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, user_id: int, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Retrieves all cart items for a given UserID.
The result contains extra fields like ProductName, ProductImageURL, Price.
"""
self._set_actor_session_variable(conn, actor_id)
- stmt = text(f"""
+ stmt = text(
+ f"""
SELECT ci.CartItemID, ci.UserID, ci.ProductID, ci.Quantity, ci.PriceAtAddition, ci.AddedDate,
p.ProductName, p.MainImageURL, p.Price -- Example: Join with Product table to get more details
FROM {self.table_name} ci
JOIN Product p ON ci.ProductID = p.ProductID -- Assuming Product table name and PK
WHERE ci.UserID = :user_id
ORDER BY ci.AddedDate DESC
- """)
+ """
+ )
# Note: The JOIN above assumes a Product table with ProductName and ProductImageURL.
# Adjust the JOIN and selected columns based on your actual Product table schema.
# If you don't want to join here, remove the JOIN and p.* columns.
@@ -233,12 +239,12 @@ def get_cart_items_by_user_id(
return [dict(row._mapping) for row in results] # type: ignore
def update_cart_item_quantity(
- self,
- conn: Connection,
- *,
- cart_item_id: int,
- new_quantity: int,
- actor_id: Optional[int] # UserID of who is performing the update
+ self,
+ conn: Connection,
+ *,
+ cart_item_id: int,
+ new_quantity: int,
+ actor_id: Optional[int], # UserID of who is performing the update
) -> Optional[Dict[str, Any]]:
"""
Updates the quantity of an existing cart item.
@@ -247,7 +253,8 @@ def update_cart_item_quantity(
"""
if new_quantity <= 0:
logger.warning(
- f"Attempt to update CartItemID {cart_item_id} with non-positive quantity: {new_quantity}. Use remove_item_from_cart instead.")
+ f"Attempt to update CartItemID {cart_item_id} with non-positive quantity: {new_quantity}. Use remove_item_from_cart instead."
+ )
# Or, you could call self.remove_item_from_cart here and return None or a specific status.
# For this example, we'll just prevent the update for non-positive quantity.
raise ValueError("Quantity must be positive. To remove an item, use the delete method.")
@@ -256,19 +263,19 @@ def update_cart_item_quantity(
# Optionally, re-fetch PriceAtAddition if business logic requires it to be current product price
# For now, we assume PriceAtAddition is fixed once added, or updated only by add_item_to_cart
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET Quantity = :new_quantity, AddedDate = UTC_TIMESTAMP() -- Update AddedDate to reflect modification
WHERE CartItemID = :cart_item_id
- """)
+ """
+ )
try:
- result = conn.execute(update_stmt, {
- "new_quantity": new_quantity,
- "cart_item_id": cart_item_id
- })
+ result = conn.execute(update_stmt, {"new_quantity": new_quantity, "cart_item_id": cart_item_id})
if result.rowcount == 0:
logger.warning(
- f"No cart item found with CartItemID {cart_item_id} to update quantity for actor {actor_id}.")
+ f"No cart item found with CartItemID {cart_item_id} to update quantity for actor {actor_id}."
+ )
return None
logger.info(f"Updated quantity for CartItemID {cart_item_id} to {new_quantity} by actor {actor_id}.")
except exc.SQLAlchemyError as e:
@@ -277,9 +284,7 @@ def update_cart_item_quantity(
return self.get_cart_item_by_id(conn, cart_item_id=cart_item_id, actor_id=actor_id)
- def remove_item_from_cart(
- self, conn: Connection, *, cart_item_id: int, actor_id: Optional[int]
- ) -> bool:
+ def remove_item_from_cart(self, conn: Connection, *, cart_item_id: int, actor_id: Optional[int]) -> bool:
"""
Removes a specific item from the cart by its CartItemID.
Returns True if an item was deleted, False otherwise.
@@ -298,9 +303,7 @@ def remove_item_from_cart(
logger.error(f"Error removing CartItemID {cart_item_id} from cart: {e}")
raise e
- def clear_cart_for_user(
- self, conn: Connection, *, user_id: int, actor_id: Optional[int]
- ) -> int:
+ def clear_cart_for_user(self, conn: Connection, *, user_id: int, actor_id: Optional[int]) -> int:
"""
Removes all items from a specific user's cart.
Returns the number of items deleted.
diff --git a/src/backend/app/crud/category_crud.py b/src/backend/app/crud/category_crud.py
index 2e72cfd..49975ab 100644
--- a/src/backend/app/crud/category_crud.py
+++ b/src/backend/app/crud/category_crud.py
@@ -29,21 +29,16 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:param actor_id: 操作者ID
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
def get_category_by_id(
- self,
- conn: Connection,
- *,
- category_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ category_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
根据ID获取分类信息
@@ -54,25 +49,24 @@ def get_category_by_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
CategoryID, CategoryName, CategoryDescription, ParentCategoryID
FROM {self.table_name} WHERE CategoryID = :category_id
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"category_id": category_id}
- ).fetchone()
+ result = conn.execute(select_stmt, {"category_id": category_id}).fetchone()
return dict(result._mapping) if result else None
def get_categories(
- self,
- conn: Connection,
- *,
- parent_id: Optional[int] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ parent_id: Optional[int] = None,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取分类列表,可指定父分类ID获取子分类
@@ -85,35 +79,36 @@ def get_categories(
if parent_id is None:
# 获取所有一级分类(没有父分类的分类)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
CategoryID, CategoryName, CategoryDescription, ParentCategoryID
FROM {self.table_name}
WHERE ParentCategoryID IS NULL
ORDER BY CategoryID
- """)
+ """
+ )
results = conn.execute(select_stmt).fetchall()
else:
# 获取指定父分类的子分类
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
CategoryID, CategoryName, CategoryDescription, ParentCategoryID
FROM {self.table_name}
WHERE ParentCategoryID = :parent_id
ORDER BY CategoryID
- """)
- results = conn.execute(
- select_stmt,
- {"parent_id": parent_id}
- ).fetchall()
+ """
+ )
+ results = conn.execute(select_stmt, {"parent_id": parent_id}).fetchall()
return [dict(row._mapping) for row in results]
def get_all_categories(
- self,
- conn: Connection,
- *,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取所有分类
@@ -123,24 +118,26 @@ def get_all_categories(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
CategoryID, CategoryName, CategoryDescription, ParentCategoryID
FROM {self.table_name}
ORDER BY ParentCategoryID IS NULL DESC, CategoryID
- """)
+ """
+ )
results = conn.execute(select_stmt).fetchall()
return [dict(row._mapping) for row in results]
def create_category(
- self,
- conn: Connection,
- *,
- category_name: str,
- category_description: Optional[str] = None,
- parent_category_id: Optional[int] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ category_name: str,
+ category_description: Optional[str] = None,
+ parent_category_id: Optional[int] = None,
+ actor_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
创建新分类
@@ -155,28 +152,26 @@ def create_category(
# 验证父分类是否存在
if parent_category_id is not None:
- parent_category = self.get_category_by_id(
- conn,
- category_id=parent_category_id,
- actor_id=actor_id
- )
+ parent_category = self.get_category_by_id(conn, category_id=parent_category_id, actor_id=actor_id)
if not parent_category:
raise ValueError(f"父分类ID {parent_category_id} 不存在")
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name}
(CategoryName, CategoryDescription, ParentCategoryID)
VALUES
(:category_name, :category_description, :parent_category_id)
- """)
+ """
+ )
conn.execute(
insert_stmt,
{
"category_name": category_name,
"category_description": category_description,
- "parent_category_id": parent_category_id
- }
+ "parent_category_id": parent_category_id,
+ },
)
# 获取自增ID
@@ -187,12 +182,12 @@ def create_category(
return self.get_category_by_id(conn, category_id=category_id, actor_id=actor_id)
def update_category(
- self,
- conn: Connection,
- *,
- category_id: int,
- update_data: Dict[str, Any],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ category_id: int,
+ update_data: Dict[str, Any],
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新分类信息
@@ -210,21 +205,17 @@ def update_category(
return None
# 检查父分类是否存在
- if 'parentcategoryid' in update_data and update_data['parentcategoryid'] is not None:
- parent_id = update_data['parentcategoryid']
+ if "parentcategoryid" in update_data and update_data["parentcategoryid"] is not None:
+ parent_id = update_data["parentcategoryid"]
# 不能将分类的父分类设置为自己
if parent_id == category_id:
raise ValueError("分类的父分类不能是自己")
-
+
# 检查父分类是否存在
- parent_category = self.get_category_by_id(
- conn,
- category_id=parent_id,
- actor_id=actor_id
- )
+ parent_category = self.get_category_by_id(conn, category_id=parent_id, actor_id=actor_id)
if not parent_category:
raise ValueError(f"父分类ID {parent_id} 不存在")
-
+
# 检查是否形成循环引用
self._check_cyclic_reference(conn, category_id, parent_id)
@@ -232,9 +223,7 @@ def update_category(
update_fields = []
params = {"category_id": category_id}
- valid_fields = [
- "CategoryName", "CategoryDescription", "ParentCategoryID"
- ]
+ valid_fields = ["CategoryName", "CategoryDescription", "ParentCategoryID"]
for field in valid_fields:
if field.lower() in update_data:
@@ -245,11 +234,13 @@ def update_category(
# 没有可更新的字段
return category
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {", ".join(update_fields)}
WHERE CategoryID = :category_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
@@ -257,11 +248,11 @@ def update_category(
return self.get_category_by_id(conn, category_id=category_id, actor_id=actor_id)
def delete_category(
- self,
- conn: Connection,
- *,
- category_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ category_id: int,
+ actor_id: Optional[int] = None,
) -> bool:
"""
删除分类
@@ -283,38 +274,29 @@ def delete_category(
raise ValueError("该分类下有子分类,不能直接删除")
# 检查是否有商品引用此分类
- check_product_stmt = text("""
+ check_product_stmt = text(
+ """
SELECT COUNT(*) as count FROM Product WHERE CategoryID = :category_id
- """)
- result = conn.execute(
- check_product_stmt,
- {"category_id": category_id}
- ).fetchone()
-
+ """
+ )
+ result = conn.execute(check_product_stmt, {"category_id": category_id}).fetchone()
+
if result and result[0] > 0:
raise ValueError("该分类下有商品,不能删除")
# 执行删除
- delete_stmt = text(f"""
+ delete_stmt = text(
+ f"""
DELETE FROM {self.table_name}
WHERE CategoryID = :category_id
- """)
-
- conn.execute(
- delete_stmt,
- {
- "category_id": category_id
- }
+ """
)
+ conn.execute(delete_stmt, {"category_id": category_id})
+
return True
- def _check_cyclic_reference(
- self,
- conn: Connection,
- category_id: int,
- new_parent_id: int
- ):
+ def _check_cyclic_reference(self, conn: Connection, category_id: int, new_parent_id: int):
"""
检查更新分类的父分类是否会导致循环引用
:param conn: 数据库连接
@@ -325,36 +307,35 @@ def _check_cyclic_reference(
# 从新的父分类开始,向上查找所有祖先分类
current_id = new_parent_id
visited = set()
-
+
while current_id is not None:
if current_id in visited: # 发现循环
raise ValueError("更新父分类会导致循环引用")
-
+
if current_id == category_id: # 祖先中包含了当前分类,形成循环
raise ValueError("更新父分类会导致循环引用")
-
+
visited.add(current_id)
-
+
# 查找当前分类的父分类
- parent_query = text(f"""
+ parent_query = text(
+ f"""
SELECT ParentCategoryID FROM {self.table_name} WHERE CategoryID = :current_id
- """)
-
- result = conn.execute(
- parent_query,
- {"current_id": current_id}
- ).fetchone()
-
+ """
+ )
+
+ result = conn.execute(parent_query, {"current_id": current_id}).fetchone()
+
if result and result[0] is not None:
current_id = result[0]
else:
current_id = None # 已到达顶层分类
def get_category_tree(
- self,
- conn: Connection,
- *,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取分类树结构
@@ -363,25 +344,22 @@ def get_category_tree(
:return: 树形结构的分类信息
"""
self._set_actor_session_variable(conn, actor_id)
-
+
# 获取所有分类
all_categories = self.get_all_categories(conn, actor_id=actor_id)
-
+
# 构建ID到分类的映射,确保每个分类都有Children字段
category_map = {category["CategoryID"]: dict(category, Children=[]) for category in all_categories}
-
+
# 构建树结构
for category_id, category in category_map.items():
parent_id = category.get("ParentCategoryID")
if parent_id is not None and parent_id in category_map:
category_map[parent_id]["Children"].append(category)
-
+
# 返回所有根分类
- root_categories = [
- category for category in category_map.values()
- if category.get("ParentCategoryID") is None
- ]
-
+ root_categories = [category for category in category_map.values() if category.get("ParentCategoryID") is None]
+
return root_categories
diff --git a/src/backend/app/crud/order_crud.py b/src/backend/app/crud/order_crud.py
index d4b7c17..ca59c7c 100644
--- a/src/backend/app/crud/order_crud.py
+++ b/src/backend/app/crud/order_crud.py
@@ -36,35 +36,30 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# Consider re-raising or handling
def create_order(
- self,
- conn: Connection,
- *,
- user_id: int,
- store_id: int,
- payment_transaction_id: int,
- order_status: OrderStatusEnum, # Initial status, e.g., PENDING_PAYMENT
- order_total_amount: Decimal,
- final_amount_for_this_order: Decimal,
- shipping_address_recipient_name: str,
- shipping_address_phone_number: str,
- shipping_address_full: str,
- discount_amount: Optional[Decimal] = Decimal("0.00"),
- shipping_fee: Optional[Decimal] = Decimal("0.00"),
- notes_by_user: Optional[str] = None,
- actor_id: Optional[int]
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ store_id: int,
+ payment_transaction_id: int,
+ order_status: OrderStatusEnum, # Initial status, e.g., PENDING_PAYMENT
+ order_total_amount: Decimal,
+ final_amount_for_this_order: Decimal,
+ shipping_address_recipient_name: str,
+ shipping_address_phone_number: str,
+ shipping_address_full: str,
+ discount_amount: Optional[Decimal] = Decimal("0.00"),
+ shipping_fee: Optional[Decimal] = Decimal("0.00"),
+ notes_by_user: Optional[str] = None,
+ actor_id: Optional[int],
) -> Optional[Dict[str, Any]]:
"""
在数据库中创建一条新的订单记录。
@@ -91,7 +86,8 @@ def create_order(
)
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
UserID, StoreID, PaymentTransactionID, OrderStatus,
OrderTotalAmount, DiscountAmount, ShippingFee, FinalAmountForThisOrder,
@@ -103,24 +99,28 @@ def create_order(
:ShippingAddress_RecipientName, :ShippingAddress_PhoneNumber, :ShippingAddress_Full,
:Notes_ByUser, UTC_TIMESTAMP(), UTC_TIMESTAMP()
)
- """)
+ """
+ )
# Using UTC_TIMESTAMP() for CreationTime and initial LastUpdatedDate for consistency
try:
- result = conn.execute(insert_stmt, {
- "UserID": user_id,
- "StoreID": store_id,
- "PaymentTransactionID": payment_transaction_id,
- "OrderStatus": order_status.value, # Use enum value
- "OrderTotalAmount": order_total_amount,
- "DiscountAmount": discount_amount,
- "ShippingFee": shipping_fee,
- "FinalAmountForThisOrder": final_amount_for_this_order,
- "ShippingAddress_RecipientName": shipping_address_recipient_name,
- "ShippingAddress_PhoneNumber": shipping_address_phone_number,
- "ShippingAddress_Full": shipping_address_full,
- "Notes_ByUser": notes_by_user
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "UserID": user_id,
+ "StoreID": store_id,
+ "PaymentTransactionID": payment_transaction_id,
+ "OrderStatus": order_status.value, # Use enum value
+ "OrderTotalAmount": order_total_amount,
+ "DiscountAmount": discount_amount,
+ "ShippingFee": shipping_fee,
+ "FinalAmountForThisOrder": final_amount_for_this_order,
+ "ShippingAddress_RecipientName": shipping_address_recipient_name,
+ "ShippingAddress_PhoneNumber": shipping_address_phone_number,
+ "ShippingAddress_Full": shipping_address_full,
+ "Notes_ByUser": notes_by_user,
+ },
+ )
new_order_id = result.lastrowid
if new_order_id is None:
@@ -132,18 +132,14 @@ def create_order(
)
return self.get_order_by_id(conn, order_id=new_order_id, actor_id=actor_id)
except exc.IntegrityError as e:
- logger.error(
- f"Integrity error creating order for UserID {user_id}, StoreID {store_id}: {e}"
- )
+ logger.error(f"Integrity error creating order for UserID {user_id}, StoreID {store_id}: {e}")
return None
except Exception as e:
- logger.error(
- f"Unexpected error creating order for UserID {user_id}, StoreID {store_id}: {e}"
- )
+ logger.error(f"Unexpected error creating order for UserID {user_id}, StoreID {store_id}: {e}")
return None
def get_order_by_id(
- self, conn: Connection, *, order_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, order_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据 OrderID 检索特定的订单信息。
@@ -151,7 +147,8 @@ def get_order_by_id(
# logger.trace(f"Getting order by OrderID {order_id}, ActorID {actor_id}")
# _set_actor_session_variable 通常不需要用于只读操作
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderID, UserID, StoreID, PaymentTransactionID, OrderStatus,
OrderTotalAmount, DiscountAmount, ShippingFee, FinalAmountForThisOrder,
ShippingAddress_RecipientName, ShippingAddress_PhoneNumber, ShippingAddress_Full,
@@ -159,7 +156,8 @@ def get_order_by_id(
ShippingTime, DeliveryTime, CompletionTime, LastUpdatedDate
FROM {self.table_name}
WHERE OrderID = :OrderID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"OrderID": order_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
@@ -168,19 +166,14 @@ def get_order_by_id(
return None
def get_orders_by_user_id(
- self,
- conn: Connection,
- *,
- user_id: int,
- offset: int = 0,
- limit: int = 20,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, user_id: int, offset: int = 0, limit: int = 20, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定用户的所有订单列表,支持分页。
"""
# logger.trace(f"Getting orders for UserID {user_id}, Offset {offset}, Limit {limit}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderID, UserID, StoreID, PaymentTransactionID, OrderStatus,
OrderTotalAmount, DiscountAmount, ShippingFee, FinalAmountForThisOrder,
ShippingAddress_RecipientName, ShippingAddress_PhoneNumber, ShippingAddress_Full,
@@ -190,32 +183,24 @@ def get_orders_by_user_id(
WHERE UserID = :UserID
ORDER BY CreationTime DESC
LIMIT :Limit OFFSET :Offset
- """)
+ """
+ )
try:
- results = conn.execute(select_stmt, {
- "UserID": user_id,
- "Limit": limit,
- "Offset": offset
- }).fetchall()
+ results = conn.execute(select_stmt, {"UserID": user_id, "Limit": limit, "Offset": offset}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
except Exception as e:
logger.error(f"Error getting orders for UserID {user_id}: {e}")
return []
def get_orders_by_store_id(
- self,
- conn: Connection,
- *,
- store_id: int,
- offset: int = 0,
- limit: int = 20,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, store_id: int, offset: int = 0, limit: int = 20, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定店铺的所有订单列表,支持分页。 (通常由商家或管理员使用)
"""
# logger.trace(f"Getting orders for StoreID {store_id}, Offset {offset}, Limit {limit}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderID, UserID, StoreID, PaymentTransactionID, OrderStatus,
OrderTotalAmount, DiscountAmount, ShippingFee, FinalAmountForThisOrder,
ShippingAddress_RecipientName, ShippingAddress_PhoneNumber, ShippingAddress_Full,
@@ -225,26 +210,24 @@ def get_orders_by_store_id(
WHERE StoreID = :StoreID
ORDER BY CreationTime DESC
LIMIT :Limit OFFSET :Offset
- """)
+ """
+ )
try:
- results = conn.execute(select_stmt, {
- "StoreID": store_id,
- "Limit": limit,
- "Offset": offset
- }).fetchall()
+ results = conn.execute(select_stmt, {"StoreID": store_id, "Limit": limit, "Offset": offset}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
except Exception as e:
logger.error(f"Error getting orders for StoreID {store_id}: {e}")
return []
def get_orders_by_payment_transaction_id(
- self, conn: Connection, *, payment_transaction_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, payment_transaction_id: int, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索与特定支付事务ID关联的所有订单。
"""
# logger.trace(f"Getting orders for PaymentTransactionID {payment_transaction_id}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderID, UserID, StoreID, PaymentTransactionID, OrderStatus,
OrderTotalAmount, DiscountAmount, ShippingFee, FinalAmountForThisOrder,
ShippingAddress_RecipientName, ShippingAddress_PhoneNumber, ShippingAddress_Full,
@@ -253,7 +236,8 @@ def get_orders_by_payment_transaction_id(
FROM {self.table_name}
WHERE PaymentTransactionID = :PaymentTransactionID
ORDER BY OrderID ASC
- """)
+ """
+ )
try:
results = conn.execute(select_stmt, {"PaymentTransactionID": payment_transaction_id}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
@@ -262,18 +246,18 @@ def get_orders_by_payment_transaction_id(
return []
def update_order_status(
- self,
- conn: Connection,
- *,
- order_id: int,
- new_status: OrderStatusEnum,
- actor_id: int,
- payment_confirmation_time: Optional[datetime.datetime] = None,
- shipping_time: Optional[datetime.datetime] = None,
- delivery_time: Optional[datetime.datetime] = None,
- completion_time: Optional[datetime.datetime] = None,
- notes_by_actor: Optional[str] = None, # Can be Notes_ByUser or Notes_ByMerchant
- is_admin_or_merchant_action: bool = False
+ self,
+ conn: Connection,
+ *,
+ order_id: int,
+ new_status: OrderStatusEnum,
+ actor_id: int,
+ payment_confirmation_time: Optional[datetime.datetime] = None,
+ shipping_time: Optional[datetime.datetime] = None,
+ delivery_time: Optional[datetime.datetime] = None,
+ completion_time: Optional[datetime.datetime] = None,
+ notes_by_actor: Optional[str] = None, # Can be Notes_ByUser or Notes_ByMerchant
+ is_admin_or_merchant_action: bool = False,
) -> Optional[Dict[str, Any]]:
"""
更新订单的状态及相关的状态时间戳。
@@ -291,9 +275,7 @@ def update_order_status(
:param is_admin_or_merchant_action: (可选) 标记是否为管理员或商家操作,以决定更新哪个备注字段。
:return: 更新成功后的订单信息字典,如果订单未找到或更新失败则返回 None。
"""
- logger.info(
- f"ActorID {actor_id} attempting to update OrderID {order_id} to status {new_status.value}"
- )
+ logger.info(f"ActorID {actor_id} attempting to update OrderID {order_id} to status {new_status.value}")
self._set_actor_session_variable(conn, actor_id)
set_clauses: List[str] = ["OrderStatus = :OrderStatus"]
@@ -330,7 +312,8 @@ def update_order_status(
result = conn.execute(text(update_stmt_str), params_to_update)
if result.rowcount == 0:
logger.warning(
- f"OrderID {order_id} not found for status update, or status already set, by ActorID {actor_id}.")
+ f"OrderID {order_id} not found for status update, or status already set, by ActorID {actor_id}."
+ )
return None
logger.info(f"Status updated for OrderID {order_id} to {new_status.value} by ActorID {actor_id}.")
diff --git a/src/backend/app/crud/order_item_crud.py b/src/backend/app/crud/order_item_crud.py
index 0e3dbd0..289501b 100644
--- a/src/backend/app/crud/order_item_crud.py
+++ b/src/backend/app/crud/order_item_crud.py
@@ -32,31 +32,26 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# Consider re-raising or handling more gracefully depending on requirements
def create_order_item(
- self,
- conn: Connection,
- *,
- order_id: int,
- product_id: int,
- store_id: int,
- quantity: int,
- price_at_purchase: Decimal,
- product_name_at_purchase: str,
- product_image_url_at_purchase: Optional[str],
- subtotal: Decimal,
- actor_id: Optional[int] # The user/system ID performing this action
+ self,
+ conn: Connection,
+ *,
+ order_id: int,
+ product_id: int,
+ store_id: int,
+ quantity: int,
+ price_at_purchase: Decimal,
+ product_name_at_purchase: str,
+ product_image_url_at_purchase: Optional[str],
+ subtotal: Decimal,
+ actor_id: Optional[int], # The user/system ID performing this action
) -> Optional[Dict[str, Any]]:
"""
为指定的订单创建一个新的订单项目。
@@ -78,7 +73,8 @@ def create_order_item(
)
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
OrderID, ProductID, StoreID, Quantity, PriceAtPurchase,
ProductNameAtPurchase, ProductImageURLAtPurchase, Subtotal
@@ -86,19 +82,23 @@ def create_order_item(
:OrderID, :ProductID, :StoreID, :Quantity, :PriceAtPurchase,
:ProductNameAtPurchase, :ProductImageURLAtPurchase, :Subtotal
)
- """)
+ """
+ )
try:
- result = conn.execute(insert_stmt, {
- "OrderID": order_id,
- "ProductID": product_id,
- "StoreID": store_id,
- "Quantity": quantity,
- "PriceAtPurchase": price_at_purchase,
- "ProductNameAtPurchase": product_name_at_purchase,
- "ProductImageURLAtPurchase": product_image_url_at_purchase,
- "Subtotal": subtotal
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "OrderID": order_id,
+ "ProductID": product_id,
+ "StoreID": store_id,
+ "Quantity": quantity,
+ "PriceAtPurchase": price_at_purchase,
+ "ProductNameAtPurchase": product_name_at_purchase,
+ "ProductImageURLAtPurchase": product_image_url_at_purchase,
+ "Subtotal": subtotal,
+ },
+ )
new_order_item_id = result.lastrowid
if new_order_item_id is None:
@@ -118,18 +118,14 @@ def create_order_item(
# 获取并返回新创建的订单项目的完整信息
return self.get_order_item_by_id(conn, order_item_id=new_order_item_id, actor_id=actor_id)
except exc.IntegrityError as e:
- logger.error(
- f"Integrity error creating order item for OrderID {order_id}, ProductID {product_id}: {e}"
- )
+ logger.error(f"Integrity error creating order item for OrderID {order_id}, ProductID {product_id}: {e}")
return None
except Exception as e:
- logger.error(
- f"Unexpected error creating order item for OrderID {order_id}, ProductID {product_id}: {e}"
- )
+ logger.error(f"Unexpected error creating order item for OrderID {order_id}, ProductID {product_id}: {e}")
return None
def get_order_item_by_id(
- self, conn: Connection, *, order_item_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, order_item_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据 OrderItemID 检索特定的订单项目信息。
@@ -143,12 +139,14 @@ def get_order_item_by_id(
# logger.trace(f"Getting order item by OrderItemID {order_item_id}, ActorID {actor_id}")
# _set_actor_session_variable 通常不需要用于只读操作,除非触发器也用于 SELECT
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderItemID, OrderID, ProductID, StoreID, Quantity, PriceAtPurchase,
ProductNameAtPurchase, ProductImageURLAtPurchase, Subtotal
FROM {self.table_name}
WHERE OrderItemID = :OrderItemID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"OrderItemID": order_item_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
@@ -157,7 +155,7 @@ def get_order_item_by_id(
return None
def get_order_items_by_order_id(
- self, conn: Connection, *, order_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, order_id: int, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定订单的所有订单项目列表。
@@ -169,13 +167,15 @@ def get_order_items_by_order_id(
:return: 包含该订单所有项目信息的字典列表,如果订单没有项目则返回空列表。
"""
# logger.trace(f"Getting all order items for OrderID {order_id}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT OrderItemID, OrderID, ProductID, StoreID, Quantity, PriceAtPurchase,
ProductNameAtPurchase, ProductImageURLAtPurchase, Subtotal
FROM {self.table_name}
WHERE OrderID = :OrderID
ORDER BY OrderItemID ASC -- 通常按项目ID排序
- """)
+ """
+ )
try:
results = conn.execute(select_stmt, {"OrderID": order_id}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
diff --git a/src/backend/app/crud/payment_transaction_crud.py b/src/backend/app/crud/payment_transaction_crud.py
index dc47ac7..226bfbf 100644
--- a/src/backend/app/crud/payment_transaction_crud.py
+++ b/src/backend/app/crud/payment_transaction_crud.py
@@ -36,27 +36,22 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# Consider re-raising or handling
def create_payment_transaction(
- self,
- conn: Connection,
- *,
- user_id: int,
- total_amount: Decimal,
- payment_method: str,
- status: str, # Should be a value from PaymentTransactionStatusEnum, e.g., "PENDING"
- actor_id: Optional[int]
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ total_amount: Decimal,
+ payment_method: str,
+ status: str, # Should be a value from PaymentTransactionStatusEnum, e.g., "PENDING"
+ actor_id: Optional[int],
) -> Optional[Dict[str, Any]]:
"""
在数据库中创建一条新的支付事务记录。
@@ -77,7 +72,8 @@ def create_payment_transaction(
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
UserID, TotalAmount, PaymentMethod, Status,
CreationTime, LastUpdatedDate
@@ -85,16 +81,15 @@ def create_payment_transaction(
:UserID, :TotalAmount, :PaymentMethod, :Status,
UTC_TIMESTAMP(), UTC_TIMESTAMP()
)
- """)
+ """
+ )
# Using UTC_TIMESTAMP() for CreationTime and initial LastUpdatedDate
try:
- result = conn.execute(insert_stmt, {
- "UserID": user_id,
- "TotalAmount": total_amount,
- "PaymentMethod": payment_method,
- "Status": status
- })
+ result = conn.execute(
+ insert_stmt,
+ {"UserID": user_id, "TotalAmount": total_amount, "PaymentMethod": payment_method, "Status": status},
+ )
new_transaction_id = result.lastrowid
if new_transaction_id is None:
@@ -104,34 +99,33 @@ def create_payment_transaction(
logger.info(
f"Payment transaction created with PaymentTransactionID {new_transaction_id} for UserID {user_id} by ActorID {actor_id}."
)
- return self.get_payment_transaction_by_id(conn, payment_transaction_id=new_transaction_id,
- actor_id=actor_id)
- except exc.IntegrityError as e:
- logger.error(
- f"Integrity error creating payment transaction for UserID {user_id}: {e}"
+ return self.get_payment_transaction_by_id(
+ conn, payment_transaction_id=new_transaction_id, actor_id=actor_id
)
+ except exc.IntegrityError as e:
+ logger.error(f"Integrity error creating payment transaction for UserID {user_id}: {e}")
return None
except Exception as e:
- logger.error(
- f"Unexpected error creating payment transaction for UserID {user_id}: {e}"
- )
+ logger.error(f"Unexpected error creating payment transaction for UserID {user_id}: {e}")
return None
def get_payment_transaction_by_id(
- self, conn: Connection, *, payment_transaction_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, payment_transaction_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据 PaymentTransactionID 检索特定的支付事务信息。
"""
# logger.trace(f"Getting payment transaction by PaymentTransactionID {payment_transaction_id}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT PaymentTransactionID, UserID, TotalAmount, PaymentMethod,
ExternalGatewayTransactionID, Status, CreationTime,
CompletionTime, LastUpdatedDate
FROM {self.table_name}
WHERE PaymentTransactionID = :PaymentTransactionID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"PaymentTransactionID": payment_transaction_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
@@ -140,19 +134,14 @@ def get_payment_transaction_by_id(
return None
def get_payment_transactions_by_user_id(
- self,
- conn: Connection,
- *,
- user_id: int,
- offset: int = 0,
- limit: int = 20,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, user_id: int, offset: int = 0, limit: int = 20, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定用户的所有支付事务列表,支持分页。
"""
# logger.trace(f"Getting payment transactions for UserID {user_id}, Offset {offset}, Limit {limit}, ActorID {actor_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT PaymentTransactionID, UserID, TotalAmount, PaymentMethod,
ExternalGatewayTransactionID, Status, CreationTime,
CompletionTime, LastUpdatedDate
@@ -160,27 +149,24 @@ def get_payment_transactions_by_user_id(
WHERE UserID = :UserID
ORDER BY CreationTime DESC
LIMIT :Limit OFFSET :Offset
- """)
+ """
+ )
try:
- results = conn.execute(select_stmt, {
- "UserID": user_id,
- "Limit": limit,
- "Offset": offset
- }).fetchall()
+ results = conn.execute(select_stmt, {"UserID": user_id, "Limit": limit, "Offset": offset}).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
except Exception as e:
logger.error(f"Error getting payment transactions for UserID {user_id}: {e}")
return []
def update_payment_transaction_status(
- self,
- conn: Connection,
- *,
- payment_transaction_id: int,
- new_status: str, # Should be a value from PaymentTransactionStatusEnum
- actor_id: int,
- external_gateway_transaction_id: Optional[str] = None,
- completion_time: Optional[datetime.datetime] = None # Should be aware UTC if provided
+ self,
+ conn: Connection,
+ *,
+ payment_transaction_id: int,
+ new_status: str, # Should be a value from PaymentTransactionStatusEnum
+ actor_id: int,
+ external_gateway_transaction_id: Optional[str] = None,
+ completion_time: Optional[datetime.datetime] = None, # Should be aware UTC if provided
) -> Optional[Dict[str, Any]]:
"""
更新支付事务的状态,并可选地更新外部网关ID和完成时间。
@@ -200,10 +186,7 @@ def update_payment_transaction_status(
self._set_actor_session_variable(conn, actor_id)
set_clauses: List[str] = ["Status = :Status"]
- params_to_update: Dict[str, Any] = {
- "PaymentTransactionID_param": payment_transaction_id,
- "Status": new_status
- }
+ params_to_update: Dict[str, Any] = {"PaymentTransactionID_param": payment_transaction_id, "Status": new_status}
if external_gateway_transaction_id is not None:
set_clauses.append("ExternalGatewayTransactionID = :ExternalGatewayTransactionID")
@@ -214,14 +197,16 @@ def update_payment_transaction_status(
# 如果数据库列是 naive DATETIME,确保传入 naive UTC datetime
if completion_time.tzinfo is not None:
params_to_update["CompletionTime"] = completion_time.astimezone(datetime.timezone.utc).replace(
- tzinfo=None)
+ tzinfo=None
+ )
else: # 假设传入的 naive datetime 已经是 UTC
params_to_update["CompletionTime"] = completion_time
if not set_clauses: # Should not happen as Status is always updated
logger.warning(f"No fields to update for PaymentTransactionID {payment_transaction_id} status update.")
- return self.get_payment_transaction_by_id(conn, payment_transaction_id=payment_transaction_id,
- actor_id=actor_id)
+ return self.get_payment_transaction_by_id(
+ conn, payment_transaction_id=payment_transaction_id, actor_id=actor_id
+ )
update_stmt_str = f"UPDATE {self.table_name} SET {', '.join(set_clauses)} WHERE PaymentTransactionID = :PaymentTransactionID_param"
@@ -236,8 +221,9 @@ def update_payment_transaction_status(
logger.info(
f"Status updated for PaymentTransactionID {payment_transaction_id} to {new_status} by ActorID {actor_id}."
)
- return self.get_payment_transaction_by_id(conn, payment_transaction_id=payment_transaction_id,
- actor_id=actor_id)
+ return self.get_payment_transaction_by_id(
+ conn, payment_transaction_id=payment_transaction_id, actor_id=actor_id
+ )
except Exception as e:
logger.error(f"Error updating status for PaymentTransactionID {payment_transaction_id}: {e}")
return None
diff --git a/src/backend/app/crud/product_change_request_crud.py b/src/backend/app/crud/product_change_request_crud.py
index 8b0bf41..60dff75 100644
--- a/src/backend/app/crud/product_change_request_crud.py
+++ b/src/backend/app/crud/product_change_request_crud.py
@@ -9,7 +9,7 @@
class ProductChangeRequestCRUD:
__instance: Optional["ProductChangeRequestCRUD"] = None
-
+
@classmethod
def get_instance(cls) -> "ProductChangeRequestCRUD":
"""
@@ -19,10 +19,10 @@ def get_instance(cls) -> "ProductChangeRequestCRUD":
if cls.__instance is None:
cls.__instance = ProductChangeRequestCRUD()
return cls.__instance
-
+
def __init__(self):
self.table_name = "ProductChangeRequest"
-
+
@staticmethod
def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
@@ -31,21 +31,16 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:param actor_id: 操作者ID
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
-
+ conn.execute(text("SET @actor_id = NULL"))
+
def get_change_request_by_id(
- self,
- conn: Connection,
- *,
- request_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
根据ID获取商品变更请求信息
@@ -56,7 +51,8 @@ def get_change_request_by_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -64,18 +60,16 @@ def get_change_request_by_id(
CreationTime, LastUpdatedDate
FROM {self.table_name}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"request_id": request_id}
- ).fetchone()
+ result = conn.execute(select_stmt, {"request_id": request_id}).fetchone()
if not result:
return None
-
+
result_dict = dict(result._mapping)
-
+
# 处理JSON字段
if result_dict.get("ProposedData_JSON"):
# 检查是否已经是字典
@@ -87,16 +81,16 @@ def get_change_request_by_id(
result_dict["ProposedData_JSON"] = {}
return result_dict
-
+
def get_change_requests_by_product_id(
- self,
- conn: Connection,
- *,
- product_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定商品的变更请求列表
@@ -111,19 +105,16 @@ def get_change_requests_by_product_id(
self._set_actor_session_variable(conn, actor_id)
conditions = ["ProductID = :product_id"]
- params = {
- "product_id": product_id,
- "limit": limit,
- "offset": offset
- }
-
+ params = {"product_id": product_id, "limit": limit, "offset": offset}
+
if status:
conditions.append("Status = :status")
params["status"] = status
-
+
where_clause = " AND ".join(conditions)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -133,14 +124,15 @@ def get_change_requests_by_product_id(
WHERE {where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -149,20 +141,20 @@ def get_change_requests_by_product_id(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_change_requests_by_store_id(
- self,
- conn: Connection,
- *,
- store_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定店铺的商品变更请求列表
@@ -177,19 +169,16 @@ def get_change_requests_by_store_id(
self._set_actor_session_variable(conn, actor_id)
conditions = ["StoreID = :store_id"]
- params = {
- "store_id": store_id,
- "limit": limit,
- "offset": offset
- }
-
+ params = {"store_id": store_id, "limit": limit, "offset": offset}
+
if status:
conditions.append("Status = :status")
params["status"] = status
-
+
where_clause = " AND ".join(conditions)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -199,14 +188,15 @@ def get_change_requests_by_store_id(
WHERE {where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -215,20 +205,20 @@ def get_change_requests_by_store_id(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_change_requests_by_merchant_id(
- self,
- conn: Connection,
- *,
- merchant_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ merchant_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定商家的变更请求列表
@@ -243,19 +233,16 @@ def get_change_requests_by_merchant_id(
self._set_actor_session_variable(conn, actor_id)
conditions = ["MerchantUserID = :merchant_id"]
- params = {
- "merchant_id": merchant_id,
- "limit": limit,
- "offset": offset
- }
-
+ params = {"merchant_id": merchant_id, "limit": limit, "offset": offset}
+
if status:
conditions.append("Status = :status")
params["status"] = status
-
+
where_clause = " AND ".join(conditions)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -265,14 +252,15 @@ def get_change_requests_by_merchant_id(
WHERE {where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -281,18 +269,18 @@ def get_change_requests_by_merchant_id(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_all_pending_requests(
- self,
- conn: Connection,
- *,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取所有待审核的商品变更请求列表
@@ -304,7 +292,8 @@ def get_all_pending_requests(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -314,20 +303,15 @@ def get_all_pending_requests(
WHERE Status = 'PENDING_APPROVAL'
ORDER BY CreationTime ASC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
+
+ results = conn.execute(select_stmt, {"limit": limit, "offset": offset}).fetchall()
- results = conn.execute(
- select_stmt,
- {
- "limit": limit,
- "offset": offset
- }
- ).fetchall()
-
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -336,26 +320,26 @@ def get_all_pending_requests(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_filtered_requests(
- self,
- conn: Connection,
- *,
- product_id: Optional[int] = None,
- store_id: Optional[int] = None,
- merchant_id: Optional[int] = None,
- request_type: Optional[str] = None,
- status: Optional[str] = None,
- admin_id: Optional[int] = None,
- start_date: Optional[datetime.datetime] = None,
- end_date: Optional[datetime.datetime] = None,
- limit: int = 20,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: Optional[int] = None,
+ store_id: Optional[int] = None,
+ merchant_id: Optional[int] = None,
+ request_type: Optional[str] = None,
+ status: Optional[str] = None,
+ admin_id: Optional[int] = None,
+ start_date: Optional[datetime.datetime] = None,
+ end_date: Optional[datetime.datetime] = None,
+ limit: int = 20,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取根据多种条件筛选的请求列表
@@ -377,47 +361,45 @@ def get_filtered_requests(
# 构建查询条件
conditions = []
- params = {
- "limit": limit,
- "offset": offset
- }
+ params = {"limit": limit, "offset": offset}
if product_id is not None:
conditions.append("ProductID = :product_id")
params["product_id"] = product_id
-
+
if store_id is not None:
conditions.append("StoreID = :store_id")
params["store_id"] = store_id
-
+
if merchant_id is not None:
conditions.append("MerchantUserID = :merchant_id")
params["merchant_id"] = merchant_id
-
+
if request_type is not None:
conditions.append("RequestType = :request_type")
params["request_type"] = request_type
-
+
if status is not None:
conditions.append("Status = :status")
params["status"] = status
-
+
if admin_id is not None:
conditions.append("AdminReviewerID = :admin_id")
params["admin_id"] = admin_id
-
+
if start_date is not None:
conditions.append("CreationTime >= :start_date")
params["start_date"] = start_date
-
+
if end_date is not None:
conditions.append("CreationTime <= :end_date")
params["end_date"] = end_date
# 构建完整查询语句
where_clause = "" if not conditions else f"WHERE {' AND '.join(conditions)}"
-
- select_stmt = text(f"""
+
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, ProductID, MerchantUserID, StoreID,
RequestType, ProposedData_JSON, Status, SubmitterNotes,
@@ -427,14 +409,15 @@ def get_filtered_requests(
{where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -443,22 +426,22 @@ def get_filtered_requests(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def create_change_request(
- self,
- conn: Connection,
- *,
- merchant_user_id: int,
- store_id: int,
- request_type: str,
- proposed_data: Optional[Dict[str, Any]] = None,
- product_id: Optional[int] = None,
- submitter_notes: Optional[str] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ merchant_user_id: int,
+ store_id: int,
+ request_type: str,
+ proposed_data: Optional[Dict[str, Any]] = None,
+ product_id: Optional[int] = None,
+ submitter_notes: Optional[str] = None,
+ actor_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
创建新商品变更请求
@@ -477,7 +460,8 @@ def create_change_request(
# 将proposed_data转换为JSON字符串
proposed_data_json = json.dumps(proposed_data) if proposed_data else None
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
ProductID, MerchantUserID, StoreID, RequestType,
ProposedData_JSON, Status, SubmitterNotes
@@ -485,7 +469,8 @@ def create_change_request(
:product_id, :merchant_user_id, :store_id, :request_type,
:proposed_data_json, 'PENDING_APPROVAL', :submitter_notes
)
- """)
+ """
+ )
result = conn.execute(
insert_stmt,
@@ -495,29 +480,25 @@ def create_change_request(
"store_id": store_id,
"request_type": request_type,
"proposed_data_json": proposed_data_json,
- "submitter_notes": submitter_notes
- }
+ "submitter_notes": submitter_notes,
+ },
)
-
+
# 获取新插入记录的ID
request_id = result.lastrowid
-
+
# 获取完整的记录
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def update_request_status(
- self,
- conn: Connection,
- *,
- request_id: int,
- status: str,
- admin_id: Optional[int] = None,
- admin_notes: Optional[str] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ status: str,
+ admin_id: Optional[int] = None,
+ admin_notes: Optional[str] = None,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新请求状态
@@ -532,53 +513,44 @@ def update_request_status(
self._set_actor_session_variable(conn, actor_id)
# 检查请求是否存在
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request:
return None
# 构建更新语句
update_fields = ["Status = :status"]
- params = {
- "request_id": request_id,
- "status": status
- }
-
+ params = {"request_id": request_id, "status": status}
+
if admin_id is not None:
update_fields.append("AdminReviewerID = :admin_id")
update_fields.append("ReviewTimestamp = NOW()")
params["admin_id"] = admin_id
-
+
if admin_notes is not None:
update_fields.append("AdminNotes = :admin_notes")
params["admin_notes"] = admin_notes
-
- update_stmt = text(f"""
+
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {', '.join(update_fields)}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
-
+
# 返回更新后的请求信息
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def update_request(
- self,
- conn: Connection,
- *,
- request_id: int,
- update_data: Dict[str, Any],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ update_data: Dict[str, Any],
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新请求信息
@@ -591,66 +563,60 @@ def update_request(
self._set_actor_session_variable(conn, actor_id)
# 检查请求是否存在
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request:
return None
-
+
# 构建更新语句
update_fields = []
params = {"request_id": request_id}
-
+
# 定义允许更新的字段
allowed_fields = {
"requesttype": "RequestType",
"proposeddata_json": "ProposedData_JSON",
"status": "Status",
"submitternotes": "SubmitterNotes",
- "productid": "ProductID"
+ "productid": "ProductID",
}
-
+
for key, value in update_data.items():
key_lower = key.lower()
if key_lower in allowed_fields:
db_field = allowed_fields[key_lower]
-
+
# 特殊处理JSON字段
if db_field == "ProposedData_JSON" and value is not None:
if isinstance(value, dict):
value = json.dumps(value)
-
+
update_fields.append(f"{db_field} = :{key_lower}")
params[key_lower] = value
-
+
if not update_fields:
# 没有可更新的字段
return request
-
- update_stmt = text(f"""
+
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {', '.join(update_fields)}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
-
+
# 返回更新后的请求信息
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def cancel_request(
- self,
- conn: Connection,
- *,
- request_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ actor_id: Optional[int] = None,
) -> bool:
"""
取消请求(商家自行取消)
@@ -662,26 +628,21 @@ def cancel_request(
self._set_actor_session_variable(conn, actor_id)
# 先检查请求是否存在且状态为待审核
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request or request.get("Status") != "PENDING_APPROVAL":
return False
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET Status = 'CANCELLED_BY_USER'
WHERE ChangeRequestID = :request_id AND Status = 'PENDING_APPROVAL'
- """)
-
- result = conn.execute(
- update_stmt,
- {"request_id": request_id}
+ """
)
-
+
+ result = conn.execute(update_stmt, {"request_id": request_id})
+
return result.rowcount > 0
diff --git a/src/backend/app/crud/product_change_request_crud_v2.py b/src/backend/app/crud/product_change_request_crud_v2.py
index 0930ede..2b7444a 100644
--- a/src/backend/app/crud/product_change_request_crud_v2.py
+++ b/src/backend/app/crud/product_change_request_crud_v2.py
@@ -7,8 +7,10 @@
import datetime
-from backend.app.schemas.product_change_request_schema_v2 import \
- ProductChangeRequestTypeApiEnum as TypeEnum, ProductChangeRequestStatusApiEnum as StatusEnum
+from backend.app.schemas.product_change_request_schema_v2 import (
+ ProductChangeRequestTypeApiEnum as TypeEnum,
+ ProductChangeRequestStatusApiEnum as StatusEnum,
+)
from backend.app.utils.json import DecimalEncoder
@@ -16,6 +18,7 @@ class ProductChangeRequestCRUD2:
"""
商品变更请求的 CRUD 操作类。
"""
+
__instance: Optional["ProductChangeRequestCRUD2"] = None
@classmethod
@@ -40,14 +43,9 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:param actor_id: 操作者ID
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
@staticmethod
def _deserialize_proposed_data(data: Optional[Any]) -> Optional[Dict[str, Any]]:
@@ -88,13 +86,15 @@ def get_request_by_id(self, conn: Connection, *, request_id: int) -> Optional[Di
:return: 商品变更请求数据字典或None
"""
logger.debug(f"Getting ProductChangeRequest by ID {request_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT ChangeRequestID, ProductID, MerchantUserID, StoreID, RequestType,
ProposedData_JSON, Status, SubmitterNotes, AdminReviewerID,
ReviewTimestamp, AdminNotes, CreationTime, LastUpdatedDate
FROM {self.table_name}
WHERE ChangeRequestID = :ChangeRequestID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"ChangeRequestID": request_id}).fetchone()
if result:
@@ -105,8 +105,9 @@ def get_request_by_id(self, conn: Connection, *, request_id: int) -> Optional[Di
logger.error(f"Error getting ProductChangeRequest by ID {request_id}: {e}")
return None
- def get_request_by_id_for_owner(self, conn: Connection, *, request_id: int, merchant_user_id: int) -> Optional[
- Dict[str, Any]]:
+ def get_request_by_id_for_owner(
+ self, conn: Connection, *, request_id: int, merchant_user_id: int
+ ) -> Optional[Dict[str, Any]]:
"""
根据请求ID和商家ID获取商品变更请求
:param conn: 数据库连接
@@ -115,37 +116,41 @@ def get_request_by_id_for_owner(self, conn: Connection, *, request_id: int, merc
:return: 商品变更请求数据字典或None
"""
logger.debug(f"Getting ProductChangeRequest by ID {request_id} for MerchantUserID {merchant_user_id}")
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT ChangeRequestID, ProductID, MerchantUserID, StoreID, RequestType,
ProposedData_JSON, Status, SubmitterNotes, AdminReviewerID,
ReviewTimestamp, AdminNotes, CreationTime, LastUpdatedDate
FROM {self.table_name}
WHERE ChangeRequestID = :ChangeRequestID AND MerchantUserID = :MerchantUserID
- """)
+ """
+ )
try:
- result = conn.execute(select_stmt,
- {"ChangeRequestID": request_id, "MerchantUserID": merchant_user_id}).fetchone()
+ result = conn.execute(
+ select_stmt, {"ChangeRequestID": request_id, "MerchantUserID": merchant_user_id}
+ ).fetchone()
if result:
row_dict = dict(result._mapping) # type: ignore
return self._handle_result_dict(row_dict)
return None
except Exception as e:
logger.error(
- f"Error getting ProductChangeRequest ID {request_id} for MerchantUserID {merchant_user_id}: {e}")
+ f"Error getting ProductChangeRequest ID {request_id} for MerchantUserID {merchant_user_id}: {e}"
+ )
return None
def get_request_list(
- self,
- conn: Connection,
- *,
- status: Optional[List[str] | str] = None,
- request_type: Optional[str] = None,
- store_id: Optional[int] = None,
- product_id: Optional[int] = None,
- merchant_user_id: Optional[int] = None,
- # 分页参数可以后续添加,当前按用户要求不分页
- # offset: int = 0,
- # limit: int = 1000 # 默认一个较大的限制,如果真的不分页
+ self,
+ conn: Connection,
+ *,
+ status: Optional[List[str] | str] = None,
+ request_type: Optional[str] = None,
+ store_id: Optional[int] = None,
+ product_id: Optional[int] = None,
+ merchant_user_id: Optional[int] = None,
+ # 分页参数可以后续添加,当前按用户要求不分页
+ # offset: int = 0,
+ # limit: int = 1000 # 默认一个较大的限制,如果真的不分页
) -> List[Dict[str, Any]]:
"""
获取商品变更请求列表,可按状态、类型、商店ID、商品ID和商家ID筛选。
@@ -158,7 +163,8 @@ def get_request_list(
:param merchant_user_id: 商家ID
"""
logger.debug(
- f"Getting ProductChangeRequest list with filters - Status: {status}, Type: {request_type}, Store: {store_id}, Product: {product_id}, Merchant: {merchant_user_id}")
+ f"Getting ProductChangeRequest list with filters - Status: {status}, Type: {request_type}, Store: {store_id}, Product: {product_id}, Merchant: {merchant_user_id}"
+ )
params: Dict[str, Any] = {}
where_clauses: List[str] = []
@@ -195,14 +201,16 @@ def get_request_list(
# DDL 默认按 CreationTime, LastUpdatedDate 排序,这里可以加一个显式排序
# ORDER BY CreationTime DESC LIMIT :Limit OFFSET :Offset (如果分页)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT ChangeRequestID, ProductID, MerchantUserID, StoreID, RequestType,
ProposedData_JSON, Status, SubmitterNotes, AdminReviewerID,
ReviewTimestamp, AdminNotes, CreationTime, LastUpdatedDate
FROM {self.table_name}
{where_sql}
ORDER BY CreationTime DESC
- """)
+ """
+ )
# params["Limit"] = limit
# params["Offset"] = offset
@@ -219,23 +227,24 @@ def get_request_list(
return []
def _create_generic_request(
- self,
- conn: Connection,
- *,
- merchant_user_id: int,
- store_id: int,
- request_type: str, # ProductChangeRequestTypeApiEnum.value
- proposed_data_json: Optional[Dict[str, Any]],
- submitter_notes: Optional[str],
- product_id: Optional[int],
- actor_id: Optional[int]
+ self,
+ conn: Connection,
+ *,
+ merchant_user_id: int,
+ store_id: int,
+ request_type: str, # ProductChangeRequestTypeApiEnum.value
+ proposed_data_json: Optional[Dict[str, Any]],
+ submitter_notes: Optional[str],
+ product_id: Optional[int],
+ actor_id: Optional[int],
) -> Optional[Dict[str, Any]]:
"""内部辅助方法,用于创建不同类型的请求。"""
self._set_actor_session_variable(conn, actor_id)
# CreationTime 和 LastUpdatedDate 由数据库 DEFAULT CURRENT_TIMESTAMP 处理
# Status 由数据库 DEFAULT 'PENDING_APPROVAL' 处理
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
ProductID, MerchantUserID, StoreID, RequestType, ProposedData_JSON,
SubmitterNotes
@@ -244,23 +253,28 @@ def _create_generic_request(
:ProductID, :MerchantUserID, :StoreID, :RequestType, :ProposedData_JSON,
:SubmitterNotes
)
- """)
+ """
+ )
serialized_proposed_data = self._serialize_proposed_data(proposed_data_json)
try:
- result = conn.execute(insert_stmt, {
- "ProductID": product_id,
- "MerchantUserID": merchant_user_id,
- "StoreID": store_id,
- "RequestType": request_type,
- "ProposedData_JSON": serialized_proposed_data,
- "SubmitterNotes": submitter_notes
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "ProductID": product_id,
+ "MerchantUserID": merchant_user_id,
+ "StoreID": store_id,
+ "RequestType": request_type,
+ "ProposedData_JSON": serialized_proposed_data,
+ "SubmitterNotes": submitter_notes,
+ },
+ )
new_request_id = result.lastrowid
if new_request_id is None:
logger.warning(
- f"lastrowid not available after creating {request_type} request for MerchantID {merchant_user_id}.")
+ f"lastrowid not available after creating {request_type} request for MerchantID {merchant_user_id}."
+ )
return None
logger.info(f"{request_type} request created with ID {new_request_id} by ActorID {actor_id}.")
@@ -273,14 +287,14 @@ def _create_generic_request(
return None
def create_request_create_product(
- self,
- conn: Connection,
- *,
- merchant_user_id: int, # 从认证用户获取
- store_id: int,
- submitter_notes: Optional[str],
- proposed_data_json: Dict[str, Any], # 服务层应确保此字典符合 ProposedProductData for create
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ merchant_user_id: int, # 从认证用户获取
+ store_id: int,
+ submitter_notes: Optional[str],
+ proposed_data_json: Dict[str, Any], # 服务层应确保此字典符合 ProposedProductData for create
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
创建商品变更请求 - 创建新商品
@@ -294,7 +308,8 @@ def create_request_create_product(
:return: 创建的请求数据字典或None
"""
logger.info(
- f"ActorID {actor_id} creating PRODUCT_CREATE request for MerchantID {merchant_user_id}, StoreID {store_id}")
+ f"ActorID {actor_id} creating PRODUCT_CREATE request for MerchantID {merchant_user_id}, StoreID {store_id}"
+ )
return self._create_generic_request(
conn=conn,
merchant_user_id=merchant_user_id,
@@ -303,19 +318,19 @@ def create_request_create_product(
proposed_data_json=proposed_data_json,
submitter_notes=submitter_notes,
product_id=None, # ProductID 为空对于创建请求
- actor_id=actor_id if actor_id is not None else merchant_user_id
+ actor_id=actor_id if actor_id is not None else merchant_user_id,
)
def create_request_update_product(
- self,
- conn: Connection,
- *,
- product_id: int,
- merchant_user_id: int, # 从认证用户获取
- store_id: int,
- submitter_notes: Optional[str],
- proposed_data_json: Dict[str, Any], # 服务层应确保此字典符合 ProposedProductData for update
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ merchant_user_id: int, # 从认证用户获取
+ store_id: int,
+ submitter_notes: Optional[str],
+ proposed_data_json: Dict[str, Any], # 服务层应确保此字典符合 ProposedProductData for update
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
创建商品变更请求 - 更新现有商品
@@ -329,7 +344,8 @@ def create_request_update_product(
:return: 创建的请求数据字典或None
"""
logger.info(
- f"ActorID {actor_id} creating PRODUCT_UPDATE request for ProductID {product_id}, MerchantID {merchant_user_id}")
+ f"ActorID {actor_id} creating PRODUCT_UPDATE request for ProductID {product_id}, MerchantID {merchant_user_id}"
+ )
return self._create_generic_request(
conn=conn,
merchant_user_id=merchant_user_id,
@@ -338,18 +354,18 @@ def create_request_update_product(
proposed_data_json=proposed_data_json,
submitter_notes=submitter_notes,
product_id=product_id,
- actor_id=actor_id if actor_id is not None else merchant_user_id
+ actor_id=actor_id if actor_id is not None else merchant_user_id,
)
def create_request_delete_product(
- self,
- conn: Connection,
- *,
- product_id: int,
- merchant_user_id: int, # 从认证用户获取
- store_id: int,
- submitter_notes: Optional[str],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ merchant_user_id: int, # 从认证用户获取
+ store_id: int,
+ submitter_notes: Optional[str],
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
创建商品变更请求 - 删除商品
@@ -362,7 +378,8 @@ def create_request_delete_product(
:return: 创建的请求数据字典或None
"""
logger.info(
- f"ActorID {actor_id} creating PRODUCT_DELETE request for ProductID {product_id}, MerchantID {merchant_user_id}")
+ f"ActorID {actor_id} creating PRODUCT_DELETE request for ProductID {product_id}, MerchantID {merchant_user_id}"
+ )
return self._create_generic_request(
conn=conn,
merchant_user_id=merchant_user_id,
@@ -371,16 +388,16 @@ def create_request_delete_product(
proposed_data_json=None, # 删除请求通常不需要建议数据体
submitter_notes=submitter_notes,
product_id=product_id,
- actor_id=actor_id if actor_id is not None else merchant_user_id
+ actor_id=actor_id if actor_id is not None else merchant_user_id,
)
def cancel_request( # This method now means "cancel by merchant"
- self,
- conn: Connection,
- *,
- request_id: int,
- # merchant_user_id: int, # Ownership check done by service layer
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ # merchant_user_id: int, # Ownership check done by service layer
+ actor_id: Optional[int] = None,
) -> bool:
"""
取消商品变更请求 - 由商家发起 (或管理员代为操作,但主要场景是商家)。
@@ -395,22 +412,23 @@ def cancel_request( # This method now means "cancel by merchant"
logger.info(f"ActorID {actor_id} attempting to cancel (delete_request) ChangeRequestID {request_id}.")
self._set_actor_session_variable(conn, actor_id)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET Status = :cancelled_status
WHERE ChangeRequestID = :request_id AND Status = :pending_status
- """)
+ """
+ )
# LastUpdatedDate 会由数据库的 ON UPDATE CURRENT_TIMESTAMP 自动处理
try:
cancelled_status = StatusEnum.CANCELLED_BY_USER
pending_status = StatusEnum.PENDING_APPROVAL
- result = conn.execute(update_stmt, {
- "cancelled_status": cancelled_status,
- "request_id": request_id,
- "pending_status": pending_status
- })
+ result = conn.execute(
+ update_stmt,
+ {"cancelled_status": cancelled_status, "request_id": request_id, "pending_status": pending_status},
+ )
if result.rowcount > 0:
logger.info(f"ChangeRequestID {request_id} status updated to {cancelled_status} by ActorID {actor_id}.")
@@ -430,14 +448,14 @@ def cancel_request( # This method now means "cancel by merchant"
return False
def update_request_by_admin(
- self,
- conn: Connection,
- *,
- request_id: int,
- status: str, # ProductChangeRequestStatusApiEnum.value
- admin_notes: Optional[str],
- admin_reviewer_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ status: str, # ProductChangeRequestStatusApiEnum.value
+ admin_notes: Optional[str],
+ admin_reviewer_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
管理员审核商品变更请求 - 更新请求状态
@@ -451,13 +469,14 @@ def update_request_by_admin(
:return:
"""
logger.info(
- f"AdminID {admin_reviewer_id} (ActorID {actor_id}) updating ChangeRequestID {request_id} to Status {status}.")
+ f"AdminID {admin_reviewer_id} (ActorID {actor_id}) updating ChangeRequestID {request_id} to Status {status}."
+ )
self._set_actor_session_variable(conn, actor_id if actor_id is not None else admin_reviewer_id)
set_clauses: List[str] = [
"Status = :Status",
"AdminReviewerID = :AdminReviewerID",
- "ReviewTimestamp = UTC_TIMESTAMP()"
+ "ReviewTimestamp = UTC_TIMESTAMP()",
# LastUpdatedDate is handled by DB's ON UPDATE
]
params: Dict[str, Any] = {
@@ -470,9 +489,11 @@ def update_request_by_admin(
set_clauses.append("AdminNotes = :AdminNotes")
params["AdminNotes"] = admin_notes
- update_stmt_str = (f"UPDATE {self.table_name}"
- f" SET {', '.join(set_clauses)}"
- f" WHERE ChangeRequestID = :ChangeRequestID_param")
+ update_stmt_str = (
+ f"UPDATE {self.table_name}"
+ f" SET {', '.join(set_clauses)}"
+ f" WHERE ChangeRequestID = :ChangeRequestID_param"
+ )
try:
result = conn.execute(text(update_stmt_str), params)
@@ -487,12 +508,12 @@ def update_request_by_admin(
return None
def update_request_applied(
- self,
- conn: Connection,
- *,
- request_id: int,
- new_product_id: Optional[int] = None, # 新商品ID,如果创建了新商品,需要回填
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ new_product_id: Optional[int] = None, # 新商品ID,如果创建了新商品,需要回填
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
管理员/商家审核商品变更请求 - 更新请求状态为已应用
@@ -504,8 +525,7 @@ def update_request_applied(
:param actor_id:
:return:
"""
- logger.info(
- f"ActorID {actor_id} updating ChangeRequestID {request_id} to Status 'APPLIED'.")
+ logger.info(f"ActorID {actor_id} updating ChangeRequestID {request_id} to Status 'APPLIED'.")
self._set_actor_session_variable(conn, actor_id)
# update_stmt = text(f"""
@@ -521,17 +541,18 @@ def update_request_applied(
params: Dict[str, Any] = {
"applied_status": applied_status,
"request_id": request_id,
- "approved_status": approved_status
+ "approved_status": approved_status,
}
if new_product_id is not None:
set_clauses.append("ProductID = :new_product_id")
params["new_product_id"] = new_product_id
- update_stmt_str = (f"UPDATE {self.table_name}"
- f" SET {', '.join(set_clauses)}"
- f" WHERE ChangeRequestID = :request_id AND Status = :approved_status")
+ update_stmt_str = (
+ f"UPDATE {self.table_name}"
+ f" SET {', '.join(set_clauses)}"
+ f" WHERE ChangeRequestID = :request_id AND Status = :approved_status"
+ )
update_stmt = text(update_stmt_str)
-
# LastUpdatedDate 会由数据库的 ON UPDATE CURRENT_TIMESTAMP 自动处理
try:
diff --git a/src/backend/app/crud/product_crud.py b/src/backend/app/crud/product_crud.py
index bb133ad..dc7b676 100644
--- a/src/backend/app/crud/product_crud.py
+++ b/src/backend/app/crud/product_crud.py
@@ -33,21 +33,16 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:param actor_id: 操作者ID
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
def get_product_by_id(
- self,
- conn: Connection,
- *,
- product_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
根据ID获取商品信息
@@ -58,29 +53,28 @@ def get_product_by_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
MainImageURL, CreationDate, LastUpdatedDate
FROM {self.table_name} WHERE ProductID = :product_id
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"product_id": product_id}
- ).fetchone()
+ result = conn.execute(select_stmt, {"product_id": product_id}).fetchone()
return dict(result._mapping) if result else None
def get_products_by_store_id(
- self,
- conn: Connection,
- *,
- store_id: int,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: int,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定店铺的商品列表
@@ -93,7 +87,8 @@ def get_products_by_store_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -102,27 +97,21 @@ def get_products_by_store_id(
WHERE StoreID = :store_id
ORDER BY ProductID
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
- results = conn.execute(
- select_stmt,
- {
- "store_id": store_id,
- "limit": limit,
- "offset": offset
- }
- ).fetchall()
+ results = conn.execute(select_stmt, {"store_id": store_id, "limit": limit, "offset": offset}).fetchall()
return [dict(row._mapping) for row in results]
def get_products_by_category_id(
- self,
- conn: Connection,
- *,
- category_id: int,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ category_id: int,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定分类的商品列表
@@ -135,7 +124,8 @@ def get_products_by_category_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -144,27 +134,21 @@ def get_products_by_category_id(
WHERE CategoryID = :category_id
ORDER BY ProductID
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
- results = conn.execute(
- select_stmt,
- {
- "category_id": category_id,
- "limit": limit,
- "offset": offset
- }
- ).fetchall()
+ results = conn.execute(select_stmt, {"category_id": category_id, "limit": limit, "offset": offset}).fetchall()
return [dict(row._mapping) for row in results]
def search_products(
- self,
- conn: Connection,
- *,
- search_term: str,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ search_term: str,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
搜索商品
@@ -176,11 +160,12 @@ def search_products(
:return: 匹配的商品信息列表
"""
self._set_actor_session_variable(conn, actor_id)
-
+
# 构建搜索条件
search_term_with_wildcards = f"%{search_term}%"
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -190,26 +175,22 @@ def search_products(
OR ProductDescription LIKE :search_term_with_wildcards
ORDER BY ProductID
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(
- select_stmt,
- {
- "search_term_with_wildcards": search_term_with_wildcards,
- "limit": limit,
- "offset": offset
- }
+ select_stmt, {"search_term_with_wildcards": search_term_with_wildcards, "limit": limit, "offset": offset}
).fetchall()
return [dict(row._mapping) for row in results]
def get_all_products(
- self,
- conn: Connection,
- *,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取所有商品列表
@@ -221,7 +202,8 @@ def get_all_products(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -229,24 +211,19 @@ def get_all_products(
FROM {self.table_name}
ORDER BY ProductID
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
- results = conn.execute(
- select_stmt,
- {
- "limit": limit,
- "offset": offset
- }
- ).fetchall()
+ results = conn.execute(select_stmt, {"limit": limit, "offset": offset}).fetchall()
return [dict(row._mapping) for row in results]
def get_product_with_category_info(
- self,
- conn: Connection,
- *,
- product_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
获取商品信息及其分类信息
@@ -258,7 +235,8 @@ def get_product_with_category_info(
try:
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
p.ProductID, p.ProductName, p.ProductDescription, p.Price,
p.ProductStatus, p.StoreID, p.CategoryID, p.StockQuantity,
@@ -267,14 +245,15 @@ def get_product_with_category_info(
FROM {self.table_name} p
JOIN ProductCategory pc ON p.CategoryID = pc.CategoryID
WHERE p.ProductID = :product_id
- """)
+ """
+ )
result = conn.execute(select_stmt, {"product_id": product_id}).fetchone()
-
+
if not result:
# 如果没有找到结果,返回None
return None
-
+
# 将结果转换为字典
product_dict = {
"ProductID": result[0],
@@ -288,28 +267,28 @@ def get_product_with_category_info(
"MainImageURL": result[8],
"CreationDate": result[9],
"LastUpdatedDate": result[10],
- "CategoryName": result[11]
+ "CategoryName": result[11],
}
-
+
return product_dict
-
+
except Exception as e:
logger.error(f"获取商品及分类信息时发生错误: {e}")
# 出错时返回None,而不是抛出异常
return None
def create_product(
- self,
- conn: Connection,
- *,
- product_name: str,
- price: Decimal,
- store_id: int,
- category_id: int,
- product_description: Optional[str] = None,
- stock_quantity: int = 0,
- main_image_url: Optional[str] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_name: str,
+ price: Decimal,
+ store_id: int,
+ category_id: int,
+ product_description: Optional[str] = None,
+ stock_quantity: int = 0,
+ main_image_url: Optional[str] = None,
+ actor_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
创建新商品
@@ -327,12 +306,14 @@ def create_product(
try:
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name}
(ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity, MainImageURL)
VALUES
(:product_name, :product_description, :price, :store_id, :category_id, :stock_quantity, :main_image_url)
- """)
+ """
+ )
conn.execute(
insert_stmt,
@@ -343,8 +324,8 @@ def create_product(
"store_id": store_id,
"category_id": category_id,
"stock_quantity": stock_quantity,
- "main_image_url": main_image_url
- }
+ "main_image_url": main_image_url,
+ },
)
# 获取自增ID
@@ -368,16 +349,16 @@ def create_product(
"MainImageURL": main_image_url,
"ProductStatus": "ACTIVE",
"CreationDate": datetime.datetime.now(),
- "LastUpdatedDate": datetime.datetime.now()
+ "LastUpdatedDate": datetime.datetime.now(),
}
def update_product(
- self,
- conn: Connection,
- *,
- product_id: int,
- update_data: Dict[str, Any],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ update_data: Dict[str, Any],
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新商品信息
@@ -400,8 +381,13 @@ def update_product(
params = {"product_id": product_id}
valid_fields = [
- "ProductName", "ProductDescription", "Price",
- "CategoryID", "StockQuantity", "MainImageURL", "ProductStatus"
+ "ProductName",
+ "ProductDescription",
+ "Price",
+ "CategoryID",
+ "StockQuantity",
+ "MainImageURL",
+ "ProductStatus",
]
for field in valid_fields:
@@ -413,11 +399,13 @@ def update_product(
# 没有可更新的字段
return product
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {", ".join(update_fields)}
WHERE ProductID = :product_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
@@ -429,12 +417,12 @@ def update_product(
return None
def update_product_stock(
- self,
- conn: Connection,
- *,
- product_id: int,
- stock_change: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ stock_change: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新商品库存
@@ -459,20 +447,16 @@ def update_product_stock(
f"Trying to reduce stock by {abs(stock_change)} but only {product['StockQuantity']} available."
)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET StockQuantity = GREATEST(0, StockQuantity + :stock_change)
WHERE ProductID = :product_id
- """)
-
- conn.execute(
- update_stmt,
- {
- "product_id": product_id,
- "stock_change": stock_change
- }
+ """
)
+ conn.execute(update_stmt, {"product_id": product_id, "stock_change": stock_change})
+
# 获取更新后的商品信息
return self.get_product_by_id(conn, product_id=product_id, actor_id=actor_id)
except InsufficientStockException as e:
@@ -485,11 +469,11 @@ def update_product_stock(
return None
def delete_product(
- self,
- conn: Connection,
- *,
- product_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ product_id: int,
+ actor_id: Optional[int] = None,
) -> bool:
"""
删除商品(通过将状态设置为DISCONTINUED)
@@ -505,35 +489,31 @@ def delete_product(
if not product:
return False
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET ProductStatus = 'DISCONTINUED'
WHERE ProductID = :product_id
- """)
-
- conn.execute(
- update_stmt,
- {
- "product_id": product_id
- }
+ """
)
- return True
+ conn.execute(update_stmt, {"product_id": product_id})
+ return True
def get_products_by_store_and_category(
- self,
- conn: Connection,
- *,
- store_id: int,
- category_id: int,
- min_price: Optional[Decimal] = None,
- max_price: Optional[Decimal] = None,
- product_status: Optional[str] = None,
- order_by: Optional[str] = None,
- limit: int = 20,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: int,
+ category_id: int,
+ min_price: Optional[Decimal] = None,
+ max_price: Optional[Decimal] = None,
+ product_status: Optional[str] = None,
+ order_by: Optional[str] = None,
+ limit: int = 20,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取特定店铺和分类的商品列表
@@ -553,12 +533,7 @@ def get_products_by_store_and_category(
# 构建查询条件
conditions = ["StoreID = :store_id", "CategoryID = :category_id"]
- params = {
- "store_id": store_id,
- "category_id": category_id,
- "limit": limit,
- "offset": offset
- }
+ params = {"store_id": store_id, "category_id": category_id, "limit": limit, "offset": offset}
# 添加价格过滤条件
if min_price is not None:
@@ -588,7 +563,8 @@ def get_products_by_store_and_category(
elif order_by == "oldest":
order_clause = "CreationDate ASC"
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -597,25 +573,26 @@ def get_products_by_store_and_category(
WHERE {' AND '.join(conditions)}
ORDER BY {order_clause}
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
return [dict(row._mapping) for row in results]
def get_filtered_products(
- self,
- conn: Connection,
- *,
- store_id: Optional[int] = None,
- category_id: Optional[int] = None,
- search_term: Optional[str] = None,
- min_price: Optional[Decimal] = None,
- max_price: Optional[Decimal] = None,
- product_status: Optional[str] = None,
- order_by: Optional[str] = None,
- limit: int = 20,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: Optional[int] = None,
+ category_id: Optional[int] = None,
+ search_term: Optional[str] = None,
+ min_price: Optional[Decimal] = None,
+ max_price: Optional[Decimal] = None,
+ product_status: Optional[str] = None,
+ order_by: Optional[str] = None,
+ limit: int = 20,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取根据多种条件筛选的商品列表
@@ -640,7 +617,7 @@ def get_filtered_products(
"limit": limit,
"offset": offset,
# 特殊标记,用于检测是否在测试过滤条件中显式传入了product_status=None
- "TESTING_WORKAROUND_PRODUCT_STATUS_IS_NONE": product_status is None and actor_id is not None
+ "TESTING_WORKAROUND_PRODUCT_STATUS_IS_NONE": product_status is None and actor_id is not None,
}
# 添加店铺过滤条件
@@ -688,8 +665,9 @@ def get_filtered_products(
# 构建完整查询语句
where_clause = "" if not conditions else f"WHERE {' AND '.join(conditions)}"
-
- select_stmt = text(f"""
+
+ select_stmt = text(
+ f"""
SELECT
ProductID, ProductName, ProductDescription, Price,
ProductStatus, StoreID, CategoryID, StockQuantity,
@@ -698,7 +676,8 @@ def get_filtered_products(
{where_clause}
ORDER BY {order_clause}
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
return [dict(row._mapping) for row in results]
diff --git a/src/backend/app/crud/store_change_request_crud.py b/src/backend/app/crud/store_change_request_crud.py
index a681e98..9020f6f 100644
--- a/src/backend/app/crud/store_change_request_crud.py
+++ b/src/backend/app/crud/store_change_request_crud.py
@@ -9,7 +9,7 @@
class StoreChangeRequestCRUD:
__instance: Optional["StoreChangeRequestCRUD"] = None
-
+
@classmethod
def get_instance(cls) -> "StoreChangeRequestCRUD":
"""
@@ -19,10 +19,10 @@ def get_instance(cls) -> "StoreChangeRequestCRUD":
if cls.__instance is None:
cls.__instance = StoreChangeRequestCRUD()
return cls.__instance
-
+
def __init__(self):
self.table_name = "StoreChangeRequest"
-
+
@staticmethod
def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
@@ -31,21 +31,16 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:param actor_id: 操作者ID
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
-
+ conn.execute(text("SET @actor_id = NULL"))
+
def get_change_request_by_id(
- self,
- conn: Connection,
- *,
- request_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
根据ID获取店铺变更请求信息
@@ -56,7 +51,8 @@ def get_change_request_by_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, StoreID, RequestingUserID, RequestType,
ProposedData_JSON, Status, SubmitterNotes,
@@ -64,18 +60,16 @@ def get_change_request_by_id(
CreationTime, LastUpdatedDate
FROM {self.table_name}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"request_id": request_id}
- ).fetchone()
+ result = conn.execute(select_stmt, {"request_id": request_id}).fetchone()
if not result:
return None
-
+
result_dict = dict(result._mapping)
-
+
# 处理JSON字段
if result_dict.get("ProposedData_JSON"):
# 检查是否已经是字典
@@ -87,16 +81,16 @@ def get_change_request_by_id(
result_dict["ProposedData_JSON"] = {}
return result_dict
-
+
def get_change_requests_by_store_id(
- self,
- conn: Connection,
- *,
- store_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定店铺的变更请求列表
@@ -111,19 +105,16 @@ def get_change_requests_by_store_id(
self._set_actor_session_variable(conn, actor_id)
conditions = ["StoreID = :store_id"]
- params = {
- "store_id": store_id,
- "limit": limit,
- "offset": offset
- }
-
+ params = {"store_id": store_id, "limit": limit, "offset": offset}
+
if status:
conditions.append("Status = :status")
params["status"] = status
-
+
where_clause = " AND ".join(conditions)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, StoreID, RequestingUserID, RequestType,
ProposedData_JSON, Status, SubmitterNotes,
@@ -133,14 +124,15 @@ def get_change_requests_by_store_id(
WHERE {where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -149,20 +141,20 @@ def get_change_requests_by_store_id(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_change_requests_by_user_id(
- self,
- conn: Connection,
- *,
- user_id: int,
- status: Optional[str] = None,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ status: Optional[str] = None,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取指定用户的变更请求列表
@@ -177,19 +169,16 @@ def get_change_requests_by_user_id(
self._set_actor_session_variable(conn, actor_id)
conditions = ["RequestingUserID = :user_id"]
- params = {
- "user_id": user_id,
- "limit": limit,
- "offset": offset
- }
-
+ params = {"user_id": user_id, "limit": limit, "offset": offset}
+
if status:
conditions.append("Status = :status")
params["status"] = status
-
+
where_clause = " AND ".join(conditions)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, StoreID, RequestingUserID, RequestType,
ProposedData_JSON, Status, SubmitterNotes,
@@ -199,14 +188,15 @@ def get_change_requests_by_user_id(
WHERE {where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -215,18 +205,18 @@ def get_change_requests_by_user_id(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_all_pending_requests(
- self,
- conn: Connection,
- *,
- limit: int = 100,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ limit: int = 100,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取所有待审核的变更请求列表
@@ -238,7 +228,8 @@ def get_all_pending_requests(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, StoreID, RequestingUserID, RequestType,
ProposedData_JSON, Status, SubmitterNotes,
@@ -248,20 +239,15 @@ def get_all_pending_requests(
WHERE Status = 'PENDING_APPROVAL'
ORDER BY CreationTime ASC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
+
+ results = conn.execute(select_stmt, {"limit": limit, "offset": offset}).fetchall()
- results = conn.execute(
- select_stmt,
- {
- "limit": limit,
- "offset": offset
- }
- ).fetchall()
-
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -270,25 +256,25 @@ def get_all_pending_requests(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def get_filtered_requests(
- self,
- conn: Connection,
- *,
- store_id: Optional[int] = None,
- user_id: Optional[int] = None,
- request_type: Optional[str] = None,
- status: Optional[str] = None,
- admin_id: Optional[int] = None,
- start_date: Optional[datetime.datetime] = None,
- end_date: Optional[datetime.datetime] = None,
- limit: int = 20,
- offset: int = 0,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ store_id: Optional[int] = None,
+ user_id: Optional[int] = None,
+ request_type: Optional[str] = None,
+ status: Optional[str] = None,
+ admin_id: Optional[int] = None,
+ start_date: Optional[datetime.datetime] = None,
+ end_date: Optional[datetime.datetime] = None,
+ limit: int = 20,
+ offset: int = 0,
+ actor_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
获取根据多种条件筛选的请求列表
@@ -309,43 +295,41 @@ def get_filtered_requests(
# 构建查询条件
conditions = []
- params = {
- "limit": limit,
- "offset": offset
- }
+ params = {"limit": limit, "offset": offset}
if store_id is not None:
conditions.append("StoreID = :store_id")
params["store_id"] = store_id
-
+
if user_id is not None:
conditions.append("RequestingUserID = :user_id")
params["user_id"] = user_id
-
+
if request_type is not None:
conditions.append("RequestType = :request_type")
params["request_type"] = request_type
-
+
if status is not None:
conditions.append("Status = :status")
params["status"] = status
-
+
if admin_id is not None:
conditions.append("AdminReviewerID = :admin_id")
params["admin_id"] = admin_id
-
+
if start_date is not None:
conditions.append("CreationTime >= :start_date")
params["start_date"] = start_date
-
+
if end_date is not None:
conditions.append("CreationTime <= :end_date")
params["end_date"] = end_date
# 构建完整查询语句
where_clause = "" if not conditions else f"WHERE {' AND '.join(conditions)}"
-
- select_stmt = text(f"""
+
+ select_stmt = text(
+ f"""
SELECT
ChangeRequestID, StoreID, RequestingUserID, RequestType,
ProposedData_JSON, Status, SubmitterNotes,
@@ -355,14 +339,15 @@ def get_filtered_requests(
{where_clause}
ORDER BY LastUpdatedDate DESC
LIMIT :limit OFFSET :offset
- """)
+ """
+ )
results = conn.execute(select_stmt, params).fetchall()
-
+
result_list = []
for row in results:
row_dict = dict(row._mapping)
-
+
# 处理JSON字段
if row_dict.get("ProposedData_JSON"):
if not isinstance(row_dict["ProposedData_JSON"], dict):
@@ -371,11 +356,11 @@ def get_filtered_requests(
except Exception as e:
logger.error(f"解析ProposedData_JSON字段失败: {e}")
row_dict["ProposedData_JSON"] = {}
-
+
result_list.append(row_dict)
-
+
return result_list
-
+
def create_change_request(
self,
conn: Connection,
@@ -385,11 +370,11 @@ def create_change_request(
proposed_data_json: Dict[str, Any] = None,
proposed_data: Dict[str, Any] = None,
submitter_notes: Optional[str] = None,
- actor_id: Optional[int] = None
+ actor_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
创建新的店铺变更请求
-
+
:param conn: 数据库连接
:param requesting_user_id: 请求用户ID
:param request_type: 请求类型
@@ -403,17 +388,17 @@ def create_change_request(
# 向后兼容:如果提供了proposed_data而不是proposed_data_json,使用proposed_data
if proposed_data_json is None and proposed_data is not None:
proposed_data_json = proposed_data
- # 检查用户角色
+ # 检查用户角色
try:
user_query = """
SELECT UserRole as Role FROM User WHERE UserID = :user_id
"""
user_result = conn.execute(text(user_query), {"user_id": requesting_user_id})
user = user_result.fetchone()
-
+
if not user:
raise ValueError(f"User {requesting_user_id} not found")
-
+
# 如果是普通用户,更新为商家角色
if user.Role == "customer":
update_role_query = """
@@ -425,11 +410,11 @@ def create_change_request(
except Exception as e:
# 忽略测试期间可能出现的错误
logger.warning(f"检查/更新用户角色时出错: {e}")
-
+
# 将提议数据转换为JSON字符串
if isinstance(proposed_data_json, dict):
proposed_data_json = json.dumps(proposed_data_json)
-
+
# 创建变更请求
query = """
INSERT INTO StoreChangeRequest (
@@ -440,7 +425,7 @@ def create_change_request(
'PENDING_APPROVAL', :submitter_notes
)
"""
-
+
result = conn.execute(
text(query),
{
@@ -448,29 +433,25 @@ def create_change_request(
"requesting_user_id": requesting_user_id,
"request_type": request_type,
"proposed_data_json": proposed_data_json,
- "submitter_notes": submitter_notes
- }
+ "submitter_notes": submitter_notes,
+ },
)
-
+
# 获取新插入记录的ID
request_id = result.lastrowid
-
+
# 获取完整的记录
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def update_request_status(
- self,
- conn: Connection,
- *,
- request_id: int,
- status: str,
- admin_id: Optional[int] = None,
- admin_notes: Optional[str] = None,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ status: str,
+ admin_id: Optional[int] = None,
+ admin_notes: Optional[str] = None,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新请求状态
@@ -485,53 +466,44 @@ def update_request_status(
self._set_actor_session_variable(conn, actor_id)
# 检查请求是否存在
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request:
return None
# 构建更新语句
update_fields = ["Status = :status"]
- params = {
- "request_id": request_id,
- "status": status
- }
-
+ params = {"request_id": request_id, "status": status}
+
if admin_id is not None:
update_fields.append("AdminReviewerID = :admin_id")
update_fields.append("ReviewTimestamp = NOW()")
params["admin_id"] = admin_id
-
+
if admin_notes is not None:
update_fields.append("AdminNotes = :admin_notes")
params["admin_notes"] = admin_notes
-
- update_stmt = text(f"""
+
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {', '.join(update_fields)}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
-
+
# 返回更新后的请求信息
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def update_request(
- self,
- conn: Connection,
- *,
- request_id: int,
- update_data: Dict[str, Any],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ update_data: Dict[str, Any],
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新请求信息
@@ -544,66 +516,60 @@ def update_request(
self._set_actor_session_variable(conn, actor_id)
# 检查请求是否存在
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request:
return None
-
+
# 构建更新语句
update_fields = []
params = {"request_id": request_id}
-
+
# 定义允许更新的字段
allowed_fields = {
"requesttype": "RequestType",
"proposeddata_json": "ProposedData_JSON",
"status": "Status",
"submitternotes": "SubmitterNotes",
- "storeid": "StoreID"
+ "storeid": "StoreID",
}
-
+
for key, value in update_data.items():
key_lower = key.lower()
if key_lower in allowed_fields:
db_field = allowed_fields[key_lower]
-
+
# 特殊处理JSON字段
if db_field == "ProposedData_JSON" and value is not None:
if isinstance(value, dict):
value = json.dumps(value)
-
+
update_fields.append(f"{db_field} = :{key_lower}")
params[key_lower] = value
-
+
if not update_fields:
# 没有可更新的字段
return request
-
- update_stmt = text(f"""
+
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET {', '.join(update_fields)}
WHERE ChangeRequestID = :request_id
- """)
+ """
+ )
conn.execute(update_stmt, params)
-
+
# 返回更新后的请求信息
- return self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ return self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
def cancel_request(
- self,
- conn: Connection,
- *,
- request_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ request_id: int,
+ actor_id: Optional[int] = None,
) -> bool:
"""
取消请求(用户自行取消)
@@ -615,26 +581,21 @@ def cancel_request(
self._set_actor_session_variable(conn, actor_id)
# 先检查请求是否存在且状态为待审核
- request = self.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+ request = self.get_change_request_by_id(conn=conn, request_id=request_id, actor_id=actor_id)
+
if not request or request.get("Status") != "PENDING_APPROVAL":
return False
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET Status = 'CANCELLED_BY_USER'
WHERE ChangeRequestID = :request_id AND Status = 'PENDING_APPROVAL'
- """)
-
- result = conn.execute(
- update_stmt,
- {"request_id": request_id}
+ """
)
-
+
+ result = conn.execute(update_stmt, {"request_id": request_id})
+
return result.rowcount > 0
diff --git a/src/backend/app/crud/store_change_request_crud_v2.py b/src/backend/app/crud/store_change_request_crud_v2.py
index fff4728..c91e2b4 100644
--- a/src/backend/app/crud/store_change_request_crud_v2.py
+++ b/src/backend/app/crud/store_change_request_crud_v2.py
@@ -72,17 +72,13 @@ def _process_row(self, row: Optional[Any]) -> Optional[Dict[str, Any]]:
if not row:
return None
row_dict = dict(row._mapping) # type: ignore
- row_dict["ProposedData_JSON"] = self._deserialize_proposed_data(
- row_dict.get("ProposedData_JSON")
- )
+ row_dict["ProposedData_JSON"] = self._deserialize_proposed_data(row_dict.get("ProposedData_JSON"))
return row_dict
def _process_rows(self, rows: Iterable[Any]) -> List[Dict[str, Any]]:
return [self._process_row(row) for row in rows if row] # type: ignore
- def get_request_by_id(
- self, conn: Connection, *, change_request_id: int
- ) -> Optional[Dict[str, Any]]:
+ def get_request_by_id(self, conn: Connection, *, change_request_id: int) -> Optional[Dict[str, Any]]:
logger.debug(f"Getting StoreChangeRequest by ID {change_request_id}")
select_sql = self._get_request_base_query() + " WHERE ChangeRequestID = :ChangeRequestID"
select_stmt = text(select_sql)
@@ -96,9 +92,7 @@ def get_request_by_id(
def get_request_by_id_for_requesting_user(
self, conn: Connection, *, change_request_id: int, requesting_user_id: int
) -> Optional[Dict[str, Any]]:
- logger.debug(
- f"Getting StoreChangeRequest ID {change_request_id} for RequestingUserID {requesting_user_id}"
- )
+ logger.debug(f"Getting StoreChangeRequest ID {change_request_id} for RequestingUserID {requesting_user_id}")
select_sql = (
self._get_request_base_query()
+ " WHERE ChangeRequestID = :ChangeRequestID AND RequestingUserID = :RequestingUserID"
@@ -170,9 +164,7 @@ def get_request_list(
params["Limit"] = limit
params["Offset"] = offset
- select_sql = (
- self._get_request_base_query() + f" {where_sql} ORDER BY CreationTime DESC {limit_sql}"
- )
+ select_sql = self._get_request_base_query() + f" {where_sql} ORDER BY CreationTime DESC {limit_sql}"
select_stmt = text(select_sql)
try:
@@ -222,9 +214,7 @@ def _create_request_internal(
f"lastrowid not available after creating {request_type} request for RequestingUserID {requesting_user_id}."
)
return None
- logger.info(
- f"{request_type} request created with ID {new_request_id} by ActorID {actor_id}."
- )
+ logger.info(f"{request_type} request created with ID {new_request_id} by ActorID {actor_id}.")
return self.get_request_by_id(conn, change_request_id=new_request_id)
except exc.IntegrityError as e:
logger.error(f"Integrity error creating {request_type} request: {e}")
@@ -242,9 +232,7 @@ def create_request_create_store(
proposed_data_json: Dict[str, Any],
actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
- logger.info(
- f"ActorID {actor_id} creating STORE_CREATE request for RequestingUserID {requesting_user_id}"
- )
+ logger.info(f"ActorID {actor_id} creating STORE_CREATE request for RequestingUserID {requesting_user_id}")
return self._create_request_internal(
conn=conn,
requesting_user_id=requesting_user_id,
@@ -307,9 +295,7 @@ def cancel_request_by_user(
change_request_id: int,
actor_id: Optional[int] = None,
) -> bool:
- logger.info(
- f"ActorID {actor_id} attempting to cancel StoreChangeRequestID {change_request_id}."
- )
+ logger.info(f"ActorID {actor_id} attempting to cancel StoreChangeRequestID {change_request_id}.")
self._set_actor_session_variable(conn, actor_id)
update_stmt = text(
@@ -340,9 +326,7 @@ def cancel_request_by_user(
)
return False
except Exception as e:
- logger.error(
- f"Error cancelling StoreChangeRequestID {change_request_id} by ActorID {actor_id}: {e}"
- )
+ logger.error(f"Error cancelling StoreChangeRequestID {change_request_id} by ActorID {actor_id}: {e}")
return False
def update_request_by_admin(
@@ -358,9 +342,7 @@ def update_request_by_admin(
logger.info(
f"AdminID {admin_reviewer_id} (ActorID {actor_id}) updating StoreChangeRequestID {change_request_id} to Status {status}."
)
- self._set_actor_session_variable(
- conn, actor_id if actor_id is not None else admin_reviewer_id
- )
+ self._set_actor_session_variable(conn, actor_id if actor_id is not None else admin_reviewer_id)
set_clauses: List[str] = [
"Status = :Status",
diff --git a/src/backend/app/crud/store_crud.py b/src/backend/app/crud/store_crud.py
index ad0b947..000e2b9 100644
--- a/src/backend/app/crud/store_crud.py
+++ b/src/backend/app/crud/store_crud.py
@@ -36,29 +36,24 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
"""
try:
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
except Exception as e:
logger.error(f"Error setting actor session variable: {e}")
# Consider re-raising or handling
def create_store(
- self,
- conn: Connection,
- *,
- store_name: str,
- owner_user_id: int,
- description: Optional[str], # Can be None
- logo_url: Optional[str], # Can be None
- store_status: StoreStatusEnum,
- creation_date: datetime.datetime,
- actor_id: Optional[int]
+ self,
+ conn: Connection,
+ *,
+ store_name: str,
+ owner_user_id: int,
+ description: Optional[str], # Can be None
+ logo_url: Optional[str], # Can be None
+ store_status: StoreStatusEnum,
+ creation_date: datetime.datetime,
+ actor_id: Optional[int],
) -> Optional[Dict[str, Any]]:
"""
在数据库中创建一条新的店铺记录。
@@ -74,12 +69,11 @@ def create_store(
:param actor_id: 执行此操作的用户ID。
:return: 创建成功后的店铺信息字典 (包含 StoreID),如果创建失败则返回 None。
"""
- logger.info(
- f"ActorID {actor_id} attempting to create store: {store_name} for OwnerUserID {owner_user_id}"
- )
+ logger.info(f"ActorID {actor_id} attempting to create store: {store_name} for OwnerUserID {owner_user_id}")
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (
StoreName, OwnerUserID, Description, LogoURL, StoreStatus,
CreationDate, LastUpdatedDate
@@ -87,17 +81,21 @@ def create_store(
:StoreName, :OwnerUserID, :Description, :LogoURL, :StoreStatus,
:CreationDate, :CreationDate
)
- """)
+ """
+ )
try:
- result = conn.execute(insert_stmt, {
- "StoreName": store_name,
- "OwnerUserID": owner_user_id,
- "Description": description, # Pass None directly if that's the value
- "LogoURL": logo_url, # Pass None directly if that's the value
- "StoreStatus": store_status.value,
- "CreationDate": creation_date
- })
+ result = conn.execute(
+ insert_stmt,
+ {
+ "StoreName": store_name,
+ "OwnerUserID": owner_user_id,
+ "Description": description, # Pass None directly if that's the value
+ "LogoURL": logo_url, # Pass None directly if that's the value
+ "StoreStatus": store_status.value,
+ "CreationDate": creation_date,
+ },
+ )
new_store_id = result.lastrowid
if new_store_id is None:
@@ -109,28 +107,26 @@ def create_store(
)
return self.get_store_by_id(conn, store_id=new_store_id, actor_id=actor_id)
except exc.IntegrityError as e:
- logger.error(
- f"Integrity error creating store for OwnerUserID {owner_user_id}: {e}"
- )
+ logger.error(f"Integrity error creating store for OwnerUserID {owner_user_id}: {e}")
return None
except Exception as e:
- logger.error(
- f"Unexpected error creating store for OwnerUserID {owner_user_id}: {e}"
- )
+ logger.error(f"Unexpected error creating store for OwnerUserID {owner_user_id}: {e}")
return None
def get_store_by_id(
- self, conn: Connection, *, store_id: int, actor_id: Optional[int] = None
+ self, conn: Connection, *, store_id: int, actor_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据 StoreID 检索特定的店铺信息。
"""
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL,
StoreStatus, CreationDate, LastUpdatedDate
FROM {self.table_name}
WHERE StoreID = :StoreID
- """)
+ """
+ )
try:
result = conn.execute(select_stmt, {"StoreID": store_id}).fetchone()
return dict(result._mapping) if result else None # type: ignore
@@ -139,70 +135,65 @@ def get_store_by_id(
return None
def get_stores_by_owner_user_id(
- self,
- conn: Connection,
- *,
- owner_user_id: int,
- offset: int = 0,
- limit: int = 20,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, owner_user_id: int, offset: int = 0, limit: int = 20, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定用户拥有的所有店铺列表,支持分页。
"""
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL,
StoreStatus, CreationDate, LastUpdatedDate
FROM {self.table_name}
WHERE OwnerUserID = :OwnerUserID
ORDER BY CreationDate DESC
LIMIT :Limit OFFSET :Offset
- """)
+ """
+ )
try:
- results = conn.execute(select_stmt, {
- "OwnerUserID": owner_user_id,
- "Limit": limit,
- "Offset": offset
- }).fetchall()
+ results = conn.execute(
+ select_stmt, {"OwnerUserID": owner_user_id, "Limit": limit, "Offset": offset}
+ ).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
except Exception as e:
logger.error(f"Error getting stores for OwnerUserID {owner_user_id}: {e}")
return []
def get_stores_by_owner_user_id_all(
- self,
- conn: Connection,
- *,
- owner_user_id: int,
- actor_id: Optional[int] = None
+ self, conn: Connection, *, owner_user_id: int, actor_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
检索指定用户拥有的所有店铺列表,支持分页。
"""
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL,
StoreStatus, CreationDate, LastUpdatedDate
FROM {self.table_name}
WHERE OwnerUserID = :OwnerUserID
ORDER BY CreationDate DESC
- """)
+ """
+ )
try:
- results = conn.execute(select_stmt, {
- "OwnerUserID": owner_user_id,
- }).fetchall()
+ results = conn.execute(
+ select_stmt,
+ {
+ "OwnerUserID": owner_user_id,
+ },
+ ).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
except Exception as e:
logger.error(f"Error getting stores for OwnerUserID {owner_user_id}: {e}")
return []
def get_all_stores_page(
- self,
- conn: Connection,
- *,
- store_status: Optional[StoreStatusEnum] = None, # 可选的筛选条件
- offset: int = 0,
- limit: int = 20, # 默认每页20条
- actor_id: Optional[int] = None # 用于可能的审计或未来扩展
+ self,
+ conn: Connection,
+ *,
+ store_status: Optional[StoreStatusEnum] = None, # 可选的筛选条件
+ offset: int = 0,
+ limit: int = 20, # 默认每页20条
+ actor_id: Optional[int] = None, # 用于可能的审计或未来扩展
) -> List[Dict[str, Any]]:
"""
检索所有店铺的列表,支持按店铺状态筛选和分页。
@@ -231,14 +222,16 @@ def get_all_stores_page(
if where_clauses:
where_sql = "WHERE " + " AND ".join(where_clauses)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL,
StoreStatus, CreationDate, LastUpdatedDate
FROM {self.table_name}
{where_sql}
ORDER BY CreationDate DESC
LIMIT :Limit OFFSET :Offset
- """)
+ """
+ )
try:
results = conn.execute(select_stmt, params).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
@@ -247,11 +240,11 @@ def get_all_stores_page(
return []
def get_all_stores(
- self,
- conn: Connection,
- *,
- store_status: Optional[StoreStatusEnum] = None, # 可选的筛选条件
- actor_id: Optional[int] = None # 用于可能的审计或未来扩展
+ self,
+ conn: Connection,
+ *,
+ store_status: Optional[StoreStatusEnum] = None, # 可选的筛选条件
+ actor_id: Optional[int] = None, # 用于可能的审计或未来扩展
) -> List[Dict[str, Any]]:
"""
检索所有店铺的列表,支持按店铺状态筛选。
@@ -262,9 +255,7 @@ def get_all_stores(
:param actor_id: (可选) 执行此操作的用户ID。
:return: 店铺信息字典的列表。
"""
- logger.info(
- f"ActorID {actor_id} attempting to get all stores. Filter status: {store_status}"
- )
+ logger.info(f"ActorID {actor_id} attempting to get all stores. Filter status: {store_status}")
# _set_actor_session_variable(conn, actor_id) # 通常GET操作不需要,除非有特定触发器
params: Dict[str, Any] = {}
@@ -278,13 +269,15 @@ def get_all_stores(
if where_clauses:
where_sql = "WHERE " + " AND ".join(where_clauses)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL,
StoreStatus, CreationDate, LastUpdatedDate
FROM {self.table_name}
{where_sql}
ORDER BY CreationDate DESC
- """)
+ """
+ )
try:
results = conn.execute(select_stmt, params).fetchall()
return [dict(row._mapping) for row in results] # type: ignore
@@ -293,15 +286,15 @@ def get_all_stores(
return []
def update_store(
- self,
- conn: Connection,
- *,
- store_id: int,
- actor_id: int,
- store_name: Optional[str] = None,
- description: Optional[str] = None,
- logo_url: Optional[str] = None,
- store_status: Optional[StoreStatusEnum] = None
+ self,
+ conn: Connection,
+ *,
+ store_id: int,
+ actor_id: int,
+ store_name: Optional[str] = None,
+ description: Optional[str] = None,
+ logo_url: Optional[str] = None,
+ store_status: Optional[StoreStatusEnum] = None,
) -> Optional[Dict[str, Any]]:
"""
更新现有店铺的信息。
@@ -317,9 +310,7 @@ def update_store(
:param store_status: (可选) 新的店铺状态 (StoreStatusEnum)。如果为 None,则不更新。
:return: 更新成功后的店铺信息字典,如果店铺未找到或更新失败则返回 None。
"""
- logger.info(
- f"ActorID {actor_id} attempting to update StoreID {store_id}."
- )
+ logger.info(f"ActorID {actor_id} attempting to update StoreID {store_id}.")
self._set_actor_session_variable(conn, actor_id)
update_fields = []
diff --git a/src/backend/app/crud/user_crud.py b/src/backend/app/crud/user_crud.py
index 5ceb5b9..9be20f3 100644
--- a/src/backend/app/crud/user_crud.py
+++ b/src/backend/app/crud/user_crud.py
@@ -31,21 +31,16 @@ def _set_actor_session_variable(conn: Connection, actor_id: Optional[int]):
:return:
"""
if actor_id is not None:
- conn.execute(
- text("SET @actor_id = :actor_id"),
- {"actor_id": actor_id}
- )
+ conn.execute(text("SET @actor_id = :actor_id"), {"actor_id": actor_id})
else:
- conn.execute(
- text("SET @actor_id = NULL")
- )
+ conn.execute(text("SET @actor_id = NULL"))
def get_user_by_id(
- self,
- conn: Connection,
- *,
- user_id: int,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get a user by their ID.
@@ -56,25 +51,24 @@ def get_user_by_id(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
UserID, Username, Email, PhoneNumber, UserRole, RegistrationDate, LastLoginDate, DefaultAddressID, AccountStatus
FROM {self.table_name} WHERE UserID = :user_id
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"user_id": user_id}
- ).fetchone()
+ result = conn.execute(select_stmt, {"user_id": user_id}).fetchone()
return dict(result._mapping) if result else None
def get_user_by_username(
- self,
- conn: Connection,
- *,
- username: str,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ username: str,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get a user by their username.
@@ -85,26 +79,25 @@ def get_user_by_username(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
UserID, Username, Email, PhoneNumber, UserRole,
RegistrationDate, LastLoginDate, AccountStatus
FROM {self.table_name} WHERE Username = :username
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"username": username}
- ).fetchone()
+ result = conn.execute(select_stmt, {"username": username}).fetchone()
return dict(result._mapping) if result else None
def get_user_by_email(
- self,
- conn: Connection,
- *,
- email: str | EmailStr,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ email: str | EmailStr,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get a user by their email.
@@ -115,26 +108,25 @@ def get_user_by_email(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
UserID, Username, Email, PhoneNumber, UserRole,
RegistrationDate, LastLoginDate, AccountStatus
FROM {self.table_name} WHERE Email = :email
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"email": email}
- ).fetchone()
+ result = conn.execute(select_stmt, {"email": email}).fetchone()
return dict(result._mapping) if result else None
def get_user_with_password_by_username(
- self,
- conn: Connection,
- *,
- username: str,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ username: str,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get a user by their username, including the hashed password.
@@ -145,26 +137,25 @@ def get_user_with_password_by_username(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
UserID, Username, Email, PhoneNumber, PasswordHash,
UserRole, RegistrationDate, LastLoginDate, AccountStatus
FROM {self.table_name} WHERE Username = :username
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"username": username}
- ).fetchone()
+ result = conn.execute(select_stmt, {"username": username}).fetchone()
return dict(result._mapping) if result else None
def get_user_with_password_by_email(
- self,
- conn: Connection,
- *,
- email: str | EmailStr,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ email: str | EmailStr,
+ actor_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get a user by their email, including the hashed password.
@@ -175,29 +166,28 @@ def get_user_with_password_by_email(
"""
self._set_actor_session_variable(conn, actor_id)
- select_stmt = text(f"""
+ select_stmt = text(
+ f"""
SELECT
UserID, Username, Email, PhoneNumber, PasswordHash,
UserRole, RegistrationDate, LastLoginDate, AccountStatus
FROM {self.table_name} WHERE Email = :email
- """)
+ """
+ )
- result = conn.execute(
- select_stmt,
- {"email": email}
- ).fetchone()
+ result = conn.execute(select_stmt, {"email": email}).fetchone()
return dict(result._mapping) if result else None
def create_user(
- self,
- conn: Connection,
- *,
- username: str,
- email: Optional[str | EmailStr] = None,
- phone_number: Optional[str] = None,
- hashed_password: str,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ username: str,
+ email: Optional[str | EmailStr] = None,
+ phone_number: Optional[str] = None,
+ hashed_password: str,
+ actor_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
Create a new user in the database.
@@ -211,39 +201,32 @@ def create_user(
"""
self._set_actor_session_variable(conn, actor_id)
- insert_stmt = text(f"""
+ insert_stmt = text(
+ f"""
INSERT INTO {self.table_name} (Username, Email, PhoneNumber, PasswordHash)
VALUES (:username, :email, :phone_number, :hashed_password)
- """)
+ """
+ )
conn.execute(
insert_stmt,
- {
- "username": username,
- "email": email,
- "phone_number": phone_number,
- "hashed_password": hashed_password
- }
+ {"username": username, "email": email, "phone_number": phone_number, "hashed_password": hashed_password},
)
# Fetch the created user
- result = self.get_user_by_username(
- conn,
- username=username,
- actor_id=actor_id
- )
+ result = self.get_user_by_username(conn, username=username, actor_id=actor_id)
if result is None:
logger.error("User creation failed, user not found after insertion.")
raise Exception("User creation failed, user not found after insertion.")
return result
def update_user_default_address_id(
- self,
- conn: Connection,
- *,
- user_id: int,
- default_address_id: Optional[int],
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ default_address_id: Optional[int],
+ actor_id: Optional[int] = None,
) -> bool:
"""
Set the default address ID for a user.
@@ -255,37 +238,35 @@ def update_user_default_address_id(
"""
self._set_actor_session_variable(conn, actor_id)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET DefaultAddressID = :default_address_id
WHERE UserID = :user_id
- """)
+ """
+ )
try:
- result = conn.execute(
- update_stmt,
- {
- "default_address_id": default_address_id,
- "user_id": user_id
- }
- )
+ result = conn.execute(update_stmt, {"default_address_id": default_address_id, "user_id": user_id})
except Exception as e:
logger.error(f"Failed to update DefaultAddressID for UserID: {user_id}. Error: {e}")
return False
if result.rowcount == 0:
- logger.error(f"Failed to update DefaultAddressID for UserID: {user_id}. "
- f"This may be due to the user not existing.")
+ logger.error(
+ f"Failed to update DefaultAddressID for UserID: {user_id}. "
+ f"This may be due to the user not existing."
+ )
return False
return True
def update_user_role(
- self,
- conn: Connection,
- *,
- user_id: int,
- new_role: str,
- actor_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ user_id: int,
+ new_role: str,
+ actor_id: Optional[int] = None,
) -> bool:
"""
Update the role of a user. TODO: testme
@@ -297,29 +278,26 @@ def update_user_role(
"""
self._set_actor_session_variable(conn, actor_id)
- update_stmt = text(f"""
+ update_stmt = text(
+ f"""
UPDATE {self.table_name}
SET UserRole = :new_role
WHERE UserID = :user_id
- """)
+ """
+ )
try:
- result = conn.execute(
- update_stmt,
- {
- "new_role": new_role,
- "user_id": user_id
- }
- )
+ result = conn.execute(update_stmt, {"new_role": new_role, "user_id": user_id})
except Exception as e:
logger.error(f"Failed to update UserRole for UserID: {user_id}. Error: {e}")
return False
if result.rowcount == 0:
- logger.error(f"Failed to update UserRole for UserID: {user_id}. "
- f"This may be due to the user not existing.")
+ logger.error(
+ f"Failed to update UserRole for UserID: {user_id}. " f"This may be due to the user not existing."
+ )
return False
return True
-user_crud_instance = UserCRUD.get_instance() # Still expose a global variable
+user_crud_instance = UserCRUD.get_instance() # Still expose a global variable
diff --git a/src/backend/app/crud/user_session_crud.py b/src/backend/app/crud/user_session_crud.py
index 2d6ef8d..7893eb4 100644
--- a/src/backend/app/crud/user_session_crud.py
+++ b/src/backend/app/crud/user_session_crud.py
@@ -82,14 +82,10 @@ def create_session(
# but it's usually acceptable. Or fetch by token and other known values)
created_session_data = self.get_session_by_token(conn, token=session_token)
if not created_session_data:
- raise Exception(
- "Session creation verification failed (session not found immediately after insert)."
- )
+ raise Exception("Session creation verification failed (session not found immediately after insert).")
# print(f"INFO: UserSessionCRUD - Session created for UserID '{user_id}' with Token '{session_token[:8]}...'.")
- self.logger.info(
- f"Created session for UserID '{user_id}' with Token '{session_token[:8]}...'."
- )
+ self.logger.info(f"Created session for UserID '{user_id}' with Token '{session_token[:8]}...'.")
return created_session_data
def get_session_by_token(self, conn: Connection, *, token: str) -> Optional[Dict[str, Any]]:
@@ -110,9 +106,7 @@ def get_session_by_token(self, conn: Connection, *, token: str) -> Optional[Dict
result = conn.execute(stmt, {"token": token}).fetchone()
return dict(result._mapping) if result else None
- def get_active_user_id_by_token_and_update_access(
- self, conn: Connection, *, token: str
- ) -> Optional[int]:
+ def get_active_user_id_by_token_and_update_access(self, conn: Connection, *, token: str) -> Optional[int]:
"""
Retrieves the UserID for an active (non-expired) session token
and updates its LastAccessedAt timestamp.
@@ -127,9 +121,7 @@ def get_active_user_id_by_token_and_update_access(
# Let's go with Option 1 with a check, common for web apps.
# For higher concurrency needs, a stored procedure or specific locking might be better.
- current_time_utc = datetime.datetime.now(
- datetime.timezone.utc
- ) # Ensure timezone aware if DB stores UTC
+ current_time_utc = datetime.datetime.now(datetime.timezone.utc) # Ensure timezone aware if DB stores UTC
# Step 1: Find the session and check if it's active and not expired
select_stmt = text(
@@ -235,26 +227,20 @@ def get_active_user_id_by_token_and_update_access_and_expiration(
return True
- def delete_session_by_token(
- self, conn: Connection, *, token: str, actor_user_id: Optional[int] = None
- ) -> bool:
+ def delete_session_by_token(self, conn: Connection, *, token: str, actor_user_id: Optional[int] = None) -> bool:
"""
Deletes a session by its token (e.g., for logout).
actor_user_id is the user performing the deletion (could be self or an admin).
Returns True if a session was deleted, False otherwise.
"""
- self._set_actor_session_variable(
- conn, actor_user_id
- ) # If session deletion needs to be audited via trigger
+ self._set_actor_session_variable(conn, actor_user_id) # If session deletion needs to be audited via trigger
stmt = text(f"DELETE FROM {self.table_name} WHERE SessionToken = :token")
result = conn.execute(stmt, {"token": token})
deleted_count = result.rowcount
if deleted_count > 0:
- print(
- f"INFO: UserSessionCRUD - Session with Token '{token[:8]}...' deleted by actor_id '{actor_user_id}'."
- )
+ print(f"INFO: UserSessionCRUD - Session with Token '{token[:8]}...' deleted by actor_id '{actor_user_id}'.")
return deleted_count > 0
def delete_all_sessions_for_user(
diff --git a/src/backend/app/dependencies/auth_deps.py b/src/backend/app/dependencies/auth_deps.py
index 0f08cbd..5fb2939 100644
--- a/src/backend/app/dependencies/auth_deps.py
+++ b/src/backend/app/dependencies/auth_deps.py
@@ -8,19 +8,27 @@
from backend.app.schemas.user_schema import UserResponse
from backend.app.services.auth_service import AuthService # ⭐ 导入 AuthService
from backend.app.services.permission_checker import PermissionChecker
-from backend.app.utils.exceptions import InvalidTokenException, UserNotFoundException, \
- AuthenticationException # 导入自定义异常
+from backend.app.utils.exceptions import (
+ InvalidTokenException,
+ UserNotFoundException,
+ AuthenticationException,
+) # 导入自定义异常
from . import get_db_connection
-from .service_deps import get_auth_service, get_token_decoder, get_permission_checker # 如果 auth_service 依赖注入单独管理
+from .service_deps import (
+ get_auth_service,
+ get_token_decoder,
+ get_permission_checker,
+) # 如果 auth_service 依赖注入单独管理
from loguru import logger
# --- Dependency Injection for Header Authentication ---
+
async def get_token_from_auth_header(
- authorization: Optional[str] = Header(None, description="用于用户认证的 Bearer Token:`Bearer `"),
+ authorization: Optional[str] = Header(None, description="用于用户认证的 Bearer Token:`Bearer `"),
) -> str:
"""
依赖项:从请求头中获取 Bearer Token。
@@ -57,8 +65,7 @@ async def get_token_from_auth_header(
async def get_current_user_payload(
- token: str = Depends(get_token_from_auth_header),
- auth_service: AuthService = Depends(get_auth_service)
+ token: str = Depends(get_token_from_auth_header), auth_service: AuthService = Depends(get_auth_service)
) -> TokenPayload:
"""
依赖项:从 Bearer Token 中解码得到 TokenPayload。
@@ -77,14 +84,14 @@ async def get_current_user_payload(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token.",
- headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
+ headers={"WWW-Authenticate": 'Bearer error="invalid_token"'},
) from e
async def get_current_active_user(
- payload: TokenPayload = Depends(get_current_user_payload),
- auth_service: AuthService = Depends(get_auth_service),
- db: Connection = Depends(get_db_connection)
+ payload: TokenPayload = Depends(get_current_user_payload),
+ auth_service: AuthService = Depends(get_auth_service),
+ db: Connection = Depends(get_db_connection),
) -> UserResponse:
try:
user_response: UserResponse
@@ -94,16 +101,18 @@ async def get_current_active_user(
except InvalidTokenException as e:
logger.warning(
f"AuthService validation failed (InvalidTokenException): {e.detail}"
- f" for payload user_id: {payload.user_id if payload else 'N/A'}")
+ f" for payload user_id: {payload.user_id if payload else 'N/A'}"
+ )
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e.detail),
- headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, # 可以更具体
+ headers={"WWW-Authenticate": 'Bearer error="invalid_token"'}, # 可以更具体
)
except UserNotFoundException as e:
logger.error(
f"AuthService validation failed (UserNotFoundException): {e.detail}"
- f" for payload user_id: {payload.user_id if payload else 'N/A'}")
+ f" for payload user_id: {payload.user_id if payload else 'N/A'}"
+ )
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e.detail),
@@ -111,22 +120,20 @@ async def get_current_active_user(
)
except AuthenticationException as e:
logger.warning(f"AuthService validation failed (AuthenticationException): {e.detail}")
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=str(e.detail)
- )
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e.detail))
except Exception as e:
logger.exception(
f"Unexpected error in get_current_active_user for payload user_id"
- f" ({payload.user_id if payload else 'N/A'}): {str(e)}")
+ f" ({payload.user_id if payload else 'N/A'}): {str(e)}"
+ )
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal server error during authentication."
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error during authentication."
)
+
async def get_current_admin(
- current_user: UserResponse = Depends(get_current_active_user),
- permission_checker: PermissionChecker = Depends(get_permission_checker),
+ current_user: UserResponse = Depends(get_current_active_user),
+ permission_checker: PermissionChecker = Depends(get_permission_checker),
) -> UserResponse:
permission_checker.require_admin(current_user)
return current_user
diff --git a/src/backend/app/dependencies/crud_deps.py b/src/backend/app/dependencies/crud_deps.py
index 17e61c8..86caf53 100644
--- a/src/backend/app/dependencies/crud_deps.py
+++ b/src/backend/app/dependencies/crud_deps.py
@@ -16,6 +16,7 @@
# dependency for CRUDs
+
def get_user_crud() -> UserCRUD:
"""
Dependency to get the UserCRUD instance.
@@ -23,6 +24,7 @@ def get_user_crud() -> UserCRUD:
"""
return UserCRUD.get_instance()
+
def get_user_session_crud() -> UserSessionCRUD:
"""
Dependency to get the UserSessionCRUD instance.
@@ -30,6 +32,7 @@ def get_user_session_crud() -> UserSessionCRUD:
"""
return UserSessionCRUD.get_instance()
+
def get_category_crud() -> CategoryCRUD:
"""
Dependency to get the CategoryCRUD instance.
@@ -37,6 +40,7 @@ def get_category_crud() -> CategoryCRUD:
"""
return CategoryCRUD.get_instance()
+
def get_product_crud() -> ProductCRUD:
"""
Dependency to get the ProductCRUD instance.
@@ -44,6 +48,7 @@ def get_product_crud() -> ProductCRUD:
"""
return ProductCRUD.get_instance()
+
def get_cart_item_crud() -> CartItemCRUD:
"""
Dependency to get the CartItemCRUD instance.
@@ -51,6 +56,7 @@ def get_cart_item_crud() -> CartItemCRUD:
"""
return CartItemCRUD.get_instance()
+
def get_address_crud() -> AddressCRUD:
"""
Dependency to get the AddressCRUD instance.
@@ -58,6 +64,7 @@ def get_address_crud() -> AddressCRUD:
"""
return AddressCRUD.get_instance()
+
def get_order_item_crud() -> OrderItemCRUD:
"""
Dependency to get the OrderItemCRUD instance.
@@ -65,6 +72,7 @@ def get_order_item_crud() -> OrderItemCRUD:
"""
return OrderItemCRUD.get_instance()
+
def get_order_crud() -> OrderCRUD:
"""
Dependency to get the OrderCRUD instance.
@@ -72,6 +80,7 @@ def get_order_crud() -> OrderCRUD:
"""
return OrderCRUD.get_instance()
+
def get_payment_transaction_crud() -> PaymentTransactionCRUD:
"""
Dependency to get the PaymentTransactionCRUD instance.
@@ -79,6 +88,7 @@ def get_payment_transaction_crud() -> PaymentTransactionCRUD:
"""
return PaymentTransactionCRUD.get_instance()
+
def get_store_crud() -> StoreCRUD:
"""
Dependency to get the StoreCRUD instance.
@@ -86,6 +96,7 @@ def get_store_crud() -> StoreCRUD:
"""
return StoreCRUD.get_instance()
+
def get_product_change_request_crud2() -> ProductChangeRequestCRUD2:
"""
Dependency to get the ProductChangeRequestCRUD2 instance.
@@ -93,6 +104,7 @@ def get_product_change_request_crud2() -> ProductChangeRequestCRUD2:
"""
return ProductChangeRequestCRUD2.get_instance()
+
def get_store_change_request_crud2() -> StoreChangeRequestCRUD2:
"""
Dependency to get the StoreChangeRequestCRUD2 instance.
diff --git a/src/backend/app/dependencies/db_deps.py b/src/backend/app/dependencies/db_deps.py
index 01b447c..6f22532 100644
--- a/src/backend/app/dependencies/db_deps.py
+++ b/src/backend/app/dependencies/db_deps.py
@@ -5,9 +5,9 @@
from backend.app.core.database import get_engine
-
# dependency injection for getting a new database connection
+
def get_db_connection() -> Generator[Connection, None, None]:
"""
Dependency to get a new database connection.
@@ -21,4 +21,5 @@ def get_db_connection() -> Generator[Connection, None, None]:
if connection:
connection.close()
+
# dependency injection for getting
diff --git a/src/backend/app/dependencies/service_deps.py b/src/backend/app/dependencies/service_deps.py
index 4fd40f6..969a1c3 100644
--- a/src/backend/app/dependencies/service_deps.py
+++ b/src/backend/app/dependencies/service_deps.py
@@ -24,7 +24,8 @@
ProductChangeRequestService,
StoreService,
ProductChangeRequestService2,
- StoreChangeRequestService2, StatisticsService
+ StoreChangeRequestService2,
+ StatisticsService,
)
from backend.app.dependencies.crud_deps import *
from backend.app.crud.store_change_request_crud import get_store_change_request_crud_instance
@@ -38,8 +39,10 @@ def get_permission_checker(user_crud: UserCRUD = Depends(get_user_crud)) -> Perm
"""
return PermissionChecker(user_crud=user_crud)
+
# dependency injection for UserService
+
def get_password_hasher() -> Callable[[str], str]:
"""
Dependency to get the password hasher function.
@@ -57,8 +60,7 @@ def get_password_verifier() -> Callable[[str, str], bool]:
def get_user_service(
- user_crud: UserCRUD = Depends(get_user_crud),
- password_hasher: Callable[[str], str] = Depends(get_password_hasher)
+ user_crud: UserCRUD = Depends(get_user_crud), password_hasher: Callable[[str], str] = Depends(get_password_hasher)
) -> UserService:
"""
Dependency to get the UserService instance.
@@ -66,14 +68,12 @@ def get_user_service(
:param password_hasher: The password hasher function.
:return: The UserService instance.
"""
- return UserService(
- user_crud=user_crud,
- hash_pwd_func=password_hasher
- )
+ return UserService(user_crud=user_crud, hash_pwd_func=password_hasher)
# dependency injection for AuthService
+
def get_token_creator() -> Callable[[Dict[str, Any], Optional[datetime.timedelta]], str]:
return create_access_token
@@ -83,48 +83,39 @@ def get_token_decoder() -> Callable[[str], Optional[TokenPayload]]:
def get_auth_service(
- user_crud: UserCRUD = Depends(get_user_crud),
- user_session_crud: UserSessionCRUD = Depends(get_user_session_crud),
- password_verifier: Callable[[str, str], bool] = Depends(get_password_verifier),
- token_creator: Callable[[Dict[str, Any], Optional[datetime.timedelta]], str] = Depends(get_token_creator),
- token_decoder: Callable[[str], Optional[TokenPayload]] = Depends(get_token_decoder),
+ user_crud: UserCRUD = Depends(get_user_crud),
+ user_session_crud: UserSessionCRUD = Depends(get_user_session_crud),
+ password_verifier: Callable[[str, str], bool] = Depends(get_password_verifier),
+ token_creator: Callable[[Dict[str, Any], Optional[datetime.timedelta]], str] = Depends(get_token_creator),
+ token_decoder: Callable[[str], Optional[TokenPayload]] = Depends(get_token_decoder),
) -> AuthService:
return AuthService(
user_crud=user_crud,
user_session_crud=user_session_crud,
verify_password_func=password_verifier,
create_jwt_func=token_creator,
- decode_jwt_func=token_decoder
+ decode_jwt_func=token_decoder,
)
# dependency injection for CartService
+
def get_cart_service(
- cart_item_crud: CartItemCRUD = Depends(get_cart_item_crud),
- product_crud: ProductCRUD = Depends(get_product_crud)
+ cart_item_crud: CartItemCRUD = Depends(get_cart_item_crud), product_crud: ProductCRUD = Depends(get_product_crud)
) -> CartService:
- return CartService(
- cart_item_crud=cart_item_crud,
- product_crud=product_crud
- )
+ return CartService(cart_item_crud=cart_item_crud, product_crud=product_crud)
# dependency injection for AddressService
def get_address_service(
- address_crud: AddressCRUD = Depends(get_address_crud),
- user_crud: UserCRUD = Depends(get_user_crud)
+ address_crud: AddressCRUD = Depends(get_address_crud), user_crud: UserCRUD = Depends(get_user_crud)
) -> AddressService:
- return AddressService(
- address_crud=address_crud,
- user_crud=user_crud
- )
+ return AddressService(address_crud=address_crud, user_crud=user_crud)
# dependency injection for ProductService
-def get_product_service(
- product_crud: ProductCRUD = Depends(get_product_crud)
-) -> ProductService:
+def get_product_service(product_crud: ProductCRUD = Depends(get_product_crud)) -> ProductService:
"""
Dependency to get the ProductService instance.
:param product_crud: The ProductCRUD instance.
@@ -132,6 +123,7 @@ def get_product_service(
"""
return ProductService()
+
# dependency injection for OrderService
def get_order_service(
order_crud: OrderCRUD = Depends(get_order_crud),
@@ -157,9 +149,10 @@ def get_order_service(
address_crud=address_crud,
cart_item_crud=cart_item_crud,
product_crud=product_crud,
- payment_transaction_crud=payment_transaction_crud
+ payment_transaction_crud=payment_transaction_crud,
)
+
# dependency injection for StoreChangeRequestService
def get_store_change_request_service() -> StoreChangeRequestService:
"""
@@ -168,6 +161,7 @@ def get_store_change_request_service() -> StoreChangeRequestService:
"""
return StoreChangeRequestService()
+
# dependency injection for ProductChangeRequestService
def get_product_change_request_service() -> ProductChangeRequestService:
"""
@@ -176,10 +170,11 @@ def get_product_change_request_service() -> ProductChangeRequestService:
"""
return ProductChangeRequestService()
+
# dependency injection for StoreService
def get_store_service(
- store_crud: StoreCRUD = Depends(get_store_crud),
- user_crud: UserCRUD = Depends(get_user_crud),
+ store_crud: StoreCRUD = Depends(get_store_crud),
+ user_crud: UserCRUD = Depends(get_user_crud),
) -> StoreService:
"""
Dependency to get the StoreService instance.
@@ -195,10 +190,10 @@ def get_store_service(
# dependency injection for ChangeRequestService
def get_product_change_request_service_v2(
- product_change_request_crud: ProductChangeRequestCRUD2 = Depends(get_product_change_request_crud2),
- product_crud: ProductCRUD = Depends(get_product_crud),
- user_crud: UserCRUD = Depends(get_user_crud),
- store_crud: StoreCRUD = Depends(get_store_crud),
+ product_change_request_crud: ProductChangeRequestCRUD2 = Depends(get_product_change_request_crud2),
+ product_crud: ProductCRUD = Depends(get_product_crud),
+ user_crud: UserCRUD = Depends(get_user_crud),
+ store_crud: StoreCRUD = Depends(get_store_crud),
) -> ProductChangeRequestService2:
"""
Dependency to get the ProductChangeRequestService2 instance.
@@ -215,10 +210,11 @@ def get_product_change_request_service_v2(
store_crud=store_crud,
)
+
def get_store_change_request_service_v2(
- store_change_request_crud: StoreChangeRequestCRUD2 = Depends(get_store_change_request_crud2),
- store_crud: StoreCRUD = Depends(get_store_crud),
- user_crud: UserCRUD = Depends(get_user_crud),
+ store_change_request_crud: StoreChangeRequestCRUD2 = Depends(get_store_change_request_crud2),
+ store_crud: StoreCRUD = Depends(get_store_crud),
+ user_crud: UserCRUD = Depends(get_user_crud),
) -> StoreChangeRequestService2:
"""
Dependency to get the StoreChangeRequestService2 instance.
@@ -233,6 +229,7 @@ def get_store_change_request_service_v2(
user_crud=user_crud,
)
+
# dependency injection for StatisticsService
def get_statistics_service() -> StatisticsService:
"""
diff --git a/src/backend/app/main.py b/src/backend/app/main.py
index 4997a78..08e5340 100644
--- a/src/backend/app/main.py
+++ b/src/backend/app/main.py
@@ -28,7 +28,7 @@ async def lifespan(_app: FastAPI):
]
app.add_middleware(
- CORSMiddleware, # type: ignore
+ CORSMiddleware, # type: ignore
allow_origin_regex=r"^http://localhost(:[0-9]+)?$",
allow_credentials=True,
allow_methods=["*"],
@@ -55,6 +55,7 @@ def read_root():
@app.get("/test_db")
def test_db():
from sqlalchemy import text
+
try:
with database.get_engine().connect() as connection:
result = connection.execute(text("SELECT VERSION()"))
diff --git a/src/backend/app/schemas/address_schema.py b/src/backend/app/schemas/address_schema.py
index 62e0945..dc7015c 100644
--- a/src/backend/app/schemas/address_schema.py
+++ b/src/backend/app/schemas/address_schema.py
@@ -13,13 +13,12 @@ class AddressBase(BaseModel):
地址共享的基础字段。
字段名使用 CamelCase 以匹配 DDL。
"""
+
RecipientName: str = Field(..., min_length=1, max_length=255, description="收货人姓名。")
- PhoneNumber: constr(pattern=PHONE_NUMBER_REGEX, min_length=5, max_length=50) = Field( # type: ignore
- ...,
- description="收货人电话号码。"
+ PhoneNumber: constr(pattern=PHONE_NUMBER_REGEX, min_length=5, max_length=50) = Field(..., description="收货人电话号码。") # type: ignore
+ FullAddress_Text: str = Field(
+ ..., min_length=5, max_length=1000, description="完整收货地址(省市区、详细地址、邮编等)。"
)
- FullAddress_Text: str = Field(..., min_length=5, max_length=1000,
- description="完整收货地址(省市区、详细地址、邮编等)。")
class AddressCreateRequest(AddressBase):
@@ -27,6 +26,7 @@ class AddressCreateRequest(AddressBase):
客户端创建新地址时发送的数据。
UserID 将从已认证的用户中获取。
"""
+
pass
@@ -36,18 +36,21 @@ class AddressUpdateRequest(BaseModel): # 不再继承 AddressBase 以精确控
此 schema 用于更新地址的收件人姓名、电话号码和完整地址。不处理 IsDefault 字段(由专门的 API 端点处理)。
所有字段都是可选的。
"""
+
RecipientName: Optional[str] = Field(None, min_length=1, max_length=255, description="新的收货人姓名 (可选)。")
PhoneNumber: Optional[constr(pattern=PHONE_NUMBER_REGEX, min_length=5, max_length=50)] = Field(
- None,
- description="新的收货人电话号码 (可选)。"
+ None, description="新的收货人电话号码 (可选)。"
+ )
+ FullAddress_Text: Optional[str] = Field(
+ None, min_length=5, max_length=1000, description="新的完整收货地址 (可选)。"
)
- FullAddress_Text: Optional[str] = Field(None, min_length=5, max_length=1000, description="新的完整收货地址 (可选)。")
class AddressResponse(AddressBase):
"""
从 API 返回地址信息时使用的数据模型。
"""
+
AddressID: int = Field(..., description="地址的唯一ID。")
UserID: int = Field(..., description="所属用户的ID。")
IsDefault: bool = Field(default=False, description="是否为默认地址。")
@@ -57,6 +60,7 @@ class AddressListResponse(BaseModel):
"""
API 返回用户的所有地址列表时使用的数据模型。
"""
+
Addresses: List[AddressResponse] = Field(..., description="用户的地址列表。")
TotalCount: int = Field(..., description="地址总数。")
@@ -65,5 +69,6 @@ class SetDefaultAddressResponse(BaseModel):
"""
设置默认地址操作的响应。
"""
+
Message: str = Field(..., description="操作结果消息。")
DefaultAddress: Optional[AddressResponse] = Field(None, description="新的默认地址信息 (可选)。")
diff --git a/src/backend/app/schemas/auth_schema.py b/src/backend/app/schemas/auth_schema.py
index 973b709..f61737a 100644
--- a/src/backend/app/schemas/auth_schema.py
+++ b/src/backend/app/schemas/auth_schema.py
@@ -7,11 +7,13 @@
# 用户登录、认证
+
class UserLogin(BaseModel):
"""
用户登录时提交的模型。
可以使用用户名或邮箱进行登录。
"""
+
UsernameOrEmail: str = Field(
...,
description="用户名或邮箱地址。",
@@ -29,18 +31,22 @@ class UserInDBForAuth(UserBase): # 继承 UserBase 以获取 username, email, p
代表从数据库中检索到的、用于认证的用户数据。
关键是包含 id 和 hashed_password。
"""
+
UserID: int = Field(..., description="用户的唯一ID。")
PasswordHash: str = Field(..., description="用户存储的哈希密码。")
- AccountStatus: str = Field(...,
- description="用户是否处于活动状态。'ACTIVE'表示活动,'INACTIVE'和'SUSPENDED_BY_ADMIN'表示非活动状态。")
+ AccountStatus: str = Field(
+ ..., description="用户是否处于活动状态。'ACTIVE'表示活动,'INACTIVE'和'SUSPENDED_BY_ADMIN'表示非活动状态。"
+ )
# --- 用于 API 响应的 Token Schema ---
+
class Token(BaseModel):
"""
API 认证成功后返回的 Token 模型。
"""
+
access_token: str = Field(..., description="JWT 访问令牌。")
token_type: str = Field("bearer", description="令牌类型,通常为 'bearer'。")
@@ -50,6 +56,7 @@ class TokenPayload(BaseModel):
JWT 负载模型。
包含用户 ID 和过期时间。
"""
+
sub: str = Field(..., description="令牌的主题 (Subject),通常是用户的唯一标识符,如用户ID的字符串形式或用户名。")
user_id: int = Field(..., description="用户的数字ID。") # 明确包含用户ID
exp: Optional[datetime.datetime] = Field(None, description="令牌的过期时间戳。")
diff --git a/src/backend/app/schemas/cartitem_schema.py b/src/backend/app/schemas/cartitem_schema.py
index de2c051..064cf50 100644
--- a/src/backend/app/schemas/cartitem_schema.py
+++ b/src/backend/app/schemas/cartitem_schema.py
@@ -8,6 +8,7 @@ class CartItemBase(BaseModel):
"""
购物车项目共享的基础字段。
"""
+
ProductID: int = Field(..., description="商品的唯一ID。")
Quantity: Annotated[int, Field(gt=0, description="商品数量,必须大于0。")]
@@ -19,6 +20,7 @@ class CartItemCreateRequest(CartItemBase):
UserID 将从已认证的用户中获取。
PriceAtAddition 将由服务层在添加时从产品信息中获取。
"""
+
pass # 继承自 CartItemBase,Quantity 字段已更新
@@ -27,6 +29,7 @@ class CartItemUpdateRequest(BaseModel):
"""
客户端更新购物车中商品数量时发送的数据。
"""
+
Quantity: Annotated[int, Field(gt=0, description="新的商品数量,必须大于0。")]
@@ -35,6 +38,7 @@ class CartItemResponse(CartItemBase): # Quantity 字段已从 CartItemBase 继
"""
从 API 返回购物车项目信息时使用的数据模型。
"""
+
CartItemID: int = Field(..., description="购物车项目的唯一ID。")
UserID: int = Field(..., description="所属用户的ID。")
PriceAtAddition: float = Field(..., description="商品加入购物车时的单价。")
@@ -51,6 +55,7 @@ class CartResponse(BaseModel):
"""
API 返回整个购物车内容时使用的数据模型。
"""
+
Items: List[CartItemResponse] = Field(..., description="购物车中的商品项目列表。")
TotalItems: int = Field(..., description="购物车中商品项目的总数。")
@@ -60,5 +65,6 @@ class CartActionResponse(BaseModel):
"""
用于购物车操作(如清空、删除条目)的通用响应。
"""
+
Message: str = Field(..., description="操作结果消息。")
Detail: Optional[dict[str, Any]] = Field(None, description="可选的详细信息。")
diff --git a/src/backend/app/schemas/category_schema.py b/src/backend/app/schemas/category_schema.py
index 0c0698f..af43850 100644
--- a/src/backend/app/schemas/category_schema.py
+++ b/src/backend/app/schemas/category_schema.py
@@ -6,6 +6,7 @@ class CategoryBase(BaseModel):
"""
商品分类基本信息模型
"""
+
CategoryName: str = Field(
...,
min_length=1,
@@ -27,6 +28,7 @@ class CategoryCreate(CategoryBase):
"""
创建新分类时,API 端点期望的分类数据模型。
"""
+
pass
@@ -35,6 +37,7 @@ class CategoryUpdate(BaseModel):
更新分类信息时,API 端点期望的分类数据模型。
包含可选字段。
"""
+
CategoryName: Optional[str] = Field(
None,
min_length=1,
@@ -56,6 +59,7 @@ class CategoryResponse(CategoryBase):
"""
分类响应模型,包含分类的基本信息和额外信息。
"""
+
CategoryID: int = Field(
...,
description="分类的唯一标识符。",
@@ -66,6 +70,7 @@ class CategoryWithChildren(CategoryResponse):
"""
包含子分类信息的分类响应模型
"""
+
Children: Optional[List["CategoryWithChildren"]] = Field(
default=[],
description="子分类列表。",
@@ -73,4 +78,4 @@ class CategoryWithChildren(CategoryResponse):
# 解决循环引用问题
-CategoryWithChildren.update_forward_refs()
\ No newline at end of file
+CategoryWithChildren.update_forward_refs()
diff --git a/src/backend/app/schemas/order_schema.py b/src/backend/app/schemas/order_schema.py
index 5765fd2..8a22d09 100644
--- a/src/backend/app/schemas/order_schema.py
+++ b/src/backend/app/schemas/order_schema.py
@@ -9,11 +9,13 @@
# --- 枚举类型 ---
+
class OrderStatusEnum(str, Enum):
"""
订单状态枚举 (基于 DDL)。
会被 OrderViewResponse, CreatedOrderDetailResponse, OrderUpdateStatusRequest, OrderActionResponse 使用。
"""
+
PENDING_PAYMENT = "PENDING_PAYMENT" # 待支付
PAID_AND_PENDING_PROCESSING = "PAID_AND_PENDING_PROCESSING" # 已支付,待处理 (新状态)
PROCESSING_BY_MERCHANT = "PROCESSING_BY_MERCHANT" # 商家处理中 (新状态)
@@ -33,6 +35,7 @@ class PaymentTransactionStatusEnum(str, Enum):
支付事务状态枚举 (基于 DDL)。
会被 OrderViewResponse (作为 PaymentStatus) 使用。
"""
+
PENDING = "PENDING" # 待处理/待支付
SUCCESSFUL = "SUCCESSFUL" # 支付成功 (DDL 使用 SUCCESSFUL)
FAILED = "FAILED" # 支付失败
@@ -46,6 +49,7 @@ class OrderItemCreationInput(BaseModel):
在创建订单请求中,指定要购买的购物车项目。
会被 POST /api/v1/orders/ (创建订单) 端点使用,作为 OrderCreateRequest 的一部分。
"""
+
CartItemID: int = Field(..., description="要购买的购物车项目ID。")
# Quantity 和 PriceAtPurchase 将由服务层从 CartItemID 对应的购物车条目中获取和验证。
@@ -56,10 +60,12 @@ class OrderCreateRequest(BaseModel):
客户端发起创建新订单请求时发送的数据。
会被 POST /api/v1/orders/ (创建订单) 端点使用。
"""
+
ShippingAddressID: int = Field(..., description="选择的收货地址ID。")
Items: List[OrderItemCreationInput] = Field(..., description="要购买的购物车项目列表。", min_length=1)
- Notes_ByUser: Optional[str] = Field(None, max_length=65535,
- description="用户订单备注 (可选,对应 DDL 的 TEXT 类型)。") # DDL 是 TEXT
+ Notes_ByUser: Optional[str] = Field(
+ None, max_length=65535, description="用户订单备注 (可选,对应 DDL 的 TEXT 类型)。"
+ ) # DDL 是 TEXT
# PaymentMethodHint: Optional[str] = Field(None, description="用户选择的支付方式提示 (可选)。服务层将用此创建 PaymentTransaction。")
@@ -70,6 +76,7 @@ class OrderItemDetailResponse(BaseModel):
会被嵌入到各种订单相关的响应中。
对应 DDL 中的 OrderItem 表。
"""
+
OrderItemID: int = Field(..., description="订单项目ID。")
OrderID: int = Field(..., description="所属订单ID。")
ProductID: int = Field(..., description="商品ID。")
@@ -88,11 +95,13 @@ class CreatedOrderDetailResponse(BaseModel):
会被嵌入到 InitiateOrderResponse 中。
对应 DDL 中的 Order 表的部分字段。
"""
+
OrderID: int = Field(..., description="新创建的订单的ID。")
StoreID: int = Field(..., description="此订单所属的店铺ID。")
# StoreName: Optional[str] = Field(None, description="店铺名称 (服务层可选填充)。") # DDL Order 表没有 StoreName
- FinalAmountForThisOrder: Decimal = Field(...,
- description="此特定订单的最终应付金额。") # 对应 DDL Order.FinalAmountForThisOrder
+ FinalAmountForThisOrder: Decimal = Field(
+ ..., description="此特定订单的最终应付金额。"
+ ) # 对应 DDL Order.FinalAmountForThisOrder
OrderStatus: OrderStatusEnum = Field(..., description="此订单的初始状态 (通常是 PENDING_PAYMENT)。")
Items: List[OrderItemDetailResponse] = Field(..., description="此订单包含的商品项。")
@@ -103,6 +112,7 @@ class InitiateOrderResponse(BaseModel):
创建订单流程(可能生成一个支付事务和多个订单)成功后的顶层响应。
由 POST /api/v1/orders/ (创建订单) 端点返回。
"""
+
PaymentTransactionID: int = Field(..., description="为此批订单创建的支付事务的ID。")
ExternalPaymentURL: Optional[str] = Field(None, description="重定向到第三方支付网关的URL (如果适用)。")
TotalAmountDue: Decimal = Field(..., description="需要支付的总金额 (对应 PaymentTransaction.TotalAmount)。")
@@ -119,6 +129,7 @@ class OrderViewResponse(BaseModel):
也会被 GET /api/v1/users/me/orders (获取用户所有订单) 端点中的列表项使用。
对应 DDL 中的 Order 表。
"""
+
OrderID: int
UserID: int
StoreID: int
@@ -157,6 +168,7 @@ class OrderListResponse(BaseModel):
会被 GET /api/v1/users/me/orders (获取用户所有订单) 端点使用。
也会被管理员查看订单列表的端点使用 (可能带有额外过滤参数)。
"""
+
Orders: List[OrderViewResponse] = Field(..., description="订单列表。")
TotalCount: int = Field(..., description="符合条件的订单总数 (用于分页)。")
Offset: Optional[int] = Field(None, description="当前分页的偏移量。")
@@ -169,6 +181,7 @@ class OrderUpdateStatusRequest(BaseModel):
用于用户或管理员更新订单状态的请求。
会被 PUT /api/v1/orders/{order_id}/status (更新订单状态) 端点使用。
"""
+
NewStatus: OrderStatusEnum = Field(..., description="新的订单状态。")
TrackingNumber: Optional[str] = Field(None, description="物流追踪号 (如果状态更新为 SHIPPED,由商家或管理员填写)。")
UserNotes: Optional[str] = Field(None, description="用户备注 (例如取消原因)。")
@@ -181,7 +194,7 @@ class OrderActionResponse(BaseModel):
用于订单操作(如取消订单成功、状态更新成功)的通用简单响应。
可被 PUT /api/v1/orders/{order_id}/status 或其他执行动作的端点使用。
"""
+
Message: str = Field(..., description="操作结果消息。")
OrderID: int = Field(..., description="相关的订单ID。")
NewStatus: Optional[OrderStatusEnum] = Field(None, description="操作后的新订单状态 (可选)。")
-
diff --git a/src/backend/app/schemas/payment_schema.py b/src/backend/app/schemas/payment_schema.py
index b21a73f..49c88c8 100644
--- a/src/backend/app/schemas/payment_schema.py
+++ b/src/backend/app/schemas/payment_schema.py
@@ -14,6 +14,7 @@
# ExternalPaymentCallbackQueryParams is no longer needed.
+
class SimulatedExternPaymentResponse(BaseModel):
"""
当用户在我们的支付页面点击“支付”时,模拟的支付网关将返回给我们的响应。
@@ -22,10 +23,11 @@ class SimulatedExternPaymentResponse(BaseModel):
这个请求体可以包含用户选择的模拟支付方式或其他相关信息。
会被 POST /api/v1/payment/{PaymentTransactionID}/simulate-pay (示例路径) 端点使用。
"""
- SimulatedPaymentMethod: str = Field(default="MockPayment",
- description="模拟支付方式名称")
- ExternalGatewayTxID: Optional[str] = Field(None,
- description="外部支付网关的交易ID (如果适用)。可以是前端随机生成的字符串。")
+
+ SimulatedPaymentMethod: str = Field(default="MockPayment", description="模拟支付方式名称")
+ ExternalGatewayTxID: Optional[str] = Field(
+ None, description="外部支付网关的交易ID (如果适用)。可以是前端随机生成的字符串。"
+ )
class PaymentProcessingResponse(BaseModel):
@@ -34,13 +36,18 @@ class PaymentProcessingResponse(BaseModel):
前端将根据此响应将用户重定向到合适的订单页面。
会被 POST /api/v1/payment/{PaymentTransactionID}/simulate-pay (示例路径) 端点使用。
"""
+
PaymentTransactionID: int = Field(..., description="系统内部的支付事务ID。")
- TransactionStatusInSystem: str = Field(...,
- description="更新后的系统内部支付事务状态 (例如,来自您系统的 PaymentTransactionStatusEnum: 'SUCCESSFUL', 'FAILED')。")
- MessageToUser: str = Field(...,
- description="给用户的最终消息 (例如,“支付成功,您的订单正在处理中!”或“支付失败,请重试。”)。")
- AffectedOrderIDs: Optional[List[int]] = Field(None,
- description="与此支付事务成功关联并更新状态的订单ID列表 (如果支付成功)。")
+ TransactionStatusInSystem: str = Field(
+ ...,
+ description="更新后的系统内部支付事务状态 (例如,来自您系统的 PaymentTransactionStatusEnum: 'SUCCESSFUL', 'FAILED')。",
+ )
+ MessageToUser: str = Field(
+ ..., description="给用户的最终消息 (例如,“支付成功,您的订单正在处理中!”或“支付失败,请重试。”)。"
+ )
+ AffectedOrderIDs: Optional[List[int]] = Field(
+ None, description="与此支付事务成功关联并更新状态的订单ID列表 (如果支付成功)。"
+ )
class PaymentResponse(BaseModel):
@@ -48,17 +55,15 @@ class PaymentResponse(BaseModel):
支付响应模型,包含支付事务的详细信息。
会被 GET /api/v1/payment/{PaymentTransactionID}/status (示例路径) 端点使用。
"""
+
PaymentTransactionID: int = Field(..., description="系统内部的支付事务ID。")
UserID: int = Field(..., description="发起支付的用户ID。")
TotalAmount: Decimal = Field(..., description="支付的总金额。")
PaymentMethod: str = Field(..., description="支付方式 (例如,'CreditCard', 'PayPal')。")
- ExternalGatewayTxID: Optional[str] = Field(None,
- description="外部支付网关的交易ID (如果适用)。")
- Status: PaymentTransactionStatusEnum = Field(...,
- description="支付事务的当前状态 (例如,'PENDING', 'SUCCESSFUL', 'FAILED')。")
- CreationTime: datetime.datetime = Field(...,
- description="支付事务创建的时间戳。")
- CompletionTime: Optional[datetime.datetime] = Field(None,
- description="支付事务完成的时间戳 (如果适用)。")
- LastUpdatedTime: datetime.datetime = Field(...,
- description="支付事务最后更新的时间戳。")
+ ExternalGatewayTxID: Optional[str] = Field(None, description="外部支付网关的交易ID (如果适用)。")
+ Status: PaymentTransactionStatusEnum = Field(
+ ..., description="支付事务的当前状态 (例如,'PENDING', 'SUCCESSFUL', 'FAILED')。"
+ )
+ CreationTime: datetime.datetime = Field(..., description="支付事务创建的时间戳。")
+ CompletionTime: Optional[datetime.datetime] = Field(None, description="支付事务完成的时间戳 (如果适用)。")
+ LastUpdatedTime: datetime.datetime = Field(..., description="支付事务最后更新的时间戳。")
diff --git a/src/backend/app/schemas/product_change_request_schema.py b/src/backend/app/schemas/product_change_request_schema.py
index 9d64951..068ccca 100644
--- a/src/backend/app/schemas/product_change_request_schema.py
+++ b/src/backend/app/schemas/product_change_request_schema.py
@@ -8,6 +8,7 @@ class ProductChangeRequestBase(BaseModel):
"""
商品变更请求基本信息模型
"""
+
MerchantUserID: int = Field(
...,
gt=0,
@@ -36,6 +37,7 @@ class ProductChangeRequestCreate(ProductChangeRequestBase):
"""
创建新商品变更请求时,API 端点期望的数据模型。
"""
+
ProductID: Optional[int] = Field(
None,
gt=0,
@@ -48,6 +50,7 @@ class ProductChangeRequestUpdate(BaseModel):
更新商品变更请求信息时,API 端点期望的数据模型。
包含可选字段。
"""
+
RequestType: Optional[Literal["PRODUCT_CREATE", "PRODUCT_UPDATE", "PRODUCT_DELETE"]] = Field(
None,
description="请求类型。可以是 'PRODUCT_CREATE'、'PRODUCT_UPDATE' 或 'PRODUCT_DELETE'。",
@@ -70,6 +73,7 @@ class ProductChangeRequestAdminUpdate(BaseModel):
"""
管理员更新商品变更请求的数据模型。
"""
+
AdminReviewerID: int = Field(
...,
gt=0,
@@ -93,6 +97,7 @@ class ProductChangeRequestResponse(ProductChangeRequestBase):
"""
商品变更请求响应模型,包含请求的基本信息和额外信息。
"""
+
ChangeRequestID: int = Field(
...,
description="变更请求的唯一标识符。",
@@ -134,6 +139,7 @@ class ProductChangeRequestQueryParams(BaseModel):
"""
商品变更请求查询参数模型
"""
+
ProductID: Optional[int] = Field(
None,
gt=0,
diff --git a/src/backend/app/schemas/product_change_request_schema_v2.py b/src/backend/app/schemas/product_change_request_schema_v2.py
index 2b91e3a..cdf8641 100644
--- a/src/backend/app/schemas/product_change_request_schema_v2.py
+++ b/src/backend/app/schemas/product_change_request_schema_v2.py
@@ -12,6 +12,7 @@
class ProductChangeRequestTypeApiEnum(str, Enum):
"""商品变更请求类型枚举"""
+
PRODUCT_CREATE = "PRODUCT_CREATE"
PRODUCT_UPDATE = "PRODUCT_UPDATE"
PRODUCT_DELETE = "PRODUCT_DELETE"
@@ -19,6 +20,7 @@ class ProductChangeRequestTypeApiEnum(str, Enum):
class ProductChangeRequestStatusApiEnum(str, Enum):
"""商品变更请求状态枚举"""
+
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED = "APPROVED"
REJECTED = "REJECTED"
@@ -34,20 +36,28 @@ class ProposedProductData(BaseModel):
- 对于 PRODUCT_UPDATE: 所有字段都是可选的,表示要更新的商品属性。
- 对于 PRODUCT_DELETE: 该对象通常为空或未提供。
"""
+
ProductName: Optional[str] = Field(None, max_length=255, description="商品名称。对于创建请求是必需的。")
ProductDescription: Optional[str] = Field(None, description="商品详细介绍。对于创建请求可选。")
- Price: Optional[Decimal] = Field(None, gt=Decimal(0), description="单价,如果提供则必须大于0。对于创建请求是必需的。")
- ProductStatus: Optional[ProductStatusApiEnum] = Field(None,
- description="商品状态。对于商家可以是ACTIVE,INACTIVE_BY_MERCHANT,DISCONTINUED。对于创建请求可选。")
+ Price: Optional[Decimal] = Field(
+ None, gt=Decimal(0), description="单价,如果提供则必须大于0。对于创建请求是必需的。"
+ )
+ ProductStatus: Optional[ProductStatusApiEnum] = Field(
+ None, description="商品状态。对于商家可以是ACTIVE,INACTIVE_BY_MERCHANT,DISCONTINUED。对于创建请求可选。"
+ )
CategoryID: Optional[int] = Field(None, description="所属的种类ID。对于创建请求是必需的。")
- StockQuantity: Optional[int] = Field(None, ge=0, description="库存数量, 如果提供则必须大于等于0。对于创建请求可选。")
+ StockQuantity: Optional[int] = Field(
+ None, ge=0, description="库存数量, 如果提供则必须大于等于0。对于创建请求可选。"
+ )
MainImageURL: Optional[str] = Field(None, max_length=512, description="商品主图片地址。可选。")
# --- ProductChangeRequest 相关 Schemas ---
+
class ProductChangeRequestBase(BaseModel):
"""商品变更请求的基础模型。"""
+
ProductID: Optional[int] = Field(None, description="如果是 PRODUCT_UPDATE 或 PRODUCT_DELETE,则提供商品ID。")
StoreID: int = Field(..., description="商品所属的店铺ID。对于所有请求都是必需的。")
RequestType: ProductChangeRequestTypeApiEnum = Field(..., description="请求类型(创建商品、更新商品或删除商品)。")
@@ -59,13 +69,16 @@ class ProductChangeRequestCreate(ProductChangeRequestBase):
用于 API 创建新的“商品变更请求”。
MerchantUserID 通常从请求上下文中获取。
"""
- ProposedData_JSON: Optional[ProposedProductData] = Field(None,
- description="提供的商品相关字段。会根据 RequestType 校验其内容。"
- "如果 RequestType 是 PRODUCT_CREATE,则必须提供所有基础字段。"
- "如果 RequestType 是 PRODUCT_UPDATE,则所有字段都是可选的。"
- "如果 RequestType 是 PRODUCT_DELETE,则不应提交该字段。")
- @model_validator(mode='after') # Pydantic V2 after validator
+ ProposedData_JSON: Optional[ProposedProductData] = Field(
+ None,
+ description="提供的商品相关字段。会根据 RequestType 校验其内容。"
+ "如果 RequestType 是 PRODUCT_CREATE,则必须提供所有基础字段。"
+ "如果 RequestType 是 PRODUCT_UPDATE,则所有字段都是可选的。"
+ "如果 RequestType 是 PRODUCT_DELETE,则不应提交该字段。",
+ )
+
+ @model_validator(mode="after") # Pydantic V2 after validator
def check_on_request_type(self) -> Self:
"""
在字段已填充后进行校验:
@@ -77,11 +90,13 @@ def check_on_request_type(self) -> Self:
if request_type == ProductChangeRequestTypeApiEnum.PRODUCT_CREATE:
# 不包含 ProductID
- if product_id is not None: raise ValueError("ProductID must be null when RequestType is PRODUCT_CREATE.")
+ if product_id is not None:
+ raise ValueError("ProductID must be null when RequestType is PRODUCT_CREATE.")
# 必须提供新商品的所有基础字段
if self.ProposedData_JSON is None:
raise ValueError(
- "ProposedData_JSON must be provided and be a dictionary when RequestType is PRODUCT_CREATE.")
+ "ProposedData_JSON must be provided and be a dictionary when RequestType is PRODUCT_CREATE."
+ )
# 所有基础字段必须存在,这个检查在服务层进行
elif request_type == ProductChangeRequestTypeApiEnum.PRODUCT_UPDATE:
@@ -91,7 +106,8 @@ def check_on_request_type(self) -> Self:
# 必须提供 ProposedData_JSON
if self.ProposedData_JSON is None:
raise ValueError(
- "ProposedData_JSON must be provided and be a dictionary when RequestType is PRODUCT_UPDATE.")
+ "ProposedData_JSON must be provided and be a dictionary when RequestType is PRODUCT_UPDATE."
+ )
elif request_type == ProductChangeRequestTypeApiEnum.PRODUCT_DELETE:
# 必须提供 ProductID
@@ -111,6 +127,7 @@ class ProductChangeRequestResponse(ProductChangeRequestBase):
"""
API 返回单个商品变更请求信息时使用的数据模型。
"""
+
# override ProductID
ProductID: Optional[int] = Field(None, description="商品ID。对于创建请求,可能还没有商品ID。")
@@ -129,6 +146,7 @@ class ProductChangeRequestListResponse(BaseModel):
"""
API 返回商品变更请求列表时使用的数据模型。
"""
+
Requests: List[ProductChangeRequestResponse] = Field(..., description="商品变更请求列表。")
TotalCount: int = Field(..., description="符合条件的请求总数。")
@@ -137,20 +155,21 @@ class ProductChangeRequestListResponse(BaseModel):
# 商家删除请求改为发送DELETE请求,不使用新的Schema
+
class ProductChangeRequestUpdateByAdmin(BaseModel):
"""
管理员审核并更新变更请求时发送的数据。
"""
+
Status: ProductChangeRequestStatusApiEnum = Field(..., description="管理员设置的新状态 (例如 APPROVED, REJECTED)。")
AdminNotes: Optional[str] = Field(None, description="管理员审核备注。")
- @model_validator(mode='after') # Pydantic V2 after validator
+ @model_validator(mode="after") # Pydantic V2 after validator
def check_status(self) -> Self:
"""
检查管理员是否只能将状态设置为 'APPROVED' 或 'REJECTED'。
"""
- if self.Status not in [ProductChangeRequestStatusApiEnum.APPROVED,
- ProductChangeRequestStatusApiEnum.REJECTED]:
+ if self.Status not in [ProductChangeRequestStatusApiEnum.APPROVED, ProductChangeRequestStatusApiEnum.REJECTED]:
raise ValueError("Admin can only set Status to 'APPROVED' or 'REJECTED'.")
return self
@@ -161,8 +180,10 @@ class ProductChangeRequestQueryParams(BaseModel):
用于 API 端点查询商品变更请求的参数模型。
将与 Depends() 一起使用。所有字段都是可选的查询参数。
"""
- Status: Optional[List[ProductChangeRequestStatusApiEnum]] = Field(default=None,
- description="按一个或多个状态筛选,例如 ['PENDING_APPROVAL', 'APPROVED']")
+
+ Status: Optional[List[ProductChangeRequestStatusApiEnum]] = Field(
+ default=None, description="按一个或多个状态筛选,例如 ['PENDING_APPROVAL', 'APPROVED']"
+ )
RequestType: Optional[ProductChangeRequestTypeApiEnum] = Field(default=None, description="按请求类型筛选")
StoreID: Optional[int] = Field(default=None, description="按店铺ID筛选")
MerchantUserID: Optional[int] = Field(default=None, description="按商家ID筛选")
diff --git a/src/backend/app/schemas/product_schema.py b/src/backend/app/schemas/product_schema.py
index b3b3d60..b9685cb 100644
--- a/src/backend/app/schemas/product_schema.py
+++ b/src/backend/app/schemas/product_schema.py
@@ -6,9 +6,9 @@
from fastapi import Query
-
class ProductStatusApiEnum(str, Enum):
"""商品状态枚举"""
+
ACTIVE = "ACTIVE"
INACTIVE_BY_MERCHANT = "INACTIVE_BY_MERCHANT"
SUSPENDED_BY_ADMIN = "SUSPENDED_BY_ADMIN"
@@ -19,6 +19,7 @@ class ProductBase(BaseModel):
"""
商品基本信息模型
"""
+
ProductName: str = Field(
...,
min_length=1,
@@ -55,6 +56,7 @@ class ProductCreate(ProductBase):
"""
创建新商品时,API 端点期望的商品数据模型。
"""
+
StoreID: int = Field(
...,
gt=0,
@@ -67,6 +69,7 @@ class ProductUpdate(BaseModel):
更新商品信息时,API 端点期望的商品数据模型。
包含可选字段。
"""
+
ProductName: Optional[str] = Field(
None,
min_length=1,
@@ -107,6 +110,7 @@ class ProductResponse(ProductBase):
"""
商品响应模型,包含商品的基本信息和额外信息。
"""
+
ProductID: int = Field(
...,
description="商品的唯一标识符。",
@@ -133,6 +137,7 @@ class ProductWithCategoryInfo(ProductResponse):
"""
包含分类信息的商品响应模型
"""
+
CategoryName: str = Field(
...,
description="商品所属分类名称。",
@@ -143,6 +148,7 @@ class ProductWithStoreInfo(ProductResponse):
"""
包含店铺信息的商品响应模型
"""
+
StoreName: str = Field(
...,
description="商品所属店铺名称。",
@@ -153,6 +159,7 @@ class ProductWithStoreAndCategoryInfo(ProductWithStoreInfo, ProductWithCategoryI
"""
包含店铺和分类信息的完整商品响应模型
"""
+
pass
@@ -160,45 +167,15 @@ class BaseProductQueryParams(BaseModel):
"""
商品查询参数基类,提供通用的筛选和分页选项
"""
- TextInclude: Optional[str] = Field(
- None,
- description="搜索关键词,支持商品名称和描述的模糊搜索"
- )
- CategoryID: Optional[int] = Field(
- None,
- gt=0,
- description="分类ID,按分类筛选商品"
- )
- StoreID: Optional[int] = Field(
- None,
- gt=0,
- description="店铺ID,按店铺筛选商品"
- )
- MinPrice: Optional[Decimal] = Field(
- None,
- ge=0,
- description="最低价格,筛选价格大于等于该值的商品"
- )
- MaxPrice: Optional[Decimal] = Field(
- None,
- gt=0,
- description="最高价格,筛选价格小于等于该值的商品"
- )
- OrderBy: Optional[str] = Field(
- None,
- description="排序字段,支持'price_asc'、'price_desc'、'newest'、'oldest'等"
- )
- Offset: int = Field(
- 0,
- ge=0,
- description="分页偏移量,默认0"
- )
- Limit: int = Field(
- 20,
- ge=1,
- le=100,
- description="返回结果数量限制,默认20,最大100"
- )
+
+ TextInclude: Optional[str] = Field(None, description="搜索关键词,支持商品名称和描述的模糊搜索")
+ CategoryID: Optional[int] = Field(None, gt=0, description="分类ID,按分类筛选商品")
+ StoreID: Optional[int] = Field(None, gt=0, description="店铺ID,按店铺筛选商品")
+ MinPrice: Optional[Decimal] = Field(None, ge=0, description="最低价格,筛选价格大于等于该值的商品")
+ MaxPrice: Optional[Decimal] = Field(None, gt=0, description="最高价格,筛选价格小于等于该值的商品")
+ OrderBy: Optional[str] = Field(None, description="排序字段,支持'price_asc'、'price_desc'、'newest'、'oldest'等")
+ Offset: int = Field(0, ge=0, description="分页偏移量,默认0")
+ Limit: int = Field(20, ge=1, le=100, description="返回结果数量限制,默认20,最大100")
class ProductQueryParamsByCustomer(BaseProductQueryParams):
@@ -206,6 +183,7 @@ class ProductQueryParamsByCustomer(BaseProductQueryParams):
普通用户查询商品的参数模型
普通用户只能看到ACTIVE状态的商品
"""
+
pass
@@ -214,9 +192,9 @@ class ProductQueryParamsByMerchant(BaseProductQueryParams):
商家查询商品的参数模型
商家可以筛选不同状态的商品
"""
+
ProductStatus: Optional[str] = Field(
- None,
- description="商品状态,可以是'ACTIVE'、'INACTIVE_BY_MERCHANT'或'DISCONTINUED'"
+ None, description="商品状态,可以是'ACTIVE'、'INACTIVE_BY_MERCHANT'或'DISCONTINUED'"
)
@@ -225,9 +203,9 @@ class ProductQueryParamsByAdmin(BaseProductQueryParams):
管理员查询商品的参数模型
管理员可以查看所有状态的商品
"""
+
ProductStatus: Optional[str] = Field(
- None,
- description="商品状态,可以是'ACTIVE'、'INACTIVE_BY_MERCHANT'、'SUSPENDED_BY_ADMIN'、'DISCONTINUED'等"
+ None, description="商品状态,可以是'ACTIVE'、'INACTIVE_BY_MERCHANT'、'SUSPENDED_BY_ADMIN'、'DISCONTINUED'等"
)
@@ -235,6 +213,7 @@ class ProductListParams(BaseModel):
"""
商品列表查询参数模型(旧版,保留向后兼容)
"""
+
store_id: Optional[int] = Field(
None,
gt=0,
diff --git a/src/backend/app/schemas/statistics_schema.py b/src/backend/app/schemas/statistics_schema.py
index 4962e27..cd223d5 100644
--- a/src/backend/app/schemas/statistics_schema.py
+++ b/src/backend/app/schemas/statistics_schema.py
@@ -15,6 +15,7 @@ class SystemStatistics(BaseModel):
class AdminDashboardStatistics(SystemStatistics):
"""Admin dashboard statistics extends system statistics."""
+
pass
diff --git a/src/backend/app/schemas/store_change_request_schema.py b/src/backend/app/schemas/store_change_request_schema.py
index b6d6467..693cb10 100644
--- a/src/backend/app/schemas/store_change_request_schema.py
+++ b/src/backend/app/schemas/store_change_request_schema.py
@@ -8,6 +8,7 @@ class StoreChangeRequestBase(BaseModel):
"""
店铺变更请求基本信息模型
"""
+
RequestingUserID: int = Field(
...,
gt=0,
@@ -31,6 +32,7 @@ class StoreChangeRequestCreate(StoreChangeRequestBase):
"""
创建新店铺变更请求时,API 端点期望的数据模型。
"""
+
StoreID: Optional[int] = Field(
None,
gt=0,
@@ -43,6 +45,7 @@ class StoreChangeRequestUpdate(BaseModel):
更新店铺变更请求信息时,API 端点期望的数据模型。
包含可选字段。
"""
+
RequestType: Optional[Literal["STORE_CREATE", "STORE_UPDATE", "STORE_DELETE"]] = Field(
None,
description="请求类型。可以是 'STORE_CREATE'、'STORE_UPDATE' 或 'STORE_DELETE'。",
@@ -65,6 +68,7 @@ class StoreChangeRequestAdminUpdate(BaseModel):
"""
管理员更新店铺变更请求的数据模型。
"""
+
AdminReviewerID: int = Field(
...,
gt=0,
@@ -88,6 +92,7 @@ class StoreChangeRequestResponse(StoreChangeRequestBase):
"""
店铺变更请求响应模型,包含请求的基本信息和额外信息。
"""
+
ChangeRequestID: int = Field(
...,
description="变更请求的唯一标识符。",
@@ -129,6 +134,7 @@ class StoreChangeRequestQueryParams(BaseModel):
"""
店铺变更请求查询参数模型
"""
+
StoreID: Optional[int] = Field(
None,
gt=0,
diff --git a/src/backend/app/schemas/store_change_request_schema_v2.py b/src/backend/app/schemas/store_change_request_schema_v2.py
index 3c17696..bc51303 100644
--- a/src/backend/app/schemas/store_change_request_schema_v2.py
+++ b/src/backend/app/schemas/store_change_request_schema_v2.py
@@ -39,13 +39,9 @@ class ProposedStoreData(BaseModel):
- 对于 STORE_DELETE: 该对象通常为空或未提供。
"""
- StoreName: Optional[str] = Field(
- None, max_length=255, description="店铺名称。对于创建请求是必需的。"
- )
+ StoreName: Optional[str] = Field(None, max_length=255, description="店铺名称。对于创建请求是必需的。")
Description: Optional[str] = Field(None, description="店铺描述。对于创建请求可选。")
- LogoURL: Optional[str] = Field(
- None, max_length=512, description="店铺 Logo 的 URL。对于创建请求可选。"
- )
+ LogoURL: Optional[str] = Field(None, max_length=512, description="店铺 Logo 的 URL。对于创建请求可选。")
StoreStatus: Optional[StoreStatusEnum] = Field(
None,
description="店铺状态。对于商家可以是 ACTIVE,INACTIVE_BY_MERCHANT。对于创建请求可选。",
@@ -60,6 +56,7 @@ class StoreChangeRequestBase(BaseModel):
店铺变更请求的基础模型。
包含所有店铺变更请求共有的字段。
"""
+
RequestType: StoreChangeRequestTypeEnum = Field(
...,
description="请求类型(创建店铺、更新店铺或删除店铺)。",
@@ -143,12 +140,8 @@ class StoreChangeRequestResponse(BaseModel):
ChangeRequestID: int = Field(..., description="变更请求的唯一ID。")
StoreID: Optional[int] = Field(None, description="目标店铺ID。对于 STORE_CREATE 请求可能为空。")
RequestingUserID: int = Field(..., description="提交请求的用户ID。")
- RequestType: StoreChangeRequestTypeEnum = Field(
- ..., description="请求类型(创建、更新或删除店铺)。"
- )
- ProposedData_JSON: Optional[ProposedStoreData] = Field(
- None, description="建议的数据体 (原始JSON)。"
- )
+ RequestType: StoreChangeRequestTypeEnum = Field(..., description="请求类型(创建、更新或删除店铺)。")
+ ProposedData_JSON: Optional[ProposedStoreData] = Field(None, description="建议的数据体 (原始JSON)。")
Status: StoreChangeRequestStatusEnum = Field(..., description="请求的当前状态。")
SubmitterNotes: Optional[str] = Field(None, description="商家提交备注。可选。")
AdminReviewerID: Optional[int] = Field(None, description="审核请求的管理员UserID。")
diff --git a/src/backend/app/schemas/store_schema.py b/src/backend/app/schemas/store_schema.py
index 84a065c..6233271 100644
--- a/src/backend/app/schemas/store_schema.py
+++ b/src/backend/app/schemas/store_schema.py
@@ -10,6 +10,7 @@ class StoreStatusEnum(str, Enum):
店铺状态枚举 (基于 DDL)。
会被 StoreResponse, StoreCreate, StoreUpdateRequest 使用。
"""
+
ACTIVE = "ACTIVE" # 活动中
INACTIVE_BY_MERCHANT = "INACTIVE_BY_MERCHANT" # 商家停用
SUSPENDED_BY_ADMIN = "SUSPENDED_BY_ADMIN" # 管理员暂停
@@ -21,6 +22,7 @@ class StoreSimpleResponse(BaseModel):
查询店铺列表时,API 端点返回的商店数据模型。
这个模型只包含店铺 ID 和名称。
"""
+
StoreID: int = Field(
...,
description="店铺 ID",
@@ -30,10 +32,12 @@ class StoreSimpleResponse(BaseModel):
description="店铺名称",
)
+
class StoreResponse(StoreSimpleResponse):
"""
查询店铺信息时,API 端点返回的商店数据模型。
"""
+
OwnerUserID: int = Field(
...,
description="店主用户 ID",
@@ -59,11 +63,13 @@ class StoreResponse(StoreSimpleResponse):
description="店铺最后更新时间",
)
+
class StoreListSimpleResponse(BaseModel):
"""
查询店铺列表时,API 端点返回的商店数据模型。
这个模型只包含店铺 ID 和名称。
"""
+
Count: int = Field(
...,
description="店铺数量",
@@ -73,10 +79,12 @@ class StoreListSimpleResponse(BaseModel):
description="店铺列表",
)
+
class StoreListResponse(BaseModel):
"""
查询店铺列表时,API 端点返回的商店数据模型。
"""
+
Count: int = Field(
...,
description="店铺数量",
@@ -93,6 +101,7 @@ class StoreCreate(BaseModel):
这个请求一般是由管理员审核请求之后由管理员发起的。
对于店主来说,创建申请要使用其他文件中的 StoreCreateRequest。
"""
+
StoreName: str = Field(
...,
max_length=50,
@@ -126,6 +135,7 @@ class StoreUpdate(BaseModel):
更新店铺信息时,API 端点期望的商店数据模型。
这个请求一般是由管理员审核请求之后由管理员发起的。
"""
+
StoreName: Optional[str] = Field(
None,
max_length=50,
diff --git a/src/backend/app/schemas/user_schema.py b/src/backend/app/schemas/user_schema.py
index 83cd41e..b83b640 100644
--- a/src/backend/app/schemas/user_schema.py
+++ b/src/backend/app/schemas/user_schema.py
@@ -70,14 +70,12 @@ class UserStatusUpdate(BaseModel):
更新用户内部状态时,API 端点期望的用户数据模型。
包含可选字段:新的账户状态/新的用户角色。
"""
+
AccountStatus: Optional[str] = Field(
None,
description="用户的账户状态。可以是 'ACTIVE'、'INACTIVE' 或 'SUSPENDED_BY_ADMIN'。",
)
- UserRole: Optional[str] = Field(
- None,
- description="用户的角色。可以是 'customer'、'merchant' 或 'admin'。"
- )
+ UserRole: Optional[str] = Field(None, description="用户的角色。可以是 'customer'、'merchant' 或 'admin'。")
class UserResponse(UserBase):
diff --git a/src/backend/app/services/address_service.py b/src/backend/app/services/address_service.py
index fcdf4c1..a52049e 100644
--- a/src/backend/app/services/address_service.py
+++ b/src/backend/app/services/address_service.py
@@ -12,8 +12,9 @@
AddressUpdateRequest,
AddressResponse,
AddressListResponse,
- SetDefaultAddressResponse
+ SetDefaultAddressResponse,
)
+
# 导入自定义异常 (如果需要)
from backend.app.utils.exceptions import AddressNotFoundException, PermissionDeniedException, UserNotFoundException
@@ -34,12 +35,12 @@ def __init__(self, address_crud: AddressCRUD, user_crud: UserCRUD):
# logger.info(f"{self.__class__.__name__} initialized.")
async def create_new_address(
- self,
- db: Connection,
- *,
- user_id: int, # 地址所属的用户ID (通常是当前认证用户)
- address_in: AddressCreateRequest, # 包含新地址信息的 Pydantic 模型 (IsDefault 在 AddressBase 中默认为 False)
- actor_id: int # 执行此操作的用户ID (通常与 user_id 相同,或为管理员ID)
+ self,
+ db: Connection,
+ *,
+ user_id: int, # 地址所属的用户ID (通常是当前认证用户)
+ address_in: AddressCreateRequest, # 包含新地址信息的 Pydantic 模型 (IsDefault 在 AddressBase 中默认为 False)
+ actor_id: int, # 执行此操作的用户ID (通常与 user_id 相同,或为管理员ID)
) -> AddressResponse:
"""
为指定用户创建一个新的收货地址。
@@ -68,10 +69,7 @@ async def create_new_address(
# Create the address. AddressCRUD.create_address sets IsDefault to FALSE.
created_address_dict = self._address_crud.create_address(
- conn=db,
- user_id=user_id,
- address_in=address_in,
- actor_id=actor_id
+ conn=db, user_id=user_id, address_in=address_in, actor_id=actor_id
)
if not created_address_dict:
logger.error(f"Failed to create address in CRUD layer for UserID {user_id}.")
@@ -83,11 +81,7 @@ async def create_new_address(
return AddressResponse(**created_address_dict)
async def get_user_addresses(
- self,
- db: Connection,
- *,
- user_id: int, # 要查询其地址的用户ID
- actor_id: int # 执行查询操作的用户ID
+ self, db: Connection, *, user_id: int, actor_id: int # 要查询其地址的用户ID # 执行查询操作的用户ID
) -> AddressListResponse:
"""
获取指定用户的所有收货地址。
@@ -102,23 +96,19 @@ async def get_user_addresses(
logger.warning(f"UserID {actor_id} attempting to get addresses for UserID {user_id}.")
# TODO: add admin check
- addresses_data = self._address_crud.get_addresses_by_user_id(
- conn=db, user_id=user_id, actor_id=actor_id
- )
+ addresses_data = self._address_crud.get_addresses_by_user_id(conn=db, user_id=user_id, actor_id=actor_id)
address_responses = [AddressResponse(**addr_data) for addr_data in addresses_data]
- logger.success(
- f"Fetched {len(address_responses)} addresses for UserID {user_id}."
- )
+ logger.success(f"Fetched {len(address_responses)} addresses for UserID {user_id}.")
return AddressListResponse(Addresses=address_responses, TotalCount=len(address_responses))
async def get_address_by_id_for_user(
- self,
- db: Connection,
- *,
- address_id: int,
- user_id: int, # 期望拥有此地址的用户ID (通常是当前认证用户)
- actor_id: int # 执行操作的用户ID
+ self,
+ db: Connection,
+ *,
+ address_id: int,
+ user_id: int,
+ actor_id: int, # 期望拥有此地址的用户ID (通常是当前认证用户) # 执行操作的用户ID
) -> AddressResponse:
"""
获取用户拥有的特定收货地址的详细信息。
@@ -146,24 +136,23 @@ async def get_address_by_id_for_user(
# Check ownership
if address_data["UserID"] != user_id:
logger.warning(
- f"Ownership mismatch: AddressID {address_id} belongs to UserID {address_data['UserID']}, requested by UserID {user_id}.")
+ f"Ownership mismatch: AddressID {address_id} belongs to UserID {address_data['UserID']}, requested by UserID {user_id}."
+ )
# This could also be a PermissionDeniedException depending on whether the user should even know it exists for someone else.
# For "get_address_by_id_for_user", it implies the caller *expects* it to belong to user_id.
raise AddressNotFoundException(f"Address with ID {address_id} not found for user {user_id}.")
- logger.success(
- f"AddressID {address_id} successfully retrieved for UserID {user_id}."
- )
+ logger.success(f"AddressID {address_id} successfully retrieved for UserID {user_id}.")
return AddressResponse(**address_data)
async def update_address_details(
- self,
- db: Connection,
- *,
- address_id: int,
- address_in: AddressUpdateRequest, # 只包含文本字段,不含 IsDefault
- user_id_making_change: int, # 拥有该地址并执行更改的用户ID
- actor_id: int # 用于审计的 actor ID
+ self,
+ db: Connection,
+ *,
+ address_id: int,
+ address_in: AddressUpdateRequest, # 只包含文本字段,不含 IsDefault
+ user_id_making_change: int, # 拥有该地址并执行更改的用户ID
+ actor_id: int, # 用于审计的 actor ID
) -> AddressResponse:
"""
更新现有收货地址的文本信息(收货人、电话、地址详情)。
@@ -182,11 +171,13 @@ async def update_address_details(
:raises Exception: 如果更新失败。
"""
logger.info(
- f"ActorID {actor_id} attempting to update AddressID {address_id} for UserID {user_id_making_change}.")
+ f"ActorID {actor_id} attempting to update AddressID {address_id} for UserID {user_id_making_change}."
+ )
if actor_id != user_id_making_change:
logger.warning(
- f"UserID {actor_id} attempting to update AddressID {address_id} for UserID {user_id_making_change}.")
+ f"UserID {actor_id} attempting to update AddressID {address_id} for UserID {user_id_making_change}."
+ )
# TODO: add admin check
# 1. Verify address exists and belongs to the user making the change
@@ -197,26 +188,22 @@ async def update_address_details(
raise AddressNotFoundException(f"Address with ID {address_id} not found for user {user_id_making_change}.")
updated_address_dict = self._address_crud.update_address_details(
- conn=db,
- address_id=address_id,
- address_in=address_in,
- actor_id=actor_id
+ conn=db, address_id=address_id, address_in=address_in, actor_id=actor_id
)
if not updated_address_dict:
logger.error(f"Failed to update address details in CRUD layer for AddressID {address_id}.")
raise Exception("Failed to update address details.") # Or more specific
- logger.success(
- f"AddressID {address_id} successfully updated for UserID {user_id_making_change}.")
+ logger.success(f"AddressID {address_id} successfully updated for UserID {user_id_making_change}.")
return AddressResponse(**updated_address_dict)
async def set_default_address_for_user(
- self,
- db: Connection,
- *,
- user_id: int, # 目标用户ID
- address_id_to_set_default: int, # 要设为默认的地址ID
- actor_id: int # 执行此操作的用户ID
+ self,
+ db: Connection,
+ *,
+ user_id: int,
+ address_id_to_set_default: int,
+ actor_id: int, # 目标用户ID # 要设为默认的地址ID # 执行此操作的用户ID
) -> SetDefaultAddressResponse:
"""
将指定的地址ID设为用户的默认收货地址。
@@ -237,11 +224,13 @@ async def set_default_address_for_user(
:raises Exception: 如果操作失败。
"""
logger.info(
- f"ActorID {actor_id} attempting to set AddressID {address_id_to_set_default} as default for UserID {user_id}.")
+ f"ActorID {actor_id} attempting to set AddressID {address_id_to_set_default} as default for UserID {user_id}."
+ )
if actor_id != user_id:
logger.warning(
- f"UserID {actor_id} attempting to set AddressID {address_id_to_set_default} as default for UserID {user_id}.")
+ f"UserID {actor_id} attempting to set AddressID {address_id_to_set_default} as default for UserID {user_id}."
+ )
# TODO: add admin check
# 1. Verify the address to be set as default exists and belongs to the user
@@ -252,7 +241,8 @@ async def set_default_address_for_user(
raise AddressNotFoundException(f"Address with ID {address_id_to_set_default} not found.")
if address_to_set_default["UserID"] != user_id:
raise AddressNotFoundException(
- f"Address with ID {address_id_to_set_default} does not belong to UserID {user_id}.")
+ f"Address with ID {address_id_to_set_default} does not belong to UserID {user_id}."
+ )
# 2.1 Set all other addresses for this user to IsDefault = False
self._address_crud.set_all_other_addresses_non_default_for_user(
@@ -278,9 +268,7 @@ async def set_default_address_for_user(
# Potentially rollback or raise a more severe error if this part fails
raise Exception("Failed to update user's default address reference in CRUD layer.")
- logger.success(
- f"AddressID {address_id_to_set_default} successfully set as default for UserID {user_id}."
- )
+ logger.success(f"AddressID {address_id_to_set_default} successfully set as default for UserID {user_id}.")
# 5. Get the newly set default address to return in the response
final_default_address_dict = self._address_crud.get_address_by_id(
@@ -288,20 +276,21 @@ async def set_default_address_for_user(
)
if not final_default_address_dict: # Should not happen
raise AddressNotFoundException(
- f"Default address {address_id_to_set_default} could not be retrieved after update.")
+ f"Default address {address_id_to_set_default} could not be retrieved after update."
+ )
return SetDefaultAddressResponse(
Message=f"Address {address_id_to_set_default} successfully set as default for user {user_id}.",
- DefaultAddress=AddressResponse(**final_default_address_dict)
+ DefaultAddress=AddressResponse(**final_default_address_dict),
)
async def delete_address_for_user(
- self,
- db: Connection,
- *,
- address_id: int,
- user_id_making_request: int, # 执行删除操作的用户ID
- actor_id: int # 用于审计的 actor ID
+ self,
+ db: Connection,
+ *,
+ address_id: int,
+ user_id_making_request: int,
+ actor_id: int, # 执行删除操作的用户ID # 用于审计的 actor ID
) -> bool: # 返回 True 表示成功删除
"""
删除用户的一个收货地址。
@@ -319,15 +308,16 @@ async def delete_address_for_user(
:raises UserNotFoundException: 如果用户未找到 (理论上不应发生,因为 user_id 来自认证用户或有效输入)。
"""
logger.info(
- f"ActorID {actor_id} attempting to delete AddressID {address_id} for UserID {user_id_making_request}.")
+ f"ActorID {actor_id} attempting to delete AddressID {address_id} for UserID {user_id_making_request}."
+ )
# Additional check if actor is different (admin case)
if actor_id != user_id_making_request:
logger.warning(
- f"UserID {actor_id} attempting to delete AddressID {address_id} for UserID {user_id_making_request}.")
+ f"UserID {actor_id} attempting to delete AddressID {address_id} for UserID {user_id_making_request}."
+ )
# TODO: add admin check
-
# 1. Verify address exists and belongs to the user making the request
address_to_delete = self._address_crud.get_address_by_id(conn=db, address_id=address_id, actor_id=actor_id)
if not address_to_delete:
@@ -335,7 +325,8 @@ async def delete_address_for_user(
raise AddressNotFoundException(f"Address with ID {address_id} not found.")
if address_to_delete["UserID"] != user_id_making_request:
logger.warning(
- f"AddressID {address_id} not found for UserID {user_id_making_request}. It belongs to UserID {address_to_delete['UserID']}.")
+ f"AddressID {address_id} not found for UserID {user_id_making_request}. It belongs to UserID {address_to_delete['UserID']}."
+ )
raise AddressNotFoundException(f"Address with ID {address_id} not found for user {user_id_making_request}.")
# 2. Check if it's the user's default address
@@ -344,27 +335,29 @@ async def delete_address_for_user(
# Should not happen if user is authenticated and address belongs to them
raise UserNotFoundException(f"User {user_id_making_request} not found during address deletion.")
- was_default = (user_data.get("DefaultAddressID") == address_id)
-
+ was_default = user_data.get("DefaultAddressID") == address_id
# 3. If it was the default, update User.DefaultAddressID to NULL
if was_default:
logger.info(
f"AddressID {address_id} was the default for UserID {user_id_making_request}."
- f" Setting User.DefaultAddressID to NULL.")
+ f" Setting User.DefaultAddressID to NULL."
+ )
update_user_success = self._user_crud.update_user_default_address_id(
conn=db, user_id=user_id_making_request, default_address_id=None, actor_id=actor_id
)
if not update_user_success:
logger.error(
f"Failed to set User.DefaultAddressID to NULL for"
- f" UserID {user_id_making_request} after deleting default address {address_id}.")
+ f" UserID {user_id_making_request} after deleting default address {address_id}."
+ )
# 4. Delete the address
deleted = self._address_crud.delete_address(conn=db, address_id=address_id, actor_id=actor_id)
if not deleted:
logger.warning(
- f"AddressCRUD.delete_address returned False for AddressID {address_id}, it might have already been deleted.")
+ f"AddressCRUD.delete_address returned False for AddressID {address_id}, it might have already been deleted."
+ )
return False # Or raise an exception if this is unexpected
return True
diff --git a/src/backend/app/services/auth_service.py b/src/backend/app/services/auth_service.py
index 5d2a40b..461027b 100644
--- a/src/backend/app/services/auth_service.py
+++ b/src/backend/app/services/auth_service.py
@@ -19,12 +19,12 @@
class AuthService:
def __init__(
- self,
- user_crud: UserCRUD,
- user_session_crud: UserSessionCRUD,
- verify_password_func: PasswordVerifyFunc,
- create_jwt_func: TokenCreateFunc,
- decode_jwt_func: TokenDecodeFunc,
+ self,
+ user_crud: UserCRUD,
+ user_session_crud: UserSessionCRUD,
+ verify_password_func: PasswordVerifyFunc,
+ create_jwt_func: TokenCreateFunc,
+ decode_jwt_func: TokenDecodeFunc,
):
self._user_crud = user_crud
self._user_session_crud = user_session_crud
@@ -33,7 +33,7 @@ def __init__(
self._decode_jwt = decode_jwt_func
async def authenticate_user(
- self, conn: Connection, *, identifier: str, password_plain: str
+ self, conn: Connection, *, identifier: str, password_plain: str
) -> Optional[UserInDBForAuth]:
"""
验证用户凭证。
@@ -43,8 +43,9 @@ async def authenticate_user(
# actor_user_id 在这里通常不适用,除非记录登录尝试
user_data_dict = self._user_crud.get_user_with_password_by_email(conn, email=identifier, actor_id=None)
if not user_data_dict:
- user_data_dict = self._user_crud.get_user_with_password_by_username(conn, username=identifier,
- actor_id=None)
+ user_data_dict = self._user_crud.get_user_with_password_by_username(
+ conn, username=identifier, actor_id=None
+ )
if not user_data_dict:
logger.info(f"User with Username or Email {identifier} not found.")
@@ -60,12 +61,12 @@ async def authenticate_user(
return user_for_auth
async def login_user(
- self,
- conn: Connection,
- *,
- login_data: UserLogin,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
+ self,
+ conn: Connection,
+ *,
+ login_data: UserLogin,
+ ip_address: Optional[str] = None,
+ user_agent: Optional[str] = None,
) -> Token:
"""
处理用户登录:验证凭证,创建 JWT,创建数据库会话记录。
@@ -105,14 +106,12 @@ async def login_user(
user_id=authenticated_user.UserID,
expires_at=session_expires_at,
ip_address=ip_address,
- user_agent=user_agent
+ user_agent=user_agent,
)
return Token(access_token=access_token_str, token_type="bearer")
- async def get_user_from_valid_payload(
- self, conn: Connection, *, payload: TokenPayload
- ) -> UserResponse:
+ async def get_user_from_valid_payload(self, conn: Connection, *, payload: TokenPayload) -> UserResponse:
"""
从有效的 JWT 负载中获取用户信息。
:param conn: 数据库连接。
@@ -131,7 +130,9 @@ async def get_user_from_valid_payload(
raise InvalidTokenException("Session not found.")
if active_session_user_id != payload.user_id:
- logger.error(f"Token {payload.jti} leads to user ID {active_session_user_id}, but payload says {payload.user_id}.")
+ logger.error(
+ f"Token {payload.jti} leads to user ID {active_session_user_id}, but payload says {payload.user_id}."
+ )
raise InvalidTokenException("Token does not match user ID in payload.")
user_data_dict = self._user_crud.get_user_by_id(conn, user_id=active_session_user_id)
@@ -141,10 +142,8 @@ async def get_user_from_valid_payload(
return UserResponse(**user_data_dict)
-
-
async def logout_user_session(
- self, conn: Connection, *, jwt_token_to_invalidate: str, actor_user_id: Optional[int]
+ self, conn: Connection, *, jwt_token_to_invalidate: str, actor_user_id: Optional[int]
) -> bool:
"""
用户登出(单个设备/会话)。
@@ -175,7 +174,7 @@ async def logout_user_session(
return was_deleted
async def logout_all_user_sessions(
- self, conn: Connection, *, user_id_to_logout: int, actor_user_id: Optional[int]
+ self, conn: Connection, *, user_id_to_logout: int, actor_user_id: Optional[int]
) -> int:
"""
用户所有设备登出。
@@ -187,7 +186,8 @@ async def logout_all_user_sessions(
# 权限检查:确保 actor_user_id 有权为 user_id_to_logout 执行此操作
if actor_user_id != user_id_to_logout:
logger.warning(
- f"Actor ID {actor_user_id} attempting to logout all sessions for user ID {user_id_to_logout}.")
+ f"Actor ID {actor_user_id} attempting to logout all sessions for user ID {user_id_to_logout}."
+ )
# 这里需要权限检查,例如,actor_user_id 必须是管理员或 user_id_to_logout 本人
# if not is_admin(actor_user_id): raise PermissionDeniedException()
pass # 示例:允许,但应有权限控制
diff --git a/src/backend/app/services/cart_service.py b/src/backend/app/services/cart_service.py
index fd05a54..01078f0 100644
--- a/src/backend/app/services/cart_service.py
+++ b/src/backend/app/services/cart_service.py
@@ -15,8 +15,13 @@
# CartItemUpdateRequest
)
-from backend.app.utils.exceptions import ProductNotFoundException, CartItemNotFoundException, PermissionDeniedException, \
- InsufficientStockException, ProductFieldMissingException
+from backend.app.utils.exceptions import (
+ ProductNotFoundException,
+ CartItemNotFoundException,
+ PermissionDeniedException,
+ InsufficientStockException,
+ ProductFieldMissingException,
+)
class CartService:
@@ -33,8 +38,8 @@ async def get_user_cart_details(self, db: Connection, *, user_id: int) -> CartRe
# 1. 从 CRUD 获取用户的购物车条目 (已经是字典列表)
cart_items_data = self._cart_item_crud.get_cart_items_by_user_id(
- conn=db, user_id=user_id, actor_id=user_id # actor_id 通常是 user_id
- )
+ conn=db, user_id=user_id, actor_id=user_id
+ ) # actor_id 通常是 user_id
# 2. 将 CRUD 返回的字典列表转换为 CartItemResponse schema 列表
# 并可能补充最新的产品信息(如果 CRUD 没有 JOIN 的话)
@@ -53,25 +58,17 @@ async def get_user_cart_details(self, db: Connection, *, user_id: int) -> CartRe
logger.error(f"Error creating CartItemResponse from data {item_data} for UserID {user_id}: {e}")
raise e
- return CartResponse(
- Items=cart_item_responses,
- TotalItems=len(cart_item_responses)
- )
+ return CartResponse(Items=cart_item_responses, TotalItems=len(cart_item_responses))
async def add_item_to_cart(
- self,
- db: Connection,
- *,
- user_id: int,
- product_id: int,
- quantity_to_add: int,
- actor_id: int
+ self, db: Connection, *, user_id: int, product_id: int, quantity_to_add: int, actor_id: int
) -> CartItemResponse: # 或者返回 CartResponse
"""
将商品添加到购物车或更新数量。
"""
logger.info(
- f"Attempting to add ProductID: {product_id}, Quantity: {quantity_to_add} to cart for UserID: {user_id} by ActorID: {actor_id}")
+ f"Attempting to add ProductID: {product_id}, Quantity: {quantity_to_add} to cart for UserID: {user_id} by ActorID: {actor_id}"
+ )
if quantity_to_add <= 0:
raise ValueError("Quantity to add must be positive.")
@@ -101,7 +98,7 @@ async def add_item_to_cart(
product_id=product_id,
quantity=quantity_to_add, # CRUD 方法接收的是要设置的总数量或增量,根据其实现
price_at_addition=float(current_price), # 确保是 float
- actor_id=actor_id
+ actor_id=actor_id,
)
if not added_or_updated_item_data:
@@ -114,18 +111,14 @@ async def add_item_to_cart(
# 或者: return await self.get_user_cart_details(db, user_id=user_id) # 返回整个更新后的购物车
async def update_cart_item_quantity(
- self,
- db: Connection,
- *,
- cart_item_id: int,
- new_quantity: int,
- user_id_making_change: int
+ self, db: Connection, *, cart_item_id: int, new_quantity: int, user_id_making_change: int
) -> CartItemResponse:
"""
更新购物车中特定条目的商品数量。
"""
logger.info(
- f"Attempting to update CartItemID: {cart_item_id} to Quantity: {new_quantity} by UserID: {user_id_making_change}")
+ f"Attempting to update CartItemID: {cart_item_id} to Quantity: {new_quantity} by UserID: {user_id_making_change}"
+ )
if new_quantity <= 0:
logger.warning(f"Invalid quantity {new_quantity} for update. Removing item instead.")
@@ -142,15 +135,14 @@ async def update_cart_item_quantity(
raise CartItemNotFoundException(f"Cart item with ID {cart_item_id} not found.")
if cart_item_to_check["UserID"] != user_id_making_change:
# TODO: add admin check
- logger.warning(f"User {user_id_making_change} tries to update"
- f" CartItemID {cart_item_id} belonging to user {cart_item_to_check['UserID']}.")
+ logger.warning(
+ f"User {user_id_making_change} tries to update"
+ f" CartItemID {cart_item_id} belonging to user {cart_item_to_check['UserID']}."
+ )
# 2. 调用 CRUD 更新数量
updated_item_data = self._cart_item_crud.update_cart_item_quantity(
- conn=db,
- cart_item_id=cart_item_id,
- new_quantity=new_quantity,
- actor_id=user_id_making_change
+ conn=db, cart_item_id=cart_item_id, new_quantity=new_quantity, actor_id=user_id_making_change
)
if not updated_item_data:
@@ -162,7 +154,7 @@ async def update_cart_item_quantity(
# 或者: return await self.get_user_cart_details(db, user_id=user_id_making_change)
async def remove_cart_item(
- self, db: Connection, *, cart_item_id: int, user_id_making_change: int
+ self, db: Connection, *, cart_item_id: int, user_id_making_change: int
) -> bool: # 或者 CartActionResponse
"""
从购物车中移除指定的商品条目。
@@ -179,8 +171,10 @@ async def remove_cart_item(
return False # 或者 True,取决于业务定义
if cart_item_to_check["UserID"] != user_id_making_change:
# TODO: add admin check
- logger.warning(f"User {user_id_making_change} tries to remove"
- f" CartItemID {cart_item_id} belonging to user {cart_item_to_check['UserID']}.")
+ logger.warning(
+ f"User {user_id_making_change} tries to remove"
+ f" CartItemID {cart_item_id} belonging to user {cart_item_to_check['UserID']}."
+ )
# 2. 调用 CRUD 删除
success = self._cart_item_crud.remove_item_from_cart(
@@ -188,9 +182,7 @@ async def remove_cart_item(
)
return success
- async def clear_cart(
- self, db: Connection, *, user_id: int, actor_id: int
- ) -> int: # 或者 CartActionResponse
+ async def clear_cart(self, db: Connection, *, user_id: int, actor_id: int) -> int: # 或者 CartActionResponse
"""
清空指定用户购物车中的所有商品。
"""
@@ -202,7 +194,5 @@ async def clear_cart(
logger.warning(f"ActorID {actor_id} is clearing cart for UserID {user_id}.")
pass
- deleted_count = self._cart_item_crud.clear_cart_for_user(
- conn=db, user_id=user_id, actor_id=actor_id
- )
+ deleted_count = self._cart_item_crud.clear_cart_for_user(conn=db, user_id=user_id, actor_id=actor_id)
return deleted_count
diff --git a/src/backend/app/services/order_service.py b/src/backend/app/services/order_service.py
index 02fcfec..6e1c7bb 100644
--- a/src/backend/app/services/order_service.py
+++ b/src/backend/app/services/order_service.py
@@ -22,7 +22,7 @@
PaymentTransactionStatusEnum,
CreatedOrderDetailResponse,
OrderItemDetailResponse,
- OrderActionResponse
+ OrderActionResponse,
)
from backend.app.schemas.payment_schema import PaymentProcessingResponse, PaymentResponse
from backend.app.utils.exceptions import (
@@ -31,8 +31,10 @@
OrderNotFoundException,
PermissionDeniedException,
InsufficientStockException,
- InvalidStatusTransitionException, ProductNotFoundException, PaymentTransactionNotFoundException,
- InvalidPaymentStatusTransitionException
+ InvalidStatusTransitionException,
+ ProductNotFoundException,
+ PaymentTransactionNotFoundException,
+ InvalidPaymentStatusTransitionException,
)
from backend.app.utils.timezone import cast_dict_datetime_to_utc
@@ -43,13 +45,13 @@ class OrderService:
"""
def __init__(
- self,
- order_crud: OrderCRUD,
- order_item_crud: OrderItemCRUD,
- address_crud: AddressCRUD,
- cart_item_crud: CartItemCRUD,
- product_crud: ProductCRUD,
- payment_transaction_crud: PaymentTransactionCRUD
+ self,
+ order_crud: OrderCRUD,
+ order_item_crud: OrderItemCRUD,
+ address_crud: AddressCRUD,
+ cart_item_crud: CartItemCRUD,
+ product_crud: ProductCRUD,
+ payment_transaction_crud: PaymentTransactionCRUD,
):
"""
初始化订单服务。
@@ -67,12 +69,7 @@ def __init__(
logger.info(f"{self.__class__.__name__} initialized.")
async def process_order_creation(
- self,
- db: Connection,
- *,
- user_id: int,
- order_create_request: OrderCreateRequest,
- actor_id: int
+ self, db: Connection, *, user_id: int, order_create_request: OrderCreateRequest, actor_id: int
) -> InitiateOrderResponse:
"""
处理订单创建流程,包括:
@@ -95,30 +92,31 @@ async def process_order_creation(
assert db.in_transaction(), "Database connection must be in a transaction."
# 验证地址存在并属于用户
- address = self.address_crud.get_address_by_id(
- db, address_id=order_create_request.ShippingAddressID
- )
+ address = self.address_crud.get_address_by_id(db, address_id=order_create_request.ShippingAddressID)
if not address or address["UserID"] != user_id:
logger.warning(
- f"Address {order_create_request.ShippingAddressID} not found or does not belong to user {user_id}")
- raise AddressNotFoundException(f"Address {order_create_request.ShippingAddressID} not found"
- f" or does not belong to user {user_id}")
+ f"Address {order_create_request.ShippingAddressID} not found or does not belong to user {user_id}"
+ )
+ raise AddressNotFoundException(
+ f"Address {order_create_request.ShippingAddressID} not found" f" or does not belong to user {user_id}"
+ )
# 验证各个购物车项目是否存在且属于当前用户
cart_item_ids_to_purchase = [item.CartItemID for item in order_create_request.Items]
cart_items = []
for cart_item_id in cart_item_ids_to_purchase:
- cart_item = self.cart_item_crud.get_cart_item_by_id(
- db, cart_item_id=cart_item_id
- )
+ cart_item = self.cart_item_crud.get_cart_item_by_id(db, cart_item_id=cart_item_id)
if not cart_item or cart_item["UserID"] != user_id:
logger.warning(f"Cart item {cart_item_id} not found or does not belong to user {user_id}")
- raise CartItemNotFoundException(f"Cart item {cart_item_id} not found"
- f" or does not belong to user {user_id}")
+ raise CartItemNotFoundException(
+ f"Cart item {cart_item_id} not found" f" or does not belong to user {user_id}"
+ )
cart_items.append(cart_item)
- logger.info(f"Number of cart items to purchase: {len(cart_items)},"
- f" Cart item IDs: {[item['CartItemID'] for item in cart_items]}."
- f" Trying to subtract stock for {len(cart_items)} items.")
+ logger.info(
+ f"Number of cart items to purchase: {len(cart_items)},"
+ f" Cart item IDs: {[item['CartItemID'] for item in cart_items]}."
+ f" Trying to subtract stock for {len(cart_items)} items."
+ )
transaction = db.begin_nested() # 开始一个嵌套事务
try:
for cart_item in cart_items:
@@ -131,21 +129,17 @@ async def process_order_creation(
raise ProductNotFoundException(f"Product {product_id} not found")
if product_info["StockQuantity"] < quantity:
logger.warning(
- f"Insufficient stock for product {product_id}. Required: {quantity}, Available: {product_info['StockQuantity']}")
+ f"Insufficient stock for product {product_id}. Required: {quantity}, Available: {product_info['StockQuantity']}"
+ )
raise InsufficientStockException(f"Insufficient stock for product {product_id}")
# 锁定库存
self.product_crud.update_product_stock(
- db,
- product_id=product_id,
- stock_change=-quantity,
- actor_id=actor_id
+ db, product_id=product_id, stock_change=-quantity, actor_id=actor_id
)
# 记录现在商品的信息用于创建订单项
cart_item["_product_info"] = product_info
# 移除购物车项目
- self.cart_item_crud.remove_item_from_cart(
- db, cart_item_id=cart_item["CartItemID"], actor_id=actor_id
- )
+ self.cart_item_crud.remove_item_from_cart(db, cart_item_id=cart_item["CartItemID"], actor_id=actor_id)
except InsufficientStockException as e:
# 回滚事务
if transaction.is_active:
@@ -180,12 +174,8 @@ async def process_order_creation(
# 计算运费和折扣
# 现在假设所有订单的运费和优惠金额都为0
- shipping_fees = {
- store_id: Decimal("0.00") for store_id in order_to_order_items_info_map.keys()
- }
- discount_amounts = {
- store_id: Decimal("0.00") for store_id in order_to_order_items_info_map.keys()
- }
+ shipping_fees = {store_id: Decimal("0.00") for store_id in order_to_order_items_info_map.keys()}
+ discount_amounts = {store_id: Decimal("0.00") for store_id in order_to_order_items_info_map.keys()}
final_amounts = {}
total_final_amount = Decimal("0.00")
for store_id, _ in order_to_order_items_info_map.items():
@@ -203,7 +193,7 @@ async def process_order_creation(
total_amount=total_final_amount,
payment_method="MOCK_PAYMENT",
status=PaymentTransactionStatusEnum.PENDING,
- actor_id=actor_id
+ actor_id=actor_id,
)
if not payment_transaction:
logger.error("Failed to create payment transaction")
@@ -232,7 +222,7 @@ async def process_order_creation(
shipping_address_phone_number=address["PhoneNumber"],
shipping_address_full=address["FullAddress_Text"],
notes_by_user=order_create_request.Notes_ByUser,
- actor_id=actor_id
+ actor_id=actor_id,
)
if not order:
logger.error("Failed to create order")
@@ -252,11 +242,12 @@ async def process_order_creation(
product_name_at_purchase=cart_item["_product_info"]["ProductName"],
product_image_url_at_purchase=cart_item["_product_info"]["MainImageURL"],
subtotal=current_price * cart_item["Quantity"],
- actor_id=actor_id
+ actor_id=actor_id,
)
if not order_item:
- logger.error(f"Failed to create order item."
- f" Info of cart item (with _product_info): {cart_item}")
+ logger.error(
+ f"Failed to create order item." f" Info of cart item (with _product_info): {cart_item}"
+ )
raise Exception("Failed to create order item")
order_to_order_items_created_info_map[order["OrderID"]].append(order_item)
order_items_created_responses.append(
@@ -269,18 +260,20 @@ async def process_order_creation(
PriceAtPurchase=order_item["PriceAtPurchase"],
ProductNameAtPurchase=order_item["ProductNameAtPurchase"],
ProductImageURLAtPurchase=order_item["ProductImageURLAtPurchase"],
- Subtotal=order_item["Subtotal"]
+ Subtotal=order_item["Subtotal"],
)
)
- logger.info(f"Order created with ID {order['OrderID']}, "
- f"Items created with IDs: {[item_resp.OrderItemID for item_resp in order_items_created_responses]}")
+ logger.info(
+ f"Order created with ID {order['OrderID']}, "
+ f"Items created with IDs: {[item_resp.OrderItemID for item_resp in order_items_created_responses]}"
+ )
orders_created_responses.append(
CreatedOrderDetailResponse(
OrderID=order["OrderID"],
StoreID=order["StoreID"],
FinalAmountForThisOrder=order["FinalAmountForThisOrder"],
OrderStatus=OrderStatusEnum(order["OrderStatus"]),
- Items=order_items_created_responses
+ Items=order_items_created_responses,
)
)
@@ -297,13 +290,7 @@ async def process_order_creation(
)
async def get_orders_for_user(
- self,
- db: Connection,
- *,
- user_id: int,
- actor_id: int,
- offset: int = 0,
- limit: int = 20
+ self, db: Connection, *, user_id: int, actor_id: int, offset: int = 0, limit: int = 20
) -> OrderListResponse:
"""
获取指定用户的所有订单列表,支持分页。
@@ -341,7 +328,7 @@ async def get_orders_for_user(
PriceAtPurchase=item["PriceAtPurchase"],
ProductNameAtPurchase=item["ProductNameAtPurchase"],
ProductImageURLAtPurchase=item["ProductImageURLAtPurchase"],
- Subtotal=item["Subtotal"]
+ Subtotal=item["Subtotal"],
)
for item in order_items_data
]
@@ -350,8 +337,9 @@ async def get_orders_for_user(
payment_transaction = self.payment_transaction_crud.get_payment_transaction_by_id(
db, payment_transaction_id=order_data["PaymentTransactionID"], actor_id=actor_id
)
- payment_status = PaymentTransactionStatusEnum(
- payment_transaction["Status"]) if payment_transaction else None
+ payment_status = (
+ PaymentTransactionStatusEnum(payment_transaction["Status"]) if payment_transaction else None
+ )
# 创建订单响应对象
order_response = OrderViewResponse(
@@ -376,27 +364,17 @@ async def get_orders_for_user(
CompletionTime=order_data["CompletionTime"],
LastUpdatedDate=order_data["LastUpdatedDate"],
Items=order_items,
- PaymentStatus=payment_status
+ PaymentStatus=payment_status,
)
order_responses.append(order_response)
# 获取总订单数
total_count = len(order_responses) # 实际应用中可能需要单独查询总数
- return OrderListResponse(
- Orders=order_responses,
- TotalCount=total_count,
- Offset=offset,
- Limit=limit
- )
+ return OrderListResponse(Orders=order_responses, TotalCount=total_count, Offset=offset, Limit=limit)
async def get_order_details_by_id_for_user(
- self,
- db: Connection,
- *,
- order_id: int,
- user_id: int,
- actor_id: int
+ self, db: Connection, *, order_id: int, user_id: int, actor_id: int
) -> OrderViewResponse:
"""
获取指定订单的详细信息,并验证该订单是否属于指定用户。
@@ -420,9 +398,7 @@ async def get_order_details_by_id_for_user(
raise OrderNotFoundException(f"Order with ID {order_id} not found")
# 获取订单项
- order_items_data = self.order_item_crud.get_order_items_by_order_id(
- db, order_id=order_id, actor_id=actor_id
- )
+ order_items_data = self.order_item_crud.get_order_items_by_order_id(db, order_id=order_id, actor_id=actor_id)
# 转换订单项为响应格式
order_items = [
@@ -435,7 +411,7 @@ async def get_order_details_by_id_for_user(
PriceAtPurchase=item["PriceAtPurchase"],
ProductNameAtPurchase=item["ProductNameAtPurchase"],
ProductImageURLAtPurchase=item["ProductImageURLAtPurchase"],
- Subtotal=item["Subtotal"]
+ Subtotal=item["Subtotal"],
)
for item in order_items_data
]
@@ -470,19 +446,19 @@ async def get_order_details_by_id_for_user(
CompletionTime=order_data["CompletionTime"],
LastUpdatedDate=order_data["LastUpdatedDate"],
Items=order_items,
- PaymentStatus=payment_status
+ PaymentStatus=payment_status,
)
async def update_order_status(
- self,
- db: Connection,
- *,
- order_id: int,
- new_status: OrderStatusEnum,
- actor_id: int,
- tracking_number: Optional[str] = None,
- notes: Optional[str] = None,
- is_admin_action: bool = False
+ self,
+ db: Connection,
+ *,
+ order_id: int,
+ new_status: OrderStatusEnum,
+ actor_id: int,
+ tracking_number: Optional[str] = None,
+ notes: Optional[str] = None,
+ is_admin_action: bool = False,
) -> OrderViewResponse:
"""
更新订单状态,并根据状态更新相关的时间戳。
@@ -512,7 +488,8 @@ async def update_order_status(
# 权限检查:验证用户是否有权限执行此状态转换
if order_data["UserID"] != actor_id:
logger.warning(
- f"ActorID {actor_id} attempting to update OrderID {order_id} belonging to UserID {order_data['UserID']}")
+ f"ActorID {actor_id} attempting to update OrderID {order_id} belonging to UserID {order_data['UserID']}"
+ )
# 获取当前订单状态
current_status = OrderStatusEnum(order_data["OrderStatus"])
@@ -534,7 +511,10 @@ async def update_order_status(
now = datetime.datetime.now(datetime.timezone.utc)
- if new_status == OrderStatusEnum.PAID_AND_PENDING_PROCESSING and current_status == OrderStatusEnum.PENDING_PAYMENT:
+ if (
+ new_status == OrderStatusEnum.PAID_AND_PENDING_PROCESSING
+ and current_status == OrderStatusEnum.PENDING_PAYMENT
+ ):
payment_confirmation_time = now
elif new_status == OrderStatusEnum.SHIPPED:
shipping_time = now
@@ -555,7 +535,7 @@ async def update_order_status(
delivery_time=delivery_time,
completion_time=completion_time,
notes_by_actor=notes,
- is_admin_or_merchant_action=is_admin_action
+ is_admin_or_merchant_action=is_admin_action,
)
if not updated_order:
@@ -568,13 +548,13 @@ async def update_order_status(
)
async def process_successful_payment(
- self,
- db: Connection,
- *,
- payment_transaction_id: int,
- # simulated_response: SimulatedExternPaymentResponse, # 如果需要从模拟响应中获取数据
- external_gateway_tx_id: Optional[str] = None, # 模拟的外部网关ID
- actor_id: int # 通常是系统或回调处理者
+ self,
+ db: Connection,
+ *,
+ payment_transaction_id: int,
+ # simulated_response: SimulatedExternPaymentResponse, # 如果需要从模拟响应中获取数据
+ external_gateway_tx_id: Optional[str] = None, # 模拟的外部网关ID
+ actor_id: int, # 通常是系统或回调处理者
) -> PaymentProcessingResponse:
"""
处理成功的支付响应。
@@ -592,7 +572,8 @@ async def process_successful_payment(
:raises OrderNotFoundException: 如果关联的订单未找到 (理论上不应发生)。
"""
logger.info(
- f"Processing successful payment for PaymentTransactionID: {payment_transaction_id} by ActorID: {actor_id}")
+ f"Processing successful payment for PaymentTransactionID: {payment_transaction_id} by ActorID: {actor_id}"
+ )
logger.debug(f"external_gateway_tx_id: {external_gateway_tx_id}")
# 1. 获取并验证支付事务
payment_tx = self.payment_transaction_crud.get_payment_transaction_by_id(
@@ -605,7 +586,8 @@ async def process_successful_payment(
# 如果已经是 SUCCESSFUL,可能是重复回调,可以幂等处理或记录警告
if payment_tx["Status"] == PaymentTransactionStatusEnum.SUCCESSFUL.value:
logger.warning(
- f"PaymentTransactionID {payment_transaction_id} is already SUCCESSFUL. Idempotency check.")
+ f"PaymentTransactionID {payment_transaction_id} is already SUCCESSFUL. Idempotency check."
+ )
# 直接构建并返回成功响应,避免重复更新
affected_orders_data = self.order_crud.get_orders_by_payment_transaction_id(
conn=db, payment_transaction_id=payment_transaction_id, actor_id=actor_id
@@ -622,7 +604,9 @@ async def process_successful_payment(
f"PaymentTransactionID {payment_transaction_id} is not in PENDING state. Current status: {payment_tx['Status']}."
)
- logger.debug(f"PaymentTransactionID {payment_transaction_id} is in PENDING state. Proceeding with payment processing.")
+ logger.debug(
+ f"PaymentTransactionID {payment_transaction_id} is in PENDING state. Proceeding with payment processing."
+ )
# 2. 更新支付事务状态
completion_time = datetime.datetime.now(datetime.timezone.utc)
updated_payment_tx = self.payment_transaction_crud.update_payment_transaction_status(
@@ -631,7 +615,7 @@ async def process_successful_payment(
new_status=PaymentTransactionStatusEnum.SUCCESSFUL.value,
actor_id=actor_id,
external_gateway_transaction_id=external_gateway_tx_id or f"mock_gw_{payment_transaction_id}",
- completion_time=completion_time
+ completion_time=completion_time,
)
if not updated_payment_tx:
logger.error(f"Failed to update PaymentTransactionID {payment_transaction_id} to SUCCESSFUL.")
@@ -652,26 +636,22 @@ async def process_successful_payment(
order_id=order_id,
new_status=OrderStatusEnum.PAID_AND_PENDING_PROCESSING,
actor_id=actor_id,
- is_admin_action=False
+ is_admin_action=False,
)
logger.info(f"OrderID {order_id} status updated to {update_order_resp.OrderStatus.value}.")
logger.success(
- f"Payment successful for PaymentTransactionID {payment_transaction_id}. {len(affected_orders_id)} orders updated.")
+ f"Payment successful for PaymentTransactionID {payment_transaction_id}. {len(affected_orders_id)} orders updated."
+ )
return PaymentProcessingResponse(
PaymentTransactionID=payment_transaction_id,
TransactionStatusInSystem=PaymentTransactionStatusEnum.SUCCESSFUL.value,
MessageToUser="支付成功",
- AffectedOrderIDs=affected_orders_id
+ AffectedOrderIDs=affected_orders_id,
)
async def get_payment_transaction_by_id_for_user(
- self,
- db: Connection,
- *,
- payment_transaction_id: int,
- user_id: int,
- actor_id: int
+ self, db: Connection, *, payment_transaction_id: int, user_id: int, actor_id: int
) -> PaymentResponse:
"""
获取指定支付事务的详细信息,并验证该支付事务是否属于指定用户。
@@ -684,8 +664,10 @@ async def get_payment_transaction_by_id_for_user(
:raises PaymentTransactionNotFoundException: 如果支付事务不存在
:raises PermissionDeniedException: 如果支付事务不属于指定用户且操作者不是管理员
"""
- logger.info(f"Getting payment transaction details for PaymentTransactionID {payment_transaction_id}, "
- f"UserID {user_id}, ActorID {actor_id}")
+ logger.info(
+ f"Getting payment transaction details for PaymentTransactionID {payment_transaction_id}, "
+ f"UserID {user_id}, ActorID {actor_id}"
+ )
# 获取支付事务信息
payment_transaction_data = self.payment_transaction_crud.get_payment_transaction_by_id(
@@ -696,8 +678,12 @@ async def get_payment_transaction_by_id_for_user(
# 检查支付事务是否存在
if not payment_transaction_data or payment_transaction_data["UserID"] != user_id:
- logger.warning(f"Payment transaction with ID {payment_transaction_id} not found or does not belong to user {user_id}")
- raise PaymentTransactionNotFoundException(f"Payment transaction with ID {payment_transaction_id} not found for user {user_id}")
+ logger.warning(
+ f"Payment transaction with ID {payment_transaction_id} not found or does not belong to user {user_id}"
+ )
+ raise PaymentTransactionNotFoundException(
+ f"Payment transaction with ID {payment_transaction_id} not found for user {user_id}"
+ )
resp = PaymentResponse(
PaymentTransactionID=payment_transaction_data["PaymentTransactionID"],
@@ -713,8 +699,6 @@ async def get_payment_transaction_by_id_for_user(
logger.info(f"Payment transaction details retrieved successfully: {resp}")
return resp
-
-
@staticmethod
def _get_valid_status_transitions(current_status: OrderStatusEnum, is_admin_action: bool) -> List[OrderStatusEnum]:
"""
diff --git a/src/backend/app/services/product_change_request_service.py b/src/backend/app/services/product_change_request_service.py
index 5c0894a..8a7c325 100644
--- a/src/backend/app/services/product_change_request_service.py
+++ b/src/backend/app/services/product_change_request_service.py
@@ -1,6 +1,7 @@
"""
商品变更请求服务模块,包含所有商品变更请求相关业务逻辑。
"""
+
from sqlalchemy import Connection
from typing import Optional, List, Dict, Any
import logging
@@ -9,10 +10,10 @@
from backend.app.crud.product_change_request_crud import get_product_change_request_crud_instance
from backend.app.utils.exceptions import (
- PermissionDeniedException,
+ PermissionDeniedException,
UserNotFoundException,
InvalidStatusTransitionException,
- ProductNotFoundException
+ ProductNotFoundException,
)
from backend.app.crud.user_crud import user_crud_instance
from backend.app.crud.product_crud import get_product_crud_instance
@@ -24,11 +25,11 @@ class ProductChangeRequestService:
"""
商品变更请求服务类,负责处理所有与商品变更请求相关的业务逻辑。
"""
-
+
def __init__(self):
self._change_request_crud = get_product_change_request_crud_instance()
self._product_crud = get_product_crud_instance()
-
+
async def create_change_request(
self,
conn: Connection,
@@ -42,7 +43,7 @@ async def create_change_request(
) -> Dict[str, Any]:
"""
创建新商品变更请求
-
+
:param conn: 数据库连接
:param merchant_user_id: 商家用户ID
:param store_id: 店铺ID
@@ -62,12 +63,12 @@ async def create_change_request(
if not merchant:
logger.warning(f"商家用户不存在,无法创建商品变更请求: {merchant_user_id}")
raise UserNotFoundException(f"商家用户ID {merchant_user_id} 不存在")
-
+
# 检查用户是否是商家角色
# 兼容测试中的两种可能的字段名
user_role = merchant.get("UserRole", "") if isinstance(merchant, dict) else ""
role = merchant.get("Role", "") if isinstance(merchant, dict) else ""
-
+
if user_role != "merchant" and role != "merchant":
# 如果是测试,可能会返回模拟对象,避免报错
if not isinstance(merchant.get("UserRole"), str) and not isinstance(merchant.get("Role"), str):
@@ -79,24 +80,26 @@ async def create_change_request(
except AttributeError:
# 在测试环境中,可能会收到模拟对象而不是正常的字典
pass
-
+
# 对于商品更新或删除请求,检查商品是否存在且属于该商家的店铺
- if product_id and request_type in ['UPDATE_PRODUCT', 'DELETE_PRODUCT']:
+ if product_id and request_type in ["UPDATE_PRODUCT", "DELETE_PRODUCT"]:
try:
product = self._product_crud.get_product_by_id(conn, product_id=product_id)
if not product:
logger.warning(f"商品不存在,无法创建变更请求: {product_id}")
raise ProductNotFoundException(f"商品ID {product_id} 不存在")
-
+
# 检查商品是否属于该商家的店铺
# 兼容测试环境
if isinstance(product, dict) and product.get("StoreID") != store_id:
- logger.warning(f"商品不属于该店铺,无法创建变更请求: 商品ID={product_id}, 店铺ID={store_id}, 实际店铺ID={product.get('StoreID')}")
+ logger.warning(
+ f"商品不属于该店铺,无法创建变更请求: 商品ID={product_id}, 店铺ID={store_id}, 实际店铺ID={product.get('StoreID')}"
+ )
raise PermissionDeniedException(f"商品ID {product_id} 不属于店铺ID {store_id}")
except AttributeError:
# 在测试环境中,可能会收到模拟对象而不是正常的字典
pass
-
+
# JSON序列化检查
if proposed_data:
try:
@@ -104,7 +107,7 @@ async def create_change_request(
except (TypeError, ValueError) as e:
logger.error(f"提议数据无法序列化为JSON: {e}")
proposed_data = {"error": "原始数据无法序列化", "data_str": str(proposed_data)}
-
+
return self._change_request_crud.create_change_request(
conn=conn,
merchant_user_id=merchant_user_id,
@@ -113,9 +116,9 @@ async def create_change_request(
proposed_data=proposed_data,
product_id=product_id,
submitter_notes=submitter_notes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def get_change_request_by_id(
self,
conn: Connection,
@@ -124,23 +127,21 @@ async def get_change_request_by_id(
) -> Optional[Dict[str, Any]]:
"""
根据ID获取商品变更请求信息
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param actor_id: 操作用户ID
:return: 请求信息字典或None
"""
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.info(f"未找到商品变更请求ID: {request_id}")
-
+
return request
-
+
async def get_change_requests_by_product_id(
self,
conn: Connection,
@@ -152,7 +153,7 @@ async def get_change_requests_by_product_id(
) -> List[Dict[str, Any]]:
"""
获取指定商品的变更请求列表
-
+
:param conn: 数据库连接
:param product_id: 商品ID
:param status: 请求状态
@@ -171,16 +172,11 @@ async def get_change_requests_by_product_id(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
return self._change_request_crud.get_change_requests_by_product_id(
- conn=conn,
- product_id=product_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, product_id=product_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_change_requests_by_store_id(
self,
conn: Connection,
@@ -192,7 +188,7 @@ async def get_change_requests_by_store_id(
) -> List[Dict[str, Any]]:
"""
获取指定店铺的商品变更请求列表
-
+
:param conn: 数据库连接
:param store_id: 店铺ID
:param status: 请求状态
@@ -202,14 +198,9 @@ async def get_change_requests_by_store_id(
:return: 请求信息列表
"""
return self._change_request_crud.get_change_requests_by_store_id(
- conn=conn,
- store_id=store_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, store_id=store_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_change_requests_by_merchant_id(
self,
conn: Connection,
@@ -221,7 +212,7 @@ async def get_change_requests_by_merchant_id(
) -> List[Dict[str, Any]]:
"""
获取指定商家的变更请求列表
-
+
:param conn: 数据库连接
:param merchant_id: 商家ID
:param status: 请求状态
@@ -240,16 +231,11 @@ async def get_change_requests_by_merchant_id(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
return self._change_request_crud.get_change_requests_by_merchant_id(
- conn=conn,
- merchant_id=merchant_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, merchant_id=merchant_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_all_pending_requests(
self,
conn: Connection,
@@ -259,7 +245,7 @@ async def get_all_pending_requests(
) -> List[Dict[str, Any]]:
"""
获取所有待审核的商品变更请求列表
-
+
:param conn: 数据库连接
:param limit: 结果数量限制
:param offset: 分页偏移量
@@ -267,12 +253,9 @@ async def get_all_pending_requests(
:return: 请求信息列表
"""
return self._change_request_crud.get_all_pending_requests(
- conn=conn,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_filtered_requests(
self,
conn: Connection,
@@ -290,7 +273,7 @@ async def get_filtered_requests(
) -> List[Dict[str, Any]]:
"""
获取根据多种条件筛选的请求列表
-
+
:param conn: 数据库连接
:param product_id: 商品ID筛选
:param store_id: 店铺ID筛选
@@ -309,7 +292,7 @@ async def get_filtered_requests(
if start_date and end_date and start_date > end_date:
logger.warning("请求筛选的开始日期晚于结束日期")
return []
-
+
return self._change_request_crud.get_filtered_requests(
conn=conn,
product_id=product_id,
@@ -322,9 +305,9 @@ async def get_filtered_requests(
end_date=end_date,
limit=limit,
offset=offset,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def update_request(
self,
conn: Connection,
@@ -334,7 +317,7 @@ async def update_request(
) -> Optional[Dict[str, Any]]:
"""
更新商品变更请求信息
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param update_data: 更新数据
@@ -344,15 +327,13 @@ async def update_request(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要更新的商品变更请求: {request_id}")
return None
-
+
# 只有PENDING_APPROVAL状态的请求可以更新
try:
status = request.get("Status") if isinstance(request, dict) else None
@@ -367,7 +348,7 @@ async def update_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 如果不是管理员,检查是否是请求创建者
try:
if actor_id and isinstance(request, dict) and actor_id != request.get("MerchantUserID"):
@@ -378,22 +359,22 @@ async def update_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# JSON序列化检查
if "proposed_data" in update_data:
try:
json.dumps(update_data["proposed_data"])
except (TypeError, ValueError) as e:
logger.error(f"更新数据无法序列化为JSON: {e}")
- update_data["proposed_data"] = {"error": "原始数据无法序列化", "data_str": str(update_data["proposed_data"])}
-
+ update_data["proposed_data"] = {
+ "error": "原始数据无法序列化",
+ "data_str": str(update_data["proposed_data"]),
+ }
+
return self._change_request_crud.update_request(
- conn=conn,
- request_id=request_id,
- update_data=update_data,
- actor_id=actor_id
+ conn=conn, request_id=request_id, update_data=update_data, actor_id=actor_id
)
-
+
async def update_request_status(
self,
conn: Connection,
@@ -405,7 +386,7 @@ async def update_request_status(
) -> Optional[Dict[str, Any]]:
"""
更新请求状态
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param status: 新状态
@@ -418,20 +399,18 @@ async def update_request_status(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要更新状态的商品变更请求: {request_id}")
return None
-
+
# 验证状态转换是否有效
try:
current_status = request.get("Status") if isinstance(request, dict) else ""
valid_transitions = self._get_valid_status_transitions(current_status)
-
+
if status not in valid_transitions and isinstance(current_status, str):
logger.warning(f"无效的状态转换: {current_status} -> {status}")
raise InvalidStatusTransitionException(
@@ -440,21 +419,21 @@ async def update_request_status(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 检查是否有管理员权限
if not admin_id:
logger.warning("状态更新需要管理员ID")
raise PermissionDeniedException("状态更新需要管理员权限")
-
+
return self._change_request_crud.update_request_status(
conn=conn,
request_id=request_id,
status=status,
admin_id=admin_id,
admin_notes=admin_notes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def cancel_request(
self,
conn: Connection,
@@ -463,7 +442,7 @@ async def cancel_request(
) -> bool:
"""
取消商品变更请求(商家自行取消)
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param actor_id: 操作用户ID
@@ -472,15 +451,13 @@ async def cancel_request(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要取消的商品变更请求: {request_id}")
return False
-
+
# 只有PENDING_APPROVAL状态的请求可以取消
try:
status = request.get("Status") if isinstance(request, dict) else None
@@ -495,7 +472,7 @@ async def cancel_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 检查是否是请求创建者
try:
if actor_id and isinstance(request, dict) and actor_id != request.get("MerchantUserID"):
@@ -506,30 +483,26 @@ async def cancel_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
- return self._change_request_crud.cancel_request(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+
+ return self._change_request_crud.cancel_request(conn=conn, request_id=request_id, actor_id=actor_id)
+
@staticmethod
def _get_valid_status_transitions(current_status: str) -> List[str]:
"""
获取有效的状态转换列表
-
+
:param current_status: 当前状态
:return: 有效的目标状态列表
"""
# 如果是测试环境中的mock对象,允许所有转换
if not isinstance(current_status, str):
return ["APPROVED", "REJECTED", "CANCELLED"]
-
+
transitions = {
"PENDING_APPROVAL": ["APPROVED", "REJECTED", "CANCELLED"],
"APPROVED": [], # 终态,不能再转换
"REJECTED": [], # 终态,不能再转换
- "CANCELLED": [] # 终态,不能再转换
+ "CANCELLED": [], # 终态,不能再转换
}
-
+
return transitions.get(current_status, [])
diff --git a/src/backend/app/services/product_change_request_service_v2.py b/src/backend/app/services/product_change_request_service_v2.py
index 2abfbb8..3f8430b 100644
--- a/src/backend/app/services/product_change_request_service_v2.py
+++ b/src/backend/app/services/product_change_request_service_v2.py
@@ -75,9 +75,7 @@ async def submit_new_request(
)
# 1. 验证店铺所有权
- store = self._store_crud.get_store_by_id(
- conn=db, store_id=request_in.StoreID, actor_id=merchant_user.UserID
- )
+ store = self._store_crud.get_store_by_id(conn=db, store_id=request_in.StoreID, actor_id=merchant_user.UserID)
if not store:
raise StoreNotFoundException(f"Store with ID {request_in.StoreID} not found.")
if store["OwnerUserID"] != merchant_user.UserID:
@@ -101,9 +99,7 @@ async def submit_new_request(
if request_in.ProductID is not None:
raise BadRequestException("ProductID must be null for PRODUCT_CREATE requests.")
if not proposed_data_dict: # ProposedData_JSON (as dict) is required for create
- raise BadRequestException(
- "ProposedData_JSON is required for PRODUCT_CREATE requests."
- )
+ raise BadRequestException("ProposedData_JSON is required for PRODUCT_CREATE requests.")
# 验证 ProposedProductData 中的必填字段 (基于 ProposedProductData schema 的注释)
required_fields = [
@@ -115,17 +111,11 @@ async def submit_new_request(
if (
proposed_data_dict.get(field) is None
): # or not proposed_data_dict.get(field) if empty string is also invalid
- raise BadRequestException(
- f"Field '{field}' in ProposedData_JSON is required for PRODUCT_CREATE."
- )
+ raise BadRequestException(f"Field '{field}' in ProposedData_JSON is required for PRODUCT_CREATE.")
# 确保 Price 和 StockQuantity (如果提供) 是有效的数字
- if "Price" in proposed_data_dict and not isinstance(
- proposed_data_dict["Price"], (int, float, Decimal)
- ):
- raise BadRequestException(
- "Price in ProposedData_JSON must be a valid number for PRODUCT_CREATE."
- )
+ if "Price" in proposed_data_dict and not isinstance(proposed_data_dict["Price"], (int, float, Decimal)):
+ raise BadRequestException("Price in ProposedData_JSON must be a valid number for PRODUCT_CREATE.")
if (
"StockQuantity" in proposed_data_dict
and proposed_data_dict["StockQuantity"] is not None
@@ -159,12 +149,8 @@ async def submit_new_request(
elif request_in.RequestType == RequestTypeEnum.PRODUCT_DELETE:
if request_in.ProductID is None:
raise BadRequestException("ProductID is required for PRODUCT_DELETE requests.")
- if (
- proposed_data_dict is not None
- ): # ProposedData_JSON (as dict) should be None for delete
- raise BadRequestException(
- "ProposedData_JSON must not be provided for PRODUCT_DELETE requests."
- )
+ if proposed_data_dict is not None: # ProposedData_JSON (as dict) should be None for delete
+ raise BadRequestException("ProposedData_JSON must not be provided for PRODUCT_DELETE requests.")
# 验证 ProductID 属于 StoreID
product_to_delete = self._product_crud.get_product_by_id(
@@ -231,9 +217,7 @@ async def get_request_details(
:param actor_user:
:return:
"""
- logger.info(
- f"ActorID {actor_user.UserID} attempting to get details for ChangeRequestID {change_request_id}"
- )
+ logger.info(f"ActorID {actor_user.UserID} attempting to get details for ChangeRequestID {change_request_id}")
request_data = self._pcr_crud.get_request_by_id(conn=db, request_id=change_request_id)
if not request_data:
@@ -290,9 +274,7 @@ async def list_requests_for_merchant(
)
response_items = [ProductChangeRequestResponse(**data) for data in requests_data]
- return ProductChangeRequestListResponse(
- Requests=response_items, TotalCount=len(response_items)
- )
+ return ProductChangeRequestListResponse(Requests=response_items, TotalCount=len(response_items))
async def list_requests_for_admin(
self,
@@ -325,9 +307,7 @@ async def list_requests_for_admin(
merchant_user_id=query_params.MerchantUserID, # 管理员可以按商家筛选
)
response_items = [ProductChangeRequestResponse(**data) for data in requests_data]
- return ProductChangeRequestListResponse(
- Requests=response_items, TotalCount=len(response_items)
- )
+ return ProductChangeRequestListResponse(Requests=response_items, TotalCount=len(response_items))
async def merchant_cancel_request(
self,
@@ -336,9 +316,7 @@ async def merchant_cancel_request(
change_request_id: int,
merchant_user: CurrentUserSchema,
) -> ProductChangeRequestResponse:
- logger.info(
- f"Merchant UserID {merchant_user.UserID} attempting to cancel ChangeRequestID {change_request_id}"
- )
+ logger.info(f"Merchant UserID {merchant_user.UserID} attempting to cancel ChangeRequestID {change_request_id}")
# 1. 获取请求,验证所有权和状态
request_to_cancel = self._pcr_crud.get_request_by_id(conn=db, request_id=change_request_id)
@@ -372,17 +350,11 @@ async def merchant_cancel_request(
# This might happen if the status changed concurrently, or DB error
raise Exception("Failed to cancel the request.")
- updated_request_dict = self._pcr_crud.get_request_by_id(
- conn=db, request_id=change_request_id
- )
+ updated_request_dict = self._pcr_crud.get_request_by_id(conn=db, request_id=change_request_id)
if not updated_request_dict: # Should not happen if cancel was successful
- raise ProductNotFoundException(
- f"Request {change_request_id} not found after cancellation attempt."
- )
+ raise ProductNotFoundException(f"Request {change_request_id} not found after cancellation attempt.")
- logger.success(
- f"ChangeRequestID {change_request_id} cancelled by Merchant {merchant_user.UserID}."
- )
+ logger.success(f"ChangeRequestID {change_request_id} cancelled by Merchant {merchant_user.UserID}.")
return ProductChangeRequestResponse(**updated_request_dict)
async def admin_review_request(
@@ -403,17 +375,13 @@ async def admin_review_request(
# 1. 获取请求,验证当前状态是 PENDING_APPROVAL
request_to_review = self._pcr_crud.get_request_by_id(conn=db, request_id=change_request_id)
if not request_to_review:
- raise ProductNotFoundException(
- f"ProductChangeRequest with ID {change_request_id} not found for review."
- )
+ raise ProductNotFoundException(f"ProductChangeRequest with ID {change_request_id} not found for review.")
if request_to_review["Status"] != RequestStatusEnum.PENDING_APPROVAL.value:
logger.warning(
f"Invalid operation: Request {change_request_id} is not in PENDING_APPROVAL status (current: {request_to_review['Status']}). Cannot review."
)
- raise InvalidOperationException(
- f"Request is not in PENDING_APPROVAL status, cannot be reviewed."
- )
+ raise InvalidOperationException(f"Request is not in PENDING_APPROVAL status, cannot be reviewed.")
# 2. 调用 CRUD 更新状态和管理员信息
updated_request_dict = self._pcr_crud.update_request_by_admin(
@@ -426,9 +394,7 @@ async def admin_review_request(
)
if not updated_request_dict:
- logger.error(
- f"Failed to update request {change_request_id} status by admin in CRUD layer."
- )
+ logger.error(f"Failed to update request {change_request_id} status by admin in CRUD layer.")
raise Exception("Failed to review and update request status.")
logger.success(
@@ -449,9 +415,7 @@ async def admin_review_request(
applier_user=admin_user, # Admin is applying their own approval
)
except Exception as e:
- logger.error(
- f"Failed to apply approved request {change_request_id} by Admin {admin_user.UserID}: {e}"
- )
+ logger.error(f"Failed to apply approved request {change_request_id} by Admin {admin_user.UserID}: {e}")
# will return the *approved* but not applied request as response in the final return
else:
logger.success(
@@ -469,9 +433,7 @@ async def apply_approved_request(
change_request_id: int,
applier_user: CurrentUserSchema,
) -> ProductChangeRequestResponse:
- logger.info(
- f"UserID {applier_user.UserID} attempting to apply approved ChangeRequestID {change_request_id}"
- )
+ logger.info(f"UserID {applier_user.UserID} attempting to apply approved ChangeRequestID {change_request_id}")
# 1. 获取请求,验证状态是 APPROVED
pcr_data = self._pcr_crud.get_request_by_id(
@@ -479,17 +441,13 @@ async def apply_approved_request(
request_id=change_request_id,
)
if not pcr_data:
- raise ProductNotFoundException(
- f"ProductChangeRequest with ID {change_request_id} not found."
- )
+ raise ProductNotFoundException(f"ProductChangeRequest with ID {change_request_id} not found.")
if pcr_data["Status"] != RequestStatusEnum.APPROVED.value:
logger.warning(
f"Cannot apply request {change_request_id}: Status is '{pcr_data['Status']}', not '{RequestStatusEnum.APPROVED.value}'."
)
- raise InvalidOperationException(
- f"Request {change_request_id} is not in APPROVED status."
- )
+ raise InvalidOperationException(f"Request {change_request_id} is not in APPROVED status.")
# 权限检查: 商家只能应用自己的,管理员可以应用任何已批准的
# TODO: 替换为权限检查类的调用
@@ -502,9 +460,7 @@ async def apply_approved_request(
# 2. 解析 ProposedData_JSON
proposed_data_obj: Optional[ProposedProductData] = None
- if pcr_data[
- "ProposedData_JSON"
- ]: # ProposedData_JSON is already a dict due to CRUD's _deserialize
+ if pcr_data["ProposedData_JSON"]: # ProposedData_JSON is already a dict due to CRUD's _deserialize
try:
proposed_data_obj = ProposedProductData(**pcr_data["ProposedData_JSON"])
except Exception as e: # Pydantic ValidationError
@@ -522,9 +478,7 @@ async def apply_approved_request(
)
raise BadRequestException(f"Invalid proposed data for request {change_request_id}.")
- applied_product_id: Optional[int] = pcr_data.get(
- "ProductID"
- ) # Existing ProductID for UPDATE/DELETE
+ applied_product_id: Optional[int] = pcr_data.get("ProductID") # Existing ProductID for UPDATE/DELETE
# 3. 根据 RequestType 执行商品操作
request_type = RequestTypeEnum(pcr_data["RequestType"])
@@ -535,9 +489,7 @@ async def apply_approved_request(
f"ProductID should be None for PRODUCT_CREATE request {change_request_id}."
)
if not proposed_data_obj:
- raise BadRequestException(
- "ProposedData_JSON is required for PRODUCT_CREATE application."
- )
+ raise BadRequestException("ProposedData_JSON is required for PRODUCT_CREATE application.")
# Ensure all required fields for product creation are present in proposed_data_obj
# This validation was also in submit_new_request, but good to re-check or ensure consistency
if not all(
@@ -563,34 +515,20 @@ async def apply_approved_request(
actor_id=applier_user.UserID,
)
if not created_product_dict:
- raise Exception(
- f"Failed to create product for ChangeRequestID {change_request_id}."
- )
+ raise Exception(f"Failed to create product for ChangeRequestID {change_request_id}.")
applied_product_id = created_product_dict["ProductID"]
- logger.info(
- f"Product {applied_product_id} created for ChangeRequestID {change_request_id}."
- )
+ logger.info(f"Product {applied_product_id} created for ChangeRequestID {change_request_id}.")
elif request_type == RequestTypeEnum.PRODUCT_UPDATE:
- if (
- applied_product_id is None
- ): # Should have been set when creating PRODUCT_UPDATE request
- raise InvalidOperationException(
- f"ProductID is missing for PRODUCT_UPDATE request {change_request_id}."
- )
+ if applied_product_id is None: # Should have been set when creating PRODUCT_UPDATE request
+ raise InvalidOperationException(f"ProductID is missing for PRODUCT_UPDATE request {change_request_id}.")
if not proposed_data_obj:
- raise BadRequestException(
- "ProposedData_JSON is required for PRODUCT_UPDATE application."
- )
+ raise BadRequestException("ProposedData_JSON is required for PRODUCT_UPDATE application.")
# Prepare kwargs for product_crud.update_product_fields
- update_kwargs = proposed_data_obj.model_dump(
- exclude_unset=True
- ) # Only fields that were set
+ update_kwargs = proposed_data_obj.model_dump(exclude_unset=True) # Only fields that were set
if not update_kwargs:
- logger.info(
- f"No fields to update in ProposedData_JSON for PRODUCT_UPDATE request {change_request_id}."
- )
+ logger.info(f"No fields to update in ProposedData_JSON for PRODUCT_UPDATE request {change_request_id}.")
else:
updated_product_dict = self._product_crud.update_product(
conn=db,
@@ -602,15 +540,11 @@ async def apply_approved_request(
raise Exception(
f"Failed to update product {applied_product_id} for ChangeRequestID {change_request_id}."
)
- logger.info(
- f"Product {applied_product_id} updated for ChangeRequestID {change_request_id}."
- )
+ logger.info(f"Product {applied_product_id} updated for ChangeRequestID {change_request_id}.")
elif request_type == RequestTypeEnum.PRODUCT_DELETE:
if applied_product_id is None:
- raise InvalidOperationException(
- f"ProductID is missing for PRODUCT_DELETE request {change_request_id}."
- )
+ raise InvalidOperationException(f"ProductID is missing for PRODUCT_DELETE request {change_request_id}.")
deleted_successfully = self._product_crud.delete_product(
conn=db, product_id=applied_product_id, actor_id=applier_user.UserID
@@ -620,9 +554,7 @@ async def apply_approved_request(
raise Exception(
f"Failed to delete product {applied_product_id} for ChangeRequestID {change_request_id}."
)
- logger.info(
- f"Product {applied_product_id} deleted for ChangeRequestID {change_request_id}."
- )
+ logger.info(f"Product {applied_product_id} deleted for ChangeRequestID {change_request_id}.")
# For DELETE, applied_product_id remains the ID of the deleted product for record keeping.
# 4. 更新 ProductChangeRequest 状态为 APPLIED 和 ProductID (如果创建)
@@ -637,14 +569,8 @@ async def apply_approved_request(
logger.error(
f"CRITICAL: Failed to update ChangeRequestID {change_request_id} to APPLIED after product operation."
)
- raise Exception(
- f"Failed to finalize ChangeRequest {change_request_id} status to APPLIED."
- )
+ raise Exception(f"Failed to finalize ChangeRequest {change_request_id} status to APPLIED.")
- logger.success(
- f"ChangeRequestID {change_request_id} successfully applied by UserID {applier_user.UserID}."
- )
- logger.debug(
- f"Final ProductChangeRequest data: {final_pcr_dict}"
- )
+ logger.success(f"ChangeRequestID {change_request_id} successfully applied by UserID {applier_user.UserID}.")
+ logger.debug(f"Final ProductChangeRequest data: {final_pcr_dict}")
return ProductChangeRequestResponse(**final_pcr_dict)
diff --git a/src/backend/app/services/product_service.py b/src/backend/app/services/product_service.py
index 7f06edb..c4dd3ae 100644
--- a/src/backend/app/services/product_service.py
+++ b/src/backend/app/services/product_service.py
@@ -1,6 +1,7 @@
"""
商品服务模块,包含所有商品相关业务逻辑。
"""
+
from sqlalchemy import Connection
from typing import Optional, List, Dict, Any
import logging
@@ -15,10 +16,10 @@ class ProductService:
"""
商品服务类,负责处理所有与商品相关的业务逻辑。
"""
-
+
def __init__(self):
self._product_crud = get_product_crud_instance()
-
+
async def create_product(
self,
conn: Connection,
@@ -43,9 +44,9 @@ async def create_product(
product_description=product_description,
stock_quantity=stock_quantity,
main_image_url=main_image_url,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def get_product_by_id(
self,
conn: Connection,
@@ -55,12 +56,8 @@ async def get_product_by_id(
"""
根据ID获取商品信息
"""
- return self._product_crud.get_product_by_id(
- conn=conn,
- product_id=product_id,
- actor_id=actor_id
- )
-
+ return self._product_crud.get_product_by_id(conn=conn, product_id=product_id, actor_id=actor_id)
+
async def get_product_with_category_info(
self,
conn: Connection,
@@ -70,12 +67,8 @@ async def get_product_with_category_info(
"""
获取带有分类信息的商品详情
"""
- return self._product_crud.get_product_with_category_info(
- conn=conn,
- product_id=product_id,
- actor_id=actor_id
- )
-
+ return self._product_crud.get_product_with_category_info(conn=conn, product_id=product_id, actor_id=actor_id)
+
async def list_products(
self,
conn: Connection,
@@ -89,42 +82,25 @@ async def list_products(
"""
获取商品列表,支持按店铺、分类筛选或搜索
"""
-
+
if store_id:
products = self._product_crud.get_products_by_store_id(
- conn=conn,
- store_id=store_id,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, store_id=store_id, limit=limit, offset=offset, actor_id=actor_id
)
elif category_id:
products = self._product_crud.get_products_by_category_id(
- conn=conn,
- category_id=category_id,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, category_id=category_id, limit=limit, offset=offset, actor_id=actor_id
)
elif search:
products = self._product_crud.search_products(
- conn=conn,
- search_term=search,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, search_term=search, limit=limit, offset=offset, actor_id=actor_id
)
else:
# 当没有指定条件时,返回所有商品
- products = self._product_crud.get_all_products(
- conn=conn,
- limit=limit,
- offset=offset,
- actor_id=actor_id
- )
-
+ products = self._product_crud.get_all_products(conn=conn, limit=limit, offset=offset, actor_id=actor_id)
+
return products
-
+
async def update_product(
self,
conn: Connection,
@@ -136,12 +112,9 @@ async def update_product(
更新商品信息
"""
return self._product_crud.update_product(
- conn=conn,
- product_id=product_id,
- update_data=update_data,
- actor_id=actor_id
+ conn=conn, product_id=product_id, update_data=update_data, actor_id=actor_id
)
-
+
async def update_product_stock(
self,
conn: Connection,
@@ -153,12 +126,9 @@ async def update_product_stock(
更新商品库存
"""
return self._product_crud.update_product_stock(
- conn=conn,
- product_id=product_id,
- stock_change=stock_change,
- actor_id=actor_id
+ conn=conn, product_id=product_id, stock_change=stock_change, actor_id=actor_id
)
-
+
async def delete_product(
self,
conn: Connection,
@@ -168,11 +138,7 @@ async def delete_product(
"""
删除商品(将状态设置为DISCONTINUED)
"""
- return self._product_crud.delete_product(
- conn=conn,
- product_id=product_id,
- actor_id=actor_id
- )
+ return self._product_crud.delete_product(conn=conn, product_id=product_id, actor_id=actor_id)
# 单例模式,提供获取 ProductService 实例的函数
diff --git a/src/backend/app/services/statistics_service.py b/src/backend/app/services/statistics_service.py
index 39243ce..15fbe2b 100644
--- a/src/backend/app/services/statistics_service.py
+++ b/src/backend/app/services/statistics_service.py
@@ -12,10 +12,10 @@ class StatisticsService:
"""
统计服务类,负责处理所有与系统统计相关的业务逻辑。
"""
-
+
def __init__(self):
pass
-
+
async def get_system_statistics(self, conn: Connection) -> SystemStatistics:
"""
Get overall system statistics from the system_statistics view.
@@ -33,7 +33,7 @@ async def get_system_statistics(self, conn: Connection) -> SystemStatistics:
pending_orders=0,
completed_orders=0,
total_orders=0,
- total_sales=0
+ total_sales=0,
)
# Use row._mapping instead of dict(row) for newer SQLAlchemy versions
stats_dict = row._mapping
@@ -56,7 +56,7 @@ async def get_admin_dashboard_statistics(self, conn: Connection) -> AdminDashboa
pending_orders=0,
completed_orders=0,
total_orders=0,
- total_sales=0
+ total_sales=0,
)
# Use row._mapping instead of dict(row) for newer SQLAlchemy versions
stats_dict = row._mapping
@@ -73,11 +73,11 @@ async def get_store_statistics(self, conn: Connection, store_id: Optional[int] =
else:
query = text("SELECT * FROM store_statistics")
result = conn.execute(query)
-
+
rows = result.fetchall()
if not rows:
logger.warning(f"No store statistics found{' for store_id: ' + str(store_id) if store_id else ''}")
return []
-
+
# Use row._mapping instead of dict(row) for newer SQLAlchemy versions
return [StoreStatistics(**row._mapping) for row in rows]
diff --git a/src/backend/app/services/store_change_request_service.py b/src/backend/app/services/store_change_request_service.py
index ce77e2f..5ea549a 100644
--- a/src/backend/app/services/store_change_request_service.py
+++ b/src/backend/app/services/store_change_request_service.py
@@ -1,6 +1,7 @@
"""
店铺变更请求服务模块,包含所有店铺变更请求相关业务逻辑。
"""
+
from sqlalchemy import Connection
from typing import Optional, List, Dict, Any
import logging
@@ -9,9 +10,9 @@
from backend.app.crud.store_change_request_crud import get_store_change_request_crud_instance
from backend.app.utils.exceptions import (
- PermissionDeniedException,
+ PermissionDeniedException,
UserNotFoundException,
- InvalidStatusTransitionException
+ InvalidStatusTransitionException,
)
from backend.app.crud.user_crud import user_crud_instance
@@ -22,10 +23,10 @@ class StoreChangeRequestService:
"""
店铺变更请求服务类,负责处理所有与店铺变更请求相关的业务逻辑。
"""
-
+
def __init__(self):
self._change_request_crud = get_store_change_request_crud_instance()
-
+
async def create_change_request(
self,
conn: Connection,
@@ -38,7 +39,7 @@ async def create_change_request(
) -> Dict[str, Any]:
"""
创建新店铺变更请求
-
+
:param conn: 数据库连接
:param requesting_user_id: 请求用户ID
:param request_type: 请求类型
@@ -56,9 +57,9 @@ async def create_change_request(
if not user and not isinstance(user, dict):
logger.warning(f"用户不存在,无法创建店铺变更请求: {requesting_user_id}")
raise UserNotFoundException(f"请求用户ID {requesting_user_id} 不存在")
-
+
# 对于店铺更新请求,验证用户是否有权限操作该店铺
- if store_id and request_type != 'CREATE_STORE':
+ if store_id and request_type != "CREATE_STORE":
# 这里可以添加检查用户是否是商家、是否拥有该店铺的逻辑
# 例如可以查询Store表检查用户是否拥有此店铺
# 根据实际业务逻辑添加检查
@@ -66,7 +67,7 @@ async def create_change_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# JSON序列化检查
if proposed_data:
try:
@@ -74,7 +75,7 @@ async def create_change_request(
except (TypeError, ValueError) as e:
logger.error(f"提议数据无法序列化为JSON: {e}")
proposed_data = {"error": "原始数据无法序列化", "data_str": str(proposed_data)}
-
+
return self._change_request_crud.create_change_request(
conn=conn,
requesting_user_id=requesting_user_id,
@@ -82,9 +83,9 @@ async def create_change_request(
store_id=store_id,
proposed_data=proposed_data,
submitter_notes=submitter_notes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def get_change_request_by_id(
self,
conn: Connection,
@@ -93,23 +94,21 @@ async def get_change_request_by_id(
) -> Optional[Dict[str, Any]]:
"""
根据ID获取店铺变更请求信息
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param actor_id: 操作用户ID
:return: 请求信息字典或None
"""
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.info(f"未找到店铺变更请求ID: {request_id}")
-
+
return request
-
+
async def get_change_requests_by_store_id(
self,
conn: Connection,
@@ -121,7 +120,7 @@ async def get_change_requests_by_store_id(
) -> List[Dict[str, Any]]:
"""
获取指定店铺的变更请求列表
-
+
:param conn: 数据库连接
:param store_id: 店铺ID
:param status: 请求状态
@@ -131,14 +130,9 @@ async def get_change_requests_by_store_id(
:return: 请求信息列表
"""
return self._change_request_crud.get_change_requests_by_store_id(
- conn=conn,
- store_id=store_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, store_id=store_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_change_requests_by_user_id(
self,
conn: Connection,
@@ -150,7 +144,7 @@ async def get_change_requests_by_user_id(
) -> List[Dict[str, Any]]:
"""
获取指定用户的变更请求列表
-
+
:param conn: 数据库连接
:param user_id: 用户ID
:param status: 请求状态
@@ -169,16 +163,11 @@ async def get_change_requests_by_user_id(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
return self._change_request_crud.get_change_requests_by_user_id(
- conn=conn,
- user_id=user_id,
- status=status,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, user_id=user_id, status=status, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_all_pending_requests(
self,
conn: Connection,
@@ -188,7 +177,7 @@ async def get_all_pending_requests(
) -> List[Dict[str, Any]]:
"""
获取所有待审核的变更请求列表
-
+
:param conn: 数据库连接
:param limit: 结果数量限制
:param offset: 分页偏移量
@@ -196,12 +185,9 @@ async def get_all_pending_requests(
:return: 请求信息列表
"""
return self._change_request_crud.get_all_pending_requests(
- conn=conn,
- limit=limit,
- offset=offset,
- actor_id=actor_id
+ conn=conn, limit=limit, offset=offset, actor_id=actor_id
)
-
+
async def get_filtered_requests(
self,
conn: Connection,
@@ -218,7 +204,7 @@ async def get_filtered_requests(
) -> List[Dict[str, Any]]:
"""
获取根据多种条件筛选的请求列表
-
+
:param conn: 数据库连接
:param store_id: 店铺ID筛选
:param user_id: 用户ID筛选
@@ -236,7 +222,7 @@ async def get_filtered_requests(
if start_date and end_date and start_date > end_date:
logger.warning("请求筛选的开始日期晚于结束日期")
return []
-
+
return self._change_request_crud.get_filtered_requests(
conn=conn,
store_id=store_id,
@@ -248,9 +234,9 @@ async def get_filtered_requests(
end_date=end_date,
limit=limit,
offset=offset,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def update_request(
self,
conn: Connection,
@@ -260,7 +246,7 @@ async def update_request(
) -> Optional[Dict[str, Any]]:
"""
更新店铺变更请求信息
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param update_data: 更新数据
@@ -270,15 +256,13 @@ async def update_request(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要更新的店铺变更请求: {request_id}")
return None
-
+
# 只有PENDING_APPROVAL状态的请求可以更新
try:
status = request.get("Status") if isinstance(request, dict) else None
@@ -293,33 +277,35 @@ async def update_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 如果不是管理员,检查是否是请求创建者
try:
if actor_id and isinstance(request, dict) and actor_id != request.get("RequestingUserID"):
# 这里可以添加管理员权限检查
# 如果actor_id不是管理员,则拒绝操作
- logger.warning(f"用户无权更新此请求: actor_id={actor_id}, request_owner={request.get('RequestingUserID')}")
+ logger.warning(
+ f"用户无权更新此请求: actor_id={actor_id}, request_owner={request.get('RequestingUserID')}"
+ )
raise PermissionDeniedException("只有请求创建者或管理员可以更新请求")
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# JSON序列化检查
if "proposed_data" in update_data:
try:
json.dumps(update_data["proposed_data"])
except (TypeError, ValueError) as e:
logger.error(f"更新数据无法序列化为JSON: {e}")
- update_data["proposed_data"] = {"error": "原始数据无法序列化", "data_str": str(update_data["proposed_data"])}
-
+ update_data["proposed_data"] = {
+ "error": "原始数据无法序列化",
+ "data_str": str(update_data["proposed_data"]),
+ }
+
return self._change_request_crud.update_request(
- conn=conn,
- request_id=request_id,
- update_data=update_data,
- actor_id=actor_id
+ conn=conn, request_id=request_id, update_data=update_data, actor_id=actor_id
)
-
+
async def update_request_status(
self,
conn: Connection,
@@ -331,7 +317,7 @@ async def update_request_status(
) -> Optional[Dict[str, Any]]:
"""
更新请求状态
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param status: 新状态
@@ -344,20 +330,18 @@ async def update_request_status(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要更新状态的店铺变更请求: {request_id}")
return None
-
+
# 验证状态转换是否有效
try:
current_status = request.get("Status") if isinstance(request, dict) else ""
valid_transitions = self._get_valid_status_transitions(current_status)
-
+
if status not in valid_transitions and isinstance(current_status, str):
logger.warning(f"无效的状态转换: {current_status} -> {status}")
raise InvalidStatusTransitionException(
@@ -366,21 +350,21 @@ async def update_request_status(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 检查是否有管理员权限
if not admin_id:
logger.warning("状态更新需要管理员ID")
raise PermissionDeniedException("状态更新需要管理员权限")
-
+
return self._change_request_crud.update_request_status(
conn=conn,
request_id=request_id,
status=status,
admin_id=admin_id,
admin_notes=admin_notes,
- actor_id=actor_id
+ actor_id=actor_id,
)
-
+
async def cancel_request(
self,
conn: Connection,
@@ -389,7 +373,7 @@ async def cancel_request(
) -> bool:
"""
取消店铺变更请求(用户自行取消)
-
+
:param conn: 数据库连接
:param request_id: 请求ID
:param actor_id: 操作用户ID
@@ -398,15 +382,13 @@ async def cancel_request(
"""
# 查询现有请求
request = self._change_request_crud.get_change_request_by_id(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
+ conn=conn, request_id=request_id, actor_id=actor_id
)
-
+
if not request:
logger.warning(f"未找到要取消的店铺变更请求: {request_id}")
return False
-
+
# 只有PENDING_APPROVAL状态的请求可以取消
try:
status = request.get("Status") if isinstance(request, dict) else None
@@ -421,41 +403,39 @@ async def cancel_request(
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
+
# 检查是否是请求创建者
try:
if actor_id and isinstance(request, dict) and actor_id != request.get("RequestingUserID"):
# 这里可以添加管理员权限检查
# 如果actor_id不是管理员,则拒绝操作
- logger.warning(f"用户无权取消此请求: actor_id={actor_id}, request_owner={request.get('RequestingUserID')}")
+ logger.warning(
+ f"用户无权取消此请求: actor_id={actor_id}, request_owner={request.get('RequestingUserID')}"
+ )
raise PermissionDeniedException("只有请求创建者或管理员可以取消请求")
except (AttributeError, TypeError):
# 在测试环境中,模拟对象可能会导致异常
pass
-
- return self._change_request_crud.cancel_request(
- conn=conn,
- request_id=request_id,
- actor_id=actor_id
- )
-
+
+ return self._change_request_crud.cancel_request(conn=conn, request_id=request_id, actor_id=actor_id)
+
@staticmethod
def _get_valid_status_transitions(current_status: str) -> List[str]:
"""
获取有效的状态转换列表
-
+
:param current_status: 当前状态
:return: 有效的目标状态列表
"""
# 如果是测试环境中的mock对象,允许所有转换
if not isinstance(current_status, str):
return ["APPROVED", "REJECTED", "CANCELLED"]
-
+
transitions = {
"PENDING_APPROVAL": ["APPROVED", "REJECTED", "CANCELLED"],
"APPROVED": [], # 终态,不能再转换
"REJECTED": [], # 终态,不能再转换
- "CANCELLED": [] # 终态,不能再转换
+ "CANCELLED": [], # 终态,不能再转换
}
-
+
return transitions.get(current_status, [])
diff --git a/src/backend/app/services/store_change_request_service_v2.py b/src/backend/app/services/store_change_request_service_v2.py
index 7ba5213..ecf81fd 100644
--- a/src/backend/app/services/store_change_request_service_v2.py
+++ b/src/backend/app/services/store_change_request_service_v2.py
@@ -67,20 +67,14 @@ async def submit_new_request(
# 1. 基础校验
if request_in.RequestType == RequestTypeEnum.STORE_CREATE:
if request_in.StoreID is not None: # StoreID should be None for create requests
- raise BadRequestException(
- "StoreID must be null or not provided for STORE_CREATE requests."
- )
+ raise BadRequestException("StoreID must be null or not provided for STORE_CREATE requests.")
if not request_in.ProposedData_JSON:
- raise BadRequestException(
- "ProposedData_JSON is required for STORE_CREATE requests."
- )
+ raise BadRequestException("ProposedData_JSON is required for STORE_CREATE requests.")
# 进一步校验 ProposedData_JSON 内容 (例如,必须包含 StoreName)
try:
proposed_data = ProposedStoreData(**request_in.ProposedData_JSON.model_dump(exclude_unset=True)) # type: ignore
if not proposed_data.StoreName:
- raise BadRequestException(
- "StoreName is required in ProposedData_JSON for STORE_CREATE."
- )
+ raise BadRequestException("StoreName is required in ProposedData_JSON for STORE_CREATE.")
except Exception as e: # Pydantic ValidationError or other
raise BadRequestException(f"Invalid ProposedData_JSON for STORE_CREATE: {e}")
@@ -88,9 +82,7 @@ async def submit_new_request(
if request_in.StoreID is None:
raise BadRequestException("StoreID is required for STORE_UPDATE requests.")
if not request_in.ProposedData_JSON:
- raise BadRequestException(
- "ProposedData_JSON is required for STORE_UPDATE requests."
- )
+ raise BadRequestException("ProposedData_JSON is required for STORE_UPDATE requests.")
if not any(request_in.ProposedData_JSON.model_dump(exclude_unset=True).values()):
raise BadRequestException(
"ProposedData_JSON must contain at least one field to update for STORE_UPDATE."
@@ -98,9 +90,7 @@ async def submit_new_request(
# 验证 StoreID 存在且属于 requesting_user (除非是管理员代为提交,但这里是商家提交)
store_to_update = self._store_crud.get_store_by_id(conn=db, store_id=request_in.StoreID)
if not store_to_update:
- raise StoreNotFoundException(
- f"Target StoreID {request_in.StoreID} not found for update request."
- )
+ raise StoreNotFoundException(f"Target StoreID {request_in.StoreID} not found for update request.")
if store_to_update["OwnerUserID"] != requesting_user.UserID:
logger.warning(
f"User {requesting_user.UserID} attempting to submit UPDATE request for store {request_in.StoreID} not owned by them."
@@ -111,15 +101,11 @@ async def submit_new_request(
if request_in.StoreID is None:
raise BadRequestException("StoreID is required for STORE_DELETE requests.")
if request_in.ProposedData_JSON is not None:
- raise BadRequestException(
- "ProposedData_JSON must be null or not provided for STORE_DELETE requests."
- )
+ raise BadRequestException("ProposedData_JSON must be null or not provided for STORE_DELETE requests.")
# 验证 StoreID 存在且属于 requesting_user
store_to_delete = self._store_crud.get_store_by_id(conn=db, store_id=request_in.StoreID)
if not store_to_delete:
- raise StoreNotFoundException(
- f"Target StoreID {request_in.StoreID} not found for delete request."
- )
+ raise StoreNotFoundException(f"Target StoreID {request_in.StoreID} not found for delete request.")
if store_to_delete["OwnerUserID"] != requesting_user.UserID:
logger.warning(
f"User {requesting_user.UserID} attempting to submit DELETE request for store {request_in.StoreID} not owned by them."
@@ -129,9 +115,7 @@ async def submit_new_request(
# 2. 调用 CRUD 创建请求记录
# CRUD 层期望 ProposedData_JSON 是字典
proposed_data_as_dict = (
- request_in.ProposedData_JSON.model_dump(exclude_unset=True)
- if request_in.ProposedData_JSON
- else None
+ request_in.ProposedData_JSON.model_dump(exclude_unset=True) if request_in.ProposedData_JSON else None
)
created_request_dict: Optional[Dict[str, Any]] = None
@@ -162,9 +146,7 @@ async def submit_new_request(
)
if not created_request_dict:
- logger.error(
- f"Failed to create store change request in CRUD for User {requesting_user.UserID}"
- )
+ logger.error(f"Failed to create store change request in CRUD for User {requesting_user.UserID}")
raise Exception("Failed to submit store change request.")
logger.success(
@@ -183,18 +165,14 @@ async def get_request_details(
f"ActorID {actor_user.UserID} attempting to get details for StoreChangeRequestID {change_request_id}"
)
- request_data = self._scr_crud.get_request_by_id(
- conn=db, change_request_id=change_request_id
- )
+ request_data = self._scr_crud.get_request_by_id(conn=db, change_request_id=change_request_id)
if not request_data:
raise StoreNotFoundException(
f"StoreChangeRequest with ID {change_request_id} not found."
) # Using StoreNotFound as a generic "resource not found"
# 权限检查: 商家只能看自己的,管理员可以看所有
- if (
- request_data["RequestingUserID"] != actor_user.UserID
- ):
+ if request_data["RequestingUserID"] != actor_user.UserID:
logger.warning(
f"Permission Denied: ActorID {actor_user.UserID} cannot view SCR ID {change_request_id} by RequesterID {request_data['RequestingUserID']}."
)
@@ -216,9 +194,7 @@ async def list_requests_for_requesting_user( # Renamed from list_requests_for_m
status_list_values: Optional[List[str]] = (
[s.value for s in query_params.Status] if query_params.Status else None
)
- request_type_value: Optional[str] = (
- query_params.RequestType.value if query_params.RequestType else None
- )
+ request_type_value: Optional[str] = query_params.RequestType.value if query_params.RequestType else None
requests_data = self._scr_crud.get_request_list(
conn=db,
@@ -230,9 +206,7 @@ async def list_requests_for_requesting_user( # Renamed from list_requests_for_m
)
response_items = [StoreChangeRequestResponse(**data) for data in requests_data]
- return StoreChangeRequestListResponse(
- Requests=response_items, TotalCount=len(response_items)
- )
+ return StoreChangeRequestListResponse(Requests=response_items, TotalCount=len(response_items))
async def list_requests_for_admin(
self,
@@ -251,9 +225,7 @@ async def list_requests_for_admin(
status_list_values: Optional[List[str]] = (
[s.value for s in query_params.Status] if query_params.Status else None
)
- request_type_value: Optional[str] = (
- query_params.RequestType.value if query_params.RequestType else None
- )
+ request_type_value: Optional[str] = query_params.RequestType.value if query_params.RequestType else None
requests_data = self._scr_crud.get_request_list(
conn=db,
@@ -264,9 +236,7 @@ async def list_requests_for_admin(
# ProductID filter is not in StoreChangeRequest DDL directly, but StoreID is.
)
response_items = [StoreChangeRequestResponse(**data) for data in requests_data]
- return StoreChangeRequestListResponse(
- Requests=response_items, TotalCount=len(response_items)
- )
+ return StoreChangeRequestListResponse(Requests=response_items, TotalCount=len(response_items))
async def user_cancel_request( # Renamed from merchant_cancel_request
self,
@@ -279,13 +249,9 @@ async def user_cancel_request( # Renamed from merchant_cancel_request
f"RequestingUserID {requesting_user.UserID} attempting to cancel StoreChangeRequestID {change_request_id}"
)
- request_to_cancel = self._scr_crud.get_request_by_id(
- conn=db, change_request_id=change_request_id
- )
+ request_to_cancel = self._scr_crud.get_request_by_id(conn=db, change_request_id=change_request_id)
if not request_to_cancel:
- raise StoreNotFoundException(
- f"StoreChangeRequest with ID {change_request_id} not found."
- )
+ raise StoreNotFoundException(f"StoreChangeRequest with ID {change_request_id} not found.")
if request_to_cancel["RequestingUserID"] != requesting_user.UserID:
logger.warning(
@@ -304,17 +270,11 @@ async def user_cancel_request( # Renamed from merchant_cancel_request
if not success:
raise Exception("Failed to cancel the request in CRUD layer.")
- updated_request_dict = self._scr_crud.get_request_by_id(
- conn=db, change_request_id=change_request_id
- )
+ updated_request_dict = self._scr_crud.get_request_by_id(conn=db, change_request_id=change_request_id)
if not updated_request_dict:
- raise StoreNotFoundException(
- f"Request {change_request_id} not found after cancellation attempt."
- )
+ raise StoreNotFoundException(f"Request {change_request_id} not found after cancellation attempt.")
- logger.success(
- f"StoreChangeRequestID {change_request_id} cancelled by User {requesting_user.UserID}."
- )
+ logger.success(f"StoreChangeRequestID {change_request_id} cancelled by User {requesting_user.UserID}.")
return StoreChangeRequestResponse(**updated_request_dict)
async def admin_review_request(
@@ -332,13 +292,9 @@ async def admin_review_request(
# raise PermissionDeniedException("Only administrators can review requests.")
# TODO: 使用专门的权限管理类
- request_to_review = self._scr_crud.get_request_by_id(
- conn=db, change_request_id=change_request_id
- )
+ request_to_review = self._scr_crud.get_request_by_id(conn=db, change_request_id=change_request_id)
if not request_to_review:
- raise StoreNotFoundException(
- f"StoreChangeRequest with ID {change_request_id} not found for review."
- )
+ raise StoreNotFoundException(f"StoreChangeRequest with ID {change_request_id} not found for review.")
if request_to_review["Status"] != StatusEnum.PENDING_APPROVAL.value:
raise InvalidOperationException(
@@ -361,18 +317,14 @@ async def admin_review_request(
)
if review_data.Status == StatusEnum.APPROVED:
- logger.info(
- f"Request {change_request_id} approved by admin. Attempting to apply changes."
- )
+ logger.info(f"Request {change_request_id} approved by admin. Attempting to apply changes.")
try:
# apply_approved_request is now public
return await self.apply_approved_request(
db=db, change_request_id=change_request_id, applier_user=admin_user
)
except Exception as e:
- logger.error(
- f"Failed to apply approved request {change_request_id} by Admin {admin_user.UserID}: {e}"
- )
+ logger.error(f"Failed to apply approved request {change_request_id} by Admin {admin_user.UserID}: {e}")
# Return the request in its 'APPROVED' state if application fails
return StoreChangeRequestResponse(**updated_request_dict)
@@ -390,18 +342,12 @@ async def apply_approved_request(
pcr_data = self._scr_crud.get_request_by_id(conn=db, change_request_id=change_request_id)
if not pcr_data:
- raise StoreNotFoundException(
- f"StoreChangeRequest with ID {change_request_id} not found for application."
- )
+ raise StoreNotFoundException(f"StoreChangeRequest with ID {change_request_id} not found for application.")
if pcr_data["Status"] != StatusEnum.APPROVED.value:
- raise InvalidOperationException(
- f"Request {change_request_id} is not in APPROVED status. Cannot apply."
- )
+ raise InvalidOperationException(f"Request {change_request_id} is not in APPROVED status. Cannot apply.")
# Permission: Admin can apply any. Merchant might apply their own if auto-apply or specific flow.
- if (
- pcr_data["RequestingUserID"] != applier_user.UserID
- ):
+ if pcr_data["RequestingUserID"] != applier_user.UserID:
logger.warning(
f"Permission Denied: User {applier_user.UserID} cannot apply request {change_request_id} for Requester {pcr_data['RequestingUserID']}."
)
@@ -429,9 +375,7 @@ async def apply_approved_request(
request_type = RequestTypeEnum(pcr_data["RequestType"]) # Convert string from DB to Enum
if request_type == RequestTypeEnum.STORE_CREATE:
- if (
- applied_store_id is not None
- ): # StoreID should be NULL for create request before application
+ if applied_store_id is not None: # StoreID should be NULL for create request before application
raise InvalidOperationException(
f"StoreID should be null for STORE_CREATE request {change_request_id} before application."
)
@@ -463,22 +407,16 @@ async def apply_approved_request(
description=proposed_data_obj.Description,
logo_url=proposed_data_obj.LogoURL,
store_status=new_store_status,
- creation_date=datetime.datetime.now(
- datetime.timezone.utc
- ), # Or from pcr_data["CreationTime"]
+ creation_date=datetime.datetime.now(datetime.timezone.utc), # Or from pcr_data["CreationTime"]
actor_id=applier_user.UserID,
)
if not created_store_dict:
raise Exception(f"Failed to create store for ChangeRequestID {change_request_id}.")
applied_store_id = created_store_dict["StoreID"]
- logger.info(
- f"Store {applied_store_id} created for ChangeRequestID {change_request_id}."
- )
+ logger.info(f"Store {applied_store_id} created for ChangeRequestID {change_request_id}.")
# change the user's role to `merchant` if they used to be only `customer`
- user_to_update = self._user_crud.get_user_by_id(
- conn=db, user_id=pcr_data["RequestingUserID"]
- )
+ user_to_update = self._user_crud.get_user_by_id(conn=db, user_id=pcr_data["RequestingUserID"])
if not user_to_update:
raise UserNotFoundException(
f"User {pcr_data['RequestingUserID']} not found for updating role to merchant."
@@ -491,22 +429,14 @@ async def apply_approved_request(
actor_id=applier_user.UserID,
)
if not update_success:
- raise Exception(
- f"Failed to update user {pcr_data['RequestingUserID']} role to merchant."
- )
- logger.info(
- f"User {pcr_data['RequestingUserID']} role updated to merchant after store creation."
- )
+ raise Exception(f"Failed to update user {pcr_data['RequestingUserID']} role to merchant.")
+ logger.info(f"User {pcr_data['RequestingUserID']} role updated to merchant after store creation.")
elif request_type == RequestTypeEnum.STORE_UPDATE:
if applied_store_id is None:
- raise InvalidOperationException(
- f"StoreID is missing for STORE_UPDATE request {change_request_id}."
- )
+ raise InvalidOperationException(f"StoreID is missing for STORE_UPDATE request {change_request_id}.")
if not proposed_data_obj:
- raise BadRequestException(
- "ProposedData_JSON is required for STORE_UPDATE application."
- )
+ raise BadRequestException("ProposedData_JSON is required for STORE_UPDATE application.")
update_kwargs = proposed_data_obj.model_dump()
# Convert StoreStatus to Enum if present
@@ -519,9 +449,7 @@ async def apply_approved_request(
)
if not update_kwargs:
- logger.info(
- f"No fields to update in ProposedData_JSON for STORE_UPDATE request {change_request_id}."
- )
+ logger.info(f"No fields to update in ProposedData_JSON for STORE_UPDATE request {change_request_id}.")
else:
# Pass individual fields to store_crud.update_store
updated_store_dict = self._store_crud.update_store(
@@ -537,15 +465,11 @@ async def apply_approved_request(
raise Exception(
f"Failed to update store {applied_store_id} for ChangeRequestID {change_request_id}."
)
- logger.info(
- f"Store {applied_store_id} updated for ChangeRequestID {change_request_id}."
- )
+ logger.info(f"Store {applied_store_id} updated for ChangeRequestID {change_request_id}.")
elif request_type == RequestTypeEnum.STORE_DELETE:
if applied_store_id is None:
- raise InvalidOperationException(
- f"StoreID is missing for STORE_DELETE request {change_request_id}."
- )
+ raise InvalidOperationException(f"StoreID is missing for STORE_DELETE request {change_request_id}.")
# Business logic for "deleting" a store is usually setting its status
# to something like 'CLOSED_PERMANENTLY_BY_ADMIN' or 'INACTIVE_BY_MERCHANT'
@@ -571,9 +495,7 @@ async def apply_approved_request(
final_pcr_dict = self._scr_crud.update_request_store_id_and_status_applied(
conn=db,
change_request_id=change_request_id,
- applied_store_id=(
- applied_store_id if request_type == RequestTypeEnum.STORE_CREATE else None
- ),
+ applied_store_id=(applied_store_id if request_type == RequestTypeEnum.STORE_CREATE else None),
# This will be new StoreID for CREATE, or existing for UPDATE/DELETE
actor_id=applier_user.UserID,
)
@@ -581,9 +503,7 @@ async def apply_approved_request(
logger.error(
f"CRITICAL: Failed to update StoreChangeRequestID {change_request_id} to APPLIED after store operation."
)
- raise Exception(
- f"Failed to finalize ChangeRequest {change_request_id} status to APPLIED."
- )
+ raise Exception(f"Failed to finalize ChangeRequest {change_request_id} status to APPLIED.")
logger.success(
f"StoreChangeRequestID {change_request_id} successfully applied by UserID {applier_user.UserID}."
diff --git a/src/backend/app/services/store_service.py b/src/backend/app/services/store_service.py
index 67a277a..d3f2704 100644
--- a/src/backend/app/services/store_service.py
+++ b/src/backend/app/services/store_service.py
@@ -16,8 +16,11 @@
StoreUpdate,
StoreResponse,
StoreListResponse,
- StoreStatusEnum, StoreListSimpleResponse, StoreSimpleResponse
+ StoreStatusEnum,
+ StoreListSimpleResponse,
+ StoreSimpleResponse,
)
+
# 导入自定义异常
from backend.app.utils.exceptions import StoreNotFoundException, UserNotFoundException, PermissionDeniedException
@@ -34,13 +37,7 @@ def __init__(self, store_crud: StoreCRUD, user_crud: UserCRUD):
self._user_crud = user_crud
logger.info(f"{self.__class__.__name__} initialized.")
- async def create_new_store(
- self,
- db: Connection,
- *,
- store_in: StoreCreate,
- actor_id: int
- ) -> StoreResponse:
+ async def create_new_store(self, db: Connection, *, store_in: StoreCreate, actor_id: int) -> StoreResponse:
"""
创建一个新的店铺 (通常由管理员操作,或通过特定流程的商家)。
@@ -52,7 +49,8 @@ async def create_new_store(
:raises Exception: 如果店铺创建失败。
"""
logger.info(
- f"ActorID {actor_id} attempting to create store '{store_in.StoreName}' for OwnerUserID {store_in.OwnerUserID}.")
+ f"ActorID {actor_id} attempting to create store '{store_in.StoreName}' for OwnerUserID {store_in.OwnerUserID}."
+ )
# TODO: 权限检查
if True:
logger.warning(
@@ -60,11 +58,13 @@ async def create_new_store(
)
# 1. 验证 OwnerUserID 是否存在
- owner_user = self._user_crud.get_user_by_id(conn=db, user_id=store_in.OwnerUserID,
- actor_id=actor_id) # actor_id for audit
+ owner_user = self._user_crud.get_user_by_id(
+ conn=db, user_id=store_in.OwnerUserID, actor_id=actor_id
+ ) # actor_id for audit
if not owner_user:
logger.warning(
- f"OwnerUserID {store_in.OwnerUserID} not found when trying to create store '{store_in.StoreName}'.")
+ f"OwnerUserID {store_in.OwnerUserID} not found when trying to create store '{store_in.StoreName}'."
+ )
raise UserNotFoundException(f"Prospective owner UserID {store_in.OwnerUserID} not found.")
# 2. 调用 CRUD 创建店铺
@@ -76,24 +76,22 @@ async def create_new_store(
logo_url=store_in.LogoURL,
store_status=store_in.StoreStatus,
creation_date=store_in.CreationDate,
- actor_id=actor_id
+ actor_id=actor_id,
)
if not created_store_dict:
logger.error(
- f"Failed to create store '{store_in.StoreName}' in CRUD layer for OwnerUserID {store_in.OwnerUserID}.")
+ f"Failed to create store '{store_in.StoreName}' in CRUD layer for OwnerUserID {store_in.OwnerUserID}."
+ )
raise Exception(f"Failed to create store '{store_in.StoreName}'.")
logger.success(
- f"Store '{store_in.StoreName}' (ID: {created_store_dict['StoreID']}) created successfully by ActorID {actor_id}.")
+ f"Store '{store_in.StoreName}' (ID: {created_store_dict['StoreID']}) created successfully by ActorID {actor_id}."
+ )
return StoreResponse(**created_store_dict)
async def user_get_store_by_id(
- self,
- db: Connection,
- *,
- store_id: int,
- actor_id: Optional[int] = None
+ self, db: Connection, *, store_id: int, actor_id: Optional[int] = None
) -> StoreResponse:
"""
根据 StoreID 获取店铺信息。
@@ -106,25 +104,23 @@ async def user_get_store_by_id(
:raises StoreNotFoundException: 如果店铺未找到或不符合查看条件。
"""
logger.info(f"ActorID {actor_id} attempting to get store by StoreID {store_id}.")
- store_data = self._store_crud.get_store_by_id(conn=db, store_id=store_id,
- actor_id=actor_id) # actor_id for CRUD audit
+ store_data = self._store_crud.get_store_by_id(
+ conn=db, store_id=store_id, actor_id=actor_id
+ ) # actor_id for CRUD audit
if not store_data:
raise StoreNotFoundException(f"Store with ID {store_id} not found.")
if store_data["StoreStatus"] != StoreStatusEnum.ACTIVE.value:
logger.info(
- f"Store {store_id} (Status: {store_data['StoreStatus']}) is not active. Access denied for non-owner/non-admin ActorID {actor_id}.")
+ f"Store {store_id} (Status: {store_data['StoreStatus']}) is not active. Access denied for non-owner/non-admin ActorID {actor_id}."
+ )
raise StoreNotFoundException(f"Store with ID {store_id} not found.")
# 对普通用户隐藏非激活店铺,表现为“未找到”
return StoreResponse(**store_data)
async def merchant_get_store_by_id(
- self,
- db: Connection,
- *,
- store_id: int,
- actor_id: Optional[int] = None
+ self, db: Connection, *, store_id: int, actor_id: Optional[int] = None
) -> StoreResponse:
"""
根据 StoreID 获取店铺信息。
@@ -137,18 +133,19 @@ async def merchant_get_store_by_id(
:raises StoreNotFoundException: 如果店铺未找到或不符合查看条件。
"""
logger.info(f"ActorID {actor_id} attempting to get store by StoreID {store_id}.")
- store_data = self._store_crud.get_store_by_id(conn=db, store_id=store_id,
- actor_id=actor_id) # actor_id for CRUD audit
+ store_data = self._store_crud.get_store_by_id(
+ conn=db, store_id=store_id, actor_id=actor_id
+ ) # actor_id for CRUD audit
if not store_data:
raise StoreNotFoundException(f"Store with ID {store_id} not found.")
return StoreResponse(**store_data)
async def user_get_stores_simple(
- self,
- db: Connection,
- *,
- offset_and_limit: Optional[Tuple[int, int]] = None,
+ self,
+ db: Connection,
+ *,
+ offset_and_limit: Optional[Tuple[int, int]] = None,
) -> StoreListSimpleResponse:
"""
用户视角:获取所有 ACTIVE 状态的店铺列表,支持分页。
@@ -162,16 +159,14 @@ async def user_get_stores_simple(
offset, limit = offset_and_limit
stores_data = self._store_crud.get_all_stores_page(
conn=db,
- store_status=StoreStatusEnum.ACTIVE, # 只查询 ACTIVE 状态的店铺
+ store_status=StoreStatusEnum.ACTIVE,
offset=offset,
limit=limit,
- actor_id=None # Public query
+ actor_id=None, # 只查询 ACTIVE 状态的店铺 # Public query
)
else:
stores_data = self._store_crud.get_all_stores(
- conn=db,
- store_status=StoreStatusEnum.ACTIVE, # 只查询 ACTIVE 状态的店铺
- actor_id=None # Public query
+ conn=db, store_status=StoreStatusEnum.ACTIVE, actor_id=None # 只查询 ACTIVE 状态的店铺 # Public query
)
# 获取总数
@@ -181,14 +176,14 @@ async def user_get_stores_simple(
return StoreListSimpleResponse(Count=total_count, StoreList=store_responses)
async def get_stores_by_owner(
- self,
- db: Connection,
- *,
- owner_user_id: int,
- actor_id: int,
- offset: int = 0,
- limit: int = 20,
- no_offset_and_limit: bool = False
+ self,
+ db: Connection,
+ *,
+ owner_user_id: int,
+ actor_id: int,
+ offset: int = 0,
+ limit: int = 20,
+ no_offset_and_limit: bool = False,
) -> StoreListResponse:
"""
商家视角:获取指定店主用户的所有店铺列表(不限状态),支持分页。
@@ -206,7 +201,8 @@ async def get_stores_by_owner(
logger.info(f"ActorID {actor_id} attempting to get stores for OwnerUserID {owner_user_id}.")
if actor_id != owner_user_id:
logger.warning(
- f"ActorID {actor_id} (not owner) attempting to access stores for OwnerUserID {owner_user_id}.")
+ f"ActorID {actor_id} (not owner) attempting to access stores for OwnerUserID {owner_user_id}."
+ )
if not no_offset_and_limit:
logger.info(f"Fetching stores for OwnerUserID {owner_user_id} with offset {offset} and limit {limit}.")
@@ -225,11 +221,7 @@ async def get_stores_by_owner(
return StoreListResponse(Count=total_count, StoreList=store_responses)
async def user_get_all_stores_full(
- self,
- db: Connection,
- *,
- offset_and_limit: Optional[Tuple[int, int]] = None,
- actor_id: Optional[int] = None
+ self, db: Connection, *, offset_and_limit: Optional[Tuple[int, int]] = None, actor_id: Optional[int] = None
) -> StoreListResponse:
"""
管理员视角:获取所有店铺列表,支持分页。
@@ -249,17 +241,11 @@ async def user_get_all_stores_full(
if offset_and_limit:
offset, limit = offset_and_limit
stores_data = self._store_crud.get_all_stores_page(
- conn=db,
- store_status=StoreStatusEnum.ACTIVE,
- offset=offset,
- limit=limit,
- actor_id=actor_id
+ conn=db, store_status=StoreStatusEnum.ACTIVE, offset=offset, limit=limit, actor_id=actor_id
)
else:
stores_data = self._store_crud.get_all_stores(
- conn=db,
- store_status=StoreStatusEnum.ACTIVE,
- actor_id=actor_id
+ conn=db, store_status=StoreStatusEnum.ACTIVE, actor_id=actor_id
)
total_count = len(stores_data)
@@ -268,12 +254,7 @@ async def user_get_all_stores_full(
return StoreListResponse(Count=total_count, StoreList=store_responses)
async def update_store_info(
- self,
- db: Connection,
- *,
- store_id: int,
- store_in: StoreUpdate,
- actor_id: int
+ self, db: Connection, *, store_id: int, store_in: StoreUpdate, actor_id: int
) -> StoreResponse:
"""
更新店铺信息 (通常由管理员操作)。
@@ -315,11 +296,8 @@ async def update_store_info(
return StoreResponse(**store_to_update)
updated_store_dict = self._store_crud.update_store(
- conn=db,
- store_id=store_id,
- actor_id=actor_id,
- **update_kwargs # 传递解包后的参数
- )
+ conn=db, store_id=store_id, actor_id=actor_id, **update_kwargs
+ ) # 传递解包后的参数
if not updated_store_dict:
logger.error(f"Failed to update store in CRUD layer for StoreID {store_id}, or no effective changes made.")
@@ -327,5 +305,3 @@ async def update_store_info(
logger.success(f"Store ID {store_id} updated successfully by ActorID {actor_id}.")
return StoreResponse(**updated_store_dict)
-
-
diff --git a/src/backend/app/services/user_service.py b/src/backend/app/services/user_service.py
index 96f3588..56bd7ea 100644
--- a/src/backend/app/services/user_service.py
+++ b/src/backend/app/services/user_service.py
@@ -13,19 +13,19 @@ class UserService:
"""
def __init__(
- self,
- user_crud: UserCRUD,
- hash_pwd_func: Callable[[str], str],
+ self,
+ user_crud: UserCRUD,
+ hash_pwd_func: Callable[[str], str],
):
self._user_crud = user_crud
self._hash_pwd_func = hash_pwd_func
def register_new_user(
- self,
- conn: Connection,
- *,
- user_in: UserCreate,
- performing_user_id: Optional[int] = None,
+ self,
+ conn: Connection,
+ *,
+ user_in: UserCreate,
+ performing_user_id: Optional[int] = None,
) -> UserResponse:
"""
注册新用户。
@@ -34,20 +34,10 @@ def register_new_user(
:param performing_user_id: 执行操作的用户 ID(通常是管理员或NULL)
:return: 注册成功的用户响应模型
"""
- if self._user_crud.get_user_by_username(
- conn,
- username=user_in.Username, actor_id=performing_user_id
- ):
- raise DuplicateUserError(
- f"User with username {user_in.Username} already exists."
- )
- if self._user_crud.get_user_by_email(
- conn,
- email=user_in.Email, actor_id=performing_user_id
- ):
- raise DuplicateUserError(
- f"User with email {user_in.Email} already exists."
- )
+ if self._user_crud.get_user_by_username(conn, username=user_in.Username, actor_id=performing_user_id):
+ raise DuplicateUserError(f"User with username {user_in.Username} already exists.")
+ if self._user_crud.get_user_by_email(conn, email=user_in.Email, actor_id=performing_user_id):
+ raise DuplicateUserError(f"User with email {user_in.Email} already exists.")
hashed_password = self._hash_pwd_func(user_in.Password)
@@ -60,4 +50,3 @@ def register_new_user(
)
return UserResponse(**created_user)
-
diff --git a/src/backend/app/utils/exceptions.py b/src/backend/app/utils/exceptions.py
index 2257b3a..c7d72e9 100644
--- a/src/backend/app/utils/exceptions.py
+++ b/src/backend/app/utils/exceptions.py
@@ -4,29 +4,37 @@ class DuplicateUserError(Exception):
def __init__(self, message: str = "User already exists."):
super().__init__(message)
+
class AuthenticationException(Exception):
"""自定义认证失败异常。"""
+
def __init__(self, detail: str = "Incorrect username or password"):
self.detail = detail
super().__init__(self.detail)
+
class InvalidTokenException(Exception):
"""自定义无效或过期令牌异常。"""
+
def __init__(self, detail: str = "Invalid or expired token"):
self.detail = detail
super().__init__(self.detail)
+
class UserNotFoundException(Exception):
"""用户未找到异常。"""
+
def __init__(self, detail: str = "User not found"):
self.detail = detail
super().__init__(self.detail)
+
class ProductNotFoundException(Exception):
def __init__(self, detail: str = "Product not found"):
self.detail = detail
super().__init__(self.detail)
+
class ProductFieldMissingException(Exception):
def __init__(self, detail: str = "Product field is missing"):
self.detail = detail
@@ -50,56 +58,74 @@ def __init__(self, detail: str = "Insufficient stock"):
self.detail = detail
super().__init__(self.detail)
+
class AddressNotFoundException(Exception):
"""地址未找到异常。"""
+
def __init__(self, detail: str = "Address not found"):
self.detail = detail
super().__init__(self.detail)
+
class OrderNotFoundException(Exception):
"""订单未找到异常。"""
+
def __init__(self, detail: str = "Order not found"):
self.detail = detail
super().__init__(self.detail)
+
class InvalidStatusTransitionException(Exception):
"""无效的状态转换异常。"""
+
def __init__(self, detail: str = "Invalid status transition"):
self.detail = detail
super().__init__(self.detail)
+
class PaymentTransactionNotFoundException(Exception):
"""支付事务未找到异常。"""
+
def __init__(self, detail: str = "Payment transaction not found"):
self.detail = detail
super().__init__(self.detail)
+
class InvalidPaymentStatusTransitionException(Exception):
"""无效的支付状态转换异常。"""
+
def __init__(self, detail: str = "Invalid payment status transition"):
self.detail = detail
super().__init__(self.detail)
+
class StoreNotFoundException(Exception):
"""店铺未找到异常。"""
+
def __init__(self, detail: str = "Store not found"):
self.detail = detail
super().__init__(self.detail)
+
class InvalidOperationException(Exception):
"""无效操作异常。"""
+
def __init__(self, detail: str = "Invalid operation"):
self.detail = detail
super().__init__(self.detail)
+
class BadRequestException(Exception):
"""错误请求异常。"""
+
def __init__(self, detail: str = "Bad request"):
self.detail = detail
super().__init__(self.detail)
+
class RequestNotFoundException(Exception):
"""请求未找到异常。"""
+
def __init__(self, detail: str = "Request not found"):
self.detail = detail
super().__init__(self.detail)
diff --git a/src/backend/app/utils/json.py b/src/backend/app/utils/json.py
index 8795543..83ec6de 100644
--- a/src/backend/app/utils/json.py
+++ b/src/backend/app/utils/json.py
@@ -3,6 +3,7 @@
import json
from decimal import Decimal
+
class DecimalEncoder(json.JSONEncoder):
"""
Custom JSON encoder for Decimal objects.
@@ -11,4 +12,4 @@ class DecimalEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Decimal):
return str(obj) # Convert Decimal to string
- return super().default(obj) # Call the superclass method for other types
+ return super().default(obj) # Call the superclass method for other types
diff --git a/src/backend/app/utils/security.py b/src/backend/app/utils/security.py
index ad55ad2..ecbd507 100644
--- a/src/backend/app/utils/security.py
+++ b/src/backend/app/utils/security.py
@@ -40,8 +40,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def create_access_token(
- data: Dict[str, Any],
- expires_delta: Optional[timedelta] = None,
+ data: Dict[str, Any],
+ expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a JWT access token.
@@ -64,7 +64,7 @@ def create_access_token(
def decode_access_token(
- token: str,
+ token: str,
) -> Optional[TokenPayload]:
"""
Decode a JWT access token.
diff --git a/src/backend/app/utils/testutils.py b/src/backend/app/utils/testutils.py
index 3533bef..bc74f09 100644
--- a/src/backend/app/utils/testutils.py
+++ b/src/backend/app/utils/testutils.py
@@ -1,6 +1,7 @@
# helper functions for testing
+
# 辅助函数来规范化SQL字符串以便比较
def normalize_sql(sql_string: str) -> str:
"""将SQL字符串中的多个空格和换行符替换为单个空格,并去除首尾空格。"""
- return ' '.join(sql_string.strip().split())
+ return " ".join(sql_string.strip().split())
diff --git a/src/backend/app/utils/timezone.py b/src/backend/app/utils/timezone.py
index 4e27743..a9d229c 100644
--- a/src/backend/app/utils/timezone.py
+++ b/src/backend/app/utils/timezone.py
@@ -17,6 +17,7 @@ def cast_db_datetime_to_utc(db_datetime: datetime.datetime) -> datetime.datetime
# Still convert it to UTC
return db_datetime.astimezone(datetime.timezone.utc) # Convert to UTC
+
def cast_anytime_to_utc(timestamp: datetime.datetime) -> datetime.datetime:
"""
Cast any datetime to UTC timezone.
@@ -26,7 +27,10 @@ def cast_anytime_to_utc(timestamp: datetime.datetime) -> datetime.datetime:
raise ValueError("The input datetime must be timezone-aware.")
return timestamp.astimezone(datetime.timezone.utc) # Convert to UTC
-def cast_dict_datetime_to_utc(data: Optional[Dict[str, Any]], caster=cast_db_datetime_to_utc) -> Optional[Dict[str, Any]]:
+
+def cast_dict_datetime_to_utc(
+ data: Optional[Dict[str, Any]], caster=cast_db_datetime_to_utc
+) -> Optional[Dict[str, Any]]:
"""
Cast all datetime objects in a dictionary to UTC timezone.
This function will recursively traverse the dictionary and convert all datetime
@@ -39,14 +43,17 @@ def cast_dict_datetime_to_utc(data: Optional[Dict[str, Any]], caster=cast_db_dat
data[key] = cast_dict_datetime_to_utc(value, caster)
return data
+
# a decorator on CRUD methods that calls caster on the result
def returns_utc(func, caster=cast_db_datetime_to_utc):
"""
Decorator to cast all datetime objects in the result of a CRUD method to UTC timezone.
"""
+
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if result and isinstance(result, dict):
return cast_dict_datetime_to_utc(result, caster=caster)
return result
+
return wrapper
diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt
index c68884f..4e38e8a 100644
--- a/src/backend/requirements.txt
+++ b/src/backend/requirements.txt
@@ -12,3 +12,4 @@ loguru>=0.7.0
httpx
python-jose[cryptography]>=3.3.0
python-multipart
+black
diff --git a/src/backend/scripts/make_demo_data.py b/src/backend/scripts/make_demo_data.py
index 9720c2b..49b4aa2 100644
--- a/src/backend/scripts/make_demo_data.py
+++ b/src/backend/scripts/make_demo_data.py
@@ -97,21 +97,13 @@
"PhoneNumber": user_data.get("phone_number"), # Use .get() for optional fields
"UserRole": user_data["user_role"],
"RegistrationDate": user_data["registration_date"],
- "LastLoginDate": user_data.get(
- "last_login_date"
- ), # Assuming it might be missing or None
- "DefaultAddressID": user_data.get(
- "default_address_id"
- ), # Assuming it might be missing or None
- "AccountStatus": user_data.get(
- "account_status", "ACTIVE"
- ), # Default to ACTIVE if not provided
+ "LastLoginDate": user_data.get("last_login_date"), # Assuming it might be missing or None
+ "DefaultAddressID": user_data.get("default_address_id"), # Assuming it might be missing or None
+ "AccountStatus": user_data.get("account_status", "ACTIVE"), # Default to ACTIVE if not provided
}
connection.execute(insert_sql, params)
- logger.info(
- f"Inserted/Updated UserID: {user_data['user_id']}, Username: {user_data['username']}"
- )
+ logger.info(f"Inserted/Updated UserID: {user_data['user_id']}, Username: {user_data['username']}")
except Exception as e:
logger.error(f"Error inserting user {user_data.get('username', 'N/A')}: {e}")
# Decide if you want to rollback the whole transaction or continue
@@ -147,9 +139,7 @@
description="This store was successfully created via a change request.",
owner_user_id=3, # 与 Applied SCR 中的 RequestingUserID 匹配
store_status=StoreStatusEnum.ACTIVE.value, # 与 Applied SCR 中的 ProposedData 匹配
- creation_date=datetime.datetime(
- 2023, 8, 1, 0, 0, tzinfo=datetime.UTC
- ), # 给一个合理的创建时间
+ creation_date=datetime.datetime(2023, 8, 1, 0, 0, tzinfo=datetime.UTC), # 给一个合理的创建时间
),
]
@@ -264,9 +254,7 @@
f"Inserted/Updated ProductCategoryID: {category_data['category_id']}, Name: {category_data['category_name']}"
)
except Exception as e:
- logger.error(
- f"Error inserting product category {category_data.get('category_name', 'N/A')}: {e}"
- )
+ logger.error(f"Error inserting product category {category_data.get('category_name', 'N/A')}: {e}")
# transaction.rollback() # Rollback immediately or let the main try-except handle it
raise # Re-raise the exception to halt the script or be caught by an outer handler
@@ -339,8 +327,7 @@
1 if category.name not in ["badge", "appliance"] else 2
) # Assuming Demo Store 1 for most categories, Demo Store 2 for badges and appliances
product_entry = dict(
- product_id=len(other_products)
- + 4, # Start from 4 to avoid conflicts with promoted products
+ product_id=len(other_products) + 4, # Start from 4 to avoid conflicts with promoted products
product_name=name,
price=100 + len(other_products) * 10, # Example pricing logic
store_id=store_id,
@@ -384,9 +371,7 @@
params = {
"ProductID": product_data["product_id"],
"ProductName": product_data["product_name"],
- "ProductDescription": product_data.get(
- "product_description"
- ), # 使用 .get() 处理可选字段
+ "ProductDescription": product_data.get("product_description"), # 使用 .get() 处理可选字段
"Price": product_data["price"], # 已经是 Decimal 或 int/float,数据库会处理
"ProductStatus": product_data.get("product_status", "ACTIVE"), # 默认为 ACTIVE
"StoreID": product_data["store_id"],
@@ -396,9 +381,7 @@
}
connection.execute(insert_sql, params)
- logger.info(
- f"Inserted/Updated ProductID: {product_data['product_id']}, Name: {product_data['product_name']}"
- )
+ logger.info(f"Inserted/Updated ProductID: {product_data['product_id']}, Name: {product_data['product_name']}")
except Exception as e:
logger.error(
f"Error inserting product {product_data.get('product_name', 'N/A')} (ID: {product_data.get('product_id', 'N/A')}): {e}"
@@ -476,9 +459,7 @@
# You can add more cart items here following the pattern.
# Example: Another item for demo_user (UserID=2)
-assert (
- len(other_products) >= 5
-), "Not enough products in 'other_products' to add more items for demo_user."
+assert len(other_products) >= 5, "Not enough products in 'other_products' to add more items for demo_user."
for i in range(3, 5): # Adding two more items for demo_user
other_prod = other_products[i]
cart_items.append(
@@ -522,9 +503,7 @@
# 对于 CartItem,lastrowid 也会返回新生成的 CartItemID
# cart_item_id = result.lastrowid
# logger.info(f"Inserted CartItem for UserID: {item_data['UserID']}, ProductID: {item_data['ProductID']} with CartItemID: {cart_item_id}")
- logger.info(
- f"Inserted CartItem for UserID: {item_data['UserID']}, ProductID: {item_data['ProductID']}"
- )
+ logger.info(f"Inserted CartItem for UserID: {item_data['UserID']}, ProductID: {item_data['ProductID']}")
except Exception as e:
logger.error(
@@ -568,26 +547,14 @@
},
# Assuming other_products[0] is ProductID=4, other_products[1] is ProductID=5
4: {
- "Name": (
- other_products[0]["product_name"] if len(other_products) > 0 else "Sample Product 4"
- ),
- "Price": (
- Decimal(str(other_products[0]["price"]))
- if len(other_products) > 0
- else Decimal("100.00")
- ),
+ "Name": (other_products[0]["product_name"] if len(other_products) > 0 else "Sample Product 4"),
+ "Price": (Decimal(str(other_products[0]["price"])) if len(other_products) > 0 else Decimal("100.00")),
"ImageURL": other_products[0].get("main_image_url") if len(other_products) > 0 else None,
"StoreID": other_products[0]["store_id"] if len(other_products) > 0 else 1,
},
5: {
- "Name": (
- other_products[1]["product_name"] if len(other_products) > 1 else "Sample Product 5"
- ),
- "Price": (
- Decimal(str(other_products[1]["price"]))
- if len(other_products) > 1
- else Decimal("110.00")
- ),
+ "Name": (other_products[1]["product_name"] if len(other_products) > 1 else "Sample Product 5"),
+ "Price": (Decimal(str(other_products[1]["price"])) if len(other_products) > 1 else Decimal("110.00")),
"ImageURL": other_products[1].get("main_image_url") if len(other_products) > 1 else None,
"StoreID": other_products[1]["store_id"] if len(other_products) > 1 else 2,
},
@@ -964,8 +931,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
ShippingAddress_Full="1 Cancel St, Testville",
Notes_ByUser="I changed my mind.",
CreationTime=pt_7_creation_time,
- CompletionTime=pt_7_creation_time
- + datetime.timedelta(minutes=10), # Order completion/cancellation time
+ CompletionTime=pt_7_creation_time + datetime.timedelta(minutes=10), # Order completion/cancellation time
)
)
order_items.append(
@@ -1014,8 +980,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
Notes_ByMerchant="Item out of stock unexpectedly.",
CreationTime=pt_8_creation_time,
PaymentConfirmationTime=payment_confirm_time_8,
- CompletionTime=payment_confirm_time_8
- + datetime.timedelta(minutes=5), # Order completion/cancellation time
+ CompletionTime=payment_confirm_time_8 + datetime.timedelta(minutes=5), # Order completion/cancellation time
)
)
order_items.append(
@@ -1046,8 +1011,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
ExternalGatewayTransactionID=None,
Status="FAILED", # Payment timed out and failed
CreationTime=pt_9_creation_time,
- CompletionTime=pt_9_creation_time
- + datetime.timedelta(minutes=30), # Payment failed after 30 mins
+ CompletionTime=pt_9_creation_time + datetime.timedelta(minutes=30), # Payment failed after 30 mins
)
)
orders.append(
@@ -1063,8 +1027,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
ShippingAddress_Full="321 Second St, Another Town",
Notes_ByMerchant="System: Payment timed out.",
CreationTime=pt_9_creation_time,
- CompletionTime=pt_9_creation_time
- + datetime.timedelta(minutes=31), # Order completion/cancellation time
+ CompletionTime=pt_9_creation_time + datetime.timedelta(minutes=31), # Order completion/cancellation time
)
)
order_items.append(
@@ -1225,8 +1188,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
ShippingAddress_Full="PO Box 11, Split Order Town",
CreationTime=pt_11_creation_time,
PaymentConfirmationTime=payment_confirm_time_11,
- ShippingTime=payment_confirm_time_11
- + datetime.timedelta(hours=1, minutes=30), # Shipped slightly later
+ ShippingTime=payment_confirm_time_11 + datetime.timedelta(hours=1, minutes=30), # Shipped slightly later
)
)
order_items.append(
@@ -1302,14 +1264,10 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
"CompletionTime": pt_data.get("CompletionTime"),
}
connection.execute(insert_pt_sql, params_pt)
- logger.info(
- f"Inserted/Updated PaymentTransactionID: {conceptual_pt_id} for UserID: {pt_data['UserID']}"
- )
+ logger.info(f"Inserted/Updated PaymentTransactionID: {conceptual_pt_id} for UserID: {pt_data['UserID']}")
except Exception as e:
- logger.error(
- f"Error inserting payment transaction for UserID {pt_data.get('UserID', 'N/A')}: {e}"
- )
+ logger.error(f"Error inserting payment transaction for UserID {pt_data.get('UserID', 'N/A')}: {e}")
raise
logger.info("PaymentTransaction data insertion complete.")
@@ -1359,9 +1317,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
"OrderID": conceptual_order_id, # Explicitly setting for demo linking
"UserID": order_data["UserID"],
"StoreID": order_data["StoreID"],
- "PaymentTransactionID": order_data[
- "PaymentTransactionID"
- ], # This comes from your pt_id_X mapping
+ "PaymentTransactionID": order_data["PaymentTransactionID"], # This comes from your pt_id_X mapping
"OrderStatus": order_data["OrderStatus"], # Should be enum.value if data stores enums
"OrderTotalAmount": order_data["OrderTotalAmount"],
"DiscountAmount": order_data.get("DiscountAmount", Decimal("0.00")),
@@ -1379,9 +1335,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
"CompletionTime": order_data.get("CompletionTime"),
}
connection.execute(insert_order_sql, params_order)
- logger.info(
- f"Inserted/Updated OrderID: {conceptual_order_id} for UserID: {order_data['UserID']}"
- )
+ logger.info(f"Inserted/Updated OrderID: {conceptual_order_id} for UserID: {order_data['UserID']}")
except Exception as e:
logger.error(
@@ -1423,9 +1377,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
connection.execute(insert_oi_sql, params_oi)
# last_oi_id = result.lastrowid
- logger.info(
- f"Inserted OrderItem for OrderID: {item_data['OrderID']}, ProductID: {item_data['ProductID']}"
- )
+ logger.info(f"Inserted OrderItem for OrderID: {item_data['OrderID']}, ProductID: {item_data['ProductID']}")
except Exception as e:
logger.error(
@@ -1825,9 +1777,7 @@ def next_timestamp(add_minutes: int = 5) -> datetime.datetime:
"RecipientName": addr_data["RecipientName"],
"PhoneNumber": addr_data["PhoneNumber"],
"FullAddress_Text": addr_data["FullAddress_Text"],
- "IsDefault": (
- 1 if addr_data["IsDefault"] else 0
- ), # Convert boolean to 1 or 0 for TINYINT(1)
+ "IsDefault": (1 if addr_data["IsDefault"] else 0), # Convert boolean to 1 or 0 for TINYINT(1)
}
connection.execute(insert_sql, params)
diff --git a/src/backend/scripts/reset.py b/src/backend/scripts/reset.py
index 5d63999..d07eadf 100644
--- a/src/backend/scripts/reset.py
+++ b/src/backend/scripts/reset.py
@@ -1,5 +1,6 @@
from .run_all_ddls import *
+
def reset_dev(no_trigger=False):
main(
ddl_scripts_dir,
@@ -7,9 +8,10 @@ def reset_dev(no_trigger=False):
refresh=True,
refresh_data=False,
no_ddl=False,
- no_trigger=no_trigger
+ no_trigger=no_trigger,
)
+
def reset_test(no_trigger=False):
main(
ddl_scripts_dir,
@@ -17,7 +19,7 @@ def reset_test(no_trigger=False):
refresh=True,
refresh_data=False,
no_ddl=False,
- no_trigger=no_trigger
+ no_trigger=no_trigger,
)
@@ -30,13 +32,9 @@ def reset_test(no_trigger=False):
"target",
choices=["dev", "test"],
default="test",
- help="Specify the mode to run the DDL scripts. Default is 'test'."
- )
- parser.add_argument(
- "--no-trigger",
- action="store_true",
- help="Skip running trigger scripts."
+ help="Specify the mode to run the DDL scripts. Default is 'test'.",
)
+ parser.add_argument("--no-trigger", action="store_true", help="Skip running trigger scripts.")
args = parser.parse_args()
ENABLE_AUDIT_LOG = os.getenv("ENABLE_AUDIT_LOG", "false").lower() == "true"
diff --git a/src/backend/scripts/run_all_ddls.py b/src/backend/scripts/run_all_ddls.py
index ccb1930..e7ce3a9 100644
--- a/src/backend/scripts/run_all_ddls.py
+++ b/src/backend/scripts/run_all_ddls.py
@@ -40,6 +40,7 @@ def execute_sql_script(engine, script_text, split_statements=True):
logger.info(f"stmt: {stmt}")
raise e
+
def execute_sql_scripts_in_order(engine, directory, split_statements=True):
"""
Execute all SQL scripts in the specified directory in order.
@@ -55,16 +56,14 @@ def execute_sql_scripts_in_order(engine, directory, split_statements=True):
for sql_file in sql_files:
logger.info(f"Running script: {sql_file.name}")
# Read the SQL file
- with open(sql_file, "r", encoding='utf-8') as file:
+ with open(sql_file, "r", encoding="utf-8") as file:
sql_script = file.read()
# Execute the SQL script
execute_sql_script(engine, sql_script, split_statements=split_statements)
-def main(
- directory: Path, engine=None, refresh=False, refresh_data=False, no_ddl=False, no_trigger=False
-):
+def main(directory: Path, engine=None, refresh=False, refresh_data=False, no_ddl=False, no_trigger=False):
"""
Run all DDL scripts in the specified directory in order.
The scripts look like `001_create_table.sql`, `002_create_xxx.sql`, etc.
@@ -72,13 +71,13 @@ def main(
"""
if refresh:
logger.info(f"Running drop_all.sql script...")
- with open(drop_all_script, "r", encoding='utf-8') as file:
+ with open(drop_all_script, "r", encoding="utf-8") as file:
drop_script = file.read()
execute_sql_script(engine, drop_script)
if refresh_data:
logger.info(f"Running drop_all_data.sql script...")
- with open(drop_all_data_script, "r", encoding='utf-8') as file:
+ with open(drop_all_data_script, "r", encoding="utf-8") as file:
drop_data_script = file.read()
execute_sql_script(engine, drop_data_script)
diff --git a/src/backend/test/base_db_testcase.py b/src/backend/test/base_db_testcase.py
index 123190c..0bfe6a2 100644
--- a/src/backend/test/base_db_testcase.py
+++ b/src/backend/test/base_db_testcase.py
@@ -59,9 +59,9 @@ def setUp(self):
def tearDown(self):
"""在每个测试方法结束后运行"""
logger.info("Returned from test function")
- if hasattr(self, 'transaction') and self.transaction:
+ if hasattr(self, "transaction") and self.transaction:
self.transaction.rollback() # 回滚事务
- if hasattr(self, 'connection') and self.connection:
+ if hasattr(self, "connection") and self.connection:
self.connection.close() # 关闭连接
# print(f"DEBUG: {self.id()} - tearDown: Transaction rolled back.")
@@ -86,11 +86,12 @@ def setUpClass(cls):
if cls.engine is None:
raise RuntimeError(
- f"{cls.__name__}.setUpClass: Failed to get database engine after setting mode to 'test'.")
+ f"{cls.__name__}.setUpClass: Failed to get database engine after setting mode to 'test'."
+ )
# 2. 重置测试数据库 (创建表结构,清空数据等)
logger.info(f"INFO: {cls.__name__}.setUpClass - Resetting test database...")
- reset_test(no_trigger=True) # 假设 reset_test() 是一个同步函数
+ reset_test(no_trigger=True) # 假设 reset_test() 是一个同步函数
logger.info(f"INFO: {cls.__name__}.setUpClass - Setup complete.")
@@ -128,14 +129,15 @@ def tearDown(self):
同步方法。
"""
logger.info(
- f"INFO: {self.id()} - tearDown: Rolling back transaction and closing connection {id(self.connection)}.")
- if hasattr(self, 'transaction') and self.transaction and self.transaction.is_active:
+ f"INFO: {self.id()} - tearDown: Rolling back transaction and closing connection {id(self.connection)}."
+ )
+ if hasattr(self, "transaction") and self.transaction and self.transaction.is_active:
try:
self.transaction.rollback() # 回滚事务
except Exception as e:
logger.error(f"ERROR: {self.id()} - Failed to rollback transaction: {e}")
- if hasattr(self, 'connection') and self.connection and not self.connection.closed:
+ if hasattr(self, "connection") and self.connection and not self.connection.closed:
try:
self.connection.close() # 关闭连接
except Exception as e:
@@ -150,9 +152,7 @@ class BaseDBTestCase(unittest.TestCase):
def _execute_sql_script(cls, connection: Connection, script_path: Path):
"""辅助方法,用于执行 SQL 脚本文件。"""
if not script_path.is_file():
- raise FileNotFoundError(
- f"SQL script '{script_path.name}' not found at: {script_path}"
- )
+ raise FileNotFoundError(f"SQL script '{script_path.name}' not found at: {script_path}")
logger.info(f"INFO: Executing SQL script: {script_path}")
sql_script_content = script_path.read_text()
@@ -212,7 +212,7 @@ def setUpClass(cls):
# 如果你的测试数据库已经有正确的表结构和过程,只需要清空数据。
logger.info(f"INFO: {cls.__name__}.setUpClass - Performing initial database cleanup.")
- assert (cls._get_drop_all_data_script_path().exists())
+ assert cls._get_drop_all_data_script_path().exists()
reset_test(no_trigger=True)
@@ -255,6 +255,6 @@ def tearDown(self):
logger.info(f"ERROR: {self.id()} - tearDown: Failed during database cleanup: {e}")
# 即使清理失败,也要尝试关闭主测试连接
finally:
- if hasattr(self, 'connection') and self.connection and not self.connection.closed:
+ if hasattr(self, "connection") and self.connection and not self.connection.closed:
self.connection.close()
# print(f"DEBUG: {self.id()} - tearDown: Connection closed and database cleaned.")
diff --git a/src/backend/test/integration/api_endpoints/mock_statistics_views.py b/src/backend/test/integration/api_endpoints/mock_statistics_views.py
index 6759f9e..7a392a8 100644
--- a/src/backend/test/integration/api_endpoints/mock_statistics_views.py
+++ b/src/backend/test/integration/api_endpoints/mock_statistics_views.py
@@ -3,17 +3,21 @@
This module provides functions to create temporary mock views for the statistics tests.
"""
+
from sqlalchemy import text
from sqlalchemy.engine.base import Connection
from backend.app.utils import logger
+
def create_mock_views(conn: Connection):
"""Create mock views for the statistics tests."""
logger.info("Creating mock views for statistics tests")
-
+
try:
# Create system statistics view
- conn.execute(text("""
+ conn.execute(
+ text(
+ """
CREATE OR REPLACE VIEW system_statistics AS
SELECT
(SELECT COUNT(*) FROM User) AS total_users,
@@ -24,16 +28,24 @@ def create_mock_views(conn: Connection):
(SELECT COUNT(*) FROM `Order` WHERE OrderStatus = 'COMPLETED') AS completed_orders,
(SELECT COUNT(*) FROM `Order`) AS total_orders,
(SELECT COALESCE(SUM(FinalAmountForThisOrder), 0) FROM `Order` WHERE OrderStatus = 'COMPLETED') AS total_sales
- """))
-
+ """
+ )
+ )
+
# Create admin dashboard statistics view
- conn.execute(text("""
+ conn.execute(
+ text(
+ """
CREATE OR REPLACE VIEW admin_dashboard_statistics AS
SELECT * FROM system_statistics
- """))
-
+ """
+ )
+ )
+
# Create store statistics view
- conn.execute(text("""
+ conn.execute(
+ text(
+ """
CREATE OR REPLACE VIEW store_statistics AS
SELECT
s.StoreID AS store_id,
@@ -48,8 +60,10 @@ def create_mock_views(conn: Connection):
LEFT JOIN OrderItem oi ON oi.StoreID = s.StoreID
GROUP BY
s.StoreID, s.StoreName
- """))
-
+ """
+ )
+ )
+
logger.info("Mock views created successfully")
except Exception as e:
logger.error(f"Error creating mock views: {e}")
diff --git a/src/backend/test/integration/api_endpoints/test_address_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_address_endpoints_async.py
index bd71dee..5bf941d 100644
--- a/src/backend/test/integration/api_endpoints/test_address_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_address_endpoints_async.py
@@ -17,7 +17,7 @@
AddressListResponse,
AddressCreateRequest,
AddressUpdateRequest,
- SetDefaultAddressResponse
+ SetDefaultAddressResponse,
)
from backend.app.services.address_service import AddressService
from backend.app.crud.address_crud import AddressCRUD
@@ -45,18 +45,18 @@ def setUp(self):
Username=self.test_user1_username,
Email=self.test_user1_email,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
self.mock_current_user2_schema = CurrentUserSchema(
UserID=self.test_user2_id,
Username=self.test_user2_username,
Email=self.test_user2_email,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# --- Dependency Overrides ---
@@ -65,10 +65,7 @@ def override_get_db_connection() -> Connection:
self.real_address_crud = AddressCRUD.get_instance()
self.real_user_crud = UserCRUD.get_instance()
- self.real_address_service = AddressService(
- address_crud=self.real_address_crud,
- user_crud=self.real_user_crud
- )
+ self.real_address_service = AddressService(address_crud=self.real_address_crud, user_crud=self.real_user_crud)
def override_get_address_service() -> AddressService:
return self.real_address_service
@@ -87,10 +84,18 @@ async def override_get_current_active_user_default() -> CurrentUserSchema:
def _setup_initial_users(self):
try:
users_to_create_data = [
- {"UserID": self.test_user1_id, "Username": self.test_user1_username, "PasswordHash": "hash1",
- "Email": self.test_user1_email},
- {"UserID": self.test_user2_id, "Username": self.test_user2_username, "PasswordHash": "hash2",
- "Email": self.test_user2_email},
+ {
+ "UserID": self.test_user1_id,
+ "Username": self.test_user1_username,
+ "PasswordHash": "hash1",
+ "Email": self.test_user1_email,
+ },
+ {
+ "UserID": self.test_user2_id,
+ "Username": self.test_user2_username,
+ "PasswordHash": "hash2",
+ "Email": self.test_user2_email,
+ },
]
for user_data in users_to_create_data:
user_exists = self.connection.execute(
@@ -99,8 +104,9 @@ def _setup_initial_users(self):
if not user_exists:
self.connection.execute(
text(
- "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"),
- user_data
+ "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"
+ ),
+ user_data,
)
except Exception as e:
self.fail(f"Error setting up initial test users for Address Endpoints: {e}")
@@ -115,7 +121,7 @@ def _create_direct_address(self, user_id: int, recipient_name: str, is_default:
RecipientName=recipient_name,
PhoneNumber="123456789", # Valid phone number
FullAddress_Text=f"{recipient_name}'s Address",
- IsDefault=is_default # This is passed to service
+ IsDefault=is_default, # This is passed to service
)
# In a real test, we'd call the service method if it's synchronous
# Since service methods are async, and this is a sync helper, we use CRUD directly
@@ -127,8 +133,8 @@ def _create_direct_address(self, user_id: int, recipient_name: str, is_default:
created_address_dict = self.real_address_crud.create_address(
self.connection,
user_id=user_id,
- address_in=address_in, # CRUD create_address sets IsDefault=False
- actor_id=user_id
+ address_in=address_in,
+ actor_id=user_id, # CRUD create_address sets IsDefault=False
)
if not created_address_dict:
self.fail(f"Helper _create_direct_address failed for {recipient_name}")
@@ -145,9 +151,9 @@ def _create_direct_address(self, user_id: int, recipient_name: str, is_default:
self.connection, user_id=user_id, default_address_id=created_address_dict["AddressID"], actor_id=user_id
)
# Re-fetch to get updated IsDefault status
- created_address_dict = self.real_address_crud.get_address_by_id(self.connection,
- address_id=created_address_dict[
- "AddressID"])
+ created_address_dict = self.real_address_crud.get_address_by_id(
+ self.connection, address_id=created_address_dict["AddressID"]
+ )
return created_address_dict # type: ignore
@@ -157,7 +163,7 @@ async def test_add_new_address_not_default(self):
"RecipientName": "John Doe",
"PhoneNumber": "13800138000",
"FullAddress_Text": "123 Main St, Anytown, USA 12345",
- "IsDefault": False # Explicitly not default
+ "IsDefault": False, # Explicitly not default
}
response = self.client.post("/api/v1/address/", json=address_payload)
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
@@ -181,7 +187,7 @@ async def test_add_new_address_set_as_default_still_return_non_default(self):
"RecipientName": "Jane Default",
"PhoneNumber": "13900139000",
"FullAddress_Text": "456 Default Ave, Anytown, USA 67890",
- "IsDefault": True # Request to set as default
+ "IsDefault": True, # Request to set as default
}
response = self.client.post("/api/v1/address/", json=address_payload)
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
@@ -257,7 +263,8 @@ async def test_update_address_details_success(self):
address_id_to_update = created_addr["AddressID"] # type: ignore
update_payload = AddressUpdateRequest(RecipientName="Updated Name Here", PhoneNumber="123450000").model_dump(
- exclude_unset=True)
+ exclude_unset=True
+ )
response = self.client.put(f"/api/v1/address/{address_id_to_update}", json=update_payload)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
data = response.json()
@@ -272,10 +279,12 @@ async def test_update_address_details_not_found(self):
# --- POST /{address_id}/set-default Tests ---
async def test_set_address_as_default_success(self):
- addr1 = self._create_direct_address(user_id=self.test_user1_id, recipient_name="Addr1 To Be Default",
- is_default=False)
- addr2 = self._create_direct_address(user_id=self.test_user1_id, recipient_name="Addr2 Old Default",
- is_default=True)
+ addr1 = self._create_direct_address(
+ user_id=self.test_user1_id, recipient_name="Addr1 To Be Default", is_default=False
+ )
+ addr2 = self._create_direct_address(
+ user_id=self.test_user1_id, recipient_name="Addr2 Old Default", is_default=True
+ )
# Set addr1 as default
response = self.client.post(f"/api/v1/address/{addr1['AddressID']}/set-default") # type: ignore
@@ -289,10 +298,8 @@ async def test_set_address_as_default_success(self):
user_db = self.real_user_crud.get_user_by_id(self.connection, user_id=self.test_user1_id)
self.assertEqual(user_db["DefaultAddressID"], addr1["AddressID"]) # type: ignore
- addr1_db = self.real_address_crud.get_address_by_id(self.connection,
- address_id=addr1["AddressID"]) # type: ignore
- addr2_db = self.real_address_crud.get_address_by_id(self.connection,
- address_id=addr2["AddressID"]) # type: ignore
+ addr1_db = self.real_address_crud.get_address_by_id(self.connection, address_id=addr1["AddressID"]) # type: ignore
+ addr2_db = self.real_address_crud.get_address_by_id(self.connection, address_id=addr2["AddressID"]) # type: ignore
self.assertTrue(addr1_db["IsDefault"]) # type: ignore
self.assertFalse(addr2_db["IsDefault"]) # type: ignore
@@ -302,8 +309,9 @@ async def test_set_address_as_default_address_not_found(self):
# --- DELETE /{address_id} Tests ---
async def test_delete_address_success_non_default(self):
- addr_to_delete = self._create_direct_address(user_id=self.test_user1_id, recipient_name="To Delete NonDefault",
- is_default=False)
+ addr_to_delete = self._create_direct_address(
+ user_id=self.test_user1_id, recipient_name="To Delete NonDefault", is_default=False
+ )
address_id = addr_to_delete["AddressID"] # type: ignore
response = self.client.delete(f"/api/v1/address/{address_id}")
@@ -313,8 +321,9 @@ async def test_delete_address_success_non_default(self):
self.assertIsNone(self.real_address_crud.get_address_by_id(self.connection, address_id=address_id))
async def test_delete_address_success_was_default(self):
- addr_default_to_delete = self._create_direct_address(user_id=self.test_user1_id,
- recipient_name="To Delete Default", is_default=True)
+ addr_default_to_delete = self._create_direct_address(
+ user_id=self.test_user1_id, recipient_name="To Delete Default", is_default=True
+ )
address_id = addr_default_to_delete["AddressID"] # type: ignore
# Verify it was default
@@ -333,5 +342,5 @@ async def test_delete_address_not_found(self):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/api_endpoints/test_auth_endpoints.py b/src/backend/test/integration/api_endpoints/test_auth_endpoints.py
index 0b4d0ee..83f053d 100644
--- a/src/backend/test/integration/api_endpoints/test_auth_endpoints.py
+++ b/src/backend/test/integration/api_endpoints/test_auth_endpoints.py
@@ -5,11 +5,14 @@
from backend.app.utils.security import hash_password
from backend.test.base_db_testcase import BaseDBTestCaseAutoRollback, BaseDBTestCase
from backend.app.dependencies.service_deps import (
- get_auth_service, get_user_crud, get_user_session_crud,
- get_password_hasher
+ get_auth_service,
+ get_user_crud,
+ get_user_session_crud,
+ get_password_hasher,
)
from backend.app.utils import logger
+
@unittest.skip("This test is reimplemented in test_auth_endpoints_async.py")
class AuthEndpointIntegrationTest(BaseDBTestCase):
@classmethod
@@ -23,12 +26,10 @@ def create_user(self, username, email, password):
hashed = hash_password(password)
self.connection.execute(
text("INSERT INTO User (Username, Email, PasswordHash) VALUES (:u, :e, :p)"),
- {"u": username, "e": email, "p": hashed}
+ {"u": username, "e": email, "p": hashed},
)
self.connection.commit()
- user_id = self.connection.execute(
- text("SELECT UserID FROM User WHERE Username=:u"), {"u": username}
- ).scalar()
+ user_id = self.connection.execute(text("SELECT UserID FROM User WHERE Username=:u"), {"u": username}).scalar()
logger.info(f"insert new user {username=}, {email=}, {password=}. returns {user_id=}")
return user_id
@@ -84,7 +85,6 @@ def test_visit_me_with_token_success(self):
resp = self.client.get("/api/v1/user/me", headers=self.auth_header(token))
self.assertEqual(resp.status_code, 200)
-
def test_logout_success(self):
username = "logoutuser"
password = "logoutpass"
@@ -93,9 +93,7 @@ def test_logout_success(self):
resp = self.get_token(username, password)
token = resp.json()["access_token"]
- logger.success(
- f"Now logging out user {username} with token {token}"
- )
+ logger.success(f"Now logging out user {username} with token {token}")
logout_resp = self.client.post("/api/v1/auth/logout", headers=self.auth_header(token))
self.assertEqual(logout_resp.status_code, 200)
@@ -123,5 +121,6 @@ def test_logout_all_sessions(self):
resp2 = self.client.post("/api/v1/auth/logout", headers=self.auth_header(token2))
self.assertIn(resp2.status_code, (400, 401))
+
if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/integration/api_endpoints/test_auth_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_auth_endpoints_async.py
index 7504186..32a3c7c 100644
--- a/src/backend/test/integration/api_endpoints/test_auth_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_auth_endpoints_async.py
@@ -19,8 +19,12 @@
from backend.app.crud.user_session_crud import UserSessionCRUD # For direct DB interaction
from backend.app.dependencies.auth_deps import get_current_active_user, get_db_connection, get_token_from_auth_header
from backend.app.dependencies.service_deps import get_auth_service # Assuming this provides AuthService
-from backend.app.utils.security import hash_password, create_access_token, decode_access_token, \
- verify_password # For test setup
+from backend.app.utils.security import (
+ hash_password,
+ create_access_token,
+ decode_access_token,
+ verify_password,
+) # For test setup
from backend.app.core.config import settings # For token expiry, algorithm, secret
from backend.app.utils import logger # Assuming logger is available from utils
@@ -45,19 +49,28 @@ def _create_shared_class_users(cls, conn: Connection):
logger.info(f"--- {cls.__name__}: Creating shared class-level users ---")
try:
users_to_create_data = [
- {"UserID": cls.user1_id_class, "Username": cls.user1_username_class,
- "PasswordHash": cls.user1_password_hash_class, "Email": cls.user1_email_class},
- {"UserID": cls.user2_id_class, "Username": cls.user2_username_class,
- "PasswordHash": cls.user2_password_hash_class, "Email": cls.user2_email_class},
+ {
+ "UserID": cls.user1_id_class,
+ "Username": cls.user1_username_class,
+ "PasswordHash": cls.user1_password_hash_class,
+ "Email": cls.user1_email_class,
+ },
+ {
+ "UserID": cls.user2_id_class,
+ "Username": cls.user2_username_class,
+ "PasswordHash": cls.user2_password_hash_class,
+ "Email": cls.user2_email_class,
+ },
]
for user_data in users_to_create_data:
# Using ON DUPLICATE KEY UPDATE to make it idempotent if reset_test wasn't perfect
conn.execute(
text(
"INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP()) "
- "ON DUPLICATE KEY UPDATE Username=VALUES(Username), Email=VALUES(Email), PasswordHash=VALUES(PasswordHash)"),
+ "ON DUPLICATE KEY UPDATE Username=VALUES(Username), Email=VALUES(Email), PasswordHash=VALUES(PasswordHash)"
+ ),
# Ensure vital fields are updated if exists
- user_data
+ user_data,
)
logger.info(f"--- {cls.__name__}: Shared class-level users created/updated ---")
except Exception as e:
@@ -90,9 +103,9 @@ def setUp(self):
Username=self.user1_username_class,
Email=self.user1_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# mock_current_user2_schema can be created if needed for a specific test override
@@ -108,7 +121,7 @@ def override_get_db_connection() -> Connection:
user_session_crud=self.real_user_session_crud,
verify_password_func=verify_password,
create_jwt_func=create_access_token,
- decode_jwt_func=decode_access_token
+ decode_jwt_func=decode_access_token,
)
def override_get_auth_service() -> AuthService:
@@ -137,14 +150,14 @@ def _create_direct_session_for_user(self, user_id: int, token_jti: str, minutes_
user_id=user_id, # Can be self.user1_id_class or self.user2_id_class
expires_at=aware_expires_at,
ip_address="127.0.0.1",
- user_agent="TestSetupAgent"
+ user_agent="TestSetupAgent",
)
# --- /token (Login) Tests ---
async def test_login_success_with_username(self):
response = self.client.post(
"/api/v1/auth/token",
- data={"username": self.user1_username_class, "password": self.user1_password_plain_class}
+ data={"username": self.user1_username_class, "password": self.user1_password_plain_class},
)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
token_data = response.json()
@@ -160,8 +173,7 @@ async def test_login_success_with_username(self):
async def test_login_success_with_email(self):
response = self.client.post(
- "/api/v1/auth/token",
- data={"username": self.user1_email_class, "password": self.user1_password_plain_class}
+ "/api/v1/auth/token", data={"username": self.user1_email_class, "password": self.user1_password_plain_class}
)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
token_data = response.json()
@@ -176,16 +188,14 @@ async def test_login_success_with_email(self):
async def test_login_failure_wrong_password(self):
response = self.client.post(
- "/api/v1/auth/token",
- data={"username": self.user1_username_class, "password": "wrongpassword"}
+ "/api/v1/auth/token", data={"username": self.user1_username_class, "password": "wrongpassword"}
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, response.text)
self.assertIn("Incorrect identifier or password", response.json()["detail"])
async def test_login_failure_user_not_found(self):
response = self.client.post(
- "/api/v1/auth/token",
- data={"username": "nonexistentuser_class", "password": "password"}
+ "/api/v1/auth/token", data={"username": "nonexistentuser_class", "password": "password"}
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, response.text)
self.assertIn("Incorrect identifier or password", response.json()["detail"])
@@ -194,7 +204,7 @@ async def test_login_failure_user_not_found(self):
async def test_logout_current_session_success(self):
login_response = self.client.post(
"/api/v1/auth/token",
- data={"username": self.user1_username_class, "password": self.user1_password_plain_class}
+ data={"username": self.user1_username_class, "password": self.user1_password_plain_class},
)
self.assertEqual(login_response.status_code, 200)
token_str = login_response.json()["access_token"]
@@ -227,8 +237,10 @@ async def test_logout_no_token(self):
response = self.client.post("/api/v1/auth/logout") # No Authorization header
- if original_gca_override: app.dependency_overrides[get_current_active_user] = original_gca_override
- if original_gtfh_override: app.dependency_overrides[get_token_from_auth_header] = original_gtfh_override
+ if original_gca_override:
+ app.dependency_overrides[get_current_active_user] = original_gca_override
+ if original_gtfh_override:
+ app.dependency_overrides[get_token_from_auth_header] = original_gtfh_override
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Detail message comes from get_token_from_auth_header if it raises first
@@ -241,8 +253,10 @@ async def test_logout_invalid_token_format(self):
headers = {"Authorization": "InvalidTokenFormat"}
response = self.client.post("/api/v1/auth/logout", headers=headers)
- if original_gca_override: app.dependency_overrides[get_current_active_user] = original_gca_override
- if original_gtfh_override: app.dependency_overrides[get_token_from_auth_header] = original_gtfh_override
+ if original_gca_override:
+ app.dependency_overrides[get_current_active_user] = original_gca_override
+ if original_gtfh_override:
+ app.dependency_overrides[get_token_from_auth_header] = original_gtfh_override
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertIn("malformed", response.json()["detail"].lower())
@@ -277,7 +291,7 @@ async def test_logout_all_sessions_success(self):
# This will also create a session for user1
login_response = self.client.post(
"/api/v1/auth/token",
- data={"username": self.user1_username_class, "password": self.user1_password_plain_class}
+ data={"username": self.user1_username_class, "password": self.user1_password_plain_class},
)
self.assertEqual(login_response.status_code, 200)
auth_token_for_user1 = login_response.json()["access_token"]
@@ -310,11 +324,12 @@ async def test_logout_all_sessions_success(self):
async def test_logout_all_no_token(self):
original_gca_override = app.dependency_overrides.pop(get_current_active_user, None)
response = self.client.post("/api/v1/auth/logout-all")
- if original_gca_override: app.dependency_overrides[get_current_active_user] = original_gca_override
+ if original_gca_override:
+ app.dependency_overrides[get_current_active_user] = original_gca_override
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertIn("not authenticated", response.json()["detail"].lower())
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/api_endpoints/test_cart_endpoints.py b/src/backend/test/integration/api_endpoints/test_cart_endpoints.py
index 4bbbca4..b9a56b0 100644
--- a/src/backend/test/integration/api_endpoints/test_cart_endpoints.py
+++ b/src/backend/test/integration/api_endpoints/test_cart_endpoints.py
@@ -17,7 +17,7 @@
CartItemResponse,
CartItemCreateRequest, # Still used for valid requests
CartItemUpdateRequest,
- CartActionResponse
+ CartActionResponse,
)
from backend.app.services.cart_service import CartService
from backend.app.crud.cartitem_crud import CartItemCRUD
@@ -44,14 +44,14 @@ def setUp(self):
Username=self.test_user1_username,
Email=self.test_user1_email,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
# RegistrationDate and LastLoginDate are expected by your UserResponse schema
# if they are not optional. Assuming they are for this mock.
# If they are required, provide them.
# For UserResponse, CreatedAt and UpdatedAt are more common.
# Let's assume UserResponse (CurrentUserSchema) has CreatedAt and UpdatedAt
RegistrationDate=current_utc_time, # 使用 UserResponse 定义的字段
- LastLoginDate=current_utc_time # 使用 UserResponse 定义的字段
+ LastLoginDate=current_utc_time, # 使用 UserResponse 定义的字段
)
# --- 模拟数据库连接和服务的依赖覆盖 ---
@@ -65,10 +65,7 @@ async def override_get_current_active_user() -> CurrentUserSchema:
self.real_product_crud = ProductCRUD.get_instance()
def override_get_cart_service() -> CartService:
- return CartService(
- cart_item_crud=self.real_cart_item_crud,
- product_crud=self.real_product_crud
- )
+ return CartService(cart_item_crud=self.real_cart_item_crud, product_crud=self.real_product_crud)
app.dependency_overrides[get_db_connection] = override_get_db_connection
app.dependency_overrides[get_current_active_user] = override_get_current_active_user
@@ -87,55 +84,72 @@ def _setup_initial_data(self):
if not user_exists:
self.connection.execute(
text(
- "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:id, :uname, 'hash', :email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"),
- {"id": self.test_user1_id, "uname": self.test_user1_username, "email": self.test_user1_email}
+ "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:id, :uname, 'hash', :email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"
+ ),
+ {"id": self.test_user1_id, "uname": self.test_user1_username, "email": self.test_user1_email},
)
# 2. 创建商品分类
self.category1_id = 10
cat_exists = self.connection.execute(
- text("SELECT CategoryID FROM ProductCategory WHERE CategoryID = :id"),
- {"id": self.category1_id}).fetchone()
+ text("SELECT CategoryID FROM ProductCategory WHERE CategoryID = :id"), {"id": self.category1_id}
+ ).fetchone()
if not cat_exists:
self.connection.execute(
text("INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, :name)"),
- {"id": self.category1_id, "name": "电子产品"}
+ {"id": self.category1_id, "name": "电子产品"},
)
# 3. 创建店铺
self.store1_id = 20
- store_exists = self.connection.execute(text("SELECT StoreID FROM Store WHERE StoreID = :id"),
- {"id": self.store1_id}).fetchone()
+ store_exists = self.connection.execute(
+ text("SELECT StoreID FROM Store WHERE StoreID = :id"), {"id": self.store1_id}
+ ).fetchone()
if not store_exists:
self.connection.execute(
text(
- "INSERT INTO Store (StoreID, StoreName, OwnerUserID, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:id, :name, :owner_id, 'ACTIVE', UTC_TIMESTAMP(), UTC_TIMESTAMP())"),
- {"id": self.store1_id, "name": "测试用户1的店铺", "owner_id": self.test_user1_id}
+ "INSERT INTO Store (StoreID, StoreName, OwnerUserID, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:id, :name, :owner_id, 'ACTIVE', UTC_TIMESTAMP(), UTC_TIMESTAMP())"
+ ),
+ {"id": self.store1_id, "name": "测试用户1的店铺", "owner_id": self.test_user1_id},
)
# 4. 创建商品
self.product1_id = 301
self.product1_price = 19.99
- prod1_exists = self.connection.execute(text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": self.product1_id}).fetchone()
+ prod1_exists = self.connection.execute(
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": self.product1_id}
+ ).fetchone()
if not prod1_exists:
self.connection.execute(
text(
- "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 100, UTC_TIMESTAMP(), UTC_TIMESTAMP())"),
- {"id": self.product1_id, "name": "测试商品1", "price": self.product1_price,
- "store_id": self.store1_id, "cat_id": self.category1_id}
+ "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 100, UTC_TIMESTAMP(), UTC_TIMESTAMP())"
+ ),
+ {
+ "id": self.product1_id,
+ "name": "测试商品1",
+ "price": self.product1_price,
+ "store_id": self.store1_id,
+ "cat_id": self.category1_id,
+ },
)
self.product2_id = 302
self.product2_price = 55.50
- prod2_exists = self.connection.execute(text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": self.product2_id}).fetchone()
+ prod2_exists = self.connection.execute(
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": self.product2_id}
+ ).fetchone()
if not prod2_exists:
self.connection.execute(
text(
- "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 50, UTC_TIMESTAMP(), UTC_TIMESTAMP())"),
- {"id": self.product2_id, "name": "测试商品2", "price": self.product2_price,
- "store_id": self.store1_id, "cat_id": self.category1_id}
+ "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 50, UTC_TIMESTAMP(), UTC_TIMESTAMP())"
+ ),
+ {
+ "id": self.product2_id,
+ "name": "测试商品2",
+ "price": self.product2_price,
+ "store_id": self.store1_id,
+ "cat_id": self.category1_id,
+ },
)
except Exception as e:
self.fail(f"Error setting up initial test data: {e}")
@@ -210,8 +224,9 @@ def test_update_cart_item_quantity_success(self):
cart_item_id_to_update = add_resp.json()["CartItemID"]
update_payload_schema = CartItemUpdateRequest(Quantity=5)
- response = self.client.put(f"/api/v1/cart/items/{cart_item_id_to_update}",
- json=update_payload_schema.model_dump())
+ response = self.client.put(
+ f"/api/v1/cart/items/{cart_item_id_to_update}", json=update_payload_schema.model_dump()
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
updated_item = response.json()
@@ -282,5 +297,6 @@ def test_clear_user_cart_empty_cart(self):
# 如果 CartService.clear_cart 返回了删除数量,并且端点将其放入 Detail,则可以断言
# self.assertEqual(response.json()["Detail"]["items_removed"], 0)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/api_endpoints/test_category_endpoints.py b/src/backend/test/integration/api_endpoints/test_category_endpoints.py
index 61274fe..a82d97e 100644
--- a/src/backend/test/integration/api_endpoints/test_category_endpoints.py
+++ b/src/backend/test/integration/api_endpoints/test_category_endpoints.py
@@ -12,7 +12,7 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
class TestCategoryEndpoints(BaseDBTestCaseAutoRollback):
@@ -20,18 +20,20 @@ def setUp(self):
super().setUp()
# 初始化测试客户端
self.client = TestClient(app, raise_server_exceptions=False)
-
+
# 创建测试数据 - 创建用户(用于模拟认证),使用随机用户名避免唯一键冲突
random_suffix = generate_random_string()
test_user_name = f"testadmin_{random_suffix}"
test_user_email = f"{test_user_name}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, Email, PasswordHash, UserRole)
VALUES (:username, :email, 'hashed_password', 'admin')
- """),
- {"username": test_user_name, "email": test_user_email}
+ """
+ ),
+ {"username": test_user_name, "email": test_user_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
@@ -39,20 +41,14 @@ def setUp(self):
def test_create_category(self):
"""测试创建分类API"""
# 准备请求数据
- category_data = {
- "CategoryName": "API测试分类",
- "CategoryDescription": "通过API创建的测试分类"
- }
-
+ category_data = {"CategoryName": "API测试分类", "CategoryDescription": "通过API创建的测试分类"}
+
# 发送请求
- response = self.client.post(
- "/api/v1/category",
- json=category_data
- )
-
+ response = self.client.post("/api/v1/category", json=category_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertIn("CategoryID", data)
@@ -63,34 +59,25 @@ def test_create_category(self):
def test_create_subcategory(self):
"""测试创建子分类API"""
# 准备数据 - 先创建父分类
- parent_data = {
- "CategoryName": "API父分类",
- "CategoryDescription": "通过API创建的父分类"
- }
-
- parent_response = self.client.post(
- "/api/v1/category",
- json=parent_data
- )
-
+ parent_data = {"CategoryName": "API父分类", "CategoryDescription": "通过API创建的父分类"}
+
+ parent_response = self.client.post("/api/v1/category", json=parent_data)
+
parent_id = parent_response.json()["CategoryID"]
-
+
# 准备子分类数据
subcategory_data = {
"CategoryName": "API子分类",
"CategoryDescription": "通过API创建的子分类",
- "ParentCategoryID": parent_id
+ "ParentCategoryID": parent_id,
}
-
+
# 发送请求
- response = self.client.post(
- "/api/v1/category",
- json=subcategory_data
- )
-
+ response = self.client.post("/api/v1/category", json=subcategory_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["CategoryName"], subcategory_data["CategoryName"])
@@ -98,17 +85,11 @@ def test_create_subcategory(self):
def test_create_category_with_invalid_parent(self):
"""测试使用无效的父分类ID创建分类API"""
- category_data = {
- "CategoryName": "无效父分类的分类",
- "ParentCategoryID": 999999 # 不存在的ID
- }
-
+ category_data = {"CategoryName": "无效父分类的分类", "ParentCategoryID": 999999} # 不存在的ID
+
# 发送请求
- response = self.client.post(
- "/api/v1/category",
- json=category_data
- )
-
+ response = self.client.post("/api/v1/category", json=category_data)
+
# 验证响应 - 应当返回400错误
self.assertEqual(response.status_code, 400)
self.assertIn("父分类ID", response.json()["detail"])
@@ -116,29 +97,23 @@ def test_create_category_with_invalid_parent(self):
def test_get_category_by_id(self):
"""测试通过ID获取分类API"""
# 准备数据 - 先创建一个分类
- category_data = {
- "CategoryName": "获取测试分类",
- "CategoryDescription": "用于测试获取API的分类"
- }
-
- create_response = self.client.post(
- "/api/v1/category",
- json=category_data
- )
-
+ category_data = {"CategoryName": "获取测试分类", "CategoryDescription": "用于测试获取API的分类"}
+
+ create_response = self.client.post("/api/v1/category", json=category_data)
+
test_category_id = create_response.json()["CategoryID"]
-
+
# 发送请求
response = self.client.get(f"/api/v1/category/{test_category_id}")
-
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["CategoryID"], test_category_id)
self.assertEqual(data["CategoryName"], category_data["CategoryName"])
-
+
# 测试获取不存在的分类
non_existent_response = self.client.get("/api/v1/category/999999")
self.assertEqual(non_existent_response.status_code, 404)
@@ -148,41 +123,31 @@ def test_list_categories(self):
# 准备数据 - 创建一些顶级分类
for i in range(3):
self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": f"顶级分类{i+1}",
- "CategoryDescription": f"顶级分类描述{i+1}"
- }
+ "/api/v1/category", json={"CategoryName": f"顶级分类{i+1}", "CategoryDescription": f"顶级分类描述{i+1}"}
)
-
+
# 发送请求 - 获取所有顶级分类
response = self.client.get("/api/v1/category")
-
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
# 确认返回的数据中至少包含我们创建的3个分类
self.assertGreaterEqual(len(data), 3)
-
+
# 准备数据 - 为第一个分类创建子分类
parent_id = data[0]["CategoryID"]
for i in range(2):
- self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": f"子分类{i+1}",
- "ParentCategoryID": parent_id
- }
- )
-
+ self.client.post("/api/v1/category", json={"CategoryName": f"子分类{i+1}", "ParentCategoryID": parent_id})
+
# 发送请求 - 获取特定父分类的子分类
response = self.client.get(f"/api/v1/category?parent_id={parent_id}")
-
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(len(data), 2)
@@ -193,76 +158,48 @@ def test_get_category_tree(self):
"""测试获取分类树结构API"""
# 准备数据 - 创建一些分类和子分类
# 顶级分类
- response1 = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "电子产品"}
- )
+ response1 = self.client.post("/api/v1/category", json={"CategoryName": "电子产品"})
cat1_id = response1.json()["CategoryID"]
-
- response2 = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "服装"}
- )
+
+ response2 = self.client.post("/api/v1/category", json={"CategoryName": "服装"})
cat2_id = response2.json()["CategoryID"]
-
+
# 电子产品的子分类
- response1_1 = self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "手机",
- "ParentCategoryID": cat1_id
- }
- )
-
- response1_2 = self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "电脑",
- "ParentCategoryID": cat1_id
- }
- )
+ response1_1 = self.client.post("/api/v1/category", json={"CategoryName": "手机", "ParentCategoryID": cat1_id})
+
+ response1_2 = self.client.post("/api/v1/category", json={"CategoryName": "电脑", "ParentCategoryID": cat1_id})
cat1_2_id = response1_2.json()["CategoryID"]
-
+
# 服装的子分类
- response2_1 = self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "男装",
- "ParentCategoryID": cat2_id
- }
- )
-
+ response2_1 = self.client.post("/api/v1/category", json={"CategoryName": "男装", "ParentCategoryID": cat2_id})
+
# 电脑的子分类
response1_2_1 = self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "笔记本电脑",
- "ParentCategoryID": cat1_2_id
- }
+ "/api/v1/category", json={"CategoryName": "笔记本电脑", "ParentCategoryID": cat1_2_id}
)
-
+
# 发送请求
response = self.client.get("/api/v1/category/tree")
-
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据 - 树结构
data = response.json()
self.assertGreaterEqual(len(data), 2) # 至少两个顶级分类
-
+
# 找到"电子产品"分类
electronic_cat = next((cat for cat in data if cat["CategoryName"] == "电子产品"), None)
self.assertIsNotNone(electronic_cat)
self.assertIn("Children", electronic_cat)
self.assertGreaterEqual(len(electronic_cat["Children"]), 2) # 至少两个子分类
-
+
# 找到"电脑"分类
computer_cat = next((cat for cat in electronic_cat["Children"] if cat["CategoryName"] == "电脑"), None)
self.assertIsNotNone(computer_cat)
self.assertIn("Children", computer_cat)
self.assertEqual(len(computer_cat["Children"]), 1) # 一个子分类
-
+
# 找到"服装"分类
clothing_cat = next((cat for cat in data if cat["CategoryName"] == "服装"), None)
self.assertIsNotNone(clothing_cat)
@@ -273,100 +210,66 @@ def test_update_category(self):
"""测试更新分类API"""
# 准备数据 - 先创建一个分类
create_response = self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "原始分类名称",
- "CategoryDescription": "原始描述"
- }
+ "/api/v1/category", json={"CategoryName": "原始分类名称", "CategoryDescription": "原始描述"}
)
-
+
test_category_id = create_response.json()["CategoryID"]
-
+
# 准备更新数据
- update_data = {
- "CategoryName": "更新后的分类名称",
- "CategoryDescription": "更新后的描述"
- }
-
+ update_data = {"CategoryName": "更新后的分类名称", "CategoryDescription": "更新后的描述"}
+
# 发送请求
- response = self.client.put(
- f"/api/v1/category/{test_category_id}",
- json=update_data
- )
-
+ response = self.client.put(f"/api/v1/category/{test_category_id}", json=update_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["CategoryID"], test_category_id)
self.assertEqual(data["CategoryName"], update_data["CategoryName"])
self.assertEqual(data["CategoryDescription"], update_data["CategoryDescription"])
-
+
# 测试更新不存在的分类
- non_existent_response = self.client.put(
- "/api/v1/category/999999",
- json={"CategoryName": "不存在的分类"}
- )
+ non_existent_response = self.client.put("/api/v1/category/999999", json={"CategoryName": "不存在的分类"})
self.assertEqual(non_existent_response.status_code, 404)
def test_update_category_parent(self):
"""测试更新分类的父分类API"""
# 准备数据 - 创建两个分类
- response1 = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "分类1"}
- )
+ response1 = self.client.post("/api/v1/category", json={"CategoryName": "分类1"})
category1_id = response1.json()["CategoryID"]
-
- response2 = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "分类2"}
- )
+
+ response2 = self.client.post("/api/v1/category", json={"CategoryName": "分类2"})
category2_id = response2.json()["CategoryID"]
-
+
# 准备更新数据 - 将分类2设为分类1的父分类
- update_data = {
- "ParentCategoryID": category2_id
- }
-
+ update_data = {"ParentCategoryID": category2_id}
+
# 发送请求
- response = self.client.put(
- f"/api/v1/category/{category1_id}",
- json=update_data
- )
-
+ response = self.client.put(f"/api/v1/category/{category1_id}", json=update_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["ParentCategoryID"], category2_id)
-
+
# 测试循环引用 - 将分类1设为分类2的父分类(应该失败)
- cyclic_update_data = {
- "ParentCategoryID": category1_id
- }
-
- cyclic_response = self.client.put(
- f"/api/v1/category/{category2_id}",
- json=cyclic_update_data
- )
-
+ cyclic_update_data = {"ParentCategoryID": category1_id}
+
+ cyclic_response = self.client.put(f"/api/v1/category/{category2_id}", json=cyclic_update_data)
+
# 验证响应 - 应当返回400错误
self.assertEqual(cyclic_response.status_code, 400)
self.assertIn("循环引用", cyclic_response.json()["detail"])
-
+
# 测试自引用 - 将分类设为自己的父分类(应该失败)
- self_update_data = {
- "ParentCategoryID": category1_id
- }
-
- self_response = self.client.put(
- f"/api/v1/category/{category1_id}",
- json=self_update_data
- )
-
+ self_update_data = {"ParentCategoryID": category1_id}
+
+ self_response = self.client.put(f"/api/v1/category/{category1_id}", json=self_update_data)
+
# 验证响应 - 应当返回400错误
self.assertEqual(self_response.status_code, 400)
self.assertIn("父分类不能是自己", self_response.json()["detail"])
@@ -374,23 +277,20 @@ def test_update_category_parent(self):
def test_delete_category(self):
"""测试删除分类API"""
# 准备数据 - 创建一个分类
- create_response = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "准备删除的分类"}
- )
-
+ create_response = self.client.post("/api/v1/category", json={"CategoryName": "准备删除的分类"})
+
test_category_id = create_response.json()["CategoryID"]
-
+
# 发送请求
response = self.client.delete(f"/api/v1/category/{test_category_id}")
-
+
# 验证响应
self.assertEqual(response.status_code, 204) # 成功但无内容
-
+
# 验证分类已被删除
get_response = self.client.get(f"/api/v1/category/{test_category_id}")
self.assertEqual(get_response.status_code, 404)
-
+
# 测试删除不存在的分类
non_existent_response = self.client.delete("/api/v1/category/999999")
self.assertEqual(non_existent_response.status_code, 404)
@@ -398,23 +298,14 @@ def test_delete_category(self):
def test_delete_category_with_subcategories(self):
"""测试删除有子分类的分类API(应该失败)"""
# 准备数据 - 创建父分类和子分类
- parent_response = self.client.post(
- "/api/v1/category",
- json={"CategoryName": "父分类"}
- )
+ parent_response = self.client.post("/api/v1/category", json={"CategoryName": "父分类"})
parent_id = parent_response.json()["CategoryID"]
-
- self.client.post(
- "/api/v1/category",
- json={
- "CategoryName": "子分类",
- "ParentCategoryID": parent_id
- }
- )
-
+
+ self.client.post("/api/v1/category", json={"CategoryName": "子分类", "ParentCategoryID": parent_id})
+
# 发送请求 - 尝试删除有子分类的父分类
response = self.client.delete(f"/api/v1/category/{parent_id}")
-
+
# 验证响应 - 应当返回400错误
self.assertEqual(response.status_code, 400)
self.assertIn("子分类", response.json()["detail"])
@@ -430,20 +321,20 @@ def test_delete_category_with_products(self):
try:
# 创建全新的独立连接来避免与测试类的connection冲突
independent_engine = self.engine.execution_options(isolation_level="READ COMMITTED")
-
+
# 使用不同的连接和显式执行SQL,避免SQLAlchemy的自动事务管理
raw_conn = independent_engine.raw_connection()
cursor = raw_conn.cursor()
-
+
# 配置会话
cursor.execute("SET SESSION innodb_lock_wait_timeout = 120")
cursor.execute("SET SESSION transaction_isolation = 'READ-COMMITTED'")
-
+
# 创建测试用户,使用随机用户名避免唯一键冲突
random_suffix = generate_random_string()
test_user_name = f"testuser_{random_suffix}"
test_user_email = f"{test_user_name}@example.com"
-
+
# 开始显式事务
cursor.execute("BEGIN")
try:
@@ -453,12 +344,12 @@ def test_delete_category_with_products(self):
INSERT INTO User (Username, Email, PasswordHash, UserRole)
VALUES (%s, %s, 'test_password', 'merchant')
""",
- (test_user_name, test_user_email)
+ (test_user_name, test_user_email),
)
cursor.execute("SELECT LAST_INSERT_ID()")
user_id = cursor.fetchone()[0]
created_resources.append(("user", user_id))
-
+
# 创建测试分类
cursor.execute(
"""
@@ -469,50 +360,50 @@ def test_delete_category_with_products(self):
cursor.execute("SELECT LAST_INSERT_ID()")
category_id = cursor.fetchone()[0]
created_resources.append(("category", category_id))
-
+
# 创建测试店铺
cursor.execute(
"""
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', %s, '用于测试的店铺', 'ACTIVE')
""",
- (user_id,)
+ (user_id,),
)
cursor.execute("SELECT LAST_INSERT_ID()")
store_id = cursor.fetchone()[0]
created_resources.append(("store", store_id))
-
+
# 创建关联到分类的商品
cursor.execute(
"""
INSERT INTO Product (ProductName, Price, StoreID, CategoryID, StockQuantity, ProductStatus)
VALUES ('测试商品', 99.99, %s, %s, 10, 'ACTIVE')
""",
- (store_id, category_id)
+ (store_id, category_id),
)
cursor.execute("SELECT LAST_INSERT_ID()")
product_id = cursor.fetchone()[0]
created_resources.append(("product", product_id))
-
+
# 提交事务
raw_conn.commit()
-
+
# 使用测试客户端进行测试
client = TestClient(app, raise_server_exceptions=False)
-
+
# 发送请求 - 尝试删除有商品的分类
response = client.delete(f"/api/v1/category/{category_id}")
-
+
# 验证响应 - 应当返回400错误
self.assertEqual(response.status_code, 400)
self.assertIn("商品", response.json()["detail"])
-
+
except Exception as e:
# 回滚事务
raw_conn.rollback()
print(f"创建测试数据失败: {e}")
raise
-
+
except Exception as e:
print(f"测试删除有商品的分类API遇到错误: {e}")
self.fail(f"测试失败,遇到错误: {e}")
@@ -523,20 +414,20 @@ def test_delete_category_with_products(self):
cursor.close()
except Exception as e:
print(f"关闭cursor时发生错误: {e}")
-
+
if raw_conn:
try:
raw_conn.close()
except Exception as e:
print(f"关闭连接时发生错误: {e}")
-
+
# 确保测试资源被清理,防止影响其他测试
try:
if independent_engine:
# 创建新的连接进行清理
cleanup_conn = independent_engine.raw_connection()
cleanup_cursor = cleanup_conn.cursor()
-
+
try:
for resource_type, resource_id in reversed(created_resources):
try:
@@ -545,12 +436,14 @@ def test_delete_category_with_products(self):
elif resource_type == "store":
cleanup_cursor.execute("DELETE FROM Store WHERE StoreID = %s", (resource_id,))
elif resource_type == "category":
- cleanup_cursor.execute("DELETE FROM ProductCategory WHERE CategoryID = %s", (resource_id,))
+ cleanup_cursor.execute(
+ "DELETE FROM ProductCategory WHERE CategoryID = %s", (resource_id,)
+ )
elif resource_type == "user":
cleanup_cursor.execute("DELETE FROM User WHERE UserID = %s", (resource_id,))
except Exception as cleanup_error:
print(f"清理资源 {resource_type} ID:{resource_id} 时出错: {cleanup_error}")
-
+
cleanup_conn.commit()
finally:
cleanup_cursor.close()
@@ -559,5 +452,5 @@ def test_delete_category_with_products(self):
print(f"清理资源时发生错误: {e}")
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/backend/test/integration/api_endpoints/test_order_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_order_endpoints_async.py
index 2a9a8a3..a59716e 100644
--- a/src/backend/test/integration/api_endpoints/test_order_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_order_endpoints_async.py
@@ -14,9 +14,17 @@
from backend.app.main import app
from backend.app.schemas.user_schema import UserResponse as CurrentUserSchema
from backend.app.schemas.order_schema import (
- OrderCreateRequest, OrderItemCreationInput, InitiateOrderResponse,
- OrderListResponse, OrderViewResponse, OrderStatusEnum, PaymentTransactionStatusEnum,
- CreatedOrderDetailResponse, OrderItemDetailResponse, OrderUpdateStatusRequest, OrderActionResponse
+ OrderCreateRequest,
+ OrderItemCreationInput,
+ InitiateOrderResponse,
+ OrderListResponse,
+ OrderViewResponse,
+ OrderStatusEnum,
+ PaymentTransactionStatusEnum,
+ CreatedOrderDetailResponse,
+ OrderItemDetailResponse,
+ OrderUpdateStatusRequest,
+ OrderActionResponse,
)
from backend.app.services.order_service import OrderService
from backend.app.crud.user_crud import UserCRUD
@@ -76,46 +84,79 @@ def setUpClass(cls):
with conn_for_class_setup.begin(): # Start a transaction for class setup
# 1. Create Users
users_data = [
- {"UserID": cls.user1_id_class, "Username": cls.user1_username_class,
- "PasswordHash": cls.user1_password_hash_class, "Email": cls.user1_email_class},
- {"UserID": cls.user2_id_class, "Username": cls.user2_username_class,
- "PasswordHash": cls.user2_password_hash_class, "Email": cls.user2_email_class},
+ {
+ "UserID": cls.user1_id_class,
+ "Username": cls.user1_username_class,
+ "PasswordHash": cls.user1_password_hash_class,
+ "Email": cls.user1_email_class,
+ },
+ {
+ "UserID": cls.user2_id_class,
+ "Username": cls.user2_username_class,
+ "PasswordHash": cls.user2_password_hash_class,
+ "Email": cls.user2_email_class,
+ },
]
for ud in users_data:
- conn_for_class_setup.execute(text(
- "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE') ON DUPLICATE KEY UPDATE Username=VALUES(Username)"),
- ud)
+ conn_for_class_setup.execute(
+ text(
+ "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE') ON DUPLICATE KEY UPDATE Username=VALUES(Username)"
+ ),
+ ud,
+ )
# 2. Create Shipping Address for User1
conn_for_class_setup.execute(
text(
- "INSERT INTO ShippingAddress (AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault) VALUES (:id, :uid, 'UserOne ClassRecipient', '13000000011', '1 Class St, UserOne City', TRUE) ON DUPLICATE KEY UPDATE RecipientName=VALUES(RecipientName)"),
- {"id": cls.address1_user1_id_class, "uid": cls.user1_id_class}
+ "INSERT INTO ShippingAddress (AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault) VALUES (:id, :uid, 'UserOne ClassRecipient', '13000000011', '1 Class St, UserOne City', TRUE) ON DUPLICATE KEY UPDATE RecipientName=VALUES(RecipientName)"
+ ),
+ {"id": cls.address1_user1_id_class, "uid": cls.user1_id_class},
+ )
+ conn_for_class_setup.execute(
+ text("UPDATE User SET DefaultAddressID = :addr_id WHERE UserID = :uid"),
+ {"addr_id": cls.address1_user1_id_class, "uid": cls.user1_id_class},
)
- conn_for_class_setup.execute(text("UPDATE User SET DefaultAddressID = :addr_id WHERE UserID = :uid"),
- {"addr_id": cls.address1_user1_id_class, "uid": cls.user1_id_class})
# 3. Create Category & Store
- conn_for_class_setup.execute(text(
- "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, 'Class Test Category') ON DUPLICATE KEY UPDATE CategoryName=VALUES(CategoryName)"),
- {"id": cls.category_id_class})
- conn_for_class_setup.execute(text(
- "INSERT INTO Store (StoreID, StoreName, OwnerUserID) VALUES (:id, 'Class Test Store', :owner_id) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"),
- {"id": cls.store_id_class, "owner_id": cls.user1_id_class})
+ conn_for_class_setup.execute(
+ text(
+ "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, 'Class Test Category') ON DUPLICATE KEY UPDATE CategoryName=VALUES(CategoryName)"
+ ),
+ {"id": cls.category_id_class},
+ )
+ conn_for_class_setup.execute(
+ text(
+ "INSERT INTO Store (StoreID, StoreName, OwnerUserID) VALUES (:id, 'Class Test Store', :owner_id) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"
+ ),
+ {"id": cls.store_id_class, "owner_id": cls.user1_id_class},
+ )
# 4. Create Products
products_data_class = [
- {"ProductID": cls.product1_id_class, "ProductName": "Class Test Product 1",
- "Price": cls.product1_price_class, "StockQuantity": cls.product1_stock_class,
- "StoreID": cls.store_id_class, "CategoryID": cls.category_id_class},
- {"ProductID": cls.product2_id_class, "ProductName": "Class Test Product 2",
- "Price": cls.product2_price_class, "StockQuantity": cls.product2_stock_class,
- "StoreID": cls.store_id_class, "CategoryID": cls.category_id_class},
+ {
+ "ProductID": cls.product1_id_class,
+ "ProductName": "Class Test Product 1",
+ "Price": cls.product1_price_class,
+ "StockQuantity": cls.product1_stock_class,
+ "StoreID": cls.store_id_class,
+ "CategoryID": cls.category_id_class,
+ },
+ {
+ "ProductID": cls.product2_id_class,
+ "ProductName": "Class Test Product 2",
+ "Price": cls.product2_price_class,
+ "StockQuantity": cls.product2_stock_class,
+ "StoreID": cls.store_id_class,
+ "CategoryID": cls.category_id_class,
+ },
]
for p_data in products_data_class:
- conn_for_class_setup.execute(text(
- "INSERT INTO Product (ProductID, ProductName, Price, StoreID, CategoryID, StockQuantity) VALUES (:ProductID, :ProductName, :Price, :StoreID, :CategoryID, :StockQuantity) ON DUPLICATE KEY UPDATE ProductName=VALUES(ProductName)"),
- p_data)
+ conn_for_class_setup.execute(
+ text(
+ "INSERT INTO Product (ProductID, ProductName, Price, StoreID, CategoryID, StockQuantity) VALUES (:ProductID, :ProductName, :Price, :StoreID, :CategoryID, :StockQuantity) ON DUPLICATE KEY UPDATE ProductName=VALUES(ProductName)"
+ ),
+ p_data,
+ )
# Transaction committed by exiting 'with' block
logger.info(f"--- {cls.__name__}: Shared class-level data committed ---")
except Exception as e:
@@ -136,18 +177,18 @@ def setUp(self):
Username=self.user1_username_class,
Email=self.user1_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
self.mock_current_user2_schema = CurrentUserSchema(
UserID=self.user2_id_class, # Use class-level ID
Username=self.user2_username_class,
Email=self.user2_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# Dependency Overrides
@@ -167,7 +208,7 @@ def override_get_db_connection() -> Connection:
address_crud=self.real_address_crud,
cart_item_crud=self.real_cart_item_crud,
product_crud=self.real_product_crud,
- payment_transaction_crud=self.real_payment_transaction_crud
+ payment_transaction_crud=self.real_payment_transaction_crud,
)
def override_get_order_service() -> OrderService:
@@ -191,17 +232,32 @@ def _setup_per_test_dynamic_data(self):
self.cart_item1_id_test = 501
self.cart_item2_id_test = 502
cart_items_data = [
- {"CartItemID": self.cart_item1_id_test, "UserID": self.user1_id_class,
- "ProductID": self.product1_id_class, "Quantity": 2, "PriceAtAddition": self.product1_price_class},
- {"CartItemID": self.cart_item2_id_test, "UserID": self.user1_id_class,
- "ProductID": self.product2_id_class, "Quantity": 1, "PriceAtAddition": self.product2_price_class},
+ {
+ "CartItemID": self.cart_item1_id_test,
+ "UserID": self.user1_id_class,
+ "ProductID": self.product1_id_class,
+ "Quantity": 2,
+ "PriceAtAddition": self.product1_price_class,
+ },
+ {
+ "CartItemID": self.cart_item2_id_test,
+ "UserID": self.user1_id_class,
+ "ProductID": self.product2_id_class,
+ "Quantity": 1,
+ "PriceAtAddition": self.product2_price_class,
+ },
]
for ci_data in cart_items_data:
- if not self.connection.execute(text("SELECT CartItemID FROM CartItem WHERE CartItemID = :CartItemID"),
- {"CartItemID": ci_data["CartItemID"]}).fetchone():
- self.connection.execute(text(
- "INSERT INTO CartItem (CartItemID, UserID, ProductID, Quantity, PriceAtAddition) VALUES (:CartItemID, :UserID, :ProductID, :Quantity, :PriceAtAddition)"),
- ci_data)
+ if not self.connection.execute(
+ text("SELECT CartItemID FROM CartItem WHERE CartItemID = :CartItemID"),
+ {"CartItemID": ci_data["CartItemID"]},
+ ).fetchone():
+ self.connection.execute(
+ text(
+ "INSERT INTO CartItem (CartItemID, UserID, ProductID, Quantity, PriceAtAddition) VALUES (:CartItemID, :UserID, :ProductID, :Quantity, :PriceAtAddition)"
+ ),
+ ci_data,
+ )
except Exception as e:
self.fail(f"Error setting up per-test dynamic data: {e}")
@@ -239,9 +295,9 @@ async def test_create_order_successfully(self):
ShippingAddressID=self.address1_user1_id_class, # Use class-level address ID
Items=[
OrderItemCreationInput(CartItemID=self.cart_item1_id_test), # Use per-test cart item ID
- OrderItemCreationInput(CartItemID=self.cart_item2_id_test)
+ OrderItemCreationInput(CartItemID=self.cart_item2_id_test),
],
- Notes_ByUser="Please pack carefully."
+ Notes_ByUser="Please pack carefully.",
)
response = self.client.post("/api/v1/order/create", json=order_create_payload.model_dump())
@@ -252,13 +308,15 @@ async def test_create_order_successfully(self):
self.assertIsNotNone(init_resp.PaymentTransactionID)
self.assertGreater(len(init_resp.OrdersCreated), 0)
- pt_db = self.real_payment_transaction_crud.get_payment_transaction_by_id(self.connection,
- payment_transaction_id=init_resp.PaymentTransactionID)
+ pt_db = self.real_payment_transaction_crud.get_payment_transaction_by_id(
+ self.connection, payment_transaction_id=init_resp.PaymentTransactionID
+ )
self.assertIsNotNone(pt_db)
self.assertEqual(pt_db["UserID"], self.user1_id_class)
self.assertEqual(pt_db["Status"], PaymentTransactionStatusEnum.PENDING.value)
expected_total_due = (self.product1_price_class * 2) + (
- self.product2_price_class * 1) # Based on cart item quantities
+ self.product2_price_class * 1
+ ) # Based on cart item quantities
self.assertAlmostEqual(Decimal(pt_db["TotalAmount"]), expected_total_due, places=2)
# get the address text from the address table
@@ -275,8 +333,9 @@ async def test_create_order_successfully(self):
self.assertEqual(order_db["ShippingAddress_PhoneNumber"], address_db["PhoneNumber"])
self.assertEqual(order_db["ShippingAddress_Full"], address_db["FullAddress_Text"])
- order_items_db = self.real_order_item_crud.get_order_items_by_order_id(self.connection,
- order_id=created_order_resp.OrderID)
+ order_items_db = self.real_order_item_crud.get_order_items_by_order_id(
+ self.connection, order_id=created_order_resp.OrderID
+ )
self.assertGreater(len(order_items_db), 0)
prod1_after = self.real_product_crud.get_product_by_id(self.connection, product_id=self.product1_id_class)
@@ -285,9 +344,11 @@ async def test_create_order_successfully(self):
self.assertEqual(prod2_after["StockQuantity"], self.product2_stock_class - 1)
self.assertIsNone(
- self.real_cart_item_crud.get_cart_item_by_id(self.connection, cart_item_id=self.cart_item1_id_test))
+ self.real_cart_item_crud.get_cart_item_by_id(self.connection, cart_item_id=self.cart_item1_id_test)
+ )
self.assertIsNone(
- self.real_cart_item_crud.get_cart_item_by_id(self.connection, cart_item_id=self.cart_item2_id_test))
+ self.real_cart_item_crud.get_cart_item_by_id(self.connection, cart_item_id=self.cart_item2_id_test)
+ )
# ... (rest of your test methods, ensuring they use class-level IDs for shared data
# and per-test IDs for dynamic data like cart items created in _setup_per_test_dynamic_data)
@@ -297,13 +358,13 @@ async def test_create_order_insufficient_stock(self):
stock_set_by_test = 1
self.connection.execute(
text("UPDATE Product SET StockQuantity = :stock WHERE ProductID = :pid"),
- {"stock": stock_set_by_test, "pid": self.product1_id_class}
+ {"stock": stock_set_by_test, "pid": self.product1_id_class},
)
order_create_payload = OrderCreateRequest(
ShippingAddressID=self.address1_user1_id_class,
Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)], # CartItem1 requests Quantity=2
- Notes_ByUser="Test notes for insufficient stock"
+ Notes_ByUser="Test notes for insufficient stock",
)
response = self.client.post("/api/v1/order/create", json=order_create_payload.model_dump())
@@ -314,8 +375,9 @@ async def test_create_order_insufficient_stock(self):
# The service's nested transaction for stock deduction should have rolled back.
# The stock should be what the test method set it to (1),
# as the outer test transaction is still active at this point.
- prod1_after_failed_order = self.real_product_crud.get_product_by_id(self.connection,
- product_id=self.product1_id_class)
+ prod1_after_failed_order = self.real_product_crud.get_product_by_id(
+ self.connection, product_id=self.product1_id_class
+ )
self.assertIsNotNone(prod1_after_failed_order)
self.assertEqual(prod1_after_failed_order["StockQuantity"], stock_set_by_test) # ⭐ Corrected Assertion
@@ -324,7 +386,7 @@ async def test_get_my_orders_success_with_orders(self):
# Create an order using class-level data
order_create_payload = OrderCreateRequest(
ShippingAddressID=self.address1_user1_id_class,
- Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)]
+ Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)],
)
create_response = self.client.post("/api/v1/order/create", json=order_create_payload.model_dump())
self.assertEqual(create_response.status_code, 201)
@@ -348,8 +410,10 @@ async def test_get_my_orders_success_no_orders_for_user2(self):
# --- III. GET /{order_id} (Get Order Details) ---
async def test_get_order_details_success_own_order(self):
- create_payload = OrderCreateRequest(ShippingAddressID=self.address1_user1_id_class,
- Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)])
+ create_payload = OrderCreateRequest(
+ ShippingAddressID=self.address1_user1_id_class,
+ Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)],
+ )
create_resp = self.client.post("/api/v1/order/create", json=create_payload.model_dump())
order_id = create_resp.json()["OrdersCreated"][0]["OrderID"]
@@ -361,8 +425,10 @@ async def test_get_order_details_success_own_order(self):
async def test_get_order_details_not_owned_by_user(self):
# User1 (default current_user) creates an order
- create_payload = OrderCreateRequest(ShippingAddressID=self.address1_user1_id_class,
- Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)])
+ create_payload = OrderCreateRequest(
+ ShippingAddressID=self.address1_user1_id_class,
+ Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)],
+ )
create_resp_user1 = self.client.post("/api/v1/order/create", json=create_payload.model_dump())
order_id_user1 = create_resp_user1.json()["OrdersCreated"][0]["OrderID"]
@@ -375,18 +441,21 @@ async def test_get_order_details_not_owned_by_user(self):
# (Keep existing status update tests, ensuring they use class-level order IDs if appropriate)
# Example:
async def test_update_order_status_user_cancels_pending_payment_order(self):
- create_payload = OrderCreateRequest(ShippingAddressID=self.address1_user1_id_class,
- Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)])
+ create_payload = OrderCreateRequest(
+ ShippingAddressID=self.address1_user1_id_class,
+ Items=[OrderItemCreationInput(CartItemID=self.cart_item1_id_test)],
+ )
create_resp = self.client.post("/api/v1/order/create", json=create_payload.model_dump())
order_id = create_resp.json()["OrdersCreated"][0]["OrderID"]
- update_payload = OrderUpdateStatusRequest(NewStatus=OrderStatusEnum.CANCELLED_BY_USER,
- UserNotes="Changed my mind")
+ update_payload = OrderUpdateStatusRequest(
+ NewStatus=OrderStatusEnum.CANCELLED_BY_USER, UserNotes="Changed my mind"
+ )
response = self.client.put(f"/api/v1/order/{order_id}/status", json=update_payload.model_dump())
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
# ... rest of assertions ...
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/api_endpoints/test_payment_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_payment_endpoints_async.py
index c2d64c0..e6645f3 100644
--- a/src/backend/test/integration/api_endpoints/test_payment_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_payment_endpoints_async.py
@@ -19,12 +19,10 @@
OrderCreateRequest,
OrderItemCreationInput,
OrderStatusEnum,
- PaymentTransactionStatusEnum, InitiateOrderResponse # Make sure this is accessible
-)
-from backend.app.schemas.payment_schema import (
- SimulatedExternPaymentResponse,
- PaymentProcessingResponse
+ PaymentTransactionStatusEnum,
+ InitiateOrderResponse, # Make sure this is accessible
)
+from backend.app.schemas.payment_schema import SimulatedExternPaymentResponse, PaymentProcessingResponse
from backend.app.services.order_service import OrderService # For direct calls in setup
from backend.app.crud.user_crud import UserCRUD
from backend.app.crud.product_crud import ProductCRUD
@@ -74,34 +72,55 @@ def _create_shared_class_data(cls, conn: Connection):
# 1. Create User1
conn.execute(
text(
- "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE Username=VALUES(Username)"),
- {"UserID": cls.user1_id_class, "Username": cls.user1_username_class,
- "PasswordHash": cls.user1_password_hash_class, "Email": cls.user1_email_class}
+ "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE Username=VALUES(Username)"
+ ),
+ {
+ "UserID": cls.user1_id_class,
+ "Username": cls.user1_username_class,
+ "PasswordHash": cls.user1_password_hash_class,
+ "Email": cls.user1_email_class,
+ },
)
# 2. Create Address for User1
conn.execute(
text(
- "INSERT INTO ShippingAddress (AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault) VALUES (:id, :uid, 'Class User Recipient', '130CLASS001', '1 Class St, Setup City', TRUE) ON DUPLICATE KEY UPDATE RecipientName=VALUES(RecipientName)"),
- {"id": cls.address1_user1_id_class, "uid": cls.user1_id_class}
+ "INSERT INTO ShippingAddress (AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault) VALUES (:id, :uid, 'Class User Recipient', '130CLASS001', '1 Class St, Setup City', TRUE) ON DUPLICATE KEY UPDATE RecipientName=VALUES(RecipientName)"
+ ),
+ {"id": cls.address1_user1_id_class, "uid": cls.user1_id_class},
+ )
+ conn.execute(
+ text("UPDATE User SET DefaultAddressID = :addr_id WHERE UserID = :uid"),
+ {"addr_id": cls.address1_user1_id_class, "uid": cls.user1_id_class},
)
- conn.execute(text("UPDATE User SET DefaultAddressID = :addr_id WHERE UserID = :uid"),
- {"addr_id": cls.address1_user1_id_class, "uid": cls.user1_id_class})
# 3. Create Category & Store
- conn.execute(text(
- "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, 'Class Pay Category') ON DUPLICATE KEY UPDATE CategoryName=VALUES(CategoryName)"),
- {"id": cls.category_id_class})
- conn.execute(text(
- "INSERT INTO Store (StoreID, StoreName, OwnerUserID) VALUES (:id, 'Class Pay Store', :owner_id) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"),
- {"id": cls.store_id_class, "owner_id": cls.user1_id_class})
+ conn.execute(
+ text(
+ "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, 'Class Pay Category') ON DUPLICATE KEY UPDATE CategoryName=VALUES(CategoryName)"
+ ),
+ {"id": cls.category_id_class},
+ )
+ conn.execute(
+ text(
+ "INSERT INTO Store (StoreID, StoreName, OwnerUserID) VALUES (:id, 'Class Pay Store', :owner_id) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"
+ ),
+ {"id": cls.store_id_class, "owner_id": cls.user1_id_class},
+ )
# 4. Create Product1
conn.execute(
text(
- "INSERT INTO Product (ProductID, ProductName, Price, StoreID, CategoryID, StockQuantity) VALUES (:id, :name, :price, :sid, :cid, :stock) ON DUPLICATE KEY UPDATE ProductName=VALUES(ProductName)"),
- {"id": cls.product1_id_class, "name": "Class Pay Product 1", "price": cls.product1_price_class,
- "sid": cls.store_id_class, "cid": cls.category_id_class, "stock": cls.product1_stock_class}
+ "INSERT INTO Product (ProductID, ProductName, Price, StoreID, CategoryID, StockQuantity) VALUES (:id, :name, :price, :sid, :cid, :stock) ON DUPLICATE KEY UPDATE ProductName=VALUES(ProductName)"
+ ),
+ {
+ "id": cls.product1_id_class,
+ "name": "Class Pay Product 1",
+ "price": cls.product1_price_class,
+ "sid": cls.store_id_class,
+ "cid": cls.category_id_class,
+ "stock": cls.product1_stock_class,
+ },
)
logger.info(f"--- {cls.__name__}: Shared class-level data creation complete ---")
except Exception as e:
@@ -137,9 +156,9 @@ def setUp(self):
Username=self.user1_username_class,
Email=self.user1_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# Dependency Overrides
@@ -160,11 +179,9 @@ def override_get_db_connection() -> Connection:
address_crud=self.real_address_crud,
cart_item_crud=self.real_cart_item_crud,
product_crud=self.real_product_crud,
- payment_transaction_crud=self.real_payment_transaction_crud
+ payment_transaction_crud=self.real_payment_transaction_crud,
)
-
-
def override_get_order_service() -> OrderService:
return self.real_order_service
@@ -185,22 +202,26 @@ def _setup_per_test_payment_transaction(self):
# 1. 为 User1 创建购物车项目 (在当前测试的事务内)
self.cart_item_id_for_test = 5001 # 使用不同的ID范围以避免与可能的setUpClass数据冲突
cart_item_exists = self.connection.execute(
- text("SELECT CartItemID FROM CartItem WHERE CartItemID = :id"),
- {"id": self.cart_item_id_for_test}
+ text("SELECT CartItemID FROM CartItem WHERE CartItemID = :id"), {"id": self.cart_item_id_for_test}
).fetchone()
if not cart_item_exists:
self.connection.execute(
text(
- "INSERT INTO CartItem (CartItemID, UserID, ProductID, Quantity, PriceAtAddition) VALUES (:id, :uid, :pid, 1, :price)"),
- {"id": self.cart_item_id_for_test, "uid": self.user1_id_class, "pid": self.product1_id_class,
- "price": self.product1_price_class}
+ "INSERT INTO CartItem (CartItemID, UserID, ProductID, Quantity, PriceAtAddition) VALUES (:id, :uid, :pid, 1, :price)"
+ ),
+ {
+ "id": self.cart_item_id_for_test,
+ "uid": self.user1_id_class,
+ "pid": self.product1_id_class,
+ "price": self.product1_price_class,
+ },
)
# 2. 使用 OrderService 创建订单和支付事务 (这将使用 self.connection)
order_create_req = OrderCreateRequest(
ShippingAddressID=self.address1_user1_id_class,
Items=[OrderItemCreationInput(CartItemID=self.cart_item_id_for_test)],
- Notes_ByUser=f"Order for test {self.id()}"
+ Notes_ByUser=f"Order for test {self.id()}",
)
# 由于 setUp 是同步的,而 OrderService 方法是异步的,我们需要在这里运行异步代码
@@ -213,7 +234,7 @@ async def create_order_async():
db=self.connection, # 使用当前测试的事务性连接
user_id=self.user1_id_class,
order_create_request=order_create_req,
- actor_id=self.user1_id_class
+ actor_id=self.user1_id_class,
)
# IsolatedAsyncioTestCase 会为每个 async def test_... 创建事件循环
@@ -234,7 +255,8 @@ async def create_order_async():
self.pending_payment_transaction_id = init_order_resp_model.PaymentTransactionID
self.pending_order_id = init_order_resp_model.OrdersCreated[0].OrderID
logger.info(
- f"INFO: {self.id()} - Created pending PaymentTransactionID: {self.pending_payment_transaction_id} for OrderID: {self.pending_order_id}")
+ f"INFO: {self.id()} - Created pending PaymentTransactionID: {self.pending_payment_transaction_id} for OrderID: {self.pending_order_id}"
+ )
except Exception as e:
self.fail(f"Error setting up per-test payment transaction for {self.id()}: {e}")
@@ -246,8 +268,7 @@ def tearDown(self):
# --- Test POST /{PaymentTransactionID}/simulate-pay ---
async def test_simulate_payment_processing_success(self):
payload = SimulatedExternPaymentResponse(
- SimulatedPaymentMethod="mock_success_pay",
- ExternalGatewayTxID="gw_sim_success_123"
+ SimulatedPaymentMethod="mock_success_pay", ExternalGatewayTxID="gw_sim_success_123"
).model_dump()
response = self.client.post(f"/api/v1/payment/{self.pending_payment_transaction_id}/simulate-pay", json=payload)
@@ -261,8 +282,9 @@ async def test_simulate_payment_processing_success(self):
self.assertIn(self.pending_order_id, data["AffectedOrderIDs"])
# Verify DB: PaymentTransaction status
- pt_db = self.real_payment_transaction_crud.get_payment_transaction_by_id(self.connection,
- payment_transaction_id=self.pending_payment_transaction_id)
+ pt_db = self.real_payment_transaction_crud.get_payment_transaction_by_id(
+ self.connection, payment_transaction_id=self.pending_payment_transaction_id
+ )
self.assertEqual(pt_db["Status"], PaymentTransactionStatusEnum.SUCCESSFUL.value) # type: ignore
self.assertIsNotNone(pt_db["CompletionTime"]) # type: ignore
self.assertEqual(pt_db["ExternalGatewayTransactionID"], "gw_sim_success_123") # type: ignore
@@ -294,15 +316,16 @@ async def test_simulate_payment_processing_tx_not_pending(self):
payment_transaction_id=self.pending_payment_transaction_id,
new_status=PaymentTransactionStatusEnum.SUCCESSFUL.value,
actor_id=self.system_actor_id, # System action
- completion_time=datetime.datetime.now(datetime.timezone.utc)
+ completion_time=datetime.datetime.now(datetime.timezone.utc),
)
# self.connection.commit() # Not needed with transactional tests
payload = SimulatedExternPaymentResponse(SimulatedPaymentMethod="mock_success_pay").model_dump()
response = self.client.post(f"/api/v1/payment/{self.pending_payment_transaction_id}/simulate-pay", json=payload)
- self.assertEqual(response.status_code, status.HTTP_200_OK,
- response.text) # Endpoint currently returns 200 for idempotency
+ self.assertEqual(
+ response.status_code, status.HTTP_200_OK, response.text
+ ) # Endpoint currently returns 200 for idempotency
data = response.json()
self.assertEqual(data["TransactionStatusInSystem"], PaymentTransactionStatusEnum.SUCCESSFUL.value)
self.assertIn("支付已成功确认", data["MessageToUser"]) # Message for already successful
@@ -335,9 +358,9 @@ async def test_get_payment_transaction_status_not_owned_by_user(self):
Username="pay_cls_user2",
Email="user2@test.email",
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=datetime.datetime.now(datetime.timezone.utc),
- LastLoginDate=datetime.datetime.now(datetime.timezone.utc)
+ LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
# Override current user to be user2
@@ -349,8 +372,9 @@ async def test_get_payment_transaction_status_not_owned_by_user(self):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND, response.text)
self.assertIn(
f"Payment transaction with ID {self.pending_payment_transaction_id} not found for user {mock_current_user2_schema.UserID}",
- response.json()["detail"])
+ response.json()["detail"],
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/api_endpoints/test_product_change_request_endpoints_v2.py b/src/backend/test/integration/api_endpoints/test_product_change_request_endpoints_v2.py
index f910c55..cb8054e 100644
--- a/src/backend/test/integration/api_endpoints/test_product_change_request_endpoints_v2.py
+++ b/src/backend/test/integration/api_endpoints/test_product_change_request_endpoints_v2.py
@@ -183,18 +183,14 @@ def override_get_db_connection() -> Connection:
user_crud=self.real_user_crud,
)
- def override_get_pcr_service() -> (
- ProductChangeRequestService2
- ): # ⭐ 使用正确的服务依赖函数名
+ def override_get_pcr_service() -> ProductChangeRequestService2: # ⭐ 使用正确的服务依赖函数名
return self.real_pcr_service
async def override_get_current_active_user_default() -> CurrentUserSchema:
return self.mock_merchant_user_schema
app.dependency_overrides[get_db_connection] = override_get_db_connection
- app.dependency_overrides[get_pcr_service] = (
- override_get_pcr_service # ⭐ 使用正确的服务依赖函数名
- )
+ app.dependency_overrides[get_pcr_service] = override_get_pcr_service # ⭐ 使用正确的服务依赖函数名
app.dependency_overrides[get_current_active_user] = override_get_current_active_user_default
self.client = TestClient(app)
@@ -261,10 +257,7 @@ async def _create_direct_pcr(
created_pcr_dict = created_pcr_resp.model_dump()
- if (
- status != RequestStatusEnum.PENDING_APPROVAL
- and created_pcr_dict.get("Status") != status.value
- ):
+ if status != RequestStatusEnum.PENDING_APPROVAL and created_pcr_dict.get("Status") != status.value:
# If a specific status other than default PENDING_APPROVAL is needed for setup,
# update it directly via CRUD (simulating an admin action for setup simplicity).
# This bypasses service layer status transition logic for setup.
@@ -300,11 +293,8 @@ async def test_submit_pcr_product_create_success(self):
),
SubmitterNotes="Please approve this amazing gadget for integration test!",
ProductID=None,
-
- )
- response = self.client.post(
- "/api/v1/product-change-new/", json=payload.model_dump(mode="json")
)
+ response = self.client.post("/api/v1/product-change-new/", json=payload.model_dump(mode="json"))
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
data = response.json()
self.assertEqual(data["RequestType"], RequestTypeEnum.PRODUCT_CREATE.value)
@@ -324,9 +314,7 @@ async def test_submit_pcr_product_update_success(self):
ProposedData_JSON=ProposedProductData(Price=Decimal("188.88"), StockQuantity=42),
SubmitterNotes="Price and stock update for integ test.",
)
- response = self.client.post(
- "/api/v1/product-change-new/", json=payload.model_dump(mode="json")
- )
+ response = self.client.post("/api/v1/product-change-new/", json=payload.model_dump(mode="json"))
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
data = response.json()
self.assertEqual(data["RequestType"], RequestTypeEnum.PRODUCT_UPDATE.value)
@@ -354,15 +342,8 @@ async def test_list_pcr_for_merchant_success_filters_by_status(self):
data = response.json()
logger.debug(f"List PCR response data: {json.dumps(data, indent=2)}")
self.assertGreaterEqual(data["TotalCount"], 1)
- self.assertTrue(
- all(req["MerchantUserID"] == self.merchant_user_id_cls for req in data["Requests"])
- )
- self.assertTrue(
- all(
- req["Status"] == RequestStatusEnum.PENDING_APPROVAL.value
- for req in data["Requests"]
- )
- )
+ self.assertTrue(all(req["MerchantUserID"] == self.merchant_user_id_cls for req in data["Requests"]))
+ self.assertTrue(all(req["Status"] == RequestStatusEnum.PENDING_APPROVAL.value for req in data["Requests"]))
# --- III. GET /list-admin/ (list_product_change_requests by admin) ---
async def test_list_pcr_for_admin_can_filter_by_merchant(self):
@@ -372,15 +353,11 @@ async def test_list_pcr_for_admin_can_filter_by_merchant(self):
request_type=RequestTypeEnum.PRODUCT_CREATE, merchant_id=self.merchant_user_id_cls
)
- response = self.client.get(
- f"/api/v1/product-change-new/list-admin/?MerchantUserID={self.merchant_user_id_cls}"
- )
+ response = self.client.get(f"/api/v1/product-change-new/list-admin/?MerchantUserID={self.merchant_user_id_cls}")
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
data = response.json()
self.assertGreaterEqual(data["TotalCount"], 1)
- self.assertTrue(
- all(req["MerchantUserID"] == self.merchant_user_id_cls for req in data["Requests"])
- )
+ self.assertTrue(all(req["MerchantUserID"] == self.merchant_user_id_cls for req in data["Requests"]))
# --- IV. GET /{change_request_id} ---
async def test_get_pcr_details_success_owner(self):
@@ -426,12 +403,10 @@ async def test_admin_review_request_approve_and_apply_create_success(self):
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
data = response.json()
self.assertEqual(data["Status"], RequestStatusEnum.APPLIED.value)
- self.assertIsNotNone(data["ProductID"]) # Ensure product was created
+ self.assertIsNotNone(data["ProductID"]) # Ensure product was created
new_product_id = data["ProductID"]
- product_db = self.real_product_crud.get_product_by_id(
- self.connection, product_id=new_product_id
- )
+ product_db = self.real_product_crud.get_product_by_id(self.connection, product_id=new_product_id)
self.assertIsNotNone(product_db)
self.assertEqual(product_db["ProductName"], "Gadget Alpha API Integ")
self.assertEqual(product_db["StoreID"], self.store_id_cls)
diff --git a/src/backend/test/integration/api_endpoints/test_product_endpoints.py b/src/backend/test/integration/api_endpoints/test_product_endpoints.py
index f77e1ff..a388b7d 100644
--- a/src/backend/test/integration/api_endpoints/test_product_endpoints.py
+++ b/src/backend/test/integration/api_endpoints/test_product_endpoints.py
@@ -14,7 +14,7 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
class TestProductEndpoints(BaseDBTestCaseAutoRollback):
@@ -22,104 +22,113 @@ def setUp(self):
super().setUp()
# 初始化测试客户端
self.client = TestClient(app, raise_server_exceptions=False)
-
+
# 初始化CRUD实例
self.category_crud = get_category_crud_instance()
-
+
# 使用随机字符串避免唯一键冲突
random_suffix = generate_random_string()
-
+
# 创建测试数据 - 创建分类
self.test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类",
- category_description="用于测试的商品分类",
- actor_id=None
+ conn=self.connection, category_name="测试分类", category_description="用于测试的商品分类", actor_id=None
)
-
+
# 创建测试数据 - 创建用户(用于模拟认证)
test_user_name = f"testuser_{random_suffix}"
test_user_email = f"{test_user_name}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, Email, PasswordHash, UserRole)
VALUES (:username, :email, 'hashed_password', 'merchant')
- """),
- {"username": test_user_name, "email": test_user_email}
+ """
+ ),
+ {"username": test_user_name, "email": test_user_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
-
+
# 创建测试数据 - 创建店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_store_id = store_id_result[0]
-
+
# 创建一些基本的测试商品
for i in range(5):
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES (:name, :description, :price, :store_id, :category_id, :stock)
- """),
+ """
+ ),
{
"name": f"测试商品{i+1}",
"description": f"这是测试商品{i+1}的描述",
- "price": (i+1) * 100.0,
+ "price": (i + 1) * 100.0,
"store_id": self.test_store_id,
"category_id": self.test_category["CategoryID"],
- "stock": (i+1) * 20 # 增加库存量,确保有足够的库存用于测试
- }
+ "stock": (i + 1) * 20, # 增加库存量,确保有足够的库存用于测试
+ },
)
-
+
# 创建另一个店铺和商品,用于测试跨店铺查询
other_user_name = f"testuser2_{random_suffix}"
other_user_email = f"{other_user_name}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, Email, PasswordHash, UserRole)
VALUES (:username, :email, 'hashed_password', 'merchant')
- """),
- {"username": other_user_name, "email": other_user_email}
+ """
+ ),
+ {"username": other_user_name, "email": other_user_email},
)
other_user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
other_user_id = other_user_id_result[0]
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('其他测试店铺', :user_id, '第二个测试店铺', 'ACTIVE')
- """),
- {"user_id": other_user_id}
+ """
+ ),
+ {"user_id": other_user_id},
)
other_store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
other_store_id = other_store_id_result[0]
-
+
# 在第二个店铺中创建商品
for i in range(3):
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES (:name, :description, :price, :store_id, :category_id, :stock)
- """),
+ """
+ ),
{
"name": f"另一店铺商品{i+1}",
"description": f"这是另一个店铺的测试商品{i+1}",
- "price": (i+1) * 200.0,
+ "price": (i + 1) * 200.0,
"store_id": other_store_id,
"category_id": self.test_category["CategoryID"],
- "stock": (i+1) * 30 # 增加库存量
- }
+ "stock": (i + 1) * 30, # 增加库存量
+ },
)
-
+
# 显式提交所有更改,确保测试数据已保存在数据库中
self.connection.commit()
print("setUp完成,所有测试数据已提交到数据库")
@@ -134,19 +143,16 @@ def test_create_product(self):
"CategoryID": self.test_category["CategoryID"],
"StoreID": self.test_store_id,
"StockQuantity": 50,
- "MainImageURL": "http://example.com/test.jpg"
+ "MainImageURL": "http://example.com/test.jpg",
}
-
+
try:
# 发送请求
- response = self.client.post(
- "/api/v1/product",
- json=product_data
- )
-
+ response = self.client.post("/api/v1/product", json=product_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertIn("ProductID", data)
@@ -167,44 +173,45 @@ def test_get_product_by_id(self):
try:
# 首先创建一个测试商品,确保有数据可用
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES ('测试获取的商品', '这是一个用于测试获取的商品', 199.99, :store_id, :category_id, 20)
- """),
- {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]},
)
-
+
# 确保数据被提交到数据库
self.connection.commit()
-
+
# 获取刚刚创建的商品ID
product_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
if not product_id_result:
self.fail("无法获取新创建的商品ID")
-
+
product_id = product_id_result[0]
-
+
# 确认商品确实存在
product_check = self.connection.execute(
- text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": product_id}
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": product_id}
).fetchone()
-
+
if not product_check:
self.fail(f"测试商品 {product_id} 不存在于数据库中")
-
+
print(f"获取到测试商品ID: {product_id}")
-
+
# 发送请求
response = self.client.get(f"/api/v1/product/{product_id}")
-
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["ProductID"], product_id)
-
+
# 测试获取不存在的商品
non_existent_response = self.client.get("/api/v1/product/999999")
self.assertEqual(non_existent_response.status_code, 404)
@@ -217,51 +224,49 @@ def test_update_product(self):
try:
# 首先创建一个测试商品,确保有数据可用
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES ('测试更新的商品', '这是一个用于测试更新的商品', 299.99, :store_id, :category_id, 25)
- """),
- {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]},
)
-
+
# 确保数据被提交到数据库
self.connection.commit()
-
+
# 获取刚刚创建的商品ID
product_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
if not product_id_result:
self.fail("无法获取新创建的商品ID")
-
+
product_id = product_id_result[0]
-
+
# 确认商品确实存在
product_check = self.connection.execute(
- text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": product_id}
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": product_id}
).fetchone()
-
+
if not product_check:
self.fail(f"测试商品 {product_id} 不存在于数据库中")
-
+
print(f"更新测试使用商品ID: {product_id}")
-
+
# 准备更新数据
update_data = {
"ProductName": "更新后的商品名称",
"ProductDescription": "更新后的描述",
"Price": 499.99,
- "ProductStatus": "INACTIVE_BY_MERCHANT"
+ "ProductStatus": "INACTIVE_BY_MERCHANT",
}
-
+
# 发送请求
- response = self.client.put(
- f"/api/v1/product/{product_id}",
- json=update_data
- )
-
+ response = self.client.put(f"/api/v1/product/{product_id}", json=update_data)
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["ProductID"], product_id)
@@ -269,12 +274,9 @@ def test_update_product(self):
self.assertEqual(data["ProductDescription"], update_data["ProductDescription"])
self.assertEqual(float(data["Price"]), update_data["Price"])
self.assertEqual(data["ProductStatus"], update_data["ProductStatus"])
-
+
# 测试更新不存在的商品
- non_existent_response = self.client.put(
- "/api/v1/product/999999",
- json={"ProductName": "不存在的商品"}
- )
+ non_existent_response = self.client.put("/api/v1/product/999999", json={"ProductName": "不存在的商品"})
self.assertEqual(non_existent_response.status_code, 404)
except Exception as e:
print(f"更新商品时发生错误: {e}")
@@ -285,72 +287,69 @@ def test_update_product_stock(self):
try:
# 首先创建一个测试商品,确保有数据可用
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES ('测试库存的商品', '这是一个用于测试库存更新的商品', 399.99, :store_id, :category_id, 100)
- """),
- {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]},
)
-
+
# 确保数据被提交到数据库
self.connection.commit()
-
+
# 获取刚刚创建的商品ID
product_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
if not product_id_result:
self.fail("无法获取新创建的商品ID")
-
+
product_id = product_id_result[0]
-
+
# 获取原始库存
stock_result = self.connection.execute(
- text("SELECT StockQuantity FROM Product WHERE ProductID = :id"),
- {"id": product_id}
+ text("SELECT StockQuantity FROM Product WHERE ProductID = :id"), {"id": product_id}
).fetchone()
-
+
if not stock_result:
self.fail(f"测试商品 {product_id} 不存在于数据库中")
-
+
original_stock = stock_result[0]
print(f"库存测试使用商品ID: {product_id}, 原库存: {original_stock}")
-
+
# 计算增加量
stock_increase = 50
-
+
# 发送请求 - 增加库存
- response = self.client.put(
- f"/api/v1/product/{product_id}/stock?stock_change={stock_increase}"
- )
-
+ response = self.client.put(f"/api/v1/product/{product_id}/stock?stock_change={stock_increase}")
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
self.assertEqual(data["ProductID"], product_id)
-
+
# 验证库存是否增加了指定数量
# 100 + 50 = 150
self.assertEqual(data["StockQuantity"], original_stock + stock_increase)
-
+
# 获取新的库存数量,直接从API响应中获取
new_stock = data["StockQuantity"]
print(f"增加库存后的数量: {new_stock}")
-
+
# 计算减少量
stock_decrease = 30
-
+
# 发送请求 - 减少库存
- response = self.client.put(
- f"/api/v1/product/{product_id}/stock?stock_change=-{stock_decrease}"
- )
-
+ response = self.client.put(f"/api/v1/product/{product_id}/stock?stock_change=-{stock_decrease}")
+
# 验证响应
self.assertEqual(response.status_code, 200)
-
+
# 验证响应数据
data = response.json()
-
+
# 验证库存是否减少了指定数量
# 150 - 30 = 120
expected_stock = new_stock - stock_decrease
@@ -368,7 +367,6 @@ def test_update_product_stock(self):
error_message = response_excessive.json().get("detail", "")
self.assertIn("库存不足", error_message)
-
except Exception as e:
print(f"更新商品库存时发生错误: {e}")
self.fail(f"测试失败,遇到错误: {e}")
@@ -378,57 +376,61 @@ def test_list_products(self):
try:
# 首先确认测试店铺ID
store_id = self.test_store_id
-
+
# 添加更多的测试商品
for i in range(3):
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES (:name, :description, :price, :store_id, :category_id, :stock)
- """),
+ """
+ ),
{
"name": f"商品列表测试商品{i+1}",
"description": f"这是用于测试商品列表的商品{i+1}",
- "price": (i+1) * 100.0,
+ "price": (i + 1) * 100.0,
"store_id": store_id,
"category_id": self.test_category["CategoryID"],
- "stock": (i+1) * 10
- }
+ "stock": (i + 1) * 10,
+ },
)
-
+
# 确保数据被提交到数据库
self.connection.commit()
-
+
# 获取测试分类ID
category_id = self.test_category["CategoryID"]
-
+
# 测试按店铺筛选
response1 = self.client.get(f"/api/v1/product?store_id={store_id}")
self.assertEqual(response1.status_code, 200)
data1 = response1.json()
self.assertGreaterEqual(len(data1), 1) # 至少有1个商品
print(f"店铺商品列表长度: {len(data1)}")
-
+
# 测试按分类筛选
response2 = self.client.get(f"/api/v1/product?category_id={category_id}")
self.assertEqual(response2.status_code, 200)
data2 = response2.json()
self.assertGreaterEqual(len(data2), 1) # 至少有1个商品
print(f"分类筛选商品列表长度: {len(data2)}")
-
+
# 测试搜索功能 - 添加一个带特定名称的商品用于搜索测试
unique_name = f"特殊搜索商品_{generate_random_string()}"
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES (:name, '特殊描述', 999.99, :store_id, :category_id, 10)
- """),
- {"name": unique_name, "store_id": store_id, "category_id": category_id}
+ """
+ ),
+ {"name": unique_name, "store_id": store_id, "category_id": category_id},
)
-
+
# 确保搜索测试的商品数据被提交
self.connection.commit()
-
+
response3 = self.client.get(f"/api/v1/product?search={unique_name}")
self.assertEqual(response3.status_code, 200)
data3 = response3.json()
@@ -443,55 +445,55 @@ def test_delete_product(self):
try:
# 创建一个专门用于删除的测试商品
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, ProductDescription, Price, StoreID, CategoryID, StockQuantity)
VALUES ('准备删除的商品', '这是一个测试删除的商品', 99.99, :store_id, :category_id, 10)
- """),
- {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": self.test_store_id, "category_id": self.test_category["CategoryID"]},
)
-
+
# 确保数据被提交到数据库
self.connection.commit()
-
+
# 获取刚刚创建的商品ID
product_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
if not product_id_result:
self.fail("无法获取新创建的商品ID")
-
+
product_id = product_id_result[0]
-
+
# 确认商品确实存在
product_check = self.connection.execute(
- text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": product_id}
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": product_id}
).fetchone()
-
+
if not product_check:
self.fail(f"测试商品 {product_id} 不存在于数据库中")
-
+
print(f"删除测试使用商品ID: {product_id}")
-
+
# 发送请求
response = self.client.delete(f"/api/v1/product/{product_id}")
-
+
# 确保测试数据库连接提交事务
self.connection.commit()
-
+
# 验证响应
self.assertEqual(response.status_code, 204) # 成功但无内容
-
+
# 验证商品状态已更改
result = self.connection.execute(
- text("SELECT ProductStatus FROM Product WHERE ProductID = :id"),
- {"id": product_id}
+ text("SELECT ProductStatus FROM Product WHERE ProductID = :id"), {"id": product_id}
).fetchone()
-
+
# 确保结果不为None
if result is None:
self.fail(f"删除后商品 {product_id} 不存在于数据库中")
-
+
self.assertEqual(result[0], "DISCONTINUED")
-
+
# 测试删除不存在的商品
non_existent_response = self.client.delete("/api/v1/product/999999")
self.assertEqual(non_existent_response.status_code, 404)
@@ -500,5 +502,5 @@ def test_delete_product(self):
self.fail(f"测试失败,遇到错误: {e}")
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/backend/test/integration/api_endpoints/test_statistics_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_statistics_endpoints_async.py
index d7498e6..d77fcd8 100644
--- a/src/backend/test/integration/api_endpoints/test_statistics_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_statistics_endpoints_async.py
@@ -41,11 +41,11 @@ class TestStatisticsEndpointsIntegration(AsyncBaseDBTestCaseAutoRollback):
customer_email: str = "customer_test@example.com"
customer_password_plain: str = "CustomerPass123!"
customer_password_hash: str = hash_password(customer_password_plain)
-
+
# Store data
store_id: int = 200
store_name: str = "Test Store"
-
+
@classmethod
def _create_test_data(cls, conn: Connection):
"""Creates test data for statistics endpoints"""
@@ -60,29 +60,29 @@ def _create_test_data(cls, conn: Connection):
),
[
{
- "UserID": cls.admin_id,
- "Username": cls.admin_username,
- "PasswordHash": cls.admin_password_hash,
- "Email": cls.admin_email,
- "UserRole": "ADMIN"
+ "UserID": cls.admin_id,
+ "Username": cls.admin_username,
+ "PasswordHash": cls.admin_password_hash,
+ "Email": cls.admin_email,
+ "UserRole": "ADMIN",
},
{
- "UserID": cls.merchant_id,
- "Username": cls.merchant_username,
- "PasswordHash": cls.merchant_password_hash,
- "Email": cls.merchant_email,
- "UserRole": "MERCHANT"
+ "UserID": cls.merchant_id,
+ "Username": cls.merchant_username,
+ "PasswordHash": cls.merchant_password_hash,
+ "Email": cls.merchant_email,
+ "UserRole": "MERCHANT",
},
{
- "UserID": cls.customer_id,
- "Username": cls.customer_username,
- "PasswordHash": cls.customer_password_hash,
- "Email": cls.customer_email,
- "UserRole": "CUSTOMER"
- }
- ]
+ "UserID": cls.customer_id,
+ "Username": cls.customer_username,
+ "PasswordHash": cls.customer_password_hash,
+ "Email": cls.customer_email,
+ "UserRole": "CUSTOMER",
+ },
+ ],
)
-
+
# Create test store
conn.execute(
text(
@@ -90,13 +90,9 @@ def _create_test_data(cls, conn: Connection):
"VALUES (:StoreID, :StoreName, :OwnerUserID, 'ACTIVE', UTC_TIMESTAMP()) "
"ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName), OwnerUserID=VALUES(OwnerUserID)"
),
- {
- "StoreID": cls.store_id,
- "StoreName": cls.store_name,
- "OwnerUserID": cls.merchant_id
- }
+ {"StoreID": cls.store_id, "StoreName": cls.store_name, "OwnerUserID": cls.merchant_id},
)
-
+
# Create test product categories
conn.execute(
text(
@@ -105,7 +101,7 @@ def _create_test_data(cls, conn: Connection):
"ON DUPLICATE KEY UPDATE CategoryName=VALUES(CategoryName), CategoryDescription=VALUES(CategoryDescription)"
)
)
-
+
# Create test products
conn.execute(
text(
@@ -114,9 +110,9 @@ def _create_test_data(cls, conn: Connection):
"(2, 'Test Product 2', 20.99, :StoreID, 2, 50) "
"ON DUPLICATE KEY UPDATE ProductName=VALUES(ProductName), Price=VALUES(Price), StoreID=VALUES(StoreID)"
),
- {"StoreID": cls.store_id}
+ {"StoreID": cls.store_id},
)
-
+
# Create test payment transactions first (required for orders)
conn.execute(
text(
@@ -126,9 +122,9 @@ def _create_test_data(cls, conn: Connection):
"(3, :CustomerID, 15.99, 'CREDIT_CARD', 'SUCCESSFUL') "
"ON DUPLICATE KEY UPDATE UserID=VALUES(UserID), TotalAmount=VALUES(TotalAmount), Status=VALUES(Status)"
),
- {"CustomerID": cls.customer_id}
+ {"CustomerID": cls.customer_id},
)
-
+
# Create test orders
conn.execute(
text(
@@ -146,9 +142,9 @@ def _create_test_data(cls, conn: Connection):
"UserID=VALUES(UserID), StoreID=VALUES(StoreID), OrderStatus=VALUES(OrderStatus), "
"FinalAmountForThisOrder=VALUES(FinalAmountForThisOrder)"
),
- {"CustomerID": cls.customer_id, "StoreID": cls.store_id}
+ {"CustomerID": cls.customer_id, "StoreID": cls.store_id},
)
-
+
# Create test order items
conn.execute(
text(
@@ -160,12 +156,12 @@ def _create_test_data(cls, conn: Connection):
"Quantity=VALUES(Quantity), PriceAtPurchase=VALUES(PriceAtPurchase), ProductNameAtPurchase=VALUES(ProductNameAtPurchase), "
"Subtotal=VALUES(Subtotal)"
),
- {"StoreID": cls.store_id}
+ {"StoreID": cls.store_id},
)
-
+
# Create mock statistics views for tests
create_mock_views(conn)
-
+
logger.info(f"--- {cls.__name__}: Test data for statistics endpoints created ---")
except Exception as e:
logger.error(f"ERROR during {cls.__name__}._create_test_data: {e}")
@@ -192,39 +188,39 @@ def setUp(self):
super().setUp() # Provides self.connection and self.transaction
current_utc_time = datetime.datetime.now(datetime.timezone.utc)
-
+
# Mock admin user
self.mock_admin_user = UserResponse(
UserID=self.admin_id,
Username=self.admin_username,
Email=self.admin_email,
PhoneNumber=None,
- UserRole='admin',
+ UserRole="admin",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
-
+
# Mock merchant user
self.mock_merchant_user = UserResponse(
UserID=self.merchant_id,
Username=self.merchant_username,
Email=self.merchant_email,
PhoneNumber=None,
- UserRole='merchant',
+ UserRole="merchant",
StoreID=self.store_id,
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
-
+
# Mock customer user
self.mock_customer_user = UserResponse(
UserID=self.customer_id,
Username=self.customer_username,
Email=self.customer_email,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# --- Dependency Overrides ---
@@ -236,7 +232,7 @@ def override_get_db_connection() -> Connection:
def override_get_current_admin() -> UserResponse:
"""Override to return mock admin user"""
return self.mock_admin_user
-
+
# Regular user dependency override (could be admin, merchant, or customer)
def override_get_current_active_user() -> UserResponse:
"""Default to customer user, can be changed in specific tests"""
@@ -259,7 +255,7 @@ async def test_get_system_statistics_admin(self):
"""Test that admin can access system statistics"""
response = self.client.get("/api/v1/statistics/system")
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
+
data = response.json()
self.assertIsInstance(data, dict)
self.assertEqual(data["total_users"], 3)
@@ -270,9 +266,10 @@ async def test_get_system_statistics_admin(self):
self.assertEqual(data["completed_orders"], 2)
self.assertEqual(data["total_orders"], 3)
self.assertEqual(data["total_sales"], 36.98) # 20.99 + 15.99
-
+
async def test_get_system_statistics_non_admin(self):
"""Test that non-admin cannot access system statistics"""
+
# Override admin dependency to return a 403 response
def override_get_current_admin():
raise HTTPException(status_code=403, detail="User is not an admin")
@@ -282,12 +279,12 @@ def override_get_current_admin():
# This should now raise a 403 error
response = self.client.get("/api/v1/statistics/system")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
+
async def test_get_admin_dashboard_statistics(self):
"""Test admin dashboard statistics endpoint"""
response = self.client.get("/api/v1/statistics/admin-dashboard")
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
+
data = response.json()
self.assertIsInstance(data, dict)
# Admin dashboard should have the same data as system statistics
@@ -299,19 +296,19 @@ async def test_get_admin_dashboard_statistics(self):
self.assertEqual(data["completed_orders"], 2)
self.assertEqual(data["total_orders"], 3)
self.assertEqual(data["total_sales"], 36.98) # 20.99 + 15.99
-
+
async def test_get_all_store_statistics_admin(self):
"""Test that admin can access all store statistics"""
# Explicitly use the admin user
app.dependency_overrides[get_current_active_user] = lambda: self.mock_admin_user
-
+
response = self.client.get("/api/v1/statistics/store")
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
+
data = response.json()
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1) # Only one store in test data
-
+
store_stats = data[0]
self.assertEqual(store_stats["store_id"], self.store_id)
self.assertEqual(store_stats["store_name"], self.store_name)
@@ -319,15 +316,15 @@ async def test_get_all_store_statistics_admin(self):
self.assertEqual(store_stats["order_count"], 3)
self.assertEqual(store_stats["items_sold"], 6)
self.assertEqual(store_stats["total_revenue"], 95.94) # 10.99 + 20.99 + 15.99
-
+
async def test_get_specific_store_statistics_admin(self):
"""Test that admin can access specific store statistics"""
# Override to use admin user explicitly
app.dependency_overrides[get_current_active_user] = lambda: self.mock_admin_user
-
+
response = self.client.get(f"/api/v1/statistics/store/{self.store_id}")
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
+
store_stats = response.json()
self.assertEqual(store_stats["store_id"], self.store_id)
self.assertEqual(store_stats["store_name"], self.store_name)
@@ -335,9 +332,10 @@ async def test_get_specific_store_statistics_admin(self):
self.assertEqual(store_stats["order_count"], 3)
self.assertEqual(store_stats["items_sold"], 6)
self.assertEqual(store_stats["total_revenue"], 95.94) # 10.99 + 20.99 + 15.99
-
+
async def test_get_specific_store_statistics_merchant(self):
"""Test that merchant can access their own store statistics"""
+
# Create a dictionary-based user object instead of a Pydantic model to ensure attributes can be added
class MerchantWithStore:
def __init__(self, user_id, username, email, user_role, store_id):
@@ -346,48 +344,48 @@ def __init__(self, user_id, username, email, user_role, store_id):
self.Email = email
self.UserRole = user_role
self.StoreID = store_id # This is key - we need this attribute for the endpoint check
-
+
# Create merchant with proper store access
temp_merchant = MerchantWithStore(
user_id=self.merchant_id,
username=self.merchant_username,
email=self.merchant_email,
- user_role='MERCHANT', # Use uppercase to match case checking
- store_id=self.store_id # Set the StoreID to match the test store
+ user_role="MERCHANT", # Use uppercase to match case checking
+ store_id=self.store_id, # Set the StoreID to match the test store
)
-
+
app.dependency_overrides[get_current_active_user] = lambda: temp_merchant
-
+
response = self.client.get(f"/api/v1/statistics/store/{self.store_id}")
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
+
store_stats = response.json()
self.assertEqual(store_stats["store_id"], self.store_id)
self.assertEqual(store_stats["store_name"], self.store_name)
-
+
async def test_merchant_cannot_access_other_store_statistics(self):
"""Test that merchant cannot access statistics for stores they don't own"""
# Override current user to return merchant
app.dependency_overrides[get_current_active_user] = lambda: self.mock_merchant_user
-
+
# Try to access a store that doesn't exist
invalid_store_id = 999
response = self.client.get(f"/api/v1/statistics/store/{invalid_store_id}")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
+
async def test_customer_cannot_access_store_statistics(self):
"""Test that customers cannot access store statistics"""
# Override current user to return customer
app.dependency_overrides[get_current_active_user] = lambda: self.mock_customer_user
-
+
response = self.client.get(f"/api/v1/statistics/store/{self.store_id}")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
+
async def test_store_not_found(self):
"""Test appropriate error when store not found"""
# Make sure we're using admin user for this test
app.dependency_overrides[get_current_active_user] = lambda: self.mock_admin_user
-
+
invalid_store_id = 999
response = self.client.get(f"/api/v1/statistics/store/{invalid_store_id}")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
diff --git a/src/backend/test/integration/api_endpoints/test_store_change_request_endpoints_v2.py b/src/backend/test/integration/api_endpoints/test_store_change_request_endpoints_v2.py
index f723849..c791522 100644
--- a/src/backend/test/integration/api_endpoints/test_store_change_request_endpoints_v2.py
+++ b/src/backend/test/integration/api_endpoints/test_store_change_request_endpoints_v2.py
@@ -90,7 +90,7 @@ def _create_shared_class_data(cls, conn: Connection):
"PasswordHash": hash_password("CustomerSCRPass1!"),
"Email": "customer@example.com",
"UserRole": "customer",
- }
+ },
]
for ud in users_data:
conn.execute(
@@ -216,9 +216,7 @@ async def _create_direct_scr_via_api(
self,
request_type: RequestTypeEnum,
proposed_data_json: Optional[ProposedStoreData] = None, # Changed from ProposedProductData
- store_id_in_payload: Optional[
- int
- ] = None, # This is the StoreID in the SCR_CreateRequest payload
+ store_id_in_payload: Optional[int] = None, # This is the StoreID in the SCR_CreateRequest payload
submitter_notes: Optional[str] = None,
# ProductID is not part of StoreChangeRequestCreate schema
as_user: Optional[CurrentUserSchema] = None,
@@ -252,7 +250,7 @@ async def _create_direct_scr_via_api(
payload_for_api: Dict[str, Any] = {
"RequestType": request_type.value,
"SubmitterNotes": submitter_notes,
- "StoreID": store_id_in_payload
+ "StoreID": store_id_in_payload,
}
# The StoreID in the payload for the API.
# For STORE_CREATE, the service expects this to be the store the merchant owns,
@@ -260,9 +258,7 @@ async def _create_direct_scr_via_api(
# For UPDATE/DELETE, this is the target store.
if proposed_data_json:
- payload_for_api["ProposedData_JSON"] = proposed_data_json.model_dump(
- mode="json", exclude_none=True
- )
+ payload_for_api["ProposedData_JSON"] = proposed_data_json.model_dump(mode="json", exclude_none=True)
# The StoreChangeRequestCreate schema from product_change_request_schema_v2 also has an optional ProductID.
# We will omit it as it's not relevant for Store Change Requests.
@@ -300,9 +296,7 @@ async def test_submit_scr_store_create_success(self):
),
SubmitterNotes="My application for a new store",
)
- response = self.client.post(
- "/api/v1/store-change-new/", json=payload.model_dump(mode="json")
- )
+ response = self.client.post("/api/v1/store-change-new/", json=payload.model_dump(mode="json"))
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
data = response.json()
self.assertEqual(data["RequestType"], RequestTypeEnum.STORE_CREATE.value)
@@ -318,21 +312,15 @@ async def test_submit_scr_store_update_success(self):
payload = StoreChangeRequestCreate(
StoreID=self.store_id_cls, # Target existing store owned by merchant
RequestType=RequestTypeEnum.STORE_UPDATE,
- ProposedData_JSON=ProposedStoreData(
- Description="Updated description via API for store."
- ),
+ ProposedData_JSON=ProposedStoreData(Description="Updated description via API for store."),
SubmitterNotes="Store description update.",
)
- response = self.client.post(
- "/api/v1/store-change-new/", json=payload.model_dump(mode="json")
- )
+ response = self.client.post("/api/v1/store-change-new/", json=payload.model_dump(mode="json"))
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.text)
data = response.json()
self.assertEqual(data["RequestType"], RequestTypeEnum.STORE_UPDATE.value)
self.assertEqual(data["StoreID"], self.store_id_cls) # StoreID is the target store
- self.assertEqual(
- data["ProposedData_JSON"]["Description"], "Updated description via API for store."
- )
+ self.assertEqual(data["ProposedData_JSON"]["Description"], "Updated description via API for store.")
async def test_submit_scr_store_update_fails_if_payload_store_id_is_none(self):
# The StoreChangeRequestCreate schema requires StoreID.
@@ -360,21 +348,12 @@ async def test_list_scr_for_requesting_user_success(self):
proposed_data_json=ProposedStoreData(StoreName="Store A", Description="Desc A"),
)
- response = self.client.get(
- f"/api/v1/store-change-new/list/?Status={RequestStatusEnum.PENDING_APPROVAL.value}"
- )
+ response = self.client.get(f"/api/v1/store-change-new/list/?Status={RequestStatusEnum.PENDING_APPROVAL.value}")
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
data = response.json()
self.assertGreaterEqual(data["TotalCount"], 1)
- self.assertTrue(
- all(req["RequestingUserID"] == self.requesting_user_id_cls for req in data["Requests"])
- )
- self.assertTrue(
- all(
- req["Status"] == RequestStatusEnum.PENDING_APPROVAL.value
- for req in data["Requests"]
- )
- )
+ self.assertTrue(all(req["RequestingUserID"] == self.requesting_user_id_cls for req in data["Requests"]))
+ self.assertTrue(all(req["Status"] == RequestStatusEnum.PENDING_APPROVAL.value for req in data["Requests"]))
# --- III. GET /list-admin/ (list_store_change_requests_admin) ---
async def test_list_scr_for_admin_success(self):
@@ -392,9 +371,7 @@ async def test_list_scr_for_admin_success(self):
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
data = response.json()
self.assertGreaterEqual(data["TotalCount"], 1)
- self.assertTrue(
- all(req["RequestingUserID"] == self.requesting_user_id_cls for req in data["Requests"])
- )
+ self.assertTrue(all(req["RequestingUserID"] == self.requesting_user_id_cls for req in data["Requests"]))
# --- IV. GET /{request_id} ---
async def test_get_scr_details_success_owner(self):
@@ -421,12 +398,10 @@ async def test_admin_review_request_approve_and_apply_store_create_success(self)
LogoURL="logo.png",
StoreStatus=StoreStatusEnum.ACTIVE,
)
- created_pcr = (
- await self._create_direct_scr_via_api( # This helper uses the merchant user by default
- request_type=RequestTypeEnum.STORE_CREATE,
- store_id_in_payload=None,
- proposed_data_json=proposed_data,
- )
+ created_pcr = await self._create_direct_scr_via_api( # This helper uses the merchant user by default
+ request_type=RequestTypeEnum.STORE_CREATE,
+ store_id_in_payload=None,
+ proposed_data_json=proposed_data,
)
# Manually set status to PENDING_APPROVAL if helper doesn't, or create a direct DB entry
self.real_scr_crud.update_request_by_admin(
@@ -455,9 +430,7 @@ async def test_admin_review_request_approve_and_apply_store_create_success(self)
self.assertIsNotNone(data["StoreID"])
newly_created_store_id = data["StoreID"]
- store_db = self.real_store_crud.get_store_by_id(
- self.connection, store_id=newly_created_store_id
- )
+ store_db = self.real_store_crud.get_store_by_id(self.connection, store_id=newly_created_store_id)
self.assertIsNotNone(store_db)
self.assertEqual(store_db["StoreName"], "API Approved New Store")
self.assertEqual(store_db["OwnerUserID"], self.requesting_user_id_cls)
@@ -471,13 +444,11 @@ async def test_admin_review_request_approve_customer_create_becomes_merchant(sel
LogoURL="logo.png",
StoreStatus=StoreStatusEnum.ACTIVE,
)
- created_pcr = (
- await self._create_direct_scr_via_api(
- request_type=RequestTypeEnum.STORE_CREATE,
- store_id_in_payload=None,
- proposed_data_json=proposed_data,
- as_user=self.mock_customer_user_schema
- )
+ created_pcr = await self._create_direct_scr_via_api(
+ request_type=RequestTypeEnum.STORE_CREATE,
+ store_id_in_payload=None,
+ proposed_data_json=proposed_data,
+ as_user=self.mock_customer_user_schema,
)
# Manually set status to PENDING_APPROVAL if helper doesn't, or create a direct DB entry
self.real_scr_crud.update_request_by_admin(
@@ -503,12 +474,12 @@ async def test_admin_review_request_approve_customer_create_becomes_merchant(sel
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text)
# assert now user is a merchant
updated_user = self.real_user_crud.get_user_by_id(
- conn=self.connection,
- user_id=self.mock_customer_user_schema.UserID
+ conn=self.connection, user_id=self.mock_customer_user_schema.UserID
)
self.assertIsNotNone(updated_user)
- self.assertEqual(updated_user["UserRole"], "merchant", "Customer should become merchant after store creation approval")
-
+ self.assertEqual(
+ updated_user["UserRole"], "merchant", "Customer should become merchant after store creation approval"
+ )
# --- VI. DELETE /{request_id} (User/Merchant Cancel) ---
@@ -516,7 +487,7 @@ async def test_delete_request_cancels_by_user_success(self):
# Current user is merchant1 (default override in setUp)
created_pcr = await self._create_direct_scr_via_api(
request_type=RequestTypeEnum.STORE_CREATE,
- store_id_in_payload=None, # create a new store, so StoreID in payload is None
+ store_id_in_payload=None, # create a new store, so StoreID in payload is None
proposed_data_json=ProposedStoreData(StoreName="To Be Cancelled"),
)
# Manually set status to PENDING_APPROVAL for this test
diff --git a/src/backend/test/integration/api_endpoints/test_store_endpoints_async.py b/src/backend/test/integration/api_endpoints/test_store_endpoints_async.py
index 2faa873..f074faf 100644
--- a/src/backend/test/integration/api_endpoints/test_store_endpoints_async.py
+++ b/src/backend/test/integration/api_endpoints/test_store_endpoints_async.py
@@ -20,7 +20,7 @@
StoreSimpleResponse,
StoreListResponse,
StoreListSimpleResponse,
- StoreStatusEnum
+ StoreStatusEnum,
)
from backend.app.services.store_service import StoreService
from backend.app.crud.user_crud import UserCRUD
@@ -66,30 +66,50 @@ def _create_shared_class_data(cls, conn: Connection):
try:
# 1. Create Users
users_data = [
- {"UserID": cls.user1_id_class, "Username": cls.user1_username_class,
- "PasswordHash": cls.user1_password_hash_class, "Email": cls.user1_email_class},
- {"UserID": cls.user2_id_class, "Username": cls.user2_username_class,
- "PasswordHash": cls.user2_password_hash_class, "Email": cls.user2_email_class},
+ {
+ "UserID": cls.user1_id_class,
+ "Username": cls.user1_username_class,
+ "PasswordHash": cls.user1_password_hash_class,
+ "Email": cls.user1_email_class,
+ },
+ {
+ "UserID": cls.user2_id_class,
+ "Username": cls.user2_username_class,
+ "PasswordHash": cls.user2_password_hash_class,
+ "Email": cls.user2_email_class,
+ },
]
for ud in users_data:
- conn.execute(text(
- "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE') ON DUPLICATE KEY UPDATE Username=VALUES(Username)"),
- ud)
+ conn.execute(
+ text(
+ "INSERT INTO User (UserID, Username, PasswordHash, Email, UserRole, AccountStatus) VALUES (:UserID, :Username, :PasswordHash, :Email, 'customer', 'ACTIVE') ON DUPLICATE KEY UPDATE Username=VALUES(Username)"
+ ),
+ ud,
+ )
# 2. Create Stores for User1
stores_data = [
- {"StoreID": cls.store1_id_active_class, "StoreName": cls.store1_name_active_class,
- "OwnerUserID": cls.user1_id_class, "StoreStatus": StoreStatusEnum.ACTIVE.value,
- "Description": "An active store by user1"},
- {"StoreID": cls.store2_id_inactive_class, "StoreName": cls.store2_name_inactive_class,
- "OwnerUserID": cls.user1_id_class, "StoreStatus": StoreStatusEnum.INACTIVE_BY_MERCHANT.value,
- "Description": "An inactive store by user1"},
+ {
+ "StoreID": cls.store1_id_active_class,
+ "StoreName": cls.store1_name_active_class,
+ "OwnerUserID": cls.user1_id_class,
+ "StoreStatus": StoreStatusEnum.ACTIVE.value,
+ "Description": "An active store by user1",
+ },
+ {
+ "StoreID": cls.store2_id_inactive_class,
+ "StoreName": cls.store2_name_inactive_class,
+ "OwnerUserID": cls.user1_id_class,
+ "StoreStatus": StoreStatusEnum.INACTIVE_BY_MERCHANT.value,
+ "Description": "An inactive store by user1",
+ },
]
for sd in stores_data:
conn.execute(
text(
- "INSERT INTO Store (StoreID, StoreName, OwnerUserID, Description, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:StoreID, :StoreName, :OwnerUserID, :Description, :StoreStatus, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"),
- sd
+ "INSERT INTO Store (StoreID, StoreName, OwnerUserID, Description, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:StoreID, :StoreName, :OwnerUserID, :Description, :StoreStatus, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE StoreName=VALUES(StoreName)"
+ ),
+ sd,
)
logger.info(f"--- {cls.__name__}: Shared class-level store data creation complete ---")
except Exception as e:
@@ -120,18 +140,18 @@ def setUp(self):
Username=self.user1_username_class,
Email=self.user1_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
self.mock_current_user2_schema = CurrentUserSchema( # Regular User
UserID=self.user2_id_class,
Username=self.user2_username_class,
Email=self.user2_email_class,
PhoneNumber=None,
- UserRole='customer',
+ UserRole="customer",
RegistrationDate=current_utc_time,
- LastLoginDate=current_utc_time
+ LastLoginDate=current_utc_time,
)
# Dependency Overrides
@@ -141,10 +161,7 @@ def override_get_db_connection() -> Connection:
self.real_store_crud = StoreTableCRUD.get_instance()
self.real_user_crud = UserCRUD.get_instance() # StoreService uses this
- self.real_store_service = StoreService(
- store_crud=self.real_store_crud,
- user_crud=self.real_user_crud
- )
+ self.real_store_service = StoreService(store_crud=self.real_store_crud, user_crud=self.real_user_crud)
def override_get_store_service() -> StoreService:
return self.real_store_service
@@ -195,8 +212,11 @@ async def test_get_store_info_owner_can_get_their_inactive_store_via_merchant_ge
# This test verifies the current /info endpoint behavior.
app.dependency_overrides[get_current_active_user] = lambda: self.mock_current_user1_schema # User1 is owner
response = self.client.get(f"/api/v1/store/info/{self.store2_id_inactive_class}")
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND,
- "Endpoint /info/{id} should only return ACTIVE stores for users")
+ self.assertEqual(
+ response.status_code,
+ status.HTTP_404_NOT_FOUND,
+ "Endpoint /info/{id} should only return ACTIVE stores for users",
+ )
async def test_get_store_info_non_existent_store(self):
response = self.client.get("/api/v1/store/info/99999")
@@ -220,7 +240,7 @@ async def test_get_stores_by_owner_other_user_logs_warning_but_proceeds(self):
# SUT currently logs a warning and proceeds.
app.dependency_overrides[get_current_active_user] = lambda: self.mock_current_user2_schema
- with patch('backend.app.services.store_service.logger') as mock_service_logger: # Patch service logger
+ with patch("backend.app.services.store_service.logger") as mock_service_logger: # Patch service logger
response = self.client.get(f"/api/v1/store/owner/{self.user1_id_class}") # Target is user1
self.assertEqual(response.status_code, status.HTTP_200_OK, response.text) # It proceeds
data = response.json()
@@ -259,7 +279,7 @@ async def test_get_all_stores_simple_success(self):
self.assertTrue(all(s.StoreID is not None for s in simple_list_resp.StoreList)) # Check basic fields
# --- Test GET /list-full (get_all_stores) ---
- @patch('backend.app.services.store_service.logger') # To check warning
+ @patch("backend.app.services.store_service.logger") # To check warning
async def test_user_get_all_stores_full_success_logs_warning(self, mock_service_logger):
# This endpoint returns ALL stores, and only gets ACTIVE stores
# SUT currently logs a warning because admin check is a TODO
@@ -285,5 +305,5 @@ async def test_user_get_all_stores_full_success_logs_warning(self, mock_service_
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/crud/test_category_crud_integration.py b/src/backend/test/integration/crud/test_category_crud_integration.py
index 6104950..c4348ad 100644
--- a/src/backend/test/integration/crud/test_category_crud_integration.py
+++ b/src/backend/test/integration/crud/test_category_crud_integration.py
@@ -10,7 +10,7 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
@unittest.skip("This test has been moved to test_category_crud_integration.py")
@@ -29,10 +29,7 @@ def test_create_category(self):
}
# 执行测试
- category = self.category_crud.create_category(
- conn=self.connection,
- **category_data
- )
+ category = self.category_crud.create_category(conn=self.connection, **category_data)
# 验证结果
self.assertIsNotNone(category)
@@ -46,22 +43,17 @@ def test_create_subcategory(self):
"""测试创建子分类"""
# 准备数据 - 先创建父分类
parent_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="父分类",
- category_description="父分类描述"
+ conn=self.connection, category_name="父分类", category_description="父分类描述"
)
# 执行测试 - 创建子分类
subcategory_data = {
"category_name": "子分类",
"category_description": "子分类描述",
- "parent_category_id": parent_category["CategoryID"]
+ "parent_category_id": parent_category["CategoryID"],
}
- subcategory = self.category_crud.create_category(
- conn=self.connection,
- **subcategory_data
- )
+ subcategory = self.category_crud.create_category(conn=self.connection, **subcategory_data)
# 验证结果
self.assertIsNotNone(subcategory)
@@ -73,25 +65,18 @@ def test_create_category_with_invalid_parent(self):
# 执行测试 - 尝试使用不存在的父分类ID
with self.assertRaises(ValueError):
self.category_crud.create_category(
- conn=self.connection,
- category_name="无效父分类的分类",
- parent_category_id=999999
+ conn=self.connection, category_name="无效父分类的分类", parent_category_id=999999
)
def test_get_category_by_id(self):
"""测试根据ID获取分类"""
# 准备数据 - 先创建分类
test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类2",
- category_description="用于测试查询的分类"
+ conn=self.connection, category_name="测试分类2", category_description="用于测试查询的分类"
)
# 执行测试
- category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
+ category = self.category_crud.get_category_by_id(conn=self.connection, category_id=test_category["CategoryID"])
# 验证结果
self.assertIsNotNone(category)
@@ -99,26 +84,17 @@ def test_get_category_by_id(self):
self.assertEqual(category["CategoryName"], "测试分类2")
# 测试获取不存在的分类
- non_existent_category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=999999
- )
+ non_existent_category = self.category_crud.get_category_by_id(conn=self.connection, category_id=999999)
self.assertIsNone(non_existent_category)
def test_get_categories(self):
"""测试获取分类列表"""
# 准备数据 - 创建一些顶级分类
for i in range(3):
- self.category_crud.create_category(
- conn=self.connection,
- category_name=f"顶级分类{i + 1}"
- )
+ self.category_crud.create_category(conn=self.connection, category_name=f"顶级分类{i + 1}")
# 执行测试 - 获取所有顶级分类
- top_categories = self.category_crud.get_categories(
- conn=self.connection,
- parent_id=None
- )
+ top_categories = self.category_crud.get_categories(conn=self.connection, parent_id=None)
# 验证结果
self.assertIsInstance(top_categories, list)
@@ -128,16 +104,11 @@ def test_get_categories(self):
parent_id = top_categories[0]["CategoryID"]
for i in range(2):
self.category_crud.create_category(
- conn=self.connection,
- category_name=f"子分类{i + 1}",
- parent_category_id=parent_id
+ conn=self.connection, category_name=f"子分类{i + 1}", parent_category_id=parent_id
)
# 执行测试 - 获取特定父分类的子分类
- sub_categories = self.category_crud.get_categories(
- conn=self.connection,
- parent_id=parent_id
- )
+ sub_categories = self.category_crud.get_categories(conn=self.connection, parent_id=parent_id)
# 验证结果
self.assertEqual(len(sub_categories), 2)
@@ -148,21 +119,14 @@ def test_update_category(self):
"""测试更新分类信息"""
# 准备数据
test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试更新分类",
- category_description="原始描述"
+ conn=self.connection, category_name="测试更新分类", category_description="原始描述"
)
# 执行测试 - 更新名称和描述
- update_data = {
- "categoryname": "更新后的分类名称",
- "categorydescription": "更新后的描述"
- }
+ update_data = {"categoryname": "更新后的分类名称", "categorydescription": "更新后的描述"}
updated_category = self.category_crud.update_category(
- conn=self.connection,
- category_id=test_category["CategoryID"],
- update_data=update_data
+ conn=self.connection, category_id=test_category["CategoryID"], update_data=update_data
)
# 验证结果
@@ -172,34 +136,22 @@ def test_update_category(self):
# 测试更新不存在的分类
non_existent_update = self.category_crud.update_category(
- conn=self.connection,
- category_id=999999,
- update_data={"categoryname": "不存在的分类"}
+ conn=self.connection, category_id=999999, update_data={"categoryname": "不存在的分类"}
)
self.assertIsNone(non_existent_update)
def test_update_category_parent(self):
"""测试更新分类的父分类"""
# 准备数据 - 创建两个分类
- category1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="分类1"
- )
+ category1 = self.category_crud.create_category(conn=self.connection, category_name="分类1")
- category2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="分类2"
- )
+ category2 = self.category_crud.create_category(conn=self.connection, category_name="分类2")
# 执行测试 - 将分类2设为分类1的父分类
- update_data = {
- "parentcategoryid": category2["CategoryID"]
- }
+ update_data = {"parentcategoryid": category2["CategoryID"]}
updated_category = self.category_crud.update_category(
- conn=self.connection,
- category_id=category1["CategoryID"],
- update_data=update_data
+ conn=self.connection, category_id=category1["CategoryID"], update_data=update_data
)
# 验证结果
@@ -207,96 +159,65 @@ def test_update_category_parent(self):
self.assertEqual(updated_category["ParentCategoryID"], category2["CategoryID"])
# 测试循环引用 - 将分类1设为分类2的父分类(应该失败,因为会形成循环)
- cyclic_update_data = {
- "parentcategoryid": category1["CategoryID"]
- }
+ cyclic_update_data = {"parentcategoryid": category1["CategoryID"]}
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=category2["CategoryID"],
- update_data=cyclic_update_data
+ conn=self.connection, category_id=category2["CategoryID"], update_data=cyclic_update_data
)
# 测试自引用 - 将分类设为自己的父分类(应该失败)
- self_update_data = {
- "parentcategoryid": category1["CategoryID"]
- }
+ self_update_data = {"parentcategoryid": category1["CategoryID"]}
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=category1["CategoryID"],
- update_data=self_update_data
+ conn=self.connection, category_id=category1["CategoryID"], update_data=self_update_data
)
def test_update_category_with_invalid_parent(self):
"""测试使用无效的父分类ID更新分类"""
# 准备数据
- test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类"
- )
+ test_category = self.category_crud.create_category(conn=self.connection, category_name="测试分类")
# 执行测试 - 尝试使用不存在的父分类ID
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=test_category["CategoryID"],
- update_data={"parentcategoryid": 999999}
+ conn=self.connection, category_id=test_category["CategoryID"], update_data={"parentcategoryid": 999999}
)
def test_delete_category(self):
"""测试删除分类"""
# 准备数据
- test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="准备删除的分类"
- )
+ test_category = self.category_crud.create_category(conn=self.connection, category_name="准备删除的分类")
# 执行测试
- result = self.category_crud.delete_category(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
+ result = self.category_crud.delete_category(conn=self.connection, category_id=test_category["CategoryID"])
# 验证结果
self.assertTrue(result)
# 验证分类已被删除
deleted_category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=test_category["CategoryID"]
+ conn=self.connection, category_id=test_category["CategoryID"]
)
self.assertIsNone(deleted_category)
# 测试删除不存在的分类
- non_existent_result = self.category_crud.delete_category(
- conn=self.connection,
- category_id=999999
- )
+ non_existent_result = self.category_crud.delete_category(conn=self.connection, category_id=999999)
self.assertFalse(non_existent_result)
def test_delete_category_with_subcategories(self):
"""测试删除有子分类的分类(应该失败)"""
# 准备数据 - 创建父分类和子分类
- parent_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="父分类"
- )
+ parent_category = self.category_crud.create_category(conn=self.connection, category_name="父分类")
self.category_crud.create_category(
- conn=self.connection,
- category_name="子分类",
- parent_category_id=parent_category["CategoryID"]
+ conn=self.connection, category_name="子分类", parent_category_id=parent_category["CategoryID"]
)
# 执行测试 - 尝试删除有子分类的父分类
with self.assertRaises(ValueError):
- self.category_crud.delete_category(
- conn=self.connection,
- category_id=parent_category["CategoryID"]
- )
+ self.category_crud.delete_category(conn=self.connection, category_id=parent_category["CategoryID"])
def test_delete_category_with_products(self):
"""测试删除有商品的分类(应该失败)"""
@@ -304,10 +225,7 @@ def test_delete_category_with_products(self):
with self.engine.begin() as conn:
# 创建分类
category_crud = get_category_crud_instance()
- test_category = category_crud.create_category(
- conn=conn,
- category_name="有商品的分类"
- )
+ test_category = category_crud.create_category(conn=conn, category_name="有商品的分类")
# 创建测试用户,使用随机用户名避免唯一键冲突
random_suffix = generate_random_string()
@@ -315,85 +233,72 @@ def test_delete_category_with_products(self):
test_user_email = f"{test_user_name}@example.com"
conn.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": test_user_name, "email": test_user_email}
+ """
+ ),
+ {"username": test_user_name, "email": test_user_email},
)
user_id_result = conn.execute(text("SELECT LAST_INSERT_ID()")).scalar()
# 创建测试店铺
conn.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": user_id_result}
+ """
+ ),
+ {"user_id": user_id_result},
)
store_id_result = conn.execute(text("SELECT LAST_INSERT_ID()")).scalar()
# 创建商品,关联到这个分类
conn.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, Price, StoreID, CategoryID, StockQuantity, ProductStatus)
VALUES ('测试商品', 99.99, :store_id, :category_id, 10, 'ACTIVE')
- """),
- {"store_id": store_id_result, "category_id": test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": store_id_result, "category_id": test_category["CategoryID"]},
)
# 执行测试 - 尝试删除有商品的分类
with self.assertRaises(ValueError):
- self.category_crud.delete_category(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
+ self.category_crud.delete_category(conn=self.connection, category_id=test_category["CategoryID"])
def test_get_category_tree(self):
"""测试获取分类树结构"""
# 准备数据 - 创建一些分类和子分类
# 顶级分类
- cat1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="电子产品"
- )
+ cat1 = self.category_crud.create_category(conn=self.connection, category_name="电子产品")
- cat2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="服装"
- )
+ cat2 = self.category_crud.create_category(conn=self.connection, category_name="服装")
# 电子产品的子分类
cat1_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="手机",
- parent_category_id=cat1["CategoryID"]
+ conn=self.connection, category_name="手机", parent_category_id=cat1["CategoryID"]
)
cat1_2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="电脑",
- parent_category_id=cat1["CategoryID"]
+ conn=self.connection, category_name="电脑", parent_category_id=cat1["CategoryID"]
)
# 服装的子分类
cat2_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="男装",
- parent_category_id=cat2["CategoryID"]
+ conn=self.connection, category_name="男装", parent_category_id=cat2["CategoryID"]
)
# 电脑的子分类
cat1_2_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="笔记本电脑",
- parent_category_id=cat1_2["CategoryID"]
+ conn=self.connection, category_name="笔记本电脑", parent_category_id=cat1_2["CategoryID"]
)
# 执行测试
- category_tree = self.category_crud.get_category_tree(
- conn=self.connection
- )
+ category_tree = self.category_crud.get_category_tree(conn=self.connection)
# 验证结果 - 顶级分类数量
self.assertIsInstance(category_tree, list)
@@ -419,5 +324,5 @@ def test_get_category_tree(self):
self.assertGreaterEqual(len(clothing_cat["Children"]), 1) # 至少一个子分类
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/integration/crud/test_product_crud_integration.py b/src/backend/test/integration/crud/test_product_crud_integration.py
index 25e85e4..ea29c65 100644
--- a/src/backend/test/integration/crud/test_product_crud_integration.py
+++ b/src/backend/test/integration/crud/test_product_crud_integration.py
@@ -14,7 +14,8 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
+
@unittest.skip("This test has been moved to test_product_crud_integration.py")
class TestProductCRUD(BaseDBTestCaseAutoRollback):
@@ -26,10 +27,7 @@ def setUp(self):
# 创建测试数据 - 创建分类
self.test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类",
- category_description="用于测试的商品分类",
- actor_id=None
+ conn=self.connection, category_name="测试分类", category_description="用于测试的商品分类", actor_id=None
)
# 创建测试数据 - 创建测试用户(使用随机用户名避免唯一键冲突)
@@ -38,22 +36,26 @@ def setUp(self):
test_email = f"{test_username}@example.com"
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": test_username, "email": test_email}
+ """
+ ),
+ {"username": test_username, "email": test_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
# 创建测试数据 - 创建店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_store_id = store_id_result[0]
@@ -68,14 +70,11 @@ def test_create_product(self):
"store_id": self.test_store_id,
"category_id": self.test_category["CategoryID"],
"stock_quantity": 100,
- "main_image_url": "http://example.com/test.jpg"
+ "main_image_url": "http://example.com/test.jpg",
}
# 执行测试
- product = self.product_crud.create_product(
- conn=self.connection,
- **product_data
- )
+ product = self.product_crud.create_product(conn=self.connection, **product_data)
# 验证结果
self.assertIsNotNone(product)
@@ -101,14 +100,11 @@ def test_get_product_by_id(self):
price=Decimal("199.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=50
+ stock_quantity=50,
)
# 执行测试
- product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=test_product["ProductID"]
- )
+ product = self.product_crud.get_product_by_id(conn=self.connection, product_id=test_product["ProductID"])
# 验证结果
self.assertIsNotNone(product)
@@ -116,10 +112,7 @@ def test_get_product_by_id(self):
self.assertEqual(product["ProductName"], "测试商品2")
# 测试获取不存在的商品
- non_existent_product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=999999
- )
+ non_existent_product = self.product_crud.get_product_by_id(conn=self.connection, product_id=999999)
self.assertIsNone(non_existent_product)
def test_get_product_with_category_info(self):
@@ -131,13 +124,12 @@ def test_get_product_with_category_info(self):
price=Decimal("299.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=30
+ stock_quantity=30,
)
# 执行测试
product_with_category = self.product_crud.get_product_with_category_info(
- conn=self.connection,
- product_id=test_product["ProductID"]
+ conn=self.connection, product_id=test_product["ProductID"]
)
# 验证结果
@@ -155,21 +147,19 @@ def test_update_product(self):
price=Decimal("399.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=40
+ stock_quantity=40,
)
# 准备更新数据
update_data = {
"productname": "更新后的商品名称",
"price": Decimal("499.99"),
- "productstatus": "INACTIVE_BY_MERCHANT"
+ "productstatus": "INACTIVE_BY_MERCHANT",
}
# 执行测试
updated_product = self.product_crud.update_product(
- conn=self.connection,
- product_id=test_product["ProductID"],
- update_data=update_data
+ conn=self.connection, product_id=test_product["ProductID"], update_data=update_data
)
# 验证结果
@@ -180,9 +170,7 @@ def test_update_product(self):
# 测试更新不存在的商品
non_existent_update = self.product_crud.update_product(
- conn=self.connection,
- product_id=999999,
- update_data={"productname": "不存在的商品"}
+ conn=self.connection, product_id=999999, update_data={"productname": "不存在的商品"}
)
# 现在应该返回None,因为商品不存在
self.assertIsNone(non_existent_update)
@@ -196,14 +184,12 @@ def test_update_product_stock(self):
price=Decimal("599.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=100
+ stock_quantity=100,
)
# 执行测试 - 增加库存
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=50
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=50
)
# 验证结果
@@ -212,9 +198,7 @@ def test_update_product_stock(self):
# 执行测试 - 减少库存
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=-30
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=-30
)
# 验证结果
@@ -224,9 +208,7 @@ def test_update_product_stock(self):
# 执行测试 - 减少超过当前库存的量
with self.assertRaises(InsufficientStockException):
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=-200
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=-200
)
# 验证结果 - 库存未更新
@@ -242,15 +224,12 @@ def test_get_products_by_store_id(self):
price=Decimal(f"{(i + 1) * 100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10 * (i + 1)
+ stock_quantity=10 * (i + 1),
)
# 执行测试
products = self.product_crud.get_products_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- limit=10,
- offset=0
+ conn=self.connection, store_id=self.test_store_id, limit=10, offset=0
)
# 验证结果
@@ -261,10 +240,7 @@ def test_get_products_by_store_id(self):
# 测试分页
products_page = self.product_crud.get_products_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- limit=2,
- offset=2
+ conn=self.connection, store_id=self.test_store_id, limit=2, offset=2
)
self.assertEqual(len(products_page), 2)
@@ -275,7 +251,7 @@ def test_get_products_by_category_id(self):
conn=self.connection,
category_name="新测试分类",
category_description="用于测试分类查询的商品分类",
- actor_id=None
+ actor_id=None,
)
# 创建不同分类下的商品
@@ -286,7 +262,7 @@ def test_get_products_by_category_id(self):
price=Decimal(f"{(i + 1) * 100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10 * (i + 1)
+ stock_quantity=10 * (i + 1),
)
for i in range(2):
@@ -296,18 +272,16 @@ def test_get_products_by_category_id(self):
price=Decimal(f"{(i + 1) * 200}.99"),
store_id=self.test_store_id,
category_id=new_category["CategoryID"],
- stock_quantity=20 * (i + 1)
+ stock_quantity=20 * (i + 1),
)
# 执行测试
products_cat1 = self.product_crud.get_products_by_category_id(
- conn=self.connection,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, category_id=self.test_category["CategoryID"]
)
products_cat2 = self.product_crud.get_products_by_category_id(
- conn=self.connection,
- category_id=new_category["CategoryID"]
+ conn=self.connection, category_id=new_category["CategoryID"]
)
# 验证结果
@@ -324,7 +298,7 @@ def test_search_products(self):
price=Decimal("5999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=50
+ stock_quantity=50,
)
self.product_crud.create_product(
@@ -334,7 +308,7 @@ def test_search_products(self):
price=Decimal("3999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=30
+ stock_quantity=30,
)
self.product_crud.create_product(
@@ -344,24 +318,15 @@ def test_search_products(self):
price=Decimal("4999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=40
+ stock_quantity=40,
)
# 执行测试 - 按名称搜索
- apple_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="苹果"
- )
+ apple_products = self.product_crud.search_products(conn=self.connection, search_term="苹果")
- huawei_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="华为"
- )
+ huawei_products = self.product_crud.search_products(conn=self.connection, search_term="华为")
- tablet_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="平板"
- )
+ tablet_products = self.product_crud.search_products(conn=self.connection, search_term="平板")
# 验证结果
self.assertEqual(len(apple_products), 2) # "苹果手机"和"苹果平板"
@@ -377,22 +342,18 @@ def test_delete_product(self):
price=Decimal("99.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10
+ stock_quantity=10,
)
# 执行测试
- result = self.product_crud.delete_product(
- conn=self.connection,
- product_id=test_product["ProductID"]
- )
+ result = self.product_crud.delete_product(conn=self.connection, product_id=test_product["ProductID"])
# 验证结果
self.assertTrue(result)
# 获取更新后的商品信息
deleted_product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=test_product["ProductID"]
+ conn=self.connection, product_id=test_product["ProductID"]
)
# 验证商品状态变为DISCONTINUED
@@ -400,10 +361,7 @@ def test_delete_product(self):
self.assertEqual(deleted_product["ProductStatus"], "DISCONTINUED")
# 测试删除不存在的商品
- non_existent_result = self.product_crud.delete_product(
- conn=self.connection,
- product_id=999999
- )
+ non_existent_result = self.product_crud.delete_product(conn=self.connection, product_id=999999)
self.assertFalse(non_existent_result)
def test_get_products_by_store_and_category(self):
@@ -413,7 +371,7 @@ def test_get_products_by_store_and_category(self):
conn=self.connection,
category_name="测试分类2",
category_description="第二个用于测试的商品分类",
- actor_id=None
+ actor_id=None,
)
# 创建商品到第一个测试分类和测试店铺
@@ -424,7 +382,7 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i + 1) * 100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10 * (i + 1)
+ stock_quantity=10 * (i + 1),
)
# 创建商品到第二个测试分类和测试店铺
@@ -435,16 +393,18 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i + 1) * 200}.99"),
store_id=self.test_store_id,
category_id=new_category["CategoryID"],
- stock_quantity=20 * (i + 1)
+ stock_quantity=20 * (i + 1),
)
# 创建第二个店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺2', :user_id, '第二个用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
test_store_id_2 = store_id_result[0]
@@ -457,26 +417,20 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i + 1) * 150}.99"),
store_id=test_store_id_2,
category_id=self.test_category["CategoryID"],
- stock_quantity=15 * (i + 1)
+ stock_quantity=15 * (i + 1),
)
# 执行测试
products_store1_cat1 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=self.test_store_id,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, store_id=self.test_store_id, category_id=self.test_category["CategoryID"]
)
products_store1_cat2 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=self.test_store_id,
- category_id=new_category["CategoryID"]
+ conn=self.connection, store_id=self.test_store_id, category_id=new_category["CategoryID"]
)
products_store2_cat1 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=test_store_id_2,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, store_id=test_store_id_2, category_id=self.test_category["CategoryID"]
)
# 验证结果
@@ -498,36 +452,33 @@ def test_get_filtered_products(self):
price=Decimal(str(price)),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=100
+ stock_quantity=100,
)
# 如果需要非默认状态,则更新状态
if status != "ACTIVE":
self.connection.execute(
- text(f"""
+ text(
+ f"""
UPDATE Product SET ProductStatus = :status
WHERE ProductID = :product_id
- """),
- {"status": status, "product_id": product["ProductID"]}
+ """
+ ),
+ {"status": status, "product_id": product["ProductID"]},
)
# 测试价格区间过滤
- low_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- max_price=Decimal("200.00")
- )
+ low_price_products = self.product_crud.get_filtered_products(conn=self.connection, max_price=Decimal("200.00"))
# 获取不同状态的高价格商品
high_price_active_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- min_price=Decimal("300.00"),
- product_status="ACTIVE" # 显式查询ACTIVE状态的高价格商品
+ conn=self.connection, min_price=Decimal("300.00"), product_status="ACTIVE" # 显式查询ACTIVE状态的高价格商品
)
high_price_discontinued_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("300.00"),
- product_status="DISCONTINUED" # 显式查询DISCONTINUED状态的高价格商品
+ product_status="DISCONTINUED", # 显式查询DISCONTINUED状态的高价格商品
)
# 合并所有高价格商品结果
@@ -538,40 +489,34 @@ def test_get_filtered_products(self):
conn=self.connection,
min_price=Decimal("150.00"),
max_price=Decimal("350.00"),
- product_status="ACTIVE" # 显式查询ACTIVE状态的中价格商品
+ product_status="ACTIVE", # 显式查询ACTIVE状态的中价格商品
)
mid_price_inactive_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("150.00"),
max_price=Decimal("350.00"),
- product_status="INACTIVE_BY_MERCHANT" # 显式查询INACTIVE_BY_MERCHANT状态的中价格商品
+ product_status="INACTIVE_BY_MERCHANT", # 显式查询INACTIVE_BY_MERCHANT状态的中价格商品
)
# 合并所有中价格商品结果
mid_price_products = mid_price_active_products + mid_price_inactive_products
# 测试状态过滤
- active_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- product_status="ACTIVE"
- )
+ active_products = self.product_crud.get_filtered_products(conn=self.connection, product_status="ACTIVE")
inactive_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- product_status="INACTIVE_BY_MERCHANT"
+ conn=self.connection, product_status="INACTIVE_BY_MERCHANT"
)
# 测试排序
asc_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- order_by="price_asc" # 使用新的排序格式
- )
+ conn=self.connection, order_by="price_asc"
+ ) # 使用新的排序格式
desc_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- order_by="price_desc" # 使用新的排序格式
- )
+ conn=self.connection, order_by="price_desc"
+ ) # 使用新的排序格式
# 组合多重过滤和排序
filtered_sorted_products = self.product_crud.get_filtered_products(
@@ -579,7 +524,7 @@ def test_get_filtered_products(self):
min_price=Decimal("100.00"),
max_price=Decimal("400.00"),
product_status="ACTIVE",
- order_by="price_desc" # 使用新的排序格式
+ order_by="price_desc", # 使用新的排序格式
)
# 验证结果
@@ -593,12 +538,20 @@ def test_get_filtered_products(self):
self.assertEqual(len(inactive_products), 1) # INACTIVE_BY_MERCHANT状态的商品
# 验证价格升序排序
- self.assertTrue(all(asc_price_products[i]["Price"] <= asc_price_products[i + 1]["Price"]
- for i in range(len(asc_price_products) - 1)))
+ self.assertTrue(
+ all(
+ asc_price_products[i]["Price"] <= asc_price_products[i + 1]["Price"]
+ for i in range(len(asc_price_products) - 1)
+ )
+ )
# 验证价格降序排序
- self.assertTrue(all(desc_price_products[i]["Price"] >= desc_price_products[i + 1]["Price"]
- for i in range(len(desc_price_products) - 1)))
+ self.assertTrue(
+ all(
+ desc_price_products[i]["Price"] >= desc_price_products[i + 1]["Price"]
+ for i in range(len(desc_price_products) - 1)
+ )
+ )
# 验证组合过滤和排序
self.assertEqual(len(filtered_sorted_products), 2) # 符合条件的商品
@@ -606,5 +559,5 @@ def test_get_filtered_products(self):
self.assertTrue(filtered_sorted_products[0]["Price"] > filtered_sorted_products[1]["Price"]) # 降序
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/integration/crud/test_user_crud_integration.py b/src/backend/test/integration/crud/test_user_crud_integration.py
index 5d2e8b6..142ad61 100644
--- a/src/backend/test/integration/crud/test_user_crud_integration.py
+++ b/src/backend/test/integration/crud/test_user_crud_integration.py
@@ -51,7 +51,6 @@ def test_create_user_return_value(self):
self.assertIsNone(user["Email"])
self.assertIsNotNone(user["RegistrationDate"])
-
def test_create_user_with_all_fields(self):
# Test creating a user with all fields
user_data = {
@@ -100,7 +99,5 @@ def test_create_users_with_same_username(self):
self.assertIn("Duplicate", str(e))
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/integration/crud/test_user_session_crud_integration.py b/src/backend/test/integration/crud/test_user_session_crud_integration.py
index 295346b..13f9afe 100644
--- a/src/backend/test/integration/crud/test_user_session_crud_integration.py
+++ b/src/backend/test/integration/crud/test_user_session_crud_integration.py
@@ -11,6 +11,7 @@
# No more @patch for datetime needed for these tests if we use real time.
# from unittest.mock import patch, MagicMock, ANY
+
class TestUserSessionCRUDIntegration(BaseDBTestCaseAutoRollback):
def setUp(self):
@@ -20,22 +21,23 @@ def setUp(self):
try:
user_exists_result = self.connection.execute(
- text("SELECT UserID FROM `User` WHERE UserID = :user_id"),
- {"user_id": self.test_user_id}
+ text("SELECT UserID FROM `User` WHERE UserID = :user_id"), {"user_id": self.test_user_id}
).fetchone()
if not user_exists_result:
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate)
VALUES (:user_id, :username, :password_hash, :email, 'customer', 'ACTIVE', UTC_TIMESTAMP())
- """),
+ """
+ ),
{
"user_id": self.test_user_id,
"username": f"integ_user_{self.test_user_id}",
"password_hash": "dummy_hash_for_testing",
- "email": f"integ_user{self.test_user_id}@example.com"
- }
+ "email": f"integ_user{self.test_user_id}@example.com",
+ },
)
except Exception as e:
self.fail(f"Failed to set up prerequisite test user (UserID={self.test_user_id}): {e}")
@@ -52,7 +54,7 @@ def test_create_session_raises_error_for_naive_datetime_expires_at(self):
user_id=self.test_user_id,
expires_at=naive_expires_at,
ip_address="192.168.1.1",
- user_agent="TestAgent Naive Error"
+ user_agent="TestAgent Naive Error",
)
retrieved_session = self.user_session_crud.get_session_by_token(self.connection, token=token_to_create)
self.assertIsNone(retrieved_session, "Session should not have been created with naive datetime.")
@@ -71,7 +73,7 @@ def test_create_session_with_aware_utc_datetime_expires_at(self):
user_id=self.test_user_id,
expires_at=aware_utc_expires_at_input,
ip_address="192.168.1.2",
- user_agent="TestAgent Aware UTC Valid"
+ user_agent="TestAgent Aware UTC Valid",
)
self.assertIsNotNone(created_session)
@@ -82,8 +84,9 @@ def test_create_session_with_aware_utc_datetime_expires_at(self):
# 由于数据库 NOW() 和 Python now() 可能有微秒差异,比较时可以允许一定容差
# 或者,如果 create_session 返回的 ExpiresAt 是它直接存入的值,可以精确比较
self.assertIsInstance(created_session["ExpiresAt"], datetime.datetime)
- self.assertAlmostEqual(created_session["ExpiresAt"], expected_stored_naive_utc,
- delta=datetime.timedelta(seconds=1))
+ self.assertAlmostEqual(
+ created_session["ExpiresAt"], expected_stored_naive_utc, delta=datetime.timedelta(seconds=1)
+ )
self.assertIsNotNone(created_session["CreatedAt"])
self.assertIsInstance(created_session["CreatedAt"], datetime.datetime)
@@ -93,8 +96,9 @@ def test_create_session_with_aware_utc_datetime_expires_at(self):
retrieved_session = self.user_session_crud.get_session_by_token(self.connection, token=token_to_create)
self.assertIsNotNone(retrieved_session)
self.assertIsInstance(retrieved_session["ExpiresAt"], datetime.datetime)
- self.assertAlmostEqual(retrieved_session["ExpiresAt"], expected_stored_naive_utc,
- delta=datetime.timedelta(seconds=1))
+ self.assertAlmostEqual(
+ retrieved_session["ExpiresAt"], expected_stored_naive_utc, delta=datetime.timedelta(seconds=1)
+ )
def test_create_session_with_aware_cst_datetime_expires_at(self):
"""测试使用 timezone-aware CST (UTC+8) datetime 创建会话"""
@@ -110,21 +114,23 @@ def test_create_session_with_aware_cst_datetime_expires_at(self):
user_id=self.test_user_id,
expires_at=aware_cst_expires_at_input,
ip_address="192.168.1.3",
- user_agent="TestAgent Aware CST Valid"
+ user_agent="TestAgent Aware CST Valid",
)
self.assertIsNotNone(created_session)
self.assertEqual(created_session["SessionToken"], token_to_create)
expected_stored_naive_utc = aware_cst_expires_at_input.astimezone(datetime.timezone.utc).replace(tzinfo=None)
self.assertIsInstance(created_session["ExpiresAt"], datetime.datetime)
- self.assertAlmostEqual(created_session["ExpiresAt"], expected_stored_naive_utc,
- delta=datetime.timedelta(seconds=1))
+ self.assertAlmostEqual(
+ created_session["ExpiresAt"], expected_stored_naive_utc, delta=datetime.timedelta(seconds=1)
+ )
retrieved_session = self.user_session_crud.get_session_by_token(self.connection, token=token_to_create)
self.assertIsNotNone(retrieved_session)
self.assertIsInstance(retrieved_session["ExpiresAt"], datetime.datetime)
- self.assertAlmostEqual(retrieved_session["ExpiresAt"], expected_stored_naive_utc,
- delta=datetime.timedelta(seconds=1))
+ self.assertAlmostEqual(
+ retrieved_session["ExpiresAt"], expected_stored_naive_utc, delta=datetime.timedelta(seconds=1)
+ )
def test_get_session_by_token_non_existent(self):
not_found_session = self.user_session_crud.get_session_by_token(self.connection, token="token_does_not_exist")
@@ -174,8 +180,10 @@ def test_get_active_user_id_and_update_access_expired_session(self):
aware_utc_expires_at_past_for_create = now_utc_for_create - datetime.timedelta(hours=1)
self.user_session_crud.create_session(
- self.connection, session_token=token, user_id=self.test_user_id,
- expires_at=aware_utc_expires_at_past_for_create
+ self.connection,
+ session_token=token,
+ user_id=self.test_user_id,
+ expires_at=aware_utc_expires_at_past_for_create,
)
active_user_id = self.user_session_crud.get_active_user_id_by_token_and_update_access(
@@ -212,8 +220,9 @@ def test_get_active_user_id_update_access_and_expiration(self):
expected_stored_expires_at_naive = new_expires_at_param_aware.replace(tzinfo=None)
self.assertIsInstance(session_after_update["ExpiresAt"], datetime.datetime)
- self.assertAlmostEqual(session_after_update["ExpiresAt"], expected_stored_expires_at_naive,
- delta=datetime.timedelta(seconds=1))
+ self.assertAlmostEqual(
+ session_after_update["ExpiresAt"], expected_stored_expires_at_naive, delta=datetime.timedelta(seconds=1)
+ )
self.assertIsInstance(session_after_update["LastAccessedAt"], datetime.datetime)
if initial_db_last_accessed_at: # type: ignore
@@ -232,14 +241,16 @@ def test_delete_session_by_token(self):
)
self.assertIsNotNone(self.user_session_crud.get_session_by_token(self.connection, token=token_to_delete))
- deleted = self.user_session_crud.delete_session_by_token(self.connection, token=token_to_delete,
- actor_user_id=self.test_user_id)
+ deleted = self.user_session_crud.delete_session_by_token(
+ self.connection, token=token_to_delete, actor_user_id=self.test_user_id
+ )
self.assertTrue(deleted)
self.assertIsNone(self.user_session_crud.get_session_by_token(self.connection, token=token_to_delete))
- deleted_non_existent = self.user_session_crud.delete_session_by_token(self.connection,
- token="does_not_exist_realtime")
+ deleted_non_existent = self.user_session_crud.delete_session_by_token(
+ self.connection, token="does_not_exist_realtime"
+ )
self.assertFalse(deleted_non_existent)
def test_delete_all_sessions_for_user(self):
@@ -247,36 +258,55 @@ def test_delete_all_sessions_for_user(self):
other_user_id = self.test_user_id + 1
try:
- user_exists_result = self.connection.execute(text("SELECT UserID FROM `User` WHERE UserID = :id"),
- {"id": other_user_id}).fetchone()
+ user_exists_result = self.connection.execute(
+ text("SELECT UserID FROM `User` WHERE UserID = :id"), {"id": other_user_id}
+ ).fetchone()
if not user_exists_result:
self.connection.execute(
text(
- "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:id, :uname, 'hash', :email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"),
- {"id": other_user_id, "uname": f"other_user_{other_user_id}",
- "email": f"other{other_user_id}@test.com"}
+ "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:id, :uname, 'hash', :email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"
+ ),
+ {
+ "id": other_user_id,
+ "uname": f"other_user_{other_user_id}",
+ "email": f"other{other_user_id}@test.com",
+ },
)
except Exception as e:
self.fail(f"Failed to set up other_user for test_delete_all_sessions_for_user: {e}")
aware_utc_expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
- self.user_session_crud.create_session(self.connection, session_token="user1_token1_realtime",
- user_id=user_to_clear_id, expires_at=aware_utc_expires_at)
- self.user_session_crud.create_session(self.connection, session_token="user1_token2_realtime",
- user_id=user_to_clear_id, expires_at=aware_utc_expires_at)
- self.user_session_crud.create_session(self.connection, session_token="user2_token1_realtime",
- user_id=other_user_id, expires_at=aware_utc_expires_at)
-
- deleted_count = self.user_session_crud.delete_all_sessions_for_user(self.connection, user_id=user_to_clear_id,
- actor_user_id=user_to_clear_id)
+ self.user_session_crud.create_session(
+ self.connection,
+ session_token="user1_token1_realtime",
+ user_id=user_to_clear_id,
+ expires_at=aware_utc_expires_at,
+ )
+ self.user_session_crud.create_session(
+ self.connection,
+ session_token="user1_token2_realtime",
+ user_id=user_to_clear_id,
+ expires_at=aware_utc_expires_at,
+ )
+ self.user_session_crud.create_session(
+ self.connection,
+ session_token="user2_token1_realtime",
+ user_id=other_user_id,
+ expires_at=aware_utc_expires_at,
+ )
+
+ deleted_count = self.user_session_crud.delete_all_sessions_for_user(
+ self.connection, user_id=user_to_clear_id, actor_user_id=user_to_clear_id
+ )
self.assertEqual(deleted_count, 2)
self.assertIsNone(self.user_session_crud.get_session_by_token(self.connection, token="user1_token1_realtime"))
self.assertIsNone(self.user_session_crud.get_session_by_token(self.connection, token="user1_token2_realtime"))
self.assertIsNotNone(
- self.user_session_crud.get_session_by_token(self.connection, token="user2_token1_realtime"))
+ self.user_session_crud.get_session_by_token(self.connection, token="user2_token1_realtime")
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/services/test_cart_service_integration.py b/src/backend/test/integration/services/test_cart_service_integration.py
index 9765943..b86b1e6 100644
--- a/src/backend/test/integration/services/test_cart_service_integration.py
+++ b/src/backend/test/integration/services/test_cart_service_integration.py
@@ -10,8 +10,11 @@
from backend.app.crud.cartitem_crud import CartItemCRUD # Assuming CartItemCRUD is used by CartService
from backend.app.crud.product_crud import ProductCRUD # Assuming ProductCRUD is used by CartService
from backend.app.schemas.cartitem_schema import CartResponse, CartItemResponse
-from backend.app.utils.exceptions import ProductNotFoundException, CartItemNotFoundException, \
- ProductFieldMissingException
+from backend.app.utils.exceptions import (
+ ProductNotFoundException,
+ CartItemNotFoundException,
+ ProductFieldMissingException,
+)
class TestCartServiceIntegration(AsyncBaseDBTestCaseAutoRollback):
@@ -24,10 +27,7 @@ def setUp(self):
self.cart_item_crud = CartItemCRUD.get_instance()
self.product_crud = ProductCRUD.get_instance()
- self.cart_service = CartService(
- cart_item_crud=self.cart_item_crud,
- product_crud=self.product_crud
- )
+ self.cart_service = CartService(cart_item_crud=self.cart_item_crud, product_crud=self.product_crud)
self.test_user_id = 1
self.another_user_id = 2
@@ -50,10 +50,16 @@ def _setup_initial_data(self):
try:
# 1. 创建测试用户
users_to_create = [
- {"UserID": self.test_user_id, "Username": "cart_svc_integ_user1",
- "Email": "cartsvc_integ1@example.com"},
- {"UserID": self.another_user_id, "Username": "cart_svc_integ_user2",
- "Email": "cartsvc_integ2@example.com"}
+ {
+ "UserID": self.test_user_id,
+ "Username": "cart_svc_integ_user1",
+ "Email": "cartsvc_integ1@example.com",
+ },
+ {
+ "UserID": self.another_user_id,
+ "Username": "cart_svc_integ_user2",
+ "Email": "cartsvc_integ2@example.com",
+ },
]
for user_data in users_to_create:
user_exists = self.connection.execute(
@@ -62,55 +68,74 @@ def _setup_initial_data(self):
if not user_exists:
self.connection.execute(
text(
- "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, 'hash', :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"),
- user_data
+ "INSERT INTO `User` (UserID, Username, PasswordHash, Email, UserRole, AccountStatus, RegistrationDate) VALUES (:UserID, :Username, 'hash', :Email, 'customer', 'ACTIVE', UTC_TIMESTAMP())"
+ ),
+ user_data,
)
# 2. 创建商品分类
self.category1_id = 1010
- cat_exists = self.connection.execute(text("SELECT CategoryID FROM ProductCategory WHERE CategoryID = :id"),
- {"id": self.category1_id}).fetchone()
+ cat_exists = self.connection.execute(
+ text("SELECT CategoryID FROM ProductCategory WHERE CategoryID = :id"), {"id": self.category1_id}
+ ).fetchone()
if not cat_exists:
self.connection.execute(
text(
- "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, :name) ON DUPLICATE KEY UPDATE CategoryName = VALUES(CategoryName)"),
- {"id": self.category1_id, "name": "服务集成测试分类"}
+ "INSERT INTO ProductCategory (CategoryID, CategoryName) VALUES (:id, :name) ON DUPLICATE KEY UPDATE CategoryName = VALUES(CategoryName)"
+ ),
+ {"id": self.category1_id, "name": "服务集成测试分类"},
)
# 3. 创建店铺 (属于 test_user_id)
self.store1_id = 2010
- store_exists = self.connection.execute(text("SELECT StoreID FROM Store WHERE StoreID = :id"),
- {"id": self.store1_id}).fetchone()
+ store_exists = self.connection.execute(
+ text("SELECT StoreID FROM Store WHERE StoreID = :id"), {"id": self.store1_id}
+ ).fetchone()
if not store_exists:
self.connection.execute(
text(
- "INSERT INTO Store (StoreID, StoreName, OwnerUserID, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:id, :name, :owner_id, 'ACTIVE', UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE StoreName = VALUES(StoreName)"),
- {"id": self.store1_id, "name": "服务集成测试店铺", "owner_id": self.test_user_id}
+ "INSERT INTO Store (StoreID, StoreName, OwnerUserID, StoreStatus, CreationDate, LastUpdatedDate) VALUES (:id, :name, :owner_id, 'ACTIVE', UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE StoreName = VALUES(StoreName)"
+ ),
+ {"id": self.store1_id, "name": "服务集成测试店铺", "owner_id": self.test_user_id},
)
# 4. 创建商品
self.product1_id = 3010
self.product1_price = Decimal("19.99")
- prod1_exists = self.connection.execute(text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": self.product1_id}).fetchone()
+ prod1_exists = self.connection.execute(
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": self.product1_id}
+ ).fetchone()
if not prod1_exists:
self.connection.execute(
text(
- "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 100, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE ProductName = VALUES(ProductName)"),
- {"id": self.product1_id, "name": "服务集成商品A", "price": self.product1_price,
- "store_id": self.store1_id, "cat_id": self.category1_id}
+ "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 100, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE ProductName = VALUES(ProductName)"
+ ),
+ {
+ "id": self.product1_id,
+ "name": "服务集成商品A",
+ "price": self.product1_price,
+ "store_id": self.store1_id,
+ "cat_id": self.category1_id,
+ },
)
self.product2_id = 3020
self.product2_price = Decimal("45.50")
- prod2_exists = self.connection.execute(text("SELECT ProductID FROM Product WHERE ProductID = :id"),
- {"id": self.product2_id}).fetchone()
+ prod2_exists = self.connection.execute(
+ text("SELECT ProductID FROM Product WHERE ProductID = :id"), {"id": self.product2_id}
+ ).fetchone()
if not prod2_exists:
self.connection.execute(
text(
- "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 50, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE ProductName = VALUES(ProductName)"),
- {"id": self.product2_id, "name": "服务集成商品B", "price": self.product2_price,
- "store_id": self.store1_id, "cat_id": self.category1_id}
+ "INSERT INTO Product (ProductID, ProductName, Price, ProductStatus, StoreID, CategoryID, StockQuantity, CreationDate, LastUpdatedDate) VALUES (:id, :name, :price, 'ACTIVE', :store_id, :cat_id, 50, UTC_TIMESTAMP(), UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE ProductName = VALUES(ProductName)"
+ ),
+ {
+ "id": self.product2_id,
+ "name": "服务集成商品B",
+ "price": self.product2_price,
+ "store_id": self.store1_id,
+ "cat_id": self.category1_id,
+ },
)
except Exception as e:
self.fail(f"Error setting up initial test data for CartService integration tests: {e}")
@@ -127,25 +152,32 @@ async def test_add_item_to_cart_new_and_get_details(self):
user_id=self.test_user_id,
product_id=self.product1_id,
quantity_to_add=2,
- actor_id=self.actor_id
+ actor_id=self.actor_id,
)
self.assertIsInstance(added_item_a_response, CartItemResponse)
self.assertEqual(added_item_a_response.ProductID, self.product1_id)
self.assertEqual(added_item_a_response.Quantity, 2)
- db_item_a = self.cart_item_crud.get_cart_item_by_id(self.connection,
- cart_item_id=added_item_a_response.CartItemID)
+ db_item_a = self.cart_item_crud.get_cart_item_by_id(
+ self.connection, cart_item_id=added_item_a_response.CartItemID
+ )
self.assertIsNotNone(db_item_a)
self.assertEqual(db_item_a["Quantity"], 2) # type: ignore
async def test_add_item_to_cart_update_existing_quantity(self):
await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=1, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=1,
+ actor_id=self.actor_id,
)
updated_item_response = await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=3, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=3,
+ actor_id=self.actor_id,
)
self.assertEqual(updated_item_response.ProductID, self.product1_id)
self.assertEqual(updated_item_response.Quantity, 1 + 3)
@@ -153,8 +185,11 @@ async def test_add_item_to_cart_update_existing_quantity(self):
async def test_add_item_to_cart_product_not_found(self):
with self.assertRaises(ProductNotFoundException):
await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=9999,
- quantity_to_add=1, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=9999,
+ quantity_to_add=1,
+ actor_id=self.actor_id,
)
async def test_add_item_to_cart_product_missing_price(self):
@@ -170,21 +205,23 @@ async def test_add_item_to_cart_product_missing_price(self):
"ProductStatus": "ACTIVE",
"StoreID": self.store1_id,
"CategoryID": self.category1_id,
- "StockQuantity": 100
+ "StockQuantity": 100,
}
# Patch the get_product_by_id method of the self.product_crud instance
# specifically for this test.
- with patch.object(self.product_crud, 'get_product_by_id',
- return_value=mock_product_data_missing_price) as mock_get_product:
- with self.assertRaisesRegex(ProductFieldMissingException,
- f"Price not available for product {self.product1_id}"):
+ with patch.object(
+ self.product_crud, "get_product_by_id", return_value=mock_product_data_missing_price
+ ) as mock_get_product:
+ with self.assertRaisesRegex(
+ ProductFieldMissingException, f"Price not available for product {self.product1_id}"
+ ):
await self.cart_service.add_item_to_cart(
db=self.connection,
user_id=self.test_user_id,
product_id=self.product1_id, # Use an existing product ID
quantity_to_add=1,
- actor_id=self.actor_id
+ actor_id=self.actor_id,
)
mock_get_product.assert_called_once_with(
conn=self.connection, product_id=self.product1_id, actor_id=self.actor_id
@@ -193,20 +230,28 @@ async def test_add_item_to_cart_product_missing_price(self):
async def test_add_item_to_cart_invalid_quantity(self):
with self.assertRaisesRegex(ValueError, "Quantity to add must be positive."):
await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=0, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=0,
+ actor_id=self.actor_id,
)
async def test_update_cart_item_quantity_success(self):
added_item = await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=2, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=2,
+ actor_id=self.actor_id,
)
cart_item_id_to_update = added_item.CartItemID
updated_item_response = await self.cart_service.update_cart_item_quantity(
- db=self.connection, cart_item_id=cart_item_id_to_update, new_quantity=5,
- user_id_making_change=self.test_user_id
+ db=self.connection,
+ cart_item_id=cart_item_id_to_update,
+ new_quantity=5,
+ user_id_making_change=self.test_user_id,
)
self.assertEqual(updated_item_response.CartItemID, cart_item_id_to_update)
self.assertEqual(updated_item_response.Quantity, 5)
@@ -214,49 +259,60 @@ async def test_update_cart_item_quantity_success(self):
async def test_update_cart_item_quantity_item_not_found(self):
with self.assertRaises(CartItemNotFoundException):
await self.cart_service.update_cart_item_quantity(
- db=self.connection, cart_item_id=9999, new_quantity=5,
- user_id_making_change=self.test_user_id
+ db=self.connection, cart_item_id=9999, new_quantity=5, user_id_making_change=self.test_user_id
)
async def test_update_cart_item_quantity_invalid_new_quantity(self):
added_item = await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=2, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=2,
+ actor_id=self.actor_id,
)
with self.assertRaisesRegex(ValueError, "Quantity must be positive."):
await self.cart_service.update_cart_item_quantity(
- db=self.connection, cart_item_id=added_item.CartItemID, new_quantity=0,
- user_id_making_change=self.test_user_id
+ db=self.connection,
+ cart_item_id=added_item.CartItemID,
+ new_quantity=0,
+ user_id_making_change=self.test_user_id,
)
async def test_update_cart_item_permission_denied_logs_warning(self):
added_item = await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=1, actor_id=self.test_user_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=1,
+ actor_id=self.test_user_id,
)
cart_item_id_of_user1 = added_item.CartItemID
# 如果 CartService 现在直接使用全局/模块级 logger (例如 from loguru import logger)
# 你需要 patch 那个 logger 的路径,例如 'backend.app.services.cart_service.logger'
- with patch('backend.app.services.cart_service.logger') as mock_module_logger:
+ with patch("backend.app.services.cart_service.logger") as mock_module_logger:
await self.cart_service.update_cart_item_quantity(
db=self.connection,
cart_item_id=cart_item_id_of_user1,
new_quantity=7,
- user_id_making_change=self.another_user_id
+ user_id_making_change=self.another_user_id,
)
mock_module_logger.warning.assert_any_call(
f"User {self.another_user_id} tries to update CartItemID {cart_item_id_of_user1} belonging to user {self.test_user_id}."
)
- updated_item_in_db = self.cart_item_crud.get_cart_item_by_id(self.connection,
- cart_item_id=cart_item_id_of_user1)
+ updated_item_in_db = self.cart_item_crud.get_cart_item_by_id(
+ self.connection, cart_item_id=cart_item_id_of_user1
+ )
self.assertEqual(updated_item_in_db["Quantity"], 7) # type: ignore
async def test_remove_cart_item_success(self):
added_item = await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=1, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=1,
+ actor_id=self.actor_id,
)
cart_item_id_to_remove = added_item.CartItemID
@@ -274,12 +330,18 @@ async def test_remove_cart_item_not_found_returns_false(self):
async def test_clear_cart_success(self):
await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product1_id,
- quantity_to_add=2, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product1_id,
+ quantity_to_add=2,
+ actor_id=self.actor_id,
)
await self.cart_service.add_item_to_cart(
- db=self.connection, user_id=self.test_user_id, product_id=self.product2_id,
- quantity_to_add=1, actor_id=self.actor_id
+ db=self.connection,
+ user_id=self.test_user_id,
+ product_id=self.product2_id,
+ quantity_to_add=1,
+ actor_id=self.actor_id,
)
cart_before = await self.cart_service.get_user_cart_details(db=self.connection, user_id=self.test_user_id)
@@ -301,5 +363,5 @@ async def test_clear_cart_empty(self):
self.assertEqual(deleted_count, 0)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/integration/test_sql_objects/test_base_class.py b/src/backend/test/integration/test_sql_objects/test_base_class.py
index 8417bf9..b784bb8 100644
--- a/src/backend/test/integration/test_sql_objects/test_base_class.py
+++ b/src/backend/test/integration/test_sql_objects/test_base_class.py
@@ -1,6 +1,7 @@
from backend.test.base_db_testcase import BaseDBTestCaseAutoRollback, BaseDBTestCase
from sqlalchemy import text, MetaData, Table, Column, Integer, String
+
class TestBaseDBTestCase(BaseDBTestCaseAutoRollback):
def test_context(self):
@@ -24,7 +25,7 @@ def test_metadata(self):
"""
metadata = MetaData()
# Reflect `User` table
- user_table = Table('User', metadata, autoload_with=self.engine)
+ user_table = Table("User", metadata, autoload_with=self.engine)
# Check if the table has been loaded
print(f"Reflected table: {user_table.name}")
@@ -51,7 +52,3 @@ def test_metadata(self):
self.assertIsNotNone(row, "Row should not be None")
print(f"Inserted row: {row}")
-
-
-
-
diff --git a/src/backend/test/integration/test_sql_objects/test_base_class_2.py b/src/backend/test/integration/test_sql_objects/test_base_class_2.py
index 80725ac..0b0eb31 100644
--- a/src/backend/test/integration/test_sql_objects/test_base_class_2.py
+++ b/src/backend/test/integration/test_sql_objects/test_base_class_2.py
@@ -25,7 +25,7 @@ def test_metadata(self):
"""
metadata = MetaData()
# Reflect `User` table
- user_table = Table('User', metadata, autoload_with=self.engine)
+ user_table = Table("User", metadata, autoload_with=self.engine)
# Check if the table has been loaded
print(f"Reflected table: {user_table.name}")
diff --git a/src/backend/test/unit/api/test_product_change_request_api.py b/src/backend/test/unit/api/test_product_change_request_api.py
index a7e640e..33be3da 100644
--- a/src/backend/test/unit/api/test_product_change_request_api.py
+++ b/src/backend/test/unit/api/test_product_change_request_api.py
@@ -13,14 +13,11 @@
from backend.app.schemas.product_change_request_schema import (
ProductChangeRequestCreate,
ProductChangeRequestUpdate,
- ProductChangeRequestResponse
+ ProductChangeRequestResponse,
)
from backend.app.utils.exceptions import PermissionDeniedException
-
-
-
# Setup sample request data
# 使用固定的时间戳以确保测试结果一致
_fixed_timestamp = "2023-01-01 12:00:00.000000"
@@ -31,19 +28,23 @@
"MerchantUserID": 1,
"StoreID": 1,
"RequestType": "PRODUCT_UPDATE",
- "ProposedData_JSON": {"ProductName": "Updated Product", "ProductDescription": "Updated Description", "Price": 199.99},
+ "ProposedData_JSON": {
+ "ProductName": "Updated Product",
+ "ProductDescription": "Updated Description",
+ "Price": 199.99,
+ },
"Status": "PENDING_APPROVAL",
"SubmitterNotes": "Please approve my product update",
"AdminReviewerID": None,
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": _fixed_timestamp,
- "LastUpdatedDate": _fixed_timestamp
+ "LastUpdatedDate": _fixed_timestamp,
}
class TestProductChangeRequestAPI(unittest.TestCase):
-
+
def setUp(self):
# Create FastAPI app for testing
self.app = FastAPI()
@@ -51,16 +52,16 @@ def setUp(self):
self.app.include_router(router, prefix="/api/v1/product-change-requests")
# Setup test client
self.client = TestClient(self.app)
-
+
# Create a properly configured mock database connection
self.mock_db = AsyncMock()
self.mock_db.execute = AsyncMock()
self.mock_db.execute.return_value.fetchall = MagicMock(return_value=[])
self.mock_db.execute.return_value.fetchone = MagicMock(return_value=None)
-
+
# Create a mock service that properly handles async operations
self.mock_service = AsyncMock(spec=ProductChangeRequestService)
-
+
# Set up default successful responses for all service methods
self.mock_service.get_change_request_by_id.return_value = sample_change_request
self.mock_service.get_change_requests_by_product_id.return_value = [sample_change_request]
@@ -72,47 +73,54 @@ def setUp(self):
self.mock_service.update_request.return_value = sample_change_request
self.mock_service.update_request_status.return_value = sample_change_request
self.mock_service.cancel_request.return_value = True
-
+
# Override dependencies - we need to override them in both the app and router
self.app.dependency_overrides = {
get_db_connection: lambda: self.mock_db,
get_current_active_user: self.mock_get_current_active_user,
- get_product_change_request_service: lambda: self.mock_service
+ get_product_change_request_service: lambda: self.mock_service,
}
-
+
# Also override them at the router level to ensure proper mocking
router.dependency_overrides = {
get_db_connection: lambda: self.mock_db,
get_current_active_user: self.mock_get_current_active_user,
- get_product_change_request_service: lambda: self.mock_service
+ get_product_change_request_service: lambda: self.mock_service,
}
-
+
def mock_get_current_active_user(self):
return {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+
def mock_get_admin_user(self):
return {"UserID": 2, "Username": "adminuser", "Role": "admin"}
-
+
def test_create_product_change_request(self):
# Define test data with required MerchantUserID field
request_data = {
"RequestType": "PRODUCT_UPDATE",
- "ProposedData_JSON": {"ProductName": "Updated Product", "ProductDescription": "Updated Description", "Price": 199.99},
+ "ProposedData_JSON": {
+ "ProductName": "Updated Product",
+ "ProductDescription": "Updated Description",
+ "Price": 199.99,
+ },
"StoreID": 1,
"MerchantUserID": 1,
"ProductID": 1,
- "SubmitterNotes": "Please approve my product update"
+ "SubmitterNotes": "Please approve my product update",
}
-
+
# Set up specific mock for this test
self.mock_service.create_change_request = AsyncMock(return_value=sample_change_request)
-
+
# Execute request
response = self.client.post("/api/v1/product-change-requests", json=request_data)
-
+
# Assertions - note that we accept 200 or 201 as valid status codes for create operations
# since the implementation might return either one
- assert response.status_code in [200, 201], f"Expected status code 200 or 201, got {response.status_code}. Response: {response.text}"
+ assert response.status_code in [
+ 200,
+ 201,
+ ], f"Expected status code 200 or 201, got {response.status_code}. Response: {response.text}"
# Only compare the fields that should match, as the actual response may include additional fields
response_json = response.json()
assert response_json["ChangeRequestID"] == sample_change_request["ChangeRequestID"]
@@ -124,21 +132,23 @@ def test_create_product_change_request(self):
def test_get_product_change_request_by_id(self):
# Set up specific mock to handle this request
self.mock_service.get_change_request_by_id = AsyncMock(return_value=sample_change_request)
-
+
# Execute request
response = self.client.get("/api/v1/product-change-requests/1")
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == sample_change_request
- self.mock_service.get_change_request_by_id.assert_called_once_with(
- conn=self.mock_db, request_id=1, actor_id=1
- )
+ self.mock_service.get_change_request_by_id.assert_called_once_with(conn=self.mock_db, request_id=1, actor_id=1)
def test_get_product_change_requests_by_product(self):
# Execute request
- response = self.client.get("/api/v1/product-change-requests/by-product/1?status=PENDING_APPROVAL&limit=10&offset=0")
-
+ response = self.client.get(
+ "/api/v1/product-change-requests/by-product/1?status=PENDING_APPROVAL&limit=10&offset=0"
+ )
+
# Assertions
assert response.status_code == 200
assert response.json() == [sample_change_request]
@@ -148,8 +158,10 @@ def test_get_product_change_requests_by_product(self):
def test_get_product_change_requests_by_store(self):
# Execute request
- response = self.client.get("/api/v1/product-change-requests/by-store/1?status=PENDING_APPROVAL&limit=10&offset=0")
-
+ response = self.client.get(
+ "/api/v1/product-change-requests/by-store/1?status=PENDING_APPROVAL&limit=10&offset=0"
+ )
+
# Assertions
assert response.status_code == 200
assert response.json() == [sample_change_request]
@@ -159,8 +171,10 @@ def test_get_product_change_requests_by_store(self):
def test_get_product_change_requests_by_merchant(self):
# Execute request
- response = self.client.get("/api/v1/product-change-requests/by-merchant/1?status=PENDING_APPROVAL&limit=10&offset=0")
-
+ response = self.client.get(
+ "/api/v1/product-change-requests/by-merchant/1?status=PENDING_APPROVAL&limit=10&offset=0"
+ )
+
# Assertions
assert response.status_code == 200
assert response.json() == [sample_change_request]
@@ -171,18 +185,20 @@ def test_get_product_change_requests_by_merchant(self):
def test_get_all_pending_product_requests_admin(self):
# Save original dependencies to restore later
original_deps = router.dependency_overrides.copy()
-
+
# Override user for admin
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Execute request
response = self.client.get("/api/v1/product-change-requests/admin/pending?limit=10&offset=0")
-
+
# Restore original dependencies
router.dependency_overrides = original_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
self.mock_service.get_all_pending_requests.assert_called_once_with(
conn=self.mock_db, limit=10, offset=0, actor_id=2
@@ -192,12 +208,12 @@ def test_update_product_change_request(self):
# Define test data
request_data = {
"ProposedData": {"ProductName": "Updated Product Name", "Price": 299.99},
- "SubmitterNotes": "Additional notes after update"
+ "SubmitterNotes": "Additional notes after update",
}
-
+
# Execute request
response = self.client.put("/api/v1/product-change-requests/1", json=request_data)
-
+
# Assertions
assert response.status_code == 200
assert response.json() == sample_change_request
@@ -209,90 +225,96 @@ def test_admin_update_product_change_request_status(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Override user for admin in both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Define test data according to ProductChangeRequestAdminUpdate schema
update_data = {
"Status": "APPROVED",
"AdminReviewerID": 2,
"AdminNotes": "Approved after review",
- "ReviewTimestamp": _fixed_timestamp
+ "ReviewTimestamp": _fixed_timestamp,
}
-
+
# Setup specific mock for this test
self.mock_service.get_change_request_by_id = AsyncMock(return_value=sample_change_request)
self.mock_service.update_request_status = AsyncMock(return_value=sample_change_request)
-
+
# Execute request using the correct URL pattern
response = self.client.put("/api/v1/product-change-requests/1/admin", json=update_data)
-
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == sample_change_request
-
+
# 不再检查update_request_status的调用参数,因为日期格式可能不一致
assert self.mock_service.update_request_status.called
call_args = self.mock_service.update_request_status.call_args
- assert call_args[1]['request_id'] == 1
- assert call_args[1]['actor_id'] == 2
+ assert call_args[1]["request_id"] == 1
+ assert call_args[1]["actor_id"] == 2
# 检查data中的非日期字段
- assert call_args[1]['data']['Status'] == update_data['Status']
- assert call_args[1]['data']['AdminReviewerID'] == update_data['AdminReviewerID']
- assert call_args[1]['data']['AdminNotes'] == update_data['AdminNotes']
+ assert call_args[1]["data"]["Status"] == update_data["Status"]
+ assert call_args[1]["data"]["AdminReviewerID"] == update_data["AdminReviewerID"]
+ assert call_args[1]["data"]["AdminNotes"] == update_data["AdminNotes"]
# 不检查ReviewTimestamp,因为格式可能不同
def test_cancel_product_change_request(self):
# Setup specific mock for this test
self.mock_service.cancel_request = AsyncMock(return_value=True)
-
+
# Override mocked user function to ensure actor_id is passed correctly
original_user_func = self.mock_get_current_active_user
self.mock_get_current_active_user = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+
# Make sure overrides are applied to both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_current_active_user
router.dependency_overrides[get_current_active_user] = self.mock_get_current_active_user
-
+
# Execute request
response = self.client.delete("/api/v1/product-change-requests/1")
-
+
# Restore original mock function
self.mock_get_current_active_user = original_user_func
-
+
# Assertions - DELETE endpoint returns 204 No Content as per the API definition
- assert response.status_code == 204, f"Expected status code 204, got {response.status_code}. Response: {response.text}"
- self.mock_service.cancel_request.assert_called_once_with(
- conn=self.mock_db, request_id=1, actor_id=1
- )
+ assert (
+ response.status_code == 204
+ ), f"Expected status code 204, got {response.status_code}. Response: {response.text}"
+ self.mock_service.cancel_request.assert_called_once_with(conn=self.mock_db, request_id=1, actor_id=1)
def test_get_filtered_product_change_requests_admin(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Override user for admin in both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Setup specific mock for this test
self.mock_service.get_filtered_requests = AsyncMock(return_value=[sample_change_request])
-
+
# Execute request with URL matching the actual endpoint
- response = self.client.get("/api/v1/product-change-requests?request_type=PRODUCT_UPDATE&status=APPROVED&start_date=2023-01-01&end_date=2023-12-31&limit=10&offset=0")
-
+ response = self.client.get(
+ "/api/v1/product-change-requests?request_type=PRODUCT_UPDATE&status=APPROVED&start_date=2023-01-01&end_date=2023-12-31&limit=10&offset=0"
+ )
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
assert self.mock_service.get_filtered_requests.called
@@ -301,26 +323,36 @@ def test_non_admin_cannot_access_admin_endpoints(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Force non-admin user for this test
- self.app.dependency_overrides[get_current_active_user] = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
- router.dependency_overrides[get_current_active_user] = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+ self.app.dependency_overrides[get_current_active_user] = lambda: {
+ "UserID": 1,
+ "Username": "testuser",
+ "Role": "merchant",
+ }
+ router.dependency_overrides[get_current_active_user] = lambda: {
+ "UserID": 1,
+ "Username": "testuser",
+ "Role": "merchant",
+ }
+
# Setup mock that raises PermissionDeniedException for non-admin access
- self.mock_service.get_all_pending_requests = AsyncMock(side_effect=PermissionDeniedException("Only admin can access this endpoint"))
-
+ self.mock_service.get_all_pending_requests = AsyncMock(
+ side_effect=PermissionDeniedException("Only admin can access this endpoint")
+ )
+
# Execute request with normal user (non-admin)
response = self.client.get("/api/v1/product-change-requests/admin/pending?limit=10&offset=0")
-
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
- # Should return 401 or 403 (Unauthorized or Forbidden)
- assert response.status_code in [401, 403], f"Expected status code 401 or 403, got {response.status_code}. Response: {response.text}"
-
-
+ # Should return 401 or 403 (Unauthorized or Forbidden)
+ assert response.status_code in [
+ 401,
+ 403,
+ ], f"Expected status code 401 or 403, got {response.status_code}. Response: {response.text}"
if __name__ == "__main__":
diff --git a/src/backend/test/unit/api/test_store_change_request_api.py b/src/backend/test/unit/api/test_store_change_request_api.py
index 1093b2e..f18694d 100644
--- a/src/backend/test/unit/api/test_store_change_request_api.py
+++ b/src/backend/test/unit/api/test_store_change_request_api.py
@@ -14,7 +14,7 @@
StoreChangeRequestCreate,
StoreChangeRequestUpdate,
StoreChangeRequestResponse,
- StoreChangeRequestAdminUpdate
+ StoreChangeRequestAdminUpdate,
)
from backend.app.utils.exceptions import PermissionDeniedException
@@ -35,12 +35,12 @@
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": _fixed_timestamp,
- "LastUpdatedDate": _fixed_timestamp
+ "LastUpdatedDate": _fixed_timestamp,
}
class TestStoreChangeRequestAPI(unittest.TestCase):
-
+
def setUp(self):
# Create FastAPI app for testing
self.app = FastAPI()
@@ -48,16 +48,16 @@ def setUp(self):
self.app.include_router(router, prefix="/api/v1/store-change-requests")
# Setup test client
self.client = TestClient(self.app)
-
+
# Create a properly configured mock database connection
self.mock_db = AsyncMock()
self.mock_db.execute = AsyncMock()
self.mock_db.execute.return_value.fetchall = MagicMock(return_value=[])
self.mock_db.execute.return_value.fetchone = MagicMock(return_value=None)
-
+
# Create a mock service that properly handles async operations
self.mock_service = AsyncMock(spec=StoreChangeRequestService)
-
+
# Set up default successful responses for all service methods
self.mock_service.get_change_request_by_id.return_value = sample_change_request
self.mock_service.get_change_requests_by_store_id.return_value = [sample_change_request]
@@ -68,46 +68,49 @@ def setUp(self):
self.mock_service.update_request.return_value = sample_change_request
self.mock_service.update_request_status.return_value = sample_change_request
self.mock_service.cancel_request.return_value = True
-
+
# Override dependencies - we need to override them in both the app and router
self.app.dependency_overrides = {
get_db_connection: lambda: self.mock_db,
get_current_active_user: self.mock_get_current_active_user,
- get_store_change_request_service: lambda: self.mock_service
+ get_store_change_request_service: lambda: self.mock_service,
}
-
+
# Also override them at the router level to ensure proper mocking
router.dependency_overrides = {
get_db_connection: lambda: self.mock_db,
get_current_active_user: self.mock_get_current_active_user,
- get_store_change_request_service: lambda: self.mock_service
+ get_store_change_request_service: lambda: self.mock_service,
}
-
+
def mock_get_current_active_user(self):
return {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+
def mock_get_admin_user(self):
return {"UserID": 2, "Username": "adminuser", "Role": "admin"}
-
+
def test_create_store_change_request(self):
- # Define test data with required RequestingUserID
+ # Define test data with required RequestingUserID
request_data = {
"RequestType": "STORE_UPDATE",
"ProposedData_JSON": {"StoreName": "Updated Store", "Description": "Updated Description"},
"StoreID": 1,
"RequestingUserID": 1,
- "SubmitterNotes": "Please approve my store update"
+ "SubmitterNotes": "Please approve my store update",
}
-
+
# Setup mock to specifically handle this request
self.mock_service.create_change_request = AsyncMock(return_value=sample_change_request)
-
+
# Execute request
response = self.client.post("/api/v1/store-change-requests", json=request_data)
-
+
# Assertions - note that we accept 200 or 201 as valid status codes for create operations
# since the implementation might return either one
- assert response.status_code in [200, 201], f"Expected status code 200 or 201, got {response.status_code}. Response: {response.text}"
+ assert response.status_code in [
+ 200,
+ 201,
+ ], f"Expected status code 200 or 201, got {response.status_code}. Response: {response.text}"
# Only compare the fields that should match, as the actual response may include additional fields
response_json = response.json()
assert response_json["ChangeRequestID"] == sample_change_request["ChangeRequestID"]
@@ -119,23 +122,25 @@ def test_create_store_change_request(self):
def test_get_store_change_request_by_id(self):
# Set up specific mock to handle this request
self.mock_service.get_change_request_by_id = AsyncMock(return_value=sample_change_request)
-
+
# Execute request
response = self.client.get("/api/v1/store-change-requests/1")
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == sample_change_request
- self.mock_service.get_change_request_by_id.assert_called_once_with(
- conn=self.mock_db, request_id=1, actor_id=1
- )
+ self.mock_service.get_change_request_by_id.assert_called_once_with(conn=self.mock_db, request_id=1, actor_id=1)
def test_get_store_change_requests_by_store(self):
# Execute request - use the correct path matching the router configuration
response = self.client.get("/api/v1/store-change-requests/by-store/1?status=PENDING_APPROVAL&limit=10&offset=0")
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
self.mock_service.get_change_requests_by_store_id.assert_called_once_with(
conn=self.mock_db, store_id=1, status="PENDING_APPROVAL", limit=10, offset=0, actor_id=1
@@ -144,9 +149,11 @@ def test_get_store_change_requests_by_store(self):
def test_get_store_change_requests_by_user(self):
# Execute request - use the correct path matching the router configuration
response = self.client.get("/api/v1/store-change-requests/by-user/1?status=PENDING_APPROVAL&limit=10&offset=0")
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
self.mock_service.get_change_requests_by_user_id.assert_called_once_with(
conn=self.mock_db, user_id=1, status="PENDING_APPROVAL", limit=10, offset=0, actor_id=1
@@ -155,18 +162,20 @@ def test_get_store_change_requests_by_user(self):
def test_get_all_pending_store_requests_admin(self):
# Save original dependencies to restore later
original_deps = router.dependency_overrides.copy()
-
+
# Override user for admin
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Execute request
response = self.client.get("/api/v1/store-change-requests/admin/pending?limit=10&offset=0")
-
+
# Restore original dependencies
router.dependency_overrides = original_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
self.mock_service.get_all_pending_requests.assert_called_once_with(
conn=self.mock_db, limit=10, offset=0, actor_id=2
@@ -176,110 +185,116 @@ def test_update_store_change_request(self):
# Define test data
request_data = {
"ProposedData": {"StoreName": "Updated Store Name", "Description": "New Updated Description"},
- "SubmitterNotes": "Updated notes"
+ "SubmitterNotes": "Updated notes",
}
-
+
# Execute request
response = self.client.put("/api/v1/store-change-requests/1", json=request_data)
-
+
# Assertions
assert response.status_code == 200
assert response.json() == sample_change_request
-
+
# 不再检查update_request的调用参数,因为我们使用了硬编码的数据
assert self.mock_service.update_request.called
call_args = self.mock_service.update_request.call_args
- assert call_args[1]['request_id'] == 1
- assert call_args[1]['actor_id'] == 1
+ assert call_args[1]["request_id"] == 1
+ assert call_args[1]["actor_id"] == 1
def test_admin_update_store_change_request_status(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Override user for admin in both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Define test data according to StoreChangeRequestAdminUpdate schema
update_data = {
"Status": "APPROVED",
"AdminReviewerID": 2,
"AdminNotes": "Approved after review",
- "ReviewTimestamp": _fixed_timestamp
+ "ReviewTimestamp": _fixed_timestamp,
}
-
+
# Setup specific mock for this test
self.mock_service.get_change_request_by_id = AsyncMock(return_value=sample_change_request)
self.mock_service.update_request_status = AsyncMock(return_value=sample_change_request)
-
+
# Execute request using the correct URL pattern
response = self.client.put("/api/v1/store-change-requests/1/admin", json=update_data)
-
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == sample_change_request
-
+
# 不再检查update_request_status的调用参数,因为日期格式可能不一致
assert self.mock_service.update_request_status.called
call_args = self.mock_service.update_request_status.call_args
- assert call_args[1]['request_id'] == 1
- assert call_args[1]['actor_id'] == 2
+ assert call_args[1]["request_id"] == 1
+ assert call_args[1]["actor_id"] == 2
# 检查data中的非日期字段
- assert call_args[1]['data']['Status'] == update_data['Status']
- assert call_args[1]['data']['AdminReviewerID'] == update_data['AdminReviewerID']
- assert call_args[1]['data']['AdminNotes'] == update_data['AdminNotes']
+ assert call_args[1]["data"]["Status"] == update_data["Status"]
+ assert call_args[1]["data"]["AdminReviewerID"] == update_data["AdminReviewerID"]
+ assert call_args[1]["data"]["AdminNotes"] == update_data["AdminNotes"]
# 不检查ReviewTimestamp,因为格式可能不同
def test_cancel_store_change_request(self):
# Setup specific mock for this test
self.mock_service.cancel_request = AsyncMock(return_value=True)
-
+
# Override mocked user function to ensure actor_id is passed correctly
original_user_func = self.mock_get_current_active_user
self.mock_get_current_active_user = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+
# Make sure overrides are applied to both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_current_active_user
router.dependency_overrides[get_current_active_user] = self.mock_get_current_active_user
-
+
# Execute request
response = self.client.delete("/api/v1/store-change-requests/1")
-
+
# Restore original mock function
self.mock_get_current_active_user = original_user_func
-
+
# Assertions - DELETE endpoint returns 204 No Content as per the API definition
- assert response.status_code == 204, f"Expected status code 204, got {response.status_code}. Response: {response.text}"
- self.mock_service.cancel_request.assert_called_once_with(
- conn=self.mock_db, request_id=1, actor_id=1
- )
+ assert (
+ response.status_code == 204
+ ), f"Expected status code 204, got {response.status_code}. Response: {response.text}"
+ self.mock_service.cancel_request.assert_called_once_with(conn=self.mock_db, request_id=1, actor_id=1)
def test_get_filtered_change_requests_admin(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Override user for admin in both app and router
self.app.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
router.dependency_overrides[get_current_active_user] = self.mock_get_admin_user
-
+
# Setup specific mock for this test
self.mock_service.get_filtered_requests = AsyncMock(return_value=[sample_change_request])
-
+
# Execute request with URL matching the actual endpoint
- response = self.client.get("/api/v1/store-change-requests?store_id=1&user_id=1&request_type=STORE_UPDATE&status=PENDING_APPROVAL&limit=10&offset=0")
-
+ response = self.client.get(
+ "/api/v1/store-change-requests?store_id=1&user_id=1&request_type=STORE_UPDATE&status=PENDING_APPROVAL&limit=10&offset=0"
+ )
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
+
# Assertions
- assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
+ assert (
+ response.status_code == 200
+ ), f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json() == [sample_change_request]
assert self.mock_service.get_filtered_requests.called
@@ -288,23 +303,36 @@ def test_non_admin_cannot_access_admin_endpoints(self):
# Save original dependencies to restore later
original_app_deps = self.app.dependency_overrides.copy()
original_router_deps = router.dependency_overrides.copy()
-
+
# Force non-admin user for this test
- self.app.dependency_overrides[get_current_active_user] = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
- router.dependency_overrides[get_current_active_user] = lambda: {"UserID": 1, "Username": "testuser", "Role": "merchant"}
-
+ self.app.dependency_overrides[get_current_active_user] = lambda: {
+ "UserID": 1,
+ "Username": "testuser",
+ "Role": "merchant",
+ }
+ router.dependency_overrides[get_current_active_user] = lambda: {
+ "UserID": 1,
+ "Username": "testuser",
+ "Role": "merchant",
+ }
+
# Setup mock that raises PermissionDeniedException for non-admin access
- self.mock_service.get_all_pending_requests = AsyncMock(side_effect=PermissionDeniedException("Only admin can access this endpoint"))
-
+ self.mock_service.get_all_pending_requests = AsyncMock(
+ side_effect=PermissionDeniedException("Only admin can access this endpoint")
+ )
+
# Execute request with normal user (non-admin)
response = self.client.get("/api/v1/store-change-requests/admin/pending?limit=10&offset=0")
-
+
# Restore original dependencies
self.app.dependency_overrides = original_app_deps
router.dependency_overrides = original_router_deps
-
+
# Should return 401 or 403 (Unauthorized or Forbidden)
- assert response.status_code in [401, 403], f"Expected status code 401 or 403, got {response.status_code}. Response: {response.text}"
+ assert response.status_code in [
+ 401,
+ 403,
+ ], f"Expected status code 401 or 403, got {response.status_code}. Response: {response.text}"
if __name__ == "__main__":
diff --git a/src/backend/test/unit/core/test_config.py b/src/backend/test/unit/core/test_config.py
index 029c3b5..5c25c13 100644
--- a/src/backend/test/unit/core/test_config.py
+++ b/src/backend/test/unit/core/test_config.py
@@ -26,5 +26,5 @@ def test_config_test_creation(self):
self.assertIsNotNone(test_config.DB_URL)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/core/test_database.py b/src/backend/test/unit/core/test_database.py
index 60625ca..606f6d9 100644
--- a/src/backend/test/unit/core/test_database.py
+++ b/src/backend/test/unit/core/test_database.py
@@ -4,6 +4,7 @@
import backend.app.core.database as database
from backend.app.core.config import test_db_config
+
class TestDatabaseCreation(unittest.TestCase):
def test_database_connection(self):
engine = database.get_engine()
@@ -43,5 +44,5 @@ def test_set_mode(self):
self.assertEqual(database.get_engine(), database.__dict__["__engine_dev"])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/crud/test_address_crud.py b/src/backend/test/unit/crud/test_address_crud.py
index 8f8be40..34ce9bf 100644
--- a/src/backend/test/unit/crud/test_address_crud.py
+++ b/src/backend/test/unit/crud/test_address_crud.py
@@ -24,7 +24,7 @@ def setUp(self):
self.mock_cursor_result = MagicMock()
self.mock_conn.execute.return_value = self.mock_cursor_result
- self.set_actor_patcher = patch.object(AddressCRUD, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(AddressCRUD, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -33,19 +33,22 @@ def test_create_address_success(self):
actor_id = 1
address_data = AddressCreateRequest(
RecipientName="Test User Create",
- PhoneNumber="1234567", # ⭐ 满足长度约束
- FullAddress_Text="123 Test St, Create City"
+ PhoneNumber="1234567",
+ FullAddress_Text="123 Test St, Create City", # ⭐ 满足长度约束
)
expected_new_address_id = 101
self.mock_cursor_result.lastrowid = expected_new_address_id
mock_created_address_dict = {
- "AddressID": expected_new_address_id, "UserID": user_id,
- "RecipientName": address_data.RecipientName, "PhoneNumber": address_data.PhoneNumber,
- "FullAddress_Text": address_data.FullAddress_Text, "IsDefault": False
+ "AddressID": expected_new_address_id,
+ "UserID": user_id,
+ "RecipientName": address_data.RecipientName,
+ "PhoneNumber": address_data.PhoneNumber,
+ "FullAddress_Text": address_data.FullAddress_Text,
+ "IsDefault": False,
}
- with patch.object(self.crud, 'get_address_by_id', return_value=mock_created_address_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_address_by_id", return_value=mock_created_address_dict) as mock_get_by_id:
created_address = self.crud.create_address(
self.mock_conn, user_id=user_id, address_in=address_data, actor_id=actor_id
)
@@ -60,26 +63,28 @@ def test_create_address_success(self):
self.assertIn(f"INSERT INTO {self.crud.table_name}", normalized_sql)
self.assertIn("VALUES (:UserID, :RecipientName, :PhoneNumber, :FullAddress_Text, FALSE)", normalized_sql)
- self.assertEqual(called_params['UserID'], user_id)
- self.assertEqual(called_params['RecipientName'], address_data.RecipientName)
- self.assertEqual(called_params['PhoneNumber'], address_data.PhoneNumber)
- self.assertEqual(called_params['FullAddress_Text'], address_data.FullAddress_Text)
+ self.assertEqual(called_params["UserID"], user_id)
+ self.assertEqual(called_params["RecipientName"], address_data.RecipientName)
+ self.assertEqual(called_params["PhoneNumber"], address_data.PhoneNumber)
+ self.assertEqual(called_params["FullAddress_Text"], address_data.FullAddress_Text)
- mock_get_by_id.assert_called_once_with(self.mock_conn, address_id=expected_new_address_id,
- actor_id=actor_id)
+ mock_get_by_id.assert_called_once_with(
+ self.mock_conn, address_id=expected_new_address_id, actor_id=actor_id
+ )
self.assertEqual(created_address, mock_created_address_dict)
- @patch('backend.app.crud.address_crud.logger')
+ @patch("backend.app.crud.address_crud.logger")
def test_create_address_lastrowid_none(self, mock_logger):
user_id = 2
actor_id = 2
- address_data = AddressCreateRequest(RecipientName="No LastRowID", PhoneNumber="5550000",
- FullAddress_Text="Some Addr")
+ address_data = AddressCreateRequest(
+ RecipientName="No LastRowID", PhoneNumber="5550000", FullAddress_Text="Some Addr"
+ )
self.mock_cursor_result.lastrowid = None
mock_retrieved_address = {"AddressID": 999, "UserID": user_id, "IsDefault": False}
- with patch.object(self.crud, 'get_address_by_id', return_value=mock_retrieved_address) as mock_get_by_id:
+ with patch.object(self.crud, "get_address_by_id", return_value=mock_retrieved_address) as mock_get_by_id:
created_address = self.crud.create_address(
self.mock_conn, user_id=user_id, address_in=address_data, actor_id=actor_id
)
@@ -92,12 +97,13 @@ def test_create_address_lastrowid_none(self, mock_logger):
def test_create_address_integrity_error(self):
user_id = 3
actor_id = 3
- address_data = AddressCreateRequest(RecipientName="Integrity Fail", PhoneNumber="1112233",
- FullAddress_Text="Random Addr")
+ address_data = AddressCreateRequest(
+ RecipientName="Integrity Fail", PhoneNumber="1112233", FullAddress_Text="Random Addr"
+ )
self.mock_conn.execute.side_effect = exc.IntegrityError("mocked integrity error", params={}, orig=None)
- with patch('backend.app.crud.address_crud.logger') as mock_logger:
+ with patch("backend.app.crud.address_crud.logger") as mock_logger:
created_address = self.crud.create_address(
self.mock_conn, user_id=user_id, address_in=address_data, actor_id=actor_id
)
@@ -122,9 +128,10 @@ def test_get_address_by_id_found(self):
self.assertIn(
f"SELECT AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault FROM {self.crud.table_name}",
- normalized_sql)
+ normalized_sql,
+ )
self.assertIn("WHERE AddressID = :AddressID", normalized_sql)
- self.assertEqual(called_params['AddressID'], address_id)
+ self.assertEqual(called_params["AddressID"], address_id)
self.assertEqual(address, expected_data)
def test_get_address_by_id_not_found(self):
@@ -137,9 +144,9 @@ def test_get_addresses_by_user_id_found(self):
actor_id = 1
mock_row_data1 = {"AddressID": 1, "UserID": user_id, "RecipientName": "Addr 1"}
mock_row_data2 = {"AddressID": 2, "UserID": user_id, "RecipientName": "Addr 2"}
- mock_row1 = MagicMock();
+ mock_row1 = MagicMock()
mock_row1._mapping = mock_row_data1
- mock_row2 = MagicMock();
+ mock_row2 = MagicMock()
mock_row2._mapping = mock_row_data2
self.mock_cursor_result.fetchall.return_value = [mock_row1, mock_row2]
@@ -152,10 +159,11 @@ def test_get_addresses_by_user_id_found(self):
self.assertIn(
f"SELECT AddressID, UserID, RecipientName, PhoneNumber, FullAddress_Text, IsDefault FROM {self.crud.table_name}",
- normalized_sql)
+ normalized_sql,
+ )
self.assertIn("WHERE UserID = :UserID", normalized_sql)
self.assertIn("ORDER BY IsDefault DESC, AddressID ASC", normalized_sql)
- self.assertEqual(called_params['UserID'], user_id)
+ self.assertEqual(called_params["UserID"], user_id)
self.assertEqual(len(addresses), 2)
self.assertEqual(addresses[0], mock_row_data1)
@@ -172,17 +180,20 @@ def test_update_address_details_success(self):
update_data = AddressUpdateRequest(
RecipientName="Updated Name",
FullAddress_Text="Updated Address Text",
- PhoneNumber="9876543" # ⭐ 满足长度约束
+ PhoneNumber="9876543", # ⭐ 满足长度约束
)
self.mock_cursor_result.rowcount = 1
mock_updated_address_dict = {
- "AddressID": address_id, "UserID": 1,
- "RecipientName": "Updated Name", "PhoneNumber": "9876543",
- "FullAddress_Text": "Updated Address Text", "IsDefault": False
+ "AddressID": address_id,
+ "UserID": 1,
+ "RecipientName": "Updated Name",
+ "PhoneNumber": "9876543",
+ "FullAddress_Text": "Updated Address Text",
+ "IsDefault": False,
}
- with patch.object(self.crud, 'get_address_by_id', return_value=mock_updated_address_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_address_by_id", return_value=mock_updated_address_dict) as mock_get_by_id:
updated_address = self.crud.update_address_details(
self.mock_conn, address_id=address_id, address_in=update_data, actor_id=actor_id
)
@@ -200,10 +211,10 @@ def test_update_address_details_success(self):
self.assertIn("FullAddress_Text = :FullAddress_Text", normalized_sql)
self.assertIn("WHERE AddressID = :AddressID_param", normalized_sql)
- self.assertEqual(called_params['RecipientName'], "Updated Name")
- self.assertEqual(called_params['PhoneNumber'], "9876543")
- self.assertEqual(called_params['FullAddress_Text'], "Updated Address Text")
- self.assertEqual(called_params['AddressID_param'], address_id)
+ self.assertEqual(called_params["RecipientName"], "Updated Name")
+ self.assertEqual(called_params["PhoneNumber"], "9876543")
+ self.assertEqual(called_params["FullAddress_Text"], "Updated Address Text")
+ self.assertEqual(called_params["AddressID_param"], address_id)
mock_get_by_id.assert_called_once_with(self.mock_conn, address_id=address_id, actor_id=actor_id)
self.assertEqual(updated_address, mock_updated_address_dict)
@@ -214,7 +225,7 @@ def test_update_address_details_no_fields_to_update(self):
empty_update_data = AddressUpdateRequest()
mock_current_address_dict = {"AddressID": address_id, "RecipientName": "Current Name"}
- with patch.object(self.crud, 'get_address_by_id', return_value=mock_current_address_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_address_by_id", return_value=mock_current_address_dict) as mock_get_by_id:
result = self.crud.update_address_details(
self.mock_conn, address_id=address_id, address_in=empty_update_data, actor_id=actor_id
)
@@ -248,10 +259,11 @@ def test_update_address_is_default_flag_success_true(self):
called_params = self.mock_conn.execute.call_args.args[1]
normalized_sql = normalize_sql(str(called_stmt_obj.text))
- self.assertIn(f"UPDATE {self.crud.table_name} SET IsDefault = :IsDefault WHERE AddressID = :AddressID",
- normalized_sql)
- self.assertEqual(called_params['IsDefault'], True)
- self.assertEqual(called_params['AddressID'], address_id)
+ self.assertIn(
+ f"UPDATE {self.crud.table_name} SET IsDefault = :IsDefault WHERE AddressID = :AddressID", normalized_sql
+ )
+ self.assertEqual(called_params["IsDefault"], True)
+ self.assertEqual(called_params["AddressID"], address_id)
def test_update_address_is_default_flag_not_found_or_no_change(self):
address_id = 31
@@ -279,14 +291,16 @@ def test_set_all_other_addresses_non_default_for_user_updates_some(self):
called_params = self.mock_conn.execute.call_args.args[1]
normalized_sql = normalize_sql(str(called_stmt_obj.text))
- expected_sql_part = normalize_sql(f"""
+ expected_sql_part = normalize_sql(
+ f"""
UPDATE {self.crud.table_name}
SET IsDefault = FALSE
WHERE UserID = :UserID AND AddressID != :ExceptAddressID AND IsDefault = TRUE
- """)
+ """
+ )
self.assertEqual(normalized_sql, expected_sql_part) # Check for exact match after normalization
- self.assertEqual(called_params['UserID'], user_id)
- self.assertEqual(called_params['ExceptAddressID'], except_address_id)
+ self.assertEqual(called_params["UserID"], user_id)
+ self.assertEqual(called_params["ExceptAddressID"], except_address_id)
def test_delete_address_success(self):
address_id = 40
@@ -302,7 +316,7 @@ def test_delete_address_success(self):
normalized_sql = normalize_sql(str(called_stmt_obj.text))
self.assertIn(f"DELETE FROM {self.crud.table_name} WHERE AddressID = :AddressID", normalized_sql)
- self.assertEqual(called_params['AddressID'], address_id)
+ self.assertEqual(called_params["AddressID"], address_id)
def test_delete_address_not_found(self):
address_id = 999
@@ -313,5 +327,5 @@ def test_delete_address_not_found(self):
self.assertFalse(result)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_cartitem_crud.py b/src/backend/test/unit/crud/test_cartitem_crud.py
index 52c91f8..7aec15d 100644
--- a/src/backend/test/unit/crud/test_cartitem_crud.py
+++ b/src/backend/test/unit/crud/test_cartitem_crud.py
@@ -6,6 +6,7 @@
# 假设 CartItemCRUD 位于此路径 (相对于 src/)
from backend.app.crud.cartitem_crud import CartItemCRUD
+
# 如果 CartItemCRUD 未找到,您可能需要在测试运行器中调整 sys.path
# 或确保您的项目结构被测试发现器正确识别。
@@ -42,17 +43,21 @@ def test_add_item_to_cart_new_item_success(self):
mock_insert_result.lastrowid = new_cart_item_id
expected_created_item_data = {
- "CartItemID": new_cart_item_id, "UserID": user_id, "ProductID": product_id,
- "Quantity": quantity, "PriceAtAddition": price_at_addition,
- "AddedDate": datetime.datetime.now(datetime.timezone.utc)
+ "CartItemID": new_cart_item_id,
+ "UserID": user_id,
+ "ProductID": product_id,
+ "Quantity": quantity,
+ "PriceAtAddition": price_at_addition,
+ "AddedDate": datetime.datetime.now(datetime.timezone.utc),
}
- with patch.object(self.crud, 'get_cart_item_by_id',
- return_value=expected_created_item_data) as mock_get_by_id_internal:
+ with patch.object(
+ self.crud, "get_cart_item_by_id", return_value=expected_created_item_data
+ ) as mock_get_by_id_internal:
self.mock_conn.execute.side_effect = [
mock_set_actor_id_result,
mock_set_actor_id_result,
mock_select_existing_result,
- mock_insert_result
+ mock_insert_result,
]
cart_item = self.crud.add_item_to_cart(
@@ -61,23 +66,23 @@ def test_add_item_to_cart_new_item_success(self):
product_id=product_id,
quantity=quantity,
price_at_addition=price_at_addition,
- actor_id=actor_id
+ actor_id=actor_id,
)
- self.assertEqual(self.mock_conn.execute.call_count,
- 4) # 2 次设置 actor_id,1 次查询现有项,1 次插入新项
+ self.assertEqual(self.mock_conn.execute.call_count, 4) # 2 次设置 actor_id,1 次查询现有项,1 次插入新项
select_call_args = self.mock_conn.execute.call_args_list[2].args
self.assertIn("SELECT", str(select_call_args[0].text))
- self.assertEqual(select_call_args[1]['user_id'], user_id)
- self.assertEqual(select_call_args[1]['product_id'], product_id)
+ self.assertEqual(select_call_args[1]["user_id"], user_id)
+ self.assertEqual(select_call_args[1]["product_id"], product_id)
insert_call_args = self.mock_conn.execute.call_args_list[3].args
self.assertIn(f"INSERT INTO {self.crud.table_name}", str(insert_call_args[0].text))
- self.assertEqual(insert_call_args[1]['quantity'], quantity)
+ self.assertEqual(insert_call_args[1]["quantity"], quantity)
- mock_get_by_id_internal.assert_called_once_with(self.mock_conn, cart_item_id=new_cart_item_id,
- actor_id=actor_id)
+ mock_get_by_id_internal.assert_called_once_with(
+ self.mock_conn, cart_item_id=new_cart_item_id, actor_id=actor_id
+ )
self.assertEqual(cart_item, expected_created_item_data)
# 移除了对 self.mock_logger.info 的断言
@@ -99,17 +104,21 @@ def test_add_item_to_cart_update_existing_item_success(self):
mock_update_result = MagicMock()
expected_updated_item_data = {
- "CartItemID": existing_cart_item_id, "UserID": user_id, "ProductID": product_id,
- "Quantity": expected_new_quantity, "PriceAtAddition": price_at_addition,
- "AddedDate": datetime.datetime.utcnow()
+ "CartItemID": existing_cart_item_id,
+ "UserID": user_id,
+ "ProductID": product_id,
+ "Quantity": expected_new_quantity,
+ "PriceAtAddition": price_at_addition,
+ "AddedDate": datetime.datetime.utcnow(),
}
- with patch.object(self.crud, 'get_cart_item_by_id',
- return_value=expected_updated_item_data) as mock_get_by_id_internal:
+ with patch.object(
+ self.crud, "get_cart_item_by_id", return_value=expected_updated_item_data
+ ) as mock_get_by_id_internal:
self.mock_conn.execute.side_effect = [
mock_set_actor_id_result, # add_to_cart
mock_set_actor_id_result, # check if item exists
mock_select_existing_result, # check if item exists
- mock_update_result # update item
+ mock_update_result, # update item
]
cart_item = self.crud.add_item_to_cart(
@@ -118,19 +127,20 @@ def test_add_item_to_cart_update_existing_item_success(self):
product_id=product_id,
quantity=add_quantity,
price_at_addition=price_at_addition,
- actor_id=actor_id
+ actor_id=actor_id,
)
self.assertEqual(self.mock_conn.execute.call_count, 4)
update_call_args = self.mock_conn.execute.call_args_list[3].args
self.assertIn(f"UPDATE {self.crud.table_name}", str(update_call_args[0].text))
- self.assertEqual(update_call_args[1]['quantity'], expected_new_quantity)
- self.assertEqual(update_call_args[1]['price_at_addition'], price_at_addition)
- self.assertEqual(update_call_args[1]['cart_item_id'], existing_cart_item_id)
+ self.assertEqual(update_call_args[1]["quantity"], expected_new_quantity)
+ self.assertEqual(update_call_args[1]["price_at_addition"], price_at_addition)
+ self.assertEqual(update_call_args[1]["cart_item_id"], existing_cart_item_id)
- mock_get_by_id_internal.assert_called_once_with(self.mock_conn, cart_item_id=existing_cart_item_id,
- actor_id=actor_id)
+ mock_get_by_id_internal.assert_called_once_with(
+ self.mock_conn, cart_item_id=existing_cart_item_id, actor_id=actor_id
+ )
self.assertEqual(cart_item, expected_updated_item_data)
# 移除了对 self.mock_logger.info 的断言
@@ -159,7 +169,7 @@ def test_get_cart_item_by_id_found(self):
self.assertIn(f"SELECT", str(call_args[0].text))
self.assertIn(f"FROM {self.crud.table_name}", str(call_args[0].text))
- self.assertEqual(call_args[1]['cart_item_id'], cart_item_id)
+ self.assertEqual(call_args[1]["cart_item_id"], cart_item_id)
self.assertEqual(item, expected_data)
def test_get_cart_item_by_id_not_found(self):
@@ -176,8 +186,9 @@ def test_get_cart_item_by_user_and_product_found(self):
mock_row._mapping = expected_data
self.mock_cursor_result.fetchone.return_value = mock_row
- item = self.crud.get_cart_item_by_user_and_product(self.mock_conn, user_id=user_id, product_id=product_id,
- actor_id=user_id)
+ item = self.crud.get_cart_item_by_user_and_product(
+ self.mock_conn, user_id=user_id, product_id=product_id, actor_id=user_id
+ )
self.assertEqual(item, expected_data)
self.assertEqual(self.mock_conn.execute.call_count, 2)
call_args = self.mock_conn.execute.call_args_list[1].args
@@ -185,10 +196,22 @@ def test_get_cart_item_by_user_and_product_found(self):
def test_get_cart_items_by_user_id_found(self):
user_id = 1
- mock_row1_data = {"CartItemID": 1, "UserID": user_id, "ProductID": 101, "Quantity": 1,
- "ProductName": "Product A", "ProductImageURL": "http://example.com/image1.jpg"}
- mock_row2_data = {"CartItemID": 2, "UserID": user_id, "ProductID": 102, "Quantity": 2,
- "ProductName": "Product B", "ProductImageURL": "http://example.com/image2.jpg"}
+ mock_row1_data = {
+ "CartItemID": 1,
+ "UserID": user_id,
+ "ProductID": 101,
+ "Quantity": 1,
+ "ProductName": "Product A",
+ "ProductImageURL": "http://example.com/image1.jpg",
+ }
+ mock_row2_data = {
+ "CartItemID": 2,
+ "UserID": user_id,
+ "ProductID": 102,
+ "Quantity": 2,
+ "ProductName": "Product B",
+ "ProductImageURL": "http://example.com/image2.jpg",
+ }
mock_row1 = MagicMock()
mock_row1._mapping = mock_row1_data
mock_row2 = MagicMock()
@@ -215,8 +238,9 @@ def test_update_cart_item_quantity_success(self):
self.mock_cursor_result.rowcount = 1
expected_updated_item_data = {"CartItemID": cart_item_id, "Quantity": new_quantity}
- with patch.object(self.crud, 'get_cart_item_by_id',
- return_value=expected_updated_item_data) as mock_get_by_id_internal:
+ with patch.object(
+ self.crud, "get_cart_item_by_id", return_value=expected_updated_item_data
+ ) as mock_get_by_id_internal:
updated_item = self.crud.update_cart_item_quantity(
self.mock_conn, cart_item_id=cart_item_id, new_quantity=new_quantity, actor_id=actor_id
)
@@ -224,25 +248,21 @@ def test_update_cart_item_quantity_success(self):
# self.mock_conn.execute.assert_called_once()
self.assertEqual(self.mock_conn.execute.call_count, 2) # 设置 actor_id 和更新数量
call_args = self.mock_conn.execute.call_args_list[1].args
- self.assertIn(
- f"UPDATE {self.crud.table_name}",
- str(call_args[0].text))
- self.assertIn(
- f"SET Quantity = :new_quantity",
- str(call_args[0].text)
- )
- self.assertEqual(call_args[1]['new_quantity'], new_quantity)
- self.assertEqual(call_args[1]['cart_item_id'], cart_item_id)
+ self.assertIn(f"UPDATE {self.crud.table_name}", str(call_args[0].text))
+ self.assertIn(f"SET Quantity = :new_quantity", str(call_args[0].text))
+ self.assertEqual(call_args[1]["new_quantity"], new_quantity)
+ self.assertEqual(call_args[1]["cart_item_id"], cart_item_id)
- mock_get_by_id_internal.assert_called_once_with(self.mock_conn, cart_item_id=cart_item_id,
- actor_id=actor_id)
+ mock_get_by_id_internal.assert_called_once_with(
+ self.mock_conn, cart_item_id=cart_item_id, actor_id=actor_id
+ )
self.assertEqual(updated_item, expected_updated_item_data)
# 移除了对 self.mock_logger.info 的断言
def test_update_cart_item_quantity_item_not_found(self):
self.mock_cursor_result.rowcount = 0
- with patch.object(self.crud, 'get_cart_item_by_id') as mock_get_by_id_internal:
+ with patch.object(self.crud, "get_cart_item_by_id") as mock_get_by_id_internal:
updated_item = self.crud.update_cart_item_quantity(
self.mock_conn, cart_item_id=999, new_quantity=3, actor_id=1
)
@@ -252,9 +272,7 @@ def test_update_cart_item_quantity_item_not_found(self):
def test_update_cart_item_quantity_non_positive_raises_value_error(self):
with self.assertRaisesRegex(ValueError, "Quantity must be positive. To remove an item, use the delete method."):
- self.crud.update_cart_item_quantity(
- self.mock_conn, cart_item_id=15, new_quantity=0, actor_id=1
- )
+ self.crud.update_cart_item_quantity(self.mock_conn, cart_item_id=15, new_quantity=0, actor_id=1)
# 移除了对 self.mock_logger.warning 的断言
def test_remove_item_from_cart_success(self):
@@ -307,7 +325,7 @@ def test_check_user_owns_cart_item_success(self):
self.assertEqual(self.mock_conn.execute.call_count, 2)
args, kwargs = self.mock_conn.execute.call_args_list[1]
self.assertIn("SELECT COUNT(*) AS ItemCount", str(args[0].text))
- self.assertEqual(args[1]['cart_item_id'], 123)
+ self.assertEqual(args[1]["cart_item_id"], 123)
def test_check_user_owns_cart_item_not_found(self):
# Mock the result of the SQL query
@@ -321,7 +339,8 @@ def test_check_user_owns_cart_item_not_found(self):
self.assertEqual(self.mock_conn.execute.call_count, 2)
args, kwargs = self.mock_conn.execute.call_args_list[1]
self.assertIn("SELECT COUNT(*) AS ItemCount", str(args[0].text))
- self.assertEqual(args[1]['cart_item_id'], 123)
+ self.assertEqual(args[1]["cart_item_id"], 123)
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_category_crud.py b/src/backend/test/unit/crud/test_category_crud.py
index aaf6fdc..61914e5 100644
--- a/src/backend/test/unit/crud/test_category_crud.py
+++ b/src/backend/test/unit/crud/test_category_crud.py
@@ -10,7 +10,8 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
+
@unittest.skip("This test has been moved to test_category_crud_integration.py")
class TestCategoryCRUD(BaseDBTestCaseAutoRollback):
@@ -26,13 +27,10 @@ def test_create_category(self):
"category_name": "测试分类",
"category_description": "这是一个用于测试的分类",
}
-
+
# 执行测试
- category = self.category_crud.create_category(
- conn=self.connection,
- **category_data
- )
-
+ category = self.category_crud.create_category(conn=self.connection, **category_data)
+
# 验证结果
self.assertIsNotNone(category)
self.assertIsInstance(category, dict)
@@ -45,23 +43,18 @@ def test_create_subcategory(self):
"""测试创建子分类"""
# 准备数据 - 先创建父分类
parent_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="父分类",
- category_description="父分类描述"
+ conn=self.connection, category_name="父分类", category_description="父分类描述"
)
-
+
# 执行测试 - 创建子分类
subcategory_data = {
"category_name": "子分类",
"category_description": "子分类描述",
- "parent_category_id": parent_category["CategoryID"]
+ "parent_category_id": parent_category["CategoryID"],
}
-
- subcategory = self.category_crud.create_category(
- conn=self.connection,
- **subcategory_data
- )
-
+
+ subcategory = self.category_crud.create_category(conn=self.connection, **subcategory_data)
+
# 验证结果
self.assertIsNotNone(subcategory)
self.assertEqual(subcategory["CategoryName"], subcategory_data["category_name"])
@@ -72,72 +65,51 @@ def test_create_category_with_invalid_parent(self):
# 执行测试 - 尝试使用不存在的父分类ID
with self.assertRaises(ValueError):
self.category_crud.create_category(
- conn=self.connection,
- category_name="无效父分类的分类",
- parent_category_id=999999
+ conn=self.connection, category_name="无效父分类的分类", parent_category_id=999999
)
def test_get_category_by_id(self):
"""测试根据ID获取分类"""
# 准备数据 - 先创建分类
test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类2",
- category_description="用于测试查询的分类"
+ conn=self.connection, category_name="测试分类2", category_description="用于测试查询的分类"
)
-
+
# 执行测试
- category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
-
+ category = self.category_crud.get_category_by_id(conn=self.connection, category_id=test_category["CategoryID"])
+
# 验证结果
self.assertIsNotNone(category)
self.assertEqual(category["CategoryID"], test_category["CategoryID"])
self.assertEqual(category["CategoryName"], "测试分类2")
-
+
# 测试获取不存在的分类
- non_existent_category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=999999
- )
+ non_existent_category = self.category_crud.get_category_by_id(conn=self.connection, category_id=999999)
self.assertIsNone(non_existent_category)
def test_get_categories(self):
"""测试获取分类列表"""
# 准备数据 - 创建一些顶级分类
for i in range(3):
- self.category_crud.create_category(
- conn=self.connection,
- category_name=f"顶级分类{i+1}"
- )
-
+ self.category_crud.create_category(conn=self.connection, category_name=f"顶级分类{i+1}")
+
# 执行测试 - 获取所有顶级分类
- top_categories = self.category_crud.get_categories(
- conn=self.connection,
- parent_id=None
- )
-
+ top_categories = self.category_crud.get_categories(conn=self.connection, parent_id=None)
+
# 验证结果
self.assertIsInstance(top_categories, list)
self.assertGreaterEqual(len(top_categories), 3) # 至少3个顶级分类
-
+
# 准备数据 - 为第一个分类创建子分类
parent_id = top_categories[0]["CategoryID"]
for i in range(2):
self.category_crud.create_category(
- conn=self.connection,
- category_name=f"子分类{i+1}",
- parent_category_id=parent_id
+ conn=self.connection, category_name=f"子分类{i+1}", parent_category_id=parent_id
)
-
+
# 执行测试 - 获取特定父分类的子分类
- sub_categories = self.category_crud.get_categories(
- conn=self.connection,
- parent_id=parent_id
- )
-
+ sub_categories = self.category_crud.get_categories(conn=self.connection, parent_id=parent_id)
+
# 验证结果
self.assertEqual(len(sub_categories), 2)
for category in sub_categories:
@@ -147,155 +119,105 @@ def test_update_category(self):
"""测试更新分类信息"""
# 准备数据
test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试更新分类",
- category_description="原始描述"
+ conn=self.connection, category_name="测试更新分类", category_description="原始描述"
)
-
+
# 执行测试 - 更新名称和描述
- update_data = {
- "categoryname": "更新后的分类名称",
- "categorydescription": "更新后的描述"
- }
-
+ update_data = {"categoryname": "更新后的分类名称", "categorydescription": "更新后的描述"}
+
updated_category = self.category_crud.update_category(
- conn=self.connection,
- category_id=test_category["CategoryID"],
- update_data=update_data
+ conn=self.connection, category_id=test_category["CategoryID"], update_data=update_data
)
-
+
# 验证结果
self.assertIsNotNone(updated_category)
self.assertEqual(updated_category["CategoryName"], update_data["categoryname"])
self.assertEqual(updated_category["CategoryDescription"], update_data["categorydescription"])
-
+
# 测试更新不存在的分类
non_existent_update = self.category_crud.update_category(
- conn=self.connection,
- category_id=999999,
- update_data={"categoryname": "不存在的分类"}
+ conn=self.connection, category_id=999999, update_data={"categoryname": "不存在的分类"}
)
self.assertIsNone(non_existent_update)
def test_update_category_parent(self):
"""测试更新分类的父分类"""
# 准备数据 - 创建两个分类
- category1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="分类1"
- )
-
- category2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="分类2"
- )
-
+ category1 = self.category_crud.create_category(conn=self.connection, category_name="分类1")
+
+ category2 = self.category_crud.create_category(conn=self.connection, category_name="分类2")
+
# 执行测试 - 将分类2设为分类1的父分类
- update_data = {
- "parentcategoryid": category2["CategoryID"]
- }
-
+ update_data = {"parentcategoryid": category2["CategoryID"]}
+
updated_category = self.category_crud.update_category(
- conn=self.connection,
- category_id=category1["CategoryID"],
- update_data=update_data
+ conn=self.connection, category_id=category1["CategoryID"], update_data=update_data
)
-
+
# 验证结果
self.assertIsNotNone(updated_category)
self.assertEqual(updated_category["ParentCategoryID"], category2["CategoryID"])
-
+
# 测试循环引用 - 将分类1设为分类2的父分类(应该失败,因为会形成循环)
- cyclic_update_data = {
- "parentcategoryid": category1["CategoryID"]
- }
-
+ cyclic_update_data = {"parentcategoryid": category1["CategoryID"]}
+
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=category2["CategoryID"],
- update_data=cyclic_update_data
+ conn=self.connection, category_id=category2["CategoryID"], update_data=cyclic_update_data
)
-
+
# 测试自引用 - 将分类设为自己的父分类(应该失败)
- self_update_data = {
- "parentcategoryid": category1["CategoryID"]
- }
-
+ self_update_data = {"parentcategoryid": category1["CategoryID"]}
+
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=category1["CategoryID"],
- update_data=self_update_data
+ conn=self.connection, category_id=category1["CategoryID"], update_data=self_update_data
)
def test_update_category_with_invalid_parent(self):
"""测试使用无效的父分类ID更新分类"""
# 准备数据
- test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类"
- )
-
+ test_category = self.category_crud.create_category(conn=self.connection, category_name="测试分类")
+
# 执行测试 - 尝试使用不存在的父分类ID
with self.assertRaises(ValueError):
self.category_crud.update_category(
- conn=self.connection,
- category_id=test_category["CategoryID"],
- update_data={"parentcategoryid": 999999}
+ conn=self.connection, category_id=test_category["CategoryID"], update_data={"parentcategoryid": 999999}
)
def test_delete_category(self):
"""测试删除分类"""
# 准备数据
- test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="准备删除的分类"
- )
-
+ test_category = self.category_crud.create_category(conn=self.connection, category_name="准备删除的分类")
+
# 执行测试
- result = self.category_crud.delete_category(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
-
+ result = self.category_crud.delete_category(conn=self.connection, category_id=test_category["CategoryID"])
+
# 验证结果
self.assertTrue(result)
-
+
# 验证分类已被删除
deleted_category = self.category_crud.get_category_by_id(
- conn=self.connection,
- category_id=test_category["CategoryID"]
+ conn=self.connection, category_id=test_category["CategoryID"]
)
self.assertIsNone(deleted_category)
-
+
# 测试删除不存在的分类
- non_existent_result = self.category_crud.delete_category(
- conn=self.connection,
- category_id=999999
- )
+ non_existent_result = self.category_crud.delete_category(conn=self.connection, category_id=999999)
self.assertFalse(non_existent_result)
def test_delete_category_with_subcategories(self):
"""测试删除有子分类的分类(应该失败)"""
# 准备数据 - 创建父分类和子分类
- parent_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="父分类"
- )
-
+ parent_category = self.category_crud.create_category(conn=self.connection, category_name="父分类")
+
self.category_crud.create_category(
- conn=self.connection,
- category_name="子分类",
- parent_category_id=parent_category["CategoryID"]
+ conn=self.connection, category_name="子分类", parent_category_id=parent_category["CategoryID"]
)
-
+
# 执行测试 - 尝试删除有子分类的父分类
with self.assertRaises(ValueError):
- self.category_crud.delete_category(
- conn=self.connection,
- category_id=parent_category["CategoryID"]
- )
+ self.category_crud.delete_category(conn=self.connection, category_id=parent_category["CategoryID"])
def test_delete_category_with_products(self):
"""测试删除有商品的分类(应该失败)"""
@@ -303,114 +225,98 @@ def test_delete_category_with_products(self):
with self.engine.begin() as conn:
# 创建分类
category_crud = get_category_crud_instance()
- test_category = category_crud.create_category(
- conn=conn,
- category_name="有商品的分类"
- )
-
+ test_category = category_crud.create_category(conn=conn, category_name="有商品的分类")
+
# 创建测试用户,使用随机用户名避免唯一键冲突
random_suffix = generate_random_string()
test_user_name = f"testuser_{random_suffix}"
test_user_email = f"{test_user_name}@example.com"
-
+
conn.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": test_user_name, "email": test_user_email}
+ """
+ ),
+ {"username": test_user_name, "email": test_user_email},
)
user_id_result = conn.execute(text("SELECT LAST_INSERT_ID()")).scalar()
-
+
# 创建测试店铺
conn.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": user_id_result}
+ """
+ ),
+ {"user_id": user_id_result},
)
store_id_result = conn.execute(text("SELECT LAST_INSERT_ID()")).scalar()
-
+
# 创建商品,关联到这个分类
conn.execute(
- text("""
+ text(
+ """
INSERT INTO Product (ProductName, Price, StoreID, CategoryID, StockQuantity, ProductStatus)
VALUES ('测试商品', 99.99, :store_id, :category_id, 10, 'ACTIVE')
- """),
- {"store_id": store_id_result, "category_id": test_category["CategoryID"]}
+ """
+ ),
+ {"store_id": store_id_result, "category_id": test_category["CategoryID"]},
)
-
+
# 执行测试 - 尝试删除有商品的分类
with self.assertRaises(ValueError):
- self.category_crud.delete_category(
- conn=self.connection,
- category_id=test_category["CategoryID"]
- )
+ self.category_crud.delete_category(conn=self.connection, category_id=test_category["CategoryID"])
def test_get_category_tree(self):
"""测试获取分类树结构"""
# 准备数据 - 创建一些分类和子分类
# 顶级分类
- cat1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="电子产品"
- )
-
- cat2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="服装"
- )
-
+ cat1 = self.category_crud.create_category(conn=self.connection, category_name="电子产品")
+
+ cat2 = self.category_crud.create_category(conn=self.connection, category_name="服装")
+
# 电子产品的子分类
cat1_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="手机",
- parent_category_id=cat1["CategoryID"]
+ conn=self.connection, category_name="手机", parent_category_id=cat1["CategoryID"]
)
-
+
cat1_2 = self.category_crud.create_category(
- conn=self.connection,
- category_name="电脑",
- parent_category_id=cat1["CategoryID"]
+ conn=self.connection, category_name="电脑", parent_category_id=cat1["CategoryID"]
)
-
+
# 服装的子分类
cat2_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="男装",
- parent_category_id=cat2["CategoryID"]
+ conn=self.connection, category_name="男装", parent_category_id=cat2["CategoryID"]
)
-
+
# 电脑的子分类
cat1_2_1 = self.category_crud.create_category(
- conn=self.connection,
- category_name="笔记本电脑",
- parent_category_id=cat1_2["CategoryID"]
+ conn=self.connection, category_name="笔记本电脑", parent_category_id=cat1_2["CategoryID"]
)
-
+
# 执行测试
- category_tree = self.category_crud.get_category_tree(
- conn=self.connection
- )
-
+ category_tree = self.category_crud.get_category_tree(conn=self.connection)
+
# 验证结果 - 顶级分类数量
self.assertIsInstance(category_tree, list)
self.assertGreaterEqual(len(category_tree), 2) # 至少两个顶级分类
-
+
# 验证结果 - 树结构
# 找到"电子产品"分类
electronic_cat = next((cat for cat in category_tree if cat["CategoryName"] == "电子产品"), None)
self.assertIsNotNone(electronic_cat)
self.assertIn("Children", electronic_cat)
self.assertGreaterEqual(len(electronic_cat["Children"]), 2) # 至少两个子分类
-
+
# 找到"电脑"分类
computer_cat = next((cat for cat in electronic_cat["Children"] if cat["CategoryName"] == "电脑"), None)
self.assertIsNotNone(computer_cat)
self.assertIn("Children", computer_cat)
self.assertGreaterEqual(len(computer_cat["Children"]), 1) # 至少一个子分类
-
+
# 找到"服装"分类
clothing_cat = next((cat for cat in category_tree if cat["CategoryName"] == "服装"), None)
self.assertIsNotNone(clothing_cat)
@@ -418,5 +324,5 @@ def test_get_category_tree(self):
self.assertGreaterEqual(len(clothing_cat["Children"]), 1) # 至少一个子分类
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/backend/test/unit/crud/test_order_crud.py b/src/backend/test/unit/crud/test_order_crud.py
index 4adbc0c..af6fda6 100644
--- a/src/backend/test/unit/crud/test_order_crud.py
+++ b/src/backend/test/unit/crud/test_order_crud.py
@@ -26,7 +26,7 @@ def setUp(self):
self.mock_conn.execute.return_value = self.mock_cursor_result
# Patch _set_actor_session_variable
- self.set_actor_patcher = patch.object(OrderCRUD, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(OrderCRUD, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -51,7 +51,7 @@ def setUp(self):
"ShippingTime": None,
"DeliveryTime": None,
"CompletionTime": None,
- "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0)
+ "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0),
}
def test_create_order_success(self):
@@ -79,9 +79,9 @@ def test_create_order_success(self):
"ShippingAddress_RecipientName": "New Order Recipient",
"ShippingAddress_PhoneNumber": "13012345678",
"ShippingAddress_Full": "456 New Address, New City",
- "Notes_ByUser": "Urgent delivery"
+ "Notes_ByUser": "Urgent delivery",
}
- with patch.object(self.crud, 'get_order_by_id', return_value=mock_created_order_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_order_by_id", return_value=mock_created_order_dict) as mock_get_by_id:
created_order = self.crud.create_order(
conn=self.mock_conn,
user_id=user_id,
@@ -94,7 +94,7 @@ def test_create_order_success(self):
shipping_address_phone_number="13012345678",
shipping_address_full="456 New Address, New City",
notes_by_user="Urgent delivery",
- actor_id=actor_id
+ actor_id=actor_id,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -106,24 +106,31 @@ def test_create_order_success(self):
self.assertIn(":UserID, :StoreID, :PaymentTransactionID, :OrderStatus", normalized_sql)
params = call_args[1]
- self.assertEqual(params['UserID'], user_id)
- self.assertEqual(params['StoreID'], store_id)
- self.assertEqual(params['PaymentTransactionID'], payment_transaction_id)
- self.assertEqual(params['OrderStatus'], order_status.value)
- self.assertEqual(params['OrderTotalAmount'], order_total)
+ self.assertEqual(params["UserID"], user_id)
+ self.assertEqual(params["StoreID"], store_id)
+ self.assertEqual(params["PaymentTransactionID"], payment_transaction_id)
+ self.assertEqual(params["OrderStatus"], order_status.value)
+ self.assertEqual(params["OrderTotalAmount"], order_total)
# ... 检查其他参数 ...
mock_get_by_id.assert_called_once_with(self.mock_conn, order_id=expected_new_order_id, actor_id=actor_id)
self.assertEqual(created_order, mock_created_order_dict)
- @patch('backend.app.crud.order_crud.logger')
+ @patch("backend.app.crud.order_crud.logger")
def test_create_order_lastrowid_none(self, mock_logger):
self.mock_cursor_result.lastrowid = None
result = self.crud.create_order(
- conn=self.mock_conn, user_id=1, store_id=1, payment_transaction_id=1,
- order_status=OrderStatusEnum.PENDING_PAYMENT, order_total_amount=Decimal(1),
- final_amount_for_this_order=Decimal(1), shipping_address_recipient_name="N",
- shipping_address_phone_number="P", shipping_address_full="F", actor_id=1
+ conn=self.mock_conn,
+ user_id=1,
+ store_id=1,
+ payment_transaction_id=1,
+ order_status=OrderStatusEnum.PENDING_PAYMENT,
+ order_total_amount=Decimal(1),
+ final_amount_for_this_order=Decimal(1),
+ shipping_address_recipient_name="N",
+ shipping_address_phone_number="P",
+ shipping_address_full="F",
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.warning.assert_called_once()
@@ -131,12 +138,19 @@ def test_create_order_lastrowid_none(self, mock_logger):
def test_create_order_integrity_error(self):
self.mock_conn.execute.side_effect = exc.IntegrityError("mock integrity", {}, None)
- with patch('backend.app.crud.order_crud.logger') as mock_logger:
+ with patch("backend.app.crud.order_crud.logger") as mock_logger:
result = self.crud.create_order(
- conn=self.mock_conn, user_id=1, store_id=1, payment_transaction_id=1,
- order_status=OrderStatusEnum.PENDING_PAYMENT, order_total_amount=Decimal(1),
- final_amount_for_this_order=Decimal(1), shipping_address_recipient_name="N",
- shipping_address_phone_number="P", shipping_address_full="F", actor_id=1
+ conn=self.mock_conn,
+ user_id=1,
+ store_id=1,
+ payment_transaction_id=1,
+ order_status=OrderStatusEnum.PENDING_PAYMENT,
+ order_total_amount=Decimal(1),
+ final_amount_for_this_order=Decimal(1),
+ shipping_address_recipient_name="N",
+ shipping_address_phone_number="P",
+ shipping_address_full="F",
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.error.assert_called_once()
@@ -155,7 +169,7 @@ def test_get_order_by_id_found(self):
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(f"SELECT OrderID, UserID, StoreID", normalized_sql) # 检查部分字段
self.assertIn(f"WHERE OrderID = :OrderID", normalized_sql)
- self.assertEqual(call_args[1]['OrderID'], order_id)
+ self.assertEqual(call_args[1]["OrderID"], order_id)
self.assertEqual(order, self.sample_order_data_dict)
def test_get_order_by_id_not_found(self):
@@ -178,9 +192,9 @@ def test_get_orders_by_user_id_found(self):
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(f"FROM {self.crud.table_name} WHERE UserID = :UserID", normalized_sql)
self.assertIn("ORDER BY CreationTime DESC LIMIT :Limit OFFSET :Offset", normalized_sql)
- self.assertEqual(call_args[1]['UserID'], user_id)
- self.assertEqual(call_args[1]['Limit'], 10)
- self.assertEqual(call_args[1]['Offset'], 0)
+ self.assertEqual(call_args[1]["UserID"], user_id)
+ self.assertEqual(call_args[1]["Limit"], 10)
+ self.assertEqual(call_args[1]["Offset"], 0)
self.assertEqual(len(orders), 2)
self.assertEqual(orders[0]["OrderID"], 1)
@@ -200,13 +214,20 @@ def test_update_order_status_success(self):
self.mock_cursor_result.rowcount = 1 # Simulate update affected 1 row
- mock_updated_order_dict = {**self.sample_order_data_dict, "OrderStatus": new_status.value,
- "ShippingTime": shipping_time_val}
- with patch.object(self.crud, 'get_order_by_id', return_value=mock_updated_order_dict) as mock_get_by_id:
+ mock_updated_order_dict = {
+ **self.sample_order_data_dict,
+ "OrderStatus": new_status.value,
+ "ShippingTime": shipping_time_val,
+ }
+ with patch.object(self.crud, "get_order_by_id", return_value=mock_updated_order_dict) as mock_get_by_id:
updated_order = self.crud.update_order_status(
- conn=self.mock_conn, order_id=order_id, new_status=new_status,
- actor_id=actor_id, shipping_time=shipping_time_val,
- notes_by_actor="Shipped by admin", is_admin_or_merchant_action=True
+ conn=self.mock_conn,
+ order_id=order_id,
+ new_status=new_status,
+ actor_id=actor_id,
+ shipping_time=shipping_time_val,
+ notes_by_actor="Shipped by admin",
+ is_admin_or_merchant_action=True,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -217,15 +238,16 @@ def test_update_order_status_success(self):
self.assertIn(f"UPDATE {self.crud.table_name} SET", normalized_sql)
self.assertIn("OrderStatus = :OrderStatus", normalized_sql)
self.assertIn("ShippingTime = :ShippingTime", normalized_sql)
- self.assertIn("Notes_ByMerchant = :Notes_ByMerchant",
- normalized_sql) # because is_admin_or_merchant_action=True
+ self.assertIn(
+ "Notes_ByMerchant = :Notes_ByMerchant", normalized_sql
+ ) # because is_admin_or_merchant_action=True
self.assertIn("WHERE OrderID = :OrderID_param", normalized_sql)
params = call_args[1]
- self.assertEqual(params['OrderStatus'], new_status.value)
- self.assertEqual(params['ShippingTime'], shipping_time_val)
- self.assertEqual(params['Notes_ByMerchant'], "Shipped by admin")
- self.assertEqual(params['OrderID_param'], order_id)
+ self.assertEqual(params["OrderStatus"], new_status.value)
+ self.assertEqual(params["ShippingTime"], shipping_time_val)
+ self.assertEqual(params["Notes_ByMerchant"], "Shipped by admin")
+ self.assertEqual(params["OrderID_param"], order_id)
mock_get_by_id.assert_called_once_with(self.mock_conn, order_id=order_id, actor_id=actor_id)
self.assertEqual(updated_order, mock_updated_order_dict)
@@ -244,5 +266,5 @@ def test_update_order_status_no_change_or_not_found(self):
# (unless SUT logic changes to fetch regardless)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_order_item_crud.py b/src/backend/test/unit/crud/test_order_item_crud.py
index 5c3d3ce..7848455 100644
--- a/src/backend/test/unit/crud/test_order_item_crud.py
+++ b/src/backend/test/unit/crud/test_order_item_crud.py
@@ -25,7 +25,7 @@ def setUp(self):
self.mock_conn.execute.return_value = self.mock_cursor_result
# Patch _set_actor_session_variable 作为 AddressCRUD 类的静态方法
- self.set_actor_patcher = patch.object(OrderItemCRUD, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(OrderItemCRUD, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -39,7 +39,7 @@ def setUp(self):
"PriceAtPurchase": Decimal("19.99"),
"ProductNameAtPurchase": "Test Product",
"ProductImageURLAtPurchase": "https://example.com/image.jpg",
- "Subtotal": Decimal("39.98")
+ "Subtotal": Decimal("39.98"),
}
def test_create_order_item_success(self):
@@ -58,12 +58,17 @@ def test_create_order_item_success(self):
# 模拟 get_order_item_by_id 的返回值
mock_created_item_dict = {
- "OrderItemID": expected_new_item_id, "OrderID": order_id, "ProductID": product_id,
- "StoreID": store_id, "Quantity": quantity, "PriceAtPurchase": price,
- "ProductNameAtPurchase": name, "ProductImageURLAtPurchase": image_url,
- "Subtotal": subtotal
+ "OrderItemID": expected_new_item_id,
+ "OrderID": order_id,
+ "ProductID": product_id,
+ "StoreID": store_id,
+ "Quantity": quantity,
+ "PriceAtPurchase": price,
+ "ProductNameAtPurchase": name,
+ "ProductImageURLAtPurchase": image_url,
+ "Subtotal": subtotal,
}
- with patch.object(self.crud, 'get_order_item_by_id', return_value=mock_created_item_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_order_item_by_id", return_value=mock_created_item_dict) as mock_get_by_id:
created_item = self.crud.create_order_item(
conn=self.mock_conn,
order_id=order_id,
@@ -74,7 +79,7 @@ def test_create_order_item_success(self):
product_name_at_purchase=name,
product_image_url_at_purchase=image_url,
subtotal=subtotal,
- actor_id=actor_id
+ actor_id=actor_id,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -85,35 +90,43 @@ def test_create_order_item_success(self):
self.assertIn(f"INSERT INTO {self.crud.table_name}", normalized_sql)
self.assertIn(
"OrderID, ProductID, StoreID, Quantity, PriceAtPurchase, ProductNameAtPurchase, ProductImageURLAtPurchase, Subtotal",
- normalized_sql
+ normalized_sql,
)
self.assertIn(
":OrderID, :ProductID, :StoreID, :Quantity, :PriceAtPurchase, :ProductNameAtPurchase, :ProductImageURLAtPurchase, :Subtotal",
- normalized_sql
+ normalized_sql,
)
params = call_args[1]
- self.assertEqual(params['OrderID'], order_id)
- self.assertEqual(params['ProductID'], product_id)
- self.assertEqual(params['StoreID'], store_id)
- self.assertEqual(params['Quantity'], quantity)
- self.assertEqual(params['PriceAtPurchase'], price)
- self.assertEqual(params['ProductNameAtPurchase'], name)
- self.assertEqual(params['ProductImageURLAtPurchase'], image_url)
- self.assertEqual(params['Subtotal'], subtotal)
-
- mock_get_by_id.assert_called_once_with(self.mock_conn, order_item_id=expected_new_item_id,
- actor_id=actor_id)
+ self.assertEqual(params["OrderID"], order_id)
+ self.assertEqual(params["ProductID"], product_id)
+ self.assertEqual(params["StoreID"], store_id)
+ self.assertEqual(params["Quantity"], quantity)
+ self.assertEqual(params["PriceAtPurchase"], price)
+ self.assertEqual(params["ProductNameAtPurchase"], name)
+ self.assertEqual(params["ProductImageURLAtPurchase"], image_url)
+ self.assertEqual(params["Subtotal"], subtotal)
+
+ mock_get_by_id.assert_called_once_with(
+ self.mock_conn, order_item_id=expected_new_item_id, actor_id=actor_id
+ )
self.assertEqual(created_item, mock_created_item_dict)
- @patch('backend.app.crud.order_item_crud.logger') # Patch logger for this specific test
+ @patch("backend.app.crud.order_item_crud.logger") # Patch logger for this specific test
def test_create_order_item_lastrowid_none(self, mock_logger):
self.mock_cursor_result.lastrowid = None # Simulate lastrowid not available
result = self.crud.create_order_item(
- conn=self.mock_conn, order_id=1, product_id=1, store_id=1, quantity=1,
- price_at_purchase=Decimal(1), product_name_at_purchase="N",
- product_image_url_at_purchase=None, subtotal=Decimal(1), actor_id=1
+ conn=self.mock_conn,
+ order_id=1,
+ product_id=1,
+ store_id=1,
+ quantity=1,
+ price_at_purchase=Decimal(1),
+ product_name_at_purchase="N",
+ product_image_url_at_purchase=None,
+ subtotal=Decimal(1),
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.warning.assert_called_once()
@@ -121,11 +134,18 @@ def test_create_order_item_lastrowid_none(self, mock_logger):
def test_create_order_item_integrity_error(self):
self.mock_conn.execute.side_effect = exc.IntegrityError("mock integrity", {}, None)
- with patch('backend.app.crud.order_item_crud.logger') as mock_logger:
+ with patch("backend.app.crud.order_item_crud.logger") as mock_logger:
result = self.crud.create_order_item(
- conn=self.mock_conn, order_id=1, product_id=1, store_id=1, quantity=1,
- price_at_purchase=Decimal(1), product_name_at_purchase="N",
- product_image_url_at_purchase=None, subtotal=Decimal(1), actor_id=1
+ conn=self.mock_conn,
+ order_id=1,
+ product_id=1,
+ store_id=1,
+ quantity=1,
+ price_at_purchase=Decimal(1),
+ product_name_at_purchase="N",
+ product_image_url_at_purchase=None,
+ subtotal=Decimal(1),
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.error.assert_called_once()
@@ -145,9 +165,10 @@ def test_get_order_item_by_id_found(self):
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(
f"SELECT OrderItemID, OrderID, ProductID, StoreID, Quantity, PriceAtPurchase, ProductNameAtPurchase, ProductImageURLAtPurchase, Subtotal FROM {self.crud.table_name}",
- normalized_sql)
+ normalized_sql,
+ )
self.assertIn("WHERE OrderItemID = :OrderItemID", normalized_sql)
- self.assertEqual(call_args[1]['OrderItemID'], item_id)
+ self.assertEqual(call_args[1]["OrderItemID"], item_id)
self.assertEqual(item, self.sample_order_item_data)
def test_get_order_item_by_id_not_found(self):
@@ -173,7 +194,7 @@ def test_get_order_items_by_order_id_found(self):
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(f"FROM {self.crud.table_name} WHERE OrderID = :OrderID", normalized_sql)
self.assertIn("ORDER BY OrderItemID ASC", normalized_sql)
- self.assertEqual(call_args[1]['OrderID'], order_id)
+ self.assertEqual(call_args[1]["OrderID"], order_id)
self.assertEqual(len(items), 2)
self.assertEqual(items[0], mock_row_data1)
@@ -185,5 +206,5 @@ def test_get_order_items_by_order_id_none_found(self):
self.assertEqual(len(items), 0)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_payment_transaction_crud.py b/src/backend/test/unit/crud/test_payment_transaction_crud.py
index 5b9e8d5..f63047b 100644
--- a/src/backend/test/unit/crud/test_payment_transaction_crud.py
+++ b/src/backend/test/unit/crud/test_payment_transaction_crud.py
@@ -26,7 +26,7 @@ def setUp(self):
self.mock_conn.execute.return_value = self.mock_cursor_result
# Patch _set_actor_session_variable
- self.set_actor_patcher = patch.object(PaymentTransactionCRUD, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(PaymentTransactionCRUD, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -40,7 +40,7 @@ def setUp(self):
"Status": PaymentTransactionStatusEnum.PENDING.value,
"CreationTime": datetime.datetime(2025, 1, 1, 10, 0, 0),
"CompletionTime": None,
- "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0)
+ "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0),
}
def test_create_payment_transaction_success(self):
@@ -60,16 +60,18 @@ def test_create_payment_transaction_success(self):
"UserID": user_id,
"TotalAmount": total_amount,
"PaymentMethod": payment_method,
- "Status": status
+ "Status": status,
}
- with patch.object(self.crud, 'get_payment_transaction_by_id', return_value=mock_created_transaction_dict) as mock_get_by_id:
+ with patch.object(
+ self.crud, "get_payment_transaction_by_id", return_value=mock_created_transaction_dict
+ ) as mock_get_by_id:
created_transaction = self.crud.create_payment_transaction(
conn=self.mock_conn,
user_id=user_id,
total_amount=total_amount,
payment_method=payment_method,
status=status,
- actor_id=actor_id
+ actor_id=actor_id,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -81,21 +83,26 @@ def test_create_payment_transaction_success(self):
self.assertIn("UserID, TotalAmount, PaymentMethod, Status", normalized_sql)
params = call_args[1]
- self.assertEqual(params['UserID'], user_id)
- self.assertEqual(params['TotalAmount'], total_amount)
- self.assertEqual(params['PaymentMethod'], payment_method)
- self.assertEqual(params['Status'], status)
+ self.assertEqual(params["UserID"], user_id)
+ self.assertEqual(params["TotalAmount"], total_amount)
+ self.assertEqual(params["PaymentMethod"], payment_method)
+ self.assertEqual(params["Status"], status)
- mock_get_by_id.assert_called_once_with(self.mock_conn, payment_transaction_id=expected_new_transaction_id, actor_id=actor_id)
+ mock_get_by_id.assert_called_once_with(
+ self.mock_conn, payment_transaction_id=expected_new_transaction_id, actor_id=actor_id
+ )
self.assertEqual(created_transaction, mock_created_transaction_dict)
- @patch('backend.app.crud.payment_transaction_crud.logger')
+ @patch("backend.app.crud.payment_transaction_crud.logger")
def test_create_payment_transaction_lastrowid_none(self, mock_logger):
self.mock_cursor_result.lastrowid = None
result = self.crud.create_payment_transaction(
- conn=self.mock_conn, user_id=101, total_amount=Decimal("100.00"),
- payment_method="Credit Card", status=PaymentTransactionStatusEnum.PENDING.value,
- actor_id=101
+ conn=self.mock_conn,
+ user_id=101,
+ total_amount=Decimal("100.00"),
+ payment_method="Credit Card",
+ status=PaymentTransactionStatusEnum.PENDING.value,
+ actor_id=101,
)
self.assertIsNone(result)
mock_logger.warning.assert_called_once()
@@ -103,11 +110,14 @@ def test_create_payment_transaction_lastrowid_none(self, mock_logger):
def test_create_payment_transaction_integrity_error(self):
self.mock_conn.execute.side_effect = exc.IntegrityError("mock integrity", {}, None)
- with patch('backend.app.crud.payment_transaction_crud.logger') as mock_logger:
+ with patch("backend.app.crud.payment_transaction_crud.logger") as mock_logger:
result = self.crud.create_payment_transaction(
- conn=self.mock_conn, user_id=101, total_amount=Decimal("100.00"),
- payment_method="Credit Card", status=PaymentTransactionStatusEnum.PENDING.value,
- actor_id=101
+ conn=self.mock_conn,
+ user_id=101,
+ total_amount=Decimal("100.00"),
+ payment_method="Credit Card",
+ status=PaymentTransactionStatusEnum.PENDING.value,
+ actor_id=101,
)
self.assertIsNone(result)
mock_logger.error.assert_called_once()
@@ -120,14 +130,16 @@ def test_get_payment_transaction_by_id_found(self):
mock_row._mapping = self.sample_payment_transaction_data
self.mock_cursor_result.fetchone.return_value = mock_row
- transaction = self.crud.get_payment_transaction_by_id(self.mock_conn, payment_transaction_id=transaction_id, actor_id=actor_id)
+ transaction = self.crud.get_payment_transaction_by_id(
+ self.mock_conn, payment_transaction_id=transaction_id, actor_id=actor_id
+ )
self.mock_conn.execute.assert_called_once()
call_args = self.mock_conn.execute.call_args.args
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn("SELECT PaymentTransactionID, UserID, TotalAmount, PaymentMethod", normalized_sql)
self.assertIn("WHERE PaymentTransactionID = :PaymentTransactionID", normalized_sql)
- self.assertEqual(call_args[1]['PaymentTransactionID'], transaction_id)
+ self.assertEqual(call_args[1]["PaymentTransactionID"], transaction_id)
self.assertEqual(transaction, self.sample_payment_transaction_data)
def test_get_payment_transaction_by_id_not_found(self):
@@ -153,9 +165,9 @@ def test_get_payment_transactions_by_user_id_found(self):
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(f"FROM {self.crud.table_name} WHERE UserID = :UserID", normalized_sql)
self.assertIn("ORDER BY CreationTime DESC LIMIT :Limit OFFSET :Offset", normalized_sql)
- self.assertEqual(call_args[1]['UserID'], user_id)
- self.assertEqual(call_args[1]['Limit'], 10)
- self.assertEqual(call_args[1]['Offset'], 0)
+ self.assertEqual(call_args[1]["UserID"], user_id)
+ self.assertEqual(call_args[1]["Limit"], 10)
+ self.assertEqual(call_args[1]["Offset"], 0)
self.assertEqual(len(transactions), 2)
self.assertEqual(transactions[0]["PaymentTransactionID"], 1)
@@ -179,16 +191,18 @@ def test_update_payment_transaction_status_success(self):
**self.sample_payment_transaction_data,
"Status": new_status,
"ExternalGatewayTransactionID": external_gateway_transaction_id,
- "CompletionTime": completion_time
+ "CompletionTime": completion_time,
}
- with patch.object(self.crud, 'get_payment_transaction_by_id', return_value=mock_updated_transaction_dict) as mock_get_by_id:
+ with patch.object(
+ self.crud, "get_payment_transaction_by_id", return_value=mock_updated_transaction_dict
+ ) as mock_get_by_id:
updated_transaction = self.crud.update_payment_transaction_status(
conn=self.mock_conn,
payment_transaction_id=transaction_id,
new_status=new_status,
actor_id=actor_id,
external_gateway_transaction_id=external_gateway_transaction_id,
- completion_time=completion_time
+ completion_time=completion_time,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -203,12 +217,14 @@ def test_update_payment_transaction_status_success(self):
self.assertIn("WHERE PaymentTransactionID = :PaymentTransactionID_param", normalized_sql)
params = call_args[1]
- self.assertEqual(params['Status'], new_status)
- self.assertEqual(params['ExternalGatewayTransactionID'], external_gateway_transaction_id)
- self.assertEqual(params['CompletionTime'], completion_time)
- self.assertEqual(params['PaymentTransactionID_param'], transaction_id)
+ self.assertEqual(params["Status"], new_status)
+ self.assertEqual(params["ExternalGatewayTransactionID"], external_gateway_transaction_id)
+ self.assertEqual(params["CompletionTime"], completion_time)
+ self.assertEqual(params["PaymentTransactionID_param"], transaction_id)
- mock_get_by_id.assert_called_once_with(self.mock_conn, payment_transaction_id=transaction_id, actor_id=actor_id)
+ mock_get_by_id.assert_called_once_with(
+ self.mock_conn, payment_transaction_id=transaction_id, actor_id=actor_id
+ )
self.assertEqual(updated_transaction, mock_updated_transaction_dict)
def test_update_payment_transaction_status_no_change_or_not_found(self):
@@ -223,5 +239,5 @@ def test_update_payment_transaction_status_no_change_or_not_found(self):
self.assertIsNone(updated_transaction)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_product_change_request_crud.py b/src/backend/test/unit/crud/test_product_change_request_crud.py
index 9bb5a09..c0f7580 100644
--- a/src/backend/test/unit/crud/test_product_change_request_crud.py
+++ b/src/backend/test/unit/crud/test_product_change_request_crud.py
@@ -15,7 +15,7 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
class TestProductChangeRequestCRUD(BaseDBTestCaseAutoRollback):
@@ -25,56 +25,59 @@ def setUp(self):
self.product_change_request_crud = get_product_change_request_crud_instance()
self.product_crud = get_product_crud_instance()
self.category_crud = get_category_crud_instance()
-
+
# 创建测试数据 - 创建分类
self.test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类",
- category_description="用于测试的商品分类",
- actor_id=None
+ conn=self.connection, category_name="测试分类", category_description="用于测试的商品分类", actor_id=None
)
-
+
# 创建测试数据 - 创建测试用户(商家)
random_suffix = generate_random_string()
self.test_username = f"testuser_{random_suffix}"
self.test_email = f"{self.test_username}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": self.test_username, "email": self.test_email}
+ """
+ ),
+ {"username": self.test_username, "email": self.test_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
-
+
# 创建管理员用户
random_suffix = generate_random_string()
admin_username = f"admin_{random_suffix}"
admin_email = f"{admin_username}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'admin')
- """),
- {"username": admin_username, "email": admin_email}
+ """
+ ),
+ {"username": admin_username, "email": admin_email},
)
admin_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.admin_user_id = admin_id_result[0]
-
+
# 创建测试数据 - 创建店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_store_id = store_id_result[0]
-
+
# 创建测试数据 - 创建商品
self.test_product = self.product_crud.create_product(
conn=self.connection,
@@ -84,18 +87,14 @@ def setUp(self):
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
stock_quantity=100,
- main_image_url="http://example.com/test.jpg"
+ main_image_url="http://example.com/test.jpg",
)
def test_create_change_request(self):
"""测试创建商品变更请求"""
# 准备数据
- proposed_data = {
- "ProductName": "更新后的商品名称",
- "ProductDescription": "更新后的商品描述",
- "Price": 199.99
- }
-
+ proposed_data = {"ProductName": "更新后的商品名称", "ProductDescription": "更新后的商品描述", "Price": 199.99}
+
# 执行测试
change_request = self.product_change_request_crud.create_change_request(
conn=self.connection,
@@ -105,9 +104,9 @@ def test_create_change_request(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes="请求更改商品信息",
- actor_id=None
+ actor_id=None,
)
-
+
# 验证结果
self.assertIsNotNone(change_request)
self.assertIsInstance(change_request, dict)
@@ -118,15 +117,15 @@ def test_create_change_request(self):
self.assertEqual(change_request["RequestType"], "PRODUCT_UPDATE")
self.assertEqual(change_request["Status"], "PENDING_APPROVAL")
self.assertEqual(change_request["SubmitterNotes"], "请求更改商品信息")
-
+
# 验证提议的数据
self.assertIsNotNone(change_request["ProposedData_JSON"])
-
+
# 对于JSON字段,需要确认它已经被解析为Python对象
self.assertIsInstance(change_request["ProposedData_JSON"], dict)
self.assertEqual(change_request["ProposedData_JSON"]["ProductName"], proposed_data["ProductName"])
self.assertEqual(change_request["ProposedData_JSON"]["ProductDescription"], proposed_data["ProductDescription"])
-
+
# 验证日期字段
self.assertIsNotNone(change_request["CreationTime"])
self.assertIsNotNone(change_request["LastUpdatedDate"])
@@ -135,7 +134,7 @@ def test_get_change_request_by_id(self):
"""测试根据ID获取商品变更请求"""
# 准备数据 - 先创建变更请求
proposed_data = {"ProductName": "测试商品2"}
-
+
created_request = self.product_change_request_crud.create_change_request(
conn=self.connection,
merchant_user_id=self.test_user_id,
@@ -144,17 +143,15 @@ def test_get_change_request_by_id(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes="测试请求",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
request_id = created_request["ChangeRequestID"]
retrieved_request = self.product_change_request_crud.get_change_request_by_id(
- conn=self.connection,
- request_id=request_id,
- actor_id=None
+ conn=self.connection, request_id=request_id, actor_id=None
)
-
+
# 验证结果
self.assertIsNotNone(retrieved_request)
self.assertEqual(retrieved_request["ChangeRequestID"], request_id)
@@ -162,7 +159,7 @@ def test_get_change_request_by_id(self):
self.assertEqual(retrieved_request["RequestType"], "PRODUCT_UPDATE")
self.assertEqual(retrieved_request["ProductID"], self.test_product["ProductID"])
self.assertEqual(retrieved_request["Status"], "PENDING_APPROVAL")
-
+
# 验证提议的数据
self.assertIsNotNone(retrieved_request["ProposedData_JSON"])
self.assertEqual(retrieved_request["ProposedData_JSON"]["ProductName"], proposed_data["ProductName"])
@@ -180,16 +177,14 @@ def test_get_change_requests_by_product_id(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes=f"测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
requests = self.product_change_request_crud.get_change_requests_by_product_id(
- conn=self.connection,
- product_id=self.test_product["ProductID"],
- actor_id=None
+ conn=self.connection, product_id=self.test_product["ProductID"], actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(requests), 3)
for request in requests:
@@ -209,16 +204,14 @@ def test_get_change_requests_by_store_id(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes=f"店铺测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
requests = self.product_change_request_crud.get_change_requests_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- actor_id=None
+ conn=self.connection, store_id=self.test_store_id, actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(requests), 3)
for request in requests:
@@ -238,16 +231,14 @@ def test_get_change_requests_by_merchant_id(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes=f"商家测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
requests = self.product_change_request_crud.get_change_requests_by_merchant_id(
- conn=self.connection,
- merchant_id=self.test_user_id,
- actor_id=None
+ conn=self.connection, merchant_id=self.test_user_id, actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(requests), 3)
for request in requests:
@@ -267,9 +258,9 @@ def test_get_all_pending_requests(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes=f"待审核测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 将其中两个请求设置为已审核状态
if i < 2:
self.product_change_request_crud.update_request_status(
@@ -278,15 +269,14 @@ def test_get_all_pending_requests(self):
status="APPROVED",
admin_id=self.admin_user_id,
admin_notes="已审核通过",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
pending_requests = self.product_change_request_crud.get_all_pending_requests(
- conn=self.connection,
- actor_id=None
+ conn=self.connection, actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(pending_requests), 2) # 应该只有2个待审核的请求
for request in pending_requests:
@@ -296,7 +286,7 @@ def test_get_filtered_requests(self):
"""测试获取根据多种条件筛选的请求列表"""
# 准备数据 - 创建多种类型的变更请求
request_types = ["PRODUCT_CREATE", "PRODUCT_UPDATE", "PRODUCT_DELETE"]
-
+
for i, request_type in enumerate(request_types):
proposed_data = {"ProductName": f"{request_type}测试{i}"}
self.product_change_request_crud.create_change_request(
@@ -307,28 +297,23 @@ def test_get_filtered_requests(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"] if request_type != "PRODUCT_CREATE" else None,
submitter_notes=f"{request_type}测试请求",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 按请求类型筛选
update_requests = self.product_change_request_crud.get_filtered_requests(
- conn=self.connection,
- request_type="PRODUCT_UPDATE",
- actor_id=None
+ conn=self.connection, request_type="PRODUCT_UPDATE", actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(update_requests), 1)
self.assertEqual(update_requests[0]["RequestType"], "PRODUCT_UPDATE")
-
+
# 执行测试 - 按多种条件筛选
all_requests = self.product_change_request_crud.get_filtered_requests(
- conn=self.connection,
- merchant_id=self.test_user_id,
- status="PENDING_APPROVAL",
- actor_id=None
+ conn=self.connection, merchant_id=self.test_user_id, status="PENDING_APPROVAL", actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(all_requests), 3) # 应该有3个符合条件的请求
for request in all_requests:
@@ -347,9 +332,9 @@ def test_update_request_status(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes="等待审核",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 更新为已审核
updated_request = self.product_change_request_crud.update_request_status(
conn=self.connection,
@@ -357,9 +342,9 @@ def test_update_request_status(self):
status="APPROVED",
admin_id=self.admin_user_id,
admin_notes="审核通过",
- actor_id=None
+ actor_id=None,
)
-
+
# 验证结果
self.assertIsNotNone(updated_request)
self.assertEqual(updated_request["Status"], "APPROVED")
@@ -379,23 +364,17 @@ def test_update_request(self):
proposed_data=original_proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes="原始备注",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 更新请求内容
new_proposed_data = {"ProductName": "更新后的商品名称", "ProductDescription": "更新后的描述"}
- update_data = {
- "ProposedData_JSON": new_proposed_data,
- "SubmitterNotes": "更新后的备注"
- }
-
+ update_data = {"ProposedData_JSON": new_proposed_data, "SubmitterNotes": "更新后的备注"}
+
updated_request = self.product_change_request_crud.update_request(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- update_data=update_data,
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], update_data=update_data, actor_id=None
)
-
+
# 验证结果
self.assertIsNotNone(updated_request)
self.assertEqual(updated_request["SubmitterNotes"], "更新后的备注")
@@ -414,28 +393,24 @@ def test_cancel_request(self):
proposed_data=proposed_data,
product_id=self.test_product["ProductID"],
submitter_notes="即将取消",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 取消请求
result = self.product_change_request_crud.cancel_request(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], actor_id=None
)
-
+
# 验证结果
self.assertTrue(result)
-
+
# 检查请求状态是否已更新
cancelled_request = self.product_change_request_crud.get_change_request_by_id(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], actor_id=None
)
-
+
self.assertEqual(cancelled_request["Status"], "CANCELLED_BY_USER")
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/crud/test_product_change_request_crud_v2.py b/src/backend/test/unit/crud/test_product_change_request_crud_v2.py
index 9ea9aa5..5912cd4 100644
--- a/src/backend/test/unit/crud/test_product_change_request_crud_v2.py
+++ b/src/backend/test/unit/crud/test_product_change_request_crud_v2.py
@@ -9,7 +9,7 @@
from backend.app.crud.product_change_request_crud_v2 import ProductChangeRequestCRUD2
from backend.app.schemas.product_change_request_schema_v2 import (
ProductChangeRequestTypeApiEnum as TypeEnum,
- ProductChangeRequestStatusApiEnum as StatusEnum
+ ProductChangeRequestStatusApiEnum as StatusEnum,
)
# 用于 SQLAlchemy text 和 exc (如果需要模拟异常)
@@ -20,7 +20,7 @@
# 辅助函数来规范化SQL字符串以便比较
def normalize_sql(sql_string: str) -> str:
"""将SQL字符串中的多个空格和换行符替换为单个空格,并去除首尾空格。"""
- return ' '.join(sql_string.strip().split())
+ return " ".join(sql_string.strip().split())
class TestProductChangeRequestCRUD2(unittest.TestCase):
@@ -32,7 +32,7 @@ def setUp(self):
self.mock_cursor_result = MagicMock()
self.mock_conn.execute.return_value = self.mock_cursor_result
- self.set_actor_patcher = patch.object(ProductChangeRequestCRUD2, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(ProductChangeRequestCRUD2, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -50,7 +50,7 @@ def setUp(self):
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": datetime.datetime(2025, 1, 1, 10, 0, 0),
- "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0)
+ "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0),
}
# --- 测试 _deserialize_proposed_data 和 _serialize_proposed_data ---
@@ -60,7 +60,7 @@ def test_deserialize_proposed_data(self):
self.assertEqual(self.crud._deserialize_proposed_data(json_str), expected_dict)
self.assertEqual(self.crud._deserialize_proposed_data(expected_dict), expected_dict) # Should return dict as is
self.assertIsNone(self.crud._deserialize_proposed_data(None))
- with patch('backend.app.crud.product_change_request_crud_v2.logger') as mock_logger:
+ with patch("backend.app.crud.product_change_request_crud_v2.logger") as mock_logger:
self.assertIsNone(self.crud._deserialize_proposed_data("invalid json"))
mock_logger.warning.assert_called_once()
@@ -68,8 +68,9 @@ def test_serialize_proposed_data(self):
data_dict = {"key": "value"}
expected_json_str = json.dumps(data_dict)
self.assertEqual(self.crud._serialize_proposed_data(data_dict), expected_json_str)
- self.assertEqual(self.crud._serialize_proposed_data(expected_json_str),
- expected_json_str) # Should return str as is
+ self.assertEqual(
+ self.crud._serialize_proposed_data(expected_json_str), expected_json_str
+ ) # Should return str as is
self.assertIsNone(self.crud._serialize_proposed_data(None))
# --- 测试 get_request_by_id ---
@@ -98,16 +99,20 @@ def test_get_request_by_id_for_owner_found(self):
request_id = 1
merchant_user_id = 201
mock_row = MagicMock()
- db_return_data = {**self.sample_request_data, "MerchantUserID": merchant_user_id,
- "ProposedData_JSON": '{"key": "val"}'}
+ db_return_data = {
+ **self.sample_request_data,
+ "MerchantUserID": merchant_user_id,
+ "ProposedData_JSON": '{"key": "val"}',
+ }
mock_row._mapping = db_return_data
self.mock_cursor_result.fetchone.return_value = mock_row
request = self.crud.get_request_by_id_for_owner(
self.mock_conn, request_id=request_id, merchant_user_id=merchant_user_id
)
- self.mock_conn.execute.assert_called_once_with(ANY, {"ChangeRequestID": request_id,
- "MerchantUserID": merchant_user_id})
+ self.mock_conn.execute.assert_called_once_with(
+ ANY, {"ChangeRequestID": request_id, "MerchantUserID": merchant_user_id}
+ )
self.assertIsNotNone(request)
self.assertEqual(request["MerchantUserID"], merchant_user_id) # type: ignore
self.assertEqual(request["ProposedData_JSON"], {"key": "val"}) # type: ignore
@@ -115,13 +120,21 @@ def test_get_request_by_id_for_owner_found(self):
# --- 测试 get_request_list ---
def test_get_request_list_with_status_list_filter(self):
status_filter = [StatusEnum.PENDING_APPROVAL.value, StatusEnum.APPROVED.value]
- mock_row1_db = {**self.sample_request_data, "ChangeRequestID": 1, "Status": status_filter[0],
- "ProposedData_JSON": '{"p":1}'}
- mock_row2_db = {**self.sample_request_data, "ChangeRequestID": 2, "Status": status_filter[1],
- "ProposedData_JSON": '{"p":2}'}
- row1 = MagicMock();
+ mock_row1_db = {
+ **self.sample_request_data,
+ "ChangeRequestID": 1,
+ "Status": status_filter[0],
+ "ProposedData_JSON": '{"p":1}',
+ }
+ mock_row2_db = {
+ **self.sample_request_data,
+ "ChangeRequestID": 2,
+ "Status": status_filter[1],
+ "ProposedData_JSON": '{"p":2}',
+ }
+ row1 = MagicMock()
row1._mapping = mock_row1_db
- row2 = MagicMock();
+ row2 = MagicMock()
row2._mapping = mock_row2_db
self.mock_cursor_result.fetchall.return_value = [row1, row2]
@@ -143,7 +156,7 @@ def test_get_request_list_all_filters(self):
"request_type": TypeEnum.PRODUCT_UPDATE.value,
"store_id": 301,
"product_id": 101,
- "merchant_user_id": 201
+ "merchant_user_id": 201,
}
self.mock_cursor_result.fetchall.return_value = [] # No need to check data, just query build
@@ -178,18 +191,28 @@ def test_create_request_create_product_success(self):
# Mock the get_request_by_id call made by _create_generic_request
mock_final_request_data = {
- "ChangeRequestID": expected_request_id, "MerchantUserID": merchant_user_id, "StoreID": store_id,
- "RequestType": TypeEnum.PRODUCT_CREATE.value, "ProposedData_JSON": proposed_data,
- "SubmitterNotes": submitter_notes, "ProductID": None, "Status": StatusEnum.PENDING_APPROVAL.value
+ "ChangeRequestID": expected_request_id,
+ "MerchantUserID": merchant_user_id,
+ "StoreID": store_id,
+ "RequestType": TypeEnum.PRODUCT_CREATE.value,
+ "ProposedData_JSON": proposed_data,
+ "SubmitterNotes": submitter_notes,
+ "ProductID": None,
+ "Status": StatusEnum.PENDING_APPROVAL.value,
# ... other fields with defaults or None ...
}
- with patch.object(self.crud, 'get_request_by_id', return_value=mock_final_request_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
created_request = self.crud.create_request_create_product(
- conn=self.mock_conn, merchant_user_id=merchant_user_id, store_id=store_id,
- submitter_notes=submitter_notes, proposed_data_json=proposed_data, actor_id=actor_id
+ conn=self.mock_conn,
+ merchant_user_id=merchant_user_id,
+ store_id=store_id,
+ submitter_notes=submitter_notes,
+ proposed_data_json=proposed_data,
+ actor_id=actor_id,
)
- self.mock_set_actor_session_variable.assert_called_with(self.mock_conn,
- actor_id) # Called by _create_generic_request
+ self.mock_set_actor_session_variable.assert_called_with(
+ self.mock_conn, actor_id
+ ) # Called by _create_generic_request
self.mock_conn.execute.assert_called_once() # INSERT call
call_args = self.mock_conn.execute.call_args.args
@@ -211,15 +234,23 @@ def test_create_request_delete_product_success(self):
self.mock_cursor_result.lastrowid = expected_request_id
mock_final_request_data = {
- "ChangeRequestID": expected_request_id, "MerchantUserID": merchant_user_id, "StoreID": store_id,
- "RequestType": TypeEnum.PRODUCT_DELETE.value, "ProposedData_JSON": None,
- "SubmitterNotes": submitter_notes, "ProductID": product_id_to_delete,
- "Status": StatusEnum.PENDING_APPROVAL.value
+ "ChangeRequestID": expected_request_id,
+ "MerchantUserID": merchant_user_id,
+ "StoreID": store_id,
+ "RequestType": TypeEnum.PRODUCT_DELETE.value,
+ "ProposedData_JSON": None,
+ "SubmitterNotes": submitter_notes,
+ "ProductID": product_id_to_delete,
+ "Status": StatusEnum.PENDING_APPROVAL.value,
}
- with patch.object(self.crud, 'get_request_by_id', return_value=mock_final_request_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
created_request = self.crud.create_request_delete_product(
- conn=self.mock_conn, merchant_user_id=merchant_user_id, store_id=store_id,
- product_id=product_id_to_delete, submitter_notes=submitter_notes, actor_id=actor_id
+ conn=self.mock_conn,
+ merchant_user_id=merchant_user_id,
+ store_id=store_id,
+ product_id=product_id_to_delete,
+ submitter_notes=submitter_notes,
+ actor_id=actor_id,
)
self.mock_conn.execute.assert_called_once()
params = self.mock_conn.execute.call_args.args[1]
@@ -257,7 +288,7 @@ def test_cancel_request_not_pending_or_not_found(self):
self.assertFalse(result)
# --- 测试 update_request_by_admin ---
- @patch('backend.app.crud.product_change_request_crud_v2.datetime') # To control ReviewTimestamp
+ @patch("backend.app.crud.product_change_request_crud_v2.datetime") # To control ReviewTimestamp
def test_update_request_by_admin_approve_success(self, mock_datetime_module):
fixed_utc_now = datetime.datetime(2025, 5, 20, 14, 0, 0, tzinfo=datetime.timezone.utc)
# SUT uses UTC_TIMESTAMP() in SQL, which is fine.
@@ -277,14 +308,20 @@ def test_update_request_by_admin_approve_success(self, mock_datetime_module):
# Mock get_request_by_id for the return value
mock_updated_request_data = {
**self.sample_request_data,
- "ChangeRequestID": request_id, "Status": new_status,
- "AdminReviewerID": admin_reviewer_id, "AdminNotes": admin_notes,
- "ReviewTimestamp": fixed_utc_now # Simulate what DB might return after update
+ "ChangeRequestID": request_id,
+ "Status": new_status,
+ "AdminReviewerID": admin_reviewer_id,
+ "AdminNotes": admin_notes,
+ "ReviewTimestamp": fixed_utc_now, # Simulate what DB might return after update
}
- with patch.object(self.crud, 'get_request_by_id', return_value=mock_updated_request_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_updated_request_data) as mock_get_by_id:
updated_request = self.crud.update_request_by_admin(
- conn=self.mock_conn, request_id=request_id, status=new_status,
- admin_notes=admin_notes, admin_reviewer_id=admin_reviewer_id, actor_id=actor_id
+ conn=self.mock_conn,
+ request_id=request_id,
+ status=new_status,
+ admin_notes=admin_notes,
+ admin_reviewer_id=admin_reviewer_id,
+ actor_id=actor_id,
)
self.mock_set_actor_session_variable.assert_called_with(self.mock_conn, actor_id)
@@ -313,14 +350,18 @@ def test_update_request_by_admin_approve_success(self, mock_datetime_module):
def test_update_request_by_admin_no_change_or_not_found(self):
self.mock_cursor_result.rowcount = 0
# Simulate get_request_by_id returning None if not found after failed update
- with patch.object(self.crud, 'get_request_by_id', return_value=None) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=None) as mock_get_by_id:
result = self.crud.update_request_by_admin(
- conn=self.mock_conn, request_id=999, status=StatusEnum.REJECTED.value,
- admin_notes="Not found", admin_reviewer_id=999, actor_id=999
+ conn=self.mock_conn,
+ request_id=999,
+ status=StatusEnum.REJECTED.value,
+ admin_notes="Not found",
+ admin_reviewer_id=999,
+ actor_id=999,
)
self.assertIsNone(result)
mock_get_by_id.assert_called_once_with(self.mock_conn, request_id=999)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_product_crud.py b/src/backend/test/unit/crud/test_product_crud.py
index fa634c2..208efa4 100644
--- a/src/backend/test/unit/crud/test_product_crud.py
+++ b/src/backend/test/unit/crud/test_product_crud.py
@@ -14,7 +14,8 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
+
@unittest.skip("This test has been moved to test_product_crud_integration.py")
class TestProductCRUD(BaseDBTestCaseAutoRollback):
@@ -23,37 +24,38 @@ def setUp(self):
# 初始化CRUD实例
self.product_crud = get_product_crud_instance()
self.category_crud = get_category_crud_instance()
-
+
# 创建测试数据 - 创建分类
self.test_category = self.category_crud.create_category(
- conn=self.connection,
- category_name="测试分类",
- category_description="用于测试的商品分类",
- actor_id=None
+ conn=self.connection, category_name="测试分类", category_description="用于测试的商品分类", actor_id=None
)
-
+
# 创建测试数据 - 创建测试用户(使用随机用户名避免唯一键冲突)
random_suffix = generate_random_string()
test_username = f"testuser_{random_suffix}"
test_email = f"{test_username}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": test_username, "email": test_email}
+ """
+ ),
+ {"username": test_username, "email": test_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
-
+
# 创建测试数据 - 创建店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_store_id = store_id_result[0]
@@ -68,15 +70,12 @@ def test_create_product(self):
"store_id": self.test_store_id,
"category_id": self.test_category["CategoryID"],
"stock_quantity": 100,
- "main_image_url": "http://example.com/test.jpg"
+ "main_image_url": "http://example.com/test.jpg",
}
-
+
# 执行测试
- product = self.product_crud.create_product(
- conn=self.connection,
- **product_data
- )
-
+ product = self.product_crud.create_product(conn=self.connection, **product_data)
+
# 验证结果
self.assertIsNotNone(product)
self.assertIsInstance(product, dict)
@@ -101,25 +100,19 @@ def test_get_product_by_id(self):
price=Decimal("199.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=50
+ stock_quantity=50,
)
-
+
# 执行测试
- product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=test_product["ProductID"]
- )
-
+ product = self.product_crud.get_product_by_id(conn=self.connection, product_id=test_product["ProductID"])
+
# 验证结果
self.assertIsNotNone(product)
self.assertEqual(product["ProductID"], test_product["ProductID"])
self.assertEqual(product["ProductName"], "测试商品2")
-
+
# 测试获取不存在的商品
- non_existent_product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=999999
- )
+ non_existent_product = self.product_crud.get_product_by_id(conn=self.connection, product_id=999999)
self.assertIsNone(non_existent_product)
def test_get_product_with_category_info(self):
@@ -131,15 +124,14 @@ def test_get_product_with_category_info(self):
price=Decimal("299.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=30
+ stock_quantity=30,
)
-
+
# 执行测试
product_with_category = self.product_crud.get_product_with_category_info(
- conn=self.connection,
- product_id=test_product["ProductID"]
+ conn=self.connection, product_id=test_product["ProductID"]
)
-
+
# 验证结果
self.assertIsNotNone(product_with_category)
self.assertEqual(product_with_category["ProductID"], test_product["ProductID"])
@@ -155,34 +147,30 @@ def test_update_product(self):
price=Decimal("399.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=40
+ stock_quantity=40,
)
-
+
# 准备更新数据
update_data = {
"productname": "更新后的商品名称",
"price": Decimal("499.99"),
- "productstatus": "INACTIVE_BY_MERCHANT"
+ "productstatus": "INACTIVE_BY_MERCHANT",
}
-
+
# 执行测试
updated_product = self.product_crud.update_product(
- conn=self.connection,
- product_id=test_product["ProductID"],
- update_data=update_data
+ conn=self.connection, product_id=test_product["ProductID"], update_data=update_data
)
-
+
# 验证结果
self.assertIsNotNone(updated_product)
self.assertEqual(updated_product["ProductName"], update_data["productname"])
self.assertEqual(float(updated_product["Price"]), float(update_data["price"]))
self.assertEqual(updated_product["ProductStatus"], update_data["productstatus"])
-
+
# 测试更新不存在的商品
non_existent_update = self.product_crud.update_product(
- conn=self.connection,
- product_id=999999,
- update_data={"productname": "不存在的商品"}
+ conn=self.connection, product_id=999999, update_data={"productname": "不存在的商品"}
)
# 现在应该返回None,因为商品不存在
self.assertIsNone(non_existent_update)
@@ -196,39 +184,33 @@ def test_update_product_stock(self):
price=Decimal("599.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=100
+ stock_quantity=100,
)
-
+
# 执行测试 - 增加库存
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=50
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=50
)
-
+
# 验证结果
self.assertIsNotNone(updated_product)
self.assertEqual(updated_product["StockQuantity"], 150)
-
+
# 执行测试 - 减少库存
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=-30
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=-30
)
-
+
# 验证结果
self.assertIsNotNone(updated_product)
self.assertEqual(updated_product["StockQuantity"], 120)
-
+
# 执行测试 - 减少超过当前库存的量
with self.assertRaises(InsufficientStockException):
updated_product = self.product_crud.update_product_stock(
- conn=self.connection,
- product_id=test_product["ProductID"],
- stock_change=-200
+ conn=self.connection, product_id=test_product["ProductID"], stock_change=-200
)
-
+
# 验证结果 - 库存未更新
self.assertEqual(updated_product["StockQuantity"], 120)
@@ -242,29 +224,23 @@ def test_get_products_by_store_id(self):
price=Decimal(f"{(i+1)*100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10*(i+1)
+ stock_quantity=10 * (i + 1),
)
-
+
# 执行测试
products = self.product_crud.get_products_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- limit=10,
- offset=0
+ conn=self.connection, store_id=self.test_store_id, limit=10, offset=0
)
-
+
# 验证结果
self.assertIsInstance(products, list)
self.assertEqual(len(products), 5)
for i, product in enumerate(products):
self.assertIn("ProductName", product)
-
+
# 测试分页
products_page = self.product_crud.get_products_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- limit=2,
- offset=2
+ conn=self.connection, store_id=self.test_store_id, limit=2, offset=2
)
self.assertEqual(len(products_page), 2)
@@ -275,9 +251,9 @@ def test_get_products_by_category_id(self):
conn=self.connection,
category_name="新测试分类",
category_description="用于测试分类查询的商品分类",
- actor_id=None
+ actor_id=None,
)
-
+
# 创建不同分类下的商品
for i in range(3):
self.product_crud.create_product(
@@ -286,9 +262,9 @@ def test_get_products_by_category_id(self):
price=Decimal(f"{(i+1)*100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10*(i+1)
+ stock_quantity=10 * (i + 1),
)
-
+
for i in range(2):
self.product_crud.create_product(
conn=self.connection,
@@ -296,20 +272,18 @@ def test_get_products_by_category_id(self):
price=Decimal(f"{(i+1)*200}.99"),
store_id=self.test_store_id,
category_id=new_category["CategoryID"],
- stock_quantity=20*(i+1)
+ stock_quantity=20 * (i + 1),
)
-
+
# 执行测试
products_cat1 = self.product_crud.get_products_by_category_id(
- conn=self.connection,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, category_id=self.test_category["CategoryID"]
)
-
+
products_cat2 = self.product_crud.get_products_by_category_id(
- conn=self.connection,
- category_id=new_category["CategoryID"]
+ conn=self.connection, category_id=new_category["CategoryID"]
)
-
+
# 验证结果
self.assertEqual(len(products_cat1), 3)
self.assertEqual(len(products_cat2), 2)
@@ -324,9 +298,9 @@ def test_search_products(self):
price=Decimal("5999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=50
+ stock_quantity=50,
)
-
+
self.product_crud.create_product(
conn=self.connection,
product_name="华为平板",
@@ -334,9 +308,9 @@ def test_search_products(self):
price=Decimal("3999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=30
+ stock_quantity=30,
)
-
+
self.product_crud.create_product(
conn=self.connection,
product_name="苹果平板",
@@ -344,25 +318,16 @@ def test_search_products(self):
price=Decimal("4999.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=40
+ stock_quantity=40,
)
-
+
# 执行测试 - 按名称搜索
- apple_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="苹果"
- )
-
- huawei_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="华为"
- )
-
- tablet_products = self.product_crud.search_products(
- conn=self.connection,
- search_term="平板"
- )
-
+ apple_products = self.product_crud.search_products(conn=self.connection, search_term="苹果")
+
+ huawei_products = self.product_crud.search_products(conn=self.connection, search_term="华为")
+
+ tablet_products = self.product_crud.search_products(conn=self.connection, search_term="平板")
+
# 验证结果
self.assertEqual(len(apple_products), 2) # "苹果手机"和"苹果平板"
self.assertEqual(len(huawei_products), 1) # "华为平板"
@@ -377,36 +342,28 @@ def test_delete_product(self):
price=Decimal("99.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10
+ stock_quantity=10,
)
-
+
# 执行测试
- result = self.product_crud.delete_product(
- conn=self.connection,
- product_id=test_product["ProductID"]
- )
-
+ result = self.product_crud.delete_product(conn=self.connection, product_id=test_product["ProductID"])
+
# 验证结果
self.assertTrue(result)
-
+
# 获取更新后的商品信息
deleted_product = self.product_crud.get_product_by_id(
- conn=self.connection,
- product_id=test_product["ProductID"]
+ conn=self.connection, product_id=test_product["ProductID"]
)
-
+
# 验证商品状态变为DISCONTINUED
self.assertIsNotNone(deleted_product)
self.assertEqual(deleted_product["ProductStatus"], "DISCONTINUED")
-
+
# 测试删除不存在的商品
- non_existent_result = self.product_crud.delete_product(
- conn=self.connection,
- product_id=999999
- )
+ non_existent_result = self.product_crud.delete_product(conn=self.connection, product_id=999999)
self.assertFalse(non_existent_result)
-
def test_get_products_by_store_and_category(self):
"""测试同时按店铺和分类筛选商品"""
# 准备数据 - 创建两个分类,并分别创建关联到不同分类的商品
@@ -414,9 +371,9 @@ def test_get_products_by_store_and_category(self):
conn=self.connection,
category_name="测试分类2",
category_description="第二个用于测试的商品分类",
- actor_id=None
+ actor_id=None,
)
-
+
# 创建商品到第一个测试分类和测试店铺
for i in range(3):
self.product_crud.create_product(
@@ -425,9 +382,9 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i+1)*100}.99"),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=10*(i+1)
+ stock_quantity=10 * (i + 1),
)
-
+
# 创建商品到第二个测试分类和测试店铺
for i in range(2):
self.product_crud.create_product(
@@ -436,20 +393,22 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i+1)*200}.99"),
store_id=self.test_store_id,
category_id=new_category["CategoryID"],
- stock_quantity=20*(i+1)
+ stock_quantity=20 * (i + 1),
)
-
+
# 创建第二个店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺2', :user_id, '第二个用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
test_store_id_2 = store_id_result[0]
-
+
# 创建商品到第一个测试分类和第二个店铺
for i in range(2):
self.product_crud.create_product(
@@ -458,28 +417,22 @@ def test_get_products_by_store_and_category(self):
price=Decimal(f"{(i+1)*150}.99"),
store_id=test_store_id_2,
category_id=self.test_category["CategoryID"],
- stock_quantity=15*(i+1)
+ stock_quantity=15 * (i + 1),
)
-
+
# 执行测试
products_store1_cat1 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=self.test_store_id,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, store_id=self.test_store_id, category_id=self.test_category["CategoryID"]
)
-
+
products_store1_cat2 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=self.test_store_id,
- category_id=new_category["CategoryID"]
+ conn=self.connection, store_id=self.test_store_id, category_id=new_category["CategoryID"]
)
-
+
products_store2_cat1 = self.product_crud.get_products_by_store_and_category(
- conn=self.connection,
- store_id=test_store_id_2,
- category_id=self.test_category["CategoryID"]
+ conn=self.connection, store_id=test_store_id_2, category_id=self.test_category["CategoryID"]
)
-
+
# 验证结果
self.assertEqual(len(products_store1_cat1), 3)
self.assertEqual(len(products_store1_cat2), 2)
@@ -490,7 +443,7 @@ def test_get_filtered_products(self):
# 准备数据 - 创建不同价格、状态的商品
prices = [99.99, 199.99, 299.99, 399.99, 499.99]
statuses = ["ACTIVE", "ACTIVE", "INACTIVE_BY_MERCHANT", "ACTIVE", "DISCONTINUED"]
-
+
for i, (price, status) in enumerate(zip(prices, statuses)):
product = self.product_crud.create_product(
conn=self.connection,
@@ -499,113 +452,112 @@ def test_get_filtered_products(self):
price=Decimal(str(price)),
store_id=self.test_store_id,
category_id=self.test_category["CategoryID"],
- stock_quantity=100
+ stock_quantity=100,
)
-
+
# 如果需要非默认状态,则更新状态
if status != "ACTIVE":
self.connection.execute(
- text(f"""
+ text(
+ f"""
UPDATE Product SET ProductStatus = :status
WHERE ProductID = :product_id
- """),
- {"status": status, "product_id": product["ProductID"]}
+ """
+ ),
+ {"status": status, "product_id": product["ProductID"]},
)
-
+
# 测试价格区间过滤
- low_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- max_price=Decimal("200.00")
- )
-
+ low_price_products = self.product_crud.get_filtered_products(conn=self.connection, max_price=Decimal("200.00"))
+
# 获取不同状态的高价格商品
high_price_active_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- min_price=Decimal("300.00"),
- product_status="ACTIVE" # 显式查询ACTIVE状态的高价格商品
+ conn=self.connection, min_price=Decimal("300.00"), product_status="ACTIVE" # 显式查询ACTIVE状态的高价格商品
)
-
+
high_price_discontinued_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("300.00"),
- product_status="DISCONTINUED" # 显式查询DISCONTINUED状态的高价格商品
+ product_status="DISCONTINUED", # 显式查询DISCONTINUED状态的高价格商品
)
-
+
# 合并所有高价格商品结果
high_price_products = high_price_active_products + high_price_discontinued_products
-
+
# 获取不同状态的中价格商品
mid_price_active_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("150.00"),
max_price=Decimal("350.00"),
- product_status="ACTIVE" # 显式查询ACTIVE状态的中价格商品
+ product_status="ACTIVE", # 显式查询ACTIVE状态的中价格商品
)
-
+
mid_price_inactive_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("150.00"),
max_price=Decimal("350.00"),
- product_status="INACTIVE_BY_MERCHANT" # 显式查询INACTIVE_BY_MERCHANT状态的中价格商品
+ product_status="INACTIVE_BY_MERCHANT", # 显式查询INACTIVE_BY_MERCHANT状态的中价格商品
)
-
+
# 合并所有中价格商品结果
mid_price_products = mid_price_active_products + mid_price_inactive_products
-
+
# 测试状态过滤
- active_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- product_status="ACTIVE"
- )
-
+ active_products = self.product_crud.get_filtered_products(conn=self.connection, product_status="ACTIVE")
+
inactive_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- product_status="INACTIVE_BY_MERCHANT"
+ conn=self.connection, product_status="INACTIVE_BY_MERCHANT"
)
-
+
# 测试排序
asc_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- order_by="price_asc" # 使用新的排序格式
- )
-
+ conn=self.connection, order_by="price_asc"
+ ) # 使用新的排序格式
+
desc_price_products = self.product_crud.get_filtered_products(
- conn=self.connection,
- order_by="price_desc" # 使用新的排序格式
- )
-
+ conn=self.connection, order_by="price_desc"
+ ) # 使用新的排序格式
+
# 组合多重过滤和排序
filtered_sorted_products = self.product_crud.get_filtered_products(
conn=self.connection,
min_price=Decimal("100.00"),
max_price=Decimal("400.00"),
product_status="ACTIVE",
- order_by="price_desc" # 使用新的排序格式
+ order_by="price_desc", # 使用新的排序格式
)
-
+
# 验证结果
self.assertEqual(len(low_price_products), 2) # 价格低于200的商品
# 高价格商品应该有2个:1个ACTIVE状态的1个和1个DISCONTINUED状态的
self.assertEqual(len(high_price_active_products) + len(high_price_discontinued_products), 2)
# 中价格商品应该有2个:1个ACTIVE状态的1个和1个INACTIVE_BY_MERCHANT状态的
self.assertEqual(len(mid_price_active_products) + len(mid_price_inactive_products), 2)
-
+
self.assertEqual(len(active_products), 3) # ACTIVE状态的商品
self.assertEqual(len(inactive_products), 1) # INACTIVE_BY_MERCHANT状态的商品
-
+
# 验证价格升序排序
- self.assertTrue(all(asc_price_products[i]["Price"] <= asc_price_products[i+1]["Price"]
- for i in range(len(asc_price_products)-1)))
-
+ self.assertTrue(
+ all(
+ asc_price_products[i]["Price"] <= asc_price_products[i + 1]["Price"]
+ for i in range(len(asc_price_products) - 1)
+ )
+ )
+
# 验证价格降序排序
- self.assertTrue(all(desc_price_products[i]["Price"] >= desc_price_products[i+1]["Price"]
- for i in range(len(desc_price_products)-1)))
-
+ self.assertTrue(
+ all(
+ desc_price_products[i]["Price"] >= desc_price_products[i + 1]["Price"]
+ for i in range(len(desc_price_products) - 1)
+ )
+ )
+
# 验证组合过滤和排序
self.assertEqual(len(filtered_sorted_products), 2) # 符合条件的商品
if len(filtered_sorted_products) >= 2:
self.assertTrue(filtered_sorted_products[0]["Price"] > filtered_sorted_products[1]["Price"]) # 降序
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/crud/test_store_change_request_crud.py b/src/backend/test/unit/crud/test_store_change_request_crud.py
index 27d2941..756e8cf 100644
--- a/src/backend/test/unit/crud/test_store_change_request_crud.py
+++ b/src/backend/test/unit/crud/test_store_change_request_crud.py
@@ -13,7 +13,7 @@
def generate_random_string(length=8):
"""生成指定长度的随机字符串"""
letters_and_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_and_digits) for _ in range(length))
+ return "".join(random.choice(letters_and_digits) for _ in range(length))
class TestStoreChangeRequestCRUD(BaseDBTestCaseAutoRollback):
@@ -21,44 +21,50 @@ def setUp(self):
super().setUp()
# 初始化CRUD实例
self.store_change_request_crud = get_store_change_request_crud_instance()
-
+
# 创建测试数据 - 创建测试用户
random_suffix = generate_random_string()
self.test_username = f"testuser_{random_suffix}"
self.test_email = f"{self.test_username}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'merchant')
- """),
- {"username": self.test_username, "email": self.test_email}
+ """
+ ),
+ {"username": self.test_username, "email": self.test_email},
)
user_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_user_id = user_id_result[0]
-
+
# 创建管理员用户
random_suffix = generate_random_string()
admin_username = f"admin_{random_suffix}"
admin_email = f"{admin_username}@example.com"
-
+
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO User (Username, PasswordHash, Email, UserRole)
VALUES (:username, 'password_hash', :email, 'admin')
- """),
- {"username": admin_username, "email": admin_email}
+ """
+ ),
+ {"username": admin_username, "email": admin_email},
)
admin_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.admin_user_id = admin_id_result[0]
-
+
# 创建测试数据 - 创建店铺
self.connection.execute(
- text("""
+ text(
+ """
INSERT INTO Store (StoreName, OwnerUserID, Description, StoreStatus)
VALUES ('测试店铺', :user_id, '用于测试的店铺', 'ACTIVE')
- """),
- {"user_id": self.test_user_id}
+ """
+ ),
+ {"user_id": self.test_user_id},
)
store_id_result = self.connection.execute(text("SELECT LAST_INSERT_ID()")).fetchone()
self.test_store_id = store_id_result[0]
@@ -66,12 +72,8 @@ def setUp(self):
def test_create_change_request(self):
"""测试创建店铺变更请求"""
# 准备数据
- proposed_data = {
- "StoreName": "更新后的店铺名称",
- "Description": "更新后的店铺描述",
- "StoreStatus": "ACTIVE"
- }
-
+ proposed_data = {"StoreName": "更新后的店铺名称", "Description": "更新后的店铺描述", "StoreStatus": "ACTIVE"}
+
# 执行测试
change_request = self.store_change_request_crud.create_change_request(
conn=self.connection,
@@ -80,9 +82,9 @@ def test_create_change_request(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes="请求更改店铺信息",
- actor_id=None
+ actor_id=None,
)
-
+
# 验证结果
self.assertIsNotNone(change_request)
self.assertIsInstance(change_request, dict)
@@ -92,15 +94,15 @@ def test_create_change_request(self):
self.assertEqual(change_request["RequestType"], "STORE_UPDATE")
self.assertEqual(change_request["Status"], "PENDING_APPROVAL")
self.assertEqual(change_request["SubmitterNotes"], "请求更改店铺信息")
-
+
# 验证提议的数据
self.assertIsNotNone(change_request["ProposedData_JSON"])
-
+
# 对于JSON字段,需要确认它已经被解析为Python对象
self.assertIsInstance(change_request["ProposedData_JSON"], dict)
self.assertEqual(change_request["ProposedData_JSON"]["StoreName"], proposed_data["StoreName"])
self.assertEqual(change_request["ProposedData_JSON"]["Description"], proposed_data["Description"])
-
+
# 验证日期字段
self.assertIsNotNone(change_request["CreationTime"])
self.assertIsNotNone(change_request["LastUpdatedDate"])
@@ -109,7 +111,7 @@ def test_get_change_request_by_id(self):
"""测试根据ID获取店铺变更请求"""
# 准备数据 - 先创建变更请求
proposed_data = {"StoreName": "测试店铺2"}
-
+
created_request = self.store_change_request_crud.create_change_request(
conn=self.connection,
requesting_user_id=self.test_user_id,
@@ -117,17 +119,15 @@ def test_get_change_request_by_id(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes="测试请求",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
request_id = created_request["ChangeRequestID"]
retrieved_request = self.store_change_request_crud.get_change_request_by_id(
- conn=self.connection,
- request_id=request_id,
- actor_id=None
+ conn=self.connection, request_id=request_id, actor_id=None
)
-
+
# 验证结果
self.assertIsNotNone(retrieved_request)
self.assertEqual(retrieved_request["ChangeRequestID"], request_id)
@@ -135,7 +135,7 @@ def test_get_change_request_by_id(self):
self.assertEqual(retrieved_request["RequestType"], "STORE_UPDATE")
self.assertEqual(retrieved_request["StoreID"], self.test_store_id)
self.assertEqual(retrieved_request["Status"], "PENDING_APPROVAL")
-
+
# 验证提议的数据
self.assertIsNotNone(retrieved_request["ProposedData_JSON"])
self.assertEqual(retrieved_request["ProposedData_JSON"]["StoreName"], proposed_data["StoreName"])
@@ -152,16 +152,14 @@ def test_get_change_requests_by_store_id(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes=f"测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
requests = self.store_change_request_crud.get_change_requests_by_store_id(
- conn=self.connection,
- store_id=self.test_store_id,
- actor_id=None
+ conn=self.connection, store_id=self.test_store_id, actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(requests), 3)
for request in requests:
@@ -180,16 +178,14 @@ def test_get_change_requests_by_user_id(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes=f"用户测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
requests = self.store_change_request_crud.get_change_requests_by_user_id(
- conn=self.connection,
- user_id=self.test_user_id,
- actor_id=None
+ conn=self.connection, user_id=self.test_user_id, actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(requests), 3)
for request in requests:
@@ -208,9 +204,9 @@ def test_get_all_pending_requests(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes=f"待审核测试请求{i}",
- actor_id=None
+ actor_id=None,
)
-
+
# 将其中两个请求设置为已审核状态
if i < 2:
self.store_change_request_crud.update_request_status(
@@ -219,15 +215,12 @@ def test_get_all_pending_requests(self):
status="APPROVED",
admin_id=self.admin_user_id,
admin_notes="已审核通过",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试
- pending_requests = self.store_change_request_crud.get_all_pending_requests(
- conn=self.connection,
- actor_id=None
- )
-
+ pending_requests = self.store_change_request_crud.get_all_pending_requests(conn=self.connection, actor_id=None)
+
# 验证结果
self.assertEqual(len(pending_requests), 2) # 应该只有2个待审核的请求
for request in pending_requests:
@@ -237,7 +230,7 @@ def test_get_filtered_requests(self):
"""测试获取根据多种条件筛选的请求列表"""
# 准备数据 - 创建多种类型的变更请求
request_types = ["STORE_CREATE", "STORE_UPDATE", "STORE_DELETE"]
-
+
for i, request_type in enumerate(request_types):
proposed_data = {"StoreName": f"{request_type}测试{i}"}
self.store_change_request_crud.create_change_request(
@@ -247,28 +240,23 @@ def test_get_filtered_requests(self):
proposed_data=proposed_data,
store_id=self.test_store_id if request_type != "STORE_CREATE" else None,
submitter_notes=f"{request_type}测试请求",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 按请求类型筛选
update_requests = self.store_change_request_crud.get_filtered_requests(
- conn=self.connection,
- request_type="STORE_UPDATE",
- actor_id=None
+ conn=self.connection, request_type="STORE_UPDATE", actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(update_requests), 1)
self.assertEqual(update_requests[0]["RequestType"], "STORE_UPDATE")
-
+
# 执行测试 - 按多种条件筛选
all_requests = self.store_change_request_crud.get_filtered_requests(
- conn=self.connection,
- user_id=self.test_user_id,
- status="PENDING_APPROVAL",
- actor_id=None
+ conn=self.connection, user_id=self.test_user_id, status="PENDING_APPROVAL", actor_id=None
)
-
+
# 验证结果
self.assertEqual(len(all_requests), 3) # 应该有3个符合条件的请求
for request in all_requests:
@@ -286,9 +274,9 @@ def test_update_request_status(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes="等待审核",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 更新为已审核
updated_request = self.store_change_request_crud.update_request_status(
conn=self.connection,
@@ -296,9 +284,9 @@ def test_update_request_status(self):
status="APPROVED",
admin_id=self.admin_user_id,
admin_notes="审核通过",
- actor_id=None
+ actor_id=None,
)
-
+
# 验证结果
self.assertIsNotNone(updated_request)
self.assertEqual(updated_request["Status"], "APPROVED")
@@ -317,23 +305,17 @@ def test_update_request(self):
proposed_data=original_proposed_data,
store_id=self.test_store_id,
submitter_notes="原始备注",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 更新请求内容
new_proposed_data = {"StoreName": "更新后的店铺名称", "Description": "更新后的描述"}
- update_data = {
- "ProposedData_JSON": new_proposed_data,
- "SubmitterNotes": "更新后的备注"
- }
-
+ update_data = {"ProposedData_JSON": new_proposed_data, "SubmitterNotes": "更新后的备注"}
+
updated_request = self.store_change_request_crud.update_request(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- update_data=update_data,
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], update_data=update_data, actor_id=None
)
-
+
# 验证结果
self.assertIsNotNone(updated_request)
self.assertEqual(updated_request["SubmitterNotes"], "更新后的备注")
@@ -351,28 +333,24 @@ def test_cancel_request(self):
proposed_data=proposed_data,
store_id=self.test_store_id,
submitter_notes="即将取消",
- actor_id=None
+ actor_id=None,
)
-
+
# 执行测试 - 取消请求
result = self.store_change_request_crud.cancel_request(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], actor_id=None
)
-
+
# 验证结果
self.assertTrue(result)
-
+
# 检查请求状态是否已更新
cancelled_request = self.store_change_request_crud.get_change_request_by_id(
- conn=self.connection,
- request_id=request["ChangeRequestID"],
- actor_id=None
+ conn=self.connection, request_id=request["ChangeRequestID"], actor_id=None
)
-
+
self.assertEqual(cancelled_request["Status"], "CANCELLED_BY_USER")
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/crud/test_store_change_request_crud_v2.py b/src/backend/test/unit/crud/test_store_change_request_crud_v2.py
index 31b45fd..cde60dc 100644
--- a/src/backend/test/unit/crud/test_store_change_request_crud_v2.py
+++ b/src/backend/test/unit/crud/test_store_change_request_crud_v2.py
@@ -34,9 +34,7 @@ def setUp(self):
self.mock_cursor_result = MagicMock()
self.mock_conn.execute.return_value = self.mock_cursor_result
- self.set_actor_patcher = patch.object(
- StoreChangeRequestCRUD2, "_set_actor_session_variable"
- )
+ self.set_actor_patcher = patch.object(StoreChangeRequestCRUD2, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -129,9 +127,7 @@ def test_get_request_list_with_status_and_type_filter(self):
row1._mapping = mock_row1_db
self.mock_cursor_result.fetchall.return_value = [row1]
- requests = self.crud.get_request_list(
- self.mock_conn, status_list=status_filter, request_type_list=type_filter
- )
+ requests = self.crud.get_request_list(self.mock_conn, status_list=status_filter, request_type_list=type_filter)
self.mock_conn.execute.assert_called_once()
call_args = self.mock_conn.execute.call_args.args
@@ -168,9 +164,7 @@ def test_create_request_create_store_success(self):
"CreationTime": datetime.datetime.now(datetime.UTC),
"LastUpdatedDate": datetime.datetime.now(datetime.UTC),
}
- with patch.object(
- self.crud, "get_request_by_id", return_value=mock_final_request_data
- ) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
created_request = self.crud.create_request_create_store(
conn=self.mock_conn,
requesting_user_id=requesting_user_id,
@@ -187,9 +181,7 @@ def test_create_request_create_store_success(self):
self.assertEqual(params["ProposedData_JSON"], json.dumps(proposed_data))
self.assertIsNone(params["StoreID"])
- mock_get_by_id.assert_called_once_with(
- self.mock_conn, change_request_id=expected_request_id
- )
+ mock_get_by_id.assert_called_once_with(self.mock_conn, change_request_id=expected_request_id)
self.assertEqual(created_request, mock_final_request_data)
# --- 测试 create_request_update_store ---
@@ -213,9 +205,7 @@ def test_create_request_update_store_success(self):
"CreationTime": datetime.datetime.now(datetime.UTC),
"LastUpdatedDate": datetime.datetime.now(datetime.UTC),
}
- with patch.object(
- self.crud, "get_request_by_id", return_value=mock_final_request_data
- ) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
created_request = self.crud.create_request_update_store(
conn=self.mock_conn,
requesting_user_id=requesting_user_id,
@@ -236,9 +226,7 @@ def test_cancel_request_by_user_success(self):
actor_id = 201
self.mock_cursor_result.rowcount = 1
- result = self.crud.cancel_request_by_user(
- self.mock_conn, change_request_id=request_id, actor_id=actor_id
- )
+ result = self.crud.cancel_request_by_user(self.mock_conn, change_request_id=request_id, actor_id=actor_id)
self.assertTrue(result)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -270,9 +258,7 @@ def test_update_request_by_admin_approve_success(self):
"ReviewTimestamp": datetime.datetime.now(datetime.UTC), # Simulate DB update
"LastUpdatedDate": datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=1),
}
- with patch.object(
- self.crud, "get_request_by_id", return_value=mock_updated_request_data
- ) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_updated_request_data) as mock_get_by_id:
updated_request = self.crud.update_request_by_admin(
conn=self.mock_conn,
change_request_id=request_id,
@@ -306,9 +292,7 @@ def test_update_request_store_id_and_status_applied_for_create(self):
"StoreID": applied_store_id,
"Status": StatusEnum.APPLIED.value,
}
- with patch.object(
- self.crud, "get_request_by_id", return_value=mock_final_request_data
- ) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
updated_request = self.crud.update_request_store_id_and_status_applied(
conn=self.mock_conn,
change_request_id=change_request_id,
@@ -337,9 +321,7 @@ def test_update_request_store_id_and_status_applied_for_update_no_store_id_chang
"Status": StatusEnum.APPLIED.value,
# StoreID remains as in self.sample_request_data as applied_store_id is None
}
- with patch.object(
- self.crud, "get_request_by_id", return_value=mock_final_request_data
- ) as mock_get_by_id:
+ with patch.object(self.crud, "get_request_by_id", return_value=mock_final_request_data) as mock_get_by_id:
updated_request = self.crud.update_request_store_id_and_status_applied(
conn=self.mock_conn,
change_request_id=change_request_id,
@@ -351,9 +333,7 @@ def test_update_request_store_id_and_status_applied_for_update_no_store_id_chang
normalized_sql = normalize_sql(str(call_args[0].text))
params = call_args[1]
- self.assertIn(
- f"UPDATE {self.crud.table_name} SET Status = :AppliedStatus", normalized_sql
- )
+ self.assertIn(f"UPDATE {self.crud.table_name} SET Status = :AppliedStatus", normalized_sql)
self.assertNotIn("StoreID = :AppliedStoreID", normalized_sql)
self.assertEqual(params["AppliedStatus"], StatusEnum.APPLIED.value)
self.assertNotIn("AppliedStoreID", params)
diff --git a/src/backend/test/unit/crud/test_store_crud.py b/src/backend/test/unit/crud/test_store_crud.py
index b12a24c..497f1cd 100644
--- a/src/backend/test/unit/crud/test_store_crud.py
+++ b/src/backend/test/unit/crud/test_store_crud.py
@@ -16,7 +16,7 @@
# 辅助函数来规范化SQL字符串以便比较
def normalize_sql(sql_string: str) -> str:
"""将SQL字符串中的多个空格和换行符替换为单个空格,并去除首尾空格。"""
- return ' '.join(sql_string.strip().split())
+ return " ".join(sql_string.strip().split())
class TestStoreCRUD(unittest.TestCase):
@@ -30,7 +30,7 @@ def setUp(self):
self.mock_conn.execute.return_value = self.mock_cursor_result
# Patch _set_actor_session_variable 作为 StoreCRUD 类的静态方法
- self.set_actor_patcher = patch.object(StoreCRUD, '_set_actor_session_variable')
+ self.set_actor_patcher = patch.object(StoreCRUD, "_set_actor_session_variable")
self.mock_set_actor_session_variable = self.set_actor_patcher.start()
self.addCleanup(self.set_actor_patcher.stop)
@@ -43,7 +43,7 @@ def setUp(self):
"LogoURL": "http://example.com/logo.png",
"StoreStatus": StoreStatusEnum.ACTIVE.value,
"CreationDate": datetime.datetime(2025, 1, 1, 10, 0, 0),
- "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0)
+ "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0),
}
def test_create_store_success(self):
@@ -60,11 +60,16 @@ def test_create_store_success(self):
# 模拟 get_store_by_id 的返回值
mock_created_store_dict = {
- "StoreID": expected_new_store_id, "StoreName": store_name, "OwnerUserID": owner_user_id,
- "Description": description, "LogoURL": logo_url, "StoreStatus": store_status.value,
- "CreationDate": creation_date, "LastUpdatedDate": creation_date # Initial
+ "StoreID": expected_new_store_id,
+ "StoreName": store_name,
+ "OwnerUserID": owner_user_id,
+ "Description": description,
+ "LogoURL": logo_url,
+ "StoreStatus": store_status.value,
+ "CreationDate": creation_date,
+ "LastUpdatedDate": creation_date, # Initial
}
- with patch.object(self.crud, 'get_store_by_id', return_value=mock_created_store_dict) as mock_get_by_id:
+ with patch.object(self.crud, "get_store_by_id", return_value=mock_created_store_dict) as mock_get_by_id:
created_store = self.crud.create_store(
conn=self.mock_conn,
store_name=store_name,
@@ -73,7 +78,7 @@ def test_create_store_success(self):
logo_url=logo_url,
store_status=store_status,
creation_date=creation_date,
- actor_id=actor_id
+ actor_id=actor_id,
)
self.mock_set_actor_session_variable.assert_called_once_with(self.mock_conn, actor_id)
@@ -86,31 +91,36 @@ def test_create_store_success(self):
self.assertIn(f"INSERT INTO {self.crud.table_name}", normalized_sql)
self.assertIn(
"StoreName, OwnerUserID, Description, LogoURL, StoreStatus, CreationDate, LastUpdatedDate",
- normalized_sql
+ normalized_sql,
)
self.assertIn(
":StoreName, :OwnerUserID, :Description, :LogoURL, :StoreStatus, :CreationDate, :CreationDate",
- normalized_sql
+ normalized_sql,
)
- self.assertEqual(called_params['StoreName'], store_name)
- self.assertEqual(called_params['OwnerUserID'], owner_user_id)
- self.assertEqual(called_params['Description'], description)
- self.assertEqual(called_params['LogoURL'], logo_url)
- self.assertEqual(called_params['StoreStatus'], store_status.value)
- self.assertEqual(called_params['CreationDate'], creation_date)
+ self.assertEqual(called_params["StoreName"], store_name)
+ self.assertEqual(called_params["OwnerUserID"], owner_user_id)
+ self.assertEqual(called_params["Description"], description)
+ self.assertEqual(called_params["LogoURL"], logo_url)
+ self.assertEqual(called_params["StoreStatus"], store_status.value)
+ self.assertEqual(called_params["CreationDate"], creation_date)
mock_get_by_id.assert_called_once_with(self.mock_conn, store_id=expected_new_store_id, actor_id=actor_id)
self.assertEqual(created_store, mock_created_store_dict)
- @patch('backend.app.crud.store_crud.logger')
+ @patch("backend.app.crud.store_crud.logger")
def test_create_store_lastrowid_none(self, mock_logger):
self.mock_cursor_result.lastrowid = None
result = self.crud.create_store(
- conn=self.mock_conn, store_name="Fail Store", owner_user_id=1,
- description=None, logo_url=None, store_status=StoreStatusEnum.ACTIVE,
- creation_date=datetime.datetime.utcnow(), actor_id=1
+ conn=self.mock_conn,
+ store_name="Fail Store",
+ owner_user_id=1,
+ description=None,
+ logo_url=None,
+ store_status=StoreStatusEnum.ACTIVE,
+ creation_date=datetime.datetime.utcnow(),
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.warning.assert_called_once()
@@ -118,11 +128,16 @@ def test_create_store_lastrowid_none(self, mock_logger):
def test_create_store_integrity_error(self):
self.mock_conn.execute.side_effect = exc.IntegrityError("mock integrity", {}, None)
- with patch('backend.app.crud.store_crud.logger') as mock_logger:
+ with patch("backend.app.crud.store_crud.logger") as mock_logger:
result = self.crud.create_store(
- conn=self.mock_conn, store_name="Integrity Store", owner_user_id=1,
- description=None, logo_url=None, store_status=StoreStatusEnum.ACTIVE,
- creation_date=datetime.datetime.utcnow(), actor_id=1
+ conn=self.mock_conn,
+ store_name="Integrity Store",
+ owner_user_id=1,
+ description=None,
+ logo_url=None,
+ store_status=StoreStatusEnum.ACTIVE,
+ creation_date=datetime.datetime.utcnow(),
+ actor_id=1,
)
self.assertIsNone(result)
mock_logger.error.assert_called_once()
@@ -144,9 +159,10 @@ def test_get_store_by_id_found(self):
self.assertIn(
f"SELECT StoreID, StoreName, OwnerUserID, Description, LogoURL, StoreStatus, CreationDate, LastUpdatedDate FROM {self.crud.table_name}",
- normalized_sql)
+ normalized_sql,
+ )
self.assertIn("WHERE StoreID = :StoreID", normalized_sql)
- self.assertEqual(called_params['StoreID'], store_id)
+ self.assertEqual(called_params["StoreID"], store_id)
self.assertEqual(store, self.sample_store_data)
def test_get_store_by_id_not_found(self):
@@ -159,21 +175,22 @@ def test_get_stores_by_owner_user_id_found(self):
actor_id = 101
mock_row_data1 = {**self.sample_store_data, "StoreID": 1, "OwnerUserID": owner_user_id}
mock_row_data2 = {**self.sample_store_data, "StoreID": 2, "OwnerUserID": owner_user_id, "StoreName": "Store 2"}
- mock_row1 = MagicMock();
+ mock_row1 = MagicMock()
mock_row1._mapping = mock_row_data1
- mock_row2 = MagicMock();
+ mock_row2 = MagicMock()
mock_row2._mapping = mock_row_data2
self.mock_cursor_result.fetchall.return_value = [mock_row1, mock_row2]
- stores = self.crud.get_stores_by_owner_user_id(self.mock_conn, owner_user_id=owner_user_id, limit=10, offset=0,
- actor_id=actor_id)
+ stores = self.crud.get_stores_by_owner_user_id(
+ self.mock_conn, owner_user_id=owner_user_id, limit=10, offset=0, actor_id=actor_id
+ )
self.mock_conn.execute.assert_called_once()
call_args = self.mock_conn.execute.call_args.args
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn(f"FROM {self.crud.table_name} WHERE OwnerUserID = :OwnerUserID", normalized_sql)
self.assertIn("ORDER BY CreationDate DESC LIMIT :Limit OFFSET :Offset", normalized_sql)
- self.assertEqual(call_args[1]['OwnerUserID'], owner_user_id)
+ self.assertEqual(call_args[1]["OwnerUserID"], owner_user_id)
self.assertEqual(len(stores), 2)
self.assertEqual(stores[0], mock_row_data1)
@@ -187,9 +204,9 @@ def test_get_all_stores_no_filter(self):
actor_id = 1
mock_row_data1 = {**self.sample_store_data, "StoreID": 1}
mock_row_data2 = {**self.sample_store_data, "StoreID": 2, "StoreName": "Store All 2"}
- mock_row1 = MagicMock();
+ mock_row1 = MagicMock()
mock_row1._mapping = mock_row_data1
- mock_row2 = MagicMock();
+ mock_row2 = MagicMock()
mock_row2._mapping = mock_row_data2
self.mock_cursor_result.fetchall.return_value = [mock_row1, mock_row2]
@@ -207,7 +224,7 @@ def test_get_all_stores_with_status_filter(self):
actor_id = 1
status_filter = StoreStatusEnum.ACTIVE
mock_row_data1 = {**self.sample_store_data, "StoreID": 1, "StoreStatus": status_filter.value}
- mock_row1 = MagicMock();
+ mock_row1 = MagicMock()
mock_row1._mapping = mock_row_data1
self.mock_cursor_result.fetchall.return_value = [mock_row1]
@@ -217,7 +234,7 @@ def test_get_all_stores_with_status_filter(self):
call_args = self.mock_conn.execute.call_args.args
normalized_sql = normalize_sql(str(call_args[0].text))
self.assertIn("WHERE StoreStatus = :StoreStatus", normalized_sql)
- self.assertEqual(call_args[1]['StoreStatus'], status_filter.value)
+ self.assertEqual(call_args[1]["StoreStatus"], status_filter.value)
self.assertEqual(len(stores), 1)
self.assertEqual(stores[0]["StoreStatus"], status_filter.value)
@@ -231,16 +248,16 @@ def test_update_store_success_some_fields(self):
expected_updated_data = {
**self.sample_store_data,
"StoreName": "Updated Store Name",
- "Description": "New Description"
+ "Description": "New Description",
# LogoURL and StoreStatus not updated, so they remain as in sample_store_data
}
- with patch.object(self.crud, 'get_store_by_id', return_value=expected_updated_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_store_by_id", return_value=expected_updated_data) as mock_get_by_id:
updated_store = self.crud.update_store(
conn=self.mock_conn,
store_id=store_id,
actor_id=actor_id,
store_name="Updated Store Name",
- description="New Description"
+ description="New Description",
# logo_url and store_status are None, so not updated
)
@@ -258,9 +275,9 @@ def test_update_store_success_some_fields(self):
self.assertNotIn("StoreStatus = :StoreStatus", normalized_sql) # Not provided for update
self.assertIn("WHERE StoreID = :StoreID_param", normalized_sql)
- self.assertEqual(params['StoreName'], "Updated Store Name")
- self.assertEqual(params['Description'], "New Description")
- self.assertEqual(params['StoreID_param'], store_id)
+ self.assertEqual(params["StoreName"], "Updated Store Name")
+ self.assertEqual(params["Description"], "New Description")
+ self.assertEqual(params["StoreID_param"], store_id)
mock_get_by_id.assert_called_once_with(self.mock_conn, store_id=store_id, actor_id=actor_id)
self.assertEqual(updated_store, expected_updated_data)
@@ -272,7 +289,7 @@ def test_update_store_only_status(self):
self.mock_cursor_result.rowcount = 1
expected_updated_data = {**self.sample_store_data, "StoreStatus": new_status.value}
- with patch.object(self.crud, 'get_store_by_id', return_value=expected_updated_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_store_by_id", return_value=expected_updated_data) as mock_get_by_id:
updated_store = self.crud.update_store(
conn=self.mock_conn, store_id=store_id, actor_id=actor_id, store_status=new_status
)
@@ -283,9 +300,10 @@ def test_update_store_only_status(self):
self.assertIn(
f"UPDATE {self.crud.table_name} SET StoreStatus = :StoreStatus WHERE StoreID = :StoreID_param",
- normalized_sql)
- self.assertEqual(params['StoreStatus'], new_status.value)
- self.assertEqual(params['StoreID_param'], store_id)
+ normalized_sql,
+ )
+ self.assertEqual(params["StoreStatus"], new_status.value)
+ self.assertEqual(params["StoreID_param"], store_id)
self.assertEqual(updated_store, expected_updated_data)
def test_update_store_no_fields_to_update(self):
@@ -293,9 +311,11 @@ def test_update_store_no_fields_to_update(self):
actor_id = self.sample_store_data["OwnerUserID"]
# Mock get_store_by_id for when no fields are updated
- with patch.object(self.crud, 'get_store_by_id', return_value=self.sample_store_data) as mock_get_by_id:
+ with patch.object(self.crud, "get_store_by_id", return_value=self.sample_store_data) as mock_get_by_id:
result = self.crud.update_store(
- conn=self.mock_conn, store_id=store_id, actor_id=actor_id
+ conn=self.mock_conn,
+ store_id=store_id,
+ actor_id=actor_id,
# No optional update fields provided
)
self.mock_conn.execute.assert_not_called() # UPDATE SQL should not be called
@@ -308,7 +328,7 @@ def test_update_store_not_found_or_no_change(self):
self.mock_cursor_result.rowcount = 0 # Simulate no row updated
# Mock get_store_by_id to return None (as if store doesn't exist)
- with patch.object(self.crud, 'get_store_by_id', return_value=None) as mock_get_by_id_after_update:
+ with patch.object(self.crud, "get_store_by_id", return_value=None) as mock_get_by_id_after_update:
updated_store = self.crud.update_store(
conn=self.mock_conn, store_id=store_id, actor_id=actor_id, store_name="Try Update"
)
@@ -316,5 +336,5 @@ def test_update_store_not_found_or_no_change(self):
mock_get_by_id_after_update.assert_called_once_with(self.mock_conn, store_id=store_id, actor_id=actor_id)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/crud/test_user_crud.py b/src/backend/test/unit/crud/test_user_crud.py
index 6532599..429210b 100644
--- a/src/backend/test/unit/crud/test_user_crud.py
+++ b/src/backend/test/unit/crud/test_user_crud.py
@@ -3,6 +3,7 @@
from backend.app.crud.user_crud import UserCRUD
+
@unittest.skip("This test has been moved to test_user_crud_integration.py")
class TestUserCRUD(BaseDBTestCaseAutoRollback):
def setUp(self):
@@ -51,7 +52,6 @@ def test_create_user_return_value(self):
self.assertIsNone(user["Email"])
self.assertIsNotNone(user["RegistrationDate"])
-
def test_create_user_with_all_fields(self):
# Test creating a user with all fields
user_data = {
@@ -100,7 +100,5 @@ def test_create_users_with_same_username(self):
self.assertIn("Duplicate", str(e))
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/crud/test_user_session_crud.py b/src/backend/test/unit/crud/test_user_session_crud.py
index ad1fba4..988315e 100644
--- a/src/backend/test/unit/crud/test_user_session_crud.py
+++ b/src/backend/test/unit/crud/test_user_session_crud.py
@@ -25,13 +25,15 @@ def test_create_session_success(self):
# Mock self.get_session_by_token which is called internally by create_session
# This makes the test a more focused unit test for create_session's direct logic
mock_created_session_data = {
- "SessionToken": "test_token", "UserID": 1,
- "ExpiresAt": datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
+ "SessionToken": "test_token",
+ "UserID": 1,
+ "ExpiresAt": datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1),
# ... other fields
}
# We need to patch the instance's method
- with patch.object(self.crud, 'get_session_by_token',
- return_value=mock_created_session_data) as mock_get_session:
+ with patch.object(
+ self.crud, "get_session_by_token", return_value=mock_created_session_data
+ ) as mock_get_session:
session_token = "test_token"
user_id = 1
expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
@@ -44,14 +46,14 @@ def test_create_session_success(self):
user_id=user_id,
expires_at=expires_at,
ip_address=ip_address,
- user_agent=user_agent
+ user_agent=user_agent,
)
self.mock_conn.execute.assert_called_once()
args, kwargs = self.mock_conn.execute.call_args
self.assertIn("INSERT INTO UserSession", str(args[0].text)) # Check query part
- self.assertEqual(args[1]['session_token'], session_token) # Check bound parameters
- self.assertEqual(args[1]['user_id'], user_id)
+ self.assertEqual(args[1]["session_token"], session_token) # Check bound parameters
+ self.assertEqual(args[1]["user_id"], user_id)
mock_get_session.assert_called_once_with(self.mock_conn, token=session_token)
self.assertEqual(created_session, mock_created_session_data)
@@ -63,24 +65,24 @@ def test_create_session_with_naive_datetime_fails(self):
self.mock_conn,
session_token="test_token",
user_id=1,
- expires_at=naive_datetime + datetime.timedelta(days=1)
+ expires_at=naive_datetime + datetime.timedelta(days=1),
)
with self.assertRaises(ValueError) as context:
self.crud.create_session(
self.mock_conn,
session_token="test_token",
user_id=1,
- expires_at=naive_datetime + datetime.timedelta(days=-1)
+ expires_at=naive_datetime + datetime.timedelta(days=-1),
)
def test_create_session_verification_fails(self):
- with patch.object(self.crud, 'get_session_by_token', return_value=None) as mock_get_session:
+ with patch.object(self.crud, "get_session_by_token", return_value=None) as mock_get_session:
with self.assertRaisesRegex(Exception, "Session creation verification failed"):
self.crud.create_session(
self.mock_conn,
session_token="test_token",
user_id=1,
- expires_at=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
+ expires_at=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1),
)
mock_get_session.assert_called_once_with(self.mock_conn, token="test_token")
@@ -95,7 +97,7 @@ def test_get_session_by_token_found(self):
self.mock_conn.execute.assert_called_once()
args, kwargs = self.mock_conn.execute.call_args
self.assertIn("SELECT SessionToken, UserID", str(args[0].text)) # Check query part
- self.assertEqual(args[1]['token'], "valid_token") # Check bound parameters
+ self.assertEqual(args[1]["token"], "valid_token") # Check bound parameters
self.assertEqual(session, expected_data)
@@ -116,10 +118,8 @@ def test_get_active_user_id_by_token_and_update_access_success(self):
# The second execute (UPDATE) doesn't call fetchone, its result is self.mock_execute_result
# Patch delete_session_by_token as it might be called if session is expired
- with patch.object(self.crud, 'delete_session_by_token') as mock_delete:
- user_id = self.crud.get_active_user_id_by_token_and_update_access(
- self.mock_conn, token="active_token"
- )
+ with patch.object(self.crud, "delete_session_by_token") as mock_delete:
+ user_id = self.crud.get_active_user_id_by_token_and_update_access(self.mock_conn, token="active_token")
self.assertEqual(user_id, 123)
self.assertEqual(self.mock_conn.execute.call_count, 2) # One SELECT, one UPDATE
@@ -127,34 +127,30 @@ def test_get_active_user_id_by_token_and_update_access_success(self):
# Check SELECT call
select_call_args, select_call_kwargs = self.mock_conn.execute.call_args_list[0]
self.assertIn("SELECT UserID, ExpiresAt", str(select_call_args[0].text))
- self.assertEqual(select_call_args[1]['token'], "active_token")
+ self.assertEqual(select_call_args[1]["token"], "active_token")
# Check UPDATE call
update_call_args, update_call_kwargs = self.mock_conn.execute.call_args_list[1]
self.assertIn("SET LastAccessedAt = NOW()", str(update_call_args[0].text))
- self.assertEqual(update_call_args[1]['token'], "active_token")
+ self.assertEqual(update_call_args[1]["token"], "active_token")
mock_delete.assert_not_called()
def test_get_active_user_id_by_token_not_found(self):
self.mock_execute_result.fetchone.return_value = None # Simulate session not found
- user_id = self.crud.get_active_user_id_by_token_and_update_access(
- self.mock_conn, token="non_existent_token"
- )
+ user_id = self.crud.get_active_user_id_by_token_and_update_access(self.mock_conn, token="non_existent_token")
self.assertIsNone(user_id)
self.mock_conn.execute.assert_called_once() # Only SELECT should be called
- @patch.object(UserSessionCRUD, 'delete_session_by_token') # Patching it on the class
+ @patch.object(UserSessionCRUD, "delete_session_by_token") # Patching it on the class
def test_get_active_user_id_by_token_expired(self, mock_delete_session_by_token):
past_expiry = datetime.datetime.now() - datetime.timedelta(days=1)
mock_row_select = MagicMock()
mock_row_select._mapping = {"UserID": 123, "ExpiresAt": past_expiry}
self.mock_execute_result.fetchone.return_value = mock_row_select
- user_id = self.crud.get_active_user_id_by_token_and_update_access(
- self.mock_conn, token="expired_token"
- )
+ user_id = self.crud.get_active_user_id_by_token_and_update_access(self.mock_conn, token="expired_token")
self.assertIsNone(user_id)
self.mock_conn.execute.assert_called_once() # Only SELECT
@@ -167,7 +163,7 @@ def test_get_active_user_id_by_token_and_update_access_and_expiration_success(se
mock_row_select._mapping = {"UserID": 123, "ExpiresAt": current_expiry}
self.mock_execute_result.fetchone.return_value = mock_row_select
- with patch.object(self.crud, 'delete_session_by_token') as mock_delete:
+ with patch.object(self.crud, "delete_session_by_token") as mock_delete:
result = self.crud.get_active_user_id_by_token_and_update_access_and_expiration(
self.mock_conn, token="active_token", expires_at=new_expiry
)
@@ -175,7 +171,7 @@ def test_get_active_user_id_by_token_and_update_access_and_expiration_success(se
self.assertEqual(self.mock_conn.execute.call_count, 2) # SELECT and UPDATE
update_call_args, update_call_kwargs = self.mock_conn.execute.call_args_list[1]
self.assertIn("SET LastAccessedAt = NOW()", str(update_call_args[0].text))
- self.assertEqual(update_call_args[1]['expires_at'], new_expiry)
+ self.assertEqual(update_call_args[1]["expires_at"], new_expiry)
mock_delete.assert_not_called()
def test_get_active_user_id_by_token_and_update_access_and_expiration_new_expiry_in_past(self):
@@ -186,9 +182,10 @@ def test_get_active_user_id_by_token_and_update_access_and_expiration_new_expiry
self.assertFalse(result)
self.mock_conn.execute.assert_not_called() # No DB calls if new expiry is invalid
- @patch.object(UserSessionCRUD, 'delete_session_by_token')
- def test_get_active_user_id_by_token_and_update_access_and_expiration_session_expired(self,
- mock_delete_session_by_token):
+ @patch.object(UserSessionCRUD, "delete_session_by_token")
+ def test_get_active_user_id_by_token_and_update_access_and_expiration_session_expired(
+ self, mock_delete_session_by_token
+ ):
past_expiry = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1)
future_new_expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)
mock_row_select = MagicMock()
@@ -211,7 +208,7 @@ def test_delete_session_by_token_deleted(self):
self.assertEqual(self.mock_conn.execute.call_count, 2) # Ensure it's called twice
args, kwargs = self.mock_conn.execute.call_args_list[1] # Get the second call's arguments
self.assertIn("DELETE FROM UserSession WHERE SessionToken = :token", str(args[0].text))
- self.assertEqual(args[1]['token'], "token_to_delete")
+ self.assertEqual(args[1]["token"], "token_to_delete")
def test_delete_session_by_token_not_found(self):
self.mock_execute_result.rowcount = 0 # Simulate no rows deleted
@@ -227,7 +224,7 @@ def test_delete_all_sessions_for_user_deleted_some(self):
self.mock_conn.execute.assert_called_once()
args, kwargs = self.mock_conn.execute.call_args
self.assertIn("DELETE FROM UserSession WHERE UserID = :user_id", str(args[0].text))
- self.assertEqual(args[1]['user_id'], 123)
+ self.assertEqual(args[1]["user_id"], 123)
def test_delete_all_sessions_for_user_none_deleted(self):
self.mock_execute_result.rowcount = 0
@@ -235,5 +232,5 @@ def test_delete_all_sessions_for_user_none_deleted(self):
self.assertEqual(count, 0)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/src/backend/test/unit/service/test_address_service.py b/src/backend/test/unit/service/test_address_service.py
index 7069100..be89fd7 100644
--- a/src/backend/test/unit/service/test_address_service.py
+++ b/src/backend/test/unit/service/test_address_service.py
@@ -13,7 +13,7 @@
AddressUpdateRequest,
AddressResponse,
AddressListResponse,
- SetDefaultAddressResponse
+ SetDefaultAddressResponse,
)
from backend.app.utils.exceptions import AddressNotFoundException, PermissionDeniedException, UserNotFoundException
@@ -27,10 +27,7 @@ def setUp(self):
self.mock_address_crud = MagicMock(spec=AddressCRUD)
self.mock_user_crud = MagicMock(spec=UserCRUD)
- self.address_service = AddressService(
- address_crud=self.mock_address_crud,
- user_crud=self.mock_user_crud
- )
+ self.address_service = AddressService(address_crud=self.mock_address_crud, user_crud=self.mock_user_crud)
self.mock_db_conn = MagicMock(spec=Connection)
self.test_user_id = 1
@@ -39,29 +36,39 @@ def setUp(self):
self.admin_actor_id = 999 # Hypothetical admin ID
self.sample_address_dict_camel = {
- "AddressID": 101, "UserID": self.test_user_id,
- "RecipientName": "Test Recipient", "PhoneNumber": "1234567890",
- "FullAddress_Text": "123 Test Street, Test City", "IsDefault": False
+ "AddressID": 101,
+ "UserID": self.test_user_id,
+ "RecipientName": "Test Recipient",
+ "PhoneNumber": "1234567890",
+ "FullAddress_Text": "123 Test Street, Test City",
+ "IsDefault": False,
}
self.sample_user_dict_camel = {
- "UserID": self.test_user_id, "Username": "testuser", "DefaultAddressID": None
+ "UserID": self.test_user_id,
+ "Username": "testuser",
+ "DefaultAddressID": None,
# ... other user fields if UserCRUD.get_user_by_id returns them
}
# --- 测试 create_new_address ---
async def test_create_new_address_success_not_default(self):
address_in = AddressCreateRequest(
- RecipientName="New User", PhoneNumber="13000000000",
- FullAddress_Text="New Address", IsDefault=False # Explicitly False
+ RecipientName="New User",
+ PhoneNumber="13000000000",
+ FullAddress_Text="New Address",
+ IsDefault=False, # Explicitly False
)
# Mock UserCRUD.get_user_by_id
self.mock_user_crud.get_user_by_id.return_value = self.sample_user_dict_camel
# Mock AddressCRUD.create_address
# It should return a dict with CamelCase keys matching AddressResponse fields
created_address_from_crud = {
- "AddressID": 102, "UserID": self.test_user_id,
- "RecipientName": address_in.RecipientName, "PhoneNumber": address_in.PhoneNumber,
- "FullAddress_Text": address_in.FullAddress_Text, "IsDefault": False # CRUD sets to False
+ "AddressID": 102,
+ "UserID": self.test_user_id,
+ "RecipientName": address_in.RecipientName,
+ "PhoneNumber": address_in.PhoneNumber,
+ "FullAddress_Text": address_in.FullAddress_Text,
+ "IsDefault": False, # CRUD sets to False
}
self.mock_address_crud.create_address.return_value = created_address_from_crud
@@ -73,16 +80,18 @@ async def test_create_new_address_success_not_default(self):
self.assertEqual(created_address_response.RecipientName, "New User")
self.assertFalse(created_address_response.IsDefault)
self.assertEqual(created_address_response.AddressID, 102)
- self.mock_user_crud.get_user_by_id.assert_called_once_with(conn=self.mock_db_conn, user_id=self.test_user_id,
- actor_id=self.test_actor_id)
+ self.mock_user_crud.get_user_by_id.assert_called_once_with(
+ conn=self.mock_db_conn, user_id=self.test_user_id, actor_id=self.test_actor_id
+ )
self.mock_address_crud.create_address.assert_called_once_with(
conn=self.mock_db_conn, user_id=self.test_user_id, address_in=address_in, actor_id=self.test_actor_id
)
- @patch('backend.app.services.address_service.logger') # Patch logger for this specific test
+ @patch("backend.app.services.address_service.logger") # Patch logger for this specific test
async def test_create_new_address_actor_not_user_logs_warning(self, mock_logger):
- address_in = AddressCreateRequest(RecipientName="New", PhoneNumber="1234567", FullAddress_Text="Addr11111",
- IsDefault=False)
+ address_in = AddressCreateRequest(
+ RecipientName="New", PhoneNumber="1234567", FullAddress_Text="Addr11111", IsDefault=False
+ )
different_actor_id = self.test_actor_id + 5
self.mock_user_crud.get_user_by_id.return_value = self.sample_user_dict_camel
@@ -118,8 +127,10 @@ async def test_create_new_address_sets_non_default(self):
:return:
"""
address_in = AddressCreateRequest(
- RecipientName="Default Address", PhoneNumber="1234567",
- FullAddress_Text="Default Town", IsDefault=True # Requesting default
+ RecipientName="Default Address",
+ PhoneNumber="1234567",
+ FullAddress_Text="Default Town",
+ IsDefault=True, # Requesting default
)
target_user_id = self.test_user_id
@@ -127,16 +138,19 @@ async def test_create_new_address_sets_non_default(self):
# CRUD create_address will return IsDefault=False initially
initial_created_address_dict = {
- "AddressID": 105, "UserID": target_user_id, "IsDefault": False,
- "RecipientName": address_in.RecipientName, "PhoneNumber": address_in.PhoneNumber,
- "FullAddress_Text": address_in.FullAddress_Text
+ "AddressID": 105,
+ "UserID": target_user_id,
+ "IsDefault": False,
+ "RecipientName": address_in.RecipientName,
+ "PhoneNumber": address_in.PhoneNumber,
+ "FullAddress_Text": address_in.FullAddress_Text,
}
self.mock_address_crud.create_address.return_value = initial_created_address_dict
# Mocks for set_default_address_for_user internals (which create_new_address will call)
self.mock_address_crud.get_address_by_id.side_effect = [
initial_created_address_dict, # First call in set_default (verify exists)
- {**initial_created_address_dict, "IsDefault": True}
+ {**initial_created_address_dict, "IsDefault": True},
# Second call in set_default (get final state for response)
]
self.mock_address_crud.set_all_other_addresses_non_default_for_user.return_value = 1
@@ -162,7 +176,7 @@ async def test_get_user_addresses_success_self(self):
self.assertEqual(len(response.Addresses), 2)
self.assertEqual(response.Addresses[0].AddressID, self.sample_address_dict_camel["AddressID"])
- @patch('backend.app.services.address_service.logger')
+ @patch("backend.app.services.address_service.logger")
async def test_get_user_addresses_actor_not_user_logs_warning(self, mock_logger):
self.mock_address_crud.get_addresses_by_user_id.return_value = [] # Still proceeds
await self.address_service.get_user_addresses(
@@ -191,24 +205,29 @@ async def test_get_address_by_id_for_user_address_not_found(self):
async def test_get_address_by_id_for_user_ownership_mismatch_raises_not_found(self):
address_belongs_to_other = {**self.sample_address_dict_camel, "UserID": self.another_user_id}
self.mock_address_crud.get_address_by_id.return_value = address_belongs_to_other
- with self.assertRaisesRegex(AddressNotFoundException,
- f"Address with ID {self.sample_address_dict_camel['AddressID']} not found for user {self.test_user_id}."):
+ with self.assertRaisesRegex(
+ AddressNotFoundException,
+ f"Address with ID {self.sample_address_dict_camel['AddressID']} not found for user {self.test_user_id}.",
+ ):
await self.address_service.get_address_by_id_for_user(
- db=self.mock_db_conn, address_id=self.sample_address_dict_camel["AddressID"],
+ db=self.mock_db_conn,
+ address_id=self.sample_address_dict_camel["AddressID"],
user_id=self.test_user_id, # Requesting for user 1
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
- @patch('backend.app.services.address_service.logger')
+ @patch("backend.app.services.address_service.logger")
async def test_get_address_by_id_for_user_actor_not_owner_logs_warning(self, mock_logger):
# Actor is not owner, SUT logs warning and proceeds if address belongs to target user_id
- self.mock_address_crud.get_address_by_id.return_value = self.sample_address_dict_camel # Address belongs to test_user_id
+ self.mock_address_crud.get_address_by_id.return_value = (
+ self.sample_address_dict_camel
+ ) # Address belongs to test_user_id
await self.address_service.get_address_by_id_for_user(
db=self.mock_db_conn,
address_id=self.sample_address_dict_camel["AddressID"],
user_id=self.test_user_id, # Target user is owner
- actor_id=self.another_user_id # Actor is different
+ actor_id=self.another_user_id, # Actor is different
)
mock_logger.warning.assert_any_call(
f"UserID {self.another_user_id} attempting to get AddressID {self.sample_address_dict_camel['AddressID']} for UserID {self.test_user_id}."
@@ -224,12 +243,16 @@ async def test_update_address_details_success(self):
self.mock_address_crud.update_address_details.return_value = updated_dict_from_crud
response = await self.address_service.update_address_details(
- db=self.mock_db_conn, address_id=address_id, address_in=update_in,
- user_id_making_change=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ address_id=address_id,
+ address_in=update_in,
+ user_id_making_change=self.test_user_id,
+ actor_id=self.test_actor_id,
)
self.assertEqual(response.RecipientName, "Updated Recipient Name")
- self.mock_address_crud.get_address_by_id.assert_called_once_with(conn=self.mock_db_conn, address_id=address_id,
- actor_id=self.test_actor_id)
+ self.mock_address_crud.get_address_by_id.assert_called_once_with(
+ conn=self.mock_db_conn, address_id=address_id, actor_id=self.test_actor_id
+ )
self.mock_address_crud.update_address_details.assert_called_once_with(
conn=self.mock_db_conn, address_id=address_id, address_in=update_in, actor_id=self.test_actor_id
)
@@ -239,8 +262,11 @@ async def test_update_address_details_address_not_found_by_get(self):
update_in = AddressUpdateRequest(RecipientName="Update")
with self.assertRaisesRegex(AddressNotFoundException, "Address with ID 999 not found."):
await self.address_service.update_address_details(
- db=self.mock_db_conn, address_id=999, address_in=update_in,
- user_id_making_change=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ address_id=999,
+ address_in=update_in,
+ user_id_making_change=self.test_user_id,
+ actor_id=self.test_actor_id,
)
async def test_update_address_details_ownership_mismatch_raises_not_found(self):
@@ -249,12 +275,15 @@ async def test_update_address_details_ownership_mismatch_raises_not_found(self):
address_of_other_user = {**self.sample_address_dict_camel, "UserID": self.another_user_id}
self.mock_address_crud.get_address_by_id.return_value = address_of_other_user
- with self.assertRaisesRegex(AddressNotFoundException,
- f"Address with ID {address_id} not found for user {self.test_user_id}."):
+ with self.assertRaisesRegex(
+ AddressNotFoundException, f"Address with ID {address_id} not found for user {self.test_user_id}."
+ ):
await self.address_service.update_address_details(
- db=self.mock_db_conn, address_id=address_id, address_in=update_in,
+ db=self.mock_db_conn,
+ address_id=address_id,
+ address_in=update_in,
user_id_making_change=self.test_user_id,
- actor_id=self.test_actor_id
+ actor_id=self.test_actor_id,
)
# --- 测试 set_default_address_for_user ---
@@ -263,15 +292,17 @@ async def test_set_default_address_success(self):
self.mock_address_crud.get_address_by_id.side_effect = [
self.sample_address_dict_camel, # For initial check
- {**self.sample_address_dict_camel, "IsDefault": True} # For final response
+ {**self.sample_address_dict_camel, "IsDefault": True}, # For final response
]
self.mock_address_crud.set_all_other_addresses_non_default_for_user.return_value = 1
self.mock_address_crud.update_address_is_default_flag.return_value = True
self.mock_user_crud.update_user_default_address_id.return_value = True
response = await self.address_service.set_default_address_for_user(
- db=self.mock_db_conn, user_id=self.test_user_id,
- address_id_to_set_default=address_id_to_set, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ address_id_to_set_default=address_id_to_set,
+ actor_id=self.test_actor_id,
)
self.assertIsInstance(response, SetDefaultAddressResponse)
self.assertTrue(response.DefaultAddress.IsDefault) # type: ignore
@@ -284,20 +315,25 @@ async def test_set_default_address_address_not_found_initial_get(self):
self.mock_address_crud.get_address_by_id.return_value = None
with self.assertRaisesRegex(AddressNotFoundException, "Address with ID 999 not found."):
await self.address_service.set_default_address_for_user(
- db=self.mock_db_conn, user_id=self.test_user_id,
- address_id_to_set_default=999, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ address_id_to_set_default=999,
+ actor_id=self.test_actor_id,
)
async def test_set_default_address_ownership_mismatch(self):
address_id_to_set = self.sample_address_dict_camel["AddressID"]
address_of_other = {**self.sample_address_dict_camel, "UserID": self.another_user_id}
self.mock_address_crud.get_address_by_id.return_value = address_of_other
- with self.assertRaisesRegex(AddressNotFoundException,
- f"Address with ID {address_id_to_set} does not belong to UserID {self.test_user_id}."):
+ with self.assertRaisesRegex(
+ AddressNotFoundException,
+ f"Address with ID {address_id_to_set} does not belong to UserID {self.test_user_id}.",
+ ):
await self.address_service.set_default_address_for_user(
- db=self.mock_db_conn, user_id=self.test_user_id,
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
address_id_to_set_default=address_id_to_set,
- actor_id=self.test_actor_id
+ actor_id=self.test_actor_id,
)
# --- 测试 delete_address_for_user ---
@@ -308,8 +344,10 @@ async def test_delete_address_success_not_default(self):
self.mock_address_crud.delete_address.return_value = True
result = await self.address_service.delete_address_for_user(
- db=self.mock_db_conn, address_id=address_id_to_delete,
- user_id_making_request=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ address_id=address_id_to_delete,
+ user_id_making_request=self.test_user_id,
+ actor_id=self.test_actor_id,
)
self.assertTrue(result)
self.mock_address_crud.delete_address.assert_called_once_with(
@@ -320,14 +358,18 @@ async def test_delete_address_success_not_default(self):
async def test_delete_address_success_was_default(self):
address_id_to_delete = self.sample_address_dict_camel["AddressID"]
self.mock_address_crud.get_address_by_id.return_value = self.sample_address_dict_camel
- self.mock_user_crud.get_user_by_id.return_value = {**self.sample_user_dict_camel,
- "DefaultAddressID": address_id_to_delete}
+ self.mock_user_crud.get_user_by_id.return_value = {
+ **self.sample_user_dict_camel,
+ "DefaultAddressID": address_id_to_delete,
+ }
self.mock_address_crud.delete_address.return_value = True
self.mock_user_crud.update_user_default_address_id.return_value = True
result = await self.address_service.delete_address_for_user(
- db=self.mock_db_conn, address_id=address_id_to_delete,
- user_id_making_request=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ address_id=address_id_to_delete,
+ user_id_making_request=self.test_user_id,
+ actor_id=self.test_actor_id,
)
self.assertTrue(result)
self.mock_user_crud.update_user_default_address_id.assert_called_once_with(
@@ -338,8 +380,10 @@ async def test_delete_address_address_not_found_by_get(self):
self.mock_address_crud.get_address_by_id.return_value = None
with self.assertRaisesRegex(AddressNotFoundException, "Address with ID 999 not found."):
await self.address_service.delete_address_for_user(
- db=self.mock_db_conn, address_id=999,
- user_id_making_request=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ address_id=999,
+ user_id_making_request=self.test_user_id,
+ actor_id=self.test_actor_id,
)
async def test_delete_address_ownership_mismatch_raises_not_found(self):
@@ -348,14 +392,16 @@ async def test_delete_address_ownership_mismatch_raises_not_found(self):
address_of_other = {**self.sample_address_dict_camel, "UserID": self.another_user_id}
self.mock_address_crud.get_address_by_id.return_value = address_of_other
- with self.assertRaisesRegex(AddressNotFoundException,
- f"Address with ID {address_id_to_delete} not found for user {self.test_user_id}."):
+ with self.assertRaisesRegex(
+ AddressNotFoundException, f"Address with ID {address_id_to_delete} not found for user {self.test_user_id}."
+ ):
await self.address_service.delete_address_for_user(
- db=self.mock_db_conn, address_id=address_id_to_delete,
+ db=self.mock_db_conn,
+ address_id=address_id_to_delete,
user_id_making_request=self.test_user_id,
- actor_id=self.test_actor_id
+ actor_id=self.test_actor_id,
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_auth_service.py b/src/backend/test/unit/service/test_auth_service.py
index 08851d8..936f804 100644
--- a/src/backend/test/unit/service/test_auth_service.py
+++ b/src/backend/test/unit/service/test_auth_service.py
@@ -31,15 +31,15 @@ def setUp(self):
self.mock_decode_jwt_func = MagicMock()
self.mock_conn = MagicMock(spec=Connection)
- self.settings_patcher = patch('backend.app.services.auth_service.settings', MockSettings())
+ self.settings_patcher = patch("backend.app.services.auth_service.settings", MockSettings())
self.mock_settings = self.settings_patcher.start()
self.addCleanup(self.settings_patcher.stop)
- self.uuid_patcher = patch('backend.app.services.auth_service.uuid.uuid4')
+ self.uuid_patcher = patch("backend.app.services.auth_service.uuid.uuid4")
self.mock_uuid4 = self.uuid_patcher.start()
- self.mock_uuid4.return_value = uuid.UUID('12345678-1234-5678-1234-567812345678')
+ self.mock_uuid4.return_value = uuid.UUID("12345678-1234-5678-1234-567812345678")
- self.logger_patcher = patch('backend.app.services.auth_service.logger')
+ self.logger_patcher = patch("backend.app.services.auth_service.logger")
self.mock_logger = self.logger_patcher.start()
self.addCleanup(self.logger_patcher.stop)
@@ -48,7 +48,7 @@ def setUp(self):
user_session_crud=self.mock_user_session_crud,
verify_password_func=self.mock_verify_password_func,
create_jwt_func=self.mock_create_jwt_func,
- decode_jwt_func=self.mock_decode_jwt_func
+ decode_jwt_func=self.mock_decode_jwt_func,
)
async def test_authenticate_user_success_by_email(self):
@@ -63,7 +63,7 @@ async def test_authenticate_user_success_by_email(self):
"PhoneNumber": None,
# CRUD might return more fields from DDL, but UserInDBForAuth will only pick up defined ones (or their aliases)
"UserRole": "customer",
- "AccountStatus": "ACTIVE"
+ "AccountStatus": "ACTIVE",
}
self.mock_user_crud.get_user_with_password_by_email.return_value = mock_user_data_from_crud
self.mock_user_crud.get_user_with_password_by_username.return_value = None
@@ -99,7 +99,7 @@ async def test_authenticate_user_success_by_username(self):
"PasswordHash": "hashed_password2",
"PhoneNumber": "1234567890",
"UserRole": "customer",
- "AccountStatus": "ACTIVE"
+ "AccountStatus": "ACTIVE",
}
self.mock_user_crud.get_user_with_password_by_email.return_value = None
self.mock_user_crud.get_user_with_password_by_username.return_value = mock_user_data_from_crud
@@ -142,7 +142,7 @@ async def test_authenticate_user_incorrect_password(self):
"Email": "test@example.com",
"PasswordHash": "hashed_password",
"PhoneNumber": None,
- "AccountStatus": "ACTIVE"
+ "AccountStatus": "ACTIVE",
}
self.mock_user_crud.get_user_with_password_by_email.return_value = mock_user_data_from_crud
self.mock_verify_password_func.return_value = False # Password verification fails
@@ -157,18 +157,23 @@ async def test_login_user_success(self):
# This part simulates the *output* of authenticate_user, which is a Pydantic model.
# The Pydantic model (UserInDBForAuth) will have snake_case attributes.
authenticated_user_pydantic_object = UserInDBForAuth(
- UserID=1, Username="testuser", Email="test@example.com",
- PasswordHash="hashed_password", PhoneNumber=None, AccountStatus="ACTIVE" # Assuming UserInDBForAuth fields
+ UserID=1,
+ Username="testuser",
+ Email="test@example.com",
+ PasswordHash="hashed_password",
+ PhoneNumber=None,
+ AccountStatus="ACTIVE", # Assuming UserInDBForAuth fields
)
- with patch.object(self.auth_service, 'authenticate_user',
- AsyncMock(return_value=authenticated_user_pydantic_object)) as mock_auth_user:
+ with patch.object(
+ self.auth_service, "authenticate_user", AsyncMock(return_value=authenticated_user_pydantic_object)
+ ) as mock_auth_user:
mock_exp_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=30)
self.mock_decode_jwt_func.return_value = TokenPayload(
sub=str(authenticated_user_pydantic_object.UserID),
user_id=authenticated_user_pydantic_object.UserID,
jti=str(self.mock_uuid4.return_value),
- exp=mock_exp_time
+ exp=mock_exp_time,
)
login_data = UserLogin(UsernameOrEmail="test@example.com", Password="password123")
@@ -181,9 +186,9 @@ async def test_login_user_success(self):
)
self.mock_create_jwt_func.assert_called_once()
create_jwt_call_args = self.mock_create_jwt_func.call_args.args[0]
- self.assertEqual(create_jwt_call_args['sub'], str(authenticated_user_pydantic_object.Username))
- self.assertEqual(create_jwt_call_args['user_id'], authenticated_user_pydantic_object.UserID)
- self.assertEqual(create_jwt_call_args['jti'], str(self.mock_uuid4.return_value))
+ self.assertEqual(create_jwt_call_args["sub"], str(authenticated_user_pydantic_object.Username))
+ self.assertEqual(create_jwt_call_args["user_id"], authenticated_user_pydantic_object.UserID)
+ self.assertEqual(create_jwt_call_args["jti"], str(self.mock_uuid4.return_value))
self.mock_decode_jwt_func.assert_called_once_with("mocked.jwt.token")
@@ -193,7 +198,7 @@ async def test_login_user_success(self):
user_id=authenticated_user_pydantic_object.UserID,
expires_at=mock_exp_time,
ip_address="1.2.3.4",
- user_agent="TestAgent"
+ user_agent="TestAgent",
)
self.assertIsInstance(token_response, Token)
self.assertEqual(token_response.access_token, "mocked.jwt.token")
@@ -206,16 +211,22 @@ async def test_login_user_success(self):
# I'll include them for completeness, assuming no major changes needed there based on this specific request.
async def test_login_user_authentication_failure(self):
- with patch.object(self.auth_service, 'authenticate_user', AsyncMock(return_value=None)) as mock_auth_user:
+ with patch.object(self.auth_service, "authenticate_user", AsyncMock(return_value=None)) as mock_auth_user:
login_data = UserLogin(UsernameOrEmail="wrong@example.com", Password="wrongpassword")
with self.assertRaisesRegex(AuthenticationException, "Incorrect identifier or password."):
await self.auth_service.login_user(self.mock_conn, login_data=login_data)
mock_auth_user.assert_called_once()
async def test_login_user_failed_to_determine_expiration(self):
- authenticated_user_data = UserInDBForAuth(UserID=1, Username="testuser", Email="test@example.com",
- PasswordHash="hp", PhoneNumber=None, AccountStatus="ACTIVE")
- with patch.object(self.auth_service, 'authenticate_user', AsyncMock(return_value=authenticated_user_data)):
+ authenticated_user_data = UserInDBForAuth(
+ UserID=1,
+ Username="testuser",
+ Email="test@example.com",
+ PasswordHash="hp",
+ PhoneNumber=None,
+ AccountStatus="ACTIVE",
+ )
+ with patch.object(self.auth_service, "authenticate_user", AsyncMock(return_value=authenticated_user_data)):
self.mock_decode_jwt_func.return_value = None
login_data = UserLogin(UsernameOrEmail="test@example.com", Password="password123")
with self.assertRaisesRegex(Exception, "Failed to determine token expiration for session storage."):
@@ -223,8 +234,12 @@ async def test_login_user_failed_to_determine_expiration(self):
async def test_logout_user_session_success(self):
mock_jti = str(self.mock_uuid4.return_value)
- mock_payload = TokenPayload(sub="1", user_id=1, jti=mock_jti,
- exp=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=5))
+ mock_payload = TokenPayload(
+ sub="1",
+ user_id=1,
+ jti=mock_jti,
+ exp=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=5),
+ )
self.mock_decode_jwt_func.return_value = mock_payload
self.mock_user_session_crud.delete_session_by_token.return_value = True
@@ -249,8 +264,12 @@ async def test_logout_user_session_invalid_token_or_no_jti(self):
async def test_logout_user_session_delete_fails(self):
mock_jti = str(self.mock_uuid4.return_value)
- mock_payload = TokenPayload(sub="1", user_id=1, jti=mock_jti,
- exp=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=5))
+ mock_payload = TokenPayload(
+ sub="1",
+ user_id=1,
+ jti=mock_jti,
+ exp=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=5),
+ )
self.mock_decode_jwt_func.return_value = mock_payload
self.mock_user_session_crud.delete_session_by_token.return_value = False
@@ -262,9 +281,7 @@ async def test_logout_user_session_delete_fails(self):
async def test_logout_all_user_sessions_success(self):
self.mock_user_session_crud.delete_all_sessions_for_user.return_value = 5
- count = await self.auth_service.logout_all_user_sessions(
- self.mock_conn, user_id_to_logout=1, actor_user_id=1
- )
+ count = await self.auth_service.logout_all_user_sessions(self.mock_conn, user_id_to_logout=1, actor_user_id=1)
self.assertEqual(count, 5)
self.mock_user_session_crud.delete_all_sessions_for_user.assert_called_once_with(
self.mock_conn, user_id=1, actor_user_id=1
@@ -274,15 +291,13 @@ async def test_logout_all_user_sessions_actor_mismatch_still_proceeds_for_now(se
self.mock_user_session_crud.delete_all_sessions_for_user.return_value = 2
await self.auth_service.logout_all_user_sessions(
- self.mock_conn, user_id_to_logout=1, actor_user_id=2 # Mismatch
- )
- self.mock_logger.warning.assert_called_once_with(
- f"Actor ID 2 attempting to logout all sessions for user ID 1."
- )
+ self.mock_conn, user_id_to_logout=1, actor_user_id=2
+ ) # Mismatch
+ self.mock_logger.warning.assert_called_once_with(f"Actor ID 2 attempting to logout all sessions for user ID 1.")
self.mock_user_session_crud.delete_all_sessions_for_user.assert_called_once_with(
self.mock_conn, user_id=1, actor_user_id=2
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_cart_service.py b/src/backend/test/unit/service/test_cart_service.py
index e7d0a93..59ebc60 100644
--- a/src/backend/test/unit/service/test_cart_service.py
+++ b/src/backend/test/unit/service/test_cart_service.py
@@ -10,8 +10,12 @@
from backend.app.crud.cartitem_crud import CartItemCRUD
from backend.app.crud.product_crud import ProductCRUD
from backend.app.schemas.cartitem_schema import CartResponse, CartItemResponse
-from backend.app.utils.exceptions import ProductNotFoundException, CartItemNotFoundException, PermissionDeniedException, \
- ProductFieldMissingException
+from backend.app.utils.exceptions import (
+ ProductNotFoundException,
+ CartItemNotFoundException,
+ PermissionDeniedException,
+ ProductFieldMissingException,
+)
# 假设 Connection 类型来自于 sqlalchemy.engine.base
from sqlalchemy.engine.base import Connection
@@ -25,14 +29,11 @@ def setUp(self):
# Patch the logger as it's used in the cart_service module
# 确保 'backend.app.services.cart_service.logger' 是 logger 在 CartService 文件中正确的导入路径
- self.logger_patcher = patch('backend.app.services.cart_service.logger')
+ self.logger_patcher = patch("backend.app.services.cart_service.logger")
self.mock_logger = self.logger_patcher.start()
self.addCleanup(self.logger_patcher.stop)
- self.cart_service = CartService(
- cart_item_crud=self.mock_cart_item_crud,
- product_crud=self.mock_product_crud
- )
+ self.cart_service = CartService(cart_item_crud=self.mock_cart_item_crud, product_crud=self.mock_product_crud)
self.mock_db_conn = MagicMock(spec=Connection)
self.test_user_id = 1
self.test_actor_id = 1
@@ -40,9 +41,7 @@ def setUp(self):
async def test_get_user_cart_details_empty_cart(self):
self.mock_cart_item_crud.get_cart_items_by_user_id.return_value = []
- cart_response = await self.cart_service.get_user_cart_details(
- db=self.mock_db_conn, user_id=self.test_user_id
- )
+ cart_response = await self.cart_service.get_user_cart_details(db=self.mock_db_conn, user_id=self.test_user_id)
self.assertIsInstance(cart_response, CartResponse)
self.assertEqual(len(cart_response.Items), 0)
@@ -55,20 +54,26 @@ async def test_get_user_cart_details_empty_cart(self):
async def test_get_user_cart_details_with_items(self):
mock_item_data_1 = {
- "CartItemID": 1, "UserID": self.test_user_id, "ProductID": 101,
- "Quantity": 2, "PriceAtAddition": Decimal("10.00"), # CRUD returns Decimal
- "AddedDate": datetime.datetime.now(UTC), "ProductName": "Product A"
+ "CartItemID": 1,
+ "UserID": self.test_user_id,
+ "ProductID": 101,
+ "Quantity": 2,
+ "PriceAtAddition": Decimal("10.00"), # CRUD returns Decimal
+ "AddedDate": datetime.datetime.now(UTC),
+ "ProductName": "Product A",
}
mock_item_data_2 = {
- "CartItemID": 2, "UserID": self.test_user_id, "ProductID": 102,
- "Quantity": 1, "PriceAtAddition": Decimal("20.50"),
- "AddedDate": datetime.datetime.now(UTC), "ProductName": "Product B"
+ "CartItemID": 2,
+ "UserID": self.test_user_id,
+ "ProductID": 102,
+ "Quantity": 1,
+ "PriceAtAddition": Decimal("20.50"),
+ "AddedDate": datetime.datetime.now(UTC),
+ "ProductName": "Product B",
}
self.mock_cart_item_crud.get_cart_items_by_user_id.return_value = [mock_item_data_1, mock_item_data_2]
- cart_response = await self.cart_service.get_user_cart_details(
- db=self.mock_db_conn, user_id=self.test_user_id
- )
+ cart_response = await self.cart_service.get_user_cart_details(db=self.mock_db_conn, user_id=self.test_user_id)
self.assertEqual(len(cart_response.Items), 2)
self.assertEqual(cart_response.TotalItems, 2)
@@ -84,9 +89,7 @@ async def test_get_user_cart_details_pydantic_error(self):
self.mock_cart_item_crud.get_cart_items_by_user_id.return_value = mock_invalid_item_data
with self.assertRaises(Exception): # Or more specific PydanticValidationError if you import it
- await self.cart_service.get_user_cart_details(
- db=self.mock_db_conn, user_id=self.test_user_id
- )
+ await self.cart_service.get_user_cart_details(db=self.mock_db_conn, user_id=self.test_user_id)
self.mock_logger.error.assert_called_once()
async def test_add_item_to_cart_new_item_success(self):
@@ -95,16 +98,23 @@ async def test_add_item_to_cart_new_item_success(self):
actor_id = self.test_user_id
mock_product_data = {"ProductID": product_id, "ProductName": "Test Product", "Price": Decimal("15.75")}
mock_added_cart_item_data = {
- "CartItemID": 10, "UserID": self.test_user_id, "ProductID": product_id,
- "Quantity": quantity, "PriceAtAddition": 15.75, "AddedDate": datetime.datetime.now(UTC)
+ "CartItemID": 10,
+ "UserID": self.test_user_id,
+ "ProductID": product_id,
+ "Quantity": quantity,
+ "PriceAtAddition": 15.75,
+ "AddedDate": datetime.datetime.now(UTC),
}
self.mock_product_crud.get_product_by_id.return_value = mock_product_data
self.mock_cart_item_crud.add_item_to_cart.return_value = mock_added_cart_item_data
result = await self.cart_service.add_item_to_cart(
- self.mock_db_conn, user_id=self.test_user_id, product_id=product_id,
- quantity_to_add=quantity, actor_id=actor_id
+ self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=product_id,
+ quantity_to_add=quantity,
+ actor_id=actor_id,
)
self.assertIsInstance(result, CartItemResponse)
@@ -114,32 +124,47 @@ async def test_add_item_to_cart_new_item_success(self):
conn=self.mock_db_conn, product_id=product_id, actor_id=actor_id
)
self.mock_cart_item_crud.add_item_to_cart.assert_called_once_with(
- conn=self.mock_db_conn, user_id=self.test_user_id, product_id=product_id,
- quantity=quantity, price_at_addition=15.75, actor_id=actor_id
+ conn=self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=product_id,
+ quantity=quantity,
+ price_at_addition=15.75,
+ actor_id=actor_id,
)
async def test_add_item_to_cart_product_not_found(self):
self.mock_product_crud.get_product_by_id.return_value = None
with self.assertRaises(ProductNotFoundException):
await self.cart_service.add_item_to_cart(
- self.mock_db_conn, user_id=self.test_user_id, product_id=999,
- quantity_to_add=1, actor_id=self.test_actor_id
+ self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=999,
+ quantity_to_add=1,
+ actor_id=self.test_actor_id,
)
async def test_add_item_to_cart_product_no_price(self):
- self.mock_product_crud.get_product_by_id.return_value = {"ProductID": 202,
- "ProductName": "No Price Product"} # Price is None
+ self.mock_product_crud.get_product_by_id.return_value = {
+ "ProductID": 202,
+ "ProductName": "No Price Product",
+ } # Price is None
with self.assertRaisesRegex(ProductFieldMissingException, "Price not available for product 202"):
await self.cart_service.add_item_to_cart(
- self.mock_db_conn, user_id=self.test_user_id, product_id=202,
- quantity_to_add=1, actor_id=self.test_actor_id
+ self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=202,
+ quantity_to_add=1,
+ actor_id=self.test_actor_id,
)
async def test_add_item_to_cart_invalid_quantity(self):
with self.assertRaisesRegex(ValueError, "Quantity to add must be positive."):
await self.cart_service.add_item_to_cart(
- self.mock_db_conn, user_id=self.test_user_id, product_id=201,
- quantity_to_add=0, actor_id=self.test_actor_id
+ self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=201,
+ quantity_to_add=0,
+ actor_id=self.test_actor_id,
)
async def test_add_item_to_cart_crud_fails(self):
@@ -149,8 +174,11 @@ async def test_add_item_to_cart_crud_fails(self):
with self.assertRaisesRegex(Exception, "Could not add item to cart."):
await self.cart_service.add_item_to_cart(
- self.mock_db_conn, user_id=self.test_user_id, product_id=201,
- quantity_to_add=1, actor_id=self.test_actor_id
+ self.mock_db_conn,
+ user_id=self.test_user_id,
+ product_id=201,
+ quantity_to_add=1,
+ actor_id=self.test_actor_id,
)
async def test_update_cart_item_quantity_success(self):
@@ -158,16 +186,22 @@ async def test_update_cart_item_quantity_success(self):
new_quantity = 5
mock_existing_item = {"CartItemID": cart_item_id, "UserID": self.test_user_id, "ProductID": 101}
mock_updated_item_data = {
- "CartItemID": cart_item_id, "UserID": self.test_user_id, "ProductID": 101,
- "Quantity": new_quantity, "PriceAtAddition": 10.00, "AddedDate": datetime.datetime.now(UTC)
+ "CartItemID": cart_item_id,
+ "UserID": self.test_user_id,
+ "ProductID": 101,
+ "Quantity": new_quantity,
+ "PriceAtAddition": 10.00,
+ "AddedDate": datetime.datetime.now(UTC),
}
self.mock_cart_item_crud.get_cart_item_by_id.return_value = mock_existing_item
self.mock_cart_item_crud.update_cart_item_quantity.return_value = mock_updated_item_data
result = await self.cart_service.update_cart_item_quantity(
- self.mock_db_conn, cart_item_id=cart_item_id, new_quantity=new_quantity,
- user_id_making_change=self.test_user_id
+ self.mock_db_conn,
+ cart_item_id=cart_item_id,
+ new_quantity=new_quantity,
+ user_id_making_change=self.test_user_id,
)
self.assertIsInstance(result, CartItemResponse)
self.assertEqual(result.Quantity, new_quantity)
@@ -175,16 +209,14 @@ async def test_update_cart_item_quantity_success(self):
conn=self.mock_db_conn, cart_item_id=cart_item_id, actor_id=self.test_user_id
)
self.mock_cart_item_crud.update_cart_item_quantity.assert_called_once_with(
- conn=self.mock_db_conn, cart_item_id=cart_item_id, new_quantity=new_quantity,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, cart_item_id=cart_item_id, new_quantity=new_quantity, actor_id=self.test_user_id
)
async def test_update_cart_item_quantity_item_not_found(self):
self.mock_cart_item_crud.get_cart_item_by_id.return_value = None
with self.assertRaises(CartItemNotFoundException):
await self.cart_service.update_cart_item_quantity(
- self.mock_db_conn, cart_item_id=999, new_quantity=5,
- user_id_making_change=self.test_user_id
+ self.mock_db_conn, cart_item_id=999, new_quantity=5, user_id_making_change=self.test_user_id
)
async def test_update_cart_item_quantity_permission_denied_logs_warning(self):
@@ -195,16 +227,22 @@ async def test_update_cart_item_quantity_permission_denied_logs_warning(self):
mock_existing_item = {"CartItemID": cart_item_id, "UserID": other_user_id, "ProductID": 102}
# Assume update still happens if no exception is raised by service
mock_updated_item_data = {
- "CartItemID": cart_item_id, "UserID": other_user_id, "ProductID": 102,
- "Quantity": 3, "PriceAtAddition": 10.00, "AddedDate": datetime.datetime.now(UTC)
+ "CartItemID": cart_item_id,
+ "UserID": other_user_id,
+ "ProductID": 102,
+ "Quantity": 3,
+ "PriceAtAddition": 10.00,
+ "AddedDate": datetime.datetime.now(UTC),
}
self.mock_cart_item_crud.get_cart_item_by_id.return_value = mock_existing_item
self.mock_cart_item_crud.update_cart_item_quantity.return_value = mock_updated_item_data
await self.cart_service.update_cart_item_quantity(
- self.mock_db_conn, cart_item_id=cart_item_id, new_quantity=3,
- user_id_making_change=self.test_user_id # User 1 tries to update User 2's item
+ self.mock_db_conn,
+ cart_item_id=cart_item_id,
+ new_quantity=3,
+ user_id_making_change=self.test_user_id, # User 1 tries to update User 2's item
)
self.mock_logger.warning.assert_called_with(
f"User {self.test_user_id} tries to update CartItemID {cart_item_id} belonging to user {other_user_id}."
@@ -215,8 +253,7 @@ async def test_update_cart_item_quantity_permission_denied_logs_warning(self):
async def test_update_cart_item_quantity_invalid_quantity(self):
with self.assertRaisesRegex(ValueError, "Quantity must be positive. Use remove item for deletion."):
await self.cart_service.update_cart_item_quantity(
- self.mock_db_conn, cart_item_id=30, new_quantity=0,
- user_id_making_change=self.test_user_id
+ self.mock_db_conn, cart_item_id=30, new_quantity=0, user_id_making_change=self.test_user_id
)
async def test_remove_cart_item_success(self):
@@ -258,14 +295,12 @@ async def test_clear_cart_actor_mismatch_logs_warning(self):
other_actor_id = self.test_actor_id + 1
self.mock_cart_item_crud.clear_cart_for_user.return_value = 0
- await self.cart_service.clear_cart(
- self.mock_db_conn, user_id=self.test_user_id, actor_id=other_actor_id
- )
+ await self.cart_service.clear_cart(self.mock_db_conn, user_id=self.test_user_id, actor_id=other_actor_id)
self.mock_logger.warning.assert_called_with(
f"ActorID {other_actor_id} is clearing cart for UserID {self.test_user_id}."
)
self.mock_cart_item_crud.clear_cart_for_user.assert_called_once()
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_order_service.py b/src/backend/test/unit/service/test_order_service.py
index 247c757..758ad9b 100644
--- a/src/backend/test/unit/service/test_order_service.py
+++ b/src/backend/test/unit/service/test_order_service.py
@@ -9,6 +9,7 @@
from loguru import logger
from backend.app.schemas.payment_schema import PaymentProcessingResponse, PaymentResponse
+
# 调整导入路径以匹配您的项目结构
from backend.app.services.order_service import OrderService
from backend.app.crud.order_crud import OrderCRUD
@@ -28,7 +29,7 @@
CreatedOrderDetailResponse,
OrderItemDetailResponse,
OrderUpdateStatusRequest,
- OrderActionResponse
+ OrderActionResponse,
)
from backend.app.schemas.address_schema import AddressResponse # For mocking address
from backend.app.utils.exceptions import (
@@ -39,7 +40,9 @@
InsufficientStockException,
InvalidStatusTransitionException,
ProductNotFoundException,
- ProductFieldMissingException, PaymentTransactionNotFoundException, InvalidPaymentStatusTransitionException
+ ProductFieldMissingException,
+ PaymentTransactionNotFoundException,
+ InvalidPaymentStatusTransitionException,
)
# 假设 Connection 类型来自于 sqlalchemy.engine.base
@@ -62,7 +65,7 @@ def setUp(self):
address_crud=self.mock_address_crud,
cart_item_crud=self.mock_cart_item_crud,
product_crud=self.mock_product_crud,
- payment_transaction_crud=self.mock_payment_transaction_crud
+ payment_transaction_crud=self.mock_payment_transaction_crud,
)
self.mock_db_conn = MagicMock(spec=Connection)
@@ -75,36 +78,73 @@ def setUp(self):
self.system_actor_id = 0 # Example system actor ID
self.sample_address_data = {
- "AddressID": 10, "UserID": self.test_user_id,
- "RecipientName": "Test Recipient", "PhoneNumber": "1234567890",
- "FullAddress_Text": "123 Test St", "IsDefault": True
+ "AddressID": 10,
+ "UserID": self.test_user_id,
+ "RecipientName": "Test Recipient",
+ "PhoneNumber": "1234567890",
+ "FullAddress_Text": "123 Test St",
+ "IsDefault": True,
}
self.sample_cart_item_data_1 = {
- "CartItemID": 101, "UserID": self.test_user_id, "ProductID": 201,
- "Quantity": 2, "PriceAtAddition": Decimal("10.00"), "StoreID": 301,
- "_product_info": {"ProductID": 201, "Name": "Product A", "Price": Decimal("10.00"), "ImageURL": "a.jpg",
- "Stock": 10, "StoreID": 301}
+ "CartItemID": 101,
+ "UserID": self.test_user_id,
+ "ProductID": 201,
+ "Quantity": 2,
+ "PriceAtAddition": Decimal("10.00"),
+ "StoreID": 301,
+ "_product_info": {
+ "ProductID": 201,
+ "Name": "Product A",
+ "Price": Decimal("10.00"),
+ "ImageURL": "a.jpg",
+ "Stock": 10,
+ "StoreID": 301,
+ },
}
self.sample_cart_item_data_2 = {
- "CartItemID": 102, "UserID": self.test_user_id, "ProductID": 202,
- "Quantity": 1, "PriceAtAddition": Decimal("25.00"), "StoreID": 302,
- "_product_info": {"ProductID": 202, "Name": "Product B", "Price": Decimal("25.00"), "ImageURL": "b.jpg",
- "Stock": 5, "StoreID": 302}
+ "CartItemID": 102,
+ "UserID": self.test_user_id,
+ "ProductID": 202,
+ "Quantity": 1,
+ "PriceAtAddition": Decimal("25.00"),
+ "StoreID": 302,
+ "_product_info": {
+ "ProductID": 202,
+ "Name": "Product B",
+ "Price": Decimal("25.00"),
+ "ImageURL": "b.jpg",
+ "Stock": 5,
+ "StoreID": 302,
+ },
+ }
+ self.sample_product_data_1 = {
+ "ProductID": 201,
+ "ProductName": "Product A",
+ "Price": Decimal("10.00"),
+ "MainImageURL": "a.jpg",
+ "StockQuantity": 10,
+ "StoreID": 301,
+ }
+ self.sample_product_data_2 = {
+ "ProductID": 202,
+ "ProductName": "Product B",
+ "Price": Decimal("25.00"),
+ "MainImageURL": "b.jpg",
+ "StockQuantity": 5,
+ "StoreID": 302,
}
- self.sample_product_data_1 = {"ProductID": 201, "ProductName": "Product A", "Price": Decimal("10.00"),
- "MainImageURL": "a.jpg", "StockQuantity": 10, "StoreID": 301}
- self.sample_product_data_2 = {"ProductID": 202, "ProductName": "Product B", "Price": Decimal("25.00"),
- "MainImageURL": "b.jpg", "StockQuantity": 5, "StoreID": 302}
self.sample_payment_transaction_data = {
- "PaymentTransactionID": 1001, "UserID": self.test_user_id,
- "TotalAmount": Decimal("45.00"), "Status": PaymentTransactionStatusEnum.PENDING.value,
+ "PaymentTransactionID": 1001,
+ "UserID": self.test_user_id,
+ "TotalAmount": Decimal("45.00"),
+ "Status": PaymentTransactionStatusEnum.PENDING.value,
# Add other fields from PaymentTransaction DDL if needed by schemas
"PaymentMethod": "MOCK_PAY",
"ExternalGatewayTransactionID": None,
"CreationTime": datetime.datetime.utcnow(),
"CompletionTime": None,
- "LastUpdatedDate": datetime.datetime.utcnow()
+ "LastUpdatedDate": datetime.datetime.utcnow(),
}
# ⭐ More complete sample order data based on DDL and OrderViewResponse
@@ -117,7 +157,7 @@ def setUp(self):
"ExternalGatewayTransactionID": None,
"CreationTime": now_time,
"CompletionTime": None,
- "LastUpdatedDate": now_time
+ "LastUpdatedDate": now_time,
}
self.sample_pending_payment_transaction_data = {
**self.base_payment_transaction_fields,
@@ -128,7 +168,7 @@ def setUp(self):
**self.base_payment_transaction_fields,
"PaymentTransactionID": 1001,
"Status": PaymentTransactionStatusEnum.SUCCESSFUL.value,
- "CompletionTime": now_time + datetime.timedelta(seconds=10)
+ "CompletionTime": now_time + datetime.timedelta(seconds=10),
}
self.base_order_fields = {
@@ -149,7 +189,7 @@ def setUp(self):
"ShippingTime": None,
"DeliveryTime": None,
"CompletionTime": None,
- "LastUpdatedDate": now_time
+ "LastUpdatedDate": now_time,
}
self.sample_order_data_store301 = {
@@ -168,14 +208,26 @@ def setUp(self):
}
self.sample_order_item_created_data_1 = {
- "OrderItemID": 1001, "OrderID": 1, "ProductID": 201, "StoreID": 301, "Quantity": 2,
- "PriceAtPurchase": Decimal("10.00"), "ProductNameAtPurchase": "Product A",
- "ProductImageURLAtPurchase": "a.jpg", "Subtotal": Decimal("20.00")
+ "OrderItemID": 1001,
+ "OrderID": 1,
+ "ProductID": 201,
+ "StoreID": 301,
+ "Quantity": 2,
+ "PriceAtPurchase": Decimal("10.00"),
+ "ProductNameAtPurchase": "Product A",
+ "ProductImageURLAtPurchase": "a.jpg",
+ "Subtotal": Decimal("20.00"),
}
self.sample_order_item_created_data_2 = {
- "OrderItemID": 1002, "OrderID": 2, "ProductID": 202, "StoreID": 302, "Quantity": 1,
- "PriceAtPurchase": Decimal("25.00"), "ProductNameAtPurchase": "Product B",
- "ProductImageURLAtPurchase": "b.jpg", "Subtotal": Decimal("25.00")
+ "OrderItemID": 1002,
+ "OrderID": 2,
+ "ProductID": 202,
+ "StoreID": 302,
+ "Quantity": 1,
+ "PriceAtPurchase": Decimal("25.00"),
+ "ProductNameAtPurchase": "Product B",
+ "ProductImageURLAtPurchase": "b.jpg",
+ "Subtotal": Decimal("25.00"),
}
# --- Test process_order_creation ---
@@ -183,20 +235,19 @@ async def test_process_order_creation_success(self):
order_create_req = OrderCreateRequest(
ShippingAddressID=10,
Items=[OrderItemCreationInput(CartItemID=101), OrderItemCreationInput(CartItemID=102)],
- Notes_ByUser="Test notes"
+ Notes_ByUser="Test notes",
)
self.mock_address_crud.get_address_by_id.return_value = self.sample_address_data
self.mock_cart_item_crud.get_cart_item_by_id.side_effect = [
self.sample_cart_item_data_1,
- self.sample_cart_item_data_2
+ self.sample_cart_item_data_2,
]
self.mock_cart_item_crud.remove_item_from_cart.return_value = True
- self.mock_product_crud.get_product_by_id.side_effect = [
- self.sample_product_data_1,
- self.sample_product_data_2
- ]
+ self.mock_product_crud.get_product_by_id.side_effect = [self.sample_product_data_1, self.sample_product_data_2]
self.mock_product_crud.update_product_stock.return_value = True
- self.mock_payment_transaction_crud.create_payment_transaction.return_value = self.sample_payment_transaction_data
+ self.mock_payment_transaction_crud.create_payment_transaction.return_value = (
+ self.sample_payment_transaction_data
+ )
# Mock OrderCRUD.create_order to return data that can fulfill CreatedOrderDetailResponse
# CreatedOrderDetailResponse needs: OrderID, StoreID, FinalAmountForThisOrder, OrderStatus, Items
@@ -205,17 +256,19 @@ async def test_process_order_creation_success(self):
mock_created_order_2_for_service = {**self.sample_order_data_store302} # Full data
self.mock_order_crud.create_order.side_effect = [
mock_created_order_1_for_service,
- mock_created_order_2_for_service
+ mock_created_order_2_for_service,
]
self.mock_order_item_crud.create_order_item.side_effect = [
self.sample_order_item_created_data_1,
- self.sample_order_item_created_data_2
+ self.sample_order_item_created_data_2,
]
response = await self.order_service.process_order_creation(
- db=self.mock_db_conn, user_id=self.test_user_id,
- order_create_request=order_create_req, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ order_create_request=order_create_req,
+ actor_id=self.test_actor_id,
)
# logger.debug(f"Response: {response}")
@@ -250,8 +303,10 @@ async def test_process_order_creation_address_not_found(self):
self.mock_address_crud.get_address_by_id.return_value = None
with self.assertRaises(AddressNotFoundException):
await self.order_service.process_order_creation(
- db=self.mock_db_conn, user_id=self.test_user_id,
- order_create_request=order_create_req, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ order_create_request=order_create_req,
+ actor_id=self.test_actor_id,
)
async def test_process_order_creation_cart_item_not_found(self):
@@ -260,8 +315,10 @@ async def test_process_order_creation_cart_item_not_found(self):
self.mock_cart_item_crud.get_cart_item_by_id.return_value = None
with self.assertRaises(CartItemNotFoundException):
await self.order_service.process_order_creation(
- db=self.mock_db_conn, user_id=self.test_user_id,
- order_create_request=order_create_req, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ order_create_request=order_create_req,
+ actor_id=self.test_actor_id,
)
async def test_process_order_creation_insufficient_stock(self):
@@ -269,14 +326,19 @@ async def test_process_order_creation_insufficient_stock(self):
self.mock_address_crud.get_address_by_id.return_value = self.sample_address_data
cart_item_with_high_qty = {**self.sample_cart_item_data_1, "Quantity": 1000}
self.mock_cart_item_crud.get_cart_item_by_id.return_value = cart_item_with_high_qty
- product_with_low_stock = {**self.sample_product_data_1, "StockQuantity": 1,
- "Stock": 1} # Ensure Stock matches SUT
+ product_with_low_stock = {
+ **self.sample_product_data_1,
+ "StockQuantity": 1,
+ "Stock": 1,
+ } # Ensure Stock matches SUT
self.mock_product_crud.get_product_by_id.return_value = product_with_low_stock
with self.assertRaises(InsufficientStockException):
await self.order_service.process_order_creation(
- db=self.mock_db_conn, user_id=self.test_user_id,
- order_create_request=order_create_req, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ user_id=self.test_user_id,
+ order_create_request=order_create_req,
+ actor_id=self.test_actor_id,
)
# The stock check failed, so the rollback of the transaction should be called
self.mock_db_conn.begin_nested.return_value.rollback.assert_called_with()
@@ -287,14 +349,16 @@ async def test_get_orders_for_user_success(self):
# Mock OrderCRUD to return full data for OrderViewResponse
self.mock_order_crud.get_orders_by_user_id.return_value = [
self.sample_order_data_store301,
- self.sample_order_data_store302
+ self.sample_order_data_store302,
]
self.mock_order_item_crud.get_order_items_by_order_id.side_effect = [
[self.sample_order_item_created_data_1],
- [self.sample_order_item_created_data_2]
+ [self.sample_order_item_created_data_2],
]
# Mock PaymentTransactionCRUD for PaymentStatus
- self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = self.sample_payment_transaction_data
+ self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = (
+ self.sample_payment_transaction_data
+ )
response = await self.order_service.get_orders_for_user(
db=self.mock_db_conn, user_id=self.test_user_id, actor_id=self.test_actor_id, offset=0, limit=10
@@ -313,7 +377,9 @@ async def test_get_order_details_by_id_success(self):
# Mock OrderCRUD to return full data for OrderViewResponse
self.mock_order_crud.get_order_by_id.return_value = self.sample_order_data_store301
self.mock_order_item_crud.get_order_items_by_order_id.return_value = [self.sample_order_item_created_data_1]
- self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = self.sample_payment_transaction_data
+ self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = (
+ self.sample_payment_transaction_data
+ )
response = await self.order_service.get_order_details_by_id_for_user(
db=self.mock_db_conn, order_id=order_id, user_id=self.test_user_id, actor_id=self.test_actor_id
@@ -322,8 +388,9 @@ async def test_get_order_details_by_id_success(self):
self.assertEqual(response.OrderID, order_id)
self.assertEqual(len(response.Items), 1)
self.assertEqual(response.PaymentStatus, PaymentTransactionStatusEnum.PENDING)
- self.assertEqual(response.ShippingAddress_RecipientName,
- self.sample_order_data_store301["ShippingAddress_RecipientName"])
+ self.assertEqual(
+ response.ShippingAddress_RecipientName, self.sample_order_data_store301["ShippingAddress_RecipientName"]
+ )
async def test_get_order_details_not_found(self):
self.mock_order_crud.get_order_by_id.return_value = None
@@ -332,7 +399,7 @@ async def test_get_order_details_not_found(self):
db=self.mock_db_conn, order_id=999, user_id=self.test_user_id, actor_id=self.test_actor_id
)
- @patch('backend.app.services.order_service.logger')
+ @patch("backend.app.services.order_service.logger")
async def test_get_order_details_permission_denied_logs_warning(self, mock_logger):
order_data_other_user = {**self.sample_order_data_store301, "UserID": self.test_user_id + 10}
self.mock_order_crud.get_order_by_id.return_value = order_data_other_user
@@ -342,30 +409,33 @@ async def test_get_order_details_permission_denied_logs_warning(self, mock_logge
# SUT currently logs warning and then raises OrderNotFoundException due to UserID mismatch
with self.assertRaises(OrderNotFoundException): # Based on current SUT logic
await self.order_service.get_order_details_by_id_for_user(
- db=self.mock_db_conn, order_id=self.sample_order_data_store301["OrderID"],
+ db=self.mock_db_conn,
+ order_id=self.sample_order_data_store301["OrderID"],
user_id=self.test_user_id,
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
mock_logger.warning.assert_any_call(
f"Order with ID {self.sample_order_data_store301['OrderID']} not found or does not belong to user {self.test_user_id}"
)
# --- Test update_order_status ---
- @patch('backend.app.services.order_service.datetime')
+ @patch("backend.app.services.order_service.datetime")
async def test_update_order_status_success_pending_to_paid(self, mock_datetime_module):
fixed_now_aware = datetime.datetime(2025, 5, 20, 10, 0, 0, tzinfo=datetime.timezone.utc)
mock_datetime_module.datetime.now.return_value = fixed_now_aware
order_id = self.sample_order_data_store301["OrderID"]
- current_order_data_from_db = {**self.sample_order_data_store301,
- "OrderStatus": OrderStatusEnum.PENDING_PAYMENT.value}
+ current_order_data_from_db = {
+ **self.sample_order_data_store301,
+ "OrderStatus": OrderStatusEnum.PENDING_PAYMENT.value,
+ }
new_status = OrderStatusEnum.PAID_AND_PENDING_PROCESSING
# This is the data that OrderCRUD.update_order_status would return (its internal get_order_by_id)
mock_data_after_crud_update = {
**current_order_data_from_db,
"OrderStatus": new_status.value,
- "PaymentConfirmationTime": fixed_now_aware.replace(tzinfo=None) # Naive UTC for DB
+ "PaymentConfirmationTime": fixed_now_aware.replace(tzinfo=None), # Naive UTC for DB
}
self.mock_order_crud.update_order_status.return_value = mock_data_after_crud_update
@@ -374,34 +444,43 @@ async def test_update_order_status_success_pending_to_paid(self, mock_datetime_m
# It should reflect the final state.
self.mock_order_crud.get_order_by_id.side_effect = [
current_order_data_from_db, # First call in update_order_status
- mock_data_after_crud_update # Second call from get_order_details_by_id_for_user
+ mock_data_after_crud_update, # Second call from get_order_details_by_id_for_user
]
self.mock_order_item_crud.get_order_items_by_order_id.return_value = []
- self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = self.sample_payment_transaction_data
+ self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = (
+ self.sample_payment_transaction_data
+ )
updated_order_response = await self.order_service.update_order_status(
- db=self.mock_db_conn, order_id=order_id, new_status=new_status,
- actor_id=self.test_actor_id, is_admin_action=True
+ db=self.mock_db_conn,
+ order_id=order_id,
+ new_status=new_status,
+ actor_id=self.test_actor_id,
+ is_admin_action=True,
)
self.assertEqual(updated_order_response.OrderStatus, new_status)
self.assertEqual(updated_order_response.PaymentConfirmationTime, fixed_now_aware)
self.mock_order_crud.update_order_status.assert_called_once_with(
- self.mock_db_conn, order_id=order_id, new_status=new_status,
+ self.mock_db_conn,
+ order_id=order_id,
+ new_status=new_status,
actor_id=self.test_actor_id,
payment_confirmation_time=fixed_now_aware,
- shipping_time=None, delivery_time=None, completion_time=None,
- notes_by_actor=None, is_admin_or_merchant_action=True
+ shipping_time=None,
+ delivery_time=None,
+ completion_time=None,
+ notes_by_actor=None,
+ is_admin_or_merchant_action=True,
)
async def test_update_order_status_order_not_found(self):
self.mock_order_crud.get_order_by_id.return_value = None
with self.assertRaises(OrderNotFoundException):
await self.order_service.update_order_status(
- db=self.mock_db_conn, order_id=999, new_status=OrderStatusEnum.SHIPPED,
- actor_id=self.test_actor_id
+ db=self.mock_db_conn, order_id=999, new_status=OrderStatusEnum.SHIPPED, actor_id=self.test_actor_id
)
async def test_update_order_status_invalid_transition(self):
@@ -411,12 +490,11 @@ async def test_update_order_status_invalid_transition(self):
with self.assertRaises(InvalidStatusTransitionException):
await self.order_service.update_order_status(
- db=self.mock_db_conn, order_id=order_id, new_status=OrderStatusEnum.SHIPPED,
- actor_id=self.test_actor_id
+ db=self.mock_db_conn, order_id=order_id, new_status=OrderStatusEnum.SHIPPED, actor_id=self.test_actor_id
)
# --- Test process_successful_payment ---
- @patch('backend.app.services.order_service.datetime') # To mock datetime.now for completion_time
+ @patch("backend.app.services.order_service.datetime") # To mock datetime.now for completion_time
async def test_process_successful_payment_success(self, mock_datetime_module):
fixed_now_aware = datetime.datetime(2025, 5, 20, 10, 0, 0, tzinfo=datetime.timezone.utc)
mock_datetime_module.datetime.now.return_value = fixed_now_aware
@@ -424,18 +502,22 @@ async def test_process_successful_payment_success(self, mock_datetime_module):
payment_tx_id = self.sample_pending_payment_transaction_data["PaymentTransactionID"]
# 1. Mock get_payment_transaction_by_id to return a PENDING transaction
- self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = self.sample_pending_payment_transaction_data
+ self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = (
+ self.sample_pending_payment_transaction_data
+ )
# 2. Mock update_payment_transaction_status to return the updated (SUCCESSFUL) transaction
- updated_tx_data = {**self.sample_pending_payment_transaction_data,
- "Status": PaymentTransactionStatusEnum.SUCCESSFUL.value,
- "CompletionTime": fixed_now_aware.replace(tzinfo=None)}
+ updated_tx_data = {
+ **self.sample_pending_payment_transaction_data,
+ "Status": PaymentTransactionStatusEnum.SUCCESSFUL.value,
+ "CompletionTime": fixed_now_aware.replace(tzinfo=None),
+ }
self.mock_payment_transaction_crud.update_payment_transaction_status.return_value = updated_tx_data
# 3. Mock get_orders_by_payment_transaction_id to return associated PENDING_PAYMENT orders
associated_orders_data = [
{**self.sample_order_data_store301, "OrderStatus": OrderStatusEnum.PENDING_PAYMENT.value},
- {**self.sample_order_data_store302, "OrderStatus": OrderStatusEnum.PENDING_PAYMENT.value}
+ {**self.sample_order_data_store302, "OrderStatus": OrderStatusEnum.PENDING_PAYMENT.value},
]
self.mock_order_crud.get_orders_by_payment_transaction_id.return_value = associated_orders_data
@@ -449,29 +531,34 @@ async def mock_update_order_status_side_effect(db, order_id, new_status, actor_i
# This should return an OrderViewResponse compatible dict/Pydantic model
# For this mock, let's just return the OrderStatus from the OrderViewResponse schema
# The SUT uses updated_order_resp.OrderStatus.value
- return OrderViewResponse(**{
- **original_order,
- "OrderStatus": new_status,
- "PaymentStatus": PaymentTransactionStatusEnum.SUCCESSFUL,
- "Items": [self.sample_order_item_created_data_1] # Mocked items
- }) # Add PaymentStatus
+ return OrderViewResponse(
+ **{
+ **original_order,
+ "OrderStatus": new_status,
+ "PaymentStatus": PaymentTransactionStatusEnum.SUCCESSFUL,
+ "Items": [self.sample_order_item_created_data_1], # Mocked items
+ }
+ ) # Add PaymentStatus
return None
- with patch.object(self.order_service, 'update_order_status',
- AsyncMock(side_effect=mock_update_order_status_side_effect)) as mock_self_update_status:
+ with patch.object(
+ self.order_service, "update_order_status", AsyncMock(side_effect=mock_update_order_status_side_effect)
+ ) as mock_self_update_status:
response = await self.order_service.process_successful_payment(
db=self.mock_db_conn,
payment_transaction_id=payment_tx_id,
external_gateway_tx_id="gw_123",
- actor_id=self.system_actor_id
+ actor_id=self.system_actor_id,
)
self.assertIsInstance(response, PaymentProcessingResponse)
self.assertEqual(response.PaymentTransactionID, payment_tx_id)
self.assertEqual(response.TransactionStatusInSystem, PaymentTransactionStatusEnum.SUCCESSFUL.value)
self.assertIn("支付成功", response.MessageToUser)
- self.assertCountEqual(response.AffectedOrderIDs, [self.sample_order_data_store301["OrderID"],
- self.sample_order_data_store302["OrderID"]])
+ self.assertCountEqual(
+ response.AffectedOrderIDs,
+ [self.sample_order_data_store301["OrderID"], self.sample_order_data_store302["OrderID"]],
+ )
self.mock_payment_transaction_crud.get_payment_transaction_by_id.assert_called_once_with(
conn=self.mock_db_conn, payment_transaction_id=payment_tx_id, actor_id=self.system_actor_id
@@ -482,16 +569,18 @@ async def mock_update_order_status_side_effect(db, order_id, new_status, actor_i
new_status=PaymentTransactionStatusEnum.SUCCESSFUL.value,
actor_id=self.system_actor_id,
external_gateway_transaction_id="gw_123",
- completion_time=fixed_now_aware # SUT passes aware
+ completion_time=fixed_now_aware, # SUT passes aware
)
self.mock_order_crud.get_orders_by_payment_transaction_id.assert_called_once_with(
conn=self.mock_db_conn, payment_transaction_id=payment_tx_id, actor_id=self.system_actor_id
)
self.assertEqual(mock_self_update_status.call_count, 2)
mock_self_update_status.assert_any_call(
- db=self.mock_db_conn, order_id=self.sample_order_data_store301["OrderID"],
+ db=self.mock_db_conn,
+ order_id=self.sample_order_data_store301["OrderID"],
new_status=OrderStatusEnum.PAID_AND_PENDING_PROCESSING,
- actor_id=self.system_actor_id, is_admin_action=False # As per SUT call
+ actor_id=self.system_actor_id,
+ is_admin_action=False, # As per SUT call
)
async def test_process_successful_payment_tx_not_found(self):
@@ -503,9 +592,12 @@ async def test_process_successful_payment_tx_not_found(self):
async def test_process_successful_payment_tx_already_successful_idempotency(self):
payment_tx_id = self.sample_successful_payment_transaction_data["PaymentTransactionID"]
- self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = self.sample_successful_payment_transaction_data
+ self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = (
+ self.sample_successful_payment_transaction_data
+ )
self.mock_order_crud.get_orders_by_payment_transaction_id.return_value = [
- self.sample_order_data_store301] # Example
+ self.sample_order_data_store301
+ ] # Example
response = await self.order_service.process_successful_payment(
db=self.mock_db_conn, payment_transaction_id=payment_tx_id, actor_id=self.system_actor_id
@@ -518,8 +610,11 @@ async def test_process_successful_payment_tx_already_successful_idempotency(self
async def test_process_successful_payment_tx_invalid_initial_state(self):
payment_tx_id = 1002
- failed_tx_data = {**self.base_payment_transaction_fields, "PaymentTransactionID": payment_tx_id,
- "Status": PaymentTransactionStatusEnum.FAILED.value}
+ failed_tx_data = {
+ **self.base_payment_transaction_fields,
+ "PaymentTransactionID": payment_tx_id,
+ "Status": PaymentTransactionStatusEnum.FAILED.value,
+ }
self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = failed_tx_data
with self.assertRaises(InvalidPaymentStatusTransitionException): # Renamed from InvalidPaymentStateException
@@ -540,27 +635,29 @@ async def test_get_payment_transaction_by_id_for_user_success(self):
"Status": PaymentTransactionStatusEnum.PENDING.value,
"CreationTime": datetime.datetime(2025, 1, 1, 10, 0, 0),
"CompletionTime": None,
- "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0) # Match DDL
+ "LastUpdatedDate": datetime.datetime(2025, 1, 1, 10, 5, 0), # Match DDL
}
self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = mock_db_tx_data
response = await self.order_service.get_payment_transaction_by_id_for_user(
- db=self.mock_db_conn, payment_transaction_id=payment_tx_id,
- user_id=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn,
+ payment_transaction_id=payment_tx_id,
+ user_id=self.test_user_id,
+ actor_id=self.test_actor_id,
)
self.assertIsInstance(response, PaymentResponse)
self.assertEqual(response.PaymentTransactionID, payment_tx_id)
self.assertEqual(response.UserID, self.test_user_id)
self.assertEqual(response.Status, PaymentTransactionStatusEnum.PENDING.value)
- self.assertEqual(response.LastUpdatedTime,
- mock_db_tx_data["LastUpdatedDate"]) # Check one of the renamed fields
+ self.assertEqual(
+ response.LastUpdatedTime, mock_db_tx_data["LastUpdatedDate"]
+ ) # Check one of the renamed fields
async def test_get_payment_transaction_by_id_for_user_not_found(self):
self.mock_payment_transaction_crud.get_payment_transaction_by_id.return_value = None
with self.assertRaises(PaymentTransactionNotFoundException):
await self.order_service.get_payment_transaction_by_id_for_user(
- db=self.mock_db_conn, payment_transaction_id=999,
- user_id=self.test_user_id, actor_id=self.test_actor_id
+ db=self.mock_db_conn, payment_transaction_id=999, user_id=self.test_user_id, actor_id=self.test_actor_id
)
async def test_get_payment_transaction_by_id_for_user_mismatch(self):
@@ -571,11 +668,12 @@ async def test_get_payment_transaction_by_id_for_user_mismatch(self):
with self.assertRaises(PaymentTransactionNotFoundException): # SUT raises this if UserID mismatch
await self.order_service.get_payment_transaction_by_id_for_user(
- db=self.mock_db_conn, payment_transaction_id=payment_tx_id,
+ db=self.mock_db_conn,
+ payment_transaction_id=payment_tx_id,
user_id=self.test_user_id, # User 1 requesting
- actor_id=self.test_actor_id
+ actor_id=self.test_actor_id,
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_permission_checker.py b/src/backend/test/unit/service/test_permission_checker.py
index c5d3529..20cda31 100644
--- a/src/backend/test/unit/service/test_permission_checker.py
+++ b/src/backend/test/unit/service/test_permission_checker.py
@@ -22,14 +22,10 @@ class TestPermissionChecker(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_user_crud = MagicMock(spec=UserCRUD)
- self.mock_db_conn = (
- MagicMock()
- ) # 模拟数据库连接,供 require_same_user_or_admin_from_id 使用
+ self.mock_db_conn = MagicMock() # 模拟数据库连接,供 require_same_user_or_admin_from_id 使用
# 默认 settings 为严格模式
- self.settings_patcher = patch(
- "backend.app.services.permission_checker.settings", MockSettings()
- )
+ self.settings_patcher = patch("backend.app.services.permission_checker.settings", MockSettings())
self.mock_settings = self.settings_patcher.start()
self.addCleanup(self.settings_patcher.stop)
@@ -57,9 +53,7 @@ async def test_require_admin_is_admin_strict_mode(self):
try:
self.checker.require_admin(user=admin_user, strict=True)
except PermissionDeniedException:
- self.fail(
- "require_admin raised PermissionDeniedException unexpectedly for admin user in strict mode"
- )
+ self.fail("require_admin raised PermissionDeniedException unexpectedly for admin user in strict mode")
self.mock_logger_permissions.warning.assert_not_called()
async def test_require_admin_is_not_admin_strict_mode_raises_exception(self):
@@ -71,9 +65,7 @@ async def test_require_admin_is_not_admin_strict_mode_raises_exception(self):
RegistrationDate=datetime.datetime.now(datetime.timezone.utc),
LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
- with self.assertRaisesRegex(
- PermissionDeniedException, "User does not have admin permissions."
- ):
+ with self.assertRaisesRegex(PermissionDeniedException, "User does not have admin permissions."):
self.checker.require_admin(user=non_admin_user, strict=True)
self.mock_logger_permissions.warning.assert_called_once_with(
f"User {non_admin_user.UserID} does not have admin permissions."
@@ -108,9 +100,7 @@ async def test_require_admin_uses_default_strict_mode_from_settings_true(self):
LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
with self.assertRaises(PermissionDeniedException):
- checker_uses_settings.require_admin(
- user=non_admin_user
- ) # strict=None, 应使用全局 True
+ checker_uses_settings.require_admin(user=non_admin_user) # strict=None, 应使用全局 True
self.mock_logger_permissions.warning.assert_called_with(
f"User {non_admin_user.UserID} does not have admin permissions."
)
@@ -127,9 +117,7 @@ async def test_require_admin_uses_default_strict_mode_from_settings_false(self):
RegistrationDate=datetime.datetime.now(datetime.timezone.utc),
LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
- checker_uses_settings.require_admin(
- user=non_admin_user
- ) # strict=None, 应使用全局 False
+ checker_uses_settings.require_admin(user=non_admin_user) # strict=None, 应使用全局 False
self.mock_logger_permissions.warning.assert_called_with(
f"User {non_admin_user.UserID} does not have admin permissions."
)
@@ -152,9 +140,7 @@ async def test_require_same_user_or_admin_from_response_is_same_user(self):
RegistrationDate=datetime.datetime.now(datetime.timezone.utc),
LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
- self.checker.require_same_user_or_admin_from_response(
- user=user, actor=actor, strict=True
- )
+ self.checker.require_same_user_or_admin_from_response(user=user, actor=actor, strict=True)
self.mock_logger_permissions.warning.assert_not_called()
async def test_require_same_user_or_admin_from_response_actor_is_admin(self):
@@ -174,9 +160,7 @@ async def test_require_same_user_or_admin_from_response_actor_is_admin(self):
RegistrationDate=datetime.datetime.now(datetime.timezone.utc),
LastLoginDate=datetime.datetime.now(datetime.timezone.utc),
)
- self.checker.require_same_user_or_admin_from_response(
- user=user, actor=admin_actor, strict=True
- )
+ self.checker.require_same_user_or_admin_from_response(user=user, actor=admin_actor, strict=True)
# SUT _require_same_user_or_admin 会先记录 actor 不是 user,然后检查 role
self.mock_logger_permissions.warning.assert_called_once_with(
f"User(Actor) {admin_actor.UserID} is not the same as the user {user.UserID}."
@@ -203,9 +187,7 @@ async def test_require_same_user_or_admin_from_response_permission_denied_strict
PermissionDeniedException,
"Actor is not the same as the user and does not have admin permissions.",
):
- self.checker.require_same_user_or_admin_from_response(
- user=user, actor=other_actor, strict=True
- )
+ self.checker.require_same_user_or_admin_from_response(user=user, actor=other_actor, strict=True)
self.assertEqual(self.mock_logger_permissions.warning.call_count, 2)
self.mock_logger_permissions.warning.assert_any_call(
@@ -295,9 +277,7 @@ async def test_require_same_user_or_admin_from_id_permission_denied_strict(self)
conn=self.mock_db_conn, user_id=user_id, actor_id=other_actor_id, strict=True
)
self.assertEqual(self.mock_user_crud.get_user_by_id.call_count, 2)
- self.assertEqual(
- self.mock_logger_permissions.warning.call_count, 2
- ) # Log "not same" and "not admin"
+ self.assertEqual(self.mock_logger_permissions.warning.call_count, 2) # Log "not same" and "not admin"
async def test_require_same_user_or_admin_from_id_user_not_found_in_crud_strict(self):
user_id = 1
diff --git a/src/backend/test/unit/service/test_product_change_request_service.py b/src/backend/test/unit/service/test_product_change_request_service.py
index 27b5c0a..5273cf3 100644
--- a/src/backend/test/unit/service/test_product_change_request_service.py
+++ b/src/backend/test/unit/service/test_product_change_request_service.py
@@ -9,7 +9,7 @@
ProductChangeRequestCreate,
ProductChangeRequestUpdate,
ProductChangeRequestResponse,
- ProductChangeRequestAdminUpdate
+ ProductChangeRequestAdminUpdate,
)
# 假设 Connection 类型来自于 sqlalchemy.engine.base
@@ -20,13 +20,13 @@ class TestProductChangeRequestService(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_product_change_request_crud = MagicMock(spec=ProductChangeRequestCRUD)
-
+
# Patch the get_product_change_request_crud_instance function
- patcher = patch('backend.app.services.product_change_request_service.get_product_change_request_crud_instance')
+ patcher = patch("backend.app.services.product_change_request_service.get_product_change_request_crud_instance")
self.mock_get_crud = patcher.start()
self.mock_get_crud.return_value = self.mock_product_change_request_crud
self.addCleanup(patcher.stop)
-
+
self.product_change_request_service = ProductChangeRequestService()
self.mock_db_conn = MagicMock(spec=Connection)
@@ -34,21 +34,25 @@ def setUp(self):
self.test_store_id = 1
self.test_product_id = 1
self.test_admin_id = 2
-
+
self.sample_change_request_dict = {
"ChangeRequestID": 1,
"ProductID": self.test_product_id,
"MerchantUserID": self.test_merchant_id,
"StoreID": self.test_store_id,
"RequestType": "PRODUCT_UPDATE",
- "ProposedData_JSON": {"ProductName": "Updated Product", "ProductDescription": "Updated Description", "Price": 199.99},
+ "ProposedData_JSON": {
+ "ProductName": "Updated Product",
+ "ProductDescription": "Updated Description",
+ "Price": 199.99,
+ },
"Status": "PENDING_APPROVAL",
"SubmitterNotes": "Please approve my product update",
"AdminReviewerID": None,
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": datetime.datetime.now(),
- "LastUpdatedDate": datetime.datetime.now()
+ "LastUpdatedDate": datetime.datetime.now(),
}
# --- 测试 create_change_request ---
@@ -62,41 +66,39 @@ async def test_create_change_request_success(self):
"proposed_data": proposed_data,
"product_id": self.test_product_id,
"submitter_notes": "Please approve",
- "actor_id": self.test_merchant_id
+ "actor_id": self.test_merchant_id,
}
# Mock user_crud_instance.get_user_by_id
- with patch('backend.app.services.product_change_request_service.user_crud_instance') as mock_user_crud:
+ with patch("backend.app.services.product_change_request_service.user_crud_instance") as mock_user_crud:
mock_user_crud.get_user_by_id.return_value = {
"UserID": self.test_merchant_id,
"UserRole": "merchant",
- "Username": "test_merchant"
+ "Username": "test_merchant",
}
-
+
# Mock the product CRUD
mock_product_crud = MagicMock()
mock_product_crud.get_product_by_id.return_value = {
"ProductID": self.test_product_id,
- "StoreID": self.test_store_id
+ "StoreID": self.test_store_id,
}
-
+
# 将mock对象直接赋值给service实例的_product_crud属性
self.product_change_request_service._product_crud = mock_product_crud
-
+
# Mock the CRUD method
self.mock_product_change_request_crud.create_change_request.return_value = self.sample_change_request_dict
# Execute the test
result = await self.product_change_request_service.create_change_request(
- conn=self.mock_db_conn,
- **change_request_in
+ conn=self.mock_db_conn, **change_request_in
)
# Verify the result
self.assertEqual(result, self.sample_change_request_dict)
self.mock_product_change_request_crud.create_change_request.assert_called_once_with(
- conn=self.mock_db_conn,
- **change_request_in
+ conn=self.mock_db_conn, **change_request_in
)
# --- 测试 get_change_request_by_id ---
@@ -106,23 +108,21 @@ async def test_get_change_request_by_id_success(self):
# Execute the test
result = await self.product_change_request_service.get_change_request_by_id(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_merchant_id
)
# Verify the result
self.assertEqual(result, self.sample_change_request_dict)
self.mock_product_change_request_crud.get_change_request_by_id.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_merchant_id
)
# --- 测试 get_change_requests_by_product_id ---
async def test_get_change_requests_by_product_id_success(self):
# Mock the CRUD method
- self.mock_product_change_request_crud.get_change_requests_by_product_id.return_value = [self.sample_change_request_dict]
+ self.mock_product_change_request_crud.get_change_requests_by_product_id.return_value = [
+ self.sample_change_request_dict
+ ]
# Execute the test
result = await self.product_change_request_service.get_change_requests_by_product_id(
@@ -131,7 +131,7 @@ async def test_get_change_requests_by_product_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# Verify the result
@@ -142,13 +142,15 @@ async def test_get_change_requests_by_product_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# --- 测试 get_change_requests_by_store_id ---
async def test_get_change_requests_by_store_id_success(self):
# Mock the CRUD method
- self.mock_product_change_request_crud.get_change_requests_by_store_id.return_value = [self.sample_change_request_dict]
+ self.mock_product_change_request_crud.get_change_requests_by_store_id.return_value = [
+ self.sample_change_request_dict
+ ]
# Execute the test
result = await self.product_change_request_service.get_change_requests_by_store_id(
@@ -157,7 +159,7 @@ async def test_get_change_requests_by_store_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# Verify the result
@@ -168,13 +170,15 @@ async def test_get_change_requests_by_store_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# --- 测试 get_change_requests_by_merchant_id ---
async def test_get_change_requests_by_merchant_id_success(self):
# Mock the CRUD method
- self.mock_product_change_request_crud.get_change_requests_by_merchant_id.return_value = [self.sample_change_request_dict]
+ self.mock_product_change_request_crud.get_change_requests_by_merchant_id.return_value = [
+ self.sample_change_request_dict
+ ]
# Execute the test
result = await self.product_change_request_service.get_change_requests_by_merchant_id(
@@ -183,7 +187,7 @@ async def test_get_change_requests_by_merchant_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# Verify the result
@@ -194,7 +198,7 @@ async def test_get_change_requests_by_merchant_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_merchant_id
+ actor_id=self.test_merchant_id,
)
# --- 测试 get_all_pending_requests ---
@@ -204,19 +208,13 @@ async def test_get_all_pending_requests_success(self):
# Execute the test
result = await self.product_change_request_service.get_all_pending_requests(
- conn=self.mock_db_conn,
- limit=10,
- offset=0,
- actor_id=self.test_admin_id
+ conn=self.mock_db_conn, limit=10, offset=0, actor_id=self.test_admin_id
)
# Verify the result
self.assertEqual(result, [self.sample_change_request_dict])
self.mock_product_change_request_crud.get_all_pending_requests.assert_called_once_with(
- conn=self.mock_db_conn,
- limit=10,
- offset=0,
- actor_id=self.test_admin_id
+ conn=self.mock_db_conn, limit=10, offset=0, actor_id=self.test_admin_id
)
# --- 测试 get_filtered_requests ---
@@ -234,7 +232,7 @@ async def test_get_filtered_requests_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# Verify the result
@@ -251,7 +249,7 @@ async def test_get_filtered_requests_success(self):
end_date=None,
limit=10,
offset=0,
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# --- 测试 update_request ---
@@ -259,7 +257,7 @@ async def test_update_request_success(self):
# Prepare test data
update_data = {
"ProposedData_JSON": {"ProductName": "Updated Product Name", "Price": 129.99},
- "SubmitterNotes": "Updated notes"
+ "SubmitterNotes": "Updated notes",
}
# Mock the CRUD method
@@ -270,19 +268,13 @@ async def test_update_request_success(self):
# Execute the test
result = await self.product_change_request_service.update_request(
- conn=self.mock_db_conn,
- request_id=1,
- update_data=update_data,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, update_data=update_data, actor_id=self.test_merchant_id
)
# Verify the result
self.assertEqual(result, updated_dict)
self.mock_product_change_request_crud.update_request.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- update_data=update_data,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, update_data=update_data, actor_id=self.test_merchant_id
)
# --- 测试 update_request_status ---
@@ -291,9 +283,9 @@ async def test_update_request_status_success(self):
self.mock_product_change_request_crud.get_change_request_by_id.return_value = {
"Status": "PENDING_APPROVAL",
"ChangeRequestID": 1,
- "MerchantUserID": self.test_merchant_id
+ "MerchantUserID": self.test_merchant_id,
}
-
+
# Mock the CRUD method
updated_dict = self.sample_change_request_dict.copy()
updated_dict["Status"] = "APPROVED"
@@ -309,7 +301,7 @@ async def test_update_request_status_success(self):
status="APPROVED",
admin_id=self.test_admin_id,
admin_notes="Approved by admin",
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# Verify the result
@@ -320,7 +312,7 @@ async def test_update_request_status_success(self):
status="APPROVED",
admin_id=self.test_admin_id,
admin_notes="Approved by admin",
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# --- 测试 cancel_request ---
@@ -330,19 +322,15 @@ async def test_cancel_request_success(self):
# Execute the test
result = await self.product_change_request_service.cancel_request(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_merchant_id
)
# Verify the result
self.assertTrue(result)
self.mock_product_change_request_crud.cancel_request.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_merchant_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_merchant_id
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_product_change_request_service_v2.py b/src/backend/test/unit/service/test_product_change_request_service_v2.py
index 7be2249..12a7568 100644
--- a/src/backend/test/unit/service/test_product_change_request_service_v2.py
+++ b/src/backend/test/unit/service/test_product_change_request_service_v2.py
@@ -264,9 +264,7 @@ async def test_get_request_details_success_admin(self):
# --- Test list_requests_for_merchant ---
async def test_list_requests_for_merchant_success(self):
query_params = ProductChangeRequestQueryParams(Status=[RequestStatusEnum.PENDING_APPROVAL])
- mock_requests_data = [
- {**self.sample_pcr_data_cls, "Status": RequestStatusEnum.PENDING_APPROVAL.value}
- ]
+ mock_requests_data = [{**self.sample_pcr_data_cls, "Status": RequestStatusEnum.PENDING_APPROVAL.value}]
self.mock_pcr_crud.get_request_list.return_value = mock_requests_data
response = await self.service.list_requests_for_merchant(
@@ -330,9 +328,7 @@ async def test_merchant_cancel_request_success(self):
# --- Test admin_review_request ---
async def test_admin_review_request_approve_and_trigger_apply(self):
change_request_id = self.sample_pcr_data_cls["ChangeRequestID"]
- review_data = ProductChangeRequestUpdateByAdmin(
- Status=RequestStatusEnum.APPROVED, AdminNotes="Looks good"
- )
+ review_data = ProductChangeRequestUpdateByAdmin(Status=RequestStatusEnum.APPROVED, AdminNotes="Looks good")
pending_request = {
**self.sample_pcr_data_cls,
@@ -385,9 +381,7 @@ async def test_admin_review_request_approve_and_trigger_apply(self):
# --- Test apply_approved_request ---
async def test_apply_approved_request_product_create_success(self):
- change_request_id = self.sample_approved_pcr_data_cls[
- "ChangeRequestID"
- ] # Use approved sample
+ change_request_id = self.sample_approved_pcr_data_cls["ChangeRequestID"] # Use approved sample
applier_user = self.mock_admin_user_cls
# Ensure the approved PCR data for create has ProductID as None initially
@@ -432,7 +426,7 @@ async def test_apply_approved_request_product_create_success(self):
self.mock_pcr_crud.update_request_applied.assert_called_once_with(
conn=self.mock_db_conn,
request_id=change_request_id,
- new_product_id=777, # 回填新创建的产品ID
+ new_product_id=777, # 回填新创建的产品ID
actor_id=applier_user.UserID,
)
diff --git a/src/backend/test/unit/service/test_store_change_request_service.py b/src/backend/test/unit/service/test_store_change_request_service.py
index b3203c6..6fc0a93 100644
--- a/src/backend/test/unit/service/test_store_change_request_service.py
+++ b/src/backend/test/unit/service/test_store_change_request_service.py
@@ -9,7 +9,7 @@
StoreChangeRequestCreate,
StoreChangeRequestUpdate,
StoreChangeRequestResponse,
- StoreChangeRequestAdminUpdate
+ StoreChangeRequestAdminUpdate,
)
# 假设 Connection 类型来自于 sqlalchemy.engine.base
@@ -20,20 +20,20 @@ class TestStoreChangeRequestService(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_store_change_request_crud = MagicMock(spec=StoreChangeRequestCRUD)
-
+
# Patch the get_store_change_request_crud_instance function
- patcher = patch('backend.app.services.store_change_request_service.get_store_change_request_crud_instance')
+ patcher = patch("backend.app.services.store_change_request_service.get_store_change_request_crud_instance")
self.mock_get_crud = patcher.start()
self.mock_get_crud.return_value = self.mock_store_change_request_crud
self.addCleanup(patcher.stop)
-
+
self.store_change_request_service = StoreChangeRequestService()
self.mock_db_conn = MagicMock(spec=Connection)
self.test_user_id = 1
self.test_store_id = 1
self.test_admin_id = 2
-
+
self.sample_change_request_dict = {
"ChangeRequestID": 1,
"StoreID": self.test_store_id,
@@ -46,7 +46,7 @@ def setUp(self):
"ReviewTimestamp": None,
"AdminNotes": None,
"CreationTime": datetime.datetime.now(),
- "LastUpdatedDate": datetime.datetime.now()
+ "LastUpdatedDate": datetime.datetime.now(),
}
# --- 测试 create_change_request ---
@@ -59,7 +59,7 @@ async def test_create_change_request_success(self):
"proposed_data": proposed_data,
"store_id": self.test_store_id,
"submitter_notes": "Please approve",
- "actor_id": self.test_user_id
+ "actor_id": self.test_user_id,
}
# Mock the CRUD method
@@ -67,15 +67,13 @@ async def test_create_change_request_success(self):
# Execute the test
result = await self.store_change_request_service.create_change_request(
- conn=self.mock_db_conn,
- **change_request_in
+ conn=self.mock_db_conn, **change_request_in
)
# Verify the result
self.assertEqual(result, self.sample_change_request_dict)
self.mock_store_change_request_crud.create_change_request.assert_called_once_with(
- conn=self.mock_db_conn,
- **change_request_in
+ conn=self.mock_db_conn, **change_request_in
)
# --- 测试 get_change_request_by_id ---
@@ -85,23 +83,21 @@ async def test_get_change_request_by_id_success(self):
# Execute the test
result = await self.store_change_request_service.get_change_request_by_id(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_user_id
)
# Verify the result
self.assertEqual(result, self.sample_change_request_dict)
self.mock_store_change_request_crud.get_change_request_by_id.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_user_id
)
# --- 测试 get_change_requests_by_store_id ---
async def test_get_change_requests_by_store_id_success(self):
# Mock the CRUD method
- self.mock_store_change_request_crud.get_change_requests_by_store_id.return_value = [self.sample_change_request_dict]
+ self.mock_store_change_request_crud.get_change_requests_by_store_id.return_value = [
+ self.sample_change_request_dict
+ ]
# Execute the test
result = await self.store_change_request_service.get_change_requests_by_store_id(
@@ -110,7 +106,7 @@ async def test_get_change_requests_by_store_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
# Verify the result
@@ -121,13 +117,15 @@ async def test_get_change_requests_by_store_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
# --- 测试 get_change_requests_by_user_id ---
async def test_get_change_requests_by_user_id_success(self):
# Mock the CRUD method
- self.mock_store_change_request_crud.get_change_requests_by_user_id.return_value = [self.sample_change_request_dict]
+ self.mock_store_change_request_crud.get_change_requests_by_user_id.return_value = [
+ self.sample_change_request_dict
+ ]
# Execute the test
result = await self.store_change_request_service.get_change_requests_by_user_id(
@@ -136,7 +134,7 @@ async def test_get_change_requests_by_user_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
# Verify the result
@@ -147,7 +145,7 @@ async def test_get_change_requests_by_user_id_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_user_id
+ actor_id=self.test_user_id,
)
# --- 测试 get_all_pending_requests ---
@@ -157,19 +155,13 @@ async def test_get_all_pending_requests_success(self):
# Execute the test
result = await self.store_change_request_service.get_all_pending_requests(
- conn=self.mock_db_conn,
- limit=10,
- offset=0,
- actor_id=self.test_admin_id
+ conn=self.mock_db_conn, limit=10, offset=0, actor_id=self.test_admin_id
)
# Verify the result
self.assertEqual(result, [self.sample_change_request_dict])
self.mock_store_change_request_crud.get_all_pending_requests.assert_called_once_with(
- conn=self.mock_db_conn,
- limit=10,
- offset=0,
- actor_id=self.test_admin_id
+ conn=self.mock_db_conn, limit=10, offset=0, actor_id=self.test_admin_id
)
# --- 测试 get_filtered_requests ---
@@ -186,7 +178,7 @@ async def test_get_filtered_requests_success(self):
status="PENDING_APPROVAL",
limit=10,
offset=0,
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# Verify the result
@@ -202,16 +194,13 @@ async def test_get_filtered_requests_success(self):
end_date=None,
limit=10,
offset=0,
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# --- 测试 update_request ---
async def test_update_request_success(self):
# Prepare test data
- update_data = {
- "ProposedData_JSON": {"StoreName": "Updated Store Name"},
- "SubmitterNotes": "Updated notes"
- }
+ update_data = {"ProposedData_JSON": {"StoreName": "Updated Store Name"}, "SubmitterNotes": "Updated notes"}
# Mock the CRUD method
updated_dict = self.sample_change_request_dict.copy()
@@ -221,19 +210,13 @@ async def test_update_request_success(self):
# Execute the test
result = await self.store_change_request_service.update_request(
- conn=self.mock_db_conn,
- request_id=1,
- update_data=update_data,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, update_data=update_data, actor_id=self.test_user_id
)
# Verify the result
self.assertEqual(result, updated_dict)
self.mock_store_change_request_crud.update_request.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- update_data=update_data,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, update_data=update_data, actor_id=self.test_user_id
)
# --- 测试 update_request_status ---
@@ -242,9 +225,9 @@ async def test_update_request_status_success(self):
self.mock_store_change_request_crud.get_change_request_by_id.return_value = {
"Status": "PENDING_APPROVAL",
"ChangeRequestID": 1,
- "RequestingUserID": self.test_user_id
+ "RequestingUserID": self.test_user_id,
}
-
+
# Mock the CRUD method
updated_dict = self.sample_change_request_dict.copy()
updated_dict["Status"] = "APPROVED"
@@ -260,7 +243,7 @@ async def test_update_request_status_success(self):
status="APPROVED",
admin_id=self.test_admin_id,
admin_notes="Approved by admin",
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# Verify the result
@@ -271,7 +254,7 @@ async def test_update_request_status_success(self):
status="APPROVED",
admin_id=self.test_admin_id,
admin_notes="Approved by admin",
- actor_id=self.test_admin_id
+ actor_id=self.test_admin_id,
)
# --- 测试 cancel_request ---
@@ -281,19 +264,15 @@ async def test_cancel_request_success(self):
# Execute the test
result = await self.store_change_request_service.cancel_request(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_user_id
)
# Verify the result
self.assertTrue(result)
self.mock_store_change_request_crud.cancel_request.assert_called_once_with(
- conn=self.mock_db_conn,
- request_id=1,
- actor_id=self.test_user_id
+ conn=self.mock_db_conn, request_id=1, actor_id=self.test_user_id
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/src/backend/test/unit/service/test_store_change_request_service_v2.py b/src/backend/test/unit/service/test_store_change_request_service_v2.py
index 295f7ba..653c78c 100644
--- a/src/backend/test/unit/service/test_store_change_request_service_v2.py
+++ b/src/backend/test/unit/service/test_store_change_request_service_v2.py
@@ -106,9 +106,7 @@ def setUp(self):
# --- Test submit_new_request ---
async def test_submit_new_request_store_create_success(self):
- proposed_data_schema = ProposedStoreData(
- StoreName="My New Store", Description="The best one"
- )
+ proposed_data_schema = ProposedStoreData(StoreName="My New Store", Description="The best one")
request_in = SCR_CreateRequest(
RequestType=RequestTypeEnum.STORE_CREATE,
ProposedData_JSON=proposed_data_schema,
@@ -194,9 +192,7 @@ async def test_get_request_details_success(self):
async def test_get_request_details_not_found_raises_exception(self):
self.mock_scr_crud.get_request_by_id.return_value = None
- with self.assertRaisesRegex(
- StoreNotFoundException, "StoreChangeRequest with ID 999 not found."
- ):
+ with self.assertRaisesRegex(StoreNotFoundException, "StoreChangeRequest with ID 999 not found."):
await self.service.get_request_details(
db=self.mock_db_conn, change_request_id=999, actor_user=self.mock_requesting_user
)
@@ -273,9 +269,7 @@ async def test_user_cancel_request_success(self):
# @patch('backend.app.services.store_change_request_service_v2._is_actor_admin', return_value=True)
async def test_admin_review_and_apply_store_create_success(self): # Removed mock_is_admin_func
pcr_id = self.base_scr_dict_from_crud["ChangeRequestID"]
- review_data = SCR_UpdateByAdminRequest(
- Status=StatusEnum.APPROVED, AdminNotes="Looks good for store creation"
- )
+ review_data = SCR_UpdateByAdminRequest(Status=StatusEnum.APPROVED, AdminNotes="Looks good for store creation")
proposed_data_for_create_dict = {
"StoreName": "Brand New Store From PCR",
@@ -316,9 +310,7 @@ async def test_admin_review_and_apply_store_create_success(self): # Removed moc
"Status": StatusEnum.APPLIED.value,
"StoreID": 789,
}
- self.mock_scr_crud.update_request_store_id_and_status_applied = MagicMock(
- return_value=final_applied_pcr_state
- )
+ self.mock_scr_crud.update_request_store_id_and_status_applied = MagicMock(return_value=final_applied_pcr_state)
response = await self.service.admin_review_request(
db=self.mock_db_conn,
@@ -332,9 +324,7 @@ async def test_admin_review_and_apply_store_create_success(self): # Removed moc
self.mock_store_crud.create_store.assert_called_once()
create_store_kwargs = self.mock_store_crud.create_store.call_args.kwargs
- self.assertEqual(
- create_store_kwargs["store_name"], proposed_data_for_create_dict["StoreName"]
- )
+ self.assertEqual(create_store_kwargs["store_name"], proposed_data_for_create_dict["StoreName"])
self.mock_scr_crud.update_request_store_id_and_status_applied.assert_called_once_with(
conn=self.mock_db_conn,
@@ -371,9 +361,7 @@ async def test_apply_approved_request_store_update_success(self): # Removed moc
self.mock_store_crud.update_store.return_value = mock_updated_store_from_db
final_applied_pcr_state = {**approved_pcr_for_update, "Status": StatusEnum.APPLIED.value}
- self.mock_scr_crud.update_request_store_id_and_status_applied = MagicMock(
- return_value=final_applied_pcr_state
- )
+ self.mock_scr_crud.update_request_store_id_and_status_applied = MagicMock(return_value=final_applied_pcr_state)
response = await self.service.apply_approved_request(
db=self.mock_db_conn, change_request_id=pcr_id, applier_user=self.mock_admin_user
diff --git a/src/backend/test/unit/service/test_store_service.py b/src/backend/test/unit/service/test_store_service.py
index 6639287..81955b4 100644
--- a/src/backend/test/unit/service/test_store_service.py
+++ b/src/backend/test/unit/service/test_store_service.py
@@ -16,12 +16,13 @@
StoreListResponse,
StoreStatusEnum,
StoreListSimpleResponse,
- StoreSimpleResponse
+ StoreSimpleResponse,
)
from backend.app.utils.exceptions import StoreNotFoundException, UserNotFoundException, PermissionDeniedException
# 假设 Connection 类型来自于 sqlalchemy.engine.base
from sqlalchemy.engine.base import Connection
+
# 导入 text 用于模拟 COUNT 查询的返回
from sqlalchemy import text
@@ -32,10 +33,7 @@ def setUp(self):
self.mock_store_crud = MagicMock(spec=StoreCRUD)
self.mock_user_crud = MagicMock(spec=UserCRUD)
- self.store_service = StoreService(
- store_crud=self.mock_store_crud,
- user_crud=self.mock_user_crud
- )
+ self.store_service = StoreService(store_crud=self.mock_store_crud, user_crud=self.mock_user_crud)
self.mock_db_conn = MagicMock(spec=Connection)
self.test_user_id = 1
@@ -45,43 +43,53 @@ def setUp(self):
self.sample_user_data = {"UserID": self.test_owner_id, "Username": "owner", "UserRole": "merchant"}
self.sample_store_data_dict_active = {
- "StoreID": 101, "StoreName": "Test Store A", "OwnerUserID": self.test_owner_id,
- "Description": "A great store", "LogoURL": "http://logo.com/a.png",
+ "StoreID": 101,
+ "StoreName": "Test Store A",
+ "OwnerUserID": self.test_owner_id,
+ "Description": "A great store",
+ "LogoURL": "http://logo.com/a.png",
"StoreStatus": StoreStatusEnum.ACTIVE.value,
"CreationDate": datetime.datetime.now(datetime.timezone.utc),
- "LastUpdatedDate": datetime.datetime.now(datetime.timezone.utc)
+ "LastUpdatedDate": datetime.datetime.now(datetime.timezone.utc),
}
self.sample_store_data_dict_inactive = {
- "StoreID": 102, "StoreName": "Inactive Store B", "OwnerUserID": self.test_owner_id,
- "Description": "An inactive store", "LogoURL": "http://logo.com/b.png",
+ "StoreID": 102,
+ "StoreName": "Inactive Store B",
+ "OwnerUserID": self.test_owner_id,
+ "Description": "An inactive store",
+ "LogoURL": "http://logo.com/b.png",
"StoreStatus": StoreStatusEnum.INACTIVE_BY_MERCHANT.value,
"CreationDate": datetime.datetime.now(datetime.timezone.utc),
- "LastUpdatedDate": datetime.datetime.now(datetime.timezone.utc)
+ "LastUpdatedDate": datetime.datetime.now(datetime.timezone.utc),
}
self.sample_store_simple_data_active = {
"StoreID": self.sample_store_data_dict_active["StoreID"],
"StoreName": self.sample_store_data_dict_active["StoreName"],
- "LogoURL": self.sample_store_data_dict_active["LogoURL"]
+ "LogoURL": self.sample_store_data_dict_active["LogoURL"],
}
self.sample_store_simple_data_active_2 = {
"StoreID": 103,
"StoreName": "Active Store C",
- "LogoURL": "http://logo.com/c.png"
+ "LogoURL": "http://logo.com/c.png",
}
# --- Test create_new_store ---
async def test_create_new_store_success_by_owner(self):
store_create_schema = StoreCreate(
- StoreName="New Awesome Store", OwnerUserID=self.test_owner_id,
- Description="The best store ever", LogoURL="http://new.logo/img.png",
- StoreStatus=StoreStatusEnum.ACTIVE, CreationDate=datetime.datetime.now(datetime.timezone.utc)
+ StoreName="New Awesome Store",
+ OwnerUserID=self.test_owner_id,
+ Description="The best store ever",
+ LogoURL="http://new.logo/img.png",
+ StoreStatus=StoreStatusEnum.ACTIVE,
+ CreationDate=datetime.datetime.now(datetime.timezone.utc),
)
self.mock_user_crud.get_user_by_id.return_value = self.sample_user_data
created_store_from_crud = {
- "StoreID": 201, **store_create_schema.model_dump(exclude={"CreationDate"}),
+ "StoreID": 201,
+ **store_create_schema.model_dump(exclude={"CreationDate"}),
"CreationDate": store_create_schema.CreationDate,
"LastUpdatedDate": store_create_schema.CreationDate,
- "StoreStatus": store_create_schema.StoreStatus.value
+ "StoreStatus": store_create_schema.StoreStatus.value,
}
self.mock_store_crud.create_store.return_value = created_store_from_crud
@@ -104,23 +112,27 @@ async def test_create_new_store_success_by_owner(self):
logo_url=store_create_schema.LogoURL,
store_status=store_create_schema.StoreStatus,
creation_date=store_create_schema.CreationDate,
- actor_id=self.test_owner_id
+ actor_id=self.test_owner_id,
)
- @patch('backend.app.services.store_service.logger') # Patch module-level logger
+ @patch("backend.app.services.store_service.logger") # Patch module-level logger
async def test_create_new_store_actor_not_owner_logs_warning(self, mock_logger):
store_create_schema = StoreCreate(
- StoreName="Another Store", OwnerUserID=self.test_owner_id, # User 1 owns
- Description="Desc", LogoURL=None,
- StoreStatus=StoreStatusEnum.ACTIVE, CreationDate=datetime.datetime.now(datetime.timezone.utc)
+ StoreName="Another Store",
+ OwnerUserID=self.test_owner_id, # User 1 owns
+ Description="Desc",
+ LogoURL=None,
+ StoreStatus=StoreStatusEnum.ACTIVE,
+ CreationDate=datetime.datetime.now(datetime.timezone.utc),
)
actor_creating = self.another_user_id # User 2 (another_user_id) is creating
self.mock_user_crud.get_user_by_id.return_value = self.sample_user_data
created_store_from_crud = {
- "StoreID": 202, **store_create_schema.model_dump(exclude={"CreationDate"}),
+ "StoreID": 202,
+ **store_create_schema.model_dump(exclude={"CreationDate"}),
"CreationDate": store_create_schema.CreationDate,
"LastUpdatedDate": store_create_schema.CreationDate,
- "StoreStatus": store_create_schema.StoreStatus.value
+ "StoreStatus": store_create_schema.StoreStatus.value,
}
self.mock_store_crud.create_store.return_value = created_store_from_crud
@@ -131,9 +143,12 @@ async def test_create_new_store_actor_not_owner_logs_warning(self, mock_logger):
async def test_create_new_store_owner_not_found(self):
store_create_schema = StoreCreate(
- StoreName="Store With No Owner", OwnerUserID=999,
- Description="Desc", LogoURL=None,
- StoreStatus=StoreStatusEnum.ACTIVE, CreationDate=datetime.datetime.now(datetime.timezone.utc)
+ StoreName="Store With No Owner",
+ OwnerUserID=999,
+ Description="Desc",
+ LogoURL=None,
+ StoreStatus=StoreStatusEnum.ACTIVE,
+ CreationDate=datetime.datetime.now(datetime.timezone.utc),
)
self.mock_user_crud.get_user_by_id.return_value = None
@@ -144,8 +159,12 @@ async def test_create_new_store_owner_not_found(self):
async def test_create_new_store_crud_fails(self):
store_create_schema = StoreCreate(
- StoreName="Store CRUD Fail", OwnerUserID=self.test_owner_id, Description="Desc",
- LogoURL=None, StoreStatus=StoreStatusEnum.ACTIVE, CreationDate=datetime.datetime.now(datetime.timezone.utc)
+ StoreName="Store CRUD Fail",
+ OwnerUserID=self.test_owner_id,
+ Description="Desc",
+ LogoURL=None,
+ StoreStatus=StoreStatusEnum.ACTIVE,
+ CreationDate=datetime.datetime.now(datetime.timezone.utc),
)
self.mock_user_crud.get_user_by_id.return_value = self.sample_user_data
self.mock_store_crud.create_store.return_value = None
@@ -205,10 +224,7 @@ async def test_merchant_get_store_by_id_not_found(self):
# --- Test user_get_stores_simple ---
async def test_user_get_stores_simple_with_pagination(self):
- paginated_active_stores_data = [
- self.sample_store_simple_data_active,
- self.sample_store_simple_data_active_2
- ]
+ paginated_active_stores_data = [self.sample_store_simple_data_active, self.sample_store_simple_data_active_2]
self.mock_store_crud.get_all_stores_page.return_value = paginated_active_stores_data
offset = 0
@@ -221,8 +237,7 @@ async def test_user_get_stores_simple_with_pagination(self):
self.assertEqual(response.Count, 2)
self.assertEqual(len(response.StoreList), 2)
self.mock_store_crud.get_all_stores_page.assert_called_once_with(
- conn=self.mock_db_conn, store_status=StoreStatusEnum.ACTIVE,
- offset=offset, limit=limit, actor_id=None
+ conn=self.mock_db_conn, store_status=StoreStatusEnum.ACTIVE, offset=offset, limit=limit, actor_id=None
)
self.mock_store_crud.get_all_stores.assert_not_called()
@@ -230,13 +245,11 @@ async def test_user_get_stores_simple_no_pagination_gets_all_active(self):
all_active_stores_data = [
self.sample_store_simple_data_active,
self.sample_store_simple_data_active_2,
- {"StoreID": 104, "StoreName": "Active Store D", "LogoURL": "d.png"}
+ {"StoreID": 104, "StoreName": "Active Store D", "LogoURL": "d.png"},
]
self.mock_store_crud.get_all_stores.return_value = all_active_stores_data
- response = await self.store_service.user_get_stores_simple(
- db=self.mock_db_conn, offset_and_limit=None
- )
+ response = await self.store_service.user_get_stores_simple(db=self.mock_db_conn, offset_and_limit=None)
self.assertIsInstance(response, StoreListSimpleResponse)
self.assertEqual(response.Count, 3)
@@ -247,15 +260,14 @@ async def test_user_get_stores_simple_no_pagination_gets_all_active(self):
self.mock_store_crud.get_all_stores_page.assert_not_called()
# --- Test get_stores_by_owner ---
- @patch('backend.app.services.store_service.logger') # Patch module-level logger
+ @patch("backend.app.services.store_service.logger") # Patch module-level logger
async def test_get_stores_by_owner_success_self(self, mock_logger):
stores_for_owner = [self.sample_store_data_dict_active, self.sample_store_data_dict_inactive]
self.mock_store_crud.get_stores_by_owner_user_id.return_value = stores_for_owner
# SUT calculates Count as len(stores_data)
response = await self.store_service.get_stores_by_owner(
- db=self.mock_db_conn, owner_user_id=self.test_owner_id, actor_id=self.test_owner_id,
- offset=0, limit=10
+ db=self.mock_db_conn, owner_user_id=self.test_owner_id, actor_id=self.test_owner_id, offset=0, limit=10
)
self.assertEqual(response.Count, 2)
self.assertEqual(len(response.StoreList), 2)
@@ -265,20 +277,19 @@ async def test_get_stores_by_owner_success_self(self, mock_logger):
for call_args in mock_logger.warning.call_args_list:
self.assertNotIn(f"ActorID {self.test_owner_id} (not owner)", call_args[0][0])
- @patch('backend.app.services.store_service.logger') # Patch module-level logger
+ @patch("backend.app.services.store_service.logger") # Patch module-level logger
async def test_get_stores_by_owner_actor_not_owner_logs_warning(self, mock_logger):
# SUT currently only logs a warning, doesn't raise PermissionDeniedException
self.mock_store_crud.get_stores_by_owner_user_id.return_value = []
await self.store_service.get_stores_by_owner(
- db=self.mock_db_conn, owner_user_id=self.test_owner_id,
- actor_id=self.another_user_id # Different actor
+ db=self.mock_db_conn, owner_user_id=self.test_owner_id, actor_id=self.another_user_id # Different actor
)
mock_logger.warning.assert_called_once_with(
f"ActorID {self.another_user_id} (not owner) attempting to access stores for OwnerUserID {self.test_owner_id}."
)
# --- Test get_all_stores_full (Admin) ---
- @patch('backend.app.services.store_service.logger')
+ @patch("backend.app.services.store_service.logger")
async def test_get_all_stores_full_with_pagination_logs_warning(self, mock_logger):
# SUT's permission check is a TODO, currently logs a warning
self.mock_store_crud.get_all_stores_page.return_value = [self.sample_store_data_dict_active]
@@ -293,10 +304,12 @@ async def test_get_all_stores_full_with_pagination_logs_warning(self, mock_logge
)
# --- Test update_store_info ---
- @patch('backend.app.services.store_service.logger')
+ @patch("backend.app.services.store_service.logger")
async def test_update_store_info_success_by_owner(self, mock_logger):
store_id = self.sample_store_data_dict_active["StoreID"] # Owner is self.test_owner_id
- update_schema = StoreUpdate(StoreName="Owner Updated Store Name", Description=None, LogoURL=None, StoreStatus=None)
+ update_schema = StoreUpdate(
+ StoreName="Owner Updated Store Name", Description=None, LogoURL=None, StoreStatus=None
+ )
self.mock_store_crud.get_store_by_id.return_value = self.sample_store_data_dict_active
updated_dict_from_crud = {**self.sample_store_data_dict_active, "StoreName": "Owner Updated Store Name"}
@@ -307,16 +320,19 @@ async def test_update_store_info_success_by_owner(self, mock_logger):
)
self.assertEqual(response.StoreName, "Owner Updated Store Name")
self.mock_store_crud.update_store.assert_called_once_with(
- conn=self.mock_db_conn, store_id=store_id, actor_id=self.test_owner_id,
- store_name="Owner Updated Store Name" # Only this field should be passed to CRUD
+ conn=self.mock_db_conn,
+ store_id=store_id,
+ actor_id=self.test_owner_id,
+ store_name="Owner Updated Store Name", # Only this field should be passed to CRUD
)
-
- @patch('backend.app.services.store_service.logger')
+ @patch("backend.app.services.store_service.logger")
async def test_update_store_info_not_owner_logs_warning(self, mock_logger):
# SUT currently logs warning and proceeds if actor is not owner (and not admin, which is not checked here)
store_id = self.sample_store_data_dict_active["StoreID"] # Owner is self.test_owner_id
- update_schema = StoreUpdate(StoreName="Attempt Update by NonOwner", Description=None, LogoURL=None, StoreStatus=None)
+ update_schema = StoreUpdate(
+ StoreName="Attempt Update by NonOwner", Description=None, LogoURL=None, StoreStatus=None
+ )
self.mock_store_crud.get_store_by_id.return_value = self.sample_store_data_dict_active
# Assume update still happens for this test of current SUT logic
@@ -324,8 +340,10 @@ async def test_update_store_info_not_owner_logs_warning(self, mock_logger):
self.mock_store_crud.update_store.return_value = updated_dict_from_crud
await self.store_service.update_store_info(
- db=self.mock_db_conn, store_id=store_id, store_in=update_schema,
- actor_id=self.another_user_id # User 2 (not owner)
+ db=self.mock_db_conn,
+ store_id=store_id,
+ store_in=update_schema,
+ actor_id=self.another_user_id, # User 2 (not owner)
)
mock_logger.warning.assert_any_call( # Check for the permission TODO log
f"ActorID {self.another_user_id} (not admin) attempting to update StoreID {store_id}. This should be handled at the endpoint level."
@@ -337,7 +355,7 @@ async def test_update_store_info_not_owner_logs_warning(self, mock_logger):
# So, this test should actually expect PermissionDeniedException.
# I will adjust it to reflect that, assuming _is_admin is not globally True.
- @patch('backend.app.services.store_service.logger')
+ @patch("backend.app.services.store_service.logger")
async def test_update_store_info_permission_denied_if_not_owner_and_not_admin_in_sut(self, mock_logger):
# This test assumes the SUT's TODO for permission check is resolved to actually
# deny if not owner and not admin (which means _is_admin would be called and return False)
@@ -345,13 +363,16 @@ async def test_update_store_info_permission_denied_if_not_owner_and_not_admin_in
update_schema = StoreUpdate(StoreName="Attempt Update", Description=None, LogoURL=None, StoreStatus=None)
updated_sample_store_data_dict = self.sample_store_data_dict_active.copy()
updated_sample_store_data_dict["StoreName"] = "Attempt Update"
- self.mock_store_crud.get_store_by_id.side_effect = [self.sample_store_data_dict_active,]
+ self.mock_store_crud.get_store_by_id.side_effect = [
+ self.sample_store_data_dict_active,
+ ]
self.mock_store_crud.update_store.return_value = updated_sample_store_data_dict
-
await self.store_service.update_store_info(
- db=self.mock_db_conn, store_id=store_id, store_in=update_schema,
- actor_id=self.another_user_id # User 2 (not owner)
+ db=self.mock_db_conn,
+ store_id=store_id,
+ store_in=update_schema,
+ actor_id=self.another_user_id, # User 2 (not owner)
)
# The warning for "not admin" might still be logged before the PermissionDeniedException
@@ -371,6 +392,5 @@ async def test_update_store_info_no_fields_in_request(self):
self.mock_store_crud.update_store.assert_not_called()
-
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main(verbosity=2)