-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add per-connection threadpool websocket callback executor. #3826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
8a76f8c
1f0af4c
4f896d0
053212a
17b097a
abbb592
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| import concurrent.futures | ||
| import json | ||
| import queue | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, List | ||
| import sys | ||
| import mimetypes | ||
| import hashlib | ||
|
|
@@ -46,9 +46,9 @@ | |
| DashWebsocketCallback, | ||
| run_ws_sender, | ||
| run_callback_in_executor, | ||
| run_callback_on_loop, | ||
| make_callback_done_handler, | ||
| SHUTDOWN_SIGNAL, | ||
| DISCONNECTED, | ||
| shutdown_ws_connection, | ||
| ) | ||
| from ._utils import format_traceback_html | ||
|
|
||
|
|
@@ -263,8 +263,8 @@ def __init__(self, server: FastAPI): | |
| self.error_handling_mode = "ignore" | ||
| self.request_adapter = FastAPIRequestAdapter | ||
| self.response_adapter = FastAPIResponseAdapter | ||
| self._before_request_funcs = [] | ||
| self._after_request_func = None | ||
| self._before_request_funcs: List[Callable[[], Any]] = [] | ||
| self._after_request_func: Callable[[], Any] | None = None | ||
| self._enable_timing = False | ||
|
|
||
| def __call__(self, *args: Any, **kwargs: Any): | ||
|
|
@@ -685,7 +685,7 @@ def serve_websocket_callback(self, dash_app: "Dash"): | |
| Args: | ||
| dash_app: The Dash application instance | ||
| """ | ||
| # pylint: disable=too-many-statements,too-many-locals | ||
| # pylint: disable=too-many-statements,too-many-locals,too-many-branches | ||
| ws_path = dash_app.config.routes_pathname_prefix + "_dash-ws-callback" | ||
|
|
||
| # Get allowed origins from dash app config | ||
|
|
@@ -729,16 +729,28 @@ async def websocket_handler(websocket: WebSocket): | |
|
|
||
| await websocket.accept() | ||
|
|
||
| # The connection's event loop, used to dispatch async callbacks as | ||
| # tasks and to resolve get_props futures. | ||
| loop = asyncio.get_running_loop() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another option (maybe in tandem) would be to add a watchdog thread outside of the pool that could keep an eye out for bottlenecked callbacks. |
||
|
|
||
| # Create janus queue for outbound messages (main loop context) | ||
| outbound_queue: janus.Queue[str] = janus.Queue() | ||
| # Track pending get_props requests with standard queue.Queue for responses | ||
| pending_get_props: Dict[str, queue.Queue] = {} | ||
| # Track pending get_props requests. Values are queue.Queue (threadpool / | ||
| # sync path) or asyncio.Future (event-loop / async path). | ||
| pending_get_props: Dict[str, queue.Queue | asyncio.Future] = {} | ||
| # Shutdown event to signal connection closure to worker threads | ||
| shutdown_event = threading.Event() | ||
| # Get thread pool executor | ||
| executor = self.get_callback_executor() | ||
| # Track pending callback futures | ||
| pending_callbacks: Dict[str, concurrent.futures.Future] = {} | ||
| # Sync callbacks run on a shared thread pool executor (async callbacks | ||
| # run directly on the event loop). A single bounded pool caps the total | ||
| # worker-thread count regardless of how many connections are open. | ||
| executor = self.get_callback_executor( | ||
| getattr(dash_app, "_websocket_max_workers", 4) | ||
| ) | ||
| # Track pending callbacks: concurrent.futures.Future (sync/threadpool) | ||
| # or asyncio.Task (async/event-loop). | ||
| pending_callbacks: Dict[ | ||
| str, concurrent.futures.Future | asyncio.Future | ||
| ] = {} | ||
|
|
||
| # Start sender task to drain outbound queue (sends pre-serialized text) | ||
| # pylint: disable=protected-access | ||
|
|
@@ -777,66 +789,78 @@ async def websocket_handler(websocket: WebSocket): | |
| dash_app._websocket_callbacks, | ||
| ) | ||
|
|
||
| # Create WebSocket callback instance | ||
| # Async callbacks (incl. persistent ones) run as tasks on | ||
| # the event loop; sync callbacks go to the threadpool. | ||
| # pylint: disable=protected-access | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can still remove these stale pylint comments. |
||
| cb_spec = dash_app.callback_map.get(payload.get("output"), {}) | ||
| is_async = inspect.iscoroutinefunction(cb_spec.get("callback")) | ||
|
|
||
| # Create WebSocket callback instance. The loop is passed only | ||
| # for the async path so get_prop awaits instead of blocking. | ||
| ws_cb = DashWebsocketCallback( | ||
| pending_get_props, | ||
| renderer_id, | ||
| outbound_queue, | ||
| shutdown_event, | ||
| loop if is_async else None, | ||
| ) | ||
|
|
||
| # Submit callback to executor | ||
| future = run_callback_in_executor( | ||
| executor, | ||
| dash_app, | ||
| payload, | ||
| ws_cb, | ||
| FastAPIResponseAdapter(), | ||
| done_handler = make_callback_done_handler( | ||
| outbound_queue, | ||
| pending_callbacks, | ||
| request_id, | ||
| renderer_id, | ||
| shutdown_event, | ||
| ) | ||
|
|
||
| # Set up done callback to send response | ||
| future.add_done_callback( | ||
| make_callback_done_handler( | ||
| outbound_queue, | ||
| pending_callbacks, | ||
| request_id, | ||
| renderer_id, | ||
| shutdown_event, | ||
| if is_async: | ||
| task = asyncio.create_task( | ||
| run_callback_on_loop( | ||
| dash_app, | ||
| payload, | ||
| ws_cb, | ||
| FastAPIResponseAdapter(), | ||
| ) | ||
| ) | ||
| ) | ||
| pending_callbacks[request_id] = future | ||
| task.add_done_callback(done_handler) | ||
| pending_callbacks[request_id] = task | ||
| else: | ||
| # Submit callback to executor | ||
| future = run_callback_in_executor( | ||
| executor, | ||
| dash_app, | ||
| payload, | ||
| ws_cb, | ||
| FastAPIResponseAdapter(), | ||
| ) | ||
| # Set up done callback to send response | ||
| future.add_done_callback(done_handler) | ||
| pending_callbacks[request_id] = future | ||
|
|
||
| elif msg_type == "get_props_response": | ||
| # Put response in waiting queue (non-blocking) | ||
| # Resolve the waiting future (async path) or queue (thread | ||
| # path) for this request (non-blocking). | ||
| request_id = message.get("requestId") | ||
| response_queue = pending_get_props.get(request_id) | ||
| if response_queue is not None: | ||
| response_queue.put_nowait(message.get("payload")) | ||
| pending = pending_get_props.get(request_id) | ||
| if isinstance(pending, asyncio.Future): | ||
| if not pending.done(): | ||
| pending.set_result(message.get("payload")) | ||
| elif pending is not None: | ||
| pending.put_nowait(message.get("payload")) | ||
|
|
||
| elif msg_type == "heartbeat": | ||
| outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') | ||
|
|
||
| except WebSocketDisconnect: | ||
| pass # Clean disconnect | ||
| finally: | ||
| # Signal shutdown to worker threads | ||
| shutdown_event.set() | ||
| # Unblock any threads waiting on get_prop responses | ||
| for response_queue in pending_get_props.values(): | ||
| response_queue.put_nowait(DISCONNECTED) | ||
| # Signal sender to shutdown and cancel it | ||
| outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) | ||
| sender_task.cancel() | ||
| try: | ||
| await sender_task | ||
| except asyncio.CancelledError: | ||
| pass | ||
| # Close the janus queue | ||
| outbound_queue.close() | ||
| await outbound_queue.wait_closed() | ||
| # Cancel any pending futures | ||
| for f in pending_callbacks.values(): | ||
| f.cancel() | ||
| await shutdown_ws_connection( | ||
| shutdown_event, | ||
| pending_get_props, | ||
| pending_callbacks, | ||
| outbound_queue, | ||
| sender_task, | ||
| ) | ||
|
|
||
| self.server.add_api_websocket_route(ws_path, websocket_handler) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of enabling debug mode for the loop here (and in a similar place for Quart)? You could use the existing debug config variable to turn this on. This way, users would get warning when a callback is executing slowly. You could also add a way for users to configure
loop.slow_callback_duration.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, maybe as a follow up.