From caeaccfcea4e4dc83bdecf8e1c20898a14ea5ab0 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 7 May 2026 18:20:38 +0200 Subject: [PATCH 1/8] [Python] Bound the memory used for fnapi outbound data messages. Previously an unbounded queue was used for pending data outputs to be sent over the fnapi to the runner. If outputs were being generated faster than the runner was consuming them, this would lead to memory growth and possible OOMs. This PR introduces a byte-limited queue data structure that is used instead to limit the # of bytes in the queue. This was preferred to just using a queue with max number of elements because the size of elements can vary greatly. For batch pipelines they are likely large while for stremaing pipelines there may be more small outputs. --- .../apache_beam/runners/worker/data_plane.py | 15 +- .../apache_beam/utils/byte_limited_queue.py | 95 ++++++++++ .../utils/byte_limited_queue_test.py | 168 ++++++++++++++++++ 3 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 sdks/python/apache_beam/utils/byte_limited_queue.py create mode 100644 sdks/python/apache_beam/utils/byte_limited_queue_test.py diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index cbd28f8b0a3f..e46e66bd11d3 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -49,6 +49,7 @@ from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue if TYPE_CHECKING: import apache_beam.coders.slow_stream @@ -455,10 +456,20 @@ class _GrpcDataChannel(DataChannel): def __init__(self, data_buffer_time_limit_ms=0): # type: (int) -> None + def _element_weight(element): + if isinstance(element, beam_fn_api_pb2.Elements.Data): + return len(element.data) + elif isinstance(element, beam_fn_api_pb2.Elements.Timers): + return len(element.timers) + return 0 + self._data_buffer_time_limit_ms = data_buffer_time_limit_ms - self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers] + self._to_send = ByteLimitedQueue( + maxsize=10000, maxweight=100 << 20, + weighing_fn=_element_weight) # type: queue.Queue[DataOrTimers] self._received = collections.defaultdict( - lambda: queue.Queue(maxsize=5) + lambda: ByteLimitedQueue( + maxsize=5, maxweight=100 << 20, weighing_fn=_element_weight) ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py new file mode 100644 index 000000000000..7e14f01c6ae0 --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A thread-safe queue that limits capacity by total byte size.""" + +import queue +import time +from typing import Any +from typing import Callable + + +class ByteLimitedQueue(queue.Queue): + """A queue.Queue that limits by both element count and total weight. + + A single element is allowed to exceed the maxweight to avoid deadlock. + """ + def __init__( + self, + weighing_fn, # type: Callable[[Any], int] + maxsize=0, # type: int + maxweight=0, # type: int + ): + # type: (...) -> None + + """Initializes a ByteLimitedQueue. + + Args: + weighing_fn: A Callable that accepts an item and returns its integer + weight. + maxsize: The maximum number of items allowed in the queue. If 0 or + negative, there is no limit on the number of elements. + maxweight: The maximum accumulated weight allowed in the queue. + """ + super().__init__(maxsize=0) + self.max_elements = maxsize + self.max_weight = maxweight + self.weighing_fn = weighing_fn + self._byte_size = 0 + + def _is_full(self, item_size): + if self._qsize() == 0: + return False + if self.max_elements > 0 and self._qsize() >= self.max_elements: + return True + if self.max_weight > 0 and self._byte_size + item_size > self.max_weight: + return True + return False + + def put(self, item, block=True, timeout=None): + item_size = max(1, self.weighing_fn(item)) + with self.not_full: + if not block: + if self._is_full(item_size): + raise queue.Full + elif timeout is None: + while self._is_full(item_size): + self.not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = time.time() + timeout + while self._is_full(item_size): + remaining = endtime - time.time() + if remaining <= 0.0: + raise queue.Full + self.not_full.wait(remaining) + + self._put((item, item_size)) + self._byte_size += item_size + self.unfinished_tasks += 1 + self.not_empty.notify() + + def _get(self): + item, item_weight = super()._get() + self._byte_size -= item_weight + return item + + def byte_size(self): + """Return the total byte weight of elements in the queue.""" + with self.mutex: + return self._byte_size diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py new file mode 100644 index 000000000000..e6349f3af00b --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for byte-limited queue.""" + +import queue +import sys +import threading +import time +import unittest + +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue + + +class FakeItem(object): + def __init__(self, size): + self._size = size + + def weight(self): + return self._size + + +class ByteLimitedQueueTest(unittest.TestCase): + def test_unbounded(self): + bq = ByteLimitedQueue(lambda x: x.weight()) + for i in range(200): + bq.put(FakeItem(i)) + # Add 1 since weight of zero is set to 1 + self.assertEqual(bq.byte_size(), sum(range(200)) + 1) + self.assertEqual(bq.qsize(), 200) + + def test_put_and_get(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=200) + bq.put(FakeItem(50)) + bq.put(FakeItem(140)) + self.assertEqual(bq.byte_size(), 190) + self.assertEqual(bq.qsize(), 2) + # Putting another would exceed 200. + with self.assertRaises(queue.Full): + bq.put(FakeItem(20), block=False) + bq.put(FakeItem(10), block=False) + self.assertEqual(bq.byte_size(), 200) + self.assertEqual(bq.qsize(), 3) + + self.assertEqual(bq.get().weight(), 50) + self.assertEqual(bq.byte_size(), 150) + self.assertEqual(bq.qsize(), 2) + bq.put(FakeItem(20), block=False) + + def test_dual_limit(self): + # Queue limits: at most 2 items, OR at most 100 weight. + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=3, maxweight=100) + bq.put(FakeItem(30)) + bq.put(FakeItem(40)) + bq.put(FakeItem(20)) + self.assertEqual(bq.byte_size(), 90) + self.assertEqual(bq.qsize(), 3) + # Full on element count (size=2). + with self.assertRaises(queue.Full): + bq.put(FakeItem(10), block=False) + self.assertEqual(bq.get().weight(), 30) + self.assertEqual(bq.get().weight(), 40) + bq.put(FakeItem(10)) + # Full on byte count + with self.assertRaises(queue.Full): + bq.put(FakeItem(90), block=False) + self.assertEqual(bq.get().weight(), 20) + bq.put(FakeItem(90), block=False) + + @unittest.skipIf(sys.version_info < (3, 13), 'Queue.ShutDown added in 3.13.') + def test_multithreading(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=100) + received = [] + + def producer(): + for i in range(101): + bq.put(FakeItem(i)) + + def consumer(): + while True: + try: + received.append(bq.get().weight()) + except queue.ShutDown: + break + + t1 = threading.Thread(target=producer) + t2 = threading.Thread(target=producer) + t3 = threading.Thread(target=consumer) + + t1.start() + t2.start() + t3.start() + + t1.join() + t2.join() + bq.shutdown() + + t3.join() + + self.assertEqual(len(received), 202) + self.assertEqual(sum(received), 2 * sum(range(101))) + + def test_multithreading_timeout(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=10) + bq.put(FakeItem(10)) + + # The queue is completely full. A timeout put should raise queue.Full. + with self.assertRaises(queue.Full): + bq.put(FakeItem(5), timeout=0.01) + + def delayed_consumer(): + time.sleep(0.05) + bq.get() + + # Start a thread that will free up space after 50ms. + t = threading.Thread(target=delayed_consumer) + t.start() + + # The put should succeed once the consumer runs, use a high timeout to + # flakiness. + bq.put(FakeItem(5), timeout=60) + t.join() + + def test_negative_timeout(self): + bq = ByteLimitedQueue(lambda x: x.weight()) + # Putting an item with a negative timeout should raise ValueError. + with self.assertRaises(ValueError): + bq.put(FakeItem(5), timeout=-1) + + def test_single_element_override(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=10) + # An item of size 50 exceeds maxweight 10, but should be admitted + # immediately without blocking since the queue is currently empty! + bq.put(FakeItem(50), block=False) + self.assertEqual(bq.qsize(), 1) + self.assertEqual(bq.byte_size(), 50) + + def test_inconsistent_weighing_fn(self): + # Return a different weight for the same item. + weights = [10, 5] + bq = ByteLimitedQueue(lambda x: weights.pop(0), maxweight=100) + + bq.put(1) + self.assertEqual(bq.byte_size(), 10) + + # Upon popping, the weighing function (if called) would have returned 5, + # but the stored weight prevents corruption and cleanly reduces the size to + # 0. + bq.get() + self.assertEqual(bq.byte_size(), 0) + + +if __name__ == '__main__': + unittest.main() From 5fa799e4faeb2ef5e849ddb3dc501419d3eda987 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 8 May 2026 16:31:54 +0200 Subject: [PATCH 2/8] monotonic and not shutdown restriction --- sdks/python/apache_beam/utils/byte_limited_queue.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py index 7e14f01c6ae0..9bb4d469ae2a 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -27,6 +27,7 @@ class ByteLimitedQueue(queue.Queue): """A queue.Queue that limits by both element count and total weight. A single element is allowed to exceed the maxweight to avoid deadlock. + Note that shutdown is only supported after there are no more put calls. """ def __init__( self, @@ -72,9 +73,9 @@ def put(self, item, block=True, timeout=None): elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: - endtime = time.time() + timeout + endtime = time.monotonic() + timeout while self._is_full(item_size): - remaining = endtime - time.time() + remaining = endtime - time.monotonic() if remaining <= 0.0: raise queue.Full self.not_full.wait(remaining) From 8cbf0acf911e985663e34447ee376f544d7f9c3c Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Mon, 11 May 2026 15:05:12 +0200 Subject: [PATCH 3/8] change to not subclass queue.Queue and to be fair --- .../apache_beam/runners/worker/data_plane.py | 70 +++---- .../apache_beam/utils/byte_limited_queue.py | 184 +++++++++++++---- .../utils/byte_limited_queue_test.py | 186 +++++++++++++----- 3 files changed, 315 insertions(+), 125 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index e46e66bd11d3..f5c4f0d19fac 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -456,20 +456,15 @@ class _GrpcDataChannel(DataChannel): def __init__(self, data_buffer_time_limit_ms=0): # type: (int) -> None - def _element_weight(element): - if isinstance(element, beam_fn_api_pb2.Elements.Data): - return len(element.data) - elif isinstance(element, beam_fn_api_pb2.Elements.Timers): - return len(element.timers) - return 0 self._data_buffer_time_limit_ms = data_buffer_time_limit_ms self._to_send = ByteLimitedQueue( - maxsize=10000, maxweight=100 << 20, - weighing_fn=_element_weight) # type: queue.Queue[DataOrTimers] + maxsize=10000, maxbytes=100 << 20, + ) # type: queue.Queue[DataOrTimers] self._received = collections.defaultdict( lambda: ByteLimitedQueue( - maxsize=5, maxweight=100 << 20, weighing_fn=_element_weight) + maxsize=5, maxbytes=100 << 20, + ) ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions @@ -596,21 +591,21 @@ def output_stream(self, instruction_id, transform_id): def add_to_send_queue(data): # type: (bytes) -> None if data: - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - data=data)) + elem = beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + data=data) + self._to_send.put(elem, self._get_element_bytes(elem)) def close_callback(data): # type: (bytes) -> None add_to_send_queue(data) # End of stream marker. - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - is_last=True)) + elem = beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + is_last=True) + self._to_send.put(elem, self._get_element_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -625,23 +620,23 @@ def output_timer_stream( def add_to_send_queue(timer): # type: (bytes) -> None if timer: - self._to_send.put( - beam_fn_api_pb2.Elements.Timers( - instruction_id=instruction_id, - transform_id=transform_id, - timer_family_id=timer_family_id, - timers=timer, - is_last=False)) + elem = beam_fn_api_pb2.Elements.Timers( + instruction_id=instruction_id, + transform_id=transform_id, + timer_family_id=timer_family_id, + timers=timer, + is_last=False) + self._to_send.put(elem, self._get_element_bytes(elem)) def close_callback(timer): # type: (bytes) -> None add_to_send_queue(timer) - self._to_send.put( - beam_fn_api_pb2.Elements.Timers( - instruction_id=instruction_id, - transform_id=transform_id, - timer_family_id=timer_family_id, - is_last=True)) + elem = beam_fn_api_pb2.Elements.Timers( + instruction_id=instruction_id, + transform_id=transform_id, + timer_family_id=timer_family_id, + is_last=True) + self._to_send.put(elem, self._get_element_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -676,6 +671,15 @@ def _write_outputs(self): raise ValueError('Unexpected output element type %s' % type(stream)) yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream) + def _get_element_bytes(self, element): + # type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> int + if isinstance(element, beam_fn_api_pb2.Elements.Data): + return len(element.data) + elif isinstance(element, beam_fn_api_pb2.Elements.Timers): + return len(element.timers) + else: + return 0 + def _read_inputs(self, elements_iterator): # type: (Iterable[beam_fn_api_pb2.Elements]) -> None @@ -702,7 +706,7 @@ def _put_queue(instruction_id, element): next_discard_log_time = current_time + 10 return try: - input_queue.put(element, timeout=1) + input_queue.put(element, self._get_element_bytes(element), timeout=1) return except queue.Full: current_time = time.time() diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py index 9bb4d469ae2a..9a1204b9500e 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -17,80 +17,180 @@ """A thread-safe queue that limits capacity by total byte size.""" +import collections import queue +import threading import time -from typing import Any -from typing import Callable -class ByteLimitedQueue(queue.Queue): - """A queue.Queue that limits by both element count and total weight. +class ByteLimitedQueue(object): + """A fair queue that limits by both element count and total weight. A single element is allowed to exceed the maxweight to avoid deadlock. - Note that shutdown is only supported after there are no more put calls. """ def __init__( self, - weighing_fn, # type: Callable[[Any], int] maxsize=0, # type: int - maxweight=0, # type: int + maxbytes=0, # type: int ): # type: (...) -> None """Initializes a ByteLimitedQueue. Args: - weighing_fn: A Callable that accepts an item and returns its integer - weight. maxsize: The maximum number of items allowed in the queue. If 0 or negative, there is no limit on the number of elements. - maxweight: The maximum accumulated weight allowed in the queue. + maxweight: The maximum accumulated weight allowed in the queue. If 0 or + negative, there is no limit on the total size of the elements. """ - super().__init__(maxsize=0) self.max_elements = maxsize - self.max_weight = maxweight - self.weighing_fn = weighing_fn + self.max_bytes = maxbytes self._byte_size = 0 + self._blocked_bytes = 0 + self._mutex = threading.Lock() + self._not_empty = threading.Condition(self._mutex) + self._waiting_writers = collections.deque() + self._queue = collections.deque() - def _is_full(self, item_size): - if self._qsize() == 0: - return False - if self.max_elements > 0 and self._qsize() >= self.max_elements: - return True - if self.max_weight > 0 and self._byte_size + item_size > self.max_weight: - return True - return False + def put(self, item, item_bytes, block=True, timeout=None): + """Put an item into the queue. + + If the queue is full, block until a free slot is available, unless `block` + is false or a timeout occurs. + + Args: + item: The item to put into the queue. + item_size: The size of the item. + block: If True, block until space is available. If False, raise queue.Full + immediately if the queue is full. + timeout: If block is True, wait for at most `timeout` seconds. If None, + block indefinitely. + + Raises: + ValueError: If timeout or item_size is negative. + queue.Full: If the queue is full and block is False or the timeout occurs. + """ + if timeout is not None and timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + if item_bytes < 0: + raise ValueError("'item_bytes' must be a positive number") + + with self._mutex: + if not self._waiting_writers and not self._is_full_locked( + item_bytes + ): + self._queue.append((item, item_bytes)) + self._byte_size += item_bytes + self._not_empty.notify() + return - def put(self, item, block=True, timeout=None): - item_size = max(1, self.weighing_fn(item)) - with self.not_full: if not block: - if self._is_full(item_size): - raise queue.Full + raise queue.Full + + my_cond = threading.Condition(self._mutex) + endtime = time.monotonic() + timeout if timeout is not None else None + try: + self._blocked_bytes += item_bytes + self._waiting_writers.append(my_cond) + while True: + if timeout is None: + my_cond.wait() + else: + remaining = endtime - time.monotonic() + if remaining <= 0.0: + raise queue.Full + my_cond.wait(remaining) + + if self._waiting_writers[0] is my_cond and not self._is_full_locked( + item_bytes + ): + break + + self._queue.append((item, item_bytes)) + self._byte_size += item_bytes + self._not_empty.notify() + finally: + self._blocked_bytes -= item_bytes + if self._waiting_writers: + was_first = (self._waiting_writers[0] is my_cond) + if was_first: + self._waiting_writers.popleft() + else: + self._waiting_writers.remove(my_cond) + if was_first and self._waiting_writers: + self._waiting_writers[0].notify() + + def get(self, block=True, timeout=None): + """Remove and return an item from the queue. + + If the queue is empty, block until an item is available, unless `block` + is false or a timeout occurs. + + Args: + block: If True, block until an item is available. If False, raise + queue.Empty immediately if the queue is empty. + timeout: If block is True, wait for at most `timeout` seconds. If None, + block indefinitely. + + Returns: + The item removed from the queue. + + Raises: + ValueError: If timeout is negative. + queue.Empty: If the queue is empty and block is False or the timeout + occurs. + """ + if timeout is not None and timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + + with self._not_empty: + if not block: + if not self._queue: + raise queue.Empty elif timeout is None: - while self._is_full(item_size): - self.not_full.wait() - elif timeout < 0: - raise ValueError("'timeout' must be a non-negative number") + while not self._queue: + self._not_empty.wait() else: endtime = time.monotonic() + timeout - while self._is_full(item_size): + while not self._queue: remaining = endtime - time.monotonic() if remaining <= 0.0: - raise queue.Full - self.not_full.wait(remaining) + raise queue.Empty + self._not_empty.wait(remaining) + + item, item_bytes = self._queue.popleft() + self._byte_size -= item_bytes + + if self._waiting_writers: + self._waiting_writers[0].notify() - self._put((item, item_size)) - self._byte_size += item_size - self.unfinished_tasks += 1 - self.not_empty.notify() + return item - def _get(self): - item, item_weight = super()._get() - self._byte_size -= item_weight - return item + def get_nowait(self): + """Remove and return an item from the queue without blocking.""" + return self.get(block=False) def byte_size(self): """Return the total byte weight of elements in the queue.""" - with self.mutex: + with self._mutex: return self._byte_size + + def blocked_byte_size(self): + """Return the total byte weight of elements in the queue that are blocked.""" + with self._mutex: + return self._blocked_bytes + + def qsize(self): + """Return the total number of elements in the queue.""" + with self._mutex: + return len(self._queue) + + def _is_full_locked(self, item_size): + # Always let in a single element, regardless of size. + if not self._queue: + return False + if self.max_elements > 0 and len(self._queue) >= self.max_elements: + return True + if self.max_bytes > 0 and self._byte_size + item_size > self.max_bytes: + return True + return False diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py index e6349f3af00b..8c6804dfa328 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue_test.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -18,7 +18,6 @@ """Unit tests for byte-limited queue.""" import queue -import sys import threading import time import unittest @@ -26,76 +25,68 @@ from apache_beam.utils.byte_limited_queue import ByteLimitedQueue -class FakeItem(object): - def __init__(self, size): - self._size = size - - def weight(self): - return self._size - - class ByteLimitedQueueTest(unittest.TestCase): def test_unbounded(self): - bq = ByteLimitedQueue(lambda x: x.weight()) - for i in range(200): - bq.put(FakeItem(i)) - # Add 1 since weight of zero is set to 1 - self.assertEqual(bq.byte_size(), sum(range(200)) + 1) - self.assertEqual(bq.qsize(), 200) + bq = ByteLimitedQueue() + for i in range(201): + bq.put(str(i), i) + self.assertEqual(bq.byte_size(), sum(range(201))) + self.assertEqual(bq.qsize(), 201) def test_put_and_get(self): - bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=200) - bq.put(FakeItem(50)) - bq.put(FakeItem(140)) + bq = ByteLimitedQueue(maxbytes=200) + bq.put('50', 50) + bq.put('140', 140) self.assertEqual(bq.byte_size(), 190) self.assertEqual(bq.qsize(), 2) # Putting another would exceed 200. with self.assertRaises(queue.Full): - bq.put(FakeItem(20), block=False) - bq.put(FakeItem(10), block=False) + bq.put('20', 20, block=False) + bq.put('10', 10, block=False) self.assertEqual(bq.byte_size(), 200) self.assertEqual(bq.qsize(), 3) - self.assertEqual(bq.get().weight(), 50) + self.assertEqual(bq.get(), '50') self.assertEqual(bq.byte_size(), 150) self.assertEqual(bq.qsize(), 2) - bq.put(FakeItem(20), block=False) + bq.put('20', 20, block=False) def test_dual_limit(self): - # Queue limits: at most 2 items, OR at most 100 weight. - bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=3, maxweight=100) - bq.put(FakeItem(30)) - bq.put(FakeItem(40)) - bq.put(FakeItem(20)) + # Queue limits: at most 3 items, OR at most 100 item bytes. + bq = ByteLimitedQueue(maxsize=3, maxbytes=100) + bq.put('30', 30) + bq.put('40', 40) + bq.put('20', 20) self.assertEqual(bq.byte_size(), 90) self.assertEqual(bq.qsize(), 3) - # Full on element count (size=2). + # Full on element count (size=3). with self.assertRaises(queue.Full): - bq.put(FakeItem(10), block=False) - self.assertEqual(bq.get().weight(), 30) - self.assertEqual(bq.get().weight(), 40) - bq.put(FakeItem(10)) + bq.put('10', 10, block=False) + self.assertEqual(bq.get(), '30') + self.assertEqual(bq.get(), '40') + bq.put('10', 10) # Full on byte count with self.assertRaises(queue.Full): - bq.put(FakeItem(90), block=False) - self.assertEqual(bq.get().weight(), 20) - bq.put(FakeItem(90), block=False) + bq.put('90', 90, block=False) + self.assertEqual(bq.get(), '20') + bq.put('90', 90, block=False) - @unittest.skipIf(sys.version_info < (3, 13), 'Queue.ShutDown added in 3.13.') def test_multithreading(self): - bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=100) + bq = ByteLimitedQueue(maxsize=0, maxbytes=100) received = [] def producer(): for i in range(101): - bq.put(FakeItem(i)) + bq.put(str(i), i) + + poison_pill = 'POISON' def consumer(): while True: - try: - received.append(bq.get().weight()) - except queue.ShutDown: + item = bq.get() + if item == poison_pill: break + received.append(int(item)) t1 = threading.Thread(target=producer) t2 = threading.Thread(target=producer) @@ -107,7 +98,7 @@ def consumer(): t1.join() t2.join() - bq.shutdown() + bq.put(poison_pill, 0) t3.join() @@ -115,12 +106,12 @@ def consumer(): self.assertEqual(sum(received), 2 * sum(range(101))) def test_multithreading_timeout(self): - bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=10) - bq.put(FakeItem(10)) + bq = ByteLimitedQueue(maxsize=0, maxbytes=10) + bq.put('10', 10) # The queue is completely full. A timeout put should raise queue.Full. with self.assertRaises(queue.Full): - bq.put(FakeItem(5), timeout=0.01) + bq.put('5', 5, timeout=0.01) def delayed_consumer(): time.sleep(0.05) @@ -136,16 +127,16 @@ def delayed_consumer(): t.join() def test_negative_timeout(self): - bq = ByteLimitedQueue(lambda x: x.weight()) + bq = ByteLimitedQueue() # Putting an item with a negative timeout should raise ValueError. with self.assertRaises(ValueError): - bq.put(FakeItem(5), timeout=-1) + bq.put('5', 5, timeout=-1) def test_single_element_override(self): - bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=10) - # An item of size 50 exceeds maxweight 10, but should be admitted + bq = ByteLimitedQueue(maxbytes=10) + # An item of size 50 exceeds maxbytes 10, but should be admitted # immediately without blocking since the queue is currently empty! - bq.put(FakeItem(50), block=False) + bq.put('50', 50, block=False) self.assertEqual(bq.qsize(), 1) self.assertEqual(bq.byte_size(), 50) @@ -163,6 +154,101 @@ def test_inconsistent_weighing_fn(self): bq.get() self.assertEqual(bq.byte_size(), 0) + def test_fairness(self): + bq = ByteLimitedQueue(maxbytes=10) + # Put an initial item so that the queue is not empty, + # causing the subsequent large item to block. + bq.put('first', 2) + self.assertEqual(bq.blocked_byte_size(), 0) + + def producer(item, size): + bq.put(item, size) + + # Add an item in a background thread that should block due to exceeding + # the limit. + t1 = threading.Thread(target=producer, args=('too_large', 9)) + t1.start() + + # Wait until the background write is queued. + while bq.blocked_byte_size() < 1: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 9) + + # Add smaller items afterwards. + t2 = threading.Thread(target=producer, args=('small1', 1)) + t2.start() + + while bq.blocked_byte_size() < 10: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 10) + + t3 = threading.Thread(target=producer, args=('small2', 1)) + t3.start() + + while bq.blocked_byte_size() < 11: + time.sleep(0.005) + self.assertEqual(bq.blocked_byte_size(), 11) + + # Verify all items are received in order. + self.assertEqual(bq.get(), 'first') + t1.join() + t2.join() + self.assertEqual(bq.get(), 'too_large') + t3.join() + self.assertEqual(bq.get(), 'small1') + self.assertEqual(bq.get(), 'small2') + + def test_blocked_waiter_timeout_multiple(self): + bq = ByteLimitedQueue(maxbytes=10) + bq.put('initial', 5) + + status = [] + lock = threading.Lock() + + def producer(name, size, timeout_val): + try: + bq.put(name, size, timeout=timeout_val) + with lock: + status.append((name, 'success')) + except queue.Full: + with lock: + status.append((name, 'timeout')) + + threads = [] + threads.append(threading.Thread(target=producer, args=('t1', 8, 0.2))) + threads.append(threading.Thread(target=producer, args=('t2', 8, 60.0))) + threads.append(threading.Thread(target=producer, args=('t3', 3, 0.1))) + threads.append(threading.Thread(target=producer, args=('t4', 3, 60.0))) + threads.append(threading.Thread(target=producer, args=('t5', 3, 0.1))) + for t in threads: + t.start() + + # Wait for the short-timeout threads. + threads[4].join() + threads[2].join() + threads[0].join() + + # Now waiting writers should just be t1 and t3 + self.assertEqual(bq.blocked_byte_size(), 11) + + self.assertEqual(bq.get(), 'initial') + threads[1].join() + self.assertGreater(bq.blocked_byte_size(), 0) + + elem = bq.get() + self.assertTrue(elem == 't2' or elem == 't4') + threads[3].join() + self.assertEqual(bq.blocked_byte_size(), 0) + elem = bq.get() + self.assertTrue(elem == 't2' or elem == 't4') + + with lock: + self.assertIn(('t1', 'timeout'), status) + self.assertIn(('t2', 'success'), status) + self.assertIn(('t3', 'timeout'), status) + self.assertIn(('t4', 'success'), status) + self.assertIn(('t5', 'timeout'), status) + if __name__ == '__main__': unittest.main() From d9a512036f39ea581be02daed52daa19f514c203 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Mon, 11 May 2026 15:27:39 +0200 Subject: [PATCH 4/8] fixups --- .../apache_beam/runners/worker/data_plane.py | 18 +++++------ .../apache_beam/utils/byte_limited_queue.py | 32 +++++++++---------- .../utils/byte_limited_queue_test.py | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index f5c4f0d19fac..0f9fac22a031 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -459,13 +459,15 @@ def __init__(self, data_buffer_time_limit_ms=0): self._data_buffer_time_limit_ms = data_buffer_time_limit_ms self._to_send = ByteLimitedQueue( - maxsize=10000, maxbytes=100 << 20, - ) # type: queue.Queue[DataOrTimers] + maxsize=10000, + maxbytes=100 << 20 + ) # type: ByteLimitedQueue[DataOrTimers] self._received = collections.defaultdict( lambda: ByteLimitedQueue( - maxsize=5, maxbytes=100 << 20, + maxsize=5, + maxbytes=100 << 20, ) - ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] + ) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions # must be discarded. See input_elements() and _clean_receiving_queue(). @@ -480,7 +482,7 @@ def __init__(self, data_buffer_time_limit_ms=0): def close(self): # type: () -> None - self._to_send.put(self._WRITES_FINISHED) + self._to_send.put(self._WRITES_FINISHED, 0) self._closed = True def wait(self, timeout=None): @@ -488,7 +490,7 @@ def wait(self, timeout=None): self._reads_finished.wait(timeout) def _receiving_queue(self, instruction_id): - # type: (str) -> Optional[queue.Queue[DataOrTimers]] + # type: (str) -> Optional[ByteLimitedQueue[DataOrTimers]] """ Gets or creates queue for a instruction_id. Or, returns None if the @@ -592,9 +594,7 @@ def add_to_send_queue(data): # type: (bytes) -> None if data: elem = beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - data=data) + instruction_id=instruction_id, transform_id=transform_id, data=data) self._to_send.put(elem, self._get_element_bytes(elem)) def close_callback(data): diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py index 9a1204b9500e..1a9ef40b9e3e 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -21,13 +21,16 @@ import queue import threading import time +import types class ByteLimitedQueue(object): - """A fair queue that limits by both element count and total weight. + """A fair queue that limits by both element count and total byte size. - A single element is allowed to exceed the maxweight to avoid deadlock. + A single element is allowed to exceed the maxbytes to avoid deadlock. """ + __class_getitem__ = classmethod(types.GenericAlias) + def __init__( self, maxsize=0, # type: int @@ -40,8 +43,8 @@ def __init__( Args: maxsize: The maximum number of items allowed in the queue. If 0 or negative, there is no limit on the number of elements. - maxweight: The maximum accumulated weight allowed in the queue. If 0 or - negative, there is no limit on the total size of the elements. + maxbytes: The maximum accumulated bytes allowed in the queue. If 0 or + negative, there is no limit on the total bytes of the elements. """ self.max_elements = maxsize self.max_bytes = maxbytes @@ -60,25 +63,23 @@ def put(self, item, item_bytes, block=True, timeout=None): Args: item: The item to put into the queue. - item_size: The size of the item. + item_bytes: The size of the item. block: If True, block until space is available. If False, raise queue.Full immediately if the queue is full. timeout: If block is True, wait for at most `timeout` seconds. If None, block indefinitely. Raises: - ValueError: If timeout or item_size is negative. + ValueError: If timeout or item_bytes is negative. queue.Full: If the queue is full and block is False or the timeout occurs. """ if timeout is not None and timeout < 0: raise ValueError("'timeout' must be a non-negative number") if item_bytes < 0: - raise ValueError("'item_bytes' must be a positive number") + raise ValueError("'item_bytes' must be a non-negative number") with self._mutex: - if not self._waiting_writers and not self._is_full_locked( - item_bytes - ): + if not self._waiting_writers and not self._is_full_locked(item_bytes): self._queue.append((item, item_bytes)) self._byte_size += item_bytes self._not_empty.notify() @@ -102,8 +103,7 @@ def put(self, item, item_bytes, block=True, timeout=None): my_cond.wait(remaining) if self._waiting_writers[0] is my_cond and not self._is_full_locked( - item_bytes - ): + item_bytes): break self._queue.append((item, item_bytes)) @@ -171,12 +171,12 @@ def get_nowait(self): return self.get(block=False) def byte_size(self): - """Return the total byte weight of elements in the queue.""" + """Return the total byte size of elements in the queue.""" with self._mutex: return self._byte_size def blocked_byte_size(self): - """Return the total byte weight of elements in the queue that are blocked.""" + """Return the total byte size of elements in the queue that are blocked.""" with self._mutex: return self._blocked_bytes @@ -185,12 +185,12 @@ def qsize(self): with self._mutex: return len(self._queue) - def _is_full_locked(self, item_size): + def _is_full_locked(self, item_bytes): # Always let in a single element, regardless of size. if not self._queue: return False if self.max_elements > 0 and len(self._queue) >= self.max_elements: return True - if self.max_bytes > 0 and self._byte_size + item_size > self.max_bytes: + if self.max_bytes > 0 and self._byte_size + item_bytes > self.max_bytes: return True return False diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py index 8c6804dfa328..4fec5c8f4dba 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue_test.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -123,7 +123,7 @@ def delayed_consumer(): # The put should succeed once the consumer runs, use a high timeout to # flakiness. - bq.put(FakeItem(5), timeout=60) + bq.put('item', timeout=60) t.join() def test_negative_timeout(self): From fc06a1e93ce3ce3044f0acba20c187fb5092bb76 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Mon, 11 May 2026 16:04:23 +0200 Subject: [PATCH 5/8] add missing pxd file, fixup test --- .../apache_beam/runners/worker/data_plane.py | 6 ++-- .../apache_beam/utils/byte_limited_queue.pxd | 30 +++++++++++++++++++ .../utils/byte_limited_queue_test.py | 16 +--------- 3 files changed, 33 insertions(+), 19 deletions(-) create mode 100644 sdks/python/apache_beam/utils/byte_limited_queue.pxd diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 0f9fac22a031..79404ec69336 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -460,13 +460,11 @@ def __init__(self, data_buffer_time_limit_ms=0): self._data_buffer_time_limit_ms = data_buffer_time_limit_ms self._to_send = ByteLimitedQueue( maxsize=10000, - maxbytes=100 << 20 - ) # type: ByteLimitedQueue[DataOrTimers] + maxbytes=100 << 20) # type: ByteLimitedQueue[DataOrTimers] self._received = collections.defaultdict( lambda: ByteLimitedQueue( maxsize=5, - maxbytes=100 << 20, - ) + maxbytes=100 << 20) ) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd b/sdks/python/apache_beam/utils/byte_limited_queue.pxd new file mode 100644 index 000000000000..4854d824efde --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: overflowcheck=True + +cdef class ByteLimitedQueue(object): + cdef readonly int max_elements + cdef readonly int max_bytes + cdef readonly int _byte_size + cdef readonly object _mutex + cdef readonly object _not_empty + cdef readonly object _waiting_writers + cdef readonly object _queue + cdef readonly int _blocked_bytes + + cpdef bint _is_full_locked(self, int item_bytes) except -1 diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py index 4fec5c8f4dba..bf0aa70750cc 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue_test.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -123,7 +123,7 @@ def delayed_consumer(): # The put should succeed once the consumer runs, use a high timeout to # flakiness. - bq.put('item', timeout=60) + bq.put('item', 5, timeout=60) t.join() def test_negative_timeout(self): @@ -140,20 +140,6 @@ def test_single_element_override(self): self.assertEqual(bq.qsize(), 1) self.assertEqual(bq.byte_size(), 50) - def test_inconsistent_weighing_fn(self): - # Return a different weight for the same item. - weights = [10, 5] - bq = ByteLimitedQueue(lambda x: weights.pop(0), maxweight=100) - - bq.put(1) - self.assertEqual(bq.byte_size(), 10) - - # Upon popping, the weighing function (if called) would have returned 5, - # but the stored weight prevents corruption and cleanly reduces the size to - # 0. - bq.get() - self.assertEqual(bq.byte_size(), 0) - def test_fairness(self): bq = ByteLimitedQueue(maxbytes=10) # Put an initial item so that the queue is not empty, From 0080e2a447aef8b1b33447921d68c72f190d3408 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Mon, 11 May 2026 17:14:28 +0200 Subject: [PATCH 6/8] use 64-bit for size in pxd --- sdks/python/apache_beam/runners/worker/data_plane.py | 4 +--- sdks/python/apache_beam/utils/byte_limited_queue.pxd | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 79404ec69336..c05abbc2666c 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -462,9 +462,7 @@ def __init__(self, data_buffer_time_limit_ms=0): maxsize=10000, maxbytes=100 << 20) # type: ByteLimitedQueue[DataOrTimers] self._received = collections.defaultdict( - lambda: ByteLimitedQueue( - maxsize=5, - maxbytes=100 << 20) + lambda: ByteLimitedQueue(maxsize=5, maxbytes=100 << 20) ) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd b/sdks/python/apache_beam/utils/byte_limited_queue.pxd index 4854d824efde..79ac6b2109bb 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.pxd +++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd @@ -18,13 +18,13 @@ # cython: overflowcheck=True cdef class ByteLimitedQueue(object): - cdef readonly int max_elements - cdef readonly int max_bytes - cdef readonly int _byte_size + cdef readonly Py_ssize_t max_elements + cdef readonly Py_ssize_t max_bytes + cdef readonly Py_ssize_t _byte_size cdef readonly object _mutex cdef readonly object _not_empty cdef readonly object _waiting_writers cdef readonly object _queue - cdef readonly int _blocked_bytes + cdef readonly Py_ssize_t _blocked_bytes - cpdef bint _is_full_locked(self, int item_bytes) except -1 + cpdef bint _is_full_locked(self, Py_ssize_t item_bytes) except -1 From 53da6ffcfd81553e0a8aca69c2cf99eab739efeb Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Tue, 12 May 2026 13:47:08 +0200 Subject: [PATCH 7/8] address comments --- .../apache_beam/runners/worker/data_plane.py | 13 ++++---- .../apache_beam/utils/byte_limited_queue.pxd | 2 +- .../apache_beam/utils/byte_limited_queue.py | 21 ++++++------ .../utils/byte_limited_queue_test.py | 32 ++++++++++++++++++- sdks/python/setup.py | 1 + 5 files changed, 50 insertions(+), 19 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index c05abbc2666c..a5589ac33a1b 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -591,7 +591,7 @@ def add_to_send_queue(data): if data: elem = beam_fn_api_pb2.Elements.Data( instruction_id=instruction_id, transform_id=transform_id, data=data) - self._to_send.put(elem, self._get_element_bytes(elem)) + self._to_send.put(elem, self._get_element_size_bytes(elem)) def close_callback(data): # type: (bytes) -> None @@ -601,7 +601,7 @@ def close_callback(data): instruction_id=instruction_id, transform_id=transform_id, is_last=True) - self._to_send.put(elem, self._get_element_bytes(elem)) + self._to_send.put(elem, self._get_element_size_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -622,7 +622,7 @@ def add_to_send_queue(timer): timer_family_id=timer_family_id, timers=timer, is_last=False) - self._to_send.put(elem, self._get_element_bytes(elem)) + self._to_send.put(elem, self._get_element_size_bytes(elem)) def close_callback(timer): # type: (bytes) -> None @@ -632,7 +632,7 @@ def close_callback(timer): transform_id=transform_id, timer_family_id=timer_family_id, is_last=True) - self._to_send.put(elem, self._get_element_bytes(elem)) + self._to_send.put(elem, self._get_element_size_bytes(elem)) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -667,7 +667,7 @@ def _write_outputs(self): raise ValueError('Unexpected output element type %s' % type(stream)) yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream) - def _get_element_bytes(self, element): + def _get_element_size_bytes(self, element): # type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> int if isinstance(element, beam_fn_api_pb2.Elements.Data): return len(element.data) @@ -702,7 +702,8 @@ def _put_queue(instruction_id, element): next_discard_log_time = current_time + 10 return try: - input_queue.put(element, self._get_element_bytes(element), timeout=1) + input_queue.put( + element, self._get_element_size_bytes(element), timeout=1) return except queue.Full: current_time = time.time() diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd b/sdks/python/apache_beam/utils/byte_limited_queue.pxd index 79ac6b2109bb..0884b1e18c42 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.pxd +++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd @@ -27,4 +27,4 @@ cdef class ByteLimitedQueue(object): cdef readonly object _queue cdef readonly Py_ssize_t _blocked_bytes - cpdef bint _is_full_locked(self, Py_ssize_t item_bytes) except -1 + cpdef bint _can_fit(self, Py_ssize_t item_bytes) except -1 diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py index 1a9ef40b9e3e..2edeafbeb971 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -55,7 +55,7 @@ def __init__( self._waiting_writers = collections.deque() self._queue = collections.deque() - def put(self, item, item_bytes, block=True, timeout=None): + def put(self, item, item_bytes, *, block=True, timeout=None): """Put an item into the queue. If the queue is full, block until a free slot is available, unless `block` @@ -79,7 +79,7 @@ def put(self, item, item_bytes, block=True, timeout=None): raise ValueError("'item_bytes' must be a non-negative number") with self._mutex: - if not self._waiting_writers and not self._is_full_locked(item_bytes): + if not self._waiting_writers and self._can_fit(item_bytes): self._queue.append((item, item_bytes)) self._byte_size += item_bytes self._not_empty.notify() @@ -102,8 +102,7 @@ def put(self, item, item_bytes, block=True, timeout=None): raise queue.Full my_cond.wait(remaining) - if self._waiting_writers[0] is my_cond and not self._is_full_locked( - item_bytes): + if self._waiting_writers[0] is my_cond and self._can_fit(item_bytes): break self._queue.append((item, item_bytes)) @@ -120,7 +119,7 @@ def put(self, item, item_bytes, block=True, timeout=None): if was_first and self._waiting_writers: self._waiting_writers[0].notify() - def get(self, block=True, timeout=None): + def get(self, *, block=True, timeout=None): """Remove and return an item from the queue. If the queue is empty, block until an item is available, unless `block` @@ -143,7 +142,7 @@ def get(self, block=True, timeout=None): if timeout is not None and timeout < 0: raise ValueError("'timeout' must be a non-negative number") - with self._not_empty: + with self._mutex: if not block: if not self._queue: raise queue.Empty @@ -185,12 +184,12 @@ def qsize(self): with self._mutex: return len(self._queue) - def _is_full_locked(self, item_bytes): + def _can_fit(self, item_bytes): # Always let in a single element, regardless of size. if not self._queue: - return False - if self.max_elements > 0 and len(self._queue) >= self.max_elements: return True + if self.max_elements > 0 and len(self._queue) >= self.max_elements: + return False if self.max_bytes > 0 and self._byte_size + item_bytes > self.max_bytes: - return True - return False + return False + return True diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py index bf0aa70750cc..27ccb2421844 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue_test.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -105,7 +105,7 @@ def consumer(): self.assertEqual(len(received), 202) self.assertEqual(sum(received), 2 * sum(range(101))) - def test_multithreading_timeout(self): + def test_put_timeout(self): bq = ByteLimitedQueue(maxsize=0, maxbytes=10) bq.put('10', 10) @@ -126,11 +126,41 @@ def delayed_consumer(): bq.put('item', 5, timeout=60) t.join() + def test_get_timeout(self): + bq = ByteLimitedQueue(maxsize=0, maxbytes=100) + with self.assertRaises(queue.Empty): + bq.get(block=False) + with self.assertRaises(queue.Empty): + bq.get(timeout=0.0) + with self.assertRaises(queue.Empty): + bq.get(timeout=.01) + + bq.put('1', 1) + self.assertEqual('1', bq.get(timeout=0)) + + bq.put('2', 2) + self.assertEqual('2', bq.get(timeout=0.1)) + + def delayed_producer(): + time.sleep(0.05) + bq.put('3', 3) + + # Start a thread that will produce soon + t = threading.Thread(target=delayed_producer) + t.start() + + # The get should succeed once the produer runs, use a high timeout to + # flakiness. + self.assertEqual('3', bq.get(timeout=60)) + t.join() + def test_negative_timeout(self): bq = ByteLimitedQueue() # Putting an item with a negative timeout should raise ValueError. with self.assertRaises(ValueError): bq.put('5', 5, timeout=-1) + with self.assertRaises(ValueError): + bq.get(timeout=-1) def test_single_element_override(self): bq = ByteLimitedQueue(maxbytes=10) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index b3fb98d8b0ef..45781a44c4b1 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -368,6 +368,7 @@ def get_portability_package_data(): 'apache_beam/runners/worker/operations.py', 'apache_beam/transforms/cy_combiners.py', 'apache_beam/transforms/stats.py', + 'apache_beam/utils/byte_limited_queue.py', 'apache_beam/utils/counters.py', 'apache_beam/utils/windowed_value.py', ]) From ef76e6ce2e0f1ee46b845be2aa509c75f62e03dd Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Tue, 12 May 2026 15:58:01 +0200 Subject: [PATCH 8/8] add condition caching --- sdks/python/apache_beam/utils/byte_limited_queue.pxd | 1 + sdks/python/apache_beam/utils/byte_limited_queue.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd b/sdks/python/apache_beam/utils/byte_limited_queue.pxd index 0884b1e18c42..396185e8e101 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.pxd +++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd @@ -24,6 +24,7 @@ cdef class ByteLimitedQueue(object): cdef readonly object _mutex cdef readonly object _not_empty cdef readonly object _waiting_writers + cdef readonly list _condition_pool cdef readonly object _queue cdef readonly Py_ssize_t _blocked_bytes diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py index 2edeafbeb971..a6ff669800c2 100644 --- a/sdks/python/apache_beam/utils/byte_limited_queue.py +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -48,11 +48,14 @@ def __init__( """ self.max_elements = maxsize self.max_bytes = maxbytes + self._byte_size = 0 self._blocked_bytes = 0 self._mutex = threading.Lock() self._not_empty = threading.Condition(self._mutex) + self._waiting_writers = collections.deque() + self._condition_pool = [] self._queue = collections.deque() def put(self, item, item_bytes, *, block=True, timeout=None): @@ -88,8 +91,13 @@ def put(self, item, item_bytes, *, block=True, timeout=None): if not block: raise queue.Full - my_cond = threading.Condition(self._mutex) + # Reuse or create a condition + my_cond = ( + self._condition_pool.pop() + if self._condition_pool else threading.Condition(self._mutex)) + endtime = time.monotonic() + timeout if timeout is not None else None + try: self._blocked_bytes += item_bytes self._waiting_writers.append(my_cond) @@ -116,6 +124,7 @@ def put(self, item, item_bytes, *, block=True, timeout=None): self._waiting_writers.popleft() else: self._waiting_writers.remove(my_cond) + self._condition_pool.append(my_cond) if was_first and self._waiting_writers: self._waiting_writers[0].notify()