diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 5e1c6d219f4b..28568bd893c5 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -17,15 +17,21 @@ from __future__ import absolute_import +import asyncio +import inspect import logging import random +import threading import uuid +from collections.abc import AsyncIterable +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from math import floor from threading import RLock from time import sleep from time import time from types import GeneratorType +from typing import Optional import apache_beam as beam from apache_beam import TimeDomain @@ -60,6 +66,9 @@ class AsyncWrapper(beam.DoFn): [coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()])) # The below items are one per dofn (not instance) so are maps of UUID to # value. + _event_loop: Optional[asyncio.AbstractEventLoop] = None + _event_loop_thread: Optional[threading.Thread] = None + _loop_started: Optional[threading.Event] = None _processing_elements = {} _items_in_buffer = {} _pool = {} @@ -78,6 +87,7 @@ def __init__( timeout=1, max_wait_time=0.5, id_fn=None, + use_asyncio=False, ): """Wraps the sync_fn to create an asynchronous version. @@ -104,6 +114,10 @@ def __init__( schedule an item. Used in testing to ensure timeouts are met. id_fn: A function that returns a hashable object from an element. This will be used to track items instead of the element's default hash. + use_asyncio: If true, use asyncio and coroutines to process items. If + false, use ThreadPoolExecutor. Use asyncio when the work being done + is not CPU intensive and heavily waits on network or IO which can + benefit from higher parallelism. """ self._sync_fn = sync_fn self._uuid = uuid.uuid4().hex @@ -112,6 +126,7 @@ def __init__( self._max_wait_time = max_wait_time self._timer_frequency = callback_frequency self._id_fn = id_fn or (lambda x: x) + self._use_asyncio = use_asyncio if max_items_to_buffer is None: self._max_items_to_buffer = max(parallelism * 2, 10) else: @@ -126,11 +141,33 @@ def __init__( def initialize_pool(parallelism): return lambda: ThreadPoolExecutor(max_workers=parallelism) + @staticmethod + def _run_event_loop(): + """Sets up and runs the asyncio event loop in a background thread.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + AsyncWrapper._event_loop = loop + AsyncWrapper._loop_started.set() + loop.run_forever() + loop.close() + @staticmethod def reset_state(): - for pool in AsyncWrapper._pool.values(): - pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( - wait=True, cancel_futures=True) + with AsyncWrapper._lock: + if AsyncWrapper._event_loop: + AsyncWrapper._event_loop.call_soon_threadsafe( + AsyncWrapper._event_loop.stop) + if AsyncWrapper._event_loop_thread: + AsyncWrapper._event_loop_thread.join() + + AsyncWrapper._event_loop = None + AsyncWrapper._event_loop_thread = None + if AsyncWrapper._loop_started is not None: + AsyncWrapper._loop_started.clear() + + for pool in AsyncWrapper._pool.values(): + pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( + wait=True, cancel_futures=True) with AsyncWrapper._lock: AsyncWrapper._pool = {} AsyncWrapper._processing_elements = {} @@ -140,6 +177,13 @@ def setup(self): """Forwards to the wrapped dofn's setup method.""" self._sync_fn.setup() with AsyncWrapper._lock: + if self._use_asyncio and AsyncWrapper._event_loop_thread is None: + AsyncWrapper._loop_started = threading.Event() + AsyncWrapper._event_loop_thread = threading.Thread( + target=AsyncWrapper._run_event_loop, daemon=True) + AsyncWrapper._event_loop_thread.start() + AsyncWrapper._loop_started.wait() + if not self._uuid in AsyncWrapper._pool: AsyncWrapper._pool[self._uuid] = Shared() AsyncWrapper._processing_elements[self._uuid] = {} @@ -187,9 +231,41 @@ def sync_fn_process(self, element, *args, **kwargs): to_return.append(x) for x in bundle_result: to_return.append(x) - return to_return + async def async_fn_process(self, element, *args, **kwargs): + """Makes the call to the wrapped dofn's start_bundle, process + and finish_bundle methods for asynchronous DoFns. + + Args: + element: The element to process. + *args: Any additional arguments to pass to the wrapped dofn's process + method. + **kwargs: Any additional keyword arguments to pass to the wrapped dofn's + process method. + + Returns: + A list of elements produced by the input element. + """ + async def _collect(result): + if result is None: + return [] + if inspect.isawaitable(result): + result = await result + if isinstance(result, AsyncIterable): + return [item async for item in result] + if isinstance(result, + (GeneratorType, Iterable)) and not isinstance(result, + (str, bytes)): + return list(result) + return [result] + + self._sync_fn.start_bundle() + process_result = await _collect( + self._sync_fn.process(element, *args, **kwargs)) + bundle_result = await _collect(self._sync_fn.finish_bundle()) + return process_result + bundle_result + def decrement_items_in_buffer(self, future): with AsyncWrapper._lock: AsyncWrapper._items_in_buffer[self._uuid] -= 1 @@ -214,10 +290,16 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): logging.info('item %s already in processing elements', element) return True if self.accepting_items() or ignore_buffer: - result = AsyncWrapper._pool[self._uuid].acquire( - AsyncWrapper.initialize_pool(self._parallelism)).submit( - lambda: self.sync_fn_process(element, *args, **kwargs), - ) + if self._use_asyncio: + result = asyncio.run_coroutine_threadsafe( + self.async_fn_process(element, *args, **kwargs), + AsyncWrapper._event_loop, + ) + else: + result = AsyncWrapper._pool[self._uuid].acquire( + AsyncWrapper.initialize_pool(self._parallelism)).submit( + lambda: self.sync_fn_process(element, *args, **kwargs), + ) result.add_done_callback(self.decrement_items_in_buffer) AsyncWrapper._processing_elements[self._uuid][element_id] = ( element, result) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index fe75de05ccd5..81c7b8e163ff 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -22,6 +22,8 @@ from concurrent.futures import ThreadPoolExecutor from threading import Lock +from parameterized import parameterized_class + import apache_beam as beam import apache_beam.transforms.async_dofn as async_lib @@ -62,7 +64,7 @@ class FakeBagState: def __init__(self, items): self.items = items # Normally SE would have a lock on the BT row protecting this from multiple - # updates. Here without SE we must lock ourselvs. + # updates. Here without SE we must lock ourselves. self.lock = Lock() def add(self, item): @@ -86,6 +88,14 @@ def set(self, time): self.time = time +@parameterized_class([ + { + "use_asyncio": True + }, + { + "use_asyncio": False + }, +]) class AsyncTest(unittest.TestCase): def setUp(self): super().setUp() @@ -132,7 +142,8 @@ def __eq__(self, other): return self.element_id == other.element_id dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id) + async_dofn = async_lib.AsyncWrapper( + dofn, id_fn=lambda x: x.element_id, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -156,7 +167,7 @@ def __eq__(self, other): def test_basic(self): # Setup an async dofn and send a message in to process. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -181,9 +192,9 @@ def test_basic(self): self.assertEqual(fake_bag_state.items, []) def test_multi_key(self): - # Send in two messages with different keys.. + # Send in two messages with different keys. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state_key1 = FakeBagState([]) fake_bag_state_key2 = FakeBagState([]) @@ -211,7 +222,7 @@ def test_multi_key(self): def test_long_item(self): # Test that everything still works with a long running time for the dofn. dofn = BasicDofn(sleep_time=5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -231,10 +242,10 @@ def test_long_item(self): self.assertEqual(fake_bag_state.items, []) def test_lost_item(self): - # Setup an element in the bag stat thats not in processing state. + # Setup an element in the bag state that's not in processing state. # The async dofn should reschedule this element. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_timer = FakeTimer(0) msg = ('key1', 1) @@ -250,9 +261,9 @@ def test_lost_item(self): def test_cancelled_item(self): # Test that an item gets removed for processing and does not get output when # it is not present in the bag state. Either this item moved or a commit - # failed making the local state and bag stat inconsistent. + # failed making the local state and bag state inconsistent. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) msg2 = ('key1', 2) @@ -272,7 +283,7 @@ def test_multi_element_dofn(self): # Test that async works when a dofn produces multiple elements in process # and finish_bundle. dofn = MultiElementDoFn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -289,7 +300,7 @@ def test_duplicates(self): # Test that async will produce a single output when a given input is sent # multiple times. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -310,7 +321,7 @@ def test_slow_duplicates(self): # Test that async will produce a single output when a given input is sent # multiple times. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -335,7 +346,7 @@ def test_slow_duplicates(self): def test_buffer_count(self): # Test that the buffer count is correctly incremented when adding items. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) fake_timer = FakeTimer(0) @@ -353,7 +364,10 @@ def test_buffer_stops_accepting_items(self): # Test that the buffer stops accepting items when it is full. dofn = BasicDofn(5) async_dofn = async_lib.AsyncWrapper( - dofn, parallelism=1, max_items_to_buffer=5) + dofn, + parallelism=1, + max_items_to_buffer=5, + use_asyncio=self.use_asyncio) async_dofn.setup() fake_timer = FakeTimer(0) fake_bag_state = FakeBagState([]) @@ -391,7 +405,7 @@ def add_item(i): def test_buffer_with_cancellation(self): dofn = BasicDofn(3) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) msg2 = ('key1', 2) @@ -423,7 +437,8 @@ def test_load_correctness(self): # Test AsyncDofn over heavy load. dofn = BasicDofn(1) max_sleep = 10 - async_dofn = async_lib.AsyncWrapper(dofn, max_wait_time=max_sleep) + async_dofn = async_lib.AsyncWrapper( + dofn, max_wait_time=max_sleep, use_asyncio=self.use_asyncio) async_dofn.setup() bag_states = {} timers = {}