Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ boolean commitWorkItem(
Windmill.WorkItemCommitRequest request,
Consumer<Windmill.CommitStatus> onDone);

boolean commitMultiKeyWorkItem(
String computation,
Windmill.MultiKeyWorkItemCommitRequest request,
Consumer<Windmill.CommitStatus> onDone);

/** Flushes any pending work items to the wire. */
void flush();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;

import com.google.auto.value.AutoValue;
import java.util.Optional;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/** Value class for a queued commit. */
@Internal
Expand All @@ -32,20 +35,43 @@ public abstract class Commit {
public static Commit create(
WorkItemCommitRequest request, ComputationState computationState, Work work) {
Preconditions.checkArgument(request.getSerializedSize() > 0);
return new AutoValue_Commit(request, computationState, work);
return new AutoValue_Commit(
Optional.of(request), computationState, Optional.empty(), ImmutableList.of(work));
}

public static Commit createMultiKey(
Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest,
ComputationState computationState,
ImmutableList<Work> workBatch) {
Preconditions.checkArgument(!workBatch.isEmpty());
return new AutoValue_Commit(
Optional.empty(), computationState, Optional.of(multiKeyRequest), workBatch);
}

public final String computationId() {
return computationState().getComputationId();
}

public abstract WorkItemCommitRequest request();
public abstract Optional<WorkItemCommitRequest> singleKeyRequest();

public abstract ComputationState computationState();

public abstract Work work();
public abstract Optional<Windmill.MultiKeyWorkItemCommitRequest> multiKeyRequest();

public abstract ImmutableList<Work> workBatch();

public final boolean isFailed() {
for (Work w : workBatch()) {
if (w.isFailed()) {
return true;
}
}
return false;
}

public final int getSize() {
return request().getSerializedSize();
return multiKeyRequest()
.map(Windmill.MultiKeyWorkItemCommitRequest::getSerializedSize)
.orElseGet(() -> singleKeyRequest().get().getSerializedSize());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,11 @@
@AutoValue
public abstract class CompleteCommit {

public static CompleteCommit create(Commit commit, CommitStatus commitStatus) {
return new AutoValue_CompleteCommit(
commit.computationId(),
ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()),
WorkId.builder()
.setWorkToken(commit.request().getWorkToken())
.setCacheToken(commit.request().getCacheToken())
.build(),
commitStatus);
}

public static CompleteCommit create(
String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) {
return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status);
}

public static CompleteCommit forFailedWork(Commit commit) {
return create(commit, CommitStatus.ABORTED);
}

public abstract String computationId();

public abstract ShardedKey shardedKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -112,15 +114,17 @@ private void commitLoop() {
}
while (commit != null) {
ComputationState computationState = commit.computationState();
commit.work().setState(Work.State.COMMITTING);
checkState(commit.workBatch().size() == 1);
commit.workBatch().get(0).setState(Work.State.COMMITTING);
Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder =
computationRequestMap.get(computationState);
if (computationRequestBuilder == null) {
computationRequestBuilder = commitRequestBuilder.addRequestsBuilder();
computationRequestBuilder.setComputationId(computationState.getComputationId());
computationRequestMap.put(computationState, computationRequestBuilder);
}
computationRequestBuilder.addRequests(commit.request());
checkState(commit.singleKeyRequest().isPresent());
computationRequestBuilder.addRequests(commit.singleKeyRequest().get());
// Send the request if we've exceeded the bytes or there is no more
// pending work. commitBytes is a long, so this cannot overflow.
commitBytes += commit.getSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import org.apache.beam.sdk.annotations.Internal;
Expand Down Expand Up @@ -100,7 +101,7 @@ public void start() {

@Override
public void commit(Commit commit) {
if (commit.work().isFailed()) {
if (commit.isFailed()) {
failCommit(commit);
} else {
commitQueue.put(commit);
Expand All @@ -113,8 +114,8 @@ public void commit(Commit commit) {
"Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={},"
+ " workId={} ].",
commit.computationId(),
commit.work().getShardedKey(),
commit.work().id());
commit.workBatch().get(0).getShardedKey(),
commit.workBatch().get(0).id());
drainCommitQueue();
}
}
Expand Down Expand Up @@ -147,8 +148,12 @@ private void drainCommitQueue() {
}

private void failCommit(Commit commit) {
commit.work().setFailed();
onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
for (Work w : commit.workBatch()) {
w.setFailed();
onCommitComplete.accept(
CompleteCommit.create(
commit.computationId(), w.getShardedKey(), w.id(), CommitStatus.ABORTED));
}
}

@Override
Expand All @@ -173,8 +178,8 @@ private void streamingCommitLoop() {
// take() blocks until a value is available in the commitQueue.
Preconditions.checkNotNull(initialCommit);

if (initialCommit.work().isFailed()) {
onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit));
if (initialCommit.isFailed()) {
failCommit(initialCommit);
initialCommit = null;
continue;
}
Expand Down Expand Up @@ -202,20 +207,43 @@ private void streamingCommitLoop() {
/** Adds the commit to the batch if it fits, returning true if it is consumed. */
private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatcher batcher) {
Preconditions.checkNotNull(commit);
commit.work().setState(Work.State.COMMITTING);
for (Work w : commit.workBatch()) {
w.setState(Work.State.COMMITTING);
}
activeCommitBytes.addAndGet(commit.getSize());
boolean isCommitAccepted =
batcher.commitWorkItem(
commit.computationId(),
commit.request(),
commitStatus -> {
onCommitComplete.accept(CompleteCommit.create(commit, commitStatus));
activeCommitBytes.addAndGet(-commit.getSize());
});
boolean isCommitAccepted;
if (commit.multiKeyRequest().isPresent()) {
isCommitAccepted =
batcher.commitMultiKeyWorkItem(
commit.computationId(),
commit.multiKeyRequest().get(),
commitStatus -> {
for (Work w : commit.workBatch()) {
onCommitComplete.accept(
CompleteCommit.create(
commit.computationId(), w.getShardedKey(), w.id(), commitStatus));
}
activeCommitBytes.addAndGet(-commit.getSize());
});
} else {
isCommitAccepted =
batcher.commitWorkItem(
commit.computationId(),
commit.singleKeyRequest().get(),
commitStatus -> {
Work w = commit.workBatch().get(0);
onCommitComplete.accept(
CompleteCommit.create(
commit.computationId(), w.getShardedKey(), w.id(), commitStatus));
activeCommitBytes.addAndGet(-commit.getSize());
});
}

// Since the commit was not accepted, revert the changes made above.
if (!isCommitAccepted) {
commit.work().setState(Work.State.COMMIT_QUEUED);
for (Work w : commit.workBatch()) {
w.setState(Work.State.COMMIT_QUEUED);
}
activeCommitBytes.addAndGet(-commit.getSize());
}

Expand Down Expand Up @@ -246,8 +274,8 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch
}

// Drop commits for failed work. Such commits will be dropped by Windmill anyway.
if (commit.work().isFailed()) {
onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
if (commit.isFailed()) {
failCommit(commit);
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk;
Expand Down Expand Up @@ -270,7 +271,7 @@ private void flushInternal(Map<Long, PendingRequest> requests)

if (requests.size() == 1) {
Map.Entry<Long, PendingRequest> elem = requests.entrySet().iterator().next();
if (elem.getValue().request().getSerializedSize()
if (elem.getValue().serializedCommit().size()
> AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
issueMultiChunkRequest(elem.getKey(), elem.getValue());
} else {
Expand All @@ -289,6 +290,7 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest)
.setComputationId(pendingRequest.computationId())
.setRequestId(id)
.setShardingKey(pendingRequest.shardingKey())
.setCommitType(pendingRequest.commitType())
.setSerializedWorkItemCommit(pendingRequest.serializedCommit());
StreamingCommitWorkRequest chunk = requestBuilder.build();
synchronized (this) {
Expand Down Expand Up @@ -318,7 +320,8 @@ private void issueBatchedRequest(Map<Long, PendingRequest> requests)
chunkBuilder
.setRequestId(entry.getKey())
.setShardingKey(request.shardingKey())
.setSerializedWorkItemCommit(request.serializedCommit());
.setSerializedWorkItemCommit(request.serializedCommit())
.setCommitType(request.commitType());
}
StreamingCommitWorkRequest request = requestBuilder.build();
synchronized (this) {
Expand Down Expand Up @@ -360,7 +363,8 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest)
.setRequestId(id)
.setSerializedWorkItemCommit(chunk)
.setComputationId(pendingRequest.computationId())
.setShardingKey(pendingRequest.shardingKey());
.setShardingKey(pendingRequest.shardingKey())
.setCommitType(pendingRequest.commitType());
int remaining = serializedCommit.size() - end;
if (remaining > 0) {
chunkBuilder.setRemainingBytesForWorkItem(remaining);
Expand All @@ -378,34 +382,34 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest)

@AutoValue
abstract static class PendingRequest {

private static PendingRequest create(
String computationId, WorkItemCommitRequest request, Consumer<CommitStatus> onDone) {
return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone);
static PendingRequest create(
String computationId,
long shardingKey,
ByteString serializedCommit,
StreamingCommitRequestChunk.CommitType commitType,
Consumer<CommitStatus> onDone) {
return new AutoValue_GrpcCommitWorkStream_PendingRequest(
computationId, shardingKey, serializedCommit, commitType, onDone);
}

abstract String computationId();

abstract WorkItemCommitRequest request();
abstract long shardingKey();

abstract ByteString serializedCommit();

abstract StreamingCommitRequestChunk.CommitType commitType();

abstract Consumer<CommitStatus> onDone();

private long getBytes() {
return (long) request().getSerializedSize() + computationId().length();
}

private ByteString serializedCommit() {
return request().toByteString();
return (long) serializedCommit().size() + computationId().length();
}
Comment on lines +385 to 407
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Optimization: Defer Serialization to the Stream Writer Thread

Currently, PendingRequest eagerly serializes the commit request to a ByteString on the committer thread (inside commitWorkItem and commitMultiKeyWorkItem). Eager serialization on the committer thread can become a performance bottleneck under high throughput.

By holding the MessageLite (or the vendored equivalent org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.MessageLite) directly in PendingRequest and using @AutoValue.Memoized to lazily serialize it, we can:

  1. Defer the serialization overhead to the stream writer thread (when serializedCommit() is actually called).
  2. Ensure that serialization is only performed once and cached, even if serializedCommit() is accessed multiple times.
    static PendingRequest create(
        String computationId,
        long shardingKey,
        com.google.protobuf.MessageLite request,
        StreamingCommitRequestChunk.CommitType commitType,
        Consumer<CommitStatus> onDone) {
      return new AutoValue_GrpcCommitWorkStream_PendingRequest(
          computationId, shardingKey, request, commitType, onDone);
    }

    abstract String computationId();

    abstract long shardingKey();

    abstract com.google.protobuf.MessageLite request();

    abstract StreamingCommitRequestChunk.CommitType commitType();

    abstract Consumer<CommitStatus> onDone();

    private long getBytes() {
      return (long) serializedCommit().size() + computationId().length();
    }

    @AutoValue.Memoized
    ByteString serializedCommit() {
      return request().toByteString();
    }


private void completeWithStatus(CommitStatus commitStatus) {
onDone().accept(commitStatus);
}

private long shardingKey() {
return request().getShardingKey();
}

private void abort() {
completeWithStatus(CommitStatus.ABORTED);
}
Expand Down Expand Up @@ -462,7 +466,34 @@ public boolean commitWorkItem(
return false;
}

PendingRequest request = PendingRequest.create(computation, commitRequest, onDone);
PendingRequest request =
PendingRequest.create(
computation,
commitRequest.getShardingKey(),
commitRequest.toByteString(),
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY,
onDone);
Comment on lines +469 to +475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass the commitRequest directly to PendingRequest.create to support lazy serialization.

Suggested change
PendingRequest request =
PendingRequest.create(
computation,
commitRequest.getShardingKey(),
commitRequest.toByteString(),
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY,
onDone);
PendingRequest request =
PendingRequest.create(
computation,
commitRequest.getShardingKey(),
commitRequest,
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY,
onDone);

add(idGenerator.incrementAndGet(), request);
return true;
}

@Override
public boolean commitMultiKeyWorkItem(
String computation,
Windmill.MultiKeyWorkItemCommitRequest commitRequest,
Consumer<CommitStatus> onDone) {
if (!canAccept(commitRequest.getSerializedSize() + computation.length())) {
return false;
}
Preconditions.checkArgument(commitRequest.getRequestsCount() > 0);
PendingRequest request =
PendingRequest.create(
computation,
// Any key in the batch for routing
commitRequest.getRequests(0).getShardingKey(),
commitRequest.toByteString(),
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY,
onDone);
Comment on lines +489 to +496
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass the commitRequest directly to PendingRequest.create to support lazy serialization.

Suggested change
PendingRequest request =
PendingRequest.create(
computation,
// Any key in the batch for routing
commitRequest.getRequests(0).getShardingKey(),
commitRequest.toByteString(),
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY,
onDone);
PendingRequest request =
PendingRequest.create(
computation,
// Any key in the batch for routing
commitRequest.getRequests(0).getShardingKey(),
commitRequest,
StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY,
onDone);

add(idGenerator.incrementAndGet(), request);
return true;
}
Expand Down
Loading
Loading