|
2 | 2 | import logging |
3 | 3 | import os |
4 | 4 | import sys |
| 5 | +from collections.abc import AsyncIterator, Callable, Coroutine |
5 | 6 | from contextlib import asynccontextmanager |
6 | 7 | from pathlib import Path |
7 | | -from typing import Literal, TextIO |
| 8 | +from typing import Any, Literal, TextIO |
8 | 9 |
|
9 | 10 | import anyio |
10 | 11 | import anyio.lowlevel |
@@ -103,6 +104,89 @@ class StdioServerParameters(BaseModel): |
103 | 104 | """ |
104 | 105 |
|
105 | 106 |
|
| 107 | +@asynccontextmanager |
| 108 | +async def _asyncio_background_tasks( |
| 109 | + stdout_reader: Callable[[], Coroutine[Any, Any, None]], |
| 110 | + stdin_writer: Callable[[], Coroutine[Any, Any, None]], |
| 111 | + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception], |
| 112 | + write_stream: MemoryObjectSendStream[SessionMessage], |
| 113 | +) -> AsyncIterator[None]: |
| 114 | + """Spawn the stdio reader/writer as top-level asyncio tasks (see #577). |
| 115 | +
|
| 116 | + The tasks are detached from the caller's cancel-scope stack, which |
| 117 | + is what lets callers clean up multiple transports in arbitrary |
| 118 | + order without tripping anyio's LIFO cancel-scope check. |
| 119 | +
|
| 120 | + If a background task crashes while the caller is still inside the |
| 121 | + yield, the memory streams are closed via ``add_done_callback`` so |
| 122 | + in-flight reads wake up with ``ClosedResourceError`` instead of |
| 123 | + hanging forever. Any non-cancellation, non-closed-resource |
| 124 | + exception from the tasks is re-raised on exit so crashes do not |
| 125 | + go unnoticed — matching the exception propagation an anyio task |
| 126 | + group would have given. |
| 127 | + """ |
| 128 | + |
| 129 | + def _on_done(task: asyncio.Task[None]) -> None: |
| 130 | + if task.cancelled(): |
| 131 | + return |
| 132 | + exc = task.exception() |
| 133 | + if exc is None: |
| 134 | + return |
| 135 | + logger.debug( |
| 136 | + "stdio_client background task raised %s — closing streams to wake up caller", |
| 137 | + type(exc).__name__, |
| 138 | + exc_info=exc, |
| 139 | + ) |
| 140 | + for stream in (read_stream_writer, write_stream): |
| 141 | + try: |
| 142 | + stream.close() |
| 143 | + except Exception: # pragma: no cover |
| 144 | + pass |
| 145 | + |
| 146 | + stdout_task: asyncio.Task[None] = asyncio.create_task(stdout_reader()) |
| 147 | + stdin_task: asyncio.Task[None] = asyncio.create_task(stdin_writer()) |
| 148 | + stdout_task.add_done_callback(_on_done) |
| 149 | + stdin_task.add_done_callback(_on_done) |
| 150 | + tasks = (stdout_task, stdin_task) |
| 151 | + try: |
| 152 | + yield |
| 153 | + finally: |
| 154 | + for task in tasks: |
| 155 | + if not task.done(): |
| 156 | + task.cancel() |
| 157 | + pending_exc: BaseException | None = None |
| 158 | + for task in tasks: |
| 159 | + try: |
| 160 | + await task |
| 161 | + except asyncio.CancelledError: |
| 162 | + pass |
| 163 | + except anyio.ClosedResourceError: |
| 164 | + pass |
| 165 | + except BaseException as exc: # noqa: BLE001 |
| 166 | + if pending_exc is None: |
| 167 | + pending_exc = exc |
| 168 | + if pending_exc is not None: |
| 169 | + raise pending_exc |
| 170 | + |
| 171 | + |
| 172 | +@asynccontextmanager |
| 173 | +async def _anyio_task_group_background( |
| 174 | + stdout_reader: Callable[[], Coroutine[Any, Any, None]], |
| 175 | + stdin_writer: Callable[[], Coroutine[Any, Any, None]], |
| 176 | +) -> AsyncIterator[None]: |
| 177 | + """Structured-concurrency fallback for backends other than asyncio. |
| 178 | +
|
| 179 | + Trio forbids orphan tasks by design, so the historical task-group |
| 180 | + pattern is retained here. Callers on trio must clean up multiple |
| 181 | + transports in LIFO order; cross-task cleanup (#577) cannot be |
| 182 | + fixed on that backend without violating its concurrency model. |
| 183 | + """ |
| 184 | + async with anyio.create_task_group() as tg: |
| 185 | + tg.start_soon(stdout_reader) |
| 186 | + tg.start_soon(stdin_writer) |
| 187 | + yield |
| 188 | + |
| 189 | + |
106 | 190 | @asynccontextmanager |
107 | 191 | async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): |
108 | 192 | """Client transport for stdio: this will connect to a server by spawning a |
@@ -215,75 +299,15 @@ async def _cleanup_process_and_streams() -> None: |
215 | 299 | # design, and cross-task cleanup is fundamentally incompatible with |
216 | 300 | # that model, so callers on trio still have to clean up LIFO. |
217 | 301 | if sniffio.current_async_library() == "asyncio": |
218 | | - |
219 | | - def _on_background_task_done(task: asyncio.Task[None]) -> None: |
220 | | - """ |
221 | | - If a background reader/writer crashes while the caller is |
222 | | - still using the transport, close the memory streams so that |
223 | | - any in-flight user read wakes up with ``ClosedResourceError`` |
224 | | - instead of hanging forever. An anyio task group would have |
225 | | - produced the same effect via scope cancellation — this |
226 | | - restores that observability on the asyncio path. |
227 | | - """ |
228 | | - if task.cancelled(): |
229 | | - return |
230 | | - exc = task.exception() |
231 | | - if exc is None: |
232 | | - return |
233 | | - logger.debug( |
234 | | - "stdio_client background task raised %s — closing streams to wake up caller", |
235 | | - type(exc).__name__, |
236 | | - exc_info=exc, |
237 | | - ) |
238 | | - for stream in (read_stream_writer, write_stream): |
239 | | - try: |
240 | | - stream.close() |
241 | | - except Exception: # pragma: no cover |
242 | | - pass |
243 | | - |
244 | | - async with process: |
245 | | - stdout_task: asyncio.Task[None] = asyncio.create_task(stdout_reader()) |
246 | | - stdin_task: asyncio.Task[None] = asyncio.create_task(stdin_writer()) |
247 | | - stdout_task.add_done_callback(_on_background_task_done) |
248 | | - stdin_task.add_done_callback(_on_background_task_done) |
249 | | - background_tasks = (stdout_task, stdin_task) |
250 | | - try: |
251 | | - yield read_stream, write_stream |
252 | | - finally: |
253 | | - try: |
254 | | - await _cleanup_process_and_streams() |
255 | | - finally: |
256 | | - for task in background_tasks: |
257 | | - if not task.done(): |
258 | | - task.cancel() |
259 | | - # Collect results; swallow CancelledError (expected for |
260 | | - # teardown) and anyio.ClosedResourceError (surfaced when |
261 | | - # we closed the streams out from under the reader/writer |
262 | | - # during cleanup). Re-raise anything else so a crash in |
263 | | - # the background does not go unnoticed — matching the |
264 | | - # exception propagation we'd get from an anyio task |
265 | | - # group on the trio path. |
266 | | - pending_exc: BaseException | None = None |
267 | | - for task in background_tasks: |
268 | | - try: |
269 | | - await task |
270 | | - except asyncio.CancelledError: |
271 | | - pass |
272 | | - except anyio.ClosedResourceError: |
273 | | - pass |
274 | | - except BaseException as exc: # noqa: BLE001 |
275 | | - if pending_exc is None: |
276 | | - pending_exc = exc |
277 | | - if pending_exc is not None: |
278 | | - raise pending_exc |
| 302 | + bg_cm = _asyncio_background_tasks(stdout_reader, stdin_writer, read_stream_writer, write_stream) |
279 | 303 | else: |
280 | | - async with anyio.create_task_group() as tg, process: |
281 | | - tg.start_soon(stdout_reader) |
282 | | - tg.start_soon(stdin_writer) |
283 | | - try: |
284 | | - yield read_stream, write_stream |
285 | | - finally: |
286 | | - await _cleanup_process_and_streams() |
| 304 | + bg_cm = _anyio_task_group_background(stdout_reader, stdin_writer) |
| 305 | + |
| 306 | + async with bg_cm, process: |
| 307 | + try: |
| 308 | + yield read_stream, write_stream |
| 309 | + finally: |
| 310 | + await _cleanup_process_and_streams() |
287 | 311 |
|
288 | 312 |
|
289 | 313 | def _get_executable_command(command: str) -> str: |
|
0 commit comments