From 0f5b471747e5fd76235af3027da3d6a93f4abade Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 12:30:55 -0500 Subject: [PATCH 01/14] Initial BQ CDC --- .../io/gcp/bigquery_change_history.py | 1186 +++++++++++++++++ .../io/gcp/bigquery_change_history_it_test.py | 525 ++++++++ .../io/gcp/bigquery_change_history_test.py | 212 +++ 3 files changed, 1923 insertions(+) create mode 100644 sdks/python/apache_beam/io/gcp/bigquery_change_history.py create mode 100644 sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py create mode 100644 sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py new file mode 100644 index 000000000000..de8f17dbd92a --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -0,0 +1,1186 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Streaming source for BigQuery change history (APPENDS/CHANGES functions). + +This module provides ``ReadBigQueryChangeHistory``, a streaming PTransform +that continuously polls BigQuery APPENDS() or CHANGES() functions and emits +changed rows as an unbounded PCollection. + +**Status: Experimental**: API may change without notice. + +Usage:: + + import apache_beam as beam + from apache_beam.io.gcp.bigquery_change_history import ReadBigQueryChangeHistory + + with beam.Pipeline(options=pipeline_options) as p: + changes = ( + p + | ReadBigQueryChangeHistory( + table='my-project:my_dataset.my_table', + change_function='APPENDS', + poll_interval_sec=60)) + +Architecture: + Poll: Polling SDF emits lightweight _QueryRange instructions. + Query: _ExecuteQueryFn runs the BQ query, writes to a temp table. + Read: SDF reads temp table via Storage Read API with dynamic splitting. + Cleanup: Stateful DoFn tracks stream completion, deletes temp tables. +""" + +import dataclasses +import datetime +import logging +import sys +import time +import uuid +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import apache_beam as beam +from apache_beam.io.gcp import bigquery_tools +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.io.restriction_trackers import OffsetRange +from apache_beam.io.restriction_trackers import OffsetRestrictionTracker +from apache_beam.io.iobase import WatermarkEstimator +from apache_beam.io.watermark_estimators import ManualWatermarkEstimator +from apache_beam.metrics import Metrics +from apache_beam.transforms.core import WatermarkEstimatorProvider +from apache_beam.transforms.window import TimestampedValue +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + +try: + from google.cloud import bigquery_storage_v1 as bq_storage +except ImportError: + bq_storage = None # type: ignore + +try: + import pyarrow +except ImportError: + pyarrow = None # type: ignore + +_LOGGER = logging.getLogger(__name__) + +__all__ = ['ReadBigQueryChangeHistory'] + +# Max time range for CHANGES() queries: 1 day in seconds. +_MAX_CHANGES_RANGE_SEC = 86400 + +# Side output tag for cleanup signals between the Read SDF and Cleanup DoFn. +_CLEANUP_TAG = 'cleanup' + +# Default number of Storage Read API streams to request. +# Matches ReadFromBigQuery's MIN_SPLIT_COUNT to enable parallelism. +# The server may return fewer streams if the table is small. +_DEFAULT_MAX_STREAMS = 10 + +# Default table expiration for auto-created temp datasets: 24 hours in ms. +# Tables created in the dataset auto-expire after this duration if not +# explicitly deleted, acting as a safety net for orphaned temp tables +# (e.g. pipeline crash before cleanup runs). +_DEFAULT_TABLE_EXPIRATION_MS = 24 * 60 * 60 * 1000 + +# ============================================================================= +# Helpers and data classes +# ============================================================================= + + +@dataclasses.dataclass +class _QueryResult: + """Bridges the Query step (query execution) to the Read SDF. + + After _ExecuteQueryFn runs a CHANGES/APPENDS query, it emits a _QueryResult + pointing to the temp table containing query results. The Read SDF reads + rows from that temp table via the Storage Read API. + + range_start/range_end define the time window this query covers. + The Read SDF uses range_start to set an initial watermark hold so the runner + doesn't advance the watermark past the data's timestamps. + """ + temp_table_ref: Optional[bigquery.TableReference] = None + range_start: float = 0.0 + range_end: float = 0.0 + + +@dataclasses.dataclass +class _PollConfig: + """Input element for the polling SDF. + + Only contains start_time, which _PollWatermarkEstimatorProvider uses + to initialize the watermark hold. All other config is passed via + _PollChangeHistoryFn.__init__. + """ + start_time: float + + +@dataclasses.dataclass +class _QueryRange: + """Lightweight instruction emitted by the polling SDF. + + Contains only the time range to query. Static config (table, project, + etc.) is held by _ExecuteQueryFn which receives these after a Reshuffle + commit boundary, preventing duplicate queries on SDF re-dispatch. + """ + chunk_start: float + chunk_end: float + + +class _StreamRestriction: + """Restriction carrying BQ Storage stream names for cross-worker safety. + + Unlike a plain OffsetRange(0, N), this restriction is self-contained: + each split carries the actual stream name strings so it can be processed + on any worker. Composes an OffsetRange for offset logic. + """ + __slots__ = ('stream_names', 'range') + + def __init__( + self, stream_names: Tuple[str, ...], start: int, stop: int) -> None: + self.stream_names = stream_names # tuple of BQ stream name strings + self.range = OffsetRange(start, stop) + + @property + def start(self) -> int: + return self.range.start + + @property + def stop(self) -> int: + return self.range.stop + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _StreamRestriction): + return False + return ( + self.stream_names == other.stream_names and self.range == other.range) + + def __hash__(self) -> int: + return hash((type(self), self.stream_names, self.range)) + + def __repr__(self) -> str: + return ( + '_StreamRestriction(streams=%d, start=%d, stop=%d)' % + (len(self.stream_names), self.start, self.stop)) + + def size(self) -> int: + return self.range.size() + + +class _StreamRestrictionTracker(beam.io.iobase.RestrictionTracker): + """Tracker for _StreamRestriction, delegating offset logic to + OffsetRestrictionTracker.""" + def __init__(self, restriction: _StreamRestriction) -> None: + self._stream_names = restriction.stream_names + self._offset_tracker = OffsetRestrictionTracker(restriction.range) + + def current_restriction(self) -> _StreamRestriction: + r = self._offset_tracker.current_restriction() + return _StreamRestriction(self._stream_names, r.start, r.stop) + + def try_claim(self, position: int) -> bool: + return self._offset_tracker.try_claim(position) + + def try_split( + self, fraction_of_remainder: float + ) -> Optional[Tuple[_StreamRestriction, _StreamRestriction]]: + result = self._offset_tracker.try_split(fraction_of_remainder) + if result is not None: + primary, residual = result + return ( + _StreamRestriction(self._stream_names, primary.start, primary.stop), + _StreamRestriction(self._stream_names, residual.start, residual.stop)) + return None + + def check_done(self) -> None: + self._offset_tracker.check_done() + + def current_progress(self): + return self._offset_tracker.current_progress() + + def is_bounded(self) -> bool: + return True + + +class _NonSplittableOffsetTracker(OffsetRestrictionTracker): + """OffsetRestrictionTracker that allows checkpointing but prevents splitting. + + Checkpointing (fraction=0) is required for defer_remainder(). All other + split fractions are refused, ensuring the polling SDF runs as a singleton. + """ + def try_split( + self, fraction_of_remainder: float + ) -> Optional[Tuple[OffsetRange, OffsetRange]]: + if fraction_of_remainder == 0: + return super().try_split(fraction_of_remainder) + return None + + +class _PollWatermarkEstimator(WatermarkEstimator): + """Watermark estimator that tracks both a watermark hold and poll cursor. + + The watermark hold (reported via current_watermark) is set to start_ts: + the earliest data timestamp emitted by the current poll. This prevents + downstream stages from seeing data as late. + + The poll cursor (last_end_ts) tracks where the next poll should start. + This is separate from the watermark so we can hold the watermark back + at start_ts while still advancing the poll cursor to end_ts. + + State is checkpointed as (watermark_hold, last_end_ts) so + both values survive SDF re-dispatch. + """ + def __init__(self, state: Tuple[Timestamp, float]) -> None: + # state is (watermark_hold: Timestamp, last_end_ts: float) + self._watermark_hold, self._last_end_ts = state + + def observe_timestamp(self, timestamp: Timestamp) -> None: + pass + + def current_watermark(self) -> Timestamp: + return self._watermark_hold + + def get_estimator_state(self) -> Tuple[Timestamp, float]: + return (self._watermark_hold, self._last_end_ts) + + def set_watermark(self, timestamp: Timestamp) -> None: + if not isinstance(timestamp, Timestamp): + raise ValueError('set_watermark expects a Timestamp as input') + if self._watermark_hold and self._watermark_hold > timestamp: + raise ValueError( + 'Watermark must be monotonically increasing. ' + 'Provided %s < current %s' % (timestamp, self._watermark_hold)) + self._watermark_hold = timestamp + + def advance_poll_cursor(self, end_ts: float) -> None: + """Record end_ts so the next poll starts from here.""" + self._last_end_ts = end_ts + + def poll_cursor(self) -> float: + """Return the start_ts for the next poll.""" + return self._last_end_ts + + +class _PollWatermarkEstimatorProvider(WatermarkEstimatorProvider): + """Provider for _PollWatermarkEstimator. + + Initializes with watermark hold at start_time and poll cursor at + start_time (first poll will query from start_time). + """ + def initial_estimator_state( + self, element: _PollConfig, + restriction: OffsetRange) -> Tuple[Timestamp, float]: + return (Timestamp(element.start_time), element.start_time) + + def create_watermark_estimator( + self, estimator_state: Tuple[Timestamp, + float]) -> _PollWatermarkEstimator: + return _PollWatermarkEstimator(estimator_state) + + +def _table_key(table_ref: bigquery.TableReference) -> str: + """Convert a TableReference to a 'project.dataset.table' string.""" + return f'{table_ref.projectId}.{table_ref.datasetId}.{table_ref.tableId}' + + +def build_changes_query( + table: str, + start_ts: float, + end_ts: float, + change_function: str, + change_type_column: str = 'change_type', + change_timestamp_column: str = 'change_timestamp', + columns: Optional[List[str]] = None, + row_filter: Optional[str] = None) -> str: + """Build a CHANGES() or APPENDS() SQL query. + + Args: + table: Table name as 'project.dataset.table' or 'project:dataset.table'. + start_ts: Start timestamp (float, seconds since epoch). Inclusive. + end_ts: End timestamp (float, seconds since epoch). Exclusive. + change_function: 'CHANGES' or 'APPENDS'. + change_type_column: Output column name for _CHANGE_TYPE pseudo-column. + change_timestamp_column: Output column name for _CHANGE_TIMESTAMP + pseudo-column. + columns: Optional list of column names to select. If None, selects all + columns. Pseudo-columns are always appended regardless. + row_filter: Optional SQL WHERE clause (without the WHERE keyword). + Applied after the CHANGES/APPENDS function. + + Returns: + SQL string. + """ + # Normalize 'project:dataset.table' to 'project.dataset.table' + table = table.replace(':', '.') + start_iso = datetime.datetime.fromtimestamp( + start_ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ') + end_iso = datetime.datetime.fromtimestamp( + end_ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ') + # Pseudo-columns (_CHANGE_TYPE, _CHANGE_TIMESTAMP) can't be written to + # destination tables with their original names. Rename them so they can + # be persisted to the temp table for Storage Read API reading. + pseudo = ( + f"_CHANGE_TYPE AS {change_type_column}, " + f"_CHANGE_TIMESTAMP AS {change_timestamp_column}") + if columns is None: + select = f"SELECT * EXCEPT(_CHANGE_TYPE, _CHANGE_TIMESTAMP), {pseudo}" + else: + select = f"SELECT {', '.join(columns)}, {pseudo}" + from_clause = ( + f"FROM {change_function}" + f"(TABLE `{table}`, " + f"TIMESTAMP '{start_iso}', " + f"TIMESTAMP '{end_iso}')") + where = f" WHERE {row_filter}" if row_filter else "" + return f"{select} {from_clause}{where}" + + +def compute_ranges(start_ts: float, end_ts: float, + change_function: str) -> List[Tuple[float, float]]: + """Split [start_ts, end_ts) into query-safe chunks. + + CHANGES() has a max 1-day range. APPENDS() has no limit. + + Args: + start_ts: Start timestamp (float, seconds since epoch). + end_ts: End timestamp (float, seconds since epoch). + change_function: 'CHANGES' or 'APPENDS'. + + Returns: + List of (start, end) float tuples. Empty if end_ts <= start_ts. + """ + if end_ts <= start_ts: + return [] + + if change_function != 'CHANGES': + return [(start_ts, end_ts)] + + # CHANGES: chunk into <=1-day ranges + ranges = [] + current = start_ts + while current < end_ts: + chunk_end = min(current + _MAX_CHANGES_RANGE_SEC, end_ts) + ranges.append((current, chunk_end)) + current = chunk_end + return ranges + + +def _utc(ts: Union[float, Timestamp]) -> str: + """Format an epoch-seconds float or Timestamp as a UTC string.""" + if isinstance(ts, Timestamp): + ts = ts.seconds() + return datetime.datetime.fromtimestamp( + ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + + +# ============================================================================= +# Poll: _PollChangeHistoryFn (Polling SDF) +# ============================================================================= + + +class _PollChangeHistoryFn(beam.DoFn, beam.transforms.core.RestrictionProvider): + """SDF that periodically emits _QueryRange instructions. + + Uses defer_remainder() for poll timing and _PollWatermarkEstimator to + control the watermark. The watermark is initially held at start_time , then + advanced to start_ts of each poll. + + Derives start_ts from the poll cursor. On each poll: + 1. start_ts = poll cursor (last end_ts, or start_time on first poll) + 2. end_ts = now - buffer_sec + 3. Computes query chunks, yields _QueryRange per chunk + 4. Advances poll cursor to end_ts (for next poll's start) + 5. Advances watermark to start_ts (earliest data in this poll) + 6. Defers to next poll interval + """ + def __init__( + self, + table: str, + project: str, + change_function: str, + buffer_sec: float, + start_time: float, + stop_time: Union[float, Timestamp], + poll_interval_sec: float, + location: Optional[str] = None) -> None: + self._table = table + self._project = project + self._change_function = change_function + self._buffer_sec = buffer_sec + self._start_time = start_time + self._stop_time = stop_time + self._poll_interval_sec = poll_interval_sec + self._location = location + + def initial_restriction(self, element: _PollConfig) -> OffsetRange: + return OffsetRange(0, sys.maxsize) + + def create_tracker( + self, restriction: OffsetRange) -> _NonSplittableOffsetTracker: + # When stop_time has passed, return an empty-range tracker so + # try_claim() fails immediately and check_done() passes (empty range). + if time.time() >= self._stop_time: + _LOGGER.info( + '[Poll] create_tracker: stop_time reached, ' + 'returning empty range to terminate SDF') + return _NonSplittableOffsetTracker( + OffsetRange(restriction.start, restriction.start)) + return _NonSplittableOffsetTracker(restriction) + + def restriction_size( + self, element: _PollConfig, restriction: OffsetRange) -> int: + return 1 + + def split(self, element: _PollConfig, + restriction: OffsetRange) -> Iterable[OffsetRange]: + yield restriction + + def truncate(self, element: _PollConfig, restriction: OffsetRange) -> None: + return None + + def _next_poll_time(self, start_ts: float, now: float) -> Optional[Timestamp]: + """Return a Timestamp to defer to, or None if we should poll now.""" + earliest = start_ts + self._buffer_sec + self._poll_interval_sec + if now < earliest: + return Timestamp.of(earliest) + return None + + def _emit_query_ranges( + self, + start_ts: float, + end_ts: float, + now: float, + watermark_estimator: _PollWatermarkEstimator) -> Iterable[_QueryRange]: + """Compute and yield _QueryRange elements, advancing estimator state.""" + if self._stop_time != MAX_TIMESTAMP and now >= self._stop_time: + _LOGGER.info('[Poll] Stop time reached') + return + + ranges = compute_ranges(start_ts, end_ts, self._change_function) + _LOGGER.info( + '[Poll] %d chunks for [%s, %s)', + len(ranges), + _utc(start_ts), + _utc(end_ts)) + Metrics.counter('BigQueryChangeHistory', 'polls').inc() + + watermark_estimator.advance_poll_cursor(end_ts) + watermark_estimator.set_watermark(Timestamp(start_ts)) + _LOGGER.info( + '[Poll] Watermark=%s (start_ts), cursor=%s (end_ts)', + _utc(start_ts), + _utc(end_ts)) + + for chunk_start, chunk_end in ranges: + yield TimestampedValue( + _QueryRange(chunk_start=chunk_start, chunk_end=chunk_end), + Timestamp(start_ts)) + + @beam.DoFn.unbounded_per_element() + def process( + self, + _: _PollConfig, + restriction_tracker=beam.DoFn.RestrictionParam(), + watermark_estimator=beam.DoFn.WatermarkEstimatorParam( + _PollWatermarkEstimatorProvider()) + ) -> Iterable[_QueryRange]: + + now = time.time() + start_ts = watermark_estimator.poll_cursor() + end_ts = now - self._buffer_sec + + defer_to = self._next_poll_time(start_ts, now) + if defer_to is not None: + restriction_tracker.defer_remainder(defer_to) + return + + _LOGGER.info( + '[Poll] Polling: start_ts=%s, end_ts=%s, watermark=%s', + _utc(start_ts), + _utc(end_ts), + _utc(watermark_estimator.current_watermark())) + + current_index = restriction_tracker.current_restriction().start + + if not restriction_tracker.try_claim(current_index): + return + restriction_tracker.defer_remainder( + Timestamp.of(now + self._poll_interval_sec)) + + yield from self._emit_query_ranges( + start_ts, end_ts, now, watermark_estimator) + + +class _ExecuteQueryFn(beam.DoFn): + """Executes a BQ CHANGES/APPENDS query from a _QueryRange instruction. + """ + def __init__( + self, + table: str, + project: str, + change_function: str, + temp_dataset: str, + location: Optional[str], + change_type_column: str = 'change_type', + change_timestamp_column: str = 'change_timestamp', + columns: Optional[List[str]] = None, + row_filter: Optional[str] = None) -> None: + self._table = table + self._project = project + self._change_function = change_function + self._temp_dataset = temp_dataset + self._location = location + self._change_type_column = change_type_column + self._change_timestamp_column = change_timestamp_column + self._columns = columns + self._row_filter = row_filter + + def setup(self) -> None: + self._bq_wrapper = bigquery_tools.BigQueryWrapper() + if self._location is None: + table_ref = bigquery_tools.parse_table_reference( + self._table, project=self._project) + self._location = self._bq_wrapper.get_table_location( + table_ref.projectId, table_ref.datasetId, table_ref.tableId) + _LOGGER.info( + '[Query] Inferred location=%s from source table %s', + self._location, + self._table) + self._get_or_create_temp_dataset() + + def _get_or_create_temp_dataset(self) -> None: + """Create the temp dataset if it doesn't exist. + + Sets defaultTableExpirationMs so orphaned temp tables are automatically + garbage-collected by BigQuery if the pipeline crashes before cleanup. + """ + try: + self._bq_wrapper.client.datasets.Get( + bigquery.BigqueryDatasetsGetRequest( + projectId=self._project, datasetId=self._temp_dataset)) + _LOGGER.info( + '[Query] Temp dataset %s.%s already exists', + self._project, + self._temp_dataset) + except Exception: + _LOGGER.info( + '[Query] Creating temp dataset %s.%s with ' + '24h table expiration, location=%s', + self._project, + self._temp_dataset, + self._location) + dataset = bigquery.Dataset( + datasetReference=bigquery.DatasetReference( + projectId=self._project, datasetId=self._temp_dataset)) + if self._location is not None: + dataset.location = self._location + dataset.defaultTableExpirationMs = _DEFAULT_TABLE_EXPIRATION_MS + self._bq_wrapper.client.datasets.Insert( + bigquery.BigqueryDatasetsInsertRequest( + projectId=self._project, dataset=dataset)) + + def process(self, qr: _QueryRange) -> Iterable[_QueryResult]: + """Execute the BQ query described by a _QueryRange and yield _QueryResult. + """ + + sql = build_changes_query( + self._table, + qr.chunk_start, + qr.chunk_end, + self._change_function, + self._change_type_column, + self._change_timestamp_column, + self._columns, + self._row_filter) + temp_table_id = f'beam_ch_temp_{uuid.uuid4().hex[:8]}' + job_id = f'beam_ch_{uuid.uuid4().hex[:12]}' + + _LOGGER.info( + '[Query] job_id=%s, temp_table=%s.%s, range=[%s, %s)', + job_id, + self._temp_dataset, + temp_table_id, + _utc(qr.chunk_start), + _utc(qr.chunk_end)) + + temp_table_ref = bigquery.TableReference( + projectId=self._project, + datasetId=self._temp_dataset, + tableId=temp_table_id) + + reference = bigquery.JobReference( + jobId=job_id, projectId=self._project, location=self._location) + + request = bigquery.BigqueryJobsInsertRequest( + projectId=self._project, + job=bigquery.Job( + configuration=bigquery.JobConfiguration( + query=bigquery.JobConfigurationQuery( + query=sql, + useLegacySql=False, + destinationTable=temp_table_ref, + writeDisposition='WRITE_TRUNCATE', + ), + ), + jobReference=reference)) + + _LOGGER.info('[Query] Submitting BQ job %s...', job_id) + response = self._bq_wrapper._start_job(request) + _LOGGER.info('[Query] BQ job %s submitted, waiting...', job_id) + self._bq_wrapper.wait_for_bq_job( + response.jobReference, sleep_duration_sec=2) + _LOGGER.info( + '[Query] BQ job %s DONE. Results in %s.%s', + job_id, + self._temp_dataset, + temp_table_id) + Metrics.counter('BigQueryChangeHistory', 'queries').inc() + + yield _QueryResult( + temp_table_ref=temp_table_ref, + range_start=qr.chunk_start, + range_end=qr.chunk_end) + + +class _CDCWatermarkEstimatorProvider(WatermarkEstimatorProvider): + """WatermarkEstimatorProvider that initializes the hold from _QueryResult. + + Uses range_start from the element to set the initial watermark hold. + This prevents the runner from advancing the watermark past the data's + timestamps before any rows are emitted. + """ + def initial_estimator_state( + self, element: _QueryResult, + restriction: _StreamRestriction) -> Timestamp: + return Timestamp(element.range_start) + + def create_watermark_estimator( + self, estimator_state: Timestamp) -> ManualWatermarkEstimator: + return ManualWatermarkEstimator(estimator_state) + + +# ============================================================================= +# Read: _ReadStorageStreamsSDF +# ============================================================================= + + +class _ReadStorageStreamsSDF(beam.DoFn, + beam.transforms.core.RestrictionProvider): + """SDF that reads a temp table via BigQuery Storage Read API. + + Note on SDF lifecycle: the runner decomposes this SDF into three internal + wrapper DoFns, each a separately deserialized copy: + - Stage A (PairWithRestriction): calls initial_restriction(): no setup() + - Stage B (SplitAndSizeRestrictions): calls split(), restriction_size() + - Stage C (ProcessSizedElements): calls setup(), then process() + Because initial_restriction() runs on a different copy than process(), + _ensure_client() lazily creates a gRPC client on whichever copy needs one. + The _StreamRestriction carries stream names directly so no shared state + is needed between copies. + + Each element is a _QueryResult pointing to a temp table. + + Watermark: Uses ManualWatermarkEstimator so the watermark only advances + as fast as the change-timestamp values we emit. + + Emits: + Main output: TimestampedValue(row_dict, event_timestamp) + Side output (_CLEANUP_TAG): (table_key, streams_read, total_streams) + """ + def __init__( + self, + batch_arrow_read: bool = True, + change_timestamp_column: str = 'change_timestamp') -> None: + self._batch_arrow_read = batch_arrow_read + self._change_timestamp_column = change_timestamp_column + self._storage_client = None + + def _ensure_client(self) -> None: + """Lazily initialize the Storage client. + + Called from both setup() and initial_restriction() because the runner + may invoke initial_restriction on the RestrictionProvider instance + before setup() runs (or on a separately deserialized copy). + """ + if self._storage_client is None: + _LOGGER.info('[Read] creating BigQueryReadClient') + self._storage_client = bq_storage.BigQueryReadClient() + + def setup(self) -> None: + self._ensure_client() + + def initial_restriction(self, element: _QueryResult) -> _StreamRestriction: + """Create ReadSession and return _StreamRestriction with stream names.""" + self._ensure_client() + table_key = _table_key(element.temp_table_ref) + session = self._create_read_session(element.temp_table_ref) + stream_names = tuple(s.name for s in session.streams) + _LOGGER.info( + '[Read] initial_restriction for %s: %d streams', + table_key, + len(stream_names)) + return _StreamRestriction(stream_names, 0, len(stream_names)) + + def create_tracker( + self, restriction: _StreamRestriction) -> _StreamRestrictionTracker: + return _StreamRestrictionTracker(restriction) + + def restriction_size( + self, element: _QueryResult, restriction: _StreamRestriction) -> int: + return restriction.size() + + def split(self, element: _QueryResult, + restriction: _StreamRestriction) -> Iterable[_StreamRestriction]: + """Yield one _StreamRestriction per stream for parallel distribution.""" + if restriction.size() <= 1: + yield restriction + else: + for i in range(restriction.start, restriction.stop): + yield _StreamRestriction(restriction.stream_names, i, i + 1) + + def is_bounded(self) -> bool: + return True + + def process( + self, + element: _QueryResult, + restriction_tracker=beam.DoFn.RestrictionParam(), + watermark_estimator=beam.DoFn.WatermarkEstimatorParam( + _CDCWatermarkEstimatorProvider()) + ) -> Iterable[Dict[str, Any]]: + self._ensure_client() + table_key = _table_key(element.temp_table_ref) + + _LOGGER.info( + '[Read] Processing %s, range=[%s, %s), ' + 'initial watermark=%s', + table_key, + _utc(element.range_start), + _utc(element.range_end), + _utc(watermark_estimator.current_watermark())) + + restriction = restriction_tracker.current_restriction() + stream_names = restriction.stream_names + total_streams = len(stream_names) + + streams_read = 0 + total_rows = 0 + + _LOGGER.info( + '[Read] Reading streams [%d, %d) of %d total for %s', + restriction.start, + restriction.stop, + total_streams, + table_key) + + for i in range(restriction.start, restriction.stop): + if not restriction_tracker.try_claim(i): + _LOGGER.info( + '[Read] try_claim(%d) FAILED for %s: ' + 'runner split or checkpoint, breaking', + i, + table_key) + break + + stream_name = stream_names[i] + _LOGGER.info( + '[Read] try_claim(%d) succeeded: reading stream %s', i, stream_name) + + stream_rows = 0 + for row in self._read_stream(stream_name): + ts = row.get(self._change_timestamp_column) + if ts is None: + raise ValueError( + 'Row missing %r column. Row keys: %s' % + (self._change_timestamp_column, list(row.keys()))) + if isinstance(ts, datetime.datetime): + ts = Timestamp.from_utc_datetime(ts) + + yield TimestampedValue(row, ts) + stream_rows += 1 + total_rows += 1 + Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc() + + streams_read += 1 + _LOGGER.info( + '[Read] Finished reading stream %d for %s: %d rows', + i, + table_key, + stream_rows) + Metrics.counter('BigQueryChangeHistory', 'streams_read').inc() + + # Advance watermark to range_end after reading all streams. The + # initial hold was set to range_start by _CDCWatermarkEstimatorProvider. + watermark_estimator.set_watermark(Timestamp(element.range_end)) + _LOGGER.info( + '[Read] Watermark advanced to %s (range_end) for %s', + _utc(element.range_end), + table_key) + + # Emit cleanup signal. Every split that reads at least one stream + # reports how many it read. + if streams_read > 0: + _LOGGER.info( + '[Read] Emitting cleanup signal for %s: ' + 'streams_read=%d, total_streams=%d, total_rows=%d', + table_key, + streams_read, + total_streams, + total_rows) + yield beam.pvalue.TaggedOutput( + _CLEANUP_TAG, ( + table_key, + streams_read, + total_streams, + )) + + def _create_read_session(self, table_ref: bigquery.TableReference) -> Any: + """Create a BigQuery Storage ReadSession for the given table.""" + table_path = ( + f'projects/{table_ref.projectId}/' + f'datasets/{table_ref.datasetId}/' + f'tables/{table_ref.tableId}') + + requested_session = bq_storage.types.ReadSession() + requested_session.table = table_path + requested_session.data_format = bq_storage.types.DataFormat.ARROW + requested_session.read_options \ + .arrow_serialization_options.buffer_compression = \ + bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME + + session = self._storage_client.create_read_session( + parent=f'projects/{table_ref.projectId}', + read_session=requested_session, + max_stream_count=_DEFAULT_MAX_STREAMS) + _LOGGER.info( + '[Read] _create_read_session: table=%s, %d streams', + table_path, + len(session.streams)) + return session + + def _read_stream(self, stream_name: str) -> Iterable[Dict[str, Any]]: + """Read all rows from a single Storage API stream as dicts. + + When batch_arrow_read is enabled, converts entire Arrow RecordBatches + at once using to_pydict() instead of calling .as_py() on each cell + individually. This is ~1.5x faster for large tables at the cost of ~2x + peak memory per batch. + """ + if self._batch_arrow_read: + yield from self._read_stream_batch(stream_name) + else: + yield from self._read_stream_row_by_row(stream_name) + + def _read_stream_row_by_row(self, + stream_name: str) -> Iterable[Dict[str, Any]]: + """Row-by-row Arrow conversion (lower memory than batch mode).""" + t0 = time.time() + row_count = 0 + for row in self._storage_client.read_rows(stream_name).rows(): + yield dict((item[0], item[1].as_py()) for item in row.items()) + row_count += 1 + elapsed = time.time() - t0 + _LOGGER.info( + '[Read] row_by_row: %d rows in %.2fs (%.0f rows/s)', + row_count, + elapsed, + row_count / elapsed if elapsed > 0 else 0) + + def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]: + """Batch-convert Arrow RecordBatches for high throughput.""" + schema = None + row_count = 0 + t0 = time.time() + for response in self._storage_client.read_rows(stream_name): + if schema is None and response.arrow_schema.serialized_schema: + schema = pyarrow.ipc.read_schema( + pyarrow.py_buffer(response.arrow_schema.serialized_schema)) + batch_bytes = response.arrow_record_batch.serialized_record_batch + if batch_bytes and schema is not None: + batch = pyarrow.ipc.read_record_batch( + pyarrow.py_buffer(batch_bytes), schema) + columns = batch.to_pydict() + col_names = batch.schema.names + for i in range(batch.num_rows): + yield {name: columns[name][i] for name in col_names} + row_count += batch.num_rows + elapsed = time.time() - t0 + _LOGGER.info( + '[Read] batch_read: %d rows in %.2fs (%.0f rows/s)', + row_count, + elapsed, + row_count / elapsed if elapsed > 0 else 0) + + +# ============================================================================= +# Cleanup: _CleanupTempTablesFn +# ============================================================================= + + +class _CleanupTempTablesFn(beam.DoFn): + """Stateful DoFn that deletes temp tables after all streams are read. + + Receives cleanup signals from the Read SDF as: + (table_key, (streams_read_count, total_streams)) + + Accumulates streams_read across all signals for the same table_key. + When streams_read >= total_streams, deletes the temp table. The >= + (rather than ==) guards against duplicate delivery in at-least-once runners. + """ + STREAMS_READ = beam.transforms.userstate.CombiningValueStateSpec( + 'streams_read', sum) + + def setup(self) -> None: + _LOGGER.info('[Cleanup] setup: creating BigQueryWrapper') + self._bq_wrapper = bigquery_tools.BigQueryWrapper() + + def process( + self, + element: Tuple[str, Tuple[int, int]], + streams_read=beam.DoFn.StateParam(STREAMS_READ) + ) -> None: + table_key = element[0] + split_count = element[1][0] + total_streams = element[1][1] + + _LOGGER.info( + '[Cleanup] Received cleanup signal for %s: ' + 'split_count=%d, total_streams=%d', + table_key, + split_count, + total_streams) + + streams_read.add(split_count) + current_read = streams_read.read() + + _LOGGER.info( + '[Cleanup] State for %s: streams_read=%d/%d', + table_key, + current_read, + total_streams) + + if current_read >= total_streams: + parts = table_key.split('.') + if len(parts) == 3: + project, dataset, table = parts + _LOGGER.info( + '[Cleanup] All streams read: DELETING temp table %s', table_key) + self._bq_wrapper._delete_table(project, dataset, table) + _LOGGER.info('[Cleanup] Deleted temp table %s', table_key) + Metrics.counter('BigQueryChangeHistory', 'temp_tables_deleted').inc() + streams_read.clear() + else: + _LOGGER.info( + '[Cleanup] Not yet complete for %s (%d/%d), ' + 'waiting for more signals', + table_key, + current_read, + total_streams) + + +# ============================================================================= +# Public API: ReadBigQueryChangeHistory +# ============================================================================= + + +class ReadBigQueryChangeHistory(beam.PTransform): + """Streaming source for BigQuery change history. + + Continuously polls BigQuery APPENDS() or CHANGES() functions and emits + changed rows as an unbounded PCollection of dicts. + + Args: + table: BigQuery table to read changes from. + Format: 'project:dataset.table' or 'project.dataset.table'. + poll_interval_sec: Seconds between polls. Default 60. + start_time: Start reading from this timestamp (float, epoch seconds). + Default: current time when pipeline starts. + stop_time: Stop polling at this timestamp. Default: run forever. + change_function: 'CHANGES' or 'APPENDS'. Default 'APPENDS'. + buffer_sec: Safety buffer in seconds behind now(). Default 15. + project: GCP project ID. Default: from pipeline options. + temp_dataset: Dataset for temp tables. If None (default), a + per-pipeline dataset is auto-created with a 24-hour table + expiration as a safety net for orphaned tables. Set this to + use an existing dataset (e.g. if your service account lacks + bigquery.datasets.create permission). + location: BigQuery geographic location for query jobs and temp + dataset (e.g. 'US', 'us-central1'). If None (default), inferred + from the source table. + change_type_column: Output column name for the _CHANGE_TYPE + pseudo-column. Default 'change_type'. Change this if your source + table already has a column named 'change_type'. + change_timestamp_column: Output column name for the + _CHANGE_TIMESTAMP pseudo-column. Default 'change_timestamp'. + Change this if your source table already has a column named + 'change_timestamp'. This column is also used internally to + extract event timestamps for watermark tracking. + columns: Optional list of column names to select from the source + table. If None (default), all columns are selected. The + pseudo-columns (change_type, change_timestamp) are always + included regardless of this setting. + row_filter: Optional SQL boolean expression used as a WHERE clause + on the CHANGES/APPENDS query. Do not include the WHERE keyword. + Example: ``'status = "active" AND region = "US"'``. + batch_arrow_read: If True (default), convert Arrow RecordBatches in + bulk using to_pydict() instead of per-cell .as_py() calls. + This is 1.5x faster for large tables at the cost of ~2x peak + memory per RecordBatch. Set to False for minimal memory usage. + """ + def __init__( + self, + table: str, + poll_interval_sec: float = 60, + start_time: Optional[float] = None, + stop_time: Optional[float] = None, + change_function: str = 'APPENDS', + buffer_sec: float = 15, + project: Optional[str] = None, + temp_dataset: Optional[str] = None, + location: Optional[str] = None, + change_type_column: str = 'change_type', + change_timestamp_column: str = 'change_timestamp', + columns: Optional[List[str]] = None, + row_filter: Optional[str] = None, + batch_arrow_read: bool = True) -> None: + super().__init__() + if bq_storage is None: + raise ImportError( + 'google-cloud-bigquery-storage is required for ' + 'ReadBigQueryChangeHistory. Install it with: ' + 'pip install google-cloud-bigquery-storage') + if pyarrow is None: + raise ImportError( + 'pyarrow is required for ReadBigQueryChangeHistory. ' + 'Install it with: pip install pyarrow') + if change_function not in ('CHANGES', 'APPENDS'): + raise ValueError( + f"change_function must be 'CHANGES' or 'APPENDS', " + f"got '{change_function}'") + if poll_interval_sec <= 0: + raise ValueError( + f'poll_interval_sec must be positive, got {poll_interval_sec}') + + self._table = table + self._poll_interval_sec = poll_interval_sec + self._start_time = start_time + self._stop_time = stop_time + self._change_function = change_function + self._buffer_sec = buffer_sec + self._project = project + self._temp_dataset = temp_dataset + self._location = location + self._change_type_column = change_type_column + self._change_timestamp_column = change_timestamp_column + self._columns = columns + self._row_filter = row_filter + self._batch_arrow_read = batch_arrow_read + + def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: + project = self._project + if project is None: + project = pbegin.pipeline.options.view_as( + beam.options.pipeline_options.GoogleCloudOptions).project + + if project is None: + raise ValueError( + 'project must be specified either in ReadBigQueryChangeHistory ' + 'or in pipeline options (--project)') + + start_time = self._start_time or time.time() + stop_time = self._stop_time or MAX_TIMESTAMP + + temp_dataset = self._temp_dataset + if temp_dataset is None: + temp_dataset = f'beam_ch_temp_{uuid.uuid4().hex[:12]}' + + _LOGGER.info( + '[ReadBigQueryChangeHistory] expand: table=%s, project=%s, ' + 'change_function=%s, poll_interval=%d sec, buffer=%d sec, ' + 'temp_dataset=%s, start_time=%s, stop_time=%s', + self._table, + project, + self._change_function, + self._poll_interval_sec, + self._buffer_sec, + temp_dataset, + _utc(start_time), + _utc(stop_time) if stop_time != MAX_TIMESTAMP else 'INF') + + # Custom polling SDF emits lightweight _QueryRange instructions. + # The SDF uses defer_remainder() for poll timing and + # _PollWatermarkEstimator to hold the watermark at data timestamps. + # On the first invocation it handles the full historical range + # [start_time, now - buffer_sec) in a single poll. + config = _PollConfig(start_time=start_time) + + query_ranges = ( + pbegin + | 'CreatePollConfig' >> beam.Create([config]) + | 'PollChangeHistory' >> beam.ParDo( + _PollChangeHistoryFn( + table=self._table, + project=project, + change_function=self._change_function, + buffer_sec=self._buffer_sec, + start_time=start_time, + stop_time=stop_time, + poll_interval_sec=self._poll_interval_sec))) + + # CommitQueryResults: Reshuffle commits _QueryResult (temp table ref) + # so that if the Read SDF retries, it re-reads the existing temp table + # instead of re-running the BQ query. + # Possible edge-case is that if ReadStorageStreams doesn't read the temp + # table within 24 hours (table expiration) it can end up in a bad state by + # trying to query a non-existing table. + query_results = ( + query_ranges + | 'CommitQueryRanges' >> beam.Reshuffle() + | 'ExecuteQueries' >> beam.ParDo( + _ExecuteQueryFn( + table=self._table, + project=project, + change_function=self._change_function, + temp_dataset=temp_dataset, + location=self._location, + change_type_column=self._change_type_column, + change_timestamp_column=self._change_timestamp_column, + columns=self._columns, + row_filter=self._row_filter)) + | 'CommitQueryResults' >> beam.Reshuffle()) + + read_outputs = ( + query_results + | 'ReadStorageStreams' >> beam.ParDo( + _ReadStorageStreamsSDF( + batch_arrow_read=self._batch_arrow_read, + change_timestamp_column=self._change_timestamp_column)). + with_outputs(_CLEANUP_TAG, main='rows')) + + _ = ( + read_outputs[_CLEANUP_TAG] + | 'KeyByTable' >> + beam.Map(lambda x: (x[0], (x[1], x[2]))).with_output_types( + beam.typehints.Tuple[str, beam.typehints.Tuple[int, int]]) + | 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn())) + + return read_outputs['rows'] diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py new file mode 100644 index 000000000000..0c93a8347149 --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -0,0 +1,525 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Integration tests for BigQuery change history streaming source. +""" + +import logging +import secrets +import time +import unittest +import uuid + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_change_history import ReadBigQueryChangeHistory +from apache_beam.io.gcp.bigquery_change_history import _CleanupTempTablesFn +from apache_beam.io.gcp.bigquery_change_history import _ExecuteQueryFn +from apache_beam.io.gcp.bigquery_change_history import _PollChangeHistoryFn +from apache_beam.io.gcp.bigquery_change_history import _PollConfig +from apache_beam.io.gcp.bigquery_change_history import _QueryRange +from apache_beam.io.gcp.bigquery_change_history import _QueryResult +from apache_beam.io.gcp.bigquery_change_history import _ReadStorageStreamsSDF +from apache_beam.io.gcp.bigquery_change_history import _table_key +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +_LOGGER = logging.getLogger(__name__) + + +class BigQueryChangeHistoryIntegrationBase(unittest.TestCase): + """Base class for integration tests against real BigQuery. + + Gets project from pipeline options (--project). + Creates two unique temp datasets per test class: + - dataset: for change-history-enabled source tables + - temp_dataset: for pipeline temp tables (query results, etc.) + Both are deleted with all contents in tearDownClass. + """ + @classmethod + def setUpClass(cls): + cls.test_pipeline = TestPipeline(is_integration_test=True) + cls.project = cls.test_pipeline.get_option('project') + cls.args = cls.test_pipeline.get_full_options_as_args() + cls.bq_wrapper = BigQueryWrapper() + suffix = secrets.token_hex(4) + cls.dataset = f'beam_ch_src_{suffix}' + cls.temp_dataset = f'beam_ch_tmp_{suffix}' + cls.bq_wrapper.get_or_create_dataset(cls.project, cls.dataset) + ds = cls.bq_wrapper.client.datasets.Get( + bigquery.BigqueryDatasetsGetRequest( + projectId=cls.project, datasetId=cls.dataset)) + cls.location = ds.location + cls.bq_wrapper.get_or_create_dataset( + cls.project, cls.temp_dataset, location=cls.location) + _LOGGER.info( + 'Created datasets: source=%s, temp=%s (location=%s)', + cls.dataset, + cls.temp_dataset, + cls.location) + + @classmethod + def tearDownClass(cls): + for dataset in (cls.dataset, cls.temp_dataset): + try: + cls.bq_wrapper.client.datasets.Delete( + bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=dataset, deleteContents=True)) + _LOGGER.info('Deleted dataset %s', dataset) + except Exception as e: + _LOGGER.warning('Failed to clean up dataset %s: %s', dataset, e) + + @classmethod + def _create_temp_table_with_data(cls, table_id, rows, schema=None): + """Create a table in the temp dataset and insert rows via streaming.""" + if schema is None: + schema = [('id', 'INTEGER'), ('name', 'STRING'), ('value', 'FLOAT')] + table_schema = bigquery.TableSchema() + for field_name, field_type in schema: + field = bigquery.TableFieldSchema() + field.name = field_name + field.type = field_type + table_schema.fields.append(field) + + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.temp_dataset, + tableId=table_id), + schema=table_schema) + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.temp_dataset, table=table) + cls.bq_wrapper.client.tables.Insert(request) + + # Wait for table to be visible + cls.bq_wrapper.get_table(cls.project, cls.temp_dataset, table_id) + + if rows: + cls.bq_wrapper.insert_rows(cls.project, cls.temp_dataset, table_id, rows) + # Give streaming buffer time to flush + time.sleep(5) + + return bigquery.TableReference( + projectId=cls.project, datasetId=cls.temp_dataset, tableId=table_id) + + @classmethod + def _create_change_history_table(cls, table_id, rows=None): + """Create a table with enable_change_history via DDL.""" + ddl = ( + f'CREATE TABLE IF NOT EXISTS ' + f'`{cls.project}.{cls.dataset}.{table_id}` ' + f'(id INT64, name STRING, value FLOAT64) ' + f'OPTIONS (enable_change_history = true)') + + job_id = f'beam_ch_ddl_{uuid.uuid4().hex[:8]}' + reference = bigquery.JobReference(jobId=job_id, projectId=cls.project) + request = bigquery.BigqueryJobsInsertRequest( + projectId=cls.project, + job=bigquery.Job( + configuration=bigquery.JobConfiguration( + query=bigquery.JobConfigurationQuery( + query=ddl, useLegacySql=False)), + jobReference=reference)) + response = cls.bq_wrapper._start_job(request) + cls.bq_wrapper.wait_for_bq_job(response.jobReference, sleep_duration_sec=2) + + # Wait for table to be visible + cls.bq_wrapper.get_table(cls.project, cls.dataset, table_id) + + if rows: + cls.bq_wrapper.insert_rows(cls.project, cls.dataset, table_id, rows) + time.sleep(5) + + return bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset, tableId=table_id) + + +class CleanupTempTablesFnTest(BigQueryChangeHistoryIntegrationBase): + """Integration tests for _CleanupTempTablesFn against real BigQuery.""" + def test_single_complete_signal_deletes_table(self): + """A single signal with streams_read == total deletes the temp table.""" + table_id = f'cleanup_test_{secrets.token_hex(4)}' + table_ref = self._create_temp_table_with_data( + table_id, [{ + 'id': 1, 'name': 'a', 'value': 1.0 + }]) + table_key = _table_key(table_ref) + + # Feed cleanup signal: all 5 streams read out of 5 + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create([(table_key, (5, 5))]) + | beam.ParDo(_CleanupTempTablesFn())) + + # Verify table was deleted + time.sleep(2) + with self.assertRaises(Exception): + self.bq_wrapper.get_table(self.project, self.temp_dataset, table_id) + + def test_partial_signals_then_complete(self): + """Partial signals don't delete; final signal triggers cleanup.""" + table_id = f'cleanup_partial_{secrets.token_hex(4)}' + table_ref = self._create_temp_table_with_data( + table_id, [{ + 'id': 1, 'name': 'a', 'value': 1.0 + }]) + table_key = _table_key(table_ref) + + # Feed two partial signals: 3/10 + 7/10 = 10/10 + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create([ + (table_key, (3, 10)), + (table_key, (7, 10)), + ]) + | beam.ParDo(_CleanupTempTablesFn())) + + time.sleep(2) + with self.assertRaises(Exception): + self.bq_wrapper.get_table(self.project, self.temp_dataset, table_id) + + +class ReadStorageStreamsSDFTest(BigQueryChangeHistoryIntegrationBase): + """Integration tests for _ReadStorageStreamsSDF against real BigQuery. + + Tables must include change_timestamp (TIMESTAMP) and change_type (STRING) + columns to match the schema that _ExecuteQueryFn produces in the real + pipeline. The Read SDF extracts event timestamps from change_timestamp. + """ + _READ_SCHEMA = [ + ('id', 'INTEGER'), + ('name', 'STRING'), + ('value', 'FLOAT'), + ('change_timestamp', 'TIMESTAMP'), + ('change_type', 'STRING'), + ] + + def test_reads_rows_from_temp_table(self): + """SDF reads rows from a real temp table via Storage Read API.""" + table_id = f'sdf_test_{secrets.token_hex(4)}' + now = time.time() + rows = [ + { + 'id': 1, + 'name': 'alice', + 'value': 10.0, + 'change_timestamp': now, + 'change_type': 'INSERT' + }, + { + 'id': 2, + 'name': 'bob', + 'value': 20.0, + 'change_timestamp': now, + 'change_type': 'INSERT' + }, + { + 'id': 3, + 'name': 'charlie', + 'value': 30.0, + 'change_timestamp': now, + 'change_type': 'INSERT' + }, + ] + table_ref = self._create_temp_table_with_data( + table_id, rows, schema=self._READ_SCHEMA) + + query_result = _QueryResult( + temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + + with beam.Pipeline(argv=self.args) as p: + outputs = ( + p + | beam.Create([query_result]) + | beam.ParDo(_ReadStorageStreamsSDF()).with_outputs( + 'cleanup', main='rows')) + + # Check that we get 3 rows + row_count = ( + outputs['rows'] + | 'CountRows' >> beam.combiners.Count.Globally()) + assert_that(row_count, equal_to([3]), label='CheckRowCount') + + def test_cleanup_signal_emitted(self): + """SDF emits cleanup signal with correct counts.""" + table_id = f'sdf_cleanup_{secrets.token_hex(4)}' + now = time.time() + rows = [{ + 'id': 1, + 'name': 'a', + 'value': 1.0, + 'change_timestamp': now, + 'change_type': 'INSERT' + }] + table_ref = self._create_temp_table_with_data( + table_id, rows, schema=self._READ_SCHEMA) + + query_result = _QueryResult( + temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + + with beam.Pipeline(argv=self.args) as p: + outputs = ( + p + | beam.Create([query_result]) + | beam.ParDo(_ReadStorageStreamsSDF()).with_outputs( + 'cleanup', main='rows')) + + # Verify cleanup signal + cleanup_table_keys = ( + outputs['cleanup'] + | 'ExtractKey' >> beam.Map(lambda x: x[0])) + assert_that( + cleanup_table_keys, + equal_to([_table_key(table_ref)]), + label='CheckCleanupKey') + + def test_empty_table(self): + """Empty table produces 0 rows and cleanup signal.""" + table_id = f'sdf_empty_{secrets.token_hex(4)}' + now = time.time() + table_ref = self._create_temp_table_with_data( + table_id, [], schema=self._READ_SCHEMA) + + query_result = _QueryResult( + temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + + with beam.Pipeline(argv=self.args) as p: + outputs = ( + p + | beam.Create([query_result]) + | beam.ParDo(_ReadStorageStreamsSDF()).with_outputs( + 'cleanup', main='rows')) + + row_count = ( + outputs['rows'] + | 'CountRows' >> beam.combiners.Count.Globally()) + assert_that(row_count, equal_to([0]), label='CheckZeroRows') + + +class PollChangeHistoryFnTest(BigQueryChangeHistoryIntegrationBase): + """Integration test for _PollChangeHistoryFn in isolation.""" + def test_poll_emits_query_ranges(self): + """Poll SDF emits _QueryRange elements with valid time ranges.""" + table_str = f'{self.project}:{self.dataset}.nonexistent' + start_time = time.time() - 120 + + config = _PollConfig(start_time=start_time) + + poll_sdf = _PollChangeHistoryFn( + table=table_str, + project=self.project, + change_function='APPENDS', + buffer_sec=0, + start_time=start_time, + stop_time=time.time() + 5, + poll_interval_sec=60) + + with beam.Pipeline(argv=self.args) as p: + ranges = (p | beam.Create([config]) | beam.ParDo(poll_sdf)) + + # assert_that works directly on unbounded PCollections (no GBK). + def check_ranges(actual): + assert len(actual) >= 1, f'Expected >= 1 range, got {len(actual)}' + for r in actual: + assert r.chunk_start < r.chunk_end, ( + f'Invalid range: {r.chunk_start} >= {r.chunk_end}') + + assert_that(ranges, check_ranges) + + +class ExecuteQueryFnTest(BigQueryChangeHistoryIntegrationBase): + """Integration test for _ExecuteQueryFn in isolation.""" + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.test_table_id = f'exec_test_{secrets.token_hex(4)}' + cls.test_table_ref = cls._create_change_history_table( + cls.test_table_id, + rows=[ + { + 'id': 1, 'name': 'row1', 'value': 1.0 + }, + { + 'id': 2, 'name': 'row2', 'value': 2.0 + }, + ]) + cls.insert_time = time.time() + time.sleep(10) + + def test_execute_query_produces_query_result(self): + """ExecuteQueryFn creates a temp table from a _QueryRange.""" + table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' + start_time = self.insert_time - 120 + + query_range = _QueryRange(chunk_start=start_time, chunk_end=time.time()) + + with beam.Pipeline(argv=self.args) as p: + results = ( + p + | beam.Create([query_range]) + | beam.ParDo( + _ExecuteQueryFn( + table=table_str, + project=self.project, + change_function='APPENDS', + temp_dataset=self.temp_dataset, + location=self.location))) + + result_count = results | beam.combiners.Count.Globally() + assert_that(result_count, equal_to([1]), label='CheckOneResult') + + +class EndToEndTest(BigQueryChangeHistoryIntegrationBase): + """End-to-end test using the public ReadBigQueryChangeHistory API. + + Creates a change-history-enabled table, inserts rows, then runs the + full pipeline via the public PTransform and verifies rows come through. + """ + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.test_table_id = f'e2e_test_{secrets.token_hex(4)}' + cls.test_table_ref = cls._create_change_history_table( + cls.test_table_id, + rows=[ + { + 'id': 1, 'name': 'alice', 'value': 10.0 + }, + { + 'id': 2, 'name': 'bob', 'value': 20.0 + }, + { + 'id': 3, 'name': 'charlie', 'value': 30.0 + }, + ]) + cls.insert_time = time.time() + # Wait for streaming buffer + change history propagation + _LOGGER.info('Waiting for streaming buffer to flush...') + time.sleep(15) + + def test_public_api_reads_inserted_rows(self): + """ReadBigQueryChangeHistory PTransform with polling SDF.""" + table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' + start_time = self.insert_time - 120 # 2 min before insert + stop_time = time.time() + 5 # Short run for test + + with beam.Pipeline(argv=self.args) as p: + rows = ( + p + | ReadBigQueryChangeHistory( + table=table_str, + poll_interval_sec=60, + start_time=start_time, + stop_time=stop_time, + change_function='APPENDS', + buffer_sec=0, + project=self.project, + temp_dataset=self.temp_dataset)) + + def check_rows(actual): + assert len(actual) == 3, f'Expected 3 rows, got {len(actual)}' + got = sorted([{ + k: v + for k, v in row.items() if k != 'change_timestamp' + } for row in actual], + key=lambda r: r['id']) + expected = [ + { + 'id': 1, + 'name': 'alice', + 'value': 10.0, + 'change_type': 'INSERT' + }, + { + 'id': 2, 'name': 'bob', 'value': 20.0, 'change_type': 'INSERT' + }, + { + 'id': 3, + 'name': 'charlie', + 'value': 30.0, + 'change_type': 'INSERT' + }, + ] + assert got == expected, ( + f'Row mismatch:\n got: {got}\n expected: {expected}') + + assert_that(rows, check_rows) + + +def insert_test_rows(project, dataset, table, n, bq_wrapper=None): + """Insert n test rows into a BigQuery table. + + Args: + project: GCP project ID. + dataset: BigQuery dataset ID. + table: BigQuery table ID. + n: Number of rows to insert. + bq_wrapper: Optional BigQueryWrapper instance (creates one if None). + + Returns: + List of inserted row dicts. + """ + if bq_wrapper is None: + bq_wrapper = BigQueryWrapper() + rows = [{'id': i, 'name': f'row_{i}', 'value': float(i)} for i in range(n)] + bq_wrapper.insert_rows(project, dataset, table, rows) + return rows + + +def create_change_history_table(project, dataset, table_id, bq_wrapper=None): + """Create a table with enable_change_history via DDL. + + Args: + project: GCP project ID. + dataset: BigQuery dataset ID. + table_id: Table name to create. + bq_wrapper: Optional BigQueryWrapper instance. + + Returns: + bigquery.TableReference for the created table. + """ + if bq_wrapper is None: + bq_wrapper = BigQueryWrapper() + + ddl = ( + f'CREATE TABLE IF NOT EXISTS ' + f'`{project}.{dataset}.{table_id}` ' + f'(id INT64, name STRING, value FLOAT64) ' + f'OPTIONS (enable_change_history = true)') + + job_id = f'beam_ch_ddl_{uuid.uuid4().hex[:8]}' + reference = bigquery.JobReference(jobId=job_id, projectId=project) + request = bigquery.BigqueryJobsInsertRequest( + projectId=project, + job=bigquery.Job( + configuration=bigquery.JobConfiguration( + query=bigquery.JobConfigurationQuery( + query=ddl, useLegacySql=False)), + jobReference=reference)) + response = bq_wrapper._start_job(request) + bq_wrapper.wait_for_bq_job(response.jobReference, sleep_duration_sec=2) + + return bigquery.TableReference( + projectId=project, datasetId=dataset, tableId=table_id) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py new file mode 100644 index 000000000000..f5484946ced7 --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for BigQuery change history streaming source (no GCP required). + +Tests: + - build_changes_query format + - compute_ranges chunking + - _table_key conversion + - ReadBigQueryChangeHistory validation +""" + +import datetime +import unittest + +from apache_beam.io.gcp.bigquery_change_history import ReadBigQueryChangeHistory +from apache_beam.io.gcp.bigquery_change_history import _table_key +from apache_beam.io.gcp.bigquery_change_history import build_changes_query +from apache_beam.io.gcp.bigquery_change_history import compute_ranges +from apache_beam.io.gcp.internal.clients import bigquery + + +class BuildChangesQueryTest(unittest.TestCase): + """Tests for build_changes_query().""" + def test_appends_query_format(self): + # Use UTC-aware datetimes to avoid timezone offset issues + ts_start = datetime.datetime( + 2025, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 1, 1, 0, 0, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'myproject.mydataset.mytable', ts_start, ts_end, 'APPENDS') + self.assertIn('APPENDS', sql) + self.assertIn('TABLE `myproject.mydataset.mytable`', sql) + self.assertIn('2025-01-01T00:00:00', sql) + self.assertIn('2025-01-01T01:00:00', sql) + + def test_changes_query_format(self): + ts_start = datetime.datetime( + 2025, 6, 15, 12, 0, 0, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 6, 15, 18, 0, 0, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query('proj.ds.tbl', ts_start, ts_end, 'CHANGES') + self.assertIn('CHANGES', sql) + self.assertIn('TABLE `proj.ds.tbl`', sql) + + def test_columns_select(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', columns=['col_a', 'col_b']) + self.assertIn('SELECT col_a, col_b, _CHANGE_TYPE AS', sql) + self.assertNotIn('EXCEPT', sql) + + def test_columns_none_selects_all(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', columns=None) + self.assertIn('SELECT * EXCEPT', sql) + + def test_row_filter(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'proj.ds.tbl', + ts_start, + ts_end, + 'APPENDS', + row_filter='status = "active"') + self.assertIn('WHERE status = "active"', sql) + + def test_no_row_filter(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', row_filter=None) + self.assertNotIn('WHERE', sql) + + def test_columns_and_row_filter(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'proj.ds.tbl', + ts_start, + ts_end, + 'CHANGES', + columns=['id', 'name'], + row_filter='id > 100') + self.assertIn('SELECT id, name, _CHANGE_TYPE AS', sql) + self.assertNotIn('EXCEPT', sql) + self.assertIn('WHERE id > 100', sql) + + def test_colon_normalized_to_dot(self): + ts_start = datetime.datetime( + 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ts_end = datetime.datetime( + 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + sql = build_changes_query( + 'myproject:mydataset.mytable', ts_start, ts_end, 'APPENDS') + self.assertIn('TABLE `myproject.mydataset.mytable`', sql) + # Verify colon in table ref is normalized (timestamps contain colons) + table_part = sql.split('TABLE')[1].split(',')[0] + self.assertNotIn(':', table_part) + + +class ComputeRangesTest(unittest.TestCase): + """Tests for compute_ranges().""" + def test_appends_single_range(self): + """APPENDS has no chunking — returns single range even for multi-day.""" + start = 0.0 + end = 86400.0 * 5 # 5 days + ranges = compute_ranges(start, end, 'APPENDS') + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0], (start, end)) + + def test_changes_single_day(self): + """CHANGES within 1 day: single range.""" + start = 0.0 + end = 86400.0 # exactly 1 day + ranges = compute_ranges(start, end, 'CHANGES') + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0], (start, end)) + + def test_changes_multi_day(self): + """CHANGES spanning 3 days: should chunk into 3 ranges.""" + start = 0.0 + end = 86400.0 * 3 # 3 days + ranges = compute_ranges(start, end, 'CHANGES') + self.assertEqual(len(ranges), 3) + # Verify no gaps + for i in range(len(ranges) - 1): + self.assertEqual(ranges[i][1], ranges[i + 1][0]) + self.assertEqual(ranges[0][0], start) + self.assertEqual(ranges[-1][1], end) + + def test_changes_partial_day(self): + """CHANGES spanning 1.5 days: should chunk into 2 ranges.""" + start = 0.0 + end = 86400.0 * 1.5 + ranges = compute_ranges(start, end, 'CHANGES') + self.assertEqual(len(ranges), 2) + self.assertEqual(ranges[0], (0.0, 86400.0)) + self.assertEqual(ranges[1], (86400.0, end)) + + def test_zero_range(self): + """end <= start: empty list.""" + self.assertEqual(compute_ranges(100.0, 100.0, 'CHANGES'), []) + self.assertEqual(compute_ranges(100.0, 50.0, 'CHANGES'), []) + self.assertEqual(compute_ranges(100.0, 100.0, 'APPENDS'), []) + + def test_exact_day_boundary(self): + """Exactly 2 days: should produce 2 chunks.""" + start = 0.0 + end = 86400.0 * 2 + ranges = compute_ranges(start, end, 'CHANGES') + self.assertEqual(len(ranges), 2) + + +class TableKeyTest(unittest.TestCase): + """Tests for _table_key().""" + def test_conversion(self): + ref = bigquery.TableReference( + projectId='proj', datasetId='ds', tableId='tbl') + self.assertEqual(_table_key(ref), 'proj.ds.tbl') + + +class ValidationTest(unittest.TestCase): + """Tests for ReadBigQueryChangeHistory validation.""" + def test_invalid_change_function(self): + with self.assertRaises(ValueError): + ReadBigQueryChangeHistory(table='p:d.t', change_function='INVALID') + + def test_invalid_poll_interval(self): + with self.assertRaises(ValueError): + ReadBigQueryChangeHistory(table='p:d.t', poll_interval_sec=0) + with self.assertRaises(ValueError): + ReadBigQueryChangeHistory(table='p:d.t', poll_interval_sec=-1) + + def test_default_buffer(self): + t = ReadBigQueryChangeHistory(table='p:d.t', change_function='CHANGES') + self.assertEqual(t._buffer_sec, 15) + t = ReadBigQueryChangeHistory(table='p:d.t', change_function='APPENDS') + self.assertEqual(t._buffer_sec, 15) + + +if __name__ == '__main__': + unittest.main() From c5a5aca90943f63ffbb6914f25f7002b6a91187c Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 12:59:28 -0500 Subject: [PATCH 02/14] comments and race --- .../io/gcp/bigquery_change_history.py | 38 ++++++------ .../io/gcp/bigquery_change_history_it_test.py | 61 +------------------ 2 files changed, 21 insertions(+), 78 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index de8f17dbd92a..e848d9e27af6 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -70,6 +70,11 @@ from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import Timestamp +try: + from apitools.base.py.exceptions import HttpError +except ImportError: + HttpError = None # type: ignore + try: from google.cloud import bigquery_storage_v1 as bq_storage except ImportError: @@ -101,10 +106,6 @@ # (e.g. pipeline crash before cleanup runs). _DEFAULT_TABLE_EXPIRATION_MS = 24 * 60 * 60 * 1000 -# ============================================================================= -# Helpers and data classes -# ============================================================================= - @dataclasses.dataclass class _QueryResult: @@ -118,9 +119,9 @@ class _QueryResult: The Read SDF uses range_start to set an initial watermark hold so the runner doesn't advance the watermark past the data's timestamps. """ - temp_table_ref: Optional[bigquery.TableReference] = None - range_start: float = 0.0 - range_end: float = 0.0 + temp_table_ref: bigquery.TableReference + range_start: float + range_end: float @dataclasses.dataclass @@ -436,9 +437,10 @@ def initial_restriction(self, element: _PollConfig) -> OffsetRange: def create_tracker( self, restriction: OffsetRange) -> _NonSplittableOffsetTracker: - # When stop_time has passed, return an empty-range tracker so - # try_claim() fails immediately and check_done() passes (empty range). - if time.time() >= self._stop_time: + # Guarantee at least one poll cycle: restriction.start == 0 on the first + # invocation (from initial_restriction). After the first try_claim(0) + + # defer_remainder, subsequent invocations arrive with start >= 1. + if restriction.start > 0 and time.time() >= self._stop_time: _LOGGER.info( '[Poll] create_tracker: stop_time reached, ' 'returning empty range to terminate SDF') @@ -471,10 +473,6 @@ def _emit_query_ranges( now: float, watermark_estimator: _PollWatermarkEstimator) -> Iterable[_QueryRange]: """Compute and yield _QueryRange elements, advancing estimator state.""" - if self._stop_time != MAX_TIMESTAMP and now >= self._stop_time: - _LOGGER.info('[Poll] Stop time reached') - return - ranges = compute_ranges(start_ts, end_ts, self._change_function) _LOGGER.info( '[Poll] %d chunks for [%s, %s)', @@ -506,7 +504,7 @@ def process( now = time.time() start_ts = watermark_estimator.poll_cursor() - end_ts = now - self._buffer_sec + end_ts = min(now - self._buffer_sec, self._stop_time) defer_to = self._next_poll_time(start_ts, now) if defer_to is not None: @@ -581,7 +579,9 @@ def _get_or_create_temp_dataset(self) -> None: '[Query] Temp dataset %s.%s already exists', self._project, self._temp_dataset) - except Exception: + except HttpError as e: + if e.status_code != 404: + raise _LOGGER.info( '[Query] Creating temp dataset %s.%s with ' '24h table expiration, location=%s', @@ -863,9 +863,9 @@ def _create_read_session(self, table_ref: bigquery.TableReference) -> Any: requested_session = bq_storage.types.ReadSession() requested_session.table = table_path requested_session.data_format = bq_storage.types.DataFormat.ARROW - requested_session.read_options \ - .arrow_serialization_options.buffer_compression = \ - bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME + read_options = requested_session.read_options + read_options.arrow_serialization_options.buffer_compression = ( + bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME) session = self._storage_client.create_read_session( parent=f'projects/{table_ref.projectId}', diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 0c93a8347149..238fecbc3ebb 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -418,14 +418,14 @@ def test_public_api_reads_inserted_rows(self): """ReadBigQueryChangeHistory PTransform with polling SDF.""" table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' start_time = self.insert_time - 120 # 2 min before insert - stop_time = time.time() + 5 # Short run for test + stop_time = time.time() with beam.Pipeline(argv=self.args) as p: rows = ( p | ReadBigQueryChangeHistory( table=table_str, - poll_interval_sec=60, + poll_interval_sec=10, start_time=start_time, stop_time=stop_time, change_function='APPENDS', @@ -463,63 +463,6 @@ def check_rows(actual): assert_that(rows, check_rows) -def insert_test_rows(project, dataset, table, n, bq_wrapper=None): - """Insert n test rows into a BigQuery table. - - Args: - project: GCP project ID. - dataset: BigQuery dataset ID. - table: BigQuery table ID. - n: Number of rows to insert. - bq_wrapper: Optional BigQueryWrapper instance (creates one if None). - - Returns: - List of inserted row dicts. - """ - if bq_wrapper is None: - bq_wrapper = BigQueryWrapper() - rows = [{'id': i, 'name': f'row_{i}', 'value': float(i)} for i in range(n)] - bq_wrapper.insert_rows(project, dataset, table, rows) - return rows - - -def create_change_history_table(project, dataset, table_id, bq_wrapper=None): - """Create a table with enable_change_history via DDL. - - Args: - project: GCP project ID. - dataset: BigQuery dataset ID. - table_id: Table name to create. - bq_wrapper: Optional BigQueryWrapper instance. - - Returns: - bigquery.TableReference for the created table. - """ - if bq_wrapper is None: - bq_wrapper = BigQueryWrapper() - - ddl = ( - f'CREATE TABLE IF NOT EXISTS ' - f'`{project}.{dataset}.{table_id}` ' - f'(id INT64, name STRING, value FLOAT64) ' - f'OPTIONS (enable_change_history = true)') - - job_id = f'beam_ch_ddl_{uuid.uuid4().hex[:8]}' - reference = bigquery.JobReference(jobId=job_id, projectId=project) - request = bigquery.BigqueryJobsInsertRequest( - projectId=project, - job=bigquery.Job( - configuration=bigquery.JobConfiguration( - query=bigquery.JobConfigurationQuery( - query=ddl, useLegacySql=False)), - jobReference=reference)) - response = bq_wrapper._start_job(request) - bq_wrapper.wait_for_bq_job(response.jobReference, sleep_duration_sec=2) - - return bigquery.TableReference( - projectId=project, datasetId=dataset, tableId=table_id) - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) unittest.main() From f7c3fa3d410dd86e5357de0122089c15360fab63 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 13:20:05 -0500 Subject: [PATCH 03/14] typehint --- sdks/python/apache_beam/io/gcp/bigquery_change_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index e848d9e27af6..44866158b469 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -60,9 +60,9 @@ import apache_beam as beam from apache_beam.io.gcp import bigquery_tools from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.io.iobase import WatermarkEstimator from apache_beam.io.restriction_trackers import OffsetRange from apache_beam.io.restriction_trackers import OffsetRestrictionTracker -from apache_beam.io.iobase import WatermarkEstimator from apache_beam.io.watermark_estimators import ManualWatermarkEstimator from apache_beam.metrics import Metrics from apache_beam.transforms.core import WatermarkEstimatorProvider @@ -119,7 +119,7 @@ class _QueryResult: The Read SDF uses range_start to set an initial watermark hold so the runner doesn't advance the watermark past the data's timestamps. """ - temp_table_ref: bigquery.TableReference + temp_table_ref: 'bigquery.TableReference' range_start: float range_end: float From fcc2d62b435cbfe7b76065a3fcfc44ae994f3537 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 13:54:00 -0500 Subject: [PATCH 04/14] lint and checks --- .../apache_beam/io/gcp/bigquery_change_history.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 44866158b469..7be82ad60c1d 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -298,7 +298,7 @@ def create_watermark_estimator( return _PollWatermarkEstimator(estimator_state) -def _table_key(table_ref: bigquery.TableReference) -> str: +def _table_key(table_ref: 'bigquery.TableReference') -> str: """Convert a TableReference to a 'project.dataset.table' string.""" return f'{table_ref.projectId}.{table_ref.datasetId}.{table_ref.tableId}' @@ -853,7 +853,7 @@ def process( total_streams, )) - def _create_read_session(self, table_ref: bigquery.TableReference) -> Any: + def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any: """Create a BigQuery Storage ReadSession for the given table.""" table_path = ( f'projects/{table_ref.projectId}/' @@ -1076,10 +1076,11 @@ def __init__( raise ValueError( f"change_function must be 'CHANGES' or 'APPENDS', " f"got '{change_function}'") - if poll_interval_sec <= 0: + if poll_interval_sec <= 15: raise ValueError( - f'poll_interval_sec must be positive, got {poll_interval_sec}') - + f'poll_interval_sec must be >= 15, got {poll_interval_sec}') + if buffer_sec < 0: + raise ValueError(f'buffer_sec must be >= 10, got {buffer_sec}') self._table = table self._poll_interval_sec = poll_interval_sec self._start_time = start_time From 7feb0d3ef4ec2546cd3470be6c21e85fbdfce895 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 14:01:31 -0500 Subject: [PATCH 05/14] fix test. --- sdks/python/apache_beam/io/gcp/bigquery_change_history.py | 2 +- .../apache_beam/io/gcp/bigquery_change_history_it_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 7be82ad60c1d..b6c13cef6431 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -1076,7 +1076,7 @@ def __init__( raise ValueError( f"change_function must be 'CHANGES' or 'APPENDS', " f"got '{change_function}'") - if poll_interval_sec <= 15: + if poll_interval_sec < 15: raise ValueError( f'poll_interval_sec must be >= 15, got {poll_interval_sec}') if buffer_sec < 0: diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 238fecbc3ebb..e99f03564ffe 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -418,18 +418,18 @@ def test_public_api_reads_inserted_rows(self): """ReadBigQueryChangeHistory PTransform with polling SDF.""" table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' start_time = self.insert_time - 120 # 2 min before insert - stop_time = time.time() + stop_time = time.time() + 15 with beam.Pipeline(argv=self.args) as p: rows = ( p | ReadBigQueryChangeHistory( table=table_str, - poll_interval_sec=10, + poll_interval_sec=15, start_time=start_time, stop_time=stop_time, change_function='APPENDS', - buffer_sec=0, + buffer_sec=10, project=self.project, temp_dataset=self.temp_dataset)) From e918a363a068d8e14d8f8e60f6cd3c47bcb3eb52 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 27 Feb 2026 16:19:04 -0500 Subject: [PATCH 06/14] skiptests --- .../apache_beam/io/gcp/bigquery_change_history_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py index f5484946ced7..e4fe77a40249 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py @@ -33,6 +33,12 @@ from apache_beam.io.gcp.bigquery_change_history import compute_ranges from apache_beam.io.gcp.internal.clients import bigquery +# Protect against environments where apitools is not available. +try: + from apitools.base.py.exceptions import HttpError +except ImportError: + HttpError = None # type: ignore + class BuildChangesQueryTest(unittest.TestCase): """Tests for build_changes_query().""" @@ -181,6 +187,7 @@ def test_exact_day_boundary(self): self.assertEqual(len(ranges), 2) +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TableKeyTest(unittest.TestCase): """Tests for _table_key().""" def test_conversion(self): @@ -189,6 +196,7 @@ def test_conversion(self): self.assertEqual(_table_key(ref), 'proj.ds.tbl') +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class ValidationTest(unittest.TestCase): """Tests for ReadBigQueryChangeHistory validation.""" def test_invalid_change_function(self): From a88de6ef1e7c06df76ec64e09907f5342a2b0728 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Mar 2026 09:09:22 -0500 Subject: [PATCH 07/14] use bigquery current_timestamp for query end_ts --- .../io/gcp/bigquery_change_history.py | 94 ++++++++++--------- .../io/gcp/bigquery_change_history_it_test.py | 6 +- .../apache_beam/io/gcp/bigquery_tools.py | 10 +- 3 files changed, 61 insertions(+), 49 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index b6c13cef6431..1b9cff1abd11 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -67,14 +67,10 @@ from apache_beam.metrics import Metrics from apache_beam.transforms.core import WatermarkEstimatorProvider from apache_beam.transforms.window import TimestampedValue +from apache_beam.utils import retry from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import Timestamp -try: - from apitools.base.py.exceptions import HttpError -except ImportError: - HttpError = None # type: ignore - try: from google.cloud import bigquery_storage_v1 as bq_storage except ImportError: @@ -432,6 +428,36 @@ def __init__( self._poll_interval_sec = poll_interval_sec self._location = location + def setup(self) -> None: + self._bq_wrapper = bigquery_tools.BigQueryWrapper() + if self._location is None: + table_ref = bigquery_tools.parse_table_reference( + self._table, project=self._project) + self._location = self._bq_wrapper.get_table_location( + table_ref.projectId, table_ref.datasetId, table_ref.tableId) + _LOGGER.info( + '[Poll] Inferred location=%s from source table %s', + self._location, + self._table) + + @retry.with_exponential_backoff( + num_retries=3, + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def _get_bq_timestamp(self) -> float: + """Query BigQuery for the current server timestamp. + + Uses BQ's CURRENT_TIMESTAMP instead of the local clock to avoid + data loss from clock skew between the worker VM and BigQuery. + """ + request = bigquery.BigqueryJobsQueryRequest( + projectId=self._project, + queryRequest=bigquery.QueryRequest( + query='SELECT UNIX_MICROS(CURRENT_TIMESTAMP()) AS ts', + useLegacySql=False, + location=self._location)) + response = self._bq_wrapper.client.jobs.Query(request) + return int(response.rows[0].f[0].v.string_value) / 1e6 + def initial_restriction(self, element: _PollConfig) -> OffsetRange: return OffsetRange(0, sys.maxsize) @@ -470,7 +496,6 @@ def _emit_query_ranges( self, start_ts: float, end_ts: float, - now: float, watermark_estimator: _PollWatermarkEstimator) -> Iterable[_QueryRange]: """Compute and yield _QueryRange elements, advancing estimator state.""" ranges = compute_ranges(start_ts, end_ts, self._change_function) @@ -504,18 +529,24 @@ def process( now = time.time() start_ts = watermark_estimator.poll_cursor() - end_ts = min(now - self._buffer_sec, self._stop_time) defer_to = self._next_poll_time(start_ts, now) if defer_to is not None: restriction_tracker.defer_remainder(defer_to) return + # Use BQ server time instead of local clock to avoid data loss + # from clock skew between the worker VM and BigQuery. + bq_now = self._get_bq_timestamp() + end_ts = min(bq_now - self._buffer_sec, self._stop_time) + _LOGGER.info( - '[Poll] Polling: start_ts=%s, end_ts=%s, watermark=%s', + '[Poll] Polling: start_ts=%s, end_ts=%s, watermark=%s, ' + 'clock_skew=%.3fs', _utc(start_ts), _utc(end_ts), - _utc(watermark_estimator.current_watermark())) + _utc(watermark_estimator.current_watermark()), + bq_now - now) current_index = restriction_tracker.current_restriction().start @@ -524,8 +555,7 @@ def process( restriction_tracker.defer_remainder( Timestamp.of(now + self._poll_interval_sec)) - yield from self._emit_query_ranges( - start_ts, end_ts, now, watermark_estimator) + yield from self._emit_query_ranges(start_ts, end_ts, watermark_estimator) class _ExecuteQueryFn(beam.DoFn): @@ -563,40 +593,11 @@ def setup(self) -> None: '[Query] Inferred location=%s from source table %s', self._location, self._table) - self._get_or_create_temp_dataset() - - def _get_or_create_temp_dataset(self) -> None: - """Create the temp dataset if it doesn't exist. - - Sets defaultTableExpirationMs so orphaned temp tables are automatically - garbage-collected by BigQuery if the pipeline crashes before cleanup. - """ - try: - self._bq_wrapper.client.datasets.Get( - bigquery.BigqueryDatasetsGetRequest( - projectId=self._project, datasetId=self._temp_dataset)) - _LOGGER.info( - '[Query] Temp dataset %s.%s already exists', - self._project, - self._temp_dataset) - except HttpError as e: - if e.status_code != 404: - raise - _LOGGER.info( - '[Query] Creating temp dataset %s.%s with ' - '24h table expiration, location=%s', - self._project, - self._temp_dataset, - self._location) - dataset = bigquery.Dataset( - datasetReference=bigquery.DatasetReference( - projectId=self._project, datasetId=self._temp_dataset)) - if self._location is not None: - dataset.location = self._location - dataset.defaultTableExpirationMs = _DEFAULT_TABLE_EXPIRATION_MS - self._bq_wrapper.client.datasets.Insert( - bigquery.BigqueryDatasetsInsertRequest( - projectId=self._project, dataset=dataset)) + self._bq_wrapper.get_or_create_dataset( + self._project, + self._temp_dataset, + location=self._location, + default_table_expiration_ms=_DEFAULT_TABLE_EXPIRATION_MS) def process(self, qr: _QueryRange) -> Iterable[_QueryResult]: """Execute the BQ query described by a _QueryRange and yield _QueryResult. @@ -1145,7 +1146,8 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: buffer_sec=self._buffer_sec, start_time=start_time, stop_time=stop_time, - poll_interval_sec=self._poll_interval_sec))) + poll_interval_sec=self._poll_interval_sec, + location=self._location))) # CommitQueryResults: Reshuffle commits _QueryResult (temp table ref) # so that if the Read SDF retries, it re-reads the existing temp table diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index e99f03564ffe..0834e09f7ffc 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -329,7 +329,8 @@ def test_poll_emits_query_ranges(self): buffer_sec=0, start_time=start_time, stop_time=time.time() + 5, - poll_interval_sec=60) + poll_interval_sec=30, + location=self.location) with beam.Pipeline(argv=self.args) as p: ranges = (p | beam.Create([config]) | beam.ParDo(poll_sdf)) @@ -431,7 +432,8 @@ def test_public_api_reads_inserted_rows(self): change_function='APPENDS', buffer_sec=10, project=self.project, - temp_dataset=self.temp_dataset)) + temp_dataset=self.temp_dataset, + location=self.location)) def check_rows(actual): assert len(actual) == 3, f'Expected 3 rows, got {len(actual)}' diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index ddab941f9278..1a7a07706a39 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -842,7 +842,13 @@ def _create_table( num_retries=MAX_RETRIES, retry_filter=retry.retry_on_server_errors_and_timeout_filter) def get_or_create_dataset( - self, project_id, dataset_id, location=None, labels=None, kms_key=None): + self, + project_id, + dataset_id, + location=None, + labels=None, + kms_key=None, + default_table_expiration_ms=None): # Check if dataset already exists otherwise create it try: dataset = self.client.datasets.Get( @@ -868,6 +874,8 @@ def get_or_create_dataset( if kms_key is not None: dataset.defaultEncryptionConfiguration = ( _build_dataset_encryption_config(kms_key)) + if default_table_expiration_ms is not None: + dataset.defaultTableExpirationMs = default_table_expiration_ms request = bigquery.BigqueryDatasetsInsertRequest( projectId=project_id, dataset=dataset) response = self.client.datasets.Insert(request) From 127699c4289b01e33ff909e8ef85bc7251b60808 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Mar 2026 09:26:05 -0500 Subject: [PATCH 08/14] advance poll cursor monitonically --- .../python/apache_beam/io/gcp/bigquery_change_history.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 1b9cff1abd11..1d94a688bced 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -269,8 +269,13 @@ def set_watermark(self, timestamp: Timestamp) -> None: self._watermark_hold = timestamp def advance_poll_cursor(self, end_ts: float) -> None: - """Record end_ts so the next poll starts from here.""" - self._last_end_ts = end_ts + """Record end_ts so the next poll starts from here. + + Only advances forward: if end_ts is earlier than the current cursor + (e.g. BQ clock regression), the cursor stays put so the next poll + doesn't re-query an already-covered range. + """ + self._last_end_ts = max(self._last_end_ts, end_ts) def poll_cursor(self) -> float: """Return the start_ts for the next poll.""" From 60cd4342d088f18043463c06ae991cfc2f456dec Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 6 Mar 2026 10:47:21 -0500 Subject: [PATCH 09/14] use beam timestamps / remove float usage --- .../io/gcp/bigquery_change_history.py | 187 +++++++++--------- .../io/gcp/bigquery_change_history_it_test.py | 32 +-- .../io/gcp/bigquery_change_history_test.py | 122 ++++++------ 3 files changed, 178 insertions(+), 163 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 1d94a688bced..722d07e4e41f 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -55,7 +55,6 @@ from typing import List from typing import Optional from typing import Tuple -from typing import Union import apache_beam as beam from apache_beam.io.gcp import bigquery_tools @@ -68,6 +67,7 @@ from apache_beam.transforms.core import WatermarkEstimatorProvider from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import retry +from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import Timestamp @@ -85,8 +85,8 @@ __all__ = ['ReadBigQueryChangeHistory'] -# Max time range for CHANGES() queries: 1 day in seconds. -_MAX_CHANGES_RANGE_SEC = 86400 +# Max time range for CHANGES() queries: 1 day. +_MAX_CHANGES_RANGE = Duration(seconds=86400) # Side output tag for cleanup signals between the Read SDF and Cleanup DoFn. _CLEANUP_TAG = 'cleanup' @@ -111,36 +111,38 @@ class _QueryResult: pointing to the temp table containing query results. The Read SDF reads rows from that temp table via the Storage Read API. - range_start/range_end define the time window this query covers. - The Read SDF uses range_start to set an initial watermark hold so the runner - doesn't advance the watermark past the data's timestamps. + range_start/range_end define the time window this query covers as Beam + Timestamps (int microseconds internally). The Read SDF uses range_start + to set an initial watermark hold so the runner doesn't advance the + watermark past the data's timestamps. """ temp_table_ref: 'bigquery.TableReference' - range_start: float - range_end: float + range_start: Timestamp + range_end: Timestamp @dataclasses.dataclass class _PollConfig: """Input element for the polling SDF. - Only contains start_time, which _PollWatermarkEstimatorProvider uses - to initialize the watermark hold. All other config is passed via - _PollChangeHistoryFn.__init__. + Only contains start_time (Beam Timestamp), which + _PollWatermarkEstimatorProvider uses to initialize the watermark hold. + All other config is passed via _PollChangeHistoryFn.__init__. """ - start_time: float + start_time: Timestamp @dataclasses.dataclass class _QueryRange: """Lightweight instruction emitted by the polling SDF. - Contains only the time range to query. Static config (table, project, - etc.) is held by _ExecuteQueryFn which receives these after a Reshuffle - commit boundary, preventing duplicate queries on SDF re-dispatch. + Contains only the time range to query as Beam Timestamps (int microseconds + internally). Static config (table, project, etc.) is held by + _ExecuteQueryFn which receives these after a Reshuffle commit boundary, + preventing duplicate queries on SDF re-dispatch. """ - chunk_start: float - chunk_end: float + chunk_start: Timestamp + chunk_end: Timestamp class _StreamRestriction: @@ -239,16 +241,17 @@ class _PollWatermarkEstimator(WatermarkEstimator): the earliest data timestamp emitted by the current poll. This prevents downstream stages from seeing data as late. - The poll cursor (last_end_ts) tracks where the next poll should start. + The poll cursor (last_end) tracks where the next poll should start. This is separate from the watermark so we can hold the watermark back at start_ts while still advancing the poll cursor to end_ts. - State is checkpointed as (watermark_hold, last_end_ts) so + All timestamps are Beam Timestamps (int microseconds internally). + + State is checkpointed as (watermark_hold, last_end) so both values survive SDF re-dispatch. """ - def __init__(self, state: Tuple[Timestamp, float]) -> None: - # state is (watermark_hold: Timestamp, last_end_ts: float) - self._watermark_hold, self._last_end_ts = state + def __init__(self, state: Tuple[Timestamp, Timestamp]) -> None: + self._watermark_hold, self._last_end = state def observe_timestamp(self, timestamp: Timestamp) -> None: pass @@ -256,8 +259,8 @@ def observe_timestamp(self, timestamp: Timestamp) -> None: def current_watermark(self) -> Timestamp: return self._watermark_hold - def get_estimator_state(self) -> Tuple[Timestamp, float]: - return (self._watermark_hold, self._last_end_ts) + def get_estimator_state(self) -> Tuple[Timestamp, Timestamp]: + return (self._watermark_hold, self._last_end) def set_watermark(self, timestamp: Timestamp) -> None: if not isinstance(timestamp, Timestamp): @@ -268,18 +271,18 @@ def set_watermark(self, timestamp: Timestamp) -> None: 'Provided %s < current %s' % (timestamp, self._watermark_hold)) self._watermark_hold = timestamp - def advance_poll_cursor(self, end_ts: float) -> None: - """Record end_ts so the next poll starts from here. + def advance_poll_cursor(self, end: Timestamp) -> None: + """Record end so the next poll starts from here. - Only advances forward: if end_ts is earlier than the current cursor + Only advances forward: if end is earlier than the current cursor (e.g. BQ clock regression), the cursor stays put so the next poll doesn't re-query an already-covered range. """ - self._last_end_ts = max(self._last_end_ts, end_ts) + self._last_end = max(self._last_end, end) - def poll_cursor(self) -> float: - """Return the start_ts for the next poll.""" - return self._last_end_ts + def poll_cursor(self) -> Timestamp: + """Return the start Timestamp for the next poll.""" + return self._last_end class _PollWatermarkEstimatorProvider(WatermarkEstimatorProvider): @@ -290,12 +293,12 @@ class _PollWatermarkEstimatorProvider(WatermarkEstimatorProvider): """ def initial_estimator_state( self, element: _PollConfig, - restriction: OffsetRange) -> Tuple[Timestamp, float]: - return (Timestamp(element.start_time), element.start_time) + restriction: OffsetRange) -> Tuple[Timestamp, Timestamp]: + return (element.start_time, element.start_time) def create_watermark_estimator( self, estimator_state: Tuple[Timestamp, - float]) -> _PollWatermarkEstimator: + Timestamp]) -> _PollWatermarkEstimator: return _PollWatermarkEstimator(estimator_state) @@ -306,8 +309,8 @@ def _table_key(table_ref: 'bigquery.TableReference') -> str: def build_changes_query( table: str, - start_ts: float, - end_ts: float, + start: Timestamp, + end: Timestamp, change_function: str, change_type_column: str = 'change_type', change_timestamp_column: str = 'change_timestamp', @@ -317,8 +320,8 @@ def build_changes_query( Args: table: Table name as 'project.dataset.table' or 'project:dataset.table'. - start_ts: Start timestamp (float, seconds since epoch). Inclusive. - end_ts: End timestamp (float, seconds since epoch). Exclusive. + start: Start timestamp (Beam Timestamp). Inclusive. + end: End timestamp (Beam Timestamp). Exclusive. change_function: 'CHANGES' or 'APPENDS'. change_type_column: Output column name for _CHANGE_TYPE pseudo-column. change_timestamp_column: Output column name for _CHANGE_TIMESTAMP @@ -333,10 +336,8 @@ def build_changes_query( """ # Normalize 'project:dataset.table' to 'project.dataset.table' table = table.replace(':', '.') - start_iso = datetime.datetime.fromtimestamp( - start_ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ') - end_iso = datetime.datetime.fromtimestamp( - end_ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ') + start_iso = start.to_rfc3339() + end_iso = end.to_rfc3339() # Pseudo-columns (_CHANGE_TYPE, _CHANGE_TIMESTAMP) can't be written to # destination tables with their original names. Rename them so they can # be persisted to the temp table for Storage Read API reading. @@ -356,42 +357,39 @@ def build_changes_query( return f"{select} {from_clause}{where}" -def compute_ranges(start_ts: float, end_ts: float, - change_function: str) -> List[Tuple[float, float]]: - """Split [start_ts, end_ts) into query-safe chunks. +def compute_ranges(start: Timestamp, end: Timestamp, + change_function: str) -> List[Tuple[Timestamp, Timestamp]]: + """Split [start, end) into query-safe chunks. CHANGES() has a max 1-day range. APPENDS() has no limit. Args: - start_ts: Start timestamp (float, seconds since epoch). - end_ts: End timestamp (float, seconds since epoch). + start: Start Timestamp. Inclusive. + end: End Timestamp. Exclusive. change_function: 'CHANGES' or 'APPENDS'. Returns: - List of (start, end) float tuples. Empty if end_ts <= start_ts. + List of (start, end) Timestamp tuples. Empty if end <= start. """ - if end_ts <= start_ts: + if end <= start: return [] if change_function != 'CHANGES': - return [(start_ts, end_ts)] + return [(start, end)] # CHANGES: chunk into <=1-day ranges ranges = [] - current = start_ts - while current < end_ts: - chunk_end = min(current + _MAX_CHANGES_RANGE_SEC, end_ts) + current = start + while current < end: + chunk_end = min(current + _MAX_CHANGES_RANGE, end) ranges.append((current, chunk_end)) current = chunk_end return ranges -def _utc(ts: Union[float, Timestamp]) -> str: - """Format an epoch-seconds float or Timestamp as a UTC string.""" - if isinstance(ts, Timestamp): - ts = ts.seconds() - return datetime.datetime.fromtimestamp( - ts, tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') +def _utc(ts: Timestamp) -> str: + """Format a Beam Timestamp as a concise UTC string for logging.""" + return ts.to_utc_datetime(has_tz=True).strftime('%Y-%m-%dT%H:%M:%S.%f') # ============================================================================= @@ -403,12 +401,15 @@ class _PollChangeHistoryFn(beam.DoFn, beam.transforms.core.RestrictionProvider): """SDF that periodically emits _QueryRange instructions. Uses defer_remainder() for poll timing and _PollWatermarkEstimator to - control the watermark. The watermark is initially held at start_time , then + control the watermark. The watermark is initially held at start_time, then advanced to start_ts of each poll. + All timestamps are Beam Timestamps (int microseconds internally). + Durations (buffer, poll_interval) are Beam Durations. + Derives start_ts from the poll cursor. On each poll: 1. start_ts = poll cursor (last end_ts, or start_time on first poll) - 2. end_ts = now - buffer_sec + 2. end_ts = bq_now - buffer 3. Computes query chunks, yields _QueryRange per chunk 4. Advances poll cursor to end_ts (for next poll's start) 5. Advances watermark to start_ts (earliest data in this poll) @@ -419,18 +420,18 @@ def __init__( table: str, project: str, change_function: str, - buffer_sec: float, - start_time: float, - stop_time: Union[float, Timestamp], - poll_interval_sec: float, + buffer: Duration, + start_time: Timestamp, + stop_time: Timestamp, + poll_interval: Duration, location: Optional[str] = None) -> None: self._table = table self._project = project self._change_function = change_function - self._buffer_sec = buffer_sec + self._buffer = buffer self._start_time = start_time self._stop_time = stop_time - self._poll_interval_sec = poll_interval_sec + self._poll_interval = poll_interval self._location = location def setup(self) -> None: @@ -448,9 +449,10 @@ def setup(self) -> None: @retry.with_exponential_backoff( num_retries=3, retry_filter=retry.retry_on_server_errors_and_timeout_filter) - def _get_bq_timestamp(self) -> float: + def _get_bq_timestamp(self) -> Timestamp: """Query BigQuery for the current server timestamp. + Returns a Beam Timestamp created from integer microseconds. Uses BQ's CURRENT_TIMESTAMP instead of the local clock to avoid data loss from clock skew between the worker VM and BigQuery. """ @@ -461,7 +463,7 @@ def _get_bq_timestamp(self) -> float: useLegacySql=False, location=self._location)) response = self._bq_wrapper.client.jobs.Query(request) - return int(response.rows[0].f[0].v.string_value) / 1e6 + return Timestamp(micros=int(response.rows[0].f[0].v.string_value)) def initial_restriction(self, element: _PollConfig) -> OffsetRange: return OffsetRange(0, sys.maxsize) @@ -471,7 +473,7 @@ def create_tracker( # Guarantee at least one poll cycle: restriction.start == 0 on the first # invocation (from initial_restriction). After the first try_claim(0) + # defer_remainder, subsequent invocations arrive with start >= 1. - if restriction.start > 0 and time.time() >= self._stop_time: + if restriction.start > 0 and time.time() >= float(self._stop_time): _LOGGER.info( '[Poll] create_tracker: stop_time reached, ' 'returning empty range to terminate SDF') @@ -490,17 +492,18 @@ def split(self, element: _PollConfig, def truncate(self, element: _PollConfig, restriction: OffsetRange) -> None: return None - def _next_poll_time(self, start_ts: float, now: float) -> Optional[Timestamp]: + def _next_poll_time(self, start_ts: Timestamp, + now: float) -> Optional[Timestamp]: """Return a Timestamp to defer to, or None if we should poll now.""" - earliest = start_ts + self._buffer_sec + self._poll_interval_sec - if now < earliest: - return Timestamp.of(earliest) + earliest = start_ts + self._buffer + self._poll_interval + if now < float(earliest): + return earliest return None def _emit_query_ranges( self, - start_ts: float, - end_ts: float, + start_ts: Timestamp, + end_ts: Timestamp, watermark_estimator: _PollWatermarkEstimator) -> Iterable[_QueryRange]: """Compute and yield _QueryRange elements, advancing estimator state.""" ranges = compute_ranges(start_ts, end_ts, self._change_function) @@ -512,7 +515,7 @@ def _emit_query_ranges( Metrics.counter('BigQueryChangeHistory', 'polls').inc() watermark_estimator.advance_poll_cursor(end_ts) - watermark_estimator.set_watermark(Timestamp(start_ts)) + watermark_estimator.set_watermark(start_ts) _LOGGER.info( '[Poll] Watermark=%s (start_ts), cursor=%s (end_ts)', _utc(start_ts), @@ -520,8 +523,7 @@ def _emit_query_ranges( for chunk_start, chunk_end in ranges: yield TimestampedValue( - _QueryRange(chunk_start=chunk_start, chunk_end=chunk_end), - Timestamp(start_ts)) + _QueryRange(chunk_start=chunk_start, chunk_end=chunk_end), start_ts) @beam.DoFn.unbounded_per_element() def process( @@ -543,22 +545,21 @@ def process( # Use BQ server time instead of local clock to avoid data loss # from clock skew between the worker VM and BigQuery. bq_now = self._get_bq_timestamp() - end_ts = min(bq_now - self._buffer_sec, self._stop_time) + end_ts = min(bq_now - self._buffer, self._stop_time) _LOGGER.info( - '[Poll] Polling: start_ts=%s, end_ts=%s, watermark=%s, ' + '[Poll] Polling: start=%s, end=%s, watermark=%s, ' 'clock_skew=%.3fs', _utc(start_ts), _utc(end_ts), _utc(watermark_estimator.current_watermark()), - bq_now - now) + float(bq_now) - now) current_index = restriction_tracker.current_restriction().start if not restriction_tracker.try_claim(current_index): return - restriction_tracker.defer_remainder( - Timestamp.of(now + self._poll_interval_sec)) + restriction_tracker.defer_remainder(Timestamp.of(now) + self._poll_interval) yield from self._emit_query_ranges(start_ts, end_ts, watermark_estimator) @@ -677,7 +678,7 @@ class _CDCWatermarkEstimatorProvider(WatermarkEstimatorProvider): def initial_estimator_state( self, element: _QueryResult, restriction: _StreamRestriction) -> Timestamp: - return Timestamp(element.range_start) + return element.range_start def create_watermark_estimator( self, estimator_state: Timestamp) -> ManualWatermarkEstimator: @@ -836,7 +837,7 @@ def process( # Advance watermark to range_end after reading all streams. The # initial hold was set to range_start by _CDCWatermarkEstimatorProvider. - watermark_estimator.set_watermark(Timestamp(element.range_end)) + watermark_estimator.set_watermark(element.range_end) _LOGGER.info( '[Read] Watermark advanced to %s (range_end) for %s', _utc(element.range_end), @@ -1113,8 +1114,12 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: 'project must be specified either in ReadBigQueryChangeHistory ' 'or in pipeline options (--project)') - start_time = self._start_time or time.time() - stop_time = self._stop_time or MAX_TIMESTAMP + start_time = Timestamp(self._start_time or time.time()) + stop_time = ( + Timestamp(self._stop_time) + if self._stop_time is not None else MAX_TIMESTAMP) + buffer = Duration(seconds=self._buffer_sec) + poll_interval = Duration(seconds=self._poll_interval_sec) temp_dataset = self._temp_dataset if temp_dataset is None: @@ -1137,7 +1142,7 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: # The SDF uses defer_remainder() for poll timing and # _PollWatermarkEstimator to hold the watermark at data timestamps. # On the first invocation it handles the full historical range - # [start_time, now - buffer_sec) in a single poll. + # [start_time, now - buffer) in a single poll. config = _PollConfig(start_time=start_time) query_ranges = ( @@ -1148,10 +1153,10 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: table=self._table, project=project, change_function=self._change_function, - buffer_sec=self._buffer_sec, + buffer=buffer, start_time=start_time, stop_time=stop_time, - poll_interval_sec=self._poll_interval_sec, + poll_interval=poll_interval, location=self._location))) # CommitQueryResults: Reshuffle commits _QueryResult (temp table ref) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 0834e09f7ffc..5c24c325835d 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -39,6 +39,8 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.utils.timestamp import Duration +from apache_beam.utils.timestamp import Timestamp _LOGGER = logging.getLogger(__name__) @@ -215,6 +217,7 @@ def test_reads_rows_from_temp_table(self): """SDF reads rows from a real temp table via Storage Read API.""" table_id = f'sdf_test_{secrets.token_hex(4)}' now = time.time() + now_ts = Timestamp(now) rows = [ { 'id': 1, @@ -242,7 +245,9 @@ def test_reads_rows_from_temp_table(self): table_id, rows, schema=self._READ_SCHEMA) query_result = _QueryResult( - temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + temp_table_ref=table_ref, + range_start=now_ts - 60, + range_end=now_ts + 60) with beam.Pipeline(argv=self.args) as p: outputs = ( @@ -261,6 +266,7 @@ def test_cleanup_signal_emitted(self): """SDF emits cleanup signal with correct counts.""" table_id = f'sdf_cleanup_{secrets.token_hex(4)}' now = time.time() + now_ts = Timestamp(now) rows = [{ 'id': 1, 'name': 'a', @@ -272,7 +278,9 @@ def test_cleanup_signal_emitted(self): table_id, rows, schema=self._READ_SCHEMA) query_result = _QueryResult( - temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + temp_table_ref=table_ref, + range_start=now_ts - 60, + range_end=now_ts + 60) with beam.Pipeline(argv=self.args) as p: outputs = ( @@ -293,12 +301,14 @@ def test_cleanup_signal_emitted(self): def test_empty_table(self): """Empty table produces 0 rows and cleanup signal.""" table_id = f'sdf_empty_{secrets.token_hex(4)}' - now = time.time() + now_ts = Timestamp(time.time()) table_ref = self._create_temp_table_with_data( table_id, [], schema=self._READ_SCHEMA) query_result = _QueryResult( - temp_table_ref=table_ref, range_start=now - 60, range_end=now + 60) + temp_table_ref=table_ref, + range_start=now_ts - 60, + range_end=now_ts + 60) with beam.Pipeline(argv=self.args) as p: outputs = ( @@ -318,7 +328,7 @@ class PollChangeHistoryFnTest(BigQueryChangeHistoryIntegrationBase): def test_poll_emits_query_ranges(self): """Poll SDF emits _QueryRange elements with valid time ranges.""" table_str = f'{self.project}:{self.dataset}.nonexistent' - start_time = time.time() - 120 + start_time = Timestamp(time.time() - 120) config = _PollConfig(start_time=start_time) @@ -326,10 +336,10 @@ def test_poll_emits_query_ranges(self): table=table_str, project=self.project, change_function='APPENDS', - buffer_sec=0, + buffer=Duration(seconds=0), start_time=start_time, - stop_time=time.time() + 5, - poll_interval_sec=30, + stop_time=Timestamp(time.time() + 5), + poll_interval=Duration(seconds=30), location=self.location) with beam.Pipeline(argv=self.args) as p: @@ -367,9 +377,9 @@ def setUpClass(cls): def test_execute_query_produces_query_result(self): """ExecuteQueryFn creates a temp table from a _QueryRange.""" table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' - start_time = self.insert_time - 120 - - query_range = _QueryRange(chunk_start=start_time, chunk_end=time.time()) + query_range = _QueryRange( + chunk_start=Timestamp(self.insert_time - 120), + chunk_end=Timestamp(time.time())) with beam.Pipeline(argv=self.args) as p: results = ( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py index e4fe77a40249..58d42fa868e5 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py @@ -32,6 +32,8 @@ from apache_beam.io.gcp.bigquery_change_history import build_changes_query from apache_beam.io.gcp.bigquery_change_history import compute_ranges from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.utils.timestamp import Duration +from apache_beam.utils.timestamp import Timestamp # Protect against environments where apitools is not available. try: @@ -39,81 +41,70 @@ except ImportError: HttpError = None # type: ignore +_DAY = Duration(seconds=86400) + + +def _ts(*args, **kwargs) -> Timestamp: + """Create a UTC datetime and return a Beam Timestamp.""" + dt = datetime.datetime(*args, tzinfo=datetime.timezone.utc, **kwargs) + return Timestamp(dt.timestamp()) + class BuildChangesQueryTest(unittest.TestCase): """Tests for build_changes_query().""" def test_appends_query_format(self): - # Use UTC-aware datetimes to avoid timezone offset issues - ts_start = datetime.datetime( - 2025, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 1, 1, 0, 0, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1, 0, 0, 0) + end = _ts(2025, 1, 1, 1, 0, 0) sql = build_changes_query( - 'myproject.mydataset.mytable', ts_start, ts_end, 'APPENDS') + 'myproject.mydataset.mytable', start, end, 'APPENDS') self.assertIn('APPENDS', sql) self.assertIn('TABLE `myproject.mydataset.mytable`', sql) self.assertIn('2025-01-01T00:00:00', sql) self.assertIn('2025-01-01T01:00:00', sql) def test_changes_query_format(self): - ts_start = datetime.datetime( - 2025, 6, 15, 12, 0, 0, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 6, 15, 18, 0, 0, tzinfo=datetime.timezone.utc).timestamp() - sql = build_changes_query('proj.ds.tbl', ts_start, ts_end, 'CHANGES') + start = _ts(2025, 6, 15, 12, 0, 0) + end = _ts(2025, 6, 15, 18, 0, 0) + sql = build_changes_query('proj.ds.tbl', start, end, 'CHANGES') self.assertIn('CHANGES', sql) self.assertIn('TABLE `proj.ds.tbl`', sql) def test_columns_select(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( - 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', columns=['col_a', 'col_b']) + 'proj.ds.tbl', start, end, 'APPENDS', columns=['col_a', 'col_b']) self.assertIn('SELECT col_a, col_b, _CHANGE_TYPE AS', sql) self.assertNotIn('EXCEPT', sql) def test_columns_none_selects_all(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( - 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', columns=None) + 'proj.ds.tbl', start, end, 'APPENDS', columns=None) self.assertIn('SELECT * EXCEPT', sql) def test_row_filter(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( - 'proj.ds.tbl', - ts_start, - ts_end, - 'APPENDS', - row_filter='status = "active"') + 'proj.ds.tbl', start, end, 'APPENDS', row_filter='status = "active"') self.assertIn('WHERE status = "active"', sql) def test_no_row_filter(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( - 'proj.ds.tbl', ts_start, ts_end, 'APPENDS', row_filter=None) + 'proj.ds.tbl', start, end, 'APPENDS', row_filter=None) self.assertNotIn('WHERE', sql) def test_columns_and_row_filter(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( 'proj.ds.tbl', - ts_start, - ts_end, + start, + end, 'CHANGES', columns=['id', 'name'], row_filter='id > 100') @@ -122,40 +113,46 @@ def test_columns_and_row_filter(self): self.assertIn('WHERE id > 100', sql) def test_colon_normalized_to_dot(self): - ts_start = datetime.datetime( - 2025, 1, 1, tzinfo=datetime.timezone.utc).timestamp() - ts_end = datetime.datetime( - 2025, 1, 2, tzinfo=datetime.timezone.utc).timestamp() + start = _ts(2025, 1, 1) + end = _ts(2025, 1, 2) sql = build_changes_query( - 'myproject:mydataset.mytable', ts_start, ts_end, 'APPENDS') + 'myproject:mydataset.mytable', start, end, 'APPENDS') self.assertIn('TABLE `myproject.mydataset.mytable`', sql) # Verify colon in table ref is normalized (timestamps contain colons) table_part = sql.split('TABLE')[1].split(',')[0] self.assertNotIn(':', table_part) + def test_microsecond_precision(self): + """Verify sub-second precision is preserved in ISO output.""" + # 2025-01-01T00:00:00.123456Z + start = Timestamp(micros=_ts(2025, 1, 1).micros + 123456) + end = start + Duration(seconds=1) + sql = build_changes_query('proj.ds.tbl', start, end, 'APPENDS') + self.assertIn('2025-01-01T00:00:00.123456Z', sql) + class ComputeRangesTest(unittest.TestCase): """Tests for compute_ranges().""" def test_appends_single_range(self): """APPENDS has no chunking — returns single range even for multi-day.""" - start = 0.0 - end = 86400.0 * 5 # 5 days + start = Timestamp(0) + end = Timestamp(0) + _DAY * Duration(seconds=5) ranges = compute_ranges(start, end, 'APPENDS') self.assertEqual(len(ranges), 1) self.assertEqual(ranges[0], (start, end)) def test_changes_single_day(self): """CHANGES within 1 day: single range.""" - start = 0.0 - end = 86400.0 # exactly 1 day + start = Timestamp(0) + end = Timestamp(0) + _DAY ranges = compute_ranges(start, end, 'CHANGES') self.assertEqual(len(ranges), 1) self.assertEqual(ranges[0], (start, end)) def test_changes_multi_day(self): """CHANGES spanning 3 days: should chunk into 3 ranges.""" - start = 0.0 - end = 86400.0 * 3 # 3 days + start = Timestamp(0) + end = Timestamp(0) + _DAY * Duration(seconds=3) ranges = compute_ranges(start, end, 'CHANGES') self.assertEqual(len(ranges), 3) # Verify no gaps @@ -166,23 +163,26 @@ def test_changes_multi_day(self): def test_changes_partial_day(self): """CHANGES spanning 1.5 days: should chunk into 2 ranges.""" - start = 0.0 - end = 86400.0 * 1.5 + start = Timestamp(0) + one_day = Timestamp(0) + _DAY + end = Timestamp(micros=_DAY.micros + _DAY.micros // 2) ranges = compute_ranges(start, end, 'CHANGES') self.assertEqual(len(ranges), 2) - self.assertEqual(ranges[0], (0.0, 86400.0)) - self.assertEqual(ranges[1], (86400.0, end)) + self.assertEqual(ranges[0], (start, one_day)) + self.assertEqual(ranges[1], (one_day, end)) def test_zero_range(self): """end <= start: empty list.""" - self.assertEqual(compute_ranges(100.0, 100.0, 'CHANGES'), []) - self.assertEqual(compute_ranges(100.0, 50.0, 'CHANGES'), []) - self.assertEqual(compute_ranges(100.0, 100.0, 'APPENDS'), []) + t100 = Timestamp(micros=100) + t50 = Timestamp(micros=50) + self.assertEqual(compute_ranges(t100, t100, 'CHANGES'), []) + self.assertEqual(compute_ranges(t100, t50, 'CHANGES'), []) + self.assertEqual(compute_ranges(t100, t100, 'APPENDS'), []) def test_exact_day_boundary(self): """Exactly 2 days: should produce 2 chunks.""" - start = 0.0 - end = 86400.0 * 2 + start = Timestamp(0) + end = Timestamp(0) + _DAY * Duration(seconds=2) ranges = compute_ranges(start, end, 'CHANGES') self.assertEqual(len(ranges), 2) From 0e902f1f02266d4a1bc1b86ccd2b2cd5515634b2 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 6 Mar 2026 11:29:49 -0500 Subject: [PATCH 10/14] lint --- sdks/python/apache_beam/io/gcp/bigquery_change_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 722d07e4e41f..5a6779e889bd 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -67,8 +67,8 @@ from apache_beam.transforms.core import WatermarkEstimatorProvider from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import retry -from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import Timestamp try: From c784edd76718a211ce43fe049b9e3ce298596b9b Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 9 Mar 2026 21:57:40 -0400 Subject: [PATCH 11/14] comments. --- .../io/gcp/bigquery_change_history.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 5a6779e889bd..30372203c7a2 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -302,11 +302,6 @@ def create_watermark_estimator( return _PollWatermarkEstimator(estimator_state) -def _table_key(table_ref: 'bigquery.TableReference') -> str: - """Convert a TableReference to a 'project.dataset.table' string.""" - return f'{table_ref.projectId}.{table_ref.datasetId}.{table_ref.tableId}' - - def build_changes_query( table: str, start: Timestamp, @@ -711,7 +706,7 @@ class _ReadStorageStreamsSDF(beam.DoFn, Emits: Main output: TimestampedValue(row_dict, event_timestamp) - Side output (_CLEANUP_TAG): (table_key, streams_read, total_streams) + Side output (_CLEANUP_TAG): (table_key, (streams_read, total_streams)) """ def __init__( self, @@ -738,7 +733,7 @@ def setup(self) -> None: def initial_restriction(self, element: _QueryResult) -> _StreamRestriction: """Create ReadSession and return _StreamRestriction with stream names.""" self._ensure_client() - table_key = _table_key(element.temp_table_ref) + table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref) session = self._create_read_session(element.temp_table_ref) stream_names = tuple(s.name for s in session.streams) _LOGGER.info( @@ -775,7 +770,7 @@ def process( _CDCWatermarkEstimatorProvider()) ) -> Iterable[Dict[str, Any]]: self._ensure_client() - table_key = _table_key(element.temp_table_ref) + table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref) _LOGGER.info( '[Read] Processing %s, range=[%s, %s), ' @@ -825,7 +820,7 @@ def process( yield TimestampedValue(row, ts) stream_rows += 1 total_rows += 1 - Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc() + Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(total_rows) streams_read += 1 _LOGGER.info( @@ -854,11 +849,7 @@ def process( total_streams, total_rows) yield beam.pvalue.TaggedOutput( - _CLEANUP_TAG, ( - table_key, - streams_read, - total_streams, - )) + _CLEANUP_TAG, (table_key, (streams_read, total_streams))) def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any: """Create a BigQuery Storage ReadSession for the given table.""" @@ -925,10 +916,7 @@ def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]: if batch_bytes and schema is not None: batch = pyarrow.ipc.read_record_batch( pyarrow.py_buffer(batch_bytes), schema) - columns = batch.to_pydict() - col_names = batch.schema.names - for i in range(batch.num_rows): - yield {name: columns[name][i] for name in col_names} + yield from batch.to_pylist() row_count += batch.num_rows elapsed = time.time() - t0 _LOGGER.info( @@ -1023,7 +1011,9 @@ class ReadBigQueryChangeHistory(beam.PTransform): Default: current time when pipeline starts. stop_time: Stop polling at this timestamp. Default: run forever. change_function: 'CHANGES' or 'APPENDS'. Default 'APPENDS'. - buffer_sec: Safety buffer in seconds behind now(). Default 15. + buffer_sec: Safety buffer in seconds behind now(). Default 15. BQ does not + fail or wait if the query end_ts is less than BQ's CURRENT_TIMESTAMP. + This is an extra guardrail to protect against silent data. project: GCP project ID. Default: from pipeline options. temp_dataset: Dataset for temp tables. If None (default), a per-pipeline dataset is auto-created with a 24-hour table @@ -1087,7 +1077,7 @@ def __init__( raise ValueError( f'poll_interval_sec must be >= 15, got {poll_interval_sec}') if buffer_sec < 0: - raise ValueError(f'buffer_sec must be >= 10, got {buffer_sec}') + raise ValueError(f'buffer_sec must be >= 0, got {buffer_sec}') self._table = table self._poll_interval_sec = poll_interval_sec self._start_time = start_time @@ -1191,9 +1181,6 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: _ = ( read_outputs[_CLEANUP_TAG] - | 'KeyByTable' >> - beam.Map(lambda x: (x[0], (x[1], x[2]))).with_output_types( - beam.typehints.Tuple[str, beam.typehints.Tuple[int, int]]) | 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn())) return read_outputs['rows'] From 029b78c7d5fdc56d162cc5b242aee712baf1029a Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 9 Mar 2026 22:10:49 -0400 Subject: [PATCH 12/14] comments --- .../apache_beam/io/gcp/bigquery_change_history.py | 15 +++++++-------- .../io/gcp/bigquery_change_history_it_test.py | 8 ++++---- .../io/gcp/bigquery_change_history_test.py | 12 ------------ 3 files changed, 11 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 30372203c7a2..1dc74a622084 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -974,14 +974,13 @@ def process( total_streams) if current_read >= total_streams: - parts = table_key.split('.') - if len(parts) == 3: - project, dataset, table = parts - _LOGGER.info( - '[Cleanup] All streams read: DELETING temp table %s', table_key) - self._bq_wrapper._delete_table(project, dataset, table) - _LOGGER.info('[Cleanup] Deleted temp table %s', table_key) - Metrics.counter('BigQueryChangeHistory', 'temp_tables_deleted').inc() + parsed = bigquery_tools.parse_table_reference(table_key) + _LOGGER.info( + '[Cleanup] All streams read: DELETING temp table %s', table_key) + self._bq_wrapper._delete_table( + parsed.projectId, parsed.datasetId, parsed.tableId) + _LOGGER.info('[Cleanup] Deleted temp table %s', table_key) + Metrics.counter('BigQueryChangeHistory', 'temp_tables_deleted').inc() streams_read.clear() else: _LOGGER.info( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 5c24c325835d..385fc13de11a 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -25,6 +25,7 @@ import uuid import apache_beam as beam +from apache_beam.io.gcp import bigquery_tools from apache_beam.io.gcp.bigquery_change_history import ReadBigQueryChangeHistory from apache_beam.io.gcp.bigquery_change_history import _CleanupTempTablesFn from apache_beam.io.gcp.bigquery_change_history import _ExecuteQueryFn @@ -33,7 +34,6 @@ from apache_beam.io.gcp.bigquery_change_history import _QueryRange from apache_beam.io.gcp.bigquery_change_history import _QueryResult from apache_beam.io.gcp.bigquery_change_history import _ReadStorageStreamsSDF -from apache_beam.io.gcp.bigquery_change_history import _table_key from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.testing.test_pipeline import TestPipeline @@ -160,7 +160,7 @@ def test_single_complete_signal_deletes_table(self): table_id, [{ 'id': 1, 'name': 'a', 'value': 1.0 }]) - table_key = _table_key(table_ref) + table_key = bigquery_tools.get_hashable_destination(table_ref) # Feed cleanup signal: all 5 streams read out of 5 with beam.Pipeline(argv=self.args) as p: @@ -181,7 +181,7 @@ def test_partial_signals_then_complete(self): table_id, [{ 'id': 1, 'name': 'a', 'value': 1.0 }]) - table_key = _table_key(table_ref) + table_key = bigquery_tools.get_hashable_destination(table_ref) # Feed two partial signals: 3/10 + 7/10 = 10/10 with beam.Pipeline(argv=self.args) as p: @@ -295,7 +295,7 @@ def test_cleanup_signal_emitted(self): | 'ExtractKey' >> beam.Map(lambda x: x[0])) assert_that( cleanup_table_keys, - equal_to([_table_key(table_ref)]), + equal_to([bigquery_tools.get_hashable_destination(table_ref)]), label='CheckCleanupKey') def test_empty_table(self): diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py index 58d42fa868e5..7015ddf76593 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py @@ -20,7 +20,6 @@ Tests: - build_changes_query format - compute_ranges chunking - - _table_key conversion - ReadBigQueryChangeHistory validation """ @@ -28,10 +27,8 @@ import unittest from apache_beam.io.gcp.bigquery_change_history import ReadBigQueryChangeHistory -from apache_beam.io.gcp.bigquery_change_history import _table_key from apache_beam.io.gcp.bigquery_change_history import build_changes_query from apache_beam.io.gcp.bigquery_change_history import compute_ranges -from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import Timestamp @@ -187,15 +184,6 @@ def test_exact_day_boundary(self): self.assertEqual(len(ranges), 2) -@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') -class TableKeyTest(unittest.TestCase): - """Tests for _table_key().""" - def test_conversion(self): - ref = bigquery.TableReference( - projectId='proj', datasetId='ds', tableId='tbl') - self.assertEqual(_table_key(ref), 'proj.ds.tbl') - - @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class ValidationTest(unittest.TestCase): """Tests for ReadBigQueryChangeHistory validation.""" From 645959b77bd6d07bdff9ade381263a964f5263a8 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Mar 2026 09:47:51 -0400 Subject: [PATCH 13/14] comments. --- .../io/gcp/bigquery_change_history.py | 4 +- .../io/gcp/bigquery_change_history_it_test.py | 93 +++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index 1dc74a622084..ba90c4a8963d 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -1010,7 +1010,7 @@ class ReadBigQueryChangeHistory(beam.PTransform): Default: current time when pipeline starts. stop_time: Stop polling at this timestamp. Default: run forever. change_function: 'CHANGES' or 'APPENDS'. Default 'APPENDS'. - buffer_sec: Safety buffer in seconds behind now(). Default 15. BQ does not + buffer_sec: Safety buffer in seconds behind now(). Default 10. BQ does not fail or wait if the query end_ts is less than BQ's CURRENT_TIMESTAMP. This is an extra guardrail to protect against silent data. project: GCP project ID. Default: from pipeline options. @@ -1049,7 +1049,7 @@ def __init__( start_time: Optional[float] = None, stop_time: Optional[float] = None, change_function: str = 'APPENDS', - buffer_sec: float = 15, + buffer_sec: float = 10, project: Optional[str] = None, temp_dataset: Optional[str] = None, location: Optional[str] = None, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 385fc13de11a..0af5a80fd434 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -150,6 +150,21 @@ def _create_change_history_table(cls, table_id, rows=None): return bigquery.TableReference( projectId=cls.project, datasetId=cls.dataset, tableId=table_id) + @classmethod + def _run_dml(cls, sql): + """Run a DML statement (INSERT/UPDATE/DELETE) and wait for completion.""" + job_id = f'beam_ch_dml_{uuid.uuid4().hex[:8]}' + reference = bigquery.JobReference(jobId=job_id, projectId=cls.project) + request = bigquery.BigqueryJobsInsertRequest( + projectId=cls.project, + job=bigquery.Job( + configuration=bigquery.JobConfiguration( + query=bigquery.JobConfigurationQuery( + query=sql, useLegacySql=False)), + jobReference=reference)) + response = cls.bq_wrapper._start_job(request) + cls.bq_wrapper.wait_for_bq_job(response.jobReference, sleep_duration_sec=2) + class CleanupTempTablesFnTest(BigQueryChangeHistoryIntegrationBase): """Integration tests for _CleanupTempTablesFn against real BigQuery.""" @@ -397,6 +412,84 @@ def test_execute_query_produces_query_result(self): assert_that(result_count, equal_to([1]), label='CheckOneResult') +class ChangesEndToEndTest(BigQueryChangeHistoryIntegrationBase): + """End-to-end test using CHANGES function to capture all mutation types. + + Creates a change-history-enabled table, performs INSERT, UPDATE, and + DELETE operations via DML, then reads back via CHANGES to verify all + change types appear. + """ + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.test_table_id = f'e2e_changes_{secrets.token_hex(4)}' + fq_table = f'{cls.project}.{cls.dataset}.{cls.test_table_id}' + + # Create a change-history-enabled table and insert initial rows via DML. + # DML inserts (not streaming inserts) are immediately visible and avoid + # streaming-buffer flush delays. + cls._create_change_history_table(cls.test_table_id) + cls.dml_start_time = time.time() + cls._run_dml( + f"INSERT INTO `{fq_table}` (id, name, value) " + f"VALUES (1, 'alice', 10.0), (2, 'bob', 20.0), (3, 'charlie', 30.0)") + cls._run_dml(f"UPDATE `{fq_table}` SET value = 25.0 WHERE id = 2") + cls._run_dml(f"DELETE FROM `{fq_table}` WHERE id = 3") + + _LOGGER.info('Waiting for change history propagation...') + time.sleep(15) + + def test_changes_captures_insert_update_delete(self): + """ReadBigQueryChangeHistory with CHANGES sees all mutation types.""" + table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' + start_time = self.dml_start_time - 120 + stop_time = time.time() + 15 + + with beam.Pipeline(argv=self.args) as p: + rows = ( + p + | ReadBigQueryChangeHistory( + table=table_str, + poll_interval_sec=15, + start_time=start_time, + stop_time=stop_time, + change_function='CHANGES', + buffer_sec=10, + project=self.project, + temp_dataset=self.temp_dataset, + location=self.location)) + + def check_rows(actual): + by_type = {} + for row in actual: + ct = row['change_type'] + by_type.setdefault(ct, []).append(row) + + # BQ CHANGES returns: + # INSERT: 3 (original rows) + # UPDATE: 1 (bob with new value=25.0) + # DELETE: 2 (bob's pre-update row + charlie's explicit delete) + inserts = sorted(by_type.get('INSERT', []), key=lambda r: r['id']) + assert len(inserts) == 3, ( + f'Expected 3 INSERTs, got {len(inserts)}: {inserts}') + + updates = by_type.get('UPDATE', []) + assert len(updates) == 1, ( + f'Expected 1 UPDATE, got {len(updates)}: {updates}') + assert updates[0]['id'] == 2 and updates[0]['value'] == 25.0, ( + f'Unexpected UPDATE row: {updates[0]}') + + deletes = sorted(by_type.get('DELETE', []), key=lambda r: r['id']) + assert len(deletes) == 2, ( + f'Expected 2 DELETEs, got {len(deletes)}: {deletes}') + delete_ids = {r['id'] for r in deletes} + assert delete_ids == { + 2, 3 + }, (f'Expected DELETE ids {{2, 3}}, got {delete_ids}') + + assert_that(rows, check_rows) + + class EndToEndTest(BigQueryChangeHistoryIntegrationBase): """End-to-end test using the public ReadBigQueryChangeHistory API. From 19dde53675ffd169a5191cf28d01520407db9f80 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Mar 2026 10:53:14 -0400 Subject: [PATCH 14/14] fix test --- .../python/apache_beam/io/gcp/bigquery_change_history_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py index 7015ddf76593..04cc84e6ef9e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_test.py @@ -199,9 +199,9 @@ def test_invalid_poll_interval(self): def test_default_buffer(self): t = ReadBigQueryChangeHistory(table='p:d.t', change_function='CHANGES') - self.assertEqual(t._buffer_sec, 15) + self.assertEqual(t._buffer_sec, 10) t = ReadBigQueryChangeHistory(table='p:d.t', change_function='APPENDS') - self.assertEqual(t._buffer_sec, 15) + self.assertEqual(t._buffer_sec, 10) if __name__ == '__main__':