diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index cbd28f8b0a3f..a5589ac33a1b 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -49,6 +49,7 @@ from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue if TYPE_CHECKING: import apache_beam.coders.slow_stream @@ -455,11 +456,14 @@ class _GrpcDataChannel(DataChannel): def __init__(self, data_buffer_time_limit_ms=0): # type: (int) -> None + self._data_buffer_time_limit_ms = data_buffer_time_limit_ms - self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers] + self._to_send = ByteLimitedQueue( + maxsize=10000, + maxbytes=100 << 20) # type: ByteLimitedQueue[DataOrTimers] self._received = collections.defaultdict( - lambda: queue.Queue(maxsize=5) - ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] + lambda: ByteLimitedQueue(maxsize=5, maxbytes=100 << 20) + ) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions # must be discarded. See input_elements() and _clean_receiving_queue(). @@ -474,7 +478,7 @@ def __init__(self, data_buffer_time_limit_ms=0): def close(self): # type: () -> None - self._to_send.put(self._WRITES_FINISHED) + self._to_send.put(self._WRITES_FINISHED, 0) self._closed = True def wait(self, timeout=None): @@ -482,7 +486,7 @@ def wait(self, timeout=None): self._reads_finished.wait(timeout) def _receiving_queue(self, instruction_id): - # type: (str) -> Optional[queue.Queue[DataOrTimers]] + # type: (str) -> Optional[ByteLimitedQueue[DataOrTimers]] """ Gets or creates queue for a instruction_id. Or, returns None if the @@ -585,21 +589,19 @@ def output_stream(self, instruction_id, transform_id): def add_to_send_queue(data): # type: (bytes) -> None if data: - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - data=data)) + elem = beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, transform_id=transform_id, data=data) + self._to_send.put(elem, self._get_element_size_bytes(elem)) def close_callback(data): # type: (bytes) -> None add_to_send_queue(data) # End of stream marker. - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - is_last=True)) + elem = beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + is_last=True) + self._to_send.put(elem, self._get_element_size_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -614,23 +616,23 @@ def output_timer_stream( def add_to_send_queue(timer): # type: (bytes) -> None if timer: - self._to_send.put( - beam_fn_api_pb2.Elements.Timers( - instruction_id=instruction_id, - transform_id=transform_id, - timer_family_id=timer_family_id, - timers=timer, - is_last=False)) + elem = beam_fn_api_pb2.Elements.Timers( + instruction_id=instruction_id, + transform_id=transform_id, + timer_family_id=timer_family_id, + timers=timer, + is_last=False) + self._to_send.put(elem, self._get_element_size_bytes(elem)) def close_callback(timer): # type: (bytes) -> None add_to_send_queue(timer) - self._to_send.put( - beam_fn_api_pb2.Elements.Timers( - instruction_id=instruction_id, - transform_id=transform_id, - timer_family_id=timer_family_id, - is_last=True)) + elem = beam_fn_api_pb2.Elements.Timers( + instruction_id=instruction_id, + transform_id=transform_id, + timer_family_id=timer_family_id, + is_last=True) + self._to_send.put(elem, self._get_element_size_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -665,6 +667,15 @@ def _write_outputs(self): raise ValueError('Unexpected output element type %s' % type(stream)) yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream) + def _get_element_size_bytes(self, element): + # type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> int + if isinstance(element, beam_fn_api_pb2.Elements.Data): + return len(element.data) + elif isinstance(element, beam_fn_api_pb2.Elements.Timers): + return len(element.timers) + else: + return 0 + def _read_inputs(self, elements_iterator): # type: (Iterable[beam_fn_api_pb2.Elements]) -> None @@ -691,7 +702,8 @@ def _put_queue(instruction_id, element): next_discard_log_time = current_time + 10 return try: - input_queue.put(element, timeout=1) + input_queue.put( + element, self._get_element_size_bytes(element), timeout=1) return except queue.Full: current_time = time.time() diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd b/sdks/python/apache_beam/utils/byte_limited_queue.pxd new file mode 100644 index 000000000000..396185e8e101 --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: overflowcheck=True + +cdef class ByteLimitedQueue(object): + cdef readonly Py_ssize_t max_elements + cdef readonly Py_ssize_t max_bytes + cdef readonly Py_ssize_t _byte_size + cdef readonly object _mutex + cdef readonly object _not_empty + cdef readonly object _waiting_writers + cdef readonly list _condition_pool + cdef readonly object _queue + cdef readonly Py_ssize_t _blocked_bytes + + cpdef bint _can_fit(self, Py_ssize_t item_bytes) except -1 diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py new file mode 100644 index 000000000000..a6ff669800c2 --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A thread-safe queue that limits capacity by total byte size.""" + +import collections +import queue +import threading +import time +import types + + +class ByteLimitedQueue(object): + """A fair queue that limits by both element count and total byte size. + + A single element is allowed to exceed the maxbytes to avoid deadlock. + """ + __class_getitem__ = classmethod(types.GenericAlias) + + def __init__( + self, + maxsize=0, # type: int + maxbytes=0, # type: int + ): + # type: (...) -> None + + """Initializes a ByteLimitedQueue. + + Args: + maxsize: The maximum number of items allowed in the queue. If 0 or + negative, there is no limit on the number of elements. + maxbytes: The maximum accumulated bytes allowed in the queue. If 0 or + negative, there is no limit on the total bytes of the elements. + """ + self.max_elements = maxsize + self.max_bytes = maxbytes + + self._byte_size = 0 + self._blocked_bytes = 0 + self._mutex = threading.Lock() + self._not_empty = threading.Condition(self._mutex) + + self._waiting_writers = collections.deque() + self._condition_pool = [] + self._queue = collections.deque() + + def put(self, item, item_bytes, *, block=True, timeout=None): + """Put an item into the queue. + + If the queue is full, block until a free slot is available, unless `block` + is false or a timeout occurs. + + Args: + item: The item to put into the queue. + item_bytes: The size of the item. + block: If True, block until space is available. If False, raise queue.Full + immediately if the queue is full. + timeout: If block is True, wait for at most `timeout` seconds. If None, + block indefinitely. + + Raises: + ValueError: If timeout or item_bytes is negative. + queue.Full: If the queue is full and block is False or the timeout occurs. + """ + if timeout is not None and timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + if item_bytes < 0: + raise ValueError("'item_bytes' must be a non-negative number") + + with self._mutex: + if not self._waiting_writers and self._can_fit(item_bytes): + self._queue.append((item, item_bytes)) + self._byte_size += item_bytes + self._not_empty.notify() + return + + if not block: + raise queue.Full + + # Reuse or create a condition + my_cond = ( + self._condition_pool.pop() + if self._condition_pool else threading.Condition(self._mutex)) + + endtime = time.monotonic() + timeout if timeout is not None else None + + try: + self._blocked_bytes += item_bytes + self._waiting_writers.append(my_cond) + while True: + if timeout is None: + my_cond.wait() + else: + remaining = endtime - time.monotonic() + if remaining <= 0.0: + raise queue.Full + my_cond.wait(remaining) + + if self._waiting_writers[0] is my_cond and self._can_fit(item_bytes): + break + + self._queue.append((item, item_bytes)) + self._byte_size += item_bytes + self._not_empty.notify() + finally: + self._blocked_bytes -= item_bytes + if self._waiting_writers: + was_first = (self._waiting_writers[0] is my_cond) + if was_first: + self._waiting_writers.popleft() + else: + self._waiting_writers.remove(my_cond) + self._condition_pool.append(my_cond) + if was_first and self._waiting_writers: + self._waiting_writers[0].notify() + + def get(self, *, block=True, timeout=None): + """Remove and return an item from the queue. + + If the queue is empty, block until an item is available, unless `block` + is false or a timeout occurs. + + Args: + block: If True, block until an item is available. If False, raise + queue.Empty immediately if the queue is empty. + timeout: If block is True, wait for at most `timeout` seconds. If None, + block indefinitely. + + Returns: + The item removed from the queue. + + Raises: + ValueError: If timeout is negative. + queue.Empty: If the queue is empty and block is False or the timeout + occurs. + """ + if timeout is not None and timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + + with self._mutex: + if not block: + if not self._queue: + raise queue.Empty + elif timeout is None: + while not self._queue: + self._not_empty.wait() + else: + endtime = time.monotonic() + timeout + while not self._queue: + remaining = endtime - time.monotonic() + if remaining <= 0.0: + raise queue.Empty + self._not_empty.wait(remaining) + + item, item_bytes = self._queue.popleft() + self._byte_size -= item_bytes + + if self._waiting_writers: + self._waiting_writers[0].notify() + + return item + + def get_nowait(self): + """Remove and return an item from the queue without blocking.""" + return self.get(block=False) + + def byte_size(self): + """Return the total byte size of elements in the queue.""" + with self._mutex: + return self._byte_size + + def blocked_byte_size(self): + """Return the total byte size of elements in the queue that are blocked.""" + with self._mutex: + return self._blocked_bytes + + def qsize(self): + """Return the total number of elements in the queue.""" + with self._mutex: + return len(self._queue) + + def _can_fit(self, item_bytes): + # Always let in a single element, regardless of size. + if not self._queue: + return True + if self.max_elements > 0 and len(self._queue) >= self.max_elements: + return False + if self.max_bytes > 0 and self._byte_size + item_bytes > self.max_bytes: + return False + return True diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py new file mode 100644 index 000000000000..27ccb2421844 --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -0,0 +1,270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for byte-limited queue.""" + +import queue +import threading +import time +import unittest + +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue + + +class ByteLimitedQueueTest(unittest.TestCase): + def test_unbounded(self): + bq = ByteLimitedQueue() + for i in range(201): + bq.put(str(i), i) + self.assertEqual(bq.byte_size(), sum(range(201))) + self.assertEqual(bq.qsize(), 201) + + def test_put_and_get(self): + bq = ByteLimitedQueue(maxbytes=200) + bq.put('50', 50) + bq.put('140', 140) + self.assertEqual(bq.byte_size(), 190) + self.assertEqual(bq.qsize(), 2) + # Putting another would exceed 200. + with self.assertRaises(queue.Full): + bq.put('20', 20, block=False) + bq.put('10', 10, block=False) + self.assertEqual(bq.byte_size(), 200) + self.assertEqual(bq.qsize(), 3) + + self.assertEqual(bq.get(), '50') + self.assertEqual(bq.byte_size(), 150) + self.assertEqual(bq.qsize(), 2) + bq.put('20', 20, block=False) + + def test_dual_limit(self): + # Queue limits: at most 3 items, OR at most 100 item bytes. + bq = ByteLimitedQueue(maxsize=3, maxbytes=100) + bq.put('30', 30) + bq.put('40', 40) + bq.put('20', 20) + self.assertEqual(bq.byte_size(), 90) + self.assertEqual(bq.qsize(), 3) + # Full on element count (size=3). + with self.assertRaises(queue.Full): + bq.put('10', 10, block=False) + self.assertEqual(bq.get(), '30') + self.assertEqual(bq.get(), '40') + bq.put('10', 10) + # Full on byte count + with self.assertRaises(queue.Full): + bq.put('90', 90, block=False) + self.assertEqual(bq.get(), '20') + bq.put('90', 90, block=False) + + def test_multithreading(self): + bq = ByteLimitedQueue(maxsize=0, maxbytes=100) + received = [] + + def producer(): + for i in range(101): + bq.put(str(i), i) + + poison_pill = 'POISON' + + def consumer(): + while True: + item = bq.get() + if item == poison_pill: + break + received.append(int(item)) + + t1 = threading.Thread(target=producer) + t2 = threading.Thread(target=producer) + t3 = threading.Thread(target=consumer) + + t1.start() + t2.start() + t3.start() + + t1.join() + t2.join() + bq.put(poison_pill, 0) + + t3.join() + + self.assertEqual(len(received), 202) + self.assertEqual(sum(received), 2 * sum(range(101))) + + def test_put_timeout(self): + bq = ByteLimitedQueue(maxsize=0, maxbytes=10) + bq.put('10', 10) + + # The queue is completely full. A timeout put should raise queue.Full. + with self.assertRaises(queue.Full): + bq.put('5', 5, timeout=0.01) + + def delayed_consumer(): + time.sleep(0.05) + bq.get() + + # Start a thread that will free up space after 50ms. + t = threading.Thread(target=delayed_consumer) + t.start() + + # The put should succeed once the consumer runs, use a high timeout to + # flakiness. + bq.put('item', 5, timeout=60) + t.join() + + def test_get_timeout(self): + bq = ByteLimitedQueue(maxsize=0, maxbytes=100) + with self.assertRaises(queue.Empty): + bq.get(block=False) + with self.assertRaises(queue.Empty): + bq.get(timeout=0.0) + with self.assertRaises(queue.Empty): + bq.get(timeout=.01) + + bq.put('1', 1) + self.assertEqual('1', bq.get(timeout=0)) + + bq.put('2', 2) + self.assertEqual('2', bq.get(timeout=0.1)) + + def delayed_producer(): + time.sleep(0.05) + bq.put('3', 3) + + # Start a thread that will produce soon + t = threading.Thread(target=delayed_producer) + t.start() + + # The get should succeed once the produer runs, use a high timeout to + # flakiness. + self.assertEqual('3', bq.get(timeout=60)) + t.join() + + def test_negative_timeout(self): + bq = ByteLimitedQueue() + # Putting an item with a negative timeout should raise ValueError. + with self.assertRaises(ValueError): + bq.put('5', 5, timeout=-1) + with self.assertRaises(ValueError): + bq.get(timeout=-1) + + def test_single_element_override(self): + bq = ByteLimitedQueue(maxbytes=10) + # An item of size 50 exceeds maxbytes 10, but should be admitted + # immediately without blocking since the queue is currently empty! + bq.put('50', 50, block=False) + self.assertEqual(bq.qsize(), 1) + self.assertEqual(bq.byte_size(), 50) + + def test_fairness(self): + bq = ByteLimitedQueue(maxbytes=10) + # Put an initial item so that the queue is not empty, + # causing the subsequent large item to block. + bq.put('first', 2) + self.assertEqual(bq.blocked_byte_size(), 0) + + def producer(item, size): + bq.put(item, size) + + # Add an item in a background thread that should block due to exceeding + # the limit. + t1 = threading.Thread(target=producer, args=('too_large', 9)) + t1.start() + + # Wait until the background write is queued. + while bq.blocked_byte_size() < 1: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 9) + + # Add smaller items afterwards. + t2 = threading.Thread(target=producer, args=('small1', 1)) + t2.start() + + while bq.blocked_byte_size() < 10: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 10) + + t3 = threading.Thread(target=producer, args=('small2', 1)) + t3.start() + + while bq.blocked_byte_size() < 11: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 11) + + # Verify all items are received in order. + self.assertEqual(bq.get(), 'first') + t1.join() + t2.join() + self.assertEqual(bq.get(), 'too_large') + t3.join() + self.assertEqual(bq.get(), 'small1') + self.assertEqual(bq.get(), 'small2') + + def test_blocked_waiter_timeout_multiple(self): + bq = ByteLimitedQueue(maxbytes=10) + bq.put('initial', 5) + + status = [] + lock = threading.Lock() + + def producer(name, size, timeout_val): + try: + bq.put(name, size, timeout=timeout_val) + with lock: + status.append((name, 'success')) + except queue.Full: + with lock: + status.append((name, 'timeout')) + + threads = [] + threads.append(threading.Thread(target=producer, args=('t1', 8, 0.2))) + threads.append(threading.Thread(target=producer, args=('t2', 8, 60.0))) + threads.append(threading.Thread(target=producer, args=('t3', 3, 0.1))) + threads.append(threading.Thread(target=producer, args=('t4', 3, 60.0))) + threads.append(threading.Thread(target=producer, args=('t5', 3, 0.1))) + for t in threads: + t.start() + + # Wait for the short-timeout threads. + threads[4].join() + threads[2].join() + threads[0].join() + + # Now waiting writers should just be t1 and t3 + self.assertEqual(bq.blocked_byte_size(), 11) + + self.assertEqual(bq.get(), 'initial') + threads[1].join() + self.assertGreater(bq.blocked_byte_size(), 0) + + elem = bq.get() + self.assertTrue(elem == 't2' or elem == 't4') + threads[3].join() + self.assertEqual(bq.blocked_byte_size(), 0) + elem = bq.get() + self.assertTrue(elem == 't2' or elem == 't4') + + with lock: + self.assertIn(('t1', 'timeout'), status) + self.assertIn(('t2', 'success'), status) + self.assertIn(('t3', 'timeout'), status) + self.assertIn(('t4', 'success'), status) + self.assertIn(('t5', 'timeout'), status) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index b3fb98d8b0ef..45781a44c4b1 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -368,6 +368,7 @@ def get_portability_package_data(): 'apache_beam/runners/worker/operations.py', 'apache_beam/transforms/cy_combiners.py', 'apache_beam/transforms/stats.py', + 'apache_beam/utils/byte_limited_queue.py', 'apache_beam/utils/counters.py', 'apache_beam/utils/windowed_value.py', ])