From c928799ee472bc28c328146aa0caa2ec9bcba070 Mon Sep 17 00:00:00 2001 From: jencymaryjoseph <35571282+jencymaryjoseph@users.noreply.github.com> Date: Mon, 4 May 2026 22:44:40 -0700 Subject: [PATCH 1/2] Fix multipart presigned URL download: parallel support and error propagation --- ...ignedUrlMultipartDownloaderSubscriber.java | 335 ++++++++++++++++++ .../multipart/PresignedUrlDownloadHelper.java | 31 +- ...ignedUrlMultipartDownloaderSubscriber.java | 68 ++-- ...lMultipartDownloaderSubscriberTckTest.java | 134 +++++++ ...dUrlMultipartDownloaderSubscriberTest.java | 250 +++++++++++++ ...lMultipartDownloaderSubscriberTckTest.java | 3 +- ...ipartDownloaderSubscriberWiremockTest.java | 269 ++++++++------ 7 files changed, 943 insertions(+), 147 deletions(-) create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTckTest.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java new file mode 100644 index 000000000000..9c8f77ada905 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java @@ -0,0 +1,335 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * A parallel subscriber for multipart presigned URL downloads that writes parts concurrently. + * Used with {@link software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformerPublisher} + * when {@code parallelSplitSupported() == true} (i.e., toFile() downloads). + * + *

Unlike {@link PresignedUrlMultipartDownloaderSubscriber} which requests one part at a time, + * this subscriber requests up to {@code maxInFlightParts} concurrently, similar to + * {@link ParallelMultipartDownloaderSubscriber} for regular multipart downloads.

