Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions arroyo/processing/strategies/run_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import time
from typing import Callable, Generic, Optional, TypeVar, Union, cast

from arroyo.processing.strategies.abstract import ProcessingStrategy
from arroyo.processing.strategies.abstract import MessageRejected, ProcessingStrategy
from arroyo.processing.strategies.guard import StrategyGuard
from arroyo.types import FilteredPayload, Message, TStrategyPayload

Expand All @@ -24,12 +25,13 @@ def __new__(
cls,
function: Callable[[Message[TStrategyPayload]], TResult],
next_step: ProcessingStrategy[Union[FilteredPayload, TResult]],
better_backpressure: bool = False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this works well it seems like it could be the default going forward.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should be the intent, just didn't want to take the risk now

) -> RunTask[TStrategyPayload, TResult]:
def build_self(
next_step: ProcessingStrategy[Union[FilteredPayload, TResult]]
) -> ProcessingStrategy[Union[FilteredPayload, TResult]]:
self = object.__new__(RunTask)
self.__init__(function, next_step) # type: ignore
self.__init__(function, next_step, better_backpressure) # type: ignore
return self

return cast(
Expand All @@ -40,25 +42,67 @@ def __init__(
self,
function: Callable[[Message[TStrategyPayload]], TResult],
next_step: ProcessingStrategy[Union[FilteredPayload, TResult]],
better_backpressure: bool = False,
) -> None:
self.__function = function
self.__next_step = next_step
self.__better_backpressure = better_backpressure
self.__message_carried_over: Optional[Message[TResult]] = None

def submit(
self, message: Message[Union[FilteredPayload, TStrategyPayload]]
) -> None:
result = self.__function(cast(Message[TStrategyPayload], message))
value = message.value.replace(result)
self.__next_step.submit(Message(value))
if self.__better_backpressure:
if self.__message_carried_over is not None:
raise MessageRejected(message)

result = self.__function(cast(Message[TStrategyPayload], message))
value = message.value.replace(result)
transformed: Message[TResult] = Message(value)

try:
self.__next_step.submit(transformed)
except MessageRejected:
self.__message_carried_over = transformed
else:
result = self.__function(cast(Message[TStrategyPayload], message))
value = message.value.replace(result)
self.__next_step.submit(Message(value))

def poll(self) -> None:
self.__next_step.poll()

if self.__better_backpressure and self.__message_carried_over is not None:
try:
self.__next_step.submit(self.__message_carried_over)
self.__message_carried_over = None
except MessageRejected:
pass

def join(self, timeout: Optional[float] = None) -> None:
self.__next_step.join(timeout=timeout)
deadline = time.time() + timeout if timeout is not None else None

if self.__better_backpressure:
msg = self.__message_carried_over
if msg is not None:
while deadline is None or time.time() < deadline:
self.__next_step.poll()
try:
self.__next_step.submit(msg)
self.__message_carried_over = None
break
except MessageRejected:
pass

remaining = max(deadline - time.time(), 0) if deadline is not None else None
self.__next_step.close()
self.__next_step.join(timeout=remaining)
else:
self.__next_step.join(timeout=timeout)

def close(self) -> None:
self.__next_step.close()
if not self.__better_backpressure:
self.__next_step.close()

def terminate(self) -> None:
self.__next_step.terminate()
75 changes: 71 additions & 4 deletions tests/processing/strategies/test_run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
from datetime import datetime
from unittest.mock import Mock, call

import pytest

from arroyo.processing.strategies.abstract import MessageRejected
from arroyo.processing.strategies.run_task import RunTask
from arroyo.types import BrokerValue, Message, Partition, Topic, Value
from tests.assertions import assert_changes


def test_run_task() -> None:
@pytest.mark.parametrize("better_backpressure", [False, True])
def test_run_task(better_backpressure: bool) -> None:
mock_func = Mock()
next_step = Mock()
now = datetime.now()

strategy = RunTask(mock_func, next_step)
strategy = RunTask(mock_func, next_step, better_backpressure=better_backpressure)
partition = Partition(Topic("topic"), 0)

strategy.submit(Message(Value(b"hello", {partition: 1}, now)))
Expand Down Expand Up @@ -42,14 +46,17 @@ def test_run_task() -> None:
assert next_step.submit.call_count == 2


def test_transform() -> None:
@pytest.mark.parametrize("better_backpressure", [False, True])
def test_transform(better_backpressure: bool) -> None:
next_step = Mock()
now = datetime.now()

def transform_function(value: Message[int]) -> int:
return value.payload * 2

transform_step = RunTask(transform_function, next_step)
transform_step = RunTask(
transform_function, next_step, better_backpressure=better_backpressure
)

original_message = Message(Value(1, {Partition(Topic("topic"), 0): 1}, now))

Expand Down Expand Up @@ -95,3 +102,63 @@ def transform_function(value: Message[int]) -> int:
)
)
)


def test_backpressure_function_called_once() -> None:
"""
With better_backpressure=True, the function should only be called once
per message even when next_step raises MessageRejected.
"""
call_count = 0

def counting_function(msg: Message[bytes]) -> bytes:
nonlocal call_count
call_count += 1
return msg.payload

next_step = Mock()
# First call rejects, second accepts
next_step.submit.side_effect = [MessageRejected(Mock()), None]

strategy = RunTask(counting_function, next_step, better_backpressure=True)
partition = Partition(Topic("topic"), 0)
now = datetime.now()

msg = Message(Value(b"hello", {partition: 1}, now))
strategy.submit(msg)

assert call_count == 1

# poll() should retry and succeed
strategy.poll()

assert call_count == 1
assert next_step.submit.call_count == 2


def test_backpressure_join_flushes_message() -> None:
"""
With better_backpressure=True, join() should flush carried-over messages.
"""

def identity(msg: Message[bytes]) -> bytes:
return msg.payload

next_step = Mock()
# First submit rejects, then accepts during join
next_step.submit.side_effect = [MessageRejected(Mock()), None]

strategy = RunTask(identity, next_step, better_backpressure=True)
partition = Partition(Topic("topic"), 0)
now = datetime.now()

msg = Message(Value(b"hello", {partition: 1}, now))
strategy.submit(msg)

assert next_step.submit.call_count == 1

# join() should flush the carried-over message
strategy.join(timeout=1.0)

assert next_step.submit.call_count == 2
assert next_step.join.call_count == 1
Loading