diff --git a/FasterAPI/__init__.py b/FasterAPI/__init__.py index 9d2f403..befeca4 100644 --- a/FasterAPI/__init__.py +++ b/FasterAPI/__init__.py @@ -29,15 +29,30 @@ from .params import Body, Cookie, File, Form, Header, Path, Query from .request import Request from .response import ( + EventSourceResponse, FileResponse, HTMLResponse, JSONResponse, + ORJSONResponse, PlainTextResponse, RedirectResponse, Response, StreamingResponse, + UJSONResponse, ) from .router import FasterRouter, RadixRouter +from .security import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + HTTPBasic, + HTTPBasicCredentials, + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, + SecurityScopes, +) +from .staticfiles import StaticFiles +from .templating import Jinja2Templates from .websocket import WebSocket, WebSocketDisconnect, WebSocketState if TYPE_CHECKING: @@ -53,11 +68,14 @@ # Responses "Response", "JSONResponse", + "ORJSONResponse", + "UJSONResponse", "HTMLResponse", "PlainTextResponse", "RedirectResponse", "StreamingResponse", "FileResponse", + "EventSourceResponse", # Params "Body", "Cookie", @@ -90,6 +108,18 @@ # Concurrency "SubInterpreterPool", "run_in_subinterpreter", + # Security + "SecurityScopes", + "OAuth2PasswordBearer", + "OAuth2PasswordRequestForm", + "HTTPBasic", + "HTTPBasicCredentials", + "APIKeyHeader", + "APIKeyQuery", + "APIKeyCookie", + # Static files & Templates + "StaticFiles", + "Jinja2Templates", # Testing "TestClient", ] @@ -101,9 +131,7 @@ def __getattr__(name: str) -> Any: from .testclient import TestClient as _TestClient except ModuleNotFoundError as e: if getattr(e, "name", None) == "httpx": - raise ImportError( - "TestClient requires httpx. Install with: pip install httpx", - ) from e + raise ImportError("TestClient requires httpx. Install with: pip install httpx") from e raise return _TestClient raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/FasterAPI/app.py b/FasterAPI/app.py index 028917f..b9f8d54 100644 --- a/FasterAPI/app.py +++ b/FasterAPI/app.py @@ -10,14 +10,14 @@ from __future__ import annotations import asyncio +import contextlib +import dataclasses from collections.abc import Callable, Sequence from typing import Any, cast -import msgspec.json - from ._version import get_version from .concurrency import install_event_loop -from .dependencies import _resolve_handler, compile_handler +from .dependencies import Depends, _resolve_handler, compile_handler from .exceptions import ( HTTPException, RequestValidationError, @@ -27,7 +27,7 @@ from .openapi.generator import generate_openapi from .openapi.ui import redoc_html, swagger_ui_html from .request import Request -from .response import HTMLResponse, JSONResponse +from .response import HTMLResponse, JSONResponse, encode_json from .router import RadixRouter from .types import ASGIApp from .websocket import WebSocket @@ -36,7 +36,6 @@ _event_loop = install_event_loop() -# Pre-encode common header values to avoid repeated bytes() calls _CT_JSON = b"application/json" _CT_TEXT = b"text/plain; charset=utf-8" _CT_OCTET = b"application/octet-stream" @@ -54,15 +53,21 @@ class Faster: "openapi_url", "docs_url", "redoc_url", + "openapi_tags", + "terms_of_service", + "contact", + "license_info", "routes", "startup_handlers", "shutdown_handlers", + "lifespan", "middleware", "exception_handlers", "_router", "_openapi_cache", "_middleware_app", "_ws_routes", + "_mounts", ) def __init__( @@ -74,6 +79,11 @@ def __init__( openapi_url: str | None = "/openapi.json", docs_url: str | None = "/docs", redoc_url: str | None = "/redoc", + openapi_tags: list[dict[str, Any]] | None = None, + terms_of_service: str | None = None, + contact: dict[str, str] | None = None, + license_info: dict[str, str] | None = None, + lifespan: Callable[[Faster], Any] | None = None, ) -> None: self.title = title self.version = version if version is not None else get_version() @@ -81,15 +91,21 @@ def __init__( self.openapi_url = openapi_url self.docs_url = docs_url self.redoc_url = redoc_url + self.openapi_tags = openapi_tags + self.terms_of_service = terms_of_service + self.contact = contact + self.license_info = license_info + self.lifespan = lifespan self.routes: list[dict[str, Any]] = [] - self.startup_handlers: list[ASGIApp] = [] - self.shutdown_handlers: list[ASGIApp] = [] + self.startup_handlers: list[Callable[[], Any]] = [] + self.shutdown_handlers: list[Callable[[], Any]] = [] self.middleware: list[dict[str, Any]] = [] - self.exception_handlers: dict[type, ASGIApp] = {} + self.exception_handlers: dict[type, Any] = {} self._router = RadixRouter() self._openapi_cache: dict[str, Any] | None = None self._middleware_app: ASGIApp | None = None self._ws_routes: dict[str, ASGIApp] = {} + self._mounts: list[tuple[str, ASGIApp]] = [] self._setup_openapi_routes() def __repr__(self) -> str: @@ -110,6 +126,10 @@ async def openapi_schema() -> JSONResponse: title=app_ref.title, version=app_ref.version, description=app_ref.description, + openapi_tags=app_ref.openapi_tags, + terms_of_service=app_ref.terms_of_service, + contact=app_ref.contact, + license_info=app_ref.license_info, ) return JSONResponse(spec) @@ -122,6 +142,8 @@ async def openapi_schema() -> JSONResponse: response_model=None, status_code=200, deprecated=False, + responses=None, + dependencies=None, ) if self.docs_url is not None and self.openapi_url is not None: @@ -139,6 +161,8 @@ async def swagger_docs() -> HTMLResponse: response_model=None, status_code=200, deprecated=False, + responses=None, + dependencies=None, ) if self.redoc_url is not None and self.openapi_url is not None: @@ -156,6 +180,8 @@ async def redoc_docs() -> HTMLResponse: response_model=None, status_code=200, deprecated=False, + responses=None, + dependencies=None, ) # ------------------------------------------------------------------ @@ -165,8 +191,8 @@ async def redoc_docs() -> HTMLResponse: async def __call__( self, scope: dict[str, Any], - receive: ASGIApp, - send: ASGIApp, + receive: Any, + send: Any, ) -> None: if self.middleware: if self._middleware_app is None: @@ -178,11 +204,20 @@ async def __call__( async def _asgi_app( self, scope: dict[str, Any], - receive: ASGIApp, - send: ASGIApp, + receive: Any, + send: Any, ) -> None: scope_type = scope["type"] if scope_type == "http": + # Check mounts first + path: str = scope.get("path", "/") + for prefix, mounted_app in self._mounts: + if path == prefix or path.startswith(prefix + "/"): + sub_scope = dict(scope) + sub_scope["path"] = path[len(prefix) :] or "/" + sub_scope["root_path"] = scope.get("root_path", "") + prefix + await mounted_app(sub_scope, receive, send) + return await self._handle_http(scope, receive, send) elif scope_type == "websocket": await self._handle_websocket(scope, receive, send) @@ -202,8 +237,8 @@ def _build_middleware_chain(self) -> ASGIApp: async def _handle_http( self, scope: dict[str, Any], - receive: ASGIApp, - send: ASGIApp, + receive: Any, + send: Any, ) -> None: result = self._router.resolve(scope["method"], scope["path"]) if result is None: @@ -215,8 +250,10 @@ async def _handle_http( request = Request(scope, receive) bg_tasks = None + extra_deps: list[Depends] | None = metadata.get("dependencies") + try: - response, bg_tasks = await _resolve_handler(handler, request, path_params) + response, bg_tasks = await _resolve_handler(handler, request, path_params, extra_deps) except RequestValidationError as exc: status, body, headers = await self._handle_exc( request, @@ -246,6 +283,13 @@ async def _handle_http( await _send_error(send, 500, "Internal Server Error") return + # Apply response_model filtering if configured + response_model = metadata.get("response_model") + response_model_include = metadata.get("response_model_include") + response_model_exclude = metadata.get("response_model_exclude") + if response_model is not None and not hasattr(response, "to_asgi"): + response = _apply_response_model(response, response_model, response_model_include, response_model_exclude) + await _send_response(send, metadata.get("status_code", 200), response) if bg_tasks is not None: @@ -256,7 +300,7 @@ async def _handle_exc( request: Request, exc: Exception, exc_class: type, - default_handler: ASGIApp, + default_handler: Any, ) -> tuple[int, bytes, list[tuple[bytes, bytes]]]: handler = self.exception_handlers.get(exc_class, default_handler) result = handler(request, exc) @@ -271,8 +315,8 @@ async def _handle_exc( async def _handle_websocket( self, scope: dict[str, Any], - receive: ASGIApp, - send: ASGIApp, + receive: Any, + send: Any, ) -> None: path = scope.get("path", "/") handler = self._ws_routes.get(path.rstrip("/") or "/") @@ -289,9 +333,13 @@ async def _handle_websocket( async def _handle_lifespan( self, scope: dict[str, Any], - receive: ASGIApp, - send: ASGIApp, + receive: Any, + send: Any, ) -> None: + if self.lifespan is not None: + await self._run_lifespan_context(receive, send) + return + while True: message = await receive() if message["type"] == "lifespan.startup": @@ -315,6 +363,33 @@ async def _handle_lifespan( pass return + async def _run_lifespan_context(self, receive: Any, send: Any) -> None: + """Run the lifespan async context manager.""" + assert self.lifespan is not None + ctx = self.lifespan(self) + # Support both @asynccontextmanager functions and plain async generators + if hasattr(ctx, "__aenter__"): + async with ctx: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + message = await receive() + if message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + else: + # Treat as async generator + gen = ctx.__aiter__() if hasattr(ctx, "__aiter__") else ctx + with contextlib.suppress(StopAsyncIteration): + await gen.__anext__() + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + message = await receive() + if message["type"] == "lifespan.shutdown": + with contextlib.suppress(StopAsyncIteration): + await gen.__anext__() + await send({"type": "lifespan.shutdown.complete"}) + # ------------------------------------------------------------------ # Route registration # ------------------------------------------------------------------ @@ -330,17 +405,25 @@ def _add_route( response_model: Any, status_code: int, deprecated: bool, + responses: dict[int | str, dict[str, Any]] | None, + dependencies: list[Depends] | None, + response_model_include: set[str] | None = None, + response_model_exclude: set[str] | None = None, ) -> None: - metadata = { + metadata: dict[str, Any] = { "tags": tags, "summary": summary, "response_model": response_model, + "response_model_include": response_model_include, + "response_model_exclude": response_model_exclude, "status_code": status_code, "deprecated": deprecated, + "responses": responses, + "dependencies": dependencies, } self.routes.append({"method": method, "path": path, "handler": handler, **metadata}) self._router.add_route(method, path, handler, metadata) - compile_handler(handler) # pre-compile at registration time + compile_handler(handler) def _route_decorator(self, method: str, path: str, **kw: Any) -> Callable[[ASGIApp], ASGIApp]: def decorator(handler: ASGIApp) -> ASGIApp: @@ -351,8 +434,12 @@ def decorator(handler: ASGIApp) -> ASGIApp: tags=kw.get("tags") or [], summary=kw.get("summary", ""), response_model=kw.get("response_model"), + response_model_include=kw.get("response_model_include"), + response_model_exclude=kw.get("response_model_exclude"), status_code=kw.get("status_code", 200), deprecated=kw.get("deprecated", False), + responses=kw.get("responses"), + dependencies=kw.get("dependencies"), ) return handler @@ -380,15 +467,29 @@ def decorator(handler: ASGIApp) -> ASGIApp: return decorator + # ------------------------------------------------------------------ + # Sub-application mounting + # ------------------------------------------------------------------ + + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: + """Mount an ASGI sub-application (e.g. StaticFiles) at *path*. + + Example:: + + app.mount("/static", StaticFiles(directory="static"), name="static") + """ + prefix = path.rstrip("/") + self._mounts.append((prefix, app)) + # ------------------------------------------------------------------ # Lifecycle hooks # ------------------------------------------------------------------ - def on_startup(self, handler: ASGIApp) -> ASGIApp: + def on_startup(self, handler: Callable[[], Any]) -> Callable[[], Any]: self.startup_handlers.append(handler) return handler - def on_shutdown(self, handler: ASGIApp) -> ASGIApp: + def on_shutdown(self, handler: Callable[[], Any]) -> Callable[[], Any]: self.shutdown_handlers.append(handler) return handler @@ -400,7 +501,7 @@ def add_middleware(self, middleware_class: type, **kwargs: Any) -> None: self.middleware.append({"class": middleware_class, "kwargs": kwargs}) self._middleware_app = None # invalidate cached chain - def add_exception_handler(self, exc_class: type, handler: ASGIApp) -> None: + def add_exception_handler(self, exc_class: type, handler: Any) -> None: self.exception_handlers[exc_class] = handler # ------------------------------------------------------------------ @@ -413,24 +514,73 @@ def include_router( *, prefix: str = "", tags: Sequence[str] = (), + dependencies: list[Depends] | None = None, ) -> None: pfx = prefix.rstrip("/") for route in router.routes: merged = dict(route) merged["path"] = pfx + merged["path"] merged["tags"] = list(tags) + merged["tags"] + # Merge router-level dependencies with route-level dependencies + route_deps: list[Depends] = merged.get("dependencies") or [] + router_deps: list[Depends] = getattr(router, "dependencies", None) or [] + caller_deps: list[Depends] = dependencies or [] + merged_deps = caller_deps + router_deps + route_deps + merged["dependencies"] = merged_deps if merged_deps else None self.routes.append(merged) metadata = {k: v for k, v in merged.items() if k not in ("method", "path", "handler")} self._router.add_route(merged["method"], merged["path"], merged["handler"], metadata) compile_handler(merged["handler"]) +# ------------------------------------------------------------------ +# Response model filtering +# ------------------------------------------------------------------ + + +def _apply_response_model( + result: Any, + model: type, + include: set[str] | None, + exclude: set[str] | None, +) -> Any: + """Filter *result* to only the fields defined in *model*.""" + import msgspec.structs + + if isinstance(result, msgspec.Struct): + data = {f.name: getattr(result, f.name) for f in msgspec.structs.fields(result)} + elif dataclasses.is_dataclass(result) and not isinstance(result, type): + data = dataclasses.asdict(result) + elif isinstance(result, dict): + data = result + else: + return result + + # Determine allowed field names from the response model + try: + if issubclass(model, msgspec.Struct): + allowed: set[str] = {f.name for f in msgspec.structs.fields(model)} + elif dataclasses.is_dataclass(model): + allowed = {f.name for f in dataclasses.fields(model)} + else: + allowed = set(data.keys()) + except TypeError: + allowed = set(data.keys()) + + if include is not None: + allowed &= include + if exclude is not None: + allowed -= exclude + + return {k: v for k, v in data.items() if k in allowed} + + # ------------------------------------------------------------------ # Module-level send helpers (avoid method lookup on self) # ------------------------------------------------------------------ -async def _send_response(send: ASGIApp, status_code: int, body: Any) -> None: +async def _send_response(send: Any, status_code: int, body: Any) -> None: if hasattr(body, "to_asgi"): await body.to_asgi(send) return @@ -443,14 +593,14 @@ async def _send_response(send: ASGIApp, status_code: int, body: Any) -> None: body = b"" ct = _CT_PLAIN else: - body = msgspec.json.encode(body) + body = encode_json(body) ct = _CT_JSON await send({"type": "http.response.start", "status": status_code, "headers": [(_HEADER_CT, ct)]}) await send({"type": "http.response.body", "body": body}) async def _send_raw( - send: ASGIApp, + send: Any, status: int, body: bytes, headers: list[tuple[bytes, bytes]], @@ -459,7 +609,7 @@ async def _send_raw( await send({"type": "http.response.body", "body": body}) -async def _send_error(send: ASGIApp, status: int, message: str) -> None: - body = msgspec.json.encode({"detail": message}) +async def _send_error(send: Any, status: int, message: str) -> None: + body = encode_json({"detail": message}) await send({"type": "http.response.start", "status": status, "headers": [(_HEADER_CT, _CT_JSON)]}) await send({"type": "http.response.body", "body": body}) diff --git a/FasterAPI/dependencies.py b/FasterAPI/dependencies.py index 3c9bdbb..6780432 100644 --- a/FasterAPI/dependencies.py +++ b/FasterAPI/dependencies.py @@ -8,11 +8,12 @@ from __future__ import annotations +import dataclasses import inspect import typing from collections.abc import Callable from functools import lru_cache -from typing import Any +from typing import Annotated, Any, get_args, get_origin import msgspec @@ -35,12 +36,13 @@ class Depends: __slots__ = ("dependency", "use_cache") - def __init__(self, dependency: Callable[..., Any], *, use_cache: bool = True) -> None: + def __init__(self, dependency: Callable[..., Any] | None = None, *, use_cache: bool = True) -> None: self.dependency = dependency self.use_cache = use_cache def __repr__(self) -> str: - return f"Depends({self.dependency.__name__})" + name = self.dependency.__name__ if self.dependency else "..." + return f"Depends({name})" # --------------------------------------------------------------------------- @@ -59,6 +61,8 @@ def __repr__(self) -> str: _KIND_FORM = 9 _KIND_BODY = 10 _KIND_FALLBACK = 11 +_KIND_DATACLASS = 12 +_KIND_FORM_BODY = 13 # OAuth2PasswordRequestForm and similar class-based form deps class _ParamSpec: @@ -81,6 +85,30 @@ def __init__( self.marker = marker +# --------------------------------------------------------------------------- +# Annotated unwrapping helpers +# --------------------------------------------------------------------------- + + +def _unwrap_annotated(annotation: Any, default: Any) -> tuple[Any, Any]: + """If *annotation* is Annotated[T, marker, ...], return (T, first_marker). + + Supports PEP 593 style: Annotated[str, Depends(get_token)] + Annotated[str, Query(alias="q")] + """ + if get_origin(annotation) is not Annotated: + return annotation, default + + args = get_args(annotation) + inner_type = args[0] + # Walk metadata left-to-right; use the first recognised marker. + _marker_types = (Depends, Path, Query, Header, Cookie, Body, File, Form) + for meta in args[1:]: + if isinstance(meta, _marker_types): + return inner_type, meta + return inner_type, default + + # --------------------------------------------------------------------------- # Compile handler (called once at route registration) # --------------------------------------------------------------------------- @@ -88,29 +116,52 @@ def __init__( @lru_cache(maxsize=512) def compile_handler(func: Callable[..., Any]) -> tuple[tuple[_ParamSpec, ...], bool]: - """Introspect *func* once and return a tuple of _ParamSpec plus is-async flag. - - This replaces per-request inspect.signature + get_type_hints calls. - """ - sig = inspect.signature(func) - try: - type_hints = typing.get_type_hints(func) - except Exception: - type_hints = {} + """Introspect *func* once and return a tuple of _ParamSpec plus is-async flag.""" + # For class-based dependencies (e.g. OAuth2PasswordRequestForm), inspect __init__ + if inspect.isclass(func): + sig = inspect.signature(func) + try: + type_hints = typing.get_type_hints(func.__init__, include_extras=True) + except Exception: + type_hints = {} + type_hints.pop("return", None) + is_async = False + elif callable(func) and not inspect.isfunction(func) and not inspect.isbuiltin(func): + # Callable instance (e.g. OAuth2PasswordBearer instance) — use the class __call__ + sig = inspect.signature(func) + call_method = inspect.getattr_static(type(func), "__call__", None) + try: + type_hints = typing.get_type_hints(call_method, include_extras=True) if call_method else {} + except Exception: + type_hints = {} + is_async = is_coroutine(call_method) if call_method else False + else: + sig = inspect.signature(func) + try: + type_hints = typing.get_type_hints(func, include_extras=True) + except Exception: + type_hints = {} + is_async = is_coroutine(func) specs: list[_ParamSpec] = [] for name, param in sig.parameters.items(): - annotation = type_hints.get(name, param.annotation) - default = param.default + raw_annotation = type_hints.get(name, param.annotation) + raw_default = param.default + + # Unwrap Annotated[T, marker] — PEP 593 support + annotation, default = _unwrap_annotated(raw_annotation, raw_default) if annotation is BackgroundTasks: specs.append(_ParamSpec(name, _KIND_BG_TASKS, annotation, default, None)) elif annotation is Request: specs.append(_ParamSpec(name, _KIND_REQUEST, annotation, default, None)) elif isinstance(default, Depends): + # Annotated[str, Depends(fn)] or foo: str = Depends(fn) specs.append(_ParamSpec(name, _KIND_DEPENDS, annotation, default, default)) elif _is_struct_type(annotation): specs.append(_ParamSpec(name, _KIND_STRUCT, annotation, default, None)) + elif _is_dataclass_type(annotation): + specs.append(_ParamSpec(name, _KIND_DATACLASS, annotation, default, None)) elif isinstance(default, Path): specs.append(_ParamSpec(name, _KIND_PATH, annotation, default, default)) elif isinstance(default, Query): @@ -128,7 +179,7 @@ def compile_handler(func: Callable[..., Any]) -> tuple[tuple[_ParamSpec, ...], b else: specs.append(_ParamSpec(name, _KIND_FALLBACK, annotation, default, None)) - return tuple(specs), is_coroutine(func) + return tuple(specs), is_async # --------------------------------------------------------------------------- @@ -140,11 +191,18 @@ async def _resolve_handler( handler: Callable[..., Any], request: Request, path_params: dict[str, str], + extra_deps: list[Depends] | None = None, ) -> tuple[Any, BackgroundTasks | None]: """Resolve dependencies, call handler, return (result, bg_tasks|None).""" specs, is_async = compile_handler(handler) cache: dict[Callable[..., Any], Any] = {} bg_tasks = BackgroundTasks() + + # Router-level dependencies resolved before handler params + if extra_deps: + for dep in extra_deps: + await _resolve_dependency(dep, request, path_params, cache, bg_tasks) + kwargs = await _resolve_from_specs(specs, request, path_params, cache, bg_tasks) result = await handler(**kwargs) if is_async else handler(**kwargs) @@ -182,6 +240,12 @@ async def _resolve_from_specs( request, spec.default, ) + elif kind == _KIND_DATACLASS: + kwargs[spec.name] = await _resolve_dataclass( + spec.annotation, + request, + spec.default, + ) elif kind == _KIND_PATH: kwargs[spec.name] = _resolve_path(spec.name, path_params, spec.marker) elif kind == _KIND_QUERY: @@ -222,12 +286,29 @@ async def _resolve_dependency( bg_tasks: BackgroundTasks, ) -> Any: func = dep.dependency + if func is None: + return None + if dep.use_cache and func in cache: return cache[func] + # Special case: OAuth2PasswordRequestForm and similar class-based form-body deps + if inspect.isclass(func) and hasattr(func, "from_request"): + result = await func.from_request(request) + if dep.use_cache: + cache[func] = result + return result + specs, is_async = compile_handler(func) dep_kwargs = await _resolve_from_specs(specs, request, path_params, cache, bg_tasks) - result = await func(**dep_kwargs) if is_async else func(**dep_kwargs) + + if inspect.isclass(func): + # Instantiate the class synchronously + result = func(**dep_kwargs) + elif is_async: + result = await func(**dep_kwargs) + else: + result = func(**dep_kwargs) if dep.use_cache: cache[func] = result @@ -247,6 +328,14 @@ def _is_struct_type(annotation: Any) -> bool: ) +def _is_dataclass_type(annotation: Any) -> bool: + return ( + annotation is not inspect.Parameter.empty + and isinstance(annotation, type) + and dataclasses.is_dataclass(annotation) + ) + + def _is_upload_file_type(annotation: Any) -> bool: return ( annotation is not inspect.Parameter.empty @@ -271,6 +360,26 @@ async def _resolve_struct( ) from exc +async def _resolve_dataclass( + dc_type: type, + request: Request, + default: Any, +) -> Any: + """Resolve a standard @dataclass from the JSON request body.""" + try: + raw = await request._read_body() + data = msgspec.json.decode(raw, type=dict) + fields = {f.name for f in dataclasses.fields(dc_type)} + filtered = {k: v for k, v in data.items() if k in fields} + return dc_type(**filtered) + except Exception as exc: + if default is not inspect.Parameter.empty: + return default + raise RequestValidationError( + [{"loc": ["body"], "msg": str(exc), "type": "value_error.dataclass"}], + ) from exc + + def _resolve_path(name: str, path_params: dict[str, str], marker: Path) -> Any: if name in path_params: return path_params[name] diff --git a/FasterAPI/openapi/generator.py b/FasterAPI/openapi/generator.py index df23e4b..408f769 100644 --- a/FasterAPI/openapi/generator.py +++ b/FasterAPI/openapi/generator.py @@ -1,11 +1,16 @@ from __future__ import annotations +import dataclasses +import datetime +import decimal +import enum import inspect import re import types import typing +import uuid from collections.abc import Callable -from typing import Any, Union, get_args, get_origin +from typing import Annotated, Any, Union, get_args, get_origin import msgspec @@ -19,6 +24,10 @@ def generate_openapi( title: str = "FasterAPI", version: str | None = None, description: str = "", + openapi_tags: list[dict[str, Any]] | None = None, + terms_of_service: str | None = None, + contact: dict[str, str] | None = None, + license_info: dict[str, str] | None = None, ) -> dict[str, Any]: """Generate an OpenAPI 3.0.3 spec dict from a Faster app instance.""" if version is None: @@ -34,23 +43,28 @@ def generate_openapi( method = route["method"].lower() raw_path = route["path"] handler = route["handler"] - - # Convert {param} to OpenAPI {param} (already compatible) openapi_path = raw_path operation = _build_operation(route, handler, schemas) paths.setdefault(openapi_path, {})[method] = operation + info: dict[str, Any] = {"title": title, "version": version} + if description: + info["description"] = description + if terms_of_service: + info["termsOfService"] = terms_of_service + if contact: + info["contact"] = contact + if license_info: + info["license"] = license_info + spec: dict[str, Any] = { "openapi": "3.0.3", - "info": { - "title": title, - "version": version, - }, + "info": info, "paths": paths, } - if description: - spec["info"]["description"] = description + if openapi_tags: + spec["tags"] = openapi_tags if schemas: spec["components"] = {"schemas": schemas} @@ -65,38 +79,32 @@ def _build_operation( ) -> dict[str, Any]: operation: dict[str, Any] = {} - # Tags tags = route.get("tags", []) if tags: operation["tags"] = tags - # Summary from decorator or function name summary = route.get("summary", "") if summary: operation["summary"] = summary else: operation["summary"] = handler.__name__.replace("_", " ").title() - # Description from docstring doc = inspect.getdoc(handler) if doc: operation["description"] = doc - # Deprecated if route.get("deprecated", False): operation["deprecated"] = True - # Operation ID operation["operationId"] = handler.__name__ - # Parameters and request body parameters, request_body = _extract_params(route, handler, schemas) if parameters: operation["parameters"] = parameters if request_body: operation["requestBody"] = request_body - # Responses + # Build responses dict status_code = str(route.get("status_code", 200)) response_model = route.get("response_model") responses: dict[str, Any] = {} @@ -110,26 +118,39 @@ def _build_operation( else: responses[status_code] = {"description": "Successful Response"} - # 422 for routes that have body/query/path params + # Merge additional responses declared via responses={404: {...}} + extra_responses: dict[int | str, dict[str, Any]] | None = route.get("responses") + if extra_responses: + for code, resp_def in extra_responses.items(): + key = str(code) + merged: dict[str, Any] = {"description": resp_def.get("description", "Response")} + model = resp_def.get("model") + if model is not None: + extra_schema = _type_to_schema(model, schemas) + merged["content"] = {"application/json": {"schema": extra_schema}} + elif "content" in resp_def: + merged["content"] = resp_def["content"] + responses[key] = merged + if parameters or request_body: - responses["422"] = { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "detail": { - "type": "array", - "items": { - "type": "object", - "properties": { - "loc": { - "type": "array", - "items": {"type": "string"}, + responses.setdefault( + "422", + { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "detail": { + "type": "array", + "items": { + "type": "object", + "properties": { + "loc": {"type": "array", "items": {"type": "string"}}, + "msg": {"type": "string"}, + "type": {"type": "string"}, }, - "msg": {"type": "string"}, - "type": {"type": "string"}, }, }, }, @@ -137,7 +158,7 @@ def _build_operation( }, }, }, - } + ) operation["responses"] = responses return operation @@ -151,7 +172,6 @@ def _extract_params( parameters: list[dict[str, Any]] = [] request_body: dict[str, Any] | None = None - # Extract {param} names from path path_param_names = set(re.findall(r"\{(\w+)\}", route["path"])) try: @@ -165,22 +185,20 @@ def _extract_params( hints = {} for name, param in sig.parameters.items(): - annotation = hints.get(name, param.annotation) - # If annotation is still a string, try to resolve from param default - if isinstance(annotation, str): - annotation = param.annotation - default = param.default + raw_annotation = hints.get(name, param.annotation) + raw_default = param.default + + # Unwrap Annotated[T, marker] + annotation, default = _unwrap_annotated_for_openapi(raw_annotation, raw_default) - # Skip Request injection from ..request import Request as RequestClass if annotation is RequestClass: continue - # Skip Depends from ..dependencies import Depends - if isinstance(default, Depends): + if isinstance(default, Depends) or isinstance(raw_default, Depends): continue # Path parameter @@ -189,7 +207,7 @@ def _extract_params( "name": name, "in": "path", "required": True, - "schema": _annotation_to_schema(annotation), + "schema": _annotation_to_schema(annotation, schemas), } desc = default.description if isinstance(default, Path) else "" if desc: @@ -203,7 +221,7 @@ def _extract_params( "name": default.alias or name, "in": "query", "required": default.default is None and not _is_optional(annotation), - "schema": _annotation_to_schema(annotation), + "schema": _annotation_to_schema(annotation, schemas), } if default.description: p["description"] = default.description @@ -219,7 +237,7 @@ def _extract_params( "name": header_name, "in": "header", "required": default.default is None, - "schema": _annotation_to_schema(annotation), + "schema": _annotation_to_schema(annotation, schemas), } if default.default is not None: p["schema"]["default"] = default.default @@ -232,15 +250,15 @@ def _extract_params( "name": name, "in": "cookie", "required": default.default is None, - "schema": _annotation_to_schema(annotation), + "schema": _annotation_to_schema(annotation, schemas), } if default.default is not None: p["schema"]["default"] = default.default parameters.append(p) continue - # Body / msgspec.Struct - if isinstance(default, Body) or _is_struct_type(annotation): + # Body / msgspec.Struct / dataclass + if isinstance(default, Body) or _is_struct_type(annotation) or _is_dataclass_type(annotation): schema = _type_to_schema(annotation, schemas) request_body = { "required": True, @@ -251,6 +269,18 @@ def _extract_params( return parameters, request_body +def _unwrap_annotated_for_openapi(annotation: Any, default: Any) -> tuple[Any, Any]: + if get_origin(annotation) is not Annotated: + return annotation, default + args = get_args(annotation) + inner = args[0] + _marker_types = (Path, Query, Header, Cookie, Body) + for meta in args[1:]: + if isinstance(meta, _marker_types): + return inner, meta + return inner, default + + def _is_struct_type(annotation: Any) -> bool: return ( annotation is not inspect.Parameter.empty @@ -259,6 +289,14 @@ def _is_struct_type(annotation: Any) -> bool: ) +def _is_dataclass_type(annotation: Any) -> bool: + return ( + annotation is not inspect.Parameter.empty + and isinstance(annotation, type) + and dataclasses.is_dataclass(annotation) + ) + + def _is_optional(annotation: Any) -> bool: origin = get_origin(annotation) if origin is Union or origin is types.UnionType: @@ -288,13 +326,37 @@ def _python_type_to_schema( return {"type": "number"} if tp is bool: return {"type": "boolean"} - - # Check if it's a struct type (before checking origin) + if tp is datetime.datetime: + return {"type": "string", "format": "date-time"} + if tp is datetime.date: + return {"type": "string", "format": "date"} + if tp is datetime.time: + return {"type": "string", "format": "time"} + if tp is uuid.UUID: + return {"type": "string", "format": "uuid"} + if tp is decimal.Decimal: + return {"type": "string", "format": "decimal"} + + # Enum + if isinstance(tp, type) and issubclass(tp, enum.Enum): + values = [m.value for m in tp] + # Infer type from values + if all(isinstance(v, int) for v in values): + return {"type": "integer", "enum": values} + return {"type": "string", "enum": values} + + # msgspec.Struct if _is_struct_type(tp): if schemas is not None: return _struct_to_ref(tp, schemas) return {"type": "object"} + # dataclass + if _is_dataclass_type(tp): + if schemas is not None: + return _dataclass_to_ref(tp, schemas) + return {"type": "object"} + origin = get_origin(tp) args = get_args(tp) @@ -315,10 +377,10 @@ def _python_type_to_schema( # dict / Dict[K, V] if origin is dict: - schema = {"type": "object"} + dict_schema: dict[str, Any] = {"type": "object"} if args and len(args) == 2: - schema["additionalProperties"] = _python_type_to_schema(args[1], schemas) - return schema + dict_schema["additionalProperties"] = _python_type_to_schema(args[1], schemas) + return dict_schema return {"type": "string"} @@ -333,6 +395,9 @@ def _type_to_schema( if _is_struct_type(tp): return _struct_to_ref(tp, schemas) + if _is_dataclass_type(tp): + return _dataclass_to_ref(tp, schemas) + return _python_type_to_schema(tp, schemas) @@ -341,10 +406,18 @@ def _struct_to_ref( schemas: dict[str, Any], ) -> dict[str, Any]: name = struct_type.__name__ - if name not in schemas: schemas[name] = _struct_to_schema(struct_type, schemas) + return {"$ref": f"#/components/schemas/{name}"} + +def _dataclass_to_ref( + dc_type: type, + schemas: dict[str, Any], +) -> dict[str, Any]: + name = dc_type.__name__ + if name not in schemas: + schemas[name] = _dataclass_to_schema(dc_type, schemas) return {"$ref": f"#/components/schemas/{name}"} @@ -372,10 +445,8 @@ def _struct_to_schema( if required: schema["required"] = required - # Only use docstring if defined directly on this class, not inherited doc = struct_type.__doc__ if doc and doc != msgspec.Struct.__doc__: - # Clean up the docstring doc = inspect.cleandoc(doc) if doc: schema["description"] = doc @@ -383,6 +454,41 @@ def _struct_to_schema( return schema +def _dataclass_to_schema( + dc_type: type, + schemas: dict[str, Any], +) -> dict[str, Any]: + properties: dict[str, Any] = {} + required: list[str] = [] + + try: + hints = typing.get_type_hints(dc_type) + except Exception: + hints = {f.name: f.type for f in dataclasses.fields(dc_type)} + + defaults = { + f.name + for f in dataclasses.fields(dc_type) + if f.default is not dataclasses.MISSING or callable(f.default_factory) + } + + for field_name, field_type in hints.items(): + prop = _type_to_schema(field_type, schemas) + properties[field_name] = prop + if field_name not in defaults and not _is_optional(field_type): + required.append(field_name) + + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required + + doc = inspect.getdoc(dc_type) + if doc: + schema["description"] = doc + + return schema + + def _get_struct_fields(struct_type: type) -> dict[str, Any]: try: hints = typing.get_type_hints(struct_type) @@ -391,7 +497,6 @@ def _get_struct_fields(struct_type: type) -> dict[str, Any]: for cls in reversed(struct_type.__mro__): if hasattr(cls, "__annotations__"): hints.update(cls.__annotations__) - # Remove non-field annotations hints.pop("__struct_fields__", None) hints.pop("__struct_config__", None) return hints diff --git a/FasterAPI/response.py b/FasterAPI/response.py index b87fbd5..38f6ad7 100644 --- a/FasterAPI/response.py +++ b/FasterAPI/response.py @@ -1,7 +1,10 @@ from __future__ import annotations import asyncio +import datetime +import decimal import mimetypes +import uuid from collections.abc import AsyncIterator, Iterator from pathlib import Path from typing import Any @@ -11,6 +14,26 @@ from .types import ASGIApp +def _enc_hook(obj: Any) -> Any: + """Custom encoder for types not natively supported by msgspec.""" + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, datetime.date): + return obj.isoformat() + if isinstance(obj, datetime.time): + return obj.isoformat() + if isinstance(obj, uuid.UUID): + return str(obj) + if isinstance(obj, decimal.Decimal): + return str(obj) + raise TypeError(f"Unsupported type: {type(obj)!r}") + + +def encode_json(content: Any) -> bytes: + """Encode content to JSON bytes, handling datetime/UUID/Decimal.""" + return msgspec.json.encode(content, enc_hook=_enc_hook) + + class Response: """Base HTTP response class.""" @@ -66,12 +89,17 @@ async def to_asgi(self, send: ASGIApp) -> None: class JSONResponse(Response): - """Response that serializes content as JSON using msgspec.""" + """Response that serializes content as JSON using msgspec (with datetime/UUID/Decimal support).""" media_type = "application/json" def _render(self, content: Any) -> bytes: - return msgspec.json.encode(content) + return encode_json(content) + + +# ORJSONResponse and UJSONResponse are aliases — msgspec is faster than both. +ORJSONResponse = JSONResponse +UJSONResponse = JSONResponse class HTMLResponse(Response): @@ -173,6 +201,80 @@ async def to_asgi(self, send: ASGIApp) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) +class EventSourceResponse: + """Server-Sent Events (SSE) response. + + Streams events to the client in the ``text/event-stream`` format. + + Usage:: + + async def event_generator(): + yield {"data": "hello"} + yield {"event": "update", "data": "world", "id": "1"} + + @app.get("/stream") + async def stream(): + return EventSourceResponse(event_generator()) + """ + + def __init__( + self, + content: AsyncIterator[dict[str, str]] | Iterator[dict[str, str]], + status_code: int = 200, + headers: dict[str, str] | None = None, + ping_interval: float | None = None, + ) -> None: + self.content = content + self.status_code = status_code + self.headers = headers or {} + self.ping_interval = ping_interval + + @staticmethod + def _format_event(event: dict[str, str] | str) -> bytes: + if isinstance(event, str): + return f"data: {event}\n\n".encode() + lines: list[str] = [] + if "id" in event: + lines.append(f"id: {event['id']}") + if "event" in event: + lines.append(f"event: {event['event']}") + if "data" in event: + for line in event["data"].splitlines(): + lines.append(f"data: {line}") + if "retry" in event: + lines.append(f"retry: {event['retry']}") + return ("\n".join(lines) + "\n\n").encode() + + def _build_headers(self) -> list[tuple[bytes, bytes]]: + raw: list[tuple[bytes, bytes]] = [ + (b"content-type", b"text/event-stream"), + (b"cache-control", b"no-cache"), + (b"connection", b"keep-alive"), + (b"x-accel-buffering", b"no"), + ] + for key, value in self.headers.items(): + raw.append((key.lower().encode("latin-1"), value.encode("latin-1"))) + return raw + + async def to_asgi(self, send: ASGIApp) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self._build_headers(), + } + ) + if hasattr(self.content, "__aiter__"): + async for event in self.content: + chunk = self._format_event(event) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + else: + for event in self.content: + chunk = self._format_event(event) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + class FileResponse: """Response that sends a file as an attachment.""" diff --git a/FasterAPI/router.py b/FasterAPI/router.py index fd8f69e..c50c906 100644 --- a/FasterAPI/router.py +++ b/FasterAPI/router.py @@ -101,7 +101,6 @@ def _walk( ) -> RadixNode | None: """Iterative-first tree walk with recursive fallback for param backtracking.""" n = len(segments) - # Fast iterative path for the common case (no backtracking needed) while idx < n: seg = segments[idx] child = node.children.get(seg) @@ -109,7 +108,6 @@ def _walk( node = child idx += 1 continue - # Try param child param_child = node.children.get("*") if param_child is not None: assert param_child.param_name is not None @@ -138,14 +136,20 @@ def _split(path: str) -> list[str]: class FasterRouter: - """API router for grouping routes with a common prefix and tags.""" + """API router for grouping routes with a common prefix, tags, and dependencies.""" - __slots__ = ("prefix", "tags", "routes") + __slots__ = ("prefix", "tags", "routes", "dependencies") - def __init__(self, prefix: str = "", tags: list[str] | None = None) -> None: + def __init__( + self, + prefix: str = "", + tags: list[str] | None = None, + dependencies: list[Any] | None = None, + ) -> None: self.prefix = prefix.rstrip("/") self.tags: list[str] = tags or [] self.routes: list[dict[str, Any]] = [] + self.dependencies: list[Any] = dependencies or [] def _add_route( self, @@ -158,6 +162,10 @@ def _add_route( response_model: Any, status_code: int, deprecated: bool, + responses: dict[int | str, dict[str, Any]] | None, + dependencies: list[Any] | None, + response_model_include: set[str] | None = None, + response_model_exclude: set[str] | None = None, ) -> None: full_path = self.prefix + path self.routes.append( @@ -168,8 +176,12 @@ def _add_route( "tags": self.tags + tags, "summary": summary, "response_model": response_model, + "response_model_include": response_model_include, + "response_model_exclude": response_model_exclude, "status_code": status_code, "deprecated": deprecated, + "responses": responses, + "dependencies": dependencies, } ) @@ -217,6 +229,10 @@ def _route_kw(kw: dict[str, Any]) -> dict[str, Any]: "tags": kw.get("tags") or [], "summary": kw.get("summary", ""), "response_model": kw.get("response_model"), + "response_model_include": kw.get("response_model_include"), + "response_model_exclude": kw.get("response_model_exclude"), "status_code": kw.get("status_code", 200), "deprecated": kw.get("deprecated", False), + "responses": kw.get("responses"), + "dependencies": kw.get("dependencies"), } diff --git a/FasterAPI/security.py b/FasterAPI/security.py new file mode 100644 index 0000000..d927e54 --- /dev/null +++ b/FasterAPI/security.py @@ -0,0 +1,249 @@ +"""Security utilities for FasterAPI — OAuth2, HTTP Basic, and API key authentication.""" + +from __future__ import annotations + +import base64 +import binascii +from typing import Any + +from .exceptions import HTTPException +from .request import Request + +__all__ = [ + "SecurityScopes", + "OAuth2PasswordBearer", + "OAuth2PasswordRequestForm", + "HTTPBasicCredentials", + "HTTPBasic", + "APIKeyHeader", + "APIKeyQuery", + "APIKeyCookie", +] + + +class SecurityScopes: + """Holds the list of OAuth2 security scopes required by a dependency tree.""" + + __slots__ = ("scopes", "scope_str") + + def __init__(self, scopes: list[str] | None = None) -> None: + self.scopes: list[str] = scopes or [] + self.scope_str: str = " ".join(self.scopes) + + def __repr__(self) -> str: + return f"SecurityScopes(scopes={self.scopes!r})" + + +class OAuth2PasswordBearer: + """Extracts a Bearer token from the Authorization header. + + Use as a dependency: + + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token") + + @app.get("/me") + async def me(token: str = Depends(oauth2_scheme)): + ... + """ + + __slots__ = ("tokenUrl", "scheme_name", "scopes", "auto_error") + + def __init__( + self, + tokenUrl: str, + *, + scheme_name: str | None = None, + scopes: dict[str, str] | None = None, + auto_error: bool = True, + ) -> None: + self.tokenUrl = tokenUrl + self.scheme_name = scheme_name or self.__class__.__name__ + self.scopes: dict[str, str] = scopes or {} + self.auto_error = auto_error + + async def __call__(self, request: Request) -> str | None: + authorization = request.headers.get("authorization", "") + if not authorization.startswith("Bearer "): + if self.auto_error: + raise HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + return None + return authorization[7:] + + +class OAuth2PasswordRequestForm: + """Parses an OAuth2 password flow form submission. + + Use as a dependency: + + @app.post("/token") + async def login(form: OAuth2PasswordRequestForm = Depends()): + form.username, form.password, form.scopes + """ + + __slots__ = ("grant_type", "username", "password", "scopes", "client_id", "client_secret") + + def __init__( + self, + *, + grant_type: str | None = None, + username: str = "", + password: str = "", + scope: str = "", + client_id: str | None = None, + client_secret: str | None = None, + ) -> None: + self.grant_type = grant_type + self.username = username + self.password = password + self.scopes: list[str] = scope.split() if scope else [] + self.client_id = client_id + self.client_secret = client_secret + + @classmethod + async def from_request(cls, request: Request) -> OAuth2PasswordRequestForm: + """Parse form data from a request and return a populated instance.""" + form_data = await request.form() + return cls( + grant_type=str(form_data.get("grant_type")) if form_data.get("grant_type") is not None else None, + username=str(form_data.get("username", "")), + password=str(form_data.get("password", "")), + scope=str(form_data.get("scope", "")), + client_id=str(form_data.get("client_id")) if form_data.get("client_id") is not None else None, + client_secret=str(form_data.get("client_secret")) if form_data.get("client_secret") is not None else None, + ) + + +class HTTPBasicCredentials: + """Username and password extracted from an HTTP Basic Authorization header.""" + + __slots__ = ("username", "password") + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + + def __repr__(self) -> str: + return f"HTTPBasicCredentials(username={self.username!r})" + + +class HTTPBasic: + """Extracts credentials from an HTTP Basic Authorization header. + + Use as a dependency: + + http_basic = HTTPBasic() + + @app.get("/protected") + async def protected(creds: HTTPBasicCredentials = Depends(http_basic)): + ... + """ + + __slots__ = ("scheme_name", "realm", "auto_error") + + def __init__( + self, + *, + scheme_name: str | None = None, + realm: str | None = None, + auto_error: bool = True, + ) -> None: + self.scheme_name = scheme_name or self.__class__.__name__ + self.realm = realm + self.auto_error = auto_error + + async def __call__(self, request: Request) -> HTTPBasicCredentials | None: + authorization = request.headers.get("authorization", "") + if not authorization.startswith("Basic "): + if self.auto_error: + www_auth = f'Basic realm="{self.realm}"' if self.realm else "Basic" + raise HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": www_auth}, + ) + return None + try: + decoded = base64.b64decode(authorization[6:]).decode("latin-1") + username, _, password = decoded.partition(":") + except (binascii.Error, UnicodeDecodeError) as exc: + if self.auto_error: + raise HTTPException(status_code=400, detail="Invalid authentication credentials") from exc + return None + return HTTPBasicCredentials(username=username, password=password) + + +class _APIKeyBase: + """Shared base for API key security schemes.""" + + __slots__ = ("name", "scheme_name", "auto_error") + + def __init__(self, name: str, *, scheme_name: str | None = None, auto_error: bool = True) -> None: + self.name = name + self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error + + def _deny(self) -> None: + if self.auto_error: + raise HTTPException(status_code=403, detail="Not authenticated") + + async def __call__(self, request: Request) -> Any: + raise NotImplementedError + + +class APIKeyHeader(_APIKeyBase): + """API key extracted from an HTTP request header. + + api_key_header = APIKeyHeader(name="X-API-Key") + + @app.get("/secure") + async def secure(key: str = Depends(api_key_header)): + ... + """ + + async def __call__(self, request: Request) -> str | None: + key = request.headers.get(self.name.lower()) + if key is None: + self._deny() + return None + return key + + +class APIKeyQuery(_APIKeyBase): + """API key extracted from a query parameter. + + api_key_query = APIKeyQuery(name="api_key") + + @app.get("/secure") + async def secure(key: str = Depends(api_key_query)): + ... + """ + + async def __call__(self, request: Request) -> str | None: + raw = request.query_params.get(self.name) + if raw is None: + self._deny() + return None + key: str = str(raw) + return key + + +class APIKeyCookie(_APIKeyBase): + """API key extracted from a cookie. + + api_key_cookie = APIKeyCookie(name="session") + + @app.get("/secure") + async def secure(key: str = Depends(api_key_cookie)): + ... + """ + + async def __call__(self, request: Request) -> str | None: + key = request.cookies.get(self.name) + if key is None: + self._deny() + return None + return key diff --git a/FasterAPI/staticfiles.py b/FasterAPI/staticfiles.py new file mode 100644 index 0000000..a9a2ce7 --- /dev/null +++ b/FasterAPI/staticfiles.py @@ -0,0 +1,95 @@ +"""StaticFiles ASGI application for serving files from a directory.""" + +from __future__ import annotations + +import mimetypes +from pathlib import Path +from typing import Any + +__all__ = ["StaticFiles"] + +_CT_JSON = b"application/json" +_NOT_FOUND = b'{"detail":"Not Found"}' +_METHOD_NOT_ALLOWED = b'{"detail":"Method Not Allowed"}' + + +class StaticFiles: + """Serve static files from a local directory as an ASGI application. + + Usage:: + + app.mount("/static", StaticFiles(directory="static"), name="static") + + Requests to ``/static/logo.png`` will serve ``static/logo.png`` from disk. + Set ``html=True`` to serve ``index.html`` for directory requests. + """ + + def __init__(self, *, directory: str | Path, html: bool = False, check_dir: bool = True) -> None: + self.directory = Path(directory).resolve() + self.html = html + if check_dir and not self.directory.is_dir(): + raise RuntimeError(f"StaticFiles directory '{directory}' does not exist") + + async def __call__( + self, + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + if scope["type"] != "http": + return + if scope["method"] not in ("GET", "HEAD"): + await _send_error(send, 405, _METHOD_NOT_ALLOWED) + return + await self._handle(scope, send) + + async def _handle(self, scope: dict[str, Any], send: Any) -> None: + raw_path: str = scope.get("path", "/") + # Strip leading slash and normalize + rel = raw_path.lstrip("/") + file_path = (self.directory / rel).resolve() + + # Security: prevent path traversal + try: + file_path.relative_to(self.directory) + except ValueError: + await _send_error(send, 404, _NOT_FOUND) + return + + # Directory handling + if file_path.is_dir(): + if self.html: + file_path = file_path / "index.html" + else: + await _send_error(send, 404, _NOT_FOUND) + return + + if not file_path.is_file(): + await _send_error(send, 404, _NOT_FOUND) + return + + media_type, encoding = mimetypes.guess_type(str(file_path)) + if media_type is None: + media_type = "application/octet-stream" + + content = file_path.read_bytes() + headers: list[tuple[bytes, bytes]] = [ + (b"content-type", media_type.encode("latin-1")), + (b"content-length", str(len(content)).encode()), + ] + if encoding: + headers.append((b"content-encoding", encoding.encode("latin-1"))) + + await send({"type": "http.response.start", "status": 200, "headers": headers}) + await send({"type": "http.response.body", "body": content}) + + +async def _send_error(send: Any, status: int, body: bytes) -> None: + await send( + { + "type": "http.response.start", + "status": status, + "headers": [(b"content-type", _CT_JSON)], + } + ) + await send({"type": "http.response.body", "body": body}) diff --git a/FasterAPI/templating.py b/FasterAPI/templating.py new file mode 100644 index 0000000..fa179d4 --- /dev/null +++ b/FasterAPI/templating.py @@ -0,0 +1,55 @@ +"""Jinja2 template rendering for FasterAPI.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from .request import Request +from .response import HTMLResponse, Response + +__all__ = ["Jinja2Templates"] + + +class Jinja2Templates: + """Render Jinja2 templates as HTML responses. + + Usage:: + + templates = Jinja2Templates(directory="templates") + + @app.get("/hello/{name}") + async def hello(request: Request, name: str): + return templates.TemplateResponse(request, "hello.html", {"name": name}) + + Requires ``jinja2`` to be installed: ``pip install jinja2``. + """ + + def __init__(self, directory: str | Path) -> None: + try: + import jinja2 + except ImportError as exc: + raise ImportError("Jinja2Templates requires jinja2. Install with: pip install jinja2") from exc + + self.env = jinja2.Environment( + loader=jinja2.FileSystemLoader(str(directory)), + autoescape=jinja2.select_autoescape(["html", "xml"]), + ) + + def get_template(self, name: str) -> Any: + return self.env.get_template(name) + + def TemplateResponse( # noqa: N802 + self, + request: Request, + name: str, + context: dict[str, Any] | None = None, + status_code: int = 200, + headers: dict[str, str] | None = None, + media_type: str = "text/html", + ) -> Response: + ctx = dict(context) if context else {} + ctx.setdefault("request", request) + template = self.get_template(name) + content = template.render(ctx) + return HTMLResponse(content=content, status_code=status_code, headers=headers) diff --git a/FasterAPI/testclient.py b/FasterAPI/testclient.py index 7c61278..0a4b213 100644 --- a/FasterAPI/testclient.py +++ b/FasterAPI/testclient.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import threading from collections.abc import Generator from contextlib import contextmanager from typing import Any @@ -93,11 +94,50 @@ def __init__( ) def __enter__(self) -> TestClient: + self._lifespan_startup_done = threading.Event() + self._lifespan_shutdown_trigger = threading.Event() + self._lifespan_thread = threading.Thread(target=self._run_lifespan_thread, daemon=True) + self._lifespan_thread.start() + self._lifespan_startup_done.wait(timeout=5.0) return self def __exit__(self, *args: Any) -> None: + if hasattr(self, "_lifespan_shutdown_trigger"): + self._lifespan_shutdown_trigger.set() + self._lifespan_thread.join(timeout=5.0) self._run(self._client.aclose()) + def _run_lifespan_thread(self) -> None: + """Run the app's lifespan protocol in a dedicated thread.""" + startup_done = self._lifespan_startup_done + shutdown_trigger = self._lifespan_shutdown_trigger + + async def _run() -> None: + messages: list[dict[str, Any]] = [ + {"type": "lifespan.startup"}, + ] + idx = [0] + + async def receive() -> dict[str, Any]: + if idx[0] < len(messages): + msg = messages[idx[0]] + idx[0] += 1 + return msg + # Wait for shutdown trigger + await asyncio.get_event_loop().run_in_executor(None, shutdown_trigger.wait) + return {"type": "lifespan.shutdown"} + + async def send(msg: dict[str, Any]) -> None: + if msg.get("type") == "lifespan.startup.complete": + startup_done.set() + + try: + await self.app({"type": "lifespan"}, receive, send) + except Exception: + startup_done.set() + + asyncio.run(_run()) + def _run(self, coro: Any) -> Any: try: loop = asyncio.get_running_loop() diff --git a/pyproject.toml b/pyproject.toml index 8fedbca..c618460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dev = [ "mypy>=1.10.0", "ruff>=0.8.0", "tox>=4.0.0", + "jinja2>=3.0.0", ] benchmark = [ "httpx>=0.27.0", diff --git a/tests/test_coverage_gaps.py b/tests/test_coverage_gaps.py new file mode 100644 index 0000000..aee0287 --- /dev/null +++ b/tests/test_coverage_gaps.py @@ -0,0 +1,390 @@ +"""Tests to cover staticfiles.py, templating.py, and security edge cases.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +from FasterAPI import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + Depends, + Faster, + HTTPBasic, + HTTPBasicCredentials, + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, + Request, + SecurityScopes, + StaticFiles, +) +from FasterAPI.testclient import TestClient + +# --------------------------------------------------------------------------- +# StaticFiles +# --------------------------------------------------------------------------- + + +def _make_static_dir() -> tempfile.TemporaryDirectory: # type: ignore[type-arg] + td = tempfile.TemporaryDirectory() + root = Path(td.name) + (root / "hello.txt").write_text("Hello, world!") + (root / "style.css").write_text("body { color: red; }") + sub = root / "sub" + sub.mkdir() + (sub / "index.html").write_text("