+ */ +@SdkInternalApi +public class ParallelPresignedUrlMultipartDownloaderSubscriber + implements Subscriber> { + + private static final Logger log = Logger.loggerFor(ParallelPresignedUrlMultipartDownloaderSubscriber.class); + private static final String BYTES_RANGE_PREFIX = "bytes="; + + private final S3AsyncClient s3AsyncClient; + private final PresignedUrlDownloadRequest presignedUrlDownloadRequest; + private final long configuredPartSizeInBytes; + private final CompletableFuture resultFuture; + private final int maxInFlightParts; + + private final AtomicInteger partNumber = new AtomicInteger(0); + private final AtomicInteger completedParts = new AtomicInteger(0); + private final AtomicInteger inFlightRequestsNum = new AtomicInteger(0); + private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false); + private final AtomicBoolean processingPending = new AtomicBoolean(false); + private final Map> inFlightRequests = new ConcurrentHashMap<>(); + private final Queue>> pendingTransformers = + new ConcurrentLinkedQueue<>(); + + private final Object subscriptionLock = new Object(); + private Subscription subscription; + + private volatile Long totalContentLength; + private volatile Integer totalParts; + private volatile String eTag; + private volatile GetObjectResponse firstResponse; + + public ParallelPresignedUrlMultipartDownloaderSubscriber( + S3AsyncClient s3AsyncClient, + PresignedUrlDownloadRequest presignedUrlDownloadRequest, + long configuredPartSizeInBytes, + CompletableFuture resultFuture, + int maxInFlightParts) { + this.s3AsyncClient = s3AsyncClient; + this.presignedUrlDownloadRequest = presignedUrlDownloadRequest; + this.configuredPartSizeInBytes = configuredPartSizeInBytes; + this.resultFuture = resultFuture; + this.maxInFlightParts = maxInFlightParts; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer asyncResponseTransformer) { + if (asyncResponseTransformer == null) { + throw new NullPointerException("onNext must not be called with null asyncResponseTransformer"); + } + + int currentPart = partNumber.getAndIncrement(); + + if (currentPart == 0) { + sendFirstRequest(asyncResponseTransformer); + } else { + if (totalParts != null && currentPart >= totalParts) { + return; + } + if (totalParts != null) { + processRequest(asyncResponseTransformer, currentPart); + } else { + pendingTransformers.offer(Pair.of(currentPart, asyncResponseTransformer)); + } + } + } + + private void sendFirstRequest(AsyncResponseTransformer transformer) { + PresignedUrlDownloadRequest partRequest = createRangedGetRequest(0); + log.debug(() -> "Sending first range request with range=" + partRequest.range()); + + CompletableFuture response = + s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer); + + inFlightRequests.put(0, response); + inFlightRequestsNum.incrementAndGet(); + CompletableFutureUtils.forwardExceptionTo(resultFuture, response); + + response.whenComplete((res, error) -> { + if (error != null || isCompletedExceptionally.get()) { + handlePartError(error, 0); + return; + } + + inFlightRequests.remove(0); + inFlightRequestsNum.decrementAndGet(); + completedParts.incrementAndGet(); + + // Discover size and ETag from first response + this.eTag = res.eTag(); + this.firstResponse = res; + + String contentRange = res.contentRange(); + if (contentRange == null) { + handlePartError(PresignedUrlDownloadHelper.missingContentRangeHeader(), 0); + return; + } + + Optional parsedTotal = MultipartDownloadUtils.parseContentRangeForTotalSize(contentRange); + if (!parsedTotal.isPresent()) { + handlePartError(PresignedUrlDownloadHelper.invalidContentRangeHeader(contentRange), 0); + return; + } + + this.totalContentLength = parsedTotal.get(); + this.totalParts = calculateTotalParts(totalContentLength, configuredPartSizeInBytes); + log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalContentLength, totalParts)); + + if (totalParts <= 1) { + synchronized (subscriptionLock) { + subscription.cancel(); + } + resultFuture.complete(firstResponse); + return; + } + + processPendingTransformers(); + + int remainingParts = totalParts - 1; + int toRequest = Math.min(remainingParts, maxInFlightParts); + synchronized (subscriptionLock) { + subscription.request(toRequest); + } + }); + } + + private void processRequest(AsyncResponseTransformer transformer, + int currentPart) { + if (currentPart >= totalParts) { + return; + } + + if (inFlightRequestsNum.get() >= maxInFlightParts) { + pendingTransformers.offer(Pair.of(currentPart, transformer)); + return; + } + + sendPartRequest(transformer, currentPart); + processPendingTransformers(); + } + + private void sendPartRequest(AsyncResponseTransformer transformer, + int partIndex) { + if (isCompletedExceptionally.get()) { + return; + } + + PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex); + log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range()); + + CompletableFuture response = + s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer); + + inFlightRequests.put(partIndex, response); + inFlightRequestsNum.incrementAndGet(); + CompletableFutureUtils.forwardExceptionTo(resultFuture, response); + + response.whenComplete((res, error) -> { + if (error != null || isCompletedExceptionally.get()) { + handlePartError(error, partIndex); + return; + } + + Optional validationError = validatePartResponse(res, partIndex); + if (validationError.isPresent()) { + handlePartError(validationError.get(), partIndex); + return; + } + + log.debug(() -> "Completed part: " + partIndex); + inFlightRequests.remove(partIndex); + inFlightRequestsNum.decrementAndGet(); + int totalComplete = completedParts.incrementAndGet(); + + if (totalComplete >= totalParts) { + synchronized (subscriptionLock) { + subscription.cancel(); + } + resultFuture.complete(firstResponse); + } else { + processPendingTransformers(); + synchronized (subscriptionLock) { + subscription.request(1); + } + } + }); + } + + private void processPendingTransformers() { + do { + if (!processingPending.compareAndSet(false, true)) { + return; + } + try { + while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts) { + Pair> pair = + pendingTransformers.poll(); + if (pair != null && pair.left() < totalParts) { + sendPartRequest(pair.right(), pair.left()); + } + } + } finally { + processingPending.set(false); + } + } while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts); + } + + private Optional validatePartResponse(GetObjectResponse response, int partIndex) { + String contentRange = response.contentRange(); + if (contentRange == null) { + return Optional.of(PresignedUrlDownloadHelper.missingContentRangeHeader()); + } + Long contentLength = response.contentLength(); + if (contentLength == null || contentLength < 0) { + return Optional.of(PresignedUrlDownloadHelper.invalidContentLength()); + } + long expectedStartByte = partIndex * configuredPartSizeInBytes; + long expectedEndByte = Math.min(expectedStartByte + configuredPartSizeInBytes - 1, totalContentLength - 1); + String expectedRange = "bytes " + expectedStartByte + "-" + expectedEndByte + "/"; + if (!contentRange.startsWith(expectedRange)) { + return Optional.of(SdkClientException.create( + "Content-Range mismatch. Expected range starting with: " + expectedRange + + ", but got: " + contentRange)); + } + long expectedPartSize = (partIndex == totalParts - 1) + ? totalContentLength - (partIndex * configuredPartSizeInBytes) + : configuredPartSizeInBytes; + if (!contentLength.equals(expectedPartSize)) { + return Optional.of(SdkClientException.create( + String.format("Part content length validation failed for part %d. Expected: %d, but got: %d", + partIndex, expectedPartSize, contentLength))); + } + return Optional.empty(); + } + + private void handlePartError(Throwable error, int partIndex) { + if (isCompletedExceptionally.compareAndSet(false, true)) { + log.debug(() -> "Error on part " + partIndex, error); + resultFuture.completeExceptionally(error); + inFlightRequests.values().forEach(future -> future.cancel(true)); + synchronized (subscriptionLock) { + if (subscription != null) { + subscription.cancel(); + } + } + } + } + + private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) { + long startByte = partIndex * configuredPartSizeInBytes; + long endByte; + if (totalContentLength != null) { + endByte = Math.min(startByte + configuredPartSizeInBytes - 1, totalContentLength - 1); + } else { + endByte = startByte + configuredPartSizeInBytes - 1; + } + String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte; + PresignedUrlDownloadRequest.Builder builder = presignedUrlDownloadRequest.toBuilder() + .range(rangeHeader); + if (partIndex > 0 && eTag != null) { + builder.ifMatch(eTag); + } + return builder.build(); + } + + private int calculateTotalParts(long contentLength, long partSize) { + return (int) Math.ceil((double) contentLength / partSize); + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Error in parallel multipart download", t); + resultFuture.completeExceptionally(t); + inFlightRequests.values().forEach(future -> future.cancel(true)); + } + + @Override + public void onComplete() { + // Completion is handled by resultFuture + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java index 7ee520b76c23..f62ec1d7d202 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java @@ -30,6 +30,7 @@ @SdkInternalApi public class PresignedUrlDownloadHelper { private static final Logger log = Logger.loggerFor(PresignedUrlDownloadHelper.class); + private static final int DEFAULT_MAX_IN_FLIGHT_PARTS = 10; private final S3AsyncClient s3AsyncClient; private final AsyncPresignedUrlExtension asyncPresignedUrlExtension; @@ -64,12 +65,38 @@ public CompletableFuture downloadObject( .build(); AsyncResponseTransformer.SplitResult split = asyncResponseTransformer.split(splittingConfig); + + if (split.parallelSplitSupported()) { + return downloadPartsInParallel(presignedRequest, split); + } + return downloadPartsSerially(presignedRequest, split); + } + + private CompletableFuture downloadPartsInParallel( + PresignedUrlDownloadRequest presignedRequest, + AsyncResponseTransformer.SplitResult split) { + log.debug(() -> "Using parallel multipart download for presigned URL"); + ParallelPresignedUrlMultipartDownloaderSubscriber subscriber = + new ParallelPresignedUrlMultipartDownloaderSubscriber( + s3AsyncClient, + presignedRequest, + configuredPartSizeInBytes, + (CompletableFuture) split.resultFuture(), + DEFAULT_MAX_IN_FLIGHT_PARTS); + split.publisher().subscribe(subscriber); + return split.resultFuture(); + } + + private CompletableFuture downloadPartsSerially( + PresignedUrlDownloadRequest presignedRequest, + AsyncResponseTransformer.SplitResult split) { + log.debug(() -> "Using serial multipart download for presigned URL"); PresignedUrlMultipartDownloaderSubscriber subscriber = new PresignedUrlMultipartDownloaderSubscriber( s3AsyncClient, presignedRequest, - configuredPartSizeInBytes); - + configuredPartSizeInBytes, + split.resultFuture()); split.publisher().subscribe(subscriber); return split.resultFuture(); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java index 975ae0339e78..7c042ef5aefa 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java @@ -54,6 +54,7 @@ public class PresignedUrlMultipartDownloaderSubscriber private final PresignedUrlDownloadRequest presignedUrlDownloadRequest; private final Long configuredPartSizeInBytes; private final CompletableFuture future; + private final CompletableFuture resultFuture; private final Object lock = new Object(); private final AtomicInteger completedParts; private final AtomicInteger requestsSent; @@ -66,13 +67,15 @@ public class PresignedUrlMultipartDownloaderSubscriber public PresignedUrlMultipartDownloaderSubscriber( S3AsyncClient s3AsyncClient, PresignedUrlDownloadRequest presignedUrlDownloadRequest, - long configuredPartSizeInBytes) { + long configuredPartSizeInBytes, + CompletableFuture resultFuture) { this.s3AsyncClient = s3AsyncClient; this.presignedUrlDownloadRequest = presignedUrlDownloadRequest; this.configuredPartSizeInBytes = configuredPartSizeInBytes; this.completedParts = new AtomicInteger(0); this.requestsSent = new AtomicInteger(0); this.future = new CompletableFuture<>(); + this.resultFuture = resultFuture; } @Override @@ -109,20 +112,20 @@ private void makeRangeRequest(int partIndex, GetObjectResponse> asyncResponseTransformer) { PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex); log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range()); - + requestsSent.incrementAndGet(); s3AsyncClient.presignedUrlExtension() - .getObject(partRequest, asyncResponseTransformer) - .whenComplete((response, error) -> { - if (error != null) { - log.debug(() -> "Error encountered during part request for part " + partIndex); - handleError(error); - return; - } - if (validatePart(response, partIndex, asyncResponseTransformer)) { - requestMoreIfNeeded(completedParts.get()); - } - }); + .getObject(partRequest, asyncResponseTransformer) + .whenComplete((response, error) -> { + if (error != null) { + log.debug(() -> "Error encountered during part request for part " + partIndex); + handleError(error); + return; + } + if (validatePart(response, partIndex, asyncResponseTransformer)) { + requestMoreIfNeeded(completedParts.get()); + } + }); } private boolean validatePart(GetObjectResponse response, int partIndex, @@ -137,14 +140,6 @@ private boolean validatePart(GetObjectResponse response, int partIndex, log.debug(() -> String.format("Multipart object ETag: %s", this.eTag)); } - Optional validationError = validateResponse(response, partIndex); - if (validationError.isPresent()) { - log.debug(() -> "Response validation failed", validationError.get()); - asyncResponseTransformer.exceptionOccurred(validationError.get()); - handleError(validationError.get()); - return false; - } - if (totalContentLength == null && responseContentRange != null) { Optional parsedContentLength = MultipartDownloadUtils.parseContentRangeForTotalSize(responseContentRange); if (!parsedContentLength.isPresent()) { @@ -159,6 +154,15 @@ private boolean validatePart(GetObjectResponse response, int partIndex, this.totalParts = calculateTotalParts(totalContentLength, configuredPartSizeInBytes); log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalContentLength, totalParts)); } + + Optional validationError = validateResponse(response, partIndex); + if (validationError.isPresent()) { + log.debug(() -> "Response validation failed", validationError.get()); + asyncResponseTransformer.exceptionOccurred(validationError.get()); + handleError(validationError.get()); + return false; + } + return true; } @@ -186,7 +190,7 @@ private Optional validateResponse(GetObjectResponse response if (contentRange == null) { return Optional.of(PresignedUrlDownloadHelper.missingContentRangeHeader()); } - + Long contentLength = response.contentLength(); if (contentLength == null || contentLength < 0) { return Optional.of(PresignedUrlDownloadHelper.invalidContentLength()); @@ -202,7 +206,7 @@ private Optional validateResponse(GetObjectResponse response String expectedRange = "bytes " + expectedStartByte + "-" + expectedEndByte + "/"; if (!contentRange.startsWith(expectedRange)) { return Optional.of(SdkClientException.create( - "Content-Range mismatch. Expected range starting with: " + expectedRange + + "Content-Range mismatch. Expected range starting with: " + expectedRange + ", but got: " + contentRange)); } @@ -215,19 +219,19 @@ private Optional validateResponse(GetObjectResponse response if (!contentLength.equals(expectedPartSize)) { return Optional.of(SdkClientException.create( String.format("Part content length validation failed for part %d. Expected: %d, but got: %d", - partIndex, expectedPartSize, contentLength))); + partIndex, expectedPartSize, contentLength))); } long actualStartByte = MultipartDownloadUtils.parseStartByteFromContentRange(contentRange); if (actualStartByte != expectedStartByte) { return Optional.of(SdkClientException.create( - "Content range offset mismatch for part " + partIndex + + "Content range offset mismatch for part " + partIndex + ". Expected start: " + expectedStartByte + ", but got: " + actualStartByte)); } - + return Optional.empty(); } - + private int calculateTotalParts(long contentLength, long partSize) { return (int) Math.ceil((double) contentLength / partSize); } @@ -246,7 +250,7 @@ private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) { } String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte; PresignedUrlDownloadRequest.Builder builder = presignedUrlDownloadRequest.toBuilder() - .range(rangeHeader); + .range(rangeHeader); if (partIndex > 0 && eTag != null) { builder.ifMatch(eTag); } @@ -254,18 +258,24 @@ private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) { } private void handleError(Throwable t) { + future.completeExceptionally(t); + if (resultFuture != null) { + resultFuture.completeExceptionally(t); + } synchronized (lock) { if (subscription != null) { subscription.cancel(); } } - onError(t); } @Override public void onError(Throwable t) { log.debug(() -> "Error in multipart download", t); future.completeExceptionally(t); + if (resultFuture != null) { + resultFuture.completeExceptionally(t); + } } @Override diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTckTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTckTest.java new file mode 100644 index 000000000000..bb05a0833a2e --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTckTest.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.net.MalformedURLException; +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.presignedurl.AsyncPresignedUrlExtension; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; + +public class ParallelPresignedUrlMultipartDownloaderSubscriberTckTest + extends SubscriberWhiteboxVerification> { + + private final S3AsyncClient s3mock; + + public ParallelPresignedUrlMultipartDownloaderSubscriberTckTest() { + super(new TestEnvironment()); + this.s3mock = Mockito.mock(S3AsyncClient.class); + AsyncPresignedUrlExtension presignedUrlExtension = Mockito.mock(AsyncPresignedUrlExtension.class); + when(s3mock.presignedUrlExtension()).thenReturn(presignedUrlExtension); + + when(presignedUrlExtension.getObject(any(PresignedUrlDownloadRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture( + GetObjectResponse.builder() + .contentRange("bytes 0-8388607/33554432") + .contentLength(8388608L) + .eTag("\"test-etag\"") + .build())); + } + + @Override + public Subscriber> createSubscriber( + WhiteboxSubscriberProbe> probe) { + + return new ParallelPresignedUrlMultipartDownloaderSubscriber( + s3mock, + createTestPresignedUrlRequest(), + 8 * 1024 * 1024L, + new CompletableFuture<>(), + 10 + ) { + @Override + public void onSubscribe(Subscription s) { + super.onSubscribe(s); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + s.request(elements); + } + + @Override + public void signalCancel() { + s.cancel(); + } + }); + } + + @Override + public void onNext(AsyncResponseTransformer item) { + super.onNext(item); + probe.registerOnNext(item); + } + + @Override + public void onError(Throwable t) { + super.onError(t); + probe.registerOnError(t); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public AsyncResponseTransformer createElement(int element) { + return new AsyncResponseTransformer() { + @Override + public CompletableFuture prepare() { + return new CompletableFuture<>(); + } + + @Override + public void onResponse(GetObjectResponse response) { + } + + @Override + public void onStream(SdkPublisher publisher) { + } + + @Override + public void exceptionOccurred(Throwable error) { + } + }; + } + + private PresignedUrlDownloadRequest createTestPresignedUrlRequest() { + try { + return PresignedUrlDownloadRequest.builder() + .presignedUrl(java.net.URI.create("https://test-bucket.s3.amazonaws.com/test-key").toURL()) + .build(); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java new file mode 100644 index 000000000000..0062029e1b83 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java @@ -0,0 +1,250 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.UUID; +import java.util.concurrent.CompletionException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; + +/** + * Unit tests for {@link ParallelPresignedUrlMultipartDownloaderSubscriber}. + * Tests parallel-specific behavior: single-part detection, concurrent requests, + * and error propagation with in-flight cancellation. + */ +@WireMockTest +class ParallelPresignedUrlMultipartDownloaderSubscriberTest { + + private static final String PRESIGNED_URL_PATH = "/parallel-test"; + private static final byte[] TEST_DATA = "ABCDEFGHIJKLMNOPQRSTUVWXYZ123456".getBytes(StandardCharsets.UTF_8); // 32 bytes + + private S3AsyncClient s3AsyncClient; + private URL presignedUrl; + private Path tempFile; + + @BeforeEach + void setup(WireMockRuntimeInfo wiremock) throws MalformedURLException { + MultipartConfiguration multipartConfig = MultipartConfiguration.builder() + .minimumPartSizeInBytes(16L) + .build(); + s3AsyncClient = S3AsyncClient.builder() + .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) + .multipartEnabled(true) + .multipartConfiguration(multipartConfig) + .build(); + presignedUrl = new URL("http://localhost:" + wiremock.getHttpPort() + PRESIGNED_URL_PATH); + } + + @AfterEach + void cleanup() throws IOException { + if (tempFile != null && Files.exists(tempFile)) { + Files.delete(tempFile); + } + } + + @Test + void singlePartObject_shouldCompleteWithoutAdditionalRequests() throws IOException { + byte[] smallData = "0123456789".getBytes(StandardCharsets.UTF_8); + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", String.valueOf(smallData.length)) + .withHeader("Content-Range", "bytes 0-9/10") + .withHeader("ETag", "\"single-part-etag\"") + .withBody(smallData))); + + tempFile = createTempFile(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + GetObjectResponse response = (GetObjectResponse) s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join(); + + assertThat(response.eTag()).isEqualTo("\"single-part-etag\""); + assertThat(Files.readAllBytes(tempFile)).isEqualTo(smallData); + verify(1, getRequestedFor(urlEqualTo(PRESIGNED_URL_PATH))); + } + + @Test + void multiPartObject_shouldDownloadAllPartsConcurrently() throws IOException { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/32") + .withHeader("ETag", "\"multi-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16)))); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=16-31")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 16-31/32") + .withHeader("ETag", "\"multi-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 16, 32)))); + + tempFile = createTempFile(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join(); + + assertThat(Files.readAllBytes(tempFile)).isEqualTo(TEST_DATA); + } + + @Test + void errorOnSecondPart_shouldCompleteExceptionallyAndNotSendMoreRequests() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/48") + .withHeader("ETag", "\"error-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16)))); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=16-31")) + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalError" + + "Simulated failure"))); + + tempFile = createTempFileUnchecked(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join()) + .isInstanceOf(CompletionException.class); + } + + @Test + void missingContentRangeOnFirstPart_shouldFail() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("ETag", "\"no-range-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16)))); + + tempFile = createTempFileUnchecked(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join()) + .hasRootCauseInstanceOf(SdkClientException.class) + .hasMessageContaining("No Content-Range header"); + } + + @Test + void contentRangeMismatchOnSecondPart_shouldFail() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/32") + .withHeader("ETag", "\"mismatch-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16)))); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=16-31")) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 9999-10014/32") + .withHeader("ETag", "\"mismatch-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 16, 32)))); + + tempFile = createTempFileUnchecked(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join()) + .hasRootCauseInstanceOf(SdkClientException.class) + .hasMessageContaining("Content-Range mismatch"); + } + + @Test + void onNext_withNullTransformer_shouldThrowNPE() { + ParallelPresignedUrlMultipartDownloaderSubscriber subscriber = + new ParallelPresignedUrlMultipartDownloaderSubscriber( + s3AsyncClient, + PresignedUrlDownloadRequest.builder().presignedUrl(presignedUrl).build(), + 16L, + new java.util.concurrent.CompletableFuture<>(), + 10); + + assertThatThrownBy(() -> subscriber.onNext(null)) + .isInstanceOf(NullPointerException.class); + } + + private static Path createTempFile() throws IOException { + Path path = Files.createTempFile("parallel-test-" + UUID.randomUUID(), ".tmp"); + Files.deleteIfExists(path); + return path; + } + + private static Path createTempFileUnchecked() { + try { + return createTempFile(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberTckTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberTckTest.java index 1b28d0175395..3c176d7e635a 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberTckTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberTckTest.java @@ -74,7 +74,8 @@ public PresignedUrlMultipartDownloaderSubscriberTckTest() { return new PresignedUrlMultipartDownloaderSubscriber( s3mock, createTestPresignedUrlRequest(), - 8 * 1024 * 1024L + 8 * 1024 * 1024L, + new CompletableFuture<>() ) { @Override public void onError(Throwable throwable) { diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java index 3228ec2b8800..2ff0e4da46b7 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java @@ -35,12 +35,19 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; @@ -70,156 +77,187 @@ public void setup(WireMockRuntimeInfo wiremock) throws MalformedURLException { presignedUrl = createPresignedUrl(); } - @Test - void presignedUrlDownload_withMultipartData_shouldReceiveCompleteBody() { + static Stream transformerTypes() { + return Stream.of( + Arguments.of("toBytes"), + Arguments.of("toFile") + ); + } + + private CompletableFuture executeDownload(PresignedUrlDownloadRequest request, String transformerType) + throws IOException { + if ("toFile".equals(transformerType)) { + tempFile = createUniqueTempFile(); + return s3AsyncClient.presignedUrlExtension().getObject(request, AsyncResponseTransformer.toFile(tempFile)); + } + return s3AsyncClient.presignedUrlExtension().getObject(request, AsyncResponseTransformer.toBytes()); + } + + @SuppressWarnings("unchecked") + private void assertSuccessfulDownload(String type, Object result) throws IOException { + if ("toBytes".equals(type)) { + assertArrayEquals(TEST_DATA, ((ResponseBytes) result).asByteArray()); + } else { + assertThat(tempFile.toFile()).exists(); + byte[] fileContent = Files.readAllBytes(tempFile); + assertArrayEquals(TEST_DATA, fileContent); + } + } + + @ParameterizedTest(name = "presignedUrlDownload_withMultipartData_shouldReceiveCompleteBody [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_withMultipartData_shouldReceiveCompleteBody(String transformerType) throws IOException { stubSuccessfulPresignedUrlResponse(); - byte[] result = s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join() - .asByteArray(); - assertArrayEquals(TEST_DATA, result); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + Object result = executeDownload(request, transformerType).join(); + assertSuccessfulDownload(transformerType, result); } - @Test - void presignedUrlDownload_withRangeHeader_shouldReceivePartialContent() { + @ParameterizedTest(name = "presignedUrlDownload_smallObjectSmallerThanPartSize_shouldSucceed [{0}]") + @MethodSource("transformerTypes") + @SuppressWarnings("unchecked") + void presignedUrlDownload_smallObjectSmallerThanPartSize_shouldSucceed(String transformerType) throws IOException { + byte[] smallData = "0123456789".getBytes(StandardCharsets.UTF_8); + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Type", "application/octet-stream") + .withHeader("Content-Length", "10") + .withHeader("Content-Range", "bytes 0-9/10") + .withHeader("ETag", "\"small-etag\"") + .withBody(smallData))); + + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + Object result = executeDownload(request, transformerType).join(); + if ("toBytes".equals(transformerType)) { + assertArrayEquals(smallData, ((ResponseBytes) result).asByteArray()); + } else { + assertThat(tempFile.toFile()).exists(); + assertArrayEquals(smallData, Files.readAllBytes(tempFile)); + } + } + + @ParameterizedTest(name = "presignedUrlDownload_withRangeHeader_shouldReceivePartialContent [{0}]") + @MethodSource("transformerTypes") + @SuppressWarnings("unchecked") + void presignedUrlDownload_withRangeHeader_shouldReceivePartialContent(String transformerType) throws IOException { stubSuccessfulRangeResponse(); - byte[] result = s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .range("bytes=0-10") - .build(), - AsyncResponseTransformer.toBytes()) - .join() - .asByteArray(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .range("bytes=0-10") + .build(); + Object result = executeDownload(request, transformerType).join(); byte[] expectedPartial = Arrays.copyOfRange(TEST_DATA, 0, 11); - assertArrayEquals(expectedPartial, result); + if ("toBytes".equals(transformerType)) { + assertArrayEquals(expectedPartial, ((ResponseBytes) result).asByteArray()); + } else { + byte[] fileContent = Files.readAllBytes(tempFile); + assertArrayEquals(expectedPartial, fileContent); + } } - @Test - void presignedUrlDownload_whenRequestFails_shouldThrowException() { + @ParameterizedTest(name = "presignedUrlDownload_whenRequestFails_shouldThrowException [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_whenRequestFails_shouldThrowException(String transformerType) { stubFailedPresignedUrlResponse(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(S3Exception.class); } - @Test - void presignedUrlDownload_withFileTransformer_shouldWork() throws IOException { - stubSuccessfulPresignedUrlResponse(); - tempFile = createUniqueTempFile(); - s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toFile(tempFile)) - .join(); - assertThat(tempFile.toFile()).exists(); - assertThat(tempFile.toFile().length()).isGreaterThan(0); - } - - @Test - void presignedUrlDownload_whenFirstRequestFails_shouldThrowException() { + @ParameterizedTest(name = "presignedUrlDownload_whenFirstRequestFails_shouldThrowException [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_whenFirstRequestFails_shouldThrowException(String transformerType) { stubInternalServerError(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(S3Exception.class); } - @Test - void presignedUrlDownload_whenSecondRequestFails_shouldThrowException() { + @ParameterizedTest(name = "presignedUrlDownload_whenSecondRequestFails_shouldThrowException [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_whenSecondRequestFails_shouldThrowException(String transformerType) { stubPartialFailureScenario(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(S3Exception.class); } - @Test - void presignedUrlDownload_whenIOErrorOccurs_shouldThrowException() { + @ParameterizedTest(name = "presignedUrlDownload_whenIOErrorOccurs_shouldThrowException [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_whenIOErrorOccurs_shouldThrowException(String transformerType) { stubConnectionReset(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) - .hasCauseInstanceOf(IOException.class); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) + .hasCauseInstanceOf(SdkClientException.class); } - - @Test - void onNext_withNullTransformer_shouldThrowException() { - PresignedUrlMultipartDownloaderSubscriber subscriber = createTestSubscriber(); - - assertThatThrownBy(() -> subscriber.onNext(null)) - .isInstanceOf(NullPointerException.class) - .hasMessageContaining("onNext must not be called with null asyncResponseTransformer"); - } - - @Test - void presignedUrlDownload_withMissingContentRange_shouldFailRequest() { + @ParameterizedTest(name = "presignedUrlDownload_withMissingContentRange_shouldFailRequest [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_withMissingContentRange_shouldFailRequest(String transformerType) { stubResponseWithMissingContentRange(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(SdkClientException.class) .hasMessageContaining("No Content-Range header in response"); } - @Test - void presignedUrlDownload_withInvalidContentLength_shouldFailRequest() { + @ParameterizedTest(name = "presignedUrlDownload_withInvalidContentLength_shouldFailRequest [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_withInvalidContentLength_shouldFailRequest(String transformerType) { stubResponseWithInvalidContentLength(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(SdkClientException.class) .hasMessageContaining("Invalid or missing Content-Length in response"); } - @Test - void presignedUrlDownload_withContentRangeMismatch_shouldFailRequest() { + @ParameterizedTest(name = "presignedUrlDownload_withContentRangeMismatch_shouldFailRequest [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_withContentRangeMismatch_shouldFailRequest(String transformerType) { stubResponseWithContentRangeMismatch(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) .hasRootCauseInstanceOf(SdkClientException.class) .hasMessageContaining("Content-Range mismatch"); } - @Test - void presignedUrlDownload_withContentLengthMismatch_shouldFailRequest() { + @ParameterizedTest(name = "presignedUrlDownload_withContentLengthMismatch_shouldFailRequest [{0}]") + @MethodSource("transformerTypes") + void presignedUrlDownload_withContentLengthMismatch_shouldFailRequest(String transformerType) { stubResponseWithContentLengthMismatch(); - assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() - .getObject(PresignedUrlDownloadRequest.builder() - .presignedUrl(presignedUrl) - .build(), - AsyncResponseTransformer.toBytes()) - .join()) - .hasRootCauseInstanceOf(SdkClientException.class) - .hasMessageContaining("Part content length validation failed"); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + assertThatThrownBy(() -> executeDownload(request, transformerType).join()) + .hasRootCauseInstanceOf(SdkClientException.class); + } + + @Test + void onNext_withNullTransformer_shouldThrowException() { + PresignedUrlMultipartDownloaderSubscriber subscriber = createTestSubscriber(); + + assertThatThrownBy(() -> subscriber.onNext(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("onNext must not be called with null asyncResponseTransformer"); } @AfterEach @@ -250,7 +288,7 @@ private void stubSuccessfulPresignedUrlResponse() { .withHeader("Content-Range", "bytes 0-15/32") .withHeader("ETag", "\"test-etag\"") .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16)))); - + // Stub for second part (bytes 16-31) stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) .withHeader("Range", matching("bytes=16-31")) @@ -324,7 +362,8 @@ private PresignedUrlMultipartDownloaderSubscriber createTestSubscriber() { return new PresignedUrlMultipartDownloaderSubscriber( s3AsyncClient, PresignedUrlDownloadRequest.builder().presignedUrl(presignedUrl).build(), - 1024L); + 1024L, + new CompletableFuture<>()); } private void stubResponseWithMissingContentRange() { @@ -368,4 +407,4 @@ private void stubResponseWithContentLengthMismatch() { .withHeader("ETag", "\"test-etag\"") .withBody(Arrays.copyOfRange(TEST_DATA, 0, 8)))); } -} \ No newline at end of file +} From 70e06de9e6fc19fd20aec5283e818315c24d48f8 Mon Sep 17 00:00:00 2001 From: jencymaryjoseph <35571282+jencymaryjoseph@users.noreply.github.com> Date: Tue, 5 May 2026 14:13:01 -0700 Subject: [PATCH 2/2] Fix CodeBuild failure: update codegen file for EndpointResolverInterceptorSpec null check --- .../rules/endpoint-resolve-interceptor-with-stringarray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-stringarray.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-stringarray.java index 0f5033376ffc..fc8471b3bb4b 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-stringarray.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-stringarray.java @@ -84,7 +84,7 @@ public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttribut @Override public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { Endpoint resolvedEndpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - if (resolvedEndpoint.headers().isEmpty()) { + if (resolvedEndpoint == null || CollectionUtils.isNullOrEmpty(resolvedEndpoint.headers())) { return context.httpRequest(); } SdkHttpRequest.Builder httpRequestBuilder = context.httpRequest().toBuilder();