diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index b05fdd305a60..13a0f13e617b 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -38,6 +38,8 @@ import fasteners +from apache_beam.utils import retry + # In some python versions, there is a bug where AutoProxy doesn't handle # the kwarg 'manager_owned'. We implement our own backup here to make sure # we avoid this problem. More info here: @@ -391,10 +393,21 @@ def _get_manager(self): manager = _SingletonRegistrar( address=(host, int(port)), authkey=AUTH_KEY) multiprocessing.current_process().authkey = AUTH_KEY - try: + + retryable_exceptions = (ConnectionError, EOFError) + + @retry.with_exponential_backoff( + num_retries=5, + initial_delay_secs=0.1, + retry_filter=lambda exn: isinstance( + exn, retryable_exceptions)) + def connect_manager(): manager.connect() + + try: + connect_manager() self._manager = manager - except ConnectionError: + except retryable_exceptions: # The server is no longer good, assume it died. os.unlink(address_file) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 3c74903b8d99..18ed49c6fa17 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -23,6 +23,7 @@ import threading import unittest from typing import Any +from unittest.mock import patch from apache_beam.utils import multi_process_shared @@ -293,7 +294,8 @@ def setUp(self): 'mix1', 'mix2', 'test_process_exit', - 'thundering_herd_test']: + 'thundering_herd_test', + 'transient_test']: for ext in ['', '.address', '.address.error']: try: os.remove(os.path.join(tempdir, tag + ext)) @@ -461,6 +463,32 @@ def test_zombie_reaping_on_acquire(self): except Exception: pass + def test_transient_connection_error_recovery(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='transient_test', always_proxy=True, spawn_process=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='transient_test', always_proxy=True, spawn_process=True) + + counter1 = shared1.acquire() + + orig_connect = multi_process_shared._SingletonRegistrar.connect + connect_calls = [0] + + def side_effect_connect(self_mgr, *args, **kwargs): + connect_calls[0] += 1 + if connect_calls[0] == 1: + raise ConnectionError("Simulated transient connection failure") + return orig_connect(self_mgr, *args, **kwargs) + + with patch.object(multi_process_shared._SingletonRegistrar, + 'connect', + autospec=True, + side_effect=side_effect_connect): + counter2 = shared2.acquire() + + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)