diff --git a/python/turboapi/openapi.py b/python/turboapi/openapi.py index 0fb39a0..bfdecfe 100644 --- a/python/turboapi/openapi.py +++ b/python/turboapi/openapi.py @@ -4,8 +4,66 @@ interactive API documentation at /docs (Swagger UI) and /redoc (ReDoc). """ +import copy import inspect -from typing import Any, Union, get_args, get_origin +import json +import re +import types +from typing import Annotated, Any, Union, get_args, get_origin + +from .datastructures import ( + Body, + Cookie, + File, + Form, + Header, + Query, + UploadFile, +) +from .datastructures import ( + Path as PathMarker, +) +from .security import Depends, SecurityBase, get_depends + +_BODY_METHODS = {"POST", "PUT", "PATCH", "DELETE"} +_PARAM_MARKER_TYPES = (Body, Cookie, File, Form, Header, PathMarker, Query) + +_VALIDATION_ERROR_SCHEMAS = { + "ValidationError": { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, +} + + +class _SchemaContext: + """Tracks component names generated for a single OpenAPI schema.""" + + def __init__(self, components: dict[str, Any]): + self.components = components + self.model_component_names: dict[type, str] = {} + self.component_models: dict[str, type] = {} def generate_openapi_schema(app) -> dict: @@ -25,8 +83,10 @@ def generate_openapi_schema(app) -> dict: "description": getattr(app, "description", ""), }, "paths": {}, - "components": {"schemas": {}}, + "components": {"schemas": copy.deepcopy(_VALIDATION_ERROR_SCHEMAS)}, } + components = schema["components"]["schemas"] + schema_context = _SchemaContext(components) routes = app.registry.get_routes() for route in routes: @@ -35,7 +95,7 @@ def generate_openapi_schema(app) -> dict: handler = route.handler # Generate operation - operation = _generate_operation(handler, route) + operation = _generate_operation(handler, route, schema_context) # Add to paths openapi_path = _convert_path(path) @@ -51,15 +111,20 @@ def _convert_path(path: str) -> str: return path -def _generate_operation(handler, route) -> dict: +def _generate_operation(handler, route, schema_context: _SchemaContext) -> dict: """Generate OpenAPI operation object from handler.""" + response_schema = {} + response_model = getattr(route, "response_model", None) + if response_model is not None: + response_schema = _type_to_schema(response_model, schema_context) + operation: dict[str, Any] = { "summary": _get_summary(handler), "operationId": f"{route.method.value.lower()}_{handler.__name__}", "responses": { "200": { "description": "Successful Response", - "content": {"application/json": {"schema": {}}}, + "content": {"application/json": {"schema": response_schema}}, }, "422": { "description": "Validation Error", @@ -72,63 +137,134 @@ def _generate_operation(handler, route) -> dict: }, } - # Extract parameters from signature sig = inspect.signature(handler) parameters = [] - request_body_props = {} - - import re - + body_entries: list[dict[str, Any]] = [] path_params = set(re.findall(r"\{([^}]+)\}", route.path)) + method = route.method.value.upper() for param_name, param in sig.parameters.items(): - annotation = param.annotation - param_schema = _type_to_schema(annotation) + if _is_dependency_parameter(param): + continue + + annotation, marker = _resolve_annotation_and_marker(param) + required = _is_required_param(param, marker) + body_required = _is_required_body_param(param, marker) + + if method in _BODY_METHODS and _is_unsupported_annotated_model_body( + param.annotation, marker + ): + body_entries.append( + { + "name": param_name, + "schema": {}, + "required": required, + "media_type": "application/json", + "direct": False, + } + ) + continue + + if isinstance(marker, Cookie): + # Runtime request handling does not bind Cookie() route parameters yet. + # Avoid advertising cookie params or registering component schemas for + # them until parsing support exists. + continue + + if isinstance(marker, Query): + # Runtime query parsing binds by Python parameter name and does not + # consume Query.alias, Query.default, or validation metadata. + param_schema = _schema_for_param( + annotation, None, param, schema_context, include_default=False + ) + elif isinstance(marker, Form): + # Runtime form parsing passes field values through as raw strings. + param_schema = _schema_for_param(str, marker, param, schema_context) + elif isinstance(marker, Body): + # Runtime body parsing does not unwrap Body.default values. + param_schema = _schema_for_param( + annotation, marker, param, schema_context, include_default=False + ) + else: + param_schema = _schema_for_param(annotation, marker, param, schema_context) - if param_name in path_params: + if param_name in path_params or isinstance(marker, PathMarker): + parameters.append(_build_parameter(param_name, "path", True, param_schema, marker)) + elif isinstance(marker, Query): parameters.append( + _build_parameter( + param_name, + "query", + True, + param_schema, + None, + ) + ) + elif isinstance(marker, Header): + parameters.append( + _build_parameter( + _parameter_alias(param_name, marker, location="header"), + "header", + required, + param_schema, + marker, + ) + ) + elif _is_form_or_file_param(annotation, marker): + media_type = getattr(marker, "media_type", None) or "multipart/form-data" + if _is_file_param(annotation, marker): + media_type = "multipart/form-data" + body_entries.append( + { + "name": _parameter_alias(param_name, marker), + "schema": param_schema, + "required": required, + "media_type": media_type, + "direct": False, + } + ) + elif isinstance(marker, Body): + body_entries.append( + { + "name": _parameter_alias(param_name, marker), + "schema": param_schema, + "required": body_required, + "media_type": marker.media_type, + # RequestBodyParser currently validates single model parameters + # against the whole JSON body and does not inspect Body.embed. + "direct": _is_model_class(annotation), + } + ) + elif method in _BODY_METHODS and _is_model_class(annotation): + body_entries.append( { "name": param_name, - "in": "path", - "required": True, "schema": param_schema, + "required": required, + "media_type": "application/json", + "direct": True, + } + ) + elif method in _BODY_METHODS: + body_entries.append( + { + "name": param_name, + "schema": param_schema, + "required": required, + "media_type": "application/json", + "direct": False, } ) - elif route.method.value.upper() in ("POST", "PUT", "PATCH"): - # Body parameter - request_body_props[param_name] = param_schema - if param.default is not inspect.Parameter.empty: - request_body_props[param_name]["default"] = param.default else: - # Query parameter - query_param = { - "name": param_name, - "in": "query", - "schema": param_schema, - } - if param.default is inspect.Parameter.empty: - query_param["required"] = True - else: - query_param["required"] = False - if param.default is not None: - query_param["schema"]["default"] = param.default - parameters.append(query_param) + parameters.append( + _build_parameter(param_name, "query", required, param_schema, marker) + ) if parameters: operation["parameters"] = parameters - if request_body_props: - operation["requestBody"] = { - "required": True, - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": request_body_props, - } - } - }, - } + if body_entries: + operation["requestBody"] = _build_request_body(body_entries) # Add tags if hasattr(route, "tags") and route.tags: @@ -147,10 +283,16 @@ def _get_summary(handler) -> str: return name.replace("_", " ").title() -def _type_to_schema(annotation) -> dict: +def _type_to_schema(annotation, schema_context: _SchemaContext | None = None) -> dict: """Convert Python type annotation to OpenAPI schema.""" + annotation, _metadata = _unwrap_annotated(annotation) + if annotation is inspect.Parameter.empty or annotation is Any: return {} + if annotation is type(None): + return {"type": "null"} + if _is_model_class(annotation): + return _register_model_schema(annotation, schema_context) if annotation is str: return {"type": "string"} if annotation is int: @@ -165,37 +307,413 @@ def _type_to_schema(annotation) -> dict: return {"type": "object"} if annotation is bytes: return {"type": "string", "format": "binary"} + if _is_upload_file_type(annotation): + return {"type": "string", "format": "binary"} - # Handle typing generics origin = get_origin(annotation) - if origin is list: + if origin in (list, tuple, set, frozenset): args = get_args(annotation) - items_schema = _type_to_schema(args[0]) if args else {} + items_schema = _type_to_schema(args[0], schema_context) if args else {} return {"type": "array", "items": items_schema} if origin is dict: + args = get_args(annotation) + schema = {"type": "object"} + if len(args) == 2: + schema["additionalProperties"] = _type_to_schema(args[1], schema_context) + return schema + + if _is_union_type(annotation): + args = get_args(annotation) + schemas = [_type_to_schema(arg, schema_context) for arg in args] + if len(schemas) == 1: + return schemas[0] + return {"anyOf": schemas} + + if inspect.isclass(annotation): return {"type": "object"} - # Handle Optional[X] / Union[X, None] — get_origin returns Union, not type(None) - if origin is Union: + return {} + + +def _unwrap_annotated(annotation) -> tuple[Any, tuple[Any, ...]]: + """Return the underlying annotation plus all Annotated metadata.""" + metadata: list[Any] = [] + while get_origin(annotation) is Annotated: args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - inner = _type_to_schema(non_none[0]) - inner["nullable"] = True - return inner - return {"nullable": True} - # Handle bare NoneType annotation - if annotation is type(None): - return {"nullable": True} + if not args: + break + annotation = args[0] + metadata.extend(args[1:]) + return annotation, tuple(metadata) - # Try to get schema from Satya/Pydantic models + +def _resolve_annotation_and_marker(param: inspect.Parameter) -> tuple[Any, Any | None]: + annotation, _metadata = _unwrap_annotated(param.annotation) + + if isinstance(param.default, _PARAM_MARKER_TYPES): + return annotation, param.default + + return annotation, None + + +def _is_unsupported_annotated_model_body(annotation: Any, marker: Any | None) -> bool: + if marker is not None: + return False + + annotation, metadata = _unwrap_annotated(annotation) + return _is_model_class(annotation) and any(isinstance(item, Body) for item in metadata) + + +def _is_dependency_parameter(param: inspect.Parameter) -> bool: + if isinstance(param.default, (Depends, SecurityBase)): + return True + return get_depends(param) is not None + + +def _is_required_param(param: inspect.Parameter, marker: Any | None) -> bool: + default = _effective_default(param, marker) + return default is inspect.Parameter.empty or default is ... + + +def _is_required_body_param(param: inspect.Parameter, marker: Any | None) -> bool: + if isinstance(marker, Body): + # Runtime body parsing does not unwrap Body.default; missing values + # become the marker object, not the marker's default value. + return True + + return _is_required_param(param, marker) + + +def _effective_default(param: inspect.Parameter, marker: Any | None) -> Any: + if isinstance(param.default, _PARAM_MARKER_TYPES): + return param.default.default + if param.default is not inspect.Parameter.empty: + return param.default + if marker is not None and hasattr(marker, "default"): + return marker.default + return inspect.Parameter.empty + + +def _schema_for_param( + annotation, + marker: Any | None, + param: inspect.Parameter, + schema_context: _SchemaContext, + *, + include_default: bool = True, +) -> dict: + if _is_file_param(annotation, marker): + schema = {"type": "string", "format": "binary"} + else: + schema = _type_to_schema(annotation, schema_context) + + schema = dict(schema) + _apply_marker_metadata(schema, marker) + + if include_default: + default = _effective_default(param, marker) + if ( + default is not inspect.Parameter.empty + and default is not ... + and _is_jsonable(default) + ): + schema["default"] = default + + return schema + + +def _apply_marker_metadata(schema: dict[str, Any], marker: Any | None) -> None: + if marker is None: + return + + for attr, openapi_name in ( + ("title", "title"), + ("description", "description"), + ("min_length", "minLength"), + ("max_length", "maxLength"), + ("regex", "pattern"), + ("gt", "exclusiveMinimum"), + ("ge", "minimum"), + ("lt", "exclusiveMaximum"), + ("le", "maximum"), + ): + value = getattr(marker, attr, None) + if value is not None: + schema[openapi_name] = value + + +def _build_parameter( + name: str, location: str, required: bool, schema: dict[str, Any], marker: Any | None +) -> dict[str, Any]: + parameter = { + "name": name, + "in": location, + "required": True if location == "path" else required, + "schema": schema, + } + description = getattr(marker, "description", None) + if description: + parameter["description"] = description + return parameter + + +def _build_request_body(body_entries: list[dict[str, Any]]) -> dict[str, Any]: + if any(entry["media_type"] == "multipart/form-data" for entry in body_entries): + for entry in body_entries: + entry["media_type"] = "multipart/form-data" + + content: dict[str, Any] = {} + media_types = sorted({entry["media_type"] for entry in body_entries}) + for media_type in media_types: + entries = [entry for entry in body_entries if entry["media_type"] == media_type] + if len(entries) == 1 and entries[0]["direct"]: + body_schema = entries[0]["schema"] + else: + body_schema = { + "type": "object", + "properties": {entry["name"]: entry["schema"] for entry in entries}, + } + required = [entry["name"] for entry in entries if entry["required"]] + if required: + body_schema["required"] = required + content[media_type] = {"schema": body_schema} + + return {"required": any(entry["required"] for entry in body_entries), "content": content} + + +def _parameter_alias(param_name: str, marker: Any | None, *, location: str | None = None) -> str: + alias = getattr(marker, "alias", None) + if alias: + return alias + if location == "header" and getattr(marker, "convert_underscores", True): + return param_name.replace("_", "-") + return param_name + + +def _is_form_or_file_param(annotation, marker: Any | None) -> bool: + return isinstance(marker, (Form, File)) or _is_upload_file_type(annotation) + + +def _is_file_param(annotation, marker: Any | None) -> bool: + return isinstance(marker, File) or _is_upload_file_type(annotation) + + +def _is_upload_file_type(annotation) -> bool: try: - if hasattr(annotation, "__fields__") or hasattr(annotation, "model_fields"): - return {"$ref": f"#/components/schemas/{annotation.__name__}"} - except (TypeError, AttributeError): - pass + return inspect.isclass(annotation) and issubclass(annotation, UploadFile) + except TypeError: + return False - return {} + +def _is_union_type(annotation) -> bool: + origin = get_origin(annotation) + return origin is Union or origin is types.UnionType or isinstance(annotation, types.UnionType) + + +def _is_model_class(annotation) -> bool: + try: + return inspect.isclass(annotation) and ( + hasattr(annotation, "model_json_schema") + or hasattr(annotation, "schema") + or hasattr(annotation, "model_fields") + or hasattr(annotation, "__fields__") + ) + except TypeError: + return False + + +def _register_model_schema(model_class, schema_context: _SchemaContext | None) -> dict[str, str]: + if schema_context is None: + name = _component_base_name(model_class) + return {"$ref": f"#/components/schemas/{name}"} + + name = _component_name_for_model(model_class, schema_context) + if name not in schema_context.components: + schema_context.components[name] = {} + model_schema = _model_to_schema(model_class, schema_context) + schema_context.components[name].update(model_schema) + return {"$ref": f"#/components/schemas/{name}"} + + +def _component_name_for_model(model_class, schema_context: _SchemaContext) -> str: + existing_name = schema_context.model_component_names.get(model_class) + if existing_name is not None: + return existing_name + + base_name = _component_base_name(model_class) + existing_model = schema_context.component_models.get(base_name) + if base_name not in schema_context.components or existing_model is model_class: + name = base_name + else: + name = _unique_component_name(model_class, base_name, schema_context) + + schema_context.model_component_names[model_class] = name + schema_context.component_models[name] = model_class + return name + + +def _component_base_name(model_class) -> str: + return _sanitize_component_name(getattr(model_class, "__name__", "Model")) + + +def _unique_component_name(model_class, base_name: str, schema_context: _SchemaContext) -> str: + module = getattr(model_class, "__module__", "") + qualname = getattr(model_class, "__qualname__", base_name) + qualified_name = ".".join(part for part in (module, qualname) if part) + candidate = _sanitize_component_name(qualified_name) + if candidate == base_name: + candidate = f"{base_name}_2" + + reserved_names = set(schema_context.components) | set(schema_context.component_models) + return _next_available_component_name(candidate, reserved_names) + + +def _component_name_for_definition(name: str, reserved_names: set[str]) -> str: + base_name = _sanitize_component_name(name) + return _next_available_component_name(base_name, reserved_names) + + +def _next_available_component_name(base_name: str, reserved_names: set[str]) -> str: + if base_name not in reserved_names: + return base_name + + index = 2 + while True: + candidate = f"{base_name}_{index}" + if candidate not in reserved_names: + return candidate + index += 1 + + +def _sanitize_component_name(name: str) -> str: + return re.sub(r"[^a-zA-Z0-9.\-_]+", "_", name).strip("._-") or "Model" + + +def _model_to_schema(model_class, schema_context: _SchemaContext | None) -> dict[str, Any]: + schema: dict[str, Any] | None = None + + if hasattr(model_class, "model_json_schema"): + try: + schema = model_class.model_json_schema(ref_template="#/components/schemas/{model}") + except TypeError: + schema = model_class.model_json_schema() + except Exception: + schema = None + elif hasattr(model_class, "schema"): + try: + schema = model_class.schema(ref_template="#/components/schemas/{model}") + except TypeError: + schema = model_class.schema() + except Exception: + schema = None + + if not isinstance(schema, dict): + schema = _schema_from_annotations(model_class, schema_context) + + schema = copy.deepcopy(schema) + _move_defs_to_components(schema, schema_context) + _rewrite_component_refs(schema) + return schema + + +def _schema_from_annotations( + model_class, schema_context: _SchemaContext | None +) -> dict[str, Any]: + properties = {} + required = [] + annotations = getattr(model_class, "__annotations__", {}) + fields = getattr(model_class, "model_fields", getattr(model_class, "__fields__", {})) + + for field_name, annotation in annotations.items(): + properties[field_name] = _type_to_schema(annotation, schema_context) + if _model_field_is_required(model_class, field_name, fields): + required.append(field_name) + + schema: dict[str, Any] = { + "title": getattr(model_class, "__name__", "Model"), + "type": "object", + "properties": properties, + } + if required: + schema["required"] = required + return schema + + +def _model_field_is_required(model_class, field_name: str, fields: Any) -> bool: + if isinstance(fields, dict) and field_name in fields: + field = fields[field_name] + is_required = getattr(field, "is_required", None) + if callable(is_required): + return bool(is_required()) + if isinstance(is_required, bool): + return is_required + default = getattr(field, "default", inspect.Parameter.empty) + return default is inspect.Parameter.empty or default is ... + return not hasattr(model_class, field_name) + + +def _move_defs_to_components( + schema: dict[str, Any], schema_context: _SchemaContext | None +) -> None: + if schema_context is None: + return + + defs_groups = [] + for defs_key in ("$defs", "definitions"): + defs = schema.pop(defs_key, None) + if isinstance(defs, dict): + defs_groups.append((defs_key, defs)) + + if not defs_groups: + return + + component_names: dict[tuple[str, str], str] = {} + ref_rewrites: dict[str, str] = {} + reserved_names = set(schema_context.components) | set(schema_context.component_models) + + for defs_key, defs in defs_groups: + for name, value in defs.items(): + component_name = _component_name_for_definition(name, reserved_names) + component_names[(defs_key, name)] = component_name + reserved_names.add(component_name) + component_ref = f"#/components/schemas/{component_name}" + ref_rewrites[f"#/{defs_key}/{name}"] = component_ref + ref_rewrites[f"#/components/schemas/{name}"] = component_ref + + _rewrite_component_refs(schema, ref_rewrites) + for defs_key, defs in defs_groups: + for name, value in defs.items(): + component_name = component_names[(defs_key, name)] + component_schema = copy.deepcopy(value) + _rewrite_component_refs(component_schema, ref_rewrites) + schema_context.components.setdefault(component_name, component_schema) + + +def _rewrite_component_refs( + value: Any, ref_rewrites: dict[str, str] | None = None +) -> None: + if isinstance(value, dict): + ref = value.get("$ref") + if isinstance(ref, str): + if ref_rewrites and ref in ref_rewrites: + value["$ref"] = ref_rewrites[ref] + elif ref.startswith("#/$defs/"): + value["$ref"] = "#/components/schemas/" + ref.rsplit("/", 1)[-1] + elif ref.startswith("#/definitions/"): + value["$ref"] = "#/components/schemas/" + ref.rsplit("/", 1)[-1] + for item in value.values(): + _rewrite_component_refs(item, ref_rewrites) + elif isinstance(value, list): + for item in value: + _rewrite_component_refs(item, ref_rewrites) + + +def _is_jsonable(value: Any) -> bool: + try: + json.dumps(value) + return True + except (TypeError, ValueError): + return False # HTML templates for Swagger UI and ReDoc diff --git a/tests/test_fastapi_parity.py b/tests/test_fastapi_parity.py index d7e741a..448eced 100644 --- a/tests/test_fastapi_parity.py +++ b/tests/test_fastapi_parity.py @@ -7,6 +7,7 @@ import json import os import tempfile +from typing import Annotated import pytest from turboapi import ( @@ -533,6 +534,383 @@ def create_item(name: str, price: float): operation = schema["paths"]["/items"]["post"] assert "requestBody" in operation + def test_openapi_skips_dependencies_and_supports_forms_and_dhi_models(self): + from dhi import BaseModel + + class SearchRequest(BaseModel): + query: str + limit: int = 10 + + class SearchResponse(BaseModel): + count: int + + def get_session(): + return {"session": True} + + SessionDep = Annotated[dict, Depends(get_session)] + app = TurboAPI(title="OpenAPICompat") + + @app.post("/login") + def login(username: str = Form(), password: str = Form()): + return {"username": username} + + @app.post("/search", response_model=SearchResponse) + def search( + session: SessionDep, + request: SearchRequest, + include_archived: bool = Query(default=False, alias="archived"), + ): + return SearchResponse(count=request.limit) + + schema = app.openapi() + json.dumps(schema) + + login_body = schema["paths"]["/login"]["post"]["requestBody"] + form_schema = login_body["content"]["application/x-www-form-urlencoded"]["schema"] + assert form_schema["properties"]["username"] == {"type": "string"} + assert form_schema["properties"]["password"] == {"type": "string"} + assert form_schema["required"] == ["username", "password"] + + search_operation = schema["paths"]["/search"]["post"] + assert search_operation["requestBody"]["content"]["application/json"]["schema"] == { + "$ref": "#/components/schemas/SearchRequest" + } + assert search_operation["parameters"] == [ + { + "name": "include_archived", + "in": "query", + "required": True, + "schema": {"type": "boolean"}, + } + ] + assert search_operation["responses"]["200"]["content"]["application/json"]["schema"] == { + "$ref": "#/components/schemas/SearchResponse" + } + assert "SearchRequest" in schema["components"]["schemas"] + assert "SearchResponse" in schema["components"]["schemas"] + + def test_openapi_disambiguates_same_named_model_components(self): + from dhi import BaseModel + + FirstItem = type("Item", (BaseModel,), {"__annotations__": {"name": str}}) + SecondItem = type("Item", (BaseModel,), {"__annotations__": {"count": int}}) + + app = TurboAPI(title="OpenAPIModels") + + @app.post("/first", response_model=FirstItem) + def create_first(item: FirstItem): + return item + + @app.post("/second", response_model=SecondItem) + def create_second(item: SecondItem): + return item + + schema = app.openapi() + first_request_ref = schema["paths"]["/first"]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"]["$ref"] + second_request_ref = schema["paths"]["/second"]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"]["$ref"] + first_response_ref = schema["paths"]["/first"]["post"]["responses"]["200"]["content"][ + "application/json" + ]["schema"]["$ref"] + second_response_ref = schema["paths"]["/second"]["post"]["responses"]["200"]["content"][ + "application/json" + ]["schema"]["$ref"] + + assert first_request_ref == first_response_ref + assert second_request_ref == second_response_ref + assert first_request_ref != second_request_ref + + first_component = first_request_ref.rsplit("/", 1)[-1] + second_component = second_request_ref.rsplit("/", 1)[-1] + components = schema["components"]["schemas"] + assert components[first_component]["properties"] == {"name": {"type": "string"}} + assert components[second_component]["properties"] == {"count": {"type": "integer"}} + + def test_openapi_disambiguates_same_named_nested_model_definitions(self): + class FirstParent: + @classmethod + def model_json_schema(cls, ref_template=None): + return { + "title": "FirstParent", + "type": "object", + "properties": {"child": {"$ref": "#/components/schemas/Nested"}}, + "$defs": { + "Nested": { + "title": "Nested", + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + + class SecondParent: + @classmethod + def model_json_schema(cls, ref_template=None): + return { + "title": "SecondParent", + "type": "object", + "properties": {"child": {"$ref": "#/components/schemas/Nested"}}, + "$defs": { + "Nested": { + "title": "Nested", + "type": "object", + "properties": {"count": {"type": "integer"}}, + } + }, + } + + app = TurboAPI(title="OpenAPINestedModels") + + @app.post("/first-nested", response_model=FirstParent) + def create_first_nested(item: FirstParent): + return item + + @app.post("/second-nested", response_model=SecondParent) + def create_second_nested(item: SecondParent): + return item + + schema = app.openapi() + components = schema["components"]["schemas"] + first_nested_ref = components["FirstParent"]["properties"]["child"]["$ref"] + second_nested_ref = components["SecondParent"]["properties"]["child"]["$ref"] + + assert first_nested_ref != second_nested_ref + first_nested_component = first_nested_ref.rsplit("/", 1)[-1] + second_nested_component = second_nested_ref.rsplit("/", 1)[-1] + assert components[first_nested_component]["properties"] == {"name": {"type": "string"}} + assert components[second_nested_component]["properties"] == { + "count": {"type": "integer"} + } + + def test_openapi_disambiguates_nested_definitions_after_ref_rewriting(self): + class FirstParent: + @classmethod + def model_json_schema(cls, ref_template=None): + return { + "title": "FirstParent", + "type": "object", + "properties": {"child": {"$ref": "#/components/schemas/Child"}}, + "$defs": { + "Child": { + "title": "Child", + "type": "object", + "properties": { + "inner": {"$ref": "#/components/schemas/Inner"} + }, + }, + "Inner": { + "title": "Inner", + "type": "object", + "properties": {"name": {"type": "string"}}, + }, + }, + } + + class SecondParent: + @classmethod + def model_json_schema(cls, ref_template=None): + return { + "title": "SecondParent", + "type": "object", + "properties": {"child": {"$ref": "#/components/schemas/Child"}}, + "$defs": { + "Child": { + "title": "Child", + "type": "object", + "properties": { + "inner": {"$ref": "#/components/schemas/Inner"} + }, + }, + "Inner": { + "title": "Inner", + "type": "object", + "properties": {"count": {"type": "integer"}}, + }, + }, + } + + app = TurboAPI(title="OpenAPINestedRefRewrite") + + @app.post("/first-child", response_model=FirstParent) + def create_first_child(item: FirstParent): + return item + + @app.post("/second-child", response_model=SecondParent) + def create_second_child(item: SecondParent): + return item + + schema = app.openapi() + components = schema["components"]["schemas"] + first_child_ref = components["FirstParent"]["properties"]["child"]["$ref"] + second_child_ref = components["SecondParent"]["properties"]["child"]["$ref"] + + assert first_child_ref != second_child_ref + first_child_component = first_child_ref.rsplit("/", 1)[-1] + second_child_component = second_child_ref.rsplit("/", 1)[-1] + first_inner_ref = components[first_child_component]["properties"]["inner"]["$ref"] + second_inner_ref = components[second_child_component]["properties"]["inner"]["$ref"] + + assert first_inner_ref != second_inner_ref + first_inner_component = first_inner_ref.rsplit("/", 1)[-1] + second_inner_component = second_inner_ref.rsplit("/", 1)[-1] + assert components[first_inner_component]["properties"] == {"name": {"type": "string"}} + assert components[second_inner_component]["properties"] == { + "count": {"type": "integer"} + } + + def test_openapi_does_not_document_unsupported_annotated_form_metadata(self): + app = TurboAPI(title="OpenAPIAnnotatedForm") + + @app.post("/annotated-form") + def annotated_form(username: Annotated[str, Form()]): + return {"username": username} + + schema = app.openapi() + content = schema["paths"]["/annotated-form"]["post"]["requestBody"]["content"] + + assert "application/x-www-form-urlencoded" not in content + assert content["application/json"]["schema"] == { + "type": "object", + "properties": {"username": {"type": "string"}}, + "required": ["username"], + } + + def test_openapi_does_not_document_unsupported_annotated_param_metadata(self): + app = TurboAPI(title="OpenAPIAnnotatedParams") + + @app.get("/annotated-query") + def annotated_query(q: Annotated[int, Query(default=10, alias="item-query")]): + return {"q": q} + + schema = app.openapi() + parameters = schema["paths"]["/annotated-query"]["get"]["parameters"] + + assert parameters == [ + { + "name": "q", + "in": "query", + "required": True, + "schema": {"type": "integer"}, + } + ] + + def test_openapi_does_not_document_unsupported_cookie_route_params(self): + from dhi import BaseModel + + class Session(BaseModel): + id: str + + app = TurboAPI(title="OpenAPICookieParams") + + @app.get("/cookie-route") + def cookie_route(session_id: str = Cookie()): + return {"session_id": session_id} + + @app.get("/cookie-model-route") + def cookie_model_route(session: Session = Cookie()): + return session + + schema = app.openapi() + cookie_operation = schema["paths"]["/cookie-route"]["get"] + cookie_model_operation = schema["paths"]["/cookie-model-route"]["get"] + + assert "parameters" not in cookie_operation + assert "parameters" not in cookie_model_operation + assert "Session" not in schema["components"]["schemas"] + + def test_openapi_body_embed_model_matches_current_runtime_binding(self): + from dhi import BaseModel + + class Item(BaseModel): + name: str + + app = TurboAPI(title="OpenAPIBodyEmbed") + + @app.post("/body-embed") + def body_embed(item: Item = Body(embed=True)): + return item + + schema = app.openapi() + body_schema = schema["paths"]["/body-embed"]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"] + + assert body_schema == {"$ref": "#/components/schemas/Item"} + + def test_openapi_documents_form_fields_as_runtime_raw_strings(self): + app = TurboAPI(title="OpenAPIFormRawStrings") + + @app.post("/form-raw-strings") + def form_raw_strings(age: int = Form(), active: bool = Form()): + return {"age": age, "active": active} + + schema = app.openapi() + body_schema = schema["paths"]["/form-raw-strings"]["post"]["requestBody"]["content"][ + "application/x-www-form-urlencoded" + ]["schema"] + + assert body_schema == { + "type": "object", + "properties": { + "age": {"type": "string"}, + "active": {"type": "string"}, + }, + "required": ["age", "active"], + } + + def test_openapi_does_not_document_unsupported_body_marker_defaults(self): + app = TurboAPI(title="OpenAPIBodyMarkerDefaults") + + @app.post("/body-marker-defaults") + def body_marker_defaults(count: int = Body(default=10)): + return {"count": count} + + schema = app.openapi() + request_body = schema["paths"]["/body-marker-defaults"]["post"]["requestBody"] + body_schema = request_body["content"]["application/json"]["schema"] + + assert request_body["required"] is True + assert body_schema == { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + } + + def test_openapi_does_not_document_unsupported_annotated_model_body(self): + from dhi import BaseModel + + class Item(BaseModel): + name: str + + app = TurboAPI(title="OpenAPIAnnotatedModelBody") + + @app.post("/annotated-model-body") + def annotated_model_body(item: Annotated[Item, Body()]): + return item + + @app.post("/annotated-model-body-embed") + def annotated_model_body_embed(item: Annotated[Item, Body(embed=True)]): + return item + + schema = app.openapi() + + for path in ("/annotated-model-body", "/annotated-model-body-embed"): + body_schema = schema["paths"][path]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"] + + assert body_schema != {"$ref": "#/components/schemas/Item"} + assert body_schema == { + "type": "object", + "properties": {"item": {}}, + "required": ["item"], + } + + assert "Item" not in schema["components"]["schemas"] + def test_app_openapi_method(self): app = TurboAPI(title="AppOpenAPI")