diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a5fb5b7..f87750d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,24 +12,34 @@ jobs: name: test runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - build: [linux_3.12, windows_3.12, mac_3.12, linux_3.13, linux_3.14] + build: [linux_3.12, windows_3.12, mac_3.12, linux_3.13, linux_3.14, linux_3.15] include: - build: linux_3.12 os: ubuntu-latest - python: 3.12 + python: '3.12' + allow-prereleases: false - build: linux_3.13 os: ubuntu-latest - python: 3.13 + python: '3.13' + allow-prereleases: false - build: linux_3.14 os: ubuntu-latest - python: 3.14 + python: '3.14' + allow-prereleases: false + - build: linux_3.15 + os: ubuntu-latest + python: '3.15' + allow-prereleases: true - build: windows_3.12 os: windows-latest - python: 3.12 + python: '3.12' + allow-prereleases: false - build: mac_3.12 os: macos-latest - python: 3.12 + python: '3.12' + allow-prereleases: false steps: - name: Checkout repository uses: actions/checkout@v4 @@ -38,6 +48,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + allow-prereleases: ${{ matrix.allow-prereleases }} - name: Install dependencies run: | diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 15e93dd..b47d772 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -25,7 +25,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.x' + python-version: '3.12' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 32898a2..3b7c296 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,9 @@ app.add_middleware( "max_overflow": 10, # number of connections to allow to be opened above pool_size }, ) +# Engines created from ``db_url`` are owned by the middleware and are disposed +# during the application shutdown lifespan. Tests that need shutdown behavior +# should run the app lifespan, for example with ``with TestClient(app)``. # once the middleware is applied, any route can then access the database session # from the global ``db`` @@ -82,6 +85,97 @@ if __name__ == "__main__": ``` +#### Engine ownership + +When the middleware receives ``db_url``, it creates and owns the async engine. +The engine is kept for the application lifetime and disposed when the ASGI +lifespan shutdown completes. It is not disposed per request. Disposal also +runs when the lifespan ends with a failure (``lifespan.shutdown.failed`` or +``lifespan.startup.failed``), so a raising user shutdown handler does not leak +the connection pool. + +Engine disposal happens before the lifespan acknowledgement is forwarded to +the ASGI server, so a stuck pool drain will block the server's graceful +shutdown ack. Configure your ASGI server's graceful shutdown timeout (for +example uvicorn's ``--timeout-graceful-shutdown``) so it accommodates the +worst-case time required to close active connections. + +When the middleware receives ``custom_engine``, the caller owns that engine. The +middleware will use it but will not dispose it during application shutdown: + +```python +from sqlalchemy.ext.asyncio import create_async_engine + +engine = create_async_engine("postgresql+asyncpg://user:pass@host/db") +app.add_middleware(SQLAlchemyMiddleware, custom_engine=engine) + +# Later, in caller-managed shutdown code or test cleanup: +await engine.dispose() +``` + +#### Manual disposal outside ASGI lifespan + +When ``SQLAlchemyMiddleware(db_url=...)`` is constructed outside an ASGI +application lifespan — for example in a script, an ad-hoc test harness, or +when embedding the middleware in a non-ASGI runtime — there is no +``lifespan.shutdown`` event to trigger engine disposal. In that case call +``await middleware.dispose()`` explicitly so the middleware-owned engine is +released: + +```python +middleware = SQLAlchemyMiddleware(app, db_url="postgresql+asyncpg://...") +try: + ... # use db.session +finally: + await middleware.dispose() +``` + +``dispose()`` is idempotent on success and is safe to retry if it raises: +the proxy session bindings are cleared deterministically so a subsequent +call actually re-attempts the underlying ``engine.dispose()``. The same +guidance applies to each pair created by +``create_middleware_and_session_proxy()``. + +#### Request transactions and streaming responses + +When ``SQLAlchemyMiddleware(..., commit_on_exit=True)`` manages a normal +non-streaming HTTP request, the request session is committed before +``http.response.start`` is forwarded to the ASGI server. If commit, rollback, +or close fails, the failure happens before a successful response is reported to +the client. + +Streaming response body generation has a different lifetime from a normal +request transaction. Do not rely on the middleware-managed request session to +stay open while a ``StreamingResponse``/``FileResponse`` yields chunks. Open an +explicit session inside the generator so the body owns the database lifetime: + +```python +from fastapi.responses import StreamingResponse + +@app.get("/export") +async def export(): + async def rows(): + async with db(): + result = await db.session.stream(foo.select()) + async for row in result: + yield f"{row.id}\n".encode() + return StreamingResponse(rows(), media_type="text/plain") +``` + +Implicit ``commit_on_exit=True`` is not a safe way to report streaming write +success: the response may have already started before an unbounded body is +finished. If a streaming route needs database writes, either complete and +commit the write in a separate explicit ``async with db(commit_on_exit=True)`` +block before creating the streaming response, or make the streaming generator +use an explicit ``async with db(commit_on_exit=True)`` block and design the API +so clients do not treat early chunks as write success. + +For applications that previously used ``db.session`` directly inside streaming +generators, move that code into an explicit generator-owned context as shown +above. This keeps database access available for the whole body while making it +clear that the session lifetime belongs to the stream, not the original request +transaction. + #### Usage of multiple databases databases.py @@ -94,6 +188,10 @@ FirstSQLAlchemyMiddleware, first_db = create_middleware_and_session_proxy() SecondSQLAlchemyMiddleware, second_db = create_middleware_and_session_proxy() ``` +Use a separate middleware/session proxy pair for each independent app or +database. Reusing the same proxy with a different live engine is rejected so +requests cannot silently switch to another database binding. + main.py ```python @@ -152,9 +250,10 @@ async def get_files_from_second_db(): @router.get("/concurrent-queries") async def parallel_select(): - async with first_db(multi_sessions=True): + async with first_db(multi_sessions=True, max_concurrent=10): async def execute_query(query): - return await first_db.session.execute(text(query)) + async with first_db.connection() as session: + return await session.execute(text(query)) tasks = [ asyncio.create_task(execute_query("SELECT 1")), @@ -167,3 +266,10 @@ async def parallel_select(): await asyncio.gather(*tasks) ``` + +Child tasks that use database sessions must finish before the owning +``async with db(multi_sessions=True)`` block exits. When ``max_concurrent`` is +set, child tasks should use ``db.connection()`` or pass coroutine objects to +``db.gather()`` so the middleware can own both the session lifetime and the +semaphore slot. Already-created ``Task`` or ``Future`` objects are rejected by +throttled ``db.gather()`` because they may have started outside the semaphore. diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 036ef5c..ab3d50e 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -19,4 +19,4 @@ "DBSessionType", ] -__version__ = "0.7.2a1" +__version__ = "0.8.0a1" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index f6d88fc..d557c57 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,26 +1,25 @@ from __future__ import annotations import asyncio +import logging import warnings from contextvars import ContextVar from dataclasses import dataclass, field from sqlalchemy.engine.url import URL -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request -from starlette.types import ASGIApp +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from starlette.types import ASGIApp, Message, Receive, Scope, Send from fastapi_async_sqlalchemy.exceptions import ( MissingSessionError, SessionNotInitialisedError, ) -try: - from sqlalchemy.ext.asyncio import async_sessionmaker -except ImportError: - from sqlalchemy.orm import sessionmaker as async_sessionmaker # type: ignore - try: from sqlmodel.ext.asyncio.session import AsyncSession as SQLModelAsyncSession @@ -31,14 +30,27 @@ def create_middleware_and_session_proxy() -> tuple: _Session: async_sessionmaker | None = None + _Session_engine: AsyncEngine | None = None _session: ContextVar[AsyncSession | None] = ContextVar("_session", default=None) + _request_session: ContextVar[AsyncSession | None] = ContextVar( + "_request_session", + default=None, + ) + _request_session_used: ContextVar[bool] = ContextVar( + "_request_session_used", + default=False, + ) + _request_session_closed_for_streaming: ContextVar[bool] = ContextVar( + "_request_session_closed_for_streaming", + default=False, + ) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) _multi_state: ContextVar[_MultiSessionState | None] = ContextVar( "_multi_sessions_state", default=None, ) - @dataclass + @dataclass(slots=True) class _MultiSessionState: tracked: set[AsyncSession] = field(default_factory=set) task_sessions: dict[asyncio.Task, AsyncSession] = field(default_factory=dict) @@ -48,6 +60,8 @@ class _MultiSessionState: session_args: dict = field(default_factory=dict) semaphore: asyncio.Semaphore | None = None slot_holders: set[asyncio.Task] = field(default_factory=set) + waiters: set[asyncio.Task] = field(default_factory=set) + closing: bool = False def _cleanup_error(error: BaseException) -> str: return f"{type(error).__name__}: {error}" @@ -61,6 +75,19 @@ def _raise_cleanup_errors(errors: list[BaseException]) -> None: details = "; ".join(_cleanup_error(error) for error in errors) raise RuntimeError(f"Session cleanup failed with {len(errors)} errors: {details}") + def _mark_request_session_used(session: AsyncSession) -> None: + if session is not _request_session.get(): + return + + if _request_session_closed_for_streaming.get(): + raise RuntimeError( + "The middleware-managed request database session is closed for streaming " + "response body generation. Use `async with db()` inside the streaming " + "generator to make the session lifetime explicit." + ) + + _request_session_used.set(True) + async def _finalize_session( session: AsyncSession, commit_on_exit: bool, @@ -122,6 +149,12 @@ async def __aenter__(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions and self._state is not None: + if self._state.closing: + raise RuntimeError( + "Cannot create a db.connection() session after the owning " + "multi-session context has started closing." + ) + task = asyncio.current_task() # Reuse existing session for this task @@ -132,7 +165,24 @@ async def __aenter__(self) -> AsyncSession: # Acquire pool slot only when this context creates a new session. if self._semaphore: - await self._semaphore.acquire() + if task is not None: + self._state.waiters.add(task) + try: + await self._semaphore.acquire() + finally: + if task is not None: + self._state.waiters.discard(task) + + # Re-check closing: parent may have started shutdown while we + # were parked on acquire(). If so, release the slot we just + # took so it doesn't leak, and refuse to create a session. + if self._state.closing: + self._semaphore.release() + raise RuntimeError( + "Cannot create a db.connection() session after the owning " + "multi-session context has started closing." + ) + self._acquired_slot = True if task is not None: self._state.slot_holders.add(task) @@ -158,6 +208,7 @@ async def __aenter__(self) -> AsyncSession: session = _session.get() if session is None: raise MissingSessionError + _mark_request_session_used(session) self._session = session self._owns_session = False return session @@ -182,7 +233,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if self._acquired_slot and self._semaphore: self._semaphore.release() - class _SQLAlchemyMiddleware(BaseHTTPMiddleware): + class _SQLAlchemyMiddleware: __test__ = False def __init__( @@ -194,19 +245,43 @@ def __init__( session_args: dict | None = None, commit_on_exit: bool = False, ): - super().__init__(app) + # Pure ASGI middleware: normal responses are buffered until the + # request session finalizes, while streaming bodies can opt into an + # explicit body-lifetime session with ``async with db()``. + self.app = app self.commit_on_exit = commit_on_exit + self.engine: AsyncEngine + self.engine_owned = custom_engine is None + self._engine_disposed = False engine_args = engine_args or {} session_args = session_args or {} if not custom_engine and not db_url: raise ValueError("You need to pass a db_url or a custom_engine parameter.") + + nonlocal _Session, _Session_engine + + # Validate the proxy/engine relationship BEFORE allocating a new + # engine from `db_url`. Any engine we create here cannot be `is` + # to an already-bound `_Session_engine`, so a rejected init must + # not leak a freshly-created engine. + if _Session_engine is not None and ( + custom_engine is None or _Session_engine is not custom_engine + ): + raise RuntimeError( + "This SQLAlchemy session proxy is already bound to another live engine. " + "Use create_middleware_and_session_proxy() for independent apps or " + "databases." + ) + if custom_engine: engine = custom_engine else: + assert db_url is not None engine = create_async_engine(db_url, **engine_args) - nonlocal _Session + self.engine = engine + _Session_engine = engine _Session = async_sessionmaker( engine, class_=DefaultAsyncSession, @@ -214,9 +289,111 @@ def __init__( **session_args, ) - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): - async with DBSession(commit_on_exit=self.commit_on_exit): - return await call_next(request) + async def dispose(self) -> None: + if not self.engine_owned or self._engine_disposed: + return + + nonlocal _Session, _Session_engine + try: + await self.engine.dispose() + self._engine_disposed = True + finally: + # Always clear proxy bindings owned by this middleware. On + # failure, this lets a retry actually re-attempt disposal + # rather than silently no-op'ing on a half-disposed engine. + if _Session_engine is self.engine: + _Session_engine = None + _Session = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "lifespan": + + async def send_with_disposal(message): + if message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + "lifespan.startup.failed", + ): + captured: BaseException | None = None + try: + await self.dispose() + except Exception as disposal_exc: + captured = disposal_exc + finally: + # Forward the lifespan ack even if disposal raised, + # so the ASGI server is not left hanging. + await send(message) + if captured is not None: + logging.getLogger(__name__).warning( + "Engine disposal failed during ASGI lifespan %s", + message["type"], + exc_info=captured, + ) + raise captured + else: + await send(message) + + await self.app(scope, receive, send_with_disposal) + return + + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request_context = DBSession( + commit_on_exit=self.commit_on_exit, + _request_context=True, + ) + buffered_messages: list[Message] = [] + streaming_passthrough = False + + async def send_with_db_finalization(message: Message) -> None: + nonlocal streaming_passthrough + + if streaming_passthrough: + await send(message) + return + + if message["type"] == "http.response.start": + buffered_messages.append(message) + return + + if message["type"] == "http.response.body": + if message.get("more_body", False): + if self.commit_on_exit and _request_session_used.get(): + raise RuntimeError( + "`commit_on_exit=True` cannot use the middleware-managed " + "request database session with a streaming response. Use " + "`async with db()` inside the streaming generator, or manage " + "the streaming transaction explicitly." + ) + + if not _request_session_used.get(): + await request_context.close_request_session_for_streaming() + + for buffered_message in buffered_messages: + await send(buffered_message) + buffered_messages.clear() + await send(message) + streaming_passthrough = True + return + + buffered_messages.append(message) + return + + buffered_messages.append(message) + + await request_context.__aenter__() + try: + await self.app(scope, receive, send_with_db_finalization) + except BaseException as exc: + await request_context.__aexit__(type(exc), exc, exc.__traceback__) + raise + + await request_context.__aexit__(None, None, None) + + for message in buffered_messages: + await send(message) class DBSessionMeta(type): @property @@ -230,6 +407,11 @@ def session(self) -> AsyncSession: state = _multi_state.get() if state is None: raise RuntimeError("Multi-session state is not initialized") + if state.closing: + raise RuntimeError( + "Cannot create or access db.session after the owning multi-session " + "context has started closing." + ) task = asyncio.current_task() @@ -261,6 +443,13 @@ def session(self) -> AsyncSession: def cleanup_callback(finished_task: asyncio.Task) -> None: async def cleanup() -> None: + # Invariant: do NOT add `await` between the + # `session in state.tracked` check and the matching + # `state.tracked.discard(session)` below. The sweep in + # DBSession.__aexit__ runs `list(state.tracked); + # state.tracked.clear()` atomically; an await here + # would let the sweep claim the same session and + # cause double-finalization. task_exception: BaseException | None try: task_exception = finished_task.exception() @@ -312,6 +501,7 @@ async def cleanup() -> None: session = _session.get() if session is None: raise MissingSessionError + _mark_request_session_used(session) return session def connection(self) -> _ConnectionContextManager: @@ -363,14 +553,55 @@ async def gather(self, *coros_or_futures, return_exceptions: bool = False): return_exceptions=return_exceptions, ) - async def _throttled(coro): + coros = list(coros_or_futures) + + try: + for item in coros: + if asyncio.isfuture(item): + raise TypeError( + "When `max_concurrent` is set, db.gather() accepts coroutine " + "objects only; pre-created Task or Future inputs may already be " + "running outside the semaphore. Pass coroutine objects or use " + "db.connection()." + ) + if not asyncio.iscoroutine(item): + raise TypeError( + "When `max_concurrent` is set, db.gather() accepts coroutine " + "objects only." + ) + except BaseException: + for item in coros: + if asyncio.iscoroutine(item): + item.close() + raise + + started = [False] * len(coros) + + async def _throttled(index, coro): async with _ConnectionContextManager(): + started[index] = True return await coro - return await asyncio.gather( - *[_throttled(c) for c in coros_or_futures], - return_exceptions=return_exceptions, - ) + tasks = [ + asyncio.create_task(_throttled(index, coro)) for index, coro in enumerate(coros) + ] + + try: + return await asyncio.gather( + *tasks, + return_exceptions=return_exceptions, + ) + except BaseException: + for task in tasks: + if not task.done(): + task.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) + + for index, coro in enumerate(coros): + if not started[index] and asyncio.iscoroutine(coro): + coro.close() + raise class DBSession(metaclass=DBSessionMeta): def __init__( @@ -379,6 +610,7 @@ def __init__( commit_on_exit: bool = False, multi_sessions: bool = False, max_concurrent: int | None = None, + _request_context: bool = False, ): if max_concurrent is not None and max_concurrent < 1: raise ValueError("`max_concurrent` must be greater than 0.") @@ -390,6 +622,11 @@ def __init__( self.commit_on_exit = commit_on_exit self.multi_sessions = multi_sessions self.max_concurrent = max_concurrent + self.request_context = _request_context + self.request_session_token = None + self.request_session_used_token = None + self.request_session_closed_token = None + self._finalized = False async def __aenter__(self): if not isinstance(_Session, async_sessionmaker): @@ -412,15 +649,72 @@ async def __aenter__(self): ) ) else: - self.token = _session.set(_Session(**self.session_args)) + session = _Session(**self.session_args) + self.token = _session.set(session) + if self.request_context: + self.request_session_token = _request_session.set(session) + self.request_session_used_token = _request_session_used.set(False) + self.request_session_closed_token = _request_session_closed_for_streaming.set( + False + ) return type(self) + async def _finalize_regular_session(self, exc_type, exc_value) -> None: + if self._finalized: + return + + session = _request_session.get() if self.request_context else _session.get() + if session is None: + raise MissingSessionError + + try: + await _finalize_session( + session, + commit_on_exit=self.commit_on_exit, + exc=exc_value if exc_type is not None else None, + ) + finally: + self._finalized = True + + async def close_request_session_for_streaming(self) -> None: + await self._finalize_regular_session(None, None) + if self.request_context: + _request_session_closed_for_streaming.set(True) + async def __aexit__(self, exc_type, exc_value, traceback): if self.multi_sessions: - _multi_sessions_ctx.reset(self.multi_sessions_token) state = _multi_state.get() + if state is not None: + state.closing = True + _multi_sessions_ctx.reset(self.multi_sessions_token) cleanup_errors: list[BaseException] = [] if state is not None: + # Cancel tasks parked on the semaphore first. They aren't in + # task_sessions yet (entry happens after acquire), so the + # task_sessions sweep below would miss them — they would + # silently take a freed slot, create a session post-closing, + # and race with finalisation. + waiting_tasks = [ + task + for task in state.waiters + if task is not state.parent_task and not task.done() + ] + for task in waiting_tasks: + task.cancel() + if waiting_tasks: + await asyncio.gather(*waiting_tasks, return_exceptions=True) + state.waiters.clear() + + pending_tasks = [ + task + for task in state.task_sessions + if task is not state.parent_task and not task.done() + ] + for task in pending_tasks: + task.cancel() + if pending_tasks: + await asyncio.gather(*pending_tasks, return_exceptions=True) + if state.cleanup_tasks: cleanup_results = await asyncio.gather( *state.cleanup_tasks, @@ -457,14 +751,24 @@ async def __aexit__(self, exc_type, exc_value, traceback): stacklevel=2, ) else: - session = _session.get() try: - await _finalize_session( - session, - commit_on_exit=self.commit_on_exit, - exc=exc_value if exc_type is not None else None, - ) + try: + await self._finalize_regular_session(exc_type, exc_value) + except BaseException as cleanup_error: + if exc_type is None: + raise + warnings.warn( + "Suppressed session cleanup error because another exception is " + f"already being raised: {_cleanup_error(cleanup_error)}", + stacklevel=2, + ) finally: + if self.request_context: + _request_session_closed_for_streaming.reset( + self.request_session_closed_token + ) + _request_session_used.reset(self.request_session_used_token) + _request_session.reset(self.request_session_token) _session.reset(self.token) return _SQLAlchemyMiddleware, DBSession diff --git a/pyproject.toml b/pyproject.toml index f389964..51f1169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.ruff] line-length = 100 -target-version = "py39" +target-version = "py312" exclude = [ ".git", ".venv", @@ -35,5 +35,4 @@ split-on-trailing-comma = true [tool.pytest.ini_options] filterwarnings = [ "ignore::DeprecationWarning", - "ignore:The garbage collector is trying to clean up:sqlalchemy.exc.SAWarning", ] diff --git a/requirements.txt b/requirements.txt index 7eec008..e514fb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ chardet==3.0.4 click>=8.1.3 coverage>=5.2.1 entrypoints==0.3 -fastapi==0.90.0 # pyup: ignore +fastapi>=0.115 flake8==3.7.9 idna==3.7 importlib-metadata==1.5.0 @@ -15,29 +15,28 @@ mccabe==0.6.1 more-itertools==7.2.0 packaging>=22.0 pathspec>=0.9.0 -pluggy==0.13.0 +pluggy>=1.5.0 pycodestyle==2.5.0 -pydantic==1.10.18 +pydantic>=2.7 pyflakes==2.1.1 pyparsing==2.4.2 -pytest==7.2.0 -pytest-cov==2.11.1 +pytest>=8.3.0 +pytest-cov>=5.0.0 PyYAML>=5.4 regex>=2020.2.20 requests>=2.22.0 httpx>=0.20.0,<0.28.0 six==1.12.0 -SQLAlchemy>=1.4.19 +SQLAlchemy>=2.0 sqlmodel>=0.0.24 asyncpg>=0.27.0 aiosqlite==0.20.0 sqlparse>=0.5.4 -starlette>=0.13.6 +starlette>=0.40 toml>=0.10.1 -typed-ast>=1.4.2 urllib3>=1.25.9 wcwidth==0.1.7 zipp==3.19.1 black==26.3.1 -pytest-asyncio==0.21.0 -greenlet==3.1.1 +pytest-asyncio>=0.24.0 +greenlet>=3.2.4 diff --git a/setup.py b/setup.py index 9026da4..894b18b 100644 --- a/setup.py +++ b/setup.py @@ -26,8 +26,8 @@ packages=["fastapi_async_sqlalchemy"], package_data={"fastapi_async_sqlalchemy": ["py.typed"]}, zip_safe=False, - python_requires=">=3.9", - install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"], + python_requires=">=3.12", + install_requires=["starlette>=0.40", "SQLAlchemy>=2.0"], classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", @@ -35,9 +35,6 @@ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", diff --git a/tests/conftest.py b/tests/conftest.py index fd8ace6..7adc05e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import sys - import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -17,26 +15,17 @@ def client(app): @pytest.fixture -def SQLAlchemyMiddleware(): - from fastapi_async_sqlalchemy import SQLAlchemyMiddleware +def middleware_pair(): + from fastapi_async_sqlalchemy import create_middleware_and_session_proxy - yield SQLAlchemyMiddleware + return create_middleware_and_session_proxy() @pytest.fixture -def db(): - from fastapi_async_sqlalchemy import db - - yield db +def SQLAlchemyMiddleware(middleware_pair): + return middleware_pair[0] - # force reloading of module to clear global state - try: - del sys.modules["fastapi_async_sqlalchemy"] - except KeyError: - pass - - try: - del sys.modules["fastapi_async_sqlalchemy.middleware"] - except KeyError: - pass +@pytest.fixture +def db(middleware_pair): + return middleware_pair[1] diff --git a/tests/test_additional_coverage.py b/tests/test_additional_coverage.py index 26be37f..d2c63ff 100644 --- a/tests/test_additional_coverage.py +++ b/tests/test_additional_coverage.py @@ -2,6 +2,8 @@ Additional tests to reach target coverage of 97.22% """ +import asyncio + from fastapi import FastAPI @@ -16,12 +18,15 @@ def test_commit_on_exit_parameter(): # Test commit_on_exit=True custom_engine = create_async_engine("sqlite+aiosqlite://") - middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=True) - assert middleware.commit_on_exit is True + try: + middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=True) + assert middleware.commit_on_exit is True - # Test commit_on_exit=False (default) - middleware2 = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=False) - assert middleware2.commit_on_exit is False + # Test commit_on_exit=False (default) + middleware2 = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=False) + assert middleware2.commit_on_exit is False + finally: + asyncio.run(custom_engine.dispose()) def test_exception_classes_simple(): @@ -48,10 +53,13 @@ def test_middleware_properties(): # Test middleware properties custom_engine = create_async_engine("sqlite+aiosqlite://") - middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=True) + try: + middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine, commit_on_exit=True) - assert hasattr(middleware, "commit_on_exit") - assert middleware.commit_on_exit is True + assert hasattr(middleware, "commit_on_exit") + assert middleware.commit_on_exit is True + finally: + asyncio.run(custom_engine.dispose()) def test_basic_imports(): @@ -99,7 +107,10 @@ def test_middleware_factory_different_instances(): app = FastAPI() engine = create_async_engine("sqlite+aiosqlite://") - middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine) - middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine) + try: + middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine) + middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine) - assert middleware1 is not middleware2 + assert middleware1 is not middleware2 + finally: + asyncio.run(engine.dispose()) diff --git a/tests/test_backward_compat_gather.py b/tests/test_backward_compat_gather.py index f7e7201..e7cd781 100644 --- a/tests/test_backward_compat_gather.py +++ b/tests/test_backward_compat_gather.py @@ -1,11 +1,9 @@ -"""Test backward compatibility for asyncio.gather() without multi_sessions flag. +"""Supported alternatives to same-session ``asyncio.gather()`` patterns. -This test verifies that after the fix, the old code pattern works without -requiring multi_sessions=True explicitly. +Concurrent operations on one SQLAlchemy ``AsyncSession`` are backend-dependent, +so these tests avoid treating that pattern as a library compatibility promise. """ -import asyncio - import pytest from sqlalchemy import text @@ -13,13 +11,11 @@ @pytest.mark.asyncio -async def test_gather_works_without_multi_sessions_flag(app, db, SQLAlchemyMiddleware): +async def test_sequential_queries_work_without_multi_sessions_flag(app, db, SQLAlchemyMiddleware): """ - Verify that asyncio.gather() works in normal mode (without multi_sessions=True). - - This is the backward compatibility fix - users shouldn't need to change their code. + Verify that normal single-session code can execute related queries sequentially. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): await db.session.execute( @@ -31,16 +27,12 @@ async def test_gather_works_without_multi_sessions_flag(app, db, SQLAlchemyMiddl {"value": f"value_{i}"}, ) - # OLD CODE PATTERN - should work now without multi_sessions=True async with db(): count_stmt = text("SELECT COUNT(*) FROM compat_test") data_stmt = text("SELECT * FROM compat_test LIMIT 5") - # This should work! Each parallel query gets its own session - count_result, data_result = await asyncio.gather( - db.session.execute(count_stmt), - db.session.execute(data_stmt), - ) + count_result = await db.session.execute(count_stmt) + data_result = await db.session.execute(data_stmt) count = count_result.scalar() data = data_result.fetchall() @@ -48,15 +40,13 @@ async def test_gather_works_without_multi_sessions_flag(app, db, SQLAlchemyMiddl assert count == 20 assert len(data) == 5 - print("✅ Backward compatibility preserved: asyncio.gather() works without multi_sessions!") - @pytest.mark.asyncio -async def test_gather_multiple_queries_parallel(app, db, SQLAlchemyMiddleware): +async def test_multiple_single_session_queries_run_sequentially(app, db, SQLAlchemyMiddleware): """ - Test that multiple parallel queries work correctly. + Test that multiple related queries work correctly on one session. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): await db.session.execute( @@ -68,17 +58,14 @@ async def test_gather_multiple_queries_parallel(app, db, SQLAlchemyMiddleware): {"status": "active" if i % 3 == 0 else "inactive"}, ) - # Multiple parallel queries without multi_sessions=True async with db(): stmt1 = text("SELECT COUNT(*) FROM parallel_test WHERE status = 'active'") stmt2 = text("SELECT COUNT(*) FROM parallel_test WHERE status = 'inactive'") stmt3 = text("SELECT * FROM parallel_test LIMIT 10") - r1, r2, r3 = await asyncio.gather( - db.session.execute(stmt1), - db.session.execute(stmt2), - db.session.execute(stmt3), - ) + r1 = await db.session.execute(stmt1) + r2 = await db.session.execute(stmt2) + r3 = await db.session.execute(stmt3) active_count = r1.scalar() inactive_count = r2.scalar() @@ -90,13 +77,11 @@ async def test_gather_multiple_queries_parallel(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio -async def test_production_pattern_without_changes(app, db, SQLAlchemyMiddleware): +async def test_production_pattern_uses_sequential_queries(app, db, SQLAlchemyMiddleware): """ - Verify the EXACT production pattern from the error report works. - - This is the pattern from /app/api/repository/routes.py:186 + Verify the production-style count and page queries with the supported pattern. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): await db.session.execute( @@ -124,7 +109,6 @@ async def test_production_pattern_without_changes(app, db, SQLAlchemyMiddleware) }, ) - # EXACT PRODUCTION CODE - should work now! async with db(): count_stmt = text("SELECT COUNT(*) FROM processes WHERE status = :status") processes_stmt = text( @@ -135,11 +119,8 @@ async def test_production_pattern_without_changes(app, db, SQLAlchemyMiddleware) count_stmt = count_stmt.bindparams(status="running") processes_stmt = processes_stmt.bindparams(status="running", limit=10, offset=0) - # This is line 186 from production - should work without any changes! - total_result, processes_result = await asyncio.gather( - db.session.execute(count_stmt), - db.session.execute(processes_stmt), - ) + total_result = await db.session.execute(count_stmt) + processes_result = await db.session.execute(processes_stmt) total = total_result.scalar() processes = processes_result.fetchall() @@ -147,15 +128,13 @@ async def test_production_pattern_without_changes(app, db, SQLAlchemyMiddleware) assert total == 50 assert len(processes) == 10 - print("✅ Production code pattern works without any changes!") - @pytest.mark.asyncio -async def test_commit_on_exit_with_parallel_queries(app, db, SQLAlchemyMiddleware): +async def test_commit_on_exit_with_sequential_writes(app, db, SQLAlchemyMiddleware): """ - Verify that commit_on_exit works correctly with parallel queries. + Verify that commit_on_exit works correctly with sequential writes. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) # Create table first async with db(commit_on_exit=True): @@ -163,14 +142,11 @@ async def test_commit_on_exit_with_parallel_queries(app, db, SQLAlchemyMiddlewar text("CREATE TABLE IF NOT EXISTS commit_test (id INTEGER PRIMARY KEY, value TEXT)") ) - # Insert data with parallel queries and commit_on_exit + # Insert data with sequential writes and commit_on_exit. async with db(commit_on_exit=True): - # These should all be committed automatically - await asyncio.gather( - db.session.execute(text("INSERT INTO commit_test (value) VALUES ('a')")), - db.session.execute(text("INSERT INTO commit_test (value) VALUES ('b')")), - db.session.execute(text("INSERT INTO commit_test (value) VALUES ('c')")), - ) + await db.session.execute(text("INSERT INTO commit_test (value) VALUES ('a')")) + await db.session.execute(text("INSERT INTO commit_test (value) VALUES ('b')")) + await db.session.execute(text("INSERT INTO commit_test (value) VALUES ('c')")) # Verify data was committed async with db(): @@ -178,15 +154,13 @@ async def test_commit_on_exit_with_parallel_queries(app, db, SQLAlchemyMiddlewar count = result.scalar() assert count == 3 - print("✅ commit_on_exit works correctly with parallel queries!") - @pytest.mark.asyncio -async def test_rollback_on_error_with_parallel_queries(app, db, SQLAlchemyMiddleware): +async def test_rollback_on_error_with_sequential_queries(app, db, SQLAlchemyMiddleware): """ - Verify that rollback works correctly when error occurs in parallel queries. + Verify that rollback works correctly when an error occurs. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): await db.session.execute( @@ -209,5 +183,3 @@ async def test_rollback_on_error_with_parallel_queries(app, db, SQLAlchemyMiddle result = await db.session.execute(text("SELECT COUNT(*) FROM rollback_test")) count = result.scalar() assert count == 0 - - print("✅ Rollback works correctly on error!") diff --git a/tests/test_concurrent_queries.py b/tests/test_concurrent_queries.py index 0eb5ec7..470df0e 100644 --- a/tests/test_concurrent_queries.py +++ b/tests/test_concurrent_queries.py @@ -45,7 +45,7 @@ async def test_concurrent_queries_same_session_may_fail(app, db, SQLAlchemyMiddl Note: SQLite (aiosqlite) may not reproduce this issue because it serializes operations internally. The issue is more common with asyncpg/asyncmy drivers. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Create a test table @@ -100,7 +100,7 @@ async def test_concurrent_queries_same_session_sequential_works(app, db, SQLAlch This is a workaround - execute queries sequentially instead of using asyncio.gather(). """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Create a test table @@ -139,7 +139,7 @@ async def test_concurrent_queries_multi_sessions_works(app, db, SQLAlchemyMiddle With multi_sessions=True, each task gets its own session, so concurrent operations don't conflict. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(multi_sessions=True, commit_on_exit=True): # Create a test table @@ -189,7 +189,7 @@ async def test_concurrent_queries_reproduce_user_error(app, db, SQLAlchemyMiddle InvalidRequestError: This session is provisioning a new connection; concurrent operations are not permitted """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Setup similar to user's use case @@ -250,7 +250,7 @@ async def test_solution_using_separate_db_contexts(app, db, SQLAlchemyMiddleware This is different from multi_sessions mode - here we're showing how to structure the code to avoid the concurrent operations issue. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) # Setup data in the main context async with db(commit_on_exit=True): @@ -296,7 +296,7 @@ async def test_antipattern_documentation(app, db, SQLAlchemyMiddleware): This test exists purely for documentation purposes to show what NOT to do and why. """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): await db.session.execute( @@ -374,7 +374,7 @@ async def test_production_error_exact_reproduction(app, db, SQLAlchemyMiddleware ) ``` """ - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): # Setup similar to production diff --git a/tests/test_coverage_boost.py b/tests/test_coverage_boost.py index c31ae3f..8d6baea 100644 --- a/tests/test_coverage_boost.py +++ b/tests/test_coverage_boost.py @@ -2,6 +2,7 @@ Simple tests to boost coverage to target level """ +import asyncio from unittest.mock import AsyncMock import pytest @@ -24,16 +25,12 @@ def test_session_not_initialised_error(): def test_missing_session_error(): """Test MissingSessionError when session context is None""" - from fastapi.testclient import TestClient - - from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db from fastapi_async_sqlalchemy.exceptions import MissingSessionError + from fastapi_async_sqlalchemy.middleware import create_middleware_and_session_proxy + SQLAlchemyMiddleware, db = create_middleware_and_session_proxy() app = FastAPI() - app.add_middleware(SQLAlchemyMiddleware, db_url="sqlite+aiosqlite://") - - # Initialize middleware by creating a client - TestClient(app) + SQLAlchemyMiddleware(app, db_url="sqlite+aiosqlite://") # Now _Session is initialized, but no active session context # This should raise MissingSessionError @@ -44,16 +41,6 @@ def test_missing_session_error(): @pytest.mark.asyncio async def test_rollback_on_commit_exception(): """Test rollback is called when commit raises exception (lines 114-116)""" - from fastapi.testclient import TestClient - - from fastapi_async_sqlalchemy import SQLAlchemyMiddleware - - app = FastAPI() - app.add_middleware(SQLAlchemyMiddleware, db_url="sqlite+aiosqlite://") - - # Initialize middleware - TestClient(app) - # Create mock session that fails on commit mock_session = AsyncMock() mock_session.commit.side_effect = SQLAlchemyError("Commit failed!") @@ -138,5 +125,8 @@ def test_skipped_tests_make_coverage(): app = FastAPI() custom_engine = create_async_engine("sqlite+aiosqlite://") - middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine) - assert middleware.commit_on_exit is False # Default value + try: + middleware = SQLAlchemyMiddleware(app, custom_engine=custom_engine) + assert middleware.commit_on_exit is False # Default value + finally: + asyncio.run(custom_engine.dispose()) diff --git a/tests/test_coverage_improvements.py b/tests/test_coverage_improvements.py index 9dfcfcc..b94f755 100644 --- a/tests/test_coverage_improvements.py +++ b/tests/test_coverage_improvements.py @@ -33,8 +33,8 @@ async def child_task(): return {"result": result} - client = TestClient(app) - response = client.get("/test_closed_loop") + with TestClient(app) as client: + response = client.get("/test_closed_loop") assert response.status_code == 200 @@ -60,8 +60,8 @@ async def child_task(): await asyncio.sleep(0.1) return {"result": result} - client = TestClient(app) - response = client.get("/test_runtime_error") + with TestClient(app) as client: + response = client.get("/test_runtime_error") assert response.status_code == 200 @@ -86,8 +86,8 @@ async def child_task(n): return {"results": results} - client = TestClient(app) - response = client.get("/test_multiple_cleanup") + with TestClient(app) as client: + response = client.get("/test_multiple_cleanup") assert response.status_code == 200 assert len(response.json()["results"]) == 5 @@ -138,8 +138,8 @@ async def test_task_context(): return {"success": True} - client = TestClient(app) - response = client.get("/test_task_context") + with TestClient(app) as client: + response = client.get("/test_task_context") assert response.status_code == 200 @@ -162,8 +162,8 @@ async def quick_task(n): return {"done": True} - client = TestClient(app) - response = client.get("/test_loop_edge") + with TestClient(app) as client: + response = client.get("/test_loop_edge") assert response.status_code == 200 @@ -191,8 +191,8 @@ async def test_none_task(): return {"success": True, "has_session": True} return {"error": "Session is None"} - client = TestClient(app) - response = client.get("/test_none_task") + with TestClient(app) as client: + response = client.get("/test_none_task") assert response.status_code == 200 assert response.json()["success"] is True assert response.json()["has_session"] is True @@ -241,9 +241,9 @@ def mock_get_running_loop_closed(): return {"done": True} - client = TestClient(app) - with pytest.warns(UserWarning, match="No running event loop during cleanup"): - response = client.get("/test_mock_closed") + with TestClient(app) as client: + with pytest.warns(UserWarning, match="No running event loop during cleanup"): + response = client.get("/test_mock_closed") assert response.status_code == 200 @@ -281,7 +281,7 @@ def mock_get_running_loop_error(): return {"done": True} - client = TestClient(app) - with pytest.warns(UserWarning, match="No running event loop during cleanup"): - response = client.get("/test_runtime_error") + with TestClient(app) as client: + with pytest.warns(UserWarning, match="No running event loop during cleanup"): + response = client.get("/test_runtime_error") assert response.status_code == 200 diff --git a/tests/test_custom_engine_branch.py b/tests/test_custom_engine_branch.py index fee1cf7..12d092e 100644 --- a/tests/test_custom_engine_branch.py +++ b/tests/test_custom_engine_branch.py @@ -2,6 +2,8 @@ Targeted test to ensure custom_engine branch (line 61) is executed """ +import asyncio + import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -37,12 +39,14 @@ async def test_endpoint(): value = result.scalar() return {"value": value} - # Test the endpoint - client = TestClient(app) - response = client.get("/test") + try: + with TestClient(app) as client: + response = client.get("/test") - assert response.status_code == 200 - assert response.json()["value"] == 42 + assert response.status_code == 200 + assert response.json()["value"] == 42 + finally: + await custom_engine.dispose() def test_custom_engine_without_db_url(): @@ -57,14 +61,17 @@ def test_custom_engine_without_db_url(): # Create custom engine custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:") - # Initialize middleware with ONLY custom_engine (no db_url) - # This should take the else branch at line 61 - middleware = SQLAlchemyMiddleware( - app, custom_engine=custom_engine, engine_args={}, session_args={} - ) + try: + # Initialize middleware with ONLY custom_engine (no db_url) + # This should take the else branch at line 61 + middleware = SQLAlchemyMiddleware( + app, custom_engine=custom_engine, engine_args={}, session_args={} + ) - assert middleware is not None - assert middleware.commit_on_exit is False + assert middleware is not None + assert middleware.commit_on_exit is False + finally: + asyncio.run(custom_engine.dispose()) def test_custom_engine_with_session_args(): @@ -77,13 +84,19 @@ def test_custom_engine_with_session_args(): custom_engine = create_async_engine("sqlite+aiosqlite://") - # Use custom engine with session args - middleware = SQLAlchemyMiddleware( - app, custom_engine=custom_engine, session_args={"autoflush": False}, commit_on_exit=True - ) + try: + # Use custom engine with session args + middleware = SQLAlchemyMiddleware( + app, + custom_engine=custom_engine, + session_args={"autoflush": False}, + commit_on_exit=True, + ) - assert middleware is not None - assert middleware.commit_on_exit is True + assert middleware is not None + assert middleware.commit_on_exit is True + finally: + asyncio.run(custom_engine.dispose()) def test_custom_engine_multiple_instances(): @@ -99,9 +112,13 @@ def test_custom_engine_multiple_instances(): engine1 = create_async_engine("sqlite+aiosqlite:///:memory:") engine2 = create_async_engine("sqlite+aiosqlite://") - # Create two middleware instances - middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine1) - middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine2) + try: + # Create two middleware instances + middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine1) + middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine2) - assert middleware1 is not None - assert middleware2 is not None + assert middleware1 is not None + assert middleware2 is not None + finally: + asyncio.run(engine1.dispose()) + asyncio.run(engine2.dispose()) diff --git a/tests/test_edge_cases_coverage.py b/tests/test_edge_cases_coverage.py index 5442038..4a9155b 100644 --- a/tests/test_edge_cases_coverage.py +++ b/tests/test_edge_cases_coverage.py @@ -33,8 +33,8 @@ async def test_exception_rollback(): return {"status": "rolled_back"} - client = TestClient(app) - response = client.get("/test_exception_rollback") + with TestClient(app) as client: + response = client.get("/test_exception_rollback") assert response.status_code == 200 @@ -59,8 +59,8 @@ async def failing_commit(): return {"status": "handled"} - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/test_commit_failure_warning") + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/test_commit_failure_warning") assert response.status_code == 500 @@ -89,8 +89,8 @@ async def failing_rollback(): return {"status": "handled"} - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/test_rollback_failure") + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/test_rollback_failure") assert response.status_code == 500 @@ -118,8 +118,8 @@ async def failing_close(): return {"status": "handled"} - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/test_close_failure") + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/test_close_failure") assert response.status_code == 500 @@ -154,8 +154,8 @@ async def tracking_rollback(): return {"status": "handled", "rollback_called": rollback_called} - client = TestClient(app) - response = client.get("/test_commit_exception") + with TestClient(app) as client: + response = client.get("/test_commit_exception") # The exception should propagate assert response.status_code == 500 or response.status_code == 200 @@ -173,8 +173,8 @@ async def test_session_created_without_tracking_warning(): app = FastAPI() app.add_middleware(SQLAlchemyMiddleware_local, db_url="sqlite+aiosqlite:///:memory:") - # Initialize middleware - TestClient(app) + with TestClient(app): + pass # This test verifies the warning path exists # In normal usage, the tracking set is always created in __aenter__ @@ -190,11 +190,18 @@ def test_custom_engine_branch(): # Create custom engine custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:") - # This should use the else branch on line 61 - middleware = SQLAlchemyMiddleware_local(app, custom_engine=custom_engine, commit_on_exit=False) + try: + # This should use the else branch on line 61 + middleware = SQLAlchemyMiddleware_local( + app, + custom_engine=custom_engine, + commit_on_exit=False, + ) - assert middleware is not None - assert middleware.commit_on_exit is False + assert middleware is not None + assert middleware.commit_on_exit is False + finally: + asyncio.run(custom_engine.dispose()) @pytest.mark.asyncio @@ -257,8 +264,8 @@ async def run_query(value: int): return {"session_count": len(set(sessions))} - client = TestClient(app) - response = client.get("/test_comprehensive") + with TestClient(app) as client: + response = client.get("/test_comprehensive") assert response.status_code == 200 assert response.json()["session_count"] == 3 @@ -278,8 +285,8 @@ async def test_no_sessions(): return {"status": "ok"} - client = TestClient(app) - response = client.get("/test_no_sessions") + with TestClient(app) as client: + response = client.get("/test_no_sessions") assert response.status_code == 200 @@ -299,6 +306,6 @@ async def test_single_exception(): return {"status": "exception_handled"} - client = TestClient(app) - response = client.get("/test_single_exception") + with TestClient(app) as client: + response = client.get("/test_single_exception") assert response.status_code == 200 diff --git a/tests/test_full_coverage.py b/tests/test_full_coverage.py new file mode 100644 index 0000000..f323be4 --- /dev/null +++ b/tests/test_full_coverage.py @@ -0,0 +1,227 @@ +"""Targeted tests covering the last middleware branches that aren't reachable +via normal end-to-end flows — race windows, defensive error paths, direct +contextvar manipulation, and the SQLModel-not-installed import fallback.""" + +import asyncio +import importlib.util +import sys + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from fastapi_async_sqlalchemy import create_middleware_and_session_proxy +from fastapi_async_sqlalchemy.exceptions import MissingSessionError + +DB_URL = "sqlite+aiosqlite://" + + +def _get_closure_var(db_obj, var_name: str): + """Search closures across every callable attribute of the proxy class and + its metaclass to find a free variable by name. The proxy factory captures + different vars in different methods.""" + seen: set[int] = set() + + def _candidates(): + for src in (db_obj, type(db_obj)): + for value in src.__dict__.values(): + fn = value.fget if isinstance(value, property) else value + if callable(fn) and id(fn) not in seen: + seen.add(id(fn)) + yield fn + + for func in _candidates(): + code = getattr(func, "__code__", None) + closure = getattr(func, "__closure__", None) + if code is None or closure is None: + continue + if var_name in code.co_freevars: + return closure[code.co_freevars.index(var_name)].cell_contents + raise KeyError(var_name) + + +@pytest.mark.asyncio +async def test_request_session_access_after_streaming_close_raises(): + """db.session access after the request session is marked closed-for-streaming + raises a clear RuntimeError (middleware.py line 83).""" + Middleware, _db = create_middleware_and_session_proxy() + Middleware(app=None, db_url=DB_URL) + + request_session_var = _get_closure_var(_db, "_request_session") + closed_var = _get_closure_var(_db, "_request_session_closed_for_streaming") + session_var = _get_closure_var(_db, "_session") + Session = _get_closure_var(_db, "_Session") + + session = Session() + session_token = session_var.set(session) + request_token = request_session_var.set(session) + closed_token = closed_var.set(True) + try: + with pytest.raises(RuntimeError, match="closed for streaming"): + _ = _db.session + finally: + closed_var.reset(closed_token) + request_session_var.reset(request_token) + session_var.reset(session_token) + await session.close() + + +@pytest.mark.asyncio +async def test_connection_releases_slot_when_parent_closes_during_acquire(): + """If the parent multi-session context starts closing while a waiter is + parked on semaphore.acquire(), the waiter must release the slot it just + obtained and raise (middleware.py lines 180-181).""" + Middleware, _db = create_middleware_and_session_proxy() + Middleware(app=None, db_url=DB_URL) + + holder_started = asyncio.Event() + holder_release = asyncio.Event() + + async def holder(): + async with _db.connection() as session: + holder_started.set() + await session.execute(text("SELECT 1")) + await holder_release.wait() + + multi_state_var = _get_closure_var(_db, "_multi_state") + + async with _db(multi_sessions=True, max_concurrent=1): + holder_task = asyncio.create_task(holder()) + await holder_started.wait() + + async def waiter(): + async with _db.connection(): + pass + + waiter_task = asyncio.create_task(waiter()) + # Let the waiter park on acquire(). + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Manually flip closing so the post-acquire re-check (line 179) raises. + state = multi_state_var.get() + state.closing = True + + # Release the holder so the waiter wakes up, finds closing=True, + # releases the slot it just took, and raises. + holder_release.set() + await holder_task + + with pytest.raises(RuntimeError, match="started closing"): + await waiter_task + + state.closing = False + + +@pytest.mark.asyncio +async def test_finalize_regular_session_raises_when_session_missing(): + """If the bound session is reset to None before __aexit__, finalize must + raise MissingSessionError (middleware.py line 668).""" + Middleware, db_obj = create_middleware_and_session_proxy() + Middleware(app=None, db_url=DB_URL) + + session_var = _get_closure_var(db_obj, "_session") + + ctx = db_obj() + await ctx.__aenter__() + + # Yank the session out from under the context — exercises the defensive + # `if session is None: raise MissingSessionError` path. + real_session = session_var.get() + drop_token = session_var.set(None) + try: + with pytest.raises(MissingSessionError): + await ctx.__aexit__(None, None, None) + finally: + session_var.reset(drop_token) + await real_session.close() + + +@pytest.mark.asyncio +async def test_waiters_cancelled_when_context_exits_with_holder_active(): + """Tasks parked on db.connection()'s semaphore at the moment the parent + multi-session context exits must be cancelled by __aexit__ via the + `state.waiters` sweep (middleware.py lines 703 & 705).""" + Middleware, _db = create_middleware_and_session_proxy() + Middleware(app=None, db_url=DB_URL) + + holder_started = asyncio.Event() + holder_release = asyncio.Event() + waiter_started = asyncio.Event() + + async def holder(): + async with _db.connection() as session: + holder_started.set() + await session.execute(text("SELECT 1")) + await holder_release.wait() + + async def waiter(): + waiter_started.set() + async with _db.connection(): + pass + + holder_task: asyncio.Task | None = None + waiter_task: asyncio.Task | None = None + try: + async with _db(multi_sessions=True, max_concurrent=1): + holder_task = asyncio.create_task(holder()) + await holder_started.wait() + + waiter_task = asyncio.create_task(waiter()) + await waiter_started.wait() + # Park the waiter on acquire(). + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Exit the multi-session context with the holder still active and + # the waiter still parked — __aexit__ must walk state.waiters and + # cancel them. + finally: + holder_release.set() + if holder_task is not None: + await asyncio.gather(holder_task, return_exceptions=True) + if waiter_task is not None: + await asyncio.gather(waiter_task, return_exceptions=True) + + assert waiter_task is not None + assert waiter_task.cancelled() + + +def test_middleware_falls_back_to_sqlalchemy_session_when_sqlmodel_missing(): + """Re-execute middleware.py source in a fresh namespace with sqlmodel + blocked, exercising the ImportError fallback path (middleware.py line 28).""" + middleware_path = sys.modules["fastapi_async_sqlalchemy.middleware"].__file__ + assert middleware_path is not None + + # Block sqlmodel and its parent packages — sys.modules[name] = None makes + # `import name` raise ModuleNotFoundError without affecting installed pkgs. + saved_modules = {} + for name in list(sys.modules): + if name == "sqlmodel" or name.startswith("sqlmodel."): + saved_modules[name] = sys.modules.pop(name) + + sys.modules["sqlmodel"] = None # type: ignore[assignment] + sys.modules["sqlmodel.ext"] = None # type: ignore[assignment] + sys.modules["sqlmodel.ext.asyncio"] = None # type: ignore[assignment] + sys.modules["sqlmodel.ext.asyncio.session"] = None # type: ignore[assignment] + + clone_name = "_fasq_sqlmodel_missing_clone" + try: + spec = importlib.util.spec_from_file_location(clone_name, middleware_path) + assert spec is not None and spec.loader is not None + clone = importlib.util.module_from_spec(spec) + # Register in sys.modules before exec so dataclass annotation resolution + # (sys.modules[cls.__module__]) finds the live module during class body. + sys.modules[clone_name] = clone + spec.loader.exec_module(clone) + + # Without sqlmodel, the fallback assigns AsyncSession directly. + assert clone.DefaultAsyncSession is AsyncSession + finally: + sys.modules.pop(clone_name, None) + sys.modules.pop("sqlmodel", None) + sys.modules.pop("sqlmodel.ext", None) + sys.modules.pop("sqlmodel.ext.asyncio", None) + sys.modules.pop("sqlmodel.ext.asyncio.session", None) + for name, mod in saved_modules.items(): + sys.modules[name] = mod diff --git a/tests/test_import_fallback_simulation.py b/tests/test_import_fallback_simulation.py index 49519a0..848fe62 100644 --- a/tests/test_import_fallback_simulation.py +++ b/tests/test_import_fallback_simulation.py @@ -4,6 +4,8 @@ which only execute in specific import scenarios """ +import asyncio + import pytest @@ -87,12 +89,15 @@ def test_custom_engine_else_branch_execution(): "sqlite+aiosqlite:///:memory:", echo=False, pool_pre_ping=True ) - # Initialize middleware with custom_engine - # This should execute line 61: engine = custom_engine - middleware = SQLAlchemyMiddleware_local(app, custom_engine=custom_engine) + try: + # Initialize middleware with custom_engine + # This should execute line 61: engine = custom_engine + middleware = SQLAlchemyMiddleware_local(app, custom_engine=custom_engine) - # Verify middleware was created - assert middleware is not None + # Verify middleware was created + assert middleware is not None + finally: + asyncio.run(custom_engine.dispose()) def test_session_tracking_warning_scenario(): @@ -153,8 +158,8 @@ async def test_verify_all_middleware_branches_tested(): # Test 1: db_url path (line 59: engine = create_async_engine) app1 = FastAPI() app1.add_middleware(SQLAlchemyMiddleware, db_url="sqlite+aiosqlite://") - client1 = TestClient(app1) - assert client1 is not None + with TestClient(app1) as client1: + assert client1 is not None # Test 2: custom_engine path (line 61: engine = custom_engine) from fastapi_async_sqlalchemy.middleware import create_middleware_and_session_proxy @@ -163,8 +168,11 @@ async def test_verify_all_middleware_branches_tested(): app2 = FastAPI() custom_engine = create_async_engine("sqlite+aiosqlite://") app2.add_middleware(SQLAlchemyMiddleware2, custom_engine=custom_engine) - client2 = TestClient(app2) - assert client2 is not None + try: + with TestClient(app2) as client2: + assert client2 is not None + finally: + await custom_engine.dispose() def test_coverage_report_explanation(): diff --git a/tests/test_import_fallbacks.py b/tests/test_import_fallbacks.py index f12b6e4..7bb1838 100644 --- a/tests/test_import_fallbacks.py +++ b/tests/test_import_fallbacks.py @@ -5,30 +5,6 @@ """ -def test_async_sessionmaker_fallback(): - """Test fallback when async_sessionmaker is not available in SQLAlchemy.""" - # Note: This test verifies the fallback structure exists. - # Actually testing both import paths would require manipulating imports - # before module load, which is complex and fragile. - - # Verify the code has proper fallback structure - import inspect - - import fastapi_async_sqlalchemy.middleware as mod - - source = inspect.getsource(mod) - assert "try:" in source - assert "from sqlalchemy.ext.asyncio import async_sessionmaker" in source - assert "except ImportError:" in source - assert "from sqlalchemy.orm import sessionmaker as async_sessionmaker" in source - - # Verify async_sessionmaker is importable (whichever path was taken) - # This exercises one of the two code paths - assert hasattr(mod, "async_sessionmaker") or "async_sessionmaker" in dir( - mod.create_middleware_and_session_proxy.__code__.co_freevars - ) - - def test_sqlmodel_not_installed_fallback(): """Test fallback when SQLModel is not installed.""" import inspect diff --git a/tests/test_maximum_coverage.py b/tests/test_maximum_coverage.py index 7977be8..8391d27 100644 --- a/tests/test_maximum_coverage.py +++ b/tests/test_maximum_coverage.py @@ -48,8 +48,8 @@ async def tracking_rollback(): return {"session_id": id(session)} - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/test_commit_failure") + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/test_commit_failure") assert response.status_code == 500 @@ -82,8 +82,8 @@ async def tracking_commit(): return {"status": "ok"} - client = TestClient(app) - response = client.get("/test_commit_success") + with TestClient(app) as client: + response = client.get("/test_commit_success") assert response.status_code == 200 # Give cleanup time to run @@ -114,8 +114,8 @@ async def execute_with_session(value: int): results = await asyncio.gather(*tasks) return {"results": results, "session_count": len(set(session_ids))} - client = TestClient(app) - response = client.get("/test_multi_cleanup") + with TestClient(app) as client: + response = client.get("/test_multi_cleanup") assert response.status_code == 200 @@ -173,9 +173,12 @@ def test_custom_engine_path(): app = FastAPI() custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:") - # Initialize with custom engine - middleware = SQLAlchemyMiddleware_local(app, custom_engine=custom_engine) - assert middleware.commit_on_exit is False + try: + # Initialize with custom engine + middleware = SQLAlchemyMiddleware_local(app, custom_engine=custom_engine) + assert middleware.commit_on_exit is False + finally: + asyncio.run(custom_engine.dispose()) # Verify it doesn't require db_url # This covers the else branch on line 61 @@ -188,10 +191,7 @@ async def test_session_outside_middleware_context(): SQLAlchemyMiddleware_local, db_local = create_middleware_and_session_proxy() app = FastAPI() - app.add_middleware(SQLAlchemyMiddleware_local, db_url="sqlite+aiosqlite://") - - # Initialize the middleware - TestClient(app) + SQLAlchemyMiddleware_local(app, db_url="sqlite+aiosqlite://") # Try to access session outside of request context with pytest.raises(MissingSessionError): @@ -216,8 +216,8 @@ async def test_context_vars(): return {"status": "ok"} - client = TestClient(app) - response = client.get("/test_context_vars") + with TestClient(app) as client: + response = client.get("/test_context_vars") assert response.status_code == 200 @@ -240,8 +240,8 @@ async def test_rollback(): return {"status": "rolled_back"} - client = TestClient(app) - response = client.get("/test_rollback") + with TestClient(app) as client: + response = client.get("/test_rollback") assert response.status_code == 200 @@ -260,8 +260,8 @@ async def test_commit(): return {"status": "committed"} - client = TestClient(app) - response = client.get("/test_commit") + with TestClient(app) as client: + response = client.get("/test_commit") assert response.status_code == 200 @@ -275,10 +275,9 @@ def test_middleware_commit_on_exit_parameter(): middleware = SQLAlchemyMiddleware_local(app, db_url="sqlite+aiosqlite://", commit_on_exit=True) assert middleware.commit_on_exit is True - # Test with commit_on_exit=False - middleware2 = SQLAlchemyMiddleware_local( - app, db_url="sqlite+aiosqlite://", commit_on_exit=False - ) + # Test with commit_on_exit=False on a separate proxy to avoid rebinding the singleton. + SecondMiddleware, _ = create_middleware_and_session_proxy() + middleware2 = SecondMiddleware(app, db_url="sqlite+aiosqlite://", commit_on_exit=False) assert middleware2.commit_on_exit is False @@ -329,8 +328,8 @@ async def test_token_reset(): # Verify by trying to access session (should raise MissingSessionError) return {"status": "ok"} - client = TestClient(app) - response = client.get("/test_token_reset") + with TestClient(app) as client: + response = client.get("/test_token_reset") assert response.status_code == 200 @@ -351,8 +350,8 @@ async def test_session_args(): return {"value": value} - client = TestClient(app) - response = client.get("/test_session_args") + with TestClient(app) as client: + response = client.get("/test_session_args") assert response.status_code == 200 assert response.json()["value"] == 42 @@ -372,8 +371,8 @@ async def test_no_commit(): return {"status": "no_commit"} - client = TestClient(app) - response = client.get("/test_no_commit") + with TestClient(app) as client: + response = client.get("/test_no_commit") assert response.status_code == 200 @@ -398,8 +397,8 @@ async def task_function(): return {"result": result} - client = TestClient(app) - response = client.get("/test_callback") + with TestClient(app) as client: + response = client.get("/test_callback") assert response.status_code == 200 # Give cleanup time to execute diff --git a/tests/test_multi_session_fixes.py b/tests/test_multi_session_fixes.py index 15a2d6a..8d57f66 100644 --- a/tests/test_multi_session_fixes.py +++ b/tests/test_multi_session_fixes.py @@ -32,6 +32,7 @@ def _get_ctx_var(_db, var_name: str): for name, cell in zip( session_prop.fget.__code__.co_freevars, session_prop.fget.__closure__, + strict=False, ) } return closure[var_name] @@ -121,32 +122,59 @@ async def work(): @pytest.mark.asyncio -async def test_single_session_aexit_aggregates_rollback_and_close_errors(): +async def test_single_session_aexit_raises_aggregated_cleanup_errors_without_original_exception(): """ Regression: non-multi __aexit__ must use _finalize_session which aggregates - errors from both rollback() and close() when both fail. + cleanup errors when there is no original exception already propagating. Before fix (manual path): - - rollback() raises → propagates to finally block - - close() raises in finally → close error REPLACES rollback error - - caller sees only RuntimeError("close error"), rollback error is lost + - commit() raises -> propagates to finally block + - close() raises in finally -> close error REPLACES commit error + - caller sees only RuntimeError("close error"), commit error is lost After fix (_finalize_session): - - both errors collected and raised as: - RuntimeError("Session cleanup failed with 2 errors: ...") + - commit, rollback, and close errors are collected and raised as: + RuntimeError("Session cleanup failed with 3 errors: ...") """ _db = _make_middleware_and_db() - with pytest.raises(RuntimeError, match="Session cleanup failed with 2 errors"): - async with _db(): + with pytest.raises(RuntimeError, match="Session cleanup failed with 3 errors"): + async with _db(commit_on_exit=True): session = _db.session + async def failing_commit(): + raise RuntimeError("commit error") + async def failing_rollback(): raise RuntimeError("rollback error") async def failing_close(): raise RuntimeError("close error") + session.commit = failing_commit session.rollback = failing_rollback session.close = failing_close - raise ValueError("trigger rollback path in __aexit__") + + +@pytest.mark.asyncio +async def test_single_session_aexit_warns_cleanup_errors_without_replacing_original(): + """ + When the block body is already raising, cleanup errors must be visible as a + warning without replacing the original exception. + """ + _db = _make_middleware_and_db() + + with pytest.warns(UserWarning, match="Session cleanup failed with 2 errors"): + with pytest.raises(ValueError, match="trigger rollback path"): + async with _db(): + session = _db.session + + async def failing_rollback(): + raise RuntimeError("rollback error") + + async def failing_close(): + raise RuntimeError("close error") + + session.rollback = failing_rollback + session.close = failing_close + raise ValueError("trigger rollback path in __aexit__") diff --git a/tests/test_pool_throttling.py b/tests/test_pool_throttling.py index 2b874af..0c94979 100644 --- a/tests/test_pool_throttling.py +++ b/tests/test_pool_throttling.py @@ -55,7 +55,9 @@ def _get_session_closure_var(db_obj, var_name: str): session_prop = type(db_obj).__dict__["session"] closure = { name: cell.cell_contents - for name, cell in zip(session_prop.fget.__code__.co_freevars, session_prop.fget.__closure__) + for name, cell in zip( + session_prop.fget.__code__.co_freevars, session_prop.fget.__closure__, strict=False + ) } return closure[var_name] @@ -445,7 +447,7 @@ async def child(n): @pytest.mark.asyncio async def test_connection_ctx_non_multi_sessions(app, db, SQLAlchemyMiddleware): """db.connection() works in regular (non-multi_sessions) mode too.""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): async with db.connection() as session: @@ -623,3 +625,98 @@ async def test_max_concurrent_must_be_positive(): with pytest.raises(ValueError, match="`max_concurrent` must be greater than 0"): async with _db(multi_sessions=True, max_concurrent=0): pass + + +@pytest.mark.asyncio +async def test_gather_rejects_pre_created_task_when_throttled(): + """db.gather() rejects pre-created Tasks when max_concurrent is set.""" + _db = _make_middleware_and_db() + + async def work(): + return 1 + + coro = work() + + async with _db(multi_sessions=True, max_concurrent=1): + task = asyncio.create_task(work()) + try: + with pytest.raises(TypeError, match="coroutine objects only"): + await _db.gather(task, coro) + finally: + task.cancel() + await asyncio.gather(task, return_exceptions=True) + + # Coroutine passed alongside the bad input must have been closed — + # awaiting a closed coroutine raises RuntimeError. + with pytest.raises(RuntimeError, match="cannot reuse already awaited coroutine"): + await coro + + +@pytest.mark.asyncio +async def test_gather_rejects_non_coroutine_when_throttled(): + """db.gather() rejects plain non-coroutine values when max_concurrent is set.""" + _db = _make_middleware_and_db() + + async def work(): + return 1 + + coro = work() + + async with _db(multi_sessions=True, max_concurrent=1): + with pytest.raises(TypeError, match="coroutine objects only"): + await _db.gather(coro, "not a coroutine") + + with pytest.raises(RuntimeError, match="cannot reuse already awaited coroutine"): + await coro + + +@pytest.mark.asyncio +async def test_middleware_buffers_unknown_response_message_types(): + """ASGI messages other than http.response.start/body must be buffered + alongside the response (e.g. http.response.trailers extension).""" + from fastapi_async_sqlalchemy import create_middleware_and_session_proxy + + Middleware, _db = create_middleware_and_session_proxy() + + async def downstream_app(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.trailers", "trailers": []}) + await send({"type": "http.response.body", "body": b"ok"}) + + middleware = Middleware(app=downstream_app, db_url=db_url) + + sent = [] + + async def receive(): + return {"type": "http.request"} + + async def send(message): + sent.append(message["type"]) + + await middleware({"type": "http", "method": "GET", "path": "/"}, receive, send) + + assert "http.response.trailers" in sent + + +@pytest.mark.asyncio +async def test_middleware_passes_through_non_http_lifespan_scope(): + """Scopes other than http/lifespan (e.g. websocket) must bypass session setup.""" + from fastapi_async_sqlalchemy import create_middleware_and_session_proxy + + Middleware, _db = create_middleware_and_session_proxy() + + seen = {} + + async def downstream_app(scope, receive, send): + seen["scope"] = scope + + middleware = Middleware(app=downstream_app, db_url=db_url) + + async def receive(): + return {"type": "websocket.connect"} + + async def send(_message): + pass + + await middleware({"type": "websocket"}, receive, send) + assert seen["scope"]["type"] == "websocket" diff --git a/tests/test_resource_lifecycle.py b/tests/test_resource_lifecycle.py new file mode 100644 index 0000000..91880bf --- /dev/null +++ b/tests/test_resource_lifecycle.py @@ -0,0 +1,675 @@ +import asyncio +import inspect +import warnings + +import pytest +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from fastapi.testclient import TestClient +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +DB_URL = "sqlite+aiosqlite://" + + +def _make_middleware_and_db(): + from fastapi_async_sqlalchemy import create_middleware_and_session_proxy + + return create_middleware_and_session_proxy() + + +def _get_ctx_var(_db, var_name: str): + session_prop = type(_db).__dict__["session"] + closure = { + name: cell.cell_contents + for name, cell in zip( + session_prop.fget.__code__.co_freevars, + session_prop.fget.__closure__, + strict=True, + ) + } + return closure[var_name] + + +def _capture_http_messages(app, messages): + async def wrapped(scope, receive, send): + async def capture_send(message): + if message["type"].startswith("http.response"): + messages.append(message.copy()) + await send(message) + + await app(scope, receive, capture_send) + + return wrapped + + +def test_commit_failure_prevents_successful_response_start(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL, commit_on_exit=True) + messages = [] + + @app.get("/commit-fails") + async def commit_fails(): + session = db.session + + async def failing_commit(): + raise SQLAlchemyError("commit failed before response start") + + session.commit = failing_commit + await session.execute(text("SELECT 1")) + return {"ok": True} + + with TestClient( + _capture_http_messages(app, messages), + raise_server_exceptions=False, + ) as client: + response = client.get("/commit-fails") + + assert response.status_code == 500 + assert not any( + message["type"] == "http.response.start" and message["status"] == 200 + for message in messages + ) + + +def test_successful_commit_precedes_normal_response_start(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL, commit_on_exit=True) + events = [] + + async def recording_app(scope, receive, send): + async def recording_send(message): + if message["type"] == "http.response.start": + events.append("response_start") + await send(message) + + await app(scope, receive, recording_send) + + @app.get("/commits-first") + async def commits_first(): + session = db.session + original_commit = session.commit + + async def tracking_commit(): + events.append("commit") + await original_commit() + + session.commit = tracking_commit + await session.execute(text("SELECT 1")) + return {"ok": True} + + with TestClient(recording_app) as client: + response = client.get("/commits-first") + + assert response.status_code == 200 + assert events == ["commit", "response_start"] + + +def test_explicit_streaming_read_owns_and_closes_body_session(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL, commit_on_exit=True) + closed = [] + + @app.get("/stream") + async def stream(): + async def body(): + async with db(): + session = db.session + original_close = session.close + + async def tracking_close(): + closed.append("closed") + await original_close() + + session.close = tracking_close + + for value in range(3): + result = await db.session.execute(text(f"SELECT {value}")) + yield f"{result.scalar()}\n".encode() + + return StreamingResponse(body(), media_type="text/plain") + + with TestClient(app) as client: + response = client.get("/stream") + + assert response.status_code == 200 + assert response.text == "0\n1\n2\n" + assert closed == ["closed"] + + +def test_implicit_commit_on_exit_streaming_write_is_not_reported_successful(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL, commit_on_exit=True) + messages = [] + + @app.get("/stream-write") + async def stream_write(): + async def body(): + await db.session.execute(text("CREATE TABLE unsafe_stream (value INTEGER)")) + await db.session.execute(text("INSERT INTO unsafe_stream VALUES (1)")) + yield b"created\n" + + return StreamingResponse(body(), media_type="text/plain") + + with TestClient( + _capture_http_messages(app, messages), + raise_server_exceptions=False, + ) as client: + response = client.get("/stream-write") + + assert response.status_code == 500 + assert not any( + message["type"] == "http.response.start" and message["status"] == 200 + for message in messages + ) + + +def test_non_database_streaming_closes_request_session_before_response_start(monkeypatch): + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + app = FastAPI() + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL, commit_on_exit=True) + events = [] + original_close = middleware_module.DefaultAsyncSession.close + + async def tracking_close(self, *args, **kwargs): + events.append("request_session_close") + await original_close(self, *args, **kwargs) + + monkeypatch.setattr(middleware_module.DefaultAsyncSession, "close", tracking_close) + + async def recording_app(scope, receive, send): + async def recording_send(message): + if message["type"] == "http.response.start": + events.append("response_start") + await send(message) + + await app(scope, receive, recording_send) + + @app.get("/plain-stream") + async def plain_stream(): + async def body(): + yield b"one\n" + yield b"two\n" + + return StreamingResponse(body(), media_type="text/plain") + + with TestClient(recording_app) as client: + response = client.get("/plain-stream") + + assert response.status_code == 200 + assert response.text == "one\ntwo\n" + assert events.index("request_session_close") < events.index("response_start") + + +def test_owned_db_url_engine_is_disposed_on_lifespan_shutdown(monkeypatch): + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL) + + @app.get("/") + async def get_value(): + result = await db.session.execute(text("SELECT 1")) + return {"value": result.scalar()} + + with TestClient(app) as client: + assert client.get("/").json() == {"value": 1} + + assert dispose_calls == [engine] + + +def test_owned_engine_is_disposed_on_shutdown_failed(monkeypatch): + from contextlib import asynccontextmanager + + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + @asynccontextmanager + async def failing_lifespan(_app): + yield + raise RuntimeError("user shutdown blew up") + + app = FastAPI(lifespan=failing_lifespan) + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL) + + with pytest.raises(RuntimeError, match="user shutdown blew up"): + with TestClient(app): + pass + + assert dispose_calls == [engine] + + +def test_owned_engine_is_disposed_on_startup_failed(monkeypatch): + from contextlib import asynccontextmanager + + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + @asynccontextmanager + async def failing_startup(_app): + raise RuntimeError("startup blew up") + yield # unreachable, but required so this is a valid asynccontextmanager + + app = FastAPI(lifespan=failing_startup) + app.add_middleware(SQLAlchemyMiddleware, db_url=DB_URL) + + with pytest.raises(RuntimeError, match="startup blew up"): + with TestClient(app): + pass + + assert dispose_calls == [engine] + + +@pytest.mark.asyncio +async def test_custom_engine_is_not_disposed_on_shutdown_failed(monkeypatch): + from contextlib import asynccontextmanager + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + + @asynccontextmanager + async def failing_lifespan(_app): + yield + raise RuntimeError("user shutdown blew up") + + app = FastAPI(lifespan=failing_lifespan) + app.add_middleware(SQLAlchemyMiddleware, custom_engine=engine) + + try: + with pytest.raises(RuntimeError, match="user shutdown blew up"): + with TestClient(app): + pass + + assert engine not in dispose_calls + finally: + await original_dispose(engine) + + +@pytest.mark.asyncio +async def test_owned_engine_disposal_is_idempotent(monkeypatch): + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + app = FastAPI() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + middleware = SQLAlchemyMiddleware(app, db_url=DB_URL) + + await middleware.dispose() + await middleware.dispose() + + assert dispose_calls == [engine] + + +@pytest.mark.asyncio +async def test_custom_engine_is_not_disposed_on_lifespan_shutdown(monkeypatch): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + app = FastAPI() + engine = create_async_engine(DB_URL) + dispose_calls = [] + original_dispose = AsyncEngine.dispose + + async def dispose_spy(self, *args, **kwargs): + dispose_calls.append(self) + return await original_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", dispose_spy) + + app.add_middleware(SQLAlchemyMiddleware, custom_engine=engine) + + @app.get("/") + async def get_value(): + result = await db.session.execute(text("SELECT 2")) + return {"value": result.scalar()} + + try: + with TestClient(app) as client: + assert client.get("/").json() == {"value": 2} + + assert engine not in dispose_calls + finally: + await original_dispose(engine) + + +@pytest.mark.asyncio +async def test_same_proxy_rejects_reinitialization_with_different_live_engine(): + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + app = FastAPI() + engine_one = create_async_engine(DB_URL) + engine_two = create_async_engine(DB_URL) + + try: + SQLAlchemyMiddleware(app, custom_engine=engine_one) + + with pytest.raises(RuntimeError, match="create_middleware_and_session_proxy"): + SQLAlchemyMiddleware(app, custom_engine=engine_two) + finally: + await engine_one.dispose() + await engine_two.dispose() + + +@pytest.mark.asyncio +async def test_independent_proxies_keep_separate_engine_bindings(): + FirstMiddleware, first_db = _make_middleware_and_db() + SecondMiddleware, second_db = _make_middleware_and_db() + app = FastAPI() + first_engine = create_async_engine(DB_URL) + second_engine = create_async_engine(DB_URL) + + try: + FirstMiddleware(app, custom_engine=first_engine) + SecondMiddleware(app, custom_engine=second_engine) + + async with first_db(commit_on_exit=True): + await first_db.session.execute(text("CREATE TABLE marker (value TEXT)")) + await first_db.session.execute(text("INSERT INTO marker VALUES ('first')")) + + async with second_db(commit_on_exit=True): + await second_db.session.execute(text("CREATE TABLE marker (value TEXT)")) + await second_db.session.execute(text("INSERT INTO marker VALUES ('second')")) + + async with first_db(): + first_value = ( + await first_db.session.execute(text("SELECT value FROM marker")) + ).scalar_one() + + async with second_db(): + second_value = ( + await second_db.session.execute(text("SELECT value FROM marker")) + ).scalar_one() + + assert first_value == "first" + assert second_value == "second" + finally: + await first_engine.dispose() + await second_engine.dispose() + + +@pytest.mark.asyncio +async def test_throttled_gather_fail_fast_closes_unopened_coroutines(): + # Scheduling-order assumption: with max_concurrent=1, only `first` can hold + # the semaphore slot at any time. When it raises, fail-fast cancels the + # `second`/`third` tasks before they ever enter their throttled wrapper — + # so their underlying coroutines must remain unstarted (CORO_CREATED) and + # be `coro.close()`'d explicitly by db.gather() rather than awaited. If + # this scheduling assumption ever breaks (e.g. a future scheduler runs + # `_throttled` for second/third synchronously), the assertion must still + # hold — it is the contract — but the path that satisfies it changes. + SQLAlchemyMiddleware, db = _make_middleware_and_db() + SQLAlchemyMiddleware(FastAPI(), db_url=DB_URL) + + async def fail_fast(): + raise RuntimeError("boom") + + async def never_opened(): + await asyncio.sleep(0) + return "not reached" + + first = fail_fast() + second = never_opened() + third = never_opened() + + with pytest.raises(RuntimeError, match="boom"): + async with db(multi_sessions=True, max_concurrent=1): + await db.gather(first, second, third) + + assert inspect.getcoroutinestate(second) == inspect.CORO_CLOSED + assert inspect.getcoroutinestate(third) == inspect.CORO_CLOSED + + +@pytest.mark.asyncio +async def test_throttled_gather_rejects_precreated_task_and_future(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + SQLAlchemyMiddleware(FastAPI(), db_url=DB_URL) + loop = asyncio.get_running_loop() + task = loop.create_task(asyncio.sleep(10)) + future = loop.create_future() + + try: + async with db(multi_sessions=True, max_concurrent=1): + with pytest.raises(TypeError, match="coroutine objects"): + await db.gather(task) + + with pytest.raises(TypeError, match="coroutine objects"): + await db.gather(future) + finally: + task.cancel() + future.cancel() + await asyncio.gather(task, future, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_child_session_creation_is_rejected_after_multi_state_starts_closing(): + SQLAlchemyMiddleware, db = _make_middleware_and_db() + SQLAlchemyMiddleware(FastAPI(), db_url=DB_URL) + multi_state_var = _get_ctx_var(db, "_multi_state") + + async with db(multi_sessions=True): + state = multi_state_var.get() + assert state is not None + state.closing = True + + async def use_db_after_cleanup_started(): + return db.session + + with pytest.raises(RuntimeError, match="closing"): + await asyncio.create_task(use_db_after_cleanup_started()) + + +@pytest.mark.asyncio +async def test_rejected_db_url_reinit_does_not_allocate_engine(monkeypatch): + """A rejected `db_url` re-init must not allocate a fresh engine.""" + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + app = FastAPI() + first_engine = create_async_engine(DB_URL) + + try: + SQLAlchemyMiddleware(app, custom_engine=first_engine) + + create_calls = [] + + def fail_if_called(*args, **kwargs): + create_calls.append((args, kwargs)) + raise AssertionError( + "create_async_engine must not be called when proxy re-init is rejected" + ) + + monkeypatch.setattr(middleware_module, "create_async_engine", fail_if_called) + + with pytest.raises(RuntimeError, match="create_middleware_and_session_proxy"): + SQLAlchemyMiddleware(app, db_url=DB_URL) + + assert create_calls == [] + finally: + await first_engine.dispose() + + +@pytest.mark.asyncio +async def test_dispose_failure_clears_bindings_and_allows_retry(monkeypatch): + """If engine.dispose() raises, bindings clear and a retry actually re-runs it.""" + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, db = _make_middleware_and_db() + engine = create_async_engine(DB_URL) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + middleware = SQLAlchemyMiddleware(FastAPI(), db_url=DB_URL) + + # _Session_engine / _Session are nonlocals in the proxy factory closure; + # the middleware's `dispose` method closes over both, so we can inspect + # the same cells through it. + closure = dict( + zip( + middleware.dispose.__func__.__code__.co_freevars, + middleware.dispose.__func__.__closure__, + strict=True, + ) + ) + assert closure["_Session_engine"].cell_contents is engine + assert closure["_Session"].cell_contents is not None + + real_dispose = AsyncEngine.dispose + call_count = {"n": 0} + + async def flaky_dispose(self, *args, **kwargs): + if self is engine: + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("dispose boom") + return await real_dispose(self, *args, **kwargs) + + monkeypatch.setattr(AsyncEngine, "dispose", flaky_dispose) + + with pytest.raises(RuntimeError, match="dispose boom"): + await middleware.dispose() + + # Bindings cleared by the finally block. + assert closure["_Session_engine"].cell_contents is None + assert closure["_Session"].cell_contents is None + + # _engine_disposed must NOT be set on failure, so a retry actually re-runs. + assert middleware._engine_disposed is False + + # Retry: now succeeds and underlying dispose is invoked again. + await middleware.dispose() + assert call_count["n"] == 2 + assert middleware._engine_disposed is True + + +@pytest.mark.asyncio +async def test_lifespan_ack_is_forwarded_when_dispose_raises(monkeypatch): + """ASGI lifespan ack must be forwarded even if engine disposal raises.""" + import fastapi_async_sqlalchemy.middleware as middleware_module + + SQLAlchemyMiddleware, _db = _make_middleware_and_db() + engine = create_async_engine(DB_URL) + monkeypatch.setattr(middleware_module, "create_async_engine", lambda *_, **__: engine) + + sent_messages = [] + shutdown_sent = [] + + async def receive(): + # First call: initiate shutdown. Subsequent calls would park, but the + # inner_app acks immediately so this branch is unreachable in practice. + if not shutdown_sent: + shutdown_sent.append(True) + return {"type": "lifespan.shutdown"} + await asyncio.Event().wait() # pragma: no cover + + async def send(message): + sent_messages.append(message) + + async def inner_app(_scope, receive_, send_): + # Minimal lifespan handler: receive shutdown, send shutdown.complete. + msg = await receive_() + assert msg["type"] == "lifespan.shutdown" + await send_({"type": "lifespan.shutdown.complete"}) + + middleware = SQLAlchemyMiddleware(inner_app, db_url=DB_URL) + + async def failing_dispose(): + raise RuntimeError("dispose fail") + + monkeypatch.setattr(middleware, "dispose", failing_dispose) + + with pytest.raises(RuntimeError, match="dispose fail"): + await middleware({"type": "lifespan"}, receive, send) + + assert sent_messages == [{"type": "lifespan.shutdown.complete"}] + + # Cleanup: real engine still needs disposal. + await engine.dispose() + + +@pytest.mark.asyncio +async def test_throttled_gather_cancellation_emits_no_resource_warning(): + """db.gather() fail-fast must not emit `coroutine ... was never awaited`.""" + SQLAlchemyMiddleware, db = _make_middleware_and_db() + SQLAlchemyMiddleware(FastAPI(), db_url=DB_URL) + + async def fail_fast(): + raise RuntimeError("boom") + + async def never_opened(): + await asyncio.sleep(0) + return "not reached" + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + + first = fail_fast() + second = never_opened() + third = never_opened() + + with pytest.raises(RuntimeError, match="boom"): + async with db(multi_sessions=True, max_concurrent=1): + await db.gather(first, second, third) + + leak_warnings = [ + w + for w in caught + if issubclass(w.category, ResourceWarning) and "was never awaited" in str(w.message) + ] + assert leak_warnings == [], ( + f"Unexpected ResourceWarning(s): {[str(w.message) for w in leak_warnings]}" + ) diff --git a/tests/test_session.py b/tests/test_session.py index 1abe5ce..121951d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,10 +1,10 @@ import asyncio import pytest +from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from starlette.middleware.base import BaseHTTPMiddleware from fastapi_async_sqlalchemy.exceptions import ( MissingSessionError, @@ -15,9 +15,11 @@ @pytest.mark.asyncio -async def test_init(app, SQLAlchemyMiddleware): +async def test_init(app, db, SQLAlchemyMiddleware): mw = SQLAlchemyMiddleware(app, db_url=db_url) - assert isinstance(mw, BaseHTTPMiddleware) + # Pure ASGI middleware: must be callable with (scope, receive, send). + assert callable(mw) + assert mw.app is app @pytest.mark.asyncio @@ -31,7 +33,10 @@ async def test_init_required_args(app, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_init_required_args_custom_engine(app, db, SQLAlchemyMiddleware): custom_engine = create_async_engine(db_url) - SQLAlchemyMiddleware(app, custom_engine=custom_engine) + try: + SQLAlchemyMiddleware(app, custom_engine=custom_engine) + finally: + await custom_engine.dispose() @pytest.mark.asyncio @@ -60,14 +65,15 @@ async def test_init_incorrect_optional_args(app, SQLAlchemyMiddleware): @pytest.mark.asyncio -async def test_inside_route(app, client, db, SQLAlchemyMiddleware): +async def test_inside_route(app, db, SQLAlchemyMiddleware): app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) @app.get("/") def test_get(): assert isinstance(db.session, AsyncSession) - client.get("/") + with TestClient(app) as client: + client.get("/") @pytest.mark.asyncio @@ -82,7 +88,7 @@ def test_get(): @pytest.mark.asyncio async def test_outside_of_route(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): assert isinstance(db.session, AsyncSession) @@ -100,7 +106,7 @@ async def test_outside_of_route_without_middleware_fails(db): @pytest.mark.asyncio async def test_outside_of_route_without_context_fails(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) with pytest.raises(MissingSessionError): _ = db.session @@ -108,7 +114,7 @@ async def test_outside_of_route_without_context_fails(app, db, SQLAlchemyMiddlew @pytest.mark.asyncio async def test_init_session(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): assert isinstance(db.session, AsyncSession) @@ -116,7 +122,7 @@ async def test_init_session(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_db_session_commit_fail(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=True) + SQLAlchemyMiddleware(app, db_url=db_url, commit_on_exit=True) with pytest.raises(IntegrityError): async with db(): @@ -132,7 +138,7 @@ async def test_rollback(app, db, SQLAlchemyMiddleware): # pytest-cov shows that the line in db.__exit__() rolling back the db session # when there is an Exception is run correctly. However, it would be much better # if we could demonstrate somehow that db.session.rollback() was called e.g. once - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) with pytest.raises(RuntimeError): async with db(): @@ -144,7 +150,7 @@ async def test_rollback(app, db, SQLAlchemyMiddleware): @pytest.mark.parametrize("commit_on_exit", [True, False]) @pytest.mark.asyncio async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_exit): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=commit_on_exit) + SQLAlchemyMiddleware(app, db_url=db_url, commit_on_exit=commit_on_exit) session_args = {} @@ -158,7 +164,7 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_ @pytest.mark.asyncio async def test_multi_sessions(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(multi_sessions=True): @@ -180,7 +186,7 @@ async def execute_query(query): @pytest.mark.asyncio async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware): - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(multi_sessions=True, commit_on_exit=True): await db.session.execute( diff --git a/tests/test_single_session_no_gather.py b/tests/test_single_session_no_gather.py index 4b1cba7..9c251ad 100644 --- a/tests/test_single_session_no_gather.py +++ b/tests/test_single_session_no_gather.py @@ -19,7 +19,7 @@ @pytest.mark.asyncio async def test_single_session_sequential_queries(app, db, SQLAlchemyMiddleware): """Sequential queries should work with single session.""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): result1 = await db.session.execute(text("SELECT 1")) @@ -31,7 +31,7 @@ async def test_single_session_sequential_queries(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_single_session_same_instance(app, db, SQLAlchemyMiddleware): """Same session instance should be returned within context.""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): session1 = db.session @@ -43,7 +43,7 @@ async def test_single_session_same_instance(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_single_session_gather_fails(app, db, SQLAlchemyMiddleware): """asyncio.gather() without multi_sessions=True should raise an error.""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) with pytest.raises((InvalidRequestError, IllegalStateChangeError)): async with db(): @@ -56,13 +56,13 @@ async def test_single_session_gather_fails(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_multi_sessions_gather_with_tasks(app, db, SQLAlchemyMiddleware): """asyncio.gather() with multi_sessions=True and create_task should work.""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(multi_sessions=True): async def query(n): res = await db.session.execute(text(f"SELECT {n}")) - return res + return res.scalar() tasks = [ asyncio.create_task(query(1)), @@ -70,12 +70,11 @@ async def query(n): ] results = await asyncio.gather(*tasks) - assert results[0].scalar() == 1 - assert results[1].scalar() == 2 + assert results == [1, 2] @pytest.mark.asyncio -async def test_single_session_in_route(app, client, db, SQLAlchemyMiddleware): +async def test_single_session_in_route(app, db, SQLAlchemyMiddleware): """Single session should work in route handler.""" app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) @@ -84,13 +83,16 @@ async def test_route(): result = await db.session.execute(text("SELECT 42")) return {"value": result.scalar()} - response = client.get("/test") + from fastapi.testclient import TestClient + + with TestClient(app) as client: + response = client.get("/test") assert response.status_code == 200 assert response.json() == {"value": 42} @pytest.mark.asyncio -async def test_single_session_multiple_sequential_in_route(app, client, db, SQLAlchemyMiddleware): +async def test_single_session_multiple_sequential_in_route(app, db, SQLAlchemyMiddleware): """Multiple sequential queries in route should work.""" app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) @@ -101,6 +103,9 @@ async def test_route(): r3 = await db.session.execute(text("SELECT 3")) return {"values": [r1.scalar(), r2.scalar(), r3.scalar()]} - response = client.get("/test-sequential") + from fastapi.testclient import TestClient + + with TestClient(app) as client: + response = client.get("/test-sequential") assert response.status_code == 200 assert response.json() == {"values": [1, 2, 3]} diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index 4a0ba6b..fc54300 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -1,6 +1,5 @@ -from typing import Optional - import pytest +from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession @@ -26,17 +25,17 @@ class Hero(SQLModel, table=True): # type: ignore __tablename__ = "test_hero" - id: Optional[int] = Field(default=None, primary_key=True) + id: int | None = Field(default=None, primary_key=True) name: str = Field(index=True) secret_name: str - age: Optional[int] = Field(default=None, index=True) + age: int | None = Field(default=None, index=True) @pytest.mark.skipif(not SQLMODEL_AVAILABLE, reason="SQLModel not available") @pytest.mark.asyncio async def test_sqlmodel_session_type(app, db, SQLAlchemyMiddleware): """Test that SQLModel's AsyncSession is used when SQLModel is available""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Should be SQLModel's AsyncSession, not regular SQLAlchemy AsyncSession @@ -48,7 +47,7 @@ async def test_sqlmodel_session_type(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_exec_method_exists(app, db, SQLAlchemyMiddleware): """Test that the .exec() method is available on the session""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Test that exec method exists @@ -60,7 +59,7 @@ async def test_sqlmodel_exec_method_exists(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_exec_method_basic_query(app, db, SQLAlchemyMiddleware): """Test that the .exec() method works with basic SQLModel queries""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Create tables using the session's bind engine @@ -79,7 +78,7 @@ async def test_sqlmodel_exec_method_basic_query(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_exec_crud_operations(app, db, SQLAlchemyMiddleware): """Test CRUD operations using SQLModel with .exec() method""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): # Create tables using the session's bind engine @@ -110,7 +109,7 @@ async def test_sqlmodel_exec_crud_operations(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_exec_with_where_clause(app, db, SQLAlchemyMiddleware): """Test .exec() method with WHERE clauses""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): # Create tables using the session's bind engine @@ -143,7 +142,7 @@ async def test_sqlmodel_exec_with_where_clause(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_exec_returns_sqlmodel_objects(app, db, SQLAlchemyMiddleware): """Test that .exec() returns actual SQLModel objects, not Row objects""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(commit_on_exit=True): # Create tables using the session's bind engine @@ -173,7 +172,7 @@ async def test_sqlmodel_exec_returns_sqlmodel_objects(app, db, SQLAlchemyMiddlew @pytest.mark.asyncio async def test_backward_compatibility_with_regular_execute(app, db, SQLAlchemyMiddleware): """Test that regular SQLAlchemy .execute() method still works for backward compatibility""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Test regular execute with text query @@ -186,7 +185,7 @@ async def test_backward_compatibility_with_regular_execute(app, db, SQLAlchemyMi @pytest.mark.asyncio async def test_session_type_without_sqlmodel(app, db, SQLAlchemyMiddleware): """Test that when SQLModel is not available, regular AsyncSession is used""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Should still be an AsyncSession (either SQLModel or regular) @@ -201,7 +200,7 @@ async def test_session_type_without_sqlmodel(app, db, SQLAlchemyMiddleware): @pytest.mark.skipif(not SQLMODEL_AVAILABLE, reason="SQLModel not available") @pytest.mark.asyncio -async def test_sqlmodel_exec_in_route(app, client, db, SQLAlchemyMiddleware): +async def test_sqlmodel_exec_in_route(app, db, SQLAlchemyMiddleware): """Test SQLModel .exec() method works inside FastAPI routes""" app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) @@ -227,7 +226,8 @@ async def test_route(): "name": found_hero.name if found_hero else None, } - response = client.get("/test-sqlmodel") + with TestClient(app) as client: + response = client.get("/test-sqlmodel") data = response.json() assert data["found"] is True assert data["is_sqlmodel"] is True @@ -238,7 +238,7 @@ async def test_route(): @pytest.mark.asyncio async def test_sqlmodel_exec_multi_sessions(app, db, SQLAlchemyMiddleware): """Test SQLModel .exec() method works with multi_sessions=True""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(multi_sessions=True): async with db.session.bind.begin() as conn: @@ -264,7 +264,7 @@ async def test_sqlmodel_exec_multi_sessions(app, db, SQLAlchemyMiddleware): @pytest.mark.asyncio async def test_sqlmodel_session_has_both_exec_and_execute(app, db, SQLAlchemyMiddleware): """Test that SQLModel session has both .exec() and .execute() methods""" - app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + SQLAlchemyMiddleware(app, db_url=db_url) async with db(): # Should have both methods diff --git a/tests/test_streaming_and_waiter_shutdown.py b/tests/test_streaming_and_waiter_shutdown.py new file mode 100644 index 0000000..e613eee --- /dev/null +++ b/tests/test_streaming_and_waiter_shutdown.py @@ -0,0 +1,90 @@ +"""Regression tests for two specific bugs: + +1. Streaming response generators that need database access must own an + explicit ``async with db()`` body-lifetime session. + +2. ``DBSession.__aexit__`` cancelled tasks already in ``state.task_sessions`` + but ignored tasks parked on the semaphore inside ``db.connection()``. + A waiter could acquire the slot freed by a cancelled task, create a + session post-closing, and run queries against state being torn down. +""" + +from __future__ import annotations + +import asyncio + +import pytest +from fastapi.responses import StreamingResponse +from fastapi.testclient import TestClient +from sqlalchemy import text + +db_url = "sqlite+aiosqlite://" + + +@pytest.mark.asyncio +async def test_streaming_response_keeps_session_open(app, db, SQLAlchemyMiddleware): + """An explicit streaming DB session must outlive a StreamingResponse body.""" + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + @app.get("/stream") + async def stream(): + async def body(): + async with db(): + for i in range(3): + result = await db.session.execute(text(f"SELECT {i}")) + yield f"{result.scalar()}\n".encode() + + return StreamingResponse(body(), media_type="text/plain") + + with TestClient(app) as client: + resp = client.get("/stream") + assert resp.status_code == 200 + assert resp.text == "0\n1\n2\n" + + +@pytest.mark.asyncio +async def test_semaphore_waiters_cancelled_on_shutdown(app, db, SQLAlchemyMiddleware): + """Tasks parked on db.connection()'s semaphore must be cancelled when + the owning multi-session context exits, and must not create a session + after closing has begun.""" + SQLAlchemyMiddleware(app, db_url=db_url) + + holder_started = asyncio.Event() + holder_release = asyncio.Event() + waiter_started = asyncio.Event() + waiter_ran_query = False + + async def holder(): + async with db.connection() as session: + holder_started.set() + await session.execute(text("SELECT 1")) + await holder_release.wait() + + async def waiter(): + nonlocal waiter_ran_query + waiter_started.set() + async with db.connection() as session: + # Should never get here once shutdown begins. + await session.execute(text("SELECT 2")) + waiter_ran_query = True + + async with db(multi_sessions=True, max_concurrent=1): + holder_task = asyncio.create_task(holder()) + await holder_started.wait() + + waiter_task = asyncio.create_task(waiter()) + # Give the waiter a chance to enter __aenter__ and park on acquire(). + await waiter_started.wait() + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Let the holder finish so its slot is released. + holder_release.set() + await holder_task + + # __aexit__ has now run; the waiter must not have created a session. + assert waiter_ran_query is False + assert waiter_task.done() + # On a clean context exit, __aexit__ cancels tasks parked on the + # semaphore deterministically — the waiter must be reported as cancelled. + assert waiter_task.cancelled() diff --git a/tests/test_type_hints_compatibility.py b/tests/test_type_hints_compatibility.py index 8a07a1a..92f1f63 100644 --- a/tests/test_type_hints_compatibility.py +++ b/tests/test_type_hints_compatibility.py @@ -82,7 +82,7 @@ def get_db(self) -> DBSessionMeta: def test_type_checking_with_callable(): """Test that DBSessionMeta works with callable type checking""" - from typing import Callable + from collections.abc import Callable def factory() -> Callable[[], DBSessionMeta]: # type: ignore[valid-type] def get_session() -> DBSessionMeta: # type: ignore[valid-type] @@ -146,9 +146,8 @@ def func2(session_proxy: DBSessionMeta) -> bool: return session_proxy is db # Pattern 3: Optional type - from typing import Optional - def func3() -> Optional[DBSessionMeta]: + def func3() -> DBSessionMeta | None: return db # Verify all patterns work