From c03c3d59940b870fe36e5bf5ee7ea5742829b317 Mon Sep 17 00:00:00 2001 From: prabhnoor0212 Date: Wed, 1 Apr 2026 21:08:05 -0400 Subject: [PATCH 1/2] Fix BigQuery enrichment batch handling for duplicate keys --- .../enrichment_handlers/bigquery.py | 18 ++++--- .../enrichment_handlers/bigquery_test.py | 54 +++++++++++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 115c5320767e..4fb8aae86f02 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -15,6 +15,7 @@ # limitations under the License. # import logging +from collections import defaultdict from collections.abc import Callable from collections.abc import Mapping from typing import Any @@ -189,7 +190,7 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): if isinstance(request, list): values = [] responses = [] - requests_map: dict[Any, Any] = {} + requests_map: dict[Any, list[beam.Row]] = defaultdict(list) batch_size = len(request) raw_query = self.query_template if batch_size > 1: @@ -208,25 +209,28 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) values.extend(current_values) - requests_map[self.create_row_key(req)] = req + requests_map[self.create_row_key(req)].append(req) query = raw_query.format(*values) responses_dict = self._execute_query(query) - unmatched_requests = requests_map.copy() + unmatched_requests = { + key: list(reqs) for key, reqs in requests_map.items() + } if responses_dict: for response in responses_dict: response_row = beam.Row(**response) response_key = self.create_row_key(response_row) if response_key in unmatched_requests: - req = unmatched_requests.pop(response_key) - responses.append((req, response_row)) + for req in unmatched_requests.pop(response_key): + responses.append((req, response_row)) if unmatched_requests: if self.throw_exception_on_empty_results: raise ValueError(f"no matching row found for query: {query}") else: _LOGGER.warning('no matching row found for query: %s', query) - for req in unmatched_requests.values(): - responses.append((req, beam.Row())) + for reqs in unmatched_requests.values(): + for req in reqs: + responses.append((req, beam.Row())) return responses else: request_dict = request._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py index 98ac6244910c..67837dbb1145 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -15,9 +15,12 @@ # limitations under the License. # import unittest +from unittest import mock from parameterized import parameterized +import apache_beam as beam + # pylint: disable=ungrouped-imports try: from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler @@ -65,6 +68,57 @@ def test_valid_params( max_batch_size=max_batch_size, ) + def test_batch_mode_fans_out_response_for_duplicate_keys(self): + handler = BigQueryEnrichmentHandler( + project=self.project, + table_name='project.dataset.table', + row_restriction_template="id='{}'", + fields=['id'], + min_batch_size=2, + max_batch_size=2, + ) + requests = [ + beam.Row(id='1', name='first'), + beam.Row(id='1', name='second') + ] + + with mock.patch.object( + handler, + '_execute_query', + return_value=[{'id': '1', 'value': 'enriched'}]): + responses = handler(requests) + + self.assertEqual( + responses, + [ + (requests[0], beam.Row(id='1', value='enriched')), + (requests[1], beam.Row(id='1', value='enriched')), + ], + ) + + def test_batch_mode_emits_empty_rows_for_all_unmatched_duplicate_keys(self): + handler = BigQueryEnrichmentHandler( + project=self.project, + table_name='project.dataset.table', + row_restriction_template="id='{}'", + fields=['id'], + min_batch_size=2, + max_batch_size=2, + throw_exception_on_empty_results=False, + ) + requests = [ + beam.Row(id='1', name='first'), + beam.Row(id='1', name='second') + ] + + with mock.patch.object(handler, '_execute_query', return_value=None): + responses = handler(requests) + + self.assertEqual( + responses, + [(requests[0], beam.Row()), (requests[1], beam.Row())], + ) + if __name__ == '__main__': unittest.main() From b5a686ef6b920297d26ab53c7079be36c8fce2d2 Mon Sep 17 00:00:00 2001 From: prabhnoor0212 Date: Wed, 1 Apr 2026 21:38:54 -0400 Subject: [PATCH 2/2] Apply yapf formatting for BigQuery enrichment changes --- .../transforms/enrichment_handlers/bigquery.py | 3 ++- .../enrichment_handlers/bigquery_test.py | 17 +++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 4fb8aae86f02..2306d7d97a57 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -214,7 +214,8 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): responses_dict = self._execute_query(query) unmatched_requests = { - key: list(reqs) for key, reqs in requests_map.items() + key: list(reqs) + for key, reqs in requests_map.items() } if responses_dict: for response in responses_dict: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py index 67837dbb1145..98508baf6619 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -77,15 +77,11 @@ def test_batch_mode_fans_out_response_for_duplicate_keys(self): min_batch_size=2, max_batch_size=2, ) - requests = [ - beam.Row(id='1', name='first'), - beam.Row(id='1', name='second') - ] + requests = [beam.Row(id='1', name='first'), beam.Row(id='1', name='second')] - with mock.patch.object( - handler, - '_execute_query', - return_value=[{'id': '1', 'value': 'enriched'}]): + with mock.patch.object(handler, + '_execute_query', + return_value=[{'id': '1', 'value': 'enriched'}]): responses = handler(requests) self.assertEqual( @@ -106,10 +102,7 @@ def test_batch_mode_emits_empty_rows_for_all_unmatched_duplicate_keys(self): max_batch_size=2, throw_exception_on_empty_results=False, ) - requests = [ - beam.Row(id='1', name='first'), - beam.Row(id='1', name='second') - ] + requests = [beam.Row(id='1', name='first'), beam.Row(id='1', name='second')] with mock.patch.object(handler, '_execute_query', return_value=None): responses = handler(requests)