Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 18 additions & 28 deletions flask_openapi/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,30 +100,20 @@ def register_api(self, api: "APIBlueprint") -> None:
# Register the nested APIBlueprint as a blueprint
self.register_blueprint(api)

def _add_url_rule(
self,
rule,
endpoint=None,
view_func=None,
provide_automatic_options=None,
**options,
) -> None:
self.add_url_rule(rule, endpoint, view_func, provide_automatic_options, **options)

def _collect_openapi_info(
self,
rule: str,
func: FunctionType,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
responses: ResponseDict | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
doc_ui: bool = True,
method: str = HTTPMethod.GET,
Expand Down Expand Up @@ -203,14 +193,14 @@ def get(
self,
rule: str,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -278,14 +268,14 @@ def post(
self,
rule: str,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -353,14 +343,14 @@ def put(
self,
rule: str,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -428,14 +418,14 @@ def delete(
self,
rule: str,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -503,14 +493,14 @@ def patch(
self,
rule: str,
*,
tags: list[Tag] | None = None,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down
14 changes: 4 additions & 10 deletions flask_openapi/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,16 @@ async def view_func(**kwargs) -> FlaskResponse:
signature = inspect.signature(view_class.__init__)
parameters = signature.parameters
if parameters.get("view_kwargs"):
view_object = view_class(view_kwargs=view_kwargs)
view_object = view_class(view_kwargs=view_kwargs) # pragma: no cover
else:
view_object = view_class()
response = await func(view_object, **func_kwargs)
else:
response = await func(**func_kwargs)

if hasattr(current_app, "validate_response"):
_validate_response = validate_response or current_app.validate_response
else:
_validate_response = validate_response
_validate_response = validate_response or current_app.validate_response # type: ignore

if _validate_response and responses:
if _validate_response and responses: # pragma: no cover
validate_response_callback = getattr(current_app, "validate_response_callback")
return validate_response_callback(response, responses)

Expand Down Expand Up @@ -90,10 +87,7 @@ def view_func(**kwargs) -> FlaskResponse:
else:
response = func(**func_kwargs)

if hasattr(current_app, "validate_response"):
_validate_response = validate_response or current_app.validate_response
else:
_validate_response = validate_response
_validate_response = validate_response or current_app.validate_response # type: ignore

if _validate_response and responses:
validate_response_callback = getattr(current_app, "validate_response_callback")
Expand Down
6 changes: 4 additions & 2 deletions flask_openapi/models/operation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from pydantic import BaseModel

from .callback import Callback
Expand All @@ -18,7 +20,7 @@ class Operation(BaseModel):
tags: list[str] | None = None
summary: str | None = None
description: str | None = None
externalDocs: ExternalDocumentation | None = None
externalDocs: ExternalDocumentation | dict[str, Any] | None = None
operationId: str | None = None
parameters: list[Parameter] | None = None
requestBody: RequestBody | Reference | None = None
Expand All @@ -27,6 +29,6 @@ class Operation(BaseModel):

deprecated: bool | None = False
security: list[SecurityRequirement] | None = None
servers: list[Server] | None = None
servers: list[Server | dict[str, Any]] | None = None

model_config = {"extra": "allow"}
26 changes: 13 additions & 13 deletions flask_openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def register_api_view(
self.components_schemas.update(**api_view.components_schemas)

# Register the APIView with the current instance
api_view.register(self, url_prefix=url_prefix, view_kwargs=view_kwargs)
api_view.register(self, view_kwargs=view_kwargs)

def _collect_openapi_info(
self,
Expand All @@ -358,12 +358,12 @@ def _collect_openapi_info(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
responses: ResponseDict | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
doc_ui: bool = True,
method: str = HTTPMethod.GET,
Expand Down Expand Up @@ -443,11 +443,11 @@ def get(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -518,11 +518,11 @@ def post(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -593,11 +593,11 @@ def put(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -668,11 +668,11 @@ def delete(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down Expand Up @@ -743,11 +743,11 @@ def patch(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
responses: ResponseDict | None = None,
validate_response: bool | None = None,
Expand Down
2 changes: 1 addition & 1 deletion flask_openapi/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict):
value = request_headers.get(key_alias_title)
else:
key = model_field_key
value = request_headers[key_title]
value = request_headers.get(key_title)
if value is not None:
header_dict[key] = value
if model_field_schema.get("type") == "null":
Expand Down
4 changes: 3 additions & 1 deletion flask_openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel

_ResponseDictValue = Type[BaseModel] | dict[Any, Any] | None
from .models import Response

_ResponseDictValue = Type[BaseModel] | Response | dict[Any, Any] | None

ResponseDict = dict[str | int | HTTPStatus, _ResponseDictValue]

Expand Down
10 changes: 5 additions & 5 deletions flask_openapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, opera
elif isinstance(response, dict):
response["description"] = response.get("description", HTTP_STATUS.get(key, ""))
_responses[key] = Response(**response)
elif isinstance(response, Response):
_responses[key] = response
else:
# OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__
schema = get_model_schema(response, mode="serialization")
Expand Down Expand Up @@ -477,18 +479,16 @@ def make_validation_error_response(e: ValidationError) -> FlaskResponse:
return response


def run_validate_response(response: Any, responses: ResponseDict | None = None) -> Any:
def run_validate_response(response: Any, responses: ResponseDict) -> Any:
"""Validate response"""
if responses is None:
return response

if isinstance(response, tuple): # noqa
if isinstance(response, tuple):
_resp, status_code = response[:2]
elif isinstance(response, FlaskResponse):
if response.mimetype != "application/json":
# only application/json
return response
_resp, status_code = response.json, response.status_code # noqa
_resp, status_code = response.json, response.status_code
else:
_resp, status_code = response, 200

Expand Down
16 changes: 4 additions & 12 deletions flask_openapi/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def doc(
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
description: str | None = None,
external_docs: ExternalDocumentation | None = None,
external_docs: ExternalDocumentation | dict[str, Any] | None = None,
operation_id: str | None = None,
responses: ResponseDict | None = None,
deprecated: bool | None = None,
security: list[dict[str, list[Any]]] | None = None,
servers: list[Server] | None = None,
servers: list[Server | dict[str, Any]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
validate_response: bool | None = None,
doc_ui: bool = True,
Expand Down Expand Up @@ -188,15 +188,12 @@ def decorator(func):

return decorator

def register(
self, app: "OpenAPI", url_prefix: str | None = None, view_kwargs: dict[Any, Any] | None = None
) -> None:
def register(self, app: "OpenAPI", view_kwargs: dict[Any, Any] | None = None) -> None:
"""
Register the API views with the given OpenAPI app.

Args:
app: An instance of the OpenAPI app.
url_prefix: A path to prepend to all the APIView's urls
view_kwargs: Additional keyword arguments to pass to the API views.
"""
for rule, (cls, methods) in self.views.items():
Expand All @@ -214,14 +211,9 @@ def register(
body,
view_class=cls,
view_kwargs=view_kwargs,
responses=func.responses,
responses=getattr(func, "responses", None),
validate_response=_validate_response,
)

if url_prefix and self.url_prefix and url_prefix != self.url_prefix:
rule = url_prefix + rule.removeprefix(self.url_prefix)
elif url_prefix and not self.url_prefix:
rule = url_prefix.rstrip("/") + "/" + rule.lstrip("/")

options: dict[str, Any] = {"endpoint": cls.__name__ + "." + method.lower(), "methods": [method.upper()]}
app.add_url_rule(rule, view_func=view_func, **options)
10 changes: 10 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SWAGGER_CONFIG = {
"docExpansion": "none",
"validatorUrl": None,
"tryItOutEnabled": True,
"filter": True,
"tagsSorter": "alpha",
"persistAuthorization": True,
}

JWT = [{"jwt": []}]
Loading