Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastapi_assets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
"""Module for core functionalities of FastAPI Assets."""

from fastapi_assets.core.base_validator import BaseValidator
from fastapi_assets.core.exceptions import ValidationError
90 changes: 78 additions & 12 deletions fastapi_assets/core/base_validator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Base classes for FastAPI validation dependencies."""

import abc
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, List
from fastapi import HTTPException
from fastapi_assets.core.exceptions import ValidationError
import inspect


class BaseValidator(abc.ABC):
Expand All @@ -20,7 +21,7 @@ class BaseValidator(abc.ABC):
from fastapi import Header
from fastapi_assets.core.base_validator import BaseValidator, ValidationError
class MyValidator(BaseValidator):
def _validate_logic(self, token: str) -> None:
def _validate(self, token: str) -> None:
# This method is testable without FastAPI
if not token.startswith("sk_"):
# Raise the logic-level exception
Expand All @@ -44,6 +45,7 @@ def __init__(
*,
status_code: int = 400,
error_detail: Union[str, Callable[[Any], str]] = "Validation failed.",
validators: Optional[List[Callable]] = None,
):
"""
Initializes the base validator.
Expand All @@ -54,9 +56,11 @@ def __init__(
error_detail: The default error message. Can be a static
string or a callable that takes the invalid value as its
argument and returns a dynamic error string.
validators: Optional list of callables for custom validation logic.
"""
self._status_code = status_code
self._error_detail = error_detail
self._custom_validators = validators or []

def _raise_error(
self,
Expand All @@ -65,17 +69,25 @@ def _raise_error(
detail: Optional[Union[str, Callable[[Any], str]]] = None,
) -> None:
"""
Helper method to raise a standardized HTTPException.
Raises a standardized HTTPException with resolved error detail.

It automatically resolves callable error details.
This helper method handles both static error strings and dynamic error
callables, automatically resolving them to a final error message before
raising the HTTPException.

Args:
value (Optional[Any]): The value that failed validation. This is passed
to the error_detail callable, if it is one.
status_code (Optional[int]): A specific status code for this failure,
overriding the instance's default status_code.
detail (Optional[Union[str, Callable[[Any], str]]]): A specific error detail for this failure,
overriding the instance's default error_detail.
value: The value that failed validation. Passed to the error_detail
callable if it is callable.
status_code: A specific HTTP status code for this failure, overriding
the instance's default status_code.
detail: A specific error detail message (string or callable) for this
failure, overriding the instance's default error_detail.

Returns:
None

Raises:
HTTPException: Always raises with the resolved status code and detail.
"""
final_status_code = status_code if status_code is not None else self._status_code

Expand All @@ -91,6 +103,59 @@ def _raise_error(

raise HTTPException(status_code=final_status_code, detail=final_detail)

@abc.abstractmethod
async def _validate(self, value: Any) -> Any:
"""
Abstract method for pure validation logic.

Subclasses MUST implement this method to perform the actual
validation. This method should raise `ValidationError` if
validation fails.

Args:
value: The value to validate.

Returns:
The validated value, which can be of any type depending on the validator.
"""
raise NotImplementedError(
"Subclasses of BaseValidator must implement the _validate method."
)

async def _validate_custom(self, value: Any) -> None:
"""
Executes all configured custom validator functions.

Iterates through the list of custom validators, supporting both
synchronous and asynchronous validator functions. Catches exceptions
and converts them to ValidationError instances.

Args:
value: The value to validate using custom validators.

Returns:
None

Raises:
ValidationError: If any validator raises an exception or explicitly
raises ValidationError.
"""
if self._custom_validators is None:
return

for validator_func in self._custom_validators:
try:
if inspect.iscoroutinefunction(validator_func):
await validator_func(value)
else:
validator_func(value)
except ValidationError:
raise # Re-raise explicit validation errors
except Exception as e:
# Catch any other exception from the validator
detail = f"Custom validation failed. Error: {e}"
raise ValidationError(detail=detail, status_code=self._status_code)

@abc.abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Expand All @@ -111,7 +176,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

class MyValidator(BaseValidator):

def _validate_logic(self, token: str) -> None:
async def _validate(self, token: str) -> None:
# This method is testable without FastAPI
if not token.startswith("sk_"):
# Raise the logic-level exception
Expand All @@ -120,7 +185,8 @@ def _validate_logic(self, token: str) -> None:
def __call__(self, x_token: str = Header(...)):
try:
# 1. Run the pure validation logic
self._validate_logic(x_token)
await self._validate(x_token)
await self._validate_custom(x_token)
except ValidationError as e:
# 2. Catch logic error and raise HTTP error
self._raise_error(
Expand Down
1 change: 0 additions & 1 deletion fastapi_assets/metadata_validators/__init__.py

This file was deleted.

4 changes: 4 additions & 0 deletions fastapi_assets/request_validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Module for request validation in FastAPI Assets."""

from fastapi_assets.request_validators.header_validator import HeaderValidator
from fastapi_assets.request_validators.cookie_validator import CookieValidator
from fastapi_assets.request_validators.path_validator import PathValidator
110 changes: 46 additions & 64 deletions fastapi_assets/request_validators/cookie_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import inspect
import re
from typing import Any, Callable, Dict, List, Optional, Union

from fastapi import Request, status

from fastapi_assets.core.base_validator import BaseValidator
from fastapi_assets.core.exceptions import ValidationError
from fastapi_assets.core import BaseValidator, ValidationError


# Pre-built regex patterns for the `format` parameter
Expand All @@ -20,7 +17,7 @@
}


