From 5426a03d634b8eb0b641b288228acfddc95f711a Mon Sep 17 00:00:00 2001 From: claudevdm Date: Thu, 2 Apr 2026 11:39:09 -0400 Subject: [PATCH] BQ Anomaly detection: Reshuffle fanout in cdc --- .../python/BigQueryAnomalyDetection.java | 15 +--- .../src/bqmonitor/cdc.py | 80 ++++++++----------- .../src/bqmonitor/pipeline.py | 12 +-- 3 files changed, 38 insertions(+), 69 deletions(-) diff --git a/python/src/main/java/com/google/cloud/teleport/templates/python/BigQueryAnomalyDetection.java b/python/src/main/java/com/google/cloud/teleport/templates/python/BigQueryAnomalyDetection.java index c5ea489127..5f1285b13a 100644 --- a/python/src/main/java/com/google/cloud/teleport/templates/python/BigQueryAnomalyDetection.java +++ b/python/src/main/java/com/google/cloud/teleport/templates/python/BigQueryAnomalyDetection.java @@ -147,19 +147,8 @@ public interface BigQueryAnomalyDetection { regexes = {"^[a-zA-Z0-9_-]+:[a-zA-Z0-9_]+\\.[a-zA-Z0-9_]+$"}) String getSinkTable(); - @TemplateParameter.Integer( - order = 13, - optional = true, - name = "decompress_shards", - description = "Decompress Shards", - helpText = - "Number of shards for CDC Arrow batch decompression fan-out. " - + "Spreads decompression CPU across workers. " - + "0 disables fan-out (decode inline). Default: 400.") - Integer getDecompressShards(); - @TemplateParameter.Text( - order = 14, + order = 13, optional = true, name = "fanout_strategy", description = "Fanout Strategy", @@ -170,7 +159,7 @@ public interface BigQueryAnomalyDetection { String getFanoutStrategy(); @TemplateParameter.Integer( - order = 15, + order = 14, optional = true, name = "fanout", description = "Fanout Shards", diff --git a/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/cdc.py b/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/cdc.py index 0bb98127a8..4cdd4e8bc4 100644 --- a/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/cdc.py +++ b/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/cdc.py @@ -45,7 +45,6 @@ import dataclasses import datetime import logging -import random import sys import time import uuid @@ -65,8 +64,6 @@ from apache_beam.io.watermark_estimators import ManualWatermarkEstimator from apache_beam.metrics import Metrics from apache_beam.transforms.core import WatermarkEstimatorProvider -from apache_beam.transforms import trigger as beam_trigger -from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import retry from apache_beam.utils.timestamp import MAX_TIMESTAMP @@ -784,10 +781,14 @@ def _split_all_streams(self, stream_names: Tuple[str, ...], rounds of doubling. """ result = list(stream_names) + no_split = set() for round_num in range(1, max_split_rounds + 1): new_result = [] made_progress = False for name in result: + if name in no_split: + new_result.append(name) + continue response = self._storage_client.split_read_stream( request=bq_storage.types.SplitReadStreamRequest( name=name, fraction=0.5)) @@ -798,6 +799,7 @@ def _split_all_streams(self, stream_names: Tuple[str, ...], made_progress = True else: new_result.append(name) + no_split.add(name) result = new_result _LOGGER.info( '[Read] _split_all_streams round %d/%d: %d streams ' @@ -1063,7 +1065,7 @@ def _read_stream_raw( class _DecompressArrowBatchesFn(beam.DoFn): """Decompress and convert raw Arrow batches to timestamped row dicts. - Receives GBK output: (shard_key, Iterable[(schema_bytes, batch_bytes)]) + Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle and converts each batch to individual row dicts with event timestamps extracted from the change_timestamp column. """ @@ -1072,25 +1074,24 @@ def __init__(self, change_timestamp_column: str = 'change_timestamp') -> None: def process( self, - element: Tuple[int, Iterable[Tuple[bytes, bytes]]] + element: Tuple[bytes, bytes] ) -> Iterable[Dict[str, Any]]: - _, batches = element - for schema_bytes, batch_bytes in batches: - schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) - batch = pyarrow.ipc.read_record_batch( - pyarrow.py_buffer(batch_bytes), schema) - - rows = batch.to_pylist() - for row in rows: - ts = row.get(self._change_timestamp_column) - if ts is None: - raise ValueError( - 'Row missing %r column. Row keys: %s' % - (self._change_timestamp_column, list(row.keys()))) - if isinstance(ts, datetime.datetime): - ts = Timestamp.from_utc_datetime(ts) - yield TimestampedValue(row, ts) - Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows)) + schema_bytes, batch_bytes = element + schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + batch = pyarrow.ipc.read_record_batch( + pyarrow.py_buffer(batch_bytes), schema) + + rows = batch.to_pylist() + for row in rows: + ts = row.get(self._change_timestamp_column) + if ts is None: + raise ValueError( + 'Row missing %r column. Row keys: %s' % + (self._change_timestamp_column, list(row.keys()))) + if isinstance(ts, datetime.datetime): + ts = Timestamp.from_utc_datetime(ts) + yield TimestampedValue(row, ts) + Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows)) # ============================================================================= @@ -1215,12 +1216,12 @@ class ReadBigQueryChangeHistory(beam.PTransform): 1 (one round of splitting). Set 0 to disable splitting entirely. Set higher for very large tables where more parallelism is needed. - decompress_shards: If set to a positive integer, the Read SDF - emits raw compressed Arrow batches instead of decoded rows. - The batches are reshuffled for fan-out and then decoded in a - separate DoFn. This spreads decompression and Arrow-to-Python - conversion CPU across more workers. If None (default), rows - are decoded inline within the Read SDF. + reshuffle_decompress: If True (default), the Read SDF emits raw + compressed Arrow batches instead of decoded rows. The batches + are reshuffled for fan-out and then decoded in a separate DoFn. + This spreads decompression and Arrow-to-Python conversion CPU + across more workers. Set to False to decode rows inline within + the Read SDF. """ def __init__( self, @@ -1239,7 +1240,7 @@ def __init__( row_filter: Optional[str] = None, batch_arrow_read: bool = True, max_split_rounds: int = 1, - decompress_shards: Optional[int] = None) -> None: + reshuffle_decompress: bool = True) -> None: super().__init__() if bq_storage is None: raise ImportError( @@ -1274,7 +1275,7 @@ def __init__( self._row_filter = row_filter self._batch_arrow_read = batch_arrow_read self._max_split_rounds = max_split_rounds - self._decompress_shards = decompress_shards + self._reshuffle_decompress = reshuffle_decompress def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: project = self._project @@ -1354,7 +1355,7 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: row_filter=self._row_filter)) | 'CommitQueryResults' >> beam.Reshuffle()) - emit_raw = self._decompress_shards is not None + emit_raw = self._reshuffle_decompress read_sdf = beam.ParDo( _ReadStorageStreamsSDF( @@ -1377,22 +1378,11 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: | 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn())) if emit_raw: - # Fan out raw Arrow batches across decompress_shards workers - # via GBK, then decompress and convert to timestamped row dicts. - # Uses a discarding trigger so GBK fires per-element without - # waiting for the GlobalWindow to close. - num_shards = self._decompress_shards + # Reshuffle raw Arrow batches for fan-out, then decompress and + # convert to timestamped row dicts in a separate DoFn. rows = ( read_outputs['rows'] - | 'ShardBatches' >> beam.WithKeys( - lambda _, n=num_shards: random.randint(0, n - 1)) - | 'WindowForGBK' >> beam.WindowInto( - GlobalWindows(), - trigger=beam_trigger.Repeatedly( - beam_trigger.AfterCount(1)), - accumulation_mode=( - beam_trigger.AccumulationMode.DISCARDING)) - | 'GroupByShardKey' >> beam.GroupByKey() + | 'ReshuffleForFanout' >> beam.Reshuffle() | 'DecompressBatches' >> beam.ParDo( _DecompressArrowBatchesFn( change_timestamp_column=( diff --git a/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/pipeline.py b/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/pipeline.py index c81c79745c..3383a50b06 100644 --- a/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/pipeline.py +++ b/python/src/main/python/bigquery-anomaly-detection/src/bqmonitor/pipeline.py @@ -533,13 +533,6 @@ def _add_argparse_args(cls, parser): help='BigQuery table to write all anomaly detection results to. ' 'Format: project:dataset.table. If unset, results are not written ' 'to BigQuery.') - parser.add_argument( - '--decompress_shards', - type=int, - default=1200, - help='Number of shards for CDC Arrow batch decompression fan-out. ' - 'Spreads decompression CPU across workers. ' - '0 disables fan-out (decode inline). Default: 1200.') parser.add_argument( '--fanout_strategy', default='sharded', @@ -996,10 +989,7 @@ def build_pipeline(pipeline, options, metric_spec, detector): buffer_sec=options.buffer_sec, columns=columns, change_type_column=change_type_col, - change_timestamp_column=change_ts_col, - decompress_shards=( - options.decompress_shards if options.decompress_shards > 0 - else None)) + change_timestamp_column=change_ts_col) if stop_time is not None: cdc_kwargs['stop_time'] = stop_time if options.temp_dataset: