From 8a76f8c56582906ec4778bf0cb7c456b6d125286 Mon Sep 17 00:00:00 2001 From: philippe Date: Wed, 17 Jun 2026 16:26:33 -0400 Subject: [PATCH 1/4] Add per-connection threadpool websocket callback executor. --- CHANGELOG.md | 3 ++ dash/backends/_fastapi.py | 10 ++++-- dash/backends/_quart.py | 10 ++++-- dash/backends/base_server.py | 28 ++++++--------- dash/dash.py | 2 ++ tests/unit/test_websocket_executor.py | 51 +++++++++++++++++++++++++++ 6 files changed, 82 insertions(+), 22 deletions(-) create mode 100644 tests/unit/test_websocket_executor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e18cb41c78..50e6614b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [UNRELEASED] +### Added +- Per-connection WebSocket callback thread pools. Each WebSocket connection now gets its own `ThreadPoolExecutor` instead of sharing a single app-wide pool, so long-lived (session-persistent) callbacks on one connection no longer limit the number of concurrent users. The per-connection size is configurable via the new `websocket_max_workers` argument to `Dash` (default `4`). + ### Fixed - [#3805](https://github.com/plotly/dash/pull/3805) Fix FastAPI POST routes deadlock caused by middleware consuming request body. Fixes [#3801](https://github.com/plotly/dash/issues/3801). - [#3813](https://github.com/plotly/dash/pull/3813) Fix websockets using incorrect path when deployed behind a proxy diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 97dce1379a..61235b6695 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -731,8 +731,12 @@ async def websocket_handler(websocket: WebSocket): pending_get_props: Dict[str, queue.Queue] = {} # Shutdown event to signal connection closure to worker threads shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() + # Create a per-connection thread pool executor so that long-lived + # callbacks on one connection cannot starve worker threads for others. + # pylint: disable=protected-access + executor = self.create_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} @@ -833,6 +837,8 @@ async def websocket_handler(websocket: WebSocket): # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Shut down this connection's executor (don't block the event loop) + executor.shutdown(wait=False) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index a6d09d1e1c..0cc8772a76 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -559,8 +559,12 @@ async def websocket_handler(): # pylint: disable=too-many-branches pending_get_props: Dict[str, queue.Queue] = {} # Shutdown event to signal connection closure to worker threads connection_shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() + # Create a per-connection thread pool executor so that long-lived + # callbacks on one connection cannot starve worker threads for others. + # pylint: disable=protected-access + executor = self.create_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} @@ -671,6 +675,8 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Shut down this connection's executor (don't block the event loop) + executor.shutdown(wait=False) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 52443d4104..ed06663c14 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -189,12 +189,16 @@ def __init__(self, server: ServerType) -> None: """ super().__init__() self.server = server - self._callback_executor: ThreadPoolExecutor | None = None - def get_callback_executor( + def create_callback_executor( self, max_workers: int | None = None ) -> ThreadPoolExecutor: - """Get or create the thread pool executor for callback execution. + """Create a new thread pool executor for callback execution. + + A fresh executor is created per WebSocket connection so that long-lived + (session-persistent) callbacks on one connection cannot exhaust worker + threads shared with other connections. The executor should be shut down + when its connection closes. Args: max_workers: Maximum number of worker threads. If None, uses default. @@ -202,21 +206,9 @@ def get_callback_executor( Returns: ThreadPoolExecutor instance for running callbacks. """ - if self._callback_executor is None: - self._callback_executor = ThreadPoolExecutor( - max_workers=max_workers, thread_name_prefix="dash-callback-" - ) - return self._callback_executor - - def shutdown_executor(self, wait: bool = True) -> None: - """Shutdown the callback executor. - - Args: - wait: If True, wait for pending tasks to complete. - """ - if self._callback_executor is not None: - self._callback_executor.shutdown(wait=wait) - self._callback_executor = None + return ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) def __call__(self, *args, **kwargs) -> Any: """Make the server wrapper callable as a WSGI/ASGI application. diff --git a/dash/dash.py b/dash/dash.py index f547b95b56..05f7700b4c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -490,6 +490,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_inactivity_timeout: Optional[int] = 300000, websocket_heartbeat_interval: Optional[int] = 30000, websocket_batch_delay: Optional[float] = 0.005, + websocket_max_workers: Optional[int] = 4, **obsolete, ): @@ -651,6 +652,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._websocket_inactivity_timeout = websocket_inactivity_timeout self._websocket_heartbeat_interval = websocket_heartbeat_interval self._websocket_batch_delay = websocket_batch_delay + self._websocket_max_workers = websocket_max_workers self.logger = logging.getLogger(__name__) diff --git a/tests/unit/test_websocket_executor.py b/tests/unit/test_websocket_executor.py new file mode 100644 index 0000000000..b693469d1d --- /dev/null +++ b/tests/unit/test_websocket_executor.py @@ -0,0 +1,51 @@ +"""Unit tests for the per-connection WebSocket callback thread pool. + +These verify that each WebSocket connection gets its own ThreadPoolExecutor +(rather than a single shared, app-wide pool), so that long-lived +(session-persistent) callbacks on one connection cannot exhaust worker threads +shared with other connections, and that the per-connection size is configurable +via the ``websocket_max_workers`` argument to ``Dash``. +""" + +from concurrent.futures import ThreadPoolExecutor + +from dash import Dash + + +def test_websocket_max_workers_default(): + """websocket_max_workers defaults to 4.""" + app = Dash(__name__) + assert app._websocket_max_workers == 4 + + +def test_websocket_max_workers_custom(): + """websocket_max_workers is stored when provided.""" + app = Dash(__name__, websocket_max_workers=16) + assert app._websocket_max_workers == 16 + + +def test_create_callback_executor_is_per_connection(): + """Each call returns a fresh executor, not a cached shared one.""" + backend = Dash(__name__).backend + + ex1 = backend.create_callback_executor(4) + ex2 = backend.create_callback_executor(4) + try: + assert isinstance(ex1, ThreadPoolExecutor) + assert isinstance(ex2, ThreadPoolExecutor) + # Distinct instances => one connection's pool can't starve another's. + assert ex1 is not ex2 + finally: + ex1.shutdown(wait=False) + ex2.shutdown(wait=False) + + +def test_create_callback_executor_honors_max_workers(): + """max_workers is forwarded to the ThreadPoolExecutor.""" + backend = Dash(__name__).backend + + ex = backend.create_callback_executor(7) + try: + assert ex._max_workers == 7 + finally: + ex.shutdown(wait=False) From 4f896d0214864c76e0bc09762e63230677651fa1 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 29 Jun 2026 10:05:19 -0400 Subject: [PATCH 2/4] async callbacks run on the loop, non async get the threadpool --- dash/_callback.py | 14 ++ dash/_callback_context.py | 20 +-- dash/backends/_fastapi.py | 118 +++++++------ dash/backends/_quart.py | 116 +++++++------ dash/backends/ws.py | 240 ++++++++++++++++++++++---- tests/websocket/test_ws_threadpool.py | 118 +++++++++++++ 6 files changed, 479 insertions(+), 147 deletions(-) create mode 100644 tests/websocket/test_ws_threadpool.py diff --git a/dash/_callback.py b/dash/_callback.py index cf34172a80..81a5345830 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib import inspect +import warnings from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict, TypeVar, cast @@ -880,6 +881,19 @@ async def async_add_context(*args, **kwargs): if inspect.iscoroutinefunction(func): callback_map[callback_id]["callback"] = async_add_context else: + # A persistent, no-output callback streams via set_props and typically + # runs for the life of the connection. When synchronous it occupies a + # WebSocket worker thread the whole time and can exhaust the pool, so + # warn that it should be async (async callbacks run on the event loop). + if _kwargs.get("persistent") and not has_output: + warnings.warn( + f"persistent=True callback '{callback_id}' is synchronous and " + "has no Output; it will occupy a WebSocket worker thread for the " + "life of the connection and can exhaust the pool. Define it with " + "'async def' so it runs on the event loop instead.", + RuntimeWarning, + stacklevel=2, + ) callback_map[callback_id]["callback"] = add_context return func diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 1bcf235036..b4285105f5 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,4 +1,3 @@ -import asyncio import functools import warnings import json @@ -367,20 +366,13 @@ def set_props(component_id: typing.Union[str, dict], props: dict): """ ws = _get_from_context("dash_websocket", None) if ws is not None: - # Stream immediately via WebSocket + # Stream immediately via WebSocket. Queuing is synchronous and thread-safe + # (janus sync side), so we queue directly instead of scheduling a task. This + # avoids detached/orphaned tasks when the callback runs on the event loop and + # preserves ordering relative to the callback response. _id = stringify_id(component_id) - - async def _send_props(): - for prop_name, value in props.items(): - await ws.set_prop(_id, prop_name, value) - - # If we're in an async context, schedule the coroutine - try: - asyncio.get_running_loop() - asyncio.ensure_future(_send_props()) - except RuntimeError: - # No running event loop - run synchronously - asyncio.run(_send_props()) + for prop_name, value in props.items(): + ws.set_prop_sync(_id, prop_name, value) else: # Batch for response (existing behavior) callback_context.set_props(component_id, props) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 61235b6695..53a93f1e7d 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -4,7 +4,6 @@ import asyncio import concurrent.futures import json -import queue from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -46,9 +45,9 @@ DashWebsocketCallback, run_ws_sender, run_callback_in_executor, + run_callback_on_loop, make_callback_done_handler, - SHUTDOWN_SIGNAL, - DISCONNECTED, + shutdown_ws_connection, ) from ._utils import format_traceback_html @@ -681,7 +680,7 @@ def serve_websocket_callback(self, dash_app: "Dash"): Args: dash_app: The Dash application instance """ - # pylint: disable=too-many-statements,too-many-locals + # pylint: disable=too-many-statements,too-many-locals,too-many-branches ws_path = dash_app.config.routes_pathname_prefix + "_dash-ws-callback" # Get allowed origins from dash app config @@ -725,10 +724,15 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() + # The connection's event loop, used to dispatch async callbacks as + # tasks and to resolve get_props futures. + loop = asyncio.get_running_loop() + # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() - # Track pending get_props requests with standard queue.Queue for responses - pending_get_props: Dict[str, queue.Queue] = {} + # Track pending get_props requests. Values are queue.Queue (threadpool / + # sync path) or asyncio.Future (event-loop / async path). + pending_get_props: Dict[str, Any] = {} # Shutdown event to signal connection closure to worker threads shutdown_event = threading.Event() # Create a per-connection thread pool executor so that long-lived @@ -737,8 +741,11 @@ async def websocket_handler(websocket: WebSocket): executor = self.create_callback_executor( getattr(dash_app, "_websocket_max_workers", 4) ) - # Track pending callback futures - pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track pending callbacks: concurrent.futures.Future (sync/threadpool) + # or asyncio.Task (async/event-loop). + pending_callbacks: Dict[ + str, concurrent.futures.Future | asyncio.Future + ] = {} # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -777,41 +784,64 @@ async def websocket_handler(websocket: WebSocket): dash_app._websocket_callbacks, ) - # Create WebSocket callback instance + # Async callbacks (incl. persistent ones) run as tasks on + # the event loop; sync callbacks go to the threadpool. + # pylint: disable=protected-access + cb_spec = dash_app.callback_map.get(payload.get("output"), {}) + is_async = inspect.iscoroutinefunction(cb_spec.get("callback")) + + # Create WebSocket callback instance. The loop is passed only + # for the async path so get_prop awaits instead of blocking. ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, shutdown_event, + loop if is_async else None, ) - # Submit callback to executor - future = run_callback_in_executor( - executor, - dash_app, - payload, - ws_cb, - FastAPIResponseAdapter(), + done_handler = make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + shutdown_event, ) - # Set up done callback to send response - future.add_done_callback( - make_callback_done_handler( - outbound_queue, - pending_callbacks, - request_id, - renderer_id, - shutdown_event, + if is_async: + task = asyncio.create_task( + run_callback_on_loop( + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) ) - ) - pending_callbacks[request_id] = future + task.add_done_callback(done_handler) + pending_callbacks[request_id] = task + else: + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) + # Set up done callback to send response + future.add_done_callback(done_handler) + pending_callbacks[request_id] = future elif msg_type == "get_props_response": - # Put response in waiting queue (non-blocking) + # Resolve the waiting future (async path) or queue (thread + # path) for this request (non-blocking). request_id = message.get("requestId") - response_queue = pending_get_props.get(request_id) - if response_queue is not None: - response_queue.put_nowait(message.get("payload")) + pending = pending_get_props.get(request_id) + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(message.get("payload")) + elif pending is not None: + pending.put_nowait(message.get("payload")) elif msg_type == "heartbeat": outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') @@ -819,26 +849,14 @@ async def websocket_handler(websocket: WebSocket): except WebSocketDisconnect: pass # Clean disconnect finally: - # Signal shutdown to worker threads - shutdown_event.set() - # Unblock any threads waiting on get_prop responses - for response_queue in pending_get_props.values(): - response_queue.put_nowait(DISCONNECTED) - # Signal sender to shutdown and cancel it - outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass - # Close the janus queue - outbound_queue.close() - await outbound_queue.wait_closed() - # Cancel any pending futures - for f in pending_callbacks.values(): - f.cancel() - # Shut down this connection's executor (don't block the event loop) - executor.shutdown(wait=False) + await shutdown_ws_connection( + shutdown_event, + pending_get_props, + pending_callbacks, + outbound_queue, + sender_task, + executor, + ) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 0cc8772a76..467dc86353 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -8,7 +8,6 @@ import sys import asyncio import concurrent.futures -import queue import threading from urllib.parse import urlparse @@ -51,9 +50,9 @@ DashWebsocketCallback, run_ws_sender, run_callback_in_executor, + run_callback_on_loop, make_callback_done_handler, - SHUTDOWN_SIGNAL, - DISCONNECTED, + shutdown_ws_connection, ) from ._utils import format_traceback_html @@ -545,6 +544,10 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() + # The connection's event loop, used to dispatch async callbacks as + # tasks and to resolve get_props futures. + loop = asyncio.get_running_loop() + # Track this connection for graceful shutdown try: ws_obj = ws._get_current_object() @@ -555,8 +558,9 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() - # Track pending get_props requests with standard queue.Queue for responses - pending_get_props: Dict[str, queue.Queue] = {} + # Track pending get_props requests. Values are queue.Queue (threadpool / + # sync path) or asyncio.Future (event-loop / async path). + pending_get_props: Dict[str, Any] = {} # Shutdown event to signal connection closure to worker threads connection_shutdown_event = threading.Event() # Create a per-connection thread pool executor so that long-lived @@ -565,8 +569,11 @@ async def websocket_handler(): # pylint: disable=too-many-branches executor = self.create_callback_executor( getattr(dash_app, "_websocket_max_workers", 4) ) - # Track pending callback futures - pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track pending callbacks: concurrent.futures.Future (sync/threadpool) + # or asyncio.Task (async/event-loop). + pending_callbacks: Dict[ + str, concurrent.futures.Future | asyncio.Future + ] = {} # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -612,41 +619,64 @@ async def websocket_handler(): # pylint: disable=too-many-branches dash_app._websocket_callbacks, ) - # Create WebSocket callback instance + # Async callbacks (incl. persistent ones) run as tasks on + # the event loop; sync callbacks go to the threadpool. + # pylint: disable=protected-access + cb_spec = dash_app.callback_map.get(payload.get("output"), {}) + is_async = inspect.iscoroutinefunction(cb_spec.get("callback")) + + # Create WebSocket callback instance. The loop is passed only + # for the async path so get_prop awaits instead of blocking. ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, connection_shutdown_event, + loop if is_async else None, ) - # Submit callback to executor - future = run_callback_in_executor( - executor, - dash_app, - payload, - ws_cb, - QuartResponseAdapter(), + done_handler = make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + connection_shutdown_event, ) - # Set up done callback to send response - future.add_done_callback( - make_callback_done_handler( - outbound_queue, - pending_callbacks, - request_id, - renderer_id, - connection_shutdown_event, + if is_async: + task = asyncio.create_task( + run_callback_on_loop( + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) ) - ) - pending_callbacks[request_id] = future + task.add_done_callback(done_handler) + pending_callbacks[request_id] = task + else: + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) + # Set up done callback to send response + future.add_done_callback(done_handler) + pending_callbacks[request_id] = future elif msg_type == "get_props_response": - # Put response in waiting queue (non-blocking) + # Resolve the waiting future (async path) or queue (thread + # path) for this request (non-blocking). request_id = message.get("requestId") - response_queue = pending_get_props.get(request_id) - if response_queue is not None: - response_queue.put_nowait(message.get("payload")) + pending = pending_get_props.get(request_id) + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(message.get("payload")) + elif pending is not None: + pending.put_nowait(message.get("payload")) elif msg_type == "heartbeat": outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') @@ -657,26 +687,14 @@ async def websocket_handler(): # pylint: disable=too-many-branches pass # Other exceptions treated as disconnect finally: self._active_websockets.discard(ws_obj) - # Signal shutdown to worker threads - connection_shutdown_event.set() - # Unblock any threads waiting on get_prop responses - for response_queue in pending_get_props.values(): - response_queue.put_nowait(DISCONNECTED) - # Signal sender to shutdown and cancel it - outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass - # Close the janus queue - outbound_queue.close() - await outbound_queue.wait_closed() - # Cancel any pending futures - for f in pending_callbacks.values(): - f.cancel() - # Shut down this connection's executor (don't block the event loop) - executor.shutdown(wait=False) + await shutdown_ws_connection( + connection_shutdown_event, + pending_get_props, + pending_callbacks, + outbound_queue, + sender_task, + executor, + ) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/ws.py b/dash/backends/ws.py index d784ea291a..aa8a61974e 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -65,24 +65,30 @@ async def long_running(n_clicks): def __init__( self, - pending_get_props: Dict[str, queue.Queue[Any]], + pending_get_props: Dict[str, Any], renderer_id: str, outbound_queue: janus.Queue[str], shutdown_event: "threading.Event", + loop: "asyncio.AbstractEventLoop | None" = None, ): """Initialize the WebSocket callback interface. Args: - pending_get_props: Dict to track pending get_props requests. - Values are queue.Queue instances for blocking response retrieval. + pending_get_props: Dict to track pending get_props requests. Values are + queue.Queue instances (blocking, thread path) or asyncio.Future + instances (awaitable, event-loop path), keyed by request_id. renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. shutdown_event: Event signaling the websocket connection has closed. + loop: The connection's event loop when the callback runs on the loop + (async dispatch). When set, get_prop uses an awaitable asyncio.Future + instead of blocking on a queue.Queue. None for the threadpool path. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue self._shutdown_event = shutdown_event + self._loop = loop @property def is_shutdown(self) -> bool: @@ -93,7 +99,7 @@ def _get_outbound_queue(self) -> janus.Queue[str] | None: """Get the outbound queue.""" return self._outbound_queue - def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None: + def _get_pending_get_props(self) -> Dict[str, Any] | None: """Get the pending_get_props dict.""" return self._pending_get_props @@ -109,10 +115,12 @@ def _queue_message(self, msg: dict) -> None: if outbound_queue is not None: outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) - async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: - """Send immediate prop update to the client via WebSocket. + def set_prop_sync(self, component_id: str, prop_name: str, value: Any) -> None: + """Queue an immediate prop update to the client (synchronous, non-blocking). - Queues the message for the sender coroutine to send. + Queuing is thread-safe (janus sync side) and never awaits, so this can be + called directly from the event loop or a worker thread without scheduling a + task. ``set_prop`` is the async wrapper kept for backward compatibility. Args: component_id: The component ID (string or stringified dict) @@ -126,12 +134,27 @@ async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: } self._queue_message(msg) + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Queues the message for the sender coroutine to send. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + self.set_prop_sync(component_id, prop_name, value) + async def get_prop( self, component_id: str, prop_name: str, timeout: float = 30.0 ) -> Any: """Request current prop value from the client. - Uses queue.Queue for blocking wait in worker thread. + On the event-loop path (``self._loop`` set, async callbacks) the wait uses an + awaitable ``asyncio.Future`` so the connection loop is never blocked. On the + threadpool path (``self._loop`` is None, sync callbacks) it blocks on a + ``queue.Queue`` in the worker thread. Args: component_id: The component ID (string or stringified dict) @@ -160,24 +183,57 @@ async def get_prop( "payload": {"componentId": component_id, "properties": [prop_name]}, } + if self._loop is not None: + result = await self._get_prop_async(request_id, msg, timeout) + else: + result = await self._get_prop_blocking(request_id, msg, timeout) + + if result == DISCONNECTED: + raise WebsocketDisconnected() + if result and prop_name in result: + return result[prop_name] + return None + + async def _get_prop_async(self, request_id: str, msg: dict, timeout: float) -> Any: + """Await a get_props response on the connection event loop.""" + pending_get_props = self._get_pending_get_props() + future: "asyncio.Future" = self._loop.create_future() # type: ignore[union-attr] + pending_get_props[request_id] = future # type: ignore[index] + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + try: + return await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"Timeout waiting for {msg['payload']['componentId']}." + f"{msg['payload']['properties'][0]}" + ) from exc + finally: + current_pending = self._get_pending_get_props() + if current_pending is not None: + current_pending.pop(request_id, None) + + async def _get_prop_blocking( + self, request_id: str, msg: dict, timeout: float + ) -> Any: + """Block on a get_props response in a worker thread (threadpool path).""" + pending_get_props = self._get_pending_get_props() # Use standard queue.Queue for response response_queue: queue.Queue = queue.Queue() - pending_get_props[request_id] = response_queue + pending_get_props[request_id] = response_queue # type: ignore[index] # Queue the outbound request via janus sync interface self._queue_message(msg) # Wait for response (blocking is OK in worker thread) try: - result = response_queue.get(timeout=timeout) - if result == DISCONNECTED: - raise WebsocketDisconnected() - if result and prop_name in result: - return result[prop_name] - return None + return response_queue.get(timeout=timeout) except queue.Empty as exc: raise TimeoutError( - f"Timeout waiting for {component_id}.{prop_name}" + f"Timeout waiting for {msg['payload']['componentId']}." + f"{msg['payload']['properties'][0]}" ) from exc finally: # Get fresh reference in case of reconnection @@ -323,14 +379,68 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool return False # WebSocketDisconnect, RuntimeError, etc. +async def shutdown_ws_connection( + shutdown_event: threading.Event, + pending_get_props: Dict[str, Any], + pending_callbacks: Dict[str, "concurrent.futures.Future | asyncio.Future"], + outbound_queue: janus.Queue[str], + sender_task: "asyncio.Task", + executor: ThreadPoolExecutor, +) -> None: + """Tear down a WebSocket connection's callback machinery. + + Shared by the FastAPI and Quart handlers so the ordering (which is + correctness-sensitive) stays in one place. Async callback tasks are cancelled + and awaited *before* the outbound queue is closed, so their cleanup can't touch + a closed queue. + + The dicts are snapshotted before iteration: threadpool done-handlers and the + blocking get_prop path pop from them on worker threads, so iterating the live + dicts here can race. + """ + # Signal shutdown to worker threads + shutdown_event.set() + # Unblock anything waiting on get_prop responses (futures or queues) + for pending in list(pending_get_props.values()): + if isinstance(pending, asyncio.Future): + if not pending.done(): + pending.set_result(DISCONNECTED) + else: + pending.put_nowait(DISCONNECTED) + callbacks = list(pending_callbacks.values()) + # Cancel running async callback tasks and let them unwind while the outbound + # queue is still open (their cleanup may touch it). + async_tasks = [f for f in callbacks if isinstance(f, asyncio.Future)] + for f in async_tasks: + f.cancel() + if async_tasks: + await asyncio.gather(*async_tasks, return_exceptions=True) + # Cancel any pending threadpool futures (running ones run to completion) + for f in callbacks: + if not isinstance(f, asyncio.Future): + f.cancel() + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Shut down this connection's executor (don't block the event loop) + executor.shutdown(wait=False) + + def make_callback_done_handler( outbound_queue: janus.Queue[str], - pending_callbacks: Dict[str, concurrent.futures.Future], + pending_callbacks: Dict[str, "concurrent.futures.Future | asyncio.Future"], request_id: str, renderer_id: str, shutdown_event: threading.Event, -) -> Callable[[concurrent.futures.Future], None]: - """Create a done callback handler for executor futures. +) -> Callable[[Any], None]: + """Create a done callback handler for executor futures or event-loop tasks. This factory creates a callback that sends the result back through the WebSocket when an executor future completes. @@ -346,7 +456,7 @@ def make_callback_done_handler( A callback function suitable for Future.add_done_callback() """ - def on_done(f: concurrent.futures.Future) -> None: + def on_done(f: "concurrent.futures.Future | asyncio.Future") -> None: try: if shutdown_event.is_set(): return @@ -364,6 +474,11 @@ def on_done(f: concurrent.futures.Future) -> None: ), ) ) + except asyncio.CancelledError: + # Task cancelled (e.g. on disconnect). CancelledError is a + # BaseException, so it is not caught by the broad except below; + # return here and let the finally clause pop the pending entry. + return except Exception as e: # pylint: disable=broad-exception-caught if shutdown_event.is_set(): return @@ -391,6 +506,27 @@ def on_done(f: concurrent.futures.Future) -> None: return on_done +def _prepare_ws_partial( + dash_app: "dash.Dash", + payload: CallbackExecutionBody, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> Callable[[], Any]: + """Build the callback context and return the partial ready to be invoked. + + Shared by the threadpool (sync) and event-loop (async) dispatch paths. + """ + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + return dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + + def run_callback_in_executor( executor: ThreadPoolExecutor, dash_app: "dash.Dash", @@ -398,10 +534,10 @@ def run_callback_in_executor( ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", ) -> concurrent.futures.Future: - """Submit callback to executor for thread pool execution. + """Submit a synchronous callback to the executor for thread pool execution. - This function creates a callback execution context and runs it - in a separate thread. Both sync and async callbacks are supported. + Used for sync callbacks. Async callbacks are dispatched on the event loop via + ``run_callback_on_loop`` instead, so they no longer occupy a worker thread. Args: executor: ThreadPoolExecutor to submit the task to @@ -416,21 +552,14 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals( # pylint: disable=protected-access - cb_ctx.inputs_list + cb_ctx.states_list + partial_func = _prepare_ws_partial( + dash_app, payload, ws_callback, response_adapter ) ctx = copy_context() - partial_func = ( - dash_app._execute_callback( # pylint: disable=protected-access - func, args, cb_ctx.outputs_list, cb_ctx - ) - ) - # Run in new event loop (handles both sync and async callbacks) + # Run in new event loop (handles a callback that still returns a + # coroutine, e.g. when reached outside the async dispatch path) def run_callback(): result = partial_func() if inspect.iscoroutine(result): @@ -449,3 +578,46 @@ def run_callback(): return {"status": "error", "message": str(e)} return executor.submit(execute) + + +async def run_callback_on_loop( + dash_app: "dash.Dash", + payload: CallbackExecutionBody, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> dict: + """Run an async callback as a task on the connection's event loop. + + This is the event-loop counterpart of ``run_callback_in_executor``: instead of + pinning a worker thread with ``asyncio.run``, the callback coroutine is awaited + directly. Persistent callbacks that await (subscriptions, sleeps) yield the loop, + so hundreds can coexist without exhausting the threadpool. + + No explicit ``copy_context`` is needed: the ``asyncio.Task`` wrapping this + coroutine copies the current context at creation, so each callback's + ``context_value`` mutations stay isolated from the handler and sibling tasks. + + Args: + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + The response dict (same shape as ``run_callback_in_executor``). + """ + try: + partial_func = _prepare_ws_partial( + dash_app, payload, ws_callback, response_adapter + ) + result = partial_func() + response_data = await result if inspect.iscoroutine(result) else result + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except WebsocketDisconnected: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} diff --git a/tests/websocket/test_ws_threadpool.py b/tests/websocket/test_ws_threadpool.py new file mode 100644 index 0000000000..9395412175 --- /dev/null +++ b/tests/websocket/test_ws_threadpool.py @@ -0,0 +1,118 @@ +""" +WebSocket callback dispatch tests: async callbacks run on the event loop, sync +callbacks run on the per-connection threadpool. + +Tests: +- Many long-lived async (persistent-style) callbacks do not exhaust the worker + threadpool, so regular callbacks still respond (thread-exhaustion regression). +- A synchronous persistent (no-output) callback warns at registration. +""" + +import asyncio + +import pytest + +from dash import Dash, html, Input, Output, ctx, set_props +from dash.exceptions import PreventUpdate + + +def test_ws050_async_callbacks_do_not_exhaust_threadpool(dash_duo): + """Many long-lived async callbacks must not starve regular callbacks. + + On the old dispatch, every async callback ran via ``asyncio.run`` inside a + worker thread, so a long-lived (never-returning) async callback pinned one of + the ``max_workers=4`` threads for the whole connection. Five of them filled the + pool and wedged regular callbacks ("Loading…"). Async callbacks now run as tasks + on the connection event loop, so they cost ~nothing and the threadpool stays + free for sync callbacks. + """ + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + n_long = 6 # > default max_workers (4) + + app.layout = html.Div( + [ + html.Button("Start long tasks", id="start", n_clicks=0), + html.Button("Regular", id="reg-btn", n_clicks=0), + html.Div("idle", id="reg-out"), + *[html.Div("idle", id=f"long-{i}") for i in range(n_long)], + ] + ) + + def make_long_callback(i): + @app.callback( + Output(f"long-{i}", "children"), + Input("start", "n_clicks"), + prevent_initial_call=True, + ) + async def _long(n): + ws = ctx.websocket + set_props(f"long-{i}", {"children": "running"}) + # Long-lived: loops for ~12s, yielding the loop on every iteration. + for _ in range(60): + if ws and ws.is_shutdown: + raise PreventUpdate + await asyncio.sleep(0.2) + return "done" + + for i in range(n_long): + make_long_callback(i) + + # A regular synchronous callback that must keep responding while the long + # async callbacks are running. + @app.callback( + Output("reg-out", "children"), + Input("reg-btn", "n_clicks"), + prevent_initial_call=True, + ) + def regular(n): + return f"ok {n}" + + dash_duo.start_server(app) + + # Kick off all the long-lived async callbacks. + dash_duo.find_element("#start").click() + # They should all reach the "running" state (would not all start on dev with + # only 4 worker threads if they pinned threads). + for i in range(n_long): + dash_duo.wait_for_text_to_equal(f"#long-{i}", "running", timeout=10) + + # The regular callback must still respond promptly while the long tasks run. + dash_duo.find_element("#reg-btn").click() + dash_duo.wait_for_text_to_equal("#reg-out", "ok 1", timeout=5) + + assert dash_duo.get_logs() == [] + + +def test_ws051_sync_persistent_callback_warns(): + """A synchronous persistent (no-output) callback warns at registration. + + Registered on a local app (not the global registry) so it can't leak phantom + callbacks into later tests. + """ + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + with pytest.warns(RuntimeWarning, match="persistent=True"): + + @app.callback( + Input("trigger", "n_clicks"), + persistent=True, + websocket=True, + ) + def _sync_persistent(n): # pragma: no cover - never executed + set_props("out", {"children": "x"}) + + +def test_ws052_async_persistent_callback_does_not_warn(recwarn): + """An async persistent (no-output) callback must not warn.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + @app.callback( + Input("trigger2", "n_clicks"), + persistent=True, + websocket=True, + ) + async def _async_persistent(n): # pragma: no cover - never executed + set_props("out", {"children": "x"}) + + assert not [w for w in recwarn.list if issubclass(w.category, RuntimeWarning)] From 053212a9ac6006e19c64dfab3b634f5d80133939 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 29 Jun 2026 14:25:40 -0400 Subject: [PATCH 3/4] back to global websocket threadpool --- CHANGELOG.md | 2 +- dash/backends/_fastapi.py | 12 +++---- dash/backends/_quart.py | 12 +++---- dash/backends/base_server.py | 32 ++++++++++++----- dash/backends/ws.py | 6 ++-- tests/unit/test_websocket_executor.py | 50 +++++++++++++++++---------- tests/websocket/test_ws_threadpool.py | 2 +- 7 files changed, 71 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed6c502b7b..c4b0f0cfc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [UNRELEASED] ### Added -- Per-connection WebSocket callback thread pools. Each WebSocket connection now gets its own `ThreadPoolExecutor` instead of sharing a single app-wide pool, so long-lived (session-persistent) callbacks on one connection no longer limit the number of concurrent users. The per-connection size is configurable via the new `websocket_max_workers` argument to `Dash` (default `4`). +- [#3826](https://github.com/plotly/dash/pull/3826) WebSocket callback dispatch no longer lets long-lived callbacks limit the number of concurrent users. Async callbacks (including session-persistent ones) run directly on the connection event loop instead of occupying a worker thread, and synchronous callbacks run on a shared `ThreadPoolExecutor` whose size is configurable via the new `websocket_max_workers` argument to `Dash` (default `4`). A synchronous persistent (no-output) callback now warns at registration since it would tie up a worker thread. ## [4.3.0] - 2026-06-18 diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 53a93f1e7d..b69786cf12 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures import json +import queue from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -732,13 +733,13 @@ async def websocket_handler(websocket: WebSocket): outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests. Values are queue.Queue (threadpool / # sync path) or asyncio.Future (event-loop / async path). - pending_get_props: Dict[str, Any] = {} + pending_get_props: Dict[str, queue.Queue | asyncio.Future] = {} # Shutdown event to signal connection closure to worker threads shutdown_event = threading.Event() - # Create a per-connection thread pool executor so that long-lived - # callbacks on one connection cannot starve worker threads for others. - # pylint: disable=protected-access - executor = self.create_callback_executor( + # Sync callbacks run on a shared thread pool executor (async callbacks + # run directly on the event loop). A single bounded pool caps the total + # worker-thread count regardless of how many connections are open. + executor = self.get_callback_executor( getattr(dash_app, "_websocket_max_workers", 4) ) # Track pending callbacks: concurrent.futures.Future (sync/threadpool) @@ -855,7 +856,6 @@ async def websocket_handler(websocket: WebSocket): pending_callbacks, outbound_queue, sender_task, - executor, ) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 467dc86353..e7bb1c3c2f 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -8,6 +8,7 @@ import sys import asyncio import concurrent.futures +import queue import threading from urllib.parse import urlparse @@ -560,13 +561,13 @@ async def websocket_handler(): # pylint: disable=too-many-branches outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests. Values are queue.Queue (threadpool / # sync path) or asyncio.Future (event-loop / async path). - pending_get_props: Dict[str, Any] = {} + pending_get_props: Dict[str, queue.Queue | asyncio.Future] = {} # Shutdown event to signal connection closure to worker threads connection_shutdown_event = threading.Event() - # Create a per-connection thread pool executor so that long-lived - # callbacks on one connection cannot starve worker threads for others. - # pylint: disable=protected-access - executor = self.create_callback_executor( + # Sync callbacks run on a shared thread pool executor (async callbacks + # run directly on the event loop). A single bounded pool caps the total + # worker-thread count regardless of how many connections are open. + executor = self.get_callback_executor( getattr(dash_app, "_websocket_max_workers", 4) ) # Track pending callbacks: concurrent.futures.Future (sync/threadpool) @@ -693,7 +694,6 @@ async def websocket_handler(): # pylint: disable=too-many-branches pending_callbacks, outbound_queue, sender_task, - executor, ) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index ed06663c14..5443662dd2 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -189,16 +189,18 @@ def __init__(self, server: ServerType) -> None: """ super().__init__() self.server = server + self._callback_executor: ThreadPoolExecutor | None = None - def create_callback_executor( + def get_callback_executor( self, max_workers: int | None = None ) -> ThreadPoolExecutor: - """Create a new thread pool executor for callback execution. + """Get or create the shared thread pool executor for sync callbacks. - A fresh executor is created per WebSocket connection so that long-lived - (session-persistent) callbacks on one connection cannot exhaust worker - threads shared with other connections. The executor should be shut down - when its connection closes. + A single executor is shared across all WebSocket connections. Only + *sync* callbacks run here -- async callbacks (including session-persistent + ones) run directly on the connection event loop -- so worker threads are + released promptly and a fixed-size shared pool bounds the total thread + count regardless of how many connections are open. Args: max_workers: Maximum number of worker threads. If None, uses default. @@ -206,9 +208,21 @@ def create_callback_executor( Returns: ThreadPoolExecutor instance for running callbacks. """ - return ThreadPoolExecutor( - max_workers=max_workers, thread_name_prefix="dash-callback-" - ) + if self._callback_executor is None: + self._callback_executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) + return self._callback_executor + + def shutdown_executor(self, wait: bool = True) -> None: + """Shutdown the shared callback executor. + + Args: + wait: If True, wait for pending tasks to complete. + """ + if self._callback_executor is not None: + self._callback_executor.shutdown(wait=wait) + self._callback_executor = None def __call__(self, *args, **kwargs) -> Any: """Make the server wrapper callable as a WSGI/ASGI application. diff --git a/dash/backends/ws.py b/dash/backends/ws.py index aa8a61974e..2db1289ec6 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -385,7 +385,6 @@ async def shutdown_ws_connection( pending_callbacks: Dict[str, "concurrent.futures.Future | asyncio.Future"], outbound_queue: janus.Queue[str], sender_task: "asyncio.Task", - executor: ThreadPoolExecutor, ) -> None: """Tear down a WebSocket connection's callback machinery. @@ -394,6 +393,9 @@ async def shutdown_ws_connection( and awaited *before* the outbound queue is closed, so their cleanup can't touch a closed queue. + The callback executor is intentionally *not* shut down here: it is a single + pool shared across all connections, so it outlives any one connection. + The dicts are snapshotted before iteration: threadpool done-handlers and the blocking get_prop path pop from them on worker threads, so iterating the live dicts here can race. @@ -429,8 +431,6 @@ async def shutdown_ws_connection( # Close the janus queue outbound_queue.close() await outbound_queue.wait_closed() - # Shut down this connection's executor (don't block the event loop) - executor.shutdown(wait=False) def make_callback_done_handler( diff --git a/tests/unit/test_websocket_executor.py b/tests/unit/test_websocket_executor.py index b693469d1d..5ab957768c 100644 --- a/tests/unit/test_websocket_executor.py +++ b/tests/unit/test_websocket_executor.py @@ -1,10 +1,11 @@ -"""Unit tests for the per-connection WebSocket callback thread pool. - -These verify that each WebSocket connection gets its own ThreadPoolExecutor -(rather than a single shared, app-wide pool), so that long-lived -(session-persistent) callbacks on one connection cannot exhaust worker threads -shared with other connections, and that the per-connection size is configurable -via the ``websocket_max_workers`` argument to ``Dash``. +"""Unit tests for the shared WebSocket callback thread pool. + +These verify that a single app-wide ``ThreadPoolExecutor`` is shared across all +WebSocket connections. Only *sync* callbacks run on it -- async (incl. +session-persistent) callbacks run directly on the event loop -- so a fixed-size +shared pool bounds the total worker-thread count regardless of how many +connections are open. The pool size is configurable via the +``websocket_max_workers`` argument to ``Dash``. """ from concurrent.futures import ThreadPoolExecutor @@ -24,28 +25,39 @@ def test_websocket_max_workers_custom(): assert app._websocket_max_workers == 16 -def test_create_callback_executor_is_per_connection(): - """Each call returns a fresh executor, not a cached shared one.""" +def test_get_callback_executor_is_shared(): + """Repeated calls return the same cached, app-wide executor.""" backend = Dash(__name__).backend - ex1 = backend.create_callback_executor(4) - ex2 = backend.create_callback_executor(4) + ex1 = backend.get_callback_executor(4) + ex2 = backend.get_callback_executor(4) try: assert isinstance(ex1, ThreadPoolExecutor) - assert isinstance(ex2, ThreadPoolExecutor) - # Distinct instances => one connection's pool can't starve another's. - assert ex1 is not ex2 + # Same instance => total thread count is bounded across connections. + assert ex1 is ex2 finally: - ex1.shutdown(wait=False) - ex2.shutdown(wait=False) + backend.shutdown_executor(wait=False) -def test_create_callback_executor_honors_max_workers(): +def test_get_callback_executor_honors_max_workers(): """max_workers is forwarded to the ThreadPoolExecutor.""" backend = Dash(__name__).backend - ex = backend.create_callback_executor(7) + ex = backend.get_callback_executor(7) try: assert ex._max_workers == 7 finally: - ex.shutdown(wait=False) + backend.shutdown_executor(wait=False) + + +def test_shutdown_executor_allows_recreation(): + """After shutdown the next get_callback_executor call creates a fresh pool.""" + backend = Dash(__name__).backend + + ex1 = backend.get_callback_executor(4) + backend.shutdown_executor(wait=False) + ex2 = backend.get_callback_executor(4) + try: + assert ex1 is not ex2 + finally: + backend.shutdown_executor(wait=False) diff --git a/tests/websocket/test_ws_threadpool.py b/tests/websocket/test_ws_threadpool.py index 9395412175..123590e403 100644 --- a/tests/websocket/test_ws_threadpool.py +++ b/tests/websocket/test_ws_threadpool.py @@ -1,6 +1,6 @@ """ WebSocket callback dispatch tests: async callbacks run on the event loop, sync -callbacks run on the per-connection threadpool. +callbacks run on the shared threadpool. Tests: - Many long-lived async (persistent-style) callbacks do not exhaust the worker From abbb592f87ed1d463d5b1d09f937cda1d708de24 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 29 Jun 2026 16:13:33 -0400 Subject: [PATCH 4/4] typing fixes --- dash/backends/_fastapi.py | 6 +++--- dash/backends/_quart.py | 6 +++--- dash/backends/ws.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 22aa1d1e43..6a23c4873b 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -5,7 +5,7 @@ import concurrent.futures import json import queue -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict, List import sys import mimetypes import hashlib @@ -263,8 +263,8 @@ def __init__(self, server: FastAPI): self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter self.response_adapter = FastAPIResponseAdapter - self._before_request_funcs = [] - self._after_request_func = None + self._before_request_funcs: List[Callable[[], Any]] = [] + self._after_request_func: Callable[[], Any] | None = None self._enable_timing = False def __call__(self, *args: Any, **kwargs: Any): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index e7bb1c3c2f..dde98061d1 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -95,7 +95,7 @@ class QuartDashServer(BaseDashServer[Quart]): def __init__(self, server: Quart) -> None: super().__init__(server) self.server_type = "quart" - self.config = {} + self.config: Dict[str, Any] = {} self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter self.response_adapter = QuartResponseAdapter @@ -150,7 +150,7 @@ async def _wrap_errors(error): tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") - def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + def register_timing_hooks(self, _first_run: bool): # type: ignore[override] # parity with Flask factory @self.server.before_request async def _before_request(): # pragma: no cover - timing infra if quart_g is not None: @@ -382,7 +382,7 @@ def add_redirect_rule(self, app, fullname, path): ) # pylint: disable=unused-argument - def serve_callback(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + def serve_callback(self, dash_app: Dash): # type: ignore[override] # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 2db1289ec6..6d0168ecda 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -418,9 +418,9 @@ async def shutdown_ws_connection( if async_tasks: await asyncio.gather(*async_tasks, return_exceptions=True) # Cancel any pending threadpool futures (running ones run to completion) - for f in callbacks: - if not isinstance(f, asyncio.Future): - f.cancel() + for cb in callbacks: + if not isinstance(cb, asyncio.Future): + cb.cancel() # Signal sender to shutdown and cancel it outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) sender_task.cancel()