diff --git a/arroyo/processing/strategies/run_task.py b/arroyo/processing/strategies/run_task.py index 29c4d931..11627d26 100644 --- a/arroyo/processing/strategies/run_task.py +++ b/arroyo/processing/strategies/run_task.py @@ -1,8 +1,9 @@ from __future__ import annotations +import time from typing import Callable, Generic, Optional, TypeVar, Union, cast -from arroyo.processing.strategies.abstract import ProcessingStrategy +from arroyo.processing.strategies.abstract import MessageRejected, ProcessingStrategy from arroyo.processing.strategies.guard import StrategyGuard from arroyo.types import FilteredPayload, Message, TStrategyPayload @@ -24,12 +25,13 @@ def __new__( cls, function: Callable[[Message[TStrategyPayload]], TResult], next_step: ProcessingStrategy[Union[FilteredPayload, TResult]], + better_backpressure: bool = False, ) -> RunTask[TStrategyPayload, TResult]: def build_self( next_step: ProcessingStrategy[Union[FilteredPayload, TResult]] ) -> ProcessingStrategy[Union[FilteredPayload, TResult]]: self = object.__new__(RunTask) - self.__init__(function, next_step) # type: ignore + self.__init__(function, next_step, better_backpressure) # type: ignore return self return cast( @@ -40,25 +42,67 @@ def __init__( self, function: Callable[[Message[TStrategyPayload]], TResult], next_step: ProcessingStrategy[Union[FilteredPayload, TResult]], + better_backpressure: bool = False, ) -> None: self.__function = function self.__next_step = next_step + self.__better_backpressure = better_backpressure + self.__message_carried_over: Optional[Message[TResult]] = None def submit( self, message: Message[Union[FilteredPayload, TStrategyPayload]] ) -> None: - result = self.__function(cast(Message[TStrategyPayload], message)) - value = message.value.replace(result) - self.__next_step.submit(Message(value)) + if self.__better_backpressure: + if self.__message_carried_over is not None: + raise MessageRejected(message) + + result = self.__function(cast(Message[TStrategyPayload], message)) + value = message.value.replace(result) + transformed: Message[TResult] = Message(value) + + try: + self.__next_step.submit(transformed) + except MessageRejected: + self.__message_carried_over = transformed + else: + result = self.__function(cast(Message[TStrategyPayload], message)) + value = message.value.replace(result) + self.__next_step.submit(Message(value)) def poll(self) -> None: self.__next_step.poll() + if self.__better_backpressure and self.__message_carried_over is not None: + try: + self.__next_step.submit(self.__message_carried_over) + self.__message_carried_over = None + except MessageRejected: + pass + def join(self, timeout: Optional[float] = None) -> None: - self.__next_step.join(timeout=timeout) + deadline = time.time() + timeout if timeout is not None else None + + if self.__better_backpressure: + msg = self.__message_carried_over + if msg is not None: + while deadline is None or time.time() < deadline: + self.__next_step.poll() + try: + self.__next_step.submit(msg) + self.__message_carried_over = None + break + except MessageRejected: + pass + + remaining = max(deadline - time.time(), 0) if deadline is not None else None + self.__next_step.close() + self.__next_step.join(timeout=remaining) + else: + self.__next_step.join(timeout=timeout) def close(self) -> None: - self.__next_step.close() + if not self.__better_backpressure: + self.__next_step.close() def terminate(self) -> None: self.__next_step.terminate() diff --git a/tests/processing/strategies/test_run_task.py b/tests/processing/strategies/test_run_task.py index 5ad3e76f..e371351c 100644 --- a/tests/processing/strategies/test_run_task.py +++ b/tests/processing/strategies/test_run_task.py @@ -2,17 +2,21 @@ from datetime import datetime from unittest.mock import Mock, call +import pytest + +from arroyo.processing.strategies.abstract import MessageRejected from arroyo.processing.strategies.run_task import RunTask from arroyo.types import BrokerValue, Message, Partition, Topic, Value from tests.assertions import assert_changes -def test_run_task() -> None: +@pytest.mark.parametrize("better_backpressure", [False, True]) +def test_run_task(better_backpressure: bool) -> None: mock_func = Mock() next_step = Mock() now = datetime.now() - strategy = RunTask(mock_func, next_step) + strategy = RunTask(mock_func, next_step, better_backpressure=better_backpressure) partition = Partition(Topic("topic"), 0) strategy.submit(Message(Value(b"hello", {partition: 1}, now))) @@ -42,14 +46,17 @@ def test_run_task() -> None: assert next_step.submit.call_count == 2 -def test_transform() -> None: +@pytest.mark.parametrize("better_backpressure", [False, True]) +def test_transform(better_backpressure: bool) -> None: next_step = Mock() now = datetime.now() def transform_function(value: Message[int]) -> int: return value.payload * 2 - transform_step = RunTask(transform_function, next_step) + transform_step = RunTask( + transform_function, next_step, better_backpressure=better_backpressure + ) original_message = Message(Value(1, {Partition(Topic("topic"), 0): 1}, now)) @@ -95,3 +102,63 @@ def transform_function(value: Message[int]) -> int: ) ) ) + + +def test_backpressure_function_called_once() -> None: + """ + With better_backpressure=True, the function should only be called once + per message even when next_step raises MessageRejected. + """ + call_count = 0 + + def counting_function(msg: Message[bytes]) -> bytes: + nonlocal call_count + call_count += 1 + return msg.payload + + next_step = Mock() + # First call rejects, second accepts + next_step.submit.side_effect = [MessageRejected(Mock()), None] + + strategy = RunTask(counting_function, next_step, better_backpressure=True) + partition = Partition(Topic("topic"), 0) + now = datetime.now() + + msg = Message(Value(b"hello", {partition: 1}, now)) + strategy.submit(msg) + + assert call_count == 1 + + # poll() should retry and succeed + strategy.poll() + + assert call_count == 1 + assert next_step.submit.call_count == 2 + + +def test_backpressure_join_flushes_message() -> None: + """ + With better_backpressure=True, join() should flush carried-over messages. + """ + + def identity(msg: Message[bytes]) -> bytes: + return msg.payload + + next_step = Mock() + # First submit rejects, then accepts during join + next_step.submit.side_effect = [MessageRejected(Mock()), None] + + strategy = RunTask(identity, next_step, better_backpressure=True) + partition = Partition(Topic("topic"), 0) + now = datetime.now() + + msg = Message(Value(b"hello", {partition: 1}, now)) + strategy.submit(msg) + + assert next_step.submit.call_count == 1 + + # join() should flush the carried-over message + strategy.join(timeout=1.0) + + assert next_step.submit.call_count == 2 + assert next_step.join.call_count == 1