diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index a78905e72..fa6a568bc 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -45,6 +45,7 @@ def __init__( slice_logger: SliceLogger, message_repository: MessageRepository, partition_reader: PartitionReader, + max_concurrent_partition_generators: Optional[int] = None, ): """ This class is responsible for handling items from a concurrent stream read process. @@ -55,6 +56,12 @@ def __init__( :param slice_logger: SliceLogger instance :param message_repository: MessageRepository instance :param partition_reader: PartitionReader instance + :param max_concurrent_partition_generators: Maximum number of partition generators allowed + to run concurrently. None means no limit. When set, should be less than the number of + workers in multi-worker mode so at least one worker slot is always available for + partition reading, preventing thread pool starvation. In single-threaded mode + (num_workers=1) the value may equal num_workers; ConcurrentSource.create() handles + this distinction. ConcurrentSource.read() passes this value explicitly. """ self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from} self._record_counter = {} @@ -62,8 +69,16 @@ def __init__( for stream in stream_instances_to_read_from: self._streams_to_running_partitions[stream.name] = set() self._record_counter[stream.name] = 0 + if ( + max_concurrent_partition_generators is not None + and max_concurrent_partition_generators < 1 + ): + raise ValueError( + f"max_concurrent_partition_generators must be >= 1 or None, got {max_concurrent_partition_generators}" + ) self._thread_pool_manager = thread_pool_manager self._partition_enqueuer = partition_enqueuer + self._max_concurrent_partition_generators = max_concurrent_partition_generators self._stream_instances_to_start_partition_generation = stream_instances_to_read_from self._streams_currently_generating_partitions: List[str] = [] self._logger = logger @@ -255,6 +270,20 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]: if not self._stream_instances_to_start_partition_generation: return None + # Enforce the concurrent generator cap so at least one worker slot is always available + # for partition reading. Recovery is guaranteed: on_partition_generation_completed + # decrements the count before calling here, so the guard always passes there. + if ( + self._max_concurrent_partition_generators is not None + and len(self._streams_currently_generating_partitions) + >= self._max_concurrent_partition_generators + ): + self._logger.debug( + f"Concurrent partition generator cap ({self._max_concurrent_partition_generators}) reached " + f"({len(self._streams_currently_generating_partitions)} active). Deferring next generator start." + ) + return None + # Remember initial queue size to avoid infinite loops if all streams are blocked max_attempts = len(self._stream_instances_to_start_partition_generation) attempts = 0 diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index de2d93523..474780bcc 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -47,6 +47,10 @@ def create( queue: Optional[Queue[QueueItem]] = None, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> "ConcurrentSource": + if initial_number_of_partitions_to_generate < 1: + raise ValueError( + f"initial_number_of_partitions_to_generate must be >= 1, got {initial_number_of_partitions_to_generate}" + ) is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 too_many_generator = ( not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers @@ -117,6 +121,7 @@ def read( self._queue, PartitionLogger(self._slice_logger, self._logger, self._message_repository), ), + max_concurrent_partition_generators=self._initial_number_partitions_to_generate, ) # Enqueue initial partition generation tasks diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index 910111a05..856bd1156 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -851,6 +851,71 @@ def test_start_next_partition_generator(self): self._partition_enqueuer.generate_partitions, self._stream ) + def test_invalid_max_concurrent_partition_generators_raises(self): + for invalid in (0, -1): + with self.assertRaises(ValueError): + ConcurrentReadProcessor( + [self._stream], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + max_concurrent_partition_generators=invalid, + ) + + def test_start_next_partition_generator_respects_concurrent_limit(self): + stream_instances_to_read_from = [self._stream] + handler = ConcurrentReadProcessor( + stream_instances_to_read_from, + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + max_concurrent_partition_generators=1, + ) + handler._streams_currently_generating_partitions.append(_STREAM_NAME) + + status_message = handler.start_next_partition_generator() + + assert status_message is None + assert ( + handler._stream_instances_to_start_partition_generation == stream_instances_to_read_from + ) + self._thread_pool_manager.submit.assert_not_called() + + def test_start_next_partition_generator_starts_when_below_limit(self): + other_stream = Mock(spec=AbstractStream) + other_stream.name = "other_stream" + other_stream.block_simultaneous_read = "" + other_stream.as_airbyte_stream.return_value = AirbyteStream( + name="other_stream", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + handler = ConcurrentReadProcessor( + [other_stream], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + max_concurrent_partition_generators=2, + ) + handler._streams_currently_generating_partitions.append(_STREAM_NAME) + + status_message = handler.start_next_partition_generator() + + assert status_message is not None + assert "other_stream" in handler._streams_currently_generating_partitions + self._thread_pool_manager.submit.assert_called_with( + self._partition_enqueuer.generate_partitions, other_stream + ) + class TestBlockSimultaneousRead(unittest.TestCase): """Tests for block_simultaneous_read functionality"""