diff --git a/star_openapi/endpoint.py b/star_openapi/endpoint.py index cd41407..651c164 100644 --- a/star_openapi/endpoint.py +++ b/star_openapi/endpoint.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from starlette.requests import Request from starlette.responses import Response +from starlette.websockets import WebSocket from .request import _validate_request @@ -37,3 +38,11 @@ async def endpoint(request: Request) -> Response: return func(**kwargs) return endpoint + + +def create_websocket_endpoint(func): + @wraps(func) + async def endpoint(websocket: WebSocket): + return await func(websocket) + + return endpoint diff --git a/star_openapi/openapi.py b/star_openapi/openapi.py index 7f01d38..5af1142 100644 --- a/star_openapi/openapi.py +++ b/star_openapi/openapi.py @@ -13,7 +13,7 @@ from .cli import cli from .config import Config -from .endpoint import create_endpoint +from .endpoint import create_endpoint, create_websocket_endpoint from .models import ( OPENAPI3_REF_PREFIX, Components, @@ -684,11 +684,9 @@ def websocket( name: str | None = None, ): def decorator(func) -> Callable: - self.add_websocket_route( - rule, - func, - name=name, - ) + endpoint = create_websocket_endpoint(func) + self.add_websocket_route(rule, endpoint, name=name) + return func return decorator diff --git a/star_openapi/router.py b/star_openapi/router.py index d33a065..056e3db 100644 --- a/star_openapi/router.py +++ b/star_openapi/router.py @@ -7,7 +7,7 @@ from starlette.routing import Request, Response, Route, Router, WebSocketRoute from starlette.websockets import WebSocket -from .endpoint import create_endpoint +from .endpoint import create_endpoint, create_websocket_endpoint from .models import ExternalDocumentation, RequestBody, Server, Tag from .types import ParametersTuple, ResponseDict from .utils import ( @@ -580,11 +580,9 @@ def websocket( name: str | None = None, ): def decorator(func) -> Callable: - self._add_websocket_route( - rule, - func, - name=name, - ) + endpoint = create_websocket_endpoint(func) + self._add_websocket_route(rule, endpoint, name=name) + return func return decorator