From 3d137746a842368c88c391a31bc1e1c185b5d4ec Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 00:09:49 -0400 Subject: [PATCH 01/13] Add a test to reproduce hanging. --- .../apache_beam/transforms/async_dofn_test.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 81c7b8e163ff..ec276f91af13 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -16,6 +16,7 @@ # import logging +import multiprocessing import random import time import unittest @@ -487,6 +488,46 @@ def add_item(i): self.check_output(results[i], expected_outputs['key' + str(i)]) self.assertEqual(bag_states['key' + str(i)].items, []) + @staticmethod + def _run_reset_state_deadlock_scenario(use_asyncio): + if use_asyncio: + return + + dofn = BasicDofn(sleep_time=0.5) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=False) + async_dofn.setup() + fake_bag_state = FakeBagState([]) + fake_timer = FakeTimer(0) + + # Start processing an item. This starts a worker thread sleeping for 0.5s. + async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer) + time.sleep(0.05) + + # Attempt to call reset_state(). If the fix is NOT applied, this will deadlock + # forever because reset_state() holds the lock while calling shutdown(wait=True), + # blocking the future's done callback from acquiring the lock. + async_lib.AsyncWrapper.reset_state() + + def test_reset_state_hang_reproduction(self): + # Run the deadlock scenario in a separate process so that if it hangs, + # we can terminate it without causing the main pytest process to hang at exit. + if self.use_asyncio: + return + + p = multiprocessing.Process( + target=AsyncTest._run_reset_state_deadlock_scenario, + args=(self.use_asyncio,)) + p.start() + p.join(timeout=3.0) + + if p.is_alive(): + p.terminate() + p.join() + self.fail("reset_state() deadlocked/hung waiting for active threads to finish") + else: + self.assertEqual(p.exitcode, 0) + if __name__ == '__main__': unittest.main() + From ed0c2d29b8d7b7f614e6b6165f76d170e5c4a3e3 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 00:33:18 -0400 Subject: [PATCH 02/13] Fix deadlock between shutdown in main thread and done callback in worker threads. --- sdks/python/apache_beam/transforms/async_dofn.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 28568bd893c5..a0248bd70c86 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -165,9 +165,14 @@ def reset_state(): if AsyncWrapper._loop_started is not None: AsyncWrapper._loop_started.clear() - for pool in AsyncWrapper._pool.values(): - pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( - wait=True, cancel_futures=True) + pools_to_shutdown = [ + pool.acquire(AsyncWrapper.initialize_pool(1)) + for pool in AsyncWrapper._pool.values() + ] + + for pool in pools_to_shutdown: + pool.shutdown(wait=True, cancel_futures=True) + with AsyncWrapper._lock: AsyncWrapper._pool = {} AsyncWrapper._processing_elements = {} @@ -268,7 +273,8 @@ async def _collect(result): def decrement_items_in_buffer(self, future): with AsyncWrapper._lock: - AsyncWrapper._items_in_buffer[self._uuid] -= 1 + if self._uuid in AsyncWrapper._items_in_buffer: + AsyncWrapper._items_in_buffer[self._uuid] -= 1 def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): """Schedules an item to be processed asynchronously if there is room. From f737cc18a42c8873c3abff1e432736411ebdb8c2 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 00:53:32 -0400 Subject: [PATCH 03/13] Address review comments. --- sdks/python/apache_beam/transforms/async_dofn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index a0248bd70c86..a6312340b234 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -165,10 +165,11 @@ def reset_state(): if AsyncWrapper._loop_started is not None: AsyncWrapper._loop_started.clear() - pools_to_shutdown = [ - pool.acquire(AsyncWrapper.initialize_pool(1)) - for pool in AsyncWrapper._pool.values() - ] + pools = list(AsyncWrapper._pool.values()) + + pools_to_shutdown = [ + pool.acquire(AsyncWrapper.initialize_pool(1)) for pool in pools + ] for pool in pools_to_shutdown: pool.shutdown(wait=True, cancel_futures=True) From d2fbefa99e9c9a3680fee44118eef19d6979a62b Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 08:33:10 -0400 Subject: [PATCH 04/13] Fix format --- sdks/python/apache_beam/transforms/async_dofn_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index ec276f91af13..18ecb133a8a8 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -516,14 +516,15 @@ def test_reset_state_hang_reproduction(self): p = multiprocessing.Process( target=AsyncTest._run_reset_state_deadlock_scenario, - args=(self.use_asyncio,)) + args=(self.use_asyncio, )) p.start() p.join(timeout=3.0) if p.is_alive(): p.terminate() p.join() - self.fail("reset_state() deadlocked/hung waiting for active threads to finish") + self.fail( + "reset_state() deadlocked/hung waiting for active threads to finish") else: self.assertEqual(p.exitcode, 0) From 95e3b2452768dd5636502621985ac2f2a2c79f53 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 09:01:59 -0400 Subject: [PATCH 05/13] Modify the test to cover reset_state() hanging in asyncio mode. --- .../apache_beam/transforms/async_dofn_test.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 18ecb133a8a8..67615009629c 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -490,30 +490,24 @@ def add_item(i): @staticmethod def _run_reset_state_deadlock_scenario(use_asyncio): - if use_asyncio: - return - dofn = BasicDofn(sleep_time=0.5) - async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=False) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) - # Start processing an item. This starts a worker thread sleeping for 0.5s. + # Start processing an item. This starts a worker thread/coroutine sleeping for 0.5s. async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer) time.sleep(0.05) # Attempt to call reset_state(). If the fix is NOT applied, this will deadlock - # forever because reset_state() holds the lock while calling shutdown(wait=True), + # forever because reset_state() holds the lock while waiting for active tasks/threads, # blocking the future's done callback from acquiring the lock. async_lib.AsyncWrapper.reset_state() def test_reset_state_hang_reproduction(self): # Run the deadlock scenario in a separate process so that if it hangs, # we can terminate it without causing the main pytest process to hang at exit. - if self.use_asyncio: - return - p = multiprocessing.Process( target=AsyncTest._run_reset_state_deadlock_scenario, args=(self.use_asyncio, )) @@ -524,7 +518,7 @@ def test_reset_state_hang_reproduction(self): p.terminate() p.join() self.fail( - "reset_state() deadlocked/hung waiting for active threads to finish") + "reset_state() deadlocked/hung waiting for active threads/tasks to finish") else: self.assertEqual(p.exitcode, 0) From 63901b0db706ddb000f4bc85081478a1cf097125 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 09:11:53 -0400 Subject: [PATCH 06/13] Fix the deadlock when asyncio is used. --- sdks/python/apache_beam/transforms/async_dofn.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index a6312340b234..ad3d5bc66469 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -153,12 +153,13 @@ def _run_event_loop(): @staticmethod def reset_state(): + event_loop_thread_to_join = None with AsyncWrapper._lock: if AsyncWrapper._event_loop: AsyncWrapper._event_loop.call_soon_threadsafe( AsyncWrapper._event_loop.stop) if AsyncWrapper._event_loop_thread: - AsyncWrapper._event_loop_thread.join() + event_loop_thread_to_join = AsyncWrapper._event_loop_thread AsyncWrapper._event_loop = None AsyncWrapper._event_loop_thread = None @@ -167,6 +168,17 @@ def reset_state(): pools = list(AsyncWrapper._pool.values()) + # We must join the asyncio event loop thread outside of the lock block. + # If joined inside the lock, the waiting thread holds the lock while blocking, + # preventing active coroutines' done callbacks from acquiring the lock on the + # event loop thread, resulting in a deadlock. + if event_loop_thread_to_join: + event_loop_thread_to_join.join() + + # We must acquire and shut down the thread pools outside of the lock block. + # If shutdown(wait=True) is called inside the lock, the caller blocks holding + # the lock, preventing active worker threads from acquiring the lock to run + # their done callbacks, resulting in a deadlock. pools_to_shutdown = [ pool.acquire(AsyncWrapper.initialize_pool(1)) for pool in pools ] From 738c81b9946eaa0e52497541647a42978c205240 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 09:14:48 -0400 Subject: [PATCH 07/13] Fix formatting. --- sdks/python/apache_beam/transforms/async_dofn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 67615009629c..ddb574a6908c 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -518,11 +518,11 @@ def test_reset_state_hang_reproduction(self): p.terminate() p.join() self.fail( - "reset_state() deadlocked/hung waiting for active threads/tasks to finish") + "reset_state() deadlocked/hung waiting for active threads/tasks to finish" + ) else: self.assertEqual(p.exitcode, 0) if __name__ == '__main__': unittest.main() - From 937071297c273ac7d9a7623690e64995685f1f9e Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 16:08:32 -0400 Subject: [PATCH 08/13] Increase timeout to reduce false-positives. --- sdks/python/apache_beam/transforms/async_dofn_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index ddb574a6908c..a3117d2fea0b 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -512,7 +512,7 @@ def test_reset_state_hang_reproduction(self): target=AsyncTest._run_reset_state_deadlock_scenario, args=(self.use_asyncio, )) p.start() - p.join(timeout=3.0) + p.join(timeout=10.0) if p.is_alive(): p.terminate() From 11ffedc241c7d16ee8066e2c18ef3d60165ad90a Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 17:05:43 -0400 Subject: [PATCH 09/13] Revise test function names and some comments. --- .../apache_beam/transforms/async_dofn_test.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index a3117d2fea0b..39901d791fb9 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -489,7 +489,7 @@ def add_item(i): self.assertEqual(bag_states['key' + str(i)].items, []) @staticmethod - def _run_reset_state_deadlock_scenario(use_asyncio): + def _run_reset_state_concurrent_teardown(use_asyncio): dofn = BasicDofn(sleep_time=0.5) async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=use_asyncio) async_dofn.setup() @@ -500,16 +500,15 @@ def _run_reset_state_deadlock_scenario(use_asyncio): async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer) time.sleep(0.05) - # Attempt to call reset_state(). If the fix is NOT applied, this will deadlock - # forever because reset_state() holds the lock while waiting for active tasks/threads, - # blocking the future's done callback from acquiring the lock. + # Verify that calling reset_state() while background tasks are actively running + # completes cleanly without causing lock-ordering deadlocks. async_lib.AsyncWrapper.reset_state() - def test_reset_state_hang_reproduction(self): - # Run the deadlock scenario in a separate process so that if it hangs, - # we can terminate it without causing the main pytest process to hang at exit. + def test_reset_state_concurrent_teardown(self): + # Verify concurrent teardown safety in a separate process to prevent any potential + # regressions from freezing the main pytest process at exit. p = multiprocessing.Process( - target=AsyncTest._run_reset_state_deadlock_scenario, + target=AsyncTest._run_reset_state_concurrent_teardown, args=(self.use_asyncio, )) p.start() p.join(timeout=10.0) From 89f88ae5ce0ae019830a349be2d0e73f5a3b9d9f Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 17:25:48 -0400 Subject: [PATCH 10/13] Fix dataframe warning on chained assignment --- sdks/python/apache_beam/dataframe/schemas.py | 7 ++----- .../typehints/pandas_type_compatibility.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index f849ab11e77c..67759a1b1b72 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -95,11 +95,8 @@ def generate_proxy(element_type: type) -> pd.DataFrame: else: fields = named_fields_from_element_type(element_type) proxy = pd.DataFrame(columns=[name for name, _ in fields]) - for name, typehint in fields: - dtype = dtype_from_typehint(typehint) - proxy[name] = proxy[name].astype(dtype) - - return proxy + dtypes = {name: dtype_from_typehint(typehint) for name, typehint in fields} + return proxy.astype(dtypes) def element_type_from_dataframe( diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index 45ae27baffe7..8158b4443e1a 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -223,8 +223,11 @@ def _get_series(self, batch: pd.DataFrame): def produce_batch(self, elements): batch = pd.DataFrame.from_records(elements, columns=self._columns) - for column, typehint in self._element_type._fields: - batch[column] = batch[column].astype(dtype_from_typehint(typehint)) + dtypes = { + column: dtype_from_typehint(typehint) + for column, typehint in self._element_type._fields + } + batch = batch.astype(dtypes) return batch @@ -249,8 +252,11 @@ def produce_batch(self, elements): # Note from_records has an index= parameter batch = pd.DataFrame.from_records(elements, columns=self._columns) - for column, typehint in self._element_type._fields: - batch[column] = batch[column].astype(dtype_from_typehint(typehint)) + dtypes = { + column: dtype_from_typehint(typehint) + for column, typehint in self._element_type._fields + } + batch = batch.astype(dtypes) return batch.set_index(self._index_columns) From c0870fb2d94838affd61957b7052ec93e1c90353 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 21:04:58 -0400 Subject: [PATCH 11/13] Fix a flaky test in ApproximateQuantilesTest --- sdks/python/apache_beam/transforms/stats_test.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/stats_test.py b/sdks/python/apache_beam/transforms/stats_test.py index bf634c003a07..b8cb43a598fe 100644 --- a/sdks/python/apache_beam/transforms/stats_test.py +++ b/sdks/python/apache_beam/transforms/stats_test.py @@ -264,6 +264,10 @@ def _approx_quantile_generator(size, num_of_quantiles, absoluteError): quantiles.append(size - 1) return quantiles + @staticmethod + def _sum_and_second(x): + return (sum(x), x[1]) + def test_quantiles_globaly(self): with TestPipeline() as p: pc = p | Create(list(range(101))) @@ -490,22 +494,27 @@ def test_batched_quantiles(self): 3, input_batched=True)) with_key = ( pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally( - 3, key=sum, input_batched=True)) + 3, key=ApproximateQuantilesTest._sum_and_second, input_batched=True)) key_with_reversed = ( pc | 'Globally with key and reversed' >> beam.ApproximateQuantiles.Globally( - 3, key=sum, reverse=True, input_batched=True)) + 3, key=ApproximateQuantilesTest._sum_and_second, reverse=True, input_batched=True)) assert_that( globally, equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]), label='checkGlobally') + # When key is present, both (72.5, 225) and (22.5, 275) produce the exact same + # sum (297.5). If we just use key=sum, tie-breaking is sensitive to bundle merging + # order and shared class-level jitter state, leading to flaky test failures. + # With the secondary key (defined in _sum_and_second), we can break ties + # deterministically. assert_that( with_key, equal_to([[(50.0, 0), (72.5, 225), (99.9, 499)]]), label='checkGloballyWithKey') assert_that( key_with_reversed, - equal_to([[(99.9, 499), (72.5, 225), (50.0, 0)]]), + equal_to([[(99.9, 499), (22.5, 275), (50.0, 0)]]), label='checkGloballyWithKeyAndReversed') def test_batched_weighted_quantiles(self): From 2fa7eb721bf4f89b61d3b24d7487f4bbcd526874 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 9 May 2026 21:21:41 -0400 Subject: [PATCH 12/13] Formatting --- sdks/python/apache_beam/transforms/stats_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/transforms/stats_test.py b/sdks/python/apache_beam/transforms/stats_test.py index b8cb43a598fe..b236c7e3d5ac 100644 --- a/sdks/python/apache_beam/transforms/stats_test.py +++ b/sdks/python/apache_beam/transforms/stats_test.py @@ -494,19 +494,24 @@ def test_batched_quantiles(self): 3, input_batched=True)) with_key = ( pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally( - 3, key=ApproximateQuantilesTest._sum_and_second, input_batched=True)) + 3, + key=ApproximateQuantilesTest._sum_and_second, + input_batched=True)) key_with_reversed = ( pc | 'Globally with key and reversed' >> beam.ApproximateQuantiles.Globally( - 3, key=ApproximateQuantilesTest._sum_and_second, reverse=True, input_batched=True)) + 3, + key=ApproximateQuantilesTest._sum_and_second, + reverse=True, + input_batched=True)) assert_that( globally, equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]), label='checkGlobally') - # When key is present, both (72.5, 225) and (22.5, 275) produce the exact same - # sum (297.5). If we just use key=sum, tie-breaking is sensitive to bundle merging + # When key is present, both (72.5, 225) and (22.5, 275) produce the exact same + # sum (297.5). If we just use key=sum, tie-breaking is sensitive to bundle merging # order and shared class-level jitter state, leading to flaky test failures. - # With the secondary key (defined in _sum_and_second), we can break ties + # With the secondary key (defined in _sum_and_second), we can break ties # deterministically. assert_that( with_key, From 5c0cb2bc949e94890fe7639420b01bd3d24918cb Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sun, 10 May 2026 17:59:53 -0400 Subject: [PATCH 13/13] Propagate dynamic SDF split exceptions to prevent pipeline hang --- .../pkg/beam/runners/prism/internal/stage.go | 2 +- .../runners/portability/prism_runner_test.py | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index c4758984af83..9e5034b58c00 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -240,7 +240,7 @@ progress: sr, err := b.Split(ctx, wk, 0.5 /* fraction of remainder */, nil /* allowed splits */) if err != nil { slog.Warn("SDK Error from split, aborting splits and failing bundle", "bundle", rb, "error", err.Error()) - if b.BundleErr != nil { + if b.BundleErr == nil { b.BundleErr = err } return b.BundleErr diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index a65f9a9960b4..4735950d77aa 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -295,6 +295,98 @@ def test_after_count_trigger_streaming(self): ('B-3', {10, 15, 16}), ]))) + def test_sdf_split_exception(self): + from apache_beam.io.iobase import RestrictionTracker + + class SimpleTracker(RestrictionTracker): + def __init__(self, rest): + self._rest = rest + def current_restriction(self): + return self._rest + def try_claim(self, position): + return True + def check_done(self): + pass + def is_bounded(self): + return True + + class FailingSplitProvider(beam.RestrictionProvider): + def initial_restriction(self, element): + return (0, 10) + def create_tracker(self, restriction): + return SimpleTracker(restriction) + def restriction_size(self, element, restriction): + return 10 + def split_and_size(self, element, restriction): + raise RuntimeError("400 invalid split") + + class SplittableFn(beam.DoFn): + def process(self, element, restriction=beam.DoFn.RestrictionParam(FailingSplitProvider())): + yield element + + try: + with self.create_pipeline() as p: + _ = p | beam.Create([1]) | beam.ParDo(SplittableFn()) + except Exception as e: + print("\n[ACTUAL EXCEPTION RAISED IN STATIC SPLIT]:\n%s" % e) + self.assertRegex(str(e), "invalid split") + else: + self.fail("Exception not raised") + + def test_sdf_dynamic_split_exception(self): + from apache_beam.io.iobase import RestrictionTracker + from apache_beam.io.iobase import RestrictionProgress + import time + + class DynamicSplitTracker(RestrictionTracker): + def __init__(self, rest): + self._rest = rest + + def current_restriction(self): + return self._rest + + def current_progress(self): + return RestrictionProgress(fraction=0.5) + + def try_claim(self, position): + return True + + def check_done(self): + pass + + def is_bounded(self): + return True + + def try_split(self, fraction_of_remainder): + # Raised when the runner sends a dynamic runtime splitting request + raise RuntimeError("dynamic runtime split failed") + + class DynamicSplitProvider(beam.RestrictionProvider): + def initial_restriction(self, element): + return (0, 100) + + def create_tracker(self, restriction): + return DynamicSplitTracker(restriction) + + def restriction_size(self, element, restriction): + return 100 + + class SleepingSDF(beam.DoFn): + def process(self, element, restriction=beam.DoFn.RestrictionParam(DynamicSplitProvider())): + # Sleep enough to guarantee that Prism sends a dynamic split request due to slow progress + for i in range(10): + time.sleep(0.5) + yield element + i + + try: + with self.create_pipeline() as p: + _ = p | beam.Create([1]) | beam.ParDo(SleepingSDF()) + except Exception as e: + print("\n[ACTUAL EXCEPTION RAISED IN DYNAMIC SPLIT]:\n%s" % e) + self.assertRegex(str(e), "dynamic runtime split failed") + else: + self.fail("Exception not raised") + class PrismJobServerTest(unittest.TestCase): def setUp(self) -> None: