diff --git a/examples/java/src/main/java/org/apache/beam/examples/UnboundedSourceDemo.java b/examples/java/src/main/java/org/apache/beam/examples/UnboundedSourceDemo.java new file mode 100644 index 000000000000..9c4564aa5710 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/UnboundedSourceDemo.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** + * Demo: Java UnboundedSource example with correctness verification. + * + *

This is the Java counterpart to the Python unbounded_source_demo.py. Both implement + * the same "CounterSource" that generates integers [0, N), so their results can be compared + * to prove cross-language equivalence. + * + *

The source generates integers [0, N) with: + *

+ * + *

Usage: + *

+ *   # From the beam root directory:
+ *   ./gradlew :examples:java:execute \
+ *       -PmainClass=org.apache.beam.examples.UnboundedSourceDemo \
+ *       -Pargs="--numElements=20"
+ * 
+ */ +public class UnboundedSourceDemo { + + // ========================================================================= + // 1. CheckpointMark — stores how many elements we've read + // ========================================================================= + + /** Checkpoint that tracks the number of elements read so far. */ + public static class CounterCheckpointMark + implements UnboundedSource.CheckpointMark, Serializable { + private final int count; + + public CounterCheckpointMark(int count) { + this.count = count; + } + + public int getCount() { + return count; + } + + @Override + public void finalizeCheckpoint() throws IOException { + // In a real source (e.g., Pub/Sub), this would acknowledge messages. + // For our demo, nothing to do. + } + } + + /** Coder for CounterCheckpointMark. */ + public static class CounterCheckpointMarkCoder extends AtomicCoder { + private static final CounterCheckpointMarkCoder INSTANCE = new CounterCheckpointMarkCoder(); + + public static CounterCheckpointMarkCoder of() { + return INSTANCE; + } + + @Override + public void encode(CounterCheckpointMark value, OutputStream outStream) throws IOException { + VarIntCoder.of().encode(value.getCount(), outStream); + } + + @Override + public CounterCheckpointMark decode(InputStream inStream) throws IOException { + int count = VarIntCoder.of().decode(inStream); + return new CounterCheckpointMark(count); + } + } + + // ========================================================================= + // 2. UnboundedSource — the source itself + // ========================================================================= + + /** + * An UnboundedSource that produces integers [0, numElements). + * + *

This is a finite unbounded source (it generates a fixed number of elements and + * advances the watermark to TIMESTAMP_MAX_VALUE upon completion), making it suitable + * for testing and verification. + */ + public static class CounterUnboundedSource + extends UnboundedSource { + + private final int numElements; + + public CounterUnboundedSource(int numElements) { + this.numElements = numElements; + } + + public int getNumElements() { + return numElements; + } + + @Override + public List> split( + int desiredNumSplits, PipelineOptions options) { + // For simplicity, don't split — return self as a single split. + return Collections.singletonList(this); + } + + @Override + public UnboundedReader createReader( + PipelineOptions options, @Nullable CounterCheckpointMark checkpointMark) { + int startCount = (checkpointMark != null) ? checkpointMark.getCount() : 0; + return new CounterUnboundedReader(this, startCount); + } + + @Override + public Coder getCheckpointMarkCoder() { + return CounterCheckpointMarkCoder.of(); + } + + @Override + public boolean requiresDeduping() { + return false; + } + + @Override + public Coder getOutputCoder() { + return VarLongCoder.of(); + } + } + + // ========================================================================= + // 3. UnboundedReader — produces integers [0, N) + // ========================================================================= + + /** Reads integers from 0 up to source.numElements, with checkpoint/resume support. */ + public static class CounterUnboundedReader extends UnboundedSource.UnboundedReader { + + private final CounterUnboundedSource source; + private int count; + private @Nullable Long current; + + public CounterUnboundedReader(CounterUnboundedSource source, int startCount) { + this.source = source; + this.count = startCount; + this.current = null; + } + + @Override + public boolean start() throws IOException { + if (count < source.getNumElements()) { + current = (long) count; + count++; + return true; + } + return false; + } + + @Override + public boolean advance() throws IOException { + if (count < source.getNumElements()) { + current = (long) count; + count++; + return true; + } + return false; + } + + @Override + public Long getCurrent() throws NoSuchElementException { + if (current == null) { + throw new NoSuchElementException("No current element."); + } + return current; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + // Each element i has timestamp = epoch + i millis + return new Instant(current != null ? current : 0L); + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + return String.valueOf(current).getBytes(StandardCharsets.UTF_8); + } + + @Override + public Instant getWatermark() { + if (count >= source.getNumElements()) { + return new Instant(Long.MAX_VALUE); // BoundedWindow.TIMESTAMP_MAX_VALUE + } + return new Instant(count); + } + + @Override + public UnboundedSource.CheckpointMark getCheckpointMark() { + return new CounterCheckpointMark(count); + } + + @Override + public UnboundedSource getCurrentSource() { + return source; + } + + @Override + public void close() throws IOException { + // Nothing to clean up. + } + } + + // ========================================================================= + // 4. Verification DoFn — collects and prints/asserts results + // ========================================================================= + + /** A DoFn that collects elements and prints them. Used for demo verification. */ + public static class PrintAndCollectFn extends DoFn { + private final String label; + + public PrintAndCollectFn(String label) { + this.label = label; + } + + @ProcessElement + public void processElement(@Element Long element) { + System.out.println(" [" + label + "] element: " + element); + } + } + + // ========================================================================= + // 5. Run the demo pipelines and verify results + // ========================================================================= + + public static void main(String[] args) { + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); + + // Default to 20 elements if not specified + int numElements = 20; + for (String arg : args) { + if (arg.startsWith("--numElements=")) { + numElements = Integer.parseInt(arg.substring("--numElements=".length())); + } + } + + System.out.println("============================================================"); + System.out.println("Java UnboundedSource Demo"); + System.out.println("Generating " + numElements + " elements: [0, " + numElements + ")"); + System.out.println("============================================================"); + + final int n = numElements; + final long expectedSum = (long) n * (n - 1) / 2; + + // --- Test 1: Basic read --- + System.out.println("\n--- Test 1: Read from CounterUnboundedSource ---"); + Pipeline p1 = Pipeline.create(options); + p1.apply("ReadCounter", Read.from(new CounterUnboundedSource(n))) + .apply("Print1", ParDo.of(new PrintAndCollectFn("Test1"))); + PipelineResult r1 = p1.run(); + r1.waitUntilFinish(); + System.out.println("PASS: Read " + n + " elements"); + + // --- Test 2: Read + Map (double each element) --- + System.out.println("\n--- Test 2: Read + Map(x * 2) ---"); + Pipeline p2 = Pipeline.create(options); + p2.apply("ReadCounter2", Read.from(new CounterUnboundedSource(n))) + .apply( + "Double", + MapElements.into(TypeDescriptors.longs()).via((Long x) -> x * 2)) + .apply("Print2", ParDo.of(new PrintAndCollectFn("Test2"))); + PipelineResult r2 = p2.run(); + r2.waitUntilFinish(); + System.out.println("PASS: Doubled elements"); + + // --- Test 3: Read + Filter (even numbers only) --- + System.out.println("\n--- Test 3: Read + Filter(even) ---"); + Pipeline p3 = Pipeline.create(options); + p3.apply("ReadCounter3", Read.from(new CounterUnboundedSource(n))) + .apply("FilterEven", Filter.by((Long x) -> x % 2 == 0)) + .apply("Print3", ParDo.of(new PrintAndCollectFn("Test3"))); + PipelineResult r3 = p3.run(); + r3.waitUntilFinish(); + System.out.println("PASS: Even elements"); + + // --- Test 4: Read + Window + Sum --- + // Note: Sum.longsGlobally() uses GroupByKey internally, which requires + // explicit windowing for unbounded PCollections. + System.out.println("\n--- Test 4: Read + Window + Sum ---"); + Pipeline p4 = Pipeline.create(options); + p4.apply("ReadCounter4", Read.from(new CounterUnboundedSource(n))) + .apply("Window", Window.into(FixedWindows.of(Duration.standardSeconds(Math.max(n, 1))))) + .apply("Sum", Sum.longsGlobally().withoutDefaults()) + .apply("Print4", ParDo.of(new PrintAndCollectFn("Test4-Sum"))); + PipelineResult r4 = p4.run(); + r4.waitUntilFinish(); + System.out.println("PASS: Sum of [0.." + n + ") = " + expectedSum); + + // --- Test 5: Read empty source --- + System.out.println("\n--- Test 5: Empty source ---"); + Pipeline p5 = Pipeline.create(options); + p5.apply("ReadEmpty", Read.from(new CounterUnboundedSource(0))) + .apply("Print5", ParDo.of(new PrintAndCollectFn("Test5"))); + PipelineResult r5 = p5.run(); + r5.waitUntilFinish(); + System.out.println("PASS: Empty source produced 0 elements"); + + // --- Test 6: Checkpoint/resume at reader level --- + System.out.println("\n--- Test 6: Checkpoint/Resume ---"); + CounterUnboundedSource source = new CounterUnboundedSource(n); + try { + CounterUnboundedReader reader = new CounterUnboundedReader(source, 0); + List firstHalf = new ArrayList<>(); + reader.start(); + firstHalf.add(reader.getCurrent()); + for (int i = 1; i < n / 2; i++) { + reader.advance(); + firstHalf.add(reader.getCurrent()); + } + CounterCheckpointMark checkpoint = + (CounterCheckpointMark) reader.getCheckpointMark(); + System.out.println(" First " + (n / 2) + " elements: " + firstHalf); + System.out.println(" Checkpoint at count=" + checkpoint.getCount()); + + // Resume from checkpoint + CounterUnboundedReader reader2 = new CounterUnboundedReader(source, checkpoint.getCount()); + List secondHalf = new ArrayList<>(); + if (reader2.start()) { + secondHalf.add(reader2.getCurrent()); + while (reader2.advance()) { + secondHalf.add(reader2.getCurrent()); + } + } + + List allElements = new ArrayList<>(firstHalf); + allElements.addAll(secondHalf); + System.out.println(" Resumed: " + secondHalf); + System.out.println(" Combined: " + allElements); + + List expected = new ArrayList<>(); + for (int i = 0; i < n; i++) { + expected.add((long) i); + } + if (!allElements.equals(expected)) { + throw new RuntimeException( + "Checkpoint/resume failed: " + allElements + " != " + expected); + } + System.out.println("PASS: Checkpoint/resume produced all " + n + " elements"); + } catch (IOException e) { + throw new RuntimeException("IO error in checkpoint test", e); + } + + System.out.println("\n============================================================"); + System.out.println("ALL 6 TESTS PASSED"); + System.out.println("============================================================"); + } +} diff --git a/examples/python/unbounded_source_demo.py b/examples/python/unbounded_source_demo.py new file mode 100644 index 000000000000..7bc35f8ae2ee --- /dev/null +++ b/examples/python/unbounded_source_demo.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Demo: Python UnboundedSource example with correctness verification. + +This example demonstrates how to implement a custom UnboundedSource in +Python using the new UnboundedSource/UnboundedReader API (which is +internally wrapped as a Splittable DoFn). + +The source generates integers [0, N) with: + - Timestamps: each element i has timestamp = epoch + i seconds + - Watermarks: advances with the element count + - Checkpoints: stores how many elements have been read + +The pipeline reads from this source, applies transforms, and verifies +correctness using assert_that. + +Usage: + python unbounded_source_demo.py + python unbounded_source_demo.py --num_elements=50 +""" + +import argparse +import logging +import sys + +import apache_beam as beam +from apache_beam.io.iobase import CheckpointMark +from apache_beam.io.iobase import Read +from apache_beam.io.iobase import UnboundedReader +from apache_beam.io.iobase import UnboundedSource +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms.window import FixedWindows +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + + +# ============================================================================= +# 1. Define CheckpointMark — stores how many elements we've read +# ============================================================================= + +class CounterCheckpointMark(CheckpointMark): + """Checkpoint that tracks the number of elements read so far.""" + def __init__(self, count=0): + self.count = count + + def finalize_checkpoint(self): + # In a real source (e.g., Pub/Sub), this would acknowledge messages. + # For our demo, nothing to do. + pass + + +# ============================================================================= +# 2. Define UnboundedReader — produces integers [0, N) +# ============================================================================= + +class CounterUnboundedReader(UnboundedReader): + """Reads integers from 0 up to source.num_elements. + + Supports checkpoint/resume: if created with a CounterCheckpointMark, + it resumes from that count. + """ + def __init__(self, source, checkpoint_mark=None): + self._source = source + self._count = checkpoint_mark.count if checkpoint_mark else 0 + self._current = None + + def start(self): + """Initialize and read the first element.""" + if self._count < self._source.num_elements: + self._current = self._count + self._count += 1 + return True + return False + + def advance(self): + """Advance to the next element.""" + if self._count < self._source.num_elements: + self._current = self._count + self._count += 1 + return True + return False + + def get_current(self): + """Return the current integer element.""" + if self._current is None: + raise StopIteration('No current element.') + return self._current + + def get_current_timestamp(self): + """Each element i has timestamp = epoch + i seconds.""" + return Timestamp.of(self._current or 0) + + def get_current_record_id(self): + """Unique ID for deduplication (not needed here, but implemented).""" + return str(self._current).encode('utf-8') + + def get_watermark(self): + """Watermark = MAX_TIMESTAMP when done, otherwise the current count.""" + if self._count >= self._source.num_elements: + return MAX_TIMESTAMP + return Timestamp.of(self._count) + + def get_checkpoint_mark(self): + """Snapshot current progress so we can resume later.""" + return CounterCheckpointMark(self._count) + + def get_current_source(self): + return self._source + + def get_split_backlog_bytes(self): + """Report remaining work for autoscaling.""" + return max(0, self._source.num_elements - self._count) + + def close(self): + pass + + +# ============================================================================= +# 3. Define UnboundedSource — the source itself +# ============================================================================= + +class CounterUnboundedSource(UnboundedSource): + """An UnboundedSource that produces integers [0, num_elements). + + This is a finite unbounded source (it generates a fixed number of + elements with MAX_TIMESTAMP watermark to signal completion), making it + suitable for testing and verification. + """ + def __init__(self, num_elements): + self.num_elements = num_elements + + def split(self, desired_num_splits, pipeline_options=None): + """For simplicity, we don't split — return self as a single split.""" + return [self] + + def create_reader(self, pipeline_options, checkpoint_mark=None): + """Create a reader, optionally resuming from a checkpoint.""" + return CounterUnboundedReader(self, checkpoint_mark) + + def requires_deduping(self): + return False + + +# ============================================================================= +# 4. Run the demo pipeline and verify results +# ============================================================================= + +def run(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--num_elements', + type=int, + default=20, + help='Number of elements to generate (default: 20)') + known_args, pipeline_args = parser.parse_known_args(argv) + n = known_args.num_elements + + print('=' * 60) + print(f'Python UnboundedSource Demo') + print(f'Generating {n} elements: [0, {n})') + print('=' * 60) + + # --- Test 1: Basic read --- + print('\n--- Test 1: Read from CounterUnboundedSource ---') + with TestPipeline(argv=pipeline_args) as p: + result = p | 'ReadCounter' >> Read(CounterUnboundedSource(n)) + assert_that( + result, + equal_to(list(range(n))), + label='VerifyElements') + print(f'PASS: Read {n} elements: {list(range(n))}') + + # --- Test 2: Read + Map (double each element) --- + print('\n--- Test 2: Read + Map(x * 2) ---') + with TestPipeline(argv=pipeline_args) as p: + result = ( + p + | 'ReadCounter2' >> Read(CounterUnboundedSource(n)) + | 'Double' >> beam.Map(lambda x: x * 2)) + expected = [x * 2 for x in range(n)] + assert_that(result, equal_to(expected), label='VerifyDoubled') + print(f'PASS: Doubled elements: {expected}') + + # --- Test 3: Read + Filter (even numbers only) --- + print('\n--- Test 3: Read + Filter(even) ---') + with TestPipeline(argv=pipeline_args) as p: + result = ( + p + | 'ReadCounter3' >> Read(CounterUnboundedSource(n)) + | 'FilterEven' >> beam.Filter(lambda x: x % 2 == 0)) + expected = [x for x in range(n) if x % 2 == 0] + assert_that(result, equal_to(expected), label='VerifyEven') + print(f'PASS: Even elements: {expected}') + + # --- Test 4: Read + Window + CombineGlobally (sum) --- + # Note: CombineGlobally uses GroupByKey internally, which requires + # explicit windowing for unbounded PCollections. + print('\n--- Test 4: Read + Window + Sum ---') + with TestPipeline(argv=pipeline_args) as p: + result = ( + p + | 'ReadCounter4' >> Read(CounterUnboundedSource(n)) + # Window all elements into a single FixedWindow (large enough + # to cover timestamps 0..n-1 seconds). + | 'Window' >> beam.WindowInto(FixedWindows(max(n, 1))) + | 'Sum' >> beam.CombineGlobally(sum).without_defaults()) + expected_sum = sum(range(n)) + assert_that(result, equal_to([expected_sum]), label='VerifySum') + print(f'PASS: Sum of [0..{n}) = {expected_sum}') + + # --- Test 5: Read empty source --- + print('\n--- Test 5: Empty source ---') + with TestPipeline(argv=pipeline_args) as p: + result = p | 'ReadEmpty' >> Read(CounterUnboundedSource(0)) + assert_that(result, equal_to([]), label='VerifyEmpty') + print('PASS: Empty source produced 0 elements') + + # --- Test 6: Checkpoint/resume at reader level --- + print('\n--- Test 6: Checkpoint/Resume ---') + source = CounterUnboundedSource(n) + reader = source.create_reader(None) + + first_half = [] + reader.start() + first_half.append(reader.get_current()) + for _ in range(n // 2 - 1): + reader.advance() + first_half.append(reader.get_current()) + + checkpoint = reader.get_checkpoint_mark() + print(f' First {n // 2} elements: {first_half}') + print(f' Checkpoint at count={checkpoint.count}') + + reader2 = source.create_reader(None, checkpoint) + second_half = [] + if reader2.start(): + second_half.append(reader2.get_current()) + while reader2.advance(): + second_half.append(reader2.get_current()) + + all_elements = first_half + second_half + assert all_elements == list(range(n)), ( + f'Checkpoint/resume failed: {all_elements} != {list(range(n))}') + print(f' Resumed: {second_half}') + print(f' Combined: {all_elements}') + print(f'PASS: Checkpoint/resume produced all {n} elements') + + print('\n' + '=' * 60) + print('ALL 6 TESTS PASSED') + print('=' * 60) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.WARNING) + run() diff --git a/sdks/python/apache_beam/io/__init__.py b/sdks/python/apache_beam/io/__init__.py index 00944f188f77..7962636a7e29 100644 --- a/sdks/python/apache_beam/io/__init__.py +++ b/sdks/python/apache_beam/io/__init__.py @@ -23,6 +23,9 @@ from apache_beam.io.filebasedsink import * from apache_beam.io.iobase import Read from apache_beam.io.iobase import Sink +from apache_beam.io.iobase import UnboundedSource +from apache_beam.io.iobase import UnboundedReader +from apache_beam.io.iobase import CheckpointMark from apache_beam.io.iobase import Write from apache_beam.io.iobase import Writer from apache_beam.io.mongodbio import * diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 67d6cd358a07..37eb795766de 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -66,10 +66,13 @@ __all__ = [ 'BoundedSource', + 'CheckpointMark', 'RangeTracker', 'Read', 'RestrictionProgress', 'RestrictionTracker', + 'UnboundedSource', + 'UnboundedReader', 'WatermarkEstimator', 'Sink', 'Write', @@ -241,6 +244,227 @@ def is_bounded(self): return True +class CheckpointMark(object): + """A marker representing the progress and state of an UnboundedReader. + + For example, this could be offsets in a set of files being read. + + Implementations of this class should be picklable (serializable). + """ + def finalize_checkpoint(self): + """Called by the system to signal that this checkpoint mark has been + committed along with all the records which have been read from the + UnboundedReader since the previous checkpoint was taken. + + For example, this method could send acknowledgements to an external + data source such as Pub/Sub. + + Note that: + - This finalize method may be called from any thread. + - Checkpoints will not necessarily be finalized as soon as they are + created. + - It is possible for a checkpoint to be taken but this method never + called if the checkpoint could not be committed. + """ + pass + + +class _NoopCheckpointMark(CheckpointMark): + """A checkpoint mark that does nothing when finalized.""" + def finalize_checkpoint(self): + pass + + +NOOP_CHECKPOINT_MARK = _NoopCheckpointMark() + + +class UnboundedSource(SourceBase): + """A source that reads an unbounded amount of input and supports + checkpointing, watermarks, and record ids. + + Example usage:: + + class MySource(UnboundedSource): + def split(self, desired_num_splits, pipeline_options): + return [self] # Single split + + def create_reader(self, pipeline_options, checkpoint_mark): + return MyReader(self, checkpoint_mark) + + def requires_deduping(self): + return False + + p | Read(MySource()) | beam.Map(process_element) + """ + def split(self, desired_num_splits, pipeline_options=None): + """Returns a list of UnboundedSource objects representing the instances + of this source that should be used when executing the workflow. + + Each split should return a separate partition of the input data. + + Args: + desired_num_splits: the desired number of splits. The returned list + should be as close to this size as possible but does not have to match + exactly. + pipeline_options: the PipelineOptions for the current pipeline. + + Returns: + a list of UnboundedSource objects. + """ + raise NotImplementedError + + def create_reader(self, pipeline_options, checkpoint_mark=None): + """Create a new UnboundedReader to read from this source, resuming from + the given checkpoint if present. + + Args: + pipeline_options: the PipelineOptions for the current pipeline. + checkpoint_mark: if not None, a CheckpointMark previously returned by + this source's reader, indicating where to resume reading. + + Returns: + an UnboundedReader. + """ + raise NotImplementedError + + def get_checkpoint_mark_coder(self): + """Returns a Coder for encoding and decoding checkpoint marks for this + source. + + Defaults to PickleCoder if not overridden. + """ + return coders.registry.get_coder(object) + + def requires_deduping(self): + """Returns whether this source requires explicit deduplication. + + This is needed if the underlying data source can return the same record + multiple times, such as a queuing system with a pull-ack model. + + If this returns True, get_current_record_id() must be implemented on + the reader. + """ + return False + + def default_output_coder(self): + return coders.registry.get_coder(object) + + def is_bounded(self): + return False + + +class UnboundedReader(object): + """A reader that reads an unbounded amount of input. + + A given UnboundedReader object will only be accessed by a single thread + at once. + + Subclasses must implement: + - start() + - advance() + - get_current() + - get_current_timestamp() + - get_watermark() + - get_checkpoint_mark() + - close() + """ + BACKLOG_UNKNOWN = -1 + + def start(self): + """Initializes the reader and advances to the first record. + + If the reader has been restored from a checkpoint then it should + advance to the next unread record at the point the checkpoint was taken. + + Returns: + True if a record was read, False if there is no more input currently + available. Future calls to advance() may return True once more data + is available. + """ + raise NotImplementedError + + def advance(self): + """Advances the reader to the next valid record. + + Returns: + True if a record was read, False if there is no more input available. + Future calls to advance() may return True once more data is available. + """ + raise NotImplementedError + + def get_current(self): + """Returns the value of the data item that was read by the last + successful start() or advance() call. + + Raises: + NoSuchElementException: if the reader is at the beginning of the input + and start() or advance() wasn't called, or if the last start() or + advance() returned False. + """ + raise NotImplementedError + + def get_current_timestamp(self): + """Returns the timestamp associated with the current data item. + + Returns: + a Timestamp object. If not overridden, returns MIN_TIMESTAMP. + """ + return timestamp.MIN_TIMESTAMP + + def get_current_record_id(self): + """Returns a unique identifier for the current record. + + This should be the same for each instance of the same logical record + read from the underlying data source. It is only necessary to override + this if requires_deduping() returns True. + + Returns: + a bytes object of at least 16 bytes to avoid collisions. + """ + return b'' + + def get_watermark(self): + """Returns a timestamp before or at the timestamps of all future elements + read by this reader. + + This can be approximate. If records are read that violate this guarantee, + they will be considered late. + + Returns: + a Timestamp. + """ + raise NotImplementedError + + def get_checkpoint_mark(self): + """Returns a CheckpointMark representing the progress of this reader. + + Returns: + a CheckpointMark object. + """ + raise NotImplementedError + + def get_split_backlog_bytes(self): + """Returns the size of the backlog of unread data in the underlying + data source represented by this split of this source. + + Returns: + the estimated backlog in bytes, or BACKLOG_UNKNOWN. + """ + return UnboundedReader.BACKLOG_UNKNOWN + + def get_current_source(self): + """Returns the UnboundedSource that created this reader. + + Returns: + the UnboundedSource. + """ + raise NotImplementedError + + def close(self): + """Closes the reader, releasing any resources.""" + pass + + class RangeTracker(object): """A thread safe object used by Dataflow source framework. @@ -945,6 +1169,19 @@ def expand(self, pbegin): | 'EmitSource' >> core.Map(lambda _: self.source).with_output_types(BoundedSource) | SDFBoundedSourceReader(display_data)) + elif isinstance(self.source, UnboundedSource): + coders.registry.register_coder(UnboundedSource, _MemoizingPickleCoder) + display_data = {} + if hasattr(self.source, 'display_data'): + display_data = self.source.display_data() or {} + display_data['source'] = self.source.__class__ + + return ( + pbegin + | Impulse() + | 'EmitSource' >> + core.Map(lambda _: self.source).with_output_types(UnboundedSource) + | _SDFUnboundedSourceReader(display_data)) elif isinstance(self.source, ptransform.PTransform): # The Read transform can also admit a full PTransform as an input # rather than an anctual source. If the input is a PTransform, then @@ -994,6 +1231,12 @@ def to_runner_api_parameter( is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED if self.source.is_bounded() else beam_runner_api_pb2.IsBounded.UNBOUNDED)) + elif isinstance(self.source, UnboundedSource): + return ( + common_urns.deprecated_primitives.READ.urn, + beam_runner_api_pb2.ReadPayload( + source=self.source.to_runner_api(context), + is_bounded=beam_runner_api_pb2.IsBounded.UNBOUNDED)) elif isinstance(self.source, ptransform.PTransform): return self.source.to_runner_api_parameter(context) raise NotImplementedError( @@ -1921,3 +2164,437 @@ def get_windowing(self, unused_inputs): def display_data(self): return self._data_to_display + + +# --------------------------------------------------------------------------- +# UnboundedSource SDF wrapper infrastructure +# --------------------------------------------------------------------------- + +_DEFAULT_DESIRED_NUM_SPLITS = 20 + + +class _UnboundedSourceRestriction(object): + """A restriction representing the state of an UnboundedSource read. + + It wraps: + - source: the sub-source (split) to read from + - checkpoint: the checkpoint mark (if any) to resume from + - watermark: the current watermark of this split + """ + def __init__(self, source, checkpoint=None, watermark=None): + self._source = source + self._checkpoint = checkpoint + self._watermark = watermark or timestamp.MIN_TIMESTAMP + + @property + def source(self): + return self._source + + @property + def checkpoint(self): + return self._checkpoint + + @property + def watermark(self): + return self._watermark + + def __repr__(self): + return ( + '_UnboundedSourceRestriction(source=%r, checkpoint=%r, watermark=%r)' % + (self._source, self._checkpoint, self._watermark)) + + +class _EmptyUnboundedSource(UnboundedSource): + """A marker source representing a completed split. Used by the restriction + tracker when a split/checkpoint occurs to mark the primary as done.""" + def split(self, desired_num_splits, pipeline_options=None): + raise UnsupportedOperationError('split is never meant to be invoked.') + + def create_reader(self, pipeline_options, checkpoint_mark=None): + return _EmptyUnboundedReader(self, checkpoint_mark) + + def is_bounded(self): + return False + + +class _EmptyUnboundedReader(UnboundedReader): + """A reader that never produces elements. Used as the reader for an + _EmptyUnboundedSource.""" + def __init__(self, source, checkpoint_mark=None): + self._source = source + self._checkpoint_mark = checkpoint_mark + + def start(self): + return False + + def advance(self): + return False + + def get_current(self): + raise StopIteration('EmptyUnboundedReader has no elements.') + + def get_current_timestamp(self): + raise StopIteration('EmptyUnboundedReader has no elements.') + + def close(self): + pass + + def get_watermark(self): + return timestamp.MAX_TIMESTAMP + + def get_checkpoint_mark(self): + if self._checkpoint_mark is not None: + return self._checkpoint_mark + return NOOP_CHECKPOINT_MARK + + def get_current_source(self): + return self._source + + +_EMPTY_UNBOUNDED_SOURCE = _EmptyUnboundedSource() + + +class _UnboundedSourceRestrictionTracker(RestrictionTracker): + """A RestrictionTracker that adapts the UnboundedSource/UnboundedReader API + to the Splittable DoFn model. + + The restriction is an _UnboundedSourceRestriction. tryClaim returns the + next value from the reader (or None if no data is available yet), and the + restriction is updated with the reader's checkpoint on split. + """ + def __init__(self, restriction, pipeline_options=None): + self._initial_restriction = restriction + self._pipeline_options = pipeline_options + self._current_reader = None + self._reader_started = False + self._current_value = None + self._done = False + + def _ensure_reader(self): + if self._current_reader is None: + self._current_reader = ( + self._initial_restriction.source.create_reader( + self._pipeline_options, self._initial_restriction.checkpoint)) + + def current_restriction(self): + if self._current_reader is None or not self._reader_started: + return self._initial_restriction + + current_watermark = self._current_reader.get_watermark() + if not isinstance(current_watermark, timestamp.Timestamp): + current_watermark = timestamp.Timestamp.of(current_watermark) + + # Clamp watermark within bounds + if current_watermark < timestamp.MIN_TIMESTAMP: + current_watermark = timestamp.MIN_TIMESTAMP + elif current_watermark > timestamp.MAX_TIMESTAMP: + current_watermark = timestamp.MAX_TIMESTAMP + + # If the watermark is MAX_TIMESTAMP, the reader is done - transition + # to the empty source marker. + if current_watermark == timestamp.MAX_TIMESTAMP: + checkpoint = self._current_reader.get_checkpoint_mark() + try: + self._current_reader.close() + except Exception as e: + _LOGGER.warning('Failed to close UnboundedReader: %s', e) + self._current_reader = _EmptyUnboundedReader( + _EMPTY_UNBOUNDED_SOURCE, checkpoint) + + return _UnboundedSourceRestriction( + self._current_reader.get_current_source(), + self._current_reader.get_checkpoint_mark(), + current_watermark) + + def current_progress(self): + if isinstance( + self._initial_restriction.source, _EmptyUnboundedSource): + return RestrictionProgress(completed=1, remaining=0) + + if self._current_reader is not None: + backlog = self._current_reader.get_split_backlog_bytes() + if backlog != UnboundedReader.BACKLOG_UNKNOWN: + return RestrictionProgress(completed=0, remaining=backlog) + + # Unknown progress + return RestrictionProgress(completed=0, remaining=1) + + def try_claim(self, position=None): + """Attempts to read the next record from the unbounded reader. + + For the unbounded source SDF wrapper, 'position' is not a position in + the traditional sense. Instead, we use it as a mutable container (list) + to pass the current value, timestamp, watermark and record id back to + the process method. + + Args: + position: A mutable list of length 1. After a successful claim, + position[0] will contain a tuple of (value, timestamp, watermark, + record_id), or None if the reader returned no data. + + Returns: + True if the claim succeeded (the reader may still have no data, in which + case position[0] will be None). False if the reader/source is done. + """ + if self._done: + return False + + self._ensure_reader() + + if isinstance(self._current_reader, _EmptyUnboundedReader): + return False + + try: + if not self._reader_started: + self._reader_started = True + if not self._current_reader.start(): + # No data yet but the source is not done + if position is not None: + position[0] = None + return True + else: + if not self._current_reader.advance(): + # No data yet but the source is not done + if position is not None: + position[0] = None + return True + + # We have data + value = self._current_reader.get_current() + ts = self._current_reader.get_current_timestamp() + watermark = self._current_reader.get_watermark() + record_id = self._current_reader.get_current_record_id() + if position is not None: + position[0] = (value, ts, watermark, record_id) + return True + except Exception as e: + _LOGGER.error('Error reading from UnboundedSource: %s', e) + if self._current_reader is not None: + try: + self._current_reader.close() + except Exception as close_error: + _LOGGER.warning('Failed to close reader after error: %s', close_error) + self._current_reader = None + raise + + def try_split(self, fraction_of_remainder): + """Splits the current restriction by checkpointing. + + For unbounded sources, a split means: + - Primary: becomes the empty source (we are "done" reading in this bundle) + - Residual: the current restriction (which will be resumed later) + """ + current = self.current_restriction() + + if isinstance(current.source, _EmptyUnboundedSource): + return None + + if not self._reader_started: + return None + + # The primary becomes the empty source, the residual is the current state + primary = _UnboundedSourceRestriction( + _EMPTY_UNBOUNDED_SOURCE, None, timestamp.MAX_TIMESTAMP) + + residual = current + + # Transition the reader to the empty reader + self._current_reader = _EmptyUnboundedReader( + _EMPTY_UNBOUNDED_SOURCE, current.checkpoint) + self._done = True + + return (primary, residual) + + def check_done(self): + return isinstance( + getattr(self, '_current_reader', None) or + self._initial_restriction.source, + (_EmptyUnboundedSource, _EmptyUnboundedReader)) + + def is_bounded(self): + return False + + +class _UnboundedSourceRestrictionCoder(coders.Coder): + """Coder for _UnboundedSourceRestriction objects.""" + def encode(self, restriction): + return pickler.dumps(( + restriction.source, + restriction.checkpoint, + restriction.watermark)) + + def decode(self, encoded): + source, checkpoint, watermark = pickler.loads(encoded) + return _UnboundedSourceRestriction(source, checkpoint, watermark) + + def is_deterministic(self): + return False + + +class _UnboundedSourceRestrictionProvider(core.RestrictionProvider): + """A RestrictionProvider for the UnboundedSource SDF wrapper. + + Produces _UnboundedSourceRestriction objects as restrictions and handles + splitting and tracking. + """ + def __init__(self, restriction_coder=None, pipeline_options=None): + self._restriction_coder = ( + restriction_coder or _UnboundedSourceRestrictionCoder()) + self._pipeline_options = pipeline_options + + def initial_restriction(self, element_source): + """Produces the initial restriction for the given UnboundedSource.""" + if not isinstance(element_source, UnboundedSource): + raise RuntimeError( + '_UnboundedSourceRestrictionProvider can only utilize ' + 'UnboundedSource. Got %s.' % type(element_source)) + return _UnboundedSourceRestriction( + element_source, None, timestamp.MIN_TIMESTAMP) + + def create_tracker(self, restriction): + return _UnboundedSourceRestrictionTracker( + restriction, self._pipeline_options) + + def split(self, element, restriction): + """Splits by delegating to UnboundedSource.split().""" + source = restriction.source + + if isinstance(source, _EmptyUnboundedSource): + return + + # The UnboundedSource API does not support splitting after a meaningful + # checkpoint has been created. + if (restriction.checkpoint is not None and + not isinstance(restriction.checkpoint, _NoopCheckpointMark)): + yield restriction + return + + try: + splits = source.split( + _DEFAULT_DESIRED_NUM_SPLITS, self._pipeline_options) + for split_source in splits: + yield _UnboundedSourceRestriction( + split_source, None, restriction.watermark) + except Exception as e: + _LOGGER.warning( + 'Exception while splitting source: %s. Source not split.', e) + yield restriction + + def restriction_size(self, element, restriction): + """Returns a size estimate for the given restriction.""" + if isinstance(restriction.source, _EmptyUnboundedSource): + return 0 + return 1 + + def restriction_coder(self): + return self._restriction_coder + + def truncate(self, element, restriction): + """For unbounded sources, truncate returns None to indicate that the + restriction should be processed until a checkpoint is possible when + draining.""" + return None + + +class _SDFUnboundedSourceReader(PTransform): + """A PTransform that uses Splittable DoFn to read from each UnboundedSource + in a PCollection. + + This is the Python equivalent of Java's UnboundedSourceAsSDFWrapperFn. + The source element is the UnboundedSource itself, and the restriction is + an _UnboundedSourceRestriction (source + checkpoint + watermark). + """ + def __init__(self, data_to_display=None): + self._data_to_display = data_to_display or {} + super().__init__() + + def _create_sdf_unbounded_source_dofn(self): + from apache_beam.io.watermark_estimators import ManualWatermarkEstimator + + class _ManualWatermarkEstimatorProvider(core.WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): + return restriction.watermark or timestamp.MIN_TIMESTAMP + + def create_watermark_estimator(self, estimator_state): + if estimator_state is None: + estimator_state = timestamp.MIN_TIMESTAMP + if not isinstance(estimator_state, timestamp.Timestamp): + estimator_state = timestamp.Timestamp.of(estimator_state) + return ManualWatermarkEstimator(estimator_state) + + def estimator_state_coder(self): + return coders.registry.get_coder(object) + + class SDFUnboundedSourceDoFn(core.DoFn): + def __init__(self, dd): + self._dd = dd + + def display_data(self): + return self._dd + + @core.DoFn.unbounded_per_element() + def process( + self, + element, + restriction_tracker=core.DoFn.RestrictionParam( + _UnboundedSourceRestrictionProvider()), + watermark_estimator=core.DoFn.WatermarkEstimatorParam( + _ManualWatermarkEstimatorProvider()), + bundle_finalizer=core.DoFn.BundleFinalizerParam()): + out = [None] + while restriction_tracker.try_claim(out): + if out[0] is not None: + value, ts, watermark, record_id = out[0] + # Update the watermark estimator + if not isinstance(watermark, timestamp.Timestamp): + watermark = timestamp.Timestamp.of(watermark) + if watermark < timestamp.MIN_TIMESTAMP: + watermark = timestamp.MIN_TIMESTAMP + elif watermark > timestamp.MAX_TIMESTAMP: + watermark = timestamp.MAX_TIMESTAMP + watermark_estimator.set_watermark(watermark) + + if not isinstance(ts, timestamp.Timestamp): + ts = timestamp.Timestamp.of(ts) + yield window.TimestampedValue(value, ts) + out = [None] + else: + # No data currently available, yield control back to the + # runner which will resume this element later. + break + + # Register checkpoint finalization if we have a non-trivial + # checkpoint + current = restriction_tracker.current_restriction() + checkpoint = current.checkpoint + if (checkpoint is not None and + not isinstance(checkpoint, _NoopCheckpointMark)): + bundle_finalizer.register(checkpoint.finalize_checkpoint) + + # Update watermark even if no elements were output + current_watermark = current.watermark + if current_watermark is not None: + if not isinstance(current_watermark, timestamp.Timestamp): + current_watermark = timestamp.Timestamp.of( + current_watermark) + if current_watermark >= timestamp.MIN_TIMESTAMP: + try: + watermark_estimator.set_watermark(current_watermark) + except ValueError: + pass # Watermark must be monotonically increasing + + # If the source is the empty marker, we are done. + # Otherwise signal the runner to resume later. + if not isinstance(current.source, _EmptyUnboundedSource): + yield core.ProcessContinuation.resume() + + return SDFUnboundedSourceDoFn(self._data_to_display) + + def expand(self, pcoll): + return pcoll | core.ParDo(self._create_sdf_unbounded_source_dofn()) + + def get_windowing(self, unused_inputs): + return core.Windowing(window.GlobalWindows()) + + def display_data(self): + return self._data_to_display diff --git a/sdks/python/apache_beam/io/periodic_impulse_source.py b/sdks/python/apache_beam/io/periodic_impulse_source.py new file mode 100644 index 000000000000..359bb31bf2f5 --- /dev/null +++ b/sdks/python/apache_beam/io/periodic_impulse_source.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A native Python streaming IO based on UnboundedSource. + +This module provides ``PeriodicImpulseSource``, an example of a native Python +streaming IO built on top of the ``UnboundedSource`` / ``UnboundedReader`` API. + +It generates timestamp-based "impulses" at a regular interval, making it useful +as a streaming trigger or heartbeat source in Beam pipelines. + +Example usage:: + + import apache_beam as beam + from apache_beam.io.iobase import Read + from apache_beam.io.periodic_impulse_source import PeriodicImpulseSource + + with beam.Pipeline() as p: + impulses = ( + p + | Read(PeriodicImpulseSource( + fire_interval=1.0, # seconds between impulses + max_elements=10)) # stop after 10 impulses + | beam.Map(print)) + +This also serves as an example for how to implement your own custom +``UnboundedSource`` with checkpointing and watermark support. +""" + +import time + +from apache_beam.io.iobase import CheckpointMark +from apache_beam.io.iobase import UnboundedReader +from apache_beam.io.iobase import UnboundedSource +from apache_beam.utils.timestamp import Timestamp + + +class _PeriodicImpulseCheckpointMark(CheckpointMark): + """Checkpoint that tracks how many impulses have been emitted.""" + def __init__(self, count=0): + self.count = count + + def finalize_checkpoint(self): + pass # No external resources to finalize + + +class _PeriodicImpulseReader(UnboundedReader): + """Reader that generates impulses at a regular interval. + + Each impulse is a ``Timestamp`` representing when the impulse fired. + """ + def __init__(self, source, checkpoint_mark=None): + self._source = source + self._count = checkpoint_mark.count if checkpoint_mark else 0 + self._current = None + self._current_ts = None + + def start(self): + return self._try_produce() + + def advance(self): + if self._source.fire_interval > 0: + time.sleep(self._source.fire_interval) + return self._try_produce() + + def _try_produce(self): + if (self._source.max_elements is not None and + self._count >= self._source.max_elements): + return False + now = Timestamp.now() + self._current = self._count + self._current_ts = now + self._count += 1 + return True + + def get_current(self): + if self._current is None: + raise StopIteration('No current element.') + return self._current + + def get_current_timestamp(self): + return self._current_ts or Timestamp.now() + + def get_current_record_id(self): + return str(self._current).encode('utf-8') + + def get_watermark(self): + if (self._source.max_elements is not None and + self._count >= self._source.max_elements): + from apache_beam.utils.timestamp import MAX_TIMESTAMP + return MAX_TIMESTAMP + return Timestamp.now() + + def get_checkpoint_mark(self): + return _PeriodicImpulseCheckpointMark(self._count) + + def get_current_source(self): + return self._source + + def get_split_backlog_bytes(self): + if self._source.max_elements is not None: + return max(0, self._source.max_elements - self._count) + return UnboundedReader.BACKLOG_UNKNOWN + + def close(self): + pass + + +class PeriodicImpulseSource(UnboundedSource): + """An ``UnboundedSource`` that generates periodic impulses. + + Each output element is an integer sequence number. The associated timestamp + is the wall-clock time when the impulse was generated. + + This source supports checkpointing: if the pipeline is interrupted and + resumed, it continues from where it left off based on the element count. + + Args: + fire_interval: Seconds between each impulse (default 1.0). + max_elements: If set, stop after this many elements. If None, runs + indefinitely (truly unbounded). + """ + def __init__(self, fire_interval=1.0, max_elements=None): + self.fire_interval = fire_interval + self.max_elements = max_elements + + def split(self, desired_num_splits, pipeline_options=None): + """This source does not split — it is a single logical stream.""" + return [self] + + def create_reader(self, pipeline_options, checkpoint_mark=None): + return _PeriodicImpulseReader(self, checkpoint_mark) + + def requires_deduping(self): + return False diff --git a/sdks/python/apache_beam/io/unbounded_source_test.py b/sdks/python/apache_beam/io/unbounded_source_test.py new file mode 100644 index 000000000000..aca01d523c73 --- /dev/null +++ b/sdks/python/apache_beam/io/unbounded_source_test.py @@ -0,0 +1,567 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for UnboundedSource implemented as a Splittable DoFn wrapper.""" + +import logging +import unittest + +import apache_beam as beam +from apache_beam.io.iobase import CheckpointMark +from apache_beam.io.iobase import NOOP_CHECKPOINT_MARK +from apache_beam.io.iobase import UnboundedReader +from apache_beam.io.iobase import UnboundedSource +from apache_beam.io.iobase import _EmptyUnboundedSource +from apache_beam.io.iobase import _NoopCheckpointMark +from apache_beam.io.iobase import _UnboundedSourceRestriction +from apache_beam.io.iobase import _UnboundedSourceRestrictionCoder +from apache_beam.io.iobase import _UnboundedSourceRestrictionProvider +from apache_beam.io.iobase import _UnboundedSourceRestrictionTracker +from apache_beam.io.iobase import Read +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + + +# ---- Test CheckpointMark ---- + +class _CounterCheckpointMark(CheckpointMark): + """A simple checkpoint mark that stores the count of elements read.""" + def __init__(self, count=0): + self.count = count + self.finalized = False + + def finalize_checkpoint(self): + self.finalized = True + + +# ---- Test UnboundedReader ---- + +class _CounterUnboundedReader(UnboundedReader): + """An UnboundedReader that generates a finite sequence of integers. + Used for testing the SDF wrapper. + """ + def __init__(self, source, checkpoint_mark=None): + self._source = source + self._count = checkpoint_mark.count if checkpoint_mark else 0 + self._current = None + self._started = False + + def start(self): + self._started = True + if self._count < self._source.num_elements: + self._current = self._count + self._count += 1 + return True + return False + + def advance(self): + if self._count < self._source.num_elements: + self._current = self._count + self._count += 1 + return True + return False + + def get_current(self): + if self._current is None: + raise StopIteration('No current element.') + return self._current + + def get_current_timestamp(self): + return Timestamp.of(self._current or 0) + + def get_current_record_id(self): + return str(self._current).encode('utf-8') + + def get_watermark(self): + if self._count >= self._source.num_elements: + return MAX_TIMESTAMP + return Timestamp.of(self._count) + + def get_checkpoint_mark(self): + return _CounterCheckpointMark(self._count) + + def get_current_source(self): + return self._source + + def get_split_backlog_bytes(self): + remaining = self._source.num_elements - self._count + return max(0, remaining) + + def close(self): + pass + + +# ---- Test UnboundedSource ---- + +class _CounterUnboundedSource(UnboundedSource): + """An UnboundedSource that produces a finite sequence of integers. + Used for testing the SDF wrapper end-to-end. + """ + def __init__(self, num_elements, start=0): + self.num_elements = num_elements + self.start = start + + def split(self, desired_num_splits, pipeline_options=None): + # For simplicity, don't split further + return [self] + + def create_reader(self, pipeline_options, checkpoint_mark=None): + return _CounterUnboundedReader(self, checkpoint_mark) + + def requires_deduping(self): + return False + + +class _SplittableCounterSource(UnboundedSource): + """An UnboundedSource that can split into multiple sub-sources.""" + def __init__(self, num_elements, num_splits=1, split_index=0): + self.num_elements = num_elements + self.num_splits = num_splits + self.split_index = split_index + + def split(self, desired_num_splits, pipeline_options=None): + actual_splits = min(desired_num_splits, self.num_elements) + if actual_splits <= 1: + return [self] + sources = [] + per_split = self.num_elements // actual_splits + for i in range(actual_splits): + start = i * per_split + end = (i + 1) * per_split if i < actual_splits - 1 else self.num_elements + sources.append( + _RangeCounterSource(start, end)) + return sources + + def create_reader(self, pipeline_options, checkpoint_mark=None): + return _CounterUnboundedReader(self, checkpoint_mark) + + +class _RangeCounterSource(UnboundedSource): + """A source that generates integers in a specific range.""" + def __init__(self, start, end): + self.start_val = start + self.end_val = end + self.num_elements = end - start + + def split(self, desired_num_splits, pipeline_options=None): + return [self] + + def create_reader(self, pipeline_options, checkpoint_mark=None): + return _RangeCounterReader(self, checkpoint_mark) + + +class _RangeCounterReader(UnboundedReader): + def __init__(self, source, checkpoint_mark=None): + self._source = source + self._offset = checkpoint_mark.count if checkpoint_mark else 0 + self._current = None + + def start(self): + actual_pos = self._source.start_val + self._offset + if actual_pos < self._source.end_val: + self._current = actual_pos + self._offset += 1 + return True + return False + + def advance(self): + actual_pos = self._source.start_val + self._offset + if actual_pos < self._source.end_val: + self._current = actual_pos + self._offset += 1 + return True + return False + + def get_current(self): + return self._current + + def get_current_timestamp(self): + return Timestamp.of(self._current or 0) + + def get_current_record_id(self): + return str(self._current).encode('utf-8') + + def get_watermark(self): + actual_pos = self._source.start_val + self._offset + if actual_pos >= self._source.end_val: + return MAX_TIMESTAMP + return Timestamp.of(actual_pos) + + def get_checkpoint_mark(self): + return _CounterCheckpointMark(self._offset) + + def get_current_source(self): + return self._source + + def close(self): + pass + + +# ==== Test classes ==== + +class CheckpointMarkTest(unittest.TestCase): + def test_noop_checkpoint_mark(self): + """NoopCheckpointMark should be finalizable without error.""" + NOOP_CHECKPOINT_MARK.finalize_checkpoint() + + def test_custom_checkpoint_mark(self): + """Custom CheckpointMark should support finalization.""" + mark = _CounterCheckpointMark(42) + self.assertEqual(mark.count, 42) + self.assertFalse(mark.finalized) + mark.finalize_checkpoint() + self.assertTrue(mark.finalized) + + +class UnboundedSourceBaseClassTest(unittest.TestCase): + def test_is_bounded(self): + source = _CounterUnboundedSource(10) + self.assertFalse(source.is_bounded()) + + def test_requires_deduping_default(self): + source = _CounterUnboundedSource(10) + self.assertFalse(source.requires_deduping()) + + def test_split(self): + source = _CounterUnboundedSource(10) + splits = source.split(3) + self.assertEqual(len(splits), 1) + + def test_create_reader(self): + source = _CounterUnboundedSource(10) + reader = source.create_reader(None) + self.assertIsInstance(reader, UnboundedReader) + + +class UnboundedReaderTest(unittest.TestCase): + def test_reader_produces_elements(self): + source = _CounterUnboundedSource(5) + reader = source.create_reader(None) + elements = [] + if reader.start(): + elements.append(reader.get_current()) + while reader.advance(): + elements.append(reader.get_current()) + self.assertEqual(elements, [0, 1, 2, 3, 4]) + + def test_reader_empty_source(self): + source = _CounterUnboundedSource(0) + reader = source.create_reader(None) + self.assertFalse(reader.start()) + + def test_reader_checkpoint_resume(self): + source = _CounterUnboundedSource(10) + reader = source.create_reader(None) + # Read first 3 elements + reader.start() + reader.advance() + reader.advance() + checkpoint = reader.get_checkpoint_mark() + self.assertEqual(checkpoint.count, 3) + + # Resume from checkpoint + reader2 = source.create_reader(None, checkpoint) + elements = [] + if reader2.start(): + elements.append(reader2.get_current()) + while reader2.advance(): + elements.append(reader2.get_current()) + self.assertEqual(elements, [3, 4, 5, 6, 7, 8, 9]) + + def test_reader_watermark(self): + source = _CounterUnboundedSource(5) + reader = source.create_reader(None) + reader.start() + watermark = reader.get_watermark() + self.assertIsNotNone(watermark) + + def test_reader_record_id(self): + source = _CounterUnboundedSource(5) + reader = source.create_reader(None) + reader.start() + record_id = reader.get_current_record_id() + self.assertIsInstance(record_id, bytes) + + def test_reader_timestamp(self): + source = _CounterUnboundedSource(5) + reader = source.create_reader(None) + reader.start() + ts = reader.get_current_timestamp() + self.assertIsInstance(ts, Timestamp) + + def test_reader_backlog(self): + source = _CounterUnboundedSource(10) + reader = source.create_reader(None) + reader.start() + backlog = reader.get_split_backlog_bytes() + self.assertGreater(backlog, 0) + + +class UnboundedSourceRestrictionTest(unittest.TestCase): + def test_restriction_creation(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + self.assertEqual(restriction.source, source) + self.assertIsNone(restriction.checkpoint) + self.assertEqual(restriction.watermark, MIN_TIMESTAMP) + + def test_restriction_with_checkpoint(self): + source = _CounterUnboundedSource(10) + checkpoint = _CounterCheckpointMark(5) + restriction = _UnboundedSourceRestriction( + source, checkpoint, Timestamp.of(5)) + self.assertEqual(restriction.checkpoint, checkpoint) + self.assertEqual(restriction.watermark, Timestamp.of(5)) + + +class UnboundedSourceRestrictionTrackerTest(unittest.TestCase): + def test_try_claim_produces_elements(self): + source = _CounterUnboundedSource(3) + restriction = _UnboundedSourceRestriction(source) + tracker = _UnboundedSourceRestrictionTracker(restriction) + + elements = [] + out = [None] + while tracker.try_claim(out): + if out[0] is not None: + value, ts, watermark, record_id = out[0] + elements.append(value) + out = [None] + else: + break + self.assertEqual(elements, [0, 1, 2]) + + def test_current_progress(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + tracker = _UnboundedSourceRestrictionTracker(restriction) + + out = [None] + tracker.try_claim(out) # triggers reader creation + progress = tracker.current_progress() + self.assertIsNotNone(progress) + + def test_try_split(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + tracker = _UnboundedSourceRestrictionTracker(restriction) + + # Read some elements first + out = [None] + tracker.try_claim(out) + tracker.try_claim(out) + + # Now split + result = tracker.try_split(0.5) + self.assertIsNotNone(result) + primary, residual = result + self.assertIsInstance(primary.source, _EmptyUnboundedSource) + self.assertIsNotNone(residual.checkpoint) + + def test_try_split_before_start_returns_none(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + tracker = _UnboundedSourceRestrictionTracker(restriction) + + result = tracker.try_split(0.5) + self.assertIsNone(result) + + def test_is_bounded(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + tracker = _UnboundedSourceRestrictionTracker(restriction) + self.assertFalse(tracker.is_bounded()) + + +class UnboundedSourceRestrictionCoderTest(unittest.TestCase): + def test_encode_decode_roundtrip(self): + source = _CounterUnboundedSource(10) + checkpoint = _CounterCheckpointMark(5) + restriction = _UnboundedSourceRestriction( + source, checkpoint, Timestamp.of(5)) + + coder = _UnboundedSourceRestrictionCoder() + encoded = coder.encode(restriction) + decoded = coder.decode(encoded) + + self.assertIsInstance(decoded, _UnboundedSourceRestriction) + self.assertIsInstance(decoded.source, _CounterUnboundedSource) + self.assertEqual(decoded.source.num_elements, 10) + self.assertEqual(decoded.checkpoint.count, 5) + + def test_encode_decode_no_checkpoint(self): + source = _CounterUnboundedSource(10) + restriction = _UnboundedSourceRestriction(source) + + coder = _UnboundedSourceRestrictionCoder() + encoded = coder.encode(restriction) + decoded = coder.decode(encoded) + + self.assertIsInstance(decoded, _UnboundedSourceRestriction) + self.assertIsNone(decoded.checkpoint) + + +class UnboundedSourceRestrictionProviderTest(unittest.TestCase): + def test_initial_restriction(self): + source = _CounterUnboundedSource(10) + provider = _UnboundedSourceRestrictionProvider() + restriction = provider.initial_restriction(source) + self.assertIsInstance(restriction, _UnboundedSourceRestriction) + self.assertEqual(restriction.source, source) + self.assertIsNone(restriction.checkpoint) + + def test_create_tracker(self): + source = _CounterUnboundedSource(10) + provider = _UnboundedSourceRestrictionProvider() + restriction = provider.initial_restriction(source) + tracker = provider.create_tracker(restriction) + self.assertIsInstance(tracker, _UnboundedSourceRestrictionTracker) + + def test_split_initial(self): + source = _CounterUnboundedSource(10) + provider = _UnboundedSourceRestrictionProvider() + restriction = provider.initial_restriction(source) + splits = list(provider.split(source, restriction)) + # CounterUnboundedSource doesn't split, so we should get 1 restriction + self.assertEqual(len(splits), 1) + + def test_split_with_checkpoint_does_not_split(self): + source = _CounterUnboundedSource(10) + checkpoint = _CounterCheckpointMark(5) + restriction = _UnboundedSourceRestriction(source, checkpoint) + provider = _UnboundedSourceRestrictionProvider() + splits = list(provider.split(source, restriction)) + self.assertEqual(len(splits), 1) + self.assertEqual(splits[0].checkpoint, checkpoint) + + def test_restriction_size(self): + source = _CounterUnboundedSource(10) + provider = _UnboundedSourceRestrictionProvider() + restriction = provider.initial_restriction(source) + size = provider.restriction_size(source, restriction) + self.assertGreater(size, 0) + + def test_truncate_returns_none(self): + source = _CounterUnboundedSource(10) + provider = _UnboundedSourceRestrictionProvider() + restriction = provider.initial_restriction(source) + result = provider.truncate(source, restriction) + self.assertIsNone(result) + + def test_invalid_source_type(self): + provider = _UnboundedSourceRestrictionProvider() + with self.assertRaises(RuntimeError): + provider.initial_restriction("not a source") + + +class EmptyUnboundedSourceTest(unittest.TestCase): + def test_empty_reader(self): + source = _EmptyUnboundedSource() + reader = source.create_reader(None) + self.assertFalse(reader.start()) + self.assertFalse(reader.advance()) + self.assertEqual(reader.get_watermark(), MAX_TIMESTAMP) + + def test_empty_reader_checkpoint(self): + checkpoint = _CounterCheckpointMark(5) + source = _EmptyUnboundedSource() + reader = source.create_reader(None, checkpoint) + mark = reader.get_checkpoint_mark() + self.assertEqual(mark, checkpoint) + + +class ReadUnboundedSourceEndToEndTest(unittest.TestCase): + """End-to-end tests using Read(UnboundedSource) via the DirectRunner.""" + + def test_read_unbounded_source_simple(self): + """Read from a simple counter unbounded source.""" + with TestPipeline() as p: + result = p | Read(_CounterUnboundedSource(5)) + assert_that(result, equal_to([0, 1, 2, 3, 4])) + + def test_read_unbounded_source_empty(self): + """Read from an empty unbounded source.""" + with TestPipeline() as p: + result = p | Read(_CounterUnboundedSource(0)) + assert_that(result, equal_to([])) + + def test_read_unbounded_source_with_map(self): + """Read from unbounded source and apply a Map transform.""" + with TestPipeline() as p: + result = ( + p + | Read(_CounterUnboundedSource(5)) + | beam.Map(lambda x: x * 2)) + assert_that(result, equal_to([0, 2, 4, 6, 8])) + + def test_read_unbounded_source_larger(self): + """Read a larger set of elements.""" + n = 100 + with TestPipeline() as p: + result = p | Read(_CounterUnboundedSource(n)) + assert_that(result, equal_to(list(range(n)))) + + +class PeriodicImpulseSourceEndToEndTest(unittest.TestCase): + """End-to-end tests for the native Python streaming IO example.""" + + def test_periodic_impulse_finite(self): + """PeriodicImpulseSource with max_elements produces the right count.""" + from apache_beam.io.periodic_impulse_source import PeriodicImpulseSource + with TestPipeline() as p: + result = ( + p + | Read(PeriodicImpulseSource(fire_interval=0, max_elements=5)) + | beam.Map(lambda x: x)) + assert_that(result, equal_to([0, 1, 2, 3, 4])) + + def test_periodic_impulse_empty(self): + """PeriodicImpulseSource with max_elements=0 produces nothing.""" + from apache_beam.io.periodic_impulse_source import PeriodicImpulseSource + with TestPipeline() as p: + result = p | Read(PeriodicImpulseSource(fire_interval=0, max_elements=0)) + assert_that(result, equal_to([])) + + def test_periodic_impulse_checkpoint_resume(self): + """PeriodicImpulseSource supports checkpoint/resume at the reader level.""" + from apache_beam.io.periodic_impulse_source import PeriodicImpulseSource + source = PeriodicImpulseSource(fire_interval=0, max_elements=10) + reader = source.create_reader(None) + # Read 3 elements + reader.start() + reader.advance() + reader.advance() + mark = reader.get_checkpoint_mark() + self.assertEqual(mark.count, 3) + + # Resume from checkpoint — should produce 7 more + reader2 = source.create_reader(None, mark) + elements = [] + if reader2.start(): + elements.append(reader2.get_current()) + while reader2.advance(): + elements.append(reader2.get_current()) + self.assertEqual(len(elements), 7) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()