diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f0d7c9abc..65589064a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [UNRELEASED] +### Added +- [#3826](https://github.com/plotly/dash/pull/3826) WebSocket callback dispatch no longer lets long-lived callbacks limit the number of concurrent users. Async callbacks (including session-persistent ones) run directly on the connection event loop instead of occupying a worker thread, and synchronous callbacks run on a shared `ThreadPoolExecutor` whose size is configurable via the new `websocket_max_workers` argument to `Dash` (default `4`). A synchronous persistent (no-output) callback now warns at registration since it would tie up a worker thread. + ## Fixed - [#3822](https://github.com/plotly/dash/pull/3822) Fix `UnboundLocalError` for `user_callback_output` in async background callbacks (Celery and Diskcache managers) when the callback raises `PreventUpdate` or another exception before the variable is assigned. - [#3819](https://github.com/plotly/dash/pull/3819) Fix `RuntimeError: No active request in context` when a non-Dash path falls through to the FastAPI catch-all route. Fixes [#3812](https://github.com/plotly/dash/issues/3812). diff --git a/dash/_callback.py b/dash/_callback.py index cf34172a80..81a5345830 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib import inspect +import warnings from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict, TypeVar, cast @@ -880,6 +881,19 @@ async def async_add_context(*args, **kwargs): if inspect.iscoroutinefunction(func): callback_map[callback_id]["callback"] = async_add_context else: + # A persistent, no-output callback streams via set_props and typically + # runs for the life of the connection. When synchronous it occupies a + # WebSocket worker thread the whole time and can exhaust the pool, so + # warn that it should be async (async callbacks run on the event loop). + if _kwargs.get("persistent") and not has_output: + warnings.warn( + f"persistent=True callback '{callback_id}' is synchronous and " + "has no Output; it will occupy a WebSocket worker thread for the " + "life of the connection and can exhaust the pool. Define it with " + "'async def' so it runs on the event loop instead.", + RuntimeWarning, + stacklevel=2, + ) callback_map[callback_id]["callback"] = add_context return func diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 1bcf235036..b4285105f5 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,4 +1,3 @@ -import asyncio import functools import warnings import json @@ -367,20 +366,13 @@ def set_props(component_id: typing.Union[str, dict], props: dict): """ ws = _get_from_context("dash_websocket", None) if ws is not None: - # Stream immediately via WebSocket + # Stream immediately via WebSocket. Queuing is synchronous and thread-safe + # (janus sync side), so we queue directly instead of scheduling a task. This + # avoids detached/orphaned tasks when the callback runs on the event loop and + # preserves ordering relative to the callback response. _id = stringify_id(component_id) - - async def _send_props(): - for prop_name, value in props.items(): - await ws.set_prop(_id, prop_name, value) - - # If we're in an async context, schedule the coroutine - try: - asyncio.get_running_loop() - asyncio.ensure_future(_send_props()) - except RuntimeError: - # No running event loop - run synchronously - asyncio.run(_send_props()) + for prop_name, value in props.items(): + ws.set_prop_sync(_id, prop_name, value) else: # Batch for response (existing behavior) callback_context.set_props(component_id, props) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4caf45e172..6a23c4873b 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -5,7 +5,7 @@ import concurrent.futures import json import queue -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict, List import sys import mimetypes import hashlib @@ -46,9 +46,9 @@ DashWebsocketCallback, run_ws_sender, run_callback_in_executor, + run_callback_on_loop, make_callback_done_handler, - SHUTDOWN_SIGNAL, - DISCONNECTED, + shutdown_ws_connection, ) from ._utils import format_traceback_html @@ -263,8 +263,8 @@ def __init__(self, server: FastAPI): self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter self.response_adapter = FastAPIResponseAdapter - self._before_request_funcs = [] - self._after_request_func = None + self._before_request_funcs: List[Callable[[], Any]] = [] + self._after_request_func: Callable[[], Any] | None = None self._enable_timing = False def __call__(self, *args: Any, **kwargs: Any): @@ -685,7 +685,7 @@ def serve_websocket_callback(self, dash_app: "Dash"): Args: dash_app: The Dash application instance """ - # pylint: disable=too-many-statements,too-many-locals + # pylint: disable=too-many-statements,too-many-locals,too-many-branches ws_path = dash_app.config.routes_pathname_prefix + "_dash-ws-callback" # Get allowed origins from dash app config @@ -729,16 +729,28 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() + # The connection's event loop, used to dispatch async callbacks as + # tasks and to resolve get_props futures. + loop = asyncio.get_running_loop() + # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() - # Track pending get_props requests with standard queue.Queue for responses - pending_get_props: Dict[str, queue.Queue] = {} + # Track pending get_props requests. Values are queue.Queue (threadpool / + # sync path) or asyncio.Future (event-loop / async path). + pending_get_props: Dict[str, queue.Queue | asyncio.Future] = {} # Shutdown event to signal connection closure to worker threads shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() - # Track pending callback futures - pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Sync callbacks run on a shared thread pool executor (async callbacks + # run directly on the event loop). A single bounded pool caps the total + # worker-thread count regardless of how many connections are open. + executor = self.get_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) + # Track pending callbacks: concurrent.futures.Future (sync/threadpool) + # or asyncio.Task (async/event-loop). + pending_callbacks: Dict[ + str, concurrent.futures.Future | asyncio.Future + ] = {} # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -777,41 +789,64 @@ async def websocket_handler(websocket: WebSocket): dash_app._websocket_callbacks, ) - # Create WebSocket callback instance + # Async callbacks (incl. persistent ones) run as tasks on + # the event loop; sync callbacks go to the threadpool. + # pylint: disable=protected-access + cb_spec = dash_app.callback_map.get(payload.get("output"), {}) + is_async = inspect.iscoroutinefunction(cb_spec.get("callback")) + + # Create WebSocket callback instance. The loop is passed only + # for the async path so get_prop awaits instead of blocking. ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, shutdown_event, + loop if is_async else None, ) - # Submit callback to executor - future = run_callback_in_executor( - executor, - dash_app, - payload, - ws_cb, - FastAPIResponseAdapter(), + done_handler = make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + shutdown_event, ) - # Set up done callback to send response - future.add_done_callback( - make_callback_done_handler( - outbound_queue, - pending_callbacks, - request_id, - renderer_id, - shutdown_event, + if is_async: + task = asyncio.create_task( + run_callback_on_loop( + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) ) - ) - pending_callbacks[request_id] = future + task.add_done_callback(done_handler) + pending_callbacks[request_id] = task + else: + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) + # Set up done callback to send response + future.add_done_callback(done_handler) + pending_callbacks[request_id] = future elif msg_type == "get_props_response": - # Put response in waiting queue (non-blocking) + # Resolve the waiting future (async path) or queue (thread + # path) for this request (non-blocking). request_id = message.get("requestId") - response_queue = pending_get_props.get(request_id) - if response_queue is not None: - response_queue.put_nowait(message.get("payload")) + pending = pending_get_props.get(request_id) + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(message.get("payload")) + elif pending is not None: + pending.put_nowait(message.get("payload")) elif msg_type == "heartbeat": outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') @@ -819,24 +854,13 @@ async def websocket_handler(websocket: WebSocket): except WebSocketDisconnect: pass # Clean disconnect finally: - # Signal shutdown to worker threads - shutdown_event.set() - # Unblock any threads waiting on get_prop responses - for response_queue in pending_get_props.values(): - response_queue.put_nowait(DISCONNECTED) - # Signal sender to shutdown and cancel it - outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass - # Close the janus queue - outbound_queue.close() - await outbound_queue.wait_closed() - # Cancel any pending futures - for f in pending_callbacks.values(): - f.cancel() + await shutdown_ws_connection( + shutdown_event, + pending_get_props, + pending_callbacks, + outbound_queue, + sender_task, + ) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index a6d09d1e1c..dde98061d1 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -51,9 +51,9 @@ DashWebsocketCallback, run_ws_sender, run_callback_in_executor, + run_callback_on_loop, make_callback_done_handler, - SHUTDOWN_SIGNAL, - DISCONNECTED, + shutdown_ws_connection, ) from ._utils import format_traceback_html @@ -95,7 +95,7 @@ class QuartDashServer(BaseDashServer[Quart]): def __init__(self, server: Quart) -> None: super().__init__(server) self.server_type = "quart" - self.config = {} + self.config: Dict[str, Any] = {} self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter self.response_adapter = QuartResponseAdapter @@ -150,7 +150,7 @@ async def _wrap_errors(error): tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") - def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + def register_timing_hooks(self, _first_run: bool): # type: ignore[override] # parity with Flask factory @self.server.before_request async def _before_request(): # pragma: no cover - timing infra if quart_g is not None: @@ -382,7 +382,7 @@ def add_redirect_rule(self, app, fullname, path): ) # pylint: disable=unused-argument - def serve_callback(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + def serve_callback(self, dash_app: Dash): # type: ignore[override] # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() @@ -545,6 +545,10 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() + # The connection's event loop, used to dispatch async callbacks as + # tasks and to resolve get_props futures. + loop = asyncio.get_running_loop() + # Track this connection for graceful shutdown try: ws_obj = ws._get_current_object() @@ -555,14 +559,22 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() - # Track pending get_props requests with standard queue.Queue for responses - pending_get_props: Dict[str, queue.Queue] = {} + # Track pending get_props requests. Values are queue.Queue (threadpool / + # sync path) or asyncio.Future (event-loop / async path). + pending_get_props: Dict[str, queue.Queue | asyncio.Future] = {} # Shutdown event to signal connection closure to worker threads connection_shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() - # Track pending callback futures - pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Sync callbacks run on a shared thread pool executor (async callbacks + # run directly on the event loop). A single bounded pool caps the total + # worker-thread count regardless of how many connections are open. + executor = self.get_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) + # Track pending callbacks: concurrent.futures.Future (sync/threadpool) + # or asyncio.Task (async/event-loop). + pending_callbacks: Dict[ + str, concurrent.futures.Future | asyncio.Future + ] = {} # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -608,41 +620,64 @@ async def websocket_handler(): # pylint: disable=too-many-branches dash_app._websocket_callbacks, ) - # Create WebSocket callback instance + # Async callbacks (incl. persistent ones) run as tasks on + # the event loop; sync callbacks go to the threadpool. + # pylint: disable=protected-access + cb_spec = dash_app.callback_map.get(payload.get("output"), {}) + is_async = inspect.iscoroutinefunction(cb_spec.get("callback")) + + # Create WebSocket callback instance. The loop is passed only + # for the async path so get_prop awaits instead of blocking. ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, connection_shutdown_event, + loop if is_async else None, ) - # Submit callback to executor - future = run_callback_in_executor( - executor, - dash_app, - payload, - ws_cb, - QuartResponseAdapter(), + done_handler = make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + connection_shutdown_event, ) - # Set up done callback to send response - future.add_done_callback( - make_callback_done_handler( - outbound_queue, - pending_callbacks, - request_id, - renderer_id, - connection_shutdown_event, + if is_async: + task = asyncio.create_task( + run_callback_on_loop( + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) ) - ) - pending_callbacks[request_id] = future + task.add_done_callback(done_handler) + pending_callbacks[request_id] = task + else: + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) + # Set up done callback to send response + future.add_done_callback(done_handler) + pending_callbacks[request_id] = future elif msg_type == "get_props_response": - # Put response in waiting queue (non-blocking) + # Resolve the waiting future (async path) or queue (thread + # path) for this request (non-blocking). request_id = message.get("requestId") - response_queue = pending_get_props.get(request_id) - if response_queue is not None: - response_queue.put_nowait(message.get("payload")) + pending = pending_get_props.get(request_id) + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(message.get("payload")) + elif pending is not None: + pending.put_nowait(message.get("payload")) elif msg_type == "heartbeat": outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') @@ -653,24 +688,13 @@ async def websocket_handler(): # pylint: disable=too-many-branches pass # Other exceptions treated as disconnect finally: self._active_websockets.discard(ws_obj) - # Signal shutdown to worker threads - connection_shutdown_event.set() - # Unblock any threads waiting on get_prop responses - for response_queue in pending_get_props.values(): - response_queue.put_nowait(DISCONNECTED) - # Signal sender to shutdown and cancel it - outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass - # Close the janus queue - outbound_queue.close() - await outbound_queue.wait_closed() - # Cancel any pending futures - for f in pending_callbacks.values(): - f.cancel() + await shutdown_ws_connection( + connection_shutdown_event, + pending_get_props, + pending_callbacks, + outbound_queue, + sender_task, + ) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 52443d4104..5443662dd2 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -194,7 +194,13 @@ def __init__(self, server: ServerType) -> None: def get_callback_executor( self, max_workers: int | None = None ) -> ThreadPoolExecutor: - """Get or create the thread pool executor for callback execution. + """Get or create the shared thread pool executor for sync callbacks. + + A single executor is shared across all WebSocket connections. Only + *sync* callbacks run here -- async callbacks (including session-persistent + ones) run directly on the connection event loop -- so worker threads are + released promptly and a fixed-size shared pool bounds the total thread + count regardless of how many connections are open. Args: max_workers: Maximum number of worker threads. If None, uses default. @@ -209,7 +215,7 @@ def get_callback_executor( return self._callback_executor def shutdown_executor(self, wait: bool = True) -> None: - """Shutdown the callback executor. + """Shutdown the shared callback executor. Args: wait: If True, wait for pending tasks to complete. diff --git a/dash/backends/ws.py b/dash/backends/ws.py index d784ea291a..6d0168ecda 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -65,24 +65,30 @@ async def long_running(n_clicks): def __init__( self, - pending_get_props: Dict[str, queue.Queue[Any]], + pending_get_props: Dict[str, Any], renderer_id: str, outbound_queue: janus.Queue[str], shutdown_event: "threading.Event", + loop: "asyncio.AbstractEventLoop | None" = None, ): """Initialize the WebSocket callback interface. Args: - pending_get_props: Dict to track pending get_props requests. - Values are queue.Queue instances for blocking response retrieval. + pending_get_props: Dict to track pending get_props requests. Values are + queue.Queue instances (blocking, thread path) or asyncio.Future + instances (awaitable, event-loop path), keyed by request_id. renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. shutdown_event: Event signaling the websocket connection has closed. + loop: The connection's event loop when the callback runs on the loop + (async dispatch). When set, get_prop uses an awaitable asyncio.Future + instead of blocking on a queue.Queue. None for the threadpool path. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue self._shutdown_event = shutdown_event + self._loop = loop @property def is_shutdown(self) -> bool: @@ -93,7 +99,7 @@ def _get_outbound_queue(self) -> janus.Queue[str] | None: """Get the outbound queue.""" return self._outbound_queue - def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None: + def _get_pending_get_props(self) -> Dict[str, Any] | None: """Get the pending_get_props dict.""" return self._pending_get_props @@ -109,10 +115,12 @@ def _queue_message(self, msg: dict) -> None: if outbound_queue is not None: outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) - async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: - """Send immediate prop update to the client via WebSocket. + def set_prop_sync(self, component_id: str, prop_name: str, value: Any) -> None: + """Queue an immediate prop update to the client (synchronous, non-blocking). - Queues the message for the sender coroutine to send. + Queuing is thread-safe (janus sync side) and never awaits, so this can be + called directly from the event loop or a worker thread without scheduling a + task. ``set_prop`` is the async wrapper kept for backward compatibility. Args: component_id: The component ID (string or stringified dict) @@ -126,12 +134,27 @@ async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: } self._queue_message(msg) + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Queues the message for the sender coroutine to send. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + self.set_prop_sync(component_id, prop_name, value) + async def get_prop( self, component_id: str, prop_name: str, timeout: float = 30.0 ) -> Any: """Request current prop value from the client. - Uses queue.Queue for blocking wait in worker thread. + On the event-loop path (``self._loop`` set, async callbacks) the wait uses an + awaitable ``asyncio.Future`` so the connection loop is never blocked. On the + threadpool path (``self._loop`` is None, sync callbacks) it blocks on a + ``queue.Queue`` in the worker thread. Args: component_id: The component ID (string or stringified dict) @@ -160,24 +183,57 @@ async def get_prop( "payload": {"componentId": component_id, "properties": [prop_name]}, } + if self._loop is not None: + result = await self._get_prop_async(request_id, msg, timeout) + else: + result = await self._get_prop_blocking(request_id, msg, timeout) + + if result == DISCONNECTED: + raise WebsocketDisconnected() + if result and prop_name in result: + return result[prop_name] + return None + + async def _get_prop_async(self, request_id: str, msg: dict, timeout: float) -> Any: + """Await a get_props response on the connection event loop.""" + pending_get_props = self._get_pending_get_props() + future: "asyncio.Future" = self._loop.create_future() # type: ignore[union-attr] + pending_get_props[request_id] = future # type: ignore[index] + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + try: + return await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"Timeout waiting for {msg['payload']['componentId']}." + f"{msg['payload']['properties'][0]}" + ) from exc + finally: + current_pending = self._get_pending_get_props() + if current_pending is not None: + current_pending.pop(request_id, None) + + async def _get_prop_blocking( + self, request_id: str, msg: dict, timeout: float + ) -> Any: + """Block on a get_props response in a worker thread (threadpool path).""" + pending_get_props = self._get_pending_get_props() # Use standard queue.Queue for response response_queue: queue.Queue = queue.Queue() - pending_get_props[request_id] = response_queue + pending_get_props[request_id] = response_queue # type: ignore[index] # Queue the outbound request via janus sync interface self._queue_message(msg) # Wait for response (blocking is OK in worker thread) try: - result = response_queue.get(timeout=timeout) - if result == DISCONNECTED: - raise WebsocketDisconnected() - if result and prop_name in result: - return result[prop_name] - return None + return response_queue.get(timeout=timeout) except queue.Empty as exc: raise TimeoutError( - f"Timeout waiting for {component_id}.{prop_name}" + f"Timeout waiting for {msg['payload']['componentId']}." + f"{msg['payload']['properties'][0]}" ) from exc finally: # Get fresh reference in case of reconnection @@ -323,14 +379,68 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool return False # WebSocketDisconnect, RuntimeError, etc. +async def shutdown_ws_connection( + shutdown_event: threading.Event, + pending_get_props: Dict[str, Any], + pending_callbacks: Dict[str, "concurrent.futures.Future | asyncio.Future"], + outbound_queue: janus.Queue[str], + sender_task: "asyncio.Task", +) -> None: + """Tear down a WebSocket connection's callback machinery. + + Shared by the FastAPI and Quart handlers so the ordering (which is + correctness-sensitive) stays in one place. Async callback tasks are cancelled + and awaited *before* the outbound queue is closed, so their cleanup can't touch + a closed queue. + + The callback executor is intentionally *not* shut down here: it is a single + pool shared across all connections, so it outlives any one connection. + + The dicts are snapshotted before iteration: threadpool done-handlers and the + blocking get_prop path pop from them on worker threads, so iterating the live + dicts here can race. + """ + # Signal shutdown to worker threads + shutdown_event.set() + # Unblock anything waiting on get_prop responses (futures or queues) + for pending in list(pending_get_props.values()): + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(DISCONNECTED) + else: + pending.put_nowait(DISCONNECTED) + callbacks = list(pending_callbacks.values()) + # Cancel running async callback tasks and let them unwind while the outbound + # queue is still open (their cleanup may touch it). + async_tasks = [f for f in callbacks if isinstance(f, asyncio.Future)] + for f in async_tasks: + f.cancel() + if async_tasks: + await asyncio.gather(*async_tasks, return_exceptions=True) + # Cancel any pending threadpool futures (running ones run to completion) + for cb in callbacks: + if not isinstance(cb, asyncio.Future): + cb.cancel() + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + + def make_callback_done_handler( outbound_queue: janus.Queue[str], - pending_callbacks: Dict[str, concurrent.futures.Future], + pending_callbacks: Dict[str, "concurrent.futures.Future | asyncio.Future"], request_id: str, renderer_id: str, shutdown_event: threading.Event, -) -> Callable[[concurrent.futures.Future], None]: - """Create a done callback handler for executor futures. +) -> Callable[[Any], None]: + """Create a done callback handler for executor futures or event-loop tasks. This factory creates a callback that sends the result back through the WebSocket when an executor future completes. @@ -346,7 +456,7 @@ def make_callback_done_handler( A callback function suitable for Future.add_done_callback() """ - def on_done(f: concurrent.futures.Future) -> None: + def on_done(f: "concurrent.futures.Future | asyncio.Future") -> None: try: if shutdown_event.is_set(): return @@ -364,6 +474,11 @@ def on_done(f: concurrent.futures.Future) -> None: ), ) ) + except asyncio.CancelledError: + # Task cancelled (e.g. on disconnect). CancelledError is a + # BaseException, so it is not caught by the broad except below; + # return here and let the finally clause pop the pending entry. + return except Exception as e: # pylint: disable=broad-exception-caught if shutdown_event.is_set(): return @@ -391,6 +506,27 @@ def on_done(f: concurrent.futures.Future) -> None: return on_done +def _prepare_ws_partial( + dash_app: "dash.Dash", + payload: CallbackExecutionBody, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> Callable[[], Any]: + """Build the callback context and return the partial ready to be invoked. + + Shared by the threadpool (sync) and event-loop (async) dispatch paths. + """ + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + return dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + + def run_callback_in_executor( executor: ThreadPoolExecutor, dash_app: "dash.Dash", @@ -398,10 +534,10 @@ def run_callback_in_executor( ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", ) -> concurrent.futures.Future: - """Submit callback to executor for thread pool execution. + """Submit a synchronous callback to the executor for thread pool execution. - This function creates a callback execution context and runs it - in a separate thread. Both sync and async callbacks are supported. + Used for sync callbacks. Async callbacks are dispatched on the event loop via + ``run_callback_on_loop`` instead, so they no longer occupy a worker thread. Args: executor: ThreadPoolExecutor to submit the task to @@ -416,21 +552,14 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals( # pylint: disable=protected-access - cb_ctx.inputs_list + cb_ctx.states_list + partial_func = _prepare_ws_partial( + dash_app, payload, ws_callback, response_adapter ) ctx = copy_context() - partial_func = ( - dash_app._execute_callback( # pylint: disable=protected-access - func, args, cb_ctx.outputs_list, cb_ctx - ) - ) - # Run in new event loop (handles both sync and async callbacks) + # Run in new event loop (handles a callback that still returns a + # coroutine, e.g. when reached outside the async dispatch path) def run_callback(): result = partial_func() if inspect.iscoroutine(result): @@ -449,3 +578,46 @@ def run_callback(): return {"status": "error", "message": str(e)} return executor.submit(execute) + + +async def run_callback_on_loop( + dash_app: "dash.Dash", + payload: CallbackExecutionBody, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> dict: + """Run an async callback as a task on the connection's event loop. + + This is the event-loop counterpart of ``run_callback_in_executor``: instead of + pinning a worker thread with ``asyncio.run``, the callback coroutine is awaited + directly. Persistent callbacks that await (subscriptions, sleeps) yield the loop, + so hundreds can coexist without exhausting the threadpool. + + No explicit ``copy_context`` is needed: the ``asyncio.Task`` wrapping this + coroutine copies the current context at creation, so each callback's + ``context_value`` mutations stay isolated from the handler and sibling tasks. + + Args: + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + The response dict (same shape as ``run_callback_in_executor``). + """ + try: + partial_func = _prepare_ws_partial( + dash_app, payload, ws_callback, response_adapter + ) + result = partial_func() + response_data = await result if inspect.iscoroutine(result) else result + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except WebsocketDisconnected: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} diff --git a/dash/dash.py b/dash/dash.py index 6f60482c3a..36a12c6d73 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -490,6 +490,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_inactivity_timeout: Optional[int] = 300000, websocket_heartbeat_interval: Optional[int] = 30000, websocket_batch_delay: Optional[float] = 0.005, + websocket_max_workers: Optional[int] = 4, enable_mcp: Optional[bool] = None, mcp_path: Optional[str] = None, **obsolete, @@ -662,6 +663,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._websocket_inactivity_timeout = websocket_inactivity_timeout self._websocket_heartbeat_interval = websocket_heartbeat_interval self._websocket_batch_delay = websocket_batch_delay + self._websocket_max_workers = websocket_max_workers self.logger = logging.getLogger(__name__) diff --git a/tests/unit/test_websocket_executor.py b/tests/unit/test_websocket_executor.py new file mode 100644 index 0000000000..5ab957768c --- /dev/null +++ b/tests/unit/test_websocket_executor.py @@ -0,0 +1,63 @@ +"""Unit tests for the shared WebSocket callback thread pool. + +These verify that a single app-wide ``ThreadPoolExecutor`` is shared across all +WebSocket connections. Only *sync* callbacks run on it -- async (incl. +session-persistent) callbacks run directly on the event loop -- so a fixed-size +shared pool bounds the total worker-thread count regardless of how many +connections are open. The pool size is configurable via the +``websocket_max_workers`` argument to ``Dash``. +""" + +from concurrent.futures import ThreadPoolExecutor + +from dash import Dash + + +def test_websocket_max_workers_default(): + """websocket_max_workers defaults to 4.""" + app = Dash(__name__) + assert app._websocket_max_workers == 4 + + +def test_websocket_max_workers_custom(): + """websocket_max_workers is stored when provided.""" + app = Dash(__name__, websocket_max_workers=16) + assert app._websocket_max_workers == 16 + + +def test_get_callback_executor_is_shared(): + """Repeated calls return the same cached, app-wide executor.""" + backend = Dash(__name__).backend + + ex1 = backend.get_callback_executor(4) + ex2 = backend.get_callback_executor(4) + try: + assert isinstance(ex1, ThreadPoolExecutor) + # Same instance => total thread count is bounded across connections. + assert ex1 is ex2 + finally: + backend.shutdown_executor(wait=False) + + +def test_get_callback_executor_honors_max_workers(): + """max_workers is forwarded to the ThreadPoolExecutor.""" + backend = Dash(__name__).backend + + ex = backend.get_callback_executor(7) + try: + assert ex._max_workers == 7 + finally: + backend.shutdown_executor(wait=False) + + +def test_shutdown_executor_allows_recreation(): + """After shutdown the next get_callback_executor call creates a fresh pool.""" + backend = Dash(__name__).backend + + ex1 = backend.get_callback_executor(4) + backend.shutdown_executor(wait=False) + ex2 = backend.get_callback_executor(4) + try: + assert ex1 is not ex2 + finally: + backend.shutdown_executor(wait=False) diff --git a/tests/websocket/test_ws_threadpool.py b/tests/websocket/test_ws_threadpool.py new file mode 100644 index 0000000000..123590e403 --- /dev/null +++ b/tests/websocket/test_ws_threadpool.py @@ -0,0 +1,118 @@ +""" +WebSocket callback dispatch tests: async callbacks run on the event loop, sync +callbacks run on the shared threadpool. + +Tests: +- Many long-lived async (persistent-style) callbacks do not exhaust the worker + threadpool, so regular callbacks still respond (thread-exhaustion regression). +- A synchronous persistent (no-output) callback warns at registration. +""" + +import asyncio + +import pytest + +from dash import Dash, html, Input, Output, ctx, set_props +from dash.exceptions import PreventUpdate + + +def test_ws050_async_callbacks_do_not_exhaust_threadpool(dash_duo): + """Many long-lived async callbacks must not starve regular callbacks. + + On the old dispatch, every async callback ran via ``asyncio.run`` inside a + worker thread, so a long-lived (never-returning) async callback pinned one of + the ``max_workers=4`` threads for the whole connection. Five of them filled the + pool and wedged regular callbacks ("Loading…"). Async callbacks now run as tasks + on the connection event loop, so they cost ~nothing and the threadpool stays + free for sync callbacks. + """ + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + n_long = 6 # > default max_workers (4) + + app.layout = html.Div( + [ + html.Button("Start long tasks", id="start", n_clicks=0), + html.Button("Regular", id="reg-btn", n_clicks=0), + html.Div("idle", id="reg-out"), + *[html.Div("idle", id=f"long-{i}") for i in range(n_long)], + ] + ) + + def make_long_callback(i): + @app.callback( + Output(f"long-{i}", "children"), + Input("start", "n_clicks"), + prevent_initial_call=True, + ) + async def _long(n): + ws = ctx.websocket + set_props(f"long-{i}", {"children": "running"}) + # Long-lived: loops for ~12s, yielding the loop on every iteration. + for _ in range(60): + if ws and ws.is_shutdown: + raise PreventUpdate + await asyncio.sleep(0.2) + return "done" + + for i in range(n_long): + make_long_callback(i) + + # A regular synchronous callback that must keep responding while the long + # async callbacks are running. + @app.callback( + Output("reg-out", "children"), + Input("reg-btn", "n_clicks"), + prevent_initial_call=True, + ) + def regular(n): + return f"ok {n}" + + dash_duo.start_server(app) + + # Kick off all the long-lived async callbacks. + dash_duo.find_element("#start").click() + # They should all reach the "running" state (would not all start on dev with + # only 4 worker threads if they pinned threads). + for i in range(n_long): + dash_duo.wait_for_text_to_equal(f"#long-{i}", "running", timeout=10) + + # The regular callback must still respond promptly while the long tasks run. + dash_duo.find_element("#reg-btn").click() + dash_duo.wait_for_text_to_equal("#reg-out", "ok 1", timeout=5) + + assert dash_duo.get_logs() == [] + + +def test_ws051_sync_persistent_callback_warns(): + """A synchronous persistent (no-output) callback warns at registration. + + Registered on a local app (not the global registry) so it can't leak phantom + callbacks into later tests. + """ + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + with pytest.warns(RuntimeWarning, match="persistent=True"): + + @app.callback( + Input("trigger", "n_clicks"), + persistent=True, + websocket=True, + ) + def _sync_persistent(n): # pragma: no cover - never executed + set_props("out", {"children": "x"}) + + +def test_ws052_async_persistent_callback_does_not_warn(recwarn): + """An async persistent (no-output) callback must not warn.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + @app.callback( + Input("trigger2", "n_clicks"), + persistent=True, + websocket=True, + ) + async def _async_persistent(n): # pragma: no cover - never executed + set_props("out", {"children": "x"}) + + assert not [w for w in recwarn.list if issubclass(w.category, RuntimeWarning)]