Sub Index

") + (root / "index.html").write_text("

Root Index

") + return td + + +def test_staticfiles_serve_text_file(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.get("/static/hello.txt") + assert resp.status_code == 200 + assert resp.text == "Hello, world!" + assert "text/plain" in resp.headers["content-type"] + + +def test_staticfiles_serve_css_file(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.get("/static/style.css") + assert resp.status_code == 200 + assert "text/css" in resp.headers["content-type"] + + +def test_staticfiles_not_found(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.get("/static/missing.txt") + assert resp.status_code == 404 + + +def test_staticfiles_method_not_allowed(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.post("/static/hello.txt") + assert resp.status_code == 405 + + +def test_staticfiles_directory_no_html(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td, html=False)) + + client = TestClient(app) + resp = client.get("/static/sub/") + assert resp.status_code == 404 + + +def test_staticfiles_directory_with_html(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td, html=True)) + + client = TestClient(app) + resp = client.get("/static/sub/") + assert resp.status_code == 200 + assert "Sub Index" in resp.text + + +def test_staticfiles_root_index_html(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td, html=True)) + + client = TestClient(app) + resp = client.get("/static/") + assert resp.status_code == 200 + assert "Root Index" in resp.text + + +def test_staticfiles_path_traversal_blocked(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.get("/static/../../../etc/passwd") + assert resp.status_code == 404 + + +def test_staticfiles_missing_directory_raises(): + with pytest.raises(RuntimeError, match="does not exist"): + StaticFiles(directory="/nonexistent/path/that/does/not/exist") + + +def test_staticfiles_check_dir_false(): + # Should not raise even if directory doesn't exist when check_dir=False + sf = StaticFiles(directory="/nonexistent", check_dir=False) + assert sf.directory == Path("/nonexistent") + + +def test_staticfiles_head_method(): + with _make_static_dir() as td: + app = Faster() + app.mount("/static", StaticFiles(directory=td)) + + client = TestClient(app) + resp = client.head("/static/hello.txt") + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Jinja2Templates +# --------------------------------------------------------------------------- + + +def _make_template_dir() -> tempfile.TemporaryDirectory: # type: ignore[type-arg] + td = tempfile.TemporaryDirectory() + (Path(td.name) / "hello.html").write_text("

