diff --git a/examples/api_router_demo.py b/examples/api_router_demo.py index 2dcead2..4536320 100644 --- a/examples/api_router_demo.py +++ b/examples/api_router_demo.py @@ -40,11 +40,16 @@ async def delete_book(path: IdModel): return JSONResponse({"id": path.id}) -@api2.get("/") +@api2.get("") async def get_api2(): return JSONResponse({"message": "Hello World2"}) +@api2.websocket("") +async def api2_websocket(): + return JSONResponse({"message": "Hello World2"}) + + api1.register_api(api2) app.register_api(api1) diff --git a/star_openapi/openapi.py b/star_openapi/openapi.py index 7994530..7f01d38 100644 --- a/star_openapi/openapi.py +++ b/star_openapi/openapi.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ValidationError from starlette.applications import Starlette from starlette.responses import HTMLResponse, JSONResponse -from starlette.routing import Mount, Route, WebSocketRoute +from starlette.routing import Mount, Route from .cli import cli from .config import Config @@ -27,7 +27,7 @@ Tag, ValidationErrorModel, ) -from .router import APIRouter +from .router import APIRoute, APIRouter, APIWebSocketRoute from .templates import openapi_html_string from .types import ParametersTuple, ResponseDict from .utils import ( @@ -272,16 +272,16 @@ def register_api(self, api: APIRouter): # Register the APIRouter with the current instance for route in api.routes: - if isinstance(route, Route): - path_with_prefix = api.url_prefix + route.path + if isinstance(route, APIRoute): + path_with_prefix = api.url_prefix + route.origin_path self.router.add_route( path=path_with_prefix, endpoint=route.endpoint, methods=route.methods, name=route.name, ) - elif isinstance(route, WebSocketRoute): - path_with_prefix = api.url_prefix + route.path + elif isinstance(route, APIWebSocketRoute): + path_with_prefix = api.url_prefix + route.origin_path self.router.add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name) def _collect_openapi_info( @@ -383,6 +383,7 @@ def get( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -444,6 +445,7 @@ def post( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -453,6 +455,7 @@ def post( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -506,6 +509,7 @@ def put( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -515,6 +519,7 @@ def put( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -568,6 +573,7 @@ def delete( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -577,6 +583,7 @@ def delete( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -630,6 +637,7 @@ def patch( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -639,6 +647,7 @@ def patch( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ diff --git a/star_openapi/router.py b/star_openapi/router.py index 64e2f3e..d33a065 100644 --- a/star_openapi/router.py +++ b/star_openapi/router.py @@ -1,9 +1,11 @@ -from collections.abc import Callable +from collections.abc import Awaitable, Callable, Collection from http import HTTPMethod from types import FunctionType -from typing import Any +from typing import Any, Sequence -from starlette.routing import Route, Router, WebSocketRoute +from starlette.middleware import Middleware +from starlette.routing import Request, Response, Route, Router, WebSocketRoute +from starlette.websockets import WebSocket from .endpoint import create_endpoint from .models import ExternalDocumentation, RequestBody, Server, Tag @@ -20,6 +22,50 @@ ) +class APIRoute(Route): + def __init__( + self, + path: str, + origin_path: str, + endpoint: Callable[..., Any], + *, + methods: Collection[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + middleware: Sequence[Middleware] | None = None, + ): + super().__init__( + path=path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + middleware=middleware, + ) + + self.origin_path = origin_path + + +class APIWebSocketRoute(WebSocketRoute): + def __init__( + self, + path: str, + origin_path: str, + endpoint: Callable[..., Any], + *, + name: str | None = None, + middleware: Sequence[Middleware] | None = None, + ): + super().__init__( + path=path, + endpoint=endpoint, + name=name, + middleware=middleware, + ) + + self.origin_path = origin_path + + class APIRouter(Router): def __init__( self, @@ -79,17 +125,17 @@ def register_api(self, api: "APIRouter"): # Register the APIRouter with the current instance for route in api.routes: - if isinstance(route, Route): - path_with_prefix = api.url_prefix + route.path - self.add_route( + if isinstance(route, APIRoute): + path_with_prefix = api.url_prefix + route.origin_path + self._add_route( path=path_with_prefix, endpoint=route.endpoint, methods=route.methods, name=route.name, ) - elif isinstance(route, WebSocketRoute): - path_with_prefix = api.url_prefix + route.path - self.add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name) + elif isinstance(route, APIWebSocketRoute): + path_with_prefix = api.url_prefix + route.origin_path + self._add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name) def _collect_openapi_info( self, @@ -173,6 +219,43 @@ def _collect_openapi_info( else: return parse_parameters(func, doc_ui=False) + def _add_route( + self, + path: str, + endpoint: Callable[[Request], Awaitable[Response] | Response], + methods: Collection[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: + if not path.startswith("/"): + origin_path = path + path = "/" + path + else: + origin_path = path + route = APIRoute( + path=path, + origin_path=origin_path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + self.routes.append(route) + + def _add_websocket_route( + self, + path: str, + endpoint: Callable[[WebSocket], Awaitable[None]], + name: str | None = None, + ) -> None: + if not path.startswith("/"): + origin_path = path + path = "/" + path + else: + origin_path = path + route = APIWebSocketRoute(path=path, origin_path=origin_path, endpoint=endpoint, name=name) + self.routes.append(route) + def get( self, rule: str, @@ -196,6 +279,7 @@ def get( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -227,7 +311,7 @@ def decorator(func) -> Callable: method=HTTPMethod.GET, ) endpoint = create_endpoint(func, header, cookie, path, query, form, body) - self.add_route(rule, endpoint, methods=["GET"], name=name, include_in_schema=False) + self._add_route(rule, endpoint, methods=["GET"], name=name, include_in_schema=False) return func @@ -257,6 +341,7 @@ def post( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -266,6 +351,7 @@ def post( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -289,7 +375,7 @@ def decorator(func) -> Callable: method=HTTPMethod.POST, ) endpoint = create_endpoint(func, header, cookie, path, query, form, body) - self.add_route(rule, endpoint, methods=["POST"], name=name, include_in_schema=False) + self._add_route(rule, endpoint, methods=["POST"], name=name, include_in_schema=False) return func @@ -319,6 +405,7 @@ def put( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -328,6 +415,7 @@ def put( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -351,7 +439,7 @@ def decorator(func) -> Callable: method=HTTPMethod.PUT, ) endpoint = create_endpoint(func, header, cookie, path, query, form, body) - self.add_route(rule, endpoint, methods=["PUT"], name=name, include_in_schema=False) + self._add_route(rule, endpoint, methods=["PUT"], name=name, include_in_schema=False) return func @@ -381,6 +469,7 @@ def delete( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -390,6 +479,7 @@ def delete( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -413,7 +503,7 @@ def decorator(func) -> Callable: method=HTTPMethod.DELETE, ) endpoint = create_endpoint(func, header, cookie, path, query, form, body) - self.add_route(rule, endpoint, methods=["DELETE"], name=name, include_in_schema=False) + self._add_route(rule, endpoint, methods=["DELETE"], name=name, include_in_schema=False) return func @@ -443,6 +533,7 @@ def patch( Args: rule: The URL rule string. + name: The URL name string. tags: Adds metadata to a single tag. summary: A short summary of what the operation does. description: A verbose explanation of the operation behavior. @@ -452,6 +543,7 @@ def patch( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body: Advanced configuration in OpenAPI. responses: API responses should be either a subclass of BaseModel, a dictionary, or None. doc_ui: Declares this operation to be shown. Default to True. """ @@ -475,7 +567,7 @@ def decorator(func) -> Callable: method=HTTPMethod.PATCH, ) endpoint = create_endpoint(func, header, cookie, path, query, form, body) - self.add_route(rule, endpoint, methods=["PATCH"], name=name, include_in_schema=False) + self._add_route(rule, endpoint, methods=["PATCH"], name=name, include_in_schema=False) return func @@ -488,7 +580,7 @@ def websocket( name: str | None = None, ): def decorator(func) -> Callable: - self.add_websocket_route( + self._add_websocket_route( rule, func, name=name, diff --git a/tests/test_api_router.py b/tests/test_api_router.py index 05c7e1f..1d4b79f 100644 --- a/tests/test_api_router.py +++ b/tests/test_api_router.py @@ -48,7 +48,7 @@ async def delete_book(path: IdModel): return JSONResponse({"id": path.id}) -@api2.get("/") +@api2.get("") async def get_api2(): return JSONResponse({"message": "Hello World2"}) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index fa59a3f..004d3c5 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -6,7 +6,7 @@ app = OpenAPI() -api = APIRouter(url_prefix="/test") +api = APIRouter(url_prefix="/test/ws") client = TestClient(app) @@ -19,7 +19,7 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() -@api.websocket("/ws") +@api.websocket("") async def websocket_endpoint_with_api_router(websocket: WebSocket): await websocket.accept() data = await websocket.receive_text()