diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 5c2a1468..3df0d31b 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -383,7 +383,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover if self.on_exit is not None: self.on_exit(self) - async def prefetcher( + async def prefetcher( # noqa: C901 self, queue: "asyncio.Queue[bytes | AckableMessage]", finish_event: asyncio.Event, @@ -396,48 +396,70 @@ async def prefetcher( """ fetched_tasks: int = 0 iterator = self.broker.listen() - current_message: asyncio.Task[bytes | AckableMessage] = asyncio.create_task( - iterator.__anext__(), # type: ignore - ) + current_message: asyncio.Task[bytes | AckableMessage] | None = None - while True: - if finish_event.is_set(): - break - try: - await self.sem_prefetch.acquire() - if ( - self.max_tasks_to_execute - and fetched_tasks >= self.max_tasks_to_execute - ): - logger.info("Max number of tasks executed.") + try: + while not finish_event.is_set(): + try: + await self.sem_prefetch.acquire() + if ( + self.max_tasks_to_execute + and fetched_tasks >= self.max_tasks_to_execute + ): + logger.info("Max number of tasks executed.") + break + if current_message is None: + current_message = asyncio.create_task( + iterator.__anext__(), # type: ignore + ) + # Here we wait for the message to be fetched, + # but we make it with timeout so it can be interrupted + done, _ = await asyncio.wait({current_message}, timeout=0.3) + # If the message is not fetched, we release the semaphore + # and continue the loop. So it will check if finished event was set. + if not done: + self.sem_prefetch.release() + continue + # We're done, so now we need to check + # whether task has returned an error. + message = current_message.result() + current_message = None + fetched_tasks += 1 + await queue.put(message) + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_add"): + await maybe_awaitable( + middleware.on_prefetch_queue_add(), # type: ignore + ) + except (asyncio.CancelledError, StopAsyncIteration): break - # Here we wait for the message to be fetched, - # but we make it with timeout so it can be interrupted - done, _ = await asyncio.wait({current_message}, timeout=0.3) - # If the message is not fetched, we release the semaphore - # and continue the loop. So it will check if finished event was set. - if not done: - self.sem_prefetch.release() + except Exception: + logger.exception("Error while prefetching.") + # current_message set => fetch failed before enqueue, so we + # still own the permit and a (possibly broken) iterator. + # Otherwise it's queued and the runner owns the permit; + # releasing here would leak a prefetch slot. + if current_message is not None: + current_message = None + iterator = self.broker.listen() + self.sem_prefetch.release() continue - # We're done, so now we need to check - # whether task has returned an error. - message = current_message.result() - current_message = asyncio.create_task(iterator.__anext__()) # type: ignore - fetched_tasks += 1 - await queue.put(message) - # Custom hooks for OTel and any future instrumentations - for middleware in reversed(self.broker.middlewares): - if hasattr(middleware, "on_prefetch_queue_add"): - await maybe_awaitable( - middleware.on_prefetch_queue_add(), # type: ignore - ) - except (asyncio.CancelledError, StopAsyncIteration): - break - # We don't want to fetch new messages if we are shutting down. - logger.info("Stopping prefetching messages...") - current_message.cancel() - await queue.put(QUEUE_DONE) - self.sem_prefetch.release() + finally: + # We don't want to fetch new messages if we are shutting down. + logger.info("Stopping prefetching messages...") + # Short window to deliver, then forward or cancel. + if current_message is not None: + await asyncio.wait({current_message}, timeout=0.3) + if not current_message.done(): + current_message.cancel() + elif ( + not current_message.cancelled() + and current_message.exception() is None + ): + await queue.put(current_message.result()) + await queue.put(QUEUE_DONE) + self.sem_prefetch.release() async def runner( self, diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index eeb29c11..bed822d4 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -3,7 +3,7 @@ import random import time import unittest.mock -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator from concurrent.futures import ThreadPoolExecutor from functools import wraps from typing import Any, ClassVar @@ -600,3 +600,68 @@ async def test_no_semaphore_without_max_async_tasks() -> None: """Test that semaphore is None when max_async_tasks is not set.""" receiver = get_receiver(max_async_tasks=None) assert receiver.sem is None + + +async def test_prefetcher_does_not_pop_message_past_max_tasks() -> None: + """Test not pulling a message without the intention of running it.""" + broker = AsyncQueueBroker() + + @broker.task + async def noop() -> None: + return None + + for _ in range(6): + await noop.kiq() + + assert broker.queue.qsize() == 6 + + receiver = Receiver( + broker, + executor=ThreadPoolExecutor(max_workers=1), + max_async_tasks=1, + max_tasks_to_execute=5, + ) + + await receiver.listen(asyncio.Event()) + + assert broker.queue.qsize() == 1 + + +async def test_prefetcher_recovers_from_transient_listen_error() -> None: + """A transient error mid-prefetch must not kill the prefetcher.""" + + class FlakyBroker(AsyncQueueBroker): + def __init__(self) -> None: + super().__init__() + self.fail_once = True + + async def listen(self) -> AsyncGenerator[AckableMessage, None]: + while True: + data = await self.queue.get() + if self.fail_once: + self.fail_once = False + self.queue.task_done() + raise RuntimeError("transient broker hiccup") + yield AckableMessage(data=data, ack=self.queue.task_done) + + broker = FlakyBroker() + ran = 0 + + @broker.task + async def collector() -> None: + nonlocal ran + ran += 1 + + await collector.kiq() # consumed by the transient error + await collector.kiq() # prefetcher recovering + + receiver = Receiver( + broker, + executor=ThreadPoolExecutor(max_workers=1), + max_async_tasks=1, + max_tasks_to_execute=1, + ) + + await asyncio.wait_for(receiver.listen(asyncio.Event()), timeout=5) + + assert ran == 1