Hello {{ name }}!

") + (Path(td.name) / "simple.html").write_text("

Simple

") + return td + + +def test_jinja2_template_response(): + try: + import jinja2 # noqa: F401 + except ImportError: + pytest.skip("jinja2 not installed") + + from FasterAPI import Jinja2Templates + + with _make_template_dir() as td: + templates = Jinja2Templates(directory=td) + app = Faster() + + @app.get("/hello/{name}") + async def hello(request: Request, name: str): + return templates.TemplateResponse(request, "hello.html", {"name": name}) + + client = TestClient(app) + resp = client.get("/hello/World") + assert resp.status_code == 200 + assert "

Hello World!

" in resp.text + + +def test_jinja2_template_response_status_code(): + try: + import jinja2 # noqa: F401 + except ImportError: + pytest.skip("jinja2 not installed") + + from FasterAPI import Jinja2Templates + + with _make_template_dir() as td: + templates = Jinja2Templates(directory=td) + app = Faster() + + @app.get("/simple") + async def simple(request: Request): + return templates.TemplateResponse(request, "simple.html", status_code=201) + + client = TestClient(app) + resp = client.get("/simple") + assert resp.status_code == 201 + assert "

Simple

" in resp.text + + +def test_jinja2_get_template(): + try: + import jinja2 # noqa: F401 + except ImportError: + pytest.skip("jinja2 not installed") + + from FasterAPI import Jinja2Templates + + with _make_template_dir() as td: + templates = Jinja2Templates(directory=td) + tmpl = templates.get_template("hello.html") + rendered = tmpl.render({"name": "Test"}) + assert "Hello Test" in rendered + + +def test_jinja2_missing_import(): + import sys + + jinja2_backup = sys.modules.pop("jinja2", None) + try: + from FasterAPI.templating import Jinja2Templates as J2T + + with _make_template_dir() as td: + t = J2T(directory=td) + assert t.env is not None + finally: + if jinja2_backup is not None: + sys.modules["jinja2"] = jinja2_backup + + +# --------------------------------------------------------------------------- +# Security edge cases (missing coverage) +# --------------------------------------------------------------------------- + + +def test_security_scopes_repr(): + scopes = SecurityScopes(["read", "write"]) + assert "read" in repr(scopes) + + +def test_oauth2_form_instantiation_with_scope(): + form = OAuth2PasswordRequestForm( + grant_type="password", + username="user", + password="pass", + scope="read write", + client_id="client", + client_secret="secret", + ) + assert form.username == "user" + assert form.scopes == ["read", "write"] + assert form.client_id == "client" + + +def test_oauth2_form_empty_scope(): + form = OAuth2PasswordRequestForm(username="u", password="p") + assert form.scopes == [] + + +def test_oauth2_form_from_request(): + app = Faster() + + @app.post("/token") + async def token(form: OAuth2PasswordRequestForm = Depends(OAuth2PasswordRequestForm)): + return {"username": form.username, "scopes": form.scopes} + + client = TestClient(app) + resp = client.post( + "/token", + data={"username": "alice", "password": "secret", "scope": "read write"}, + ) + assert resp.status_code == 200 + assert resp.json()["username"] == "alice" + assert "read" in resp.json()["scopes"] + + +def test_http_basic_credentials_repr(): + creds = HTTPBasicCredentials(username="alice", password="secret") + assert "alice" in repr(creds) + + +def test_http_basic_invalid_base64(): + http_basic = HTTPBasic() + app = Faster() + + @app.get("/protected") + async def protected(creds: HTTPBasicCredentials = Depends(http_basic)): + return {"username": creds.username} + + client = TestClient(app) + resp = client.get("/protected", headers={"Authorization": "Basic not-valid-base64!!!"}) + assert resp.status_code in (400, 401) + + +def test_api_key_header_missing_auto_error_false(): + api_key = APIKeyHeader(name="X-API-Key", auto_error=False) + app = Faster() + + @app.get("/secure") + async def secure(key: str | None = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 200 + assert resp.json()["key"] is None + + +def test_api_key_query_missing_auto_error_false(): + api_key = APIKeyQuery(name="api_key", auto_error=False) + app = Faster() + + @app.get("/secure") + async def secure(key: str | None = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 200 + assert resp.json()["key"] is None + + +def test_api_key_cookie_missing_auto_error_false(): + api_key = APIKeyCookie(name="session", auto_error=False) + app = Faster() + + @app.get("/secure") + async def secure(key: str | None = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 200 + assert resp.json()["key"] is None + + +def test_api_key_header_missing_auto_error_true(): + api_key = APIKeyHeader(name="X-API-Key", auto_error=True) + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 403 + + +def test_api_key_query_missing_auto_error_true(): + api_key = APIKeyQuery(name="api_key", auto_error=True) + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 403 + + +def test_api_key_cookie_missing_auto_error_true(): + api_key = APIKeyCookie(name="session", auto_error=True) + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure") + assert resp.status_code == 403 + + +def test_oauth2_bearer_no_auto_error_returns_none(): + oauth2 = OAuth2PasswordBearer(tokenUrl="/token", auto_error=False) + app = Faster() + + @app.get("/optional") + async def optional(token: str | None = Depends(oauth2)): + return {"authenticated": token is not None} + + client = TestClient(app) + resp = client.get("/optional") + assert resp.status_code == 200 + assert resp.json()["authenticated"] is False diff --git a/tests/test_new_features.py b/tests/test_new_features.py new file mode 100644 index 0000000..fd6c637 --- /dev/null +++ b/tests/test_new_features.py @@ -0,0 +1,610 @@ +"""Tests for the newly added features: +- Security utilities (OAuth2, HTTPBasic, APIKey*) +- Lifespan context manager +- response_model / response_model_include / response_model_exclude +- Annotated[T, Depends(...)] — PEP 593 style +- Sub-application mounting (mount()) +- Server-Sent Events (EventSourceResponse) +- ORJSONResponse / UJSONResponse aliases +- datetime / UUID / Decimal serialization +- Multiple response declarations (responses={...}) +- APIRouter dependencies +- Dataclass support +- Enum path parameters in OpenAPI +- openapi_tags / terms_of_service / contact / license_info +""" + +from __future__ import annotations + +import dataclasses +import datetime +import decimal +import enum +import uuid +from contextlib import asynccontextmanager +from typing import Annotated + +import msgspec +from FasterAPI import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + Depends, + EventSourceResponse, + Faster, + FasterRouter, + HTTPBasic, + HTTPBasicCredentials, + JSONResponse, + OAuth2PasswordBearer, + ORJSONResponse, + Request, + SecurityScopes, + UJSONResponse, +) +from FasterAPI.testclient import TestClient + +# --------------------------------------------------------------------------- +# Security — OAuth2PasswordBearer +# --------------------------------------------------------------------------- + + +def test_oauth2_password_bearer_valid(): + oauth2 = OAuth2PasswordBearer(tokenUrl="/token") + app = Faster() + + @app.get("/me") + async def me(token: str = Depends(oauth2)): + return {"token": token} + + client = TestClient(app) + resp = client.get("/me", headers={"Authorization": "Bearer mytoken123"}) + assert resp.status_code == 200 + assert resp.json() == {"token": "mytoken123"} + + +def test_oauth2_password_bearer_missing(): + oauth2 = OAuth2PasswordBearer(tokenUrl="/token") + app = Faster() + + @app.get("/me") + async def me(token: str = Depends(oauth2)): + return {"token": token} + + client = TestClient(app) + resp = client.get("/me") + assert resp.status_code == 401 + + +def test_oauth2_password_bearer_no_auto_error(): + oauth2 = OAuth2PasswordBearer(tokenUrl="/token", auto_error=False) + app = Faster() + + @app.get("/me") + async def me(token: str | None = Depends(oauth2)): + return {"token": token} + + client = TestClient(app) + resp = client.get("/me") + assert resp.status_code == 200 + assert resp.json() == {"token": None} + + +# --------------------------------------------------------------------------- +# Security — HTTPBasic +# --------------------------------------------------------------------------- + + +def test_http_basic_valid(): + import base64 + + http_basic = HTTPBasic() + app = Faster() + + @app.get("/protected") + async def protected(creds: HTTPBasicCredentials = Depends(http_basic)): + return {"username": creds.username} + + client = TestClient(app) + encoded = base64.b64encode(b"alice:secret").decode() + resp = client.get("/protected", headers={"Authorization": f"Basic {encoded}"}) + assert resp.status_code == 200 + assert resp.json() == {"username": "alice"} + + +def test_http_basic_missing(): + http_basic = HTTPBasic() + app = Faster() + + @app.get("/protected") + async def protected(creds: HTTPBasicCredentials = Depends(http_basic)): + return {"username": creds.username} + + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# Security — API Key variants +# --------------------------------------------------------------------------- + + +def test_api_key_header(): + api_key = APIKeyHeader(name="X-API-Key") + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure", headers={"X-API-Key": "secret123"}) + assert resp.status_code == 200 + assert resp.json() == {"key": "secret123"} + + +def test_api_key_query(): + api_key = APIKeyQuery(name="api_key") + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure?api_key=mykey") + assert resp.status_code == 200 + assert resp.json() == {"key": "mykey"} + + +def test_api_key_cookie(): + api_key = APIKeyCookie(name="session") + app = Faster() + + @app.get("/secure") + async def secure(key: str = Depends(api_key)): + return {"key": key} + + client = TestClient(app) + resp = client.get("/secure", cookies={"session": "cookietoken"}) + assert resp.status_code == 200 + assert resp.json() == {"key": "cookietoken"} + + +# --------------------------------------------------------------------------- +# SecurityScopes +# --------------------------------------------------------------------------- + + +def test_security_scopes(): + scopes = SecurityScopes(["read:users", "write:users"]) + assert scopes.scopes == ["read:users", "write:users"] + assert scopes.scope_str == "read:users write:users" + + +# --------------------------------------------------------------------------- +# Lifespan context manager +# --------------------------------------------------------------------------- + + +def test_lifespan_context_manager(): + state: list[str] = [] + + @asynccontextmanager + async def lifespan(app: Faster): + state.append("startup") + yield + state.append("shutdown") + + app = Faster(lifespan=lifespan) + + @app.get("/") + async def root(): + return {"state": state} + + with TestClient(app): + pass + + assert state == ["startup", "shutdown"] + + +def test_lifespan_and_route(): + db: dict[str, str] = {} + + @asynccontextmanager + async def lifespan(app: Faster): + db["initialized"] = "true" + yield + db.clear() + + app = Faster(lifespan=lifespan) + + @app.get("/db") + async def check(): + return {"initialized": db.get("initialized")} + + with TestClient(app) as client: + resp = client.get("/db") + assert resp.json() == {"initialized": "true"} + + +# --------------------------------------------------------------------------- +# response_model filtering +# --------------------------------------------------------------------------- + + +class UserFull(msgspec.Struct): + id: int + name: str + password: str + + +class UserPublic(msgspec.Struct): + id: int + name: str + + +def test_response_model_filters_fields(): + app = Faster() + + @app.get("/user", response_model=UserPublic) + async def get_user(): + return UserFull(id=1, name="Alice", password="secret") + + client = TestClient(app) + resp = client.get("/user") + assert resp.status_code == 200 + data = resp.json() + assert "id" in data + assert "name" in data + assert "password" not in data + + +def test_response_model_include(): + app = Faster() + + @app.get("/user", response_model=UserFull, response_model_include={"id", "name"}) + async def get_user(): + return UserFull(id=1, name="Alice", password="secret") + + client = TestClient(app) + resp = client.get("/user") + data = resp.json() + assert "password" not in data + assert data["name"] == "Alice" + + +def test_response_model_exclude(): + app = Faster() + + @app.get("/user", response_model=UserFull, response_model_exclude={"password"}) + async def get_user(): + return UserFull(id=1, name="Alice", password="secret") + + client = TestClient(app) + resp = client.get("/user") + data = resp.json() + assert "password" not in data + assert data["id"] == 1 + + +# --------------------------------------------------------------------------- +# Annotated PEP 593 style dependencies +# --------------------------------------------------------------------------- + + +def get_token_dep(request: Request) -> str: + return request.headers.get("x-token", "none") + + +def test_annotated_depends(): + app = Faster() + + @app.get("/annotated") + async def handler(token: Annotated[str, Depends(get_token_dep)]): + return {"token": token} + + client = TestClient(app) + resp = client.get("/annotated", headers={"X-Token": "hello"}) + assert resp.status_code == 200 + assert resp.json() == {"token": "hello"} + + +# --------------------------------------------------------------------------- +# Sub-application mounting +# --------------------------------------------------------------------------- + + +def test_mount_sub_app(): + sub = Faster() + + @sub.get("/hello") + async def sub_hello(): + return {"from": "sub"} + + app = Faster() + app.mount("/sub", sub) + + client = TestClient(app) + resp = client.get("/sub/hello") + assert resp.status_code == 200 + assert resp.json() == {"from": "sub"} + + +# --------------------------------------------------------------------------- +# Server-Sent Events +# --------------------------------------------------------------------------- + + +def test_event_source_response_sync(): + def generator(): + yield {"data": "hello"} + yield {"event": "update", "data": "world", "id": "1"} + + app = Faster() + + @app.get("/stream") + async def stream(): + return EventSourceResponse(generator()) + + client = TestClient(app) + resp = client.get("/stream") + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers["content-type"] + body = resp.text + assert "data: hello" in body + assert "event: update" in body + + +def test_event_source_response_format(): + from FasterAPI.response import EventSourceResponse as ESR + + sse = ESR.__new__(ESR) + chunk = sse._format_event({"event": "msg", "data": "hi", "id": "42"}) + text = chunk.decode() + assert "event: msg" in text + assert "data: hi" in text + assert "id: 42" in text + + +# --------------------------------------------------------------------------- +# ORJSONResponse / UJSONResponse aliases +# --------------------------------------------------------------------------- + + +def test_orjson_response_alias(): + assert ORJSONResponse is JSONResponse + + +def test_ujson_response_alias(): + assert UJSONResponse is JSONResponse + + +# --------------------------------------------------------------------------- +# datetime / UUID / Decimal serialization +# --------------------------------------------------------------------------- + + +def test_datetime_serialization(): + app = Faster() + now = datetime.datetime(2024, 1, 15, 10, 30, 0) + + @app.get("/dt") + async def get_dt(): + return {"dt": now} + + client = TestClient(app) + resp = client.get("/dt") + assert resp.status_code == 200 + assert "2024-01-15" in resp.json()["dt"] + + +def test_uuid_serialization(): + app = Faster() + uid = uuid.UUID("12345678-1234-5678-1234-567812345678") + + @app.get("/uid") + async def get_uid(): + return {"uid": uid} + + client = TestClient(app) + resp = client.get("/uid") + assert resp.status_code == 200 + assert resp.json()["uid"] == str(uid) + + +def test_decimal_serialization(): + app = Faster() + + @app.get("/dec") + async def get_dec(): + return {"value": decimal.Decimal("3.14159")} + + client = TestClient(app) + resp = client.get("/dec") + assert resp.status_code == 200 + assert resp.json()["value"] == "3.14159" + + +# --------------------------------------------------------------------------- +# Multiple response declarations (responses={...}) +# --------------------------------------------------------------------------- + + +def test_multiple_responses_in_openapi(): + app = Faster() + + @app.get( + "/items/{id}", + responses={404: {"description": "Item not found"}, 403: {"description": "Forbidden"}}, + ) + async def get_item(id: int): + return {"id": id} + + from FasterAPI.openapi.generator import generate_openapi + + spec = generate_openapi(app, title="Test", version="0.1") + path = spec["paths"]["/items/{id}"]["get"] + assert "404" in path["responses"] + assert path["responses"]["404"]["description"] == "Item not found" + assert "403" in path["responses"] + + +# --------------------------------------------------------------------------- +# Router-level dependencies +# --------------------------------------------------------------------------- + + +def test_router_level_dependencies(): + called: list[str] = [] + + async def router_dep(): + called.append("router_dep") + + router = FasterRouter(prefix="/api", dependencies=[Depends(router_dep)]) + + @router.get("/hello") + async def hello(): + return {"msg": "hi"} + + app = Faster() + app.include_router(router) + + client = TestClient(app) + resp = client.get("/api/hello") + assert resp.status_code == 200 + assert "router_dep" in called + + +def test_include_router_with_dependencies(): + called: list[str] = [] + + async def extra_dep(): + called.append("extra") + + router = FasterRouter(prefix="/r") + + @router.get("/x") + async def x(): + return {"x": 1} + + app = Faster() + app.include_router(router, dependencies=[Depends(extra_dep)]) + + client = TestClient(app) + resp = client.get("/r/x") + assert resp.status_code == 200 + assert "extra" in called + + +# --------------------------------------------------------------------------- +# Dataclass support +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class Item: + name: str + price: float + in_stock: bool = True + + +def test_dataclass_request_body(): + app = Faster() + + @app.post("/items") + async def create_item(item: Item): + return {"name": item.name, "price": item.price} + + client = TestClient(app) + resp = client.post("/items", json={"name": "Widget", "price": 9.99}) + assert resp.status_code == 200 + assert resp.json() == {"name": "Widget", "price": 9.99} + + +def test_dataclass_openapi_schema(): + from FasterAPI.openapi.generator import generate_openapi + + app = Faster() + + @app.post("/items") + async def create_item(item: Item): + return item + + spec = generate_openapi(app, title="T", version="0") + assert "Item" in spec.get("components", {}).get("schemas", {}) + + +# --------------------------------------------------------------------------- +# Enum path parameters in OpenAPI +# --------------------------------------------------------------------------- + + +class Color(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +def test_enum_path_param_openapi(): + from FasterAPI.openapi.generator import generate_openapi + + app = Faster() + + @app.get("/items/{color}") + async def get_by_color(color: Color): + return {"color": color.value} + + spec = generate_openapi(app, title="T", version="0") + params = spec["paths"]["/items/{color}"]["get"]["parameters"] + color_param = next(p for p in params if p["name"] == "color") + assert "enum" in color_param["schema"] + assert "red" in color_param["schema"]["enum"] + + +# --------------------------------------------------------------------------- +# openapi_tags / terms_of_service / contact / license_info +# --------------------------------------------------------------------------- + + +def test_openapi_metadata(): + from FasterAPI.openapi.generator import generate_openapi + + app = Faster( + openapi_tags=[{"name": "users", "description": "User operations"}], + terms_of_service="https://example.com/tos", + contact={"name": "Support", "email": "support@example.com"}, + license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"}, + ) + + @app.get("/users", tags=["users"]) + async def list_users(): + return [] + + spec = generate_openapi( + app, + title="MyAPI", + version="1.0", + openapi_tags=app.openapi_tags, + terms_of_service=app.terms_of_service, + contact=app.contact, + license_info=app.license_info, + ) + + assert spec["info"]["termsOfService"] == "https://example.com/tos" + assert spec["info"]["contact"]["name"] == "Support" + assert spec["info"]["license"]["name"] == "MIT" + assert any(t["name"] == "users" for t in spec.get("tags", [])) + + +def test_openapi_tags_via_app_route(): + app = Faster() + + @app.get("/ping") + async def ping(): + return "pong" + + client = TestClient(app) + resp = client.get("/openapi.json") + assert resp.status_code == 200