From 51546796aa39c55732e0f5de41d76f3e15273f9f Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 2 Jun 2026 05:35:05 +0000 Subject: [PATCH] [Dataflow Streaming][Multikey] Support MultiKey commits in windmill clients - Add MultiKeyWorkItemCommitRequest to windmill.proto. - Support MultiKey commits in Commit model and StreamingEngineWorkCommitter. - Update GrpcCommitWorkStream to batch and stream MultiKey commit requests. --- .../windmill/client/WindmillStream.java | 5 + .../windmill/client/commits/Commit.java | 34 ++- .../client/commits/CompleteCommit.java | 15 - .../StreamingApplianceWorkCommitter.java | 8 +- .../commits/StreamingEngineWorkCommitter.java | 66 +++-- .../client/grpc/GrpcCommitWorkStream.java | 67 +++-- .../dataflow/worker/FakeWindmillServer.java | 71 ++++- .../StreamingApplianceWorkCommitterTest.java | 9 +- .../StreamingEngineWorkCommitterTest.java | 270 ++++++++++++++++-- .../client/grpc/GrpcCommitWorkStreamTest.java | 70 +++++ .../windmill/src/main/proto/windmill.proto | 28 ++ 11 files changed, 558 insertions(+), 85 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 526b67890783..36001c151508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -108,6 +108,11 @@ boolean commitWorkItem( Windmill.WorkItemCommitRequest request, Consumer onDone); + boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone); + /** Flushes any pending work items to the wire. */ void flush(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java index b840d22a3434..e52a9846645f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** Value class for a queued commit. */ @Internal @@ -32,20 +35,43 @@ public abstract class Commit { public static Commit create( WorkItemCommitRequest request, ComputationState computationState, Work work) { Preconditions.checkArgument(request.getSerializedSize() > 0); - return new AutoValue_Commit(request, computationState, work); + return new AutoValue_Commit( + Optional.of(request), computationState, Optional.empty(), ImmutableList.of(work)); + } + + public static Commit createMultiKey( + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest, + ComputationState computationState, + ImmutableList workBatch) { + Preconditions.checkArgument(!workBatch.isEmpty()); + return new AutoValue_Commit( + Optional.empty(), computationState, Optional.of(multiKeyRequest), workBatch); } public final String computationId() { return computationState().getComputationId(); } - public abstract WorkItemCommitRequest request(); + public abstract Optional singleKeyRequest(); public abstract ComputationState computationState(); - public abstract Work work(); + public abstract Optional multiKeyRequest(); + + public abstract ImmutableList workBatch(); + + public final boolean isFailed() { + for (Work w : workBatch()) { + if (w.isFailed()) { + return true; + } + } + return false; + } public final int getSize() { - return request().getSerializedSize(); + return multiKeyRequest() + .map(Windmill.MultiKeyWorkItemCommitRequest::getSerializedSize) + .orElseGet(() -> singleKeyRequest().get().getSerializedSize()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java index e33e853d3d76..6c0a5a98e2ab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -37,26 +37,11 @@ @AutoValue public abstract class CompleteCommit { - public static CompleteCommit create(Commit commit, CommitStatus commitStatus) { - return new AutoValue_CompleteCommit( - commit.computationId(), - ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()), - WorkId.builder() - .setWorkToken(commit.request().getWorkToken()) - .setCacheToken(commit.request().getCacheToken()) - .build(), - commitStatus); - } - public static CompleteCommit create( String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); } - public static CompleteCommit forFailedWork(Commit commit) { - return create(commit, CommitStatus.ABORTED); - } - public abstract String computationId(); public abstract ShardedKey shardedKey(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 20b95b0661d0..40e82c4ca368 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -112,7 +114,8 @@ private void commitLoop() { } while (commit != null) { ComputationState computationState = commit.computationState(); - commit.work().setState(Work.State.COMMITTING); + checkState(commit.workBatch().size() == 1); + commit.workBatch().get(0).setState(Work.State.COMMITTING); Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = computationRequestMap.get(computationState); if (computationRequestBuilder == null) { @@ -120,7 +123,8 @@ private void commitLoop() { computationRequestBuilder.setComputationId(computationState.getComputationId()); computationRequestMap.put(computationState, computationRequestBuilder); } - computationRequestBuilder.addRequests(commit.request()); + checkState(commit.singleKeyRequest().isPresent()); + computationRequestBuilder.addRequests(commit.singleKeyRequest().get()); // Send the request if we've exceeded the bytes or there is no more // pending work. commitBytes is a long, so this cannot overflow. commitBytes += commit.getSize(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index b68f53121b86..cb8e6d26d089 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -30,6 +30,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.sdk.annotations.Internal; @@ -100,7 +101,7 @@ public void start() { @Override public void commit(Commit commit) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { failCommit(commit); } else { commitQueue.put(commit); @@ -113,8 +114,8 @@ public void commit(Commit commit) { "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}," + " workId={} ].", commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); + commit.workBatch().get(0).getShardedKey(), + commit.workBatch().get(0).id()); drainCommitQueue(); } } @@ -147,8 +148,12 @@ private void drainCommitQueue() { } private void failCommit(Commit commit) { - commit.work().setFailed(); - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + for (Work w : commit.workBatch()) { + w.setFailed(); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), CommitStatus.ABORTED)); + } } @Override @@ -173,8 +178,8 @@ private void streamingCommitLoop() { // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); - if (initialCommit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit)); + if (initialCommit.isFailed()) { + failCommit(initialCommit); initialCommit = null; continue; } @@ -202,20 +207,43 @@ private void streamingCommitLoop() { /** Adds the commit to the batch if it fits, returning true if it is consumed. */ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatcher batcher) { Preconditions.checkNotNull(commit); - commit.work().setState(Work.State.COMMITTING); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMITTING); + } activeCommitBytes.addAndGet(commit.getSize()); - boolean isCommitAccepted = - batcher.commitWorkItem( - commit.computationId(), - commit.request(), - commitStatus -> { - onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); - activeCommitBytes.addAndGet(-commit.getSize()); - }); + boolean isCommitAccepted; + if (commit.multiKeyRequest().isPresent()) { + isCommitAccepted = + batcher.commitMultiKeyWorkItem( + commit.computationId(), + commit.multiKeyRequest().get(), + commitStatus -> { + for (Work w : commit.workBatch()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + } + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } else { + isCommitAccepted = + batcher.commitWorkItem( + commit.computationId(), + commit.singleKeyRequest().get(), + commitStatus -> { + Work w = commit.workBatch().get(0); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } // Since the commit was not accepted, revert the changes made above. if (!isCommitAccepted) { - commit.work().setState(Work.State.COMMIT_QUEUED); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMIT_QUEUED); + } activeCommitBytes.addAndGet(-commit.getSize()); } @@ -246,8 +274,8 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch } // Drop commits for failed work. Such commits will be dropped by Windmill anyway. - if (commit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (commit.isFailed()) { + failCommit(commit); continue; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index d24676652fd8..afa736d7c3ad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -35,6 +35,7 @@ import java.util.function.Function; import javax.annotation.Nullable; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -270,7 +271,7 @@ private void flushInternal(Map requests) if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request().getSerializedSize() + if (elem.getValue().serializedCommit().size() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -289,6 +290,7 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) .setComputationId(pendingRequest.computationId()) .setRequestId(id) .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { @@ -318,7 +320,8 @@ private void issueBatchedRequest(Map requests) chunkBuilder .setRequestId(entry.getKey()) .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); + .setSerializedWorkItemCommit(request.serializedCommit()) + .setCommitType(request.commitType()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { @@ -360,7 +363,8 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) .setRequestId(id) .setSerializedWorkItemCommit(chunk) .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); + .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); @@ -378,34 +382,34 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) @AutoValue abstract static class PendingRequest { - - private static PendingRequest create( - String computationId, WorkItemCommitRequest request, Consumer onDone) { - return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + static PendingRequest create( + String computationId, + long shardingKey, + ByteString serializedCommit, + StreamingCommitRequestChunk.CommitType commitType, + Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest( + computationId, shardingKey, serializedCommit, commitType, onDone); } abstract String computationId(); - abstract WorkItemCommitRequest request(); + abstract long shardingKey(); + + abstract ByteString serializedCommit(); + + abstract StreamingCommitRequestChunk.CommitType commitType(); abstract Consumer onDone(); private long getBytes() { - return (long) request().getSerializedSize() + computationId().length(); - } - - private ByteString serializedCommit() { - return request().toByteString(); + return (long) serializedCommit().size() + computationId().length(); } private void completeWithStatus(CommitStatus commitStatus) { onDone().accept(commitStatus); } - private long shardingKey() { - return request().getShardingKey(); - } - private void abort() { completeWithStatus(CommitStatus.ABORTED); } @@ -462,7 +466,34 @@ public boolean commitWorkItem( return false; } - PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); + PendingRequest request = + PendingRequest.create( + computation, + commitRequest.getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY, + onDone); + add(idGenerator.incrementAndGet(), request); + return true; + } + + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest commitRequest, + Consumer onDone) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + return false; + } + Preconditions.checkArgument(commitRequest.getRequestsCount() > 0); + PendingRequest request = + PendingRequest.create( + computation, + // Any key in the batch for routing + commitRequest.getRequests(0).getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY, + onDone); add(idGenerator.incrementAndGet(), request); return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 5be8ec0a6c72..eec77ccf435b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,6 +36,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -89,6 +89,8 @@ public final class FakeWindmillServer extends WindmillServerStub { private final Map streamingCommitsToOffer; // Keys are work tokens. private final Map commitsReceived; + private final List multiKeyCommitsReceived = + new CopyOnWriteArrayList<>(); private final ArrayList statsReceived; private final LinkedBlockingQueue exceptions; private final AtomicInteger expectedExceptionCount; @@ -118,7 +120,7 @@ public FakeWindmillServer( commitsToOffer = new ResponseQueue() .returnByDefault(CommitWorkResponse.getDefaultInstance()); - streamingCommitsToOffer = new HashMap<>(); + streamingCommitsToOffer = new ConcurrentHashMap<>(); commitsReceived = new ConcurrentHashMap<>(); exceptions = new LinkedBlockingQueue<>(); expectedExceptionCount = new AtomicInteger(); @@ -400,6 +402,7 @@ public void shutdown() {} public RequestBatcher batcher() { return new RequestBatcher() { final List requests = new ArrayList<>(); + final List multiKeyRequests = new ArrayList<>(); @Override public boolean commitWorkItem( @@ -423,6 +426,18 @@ public boolean commitWorkItem( return true; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + LOG.debug("commitWorkStream::commitMultiKeyWorkItem: {}", request); + if (multiKeyRequests.size() > 5) return false; + multiKeyRequests.add(new MultiKeyRequestAndDone(request, onDone)); + flush(); + return true; + } + @Override public void flush() { for (RequestAndDone elem : requests) { @@ -445,6 +460,37 @@ public void flush() { .orElse(Windmill.CommitStatus.OK)); } requests.clear(); + + for (MultiKeyRequestAndDone elem : multiKeyRequests) { + if (dropStreamingCommits) { + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + droppedStreamingCommits.put(workRequest.getWorkToken(), elem.onDone); + } + continue; + } + + multiKeyCommitsReceived.add(elem.request); + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + commitsReceived.put(workRequest.getWorkToken(), workRequest); + } + + // Determine status for the batch. + // Default to OK, but if any of the works in the batch has an offered status, use it. + Windmill.CommitStatus status = Windmill.CommitStatus.OK; + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + Windmill.CommitStatus offeredStatus = + streamingCommitsToOffer.remove( + WorkId.builder() + .setWorkToken(workRequest.getWorkToken()) + .setCacheToken(workRequest.getCacheToken()) + .build()); + if (offeredStatus != null) { + status = offeredStatus; + } + } + elem.onDone.accept(status); + } + multiKeyRequests.clear(); } class RequestAndDone { @@ -456,6 +502,18 @@ class RequestAndDone { this.onDone = onDone; } } + + class MultiKeyRequestAndDone { + final Consumer onDone; + final Windmill.MultiKeyWorkItemCommitRequest request; + + MultiKeyRequestAndDone( + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + this.request = request; + this.onDone = onDone; + } + } }; } @@ -518,6 +576,15 @@ public Map waitForAndGetCommits(int numCommits) { public void clearCommitsReceived() { commitsRequested = 0; commitsReceived.clear(); + multiKeyCommitsReceived.clear(); + } + + public List getMultiKeyCommitsReceived() { + return multiKeyCommitsReceived; + } + + public void clearMultiKeyCommitsReceived() { + multiKeyCommitsReceived.clear(); } public ConcurrentHashMap> waitForDroppedCommits( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 5c3132ae471d..3da740d53361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -129,9 +129,9 @@ public void testCommit() { for (Commit commit : commits) { Windmill.WorkItemCommitRequest request = - committed.get(commit.work().getWorkItem().getWorkToken()); + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); } assertThat(completeCommits).hasSize(commits.size()); @@ -141,12 +141,13 @@ public void testCommit() { (CompleteCommit completeCommit, Commit commit) -> completeCommit.computationId().equals(commit.computationId()) && completeCommit.status() == Windmill.CommitStatus.OK - && completeCommit.workId().equals(commit.work().id()) + && completeCommit.workId().equals(commit.workBatch().get(0).id()) && completeCommit .shardedKey() .equals( ShardedKey.create( - commit.request().getKey(), commit.request().getShardingKey())), + commit.singleKeyRequest().get().getKey(), + commit.singleKeyRequest().get().getShardingKey())), "expected to equal")) .containsExactlyElementsIn(commits); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 01197622c24d..5e5fd9ce6420 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; @@ -62,6 +63,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -134,12 +136,10 @@ private static ComputationState createComputationState(String computationId) { null); } - private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitStatus status) { - if (commit.work().isFailed()) { - return CompleteCommit.forFailedWork(commit); - } - - return CompleteCommit.create(commit, status); + private static CompleteCommit asCompleteCommit( + String computationId, Work work, Windmill.CommitStatus status) { + Windmill.CommitStatus finalStatus = work.isFailed() ? Windmill.CommitStatus.ABORTED : status; + return CompleteCommit.create(computationId, work.getShardedKey(), work.id(), finalStatus); } @Before @@ -186,10 +186,14 @@ public void testCommit_sendsCommitsToStreamingEngine() { waitForExpectedSetSize(completeCommits, 5); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -224,14 +228,24 @@ public void testCommit_handlesFailedCommits() { waitForExpectedSetSize(completeCommits, 10); for (Commit commit : commits) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { assertThat(completeCommits) - .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED)); - assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken()); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + Windmill.CommitStatus.ABORTED)); + assertThat(committed) + .doesNotContainKey(commit.workBatch().get(0).getWorkItem().getWorkToken()); } else { - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); assertThat(committed) - .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); + .containsEntry( + commit.workBatch().get(0).getWorkItem().getWorkToken(), + commit.singleKeyRequest().get()); } } @@ -282,11 +296,16 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); assertThat(completeCommits) - .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + expectedCommitStatus.get(commit.workBatch().get(0).id()))); } workCommitter.stop(); @@ -313,6 +332,14 @@ public boolean commitWorkItem( return false; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + return false; + } + @Override public void flush() {} }; @@ -370,7 +397,7 @@ public void shutdown() {} } for (Commit commit : commits) { - assertTrue(commit.work().isFailed()); + assertTrue(commit.isFailed()); } } @@ -409,10 +436,14 @@ public void testMultipleCommitSendersSingleStream() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -474,4 +505,201 @@ public void testStop_drainsCommitQueue_concurrentCommit() waitForExpectedSetSize(completeCommits, sentCommits.intValue()); } + + @Test + public void testCommit_multiKeyCommitFailedWork() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + // Mark non-primary key B as failed + workB.setFailed(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // The entire batch must be aborted immediately without making network calls + waitForExpectedSetSize(completeCommits, 3); + + // Verify all three works are aborted individually + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.ABORTED), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.ABORTED), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.ABORTED)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitSuccess() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received all 3 work requests in multiKeyCommitsReceived + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works are completed successfully + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.OK), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.OK), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.OK)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitStatusNotOK() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + // Offer NOT_FOUND status for one of the works. + fakeWindmillServer.whenCommitWorkStreamCalled().put(workB.id(), CommitStatus.NOT_FOUND); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received the multi-key commit + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works in the multi-key commit are completed with NOT_FOUND status + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.NOT_FOUND), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.NOT_FOUND), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.NOT_FOUND)); + + workCommitter.stop(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index e9fd55fa5668..b83890c1dbdd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -1133,6 +1133,76 @@ public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_halfClo assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); } + @Test + public void testCommit_multiKeyCommit() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + // 1. Construct two individual WorkItemCommitRequests + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + // 2. Wrap them into a MultiKeyWorkItemCommitRequest + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + // 3. Commit the multi-key work item using the request batcher + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + // 4. Receive and assert request properties on FakeWindmillGrpcService + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + + // Assert that the commit type is correctly identified as COMMIT_TYPE_MULTI_KEY + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + + // Assert that the routing sharding key is mapped to the first request's sharding key + assertThat(chunk.getShardingKey()).isEqualTo(request1.getShardingKey()); + + // Assert that the serialized payload matches the input multiKeyRequest + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(chunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + // 5. Respond with the generated requestId to complete the commit + long requestId = chunk.getRequestId(); + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + // 6. Verify callback completed successfully with CommitStatus.OK + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { try { FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 1da7ef9be8bb..9abe23f58c89 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -421,6 +421,11 @@ message WatermarkHold { optional string state_family = 4; } +message Uint128Proto { + required fixed64 high = 1; + required fixed64 low = 2; +} + // Proto describing a hot key detected on a given WorkItem. message HotKeyInfo { // The age of the hot key measured from when it was first detected. @@ -671,9 +676,24 @@ message WorkItemCommitRequest { reserved 6, 23; } +message MultiKeyWorkItemCommitRequest { + optional Uint128Proto key_group = 7; + + repeated WorkItemCommitRequest requests = 1; + + repeated OutputMessageBundle output_messages = 2; + + repeated PubSubMessageBundle pubsub_messages = 3; + + repeated int64 finalize_ids = 4 [packed = true]; + + reserved 6; +} + message ComputationCommitWorkRequest { required string computation_id = 1; repeated WorkItemCommitRequest requests = 2; + repeated MultiKeyWorkItemCommitRequest multi_key_requests = 3; } message CommitWorkRequest { @@ -899,6 +919,14 @@ message StreamingCommitRequestChunk { // before handing off to the WindmillHost for processing. optional int64 remaining_bytes_for_work_item = 4; optional bytes serialized_work_item_commit = 5; + + enum CommitType { + COMMIT_TYPE_UNSPECIFIED = 0; + COMMIT_TYPE_SINGLE_KEY = 1; + COMMIT_TYPE_MULTI_KEY = 2; + } + + optional CommitType commit_type = 7; } message StreamingCommitResponse {