diff --git a/fastapi_assets/request_validators/__init__.py b/fastapi_assets/request_validators/__init__.py index d34a237..a6ccd5e 100644 --- a/fastapi_assets/request_validators/__init__.py +++ b/fastapi_assets/request_validators/__init__.py @@ -3,3 +3,4 @@ from fastapi_assets.request_validators.header_validator import HeaderValidator from fastapi_assets.request_validators.cookie_validator import CookieValidator from fastapi_assets.request_validators.path_validator import PathValidator +from fastapi_assets.request_validators.query_validator import QueryValidator diff --git a/fastapi_assets/request_validators/query_validator.py b/fastapi_assets/request_validators/query_validator.py new file mode 100644 index 0000000..1d0fd47 --- /dev/null +++ b/fastapi_assets/request_validators/query_validator.py @@ -0,0 +1,200 @@ +"""Module providing the QueryValidator for validating query parameters in FastAPI.""" + +from typing import Any, Callable, List, Optional, Union +from inspect import Signature, Parameter +from fastapi import Query +from fastapi_assets.core import BaseValidator, ValidationError + + +class QueryValidator(BaseValidator): + r""" + A dependency factory for adding custom validation to FastAPI query parameters. + + This class extends the functionality of FastAPI's `Query()` by adding + support for `allowed_values` and custom `validators`. + + It acts as a factory: you instantiate it, and then *call* the + instance inside `Depends()` to get the actual dependency. + + Example: + .. code-block:: python + + from fastapi import FastAPI, Depends + from fastapi_assets.request_validators import QueryValidator + + app = FastAPI() + + # 1. Create reusable validator *instances* + page_validator = QueryValidator( + "page", + _type=int, + default=1, + ge=1, + le=100, + ) + + status_validator = QueryValidator( + "status", + _type=str, + allowed_values=["active", "inactive", "pending"], + ) + + sort_validator = QueryValidator( + "sort", + _type=str, + default="name", + pattern=r"^[a-zA-Z_]+$", + ) + + @app.get("/items/") + def list_items( + page: int = Depends(page_validator()), + status: str = Depends(status_validator()), + sort: str = Depends(sort_validator()), + ): + return {"page": page, "status": status, "sort": sort} + """ + + def __init__( + self, + param_name: str, + _type: type, + default: Any = ..., + *, + # Custom validation rules + allowed_values: Optional[List[Any]] = None, + validators: Optional[List[Callable[[Any], Any]]] = None, + on_custom_validator_error_detail: str = "Custom validation failed.", + # Standard Query() parameters + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[Union[int, float]] = None, + lt: Optional[Union[int, float]] = None, + ge: Optional[Union[int, float]] = None, + le: Optional[Union[int, float]] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + deprecated: Optional[bool] = None, + **query_kwargs: Any, + ) -> None: + """ + Initializes the QueryValidator factory. + + Args: + param_name: The exact name of the query parameter. + _type: The Python type for coercion (e.g., int, str, UUID). + default: Default value for the query parameter. + allowed_values: List of allowed values. + validators: List of custom validation functions. + on_custom_validator_error_detail: Error message for custom validators. + title: Title for API documentation. + description: Description for API documentation. + gt: Value must be greater than this. + lt: Value must be less than this. + ge: Value must be greater than or equal to this. + le: Value must be less than or equal to this. + min_length: Minimum length for string parameters. + max_length: Maximum length for string parameters. + pattern: Regex pattern the parameter must match. + deprecated: Whether the parameter is deprecated. + **query_kwargs: Additional arguments passed to FastAPI's Query(). + """ + query_kwargs.setdefault("error_detail", "Query parameter validation failed.") + query_kwargs.setdefault("status_code", 400) + + super().__init__( + status_code=query_kwargs["status_code"], + error_detail=query_kwargs["error_detail"], + validators=validators, + ) + + self._param_name = param_name + self._type = _type + self._allowed_values = allowed_values + self._on_custom_validator_error_detail = on_custom_validator_error_detail + + self._query_param = Query( + default, + title=title, + description=description, + deprecated=deprecated, + gt=gt, + lt=lt, + ge=ge, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + **query_kwargs, + ) + + def __call__(self) -> Callable[..., Any]: + """ + This is the factory method. + It generates and returns the dependency function + that FastAPI will use. + """ + + async def dependency(**kwargs: Any) -> Any: + query_value = kwargs[self._param_name] + try: + validated_value = await self._validate(query_value) + return validated_value + except ValidationError as e: + self._raise_error(query_value, status_code=e.status_code, detail=e.detail) + return None + + sig = Signature( + [ + Parameter( + self._param_name, + Parameter.KEYWORD_ONLY, + default=self._query_param, + annotation=self._type, + ) + ] + ) + + dependency.__signature__ = sig # type: ignore + return dependency + + async def _validate(self, value: Any) -> Any: + """ + Runs all validation checks on the query parameter value. + + Executes allowed values checking and custom validator checking in sequence. + + Args: + value: The query parameter value to validate. + + Returns: + Any: The validated value (unchanged if validation passes). + + Raises: + ValidationError: If any validation check fails. + """ + self._validate_allowed_values(value) + await self._validate_custom(value) + return value + + def _validate_allowed_values(self, value: Any) -> None: + """ + Checks if the query parameter value is in the list of allowed values. + + Args: + value: The value to validate. + + Returns: + None + + Raises: + ValidationError: If the value is not in the allowed values list. + """ + if self._allowed_values is None: + return + + if value not in self._allowed_values: + allowed_str = ", ".join(map(str, self._allowed_values)) + detail = f"Value '{value}' is not allowed. Allowed values are: {allowed_str}" + raise ValidationError(detail=detail, status_code=400) diff --git a/tests/test_query_validator.py b/tests/test_query_validator.py new file mode 100644 index 0000000..85d1074 --- /dev/null +++ b/tests/test_query_validator.py @@ -0,0 +1,244 @@ +"""Tests for the QueryValidator class.""" + +import pytest +import asyncio +from typing import Any, Callable, List, Optional, Union +from inspect import Signature, Parameter +from fastapi import FastAPI, Depends, Query, HTTPException +from fastapi.testclient import TestClient +from fastapi_assets.request_validators import QueryValidator +from fastapi_assets.core import ValidationError + + +def get_app_and_client(validator_instance: QueryValidator) -> tuple[FastAPI, TestClient]: + """Helper function to create a test app for a given validator.""" + app = FastAPI() + + @app.get("/validate/") + def validate_endpoint( + # The key is to call the instance to get the dependency function + param: Any = Depends(validator_instance()), + ): + return {"validated_param": param} + + client = TestClient(app) + return app, client + + +def test_standard_query_validation_ge(): + """ + Tests that standard Query params (like 'ge') are enforced by FastAPI + before our custom validation runs. This should result in a 422 error. + """ + page_validator = QueryValidator("page", _type=int, default=1, ge=1) + app, client = get_app_and_client(page_validator) + + # Test valid case + response = client.get("/validate/?page=5") + assert response.status_code == 200 + assert response.json() == {"validated_param": 5} + + # Test invalid case (FastAPI's built-in 'ge' validation) + response = client.get("/validate/?page=0") + assert response.status_code == 422 # Unprocessable Entity + assert "greater than or equal to 1" in response.text + + +def test_standard_query_validation_type_error(): + """ + Tests that FastAPI's type coercion and validation fail first. + """ + page_validator = QueryValidator("page", _type=int, ge=1) + app, client = get_app_and_client(page_validator) + + response = client.get("/validate/?page=not-an-integer") + assert response.status_code == 422 + detail = response.json()["detail"] + # Check if detail is a list (modern Pydantic v2) or a string (older format) + if isinstance(detail, list): + assert any("integer" in str(error.get("msg", "")).lower() for error in detail) + else: + assert "integer" in str(detail).lower() + + +def test_required_parameter_missing(): + """ + Tests that a parameter without a default is correctly marked as required. + """ + # Note: `default=...` is the default, making it required + token_validator = QueryValidator("token", _type=str) + app, client = get_app_and_client(token_validator) + + # Test missing required parameter + response = client.get("/validate/") + assert response.status_code == 422 + assert "Field required" in response.text + + # Test providing the parameter + response = client.get("/validate/?token=abc") + assert response.status_code == 200 + assert response.json() == {"validated_param": "abc"} + + +def test_default_value_is_used(): + """ + Tests that the default value is used when the parameter is omitted. + """ + page_validator = QueryValidator("page", _type=int, default=1, ge=1) + app, client = get_app_and_client(page_validator) + + response = client.get("/validate/") + assert response.status_code == 200 + assert response.json() == {"validated_param": 1} + + +def test_allowed_values_success(): + """ + Tests that a value in the 'allowed_values' list passes validation. + """ + status_validator = QueryValidator("status", _type=str, allowed_values=["active", "pending"]) + app, client = get_app_and_client(status_validator) + + response_active = client.get("/validate/?status=active") + assert response_active.status_code == 200 + assert response_active.json() == {"validated_param": "active"} + + response_pending = client.get("/validate/?status=pending") + assert response_pending.status_code == 200 + assert response_pending.json() == {"validated_param": "pending"} + + +def test_allowed_values_failure(): + """ + Tests that a value NOT in the 'allowed_values' list fails with a 400. + """ + status_validator = QueryValidator("status", _type=str, allowed_values=["active", "pending"]) + app, client = get_app_and_client(status_validator) + + response = client.get("/validate/?status=archived") + assert response.status_code == 400 # Bad Request + detail = response.json()["detail"] + assert "Value 'archived' is not allowed" in detail + assert "Allowed values are: active, pending" in detail + + +def test_custom_sync_validator_success(): + """ + Tests a passing synchronous custom validator. + """ + + def is_even(v: int): + if not v % 2 == 0: + raise ValidationError("Not Even") + + num_validator = QueryValidator("num", _type=int, validators=[is_even]) + app, client = get_app_and_client(num_validator) + + response = client.get("/validate/?num=10") + assert response.status_code == 200 + assert response.json() == {"validated_param": 10} + + +def test_custom_sync_validator_failure_with_validation_error(): + """ + Tests a failing synchronous custom validator that raises ValidationError. + """ + + def must_be_even(v: int): + if v % 2 != 0: + raise ValidationError(detail="Value must be even.", status_code=400) + + num_validator = QueryValidator("num", _type=int, validators=[must_be_even]) + app, client = get_app_and_client(num_validator) + + response = client.get("/validate/?num=7") + assert response.status_code == 400 + assert "Value must be even." in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_custom_async_validator_success(): + """ + Tests a passing asynchronous custom validator. + """ + + async def async_check_pass(v: str): + await asyncio.sleep(0) + return v == "valid" + + key_validator = QueryValidator("key", _type=str, validators=[async_check_pass]) + app, client = get_app_and_client(key_validator) + + response = client.get("/validate/?key=valid") + assert response.status_code == 200 + assert response.json() == {"validated_param": "valid"} + + +@pytest.mark.asyncio +async def test_custom_async_validator_failure_with_validation_error(): + """ + Tests a failing asynchronous custom validator that raises ValidationError. + """ + + async def async_check_fail(v: str): + await asyncio.sleep(0) + if v != "valid": + raise ValidationError(detail="Key is not valid.", status_code=400) + + key_validator = QueryValidator("key", _type=str, validators=[async_check_fail]) + app, client = get_app_and_client(key_validator) + + response = client.get("/validate/?key=invalid") + assert response.status_code == 400 + assert "Key is not valid." in response.json()["detail"] + + +def test_custom_validator_failure_silent(): + """ + Tests a validator that fails by returning 'False' and checks that + 'on_custom_validator_error_detail' is used. + """ + + def silent_fail(v: str): + if not v == "must-be-this": + raise ValidationError("Value did not match required string.") + + error_msg = "Value did not match required string." + key_validator = QueryValidator( + "key", _type=str, validators=[silent_fail], on_custom_validator_error_detail=error_msg + ) + app, client = get_app_and_client(key_validator) + + response = client.get("/validate/?key=wrong-string") + assert response.status_code == 400 + assert error_msg in response.json()["detail"] + + +def test_validation_order(): + """ + Tests that 'allowed_values' check runs before 'validators'. + """ + + def should_not_be_called(v: str): + """This validator should fail, but it shouldn't even be reached.""" + if v == "beta": + raise ValidationError(detail="Custom validator was called.", status_code=400) + return + + validator = QueryValidator( + "version", + _type=str, + allowed_values=["alpha", "gamma"], # "beta" is not allowed + validators=[should_not_be_called], + ) + app, client = get_app_and_client(validator) + + # This request should fail at the 'allowed_values' check + response = client.get("/validate/?version=beta") + + # It should be a 400 Bad Request + assert response.status_code == 400 + + # The error detail should be from _validate_allowed_values, NOT the custom validator + assert "Value 'beta' is not allowed" in response.json()["detail"] + assert "Custom validator was called" not in response.json()["detail"]