class CookieAssert(BaseValidator):
class CookieValidator(BaseValidator):
"""
A class-based dependency to validate FastAPI Cookies with granular control.

Expand All @@ -31,22 +28,27 @@ class CookieAssert(BaseValidator):
Example:
```python
from fastapi import FastAPI, Depends
from fastapi_assets.core import ValidationError
from fastapi_assets.request_validators import CookieValidator

app = FastAPI()

def is_whitelisted(user_id: str) -> bool:
# Logic to check if user_id is in a whitelist
return user_id in {"user_1", "user_2"}
def is_whitelisted(user_id: str) -> None:
# Logic to check if user_id is in a whitelist.
# Custom validators must raise ValidationError on failure.
if user_id not in {"user_1", "user_2"}:
raise ValidationError("User is not whitelisted.")

validate_session = CookieAssert(
"session-id", # This is the required 'alias'
# Create validators that will extract cookies from the incoming request
validate_session = CookieValidator(
"session-id", # Cookie name to extract from request.cookies
format="uuid4",
on_required_error_detail="Invalid or missing session ID.",
on_pattern_error_detail="Session ID must be a valid UUIDv4."
)

validate_user = CookieAssert(
"user-id",
validate_user = CookieValidator(
"user-id", # Cookie name to extract from request.cookies
min_length=6,
validators=[is_whitelisted],
on_length_error_detail="User ID must be at least 6 characters.",
Expand All @@ -55,10 +57,14 @@ def is_whitelisted(user_id: str) -> bool:

@app.get("/items/")
async def read_items(session_id: str = Depends(validate_session)):
# validate_session extracts the "session-id" cookie from the request,
# validates it, and returns the validated value
return {"session_id": session_id}

@app.get("/users/me")
async def read_user(user_id: str = Depends(validate_user)):
# validate_user extracts the "user-id" cookie from the request,
# validates it (including length and custom validators), and returns it
return {"user_id": user_id}
```
"""
Expand All @@ -80,7 +86,7 @@ def __init__(
regex: Optional[str] = None,
pattern: Optional[str] = None,
format: Optional[str] = None,
validators: Optional[List[Callable[[Any], bool]]] = None,
validators: Optional[List[Callable[[Any], Any]]] = None,
# Granular Error Messages
on_required_error_detail: str = "Cookie is required.",
on_numeric_error_detail: str = "Cookie value must be a number.",
Expand All @@ -89,7 +95,7 @@ def __init__(
on_pattern_error_detail: str = "Cookie has an invalid format.",
on_validator_error_detail: str = "Cookie failed custom validation.",
# Base Error
status_code: int = status.HTTP_400_BAD_REQUEST,
status_code: int = 400,
error_detail: str = "Cookie validation failed.",
) -> None:
"""
Expand Down Expand Up @@ -126,7 +132,7 @@ def __init__(
ValueError: If `regex`/`pattern` and `format` are used simultaneously.
ValueError: If an unknown `format` key is provided.
"""
super().__init__(status_code=status_code, error_detail=error_detail)
super().__init__(status_code=status_code, error_detail=error_detail, validators=validators)

# Store Core Parameters
self.alias = alias
Expand Down Expand Up @@ -158,29 +164,28 @@ def __init__(
# Handle Regex/Pattern
self.final_regex_str: Optional[str] = regex or pattern
if self.final_regex_str and format:
raise ValueError(
"Cannot use 'regex'/'pattern' and 'format' simultaneously."
)
raise ValueError("Cannot use 'regex'/'pattern' and 'format' simultaneously.")
if format:
if format not in PRE_BUILT_PATTERNS:
raise ValueError(
f"Unknown format: '{format}'. "
f"Available: {list(PRE_BUILT_PATTERNS.keys())}"
f"Unknown format: '{format}'. Available: {list(PRE_BUILT_PATTERNS.keys())}"
)
self.final_regex_str = PRE_BUILT_PATTERNS[format]

self.final_regex: Optional[re.Pattern[str]] = (
re.compile(self.final_regex_str)
if self.final_regex_str
else None
re.compile(self.final_regex_str) if self.final_regex_str else None
)

def _validate_numeric(self, value: str) -> Optional[float]:
"""
Tries to convert value to float. Returns float or None.

This check is only triggered if gt, ge, lt, or le are set.

Args:
value (str): The cookie value to convert.
Returns:
Optional[float]: The converted float value, or None if numeric checks
are not applicable.
Raises:
ValidationError: If conversion to float fails.
"""
Expand All @@ -198,6 +203,10 @@ def _validate_comparison(self, value: float) -> None:
"""
Checks gt, ge, lt, le rules against a numeric value.

Args:
value (float): The numeric value to compare.
Returns:
None
Raises:
ValidationError: If any comparison fails.
"""
Expand Down Expand Up @@ -226,6 +235,12 @@ def _validate_length(self, value: str) -> None:
"""
Checks min_length and max_length rules.

Args:
value (str): The cookie value to check.

Returns:
None

Raises:
ValidationError: If length constraints fail.
"""
Expand All @@ -244,6 +259,10 @@ def _validate_length(self, value: str) -> None:
def _validate_pattern(self, value: str) -> None:
"""
Checks regex/format pattern rule.
Args:
value (str): The cookie value to check.
Returns:
None

Raises:
ValidationError: If the regex pattern does not match.
Expand All @@ -254,43 +273,7 @@ def _validate_pattern(self, value: str) -> None:
status_code=status.HTTP_400_BAD_REQUEST,
)

async def _validate_custom(self, value: str) -> None:
"""
Runs all custom validator functions (sync or async).

Raises:
ValidationError: If any function returns False or raises an Exception.
"""
if not self.custom_validators:
return

for validator_func in self.custom_validators:
try:
is_valid = None
# Handle both sync and async validators
if inspect.iscoroutinefunction(validator_func):
is_valid = await validator_func(value)
else:
is_valid = validator_func(value)

if not is_valid:
raise ValidationError(
detail=self.err_validator,
status_code=status.HTTP_400_BAD_REQUEST,
)
except ValidationError:
# Re-raise our own validation errors
raise
except Exception as e:
# Validator function raising an error is a validation failure
raise ValidationError(
detail=f"{self.err_validator}: {e}",
status_code=status.HTTP_400_BAD_REQUEST,
)

async def _validate_logic(
self, cookie_value: Optional[str]
) -> Union[float, str, None]:
async def _validate(self, cookie_value: Optional[str]) -> Union[float, str, None]:
"""
Pure validation logic (testable without FastAPI).

Expand Down Expand Up @@ -355,11 +338,10 @@ async def __call__(self, request: Request) -> Union[float, str, None]:
try:
# Extract cookie value from request
cookie_value: Optional[str] = request.cookies.get(self.alias)

# Run all validation logic
return await self._validate_logic(cookie_value)
return await self._validate(cookie_value)

except ValidationError as e:
# Convert validation error to HTTP exception
self._raise_error(detail=e.detail, status_code=e.status_code)
return None # pragma: no cover
return None # pragma: no cover
Loading