diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 4d070da995b3..4d04c97a1eca 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -178,6 +178,9 @@ public final class StreamingDataflowWorker { // Experiment make the monitor within BoundedQueueExecutor fair public static final String BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT = "windmill_bounded_queue_executor_use_fair_monitor"; + // Don't use. Experiment guarding multi key bundles. The feature is work in progress and + // incomplete. + private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; private final WindmillStateCache stateCache; private AtomicReference statusPages = new AtomicReference<>(); @@ -1017,6 +1020,8 @@ private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, l private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarnessOptions options) { boolean useFairMonitor = DataflowRunner.hasExperiment(options, BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT); + boolean useKeyGroupWorkQueue = + DataflowRunner.hasExperiment(options, UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); return new BoundedQueueExecutor( chooseMaxThreads(options), THREAD_EXPIRATION_TIME_SEC, @@ -1024,7 +1029,8 @@ private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarness chooseMaxBundlesOutstanding(options), chooseMaxBytesOutstanding(options), new ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(), - useFairMonitor); + useFairMonitor, + useKeyGroupWorkQueue); } public static void main(String[] args) throws Exception { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java index ecaa673f5570..432227c9253f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.Objects; +import java.util.Optional; import java.util.function.BiConsumer; import org.apache.beam.runners.dataflow.worker.util.ExceptionUtils; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -62,11 +63,11 @@ public void run(BoundedQueueExecutorWorkHandle handle) { } } - public final WorkId id() { + public WorkId id() { return work().id(); } - public final Windmill.WorkItem getWorkItem() { + public Windmill.WorkItem getWorkItem() { return work().getWorkItem(); } @@ -74,4 +75,12 @@ public final Windmill.WorkItem getWorkItem() { public String toString() { return "ExecutableWork{" + id() + "}"; } + + public String getComputationId() { + return work().getComputationId(); + } + + public Optional getKeyGroup() { + return work().getKeyGroup(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index cb01e1e508ce..9dbaf1bb519c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -25,6 +25,7 @@ import java.util.IntSummaryStatistics; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -52,6 +53,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -74,6 +76,7 @@ public final class Work implements RefreshableWork { private final Instant startTime; private final Map totalDurationPerState; private final WorkId id; + private final Optional keyGroup; private final String latencyTrackingId; private final long serializedWorkItemSize; private volatile TimedState currentState; @@ -101,6 +104,11 @@ private Work( // keyUniverse inside EnumMap every time. this.totalDurationPerState = new EnumMap<>(EMPTY_ENUM_MAP); this.id = WorkId.of(workItem); + this.keyGroup = + workItem.hasKeyGroup() + ? Optional.of( + KeyGroup.create(workItem.getKeyGroup().getHigh(), workItem.getKeyGroup().getLow())) + : Optional.empty(); this.latencyTrackingId = Long.toHexString(workItem.getShardingKey()) + '-' @@ -383,6 +391,14 @@ private boolean isCommitPending() { abstract Instant startTime(); } + public String getComputationId() { + return processingContext.computationId(); + } + + public Optional getKeyGroup() { + return keyGroup; + } + @AutoValue public abstract static class ProcessingContext { @@ -416,4 +432,48 @@ private Optional fetchKeyedState(KeyedGetDataRequest reque return Optional.ofNullable(getDataClient().getStateData(computationId(), request)); } } + + public static final class KeyGroup { + private final long high; + private final long low; + + private KeyGroup(long high, long low) { + this.high = high; + this.low = low; + } + + public static KeyGroup create(long high, long low) { + return new KeyGroup(high, low); + } + + public long high() { + return high; + } + + public long low() { + return low; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof KeyGroup)) { + return false; + } + KeyGroup other = (KeyGroup) o; + return high == other.high && low == other.low; + } + + @Override + public int hashCode() { + return Objects.hash(high, low); + } + + @Override + public String toString() { + return "KeyGroup{" + "high=" + high + ", low=" + low + '}'; + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index c6fd96e0a4cb..57046147e204 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import java.util.Optional; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; @@ -29,6 +30,7 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; @@ -85,7 +87,8 @@ public BoundedQueueExecutor( int maximumElementsOutstanding, long maximumBytesOutstanding, ThreadFactory threadFactory, - boolean useFairMonitor) { + boolean useFairMonitor, + boolean useKeyGroupWorkQueue) { this.maximumPoolSize = initialMaximumPoolSize; monitor = new Monitor(useFairMonitor); executor = @@ -94,7 +97,7 @@ public BoundedQueueExecutor( initialMaximumPoolSize, keepAliveTime, unit, - new LinkedBlockingQueue<>(), + useKeyGroupWorkQueue ? new KeyGroupWorkQueue() : new LinkedBlockingQueue<>(), threadFactory) { @Override protected void beforeExecute(Thread t, Runnable r) { @@ -313,7 +316,7 @@ public synchronized void close() { } } - private static final class QueuedWork implements Runnable { + static final class QueuedWork implements Runnable { private final ExecutableWork work; private final BoundedQueueExecutorWorkHandleImpl handle; @@ -378,6 +381,23 @@ BoundedQueueExecutorWorkHandleImpl createBudgetHandle(int elements, long bytes) return new BoundedQueueExecutorWorkHandleImpl(elements, bytes); } + /** Poll work for a specific computationId and keyGroup. */ + public Optional pollWork( + String computationId, Work.KeyGroup keyGroup, BoundedQueueExecutorWorkHandle handle) { + checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl); + BoundedQueueExecutorWorkHandleImpl internalHandle = (BoundedQueueExecutorWorkHandleImpl) handle; + if (!(executor.getQueue() instanceof KeyGroupWorkQueue)) { + return Optional.empty(); + } + QueuedWork queuedWork = + ((KeyGroupWorkQueue) executor.getQueue()).pollWork(computationId, keyGroup); + if (queuedWork == null) { + return Optional.empty(); + } + internalHandle.merge(queuedWork.getHandle()); + return Optional.of(queuedWork.getWork()); + } + private void decrementCounters(int elements, long bytes) { // All threads queue decrements and one thread grabs the monitor and updates // counters. We do this to reduce contention on monitor which is locked by diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java new file mode 100644 index 000000000000..50494edf70bf --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import java.util.AbstractQueue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A custom, thread-safe doubly-linked BlockingQueue. In addition to global FIFO ordering, the queue + * supports polling work by computation + key group in FIFO order + */ +class KeyGroupWorkQueue extends AbstractQueue implements BlockingQueue { + + static class Node { + final @Nullable Runnable task; + final @Nullable String computationId; + final Work.@Nullable KeyGroup keyGroup; + + // prevNode, nextNode are used for the global order across all queued Runnables + @Nullable Node prevNode; + @Nullable Node nextNode; + + // prevKeyGroupNode and nextKeyGroupNode are used for the keyGroup level lists linking + // QueuedWork with same keyGroup + @Nullable Node prevKeyGroupNode; + @Nullable Node nextKeyGroupNode; + + Node(@Nullable Runnable task) { + this.task = task; + if (task instanceof QueuedWork) { + this.computationId = ((QueuedWork) task).getWork().getComputationId(); + this.keyGroup = ((QueuedWork) task).getWork().getKeyGroup().orElse(null); + } else { + this.computationId = null; + this.keyGroup = null; + } + } + } + + /** Double linked list implementing key group level queue */ + private static class KeyGroupWorkList { + final Node head; + final Node tail; + + KeyGroupWorkList() { + head = new Node(null); + tail = new Node(null); + head.nextKeyGroupNode = tail; + tail.prevKeyGroupNode = head; + } + + boolean isEmpty() { + return head.nextKeyGroupNode == tail; + } + + void append(Node node) { + @Nullable Node last = tail.prevKeyGroupNode; + if (last == null) { + throw new NullPointerException("tail.prevComp is null"); + } + node.prevKeyGroupNode = last; + node.nextKeyGroupNode = tail; + last.nextKeyGroupNode = node; + tail.prevKeyGroupNode = node; + } + + void remove(Node node) { + @Nullable Node prev = node.prevKeyGroupNode; + @Nullable Node next = node.nextKeyGroupNode; + if (prev != null && next != null) { + prev.nextKeyGroupNode = next; + next.prevKeyGroupNode = prev; + node.prevKeyGroupNode = null; + node.nextKeyGroupNode = null; + } + } + } + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition notEmpty = lock.newCondition(); + + // Sentinels for the global list + private final Node globalHead = new Node(null); + private final Node globalTail = new Node(null); + + private final Map keyGroupQueueMap = new HashMap<>(); + + private int size = 0; + + public KeyGroupWorkQueue() { + globalHead.nextNode = globalTail; + globalTail.prevNode = globalHead; + } + + private void unlinkNode(Node node) { + // 1. Unlink from global list + Node prevG = node.prevNode; + Node nextG = node.nextNode; + if (prevG != null && nextG != null) { + prevG.nextNode = nextG; + nextG.prevNode = prevG; + } + node.prevNode = null; + node.nextNode = null; + + // 2. Unlink from key group list + if (node.computationId != null) { + QueueKey key = QueueKey.create(node.computationId, node.keyGroup); + KeyGroupWorkList keyGroupQueue = keyGroupQueueMap.get(key); + if (keyGroupQueue != null) { + keyGroupQueue.remove(node); + if (keyGroupQueue.isEmpty()) { + keyGroupQueueMap.remove(key); + } + } + } + --size; + } + + private @Nullable Node removeFirstGlobal() { + @Nullable Node first = globalHead.nextNode; + if (first == null || first == globalTail) { + return null; + } + unlinkNode(first); + return first; + } + + /** + * Remove and Return QueuedWork for the computationId, keyGroup in the FIFO order Returns null, if + * there are no matches. + */ + public @Nullable QueuedWork pollWork(String computationId, Work.KeyGroup keyGroup) { + if (computationId == null || keyGroup == null) { + return null; + } + lock.lock(); + try { + QueueKey key = QueueKey.create(computationId, keyGroup); + KeyGroupWorkList keyGroupWorkList = keyGroupQueueMap.get(key); + if (keyGroupWorkList == null || keyGroupWorkList.isEmpty()) { + return null; + } + + // Retrieve the first pending task for this computation and keyGroup in O(1) + @Nullable Node firstNode = keyGroupWorkList.head.nextKeyGroupNode; + if (firstNode == null || firstNode == keyGroupWorkList.tail) { + return null; + } + unlinkNode(firstNode); + + Runnable task = firstNode.task; + if (task == null) { + return null; + } + return (QueuedWork) task; + } finally { + lock.unlock(); + } + } + + @Override + public boolean offer(Runnable runnable) { + if (runnable == null) throw new NullPointerException(); + lock.lock(); + try { + Node node = new Node(runnable); + + // Append to global list tail + @Nullable Node lastG = globalTail.prevNode; + if (lastG == null) { + throw new NullPointerException("globalTail.prevNode is null"); + } + node.prevNode = lastG; + node.nextNode = globalTail; + lastG.nextNode = node; + globalTail.prevNode = node; + + // Append to key group list if applicable + if (node.computationId != null) { + QueueKey key = QueueKey.create(node.computationId, node.keyGroup); + KeyGroupWorkList keyGroupWorkList = + keyGroupQueueMap.computeIfAbsent(key, k -> new KeyGroupWorkList()); + keyGroupWorkList.append(node); + } + + ++size; + notEmpty.signal(); + return true; + } finally { + lock.unlock(); + } + } + + @Override + public void put(Runnable e) throws InterruptedException { + offer(e); // Unbounded queue + } + + @Override + public boolean offer(Runnable e, long timeout, TimeUnit unit) throws InterruptedException { + return offer(e); // Unbounded queue + } + + @Override + public @Nullable Runnable poll() { + lock.lock(); + try { + @Nullable Node node = removeFirstGlobal(); + return (node != null) ? node.task : null; + } finally { + lock.unlock(); + } + } + + @Override + public Runnable take() throws InterruptedException { + lock.lockInterruptibly(); + try { + while (size == 0) { + notEmpty.await(); + } + @Nullable Node node = removeFirstGlobal(); + checkStateNotNull(node, "Queue is empty but size was " + size); + Runnable task = node.task; + checkStateNotNull(task, "Encountered null task in queue"); + return task; + } finally { + lock.unlock(); + } + } + + @Override + public @Nullable Runnable poll(long timeout, TimeUnit unit) throws InterruptedException { + long nanos = unit.toNanos(timeout); + lock.lockInterruptibly(); + try { + while (size == 0) { + if (nanos <= 0) { + return null; + } + nanos = notEmpty.awaitNanos(nanos); + } + @Nullable Node node = removeFirstGlobal(); + return (node != null) ? node.task : null; + } finally { + lock.unlock(); + } + } + + @Override + public @Nullable Runnable peek() { + lock.lock(); + try { + @Nullable Node first = globalHead.nextNode; + if (first == null || first == globalTail) { + return null; + } + return first.task; + } finally { + lock.unlock(); + } + } + + @Override + public int size() { + lock.lock(); + try { + return size; + } finally { + lock.unlock(); + } + } + + @Override + public boolean isEmpty() { + lock.lock(); + try { + return size == 0; + } finally { + lock.unlock(); + } + } + + @Override + public boolean remove(Object o) { + if (o == null) return false; + lock.lock(); + try { + // Walk the global queue in O(N) to find and unlink the node + @Nullable Node curr = globalHead.nextNode; + while (curr != null && curr != globalTail) { + if (o.equals(curr.task)) { + unlinkNode(curr); + return true; + } + curr = curr.nextNode; + } + return false; + } finally { + lock.unlock(); + } + } + + @Override + public boolean contains(Object o) { + if (o == null) return false; + lock.lock(); + try { + @Nullable Node curr = globalHead.nextNode; + while (curr != null && curr != globalTail) { + if (o.equals(curr.task)) { + return true; + } + curr = curr.nextNode; + } + return false; + } finally { + lock.unlock(); + } + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public int drainTo(Collection c) { + return drainTo(c, Integer.MAX_VALUE); + } + + @Override + public int drainTo(Collection c, int maxElements) { + if (c == null) throw new NullPointerException(); + if (c == this) throw new IllegalArgumentException(); + if (maxElements <= 0) return 0; + lock.lock(); + try { + int added = 0; + @Nullable Node curr = globalHead.nextNode; + while (curr != null && curr != globalTail && added < maxElements) { + @Nullable Node next = curr.nextNode; + unlinkNode(curr); + Runnable task = curr.task; + if (task != null) { + c.add(task); + ++added; + } + curr = next; + } + return added; + } finally { + lock.unlock(); + } + } + + @Override + public void clear() { + lock.lock(); + try { + @Nullable Node curr = globalHead.nextNode; + while (curr != null && curr != globalTail) { + @Nullable Node next = curr.nextNode; + unlinkNode(curr); + curr = next; + } + } finally { + lock.unlock(); + } + } + + @Override + public Iterator iterator() { + lock.lock(); + try { + List snapshot = new ArrayList<>(size); + @Nullable Node curr = globalHead.nextNode; + while (curr != null && curr != globalTail) { + if (curr.task != null) { + snapshot.add(curr.task); + } + curr = curr.nextNode; + } + return Collections.unmodifiableList(snapshot).iterator(); + } finally { + lock.unlock(); + } + } + + static final class QueueKey { + private final String computationId; + private final Work.@Nullable KeyGroup keyGroup; + + private QueueKey(String computationId, Work.@Nullable KeyGroup keyGroup) { + this.computationId = Objects.requireNonNull(computationId); + this.keyGroup = keyGroup; + } + + public static QueueKey create(String computationId, Work.@Nullable KeyGroup keyGroup) { + return new QueueKey(computationId, keyGroup); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof QueueKey)) { + return false; + } + QueueKey other = (QueueKey) o; + return computationId.equals(other.computationId) && Objects.equals(keyGroup, other.keyGroup); + } + + @Override + public int hashCode() { + return Objects.hash(computationId, keyGroup); + } + + @Override + public String toString() { + return "QueueKey{" + + "computationId='" + + computationId + + '\'' + + ", keyGroup=" + + keyGroup + + '}'; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d58f20076994..5bcdffcc2564 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3036,7 +3036,8 @@ public void testMaxThreadMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3097,7 +3098,8 @@ public void testActiveThreadMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3167,7 +3169,8 @@ public void testOutstandingBytesMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3241,7 +3244,8 @@ public void testOutstandingBundlesMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index 55fe82c7163c..aa9ce32b3c5c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.BoundedQueueExecutorWorkHandleImpl; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -66,13 +67,30 @@ public static Collection useFairMonitor() { @Rule public transient Timeout globalTimeout = Timeout.seconds(300); private BoundedQueueExecutor executor; + private static final Work.KeyGroup DEFAULT_KEY_GROUP = Work.KeyGroup.create(1, 2); + private static ExecutableWork createWork(Consumer executeWorkFn) { + return createWorkWithCompId("computationId", executeWorkFn); + } + + private static ExecutableWork createWorkWithCompId( + String computationId, Consumer executeWorkFn) { + return createWorkWithCompIdAndKeyGroup(computationId, DEFAULT_KEY_GROUP, executeWorkFn); + } + + private static ExecutableWork createWorkWithCompIdAndKeyGroup( + String computationId, Work.KeyGroup keyGroup, Consumer executeWorkFn) { WorkItem workItem = WorkItem.newBuilder() .setKey(ByteString.EMPTY) .setShardingKey(1) .setWorkToken(33) .setCacheToken(1) + .setKeyGroup( + Windmill.Uint128Proto.newBuilder() + .setHigh(keyGroup.high()) + .setLow(keyGroup.low()) + .build()) .build(); return ExecutableWork.create( Work.create( @@ -80,10 +98,7 @@ private static ExecutableWork createWork(Consumer executeWorkFn) { workItem.getSerializedSize(), Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( - "computationId", - new FakeGetDataClient(), - ignored -> {}, - mock(HeartbeatSender.class)), + computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), false, Instant::now), (work, handle) -> { @@ -116,7 +131,8 @@ public void setUp() { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - useFairMonitor); + useFairMonitor, + /*useKeyGroupWorkQueue=*/ false); } @Test @@ -413,4 +429,129 @@ public void testRenderSummaryHtml() { + "Work Queue Bytes: 0/10000000
/n"; assertEquals(expectedSummaryHtml, executor.summaryHtml()); } + + @Test + public void testPollWork() throws Exception { + // Create separate BoundedQueueExecutor with 1 thread so we can block it easily + BoundedQueueExecutor testExecutor = + new BoundedQueueExecutor( + 1, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("testStealing-%d").setDaemon(true).build(), + useFairMonitor, + /*useKeyGroupWorkQueue=*/ true); + + // 1. Create blocker task to occupy the worker thread + CountDownLatch blockerStart = new CountDownLatch(1); + CountDownLatch blockerStop = new CountDownLatch(1); + ExecutableWork blockerWork = + createWorkWithCompIdAndKeyGroup( + "blockerComp", + DEFAULT_KEY_GROUP, + ignored -> { + blockerStart.countDown(); + try { + blockerStop.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + testExecutor.execute(blockerWork, 0); + blockerStart.await(); + + // 2. Create two distinct key groups + Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); + Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2); + + // Create executable tasks + CountDownLatch targetStart = new CountDownLatch(1); + ExecutableWork work1 = createWorkWithCompIdAndKeyGroup("compA", keyGroup1, ignored -> {}); + ExecutableWork work2 = + createWorkWithCompIdAndKeyGroup( + "compA", + keyGroup2, + ignored -> { + targetStart.countDown(); + }); + + // Enqueue tasks (they will wait in the queue because the thread is blocked) + testExecutor.execute(work1, 100); + testExecutor.execute(work2, 150); + + // Total outstanding elements must be 3 (blocker + work1 + work2) + assertEquals(3, testExecutor.elementsOutstanding()); + + // Steal work2 using pollWork with compA and keyGroup2 + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + java.util.Optional stolen = + testExecutor.pollWork("compA", keyGroup2, stealHandle); + assertTrue(stolen.isPresent()); + assertEquals(work2, stolen.get()); + + // Run the stolen task + stolen.get().run(stealHandle); + targetStart.await(); + } + + // Steal work1 using pollWork with compA and keyGroup1 + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + java.util.Optional stolen = + testExecutor.pollWork("compA", keyGroup1, stealHandle); + assertTrue(stolen.isPresent()); + assertEquals(work1, stolen.get()); + } + + // Unblock the blocker and shut down + blockerStop.countDown(); + testExecutor.shutdown(); + } + + @Test + public void testPollWorkWithLinkedBlockingQueue() throws Exception { + BoundedQueueExecutor testExecutor = + new BoundedQueueExecutor( + 1, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("testLinkedQueue-%d").setDaemon(true).build(), + useFairMonitor, + /* useKeyGroupWorkQueue= */ false); + + CountDownLatch blockerStart = new CountDownLatch(1); + CountDownLatch blockerStop = new CountDownLatch(1); + ExecutableWork blockerWork = + createWorkWithCompIdAndKeyGroup( + "blockerComp", + DEFAULT_KEY_GROUP, + ignored -> { + blockerStart.countDown(); + try { + blockerStop.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + testExecutor.execute(blockerWork, 0); + blockerStart.await(); + + Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1); + ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, ignored -> {}); + testExecutor.execute(work, 100); + + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + java.util.Optional stolen = + testExecutor.pollWork("compA", keyGroup, stealHandle); + assertFalse(stolen.isPresent()); + } + + blockerStop.countDown(); + testExecutor.shutdown(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java new file mode 100644 index 000000000000..45ebab4bca28 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java @@ -0,0 +1,461 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class KeyGroupWorkQueueTest { + + private BoundedQueueExecutor executor; + + @Before + public void setUp() { + executor = + new BoundedQueueExecutor( + 2, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("Test-%d").setDaemon(true).build(), + false, + /*useKeyGroupWorkQueue=*/ true); + } + + private static final Work.KeyGroup DEFAULT_KEY_GROUP = Work.KeyGroup.create(1, 2); + + private QueuedWork createQueuedWork(String computationId, long workBytes) { + return createQueuedWork(computationId, DEFAULT_KEY_GROUP, workBytes); + } + + private QueuedWork createQueuedWork( + String computationId, Work.KeyGroup keyGroup, long workBytes) { + WorkItem workItem = + WorkItem.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setWorkToken(33) + .setCacheToken(1) + .setKeyGroup( + org.apache.beam.runners.dataflow.worker.windmill.Windmill.Uint128Proto.newBuilder() + .setHigh(keyGroup.high()) + .setLow(keyGroup.low()) + .build()) + .build(); + ExecutableWork work = + ExecutableWork.create( + Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + computationId, + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now), + (w, h) -> {}); + return new QueuedWork(work, executor.createBudgetHandle(1, workBytes)); + } + + private static class MockRunnable implements Runnable { + final String id; + + MockRunnable(String id) { + this.id = id; + } + + @Override + public void run() {} + + @Override + public String toString() { + return "MockRunnable(" + id + ")"; + } + } + + @Test + public void testBasicOfferAndPoll() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + + assertTrue(queue.offer(task1)); + assertTrue(queue.offer(task2)); + assertEquals(2, queue.size()); + + assertEquals(task1, queue.poll()); + assertEquals(task2, queue.poll()); + assertNull(queue.poll()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testRemove() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + + queue.offer(task1); + queue.offer(task2); + + assertTrue(queue.remove(task1)); + assertEquals(1, queue.size()); + assertEquals(task2, queue.poll()); + assertFalse(queue.remove(task1)); // Already gone + } + + @Test + public void testDrainTo() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + queue.offer(task1); + queue.offer(task2); + + List drained = new ArrayList<>(); + assertEquals(2, queue.drainTo(drained)); + assertEquals(2, drained.size()); + assertEquals(task1, drained.get(0)); + assertEquals(task2, drained.get(1)); + assertTrue(queue.isEmpty()); + } + + @Test + public void testIteratorSafeTraversalAndImmutable() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + queue.offer(task1); + queue.offer(task2); + + Iterator it = queue.iterator(); + assertTrue(it.hasNext()); + assertEquals(task1, it.next()); + assertTrue(it.hasNext()); + assertEquals(task2, it.next()); + assertFalse(it.hasNext()); + + // Assert that mutating the iterator throws UnsupportedOperationException + it = queue.iterator(); + assertTrue(it.hasNext()); + it.next(); + try { + it.remove(); + fail("Iterator must be immutable"); + } catch (UnsupportedOperationException e) { + // Expected + } + } + + @Test + public void testPollWorkTargeted() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + + QueuedWork workA1 = createQueuedWork("compA", 100); + QueuedWork workB1 = createQueuedWork("compB", 200); + QueuedWork workA2 = createQueuedWork("compA", 150); + + queue.offer(workA1); + queue.offer(workB1); + queue.offer(workA2); + + assertEquals(3, queue.size()); + + // Targeted poll A + QueuedWork polledA1 = queue.pollWork("compA", DEFAULT_KEY_GROUP); + assertNotNull(polledA1); + assertEquals("compA", polledA1.getWork().getComputationId()); + assertEquals(100, polledA1.getHandle().bytes()); + + // Verify size decremented + assertEquals(2, queue.size()); + + // Poll next should be B1 (since A1 was stolen, B1 is now first global) + assertEquals(workB1, queue.poll()); + assertEquals(1, queue.size()); + + // Last should be A2 + assertEquals(workA2, queue.poll()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testMemoryPruningLeavesZeroLeaks() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + QueuedWork workA1 = createQueuedWork("compA", 100); + queue.offer(workA1); + + // Steal A1 + QueuedWork polled = queue.pollWork("compA", DEFAULT_KEY_GROUP); + assertNotNull(polled); + assertTrue(queue.isEmpty()); + + // Offering another work with different computation ID + QueuedWork workB1 = createQueuedWork("compB", 200); + queue.offer(workB1); + assertEquals(1, queue.size()); + + // Steal B1 + QueuedWork polledB = queue.pollWork("compB", DEFAULT_KEY_GROUP); + assertNotNull(polledB); + assertTrue(queue.isEmpty()); + } + + @Test + public void testConcurrentStress() throws InterruptedException, ExecutionException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + final int producerThreads = 4; + final int consumerThreads = 4; + final int tasksPerProducer = 1000; + final int totalTasks = producerThreads * tasksPerProducer; + + ExecutorService executorService = + Executors.newFixedThreadPool(producerThreads + consumerThreads); + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(producerThreads + consumerThreads); + final AtomicInteger consumedCount = new AtomicInteger(0); + List> futures = new ArrayList<>(); + + // Start producers + for (int i = 0; i < producerThreads; i++) { + futures.add( + executorService.submit( + () -> { + try { + startLatch.await(); + for (int j = 0; j < tasksPerProducer; j++) { + String compId = "comp-" + (j % 5); + queue.offer(createQueuedWork(compId, 10)); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + doneLatch.countDown(); + } + })); + } + + // Start consumers (mix of poll and pollWork) + for (int i = 0; i < consumerThreads; i++) { + final int consumerId = i; + futures.add( + executorService.submit( + () -> { + try { + startLatch.await(); + while (consumedCount.get() < totalTasks) { + Runnable task; + if (consumerId % 2 == 0) { + // Targeted poll + String compId = "comp-" + (consumedCount.get() % 5); + task = queue.pollWork(compId, DEFAULT_KEY_GROUP); + } else { + // Global poll + task = queue.poll(); + } + if (task != null) { + consumedCount.incrementAndGet(); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + doneLatch.countDown(); + } + })); + } + + startLatch.countDown(); + assertTrue(doneLatch.await(10, TimeUnit.SECONDS)); + + // Check for exceptions in threads + for (Future future : futures) { + future.get(); + } + + executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); + + assertEquals(0, queue.size()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testTakeBlocksAndWakesUp() throws InterruptedException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + final MockRunnable task = new MockRunnable("take-task"); + final AtomicReference result = new AtomicReference<>(); + final CountDownLatch started = new CountDownLatch(1); + final CountDownLatch finished = new CountDownLatch(1); + + Thread t = + new Thread( + () -> { + started.countDown(); + try { + result.set(queue.take()); + } catch (InterruptedException e) { + // Ignore + } finally { + finished.countDown(); + } + }); + t.setDaemon(true); + t.start(); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + // Give thread a moment to enter await() + Thread.sleep(100); + assertEquals(Thread.State.WAITING, t.getState()); + + queue.offer(task); + + assertTrue(finished.await(2, TimeUnit.SECONDS)); + assertEquals(task, result.get()); + } + + @Test + public void testPollWithTimeout() throws InterruptedException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + final MockRunnable task = new MockRunnable("poll-task"); + final AtomicReference result = new AtomicReference<>(); + final CountDownLatch started = new CountDownLatch(1); + final CountDownLatch finished = new CountDownLatch(1); + + // 1. Verify timeout returns null + Thread t1 = + new Thread( + () -> { + started.countDown(); + try { + result.set(queue.poll(500, TimeUnit.MILLISECONDS)); + } catch (InterruptedException e) { + // Ignore + } finally { + finished.countDown(); + } + }); + t1.setDaemon(true); + t1.start(); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); + assertEquals(Thread.State.TIMED_WAITING, t1.getState()); + + assertTrue(finished.await(2, TimeUnit.SECONDS)); + assertNull(result.get()); + + // 2. Verify timed poll receives task offered concurrently + final CountDownLatch started2 = new CountDownLatch(1); + final CountDownLatch finished2 = new CountDownLatch(1); + final AtomicReference result2 = new AtomicReference<>(); + + Thread t2 = + new Thread( + () -> { + started2.countDown(); + try { + result2.set(queue.poll(2, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + // Ignore + } finally { + finished2.countDown(); + } + }); + t2.setDaemon(true); + t2.start(); + + assertTrue(started2.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); + assertEquals(Thread.State.TIMED_WAITING, t2.getState()); + + queue.offer(task); + + assertTrue(finished2.await(2, TimeUnit.SECONDS)); + assertEquals(task, result2.get()); + } + + @Test + public void testPollWorkWithKeyGroup() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(); + + Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); + Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2); + + QueuedWork workA1 = createQueuedWork("compA", keyGroup1, 100); + QueuedWork workA2 = createQueuedWork("compA", keyGroup2, 150); + + queue.offer(workA1); + queue.offer(workA2); + + assertEquals(2, queue.size()); + + // Poll with keyGroup2 first - should return workA2 + QueuedWork polledA2 = queue.pollWork("compA", keyGroup2); + assertNotNull(polledA2); + assertEquals(workA2, polledA2); + assertEquals(1, queue.size()); + + // Poll with keyGroup2 again - should return null + assertNull(queue.pollWork("compA", keyGroup2)); + + // Poll with keyGroup1 - should return workA1 + QueuedWork polledA1 = queue.pollWork("compA", keyGroup1); + assertNotNull(polledA1); + assertEquals(workA1, polledA1); + assertTrue(queue.isEmpty()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java index 07b4b14fd115..ef0d8e434858 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java @@ -62,7 +62,8 @@ public void setUp() { .setNameFormat("FinalizationCallback-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); cleanupExecutor = Executors.newScheduledThreadPool( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 51bd4816b031..0610ed44c27f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -64,7 +64,8 @@ private static WorkFailureProcessor createWorkFailureProcessor( .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); return WorkFailureProcessor.forTesting(workExecutor, failureTracker, Optional::empty, clock, 0); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java index 88a82c6f76b6..f32282056e4f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java @@ -80,7 +80,8 @@ private static BoundedQueueExecutor workExecutor() { 1, 10000000, new ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); } private static ComputationState createComputationState(int computationIdSuffix) { diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 1da7ef9be8bb..aaa09c105fc3 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -421,6 +421,11 @@ message WatermarkHold { optional string state_family = 4; } +message Uint128Proto { + required fixed64 high = 1; + required fixed64 low = 2; +} + // Proto describing a hot key detected on a given WorkItem. message HotKeyInfo { // The age of the hot key measured from when it was first detected. @@ -448,6 +453,8 @@ message WorkItem { // present, this field includes metadata associated with any hot key. optional HotKeyInfo hot_key_info = 11; + optional Uint128Proto key_group = 18; + reserved 12, 13, 14, 15, 16; }