Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/api_router_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ async def delete_book(path: IdModel):
return JSONResponse({"id": path.id})


@api2.get("/")
@api2.get("")
async def get_api2():
return JSONResponse({"message": "Hello World2"})


@api2.websocket("")
async def api2_websocket():
return JSONResponse({"message": "Hello World2"})


api1.register_api(api2)

app.register_api(api1)
Expand Down
21 changes: 15 additions & 6 deletions star_openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel, ValidationError
from starlette.applications import Starlette
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.routing import Mount, Route

from .cli import cli
from .config import Config
Expand All @@ -27,7 +27,7 @@
Tag,
ValidationErrorModel,
)
from .router import APIRouter
from .router import APIRoute, APIRouter, APIWebSocketRoute
from .templates import openapi_html_string
from .types import ParametersTuple, ResponseDict
from .utils import (
Expand Down Expand Up @@ -272,16 +272,16 @@ def register_api(self, api: APIRouter):

# Register the APIRouter with the current instance
for route in api.routes:
if isinstance(route, Route):
path_with_prefix = api.url_prefix + route.path
if isinstance(route, APIRoute):
path_with_prefix = api.url_prefix + route.origin_path
self.router.add_route(
path=path_with_prefix,
endpoint=route.endpoint,
methods=route.methods,
name=route.name,
)
elif isinstance(route, WebSocketRoute):
path_with_prefix = api.url_prefix + route.path
elif isinstance(route, APIWebSocketRoute):
path_with_prefix = api.url_prefix + route.origin_path
self.router.add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name)

