Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions star_openapi/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket

from .request import _validate_request

Expand Down Expand Up @@ -37,3 +38,11 @@ async def endpoint(request: Request) -> Response:
return func(**kwargs)

return endpoint


def create_websocket_endpoint(func):
@wraps(func)
async def endpoint(websocket: WebSocket):
return await func(websocket)

return endpoint
10 changes: 4 additions & 6 deletions star_openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .cli import cli
from .config import Config
from .endpoint import create_endpoint
from .endpoint import create_endpoint, create_websocket_endpoint
from .models import (
OPENAPI3_REF_PREFIX,
Components,
Expand Down Expand Up @@ -684,11 +684,9 @@ def websocket(
name: str | None = None,
):
def decorator(func) -> Callable:
self.add_websocket_route(
rule,
func,
name=name,
)
endpoint = create_websocket_endpoint(func)
self.add_websocket_route(rule, endpoint, name=name)

return func

return decorator
10 changes: 4 additions & 6 deletions star_openapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from starlette.routing import Request, Response, Route, Router, WebSocketRoute
from starlette.websockets import WebSocket

from .endpoint import create_endpoint
from .endpoint import create_endpoint, create_websocket_endpoint
from .models import ExternalDocumentation, RequestBody, Server, Tag
from .types import ParametersTuple, ResponseDict
from .utils import (
Expand Down Expand Up @@ -580,11 +580,9 @@ def websocket(
name: str | None = None,
):
def decorator(func) -> Callable:
self._add_websocket_route(
rule,
func,
name=name,
)
endpoint = create_websocket_endpoint(func)
self._add_websocket_route(rule, endpoint, name=name)

return func

return decorator