From 9b661d8be84d037c19aea8a6bae37c123c22d85e Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Fri, 21 Jul 2023 09:45:20 -0700 Subject: [PATCH 1/2] Support uploading with unknown content length --- .../internal/async/SplittingPublisher.java | 39 ++- ...ltipartClientPutObjectIntegrationTest.java | 36 ++- .../s3/src/it/resources/log4j2.properties | 8 +- .../multipart/BaseMultipartUploadHelper.java | 113 ++++++++ .../multipart/GenericMultipartHelper.java | 17 +- .../multipart/MultipartS3AsyncClient.java | 4 +- .../multipart/UploadObjectHelper.java | 76 ++++++ ...> UploadWithKnownContentLengthHelper.java} | 69 ++--- .../UploadWithUnknownContentLengthHelper.java | 245 ++++++++++++++++++ .../multipart/MultipartUploadHelperTest.java | 6 +- 10 files changed, 535 insertions(+), 78 deletions(-) create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java rename services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/{MultipartUploadHelper.java => UploadWithKnownContentLengthHelper.java} (78%) create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 99cf1e7c3381..27aaadc5f7a1 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -120,23 +120,28 @@ public void onNext(ByteBuffer byteBuffer) { int amountRemainingInChunk = amountRemainingInChunk(); // If we have fulfilled this chunk, - // we should create a new DownstreamBody if needed + // complete the current body if (amountRemainingInChunk == 0) { - completeCurrentBody(); - - if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { - int currentChunk = chunkNumber.incrementAndGet(); - long chunkSize = calculateChunkSize(totalDataRemaining()); - currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); - } + completeCurrentBodyAndCreateNewIfNeeded(byteBuffer); } amountRemainingInChunk = amountRemainingInChunk(); - if (amountRemainingInChunk >= byteBuffer.remaining()) { + + // If the current ByteBuffer < this chunk, send it as-is + if (amountRemainingInChunk > byteBuffer.remaining()) { + currentBody.send(byteBuffer.duplicate()); + break; + } + + // If the current ByteBuffer == this chunk, send it as-is and + // complete the current body + if (amountRemainingInChunk == byteBuffer.remaining()) { currentBody.send(byteBuffer.duplicate()); + completeCurrentBodyAndCreateNewIfNeeded(byteBuffer); break; } + // If the current ByteBuffer > this chunk, split this ByteBuffer ByteBuffer firstHalf = byteBuffer.duplicate(); int newLimit = firstHalf.position() + amountRemainingInChunk; firstHalf.limit(newLimit); @@ -147,6 +152,16 @@ public void onNext(ByteBuffer byteBuffer) { maybeRequestMoreUpstreamData(); } + private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { + completeCurrentBody(); + + if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { + int currentChunk = chunkNumber.incrementAndGet(); + long chunkSize = calculateChunkSize(totalDataRemaining()); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + } + } + /** * If content length is known, we should create new DownstreamRequestBody if there's remaining data. @@ -161,6 +176,7 @@ private int amountRemainingInChunk() { } private void completeCurrentBody() { + log.debug(() -> "completeCurrentBody"); currentBody.complete(); if (upstreamSize == null) { sendCurrentBody(currentBody); @@ -181,6 +197,7 @@ public void onError(Throwable t) { } private void sendCurrentBody(AsyncRequestBody body) { + log.debug(() -> "sendCurrentBody"); downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); return null; @@ -206,7 +223,7 @@ private void maybeRequestMoreUpstreamData() { } private boolean shouldRequestMoreData(long buffered) { - return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; + return buffered == 0 || buffered + byteBufferSizeHint <= maxMemoryUsageInBytes; } private Long totalDataRemaining() { @@ -240,7 +257,7 @@ public Optional contentLength() { } public void send(ByteBuffer data) { - log.trace(() -> "Sending bytebuffer " + data); + log.trace(() -> String.format("Sending bytebuffer %s to chunk %d", data, chunkNumber)); int length = data.remaining(); transferredLength += length; addDataBuffered(length); diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index f791b4b3c26a..616cd794418b 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -22,18 +22,23 @@ import java.io.ByteArrayInputStream; import java.io.File; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Files; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Subscriber; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; @@ -42,11 +47,12 @@ import software.amazon.awssdk.services.s3.utils.ChecksumUtils; import software.amazon.awssdk.testutils.RandomTempFile; +@Timeout(value = 30, unit = SECONDS) public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); private static final String TEST_KEY = "testfile.dat"; - private static final int OBJ_SIZE = 19 * 1024 * 1024; + private static final int OBJ_SIZE = 31 * 1024 * 1024; private static File testFile; private static S3AsyncClient mpuS3Client; @@ -71,7 +77,6 @@ public static void teardown() throws Exception { } @Test - @Timeout(value = 20, unit = SECONDS) void putObject_fileRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); @@ -99,4 +104,31 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); } + @Test + @Timeout(value = 30, unit = SECONDS) + void putObject_unknownContentLength_objectSentCorrectly() throws Exception { + AsyncRequestBody body = FileAsyncRequestBody.builder() + .path(testFile.toPath()) + //.chunkSizeInBytes(2 * 1024 * 1024) + .build(); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + body.subscribe(s); + } + }).get(30, SECONDS); + + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + } diff --git a/services/s3/src/it/resources/log4j2.properties b/services/s3/src/it/resources/log4j2.properties index ea24f17148e6..adf6424b8f62 100644 --- a/services/s3/src/it/resources/log4j2.properties +++ b/services/s3/src/it/resources/log4j2.properties @@ -25,11 +25,11 @@ rootLogger.appenderRef.stdout.ref = ConsoleAppender # Uncomment below to enable more specific logging # -#logger.sdk.name = software.amazon.awssdk -#logger.sdk.level = debug +logger.sdk.name = software.amazon.awssdk +logger.sdk.level = debug # -#logger.request.name = software.amazon.awssdk.request -#logger.request.level = debug +logger.request.name = software.amazon.awssdk.core.internal.async.SplittingPublisher +logger.request.level = trace # #logger.apache.name = org.apache.http.wire #logger.apache.level = debug diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java new file mode 100644 index 000000000000..93b6e0c88343 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; + +import java.util.Collection; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; + +/** + * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} and {@link UploadWithKnownContentLengthHelper}. + */ +@SdkInternalApi +public abstract class BaseMultipartUploadHelper { + private static final Logger log = Logger.loggerFor(BaseMultipartUploadHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + + public BaseMultipartUploadHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + } + + CompletableFuture createMultipartUpload(PutObjectRequest putObjectRequest, CompletableFuture returnFuture) { + CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); + CompletableFuture createMultipartUploadFuture = + s3AsyncClient.createMultipartUpload(request); + + // Ensure cancellations are forwarded to the createMultipartUploadFuture future + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + return createMultipartUploadFuture; + } + + void completeMultipartUpload(CompletableFuture returnFuture, + String uploadId, + CompletedPart[] completedParts, + PutObjectRequest putObjectRequest, + Collection> futures) { + CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) + .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, + uploadId, + completedParts)) + .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, + uploadId)) + .exceptionally(throwable -> { + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); + return null; + }); + } + + static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { + log.trace(() -> "cancelling other ongoing requests " + futures.size()); + futures.forEach(f -> f.completeExceptionally(t)); + } + + static CompletedPart convertUploadPartResponse(Queue completedParts, + Integer partNumber, + UploadPartResponse uploadPartResponse) { + CompletedPart completedPart = SdkPojoConversionUtils.toCompletedPart(uploadPartResponse, partNumber); + + completedParts.add(completedPart); + return completedPart; + } + + void uploadInOneChunk(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + CompletableFuture putObjectResponseCompletableFuture = s3AsyncClient.putObject(putObjectRequest, + asyncRequestBody); + CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); + CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java index 4ab4b22a0e79..b9dda956fd4e 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -15,6 +15,8 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import java.util.Collection; +import java.util.Comparator; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReferenceArray; @@ -79,13 +81,9 @@ public int determinePartCount(long contentLength, long partSize) { } public CompletableFuture completeMultipartUpload( - RequestT request, String uploadId, AtomicReferenceArray completedParts) { + RequestT request, String uploadId, CompletedPart[] parts) { log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", uploadId)); - CompletedPart[] parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() .bucket(request.getValueForField("Bucket", String.class).get()) @@ -99,6 +97,15 @@ public CompletableFuture completeMultipartUploa return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } + public CompletableFuture completeMultipartUpload( + RequestT request, String uploadId, AtomicReferenceArray completedParts) { + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + return completeMultipartUpload(request, uploadId, parts); + } + public BiFunction handleExceptionOrResponse( RequestT request, CompletableFuture returnFuture, diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index 869eb4048144..a4b3147254f9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -34,13 +34,13 @@ public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { private static final long DEFAULT_THRESHOLD = 8L * 1024 * 1024; private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; - private final MultipartUploadHelper mpuHelper; + private final UploadObjectHelper mpuHelper; private final CopyObjectHelper copyObjectHelper; public MultipartS3AsyncClient(S3AsyncClient delegate) { super(delegate); // TODO: pass a config object to the upload helper instead - mpuHelper = new MultipartUploadHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); + mpuHelper = new UploadObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); copyObjectHelper = new CopyObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java new file mode 100644 index 000000000000..c3a9b4f84f76 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.utils.Logger; + +/** + * An internal helper class that automatically uses multipart upload based on the size of the object. + */ +@SdkInternalApi +public final class UploadObjectHelper { + private static final Logger log = Logger.loggerFor(UploadObjectHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + private final UploadWithKnownContentLengthHelper uploadWithKnownContentLength; + private final UploadWithUnknownContentLengthHelper uploadWithUnknownContentLength; + + public UploadObjectHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.uploadWithKnownContentLength = new UploadWithKnownContentLengthHelper(s3AsyncClient, + partSizeInBytes, + multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + this.uploadWithUnknownContentLength = new UploadWithUnknownContentLengthHelper(s3AsyncClient, + partSizeInBytes, + multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); + + if (contentLength == null) { + return uploadWithUnknownContentLength.uploadObject(putObjectRequest, asyncRequestBody); + } else { + return uploadWithKnownContentLength.uploadObject(putObjectRequest, asyncRequestBody, contentLength.longValue()); + } + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java similarity index 78% rename from services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java rename to services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index a3aea4a9bdf7..6a482c89e4dd 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -23,12 +23,12 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Function; +import java.util.stream.IntStream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; @@ -42,8 +42,8 @@ * An internal helper class that automatically uses multipart upload based on the size of the object. */ @SdkInternalApi -public final class MultipartUploadHelper { - private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class); +public final class UploadWithKnownContentLengthHelper extends BaseMultipartUploadHelper { + private static final Logger log = Logger.loggerFor(UploadWithKnownContentLengthHelper.class); private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; @@ -52,10 +52,11 @@ public final class MultipartUploadHelper { private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; - public MultipartUploadHelper(S3AsyncClient s3AsyncClient, - long partSizeInBytes, - long multipartUploadThresholdInBytes, - long maxMemoryUsageInBytes) { + public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + super(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -66,15 +67,8 @@ public MultipartUploadHelper(S3AsyncClient s3AsyncClient, } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, - AsyncRequestBody asyncRequestBody) { - Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); - - // TODO: support null content length. Need to determine whether to use single object or MPU based on the first - // AsyncRequestBody - if (contentLength == null) { - throw new IllegalArgumentException("Content-length is required"); - } - + AsyncRequestBody asyncRequestBody, + long contentLength) { CompletableFuture returnFuture = new CompletableFuture<>(); try { @@ -96,12 +90,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, CompletableFuture returnFuture) { - CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); - CompletableFuture createMultipartUploadFuture = - s3AsyncClient.createMultipartUpload(request); - - // Ensure cancellations are forwarded to the createMultipartUploadFuture future - CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + CompletableFuture createMultipartUploadFuture = createMultipartUpload(putObjectRequest, + returnFuture); createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { if (throwable != null) { @@ -145,32 +135,20 @@ private void doUploadInParts(Pair request, cancelingOtherOngoingRequests(futures, t); return; } - CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) - .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, - uploadId, - completedParts)) - .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, - uploadId)) - .exceptionally(throwable -> { - genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", - throwable); - return null; - }); + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, futures); }); } - private static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { - log.trace(() -> "cancelling other ongoing requests " + futures.size()); - futures.forEach(f -> f.completeExceptionally(t)); - } - private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequestContext, AtomicReferenceArray completedParts, CompletableFuture returnFuture, Collection> futures) { - AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); SplitAsyncRequestBodyResponse result = asyncRequestBody.split(mpuRequestContext.partSize, maxMemoryUsageInBytes); @@ -221,18 +199,10 @@ private static CompletedPart convertUploadPartResponse(AtomicReferenceArray returnFuture) { - CompletableFuture putObjectResponseCompletableFuture = s3AsyncClient.putObject(putObjectRequest, - asyncRequestBody); - CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); - CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); - } - private static final class BodyToRequestConverter implements Function> { - private int partNumber = 1; + + private volatile int partNumber = 1; private final PutObjectRequest putObjectRequest; private final String uploadId; @@ -270,5 +240,4 @@ private MpuRequestContext(Pair request, this.uploadId = uploadId; } } - } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java new file mode 100644 index 000000000000..9d4ab5cf02ba --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -0,0 +1,245 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * An internal helper class that uploads streams with unknown content length. + */ +@SdkInternalApi +public final class UploadWithUnknownContentLengthHelper extends BaseMultipartUploadHelper { + private static final Logger log = Logger.loggerFor(UploadWithUnknownContentLengthHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + + public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + super(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + CompletableFuture returnFuture = new CompletableFuture<>(); + Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); + + SplitAsyncRequestBodyResponse splitAsyncRequestBodyResponse = asyncRequestBody.split(partSizeInBytes, maxMemoryUsageInBytes); + + splitAsyncRequestBodyResponse.asyncRequestBodyPublisher() + .subscribe( + new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, + putObjectRequest, + returnFuture)); + + return returnFuture; + } + + private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { + /** + * Indicates whether this is the first async request body or not. + */ + private final AtomicBoolean isFirstAsyncRequestBody = new AtomicBoolean(true); + + /** + * Indicates whether CreateMultipartUpload has been initiated or not + */ + private final AtomicBoolean createMultipartUploadInitiated = new AtomicBoolean(false); + + /** + * Indicates whether CompleteMultipart has been initiated or not. + */ + private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); + + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + + private AtomicInteger partNumber = new AtomicInteger(1); + + private final Queue completedParts = new ConcurrentLinkedQueue<>(); + private final Collection> futures = new ConcurrentLinkedQueue<>(); + + private final CompletableFuture uploadIdFuture = new CompletableFuture<>(); + + private final long maximumChunkSizeInByte; + private final PutObjectRequest putObjectRequest; + private final CompletableFuture returnFuture; + private Subscription subscription; + private AsyncRequestBody firstRequestBody; + + private String uploadId; + private volatile boolean isDone; + + public UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByte, + PutObjectRequest putObjectRequest, + CompletableFuture returnFuture) { + this.maximumChunkSizeInByte = maximumChunkSizeInByte; + this.putObjectRequest = putObjectRequest; + this.returnFuture = returnFuture; + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncRequestBody asyncRequestBody) { + log.debug(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + asyncRequestBodyInFlight.incrementAndGet(); + + if (isFirstAsyncRequestBody.compareAndSet(true, false)) { + log.debug(() -> "Received first async request body"); + // If this is the first AsyncRequestBody received, request another one because we don't know if there is more + firstRequestBody = asyncRequestBody; + subscription.request(1); + log.debug(() -> "requesting"); + return; + } + + // If there are more than 1 AsyncRequestBodies, then we know we need to upload this + // object using MPU + if (createMultipartUploadInitiated.compareAndSet(false, true)) { + log.debug(() -> "Starting the upload as multipart upload request"); + CompletableFuture createMultipartUploadFuture = + createMultipartUpload(putObjectRequest, returnFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", + throwable); + subscription.cancel(); + } else { + uploadId = createMultipartUploadResponse.uploadId(); + uploadIdFuture.complete(uploadId); + log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + + sendIndividualUploadPartRequest(uploadId, completedParts, futures, uploadPart(firstRequestBody)); + sendIndividualUploadPartRequest(uploadId, completedParts, futures, uploadPart(asyncRequestBody)); + } + }); + } else { + uploadIdFuture.whenComplete((r, t) -> sendIndividualUploadPartRequest(uploadId, completedParts, futures, + uploadPart(asyncRequestBody))); + } + } + + private Pair uploadPart(AsyncRequestBody asyncRequestBody) { + UploadPartRequest uploadRequest = + SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber.getAndIncrement(), + uploadId); + return Pair.of(uploadRequest, asyncRequestBody); + } + + @Override + public void onError(Throwable t) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + returnFuture.completeExceptionally(t); + if (uploadId != null) { + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + } + cancelingOtherOngoingRequests(futures, t); + } + + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + // If CreateMultipartUpload has not been initiated at this point, we know this is a single object upload + if (createMultipartUploadInitiated.get() == false) { + log.debug(() -> "Starting the upload as a single object upload request"); + uploadInOneChunk(putObjectRequest, firstRequestBody, returnFuture); + } else { + isDone = true; + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + } + } + + private void sendIndividualUploadPartRequest(String uploadId, + Queue completedParts, + Collection> futures, + Pair requestPair) { + UploadPartRequest uploadPartRequest = requestPair.left(); + Integer partNumber = uploadPartRequest.partNumber(); + log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " + + "contentLength " + requestPair.right().contentLength()); + + CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); + + CompletableFuture convertFuture = + uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, + uploadPartResponse)) + .whenComplete((r, t) -> { + int numRequests = asyncRequestBodyInFlight.decrementAndGet(); + completeMultipartUploadIfFinish(numRequests); + }); + futures.add(convertFuture); + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); + synchronized (this) { + subscription.request(1); + } + } + + private void completeMultipartUploadIfFinish(int requestsInFlight) { + if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { + CompletedPart[] parts = completedParts.stream() + .sorted(Comparator.comparingInt(CompletedPart::partNumber)) + .toArray(CompletedPart[]::new); + completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, futures); + } + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java index 1ea17d4ba967..28b48c07c68b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java @@ -28,9 +28,7 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -70,7 +68,7 @@ public class MultipartUploadHelperTest { private static final String UPLOAD_ID = "1234"; private static RandomTempFile testFile; - private MultipartUploadHelper uploadHelper; + private UploadObjectHelper uploadHelper; private S3AsyncClient s3AsyncClient; @BeforeAll @@ -86,7 +84,7 @@ public static void afterAll() throws Exception { @BeforeEach public void beforeEach() { s3AsyncClient = Mockito.mock(S3AsyncClient.class); - uploadHelper = new MultipartUploadHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); + uploadHelper = new UploadObjectHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); } @ParameterizedTest From 45ced32c42ba8192cb58c16ecd034aef54f5f099 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Tue, 25 Jul 2023 11:41:13 -0700 Subject: [PATCH 2/2] Refactoring --- .../awssdk/core/async/AsyncRequestBody.java | 19 +- .../async/SplitAsyncRequestBodyResponse.java | 80 ------- .../internal/async/SplittingPublisher.java | 65 ++--- .../SplitAsyncRequestBodyResponseTest.java | 29 --- .../async/SplittingPublisherTest.java | 29 --- ...ltipartClientPutObjectIntegrationTest.java | 5 +- .../s3/src/it/resources/log4j2.properties | 8 +- .../multipart/GenericMultipartHelper.java | 3 +- ...Helper.java => MultipartUploadHelper.java} | 72 ++++-- .../multipart/UploadObjectHelper.java | 3 - .../UploadWithKnownContentLengthHelper.java | 226 +++++++++--------- .../UploadWithUnknownContentLengthHelper.java | 124 +++++----- ...rTest.java => UploadObjectHelperTest.java} | 181 +++++++++++++- 13 files changed, 435 insertions(+), 409 deletions(-) delete mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java delete mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java rename services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/{BaseMultipartUploadHelper.java => MultipartUploadHelper.java} (60%) rename services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/{MultipartUploadHelperTest.java => UploadObjectHelperTest.java} (58%) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 3bd3d7136d47..3c6adb8fdbac 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -24,7 +24,6 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -420,24 +419,20 @@ static AsyncRequestBody empty() { * @param maxMemoryUsageInBytes the max memory the SDK will use to buffer the content * @return SplitAsyncRequestBodyResult */ - default SplitAsyncRequestBodyResponse split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { + default SdkPublisher split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { Validate.isPositive(chunkSizeInBytes, "chunkSizeInBytes"); Validate.isPositive(maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); - if (!this.contentLength().isPresent()) { + if (!contentLength().isPresent()) { Validate.isTrue(maxMemoryUsageInBytes >= chunkSizeInBytes, "maxMemoryUsageInBytes must be larger than or equal to " + "chunkSizeInBytes if the content length is unknown"); } - CompletableFuture future = new CompletableFuture<>(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .asyncRequestBody(this) - .chunkSizeInBytes(chunkSizeInBytes) - .maxMemoryUsageInBytes(maxMemoryUsageInBytes) - .resultFuture(future) - .build(); - - return SplitAsyncRequestBodyResponse.create(splittingPublisher, future); + return SplittingPublisher.builder() + .asyncRequestBody(this) + .chunkSizeInBytes(chunkSizeInBytes) + .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .build(); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java deleted file mode 100644 index 0035c87520ec..000000000000 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.core.async; - - -import java.util.concurrent.CompletableFuture; -import software.amazon.awssdk.annotations.SdkPublicApi; -import software.amazon.awssdk.utils.Validate; - -/** - * Containing the result from {@link AsyncRequestBody#split(long, long)} - */ -@SdkPublicApi -public final class SplitAsyncRequestBodyResponse { - private final SdkPublisher asyncRequestBody; - private final CompletableFuture future; - - private SplitAsyncRequestBodyResponse(SdkPublisher asyncRequestBody, CompletableFuture future) { - this.asyncRequestBody = Validate.paramNotNull(asyncRequestBody, "asyncRequestBody"); - this.future = Validate.paramNotNull(future, "future"); - } - - public static SplitAsyncRequestBodyResponse create(SdkPublisher asyncRequestBody, - CompletableFuture future) { - return new SplitAsyncRequestBodyResponse(asyncRequestBody, future); - } - - /** - * Returns the converted {@link SdkPublisher} of {@link AsyncRequestBody}s. Each {@link AsyncRequestBody} publishes a specific - * portion of the original data. - */ - public SdkPublisher asyncRequestBodyPublisher() { - return asyncRequestBody; - } - - /** - * Returns {@link CompletableFuture} that will be notified when all data has been consumed or if an error occurs. - */ - public CompletableFuture future() { - return future; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - SplitAsyncRequestBodyResponse that = (SplitAsyncRequestBodyResponse) o; - - if (!asyncRequestBody.equals(that.asyncRequestBody)) { - return false; - } - return future.equals(that.future); - } - - @Override - public int hashCode() { - int result = asyncRequestBody.hashCode(); - result = 31 * result + future.hashCode(); - return result; - } -} - diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 27aaadc5f7a1..e18f9944a09e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -17,7 +17,6 @@ import java.nio.ByteBuffer; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -45,24 +44,12 @@ public class SplittingPublisher implements SdkPublisher { private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); private final long chunkSizeInBytes; private final long maxMemoryUsageInBytes; - private final CompletableFuture future; private SplittingPublisher(Builder builder) { this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); this.chunkSizeInBytes = Validate.isPositive(builder.chunkSizeInBytes, "chunkSizeInBytes"); this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); this.maxMemoryUsageInBytes = Validate.isPositive(builder.maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); - this.future = builder.future; - - // We need to cancel upstream subscription if the future gets cancelled. - future.whenComplete((r, t) -> { - if (t != null) { - if (splittingSubscriber.upstreamSubscription != null) { - log.trace(() -> "Cancelling subscription because return future completed exceptionally ", t); - splittingSubscriber.upstreamSubscription.cancel(); - } - } - }); } public static Builder builder() { @@ -117,16 +104,20 @@ public void onNext(ByteBuffer byteBuffer) { byteBufferSizeHint = byteBuffer.remaining(); while (true) { + + if (!byteBuffer.hasRemaining()) { + break; + } + int amountRemainingInChunk = amountRemainingInChunk(); // If we have fulfilled this chunk, // complete the current body if (amountRemainingInChunk == 0) { completeCurrentBodyAndCreateNewIfNeeded(byteBuffer); + amountRemainingInChunk = amountRemainingInChunk(); } - amountRemainingInChunk = amountRemainingInChunk(); - // If the current ByteBuffer < this chunk, send it as-is if (amountRemainingInChunk > byteBuffer.remaining()) { currentBody.send(byteBuffer.duplicate()); @@ -154,21 +145,20 @@ public void onNext(ByteBuffer byteBuffer) { private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { completeCurrentBody(); + int currentChunk = chunkNumber.incrementAndGet(); + boolean shouldCreateNewDownstreamRequestBody; + Long dataRemaining = totalDataRemaining(); - if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { - int currentChunk = chunkNumber.incrementAndGet(); - long chunkSize = calculateChunkSize(totalDataRemaining()); - currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + if (upstreamSize == null) { + shouldCreateNewDownstreamRequestBody = !upstreamComplete || byteBuffer.hasRemaining(); + } else { + shouldCreateNewDownstreamRequestBody = dataRemaining != null && dataRemaining > 0; } - } - - /** - * If content length is known, we should create new DownstreamRequestBody if there's remaining data. - * If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet. - */ - private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) { - return !upstreamComplete || byteBuffer.remaining() > 0; + if (shouldCreateNewDownstreamRequestBody) { + long chunkSize = calculateChunkSize(dataRemaining); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + } } private int amountRemainingInChunk() { @@ -176,7 +166,7 @@ private int amountRemainingInChunk() { } private void completeCurrentBody() { - log.debug(() -> "completeCurrentBody"); + log.debug(() -> "completeCurrentBody for chunk " + chunkNumber.get()); currentBody.complete(); if (upstreamSize == null) { sendCurrentBody(currentBody); @@ -188,16 +178,16 @@ public void onComplete() { upstreamComplete = true; log.trace(() -> "Received onComplete()"); completeCurrentBody(); - downstreamPublisher.complete().thenRun(() -> future.complete(null)); + downstreamPublisher.complete(); } @Override public void onError(Throwable t) { - currentBody.error(t); + log.trace(() -> "Received onError()", t); + downstreamPublisher.error(t); } private void sendCurrentBody(AsyncRequestBody body) { - log.debug(() -> "sendCurrentBody"); downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); return null; @@ -300,7 +290,6 @@ public static final class Builder { private AsyncRequestBody asyncRequestBody; private Long chunkSizeInBytes; private Long maxMemoryUsageInBytes; - private CompletableFuture future; /** * Configures the asyncRequestBody to split @@ -339,18 +328,6 @@ public Builder maxMemoryUsageInBytes(long maxMemoryUsageInBytes) { return this; } - /** - * Sets the result future. The future will be completed when all request bodies - * have been sent. - * - * @param future The new future value. - * @return This object for method chaining. - */ - public Builder resultFuture(CompletableFuture future) { - this.future = future; - return this; - } - public SplittingPublisher build() { return new SplittingPublisher(this); } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java deleted file mode 100644 index 2d1e50bcd59d..000000000000 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.core.async; - -import nl.jqno.equalsverifier.EqualsVerifier; -import org.junit.jupiter.api.Test; - -public class SplitAsyncRequestBodyResponseTest { - - @Test - void equalsHashcode() { - EqualsVerifier.forClass(SplitAsyncRequestBodyResponse.class) - .withNonnullFields("asyncRequestBody", "future") - .verify(); - } -} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 3ce8559eec32..368c403dbf88 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -31,10 +31,7 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; @@ -45,7 +42,6 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.BinaryUtils; public class SplittingPublisherTest { @@ -87,26 +83,6 @@ void differentChunkSize_byteArrayShouldSplitAsyncRequestBodyCorrectly(int chunkS verifySplitContent(AsyncRequestBody.fromBytes(CONTENT), chunkSize); } - - @Test - void cancelFuture_shouldCancelUpstream() throws IOException { - CompletableFuture future = new CompletableFuture<>(); - TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) - .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes(CHUNK_SIZE) - .maxMemoryUsageInBytes(10L) - .build(); - - OnlyRequestOnceSubscriber downstreamSubscriber = new OnlyRequestOnceSubscriber(); - splittingPublisher.subscribe(downstreamSubscriber); - - future.completeExceptionally(new RuntimeException("test")); - assertThat(asyncRequestBody.cancelled).isTrue(); - assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); - } - @Test void contentLengthNotPresent_shouldHandle() throws Exception { CompletableFuture future = new CompletableFuture<>(); @@ -117,7 +93,6 @@ public Optional contentLength() { } }; SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) .asyncRequestBody(asyncRequestBody) .chunkSizeInBytes(CHUNK_SIZE) .maxMemoryUsageInBytes(10L) @@ -159,11 +134,8 @@ public Optional contentLength() { private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { - CompletableFuture future = new CompletableFuture<>(); SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) .asyncRequestBody(asyncRequestBody) - .resultFuture(future) .chunkSizeInBytes(chunkSize) .maxMemoryUsageInBytes((long) chunkSize * 4) .build(); @@ -194,7 +166,6 @@ private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int ch assertThat(actualBytes).isEqualTo(expected); }; } - assertThat(future).isCompleted(); } private static class TestAsyncRequestBody implements AsyncRequestBody { diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index 616cd794418b..cb72906943b9 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -52,7 +52,7 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); private static final String TEST_KEY = "testfile.dat"; - private static final int OBJ_SIZE = 31 * 1024 * 1024; + private static final int OBJ_SIZE = 19 * 1024 * 1024; private static File testFile; private static S3AsyncClient mpuS3Client; @@ -90,7 +90,6 @@ void putObject_fileRequestBody_objectSentCorrectly() throws Exception { } @Test - @Timeout(value = 30, unit = SECONDS) void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { byte[] bytes = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); @@ -105,11 +104,9 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { } @Test - @Timeout(value = 30, unit = SECONDS) void putObject_unknownContentLength_objectSentCorrectly() throws Exception { AsyncRequestBody body = FileAsyncRequestBody.builder() .path(testFile.toPath()) - //.chunkSizeInBytes(2 * 1024 * 1024) .build(); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), new AsyncRequestBody() { @Override diff --git a/services/s3/src/it/resources/log4j2.properties b/services/s3/src/it/resources/log4j2.properties index adf6424b8f62..ea24f17148e6 100644 --- a/services/s3/src/it/resources/log4j2.properties +++ b/services/s3/src/it/resources/log4j2.properties @@ -25,11 +25,11 @@ rootLogger.appenderRef.stdout.ref = ConsoleAppender # Uncomment below to enable more specific logging # -logger.sdk.name = software.amazon.awssdk -logger.sdk.level = debug +#logger.sdk.name = software.amazon.awssdk +#logger.sdk.level = debug # -logger.request.name = software.amazon.awssdk.core.internal.async.SplittingPublisher -logger.request.level = trace +#logger.request.name = software.amazon.awssdk.request +#logger.request.level = debug # #logger.apache.name = org.apache.http.wire #logger.apache.level = debug diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java index b9dda956fd4e..905c1bc928ea 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -15,8 +15,6 @@ package software.amazon.awssdk.services.s3.internal.multipart; -import java.util.Collection; -import java.util.Comparator; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReferenceArray; @@ -126,6 +124,7 @@ public BiFunction handleExcept } public void cleanUpParts(String uploadId, AbortMultipartUploadRequest.Builder abortMultipartUploadRequest) { + log.debug(() -> "Aborting multipart upload: " + uploadId); s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest.uploadId(uploadId).build()) .exceptionally(throwable -> { log.warn(() -> String.format("Failed to abort previous multipart upload " diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java similarity index 60% rename from services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java rename to services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index 93b6e0c88343..1228e577fcd1 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/BaseMultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -19,8 +19,8 @@ import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; import java.util.Collection; -import java.util.Queue; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -29,16 +29,19 @@ import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; /** - * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} and {@link UploadWithKnownContentLengthHelper}. + * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} + * and {@link UploadWithKnownContentLengthHelper}. */ @SdkInternalApi -public abstract class BaseMultipartUploadHelper { - private static final Logger log = Logger.loggerFor(BaseMultipartUploadHelper.class); +public final class MultipartUploadHelper { + private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class); private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; @@ -47,10 +50,10 @@ public abstract class BaseMultipartUploadHelper { private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; - public BaseMultipartUploadHelper(S3AsyncClient s3AsyncClient, - long partSizeInBytes, - long multipartUploadThresholdInBytes, - long maxMemoryUsageInBytes) { + public MultipartUploadHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -60,7 +63,8 @@ public BaseMultipartUploadHelper(S3AsyncClient s3AsyncClient, this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; } - CompletableFuture createMultipartUpload(PutObjectRequest putObjectRequest, CompletableFuture returnFuture) { + CompletableFuture createMultipartUpload(PutObjectRequest putObjectRequest, + CompletableFuture returnFuture) { CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); @@ -73,12 +77,10 @@ CompletableFuture createMultipartUpload(PutObject void completeMultipartUpload(CompletableFuture returnFuture, String uploadId, CompletedPart[] completedParts, - PutObjectRequest putObjectRequest, - Collection> futures) { - CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) - .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, - uploadId, - completedParts)) + PutObjectRequest putObjectRequest) { + genericMultipartHelper.completeMultipartUpload(putObjectRequest, + uploadId, + completedParts) .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, uploadId)) .exceptionally(throwable -> { @@ -88,17 +90,49 @@ void completeMultipartUpload(CompletableFuture returnFuture, }); } + CompletableFuture sendIndividualUploadPartRequest(String uploadId, + Consumer completedPartsConsumer, + Collection> futures, + Pair requestPair) { + UploadPartRequest uploadPartRequest = requestPair.left(); + Integer partNumber = uploadPartRequest.partNumber(); + log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " + + "contentLength " + requestPair.right().contentLength()); + + CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, + requestPair.right()); + + CompletableFuture convertFuture = + uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedPartsConsumer, partNumber, + uploadPartResponse)); + futures.add(convertFuture); + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); + return convertFuture; + } + + void failRequestsElegantly(Collection> futures, + Throwable t, + String uploadId, + CompletableFuture returnFuture, + PutObjectRequest putObjectRequest) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + if (uploadId != null) { + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + } + cancelingOtherOngoingRequests(futures, t); + } + static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { log.trace(() -> "cancelling other ongoing requests " + futures.size()); futures.forEach(f -> f.completeExceptionally(t)); } - static CompletedPart convertUploadPartResponse(Queue completedParts, - Integer partNumber, - UploadPartResponse uploadPartResponse) { + static CompletedPart convertUploadPartResponse(Consumer consumer, + Integer partNumber, + UploadPartResponse uploadPartResponse) { CompletedPart completedPart = SdkPojoConversionUtils.toCompletedPart(uploadPartResponse, partNumber); - completedParts.add(completedPart); + consumer.accept(completedPart); return completedPart; } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java index c3a9b4f84f76..0700e8ade5f9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java @@ -15,9 +15,6 @@ package software.amazon.awssdk.services.s3.internal.multipart; - -import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; - import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 6a482c89e4dd..e8bef01ab81b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -16,25 +16,24 @@ package software.amazon.awssdk.services.s3.internal.multipart; -import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; - import java.util.Collection; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.Function; +import java.util.function.Consumer; import java.util.stream.IntStream; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Pair; @@ -42,7 +41,7 @@ * An internal helper class that automatically uses multipart upload based on the size of the object. */ @SdkInternalApi -public final class UploadWithKnownContentLengthHelper extends BaseMultipartUploadHelper { +public final class UploadWithKnownContentLengthHelper { private static final Logger log = Logger.loggerFor(UploadWithKnownContentLengthHelper.class); private final S3AsyncClient s3AsyncClient; @@ -51,12 +50,12 @@ public final class UploadWithKnownContentLengthHelper extends BaseMultipartUploa private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; + private final MultipartUploadHelper multipartUploadHelper; public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long multipartUploadThresholdInBytes, long maxMemoryUsageInBytes) { - super(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -64,6 +63,8 @@ public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, @@ -77,7 +78,7 @@ public CompletableFuture uploadObject(PutObjectRequest putObj uploadInParts(putObjectRequest, contentLength, asyncRequestBody, returnFuture); } else { log.debug(() -> "Starting the upload as a single upload part request"); - uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); + multipartUploadHelper.uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); } } catch (Throwable throwable) { @@ -90,8 +91,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, CompletableFuture returnFuture) { - CompletableFuture createMultipartUploadFuture = createMultipartUpload(putObjectRequest, - returnFuture); + CompletableFuture createMultipartUploadFuture = + multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { if (throwable != null) { @@ -115,129 +116,136 @@ private void doUploadInParts(Pair request, log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, optimalPartSize)); - // The list of completed parts must be sorted - AtomicReferenceArray completedParts = new AtomicReferenceArray<>(partCount); - - PutObjectRequest putObjectRequest = request.left(); - - Collection> futures = new ConcurrentLinkedQueue<>(); - MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); - CompletableFuture requestsFuture = sendUploadPartRequests(mpuRequestContext, - completedParts, - returnFuture, - futures); - requestsFuture.whenComplete((r, t) -> { - if (t != null) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); - genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); - cancelingOtherOngoingRequests(futures, t); - return; - } - CompletedPart[] parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); - completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, futures); - }); + request.right() + .split(mpuRequestContext.partSize, maxMemoryUsageInBytes) + .subscribe(new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, + returnFuture)); } - private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequestContext, - AtomicReferenceArray completedParts, - CompletableFuture returnFuture, - Collection> futures) { - + private static final class MpuRequestContext { + private final Pair request; + private final long contentLength; + private final long partSize; - AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); + private final String uploadId; - SplitAsyncRequestBodyResponse result = asyncRequestBody.split(mpuRequestContext.partSize, maxMemoryUsageInBytes); + private MpuRequestContext(Pair request, + long contentLength, + long partSize, + String uploadId) { + this.request = request; + this.contentLength = contentLength; + this.partSize = partSize; + this.uploadId = uploadId; + } + } - CompletableFuture splittingPublisherFuture = result.future(); + private class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { - result.asyncRequestBodyPublisher() - .map(new BodyToRequestConverter(mpuRequestContext.request.left(), - mpuRequestContext.uploadId)) - .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, - completedParts, - futures, - pair, - splittingPublisherFuture)) - .exceptionally(throwable -> { - returnFuture.completeExceptionally(throwable); - return null; - }); - return splittingPublisherFuture; - } + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); - private void sendIndividualUploadPartRequest(String uploadId, - AtomicReferenceArray completedParts, - Collection> futures, - Pair requestPair, - CompletableFuture sendUploadPartRequestsFuture) { - UploadPartRequest uploadPartRequest = requestPair.left(); - Integer partNumber = uploadPartRequest.partNumber(); - log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " - + "contentLength " + requestPair.right().contentLength()); - - CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); - - CompletableFuture convertFuture = - uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, - uploadPartResponse)); - futures.add(convertFuture); - CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); - CompletableFutureUtils.forwardExceptionTo(uploadPartFuture, sendUploadPartRequestsFuture); - } + /** + * Indicates whether CompleteMultipart has been initiated or not. + */ + private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); - private static CompletedPart convertUploadPartResponse(AtomicReferenceArray completedParts, - Integer partNumber, - UploadPartResponse uploadPartResponse) { - CompletedPart completedPart = SdkPojoConversionUtils.toCompletedPart(uploadPartResponse, partNumber); + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); - completedParts.set(partNumber - 1, completedPart); - return completedPart; - } + private final AtomicInteger partNumber = new AtomicInteger(1); - private static final class BodyToRequestConverter implements Function> { + private final AtomicReferenceArray completedParts; + private final String uploadId; + private final Collection> futures = new ConcurrentLinkedQueue<>(); - private volatile int partNumber = 1; private final PutObjectRequest putObjectRequest; - private final String uploadId; + private final CompletableFuture returnFuture; + private Subscription subscription; + + private volatile boolean isDone; + + KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, + CompletableFuture returnFuture) { + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(mpuRequestContext.contentLength, + partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(mpuRequestContext.contentLength, optimalPartSize); + this.putObjectRequest = mpuRequestContext.request.left(); + this.returnFuture = returnFuture; + this.completedParts = new AtomicReferenceArray<>(partCount); + this.uploadId = mpuRequestContext.uploadId; + } - BodyToRequestConverter(PutObjectRequest putObjectRequest, String uploadId) { - this.putObjectRequest = putObjectRequest; - this.uploadId = uploadId; + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } + this.subscription = s; + s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + multipartUploadHelper.cancelingOtherOngoingRequests(futures, t); + } + }); } @Override - public Pair apply(AsyncRequestBody asyncRequestBody) { - log.trace(() -> "Generating uploadPartRequest for partNumber " + partNumber); + public void onNext(AsyncRequestBody asyncRequestBody) { + log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + asyncRequestBodyInFlight.incrementAndGet(); UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, - partNumber, + partNumber.getAndIncrement(), uploadId); - ++partNumber; - return Pair.of(uploadRequest, asyncRequestBody); + + Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, + completedPart); + multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, + Pair.of(uploadRequest, asyncRequestBody)) + .whenComplete((r, t) -> { + if (t != null) { + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, + putObjectRequest); + } + } else { + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); + } + }); + subscription.request(1); } - } - private static final class MpuRequestContext { - private final Pair request; - private final long contentLength; - private final long partSize; + @Override + public void onError(Throwable t) { + log.debug(() -> "Received onError ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } - private final String uploadId; + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + isDone = true; + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + } - private MpuRequestContext(Pair request, - long contentLength, - long partSize, - String uploadId) { - this.request = request; - this.contentLength = contentLength; - this.partSize = partSize; - this.uploadId = uploadId; + private void completeMultipartUploadIfFinish(int requestsInFlight) { + if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest); + } } + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index 9d4ab5cf02ba..d2034b4b4e94 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -16,8 +16,6 @@ package software.amazon.awssdk.services.s3.internal.multipart; -import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; - import java.util.Collection; import java.util.Comparator; import java.util.Queue; @@ -29,14 +27,13 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; +import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Pair; @@ -45,7 +42,7 @@ * An internal helper class that uploads streams with unknown content length. */ @SdkInternalApi -public final class UploadWithUnknownContentLengthHelper extends BaseMultipartUploadHelper { +public final class UploadWithUnknownContentLengthHelper { private static final Logger log = Logger.loggerFor(UploadWithUnknownContentLengthHelper.class); private final S3AsyncClient s3AsyncClient; @@ -55,11 +52,12 @@ public final class UploadWithUnknownContentLengthHelper extends BaseMultipartUpl private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; + private final MultipartUploadHelper multipartUploadHelper; + public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long multipartUploadThresholdInBytes, long maxMemoryUsageInBytes) { - super(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -67,21 +65,21 @@ public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, AsyncRequestBody asyncRequestBody) { CompletableFuture returnFuture = new CompletableFuture<>(); - Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); - - SplitAsyncRequestBodyResponse splitAsyncRequestBodyResponse = asyncRequestBody.split(partSizeInBytes, maxMemoryUsageInBytes); - splitAsyncRequestBodyResponse.asyncRequestBodyPublisher() - .subscribe( - new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, - putObjectRequest, - returnFuture)); + SdkPublisher splitAsyncRequestBodyResponse = + asyncRequestBody.split(partSizeInBytes, + maxMemoryUsageInBytes); + splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, + putObjectRequest, + returnFuture)); return returnFuture; } @@ -106,6 +104,8 @@ private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscrib */ private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + private AtomicInteger partNumber = new AtomicInteger(1); private final Queue completedParts = new ConcurrentLinkedQueue<>(); @@ -122,9 +122,9 @@ private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscrib private String uploadId; private volatile boolean isDone; - public UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByte, - PutObjectRequest putObjectRequest, - CompletableFuture returnFuture) { + UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByte, + PutObjectRequest putObjectRequest, + CompletableFuture returnFuture) { this.maximumChunkSizeInByte = maximumChunkSizeInByte; this.putObjectRequest = putObjectRequest; this.returnFuture = returnFuture; @@ -132,21 +132,31 @@ public UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByt @Override public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } this.subscription = s; s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + multipartUploadHelper.cancelingOtherOngoingRequests(futures, t); + } + }); } @Override public void onNext(AsyncRequestBody asyncRequestBody) { - log.debug(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); asyncRequestBodyInFlight.incrementAndGet(); if (isFirstAsyncRequestBody.compareAndSet(true, false)) { - log.debug(() -> "Received first async request body"); + log.trace(() -> "Received first async request body"); // If this is the first AsyncRequestBody received, request another one because we don't know if there is more firstRequestBody = asyncRequestBody; subscription.request(1); - log.debug(() -> "requesting"); return; } @@ -155,7 +165,7 @@ public void onNext(AsyncRequestBody asyncRequestBody) { if (createMultipartUploadInitiated.compareAndSet(false, true)) { log.debug(() -> "Starting the upload as multipart upload request"); CompletableFuture createMultipartUploadFuture = - createMultipartUpload(putObjectRequest, returnFuture); + multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { if (throwable != null) { @@ -165,18 +175,37 @@ public void onNext(AsyncRequestBody asyncRequestBody) { } else { uploadId = createMultipartUploadResponse.uploadId(); uploadIdFuture.complete(uploadId); - log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); - sendIndividualUploadPartRequest(uploadId, completedParts, futures, uploadPart(firstRequestBody)); - sendIndividualUploadPartRequest(uploadId, completedParts, futures, uploadPart(asyncRequestBody)); + sendUploadPartRequest(uploadId, firstRequestBody); + sendUploadPartRequest(uploadId, asyncRequestBody); } }); + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); } else { - uploadIdFuture.whenComplete((r, t) -> sendIndividualUploadPartRequest(uploadId, completedParts, futures, - uploadPart(asyncRequestBody))); + uploadIdFuture.whenComplete((r, t) -> { + sendUploadPartRequest(uploadId, asyncRequestBody); + }); } } + private void sendUploadPartRequest(String uploadId, AsyncRequestBody asyncRequestBody) { + multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, + uploadPart(asyncRequestBody)) + .whenComplete((r, t) -> { + if (t != null) { + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } else { + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); + } + }); + synchronized (this) { + subscription.request(1); + }; + } + private Pair uploadPart(AsyncRequestBody asyncRequestBody) { UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, @@ -187,12 +216,10 @@ private Pair uploadPart(AsyncRequestBody as @Override public void onError(Throwable t) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); - returnFuture.completeExceptionally(t); - if (uploadId != null) { - genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + log.debug(() -> "Received onError() ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); } - cancelingOtherOngoingRequests(futures, t); } @Override @@ -201,44 +228,19 @@ public void onComplete() { // If CreateMultipartUpload has not been initiated at this point, we know this is a single object upload if (createMultipartUploadInitiated.get() == false) { log.debug(() -> "Starting the upload as a single object upload request"); - uploadInOneChunk(putObjectRequest, firstRequestBody, returnFuture); + multipartUploadHelper.uploadInOneChunk(putObjectRequest, firstRequestBody, returnFuture); } else { isDone = true; completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); } } - private void sendIndividualUploadPartRequest(String uploadId, - Queue completedParts, - Collection> futures, - Pair requestPair) { - UploadPartRequest uploadPartRequest = requestPair.left(); - Integer partNumber = uploadPartRequest.partNumber(); - log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " - + "contentLength " + requestPair.right().contentLength()); - - CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); - - CompletableFuture convertFuture = - uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, - uploadPartResponse)) - .whenComplete((r, t) -> { - int numRequests = asyncRequestBodyInFlight.decrementAndGet(); - completeMultipartUploadIfFinish(numRequests); - }); - futures.add(convertFuture); - CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); - synchronized (this) { - subscription.request(1); - } - } - private void completeMultipartUploadIfFinish(int requestsInFlight) { if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { - CompletedPart[] parts = completedParts.stream() - .sorted(Comparator.comparingInt(CompletedPart::partNumber)) - .toArray(CompletedPart[]::new); - completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, futures); + CompletedPart[] parts = completedParts.stream() + .sorted(Comparator.comparingInt(CompletedPart::partNumber)) + .toArray(CompletedPart[]::new); + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest); } } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java similarity index 58% rename from services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java rename to services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java index 28b48c07c68b..11d54a73fb72 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java @@ -26,20 +26,28 @@ import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.OngoingStubbing; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -56,7 +64,7 @@ import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.CompletableFutureUtils; -public class MultipartUploadHelperTest { +public class UploadObjectHelperTest { private static final String BUCKET = "bucket"; private static final String KEY = "key"; @@ -81,6 +89,11 @@ public static void afterAll() throws Exception { testFile.delete(); } + public static Stream asyncRequestBody() { + return Stream.of(new UnknownContentLengthAsyncRequestBody(AsyncRequestBody.fromFile(testFile)), + AsyncRequestBody.fromFile(testFile)); + } + @BeforeEach public void beforeEach() { s3AsyncClient = Mockito.mock(S3AsyncClient.class); @@ -89,7 +102,7 @@ public void beforeEach() { @ParameterizedTest @ValueSource(longs = {THRESHOLD, PART_SIZE, THRESHOLD - 1, PART_SIZE - 1}) - public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { + void uploadObject_contentLengthDoesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { PutObjectRequest putObjectRequest = putObjectRequest(contentLength); AsyncRequestBody asyncRequestBody = Mockito.mock(AsyncRequestBody.class); @@ -100,15 +113,31 @@ public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChun Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); } - @Test - public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() { + @ParameterizedTest + @ValueSource(longs = {PART_SIZE, PART_SIZE - 1}) + void uploadObject_unKnownContentLengthDoesNotExceedPartSize_shouldUploadInOneChunk(long contentLength) { + PutObjectRequest putObjectRequest = putObjectRequest(contentLength); + AsyncRequestBody asyncRequestBody = + new UnknownContentLengthAsyncRequestBody(AsyncRequestBody.fromBytes(RandomStringUtils.randomAscii(Math.toIntExact(contentLength)) + .getBytes(StandardCharsets.UTF_8))); + + CompletableFuture completedFuture = + CompletableFuture.completedFuture(PutObjectResponse.builder().build()); + when(s3AsyncClient.putObject(putObjectRequest, asyncRequestBody)).thenReturn(completedFuture); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); + Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); stubSuccessfulUploadPartCalls(); stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); - uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)).join(); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); verify(s3AsyncClient, times(4)).uploadPart(requestArgumentCaptor.capture(), @@ -137,8 +166,9 @@ public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() /** * The second part failed, it should cancel ongoing part(first part). */ - @Test - void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { + @ParameterizedTest + @MethodSource("asyncRequestBody") + void mpu_onePartFailed_shouldFailOtherPartsAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(MPU_CONTENT_SIZE); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); @@ -155,7 +185,7 @@ void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, - AsyncRequestBody.fromFile(testFile)); + asyncRequestBody); assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); @@ -170,12 +200,16 @@ void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { ongoingRequest.get(1, TimeUnit.MILLISECONDS); fail("no exception thrown"); } catch (Exception e) { - assertThat(e.getCause()).hasMessageContaining("request failed"); + assertThat(e.getCause()).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); } } + /** + * This test is not parameterized because for unknown content length, the progress is nondeterministic. For example, we + * don't know if it has created multipart upload when we cancel the future. + */ @Test - void upload_cancelResponseFuture_shouldPropagate() { + void upload_knownContentLengthCancelResponseFuture_shouldCancelCreateMultipart() { PutObjectRequest putObjectRequest = putObjectRequest(null); CompletableFuture createMultipartFuture = new CompletableFuture<>(); @@ -192,7 +226,48 @@ void upload_cancelResponseFuture_shouldPropagate() { } @Test - public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { + void upload_knownContentLengthCancelResponseFuture_shouldCancelUploadPart() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + CompletableFuture createMultipartFuture = new CompletableFuture<>(); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + + CompletableFuture ongoingRequest = new CompletableFuture<>(); + + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), + any(AsyncRequestBody.class))).thenReturn(ongoingRequest); + + CompletableFuture future = + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + + future.cancel(true); + + assertThat(ongoingRequest).isCancelled(); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_createMultipartUploadFailed_shouldFail(AsyncRequestBody asyncRequestBody) { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); + + CompletableFuture createMultipartUploadFuture = + CompletableFutureUtils.failedFuture(exception); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartUploadFuture); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + asyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to initiate multipart upload") + .hasRootCause(exception); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_completeMultipartFailed_shouldFailAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); @@ -209,8 +284,31 @@ public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); - CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + asyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests") + .hasRootCause(exception); + } + + @ParameterizedTest() + @ValueSource(booleans = {false, true}) + void uploadObject_requestBodyOnError_shouldFailAndAbort(boolean contentLengthKnown) { + PutObjectRequest putObjectRequest = putObjectRequest(null); + Exception exception = new RuntimeException("error"); + + Long contentLength = contentLengthKnown ? MPU_CONTENT_SIZE : null; + ErroneousAsyncRequestBody erroneousAsyncRequestBody = + new ErroneousAsyncRequestBody(contentLength, exception); + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + erroneousAsyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests") + .hasRootCause(exception); } private static PutObjectRequest putObjectRequest(Long contentLength) { @@ -254,4 +352,61 @@ public CompletableFuture answer(InvocationOnMock invocationO }); } + private static class UnknownContentLengthAsyncRequestBody implements AsyncRequestBody { + private final AsyncRequestBody delegate; + private volatile boolean cancelled; + + public UnknownContentLengthAsyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.delegate = asyncRequestBody; + } + + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + } + + private static class ErroneousAsyncRequestBody implements AsyncRequestBody { + private volatile boolean isDone; + private final Long contentLength; + private final Exception exception; + + private ErroneousAsyncRequestBody(Long contentLength, Exception exception) { + this.contentLength = contentLength; + this.exception = exception; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(contentLength); + } + + + @Override + public void subscribe(Subscriber s) { + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + if (isDone) { + return; + } + isDone = true; + s.onNext(ByteBuffer.wrap(RandomStringUtils.randomAscii(Math.toIntExact(PART_SIZE)).getBytes(StandardCharsets.UTF_8))); + s.onNext(ByteBuffer.wrap(RandomStringUtils.randomAscii(Math.toIntExact(PART_SIZE)).getBytes(StandardCharsets.UTF_8))); + s.onError(exception); + + } + + @Override + public void cancel() { + } + }); + + } + } }