def _collect_openapi_info(
Expand Down Expand Up @@ -383,6 +383,7 @@ def get(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand Down Expand Up @@ -444,6 +445,7 @@ def post(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -453,6 +455,7 @@ def post(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand Down Expand Up @@ -506,6 +509,7 @@ def put(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -515,6 +519,7 @@ def put(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand Down Expand Up @@ -568,6 +573,7 @@ def delete(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -577,6 +583,7 @@ def delete(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand Down Expand Up @@ -630,6 +637,7 @@ def patch(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -639,6 +647,7 @@ def patch(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand Down
122 changes: 107 additions & 15 deletions star_openapi/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Callable
from collections.abc import Awaitable, Callable, Collection
from http import HTTPMethod
from types import FunctionType
from typing import Any
from typing import Any, Sequence

from starlette.routing import Route, Router, WebSocketRoute
from starlette.middleware import Middleware
from starlette.routing import Request, Response, Route, Router, WebSocketRoute
from starlette.websockets import WebSocket

from .endpoint import create_endpoint
from .models import ExternalDocumentation, RequestBody, Server, Tag
Expand All @@ -20,6 +22,50 @@
)


class APIRoute(Route):
def __init__(
self,
path: str,
origin_path: str,
endpoint: Callable[..., Any],
*,
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
middleware: Sequence[Middleware] | None = None,
):
super().__init__(
path=path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
middleware=middleware,
)

self.origin_path = origin_path


class APIWebSocketRoute(WebSocketRoute):
def __init__(
self,
path: str,
origin_path: str,
endpoint: Callable[..., Any],
*,
name: str | None = None,
middleware: Sequence[Middleware] | None = None,
):
super().__init__(
path=path,
endpoint=endpoint,
name=name,
middleware=middleware,
)

self.origin_path = origin_path


class APIRouter(Router):
def __init__(
self,
Expand Down Expand Up @@ -79,17 +125,17 @@ def register_api(self, api: "APIRouter"):

# Register the APIRouter with the current instance
for route in api.routes:
if isinstance(route, Route):
path_with_prefix = api.url_prefix + route.path
self.add_route(
if isinstance(route, APIRoute):
path_with_prefix = api.url_prefix + route.origin_path
self._add_route(
path=path_with_prefix,
endpoint=route.endpoint,
methods=route.methods,
name=route.name,
)
elif isinstance(route, WebSocketRoute):
path_with_prefix = api.url_prefix + route.path
self.add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name)
elif isinstance(route, APIWebSocketRoute):
path_with_prefix = api.url_prefix + route.origin_path
self._add_websocket_route(path=path_with_prefix, endpoint=route.endpoint, name=route.name)

def _collect_openapi_info(
self,
Expand Down Expand Up @@ -173,6 +219,43 @@ def _collect_openapi_info(
else:
return parse_parameters(func, doc_ui=False)

def _add_route(
self,
path: str,
endpoint: Callable[[Request], Awaitable[Response] | Response],
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None:
if not path.startswith("/"):
origin_path = path
path = "/" + path
else:
origin_path = path
route = APIRoute(
path=path,
origin_path=origin_path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
self.routes.append(route)

def _add_websocket_route(
self,
path: str,
endpoint: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None:
if not path.startswith("/"):
origin_path = path
path = "/" + path
else:
origin_path = path
route = APIWebSocketRoute(path=path, origin_path=origin_path, endpoint=endpoint, name=name)
self.routes.append(route)

def get(
self,
rule: str,
Expand All @@ -196,6 +279,7 @@ def get(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand Down Expand Up @@ -227,7 +311,7 @@ def decorator(func) -> Callable:
method=HTTPMethod.GET,
)
endpoint = create_endpoint(func, header, cookie, path, query, form, body)
self.add_route(rule, endpoint, methods=["GET"], name=name, include_in_schema=False)
self._add_route(rule, endpoint, methods=["GET"], name=name, include_in_schema=False)

return func

Expand Down Expand Up @@ -257,6 +341,7 @@ def post(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -266,6 +351,7 @@ def post(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand All @@ -289,7 +375,7 @@ def decorator(func) -> Callable:
method=HTTPMethod.POST,
)
endpoint = create_endpoint(func, header, cookie, path, query, form, body)
self.add_route(rule, endpoint, methods=["POST"], name=name, include_in_schema=False)
self._add_route(rule, endpoint, methods=["POST"], name=name, include_in_schema=False)

return func

Expand Down Expand Up @@ -319,6 +405,7 @@ def put(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -328,6 +415,7 @@ def put(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand All @@ -351,7 +439,7 @@ def decorator(func) -> Callable:
method=HTTPMethod.PUT,
)
endpoint = create_endpoint(func, header, cookie, path, query, form, body)
self.add_route(rule, endpoint, methods=["PUT"], name=name, include_in_schema=False)
self._add_route(rule, endpoint, methods=["PUT"], name=name, include_in_schema=False)

return func

Expand Down Expand Up @@ -381,6 +469,7 @@ def delete(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -390,6 +479,7 @@ def delete(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand All @@ -413,7 +503,7 @@ def decorator(func) -> Callable:
method=HTTPMethod.DELETE,
)
endpoint = create_endpoint(func, header, cookie, path, query, form, body)
self.add_route(rule, endpoint, methods=["DELETE"], name=name, include_in_schema=False)
self._add_route(rule, endpoint, methods=["DELETE"], name=name, include_in_schema=False)

return func

Expand Down Expand Up @@ -443,6 +533,7 @@ def patch(

Args:
rule: The URL rule string.
name: The URL name string.
tags: Adds metadata to a single tag.
summary: A short summary of what the operation does.
description: A verbose explanation of the operation behavior.
Expand All @@ -452,6 +543,7 @@ def patch(
security: A declaration of which security mechanisms can be used for this operation.
servers: An alternative server array to service this operation.
openapi_extensions: Allows extensions to the OpenAPI Schema.
request_body: Advanced configuration in OpenAPI.
responses: API responses should be either a subclass of BaseModel, a dictionary, or None.
doc_ui: Declares this operation to be shown. Default to True.
"""
Expand All @@ -475,7 +567,7 @@ def decorator(func) -> Callable:
method=HTTPMethod.PATCH,
)
endpoint = create_endpoint(func, header, cookie, path, query, form, body)
self.add_route(rule, endpoint, methods=["PATCH"], name=name, include_in_schema=False)
self._add_route(rule, endpoint, methods=["PATCH"], name=name, include_in_schema=False)

return func

Expand All @@ -488,7 +580,7 @@ def websocket(
name: str | None = None,
):
def decorator(func) -> Callable:
self.add_websocket_route(
self._add_websocket_route(
rule,
func,
name=name,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def delete_book(path: IdModel):
return JSONResponse({"id": path.id})


@api2.get("/")
@api2.get("")
async def get_api2():
return JSONResponse({"message": "Hello World2"})

Expand Down
Loading