diff --git a/fastapi_assets/core/__init__.py b/fastapi_assets/core/__init__.py index 03d97ae..b2d120c 100644 --- a/fastapi_assets/core/__init__.py +++ b/fastapi_assets/core/__init__.py @@ -1 +1,4 @@ """Module for core functionalities of FastAPI Assets.""" + +from fastapi_assets.core.base_validator import BaseValidator +from fastapi_assets.core.exceptions import ValidationError diff --git a/fastapi_assets/core/base_validator.py b/fastapi_assets/core/base_validator.py index 1824dbb..6b86ad6 100644 --- a/fastapi_assets/core/base_validator.py +++ b/fastapi_assets/core/base_validator.py @@ -1,9 +1,10 @@ """Base classes for FastAPI validation dependencies.""" import abc -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, List from fastapi import HTTPException from fastapi_assets.core.exceptions import ValidationError +import inspect class BaseValidator(abc.ABC): @@ -20,7 +21,7 @@ class BaseValidator(abc.ABC): from fastapi import Header from fastapi_assets.core.base_validator import BaseValidator, ValidationError class MyValidator(BaseValidator): - def _validate_logic(self, token: str) -> None: + def _validate(self, token: str) -> None: # This method is testable without FastAPI if not token.startswith("sk_"): # Raise the logic-level exception @@ -44,6 +45,7 @@ def __init__( *, status_code: int = 400, error_detail: Union[str, Callable[[Any], str]] = "Validation failed.", + validators: Optional[List[Callable]] = None, ): """ Initializes the base validator. @@ -54,9 +56,11 @@ def __init__( error_detail: The default error message. Can be a static string or a callable that takes the invalid value as its argument and returns a dynamic error string. + validators: Optional list of callables for custom validation logic. """ self._status_code = status_code self._error_detail = error_detail + self._custom_validators = validators or [] def _raise_error( self, @@ -65,17 +69,25 @@ def _raise_error( detail: Optional[Union[str, Callable[[Any], str]]] = None, ) -> None: """ - Helper method to raise a standardized HTTPException. + Raises a standardized HTTPException with resolved error detail. - It automatically resolves callable error details. + This helper method handles both static error strings and dynamic error + callables, automatically resolving them to a final error message before + raising the HTTPException. Args: - value (Optional[Any]): The value that failed validation. This is passed - to the error_detail callable, if it is one. - status_code (Optional[int]): A specific status code for this failure, - overriding the instance's default status_code. - detail (Optional[Union[str, Callable[[Any], str]]]): A specific error detail for this failure, - overriding the instance's default error_detail. + value: The value that failed validation. Passed to the error_detail + callable if it is callable. + status_code: A specific HTTP status code for this failure, overriding + the instance's default status_code. + detail: A specific error detail message (string or callable) for this + failure, overriding the instance's default error_detail. + + Returns: + None + + Raises: + HTTPException: Always raises with the resolved status code and detail. """ final_status_code = status_code if status_code is not None else self._status_code @@ -91,6 +103,59 @@ def _raise_error( raise HTTPException(status_code=final_status_code, detail=final_detail) + @abc.abstractmethod + async def _validate(self, value: Any) -> Any: + """ + Abstract method for pure validation logic. + + Subclasses MUST implement this method to perform the actual + validation. This method should raise `ValidationError` if + validation fails. + + Args: + value: The value to validate. + + Returns: + The validated value, which can be of any type depending on the validator. + """ + raise NotImplementedError( + "Subclasses of BaseValidator must implement the _validate method." + ) + + async def _validate_custom(self, value: Any) -> None: + """ + Executes all configured custom validator functions. + + Iterates through the list of custom validators, supporting both + synchronous and asynchronous validator functions. Catches exceptions + and converts them to ValidationError instances. + + Args: + value: The value to validate using custom validators. + + Returns: + None + + Raises: + ValidationError: If any validator raises an exception or explicitly + raises ValidationError. + """ + if self._custom_validators is None: + return + + for validator_func in self._custom_validators: + try: + if inspect.iscoroutinefunction(validator_func): + await validator_func(value) + else: + validator_func(value) + except ValidationError: + raise # Re-raise explicit validation errors + except Exception as e: + # Catch any other exception from the validator + detail = f"Custom validation failed. Error: {e}" + raise ValidationError(detail=detail, status_code=self._status_code) + @abc.abstractmethod def __call__(self, *args: Any, **kwargs: Any) -> Any: """ @@ -111,7 +176,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: class MyValidator(BaseValidator): - def _validate_logic(self, token: str) -> None: + async def _validate(self, token: str) -> None: # This method is testable without FastAPI if not token.startswith("sk_"): # Raise the logic-level exception @@ -120,7 +185,8 @@ def _validate_logic(self, token: str) -> None: def __call__(self, x_token: str = Header(...)): try: # 1. Run the pure validation logic - self._validate_logic(x_token) + await self._validate(x_token) + await self._validate_custom(x_token) except ValidationError as e: # 2. Catch logic error and raise HTTP error self._raise_error( diff --git a/fastapi_assets/metadata_validators/__init__.py b/fastapi_assets/metadata_validators/__init__.py deleted file mode 100644 index 5581d03..0000000 --- a/fastapi_assets/metadata_validators/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for metadata validation in FastAPI Assets.""" diff --git a/fastapi_assets/request_validators/__init__.py b/fastapi_assets/request_validators/__init__.py index 860dc57..d34a237 100644 --- a/fastapi_assets/request_validators/__init__.py +++ b/fastapi_assets/request_validators/__init__.py @@ -1 +1,5 @@ """Module for request validation in FastAPI Assets.""" + +from fastapi_assets.request_validators.header_validator import HeaderValidator +from fastapi_assets.request_validators.cookie_validator import CookieValidator +from fastapi_assets.request_validators.path_validator import PathValidator diff --git a/fastapi_assets/request_validators/cookie_validator.py b/fastapi_assets/request_validators/cookie_validator.py index 3b394e2..355611a 100644 --- a/fastapi_assets/request_validators/cookie_validator.py +++ b/fastapi_assets/request_validators/cookie_validator.py @@ -3,11 +3,8 @@ import inspect import re from typing import Any, Callable, Dict, List, Optional, Union - from fastapi import Request, status - -from fastapi_assets.core.base_validator import BaseValidator -from fastapi_assets.core.exceptions import ValidationError +from fastapi_assets.core import BaseValidator, ValidationError # Pre-built regex patterns for the `format` parameter @@ -20,7 +17,7 @@ } -class CookieAssert(BaseValidator): +class CookieValidator(BaseValidator): """ A class-based dependency to validate FastAPI Cookies with granular control. @@ -31,22 +28,27 @@ class CookieAssert(BaseValidator): Example: ```python from fastapi import FastAPI, Depends + from fastapi_assets.core import ValidationError + from fastapi_assets.request_validators import CookieValidator app = FastAPI() - def is_whitelisted(user_id: str) -> bool: - # Logic to check if user_id is in a whitelist - return user_id in {"user_1", "user_2"} + def is_whitelisted(user_id: str) -> None: + # Logic to check if user_id is in a whitelist. + # Custom validators must raise ValidationError on failure. + if user_id not in {"user_1", "user_2"}: + raise ValidationError("User is not whitelisted.") - validate_session = CookieAssert( - "session-id", # This is the required 'alias' + # Create validators that will extract cookies from the incoming request + validate_session = CookieValidator( + "session-id", # Cookie name to extract from request.cookies format="uuid4", on_required_error_detail="Invalid or missing session ID.", on_pattern_error_detail="Session ID must be a valid UUIDv4." ) - validate_user = CookieAssert( - "user-id", + validate_user = CookieValidator( + "user-id", # Cookie name to extract from request.cookies min_length=6, validators=[is_whitelisted], on_length_error_detail="User ID must be at least 6 characters.", @@ -55,10 +57,14 @@ def is_whitelisted(user_id: str) -> bool: @app.get("/items/") async def read_items(session_id: str = Depends(validate_session)): + # validate_session extracts the "session-id" cookie from the request, + # validates it, and returns the validated value return {"session_id": session_id} @app.get("/users/me") async def read_user(user_id: str = Depends(validate_user)): + # validate_user extracts the "user-id" cookie from the request, + # validates it (including length and custom validators), and returns it return {"user_id": user_id} ``` """ @@ -80,7 +86,7 @@ def __init__( regex: Optional[str] = None, pattern: Optional[str] = None, format: Optional[str] = None, - validators: Optional[List[Callable[[Any], bool]]] = None, + validators: Optional[List[Callable[[Any], Any]]] = None, # Granular Error Messages on_required_error_detail: str = "Cookie is required.", on_numeric_error_detail: str = "Cookie value must be a number.", @@ -89,7 +95,7 @@ def __init__( on_pattern_error_detail: str = "Cookie has an invalid format.", on_validator_error_detail: str = "Cookie failed custom validation.", # Base Error - status_code: int = status.HTTP_400_BAD_REQUEST, + status_code: int = 400, error_detail: str = "Cookie validation failed.", ) -> None: """ @@ -126,7 +132,7 @@ def __init__( ValueError: If `regex`/`pattern` and `format` are used simultaneously. ValueError: If an unknown `format` key is provided. """ - super().__init__(status_code=status_code, error_detail=error_detail) + super().__init__(status_code=status_code, error_detail=error_detail, validators=validators) # Store Core Parameters self.alias = alias @@ -158,21 +164,16 @@ def __init__( # Handle Regex/Pattern self.final_regex_str: Optional[str] = regex or pattern if self.final_regex_str and format: - raise ValueError( - "Cannot use 'regex'/'pattern' and 'format' simultaneously." - ) + raise ValueError("Cannot use 'regex'/'pattern' and 'format' simultaneously.") if format: if format not in PRE_BUILT_PATTERNS: raise ValueError( - f"Unknown format: '{format}'. " - f"Available: {list(PRE_BUILT_PATTERNS.keys())}" + f"Unknown format: '{format}'. Available: {list(PRE_BUILT_PATTERNS.keys())}" ) self.final_regex_str = PRE_BUILT_PATTERNS[format] self.final_regex: Optional[re.Pattern[str]] = ( - re.compile(self.final_regex_str) - if self.final_regex_str - else None + re.compile(self.final_regex_str) if self.final_regex_str else None ) def _validate_numeric(self, value: str) -> Optional[float]: @@ -180,7 +181,11 @@ def _validate_numeric(self, value: str) -> Optional[float]: Tries to convert value to float. Returns float or None. This check is only triggered if gt, ge, lt, or le are set. - + Args: + value (str): The cookie value to convert. + Returns: + Optional[float]: The converted float value, or None if numeric checks + are not applicable. Raises: ValidationError: If conversion to float fails. """ @@ -198,6 +203,10 @@ def _validate_comparison(self, value: float) -> None: """ Checks gt, ge, lt, le rules against a numeric value. + Args: + value (float): The numeric value to compare. + Returns: + None Raises: ValidationError: If any comparison fails. """ @@ -226,6 +235,12 @@ def _validate_length(self, value: str) -> None: """ Checks min_length and max_length rules. + Args: + value (str): The cookie value to check. + + Returns: + None + Raises: ValidationError: If length constraints fail. """ @@ -244,6 +259,10 @@ def _validate_length(self, value: str) -> None: def _validate_pattern(self, value: str) -> None: """ Checks regex/format pattern rule. + Args: + value (str): The cookie value to check. + Returns: + None Raises: ValidationError: If the regex pattern does not match. @@ -254,43 +273,7 @@ def _validate_pattern(self, value: str) -> None: status_code=status.HTTP_400_BAD_REQUEST, ) - async def _validate_custom(self, value: str) -> None: - """ - Runs all custom validator functions (sync or async). - - Raises: - ValidationError: If any function returns False or raises an Exception. - """ - if not self.custom_validators: - return - - for validator_func in self.custom_validators: - try: - is_valid = None - # Handle both sync and async validators - if inspect.iscoroutinefunction(validator_func): - is_valid = await validator_func(value) - else: - is_valid = validator_func(value) - - if not is_valid: - raise ValidationError( - detail=self.err_validator, - status_code=status.HTTP_400_BAD_REQUEST, - ) - except ValidationError: - # Re-raise our own validation errors - raise - except Exception as e: - # Validator function raising an error is a validation failure - raise ValidationError( - detail=f"{self.err_validator}: {e}", - status_code=status.HTTP_400_BAD_REQUEST, - ) - - async def _validate_logic( - self, cookie_value: Optional[str] - ) -> Union[float, str, None]: + async def _validate(self, cookie_value: Optional[str]) -> Union[float, str, None]: """ Pure validation logic (testable without FastAPI). @@ -355,11 +338,10 @@ async def __call__(self, request: Request) -> Union[float, str, None]: try: # Extract cookie value from request cookie_value: Optional[str] = request.cookies.get(self.alias) - # Run all validation logic - return await self._validate_logic(cookie_value) + return await self._validate(cookie_value) except ValidationError as e: # Convert validation error to HTTP exception self._raise_error(detail=e.detail, status_code=e.status_code) - return None # pragma: no cover \ No newline at end of file + return None # pragma: no cover diff --git a/fastapi_assets/request_validators/header_validator.py b/fastapi_assets/request_validators/header_validator.py index 7c783e6..4ae9dab 100644 --- a/fastapi_assets/request_validators/header_validator.py +++ b/fastapi_assets/request_validators/header_validator.py @@ -1,11 +1,13 @@ """HeaderValidator for validating HTTP headers in FastAPI.""" +from inspect import Signature, Parameter import re -from typing import Any, Callable, Dict, List, Optional, Union, Pattern -from fastapi_assets.core.base_validator import BaseValidator, ValidationError +from typing import Any, Callable, Dict, List, Optional, Pattern from fastapi import Header from fastapi.param_functions import _Unset +from fastapi_assets.core import BaseValidator, ValidationError + Undefined = _Unset @@ -22,62 +24,83 @@ class HeaderValidator(BaseValidator): r""" - A general-purpose dependency for validating HTTP request headers in FastAPI. + A dependency for validating HTTP headers with extended rules. - It extends FastAPI's built-in Header with additional validation capabilities - including pattern matching, format validation, allowed values, and custom validators. + Extends FastAPI's `Header` with pattern matching, format validation, + allowed values, and custom validators, providing granular error control. + Example: .. code-block:: python - from fastapi import FastAPI - from fastapi_assets.request_validators.header_validator import HeaderValidator + from fastapi import FastAPI, Depends + from fastapi_assets.request_validators import HeaderValidator app = FastAPI() - # Validate API key header with pattern + def is_valid_api_version(version: str) -> bool: + # Custom validators must raise ValidationError on failure + if version not in ["v1", "v2", "v3"]: + raise ValidationError(detail="Unsupported API version.") + + # Validate required API key header with a specific pattern. + # HeaderValidator extracts the header from the incoming request automatically. api_key_validator = HeaderValidator( alias="X-API-Key", pattern=r"^[a-zA-Z0-9]{32}$", - required=True, - on_error_detail="Invalid API key format" + on_required_error_detail="X-API-Key header is missing.", + on_pattern_error_detail="Invalid API key format." ) - # Validate authorization header with bearer token format + # Validate required authorization header with bearer token format. + # The header is extracted from request.headers by the Header() dependency. auth_validator = HeaderValidator( alias="Authorization", format="bearer_token", - required=True + on_required_error_detail="Authorization header is required.", + on_pattern_error_detail="Invalid Bearer token format." ) - # Validate custom header with allowed values + # Validate optional custom header with custom validator and a default. + # If not provided, the default value "v1" will be used. version_validator = HeaderValidator( alias="X-API-Version", - allowed_values=["v1", "v2", "v3"], - required=False, - default="v1" + default="v1", + validators=[is_valid_api_version], + on_custom_validator_error_detail="Invalid API version." ) @app.get("/secure") - def secure_endpoint( - api_key: str = api_key_validator, - auth: str = auth_validator, - version: str = version_validator + async def secure_endpoint( + api_key: str = Depends(api_key_validator), + auth: str = Depends(auth_validator), + version: str = Depends(version_validator) ): + # Each dependency automatically extracts and validates the corresponding header, + # returning the validated value to the endpoint return {"message": "Access granted", "version": version} + ``` """ def __init__( self, default: Any = Undefined, *, - required: Optional[bool] = True, alias: Optional[str] = None, convert_underscores: bool = True, pattern: Optional[str] = None, format: Optional[str] = None, allowed_values: Optional[List[str]] = None, - validator: Optional[Callable[[str], bool]] = None, + validators: Optional[List[Callable[[Any], Any]]] = None, + # Standard Header parameters title: Optional[str] = None, description: Optional[str] = None, + # Granular Error Messages + on_required_error_detail: str = "Required header is missing.", + on_pattern_error_detail: str = "Header has an invalid format.", + on_allowed_values_error_detail: str = "Header value is not allowed.", + on_custom_validator_error_detail: str = "Header failed custom validation.", + # Base Error + status_code: int = 400, + error_detail: str = "Header Validation Failed", **header_kwargs: Any, ) -> None: """ @@ -85,73 +108,71 @@ def __init__( Args: default (Any): The default value if the header is not provided. - required Optional[bool]: Explicitly set if the header is not required. - alias (Optional[str]): The alias of the header. This is the actual - header name (e.g., "X-API-Key"). + If not set (or set to `Undefined`), the header is required. + alias (Optional[str]): The alias of the header (the actual + header name, e.g., "X-API-Key"). convert_underscores (bool): If `True` (default), underscores in the variable name will be converted to hyphens in the header name. pattern (Optional[str]): A regex pattern string that the header value must match. - format (Optional[str]): A predefined format name (e.g., "uuid4", - "email", "bearer_token") that the header value must match. + format (Optional[str]): A predefined format name (e.g., "uuid4"). Cannot be used with `pattern`. allowed_values (Optional[List[str]]): A list of exact string values that are allowed for the header. - validator (Optional[Callable[[str], bool]]): A custom callable that - receives the header value and returns `True` if valid, or - `False` (or raises an Exception) if invalid. + validators (Optional[List[Callable]]): A list of custom validation + functions (sync or async) that receive the header value. title (Optional[str]): A title for the header in OpenAPI docs. description (Optional[str]): A description for the header in OpenAPI docs. - **header_kwargs (Any): Additional keyword arguments passed to the - parent `BaseValidator` (for error handling) and the - underlying `fastapi.Header` dependency. - Includes `status_code` (default 400) and `error_detail` - (default "Header Validation Failed") for error responses. - - Raises: - ValueError: If both `pattern` and `format` are specified, or if - an unknown `format` name is provided. + on_required_error_detail (str): Error message if header is missing. + on_pattern_error_detail (str): Error message if pattern/format fails. + on_allowed_values_error_detail (str): Error message if value not allowed. + on_custom_validator_error_detail (str): Error message if custom validator fails. + status_code (int): The default HTTP status code for validation errors. + error_detail (str): A generic fallback error message. + **header_kwargs (Any): Additional keyword arguments passed to FastAPI's Header(). """ - header_kwargs["status_code"] = header_kwargs.get("status_code", 400) - header_kwargs["error_detail"] = header_kwargs.get( - "error_detail", "Header Validation Failed" - ) - # Call super() with default error handling - super().__init__(**header_kwargs) - self._required = required + super().__init__(status_code=status_code, error_detail=error_detail, validators=validators) + + # Store "required" status based on the default value + self._is_required = default is Undefined # Store validation rules self._allowed_values = allowed_values - self._custom_validator = validator + self._custom_validators: list[Callable[..., Any]] = validators or [] - # Define type hints for attributes - self._pattern: Optional[Pattern[str]] = None - self._format_name: Optional[str] = None + # Store error messages + self._on_required_error_detail = on_required_error_detail + self._on_pattern_error_detail = on_pattern_error_detail + self._on_allowed_values_error_detail = on_allowed_values_error_detail + self._on_custom_validator_error_detail = on_custom_validator_error_detail + + self._pattern_str: Optional[str] = None + self._compiled_pattern: Optional[Pattern[str]] = None - # Handle pattern and format keys if pattern and format: - raise ValueError("Cannot specify both 'pattern' and 'format'. Choose one.") + raise ValueError("Cannot specify both 'pattern' and 'format'.") if format: - if format not in _FORMAT_PATTERNS: + self._pattern_str = _FORMAT_PATTERNS.get(format) + if self._pattern_str is None: raise ValueError( - f"Unknown format '{format}'. " - f"Available formats: {', '.join(_FORMAT_PATTERNS.keys())}" + f"Unknown format '{format}'. Available: {list(_FORMAT_PATTERNS.keys())}" ) - self._pattern = re.compile(_FORMAT_PATTERNS[format], re.IGNORECASE) - self._format_name = format + # Use IGNORECASE for format matching (e.g., UUIDs) + self._compiled_pattern = re.compile(self._pattern_str, re.IGNORECASE) elif pattern: - self._pattern = re.compile(pattern) - self._format_name = None - else: - self._pattern = None - self._format_name = None + self._pattern_str = pattern + self._compiled_pattern = re.compile(self._pattern_str) + + # We pass `None` if the header is required (default=Undefined) + # to bypass FastAPI's default 422, allowing our validator to run + # and use the custom error message. + fastapi_header_default = None if self._is_required else default - # Store the underlying FastAPI Header parameter self._header_param = Header( - default, + fastapi_header_default, alias=alias, convert_underscores=convert_underscores, title=title, @@ -159,57 +180,94 @@ def __init__( **header_kwargs, ) - def __call__(self, header_value: Optional[str] = None) -> Any: + # Dynamically set the __call__ method's signature so FastAPI recognizes + # the Header() dependency and injects the header value correctly. + # This is necessary because we need to pass self._header_param as the default, + # which isn't available at class definition time. + self._set_call_signature() + + def _set_call_signature(self) -> None: + """ + Sets the __call__ method's signature so FastAPI's dependency injection + system recognizes the Header() parameter and extracts the header value. + """ + + # Create a new signature with self and header_value parameters + # The header_value parameter has self._header_param as its default + # so FastAPI will use Header() to extract it from the request + sig = Signature( + [ + Parameter("self", Parameter.POSITIONAL_OR_KEYWORD), + Parameter( + "header_value", + Parameter.KEYWORD_ONLY, + default=self._header_param, + annotation=Optional[str], + ), + ] + ) + + # Set the signature on the underlying function, not the bound method + # Access the function object from the method + self.__call__.__func__.__signature__ = sig # type: ignore + + async def __call__(self, header_value: Optional[str] = None) -> Optional[str]: """ FastAPI dependency entry point for header validation. + FastAPI automatically injects the header value by recognizing the + Header() dependency in the method signature (set via _set_call_signature). + This method then validates the extracted header value and returns it + or raises an HTTPException with a custom error message. + Args: - header_value: The header value extracted from the request. + header_value: The header value extracted from the request by FastAPI. + Will be None if the header is not present. Returns: - The validated header value. + Optional[str]: The validated header value, or None if the header is + optional and not present. Raises: - HTTPException: If validation fails. + HTTPException: If validation fails, with the configured status code + and error message. """ - # If value is None, return a dependency that FastAPI will use - if header_value is None: - - def dependency(value: Optional[str] = self._header_param) -> Optional[str]: - return self._validate(value) - - return dependency - - # If value is provided (for testing), validate directly - return self._validate(header_value) + try: + # Validate the header value (which FastAPI injected via Header()) + return await self._validate(header_value) + except ValidationError as e: + # Convert our internal error to an HTTPException + self._raise_error(status_code=e.status_code, detail=str(e.detail)) + return None # pragma: no cover (unreachable) - def _validate(self, value: Optional[str]) -> Optional[str]: + async def _validate(self, value: Optional[str]) -> Optional[str]: """ - Runs all validation checks on the header value. + Runs all configured validation checks on the header value. + + Checks if the header is required, validates allowed values, pattern matching, + and custom validators in sequence. Args: - value: The header value to validate. + value: The header value to validate (None if not present). Returns: - The validated value. + Optional[str]: The validated header value, or None if optional and not present. Raises: - HTTPException: If any validation check fails. + ValidationError: If any validation check fails. """ - try: - self._validate_required(value) - except ValidationError as e: - self._raise_error(value=value, status_code=e.status_code, detail=str(e.detail)) - if value is None or value == "": - return value or "" - try: - self._validate_allowed_values(value) - self._validate_pattern(value) - self._validate_custom(value) + # 1. Check if required and not present + self._validate_required(value) - except ValidationError as e: - # Convert ValidationError to HTTPException - self._raise_error(value=value, status_code=e.status_code, detail=str(e.detail)) + # 2. If optional and not present, return None + # (It passed _validate_required, so if value is None, it's optional) + if value is None: + return None + + # 3. Run all other validations on the present value + self._validate_allowed_values(value) + self._validate_pattern(value) + await self._validate_custom(value) return value @@ -218,81 +276,59 @@ def _validate_required(self, value: Optional[str]) -> None: Checks if the header is present when required. Args: - value: The header value to check. + value: The header value (None if not present). + + Returns: + None Raises: - ValidationError: If the header is required but missing. + ValidationError: If the header is required but missing or empty. """ - if self._required and (value is None or value == ""): - detail = "Required header is missing." - if callable(detail): - detail_str = detail(value) - else: - detail_str = str(detail) - - raise ValidationError(detail=detail_str, status_code=400) + if self._is_required and (value is None or value == ""): + raise ValidationError( + detail=self._on_required_error_detail, status_code=self._status_code + ) def _validate_allowed_values(self, value: str) -> None: """ - Checks if the value is in the list of allowed values. + Checks if the header value is in the list of allowed values. Args: - value: The header value to check. + value: The header value to validate. + + Returns: + None Raises: - ValidationError: If the value is not in allowed_values. + ValidationError: If the value is not in the allowed list. """ if self._allowed_values is None: - return # No validation rule set + return if value not in self._allowed_values: detail = ( - f"Header value '{value}' is not allowed. " + f"{self._on_allowed_values_error_detail} " f"Allowed values are: {', '.join(self._allowed_values)}" ) - raise ValidationError(detail=detail, status_code=400) + raise ValidationError(detail=detail, status_code=self._status_code) def _validate_pattern(self, value: str) -> None: """ - Checks if the header value matches the required regex pattern. + Checks if the header value matches the configured regex pattern. Args: - value: The header value to check. - - Raises: - ValidationError: If the value doesn't match the pattern. - """ - if self._pattern is None: - return # No validation rule set - - if not self._pattern.match(value): - if self._format_name: - detail = f"Header value does not match the required format: '{self._format_name}'" - else: - detail = ( - f"Header value '{value}' does not match the required pattern: " - f"{self._pattern.pattern}" - ) - raise ValidationError(detail=detail, status_code=400) - - def _validate_custom(self, value: str) -> None: - """ - Runs a custom validation function if provided. + value: The header value to validate. - Args: - value: The header value to check. + Returns: + None Raises: - ValidationError: If the custom validator returns False or raises an exception. + ValidationError: If the value doesn't match the pattern. """ - if self._custom_validator is None: - return # No custom validator set + if self._compiled_pattern is None: + return - try: - if not self._custom_validator(value): - detail = f"Custom validation failed for header value '{value}'" - raise ValidationError(detail=detail, status_code=400) - except Exception as e: - # If the validator itself raises an exception, catch it - detail = f"Custom validation error: {str(e)}" - raise ValidationError(detail=detail, status_code=400) + if not self._compiled_pattern.fullmatch(value): + raise ValidationError( + detail=self._on_pattern_error_detail, status_code=self._status_code + ) diff --git a/fastapi_assets/request_validators/path_validator.py b/fastapi_assets/request_validators/path_validator.py index 363ec4e..135bd86 100644 --- a/fastapi_assets/request_validators/path_validator.py +++ b/fastapi_assets/request_validators/path_validator.py @@ -1,289 +1,194 @@ """Module providing the PathValidator for validating path parameters in FastAPI.""" -import re + from typing import Any, Callable, List, Optional, Union -from fastapi import Depends, Path -from fastapi_assets.core.base_validator import BaseValidator, ValidationError +from inspect import Signature, Parameter +from fastapi import Path +from fastapi_assets.core import BaseValidator, ValidationError class PathValidator(BaseValidator): r""" - A general-purpose dependency for validating path parameters in FastAPI. + A dependency factory for adding custom validation to FastAPI path parameters. + + This class extends the functionality of FastAPI's `Path()` by adding + support for `allowed_values` and custom `validators`. - It validates path parameters with additional constraints like allowed values, - regex patterns, string length checks, numeric bounds, and custom validators. + It acts as a factory: you instantiate it, and then *call* the + instance inside `Depends()` to get the actual dependency. + Example: .. code-block:: python - from fastapi import FastAPI - from fastapi_assets.path_validator import PathValidator + + from fastapi import FastAPI, Depends + from fastapi_assets.request_validators import PathValidator app = FastAPI() - # Create reusable validators + # 1. Create reusable validator *instances* item_id_validator = PathValidator( + "item_id", + _type=int, gt=0, lt=1000, - error_detail="Item ID must be between 1 and 999" ) username_validator = PathValidator( + "username", + _type=str, min_length=5, max_length=15, pattern=r"^[a-zA-Z0-9]+$", - error_detail="Username must be 5-15 alphanumeric characters" ) @app.get("/items/{item_id}") - def get_item(item_id: int = item_id_validator): + def get_item(item_id: int = Depends(item_id_validator())): return {"item_id": item_id} @app.get("/users/{username}") - def get_user(username: str = username_validator): + def get_user(username: str = Depends(username_validator())): return {"username": username} """ def __init__( self, + param_name: str, + _type: type, default: Any = ..., *, + # Custom validation rules allowed_values: Optional[List[Any]] = None, - pattern: Optional[str] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, + validators: Optional[List[Callable[[Any], Any]]] = None, + on_custom_validator_error_detail: str = "Custom validation failed.", + # Standard Path() parameters + title: Optional[str] = None, + description: Optional[str] = None, gt: Optional[Union[int, float]] = None, lt: Optional[Union[int, float]] = None, ge: Optional[Union[int, float]] = None, le: Optional[Union[int, float]] = None, - validator: Optional[Callable[[Any], bool]] = None, - # Standard Path() parameters - title: Optional[str] = None, - description: Optional[str] = None, - alias: Optional[str] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, deprecated: Optional[bool] = None, - **path_kwargs : Any + **path_kwargs: Any, ) -> None: """ - Initializes the PathValidator. + Initializes the PathValidator factory. Args: - default: Default value for the path parameter (usually ... for required). - allowed_values: List of allowed values for the parameter. - pattern: Regex pattern the parameter must match (for strings). - min_length: Minimum length for string parameters. - max_length: Maximum length for string parameters. - gt: Value must be greater than this (for numeric parameters). - lt: Value must be less than this (for numeric parameters). - ge: Value must be greater than or equal to this. - le: Value must be less than or equal to this. - validator: Custom validation function that takes the value and returns bool. + param_name: The exact name of the path parameter. + _type: The Python type for coercion (e.g., int, str, UUID). + default: Default value for the path parameter. + allowed_values: List of allowed values. + validators: List of custom validation functions. + on_custom_validator_error_detail: Error message for custom validators. title: Title for API documentation. description: Description for API documentation. - alias: Alternative parameter name. + gt: Value must be greater than this. + lt: Value must be less than this. + ge: Value must be greater than or equal to this. + le: Value must be less than or equal to this. + min_length: Minimum length for string parameters. + max_length: Maximum length for string parameters. + pattern: Regex pattern the parameter must match. deprecated: Whether the parameter is deprecated. **path_kwargs: Additional arguments passed to FastAPI's Path(). """ - path_kwargs["error_detail"] = path_kwargs.get("error_detail", "Path parameter validation failed.") - path_kwargs["status_code"] = path_kwargs.get("status_code", 400) - # Call super() with default error handling + path_kwargs.setdefault("error_detail", "Path parameter validation failed.") + path_kwargs.setdefault("status_code", 400) + super().__init__( - **path_kwargs + status_code=path_kwargs["status_code"], + error_detail=path_kwargs["error_detail"], + validators=validators, ) - # Store validation rules + + self._param_name = param_name + self._type = _type self._allowed_values = allowed_values - self._pattern = re.compile(pattern) if pattern else None - self._min_length = min_length - self._max_length = max_length - self._gt = gt - self._lt = lt - self._ge = ge - self._le = le - self._custom_validator = validator + self._on_custom_validator_error_detail = on_custom_validator_error_detail - # Store the underlying FastAPI Path parameter - # This preserves all standard Path() features (title, description, etc.) self._path_param = Path( default, title=title, description=description, - alias=alias, deprecated=deprecated, gt=gt, lt=lt, ge=ge, le=le, - **path_kwargs + min_length=min_length, + max_length=max_length, + pattern=pattern, + **path_kwargs, ) - def __call__(self, value: Any = None) -> Any: - """ - FastAPI dependency entry point for path validation. - - Args: - value: The path parameter value extracted from the URL. + def __call__(self) -> Callable[..., Any]: + """ + This is the factory method. + It generates and returns the dependency function + that FastAPI will use. + """ + + async def dependency(**kwargs: Any) -> Any: + path_value = kwargs[self._param_name] + try: + validated_value = await self._validate(path_value) + return validated_value + except ValidationError as e: + self._raise_error(path_value, status_code=e.status_code, detail=e.detail) + return None + + sig = Signature( + [ + Parameter( + self._param_name, + Parameter.KEYWORD_ONLY, + default=self._path_param, + annotation=self._type, + ) + ] + ) - Returns: - The validated path parameter value. + dependency.__signature__ = sig # type: ignore + return dependency - Raises: - HTTPException: If validation fails. + async def _validate(self, value: Any) -> Any: """ - # If value is None, it means FastAPI will inject the actual path parameter - # This happens because FastAPI handles the Path() dependency internally - if value is None: - # Return a dependency that FastAPI will use - async def dependency(param_value: Any = self._path_param) -> Any: - return self._validate(param_value) - return Depends(dependency) + Runs all validation checks on the path parameter value. - # If value is provided (for testing), validate directly - return self._validate(value) - - def _validate(self, value: Any) -> Any: - """ - Runs all validation checks on the parameter value. + Executes allowed values checking and custom validator checking in sequence. Args: value: The path parameter value to validate. Returns: - The validated value. + Any: The validated value (unchanged if validation passes). Raises: - HTTPException: If any validation check fails. + ValidationError: If any validation check fails. """ - try: - self._validate_allowed_values(value) - self._validate_pattern(value) - self._validate_length(value) - self._validate_numeric_bounds(value) - self._validate_custom(value) - except ValidationError as e: - # Convert ValidationError to HTTPException - self._raise_error( - status_code=e.status_code, - detail=str(e.detail) - ) - + self._validate_allowed_values(value) + await self._validate_custom(value) return value def _validate_allowed_values(self, value: Any) -> None: """ - Checks if the value is in the list of allowed values. + Checks if the path parameter value is in the list of allowed values. Args: - value: The parameter value to check. + value: The value to validate. + + Returns: + None Raises: - ValidationError: If the value is not in allowed_values. + ValidationError: If the value is not in the allowed values list. """ if self._allowed_values is None: - return # No validation rule set + return if value not in self._allowed_values: - detail = ( - f"Value '{value}' is not allowed. " - f"Allowed values are: {', '.join(map(str, self._allowed_values))}" - ) + allowed_str = ", ".join(map(str, self._allowed_values)) + detail = f"Value '{value}' is not allowed. Allowed values are: {allowed_str}" raise ValidationError(detail=detail, status_code=400) - - def _validate_pattern(self, value: Any) -> None: - """ - Checks if the string value matches the required regex pattern. - - Args: - value: The parameter value to check. - - Raises: - ValidationError: If the value doesn't match the pattern. - """ - if self._pattern is None: - return # No validation rule set - - if not isinstance(value, str): - return # Pattern validation only applies to strings - - if not self._pattern.match(value): - detail = ( - f"Value '{value}' does not match the required pattern: " - f"{self._pattern.pattern}" - ) - raise ValidationError(detail=detail, status_code=400) - - def _validate_length(self, value: Any) -> None: - """ - Checks if the string length is within the specified bounds. - - Args: - value: The parameter value to check. - - Raises: - ValidationError: If the length is out of bounds. - """ - if not isinstance(value, str): - return # Length validation only applies to strings - - value_len = len(value) - - if self._min_length is not None and value_len < self._min_length: - detail = ( - f"Value '{value}' is too short. " - f"Minimum length is {self._min_length} characters." - ) - raise ValidationError(detail=detail, status_code=400) - - if self._max_length is not None and value_len > self._max_length: - detail = ( - f"Value '{value}' is too long. " - f"Maximum length is {self._max_length} characters." - ) - raise ValidationError(detail=detail, status_code=400) - - def _validate_numeric_bounds(self, value: Any) -> None: - """ - Checks if numeric values satisfy gt, lt, ge, le constraints. - - Args: - value: The parameter value to check. - - Raises: - ValidationError: If the value is out of the specified bounds. - """ - if not isinstance(value, (int, float)): - return # Numeric validation only applies to numbers - - if self._gt is not None and value <= self._gt: - detail = f"Value must be greater than {self._gt}" - raise ValidationError(detail=detail, status_code=400) - - if self._lt is not None and value >= self._lt: - detail = f"Value must be less than {self._lt}" - raise ValidationError(detail=detail, status_code=400) - - if self._ge is not None and value < self._ge: - detail = f"Value must be greater than or equal to {self._ge}" - raise ValidationError(detail=detail, status_code=400) - - if self._le is not None and value > self._le: - detail = f"Value must be less than or equal to {self._le}" - raise ValidationError(detail=detail, status_code=400) - - def _validate_custom(self, value: Any) -> None: - """ - Runs a custom validation function if provided. - - Args: - value: The parameter value to check. - - Raises: - ValidationError: If the custom validator returns False or raises an exception. - """ - if self._custom_validator is None: - return # No custom validator set - - try: - if not self._custom_validator(value): - detail = f"Custom validation failed for value '{value}'" - raise ValidationError(detail=detail, status_code=400) - except Exception as e: - # If the validator itself raises an exception, catch it - detail = f"Custom validation error: {str(e)}" - raise ValidationError(detail=detail, status_code=400) \ No newline at end of file diff --git a/fastapi_assets/validators/__init__.py b/fastapi_assets/validators/__init__.py index 264e724..9854383 100644 --- a/fastapi_assets/validators/__init__.py +++ b/fastapi_assets/validators/__init__.py @@ -1 +1,5 @@ """Module for file based validation in FastAPI Assets.""" + +from fastapi_assets.validators.csv_validator import CSVValidator +from fastapi_assets.validators.file_validator import FileValidator +from fastapi_assets.validators.image_validator import ImageValidator diff --git a/fastapi_assets/validators/csv_validator.py b/fastapi_assets/validators/csv_validator.py index cc0f9f2..c8860ee 100644 --- a/fastapi_assets/validators/csv_validator.py +++ b/fastapi_assets/validators/csv_validator.py @@ -5,7 +5,7 @@ from starlette.datastructures import UploadFile as StarletteUploadFile # Import from base file_validator module -from fastapi_assets.core.base_validator import ValidationError +from fastapi_assets.core import ValidationError from fastapi_assets.validators.file_validator import ( FileValidator, ) @@ -29,7 +29,7 @@ class CSVValidator(FileValidator): .. code-block:: python from fastapi import FastAPI, UploadFile, Depends - from fastapi_assets.validators.csv_validator import CSVValidator + from fastapi_assets.validators import CSVValidator app = FastAPI() @@ -121,7 +121,7 @@ def __init__( self._row_error_detail = on_row_error_detail self._parse_error_detail = on_parse_error_detail - async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> StarletteUploadFile: + async def __call__(self, file: UploadFile = File(...)) -> StarletteUploadFile: """ FastAPI dependency entry point for CSV validation. @@ -137,35 +137,66 @@ async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> Starlet # Run all parent validations (size, content-type, filename) # This will also rewind the file (await file.seek(0)) try: - await super().__call__(file, **kwargs) - except ValidationError as e: - # Re-raise parent's validation error - self._raise_error(status_code=e.status_code, detail=str(e.detail)) - - # File is validated by parent and rewound. Start CSV checks. - try: - # Check encoding if specified - await self._validate_encoding(file) - await file.seek(0) # Rewind after encoding check - - # Check columns and row counts - await self._validate_csv_structure(file) - + await self._validate(file=file) except ValidationError as e: await file.close() + # Re-raise parent's validation error self._raise_error(status_code=e.status_code, detail=str(e.detail)) except Exception as e: # Catch pandas errors (e.g., CParserError, UnicodeDecodeError) await file.close() detail = self._parse_error_detail or f"Failed to parse CSV file: {e}" self._raise_error(status_code=400, detail=detail) + try: + # CRITICAL: Rewind the file AGAIN so the endpoint can read it. + await file.seek(0) + return file + except Exception as e: + await file.close() + self._raise_error( + status_code=e.status_code if hasattr(e, "status_code") else 400, + detail="File could not be rewound after validation.", + ) + return None # type: ignore # pragma: no cover - # CRITICAL: Rewind the file AGAIN so the endpoint can read it. + async def _validate(self, file: UploadFile) -> None: + """ + Runs all CSV-specific validation checks on the uploaded file. + + This method orchestrates the validation pipeline: first calls parent + FileValidator validations, then validates encoding, and finally + validates the CSV structure (columns and rows). + + Args: + file: The uploaded file to validate. + + Returns: + None + + Raises: + ValidationError: If any validation check fails. + """ + await super()._validate(file) + await self._validate_encoding(file) await file.seek(0) - return file + await self._validate_csv_structure(file) async def _validate_encoding(self, file: UploadFile) -> None: - """Checks if the file encoding matches one of the allowed encodings.""" + """ + Validates that the file encoding matches one of the allowed encodings. + + Reads a small chunk of the file and attempts to decode it with each + specified encoding. If none match, raises a ValidationError. + + Args: + file: The uploaded file to validate. + + Returns: + None + + Raises: + ValidationError: If the file encoding is not one of the allowed encodings. + """ if not self._encoding: return # No check needed @@ -190,13 +221,22 @@ async def _validate_encoding(self, file: UploadFile) -> None: raise ValidationError(detail=str(detail), status_code=400) def _check_columns(self, header: List[str]) -> None: - """Validates the CSV header against column rules. + """ + Validates the CSV header against configured column rules. + + Checks for required columns, disallowed columns, and exact column matching + based on the validator's configuration. Raises ValidationError if any + rule is violated. + Args: - header: List of column names from the CSV header. + header: List of column names extracted from the CSV header row. + Returns: None + Raises: - ValidationError: If any column validation fails. + ValidationError: If exact columns don't match, required columns are missing, + or disallowed columns are present. """ header_set = set(header) @@ -229,7 +269,22 @@ def _check_columns(self, header: List[str]) -> None: raise ValidationError(detail=str(detail), status_code=400) def _check_row_counts(self, total_rows: int) -> None: - """Validates the total row count against min/max rules.""" + """ + Validates that the CSV row count meets min/max constraints. + + Compares the actual number of data rows against the configured + minimum and maximum row limits. Raises ValidationError if constraints + are violated. + + Args: + total_rows: The total number of data rows (excluding header) in the CSV. + + Returns: + None + + Raises: + ValidationError: If the row count is below minimum or exceeds maximum. + """ if self._min_rows is not None and total_rows < self._min_rows: detail = self._row_error_detail or ( f"File does not meet minimum required rows: {self._min_rows}. Found: {total_rows}." @@ -244,16 +299,21 @@ def _check_row_counts(self, total_rows: int) -> None: async def _validate_csv_structure(self, file: UploadFile) -> None: """ - Validates the CSV columns and row counts using pandas. + Validates the CSV structure including columns and row counts using pandas. + + This method handles both efficient bounded reads (checking only necessary rows) + and full file reads depending on the `header_check_only` setting. It validates + column constraints first, then row count constraints if applicable. - Uses either an efficient bounded read (header_check_only=True) - or a full stream (header_check_only=False) for row counts. Args: file: The uploaded CSV file to validate. + Returns: None + Raises: - ValidationError: If any structure validation fails. + ValidationError: If column validation fails, row count validation fails, + or if the file cannot be parsed as valid CSV. """ # file.file is the underlying SpooledTemporaryFile file_obj = file.file diff --git a/fastapi_assets/validators/file_validator.py b/fastapi_assets/validators/file_validator.py index 0cbd06d..178c6d8 100644 --- a/fastapi_assets/validators/file_validator.py +++ b/fastapi_assets/validators/file_validator.py @@ -33,7 +33,7 @@ class FileValidator(BaseValidator): .. code-block:: python from fastapi import FastAPI, UploadFile, Depends - from fastapi_assets.validators.file_validator import FileValidator + from fastapi_assets.validators import FileValidator app = FastAPI() @@ -64,6 +64,7 @@ def __init__( on_size_error_detail: Optional[Union[str, Callable[[Any], str]]] = None, on_type_error_detail: Optional[Union[str, Callable[[Any], str]]] = None, on_filename_error_detail: Optional[Union[str, Callable[[Any], str]]] = None, + validators: Optional[List[Callable]] = None, **kwargs: Any, ): """ @@ -83,7 +84,11 @@ def __init__( # by the specific error handlers. kwargs["error_detail"] = kwargs.get("error_detail", "File validation failed.") kwargs["status_code"] = 400 - super().__init__(**kwargs) + super().__init__( + error_detail=kwargs["error_detail"], + status_code=kwargs["status_code"], + validators=validators, + ) # Parse sizes once self._max_size = ( @@ -106,28 +111,31 @@ def __init__( self._type_error_detail = on_type_error_detail self._filename_error_detail = on_filename_error_detail - async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> StarletteUploadFile: + async def __call__(self, file: UploadFile = File(...)) -> StarletteUploadFile: """ FastAPI dependency entry point for file validation. + + Runs all configured validation checks on the uploaded file (content type, + filename, size, and custom validators) and returns the validated file + after rewinding it so the endpoint can read it from the beginning. + Args: file: The uploaded file to validate. + Returns: - The validated UploadFile object. + StarletteUploadFile: The validated UploadFile object, rewound to the start. + Raises: - HTTPException: If validation fails. + HTTPException: If any validation check fails. """ try: - self._validate_content_type(file) - self._validate_filename(file) - await self._validate_size(file) - # Additional validations can be added here + await self._validate(file=file) except ValidationError as e: # Our custom validation exception, convert to HTTPException self._raise_error(status_code=e.status_code, detail=str(e.detail)) except Exception as e: # Catch any other unexpected error during validation await file.close() - print("Raising HTTPException for unexpected error:", e) self._raise_error( status_code=400, detail="An unexpected error occurred during file validation.", @@ -138,14 +146,42 @@ async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> Starlet await file.seek(0) return file + async def _validate(self, file: UploadFile) -> None: + """ + Runs all file validation checks in sequence. + + Executes content-type, filename, size, and custom validator checks + on the uploaded file. + + Args: + file: The uploaded file to validate. + + Returns: + None + + Raises: + ValidationError: If any validation check fails. + """ + self._validate_content_type(file) + self._validate_filename(file) + await self._validate_size(file=file) + await self._validate_custom(value=file) + def _validate_content_type(self, file: UploadFile) -> None: - """Checks the file's MIME type. + """ + Validates that the file's MIME type is in the allowed list. + + Checks the file's Content-Type against the configured allowed types, + supporting wildcard patterns (e.g., "image/*"). + Args: file: The uploaded file to validate. + Returns: None + Raises: - ValidationError: If the content type is not allowed. + ValidationError: If the content type is not in the allowed list. """ if not self._content_types: return # No validation rule set @@ -156,12 +192,22 @@ def _validate_content_type(self, file: UploadFile) -> None: f"File has an unsupported media type: '{file_type}'. " f"Allowed types are: {', '.join(self._content_types)}" ) - print("Raising ValidationError for content type:", detail) # Use 415 for Unsupported Media Type raise ValidationError(detail=str(detail), status_code=415) def _validate_filename(self, file: UploadFile) -> None: - """Checks the file's name against a regex pattern.""" + """ + Validates that the filename matches the configured regex pattern. + + Args: + file: The uploaded file to validate. + + Returns: + None + + Raises: + ValidationError: If the filename doesn't match the pattern or is missing. + """ if not self._filename_regex: return # No validation rule set @@ -173,8 +219,19 @@ def _validate_filename(self, file: UploadFile) -> None: async def _validate_size(self, file: UploadFile) -> None: """ - Checks file size, using Content-Length if available, - or streaming and counting if not. + Validates that the file size is within configured bounds. + + Uses the Content-Length header if available for efficiency, otherwise + streams the file to determine its actual size. + + Args: + file: The uploaded file to validate. + + Returns: + None + + Raises: + ValidationError: If the file size exceeds max_size or is below min_size. """ if self._max_size is None and self._min_size is None: return # No validation rule set diff --git a/fastapi_assets/validators/image_validator.py b/fastapi_assets/validators/image_validator.py index af0c59d..8a8a4aa 100644 --- a/fastapi_assets/validators/image_validator.py +++ b/fastapi_assets/validators/image_validator.py @@ -3,7 +3,7 @@ """ from typing import Any, Callable, List, Optional, Union -from fastapi_assets.core.base_validator import ValidationError +from fastapi_assets.core import ValidationError from fastapi import File, UploadFile from starlette.datastructures import UploadFile as StarletteUploadFile from fastapi_assets.validators.file_validator import FileValidator @@ -41,7 +41,7 @@ class ImageValidator(FileValidator): .. code-block:: python from fastapi import FastAPI, UploadFile, Depends - from fastapi_assets.validators.image_validator import ImageValidator + from fastapi_assets.validators import ImageValidator app = FastAPI() @@ -149,7 +149,7 @@ def _parse_aspect_ratios(self, ratios: Optional[List[str]]) -> Optional[List[flo ) return parsed - async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> StarletteUploadFile: + async def __call__(self, file: UploadFile = File(...)) -> StarletteUploadFile: """ FastAPI dependency entry point for image validation. Args: @@ -157,24 +157,9 @@ async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> Starlet Returns: The validated UploadFile object. """ - # Run all parent validations (size, content-type, filename) - # This will also rewind the file stream to position 0. - try: - await super().__call__(file, **kwargs) - except ValidationError as e: - # Re-raise the exception from the parent - self._raise_error(status_code=e.status_code, detail=str(e.detail)) - # Run image-specific validations using Pillow - img = None try: - # `file.file` is a SpooledTemporaryFile, which Image.open can read. - img = Image.open(file.file) - - # Perform content-based validations - self._validate_format(img) - self._validate_resolution(img) - self._validate_aspect_ratio(img) + await self._validate(file) except (UnidentifiedImageError, IOError) as e: # Pillow couldn't identify it as an image, or file is corrupt @@ -196,23 +181,46 @@ async def __call__(self, file: UploadFile = File(...), **kwargs: Any) -> Starlet detail=f"An unexpected error occurred during image validation: {e}", ) finally: - if img: - img.close() - # CRITICAL: Rewind the file stream *again* so the endpoint # can read it after Pillow is done. await file.seek(0) return file + async def _validate(self, file: UploadFile) -> None: + """ + Runs all image validation checks using PIL/Pillow. + + Opens the image file with Pillow and validates its format, resolution, + and aspect ratio against the configured constraints. + + Args: + file: The uploaded image file to validate. + + Returns: + None + + Raises: + ValidationError: If any image validation check fails. + """ + await super()._validate(file) + img = Image.open(file.file) + self._validate_format(img) + self._validate_resolution(img) + self._validate_aspect_ratio(img) + def _validate_format(self, img: Image.Image) -> None: - """Checks the image's actual format (e.g., 'JPEG', 'PNG'). + """ + Validates that the image format is in the allowed list. + Args: img: The opened PIL Image object. + Returns: None + Raises: - ValidationError: If the image format is not allowed. + ValidationError: If the image format is not in the allowed list. """ if not self._allowed_formats: return # No rule set @@ -227,13 +235,19 @@ def _validate_format(self, img: Image.Image) -> None: raise ValidationError(detail=str(detail), status_code=415) def _validate_resolution(self, img: Image.Image) -> None: - """Checks image dimensions against min, max, and exact constraints. + """ + Validates the image's resolution against min, max, and exact constraints. + + Checks that the image width and height meet the configured constraints. + Args: img: The opened PIL Image object. + Returns: None + Raises: - ValidationError: If the image resolution is out of bounds + ValidationError: If the image resolution does not meet constraints. """ if not (self._min_resolution or self._max_resolution or self._exact_resolution): return # No resolution rules set @@ -265,13 +279,20 @@ def _validate_resolution(self, img: Image.Image) -> None: raise ValidationError(detail=str(detail), status_code=400) def _validate_aspect_ratio(self, img: Image.Image) -> None: - """Checks the image's aspect ratio against a list of allowed ratios. + """ + Validates that the image's aspect ratio is in the allowed list. + + Compares the actual aspect ratio against configured allowed ratios with + a tolerance for floating-point precision. + Args: img: The opened PIL Image object. + Returns: None + Raises: - ValidationError: If the image's aspect ratio is not allowed. + ValidationError: If the image aspect ratio is not allowed or cannot be calculated. """ if not self._aspect_ratios: return # No rule set diff --git a/tests/test_base_validator.py b/tests/test_base_validator.py index 7e0a88c..1cd65f2 100644 --- a/tests/test_base_validator.py +++ b/tests/test_base_validator.py @@ -7,7 +7,7 @@ from fastapi_assets.core.exceptions import ValidationError -# --- Test Setup --- +# Test Setup class _MockValidator(BaseValidator): @@ -23,6 +23,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: """Minimal implementation to satisfy the abstract contract.""" pass + def _validate(self, value: Any) -> None: + """Minimal implementation to satisfy the abstract contract.""" + pass + def public_raise_error( self, value: Any, @@ -35,7 +39,7 @@ def public_raise_error( self._raise_error(value=value, status_code=status_code, detail=detail) -# --- Test Cases --- +# Test Cases def test_cannot_instantiate_abstract_class() -> None: diff --git a/tests/test_cookie_validator.py b/tests/test_cookie_validator.py index 664416d..8a7be6f 100644 --- a/tests/test_cookie_validator.py +++ b/tests/test_cookie_validator.py @@ -10,19 +10,19 @@ # Import the class to be tested try: - from fastapi_assets.request_validators.cookie_validator import CookieAssert + from fastapi_assets.request_validators.cookie_validator import CookieValidator as CookieAssert from fastapi_assets.core.exceptions import ValidationError except ImportError as e: pytest.skip(f"Could not import CookieAssert: {e}", allow_module_level=True) -# --- Test Application Setup --- +# Test Application Setup # Define validators once, as they would be in a real app validate_required_uuid = CookieAssert( alias="session-id", format="uuid4", on_required_error_detail="Session is required.", - on_pattern_error_detail="Invalid session format." + on_pattern_error_detail="Invalid session format.", ) validate_optional_gt10 = CookieAssert( @@ -31,33 +31,29 @@ default=None, gt=10, on_comparison_error_detail="Tracker must be > 10.", - on_numeric_error_detail="Tracker must be a number." + on_numeric_error_detail="Tracker must be a number.", ) validate_length_5 = CookieAssert( - alias="code", - min_length=5, - max_length=5, - on_length_error_detail="Code must be 5 chars." + alias="code", min_length=5, max_length=5, on_length_error_detail="Code must be 5 chars." ) + def _custom_check(val: str): """A sample custom validator function""" if val not in ["admin", "user"]: - raise ValueError("Role is invalid") - return True + raise ValidationError("Role is invalid") + validate_custom_role = CookieAssert( - alias="role", - validators=[_custom_check], - on_validator_error_detail="Invalid role." + alias="role", validators=[_custom_check], on_validator_error_detail="Invalid role." ) # Additional validators for extended tests validate_bearer_token = CookieAssert( alias="auth-token", format="bearer_token", - on_pattern_error_detail="Invalid bearer token format." + on_pattern_error_detail="Invalid bearer token format.", ) validate_numeric_ge_le = CookieAssert( @@ -65,71 +61,82 @@ def _custom_check(val: str): ge=0, le=100, on_comparison_error_detail="Score must be between 0 and 100.", - on_numeric_error_detail="Score must be a number." + on_numeric_error_detail="Score must be a number.", ) + def _async_validator(val: str): """An async custom validator function""" # This is actually a sync function that will be called # The CookieAssert supports both sync and async validators - return val.startswith("valid_") + if not val.startswith("valid_"): + raise ValidationError("Token must start with 'valid_'.") + validate_async_custom = CookieAssert( alias="async-token", validators=[_async_validator], - on_validator_error_detail="Token must start with 'valid_'." + on_validator_error_detail="Token must start with 'valid_'.", ) validate_email_format = CookieAssert( - alias="user-email", - format="email", - on_pattern_error_detail="Invalid email format." + alias="user-email", format="email", on_pattern_error_detail="Invalid email format." ) # Create a minimal FastAPI app for testing app = FastAPI() + @app.get("/test-required") async def get_required(session: str = Depends(validate_required_uuid)): """Test endpoint for a required, formatted cookie.""" return {"session": session} + @app.get("/test-optional") async def get_optional(tracker: Optional[float] = Depends(validate_optional_gt10)): """Test endpoint for an optional, numeric cookie.""" return {"tracker": tracker} + @app.get("/test-length") async def get_length(code: str = Depends(validate_length_5)): """Test endpoint for a length-constrained cookie.""" return {"code": code} + @app.get("/test-custom") async def get_custom(role: str = Depends(validate_custom_role)): """Test endpoint for a custom-validated cookie.""" return {"role": role} + @app.get("/test-bearer") async def get_bearer(token: str = Depends(validate_bearer_token)): """Test endpoint for bearer token format.""" return {"token": token} + @app.get("/test-ge-le") async def get_numeric_range(score: float = Depends(validate_numeric_ge_le)): """Test endpoint for numeric range validation.""" return {"score": score} + @app.get("/test-async-custom") async def get_async_custom(token: str = Depends(validate_async_custom)): """Test endpoint for async custom validator.""" return {"token": token} + @app.get("/test-email") async def get_email(email: str = Depends(validate_email_format)): """Test endpoint for email format.""" return {"email": email} -# --- Pytest Fixtures --- + +# Pytest Fixtures + @pytest.fixture(scope="module") def anyio_backend(): @@ -145,13 +152,12 @@ async def client(anyio_backend): Pytest fixture to create an AsyncClient for the test app. Depends on the 'anyio_backend' fixture. """ - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" - ) as ac: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: yield ac -# --- Test Cases --- + +# Test Cases + # REQUIRED COOKIE TESTS @pytest.mark.anyio @@ -161,6 +167,7 @@ async def test_required_cookie_missing(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Session is required."} + @pytest.mark.anyio async def test_required_cookie_invalid_format(client: AsyncClient): """Tests that a required cookie fails on invalid format.""" @@ -169,6 +176,7 @@ async def test_required_cookie_invalid_format(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Invalid session format."} + @pytest.mark.anyio async def test_required_cookie_valid(client: AsyncClient): """Tests that a required cookie passes with valid format.""" @@ -178,6 +186,7 @@ async def test_required_cookie_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"session": valid_uuid} + # OPTIONAL COOKIE TESTS @pytest.mark.anyio async def test_optional_cookie_missing(client: AsyncClient): @@ -186,6 +195,7 @@ async def test_optional_cookie_missing(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"tracker": None} + @pytest.mark.anyio async def test_optional_cookie_invalid_comparison(client: AsyncClient): """Tests that an optional cookie fails numeric comparison.""" @@ -194,6 +204,7 @@ async def test_optional_cookie_invalid_comparison(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Tracker must be > 10."} + @pytest.mark.anyio async def test_optional_cookie_invalid_numeric(client: AsyncClient): """Tests that a numeric cookie fails non-numeric values.""" @@ -202,6 +213,7 @@ async def test_optional_cookie_invalid_numeric(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Tracker must be a number."} + @pytest.mark.anyio async def test_optional_cookie_valid(client: AsyncClient): """Tests that an optional cookie passes with a valid value.""" @@ -210,6 +222,7 @@ async def test_optional_cookie_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"tracker": 100.0} + @pytest.mark.anyio async def test_optional_cookie_boundary_gt(client: AsyncClient): """Tests boundary condition for gt comparison (10 is not > 10).""" @@ -218,6 +231,7 @@ async def test_optional_cookie_boundary_gt(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Tracker must be > 10."} + @pytest.mark.anyio async def test_optional_cookie_boundary_gt_valid(client: AsyncClient): """Tests boundary condition for gt comparison (10.1 is > 10).""" @@ -226,6 +240,7 @@ async def test_optional_cookie_boundary_gt_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"tracker": 10.1} + # LENGTH CONSTRAINT TESTS @pytest.mark.anyio async def test_length_cookie_too_short(client: AsyncClient): @@ -235,6 +250,7 @@ async def test_length_cookie_too_short(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Code must be 5 chars."} + @pytest.mark.anyio async def test_length_cookie_too_long(client: AsyncClient): """Tests max_length validation.""" @@ -243,6 +259,7 @@ async def test_length_cookie_too_long(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Code must be 5 chars."} + @pytest.mark.anyio async def test_length_cookie_valid(client: AsyncClient): """Tests valid length validation.""" @@ -251,6 +268,7 @@ async def test_length_cookie_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"code": "12345"} + @pytest.mark.anyio async def test_length_cookie_min_boundary(client: AsyncClient): """Tests minimum boundary condition.""" @@ -258,6 +276,7 @@ async def test_length_cookie_min_boundary(client: AsyncClient): response = await client.get("/test-length", cookies=cookies) assert response.status_code == status.HTTP_400_BAD_REQUEST + # CUSTOM VALIDATOR TESTS @pytest.mark.anyio async def test_custom_validator_fail(client: AsyncClient): @@ -265,7 +284,8 @@ async def test_custom_validator_fail(client: AsyncClient): cookies = {"role": "guest"} # "guest" is not in ["admin", "user"] response = await client.get("/test-custom", cookies=cookies) assert response.status_code == status.HTTP_400_BAD_REQUEST - assert "Invalid role." in response.json()["detail"] + assert "role is invalid" in response.json()["detail"].lower() + @pytest.mark.anyio async def test_custom_validator_pass_admin(client: AsyncClient): @@ -275,6 +295,7 @@ async def test_custom_validator_pass_admin(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"role": "admin"} + @pytest.mark.anyio async def test_custom_validator_pass_user(client: AsyncClient): """Tests custom validator function success with 'user'.""" @@ -283,6 +304,7 @@ async def test_custom_validator_pass_user(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"role": "user"} + # FORMAT PATTERN TESTS @pytest.mark.anyio async def test_bearer_token_valid_format(client: AsyncClient): @@ -291,6 +313,7 @@ async def test_bearer_token_valid_format(client: AsyncClient): response = await client.get("/test-bearer", cookies=cookies) assert response.status_code == status.HTTP_200_OK + @pytest.mark.anyio async def test_bearer_token_lowercase_bearer(client: AsyncClient): """Tests bearer token with lowercase 'bearer'.""" @@ -298,6 +321,7 @@ async def test_bearer_token_lowercase_bearer(client: AsyncClient): response = await client.get("/test-bearer", cookies=cookies) assert response.status_code == status.HTTP_200_OK + @pytest.mark.anyio async def test_bearer_token_invalid_no_bearer_prefix(client: AsyncClient): """Tests bearer token missing 'Bearer' prefix.""" @@ -306,6 +330,7 @@ async def test_bearer_token_invalid_no_bearer_prefix(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Invalid bearer token format."} + @pytest.mark.anyio async def test_email_valid_format(client: AsyncClient): """Tests valid email format.""" @@ -313,6 +338,7 @@ async def test_email_valid_format(client: AsyncClient): response = await client.get("/test-email", cookies=cookies) assert response.status_code == status.HTTP_200_OK + @pytest.mark.anyio async def test_email_with_plus_sign(client: AsyncClient): """Tests valid email with plus sign.""" @@ -320,6 +346,7 @@ async def test_email_with_plus_sign(client: AsyncClient): response = await client.get("/test-email", cookies=cookies) assert response.status_code == status.HTTP_200_OK + @pytest.mark.anyio async def test_email_invalid_format_no_at(client: AsyncClient): """Tests invalid email without @ symbol.""" @@ -327,6 +354,7 @@ async def test_email_invalid_format_no_at(client: AsyncClient): response = await client.get("/test-email", cookies=cookies) assert response.status_code == status.HTTP_400_BAD_REQUEST + @pytest.mark.anyio async def test_email_invalid_format_no_domain(client: AsyncClient): """Tests invalid email without domain.""" @@ -334,6 +362,7 @@ async def test_email_invalid_format_no_domain(client: AsyncClient): response = await client.get("/test-email", cookies=cookies) assert response.status_code == status.HTTP_400_BAD_REQUEST + # NUMERIC RANGE TESTS @pytest.mark.anyio async def test_numeric_range_valid_min(client: AsyncClient): @@ -343,6 +372,7 @@ async def test_numeric_range_valid_min(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"score": 0.0} + @pytest.mark.anyio async def test_numeric_range_valid_max(client: AsyncClient): """Tests numeric value at maximum boundary (le=100).""" @@ -351,6 +381,7 @@ async def test_numeric_range_valid_max(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"score": 100.0} + @pytest.mark.anyio async def test_numeric_range_valid_middle(client: AsyncClient): """Tests numeric value in middle of range.""" @@ -359,6 +390,7 @@ async def test_numeric_range_valid_middle(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"score": 50.0} + @pytest.mark.anyio async def test_numeric_range_below_min(client: AsyncClient): """Tests numeric value below minimum (< 0).""" @@ -367,6 +399,7 @@ async def test_numeric_range_below_min(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Score must be between 0 and 100."} + @pytest.mark.anyio async def test_numeric_range_above_max(client: AsyncClient): """Tests numeric value above maximum (> 100).""" @@ -375,6 +408,7 @@ async def test_numeric_range_above_max(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Score must be between 0 and 100."} + @pytest.mark.anyio async def test_numeric_range_float_valid(client: AsyncClient): """Tests decimal numeric value within range.""" @@ -383,6 +417,7 @@ async def test_numeric_range_float_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"score": 75.5} + @pytest.mark.anyio async def test_numeric_range_non_numeric(client: AsyncClient): """Tests non-numeric value.""" @@ -391,6 +426,7 @@ async def test_numeric_range_non_numeric(client: AsyncClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"detail": "Score must be a number."} + # ASYNC VALIDATOR TESTS @pytest.mark.anyio async def test_async_validator_valid(client: AsyncClient): @@ -400,14 +436,16 @@ async def test_async_validator_valid(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"token": "valid_token123"} + @pytest.mark.anyio async def test_async_validator_invalid(client: AsyncClient): """Tests async validator function failure.""" cookies = {"async-token": "invalid_token123"} response = await client.get("/test-async-custom", cookies=cookies) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == 400 assert "Token must start with 'valid_'." in response.json()["detail"] + # EDGE CASE TESTS @pytest.mark.anyio async def test_cookie_with_special_characters(client: AsyncClient): @@ -416,6 +454,7 @@ async def test_cookie_with_special_characters(client: AsyncClient): response = await client.get("/test-length", cookies=cookies) assert response.status_code == status.HTTP_200_OK + @pytest.mark.anyio async def test_cookie_with_spaces(client: AsyncClient): """Tests cookie value containing spaces.""" @@ -425,6 +464,7 @@ async def test_cookie_with_spaces(client: AsyncClient): assert response.status_code == status.HTTP_200_OK assert response.json() == {"code": "a b c"} + @pytest.mark.anyio async def test_cookie_with_unicode_characters(client: AsyncClient): """Tests cookie value containing numeric and special characters.""" diff --git a/tests/test_csv_validator.py b/tests/test_csv_validator.py index c0cd279..c8d6e9b 100644 --- a/tests/test_csv_validator.py +++ b/tests/test_csv_validator.py @@ -12,9 +12,6 @@ import pytest from fastapi import UploadFile, HTTPException -# --- Module under test --- -# (Adjust this import path based on your project structure) -# We assume the code is in: fastapi_assets/validators/csv_validator.py from fastapi_assets.validators.csv_validator import CSVValidator # Mock pandas for the dependency test @@ -25,7 +22,7 @@ pd = None -# --- Fixtures --- +# Fixtures @pytest.fixture @@ -76,20 +73,20 @@ def _create_file( # Yield the factory function to the tests yield _create_file - # --- Teardown --- + # Teardown # Close all files created by the factory for f in files_to_close: f.close() -# --- Test Cases --- +# Test Cases @pytest.mark.asyncio class TestCSVValidator: """Groups all tests for the CSVValidator.""" - # --- Basic Success and File Handling --- + # Basic Success and File Handling async def test_happy_path_validation(self, mock_upload_file_factory: Callable[..., UploadFile]): """ @@ -128,7 +125,7 @@ async def test_file_is_rewound_after_validation( # Check if the file pointer is at the beginning assert await file.read() == csv_content.encode("utf-8") - # --- Dependency Check --- + # Dependency Check def test_pandas_dependency_check(self, monkeypatch): """ @@ -152,7 +149,7 @@ def test_pandas_dependency_check(self, monkeypatch): # Restore pandas for other tests monkeypatch.setattr("fastapi_assets.validators.csv_validator.pd", pd) - # --- CSV-Specific Validations --- + # CSV-Specific Validations @pytest.mark.asyncio @pytest.mark.parametrize( @@ -259,7 +256,7 @@ async def test_row_count_validation(self, mock_upload_file_factory, header_check validator_pass = CSVValidator(min_rows=3, max_rows=3, header_check_only=header_check_only) await validator_pass(mock_upload_file_factory(csv_content)) - # --- Error Handling and Custom Messages --- + # Error Handling and Custom Messages @pytest.mark.asyncio async def test_csv_parsing_error(self, mock_upload_file_factory): @@ -299,7 +296,7 @@ async def test_custom_error_messages(self, mock_upload_file_factory): await validator_row(mock_upload_file_factory(csv_content)) assert exc_row.value.detail == "File must have at least 5 data rows." - # --- Inherited Validation --- + # Inherited Validation @pytest.mark.asyncio async def test_inherited_max_size_validation(self, mock_upload_file_factory): diff --git a/tests/test_file_validator.py b/tests/test_file_validator.py index fb9b39c..3cb50a0 100644 --- a/tests/test_file_validator.py +++ b/tests/test_file_validator.py @@ -37,7 +37,7 @@ def mock_upload_file() -> MagicMock: return file -# --- Test Cases --- +# Test Cases @pytest.mark.asyncio @@ -122,7 +122,7 @@ async def test_call_invalid_content_type(self, mock_upload_file: MagicMock): with pytest.raises(HTTPException) as exc_info: await validator(mock_upload_file) - print(exc_info) + assert exc_info.value.status_code == 415 assert "unsupported media type" in exc_info.value.detail diff --git a/tests/test_header_validator.py b/tests/test_header_validator.py index 1f0fd29..e785b4d 100644 --- a/tests/test_header_validator.py +++ b/tests/test_header_validator.py @@ -4,11 +4,11 @@ import pytest from fastapi import HTTPException -from fastapi_assets.core.base_validator import ValidationError -from fastapi_assets.request_validators.header_validator import HeaderValidator +from fastapi_assets.core import ValidationError +from fastapi_assets.request_validators import HeaderValidator -# --- Fixtures --- +# Fixtures @pytest.fixture @@ -20,7 +20,7 @@ def base_validator(): @pytest.fixture def required_validator(): """Returns a HeaderValidator with required=True.""" - return HeaderValidator(required=True) + return HeaderValidator(default="Hello") @pytest.fixture @@ -45,13 +45,14 @@ def allowed_values_validator(): def custom_validator_obj(): """Returns a HeaderValidator with custom validator function.""" - def is_even_length(val: str) -> bool: - return len(val) % 2 == 0 + def is_even_length(val: str): + if len(val) % 2 != 0: + raise ValidationError(detail="Length is not even") - return HeaderValidator(validator=is_even_length) + return HeaderValidator(validators=[is_even_length]) -# --- Test Classes --- +# Test Classes class TestHeaderValidatorInit: @@ -61,44 +62,15 @@ def test_init_defaults(self): """Tests that all validation rules are None by default.""" validator = HeaderValidator() assert validator._allowed_values is None - assert validator._pattern is None - assert validator._custom_validator is None - assert validator._format_name is None - - def test_init_required_true(self): - """Tests that required flag is stored correctly.""" - validator = HeaderValidator(required=True) - assert validator._required is True - - def test_init_required_false(self): - """Tests that required can be set to False.""" - validator = HeaderValidator(required=False, default="default_value") - assert validator._required is False + assert validator._pattern_str is None + assert validator._custom_validators == [] def test_init_pattern_compilation(self): """Tests that pattern is compiled to regex.""" pattern = r"^[A-Z0-9]+$" validator = HeaderValidator(pattern=pattern) - assert validator._pattern is not None - assert validator._pattern.pattern == pattern - - def test_init_format_uuid4(self): - """Tests that format='uuid4' is recognized.""" - validator = HeaderValidator(format="uuid4") - assert validator._format_name == "uuid4" - assert validator._pattern is not None - - def test_init_format_email(self): - """Tests that format='email' is recognized.""" - validator = HeaderValidator(format="email") - assert validator._format_name == "email" - assert validator._pattern is not None - - def test_init_format_bearer_token(self): - """Tests that format='bearer_token' is recognized.""" - validator = HeaderValidator(format="bearer_token") - assert validator._format_name == "bearer_token" - assert validator._pattern is not None + assert validator._pattern_str is not None + assert validator._pattern_str == pattern def test_init_invalid_format(self): """Tests that invalid format raises ValueError.""" @@ -119,13 +91,13 @@ def test_init_allowed_values(self): def test_init_custom_validator_function(self): """Tests that custom validator function is stored.""" - def is_positive(val: str) -> bool: - return val.startswith("+") + def is_positive(val: str): + if not val.startswith("+"): + raise ValidationError(detail="Value does not start with '+'") - validator = HeaderValidator(validator=is_positive) - assert validator._custom_validator is not None - assert validator._custom_validator("+test") is True - assert validator._custom_validator("-test") is False + validator = HeaderValidator(validators=[is_positive]) + assert validator._custom_validators is not None + assert validator._custom_validators[0]("+test") is None def test_init_custom_error_detail(self): """Tests that custom error detail is stored.""" @@ -149,27 +121,6 @@ def test_required_with_value(self, required_validator): except ValidationError: pytest.fail("Required validation failed with valid value") - def test_required_missing_value(self, required_validator): - """Tests required validation fails when value is None.""" - with pytest.raises(ValidationError) as e: - required_validator._validate_required(None) - - assert e.value.status_code == 400 - assert "missing" in e.value.detail.lower() - - def test_required_empty_string(self, required_validator): - """Tests required validation fails with empty string.""" - with pytest.raises(ValidationError): - required_validator._validate_required("") - - def test_not_required_with_none(self, base_validator): - """Tests validation passes when not required and value is None.""" - base_validator._required = False - try: - base_validator._validate_required(None) - except ValidationError: - pytest.fail("Non-required validation should pass with None") - class TestHeaderValidatorValidateAllowedValues: """Tests for the _validate_allowed_values method.""" @@ -233,7 +184,7 @@ def test_pattern_invalid_match(self, pattern_validator): pattern_validator._validate_pattern("short") assert e.value.status_code == 400 - assert "does not match" in e.value.detail.lower() + assert "invalid format" in e.value.detail.lower() def test_pattern_format_uuid4_valid(self): """Tests uuid4 format validation passes.""" @@ -282,95 +233,109 @@ def test_pattern_format_email_invalid(self): class TestHeaderValidatorValidateCustom: """Tests for the _validate_custom method.""" - def test_custom_no_validator(self, base_validator): + @pytest.mark.asyncio + async def test_custom_no_validator(self, base_validator): """Tests validation passes with no custom validator.""" try: - base_validator._validate_custom("any_value") + await base_validator._validate_custom("any_value") except ValidationError: pytest.fail("Validation failed with no custom validator") - def test_custom_validator_valid(self, custom_validator_obj): + @pytest.mark.asyncio + async def test_custom_validator_valid(self, custom_validator_obj): """Tests custom validator passes on valid input.""" try: - custom_validator_obj._validate_custom("even") # 4 chars + await custom_validator_obj._validate_custom("even") # 4 chars except ValidationError: pytest.fail("Valid custom validation failed") - def test_custom_validator_invalid(self, custom_validator_obj): + @pytest.mark.asyncio + async def test_custom_validator_invalid(self, custom_validator_obj): """Tests custom validator fails on invalid input.""" with pytest.raises(ValidationError) as e: - custom_validator_obj._validate_custom("odd") # 3 chars + await custom_validator_obj._validate_custom("odd") # 3 chars assert e.value.status_code == 400 # Accept either failure message depending on your validator code + # The custom validator raises ValidationError with detail="Length is not even" assert ( - "custom validation failed" in e.value.detail.lower() + "failed custom validation" in e.value.detail.lower() or "custom validation error" in e.value.detail.lower() + or "length is not even" in e.value.detail.lower() ) - def test_custom_validator_exception(self): + @pytest.mark.asyncio + async def test_custom_validator_exception(self): """Tests custom validator exception is caught.""" - def buggy_validator(val: str) -> bool: + def buggy_validator(val: str): raise ValueError("Unexpected error") - validator = HeaderValidator(validator=buggy_validator) + validator = HeaderValidator(validators=[buggy_validator]) with pytest.raises(ValidationError) as e: - validator._validate_custom("test") + await validator._validate_custom("test") - assert "custom validation error" in e.value.detail.lower() + assert "custom validation failed" in e.value.detail.lower() class TestHeaderValidatorValidate: """Tests for the main _validate method.""" - def test_validate_valid_header(self): + @pytest.mark.asyncio + async def test_validate_valid_header(self): """Tests full validation pipeline with valid header.""" validator = HeaderValidator( required=True, allowed_values=["api", "web"], pattern=r"^[a-z]+$" ) try: - result = validator._validate("api") + result = await validator._validate("api") assert result == "api" except ValidationError: pytest.fail("Valid header failed validation") - def test_validate_fails_required(self): + @pytest.mark.asyncio + async def test_validate_fails_required(self): """Tests validation fails on required check.""" validator = HeaderValidator(required=True) - with pytest.raises(HTTPException): - validator._validate(None) + with pytest.raises(ValidationError): + await validator._validate(None) - def test_validate_fails_allowed_values(self): + @pytest.mark.asyncio + async def test_validate_fails_allowed_values(self): """Tests validation fails on allowed values check.""" validator = HeaderValidator(allowed_values=["good"]) - with pytest.raises(HTTPException): - validator._validate("bad") + with pytest.raises(ValidationError): + await validator._validate("bad") - def test_validate_fails_pattern(self): + @pytest.mark.asyncio + async def test_validate_fails_pattern(self): """Tests validation fails on pattern check.""" validator = HeaderValidator(pattern=r"^[0-9]+$") - with pytest.raises(HTTPException): - validator._validate("abc") + with pytest.raises(ValidationError): + await validator._validate("abc") - def test_validate_fails_custom(self): + @pytest.mark.asyncio + async def test_validate_fails_custom(self): """Tests validation fails on custom validator.""" - def no_spaces(val: str) -> bool: - return " " not in val + def no_spaces(val: str): + if " " in val: + raise ValidationError(detail="Spaces are not allowed") - validator = HeaderValidator(validator=no_spaces) - with pytest.raises(HTTPException): - validator._validate("has space") + validator = HeaderValidator(validators=[no_spaces]) + with pytest.raises(ValidationError): + await validator._validate("has space") - def test_validate_empty_optional_header(self): + @pytest.mark.asyncio + async def test_validate_empty_optional_header(self): """Tests optional header with empty string passes.""" - validator = HeaderValidator(required=False) - result = validator._validate("") + validator = HeaderValidator(default="") + result = await validator._validate("") assert result == "" - def test_validate_none_optional_header(self): + @pytest.mark.asyncio + async def test_validate_none_optional_header(self): """Tests optional header with None passes.""" - validator = HeaderValidator(required=False) - result = validator._validate(None) - assert result is None or result == "" + validator = HeaderValidator(default=None) + result = await validator._validate(None) + assert result is None diff --git a/tests/test_image_validator.py b/tests/test_image_validator.py index 904c044..4967cc4 100644 --- a/tests/test_image_validator.py +++ b/tests/test_image_validator.py @@ -22,7 +22,7 @@ class MockValidationError(Exception): """Mock a ValidationError for testing.""" - def __init__(self, detail: str, status_code: int): + def __init__(self, detail: str, status_code: int = 400): self.detail = detail self.status_code = status_code super().__init__(detail) @@ -129,6 +129,7 @@ def create_mock_image_file( file.filename = filename file.content_type = content_type file.file = buffer + file.size = len(buffer.getvalue()) # Set the size attribute # Create a wrapper for seek async def mock_seek(offset): @@ -161,6 +162,7 @@ def create_mock_text_file(filename: str) -> UploadFile: file.filename = filename file.content_type = "text/plain" file.file = buffer + file.size = len(buffer.getvalue()) # Set the size attribute # Create a wrapper for seek async def mock_seek(offset): @@ -281,7 +283,7 @@ async def test_inherited_max_size_failure(self): await validator(file) assert exc_info.value.status_code == 413 # From our mock - assert "File is too large" in exc_info.value.detail + assert "exceeds the maximum limit" in exc_info.value.detail finally: await file.close() diff --git a/tests/test_path_validator.py b/tests/test_path_validator.py index 986e4e1..af808fd 100644 --- a/tests/test_path_validator.py +++ b/tests/test_path_validator.py @@ -1,730 +1,273 @@ """ -tests for the PathValidator class. +Test suite for the PathValidator class. """ -from fastapi import HTTPException + import pytest -from fastapi_assets.core.base_validator import ValidationError -from fastapi_assets.request_validators.path_validator import PathValidator - -# Fixtures for common PathValidator configurations -@pytest.fixture -def base_validator(): - """Returns a basic PathValidator with no rules.""" - return PathValidator() - -@pytest.fixture -def numeric_validator(): - """Returns a PathValidator configured for numeric validation.""" - return PathValidator(gt=0, lt=1000) - -@pytest.fixture -def string_validator(): - """Returns a PathValidator configured for string validation.""" - return PathValidator( +import asyncio +from fastapi import FastAPI, Depends, HTTPException +from fastapi.testclient import TestClient +from typing import Any, Callable, List + + +class MockValidationError(Exception): + """Minimal mock of the custom ValidationError.""" + + def __init__(self, detail: str, status_code: int = 400): + self.detail = detail + self.status_code = status_code + super().__init__(detail) + + +class MockBaseValidator: + """ + Minimal mock of the BaseValidator to provide + the methods PathValidator inherits. + """ + + def __init__( + self, status_code: int, error_detail: str, validators: List[Callable] | None = None + ): + self._status_code = status_code + self._error_detail = error_detail + self._custom_validators = validators or [] + + async def _validate_custom(self, value: Any) -> None: + """Mock implementation of custom validator runner.""" + for validator in self._custom_validators: + try: + if asyncio.iscoroutinefunction(validator): + await validator(value) + else: + validator(value) + except Exception as e: + # Raise the specific error PathValidator expects + raise MockValidationError(detail=str(e), status_code=400) from e + + def _raise_error(self, value: Any, status_code: int, detail: str) -> None: + """Mock implementation of the error raiser.""" + raise HTTPException(status_code=status_code, detail=detail) + + +# Patch the imports in the module to be tested +# This is a professional testing pattern to inject mocks +import sys +import unittest.mock + +# Create mock modules +mock_core_module = unittest.mock.MagicMock() +mock_core_module.BaseValidator = MockBaseValidator +mock_core_module.ValidationError = MockValidationError + +# Add the mock module to sys.modules +# This ensures that when 'path_validator' imports from 'fastapi_assets.core', +# it gets our mock classes. +sys.modules["fastapi_assets.core"] = mock_core_module + +# Now we can safely import the class to be tested +from fastapi_assets.request_validators import PathValidator + +# +# Test Cases +# + + +def test_standard_path_validation_numeric(): + """ + Tests that standard validations (gt, lt) from fastapi.Path + are correctly applied and that type coercion works. + """ + app = FastAPI() + item_id_validator = PathValidator("item_id", _type=int, gt=0, lt=10) + + @app.get("/items/{item_id}") + def get_item(item_id: int = Depends(item_id_validator())): + # We also check the type to ensure coercion from string happened + return {"item_id": item_id, "type": str(type(item_id))} + + client = TestClient(app) + + # 1. Success case + response = client.get("/items/5") + assert response.status_code == 200 + assert response.json() == {"item_id": 5, "type": ""} + + # 2. Failure case (gt) + response = client.get("/items/0") + assert response.status_code == 422 # Pydantic validation error + assert "greater than 0" in response.text + + # 3. Failure case (lt) + response = client.get("/items/10") + assert response.status_code == 422 + assert "less than 10" in response.text + + # 4. Failure case (type coercion) + response = client.get("/items/abc") + assert response.status_code == 422 + assert "Input should be a valid integer" in response.text + + +def test_standard_path_validation_string(): + """ + Tests that standard string validations (min_length, max_length, pattern) + from fastapi.Path are correctly applied. + """ + app = FastAPI() + username_validator = PathValidator( + "username", + _type=str, min_length=3, - max_length=15, - pattern=r"^[a-zA-Z0-9_]+$" + max_length=5, + pattern=r"^[a-z]+$", # only lowercase letters ) -@pytest.fixture -def allowed_values_validator(): - """Returns a PathValidator with allowed values.""" - return PathValidator( - allowed_values=["active", "inactive", "pending"] + @app.get("/users/{username}") + def get_user(username: str = Depends(username_validator())): + return {"username": username} + + client = TestClient(app) + + # 1. Success case + response = client.get("/users/abc") + assert response.status_code == 200 + assert response.json() == {"username": "abc"} + + # 2. Failure case (min_length) + response = client.get("/users/ab") + assert response.status_code == 422 + assert "at least 3 characters" in response.text + + # 3. Failure case (max_length) + response = client.get("/users/abcdef") + assert response.status_code == 422 + assert "at most 5 characters" in response.text + + # 4. Failure case (pattern) + response = client.get("/users/123") + assert response.status_code == 422 + assert "String should match pattern" in response.text + + +def test_custom_validation_allowed_values(): + """ + Tests the custom 'allowed_values' feature of PathValidator. + """ + app = FastAPI() + mode_validator = PathValidator("mode", _type=str, allowed_values=["read", "write"]) + + @app.get("/modes/{mode}") + def get_mode(mode: str = Depends(mode_validator())): + return {"mode": mode} + + client = TestClient(app) + + # 1. Success cases + response_read = client.get("/modes/read") + assert response_read.status_code == 200 + assert response_read.json() == {"mode": "read"} + + response_write = client.get("/modes/write") + assert response_write.status_code == 200 + assert response_write.json() == {"mode": "write"} + + # 2. Failure case + response = client.get("/modes/admin") + # This fails our custom check, which raises an HTTPException + # based on the (mocked) _raise_error method. + assert response.status_code == 400 + assert "Value 'admin' is not allowed" in response.text + assert "Allowed values are: read, write" in response.text + + +def test_custom_validation_validators_list(): + """ + Tests the custom 'validators' list with both sync and async functions. + """ + + # Custom validator functions for this test + def must_be_even(value: int): + """Sync validator.""" + if value % 2 != 0: + raise ValueError("Value must be even") + + async def must_be_multiple_of_three(value: int): + """Async validator.""" + await asyncio.sleep(0) # Simulate async work + if value % 3 != 0: + raise Exception("Value must be a multiple of three") + + # - + + app = FastAPI() + custom_num_validator = PathValidator( + "num", _type=int, validators=[must_be_even, must_be_multiple_of_three] ) -# Test class for constructor __init__ behavior -class TestPathValidatorInit: - def test_init_defaults(self): - """Tests that all validation rules are None by default.""" - validator = PathValidator() - assert validator._allowed_values is None - assert validator._pattern is None - assert validator._min_length is None - assert validator._max_length is None - assert validator._gt is None - assert validator._lt is None - assert validator._ge is None - assert validator._le is None - assert validator._custom_validator is None - - def test_init_allowed_values(self): - """Tests that allowed_values are stored correctly.""" - values = ["active", "inactive"] - validator = PathValidator(allowed_values=values) - assert validator._allowed_values == values - - def test_init_pattern_compilation(self): - """Tests that regex pattern is compiled.""" - pattern = r"^[a-z0-9]+$" - validator = PathValidator(pattern=pattern) - assert validator._pattern is not None - assert validator._pattern.pattern == pattern - - def test_init_numeric_bounds(self): - """Tests that numeric bounds are stored correctly.""" - validator = PathValidator(gt=0, lt=100, ge=1, le=99) - assert validator._gt == 0 - assert validator._lt == 100 - assert validator._ge == 1 - assert validator._le == 99 - - def test_init_length_bounds(self): - """Tests that length bounds are stored correctly.""" - validator = PathValidator(min_length=5, max_length=20) - assert validator._min_length == 5 - assert validator._max_length == 20 - - def test_init_custom_error_detail(self): - """Tests that custom error messages are stored.""" - custom_error = "Invalid path parameter" - validator = PathValidator(error_detail=custom_error) - print(validator._error_detail) - - # _error_detail attribute holds error message - assert validator._error_detail == custom_error or custom_error in str(validator.__dict__) - - def test_init_custom_validator_function(self): - """Tests that custom validator function is stored.""" - def is_even(x): return x % 2 == 0 - validator = PathValidator(validator=is_even) - # Validate custom function works - assert validator._custom_validator(4) is True - assert validator._custom_validator(3) is False - - def test_init_fastapi_path_creation(self): - """Tests that internal FastAPI Path object is created.""" - validator = PathValidator( - title="Item ID", - description="The unique identifier", - gt=0, - lt=1000 - ) - assert validator._path_param is not None - - def test_init_combined_rules(self): - """Tests initialization with multiple combined rules.""" - validator = PathValidator( - min_length=3, - max_length=20, - pattern=r"^[a-zA-Z]+$", - title="Category", - description="Product category slug" - ) - assert validator._min_length == 3 - assert validator._max_length == 20 - assert validator._pattern is not None - -# Validation method tests -class TestPathValidatorValidateAllowedValues: - def test_allowed_values_no_rule(self, base_validator): - """Validation should pass if no rule is set.""" - try: - base_validator._validate_allowed_values("any_value") - except ValidationError: - pytest.fail("Validation failed when no rule was set.") - - def test_allowed_values_valid(self, allowed_values_validator): - """Test valid allowed value.""" - try: - allowed_values_validator._validate_allowed_values("active") - except ValidationError: - pytest.fail("Failed on valid allowed value.") - - def test_allowed_values_invalid(self, allowed_values_validator): - """Test invalid allowed value raises ValidationError.""" - with pytest.raises(ValidationError): - allowed_values_validator._validate_allowed_values("deleted") - -class TestPathValidatorValidatePattern: - def test_pattern_no_rule(self, base_validator): - """Validation passes when no pattern rule.""" - try: - base_validator._validate_pattern("anything@123!@#") - except ValidationError: - pytest.fail("Validation failed when no pattern rule.") - - def test_pattern_valid_match(self, string_validator): - """Valid pattern match.""" - try: - string_validator._validate_pattern("user_123") - except ValidationError: - pytest.fail("Validation failed on valid pattern.") - - def test_pattern_invalid_match(self, string_validator): - """Invalid pattern raises ValidationError.""" - with pytest.raises(ValidationError): - string_validator._validate_pattern("user@123") - - def test_pattern_non_string_ignored(self, string_validator): - """Skip pattern validation for non-strings.""" - try: - string_validator._validate_pattern(123) - except ValidationError: - pytest.fail("Pattern validation should not apply to non-strings.") - - def test_pattern_email_like(self): - """Email pattern with valid and invalid cases.""" - validator = PathValidator(pattern=r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") - try: - validator._validate_pattern("user.name+tag@example.com") - except ValidationError: - pytest.fail("Valid email-like pattern failed") - with pytest.raises(ValidationError): - validator._validate_pattern("user@domain") # missing TLD - -# Length validation tests -class TestPathValidatorValidateLength: - def test_length_no_rule(self, base_validator): - """Validation passes when no length rule.""" - try: - base_validator._validate_length("x") - base_validator._validate_length("longer") - except ValidationError: - pytest.fail("Failed no length rule.") - - def test_length_valid_within_bounds(self, string_validator): - """Valid length within bounds.""" - try: - string_validator._validate_length("hello") - except ValidationError: - pytest.fail("Failed valid length.") - - def test_length_too_short(self, string_validator): - """Fails if shorter than min_length.""" - with pytest.raises(ValidationError): - string_validator._validate_length("ab") - - def test_length_too_long(self, string_validator): - """Fails if longer than max_length.""" - with pytest.raises(ValidationError): - string_validator._validate_length("a"*20) - -# Numeric bounds validation -class TestPathValidatorValidateNumericBounds: - def test_no_rule(self, base_validator): - try: - base_validator._validate_numeric_bounds(999) - base_validator._validate_numeric_bounds(-999) - except ValidationError: - pytest.fail("Failed no numeric rule.") - - def test_gt_lt(self, numeric_validator): - try: - numeric_validator._validate_numeric_bounds(1) - numeric_validator._validate_numeric_bounds(999) - except ValidationError: - pytest.fail("Failed valid bounds.") - with pytest.raises(ValidationError): - numeric_validator._validate_numeric_bounds(0) - - def test_ge_le(self): - validator = PathValidator(ge=0, le=10) - try: - validator._validate_numeric_bounds(0) - validator._validate_numeric_bounds(10) - except ValidationError: - pytest.fail("Failed boundary values.") - with pytest.raises(ValidationError): - validator._validate_numeric_bounds(-1) - -# Custom validation tests -class TestPathValidatorValidateCustom: - def test_no_custom_validator(self, base_validator): - try: - base_validator._validate_custom("test") - except ValidationError: - pytest.fail("Failed with no custom validator.") - def test_valid_custom(self): - def is_even(x): return x % 2 == 0 - v = PathValidator(validator=is_even) - try: - v._validate_custom(4) - except ValidationError: - pytest.fail("Valid custom validation failed.") - def test_invalid_custom(self): - def is_even(x): return x % 2 == 0 - v = PathValidator(validator=is_even) - with pytest.raises(ValidationError): - v._validate_custom(3) - -# Integration of multiple validations -class TestPathValidatorIntegration: - def test_combined_valid(self): - v = PathValidator(allowed_values=["ok"], pattern=r"^ok$", min_length=2, max_length=2) - try: - v._validate("ok") - except ValidationError: - pytest.fail("Valid data failed validation.") - - def test_fail_in_combined(self): - v = PathValidator(allowed_values=["ok"], pattern=r"^ok$", min_length=2, max_length=2) - with pytest.raises(HTTPException): - v._validate("no") - - -# Edge case tests for bounds -class TestPathValidatorNumericEdgeCases: - """Test edge cases and boundary conditions for numeric validation.""" - - def test_gt_with_equal_value(self): - """Value equal to gt boundary should fail.""" - validator = PathValidator(gt=10) - with pytest.raises(ValidationError) as exc_info: - validator._validate_numeric_bounds(10) - assert "greater than 10" in str(exc_info.value.detail) - - def test_lt_with_equal_value(self): - """Value equal to lt boundary should fail.""" - validator = PathValidator(lt=10) - with pytest.raises(ValidationError) as exc_info: - validator._validate_numeric_bounds(10) - assert "less than 10" in str(exc_info.value.detail) - - def test_ge_with_equal_value(self): - """Value equal to ge boundary should pass.""" - validator = PathValidator(ge=10) - try: - validator._validate_numeric_bounds(10) - except ValidationError: - pytest.fail("GE with equal value should pass") - - def test_le_with_equal_value(self): - """Value equal to le boundary should pass.""" - validator = PathValidator(le=10) - try: - validator._validate_numeric_bounds(10) - except ValidationError: - pytest.fail("LE with equal value should pass") - - def test_negative_numeric_bounds(self): - """Test numeric bounds with negative values.""" - validator = PathValidator(gt=-100, lt=-10) - try: - validator._validate_numeric_bounds(-50) - except ValidationError: - pytest.fail("Valid negative value failed") - with pytest.raises(ValidationError): - validator._validate_numeric_bounds(-100) - - def test_float_numeric_bounds(self): - """Test numeric bounds with float values.""" - validator = PathValidator(gt=0.0, lt=1.0) - try: - validator._validate_numeric_bounds(0.5) - except ValidationError: - pytest.fail("Valid float value failed") - with pytest.raises(ValidationError): - validator._validate_numeric_bounds(1.0) - - def test_zero_as_boundary(self): - """Test with zero as boundary value.""" - validator = PathValidator(ge=0, le=0) - try: - validator._validate_numeric_bounds(0) - except ValidationError: - pytest.fail("Zero should be valid with ge=0, le=0") - with pytest.raises(ValidationError): - validator._validate_numeric_bounds(1) - - -# Edge case tests for string length -class TestPathValidatorStringEdgeCases: - """Test edge cases and boundary conditions for string validation.""" - - def test_empty_string_with_min_length(self): - """Empty string should fail if min_length is set.""" - validator = PathValidator(min_length=1) - with pytest.raises(ValidationError) as exc_info: - validator._validate_length("") - assert "too short" in str(exc_info.value.detail) - - def test_min_length_exact(self): - """String exactly at min_length should pass.""" - validator = PathValidator(min_length=5) - try: - validator._validate_length("exact") - except ValidationError: - pytest.fail("Exact min_length should pass") - - def test_max_length_exact(self): - """String exactly at max_length should pass.""" - validator = PathValidator(max_length=5) - try: - validator._validate_length("exact") - except ValidationError: - pytest.fail("Exact max_length should pass") - - def test_unicode_string_length(self): - """Test length validation with unicode characters.""" - validator = PathValidator(min_length=3, max_length=5) - try: - validator._validate_length("😀😁😂") # 3 emoji characters - except ValidationError: - pytest.fail("Valid unicode string failed") - - def test_zero_length_bounds(self): - """Test with min and max length of zero.""" - validator = PathValidator(min_length=0, max_length=0) - try: - validator._validate_length("") - except ValidationError: - pytest.fail("Empty string should be valid with min=0, max=0") - with pytest.raises(ValidationError): - validator._validate_length("x") - - -# Edge case tests for pattern matching -class TestPathValidatorPatternEdgeCases: - """Test edge cases for regex pattern validation.""" - - def test_pattern_with_special_characters(self): - """Pattern with special regex characters.""" - validator = PathValidator(pattern=r"^[\w\-\.]+@[\w\-\.]+\.\w+$") - try: - validator._validate_pattern("user-name.test@sub-domain.co.uk") - except ValidationError: - pytest.fail("Valid email-like pattern failed") - with pytest.raises(ValidationError): - validator._validate_pattern("invalid@domain") - - def test_pattern_case_sensitive(self): - """Regex patterns are case-sensitive by default.""" - validator = PathValidator(pattern=r"^[a-z]+$") - try: - validator._validate_pattern("lowercase") - except ValidationError: - pytest.fail("Lowercase letters should match [a-z]") - with pytest.raises(ValidationError): - validator._validate_pattern("UPPERCASE") - - def test_pattern_with_anchors(self): - """Pattern with start and end anchors.""" - validator = PathValidator(pattern=r"^START.*END$") - try: - validator._validate_pattern("START-middle-END") - except ValidationError: - pytest.fail("String with anchors should match") - with pytest.raises(ValidationError): - validator._validate_pattern("MIDDLE-START-END") - - def test_pattern_match_from_start(self): - """re.match() only matches from the start of string.""" - validator = PathValidator(pattern=r"test") - try: - validator._validate_pattern("test_string") - except ValidationError: - pytest.fail("Pattern should match from start") - # This should fail because re.match only checks beginning - with pytest.raises(ValidationError): - validator._validate_pattern("this_is_a_test_string") - - def test_pattern_with_groups(self): - """Pattern with capture groups.""" - validator = PathValidator(pattern=r"^(\d{4})-(\d{2})-(\d{2})$") - try: - validator._validate_pattern("2025-11-04") - except ValidationError: - pytest.fail("Valid date format should match") - with pytest.raises(ValidationError): - validator._validate_pattern("2025/11/04") - - -# Allowed values edge cases -class TestPathValidatorAllowedValuesEdgeCases: - """Test edge cases for allowed values validation.""" - - def test_allowed_values_with_none(self): - """Test when None is in allowed values.""" - validator = PathValidator(allowed_values=[None, "active", "inactive"]) - try: - validator._validate_allowed_values(None) - except ValidationError: - pytest.fail("None should be allowed if in list") - - def test_allowed_values_case_sensitive(self): - """Allowed values matching is case-sensitive.""" - validator = PathValidator(allowed_values=["Active", "Inactive"]) - try: - validator._validate_allowed_values("Active") - except ValidationError: - pytest.fail("Case-sensitive match should work") - with pytest.raises(ValidationError): - validator._validate_allowed_values("active") - - def test_allowed_values_numeric_types(self): - """Test allowed values with numeric types.""" - validator = PathValidator(allowed_values=[1, 2, 3]) - try: - validator._validate_allowed_values(2) - except ValidationError: - pytest.fail("Numeric allowed value should work") - with pytest.raises(ValidationError): - validator._validate_allowed_values("2") # String "2" != int 2 - - def test_allowed_values_empty_list(self): - """Empty allowed values list should reject everything.""" - validator = PathValidator(allowed_values=[]) - with pytest.raises(ValidationError): - validator._validate_allowed_values("anything") - - def test_allowed_values_with_duplicates(self): - """Allowed values list with duplicates.""" - validator = PathValidator(allowed_values=["status", "status", "active"]) - try: - validator._validate_allowed_values("status") - except ValidationError: - pytest.fail("Duplicates shouldn't affect validation") - - -# Custom validator edge cases -class TestPathValidatorCustomValidatorEdgeCases: - """Test edge cases for custom validator functions.""" - - def test_custom_validator_exception_handling(self): - """Custom validator that raises exception.""" - def bad_validator(x): - raise ValueError("Something went wrong") - - validator = PathValidator(validator=bad_validator) - with pytest.raises(ValidationError) as exc_info: - validator._validate_custom("test") - assert "Custom validation error" in str(exc_info.value.detail) - - def test_custom_validator_returns_false(self): - """Custom validator returns False.""" - def always_fail(x): - return False - - validator = PathValidator(validator=always_fail) - with pytest.raises(ValidationError) as exc_info: - validator._validate_custom("test") - assert "Custom validation failed" in str(exc_info.value.detail) - - def test_custom_validator_returns_true(self): - """Custom validator returns True.""" - def always_pass(x): - return True - - validator = PathValidator(validator=always_pass) - try: - validator._validate_custom("test") - except ValidationError: - pytest.fail("Custom validator returning True should pass") - - def test_custom_validator_with_complex_logic(self): - """Custom validator with complex validation logic.""" - def validate_phone(phone): - import re - return bool(re.match(r"^\+?1?\d{9,15}$", str(phone))) - - validator = PathValidator(validator=validate_phone) - try: - validator._validate_custom("+14155552671") - except ValidationError: - pytest.fail("Valid phone should pass") - with pytest.raises(ValidationError): - validator._validate_custom("123") - - def test_custom_validator_lambda(self): - """Custom validator using lambda function.""" - validator = PathValidator(validator=lambda x: len(str(x)) > 3) - try: - validator._validate_custom("test") - except ValidationError: - pytest.fail("Lambda validator should work") - with pytest.raises(ValidationError): - validator._validate_custom("ab") - - -# Complete validation flow tests -class TestPathValidatorCompleteFlow: - """Test complete validation flows with multiple rules.""" - - def test_all_validations_pass(self): - """All validation rules pass together.""" - validator = PathValidator( - allowed_values=["user_123", "admin_456"], - pattern=r"^[a-z]+_\d+$", - min_length=7, - max_length=10, - validator=lambda x: "_" in x - ) - try: - validator._validate("user_123") - except (ValidationError, HTTPException): - pytest.fail("All validations should pass") - - def test_fail_on_first_validation(self): - """Validation fails on first rule.""" - validator = PathValidator( - allowed_values=["valid"], - pattern=r"^[a-z]+$", - min_length=3 - ) - with pytest.raises(HTTPException): - validator._validate("invalid") - - def test_multiple_combined_rules(self): - """Complex scenario with multiple rules.""" - validator = PathValidator( - min_length=5, - max_length=15, - pattern=r"^[a-zA-Z0-9_-]+$", - allowed_values=["user_name", "admin_test", "guest-user"], - validator=lambda x: not x.startswith("_") - ) - for valid_value in ["user_name", "admin_test", "guest-user"]: - try: - validator._validate(valid_value) - except (ValidationError, HTTPException): - pytest.fail(f"'{valid_value}' should be valid") - - def test_validation_error_messages(self): - """Validation error messages are informative.""" - validator = PathValidator( - allowed_values=["a", "b", "c"], - min_length=2, - max_length=5 - ) - try: - validator._validate("d") - except HTTPException as e: - assert "not allowed" in str(e.detail).lower() or "validation" in str(e.detail).lower() - - -# Non-string and non-numeric type handling -class TestPathValidatorTypeHandling: - """Test handling of various data types.""" - - def test_non_string_skips_string_validations(self): - """Non-string types skip string-specific validations.""" - validator = PathValidator(min_length=3, max_length=10) - try: - validator._validate_length(123) - validator._validate_pattern(123) - except ValidationError: - pytest.fail("Non-strings should skip string validations") - - def test_non_numeric_skips_numeric_validations(self): - """Non-numeric types skip numeric-specific validations.""" - validator = PathValidator(gt=0, lt=100) - try: - validator._validate_numeric_bounds("test") - except ValidationError: - pytest.fail("Non-numeric should skip numeric validations") - - def test_boolean_type_validation(self): - """Test validation with boolean values.""" - validator = PathValidator(allowed_values=[True, False]) - try: - validator._validate_allowed_values(True) - validator._validate_allowed_values(False) - except ValidationError: - pytest.fail("Booleans should validate against allowed values") - - def test_list_type_validation(self): - """Test validation with list/collection types.""" - validator = PathValidator( - allowed_values=[[1, 2], [3, 4], [5, 6]], - validator=lambda x: isinstance(x, list) - ) - try: - validator._validate_allowed_values([1, 2]) - validator._validate_custom([3, 4]) - except ValidationError: - pytest.fail("Lists should validate correctly") - - -# Initialization parameter combinations -class TestPathValidatorInitParameterCombinations: - """Test various parameter combinations during initialization.""" - - def test_init_with_all_parameters(self): - """Initialize with all possible parameters.""" - validator = PathValidator( - default=..., - allowed_values=["a", "b"], - pattern=r"^[a-z]$", - min_length=1, - max_length=1, - gt=0, - lt=10, - ge=1, - le=9, - validator=lambda x: x in ["a", "b"], - title="Test Parameter", - description="A test path parameter", - alias="test_param", - deprecated=False, - error_detail="Test error", - status_code=422 - ) - assert validator._allowed_values == ["a", "b"] - assert validator._pattern is not None - assert validator._min_length == 1 - assert validator._max_length == 1 - - def test_init_only_required(self): - """Initialize with only required parameters.""" - validator = PathValidator() - assert validator._allowed_values is None - assert validator._pattern is None - assert validator._min_length is None - assert validator._max_length is None - - def test_init_with_only_custom_validator(self): - """Initialize with only custom validator.""" - custom = lambda x: x > 0 - validator = PathValidator(validator=custom) - assert validator._custom_validator is custom - assert validator._allowed_values is None - - def test_status_code_default(self): - """Default status code should be 400.""" - validator = PathValidator() - # Status code is set in parent class - - -# Error message verification tests -class TestPathValidatorErrorMessages: - """Test that error messages are clear and informative.""" - - def test_allowed_values_error_message(self): - """Error message includes list of allowed values.""" - validator = PathValidator(allowed_values=["a", "b", "c"]) - try: - validator._validate_allowed_values("d") - except ValidationError as e: - assert "a" in str(e.detail) - assert "b" in str(e.detail) - assert "c" in str(e.detail) - - def test_pattern_error_message_includes_pattern(self): - """Error message includes the regex pattern.""" - pattern = r"^[0-9]{3}$" - validator = PathValidator(pattern=pattern) - try: - validator._validate_pattern("abc") - except ValidationError as e: - assert pattern in str(e.detail) - - def test_length_error_message_info(self): - """Length error includes bounds information.""" - validator = PathValidator(min_length=5, max_length=10) - try: - validator._validate_length("ab") - except ValidationError as e: - assert "5" in str(e.detail) - try: - validator._validate_length("a" * 15) - except ValidationError as e: - assert "10" in str(e.detail) - - def test_numeric_bounds_error_messages(self): - """Numeric bounds errors include boundary values.""" - validator = PathValidator(gt=100) - try: - validator._validate_numeric_bounds(50) - except ValidationError as e: - assert "100" in str(e.detail) + @app.get("/nums/{num}") + def get_num(num: int = Depends(custom_num_validator())): + return {"num": num} + + client = TestClient(app) + + # 1. Success case (passes both validators) + response = client.get("/nums/6") + assert response.status_code == 200 + assert response.json() == {"num": 6} + + # 2. Failure case (fails sync validator) + response = client.get("/nums/9") + assert response.status_code == 400 + assert "Value must be even" in response.text + + # 3. Failure case (fails async validator) + response = client.get("/nums/4") + assert response.status_code == 400 + assert "Value must be a multiple of three" in response.text + + +def test_validator_isolation(): + """ + Tests that multiple PathValidator instances on the same app + do not interfere with each other's signatures. This is the + most critical test given the history of bugs. + """ + app = FastAPI() + + # 1. Define two different validators + item_id_validator = PathValidator("item_id", _type=int, gt=10) + username_validator = PathValidator("username", _type=str, min_length=5) + + # 2. Define two separate endpoints + @app.get("/items/{item_id}") + def get_item(item_id: int = Depends(item_id_validator())): + return {"item_id": item_id} + + @app.get("/users/{username}") + def get_user(username: str = Depends(username_validator())): + return {"username": username} + + client = TestClient(app) + + # 3. Test both endpoints successfully + response_item = client.get("/items/11") + assert response_item.status_code == 200 + assert response_item.json() == {"item_id": 11} + + response_user = client.get("/users/administrator") + assert response_user.status_code == 200 + assert response_user.json() == {"username": "administrator"} + + # 4. Test failure on the *first* endpoint + response_item_fail = client.get("/items/5") + assert response_item_fail.status_code == 422 + # CRITICAL: Error must be about 'item_id', not 'username' + assert "item_id" in response_item_fail.text + assert "greater than 10" in response_item_fail.text + assert "username" not in response_item_fail.text + + # 5. Test failure on the *second* endpoint + response_user_fail = client.get("/users/adm") + assert response_user_fail.status_code == 422 + assert "username" in response_user_fail.text + assert "at least 5 characters" in response_user_fail.text + assert "item_id" not in response_user_fail.text