From bc80de0bf99b596797a907835511319cf0241eea Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Fri, 16 Jun 2023 15:14:56 -0700 Subject: [PATCH 01/13] Implement multipart upload in Java-based S3 async client (#4052) * Implement multipart upload in Java-based S3 async client Co-authored-by: Matthew Miller --- .../internal/async/SplittingPublisher.java | 298 ++++++++++++++++++ .../async/SplittingPublisherTest.java | 215 +++++++++++++ ...ltipartClientPutObjectIntegrationTest.java | 78 +++++ .../s3/internal/crt/CopyObjectHelper.java | 101 ++---- ...Utils.java => RequestConversionUtils.java} | 108 ++++++- .../crt/UploadPartCopyRequestIterable.java | 8 +- .../multipart/GenericMultipartHelper.java | 134 ++++++++ .../multipart/MultipartS3AsyncClient.java | 47 +++ .../multipart/MultipartUploadHelper.java | 274 ++++++++++++++++ .../s3/internal/crt/CopyObjectHelperTest.java | 4 +- .../crt/CopyRequestConversionUtilsTest.java | 19 +- .../s3/internal/multipart/MpuTestUtils.java | 65 ++++ .../multipart/MultipartUploadHelperTest.java | 250 +++++++++++++++ .../awssdk/utils/async/SimplePublisher.java | 2 +- 14 files changed, 1496 insertions(+), 107 deletions(-) create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java create mode 100644 services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java rename services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/{CopyRequestConversionUtils.java => RequestConversionUtils.java} (61%) create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java new file mode 100644 index 000000000000..095d69ac5e7d --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -0,0 +1,298 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.SimplePublisher; + +/** + * Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the + * original data. + * // TODO: create a default method in AsyncRequestBody for this + * // TODO: fix the case where content length is null + */ +@SdkInternalApi +public class SplittingPublisher implements SdkPublisher { + private static final Logger log = Logger.loggerFor(SplittingPublisher.class); + private final AsyncRequestBody upstreamPublisher; + private final SplittingSubscriber splittingSubscriber; + private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); + private final long chunkSizeInBytes; + private final long maxMemoryUsageInBytes; + private final CompletableFuture future; + + private SplittingPublisher(Builder builder) { + this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); + this.chunkSizeInBytes = Validate.paramNotNull(builder.chunkSizeInBytes, "chunkSizeInBytes"); + this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); + this.maxMemoryUsageInBytes = builder.maxMemoryUsageInBytes == null ? Long.MAX_VALUE : builder.maxMemoryUsageInBytes; + this.future = builder.future; + + // We need to cancel upstream subscription if the future gets cancelled. + future.whenComplete((r, t) -> { + if (t != null) { + if (splittingSubscriber.upstreamSubscription != null) { + log.trace(() -> "Cancelling subscription because return future completed exceptionally ", t); + splittingSubscriber.upstreamSubscription.cancel(); + } + } + }); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void subscribe(Subscriber downstreamSubscriber) { + downstreamPublisher.subscribe(downstreamSubscriber); + upstreamPublisher.subscribe(splittingSubscriber); + } + + private class SplittingSubscriber implements Subscriber { + private Subscription upstreamSubscription; + private final Long upstreamSize; + private final AtomicInteger chunkNumber = new AtomicInteger(0); + private volatile DownstreamBody currentBody; + private final AtomicBoolean hasOpenUpstreamDemand = new AtomicBoolean(false); + private final AtomicLong dataBuffered = new AtomicLong(0); + + /** + * A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call. + */ + private int byteBufferSizeHint; + + SplittingSubscriber(Long upstreamSize) { + this.upstreamSize = upstreamSize; + } + + @Override + public void onSubscribe(Subscription s) { + this.upstreamSubscription = s; + this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get()); + sendCurrentBody(); + // We need to request subscription *after* we set currentBody because onNext could be invoked right away. + upstreamSubscription.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + hasOpenUpstreamDemand.set(false); + byteBufferSizeHint = byteBuffer.remaining(); + + while (true) { + int amountRemainingInPart = amountRemainingInPart(); + int finalAmountRemainingInPart = amountRemainingInPart; + if (amountRemainingInPart == 0) { + currentBody.complete(); + int currentChunk = chunkNumber.incrementAndGet(); + Long partSize = calculateChunkSize(); + currentBody = new DownstreamBody(partSize, currentChunk); + sendCurrentBody(); + } + + amountRemainingInPart = amountRemainingInPart(); + if (amountRemainingInPart >= byteBuffer.remaining()) { + currentBody.send(byteBuffer.duplicate()); + break; + } + + ByteBuffer firstHalf = byteBuffer.duplicate(); + int newLimit = firstHalf.position() + amountRemainingInPart; + firstHalf.limit(newLimit); + byteBuffer.position(newLimit); + currentBody.send(firstHalf); + } + + maybeRequestMoreUpstreamData(); + } + + private int amountRemainingInPart() { + return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength); + } + + @Override + public void onComplete() { + log.trace(() -> "Received onComplete()"); + downstreamPublisher.complete().thenRun(() -> future.complete(null)); + currentBody.complete(); + } + + @Override + public void onError(Throwable t) { + currentBody.error(t); + } + + private void sendCurrentBody() { + downstreamPublisher.send(currentBody).exceptionally(t -> { + downstreamPublisher.error(t); + return null; + }); + } + + private Long calculateChunkSize() { + Long dataRemaining = dataRemaining(); + if (dataRemaining == null) { + return null; + } + + return Math.min(chunkSizeInBytes, dataRemaining); + } + + private void maybeRequestMoreUpstreamData() { + long buffered = dataBuffered.get(); + if (shouldRequestMoreData(buffered) && + hasOpenUpstreamDemand.compareAndSet(false, true)) { + log.trace(() -> "Requesting more data, current data buffered: " + buffered); + upstreamSubscription.request(1); + } + } + + private boolean shouldRequestMoreData(long buffered) { + return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; + } + + private Long dataRemaining() { + if (upstreamSize == null) { + return null; + } + return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); + } + + private class DownstreamBody implements AsyncRequestBody { + private final Long totalLength; + private final SimplePublisher delegate = new SimplePublisher<>(); + private final int chunkNumber; + private volatile long transferredLength = 0; + + private DownstreamBody(Long totalLength, int chunkNumber) { + this.totalLength = totalLength; + this.chunkNumber = chunkNumber; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(totalLength); + } + + public void send(ByteBuffer data) { + log.trace(() -> "Sending bytebuffer " + data); + int length = data.remaining(); + transferredLength += length; + addDataBuffered(length); + delegate.send(data).whenComplete((r, t) -> { + addDataBuffered(-length); + if (t != null) { + error(t); + } + }); + } + + public void complete() { + log.debug(() -> "Received complete() for chunk number: " + chunkNumber); + delegate.complete(); + } + + public void error(Throwable error) { + delegate.error(error); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + + private void addDataBuffered(int length) { + dataBuffered.addAndGet(length); + if (length < 0) { + maybeRequestMoreUpstreamData(); + } + } + } + } + + public static final class Builder { + private AsyncRequestBody asyncRequestBody; + private Long chunkSizeInBytes; + private Long maxMemoryUsageInBytes; + private CompletableFuture future; + + /** + * Configures the asyncRequestBody to split + * + * @param asyncRequestBody The new asyncRequestBody value. + * @return This object for method chaining. + */ + public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.asyncRequestBody = asyncRequestBody; + return this; + } + + /** + * Configures the size of the chunk for each {@link AsyncRequestBody} to publish + * + * @param chunkSizeInBytes The new chunkSizeInBytes value. + * @return This object for method chaining. + */ + public Builder chunkSizeInBytes(Long chunkSizeInBytes) { + this.chunkSizeInBytes = chunkSizeInBytes; + return this; + } + + /** + * Sets the maximum memory usage in bytes. By default, it uses unlimited memory. + * + * @param maxMemoryUsageInBytes The new maxMemoryUsageInBytes value. + * @return This object for method chaining. + */ + // TODO: max memory usage might not be the best name, since we may technically go a little above this limit when we add + // on a new byte buffer. But we don't know for sure what the size of a buffer we request will be (we do use the size + // for the last byte buffer as a hint), so I don't think we can have a truly accurate max. Maybe we call it minimum + // buffer size instead? + public Builder maxMemoryUsageInBytes(Long maxMemoryUsageInBytes) { + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + return this; + } + + /** + * Sets the result future. The future will be completed when all request bodies + * have been sent. + * + * @param future The new future value. + * @return This object for method chaining. + */ + public Builder resultFuture(CompletableFuture future) { + this.future = future; + return this; + } + + public SplittingPublisher build() { + return new SplittingPublisher(this); + } + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java new file mode 100644 index 000000000000..df318190b92d --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -0,0 +1,215 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; + +import java.io.ByteArrayOutputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.BinaryUtils; + +public class SplittingPublisherTest { + private static final int CHUNK_SIZE = 5; + + private static final int CONTENT_SIZE = 101; + + private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE); + + private static RandomTempFile testFile; + + @BeforeAll + public static void beforeAll() throws IOException { + testFile = new RandomTempFile("testfile.dat", CONTENT_SIZE); + } + + @AfterAll + public static void afterAll() throws Exception { + testFile.delete(); + } + + @ParameterizedTest + @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) + void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int upstreamByteBufferSize) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .chunkSizeInBytes(upstreamByteBufferSize) + .build()) + + .resultFuture(future) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes((long) CHUNK_SIZE * 4) + .build(); + + List> futures = new ArrayList<>(); + + splittingPublisher.subscribe(requestBody -> { + CompletableFuture baosFuture = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(baosFuture); + futures.add(baosFuture); + requestBody.subscribe(subscriber); + }).get(5, TimeUnit.SECONDS); + + assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK); + + for (int i = 0; i < futures.size(); i++) { + try (FileInputStream fileInputStream = new FileInputStream(testFile)) { + byte[] expected; + if (i == futures.size() - 1) { + expected = new byte[1]; + } else { + expected = new byte[5]; + } + fileInputStream.skip(i * 5); + fileInputStream.read(expected); + byte[] actualBytes = futures.get(i).join(); + assertThat(actualBytes).isEqualTo(expected); + }; + } + assertThat(future).isCompleted(); + } + + + @Test + void cancelFuture_shouldCancelUpstream() throws IOException { + CompletableFuture future = new CompletableFuture<>(); + TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes(10L) + .build(); + + OnlyRequestOnceSubscriber downstreamSubscriber = new OnlyRequestOnceSubscriber(); + splittingPublisher.subscribe(downstreamSubscriber); + + future.completeExceptionally(new RuntimeException("test")); + assertThat(asyncRequestBody.cancelled).isTrue(); + assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); + } + + private static final class TestAsyncRequestBody implements AsyncRequestBody { + private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset()); + private boolean cancelled; + + @Override + public Optional contentLength() { + return Optional.of((long) CONTENT.length); + } + + @Override + public void subscribe(Subscriber s) { + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + s.onNext(ByteBuffer.wrap(CONTENT)); + s.onComplete(); + } + + @Override + public void cancel() { + cancelled = true; + } + }); + + } + } + + private static final class OnlyRequestOnceSubscriber implements Subscriber { + private List asyncRequestBodies = new ArrayList<>(); + + @Override + public void onSubscribe(Subscription s) { + s.request(1); + } + + @Override + public void onNext(AsyncRequestBody requestBody) { + asyncRequestBodies.add(requestBody); + } + + @Override + public void onError(Throwable t) { + + } + + @Override + public void onComplete() { + + } + } + + private static final class BaosSubscriber implements Subscriber { + private final CompletableFuture resultFuture; + + private ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + private Subscription subscription; + + BaosSubscriber(CompletableFuture resultFuture) { + this.resultFuture = resultFuture; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + invokeSafely(() -> baos.write(BinaryUtils.copyBytesFrom(byteBuffer))); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + baos = null; + resultFuture.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + resultFuture.complete(baos.toByteArray()); + } + } +} diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java new file mode 100644 index 000000000000..4174b87883dc --- /dev/null +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; + +import java.nio.file.Files; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3IntegrationTestBase; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.utils.ChecksumUtils; +import software.amazon.awssdk.testutils.RandomTempFile; + +public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { + + private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); + private static final String TEST_KEY = "testfile.dat"; + private static final int OBJ_SIZE = 19 * 1024 * 1024; + + private static RandomTempFile testFile; + private static S3AsyncClient mpuS3Client; + + @BeforeAll + public static void setup() throws Exception { + S3IntegrationTestBase.setUp(); + S3IntegrationTestBase.createBucket(TEST_BUCKET); + + testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE); + mpuS3Client = new MultipartS3AsyncClient(s3Async); + } + + @AfterAll + public static void teardown() throws Exception { + mpuS3Client.close(); + testFile.delete(); + deleteBucketAndAllContents(TEST_BUCKET); + } + + @Test + @Timeout(value = 20, unit = SECONDS) + void putObject_fileRequestBody_objectSentCorrectly() throws Exception { + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java index e3e125c9d084..414262b7bffa 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java @@ -19,15 +19,11 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.BiFunction; -import java.util.function.Supplier; import java.util.stream.IntStream; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.internal.multipart.GenericMultipartHelper; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; @@ -50,17 +46,16 @@ public final class CopyObjectHelper { private static final Logger log = Logger.loggerFor(S3AsyncClient.class); - /** - * The max number of parts on S3 side is 10,000 - */ - private static final long MAX_UPLOAD_PARTS = 10_000; - private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; public CopyObjectHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + RequestConversionUtils::toAbortMultipartUploadRequest, + RequestConversionUtils::toCopyObjectResponse); } public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { @@ -69,14 +64,15 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb try { CompletableFuture headFuture = - s3AsyncClient.headObject(CopyRequestConversionUtils.toHeadObjectRequest(copyObjectRequest)); + s3AsyncClient.headObject(RequestConversionUtils.toHeadObjectRequest(copyObjectRequest)); // Ensure cancellations are forwarded to the head future CompletableFutureUtils.forwardExceptionTo(returnFuture, headFuture); headFuture.whenComplete((headObjectResponse, throwable) -> { if (throwable != null) { - handleException(returnFuture, () -> "Failed to retrieve metadata from the source object", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Failed to retrieve metadata from the source " + + "object", throwable); } else { doCopyObject(copyObjectRequest, returnFuture, headObjectResponse); } @@ -105,7 +101,7 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, Long contentLength, CompletableFuture returnFuture) { - CreateMultipartUploadRequest request = CopyRequestConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); + CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); @@ -114,7 +110,7 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { if (throwable != null) { - handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); } else { log.debug(() -> "Initiated new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); doCopyInParts(copyObjectRequest, contentLength, returnFuture, createMultipartUploadResponse.uploadId()); @@ -122,17 +118,14 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, }); } - private int determinePartCount(long contentLength, long partSize) { - return (int) Math.ceil(contentLength / (double) partSize); - } - private void doCopyInParts(CopyObjectRequest copyObjectRequest, Long contentLength, CompletableFuture returnFuture, String uploadId) { - long optimalPartSize = calculateOptimalPartSizeForCopy(contentLength); - int partCount = determinePartCount(contentLength, optimalPartSize); + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + + int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); log.debug(() -> String.format("Starting multipart copy with partCount: %s, optimalPartSize: %s", partCount, optimalPartSize)); @@ -147,32 +140,15 @@ private void doCopyInParts(CopyObjectRequest copyObjectRequest, optimalPartSize); CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) .thenCompose(ignore -> completeMultipartUpload(copyObjectRequest, uploadId, completedParts)) - .handle(handleExceptionOrResponse(copyObjectRequest, returnFuture, uploadId)) + .handle(genericMultipartHelper.handleExceptionOrResponse(copyObjectRequest, returnFuture, + uploadId)) .exceptionally(throwable -> { - handleException(returnFuture, () -> "Unexpected exception occurred", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); return null; }); } - private BiFunction handleExceptionOrResponse( - CopyObjectRequest copyObjectRequest, - CompletableFuture returnFuture, - String uploadId) { - - return (completeMultipartUploadResponse, throwable) -> { - if (throwable != null) { - cleanUpParts(copyObjectRequest, uploadId); - handleException(returnFuture, () -> "Failed to send multipart copy requests.", - throwable); - } else { - returnFuture.complete(CopyRequestConversionUtils.toCopyObjectResponse( - completeMultipartUploadResponse)); - } - - return null; - }; - } - private CompletableFuture completeMultipartUpload( CopyObjectRequest copyObjectRequest, String uploadId, AtomicReferenceArray completedParts) { log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", @@ -194,35 +170,6 @@ private CompletableFuture completeMultipartUplo return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } - private void cleanUpParts(CopyObjectRequest copyObjectRequest, String uploadId) { - AbortMultipartUploadRequest abortMultipartUploadRequest = - CopyRequestConversionUtils.toAbortMultipartUploadRequest(copyObjectRequest, uploadId); - s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest) - .exceptionally(throwable -> { - log.warn(() -> String.format("Failed to abort previous multipart upload " - + "(id: %s)" - + ". You may need to call " - + "S3AsyncClient#abortMultiPartUpload to " - + "free all storage consumed by" - + " all parts. ", - uploadId), throwable); - return null; - }); - } - - private static void handleException(CompletableFuture returnFuture, - Supplier message, - Throwable throwable) { - Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; - - if (cause instanceof Error) { - returnFuture.completeExceptionally(cause); - } else { - SdkClientException exception = SdkClientException.create(message.get(), cause); - returnFuture.completeExceptionally(exception); - } - } - private List> sendUploadPartCopyRequests(CopyObjectRequest copyObjectRequest, long contentLength, String uploadId, @@ -265,23 +212,13 @@ private static CompletedPart convertUploadPartCopyResponse(AtomicReferenceArray< UploadPartCopyResponse uploadPartCopyResponse) { CopyPartResult copyPartResult = uploadPartCopyResponse.copyPartResult(); CompletedPart completedPart = - CopyRequestConversionUtils.toCompletedPart(copyPartResult, - partNumber); + RequestConversionUtils.toCompletedPart(copyPartResult, + partNumber); completedParts.set(partNumber - 1, completedPart); return completedPart; } - /** - * Calculates the optimal part size of each part request if the copy operation is carried out as multipart copy. - */ - private long calculateOptimalPartSizeForCopy(long contentLengthOfSource) { - double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; - - optimalPartSize = Math.ceil(optimalPartSize); - return (long) Math.max(optimalPartSize, partSizeInBytes); - } - private void copyInOneChunk(CopyObjectRequest copyObjectRequest, CompletableFuture returnFuture) { CompletableFuture copyObjectFuture = diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java similarity index 61% rename from services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java rename to services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java index 2a464b10f499..f4a3aaf60d4a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java @@ -24,15 +24,47 @@ import software.amazon.awssdk.services.s3.model.CopyPartResult; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; /** - * Request conversion utility method for POJO classes associated with {@link S3CrtAsyncClient#copyObject(CopyObjectRequest)} + * Request conversion utility method for POJO classes associated with multipart feature. */ +//TODO: iterate over SDK fields to get the data @SdkInternalApi -public final class CopyRequestConversionUtils { +public final class RequestConversionUtils { - private CopyRequestConversionUtils() { + private RequestConversionUtils() { + } + + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { + + return CreateMultipartUploadRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .requestPayer(putObjectRequest.requestPayer()) + .acl(putObjectRequest.acl()) + .cacheControl(putObjectRequest.cacheControl()) + .metadata(putObjectRequest.metadata()) + .contentDisposition(putObjectRequest.contentDisposition()) + .contentEncoding(putObjectRequest.contentEncoding()) + .contentType(putObjectRequest.contentType()) + .contentLanguage(putObjectRequest.contentLanguage()) + .grantFullControl(putObjectRequest.grantFullControl()) + .expires(putObjectRequest.expires()) + .grantRead(putObjectRequest.grantRead()) + .grantFullControl(putObjectRequest.grantFullControl()) + .grantReadACP(putObjectRequest.grantReadACP()) + .grantWriteACP(putObjectRequest.grantWriteACP()) + //TODO filter out headers + //.overrideConfiguration(putObjectRequest.overrideConfiguration()) + .build(); } public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { @@ -63,6 +95,18 @@ public static CompletedPart toCompletedPart(CopyPartResult copyPartResult, int p .build(); } + public static CompletedPart toCompletedPart(UploadPartResponse partResponse, int partNumber) { + return CompletedPart.builder() + .partNumber(partNumber) + .eTag(partResponse.eTag()) + .checksumCRC32C(partResponse.checksumCRC32C()) + .checksumCRC32(partResponse.checksumCRC32()) + .checksumSHA1(partResponse.checksumSHA1()) + .checksumSHA256(partResponse.checksumSHA256()) + .eTag(partResponse.eTag()) + .build(); + } + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { return CreateMultipartUploadRequest.builder() .bucket(copyObjectRequest.destinationBucket()) @@ -124,15 +168,20 @@ public static CopyObjectResponse toCopyObjectResponse(CompleteMultipartUploadRes return builder.build(); } - public static AbortMultipartUploadRequest toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest, - String uploadId) { + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { return AbortMultipartUploadRequest.builder() - .uploadId(uploadId) .bucket(copyObjectRequest.destinationBucket()) .key(copyObjectRequest.destinationKey()) .requestPayer(copyObjectRequest.requestPayerAsString()) - .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()) - .build(); + .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()); + } + + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(PutObjectRequest putObjectRequest) { + return AbortMultipartUploadRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .requestPayer(putObjectRequest.requestPayerAsString()) + .expectedBucketOwner(putObjectRequest.expectedBucketOwner()); } public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest copyObjectRequest, @@ -165,4 +214,47 @@ public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest co .build(); } + public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) { + return UploadPartRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .uploadId(uploadId) + .partNumber(partNumber) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .expectedBucketOwner(putObjectRequest.expectedBucketOwner()) + .requestPayer(putObjectRequest.requestPayerAsString()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .build(); + } + + public static PutObjectResponse toPutObjectResponse(CompleteMultipartUploadResponse response) { + PutObjectResponse.Builder builder = PutObjectResponse.builder() + .versionId(response.versionId()) + .checksumCRC32(response.checksumCRC32()) + .checksumSHA1(response.checksumSHA1()) + .checksumSHA256(response.checksumSHA256()) + .checksumCRC32C(response.checksumCRC32C()) + .eTag(response.eTag()) + .expiration(response.expiration()) + .bucketKeyEnabled(response.bucketKeyEnabled()) + .serverSideEncryption(response.serverSideEncryption()) + .ssekmsKeyId(response.ssekmsKeyId()) + .serverSideEncryption(response.serverSideEncryptionAsString()) + .requestCharged(response.requestChargedAsString()); + + // TODO: check why we have to do null check + if (response.responseMetadata() != null) { + builder.responseMetadata(response.responseMetadata()); + } + + if (response.sdkHttpResponse() != null) { + builder.sdkHttpResponse(response.sdkHttpResponse()); + } + + return builder.build(); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java index 84d3c6ac5305..f929bc3fc8f4 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java @@ -65,10 +65,10 @@ public UploadPartCopyRequest next() { long partSize = Math.min(optimalPartSize, remainingBytes); String range = range(partSize); UploadPartCopyRequest uploadPartCopyRequest = - CopyRequestConversionUtils.toUploadPartCopyRequest(copyObjectRequest, - partNumber, - uploadId, - range); + RequestConversionUtils.toUploadPartCopyRequest(copyObjectRequest, + partNumber, + uploadId, + range); partNumber++; offset += partSize; remainingBytes -= partSize; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java new file mode 100644 index 000000000000..4ab4b22a0e79 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.S3Request; +import software.amazon.awssdk.services.s3.model.S3Response; +import software.amazon.awssdk.utils.Logger; + +@SdkInternalApi +public final class GenericMultipartHelper { + private static final Logger log = Logger.loggerFor(GenericMultipartHelper.class); + /** + * The max number of parts on S3 side is 10,000 + */ + private static final long MAX_UPLOAD_PARTS = 10_000; + + private final S3AsyncClient s3AsyncClient; + private final Function abortMultipartUploadRequestConverter; + private final Function responseConverter; + + public GenericMultipartHelper(S3AsyncClient s3AsyncClient, + Function abortMultipartUploadRequestConverter, + Function responseConverter) { + this.s3AsyncClient = s3AsyncClient; + this.abortMultipartUploadRequestConverter = abortMultipartUploadRequestConverter; + this.responseConverter = responseConverter; + } + + public void handleException(CompletableFuture returnFuture, + Supplier message, + Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create(message.get(), cause); + returnFuture.completeExceptionally(exception); + } + } + + public long calculateOptimalPartSizeFor(long contentLengthOfSource, long partSizeInBytes) { + double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; + + optimalPartSize = Math.ceil(optimalPartSize); + return (long) Math.max(optimalPartSize, partSizeInBytes); + } + + public int determinePartCount(long contentLength, long partSize) { + return (int) Math.ceil(contentLength / (double) partSize); + } + + public CompletableFuture completeMultipartUpload( + RequestT request, String uploadId, AtomicReferenceArray completedParts) { + log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", + uploadId)); + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + CompleteMultipartUploadRequest completeMultipartUploadRequest = + CompleteMultipartUploadRequest.builder() + .bucket(request.getValueForField("Bucket", String.class).get()) + .key(request.getValueForField("Key", String.class).get()) + .uploadId(uploadId) + .multipartUpload(CompletedMultipartUpload.builder() + .parts(parts) + .build()) + .build(); + + return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); + } + + public BiFunction handleExceptionOrResponse( + RequestT request, + CompletableFuture returnFuture, + String uploadId) { + + return (completeMultipartUploadResponse, throwable) -> { + if (throwable != null) { + cleanUpParts(uploadId, abortMultipartUploadRequestConverter.apply(request)); + handleException(returnFuture, () -> "Failed to send multipart requests", + throwable); + } else { + returnFuture.complete(responseConverter.apply( + completeMultipartUploadResponse)); + } + + return null; + }; + } + + public void cleanUpParts(String uploadId, AbortMultipartUploadRequest.Builder abortMultipartUploadRequest) { + s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest.uploadId(uploadId).build()) + .exceptionally(throwable -> { + log.warn(() -> String.format("Failed to abort previous multipart upload " + + "(id: %s)" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + uploadId), throwable); + return null; + }); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java new file mode 100644 index 000000000000..f2895d65fcd2 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -0,0 +1,47 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +// This is just a temporary class for testing +//TODO: change this +@SdkInternalApi +public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { + private static final long DEFAULT_PART_SIZE_IN_BYTES = 8L * 1024 * 1024; + private static final long DEFAULT_THRESHOLD = 8L * 1024 * 1024; + + private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; + private final MultipartUploadHelper mpuHelper; + + public MultipartS3AsyncClient(S3AsyncClient delegate) { + super(delegate); + // TODO: pass a config object to the upload helper instead + mpuHelper = new MultipartUploadHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); + } + + @Override + public CompletableFuture putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) { + return mpuHelper.uploadObject(putObjectRequest, requestBody); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java new file mode 100644 index 000000000000..d043d88936c6 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -0,0 +1,274 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils.toAbortMultipartUploadRequest; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.Function; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.async.SplittingPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * An internal helper class that automatically uses multipart upload based on the size of the object. + */ +@SdkInternalApi +public final class MultipartUploadHelper { + private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + + public MultipartUploadHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + RequestConversionUtils::toAbortMultipartUploadRequest, + RequestConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); + + // TODO: support null content length. Should be trivial to support it now + if (contentLength == null) { + throw new IllegalArgumentException("Content-length is required"); + } + + CompletableFuture returnFuture = new CompletableFuture<>(); + + try { + if (contentLength > multipartUploadThresholdInBytes && contentLength > partSizeInBytes) { + log.debug(() -> "Starting the upload as multipart upload request"); + uploadInParts(putObjectRequest, contentLength, asyncRequestBody, returnFuture); + } else { + log.debug(() -> "Starting the upload as a single upload part request"); + uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); + } + + } catch (Throwable throwable) { + returnFuture.completeExceptionally(throwable); + } + + return returnFuture; + } + + private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + + CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); + CompletableFuture createMultipartUploadFuture = + s3AsyncClient.createMultipartUpload(request); + + // Ensure cancellations are forwarded to the createMultipartUploadFuture future + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + } else { + log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + doUploadInParts(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, + createMultipartUploadResponse.uploadId()); + } + }); + } + + private void doUploadInParts(Pair request, + long contentLength, + CompletableFuture returnFuture, + String uploadId) { + + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); + + log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, + optimalPartSize)); + + // The list of completed parts must be sorted + AtomicReferenceArray completedParts = new AtomicReferenceArray<>(partCount); + + PutObjectRequest putObjectRequest = request.left(); + + Collection> futures = new ConcurrentLinkedQueue<>(); + + MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); + + CompletableFuture requestsFuture = sendUploadPartRequests(mpuRequestContext, + completedParts, + returnFuture, + futures); + requestsFuture.whenComplete((r, t) -> { + if (t != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + cancelingOtherOngoingRequests(futures, t); + return; + } + CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) + .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, + uploadId, + completedParts)) + .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, + uploadId)) + .exceptionally(throwable -> { + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); + return null; + }); + }); + } + + private static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { + log.trace(() -> "cancelling other ongoing requests " + futures.size()); + futures.forEach(f -> f.completeExceptionally(t)); + } + + private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequestContext, + AtomicReferenceArray completedParts, + CompletableFuture returnFuture, + Collection> futures) { + + CompletableFuture splittingPublisherFuture = new CompletableFuture<>(); + + AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes(mpuRequestContext.partSize) + .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .resultFuture(splittingPublisherFuture) + .build(); + + splittingPublisher.map(new BodyToRequestConverter(mpuRequestContext.request.left(), mpuRequestContext.uploadId)) + .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, + completedParts, + futures, + pair, + splittingPublisherFuture)) + .exceptionally(throwable -> { + returnFuture.completeExceptionally(throwable); + return null; + }); + return splittingPublisherFuture; + } + + private void sendIndividualUploadPartRequest(String uploadId, + AtomicReferenceArray completedParts, + Collection> futures, + Pair requestPair, + CompletableFuture sendUploadPartRequestsFuture) { + UploadPartRequest uploadPartRequest = requestPair.left(); + Integer partNumber = uploadPartRequest.partNumber(); + log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " + + "contentLength " + requestPair.right().contentLength()); + + CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); + + CompletableFuture convertFuture = + uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, + uploadPartResponse)); + futures.add(convertFuture); + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); + CompletableFutureUtils.forwardExceptionTo(uploadPartFuture, sendUploadPartRequestsFuture); + } + + private static CompletedPart convertUploadPartResponse(AtomicReferenceArray completedParts, + Integer partNumber, + UploadPartResponse uploadPartResponse) { + CompletedPart completedPart = RequestConversionUtils.toCompletedPart(uploadPartResponse, partNumber); + + completedParts.set(partNumber - 1, completedPart); + return completedPart; + } + + private void uploadInOneChunk(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + CompletableFuture putObjectResponseCompletableFuture = s3AsyncClient.putObject(putObjectRequest, + asyncRequestBody); + CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); + CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); + } + + private static final class BodyToRequestConverter implements Function> { + private int partNumber = 1; + private final PutObjectRequest putObjectRequest; + private final String uploadId; + + BodyToRequestConverter(PutObjectRequest putObjectRequest, String uploadId) { + this.putObjectRequest = putObjectRequest; + this.uploadId = uploadId; + } + + @Override + public Pair apply(AsyncRequestBody asyncRequestBody) { + log.trace(() -> "Generating uploadPartRequest for partNumber " + partNumber); + UploadPartRequest uploadRequest = + RequestConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber, + uploadId); + ++partNumber; + return Pair.of(uploadRequest, asyncRequestBody); + } + } + + private static final class MpuRequestContext { + private final Pair request; + private final long contentLength; + private final long partSize; + + private final String uploadId; + + private MpuRequestContext(Pair request, + long contentLength, + long partSize, + String uploadId) { + this.request = request; + this.contentLength = contentLength; + this.partSize = partSize; + this.uploadId = uploadId; + } + } + +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java index d3593570a6e6..ec78d7b15eb6 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java @@ -175,7 +175,7 @@ void multiPartCopy_onePartFailed_shouldFailOtherPartsAndAbort() { CompletableFuture future = copyHelper.copyObject(copyObjectRequest); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart copy requests").hasRootCause(exception); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); @@ -213,7 +213,7 @@ void multiPartCopy_completeMultipartFailed_shouldFailAndAbort() { CompletableFuture future = copyHelper.copyObject(copyObjectRequest); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart copy requests").hasRootCause(exception); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); verify(s3AsyncClient).abortMultipartUpload(argumentCaptor.capture()); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java index 94071ad115fd..104d5f6e045f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java @@ -48,14 +48,14 @@ import software.amazon.awssdk.utils.Logger; class CopyRequestConversionUtilsTest { - private static final Logger log = Logger.loggerFor(CopyRequestConversionUtils.class); + private static final Logger log = Logger.loggerFor(RequestConversionUtils.class); private static final Random RNG = new Random(); @Test void toHeadObject_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - HeadObjectRequest convertedToHeadObject = CopyRequestConversionUtils.toHeadObjectRequest(randomCopyObject); + HeadObjectRequest convertedToHeadObject = RequestConversionUtils.toHeadObjectRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(Arrays.asList("ExpectedBucketOwner", "RequestPayer", "Bucket", @@ -74,7 +74,7 @@ void toCompletedPart_shouldCopyProperties() { setFieldsToRandomValues(fromObject.sdkFields(), fromObject); CopyPartResult result = fromObject.build(); - CompletedPart convertedCompletedPart = CopyRequestConversionUtils.toCompletedPart(result, 1); + CompletedPart convertedCompletedPart = RequestConversionUtils.toCompletedPart(result, 1); verifyFieldsAreCopied(result, convertedCompletedPart, new HashSet<>(), CopyPartResult.builder().sdkFields(), CompletedPart.builder().sdkFields()); @@ -84,7 +84,7 @@ void toCompletedPart_shouldCopyProperties() { @Test void toCreateMultipartUploadRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - CreateMultipartUploadRequest convertedRequest = CopyRequestConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); + CreateMultipartUploadRequest convertedRequest = RequestConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), @@ -100,7 +100,7 @@ void toCopyObjectResponse_shouldCopyProperties() { responseBuilder.responseMetadata(s3ResponseMetadata).sdkHttpResponse(sdkHttpFullResponse); CompleteMultipartUploadResponse result = responseBuilder.build(); - CopyObjectResponse convertedRequest = CopyRequestConversionUtils.toCopyObjectResponse(result); + CopyObjectResponse convertedRequest = RequestConversionUtils.toCopyObjectResponse(result); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(result, convertedRequest, fieldsToIgnore, CompleteMultipartUploadResponse.builder().sdkFields(), @@ -113,21 +113,20 @@ void toCopyObjectResponse_shouldCopyProperties() { @Test void toAbortMultipartUploadRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - AbortMultipartUploadRequest convertedRequest = CopyRequestConversionUtils.toAbortMultipartUploadRequest(randomCopyObject, - "id"); + AbortMultipartUploadRequest convertedRequest = RequestConversionUtils.toAbortMultipartUploadRequest(randomCopyObject).build(); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), AbortMultipartUploadRequest.builder().sdkFields()); - assertThat(convertedRequest.uploadId()).isEqualTo("id"); + //assertThat(convertedRequest.uploadId()).isEqualTo("id"); } @Test void toUploadPartCopyRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - UploadPartCopyRequest convertedObject = CopyRequestConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", - "bytes=0-1024"); + UploadPartCopyRequest convertedObject = RequestConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", + "bytes=0-1024"); Set fieldsToIgnore = new HashSet<>(Collections.singletonList("CopySource")); verifyFieldsAreCopied(randomCopyObject, convertedObject, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java new file mode 100644 index 000000000000..435d5b406189 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; + +public final class MpuTestUtils { + + private MpuTestUtils() { + } + + public static void stubSuccessfulHeadObjectCall(long contentLength, S3AsyncClient s3AsyncClient) { + CompletableFuture headFuture = + CompletableFuture.completedFuture(HeadObjectResponse.builder() + .contentLength(contentLength) + .build()); + + when(s3AsyncClient.headObject(any(HeadObjectRequest.class))) + .thenReturn(headFuture); + } + + public static void stubSuccessfulCreateMultipartCall(String mpuId, S3AsyncClient s3AsyncClient) { + CompletableFuture createMultipartUploadFuture = + CompletableFuture.completedFuture(CreateMultipartUploadResponse.builder() + .uploadId(mpuId) + .build()); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartUploadFuture); + } + + public static void stubSuccessfulCompleteMultipartCall(String bucket, String key, S3AsyncClient s3AsyncClient) { + CompletableFuture completeMultipartUploadFuture = + CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder() + .bucket(bucket) + .key(key) + .build()); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeMultipartUploadFuture); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java new file mode 100644 index 000000000000..0db53c246e03 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java @@ -0,0 +1,250 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.OngoingStubbing; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +public class MultipartUploadHelperTest { + + private static final String BUCKET = "bucket"; + private static final String KEY = "key"; + private static final long PART_SIZE = 8 * 1024; + + // Should contain four parts: [8KB, 8KB, 8KB, 1KB] + private static final long MPU_CONTENT_SIZE = 25 * 1024; + private static final long THRESHOLD = 10 * 1024; + private static final String UPLOAD_ID = "1234"; + + private static RandomTempFile testFile; + private MultipartUploadHelper uploadHelper; + private S3AsyncClient s3AsyncClient; + + @BeforeAll + public static void beforeAll() throws IOException { + testFile = new RandomTempFile("testfile.dat", MPU_CONTENT_SIZE); + } + + @AfterAll + public static void afterAll() throws Exception { + testFile.delete(); + } + + @BeforeEach + public void beforeEach() { + s3AsyncClient = Mockito.mock(S3AsyncClient.class); + uploadHelper = new MultipartUploadHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); + } + + @ParameterizedTest + @ValueSource(longs = {THRESHOLD, PART_SIZE, THRESHOLD - 1, PART_SIZE - 1}) + public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { + PutObjectRequest putObjectRequest = putObjectRequest(contentLength); + AsyncRequestBody asyncRequestBody = Mockito.mock(AsyncRequestBody.class); + + CompletableFuture completedFuture = + CompletableFuture.completedFuture(PutObjectResponse.builder().build()); + when(s3AsyncClient.putObject(putObjectRequest, asyncRequestBody)).thenReturn(completedFuture); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); + Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); + } + + @Test + public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); + + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)).join(); + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + verify(s3AsyncClient, times(4)).uploadPart(requestArgumentCaptor.capture(), + requestBodyArgumentCaptor.capture()); + + List actualRequests = requestArgumentCaptor.getAllValues(); + List actualRequestBodies = requestBodyArgumentCaptor.getAllValues(); + assertThat(actualRequestBodies).hasSize(4); + assertThat(actualRequests).hasSize(4); + + for (int i = 0; i < actualRequests.size(); i++) { + UploadPartRequest request = actualRequests.get(i); + AsyncRequestBody requestBody = actualRequestBodies.get(i); + assertThat(request.partNumber()).isEqualTo( i + 1); + assertThat(request.bucket()).isEqualTo(BUCKET); + assertThat(request.key()).isEqualTo(KEY); + + if (i == actualRequests.size() - 1) { + assertThat(requestBody.contentLength()).hasValue(1024L); + } else{ + assertThat(requestBody.contentLength()).hasValue(PART_SIZE); + } + } + } + + /** + * The second part failed, it should cancel ongoing part(first part). + */ + @Test + void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { + PutObjectRequest putObjectRequest = putObjectRequest(MPU_CONTENT_SIZE); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + CompletableFuture ongoingRequest = new CompletableFuture<>(); + + SdkClientException exception = SdkClientException.create("request failed"); + + OngoingStubbing> ongoingStubbing = + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn(ongoingRequest); + + stubFailedUploadPartCalls(ongoingStubbing, exception); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + AsyncRequestBody.fromFile(testFile)); + + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); + + verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); + verify(s3AsyncClient).abortMultipartUpload(argumentCaptor.capture()); + AbortMultipartUploadRequest actualRequest = argumentCaptor.getValue(); + assertThat(actualRequest.uploadId()).isEqualTo(UPLOAD_ID); + + assertThat(ongoingRequest).isCompletedExceptionally(); + } + + @Test + void upload_cancelResponseFuture_shouldPropagate() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + CompletableFuture createMultipartFuture = new CompletableFuture<>(); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartFuture); + + CompletableFuture future = + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + + future.cancel(true); + + assertThat(createMultipartFuture).isCancelled(); + } + + @Test + public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + + SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); + + CompletableFuture completeMultipartUploadFuture = + CompletableFutureUtils.failedFuture(exception); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeMultipartUploadFuture); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); + } + + private static PutObjectRequest putObjectRequest(Long contentLength) { + return PutObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .contentLength(contentLength) + .build(); + } + + private void stubSuccessfulUploadPartCalls() { + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenAnswer(new Answer>() { + int numberOfCalls = 0; + + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); + + numberOfCalls++; + return CompletableFuture.completedFuture(UploadPartResponse.builder() + .checksumCRC32("crc" + numberOfCalls) + .build()); + } + }); + } + + private OngoingStubbing> stubFailedUploadPartCalls(OngoingStubbing> stubbing, Exception exception) { + return stubbing.thenAnswer(new Answer>() { + + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); + + return CompletableFutureUtils.failedFuture(exception); + } + }); + } + +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java index 15bba8a0aaf1..11d029ee96c2 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java @@ -382,7 +382,7 @@ public void request(long n) { @Override public void cancel() { - log.trace(() -> "Received cancel()"); + log.trace(() -> "Received cancel() from " + subscriber); // Create exception here instead of in supplier to preserve a more-useful stack trace. highPriorityQueue.add(new CancelQueueEntry<>()); From 910b30f01a0845f3e17cc79f091020872f9456a4 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Wed, 12 Jul 2023 09:57:10 -0700 Subject: [PATCH 02/13] Iterate SdkFields to convert requests (#4177) * Iterate SdkFields to convert requests * Fix flaky test * Rename convertion utils class --- .../s3/internal/crt/CopyObjectHelper.java | 11 +- .../internal/crt/RequestConversionUtils.java | 260 ------------------ .../crt/UploadPartCopyRequestIterable.java | 3 +- .../multipart/MultipartUploadHelper.java | 13 +- .../multipart/SdkPojoConversionUtils.java | 185 +++++++++++++ .../multipart/MultipartUploadHelperTest.java | 11 +- .../SdkPojoConversionUtilsTest.java} | 116 ++++++-- 7 files changed, 309 insertions(+), 290 deletions(-) delete mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java rename services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/{crt/CopyRequestConversionUtilsTest.java => multipart/SdkPojoConversionUtilsTest.java} (63%) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java index 414262b7bffa..9070eb7192c5 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java @@ -24,6 +24,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.multipart.GenericMultipartHelper; +import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; @@ -54,8 +55,8 @@ public CopyObjectHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, - RequestConversionUtils::toAbortMultipartUploadRequest, - RequestConversionUtils::toCopyObjectResponse); + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toCopyObjectResponse); } public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { @@ -64,7 +65,7 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb try { CompletableFuture headFuture = - s3AsyncClient.headObject(RequestConversionUtils.toHeadObjectRequest(copyObjectRequest)); + s3AsyncClient.headObject(SdkPojoConversionUtils.toHeadObjectRequest(copyObjectRequest)); // Ensure cancellations are forwarded to the head future CompletableFutureUtils.forwardExceptionTo(returnFuture, headFuture); @@ -101,7 +102,7 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, Long contentLength, CompletableFuture returnFuture) { - CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); + CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); @@ -212,7 +213,7 @@ private static CompletedPart convertUploadPartCopyResponse(AtomicReferenceArray< UploadPartCopyResponse uploadPartCopyResponse) { CopyPartResult copyPartResult = uploadPartCopyResponse.copyPartResult(); CompletedPart completedPart = - RequestConversionUtils.toCompletedPart(copyPartResult, + SdkPojoConversionUtils.toCompletedPart(copyPartResult, partNumber); completedParts.set(partNumber - 1, completedPart); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java deleted file mode 100644 index f4a3aaf60d4a..000000000000 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services.s3.internal.crt; - -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CopyObjectRequest; -import software.amazon.awssdk.services.s3.model.CopyObjectResponse; -import software.amazon.awssdk.services.s3.model.CopyPartResult; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.HeadObjectRequest; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; - -/** - * Request conversion utility method for POJO classes associated with multipart feature. - */ -//TODO: iterate over SDK fields to get the data -@SdkInternalApi -public final class RequestConversionUtils { - - private RequestConversionUtils() { - } - - public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { - - return CreateMultipartUploadRequest.builder() - .bucket(putObjectRequest.bucket()) - .key(putObjectRequest.key()) - .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) - .sseCustomerKey(putObjectRequest.sseCustomerKey()) - .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) - .requestPayer(putObjectRequest.requestPayer()) - .acl(putObjectRequest.acl()) - .cacheControl(putObjectRequest.cacheControl()) - .metadata(putObjectRequest.metadata()) - .contentDisposition(putObjectRequest.contentDisposition()) - .contentEncoding(putObjectRequest.contentEncoding()) - .contentType(putObjectRequest.contentType()) - .contentLanguage(putObjectRequest.contentLanguage()) - .grantFullControl(putObjectRequest.grantFullControl()) - .expires(putObjectRequest.expires()) - .grantRead(putObjectRequest.grantRead()) - .grantFullControl(putObjectRequest.grantFullControl()) - .grantReadACP(putObjectRequest.grantReadACP()) - .grantWriteACP(putObjectRequest.grantWriteACP()) - //TODO filter out headers - //.overrideConfiguration(putObjectRequest.overrideConfiguration()) - .build(); - } - - public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { - return HeadObjectRequest.builder() - .bucket(copyObjectRequest.sourceBucket()) - .key(copyObjectRequest.sourceKey()) - .versionId(copyObjectRequest.sourceVersionId()) - .ifMatch(copyObjectRequest.copySourceIfMatch()) - .ifModifiedSince(copyObjectRequest.copySourceIfModifiedSince()) - .ifNoneMatch(copyObjectRequest.copySourceIfNoneMatch()) - .ifUnmodifiedSince(copyObjectRequest.copySourceIfUnmodifiedSince()) - .expectedBucketOwner(copyObjectRequest.expectedSourceBucketOwner()) - .sseCustomerAlgorithm(copyObjectRequest.copySourceSSECustomerAlgorithm()) - .sseCustomerKey(copyObjectRequest.copySourceSSECustomerKey()) - .sseCustomerKeyMD5(copyObjectRequest.copySourceSSECustomerKeyMD5()) - .build(); - } - - public static CompletedPart toCompletedPart(CopyPartResult copyPartResult, int partNumber) { - return CompletedPart.builder() - .partNumber(partNumber) - .eTag(copyPartResult.eTag()) - .checksumCRC32C(copyPartResult.checksumCRC32C()) - .checksumCRC32(copyPartResult.checksumCRC32()) - .checksumSHA1(copyPartResult.checksumSHA1()) - .checksumSHA256(copyPartResult.checksumSHA256()) - .eTag(copyPartResult.eTag()) - .build(); - } - - public static CompletedPart toCompletedPart(UploadPartResponse partResponse, int partNumber) { - return CompletedPart.builder() - .partNumber(partNumber) - .eTag(partResponse.eTag()) - .checksumCRC32C(partResponse.checksumCRC32C()) - .checksumCRC32(partResponse.checksumCRC32()) - .checksumSHA1(partResponse.checksumSHA1()) - .checksumSHA256(partResponse.checksumSHA256()) - .eTag(partResponse.eTag()) - .build(); - } - - public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { - return CreateMultipartUploadRequest.builder() - .bucket(copyObjectRequest.destinationBucket()) - .contentEncoding(copyObjectRequest.contentEncoding()) - .checksumAlgorithm(copyObjectRequest.checksumAlgorithmAsString()) - .tagging(copyObjectRequest.tagging()) - .contentType(copyObjectRequest.contentType()) - .contentLanguage(copyObjectRequest.contentLanguage()) - .contentDisposition(copyObjectRequest.contentDisposition()) - .cacheControl(copyObjectRequest.cacheControl()) - .expires(copyObjectRequest.expires()) - .key(copyObjectRequest.destinationKey()) - .websiteRedirectLocation(copyObjectRequest.websiteRedirectLocation()) - .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()) - .requestPayer(copyObjectRequest.requestPayerAsString()) - .acl(copyObjectRequest.aclAsString()) - .grantRead(copyObjectRequest.grantRead()) - .grantReadACP(copyObjectRequest.grantReadACP()) - .grantWriteACP(copyObjectRequest.grantWriteACP()) - .grantFullControl(copyObjectRequest.grantFullControl()) - .storageClass(copyObjectRequest.storageClassAsString()) - .ssekmsKeyId(copyObjectRequest.ssekmsKeyId()) - .sseCustomerKey(copyObjectRequest.sseCustomerKey()) - .sseCustomerAlgorithm(copyObjectRequest.sseCustomerAlgorithm()) - .sseCustomerKeyMD5(copyObjectRequest.sseCustomerKeyMD5()) - .ssekmsEncryptionContext(copyObjectRequest.ssekmsEncryptionContext()) - .serverSideEncryption(copyObjectRequest.serverSideEncryptionAsString()) - .bucketKeyEnabled(copyObjectRequest.bucketKeyEnabled()) - .objectLockMode(copyObjectRequest.objectLockModeAsString()) - .objectLockLegalHoldStatus(copyObjectRequest.objectLockLegalHoldStatusAsString()) - .objectLockRetainUntilDate(copyObjectRequest.objectLockRetainUntilDate()) - .metadata(copyObjectRequest.metadata()) - .build(); - } - - public static CopyObjectResponse toCopyObjectResponse(CompleteMultipartUploadResponse response) { - CopyObjectResponse.Builder builder = CopyObjectResponse.builder() - .versionId(response.versionId()) - .copyObjectResult(b -> b.checksumCRC32(response.checksumCRC32()) - .checksumSHA1(response.checksumSHA1()) - .checksumSHA256(response.checksumSHA256()) - .checksumCRC32C(response.checksumCRC32C()) - .eTag(response.eTag()) - .build()) - .expiration(response.expiration()) - .bucketKeyEnabled(response.bucketKeyEnabled()) - .serverSideEncryption(response.serverSideEncryption()) - .ssekmsKeyId(response.ssekmsKeyId()) - .serverSideEncryption(response.serverSideEncryptionAsString()) - .requestCharged(response.requestChargedAsString()); - if (response.responseMetadata() != null) { - builder.responseMetadata(response.responseMetadata()); - } - - if (response.sdkHttpResponse() != null) { - builder.sdkHttpResponse(response.sdkHttpResponse()); - } - - return builder.build(); - } - - public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { - return AbortMultipartUploadRequest.builder() - .bucket(copyObjectRequest.destinationBucket()) - .key(copyObjectRequest.destinationKey()) - .requestPayer(copyObjectRequest.requestPayerAsString()) - .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()); - } - - public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(PutObjectRequest putObjectRequest) { - return AbortMultipartUploadRequest.builder() - .bucket(putObjectRequest.bucket()) - .key(putObjectRequest.key()) - .requestPayer(putObjectRequest.requestPayerAsString()) - .expectedBucketOwner(putObjectRequest.expectedBucketOwner()); - } - - public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest copyObjectRequest, - int partNumber, - String uploadId, - String range) { - - return UploadPartCopyRequest.builder() - .sourceBucket(copyObjectRequest.sourceBucket()) - .sourceKey(copyObjectRequest.sourceKey()) - .sourceVersionId(copyObjectRequest.sourceVersionId()) - .uploadId(uploadId) - .partNumber(partNumber) - .destinationBucket(copyObjectRequest.destinationBucket()) - .destinationKey(copyObjectRequest.destinationKey()) - .copySourceIfMatch(copyObjectRequest.copySourceIfMatch()) - .copySourceIfNoneMatch(copyObjectRequest.copySourceIfNoneMatch()) - .copySourceIfUnmodifiedSince(copyObjectRequest.copySourceIfUnmodifiedSince()) - .copySourceRange(range) - .copySourceSSECustomerAlgorithm(copyObjectRequest.copySourceSSECustomerAlgorithm()) - .copySourceSSECustomerKeyMD5(copyObjectRequest.copySourceSSECustomerKeyMD5()) - .copySourceSSECustomerKey(copyObjectRequest.copySourceSSECustomerKey()) - .copySourceIfModifiedSince(copyObjectRequest.copySourceIfModifiedSince()) - .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()) - .expectedSourceBucketOwner(copyObjectRequest.expectedSourceBucketOwner()) - .requestPayer(copyObjectRequest.requestPayerAsString()) - .sseCustomerKey(copyObjectRequest.sseCustomerKey()) - .sseCustomerAlgorithm(copyObjectRequest.sseCustomerAlgorithm()) - .sseCustomerKeyMD5(copyObjectRequest.sseCustomerKeyMD5()) - .build(); - } - - public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) { - return UploadPartRequest.builder() - .bucket(putObjectRequest.bucket()) - .key(putObjectRequest.key()) - .uploadId(uploadId) - .partNumber(partNumber) - .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) - .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) - .sseCustomerKey(putObjectRequest.sseCustomerKey()) - .expectedBucketOwner(putObjectRequest.expectedBucketOwner()) - .requestPayer(putObjectRequest.requestPayerAsString()) - .sseCustomerKey(putObjectRequest.sseCustomerKey()) - .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) - .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) - .build(); - } - - public static PutObjectResponse toPutObjectResponse(CompleteMultipartUploadResponse response) { - PutObjectResponse.Builder builder = PutObjectResponse.builder() - .versionId(response.versionId()) - .checksumCRC32(response.checksumCRC32()) - .checksumSHA1(response.checksumSHA1()) - .checksumSHA256(response.checksumSHA256()) - .checksumCRC32C(response.checksumCRC32C()) - .eTag(response.eTag()) - .expiration(response.expiration()) - .bucketKeyEnabled(response.bucketKeyEnabled()) - .serverSideEncryption(response.serverSideEncryption()) - .ssekmsKeyId(response.ssekmsKeyId()) - .serverSideEncryption(response.serverSideEncryptionAsString()) - .requestCharged(response.requestChargedAsString()); - - // TODO: check why we have to do null check - if (response.responseMetadata() != null) { - builder.responseMetadata(response.responseMetadata()); - } - - if (response.sdkHttpResponse() != null) { - builder.sdkHttpResponse(response.sdkHttpResponse()); - } - - return builder.build(); - } -} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java index f929bc3fc8f4..da8eea8fc64a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java @@ -19,6 +19,7 @@ import java.util.NoSuchElementException; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.pagination.sync.SdkIterable; +import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; @@ -65,7 +66,7 @@ public UploadPartCopyRequest next() { long partSize = Math.min(optimalPartSize, remainingBytes); String range = range(partSize); UploadPartCopyRequest uploadPartCopyRequest = - RequestConversionUtils.toUploadPartCopyRequest(copyObjectRequest, + SdkPojoConversionUtils.toUploadPartCopyRequest(copyObjectRequest, partNumber, uploadId, range); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index d043d88936c6..0f9be5070861 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -16,7 +16,7 @@ package software.amazon.awssdk.services.s3.internal.multipart; -import static software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils.toAbortMultipartUploadRequest; +import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toAbortMultipartUploadRequest; import java.util.Collection; import java.util.concurrent.CompletableFuture; @@ -27,7 +27,6 @@ import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.internal.async.SplittingPublisher; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; @@ -60,8 +59,8 @@ public MultipartUploadHelper(S3AsyncClient s3AsyncClient, this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, - RequestConversionUtils::toAbortMultipartUploadRequest, - RequestConversionUtils::toPutObjectResponse); + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; } @@ -96,7 +95,7 @@ public CompletableFuture uploadObject(PutObjectRequest putObj private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, CompletableFuture returnFuture) { - CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); + CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); @@ -215,7 +214,7 @@ private void sendIndividualUploadPartRequest(String uploadId, private static CompletedPart convertUploadPartResponse(AtomicReferenceArray completedParts, Integer partNumber, UploadPartResponse uploadPartResponse) { - CompletedPart completedPart = RequestConversionUtils.toCompletedPart(uploadPartResponse, partNumber); + CompletedPart completedPart = SdkPojoConversionUtils.toCompletedPart(uploadPartResponse, partNumber); completedParts.set(partNumber - 1, completedPart); return completedPart; @@ -245,7 +244,7 @@ private static final class BodyToRequestConverter implements Function apply(AsyncRequestBody asyncRequestBody) { log.trace(() -> "Generating uploadPartRequest for partNumber " + partNumber); UploadPartRequest uploadRequest = - RequestConversionUtils.toUploadPartRequest(putObjectRequest, + SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, partNumber, uploadId); ++partNumber; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java new file mode 100644 index 000000000000..70512084150b --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java @@ -0,0 +1,185 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CopyObjectRequest; +import software.amazon.awssdk.services.s3.model.CopyObjectResponse; +import software.amazon.awssdk.services.s3.model.CopyObjectResult; +import software.amazon.awssdk.services.s3.model.CopyPartResult; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; + +/** + * Request conversion utility method for POJO classes associated with multipart feature. + */ +@SdkInternalApi +public final class SdkPojoConversionUtils { + + private static final HashSet PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE = + new HashSet<>(Arrays.asList("ChecksumSHA1", "ChecksumSHA256", "ContentMD5", "ChecksumCRC32C", "ChecksumCRC32")); + + private SdkPojoConversionUtils() { + } + + public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) { + + UploadPartRequest.Builder builder = UploadPartRequest.builder(); + + setSdkFields(builder, putObjectRequest, PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE); + + return builder.uploadId(uploadId).partNumber(partNumber).build(); + } + + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { + + CreateMultipartUploadRequest.Builder builder = CreateMultipartUploadRequest.builder(); + setSdkFields(builder, putObjectRequest); + return builder.build(); + } + + public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { + HeadObjectRequest.Builder builder = HeadObjectRequest.builder(); + setSdkFields(builder, copyObjectRequest); + return builder.build(); + } + + public static CompletedPart toCompletedPart(CopyPartResult copyPartResult, int partNumber) { + CompletedPart.Builder builder = CompletedPart.builder(); + + setSdkFields(builder, copyPartResult); + return builder.partNumber(partNumber).build(); + } + + public static CompletedPart toCompletedPart(UploadPartResponse partResponse, int partNumber) { + CompletedPart.Builder builder = CompletedPart.builder(); + setSdkFields(builder, partResponse); + return builder.partNumber(partNumber).build(); + } + + private static void setSdkFields(SdkPojo targetBuilder, SdkPojo sourceObject) { + setSdkFields(targetBuilder, sourceObject, new HashSet<>()); + } + + private static void setSdkFields(SdkPojo targetBuilder, SdkPojo sourceObject, Set fieldsToIgnore) { + Map sourceFields = retrieveSdkFields(sourceObject, sourceObject.sdkFields()); + List> targetSdkFields = targetBuilder.sdkFields(); + + for (SdkField field : targetSdkFields) { + if (fieldsToIgnore.contains(field.memberName())) { + continue; + } + field.set(targetBuilder, sourceFields.getOrDefault(field.memberName(), null)); + } + } + + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { + CreateMultipartUploadRequest.Builder builder = CreateMultipartUploadRequest.builder(); + + setSdkFields(builder, copyObjectRequest); + return builder.build(); + } + + public static CopyObjectResponse toCopyObjectResponse(CompleteMultipartUploadResponse response) { + CopyObjectResponse.Builder builder = CopyObjectResponse.builder(); + + setSdkFields(builder, response); + + if (response.responseMetadata() != null) { + builder.responseMetadata(response.responseMetadata()); + } + + if (response.sdkHttpResponse() != null) { + builder.sdkHttpResponse(response.sdkHttpResponse()); + } + + return builder.copyObjectResult(toCopyObjectResult(response)) + .build(); + } + + private static CopyObjectResult toCopyObjectResult(CompleteMultipartUploadResponse response) { + CopyObjectResult.Builder builder = CopyObjectResult.builder(); + + setSdkFields(builder, response); + return builder.build(); + } + + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { + AbortMultipartUploadRequest.Builder builder = AbortMultipartUploadRequest.builder(); + setSdkFields(builder, copyObjectRequest); + return builder; + } + + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(PutObjectRequest putObjectRequest) { + AbortMultipartUploadRequest.Builder builder = AbortMultipartUploadRequest.builder(); + setSdkFields(builder, putObjectRequest); + return builder; + } + + public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest copyObjectRequest, + int partNumber, + String uploadId, + String range) { + UploadPartCopyRequest.Builder builder = UploadPartCopyRequest.builder(); + setSdkFields(builder, copyObjectRequest); + return builder.copySourceRange(range) + .partNumber(partNumber) + .uploadId(uploadId) + .build(); + } + + public static PutObjectResponse toPutObjectResponse(CompleteMultipartUploadResponse response) { + + PutObjectResponse.Builder builder = PutObjectResponse.builder(); + + setSdkFields(builder, response); + + // TODO: check why we have to do null check + if (response.responseMetadata() != null) { + builder.responseMetadata(response.responseMetadata()); + } + + if (response.sdkHttpResponse() != null) { + builder.sdkHttpResponse(response.sdkHttpResponse()); + } + + return builder.build(); + } + + private static Map retrieveSdkFields(SdkPojo sourceObject, List> sdkFields) { + return sdkFields.stream().collect( + HashMap::new, + (map, field) -> map.put(field.memberName(), + field.getValueOrDefault(sourceObject)), + Map::putAll); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java index 0db53c246e03..1ea17d4ba967 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -27,6 +28,9 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -164,7 +168,12 @@ void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { AbortMultipartUploadRequest actualRequest = argumentCaptor.getValue(); assertThat(actualRequest.uploadId()).isEqualTo(UPLOAD_ID); - assertThat(ongoingRequest).isCompletedExceptionally(); + try { + ongoingRequest.get(1, TimeUnit.MILLISECONDS); + fail("no exception thrown"); + } catch (Exception e) { + assertThat(e.getCause()).hasMessageContaining("request failed"); + } } @Test diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java similarity index 63% rename from services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java rename to services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java index 104d5f6e045f..4d5a333a51dd 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.s3.internal.crt; +package software.amazon.awssdk.services.s3.internal.multipart; import static org.assertj.core.api.Assertions.assertThat; @@ -35,6 +35,7 @@ import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; @@ -43,19 +44,23 @@ import software.amazon.awssdk.services.s3.model.CopyPartResult; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3ResponseMetadata; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.utils.Logger; -class CopyRequestConversionUtilsTest { - private static final Logger log = Logger.loggerFor(RequestConversionUtils.class); +class SdkPojoConversionUtilsTest { + private static final Logger log = Logger.loggerFor(SdkPojoConversionUtils.class); private static final Random RNG = new Random(); @Test void toHeadObject_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - HeadObjectRequest convertedToHeadObject = RequestConversionUtils.toHeadObjectRequest(randomCopyObject); + HeadObjectRequest convertedToHeadObject = SdkPojoConversionUtils.toHeadObjectRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(Arrays.asList("ExpectedBucketOwner", "RequestPayer", "Bucket", @@ -69,12 +74,12 @@ void toHeadObject_shouldCopyProperties() { } @Test - void toCompletedPart_shouldCopyProperties() { + void toCompletedPart_copy_shouldCopyProperties() { CopyPartResult.Builder fromObject = CopyPartResult.builder(); setFieldsToRandomValues(fromObject.sdkFields(), fromObject); CopyPartResult result = fromObject.build(); - CompletedPart convertedCompletedPart = RequestConversionUtils.toCompletedPart(result, 1); + CompletedPart convertedCompletedPart = SdkPojoConversionUtils.toCompletedPart(result, 1); verifyFieldsAreCopied(result, convertedCompletedPart, new HashSet<>(), CopyPartResult.builder().sdkFields(), CompletedPart.builder().sdkFields()); @@ -82,9 +87,9 @@ void toCompletedPart_shouldCopyProperties() { } @Test - void toCreateMultipartUploadRequest_shouldCopyProperties() { + void toCreateMultipartUploadRequest_copyObject_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - CreateMultipartUploadRequest convertedRequest = RequestConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); + CreateMultipartUploadRequest convertedRequest = SdkPojoConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), @@ -100,7 +105,7 @@ void toCopyObjectResponse_shouldCopyProperties() { responseBuilder.responseMetadata(s3ResponseMetadata).sdkHttpResponse(sdkHttpFullResponse); CompleteMultipartUploadResponse result = responseBuilder.build(); - CopyObjectResponse convertedRequest = RequestConversionUtils.toCopyObjectResponse(result); + CopyObjectResponse convertedRequest = SdkPojoConversionUtils.toCopyObjectResponse(result); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(result, convertedRequest, fieldsToIgnore, CompleteMultipartUploadResponse.builder().sdkFields(), @@ -111,21 +116,29 @@ void toCopyObjectResponse_shouldCopyProperties() { } @Test - void toAbortMultipartUploadRequest_shouldCopyProperties() { + void toAbortMultipartUploadRequest_copyObject_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - AbortMultipartUploadRequest convertedRequest = RequestConversionUtils.toAbortMultipartUploadRequest(randomCopyObject).build(); + AbortMultipartUploadRequest convertedRequest = SdkPojoConversionUtils.toAbortMultipartUploadRequest(randomCopyObject).build(); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), AbortMultipartUploadRequest.builder().sdkFields()); + } - //assertThat(convertedRequest.uploadId()).isEqualTo("id"); + @Test + void toAbortMultipartUploadRequest_putObject_shouldCopyProperties() { + PutObjectRequest randomCopyObject = randomPutObjectRequest(); + AbortMultipartUploadRequest convertedRequest = SdkPojoConversionUtils.toAbortMultipartUploadRequest(randomCopyObject).build(); + Set fieldsToIgnore = new HashSet<>(); + verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, + PutObjectRequest.builder().sdkFields(), + AbortMultipartUploadRequest.builder().sdkFields()); } @Test void toUploadPartCopyRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - UploadPartCopyRequest convertedObject = RequestConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", + UploadPartCopyRequest convertedObject = SdkPojoConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", "bytes=0-1024"); Set fieldsToIgnore = new HashSet<>(Collections.singletonList("CopySource")); verifyFieldsAreCopied(randomCopyObject, convertedObject, fieldsToIgnore, @@ -133,6 +146,61 @@ void toUploadPartCopyRequest_shouldCopyProperties() { UploadPartCopyRequest.builder().sdkFields()); } + @Test + void toUploadPartRequest_shouldCopyProperties() { + PutObjectRequest randomObject = randomPutObjectRequest(); + UploadPartRequest convertedObject = SdkPojoConversionUtils.toUploadPartRequest(randomObject, 1, "id"); + Set fieldsToIgnore = new HashSet<>(Arrays.asList("ChecksumCRC32", "ChecksumSHA256", "ContentMD5", "ChecksumSHA1", + "ChecksumCRC32C")); + verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, + PutObjectRequest.builder().sdkFields(), + UploadPartRequest.builder().sdkFields()); + assertThat(convertedObject.partNumber()).isEqualTo(1); + assertThat(convertedObject.uploadId()).isEqualTo("id"); + } + + @Test + void toPutObjectResponse_shouldCopyProperties() { + CompleteMultipartUploadResponse.Builder builder = CompleteMultipartUploadResponse.builder(); + populateFields(builder); + S3ResponseMetadata s3ResponseMetadata = S3ResponseMetadata.create(DefaultAwsResponseMetadata.create(new HashMap<>())); + SdkHttpFullResponse sdkHttpFullResponse = SdkHttpFullResponse.builder().statusCode(200).build(); + builder.responseMetadata(s3ResponseMetadata).sdkHttpResponse(sdkHttpFullResponse); + CompleteMultipartUploadResponse randomObject = builder.build(); + PutObjectResponse convertedObject = SdkPojoConversionUtils.toPutObjectResponse(randomObject); + Set fieldsToIgnore = new HashSet<>(); + verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, + CompleteMultipartUploadResponse.builder().sdkFields(), + PutObjectResponse.builder().sdkFields()); + + assertThat(convertedObject.sdkHttpResponse()).isEqualTo(sdkHttpFullResponse); + assertThat(convertedObject.responseMetadata()).isEqualTo(s3ResponseMetadata); + } + + @Test + void toCreateMultipartUploadRequest_putObjectRequest_shouldCopyProperties() { + PutObjectRequest randomObject = randomPutObjectRequest(); + CreateMultipartUploadRequest convertedObject = SdkPojoConversionUtils.toCreateMultipartUploadRequest(randomObject); + Set fieldsToIgnore = new HashSet<>(); + System.out.println(convertedObject); + verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, + PutObjectRequest.builder().sdkFields(), + CreateMultipartUploadRequest.builder().sdkFields()); + } + + @Test + void toCompletedPart_putObject_shouldCopyProperties() { + UploadPartResponse.Builder fromObject = UploadPartResponse.builder(); + setFieldsToRandomValues(fromObject.sdkFields(), fromObject); + UploadPartResponse result = fromObject.build(); + + CompletedPart convertedCompletedPart = SdkPojoConversionUtils.toCompletedPart(result, 1); + verifyFieldsAreCopied(result, convertedCompletedPart, new HashSet<>(), + UploadPartResponse.builder().sdkFields(), + CompletedPart.builder().sdkFields()); + assertThat(convertedCompletedPart.partNumber()).isEqualTo(1); + } + private static void verifyFieldsAreCopied(SdkPojo requestConvertedFrom, SdkPojo requestConvertedTo, Set fieldsToIgnore, @@ -147,7 +215,7 @@ private static void verifyFieldsAreCopied(SdkPojo requestConvertedFrom, SdkField toField = toObjectEntry.getValue(); if (fieldsToIgnore.contains(toField.memberName())) { - log.info(() -> "Ignoring fields: " + toField.locationName()); + log.info(() -> "Ignoring fields: " + toField.memberName()); continue; } @@ -155,7 +223,7 @@ private static void verifyFieldsAreCopied(SdkPojo requestConvertedFrom, if (fromField == null) { log.info(() -> String.format("Ignoring field [%s] because the object to convert from does not have such field ", - toField.locationName())); + toField.memberName())); continue; } @@ -175,6 +243,16 @@ private CopyObjectRequest randomCopyObjectRequest() { return builder.build(); } + private PutObjectRequest randomPutObjectRequest() { + PutObjectRequest.Builder builder = PutObjectRequest.builder(); + setFieldsToRandomValues(builder.sdkFields(), builder); + return builder.build(); + } + + private void populateFields(SdkPojo pojo) { + setFieldsToRandomValues(pojo.sdkFields(), pojo); + } + private void setFieldsToRandomValues(Collection> fields, Object builder) { for (SdkField f : fields) { setFieldToRandomValue(f, builder); @@ -193,6 +271,8 @@ private static void setFieldToRandomValue(SdkField sdkField, Object obj) { sdkField.set(obj, new HashMap<>()); } else if (targetClass.equals(Boolean.class)) { sdkField.set(obj, true); + } else if (targetClass.equals(Long.class)) { + sdkField.set(obj, randomLong()); } else { throw new IllegalArgumentException("Unknown SdkField type: " + targetClass + " name: " + sdkField.memberName()); } @@ -201,7 +281,7 @@ private static void setFieldToRandomValue(SdkField sdkField, Object obj) { private static Map> sdkFieldMap(Collection> sdkFields) { Map> map = new HashMap<>(sdkFields.size()); for (SdkField f : sdkFields) { - String locName = f.locationName(); + String locName = f.memberName(); if (map.put(locName, f) != null) { throw new IllegalArgumentException("Multiple SdkFields map to same location name"); } @@ -216,4 +296,8 @@ private static Instant randomInstant() { private static Integer randomInteger() { return RNG.nextInt(); } + + private static long randomLong() { + return RNG.nextLong(); + } } From d99890895d2c3f206a8220fad4c87768e2824c1c Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Wed, 12 Jul 2023 12:04:50 -0700 Subject: [PATCH 03/13] Fix null content length in SplittingPublisher (#4173) --- .../internal/async/SplittingPublisher.java | 103 +++++++++++++----- .../async/SplittingPublisherTest.java | 65 ++++++++++- .../multipart/MultipartUploadHelper.java | 3 +- 3 files changed, 138 insertions(+), 33 deletions(-) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 095d69ac5e7d..8152e13980a6 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -33,8 +33,11 @@ /** * Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the * original data. + * + *

If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized. + * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length. + * * // TODO: create a default method in AsyncRequestBody for this - * // TODO: fix the case where content length is null */ @SdkInternalApi public class SplittingPublisher implements SdkPublisher { @@ -86,6 +89,7 @@ private class SplittingSubscriber implements Subscriber { * A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call. */ private int byteBufferSizeHint; + private volatile boolean upstreamComplete; SplittingSubscriber(Long upstreamSize) { this.upstreamSize = upstreamSize; @@ -94,36 +98,49 @@ private class SplittingSubscriber implements Subscriber { @Override public void onSubscribe(Subscription s) { this.upstreamSubscription = s; - this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get()); - sendCurrentBody(); + this.currentBody = + initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize), + chunkNumber.get()); // We need to request subscription *after* we set currentBody because onNext could be invoked right away. upstreamSubscription.request(1); } + private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) { + DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber); + if (contentLengthKnown) { + sendCurrentBody(body); + } + return body; + } + @Override public void onNext(ByteBuffer byteBuffer) { hasOpenUpstreamDemand.set(false); byteBufferSizeHint = byteBuffer.remaining(); while (true) { - int amountRemainingInPart = amountRemainingInPart(); - int finalAmountRemainingInPart = amountRemainingInPart; - if (amountRemainingInPart == 0) { - currentBody.complete(); - int currentChunk = chunkNumber.incrementAndGet(); - Long partSize = calculateChunkSize(); - currentBody = new DownstreamBody(partSize, currentChunk); - sendCurrentBody(); + int amountRemainingInChunk = amountRemainingInChunk(); + + // If we have fulfilled this chunk, + // we should create a new DownstreamBody if needed + if (amountRemainingInChunk == 0) { + completeCurrentBody(); + + if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { + int currentChunk = chunkNumber.incrementAndGet(); + long chunkSize = calculateChunkSize(totalDataRemaining()); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + } } - amountRemainingInPart = amountRemainingInPart(); - if (amountRemainingInPart >= byteBuffer.remaining()) { + amountRemainingInChunk = amountRemainingInChunk(); + if (amountRemainingInChunk >= byteBuffer.remaining()) { currentBody.send(byteBuffer.duplicate()); break; } ByteBuffer firstHalf = byteBuffer.duplicate(); - int newLimit = firstHalf.position() + amountRemainingInPart; + int newLimit = firstHalf.position() + amountRemainingInChunk; firstHalf.limit(newLimit); byteBuffer.position(newLimit); currentBody.send(firstHalf); @@ -132,15 +149,32 @@ public void onNext(ByteBuffer byteBuffer) { maybeRequestMoreUpstreamData(); } - private int amountRemainingInPart() { - return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength); + + /** + * If content length is known, we should create new DownstreamRequestBody if there's remaining data. + * If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet. + */ + private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) { + return !upstreamComplete || byteBuffer.remaining() > 0; + } + + private int amountRemainingInChunk() { + return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength); + } + + private void completeCurrentBody() { + currentBody.complete(); + if (upstreamSize == null) { + sendCurrentBody(currentBody); + } } @Override public void onComplete() { + upstreamComplete = true; log.trace(() -> "Received onComplete()"); + completeCurrentBody(); downstreamPublisher.complete().thenRun(() -> future.complete(null)); - currentBody.complete(); } @Override @@ -148,17 +182,17 @@ public void onError(Throwable t) { currentBody.error(t); } - private void sendCurrentBody() { - downstreamPublisher.send(currentBody).exceptionally(t -> { + private void sendCurrentBody(AsyncRequestBody body) { + downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); return null; }); } - private Long calculateChunkSize() { - Long dataRemaining = dataRemaining(); + private long calculateChunkSize(Long dataRemaining) { + // Use default chunk size if the content length is unknown if (dataRemaining == null) { - return null; + return chunkSizeInBytes; } return Math.min(chunkSizeInBytes, dataRemaining); @@ -177,27 +211,34 @@ private boolean shouldRequestMoreData(long buffered) { return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; } - private Long dataRemaining() { + private Long totalDataRemaining() { if (upstreamSize == null) { return null; } return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); } - private class DownstreamBody implements AsyncRequestBody { + private final class DownstreamBody implements AsyncRequestBody { + + /** + * The maximum length of the content this AsyncRequestBody can hold. + * If the upstream content length is known, this is the same as totalLength + */ + private final long maxLength; private final Long totalLength; private final SimplePublisher delegate = new SimplePublisher<>(); private final int chunkNumber; private volatile long transferredLength = 0; - private DownstreamBody(Long totalLength, int chunkNumber) { - this.totalLength = totalLength; + private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) { + this.totalLength = contentLengthKnown ? maxLength : null; + this.maxLength = maxLength; this.chunkNumber = chunkNumber; } @Override public Optional contentLength() { - return Optional.ofNullable(totalLength); + return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength); } public void send(ByteBuffer data) { @@ -214,8 +255,12 @@ public void send(ByteBuffer data) { } public void complete() { - log.debug(() -> "Received complete() for chunk number: " + chunkNumber); - delegate.complete(); + log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength); + delegate.complete().whenComplete((r, t) -> { + if (t != null) { + error(t); + } + }); } public void error(Throwable error) { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index df318190b92d..45938ea684c8 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.io.IOException; @@ -28,6 +29,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -44,6 +46,8 @@ public class SplittingPublisherTest { private static final int CHUNK_SIZE = 5; private static final int CONTENT_SIZE = 101; + private static final byte[] CONTENT = + RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset()); private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE); @@ -123,9 +127,59 @@ void cancelFuture_shouldCancelUpstream() throws IOException { assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); } - private static final class TestAsyncRequestBody implements AsyncRequestBody { - private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset()); - private boolean cancelled; + @Test + void contentLengthNotPresent_shouldHandle() throws Exception { + CompletableFuture future = new CompletableFuture<>(); + TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + }; + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes(10L) + .build(); + + + List> futures = new ArrayList<>(); + AtomicInteger index = new AtomicInteger(0); + + splittingPublisher.subscribe(requestBody -> { + CompletableFuture baosFuture = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(baosFuture); + futures.add(baosFuture); + requestBody.subscribe(subscriber); + if (index.incrementAndGet() == NUM_OF_CHUNK) { + assertThat(requestBody.contentLength()).hasValue(1L); + } else { + assertThat(requestBody.contentLength()).hasValue((long) CHUNK_SIZE); + } + }).get(5, TimeUnit.SECONDS); + assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK); + + for (int i = 0; i < futures.size(); i++) { + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(CONTENT)) { + byte[] expected; + if (i == futures.size() - 1) { + expected = new byte[1]; + } else { + expected = new byte[CHUNK_SIZE]; + } + inputStream.skip(i * CHUNK_SIZE); + inputStream.read(expected); + byte[] actualBytes = futures.get(i).join(); + assertThat(actualBytes).isEqualTo(expected); + }; + } + + } + + private static class TestAsyncRequestBody implements AsyncRequestBody { + private volatile boolean cancelled; + private volatile boolean isDone; @Override public Optional contentLength() { @@ -137,8 +191,13 @@ public void subscribe(Subscriber s) { s.onSubscribe(new Subscription() { @Override public void request(long n) { + if (isDone) { + return; + } + isDone = true; s.onNext(ByteBuffer.wrap(CONTENT)); s.onComplete(); + } @Override diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index 0f9be5070861..0dd017d0fd17 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -69,7 +69,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj AsyncRequestBody asyncRequestBody) { Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); - // TODO: support null content length. Should be trivial to support it now + // TODO: support null content length. Need to determine whether to use single object or MPU based on the first + // AsyncRequestBody if (contentLength == null) { throw new IllegalArgumentException("Content-length is required"); } From afe5f58317f2208b480e2889ce9fde1f299a4423 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Mon, 17 Jul 2023 15:27:03 -0700 Subject: [PATCH 04/13] Implement multipart copy in Java-based S3 async client (#4189) --- ...S3ClientMultiPartCopyIntegrationTest.java} | 59 ++++++++++++------- .../internal/crt/DefaultS3CrtAsyncClient.java | 5 +- .../{crt => multipart}/CopyObjectHelper.java | 9 ++- .../multipart/MultipartS3AsyncClient.java | 14 +++++ .../multipart/SdkPojoConversionUtils.java | 27 ++++++++- .../s3/internal/crt/CopyObjectHelperTest.java | 25 +++++++- 6 files changed, 110 insertions(+), 29 deletions(-) rename services/s3/src/it/java/software/amazon/awssdk/services/s3/{crt/S3CrtClientCopyIntegrationTest.java => multipart/S3ClientMultiPartCopyIntegrationTest.java} (77%) rename services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/{crt => multipart}/CopyObjectHelper.java (97%) diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrtClientCopyIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java similarity index 77% rename from services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrtClientCopyIntegrationTest.java rename to services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java index d0f92bb5b29a..46cc06e3b415 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrtClientCopyIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.s3.crt; +package software.amazon.awssdk.services.s3.multipart; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.fail; @@ -24,26 +24,32 @@ import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.Base64; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; import javax.crypto.KeyGenerator; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.MetadataDirective; import software.amazon.awssdk.utils.Md5Utils; -public class S3CrtClientCopyIntegrationTest extends S3IntegrationTestBase { - private static final String BUCKET = temporaryBucketName(S3CrtClientCopyIntegrationTest.class); +@Timeout(value = 3, unit = TimeUnit.MINUTES) +public class S3ClientMultiPartCopyIntegrationTest extends S3IntegrationTestBase { + private static final String BUCKET = temporaryBucketName(S3ClientMultiPartCopyIntegrationTest.class); private static final String ORIGINAL_OBJ = "test_file.dat"; private static final String COPIED_OBJ = "test_file_copy.dat"; private static final String ORIGINAL_OBJ_SPECIAL_CHARACTER = "original-special-chars-@$%"; @@ -51,6 +57,7 @@ public class S3CrtClientCopyIntegrationTest extends S3IntegrationTestBase { private static final long OBJ_SIZE = ThreadLocalRandom.current().nextLong(8 * 1024 * 1024, 16 * 1024 * 1024 + 1); private static final long SMALL_OBJ_SIZE = 1024 * 1024; private static S3AsyncClient s3CrtAsyncClient; + private static S3AsyncClient s3MpuClient; @BeforeAll public static void setUp() throws Exception { S3IntegrationTestBase.setUp(); @@ -59,40 +66,50 @@ public static void setUp() throws Exception { .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) .region(DEFAULT_REGION) .build(); + s3MpuClient = new MultipartS3AsyncClient(s3Async); } @AfterAll public static void teardown() throws Exception { s3CrtAsyncClient.close(); + s3MpuClient.close(); deleteBucketAndAllContents(BUCKET); } - @Test - void copy_singlePart_hasSameContent() { + public static Stream s3AsyncClient() { + return Stream.of(s3MpuClient, s3CrtAsyncClient); + } + + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("s3AsyncClient") + void copy_singlePart_hasSameContent(S3AsyncClient s3AsyncClient) { byte[] originalContent = randomBytes(SMALL_OBJ_SIZE); createOriginalObject(originalContent, ORIGINAL_OBJ); - copyObject(ORIGINAL_OBJ, COPIED_OBJ); + copyObject(ORIGINAL_OBJ, COPIED_OBJ, s3AsyncClient); validateCopiedObject(originalContent, ORIGINAL_OBJ); } - @Test - void copy_copiedObject_hasSameContent() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("s3AsyncClient") + void copy_copiedObject_hasSameContent(S3AsyncClient s3AsyncClient) { byte[] originalContent = randomBytes(OBJ_SIZE); createOriginalObject(originalContent, ORIGINAL_OBJ); - copyObject(ORIGINAL_OBJ, COPIED_OBJ); + copyObject(ORIGINAL_OBJ, COPIED_OBJ, s3AsyncClient); validateCopiedObject(originalContent, ORIGINAL_OBJ); } - @Test - void copy_specialCharacters_hasSameContent() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("s3AsyncClient") + void copy_specialCharacters_hasSameContent(S3AsyncClient s3AsyncClient) { byte[] originalContent = randomBytes(OBJ_SIZE); createOriginalObject(originalContent, ORIGINAL_OBJ_SPECIAL_CHARACTER); - copyObject(ORIGINAL_OBJ_SPECIAL_CHARACTER, COPIED_OBJ_SPECIAL_CHARACTER); + copyObject(ORIGINAL_OBJ_SPECIAL_CHARACTER, COPIED_OBJ_SPECIAL_CHARACTER, s3AsyncClient); validateCopiedObject(originalContent, COPIED_OBJ_SPECIAL_CHARACTER); } - @Test - void copy_ssecServerSideEncryption_shouldSucceed() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("s3AsyncClient") + void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) { byte[] originalContent = randomBytes(OBJ_SIZE); byte[] secretKey = generateSecretKey(); String b64Key = Base64.getEncoder().encodeToString(secretKey); @@ -102,8 +119,8 @@ void copy_ssecServerSideEncryption_shouldSucceed() { String newB64Key = Base64.getEncoder().encodeToString(newSecretKey); String newB64KeyMd5 = Md5Utils.md5AsBase64(newSecretKey); - // Java S3 client is used because CRT S3 client putObject fails with SSE-C - // TODO: change back to S3CrtClient once the issue is fixed in CRT + // MPU S3 client gets stuck + // TODO: change back to s3AsyncClient once the issue is fixed in MPU S3 client s3Async.putObject(r -> r.bucket(BUCKET) .key(ORIGINAL_OBJ) .sseCustomerKey(b64Key) @@ -111,7 +128,7 @@ void copy_ssecServerSideEncryption_shouldSucceed() { .sseCustomerKeyMD5(b64KeyMd5), AsyncRequestBody.fromBytes(originalContent)).join(); - CompletableFuture future = s3CrtAsyncClient.copyObject(c -> c + CompletableFuture future = s3AsyncClient.copyObject(c -> c .sourceBucket(BUCKET) .sourceKey(ORIGINAL_OBJ) .metadataDirective(MetadataDirective.REPLACE) @@ -147,8 +164,8 @@ private void createOriginalObject(byte[] originalContent, String originalKey) { AsyncRequestBody.fromBytes(originalContent)).join(); } - private void copyObject(String original, String destination) { - CompletableFuture future = s3CrtAsyncClient.copyObject(c -> c + private void copyObject(String original, String destination, S3AsyncClient s3AsyncClient) { + CompletableFuture future = s3AsyncClient.copyObject(c -> c .sourceBucket(BUCKET) .sourceKey(original) .destinationBucket(BUCKET) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java index 860ac509932e..21c85520db9b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java @@ -51,6 +51,7 @@ import software.amazon.awssdk.services.s3.S3CrtAsyncClientBuilder; import software.amazon.awssdk.services.s3.crt.S3CrtHttpConfiguration; import software.amazon.awssdk.services.s3.crt.S3CrtRetryConfiguration; +import software.amazon.awssdk.services.s3.internal.multipart.CopyObjectHelper; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectRequest; @@ -67,7 +68,9 @@ private DefaultS3CrtAsyncClient(DefaultS3CrtClientBuilder builder) { super(initializeS3AsyncClient(builder)); long partSizeInBytes = builder.minimalPartSizeInBytes == null ? DEFAULT_PART_SIZE_IN_BYTES : builder.minimalPartSizeInBytes; - this.copyObjectHelper = new CopyObjectHelper((S3AsyncClient) delegate(), partSizeInBytes); + this.copyObjectHelper = new CopyObjectHelper((S3AsyncClient) delegate(), + partSizeInBytes, + partSizeInBytes); } @Override diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java similarity index 97% rename from services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java rename to services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java index 9070eb7192c5..31b947bb89c5 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.s3.internal.crt; +package software.amazon.awssdk.services.s3.internal.multipart; import java.util.ArrayList; @@ -23,6 +23,7 @@ import java.util.stream.IntStream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.crt.UploadPartCopyRequestIterable; import software.amazon.awssdk.services.s3.internal.multipart.GenericMultipartHelper; import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; @@ -50,13 +51,15 @@ public final class CopyObjectHelper { private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; private final GenericMultipartHelper genericMultipartHelper; + private final long uploadThreshold; - public CopyObjectHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes) { + public CopyObjectHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long uploadThreshold) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, SdkPojoConversionUtils::toAbortMultipartUploadRequest, SdkPojoConversionUtils::toCopyObjectResponse); + this.uploadThreshold = uploadThreshold; } public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { @@ -89,7 +92,7 @@ private void doCopyObject(CopyObjectRequest copyObjectRequest, CompletableFuture HeadObjectResponse headObjectResponse) { Long contentLength = headObjectResponse.contentLength(); - if (contentLength <= partSizeInBytes) { + if (contentLength <= partSizeInBytes || contentLength <= uploadThreshold) { log.debug(() -> "Starting the copy as a single copy part request"); copyInOneChunk(copyObjectRequest, returnFuture); } else { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index f2895d65fcd2..869eb4048144 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -21,6 +21,8 @@ import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CopyObjectRequest; +import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; @@ -33,15 +35,27 @@ public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; private final MultipartUploadHelper mpuHelper; + private final CopyObjectHelper copyObjectHelper; public MultipartS3AsyncClient(S3AsyncClient delegate) { super(delegate); // TODO: pass a config object to the upload helper instead mpuHelper = new MultipartUploadHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); + copyObjectHelper = new CopyObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD); } @Override public CompletableFuture putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) { return mpuHelper.uploadObject(putObjectRequest, requestBody); } + + @Override + public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { + return copyObjectHelper.copyObject(copyObjectRequest); + } + + @Override + public void close() { + delegate().close(); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java index 70512084150b..a99c16670d17 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java @@ -38,12 +38,14 @@ import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.Logger; /** * Request conversion utility method for POJO classes associated with multipart feature. */ @SdkInternalApi public final class SdkPojoConversionUtils { + private static final Logger log = Logger.loggerFor(SdkPojoConversionUtils.class); private static final HashSet PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE = new HashSet<>(Arrays.asList("ChecksumSHA1", "ChecksumSHA256", "ContentMD5", "ChecksumCRC32C", "ChecksumCRC32")); @@ -68,9 +70,22 @@ public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObj } public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { - HeadObjectRequest.Builder builder = HeadObjectRequest.builder(); - setSdkFields(builder, copyObjectRequest); - return builder.build(); + + // We can't set SdkFields directly because the fields in CopyObjectRequest do not match 100% with the ones in + // HeadObjectRequest + return HeadObjectRequest.builder() + .bucket(copyObjectRequest.sourceBucket()) + .key(copyObjectRequest.sourceKey()) + .versionId(copyObjectRequest.sourceVersionId()) + .ifMatch(copyObjectRequest.copySourceIfMatch()) + .ifModifiedSince(copyObjectRequest.copySourceIfModifiedSince()) + .ifNoneMatch(copyObjectRequest.copySourceIfNoneMatch()) + .ifUnmodifiedSince(copyObjectRequest.copySourceIfUnmodifiedSince()) + .expectedBucketOwner(copyObjectRequest.expectedSourceBucketOwner()) + .sseCustomerAlgorithm(copyObjectRequest.copySourceSSECustomerAlgorithm()) + .sseCustomerKey(copyObjectRequest.copySourceSSECustomerKey()) + .sseCustomerKeyMD5(copyObjectRequest.copySourceSSECustomerKeyMD5()) + .build(); } public static CompletedPart toCompletedPart(CopyPartResult copyPartResult, int partNumber) { @@ -106,6 +121,8 @@ public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(CopyOb CreateMultipartUploadRequest.Builder builder = CreateMultipartUploadRequest.builder(); setSdkFields(builder, copyObjectRequest); + builder.bucket(copyObjectRequest.destinationBucket()); + builder.key(copyObjectRequest.destinationKey()); return builder.build(); } @@ -136,6 +153,8 @@ private static CopyObjectResult toCopyObjectResult(CompleteMultipartUploadRespon public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { AbortMultipartUploadRequest.Builder builder = AbortMultipartUploadRequest.builder(); setSdkFields(builder, copyObjectRequest); + builder.bucket(copyObjectRequest.destinationBucket()); + builder.key(copyObjectRequest.destinationKey()); return builder; } @@ -154,6 +173,8 @@ public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest co return builder.copySourceRange(range) .partNumber(partNumber) .uploadId(uploadId) + .bucket(copyObjectRequest.destinationBucket()) + .key(copyObjectRequest.destinationKey()) .build(); } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java index ec78d7b15eb6..acca503a352f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java @@ -33,6 +33,7 @@ import org.mockito.stubbing.Answer; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.CopyObjectHelper; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; @@ -59,10 +60,13 @@ class CopyObjectHelperTest { private S3AsyncClient s3AsyncClient; private CopyObjectHelper copyHelper; + private static final long PART_SIZE = 1024L; + private static final long UPLOAD_THRESHOLD = 2048L; + @BeforeEach public void setUp() { s3AsyncClient = Mockito.mock(S3AsyncClient.class); - copyHelper = new CopyObjectHelper(s3AsyncClient, 1024L); + copyHelper = new CopyObjectHelper(s3AsyncClient, PART_SIZE, UPLOAD_THRESHOLD); } @Test @@ -114,6 +118,25 @@ void singlePartCopy_happyCase_shouldSucceed() { assertThat(future.join()).isEqualTo(expectedResponse); } + @Test + void copy_doesNotExceedThreshold_shouldUseSingleObjectCopy() { + + CopyObjectRequest copyObjectRequest = copyObjectRequest(); + + stubSuccessfulHeadObjectCall(2000L); + + CopyObjectResponse expectedResponse = CopyObjectResponse.builder().build(); + CompletableFuture copyFuture = + CompletableFuture.completedFuture(expectedResponse); + + when(s3AsyncClient.copyObject(copyObjectRequest)).thenReturn(copyFuture); + + CompletableFuture future = + copyHelper.copyObject(copyObjectRequest); + + assertThat(future.join()).isEqualTo(expectedResponse); + } + @Test void multiPartCopy_fourPartsHappyCase_shouldSucceed() { CopyObjectRequest copyObjectRequest = copyObjectRequest(); From 35e6e4e486e908200ce0b5bd3d64c02a98d3e0a4 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Tue, 18 Jul 2023 09:29:16 -0700 Subject: [PATCH 05/13] Create split method in AsyncRequestBody to return SplittingPublisher (#4188) * Create split method in AsyncRequestBody to return SplittingPublisher * Fix Javadoc and build --- .../awssdk/core/async/AsyncRequestBody.java | 41 ++++++++++ .../async/SplitAsyncRequestBodyResponse.java | 80 +++++++++++++++++++ .../internal/async/SplittingPublisher.java | 16 ++-- .../core/async/AsyncRequestBodyTest.java | 32 +++++--- .../SplitAsyncRequestBodyResponseTest.java | 29 +++++++ .../multipart/MultipartUploadHelper.java | 38 ++++----- 6 files changed, 198 insertions(+), 38 deletions(-) create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 7a1738f51d97..cad4236d241a 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -30,8 +31,10 @@ import software.amazon.awssdk.core.internal.async.ByteArrayAsyncRequestBody; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.internal.async.InputStreamWithExecutorAsyncRequestBody; +import software.amazon.awssdk.core.internal.async.SplittingPublisher; import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.Validate; /** * Interface to allow non-blocking streaming of request content. This follows the reactive streams pattern where @@ -246,4 +249,42 @@ static BlockingOutputStreamAsyncRequestBody forBlockingOutputStream(Long content static AsyncRequestBody empty() { return fromBytes(new byte[0]); } + + + /** + * Converts this {@link AsyncRequestBody} to a publisher of {@link AsyncRequestBody}s, each of which publishes a specific + * portion of the original data, based on the configured {code chunkSizeInBytes}. + * + *

+ * If content length of this {@link AsyncRequestBody} is present, each divided {@link AsyncRequestBody} is delivered to the + * subscriber right after it's initialized. + *

+ * // TODO: API Surface Area review: should we make this behavior configurable? + * If content length is null, it is sent after the entire content for that chunk is buffered. + * In this case, the configured {@code maxMemoryUsageInBytes} must be larger than or equal to {@code chunkSizeInBytes}. + * + * @param chunkSizeInBytes the size for each divided chunk. The last chunk may be smaller than the configured size. + * @param maxMemoryUsageInBytes the max memory the SDK will use to buffer the content + * @return SplitAsyncRequestBodyResult + */ + default SplitAsyncRequestBodyResponse split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { + Validate.isPositive(chunkSizeInBytes, "chunkSizeInBytes"); + Validate.isPositive(maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); + + if (!this.contentLength().isPresent()) { + Validate.isTrue(maxMemoryUsageInBytes >= chunkSizeInBytes, + "maxMemoryUsageInBytes must be larger than or equal to " + + "chunkSizeInBytes if the content length is unknown"); + } + + CompletableFuture future = new CompletableFuture<>(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(this) + .chunkSizeInBytes(chunkSizeInBytes) + .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .resultFuture(future) + .build(); + + return SplitAsyncRequestBodyResponse.create(splittingPublisher, future); + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java new file mode 100644 index 000000000000..0035c87520ec --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java @@ -0,0 +1,80 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.Validate; + +/** + * Containing the result from {@link AsyncRequestBody#split(long, long)} + */ +@SdkPublicApi +public final class SplitAsyncRequestBodyResponse { + private final SdkPublisher asyncRequestBody; + private final CompletableFuture future; + + private SplitAsyncRequestBodyResponse(SdkPublisher asyncRequestBody, CompletableFuture future) { + this.asyncRequestBody = Validate.paramNotNull(asyncRequestBody, "asyncRequestBody"); + this.future = Validate.paramNotNull(future, "future"); + } + + public static SplitAsyncRequestBodyResponse create(SdkPublisher asyncRequestBody, + CompletableFuture future) { + return new SplitAsyncRequestBodyResponse(asyncRequestBody, future); + } + + /** + * Returns the converted {@link SdkPublisher} of {@link AsyncRequestBody}s. Each {@link AsyncRequestBody} publishes a specific + * portion of the original data. + */ + public SdkPublisher asyncRequestBodyPublisher() { + return asyncRequestBody; + } + + /** + * Returns {@link CompletableFuture} that will be notified when all data has been consumed or if an error occurs. + */ + public CompletableFuture future() { + return future; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SplitAsyncRequestBodyResponse that = (SplitAsyncRequestBodyResponse) o; + + if (!asyncRequestBody.equals(that.asyncRequestBody)) { + return false; + } + return future.equals(that.future); + } + + @Override + public int hashCode() { + int result = asyncRequestBody.hashCode(); + result = 31 * result + future.hashCode(); + return result; + } +} + diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 8152e13980a6..99cf1e7c3381 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -31,13 +31,11 @@ import software.amazon.awssdk.utils.async.SimplePublisher; /** - * Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the - * original data. + * Splits an {@link AsyncRequestBody} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of + * the original data. * *

If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized. * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length. - * - * // TODO: create a default method in AsyncRequestBody for this */ @SdkInternalApi public class SplittingPublisher implements SdkPublisher { @@ -51,9 +49,9 @@ public class SplittingPublisher implements SdkPublisher { private SplittingPublisher(Builder builder) { this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); - this.chunkSizeInBytes = Validate.paramNotNull(builder.chunkSizeInBytes, "chunkSizeInBytes"); + this.chunkSizeInBytes = Validate.isPositive(builder.chunkSizeInBytes, "chunkSizeInBytes"); this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); - this.maxMemoryUsageInBytes = builder.maxMemoryUsageInBytes == null ? Long.MAX_VALUE : builder.maxMemoryUsageInBytes; + this.maxMemoryUsageInBytes = Validate.isPositive(builder.maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); this.future = builder.future; // We need to cancel upstream subscription if the future gets cancelled. @@ -304,13 +302,13 @@ public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { * @param chunkSizeInBytes The new chunkSizeInBytes value. * @return This object for method chaining. */ - public Builder chunkSizeInBytes(Long chunkSizeInBytes) { + public Builder chunkSizeInBytes(long chunkSizeInBytes) { this.chunkSizeInBytes = chunkSizeInBytes; return this; } /** - * Sets the maximum memory usage in bytes. By default, it uses unlimited memory. + * Sets the maximum memory usage in bytes. * * @param maxMemoryUsageInBytes The new maxMemoryUsageInBytes value. * @return This object for method chaining. @@ -319,7 +317,7 @@ public Builder chunkSizeInBytes(Long chunkSizeInBytes) { // on a new byte buffer. But we don't know for sure what the size of a buffer we request will be (we do use the size // for the last byte buffer as a hint), so I don't think we can have a truly accurate max. Maybe we call it minimum // buffer size instead? - public Builder maxMemoryUsageInBytes(Long maxMemoryUsageInBytes) { + public Builder maxMemoryUsageInBytes(long maxMemoryUsageInBytes) { this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; return this; } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java index e0252c9ba6d2..2dd4cb029ea0 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java @@ -15,31 +15,23 @@ package software.amazon.awssdk.core.async; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import io.reactivex.Flowable; -import java.io.File; -import java.io.FileWriter; import java.io.IOException; -import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.nio.file.FileSystem; import java.nio.file.Files; import java.nio.file.Path; -import java.time.Instant; -import java.util.Collections; import java.util.List; -import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.stream.Collectors; import org.assertj.core.util.Lists; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.reactivestreams.Publisher; @@ -47,7 +39,6 @@ import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.http.async.SimpleSubscriber; import software.amazon.awssdk.utils.BinaryUtils; -import software.amazon.awssdk.utils.StringInputStream; @RunWith(Parameterized.class) public class AsyncRequestBodyTest { @@ -177,4 +168,25 @@ public void fromBytes_byteArrayNotNull_createsCopy() { ByteBuffer publishedBb = Flowable.fromPublisher(body).toList().blockingGet().get(0); assertThat(BinaryUtils.copyAllBytesFrom(publishedBb)).isEqualTo(original); } + + @Test + public void split_nonPositiveInput_shouldThrowException() { + AsyncRequestBody body = AsyncRequestBody.fromString("test"); + assertThatThrownBy(() -> body.split(0, 4)).hasMessageContaining("must be positive"); + assertThatThrownBy(() -> body.split(-1, 4)).hasMessageContaining("must be positive"); + assertThatThrownBy(() -> body.split(5, 0)).hasMessageContaining("must be positive"); + assertThatThrownBy(() -> body.split(5, -1)).hasMessageContaining("must be positive"); + } + + @Test + public void split_contentUnknownMaxMemorySmallerThanChunkSize_shouldThrowException() { + AsyncRequestBody body = AsyncRequestBody.fromPublisher(new Publisher() { + @Override + public void subscribe(Subscriber s) { + + } + }); + assertThatThrownBy(() -> body.split(10, 4)) + .hasMessageContaining("must be larger than or equal"); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java new file mode 100644 index 000000000000..2d1e50bcd59d --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class SplitAsyncRequestBodyResponseTest { + + @Test + void equalsHashcode() { + EqualsVerifier.forClass(SplitAsyncRequestBodyResponse.class) + .withNonnullFields("asyncRequestBody", "future") + .verify(); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index 0dd017d0fd17..a3aea4a9bdf7 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -25,7 +25,7 @@ import java.util.function.Function; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.internal.async.SplittingPublisher; +import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; @@ -169,26 +169,26 @@ private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequ CompletableFuture returnFuture, Collection> futures) { - CompletableFuture splittingPublisherFuture = new CompletableFuture<>(); + AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes(mpuRequestContext.partSize) - .maxMemoryUsageInBytes(maxMemoryUsageInBytes) - .resultFuture(splittingPublisherFuture) - .build(); - - splittingPublisher.map(new BodyToRequestConverter(mpuRequestContext.request.left(), mpuRequestContext.uploadId)) - .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, - completedParts, - futures, - pair, - splittingPublisherFuture)) - .exceptionally(throwable -> { - returnFuture.completeExceptionally(throwable); - return null; - }); + + SplitAsyncRequestBodyResponse result = asyncRequestBody.split(mpuRequestContext.partSize, maxMemoryUsageInBytes); + + CompletableFuture splittingPublisherFuture = result.future(); + + result.asyncRequestBodyPublisher() + .map(new BodyToRequestConverter(mpuRequestContext.request.left(), + mpuRequestContext.uploadId)) + .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, + completedParts, + futures, + pair, + splittingPublisherFuture)) + .exceptionally(throwable -> { + returnFuture.completeExceptionally(throwable); + return null; + }); return splittingPublisherFuture; } From e0b4bfcb7e0269627f9494d0b864a679cb07db62 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Wed, 19 Jul 2023 16:05:01 -0700 Subject: [PATCH 06/13] Add more tests with ByteArrayAsyncRequestBody (#4214) --- .../async/SplittingPublisherTest.java | 104 +++++++++++------- .../crt/S3CrossRegionCrtIntegrationTest.java | 2 +- .../S3ClientMultiPartCopyIntegrationTest.java | 14 +-- ...ltipartClientPutObjectIntegrationTest.java | 28 ++++- 4 files changed, 95 insertions(+), 53 deletions(-) diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 45938ea684c8..3ce8559eec32 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -20,15 +20,21 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import java.nio.file.Files; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; @@ -51,11 +57,12 @@ public class SplittingPublisherTest { private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE); - private static RandomTempFile testFile; + private static File testFile; @BeforeAll public static void beforeAll() throws IOException { - testFile = new RandomTempFile("testfile.dat", CONTENT_SIZE); + testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); + Files.write(testFile.toPath(), CONTENT); } @AfterAll @@ -65,46 +72,19 @@ public static void afterAll() throws Exception { @ParameterizedTest @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) - void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int upstreamByteBufferSize) throws Exception { - CompletableFuture future = new CompletableFuture<>(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) - .asyncRequestBody(FileAsyncRequestBody.builder() - .path(testFile.toPath()) - .chunkSizeInBytes(upstreamByteBufferSize) - .build()) - - .resultFuture(future) - .chunkSizeInBytes((long) CHUNK_SIZE) - .maxMemoryUsageInBytes((long) CHUNK_SIZE * 4) - .build(); - - List> futures = new ArrayList<>(); + void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int chunkSize) throws Exception { - splittingPublisher.subscribe(requestBody -> { - CompletableFuture baosFuture = new CompletableFuture<>(); - BaosSubscriber subscriber = new BaosSubscriber(baosFuture); - futures.add(baosFuture); - requestBody.subscribe(subscriber); - }).get(5, TimeUnit.SECONDS); - - assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK); + FileAsyncRequestBody fileAsyncRequestBody = FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .chunkSizeInBytes(chunkSize) + .build(); + verifySplitContent(fileAsyncRequestBody, chunkSize); + } - for (int i = 0; i < futures.size(); i++) { - try (FileInputStream fileInputStream = new FileInputStream(testFile)) { - byte[] expected; - if (i == futures.size() - 1) { - expected = new byte[1]; - } else { - expected = new byte[5]; - } - fileInputStream.skip(i * 5); - fileInputStream.read(expected); - byte[] actualBytes = futures.get(i).join(); - assertThat(actualBytes).isEqualTo(expected); - }; - } - assertThat(future).isCompleted(); + @ParameterizedTest + @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) + void differentChunkSize_byteArrayShouldSplitAsyncRequestBodyCorrectly(int chunkSize) throws Exception { + verifySplitContent(AsyncRequestBody.fromBytes(CONTENT), chunkSize); } @@ -115,7 +95,7 @@ void cancelFuture_shouldCancelUpstream() throws IOException { SplittingPublisher splittingPublisher = SplittingPublisher.builder() .resultFuture(future) .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes((long) CHUNK_SIZE) + .chunkSizeInBytes(CHUNK_SIZE) .maxMemoryUsageInBytes(10L) .build(); @@ -139,7 +119,7 @@ public Optional contentLength() { SplittingPublisher splittingPublisher = SplittingPublisher.builder() .resultFuture(future) .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes((long) CHUNK_SIZE) + .chunkSizeInBytes(CHUNK_SIZE) .maxMemoryUsageInBytes(10L) .build(); @@ -177,6 +157,46 @@ public Optional contentLength() { } + + private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(asyncRequestBody) + .resultFuture(future) + .chunkSizeInBytes(chunkSize) + .maxMemoryUsageInBytes((long) chunkSize * 4) + .build(); + + List> futures = new ArrayList<>(); + + splittingPublisher.subscribe(requestBody -> { + CompletableFuture baosFuture = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(baosFuture); + futures.add(baosFuture); + requestBody.subscribe(subscriber); + }).get(5, TimeUnit.SECONDS); + + assertThat(futures.size()).isEqualTo((int) Math.ceil(CONTENT_SIZE / (double) chunkSize)); + + for (int i = 0; i < futures.size(); i++) { + try (FileInputStream fileInputStream = new FileInputStream(testFile)) { + byte[] expected; + if (i == futures.size() - 1) { + int lastChunk = CONTENT_SIZE % chunkSize == 0 ? chunkSize : (CONTENT_SIZE % chunkSize); + expected = new byte[lastChunk]; + } else { + expected = new byte[chunkSize]; + } + fileInputStream.skip(i * chunkSize); + fileInputStream.read(expected); + byte[] actualBytes = futures.get(i).join(); + assertThat(actualBytes).isEqualTo(expected); + }; + } + assertThat(future).isCompleted(); + } + private static class TestAsyncRequestBody implements AsyncRequestBody { private volatile boolean cancelled; private volatile boolean isDone; diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrossRegionCrtIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrossRegionCrtIntegrationTest.java index 953c6e4b4f4b..72c6fce095ce 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrossRegionCrtIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/S3CrossRegionCrtIntegrationTest.java @@ -17,7 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static software.amazon.awssdk.services.s3.crt.S3CrtClientCopyIntegrationTest.randomBytes; +import static software.amazon.awssdk.services.s3.multipart.S3ClientMultiPartCopyIntegrationTest.randomBytes; import static software.amazon.awssdk.services.s3.utils.ChecksumUtils.computeCheckSum; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java index ad52804b963d..6db434526fb9 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java @@ -119,14 +119,12 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) { String newB64Key = Base64.getEncoder().encodeToString(newSecretKey); String newB64KeyMd5 = Md5Utils.md5AsBase64(newSecretKey); - // MPU S3 client gets stuck - // TODO: change back to s3AsyncClient once the issue is fixed in MPU S3 client - s3Async.putObject(r -> r.bucket(BUCKET) - .key(ORIGINAL_OBJ) - .sseCustomerKey(b64Key) - .sseCustomerAlgorithm(AES256.name()) - .sseCustomerKeyMD5(b64KeyMd5), - AsyncRequestBody.fromBytes(originalContent)).join(); + s3AsyncClient.putObject(r -> r.bucket(BUCKET) + .key(ORIGINAL_OBJ) + .sseCustomerKey(b64Key) + .sseCustomerAlgorithm(AES256.name()) + .sseCustomerKeyMD5(b64KeyMd5), + AsyncRequestBody.fromBytes(originalContent)).join(); CompletableFuture future = s3AsyncClient.copyObject(c -> c .sourceBucket(BUCKET) diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index 4174b87883dc..f791b4b3c26a 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -20,7 +20,13 @@ import static org.assertj.core.api.Assertions.assertThat; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.nio.charset.Charset; import java.nio.file.Files; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -42,15 +48,18 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest private static final String TEST_KEY = "testfile.dat"; private static final int OBJ_SIZE = 19 * 1024 * 1024; - private static RandomTempFile testFile; + private static File testFile; private static S3AsyncClient mpuS3Client; @BeforeAll public static void setup() throws Exception { S3IntegrationTestBase.setUp(); S3IntegrationTestBase.createBucket(TEST_BUCKET); + byte[] CONTENT = + RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); - testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE); + testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); + Files.write(testFile.toPath(), CONTENT); mpuS3Client = new MultipartS3AsyncClient(s3Async); } @@ -75,4 +84,19 @@ void putObject_fileRequestBody_objectSentCorrectly() throws Exception { assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); } + @Test + @Timeout(value = 30, unit = SECONDS) + void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { + byte[] bytes = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); + AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(OBJ_SIZE); + byte[] expectedSum = ChecksumUtils.computeCheckSum(new ByteArrayInputStream(bytes)); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + } From 11c63629a315e7a49e7e77d9eddf7cb4a96f4ea4 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Fri, 21 Jul 2023 09:03:54 -0700 Subject: [PATCH 07/13] Handle null response metadata (#4215) * Handle null response metadata * Fix build --- .../awssdk/awscore/AwsResponseMetadata.java | 3 ++- .../multipart/SdkPojoConversionUtils.java | 19 ++++--------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/AwsResponseMetadata.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/AwsResponseMetadata.java index f9e326f62317..2bbdec695da8 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/AwsResponseMetadata.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/AwsResponseMetadata.java @@ -18,6 +18,7 @@ import static software.amazon.awssdk.awscore.util.AwsHeader.AWS_REQUEST_ID; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -48,7 +49,7 @@ protected AwsResponseMetadata(Map metadata) { } protected AwsResponseMetadata(AwsResponseMetadata responseMetadata) { - this(responseMetadata.metadata); + this(responseMetadata == null ? new HashMap<>() : responseMetadata.metadata); } /** diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java index a99c16670d17..25fde18cadaf 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java @@ -131,13 +131,8 @@ public static CopyObjectResponse toCopyObjectResponse(CompleteMultipartUploadRes setSdkFields(builder, response); - if (response.responseMetadata() != null) { - builder.responseMetadata(response.responseMetadata()); - } - - if (response.sdkHttpResponse() != null) { - builder.sdkHttpResponse(response.sdkHttpResponse()); - } + builder.responseMetadata(response.responseMetadata()); + builder.sdkHttpResponse(response.sdkHttpResponse()); return builder.copyObjectResult(toCopyObjectResult(response)) .build(); @@ -184,14 +179,8 @@ public static PutObjectResponse toPutObjectResponse(CompleteMultipartUploadRespo setSdkFields(builder, response); - // TODO: check why we have to do null check - if (response.responseMetadata() != null) { - builder.responseMetadata(response.responseMetadata()); - } - - if (response.sdkHttpResponse() != null) { - builder.sdkHttpResponse(response.sdkHttpResponse()); - } + builder.responseMetadata(response.responseMetadata()); + builder.sdkHttpResponse(response.sdkHttpResponse()); return builder.build(); } From 3e6e70f089111d290a9d92d8ad6b72d083e2cec1 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:22:37 -0700 Subject: [PATCH 08/13] Support streaming with unknown content length (#4226) * Support uploading with unknown content length * Refactoring --- .../awssdk/core/async/AsyncRequestBody.java | 19 +- .../async/SplitAsyncRequestBodyResponse.java | 80 ------ .../internal/async/SplittingPublisher.java | 84 +++--- .../SplitAsyncRequestBodyResponseTest.java | 29 -- .../async/SplittingPublisherTest.java | 29 -- ...ltipartClientPutObjectIntegrationTest.java | 33 ++- .../multipart/GenericMultipartHelper.java | 16 +- .../multipart/MultipartS3AsyncClient.java | 4 +- .../multipart/MultipartUploadHelper.java | 231 ++++------------ .../multipart/UploadObjectHelper.java | 73 +++++ .../UploadWithKnownContentLengthHelper.java | 251 ++++++++++++++++++ .../UploadWithUnknownContentLengthHelper.java | 247 +++++++++++++++++ ...rTest.java => UploadObjectHelperTest.java} | 187 +++++++++++-- 13 files changed, 883 insertions(+), 400 deletions(-) delete mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java delete mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java rename services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/{MultipartUploadHelperTest.java => UploadObjectHelperTest.java} (58%) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 3bd3d7136d47..3c6adb8fdbac 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -24,7 +24,6 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -420,24 +419,20 @@ static AsyncRequestBody empty() { * @param maxMemoryUsageInBytes the max memory the SDK will use to buffer the content * @return SplitAsyncRequestBodyResult */ - default SplitAsyncRequestBodyResponse split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { + default SdkPublisher split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { Validate.isPositive(chunkSizeInBytes, "chunkSizeInBytes"); Validate.isPositive(maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); - if (!this.contentLength().isPresent()) { + if (!contentLength().isPresent()) { Validate.isTrue(maxMemoryUsageInBytes >= chunkSizeInBytes, "maxMemoryUsageInBytes must be larger than or equal to " + "chunkSizeInBytes if the content length is unknown"); } - CompletableFuture future = new CompletableFuture<>(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .asyncRequestBody(this) - .chunkSizeInBytes(chunkSizeInBytes) - .maxMemoryUsageInBytes(maxMemoryUsageInBytes) - .resultFuture(future) - .build(); - - return SplitAsyncRequestBodyResponse.create(splittingPublisher, future); + return SplittingPublisher.builder() + .asyncRequestBody(this) + .chunkSizeInBytes(chunkSizeInBytes) + .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .build(); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java deleted file mode 100644 index 0035c87520ec..000000000000 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponse.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.core.async; - - -import java.util.concurrent.CompletableFuture; -import software.amazon.awssdk.annotations.SdkPublicApi; -import software.amazon.awssdk.utils.Validate; - -/** - * Containing the result from {@link AsyncRequestBody#split(long, long)} - */ -@SdkPublicApi -public final class SplitAsyncRequestBodyResponse { - private final SdkPublisher asyncRequestBody; - private final CompletableFuture future; - - private SplitAsyncRequestBodyResponse(SdkPublisher asyncRequestBody, CompletableFuture future) { - this.asyncRequestBody = Validate.paramNotNull(asyncRequestBody, "asyncRequestBody"); - this.future = Validate.paramNotNull(future, "future"); - } - - public static SplitAsyncRequestBodyResponse create(SdkPublisher asyncRequestBody, - CompletableFuture future) { - return new SplitAsyncRequestBodyResponse(asyncRequestBody, future); - } - - /** - * Returns the converted {@link SdkPublisher} of {@link AsyncRequestBody}s. Each {@link AsyncRequestBody} publishes a specific - * portion of the original data. - */ - public SdkPublisher asyncRequestBodyPublisher() { - return asyncRequestBody; - } - - /** - * Returns {@link CompletableFuture} that will be notified when all data has been consumed or if an error occurs. - */ - public CompletableFuture future() { - return future; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - SplitAsyncRequestBodyResponse that = (SplitAsyncRequestBodyResponse) o; - - if (!asyncRequestBody.equals(that.asyncRequestBody)) { - return false; - } - return future.equals(that.future); - } - - @Override - public int hashCode() { - int result = asyncRequestBody.hashCode(); - result = 31 * result + future.hashCode(); - return result; - } -} - diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 99cf1e7c3381..e18f9944a09e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -17,7 +17,6 @@ import java.nio.ByteBuffer; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -45,24 +44,12 @@ public class SplittingPublisher implements SdkPublisher { private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); private final long chunkSizeInBytes; private final long maxMemoryUsageInBytes; - private final CompletableFuture future; private SplittingPublisher(Builder builder) { this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); this.chunkSizeInBytes = Validate.isPositive(builder.chunkSizeInBytes, "chunkSizeInBytes"); this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); this.maxMemoryUsageInBytes = Validate.isPositive(builder.maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); - this.future = builder.future; - - // We need to cancel upstream subscription if the future gets cancelled. - future.whenComplete((r, t) -> { - if (t != null) { - if (splittingSubscriber.upstreamSubscription != null) { - log.trace(() -> "Cancelling subscription because return future completed exceptionally ", t); - splittingSubscriber.upstreamSubscription.cancel(); - } - } - }); } public static Builder builder() { @@ -117,26 +104,35 @@ public void onNext(ByteBuffer byteBuffer) { byteBufferSizeHint = byteBuffer.remaining(); while (true) { + + if (!byteBuffer.hasRemaining()) { + break; + } + int amountRemainingInChunk = amountRemainingInChunk(); // If we have fulfilled this chunk, - // we should create a new DownstreamBody if needed + // complete the current body if (amountRemainingInChunk == 0) { - completeCurrentBody(); + completeCurrentBodyAndCreateNewIfNeeded(byteBuffer); + amountRemainingInChunk = amountRemainingInChunk(); + } - if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { - int currentChunk = chunkNumber.incrementAndGet(); - long chunkSize = calculateChunkSize(totalDataRemaining()); - currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); - } + // If the current ByteBuffer < this chunk, send it as-is + if (amountRemainingInChunk > byteBuffer.remaining()) { + currentBody.send(byteBuffer.duplicate()); + break; } - amountRemainingInChunk = amountRemainingInChunk(); - if (amountRemainingInChunk >= byteBuffer.remaining()) { + // If the current ByteBuffer == this chunk, send it as-is and + // complete the current body + if (amountRemainingInChunk == byteBuffer.remaining()) { currentBody.send(byteBuffer.duplicate()); + completeCurrentBodyAndCreateNewIfNeeded(byteBuffer); break; } + // If the current ByteBuffer > this chunk, split this ByteBuffer ByteBuffer firstHalf = byteBuffer.duplicate(); int newLimit = firstHalf.position() + amountRemainingInChunk; firstHalf.limit(newLimit); @@ -147,13 +143,22 @@ public void onNext(ByteBuffer byteBuffer) { maybeRequestMoreUpstreamData(); } + private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { + completeCurrentBody(); + int currentChunk = chunkNumber.incrementAndGet(); + boolean shouldCreateNewDownstreamRequestBody; + Long dataRemaining = totalDataRemaining(); - /** - * If content length is known, we should create new DownstreamRequestBody if there's remaining data. - * If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet. - */ - private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) { - return !upstreamComplete || byteBuffer.remaining() > 0; + if (upstreamSize == null) { + shouldCreateNewDownstreamRequestBody = !upstreamComplete || byteBuffer.hasRemaining(); + } else { + shouldCreateNewDownstreamRequestBody = dataRemaining != null && dataRemaining > 0; + } + + if (shouldCreateNewDownstreamRequestBody) { + long chunkSize = calculateChunkSize(dataRemaining); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + } } private int amountRemainingInChunk() { @@ -161,6 +166,7 @@ private int amountRemainingInChunk() { } private void completeCurrentBody() { + log.debug(() -> "completeCurrentBody for chunk " + chunkNumber.get()); currentBody.complete(); if (upstreamSize == null) { sendCurrentBody(currentBody); @@ -172,12 +178,13 @@ public void onComplete() { upstreamComplete = true; log.trace(() -> "Received onComplete()"); completeCurrentBody(); - downstreamPublisher.complete().thenRun(() -> future.complete(null)); + downstreamPublisher.complete(); } @Override public void onError(Throwable t) { - currentBody.error(t); + log.trace(() -> "Received onError()", t); + downstreamPublisher.error(t); } private void sendCurrentBody(AsyncRequestBody body) { @@ -206,7 +213,7 @@ private void maybeRequestMoreUpstreamData() { } private boolean shouldRequestMoreData(long buffered) { - return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; + return buffered == 0 || buffered + byteBufferSizeHint <= maxMemoryUsageInBytes; } private Long totalDataRemaining() { @@ -240,7 +247,7 @@ public Optional contentLength() { } public void send(ByteBuffer data) { - log.trace(() -> "Sending bytebuffer " + data); + log.trace(() -> String.format("Sending bytebuffer %s to chunk %d", data, chunkNumber)); int length = data.remaining(); transferredLength += length; addDataBuffered(length); @@ -283,7 +290,6 @@ public static final class Builder { private AsyncRequestBody asyncRequestBody; private Long chunkSizeInBytes; private Long maxMemoryUsageInBytes; - private CompletableFuture future; /** * Configures the asyncRequestBody to split @@ -322,18 +328,6 @@ public Builder maxMemoryUsageInBytes(long maxMemoryUsageInBytes) { return this; } - /** - * Sets the result future. The future will be completed when all request bodies - * have been sent. - * - * @param future The new future value. - * @return This object for method chaining. - */ - public Builder resultFuture(CompletableFuture future) { - this.future = future; - return this; - } - public SplittingPublisher build() { return new SplittingPublisher(this); } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java deleted file mode 100644 index 2d1e50bcd59d..000000000000 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SplitAsyncRequestBodyResponseTest.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.core.async; - -import nl.jqno.equalsverifier.EqualsVerifier; -import org.junit.jupiter.api.Test; - -public class SplitAsyncRequestBodyResponseTest { - - @Test - void equalsHashcode() { - EqualsVerifier.forClass(SplitAsyncRequestBodyResponse.class) - .withNonnullFields("asyncRequestBody", "future") - .verify(); - } -} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 3ce8559eec32..368c403dbf88 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -31,10 +31,7 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; @@ -45,7 +42,6 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.BinaryUtils; public class SplittingPublisherTest { @@ -87,26 +83,6 @@ void differentChunkSize_byteArrayShouldSplitAsyncRequestBodyCorrectly(int chunkS verifySplitContent(AsyncRequestBody.fromBytes(CONTENT), chunkSize); } - - @Test - void cancelFuture_shouldCancelUpstream() throws IOException { - CompletableFuture future = new CompletableFuture<>(); - TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody(); - SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) - .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes(CHUNK_SIZE) - .maxMemoryUsageInBytes(10L) - .build(); - - OnlyRequestOnceSubscriber downstreamSubscriber = new OnlyRequestOnceSubscriber(); - splittingPublisher.subscribe(downstreamSubscriber); - - future.completeExceptionally(new RuntimeException("test")); - assertThat(asyncRequestBody.cancelled).isTrue(); - assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); - } - @Test void contentLengthNotPresent_shouldHandle() throws Exception { CompletableFuture future = new CompletableFuture<>(); @@ -117,7 +93,6 @@ public Optional contentLength() { } }; SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) .asyncRequestBody(asyncRequestBody) .chunkSizeInBytes(CHUNK_SIZE) .maxMemoryUsageInBytes(10L) @@ -159,11 +134,8 @@ public Optional contentLength() { private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { - CompletableFuture future = new CompletableFuture<>(); SplittingPublisher splittingPublisher = SplittingPublisher.builder() - .resultFuture(future) .asyncRequestBody(asyncRequestBody) - .resultFuture(future) .chunkSizeInBytes(chunkSize) .maxMemoryUsageInBytes((long) chunkSize * 4) .build(); @@ -194,7 +166,6 @@ private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int ch assertThat(actualBytes).isEqualTo(expected); }; } - assertThat(future).isCompleted(); } private static class TestAsyncRequestBody implements AsyncRequestBody { diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index f791b4b3c26a..cb72906943b9 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -22,18 +22,23 @@ import java.io.ByteArrayInputStream; import java.io.File; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Files; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Subscriber; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; @@ -42,6 +47,7 @@ import software.amazon.awssdk.services.s3.utils.ChecksumUtils; import software.amazon.awssdk.testutils.RandomTempFile; +@Timeout(value = 30, unit = SECONDS) public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); @@ -71,7 +77,6 @@ public static void teardown() throws Exception { } @Test - @Timeout(value = 20, unit = SECONDS) void putObject_fileRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); @@ -85,7 +90,6 @@ void putObject_fileRequestBody_objectSentCorrectly() throws Exception { } @Test - @Timeout(value = 30, unit = SECONDS) void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { byte[] bytes = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); @@ -99,4 +103,29 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); } + @Test + void putObject_unknownContentLength_objectSentCorrectly() throws Exception { + AsyncRequestBody body = FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .build(); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + body.subscribe(s); + } + }).get(30, SECONDS); + + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java index 4ab4b22a0e79..905c1bc928ea 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -79,13 +79,9 @@ public int determinePartCount(long contentLength, long partSize) { } public CompletableFuture completeMultipartUpload( - RequestT request, String uploadId, AtomicReferenceArray completedParts) { + RequestT request, String uploadId, CompletedPart[] parts) { log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", uploadId)); - CompletedPart[] parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() .bucket(request.getValueForField("Bucket", String.class).get()) @@ -99,6 +95,15 @@ public CompletableFuture completeMultipartUploa return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } + public CompletableFuture completeMultipartUpload( + RequestT request, String uploadId, AtomicReferenceArray completedParts) { + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + return completeMultipartUpload(request, uploadId, parts); + } + public BiFunction handleExceptionOrResponse( RequestT request, CompletableFuture returnFuture, @@ -119,6 +124,7 @@ public BiFunction handleExcept } public void cleanUpParts(String uploadId, AbortMultipartUploadRequest.Builder abortMultipartUploadRequest) { + log.debug(() -> "Aborting multipart upload: " + uploadId); s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest.uploadId(uploadId).build()) .exceptionally(throwable -> { log.warn(() -> String.format("Failed to abort previous multipart upload " diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index 869eb4048144..a4b3147254f9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -34,13 +34,13 @@ public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { private static final long DEFAULT_THRESHOLD = 8L * 1024 * 1024; private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; - private final MultipartUploadHelper mpuHelper; + private final UploadObjectHelper mpuHelper; private final CopyObjectHelper copyObjectHelper; public MultipartS3AsyncClient(S3AsyncClient delegate) { super(delegate); // TODO: pass a config object to the upload helper instead - mpuHelper = new MultipartUploadHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); + mpuHelper = new UploadObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); copyObjectHelper = new CopyObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index a3aea4a9bdf7..1228e577fcd1 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -20,12 +20,9 @@ import java.util.Collection; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.Function; +import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.SplitAsyncRequestBodyResponse; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; @@ -39,7 +36,8 @@ import software.amazon.awssdk.utils.Pair; /** - * An internal helper class that automatically uses multipart upload based on the size of the object. + * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} + * and {@link UploadWithKnownContentLengthHelper}. */ @SdkInternalApi public final class MultipartUploadHelper { @@ -65,210 +63,85 @@ public MultipartUploadHelper(S3AsyncClient s3AsyncClient, this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; } - public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, - AsyncRequestBody asyncRequestBody) { - Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); - - // TODO: support null content length. Need to determine whether to use single object or MPU based on the first - // AsyncRequestBody - if (contentLength == null) { - throw new IllegalArgumentException("Content-length is required"); - } - - CompletableFuture returnFuture = new CompletableFuture<>(); - - try { - if (contentLength > multipartUploadThresholdInBytes && contentLength > partSizeInBytes) { - log.debug(() -> "Starting the upload as multipart upload request"); - uploadInParts(putObjectRequest, contentLength, asyncRequestBody, returnFuture); - } else { - log.debug(() -> "Starting the upload as a single upload part request"); - uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); - } - - } catch (Throwable throwable) { - returnFuture.completeExceptionally(throwable); - } - - return returnFuture; - } - - private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, - CompletableFuture returnFuture) { - + CompletableFuture createMultipartUpload(PutObjectRequest putObjectRequest, + CompletableFuture returnFuture) { CreateMultipartUploadRequest request = SdkPojoConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); // Ensure cancellations are forwarded to the createMultipartUploadFuture future CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); - - createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { - if (throwable != null) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); - } else { - log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); - doUploadInParts(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, - createMultipartUploadResponse.uploadId()); - } - }); + return createMultipartUploadFuture; } - private void doUploadInParts(Pair request, - long contentLength, - CompletableFuture returnFuture, - String uploadId) { - - long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); - int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); - - log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, - optimalPartSize)); - - // The list of completed parts must be sorted - AtomicReferenceArray completedParts = new AtomicReferenceArray<>(partCount); - - PutObjectRequest putObjectRequest = request.left(); - - Collection> futures = new ConcurrentLinkedQueue<>(); - - MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); - - CompletableFuture requestsFuture = sendUploadPartRequests(mpuRequestContext, - completedParts, - returnFuture, - futures); - requestsFuture.whenComplete((r, t) -> { - if (t != null) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); - genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); - cancelingOtherOngoingRequests(futures, t); - return; - } - CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) - .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, - uploadId, - completedParts)) - .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, - uploadId)) - .exceptionally(throwable -> { - genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", - throwable); - return null; - }); - }); - } - - private static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { - log.trace(() -> "cancelling other ongoing requests " + futures.size()); - futures.forEach(f -> f.completeExceptionally(t)); + void completeMultipartUpload(CompletableFuture returnFuture, + String uploadId, + CompletedPart[] completedParts, + PutObjectRequest putObjectRequest) { + genericMultipartHelper.completeMultipartUpload(putObjectRequest, + uploadId, + completedParts) + .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, + uploadId)) + .exceptionally(throwable -> { + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); + return null; + }); } - private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequestContext, - AtomicReferenceArray completedParts, - CompletableFuture returnFuture, - Collection> futures) { - - - - AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); - - SplitAsyncRequestBodyResponse result = asyncRequestBody.split(mpuRequestContext.partSize, maxMemoryUsageInBytes); - - CompletableFuture splittingPublisherFuture = result.future(); - - result.asyncRequestBodyPublisher() - .map(new BodyToRequestConverter(mpuRequestContext.request.left(), - mpuRequestContext.uploadId)) - .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, - completedParts, - futures, - pair, - splittingPublisherFuture)) - .exceptionally(throwable -> { - returnFuture.completeExceptionally(throwable); - return null; - }); - return splittingPublisherFuture; - } - - private void sendIndividualUploadPartRequest(String uploadId, - AtomicReferenceArray completedParts, - Collection> futures, - Pair requestPair, - CompletableFuture sendUploadPartRequestsFuture) { + CompletableFuture sendIndividualUploadPartRequest(String uploadId, + Consumer completedPartsConsumer, + Collection> futures, + Pair requestPair) { UploadPartRequest uploadPartRequest = requestPair.left(); Integer partNumber = uploadPartRequest.partNumber(); log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " + "contentLength " + requestPair.right().contentLength()); - CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); + CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, + requestPair.right()); CompletableFuture convertFuture = - uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, + uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedPartsConsumer, partNumber, uploadPartResponse)); futures.add(convertFuture); CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); - CompletableFutureUtils.forwardExceptionTo(uploadPartFuture, sendUploadPartRequestsFuture); + return convertFuture; + } + + void failRequestsElegantly(Collection> futures, + Throwable t, + String uploadId, + CompletableFuture returnFuture, + PutObjectRequest putObjectRequest) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + if (uploadId != null) { + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + } + cancelingOtherOngoingRequests(futures, t); + } + + static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { + log.trace(() -> "cancelling other ongoing requests " + futures.size()); + futures.forEach(f -> f.completeExceptionally(t)); } - private static CompletedPart convertUploadPartResponse(AtomicReferenceArray completedParts, - Integer partNumber, - UploadPartResponse uploadPartResponse) { + static CompletedPart convertUploadPartResponse(Consumer consumer, + Integer partNumber, + UploadPartResponse uploadPartResponse) { CompletedPart completedPart = SdkPojoConversionUtils.toCompletedPart(uploadPartResponse, partNumber); - completedParts.set(partNumber - 1, completedPart); + consumer.accept(completedPart); return completedPart; } - private void uploadInOneChunk(PutObjectRequest putObjectRequest, - AsyncRequestBody asyncRequestBody, - CompletableFuture returnFuture) { + void uploadInOneChunk(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { CompletableFuture putObjectResponseCompletableFuture = s3AsyncClient.putObject(putObjectRequest, asyncRequestBody); CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); } - - private static final class BodyToRequestConverter implements Function> { - private int partNumber = 1; - private final PutObjectRequest putObjectRequest; - private final String uploadId; - - BodyToRequestConverter(PutObjectRequest putObjectRequest, String uploadId) { - this.putObjectRequest = putObjectRequest; - this.uploadId = uploadId; - } - - @Override - public Pair apply(AsyncRequestBody asyncRequestBody) { - log.trace(() -> "Generating uploadPartRequest for partNumber " + partNumber); - UploadPartRequest uploadRequest = - SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, - partNumber, - uploadId); - ++partNumber; - return Pair.of(uploadRequest, asyncRequestBody); - } - } - - private static final class MpuRequestContext { - private final Pair request; - private final long contentLength; - private final long partSize; - - private final String uploadId; - - private MpuRequestContext(Pair request, - long contentLength, - long partSize, - String uploadId) { - this.request = request; - this.contentLength = contentLength; - this.partSize = partSize; - this.uploadId = uploadId; - } - } - } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java new file mode 100644 index 000000000000..0700e8ade5f9 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.utils.Logger; + +/** + * An internal helper class that automatically uses multipart upload based on the size of the object. + */ +@SdkInternalApi +public final class UploadObjectHelper { + private static final Logger log = Logger.loggerFor(UploadObjectHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + private final UploadWithKnownContentLengthHelper uploadWithKnownContentLength; + private final UploadWithUnknownContentLengthHelper uploadWithUnknownContentLength; + + public UploadObjectHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.uploadWithKnownContentLength = new UploadWithKnownContentLengthHelper(s3AsyncClient, + partSizeInBytes, + multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + this.uploadWithUnknownContentLength = new UploadWithUnknownContentLengthHelper(s3AsyncClient, + partSizeInBytes, + multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); + + if (contentLength == null) { + return uploadWithUnknownContentLength.uploadObject(putObjectRequest, asyncRequestBody); + } else { + return uploadWithKnownContentLength.uploadObject(putObjectRequest, asyncRequestBody, contentLength.longValue()); + } + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java new file mode 100644 index 000000000000..e8bef01ab81b --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -0,0 +1,251 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.Consumer; +import java.util.stream.IntStream; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * An internal helper class that automatically uses multipart upload based on the size of the object. + */ +@SdkInternalApi +public final class UploadWithKnownContentLengthHelper { + private static final Logger log = Logger.loggerFor(UploadWithKnownContentLengthHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + private final MultipartUploadHelper multipartUploadHelper; + + public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + long contentLength) { + CompletableFuture returnFuture = new CompletableFuture<>(); + + try { + if (contentLength > multipartUploadThresholdInBytes && contentLength > partSizeInBytes) { + log.debug(() -> "Starting the upload as multipart upload request"); + uploadInParts(putObjectRequest, contentLength, asyncRequestBody, returnFuture); + } else { + log.debug(() -> "Starting the upload as a single upload part request"); + multipartUploadHelper.uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); + } + + } catch (Throwable throwable) { + returnFuture.completeExceptionally(throwable); + } + + return returnFuture; + } + + private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + + CompletableFuture createMultipartUploadFuture = + multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + } else { + log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + doUploadInParts(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, + createMultipartUploadResponse.uploadId()); + } + }); + } + + private void doUploadInParts(Pair request, + long contentLength, + CompletableFuture returnFuture, + String uploadId) { + + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); + + log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, + optimalPartSize)); + + MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); + + request.right() + .split(mpuRequestContext.partSize, maxMemoryUsageInBytes) + .subscribe(new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, + returnFuture)); + } + + private static final class MpuRequestContext { + private final Pair request; + private final long contentLength; + private final long partSize; + + private final String uploadId; + + private MpuRequestContext(Pair request, + long contentLength, + long partSize, + String uploadId) { + this.request = request; + this.contentLength = contentLength; + this.partSize = partSize; + this.uploadId = uploadId; + } + } + + private class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { + + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + + /** + * Indicates whether CompleteMultipart has been initiated or not. + */ + private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); + + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + + private final AtomicInteger partNumber = new AtomicInteger(1); + + private final AtomicReferenceArray completedParts; + private final String uploadId; + private final Collection> futures = new ConcurrentLinkedQueue<>(); + + private final PutObjectRequest putObjectRequest; + private final CompletableFuture returnFuture; + private Subscription subscription; + + private volatile boolean isDone; + + KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, + CompletableFuture returnFuture) { + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(mpuRequestContext.contentLength, + partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(mpuRequestContext.contentLength, optimalPartSize); + this.putObjectRequest = mpuRequestContext.request.left(); + this.returnFuture = returnFuture; + this.completedParts = new AtomicReferenceArray<>(partCount); + this.uploadId = mpuRequestContext.uploadId; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } + this.subscription = s; + s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + multipartUploadHelper.cancelingOtherOngoingRequests(futures, t); + } + }); + } + + @Override + public void onNext(AsyncRequestBody asyncRequestBody) { + log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + asyncRequestBodyInFlight.incrementAndGet(); + UploadPartRequest uploadRequest = + SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber.getAndIncrement(), + uploadId); + + Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, + completedPart); + multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, + Pair.of(uploadRequest, asyncRequestBody)) + .whenComplete((r, t) -> { + if (t != null) { + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, + putObjectRequest); + } + } else { + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); + } + }); + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Received onError ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } + + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + isDone = true; + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + } + + private void completeMultipartUploadIfFinish(int requestsInFlight) { + if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest); + } + } + + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java new file mode 100644 index 000000000000..d2034b4b4e94 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -0,0 +1,247 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import java.util.Collection; +import java.util.Comparator; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * An internal helper class that uploads streams with unknown content length. + */ +@SdkInternalApi +public final class UploadWithUnknownContentLengthHelper { + private static final Logger log = Logger.loggerFor(UploadWithUnknownContentLengthHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + + private final MultipartUploadHelper multipartUploadHelper; + + public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + SdkPojoConversionUtils::toAbortMultipartUploadRequest, + SdkPojoConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + maxMemoryUsageInBytes); + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + CompletableFuture returnFuture = new CompletableFuture<>(); + + SdkPublisher splitAsyncRequestBodyResponse = + asyncRequestBody.split(partSizeInBytes, + maxMemoryUsageInBytes); + + splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, + putObjectRequest, + returnFuture)); + return returnFuture; + } + + private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { + /** + * Indicates whether this is the first async request body or not. + */ + private final AtomicBoolean isFirstAsyncRequestBody = new AtomicBoolean(true); + + /** + * Indicates whether CreateMultipartUpload has been initiated or not + */ + private final AtomicBoolean createMultipartUploadInitiated = new AtomicBoolean(false); + + /** + * Indicates whether CompleteMultipart has been initiated or not. + */ + private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); + + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + + private AtomicInteger partNumber = new AtomicInteger(1); + + private final Queue completedParts = new ConcurrentLinkedQueue<>(); + private final Collection> futures = new ConcurrentLinkedQueue<>(); + + private final CompletableFuture uploadIdFuture = new CompletableFuture<>(); + + private final long maximumChunkSizeInByte; + private final PutObjectRequest putObjectRequest; + private final CompletableFuture returnFuture; + private Subscription subscription; + private AsyncRequestBody firstRequestBody; + + private String uploadId; + private volatile boolean isDone; + + UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByte, + PutObjectRequest putObjectRequest, + CompletableFuture returnFuture) { + this.maximumChunkSizeInByte = maximumChunkSizeInByte; + this.putObjectRequest = putObjectRequest; + this.returnFuture = returnFuture; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } + this.subscription = s; + s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + multipartUploadHelper.cancelingOtherOngoingRequests(futures, t); + } + }); + } + + @Override + public void onNext(AsyncRequestBody asyncRequestBody) { + log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + asyncRequestBodyInFlight.incrementAndGet(); + + if (isFirstAsyncRequestBody.compareAndSet(true, false)) { + log.trace(() -> "Received first async request body"); + // If this is the first AsyncRequestBody received, request another one because we don't know if there is more + firstRequestBody = asyncRequestBody; + subscription.request(1); + return; + } + + // If there are more than 1 AsyncRequestBodies, then we know we need to upload this + // object using MPU + if (createMultipartUploadInitiated.compareAndSet(false, true)) { + log.debug(() -> "Starting the upload as multipart upload request"); + CompletableFuture createMultipartUploadFuture = + multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", + throwable); + subscription.cancel(); + } else { + uploadId = createMultipartUploadResponse.uploadId(); + uploadIdFuture.complete(uploadId); + log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); + + sendUploadPartRequest(uploadId, firstRequestBody); + sendUploadPartRequest(uploadId, asyncRequestBody); + } + }); + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + } else { + uploadIdFuture.whenComplete((r, t) -> { + sendUploadPartRequest(uploadId, asyncRequestBody); + }); + } + } + + private void sendUploadPartRequest(String uploadId, AsyncRequestBody asyncRequestBody) { + multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, + uploadPart(asyncRequestBody)) + .whenComplete((r, t) -> { + if (t != null) { + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } else { + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); + } + }); + synchronized (this) { + subscription.request(1); + }; + } + + private Pair uploadPart(AsyncRequestBody asyncRequestBody) { + UploadPartRequest uploadRequest = + SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber.getAndIncrement(), + uploadId); + return Pair.of(uploadRequest, asyncRequestBody); + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Received onError() ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } + + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + // If CreateMultipartUpload has not been initiated at this point, we know this is a single object upload + if (createMultipartUploadInitiated.get() == false) { + log.debug(() -> "Starting the upload as a single object upload request"); + multipartUploadHelper.uploadInOneChunk(putObjectRequest, firstRequestBody, returnFuture); + } else { + isDone = true; + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + } + } + + private void completeMultipartUploadIfFinish(int requestsInFlight) { + if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { + CompletedPart[] parts = completedParts.stream() + .sorted(Comparator.comparingInt(CompletedPart::partNumber)) + .toArray(CompletedPart[]::new); + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest); + } + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java similarity index 58% rename from services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java rename to services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java index 1ea17d4ba967..11d54a73fb72 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java @@ -26,22 +26,28 @@ import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.OngoingStubbing; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -58,7 +64,7 @@ import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.CompletableFutureUtils; -public class MultipartUploadHelperTest { +public class UploadObjectHelperTest { private static final String BUCKET = "bucket"; private static final String KEY = "key"; @@ -70,7 +76,7 @@ public class MultipartUploadHelperTest { private static final String UPLOAD_ID = "1234"; private static RandomTempFile testFile; - private MultipartUploadHelper uploadHelper; + private UploadObjectHelper uploadHelper; private S3AsyncClient s3AsyncClient; @BeforeAll @@ -83,15 +89,20 @@ public static void afterAll() throws Exception { testFile.delete(); } + public static Stream asyncRequestBody() { + return Stream.of(new UnknownContentLengthAsyncRequestBody(AsyncRequestBody.fromFile(testFile)), + AsyncRequestBody.fromFile(testFile)); + } + @BeforeEach public void beforeEach() { s3AsyncClient = Mockito.mock(S3AsyncClient.class); - uploadHelper = new MultipartUploadHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); + uploadHelper = new UploadObjectHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); } @ParameterizedTest @ValueSource(longs = {THRESHOLD, PART_SIZE, THRESHOLD - 1, PART_SIZE - 1}) - public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { + void uploadObject_contentLengthDoesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { PutObjectRequest putObjectRequest = putObjectRequest(contentLength); AsyncRequestBody asyncRequestBody = Mockito.mock(AsyncRequestBody.class); @@ -102,15 +113,31 @@ public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChun Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); } - @Test - public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() { + @ParameterizedTest + @ValueSource(longs = {PART_SIZE, PART_SIZE - 1}) + void uploadObject_unKnownContentLengthDoesNotExceedPartSize_shouldUploadInOneChunk(long contentLength) { + PutObjectRequest putObjectRequest = putObjectRequest(contentLength); + AsyncRequestBody asyncRequestBody = + new UnknownContentLengthAsyncRequestBody(AsyncRequestBody.fromBytes(RandomStringUtils.randomAscii(Math.toIntExact(contentLength)) + .getBytes(StandardCharsets.UTF_8))); + + CompletableFuture completedFuture = + CompletableFuture.completedFuture(PutObjectResponse.builder().build()); + when(s3AsyncClient.putObject(putObjectRequest, asyncRequestBody)).thenReturn(completedFuture); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); + Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); stubSuccessfulUploadPartCalls(); stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); - uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)).join(); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); verify(s3AsyncClient, times(4)).uploadPart(requestArgumentCaptor.capture(), @@ -139,8 +166,9 @@ public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() /** * The second part failed, it should cancel ongoing part(first part). */ - @Test - void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { + @ParameterizedTest + @MethodSource("asyncRequestBody") + void mpu_onePartFailed_shouldFailOtherPartsAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(MPU_CONTENT_SIZE); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); @@ -157,7 +185,7 @@ void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, - AsyncRequestBody.fromFile(testFile)); + asyncRequestBody); assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); @@ -172,12 +200,16 @@ void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { ongoingRequest.get(1, TimeUnit.MILLISECONDS); fail("no exception thrown"); } catch (Exception e) { - assertThat(e.getCause()).hasMessageContaining("request failed"); + assertThat(e.getCause()).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); } } + /** + * This test is not parameterized because for unknown content length, the progress is nondeterministic. For example, we + * don't know if it has created multipart upload when we cancel the future. + */ @Test - void upload_cancelResponseFuture_shouldPropagate() { + void upload_knownContentLengthCancelResponseFuture_shouldCancelCreateMultipart() { PutObjectRequest putObjectRequest = putObjectRequest(null); CompletableFuture createMultipartFuture = new CompletableFuture<>(); @@ -194,7 +226,48 @@ void upload_cancelResponseFuture_shouldPropagate() { } @Test - public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { + void upload_knownContentLengthCancelResponseFuture_shouldCancelUploadPart() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + CompletableFuture createMultipartFuture = new CompletableFuture<>(); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + + CompletableFuture ongoingRequest = new CompletableFuture<>(); + + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), + any(AsyncRequestBody.class))).thenReturn(ongoingRequest); + + CompletableFuture future = + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + + future.cancel(true); + + assertThat(ongoingRequest).isCancelled(); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_createMultipartUploadFailed_shouldFail(AsyncRequestBody asyncRequestBody) { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); + + CompletableFuture createMultipartUploadFuture = + CompletableFutureUtils.failedFuture(exception); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartUploadFuture); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + asyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to initiate multipart upload") + .hasRootCause(exception); + } + + @ParameterizedTest + @MethodSource("asyncRequestBody") + void uploadObject_completeMultipartFailed_shouldFailAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); @@ -211,8 +284,31 @@ public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); - CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + asyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests") + .hasRootCause(exception); + } + + @ParameterizedTest() + @ValueSource(booleans = {false, true}) + void uploadObject_requestBodyOnError_shouldFailAndAbort(boolean contentLengthKnown) { + PutObjectRequest putObjectRequest = putObjectRequest(null); + Exception exception = new RuntimeException("error"); + + Long contentLength = contentLengthKnown ? MPU_CONTENT_SIZE : null; + ErroneousAsyncRequestBody erroneousAsyncRequestBody = + new ErroneousAsyncRequestBody(contentLength, exception); + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + erroneousAsyncRequestBody); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests") + .hasRootCause(exception); } private static PutObjectRequest putObjectRequest(Long contentLength) { @@ -256,4 +352,61 @@ public CompletableFuture answer(InvocationOnMock invocationO }); } + private static class UnknownContentLengthAsyncRequestBody implements AsyncRequestBody { + private final AsyncRequestBody delegate; + private volatile boolean cancelled; + + public UnknownContentLengthAsyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.delegate = asyncRequestBody; + } + + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + } + + private static class ErroneousAsyncRequestBody implements AsyncRequestBody { + private volatile boolean isDone; + private final Long contentLength; + private final Exception exception; + + private ErroneousAsyncRequestBody(Long contentLength, Exception exception) { + this.contentLength = contentLength; + this.exception = exception; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(contentLength); + } + + + @Override + public void subscribe(Subscriber s) { + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + if (isDone) { + return; + } + isDone = true; + s.onNext(ByteBuffer.wrap(RandomStringUtils.randomAscii(Math.toIntExact(PART_SIZE)).getBytes(StandardCharsets.UTF_8))); + s.onNext(ByteBuffer.wrap(RandomStringUtils.randomAscii(Math.toIntExact(PART_SIZE)).getBytes(StandardCharsets.UTF_8))); + s.onError(exception); + + } + + @Override + public void cancel() { + } + }); + + } + } } From d255b1f584ed7898f89703a9102423788db773bd Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Fri, 28 Jul 2023 14:59:34 -0700 Subject: [PATCH 09/13] Create a configuration class for SdkPublisher#split (#4236) --- .../awssdk/core/async/AsyncRequestBody.java | 35 +++-- .../AsyncRequestBodySplitConfiguration.java | 141 ++++++++++++++++++ .../internal/async/SplittingPublisher.java | 46 ++---- .../AsyncRequestBodyConfigurationTest.java | 58 +++++++ .../core/async/AsyncRequestBodyTest.java | 21 --- .../async/SplittingPublisherTest.java | 22 ++- .../UploadWithKnownContentLengthHelper.java | 3 +- .../UploadWithUnknownContentLengthHelper.java | 4 +- 8 files changed, 256 insertions(+), 74 deletions(-) create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfiguration.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 3c6adb8fdbac..4c7d70ab7553 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.Optional; import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import software.amazon.awssdk.annotations.SdkPublicApi; @@ -405,34 +406,36 @@ static AsyncRequestBody empty() { /** * Converts this {@link AsyncRequestBody} to a publisher of {@link AsyncRequestBody}s, each of which publishes a specific - * portion of the original data, based on the configured {code chunkSizeInBytes}. + * portion of the original data, based on the provided {@link AsyncRequestBodySplitConfiguration}. The default chunk size + * is 2MB and the default buffer size is 8MB. * *

* If content length of this {@link AsyncRequestBody} is present, each divided {@link AsyncRequestBody} is delivered to the * subscriber right after it's initialized. *

- * // TODO: API Surface Area review: should we make this behavior configurable? * If content length is null, it is sent after the entire content for that chunk is buffered. * In this case, the configured {@code maxMemoryUsageInBytes} must be larger than or equal to {@code chunkSizeInBytes}. * - * @param chunkSizeInBytes the size for each divided chunk. The last chunk may be smaller than the configured size. - * @param maxMemoryUsageInBytes the max memory the SDK will use to buffer the content - * @return SplitAsyncRequestBodyResult + * @see AsyncRequestBodySplitConfiguration */ - default SdkPublisher split(long chunkSizeInBytes, long maxMemoryUsageInBytes) { - Validate.isPositive(chunkSizeInBytes, "chunkSizeInBytes"); - Validate.isPositive(maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); - - if (!contentLength().isPresent()) { - Validate.isTrue(maxMemoryUsageInBytes >= chunkSizeInBytes, - "maxMemoryUsageInBytes must be larger than or equal to " + - "chunkSizeInBytes if the content length is unknown"); - } + default SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); return SplittingPublisher.builder() .asyncRequestBody(this) - .chunkSizeInBytes(chunkSizeInBytes) - .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .chunkSizeInBytes(splitConfiguration.chunkSizeInBytes()) + .bufferSizeInBytes(splitConfiguration.bufferSizeInBytes()) .build(); } + + /** + * This is a convenience method that passes an instance of the {@link AsyncRequestBodySplitConfiguration} builder, + * avoiding the need to create one manually via {@link AsyncRequestBodySplitConfiguration#builder()}. + * + * @see #split(AsyncRequestBodySplitConfiguration) + */ + default SdkPublisher split(Consumer splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); + return split(AsyncRequestBodySplitConfiguration.builder().applyMutation(splitConfiguration).build()); + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfiguration.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfiguration.java new file mode 100644 index 000000000000..fe51f33b4ff3 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfiguration.java @@ -0,0 +1,141 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import java.util.Objects; +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.builder.CopyableBuilder; +import software.amazon.awssdk.utils.builder.ToCopyableBuilder; + +/** + * Configuration options for {@link AsyncRequestBody#split} to configure how the SDK + * should split an {@link SdkPublisher}. + */ +@SdkPublicApi +public final class AsyncRequestBodySplitConfiguration implements ToCopyableBuilder { + private final Long chunkSizeInBytes; + private final Long bufferSizeInBytes; + + private AsyncRequestBodySplitConfiguration(DefaultBuilder builder) { + this.chunkSizeInBytes = Validate.isPositiveOrNull(builder.chunkSizeInBytes, "chunkSizeInBytes"); + this.bufferSizeInBytes = Validate.isPositiveOrNull(builder.bufferSizeInBytes, "bufferSizeInBytes"); + } + + /** + * The configured chunk size for each divided {@link AsyncRequestBody}. + */ + public Long chunkSizeInBytes() { + return chunkSizeInBytes; + } + + /** + * The configured maximum buffer size the SDK will use to buffer the content from the source {@link SdkPublisher}. + */ + public Long bufferSizeInBytes() { + return bufferSizeInBytes; + } + + /** + * Create a {@link Builder}, used to create a {@link AsyncRequestBodySplitConfiguration}. + */ + public static Builder builder() { + return new DefaultBuilder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AsyncRequestBodySplitConfiguration that = (AsyncRequestBodySplitConfiguration) o; + + if (!Objects.equals(chunkSizeInBytes, that.chunkSizeInBytes)) { + return false; + } + return Objects.equals(bufferSizeInBytes, that.bufferSizeInBytes); + } + + @Override + public int hashCode() { + int result = chunkSizeInBytes != null ? chunkSizeInBytes.hashCode() : 0; + result = 31 * result + (bufferSizeInBytes != null ? bufferSizeInBytes.hashCode() : 0); + return result; + } + + @Override + public AsyncRequestBodySplitConfiguration.Builder toBuilder() { + return new DefaultBuilder(this); + } + + public interface Builder extends CopyableBuilder { + + /** + * Configures the size for each divided chunk. The last chunk may be smaller than the configured size. The default value + * is 2MB. + * + * @param chunkSizeInBytes the chunk size in bytes + * @return This object for method chaining. + */ + Builder chunkSizeInBytes(Long chunkSizeInBytes); + + /** + * The maximum buffer size the SDK will use to buffer the content from the source {@link SdkPublisher}. The default value + * is 8MB. + * + * @param bufferSizeInBytes the buffer size in bytes + * @return This object for method chaining. + */ + Builder bufferSizeInBytes(Long bufferSizeInBytes); + } + + private static final class DefaultBuilder implements Builder { + private Long chunkSizeInBytes; + private Long bufferSizeInBytes; + + private DefaultBuilder(AsyncRequestBodySplitConfiguration asyncRequestBodySplitConfiguration) { + this.chunkSizeInBytes = asyncRequestBodySplitConfiguration.chunkSizeInBytes; + this.bufferSizeInBytes = asyncRequestBodySplitConfiguration.bufferSizeInBytes; + } + + private DefaultBuilder() { + + } + + @Override + public Builder chunkSizeInBytes(Long chunkSizeInBytes) { + this.chunkSizeInBytes = chunkSizeInBytes; + return this; + } + + @Override + public Builder bufferSizeInBytes(Long bufferSizeInBytes) { + this.bufferSizeInBytes = bufferSizeInBytes; + return this; + } + + @Override + public AsyncRequestBodySplitConfiguration build() { + return new AsyncRequestBodySplitConfiguration(this); + } + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index e18f9944a09e..43f2e10ff192 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -39,17 +39,25 @@ @SdkInternalApi public class SplittingPublisher implements SdkPublisher { private static final Logger log = Logger.loggerFor(SplittingPublisher.class); + private static final long DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024L; + private static final long DEFAULT_BUFFER_SIZE = DEFAULT_CHUNK_SIZE * 4; private final AsyncRequestBody upstreamPublisher; private final SplittingSubscriber splittingSubscriber; private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); private final long chunkSizeInBytes; - private final long maxMemoryUsageInBytes; + private final long bufferSizeInBytes; private SplittingPublisher(Builder builder) { this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); - this.chunkSizeInBytes = Validate.isPositive(builder.chunkSizeInBytes, "chunkSizeInBytes"); + this.chunkSizeInBytes = builder.chunkSizeInBytes == null ? DEFAULT_CHUNK_SIZE : builder.chunkSizeInBytes; + this.bufferSizeInBytes = builder.bufferSizeInBytes == null ? DEFAULT_BUFFER_SIZE : builder.bufferSizeInBytes; this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); - this.maxMemoryUsageInBytes = Validate.isPositive(builder.maxMemoryUsageInBytes, "maxMemoryUsageInBytes"); + + if (!upstreamPublisher.contentLength().isPresent()) { + Validate.isTrue(bufferSizeInBytes >= chunkSizeInBytes, + "bufferSizeInBytes must be larger than or equal to " + + "chunkSizeInBytes if the content length is unknown"); + } } public static Builder builder() { @@ -213,7 +221,7 @@ private void maybeRequestMoreUpstreamData() { } private boolean shouldRequestMoreData(long buffered) { - return buffered == 0 || buffered + byteBufferSizeHint <= maxMemoryUsageInBytes; + return buffered == 0 || buffered + byteBufferSizeHint <= bufferSizeInBytes; } private Long totalDataRemaining() { @@ -289,42 +297,20 @@ private void addDataBuffered(int length) { public static final class Builder { private AsyncRequestBody asyncRequestBody; private Long chunkSizeInBytes; - private Long maxMemoryUsageInBytes; + private Long bufferSizeInBytes; - /** - * Configures the asyncRequestBody to split - * - * @param asyncRequestBody The new asyncRequestBody value. - * @return This object for method chaining. - */ public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { this.asyncRequestBody = asyncRequestBody; return this; } - /** - * Configures the size of the chunk for each {@link AsyncRequestBody} to publish - * - * @param chunkSizeInBytes The new chunkSizeInBytes value. - * @return This object for method chaining. - */ - public Builder chunkSizeInBytes(long chunkSizeInBytes) { + public Builder chunkSizeInBytes(Long chunkSizeInBytes) { this.chunkSizeInBytes = chunkSizeInBytes; return this; } - /** - * Sets the maximum memory usage in bytes. - * - * @param maxMemoryUsageInBytes The new maxMemoryUsageInBytes value. - * @return This object for method chaining. - */ - // TODO: max memory usage might not be the best name, since we may technically go a little above this limit when we add - // on a new byte buffer. But we don't know for sure what the size of a buffer we request will be (we do use the size - // for the last byte buffer as a hint), so I don't think we can have a truly accurate max. Maybe we call it minimum - // buffer size instead? - public Builder maxMemoryUsageInBytes(long maxMemoryUsageInBytes) { - this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + public Builder bufferSizeInBytes(Long bufferSizeInBytes) { + this.bufferSizeInBytes = bufferSizeInBytes; return this; } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java new file mode 100644 index 000000000000..8b8f78f2b5e9 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +public class AsyncRequestBodyConfigurationTest { + + @Test + void equalsHashCode() { + EqualsVerifier.forClass(AsyncRequestBodySplitConfiguration.class) + .verify(); + } + + @ParameterizedTest + @ValueSource(longs = {0, -1}) + void nonPositiveValue_shouldThrowException(long size) { + assertThatThrownBy(() -> + AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes(size) + .build()) + .hasMessageContaining("must be positive"); + assertThatThrownBy(() -> + AsyncRequestBodySplitConfiguration.builder() + .bufferSizeInBytes(size) + .build()) + .hasMessageContaining("must be positive"); + } + + @Test + void toBuilder_shouldCopyAllFields() { + AsyncRequestBodySplitConfiguration config = AsyncRequestBodySplitConfiguration.builder() + .bufferSizeInBytes(1L) + .chunkSizeInBytes(2L) + .build(); + + assertThat(config.toBuilder().build()).isEqualTo(config); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java index 4d4bb42e06e0..cdd87822e3d4 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java @@ -356,25 +356,4 @@ void publisherConstructorHasCorrectContentType() { AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(bodyPublisher); assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); } - - @Test - public void split_nonPositiveInput_shouldThrowException() { - AsyncRequestBody body = AsyncRequestBody.fromString("test"); - assertThatThrownBy(() -> body.split(0, 4)).hasMessageContaining("must be positive"); - assertThatThrownBy(() -> body.split(-1, 4)).hasMessageContaining("must be positive"); - assertThatThrownBy(() -> body.split(5, 0)).hasMessageContaining("must be positive"); - assertThatThrownBy(() -> body.split(5, -1)).hasMessageContaining("must be positive"); - } - - @Test - public void split_contentUnknownMaxMemorySmallerThanChunkSize_shouldThrowException() { - AsyncRequestBody body = AsyncRequestBody.fromPublisher(new Publisher() { - @Override - public void subscribe(Subscriber s) { - - } - }); - assertThatThrownBy(() -> body.split(10, 4)) - .hasMessageContaining("must be larger than or equal"); - } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 368c403dbf88..0966ea6eb76f 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.core.internal.async; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; @@ -39,6 +40,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; @@ -66,6 +68,18 @@ public static void afterAll() throws Exception { testFile.delete(); } + @Test + public void split_contentUnknownMaxMemorySmallerThanChunkSize_shouldThrowException() { + AsyncRequestBody body = AsyncRequestBody.fromPublisher(s -> { + }); + assertThatThrownBy(() -> SplittingPublisher.builder() + .asyncRequestBody(body) + .chunkSizeInBytes(10L) + .bufferSizeInBytes(5L) + .build()) + .hasMessageContaining("must be larger than or equal"); + } + @ParameterizedTest @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int chunkSize) throws Exception { @@ -94,8 +108,8 @@ public Optional contentLength() { }; SplittingPublisher splittingPublisher = SplittingPublisher.builder() .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes(CHUNK_SIZE) - .maxMemoryUsageInBytes(10L) + .chunkSizeInBytes((long) CHUNK_SIZE) + .bufferSizeInBytes(10L) .build(); @@ -136,8 +150,8 @@ public Optional contentLength() { private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { SplittingPublisher splittingPublisher = SplittingPublisher.builder() .asyncRequestBody(asyncRequestBody) - .chunkSizeInBytes(chunkSize) - .maxMemoryUsageInBytes((long) chunkSize * 4) + .chunkSizeInBytes((long) chunkSize) + .bufferSizeInBytes((long) chunkSize * 4) .build(); List> futures = new ArrayList<>(); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index e8bef01ab81b..5e1a41da4d86 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -119,7 +119,8 @@ private void doUploadInParts(Pair request, MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); request.right() - .split(mpuRequestContext.partSize, maxMemoryUsageInBytes) + .split(b -> b.chunkSizeInBytes(mpuRequestContext.partSize) + .bufferSizeInBytes(maxMemoryUsageInBytes)) .subscribe(new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture)); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index d2034b4b4e94..fa8be1e0c6f3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -74,8 +74,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj CompletableFuture returnFuture = new CompletableFuture<>(); SdkPublisher splitAsyncRequestBodyResponse = - asyncRequestBody.split(partSizeInBytes, - maxMemoryUsageInBytes); + asyncRequestBody.split(b -> b.chunkSizeInBytes(partSizeInBytes) + .bufferSizeInBytes(maxMemoryUsageInBytes)); splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, putObjectRequest, From 65cc9dfaf126e6cfb7b9bf9bc2faba520c6b9a90 Mon Sep 17 00:00:00 2001 From: Olivier L Applin Date: Tue, 1 Aug 2023 13:46:32 -0400 Subject: [PATCH 10/13] S3 Multipart API implementation (#4235) * Multipart API fix merge conflicts * getObject(...) throw UnsupportedOperationException * Use user agent for all requests in MultipartS3Client * MultipartS3AsyncClient javadoc + API_NAME private * use `maximumMemoryUsageInBytes` * fix problem with UserAgent, cleanup * move contextParam keys to S3AsyncClientDecorator * javadoc * more javadoc * Use 4x part size as default apiCallBufferSize --- .../customization/CustomizationConfig.java | 13 ++ .../customization/MultipartCustomization.java | 64 ++++++ .../amazon/awssdk/codegen/poet/ClassSpec.java | 2 +- .../poet/builder/AsyncClientBuilderClass.java | 64 ++++-- .../builder/AsyncClientBuilderInterface.java | 77 ++++++- .../services/s3/S3IntegrationTestBase.java | 2 +- .../S3ClientMultiPartCopyIntegrationTest.java | 14 +- ...ltipartClientPutObjectIntegrationTest.java | 31 +-- .../client/S3AsyncClientDecorator.java | 24 ++- .../internal/multipart/CopyObjectHelper.java | 10 +- .../multipart/GenericMultipartHelper.java | 4 +- .../multipart/MultipartS3AsyncClient.java | 65 +++++- .../multipart/MultipartUploadHelper.java | 4 +- .../UploadWithKnownContentLengthHelper.java | 4 + .../s3/multipart/MultipartConfiguration.java | 199 ++++++++++++++++++ .../codegen-resources/customization.config | 7 + .../MultipartClientUserAgentTest.java | 82 ++++++++ .../S3MultipartClientBuilderTest.java | 63 ++++++ 18 files changed, 668 insertions(+), 61 deletions(-) create mode 100644 codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/MultipartCustomization.java create mode 100644 services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientUserAgentTest.java create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientBuilderTest.java diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/CustomizationConfig.java b/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/CustomizationConfig.java index 596d44bcf14b..0bef67df7867 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/CustomizationConfig.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/CustomizationConfig.java @@ -227,6 +227,11 @@ public class CustomizationConfig { */ private String asyncClientDecorator; + /** + * Only for s3. A set of customization to related to multipart operations. + */ + private MultipartCustomization multipartCustomization; + /** * Whether to skip generating endpoint tests from endpoint-tests.json */ @@ -665,4 +670,12 @@ public Map getCustomClientContextParams() { public void setCustomClientContextParams(Map customClientContextParams) { this.customClientContextParams = customClientContextParams; } + + public MultipartCustomization getMultipartCustomization() { + return this.multipartCustomization; + } + + public void setMultipartCustomization(MultipartCustomization multipartCustomization) { + this.multipartCustomization = multipartCustomization; + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/MultipartCustomization.java b/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/MultipartCustomization.java new file mode 100644 index 000000000000..94264a9e5ec6 --- /dev/null +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/model/config/customization/MultipartCustomization.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.codegen.model.config.customization; + +public class MultipartCustomization { + private String multipartConfigurationClass; + private String multipartConfigMethodDoc; + private String multipartEnableMethodDoc; + private String contextParamEnabledKey; + private String contextParamConfigKey; + + public String getMultipartConfigurationClass() { + return multipartConfigurationClass; + } + + public void setMultipartConfigurationClass(String multipartConfigurationClass) { + this.multipartConfigurationClass = multipartConfigurationClass; + } + + public String getMultipartConfigMethodDoc() { + return multipartConfigMethodDoc; + } + + public void setMultipartConfigMethodDoc(String multipartMethodDoc) { + this.multipartConfigMethodDoc = multipartMethodDoc; + } + + public String getMultipartEnableMethodDoc() { + return multipartEnableMethodDoc; + } + + public void setMultipartEnableMethodDoc(String multipartEnableMethodDoc) { + this.multipartEnableMethodDoc = multipartEnableMethodDoc; + } + + public String getContextParamEnabledKey() { + return contextParamEnabledKey; + } + + public void setContextParamEnabledKey(String contextParamEnabledKey) { + this.contextParamEnabledKey = contextParamEnabledKey; + } + + public String getContextParamConfigKey() { + return contextParamConfigKey; + } + + public void setContextParamConfigKey(String contextParamConfigKey) { + this.contextParamConfigKey = contextParamConfigKey; + } +} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/ClassSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/ClassSpec.java index a8265f0dc7f1..59a719fb2c7d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/ClassSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/ClassSpec.java @@ -20,7 +20,7 @@ import java.util.Collections; /** - * Represents the a Poet generated class + * Represents a Poet generated class */ public interface ClassSpec { diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderClass.java index 509a30c6c8d7..3ff2b99ec98e 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderClass.java @@ -17,6 +17,7 @@ import com.squareup.javapoet.ClassName; import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeSpec; import java.net.URI; @@ -24,6 +25,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider; import software.amazon.awssdk.awscore.client.config.AwsClientOption; +import software.amazon.awssdk.codegen.model.config.customization.MultipartCustomization; import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; import software.amazon.awssdk.codegen.poet.ClassSpec; import software.amazon.awssdk.codegen.poet.PoetExtension; @@ -59,12 +61,12 @@ public AsyncClientBuilderClass(IntermediateModel model) { @Override public TypeSpec poetSpec() { TypeSpec.Builder builder = - PoetUtils.createClassBuilder(builderClassName) - .addAnnotation(SdkInternalApi.class) - .addModifiers(Modifier.FINAL) - .superclass(ParameterizedTypeName.get(builderBaseClassName, builderInterfaceName, clientInterfaceName)) - .addSuperinterface(builderInterfaceName) - .addJavadoc("Internal implementation of {@link $T}.", builderInterfaceName); + PoetUtils.createClassBuilder(builderClassName) + .addAnnotation(SdkInternalApi.class) + .addModifiers(Modifier.FINAL) + .superclass(ParameterizedTypeName.get(builderBaseClassName, builderInterfaceName, clientInterfaceName)) + .addSuperinterface(builderInterfaceName) + .addJavadoc("Internal implementation of {@link $T}.", builderInterfaceName); if (model.getEndpointOperation().isPresent()) { builder.addMethod(endpointDiscoveryEnabled()); @@ -80,6 +82,12 @@ public TypeSpec poetSpec() { builder.addMethod(bearerTokenProviderMethod()); } + MultipartCustomization multipartCustomization = model.getCustomizationConfig().getMultipartCustomization(); + if (multipartCustomization != null) { + builder.addMethod(multipartEnabledMethod(multipartCustomization)); + builder.addMethod(multipartConfigMethods(multipartCustomization)); + } + builder.addMethod(buildClientMethod()); builder.addMethod(initializeServiceClientConfigMethod()); @@ -124,15 +132,15 @@ private MethodSpec endpointProviderMethod() { private MethodSpec buildClientMethod() { MethodSpec.Builder builder = MethodSpec.methodBuilder("buildClient") - .addAnnotation(Override.class) - .addModifiers(Modifier.PROTECTED, Modifier.FINAL) - .returns(clientInterfaceName) - .addStatement("$T clientConfiguration = super.asyncClientConfiguration()", - SdkClientConfiguration.class).addStatement("this.validateClientOptions" - + "(clientConfiguration)") - .addStatement("$T serviceClientConfiguration = initializeServiceClientConfig" - + "(clientConfiguration)", - serviceConfigClassName); + .addAnnotation(Override.class) + .addModifiers(Modifier.PROTECTED, Modifier.FINAL) + .returns(clientInterfaceName) + .addStatement("$T clientConfiguration = super.asyncClientConfiguration()", + SdkClientConfiguration.class) + .addStatement("this.validateClientOptions(clientConfiguration)") + .addStatement("$T serviceClientConfiguration = initializeServiceClientConfig" + + "(clientConfiguration)", + serviceConfigClassName); builder.addStatement("$1T client = new $2T(serviceClientConfiguration, clientConfiguration)", clientInterfaceName, clientClassName); @@ -156,6 +164,32 @@ private MethodSpec bearerTokenProviderMethod() { .build(); } + private MethodSpec multipartEnabledMethod(MultipartCustomization multipartCustomization) { + return MethodSpec.methodBuilder("multipartEnabled") + .addAnnotation(Override.class) + .addModifiers(Modifier.PUBLIC) + .returns(builderInterfaceName) + .addParameter(Boolean.class, "enabled") + .addStatement("clientContextParams.put($N, enabled)", + multipartCustomization.getContextParamEnabledKey()) + .addStatement("return this") + .build(); + } + + private MethodSpec multipartConfigMethods(MultipartCustomization multipartCustomization) { + ClassName mulitpartConfigClassName = + PoetUtils.classNameFromFqcn(multipartCustomization.getMultipartConfigurationClass()); + return MethodSpec.methodBuilder("multipartConfiguration") + .addAnnotation(Override.class) + .addModifiers(Modifier.PUBLIC) + .addParameter(ParameterSpec.builder(mulitpartConfigClassName, "multipartConfig").build()) + .returns(builderInterfaceName) + .addStatement("clientContextParams.put($N, multipartConfig)", + multipartCustomization.getContextParamConfigKey()) + .addStatement("return this") + .build(); + } + private MethodSpec initializeServiceClientConfigMethod() { return MethodSpec.methodBuilder("initializeServiceClientConfig").addModifiers(Modifier.PRIVATE) .addParameter(SdkClientConfiguration.class, "clientConfig") diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderInterface.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderInterface.java index 5348972b5df9..df62f97ae7c0 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderInterface.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/AsyncClientBuilderInterface.java @@ -17,34 +17,97 @@ import com.squareup.javapoet.ClassName; import com.squareup.javapoet.CodeBlock; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeSpec; +import java.util.function.Consumer; +import javax.lang.model.element.Modifier; import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder; +import software.amazon.awssdk.codegen.model.config.customization.MultipartCustomization; import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; import software.amazon.awssdk.codegen.poet.ClassSpec; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; public class AsyncClientBuilderInterface implements ClassSpec { + private static final Logger log = Logger.loggerFor(AsyncClientBuilderInterface.class); + private final ClassName builderInterfaceName; private final ClassName clientInterfaceName; private final ClassName baseBuilderInterfaceName; + private final IntermediateModel model; public AsyncClientBuilderInterface(IntermediateModel model) { String basePackage = model.getMetadata().getFullClientPackageName(); this.clientInterfaceName = ClassName.get(basePackage, model.getMetadata().getAsyncInterface()); this.builderInterfaceName = ClassName.get(basePackage, model.getMetadata().getAsyncBuilderInterface()); this.baseBuilderInterfaceName = ClassName.get(basePackage, model.getMetadata().getBaseBuilderInterface()); + this.model = model; } @Override public TypeSpec poetSpec() { - return PoetUtils.createInterfaceBuilder(builderInterfaceName) - .addSuperinterface(ParameterizedTypeName.get(ClassName.get(AwsAsyncClientBuilder.class), - builderInterfaceName, clientInterfaceName)) - .addSuperinterface(ParameterizedTypeName.get(baseBuilderInterfaceName, - builderInterfaceName, clientInterfaceName)) - .addJavadoc(getJavadoc()) - .build(); + TypeSpec.Builder builder = PoetUtils + .createInterfaceBuilder(builderInterfaceName) + .addSuperinterface(ParameterizedTypeName.get(ClassName.get(AwsAsyncClientBuilder.class), + builderInterfaceName, clientInterfaceName)) + .addSuperinterface(ParameterizedTypeName.get(baseBuilderInterfaceName, + builderInterfaceName, clientInterfaceName)) + .addJavadoc(getJavadoc()); + + MultipartCustomization multipartCustomization = model.getCustomizationConfig().getMultipartCustomization(); + if (multipartCustomization != null) { + includeMultipartMethod(builder, multipartCustomization); + } + return builder.build(); + } + + private void includeMultipartMethod(TypeSpec.Builder builder, MultipartCustomization multipartCustomization) { + log.debug(() -> String.format("Adding multipart config methods to builder interface for service '%s'", + model.getMetadata().getServiceId())); + + // .multipartEnabled(Boolean) + builder.addMethod( + MethodSpec.methodBuilder("multipartEnabled") + .addModifiers(Modifier.DEFAULT, Modifier.PUBLIC) + .returns(builderInterfaceName) + .addParameter(Boolean.class, "enabled") + .addCode("throw new $T();", UnsupportedOperationException.class) + .addJavadoc(CodeBlock.of(multipartCustomization.getMultipartEnableMethodDoc())) + .build()); + + // .multipartConfiguration(MultipartConfiguration) + String multiPartConfigMethodName = "multipartConfiguration"; + String multipartConfigClass = Validate.notNull(multipartCustomization.getMultipartConfigurationClass(), + "'multipartConfigurationClass' must be defined"); + ClassName mulitpartConfigClassName = PoetUtils.classNameFromFqcn(multipartConfigClass); + builder.addMethod( + MethodSpec.methodBuilder(multiPartConfigMethodName) + .addModifiers(Modifier.DEFAULT, Modifier.PUBLIC) + .returns(builderInterfaceName) + .addParameter(ParameterSpec.builder(mulitpartConfigClassName, "multipartConfiguration").build()) + .addCode("throw new $T();", UnsupportedOperationException.class) + .addJavadoc(CodeBlock.of(multipartCustomization.getMultipartConfigMethodDoc())) + .build()); + + // .multipartConfiguration(Consumer) + ClassName mulitpartConfigBuilderClassName = PoetUtils.classNameFromFqcn(multipartConfigClass + ".Builder"); + ParameterizedTypeName consumerBuilderType = ParameterizedTypeName.get(ClassName.get(Consumer.class), + mulitpartConfigBuilderClassName); + builder.addMethod( + MethodSpec.methodBuilder(multiPartConfigMethodName) + .addModifiers(Modifier.DEFAULT, Modifier.PUBLIC) + .returns(builderInterfaceName) + .addParameter(ParameterSpec.builder(consumerBuilderType, "multipartConfiguration").build()) + .addStatement("$T builder = $T.builder()", + mulitpartConfigBuilderClassName, + mulitpartConfigClassName) + .addStatement("multipartConfiguration.accept(builder)") + .addStatement("return multipartConfiguration(builder.build())") + .addJavadoc(CodeBlock.of(multipartCustomization.getMultipartConfigMethodDoc())) + .build()); } @Override diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java index 63dcf2ddc88f..03cf42afe5df 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java @@ -117,7 +117,7 @@ protected static void deleteBucketAndAllContents(String bucketName) { S3TestUtils.deleteBucketAndAllContents(s3, bucketName); } - private static class UserAgentVerifyingExecutionInterceptor implements ExecutionInterceptor { + protected static class UserAgentVerifyingExecutionInterceptor implements ExecutionInterceptor { private final String clientName; private final ClientType clientType; diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java index 6db434526fb9..fc4f31b76b1a 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java @@ -31,17 +31,16 @@ import javax.crypto.KeyGenerator; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient; -import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.MetadataDirective; @@ -58,6 +57,7 @@ public class S3ClientMultiPartCopyIntegrationTest extends S3IntegrationTestBase private static final long SMALL_OBJ_SIZE = 1024 * 1024; private static S3AsyncClient s3CrtAsyncClient; private static S3AsyncClient s3MpuClient; + @BeforeAll public static void setUp() throws Exception { S3IntegrationTestBase.setUp(); @@ -66,7 +66,13 @@ public static void setUp() throws Exception { .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) .region(DEFAULT_REGION) .build(); - s3MpuClient = new MultipartS3AsyncClient(s3Async); + s3MpuClient = S3AsyncClient.builder() + .region(DEFAULT_REGION) + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .overrideConfiguration(o -> o.addExecutionInterceptor( + new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))) + .multipartEnabled(true) + .build(); } @AfterAll @@ -158,7 +164,7 @@ private static byte[] generateSecretKey() { private void createOriginalObject(byte[] originalContent, String originalKey) { s3CrtAsyncClient.putObject(r -> r.bucket(BUCKET) - .key(originalKey), + .key(originalKey), AsyncRequestBody.fromBytes(originalContent)).join(); } diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index cb72906943b9..fa31b5453e5e 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.services.s3.multipart; -import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; @@ -27,25 +26,21 @@ import java.nio.file.Files; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; -import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.reactivestreams.Subscriber; +import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; -import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.utils.ChecksumUtils; -import software.amazon.awssdk.testutils.RandomTempFile; @Timeout(value = 30, unit = SECONDS) public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { @@ -66,7 +61,14 @@ public static void setup() throws Exception { testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); Files.write(testFile.toPath(), CONTENT); - mpuS3Client = new MultipartS3AsyncClient(s3Async); + mpuS3Client = S3AsyncClient + .builder() + .region(DEFAULT_REGION) + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .overrideConfiguration(o -> o.addExecutionInterceptor( + new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))) + .multipartEnabled(true) + .build(); } @AfterAll @@ -81,8 +83,9 @@ void putObject_fileRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); - ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); + ResponseInputStream objContent = + S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); @@ -95,8 +98,9 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); - ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); + ResponseInputStream objContent = + S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); assertThat(objContent.response().contentLength()).isEqualTo(OBJ_SIZE); byte[] expectedSum = ChecksumUtils.computeCheckSum(new ByteArrayInputStream(bytes)); @@ -120,8 +124,9 @@ public void subscribe(Subscriber s) { } }).get(30, SECONDS); - ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); + ResponseInputStream objContent = + S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/client/S3AsyncClientDecorator.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/client/S3AsyncClientDecorator.java index 2dbb61091da2..b751cb29c1b0 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/client/S3AsyncClientDecorator.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/client/S3AsyncClientDecorator.java @@ -23,11 +23,17 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; import software.amazon.awssdk.services.s3.internal.crossregion.S3CrossRegionAsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.ConditionalDecorator; @SdkInternalApi public class S3AsyncClientDecorator { + public static final AttributeMap.Key MULTIPART_CONFIGURATION_KEY = + new AttributeMap.Key(MultipartConfiguration.class){}; + public static final AttributeMap.Key MULTIPART_ENABLED_KEY = + new AttributeMap.Key(Boolean.class){}; public S3AsyncClientDecorator() { } @@ -36,14 +42,26 @@ public S3AsyncClient decorate(S3AsyncClient base, SdkClientConfiguration clientConfiguration, AttributeMap clientContextParams) { List> decorators = new ArrayList<>(); - decorators.add(ConditionalDecorator.create(isCrossRegionEnabledAsync(clientContextParams), - S3CrossRegionAsyncClient::new)); + decorators.add(ConditionalDecorator.create( + isCrossRegionEnabledAsync(clientContextParams), + S3CrossRegionAsyncClient::new)); + + decorators.add(ConditionalDecorator.create( + isMultipartEnable(clientContextParams), + client -> { + MultipartConfiguration multipartConfiguration = clientContextParams.get(MULTIPART_CONFIGURATION_KEY); + return MultipartS3AsyncClient.create(client, multipartConfiguration); + })); return ConditionalDecorator.decorate(base, decorators); } private Predicate isCrossRegionEnabledAsync(AttributeMap clientContextParams) { Boolean crossRegionEnabled = clientContextParams.get(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED); - return client -> crossRegionEnabled != null && crossRegionEnabled.booleanValue(); + return client -> crossRegionEnabled != null && crossRegionEnabled.booleanValue(); } + private Predicate isMultipartEnable(AttributeMap clientContextParams) { + Boolean multipartEnabled = clientContextParams.get(MULTIPART_ENABLED_KEY); + return client -> multipartEnabled != null && multipartEnabled.booleanValue(); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java index 31b947bb89c5..16294ff8f065 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CopyObjectHelper.java @@ -24,8 +24,6 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.crt.UploadPartCopyRequestIterable; -import software.amazon.awssdk.services.s3.internal.multipart.GenericMultipartHelper; -import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; @@ -130,6 +128,10 @@ private void doCopyInParts(CopyObjectRequest copyObjectRequest, long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); + if (optimalPartSize > partSizeInBytes) { + log.debug(() -> String.format("Configured partSize is %d, but using %d to prevent reaching maximum number of parts " + + "allowed", partSizeInBytes, optimalPartSize)); + } log.debug(() -> String.format("Starting multipart copy with partCount: %s, optimalPartSize: %s", partCount, optimalPartSize)); @@ -170,7 +172,6 @@ private CompletableFuture completeMultipartUplo .parts(parts) .build()) .build(); - return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } @@ -201,7 +202,8 @@ private void sendIndividualUploadPartCopy(String uploadId, log.debug(() -> "Sending uploadPartCopyRequest with range: " + uploadPartCopyRequest.copySourceRange() + " uploadId: " + uploadId); - CompletableFuture uploadPartCopyFuture = s3AsyncClient.uploadPartCopy(uploadPartCopyRequest); + CompletableFuture uploadPartCopyFuture = + s3AsyncClient.uploadPartCopy(uploadPartCopyRequest); CompletableFuture convertFuture = uploadPartCopyFuture.thenApply(uploadPartCopyResponse -> diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java index 905c1bc928ea..38e76394958e 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -91,7 +91,6 @@ public CompletableFuture completeMultipartUploa .parts(parts) .build()) .build(); - return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } @@ -125,7 +124,8 @@ public BiFunction handleExcept public void cleanUpParts(String uploadId, AbortMultipartUploadRequest.Builder abortMultipartUploadRequest) { log.debug(() -> "Aborting multipart upload: " + uploadId); - s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest.uploadId(uploadId).build()) + AbortMultipartUploadRequest request = abortMultipartUploadRequest.uploadId(uploadId).build(); + s3AsyncClient.abortMultipartUpload(request) .exceptionally(throwable -> { log.warn(() -> String.format("Failed to abort previous multipart upload " + "(id: %s)" diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index a4b3147254f9..65b26ddec971 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -17,31 +17,59 @@ import java.util.concurrent.CompletableFuture; +import java.util.function.Function; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.ApiName; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.UserAgentUtils; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Request; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.utils.Validate; -// This is just a temporary class for testing -//TODO: change this +/** + * An {@link S3AsyncClient} that automatically converts put, copy requests to their respective multipart call. Note: get is not + * yet supported. + * + * @see MultipartConfiguration + */ @SdkInternalApi -public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { - private static final long DEFAULT_PART_SIZE_IN_BYTES = 8L * 1024 * 1024; +public final class MultipartS3AsyncClient extends DelegatingS3AsyncClient { + + private static final ApiName USER_AGENT_API_NAME = ApiName.builder().name("hll").version("s3Multipart").build(); + + private static final long DEFAULT_MIN_PART_SIZE = 8L * 1024 * 1024; private static final long DEFAULT_THRESHOLD = 8L * 1024 * 1024; + private static final long DEFAULT_API_CALL_BUFFER_SIZE = DEFAULT_MIN_PART_SIZE * 4; - private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; private final UploadObjectHelper mpuHelper; private final CopyObjectHelper copyObjectHelper; - public MultipartS3AsyncClient(S3AsyncClient delegate) { + private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration multipartConfiguration) { super(delegate); - // TODO: pass a config object to the upload helper instead - mpuHelper = new UploadObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); - copyObjectHelper = new CopyObjectHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD); + MultipartConfiguration validConfiguration = Validate.getOrDefault(multipartConfiguration, + MultipartConfiguration.builder()::build); + long minPartSizeInBytes = Validate.getOrDefault(validConfiguration.minimumPartSizeInBytes(), + () -> DEFAULT_MIN_PART_SIZE); + long threshold = Validate.getOrDefault(validConfiguration.thresholdInBytes(), + () -> DEFAULT_THRESHOLD); + long apiCallBufferSizeInBytes = Validate.getOrDefault(validConfiguration.apiCallBufferSizeInBytes(), + () -> computeApiCallBufferSize(validConfiguration)); + mpuHelper = new UploadObjectHelper(delegate, minPartSizeInBytes, threshold, apiCallBufferSizeInBytes); + copyObjectHelper = new CopyObjectHelper(delegate, minPartSizeInBytes, threshold); + } + + private long computeApiCallBufferSize(MultipartConfiguration multipartConfiguration) { + return multipartConfiguration.minimumPartSizeInBytes() != null ? multipartConfiguration.minimumPartSizeInBytes() * 4 + : DEFAULT_API_CALL_BUFFER_SIZE; } @Override @@ -54,8 +82,27 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb return copyObjectHelper.copyObject(copyObjectRequest); } + @Override + public CompletableFuture getObject( + GetObjectRequest getObjectRequest, AsyncResponseTransformer asyncResponseTransformer) { + throw new UnsupportedOperationException( + "Multipart download is not yet supported. Instead use the CRT based S3 client for multipart download."); + } + @Override public void close() { delegate().close(); } + + public static MultipartS3AsyncClient create(S3AsyncClient client, MultipartConfiguration multipartConfiguration) { + S3AsyncClient clientWithUserAgent = new DelegatingS3AsyncClient(client) { + @Override + protected CompletableFuture invokeOperation(T request, Function> operation) { + T requestWithUserAgent = UserAgentUtils.applyUserAgentInfo(request, c -> c.addApiName(USER_AGENT_API_NAME)); + return operation.apply(requestWithUserAgent); + } + }; + return new MultipartS3AsyncClient(clientWithUserAgent, multipartConfiguration); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index 1228e577fcd1..9754d284f5b9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -36,8 +36,8 @@ import software.amazon.awssdk.utils.Pair; /** - * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} - * and {@link UploadWithKnownContentLengthHelper}. + * A base class contains common logic used by {@link UploadWithUnknownContentLengthHelper} and + * {@link UploadWithKnownContentLengthHelper}. */ @SdkInternalApi public final class MultipartUploadHelper { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 5e1a41da4d86..f7d199ac3aa6 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -112,6 +112,10 @@ private void doUploadInParts(Pair request, long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); + if (optimalPartSize > partSizeInBytes) { + log.debug(() -> String.format("Configured partSize is %d, but using %d to prevent reaching maximum number of parts " + + "allowed", partSizeInBytes, optimalPartSize)); + } log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, optimalPartSize)); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java new file mode 100644 index 000000000000..28e418974db8 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java @@ -0,0 +1,199 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import java.util.function.Consumer; +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.model.CopyObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.utils.builder.CopyableBuilder; +import software.amazon.awssdk.utils.builder.ToCopyableBuilder; + +/** + * Class that hold configuration properties related to multipart operation for a {@link S3AsyncClient}. Passing this class to the + * {@link S3AsyncClientBuilder#multipartConfiguration(MultipartConfiguration)} will enable automatic conversion of + * {@link S3AsyncClient#putObject(Consumer, AsyncRequestBody)}, {@link S3AsyncClient#copyObject(CopyObjectRequest)} to their + * respective multipart operation. + *

+ * Note: The multipart operation for {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer)} is + * temporarily disabled and will result in throwing a {@link UnsupportedOperationException} if called when configured for + * multipart operation. + */ +@SdkPublicApi +public final class MultipartConfiguration implements ToCopyableBuilder { + + private final Long thresholdInBytes; + private final Long minimumPartSizeInBytes; + private final Long apiCallBufferSizeInBytes; + + private MultipartConfiguration(DefaultMultipartConfigBuilder builder) { + this.thresholdInBytes = builder.thresholdInBytes; + this.minimumPartSizeInBytes = builder.minimumPartSizeInBytes; + this.apiCallBufferSizeInBytes = builder.apiCallBufferSizeInBytes; + } + + public static Builder builder() { + return new DefaultMultipartConfigBuilder(); + } + + @Override + public Builder toBuilder() { + return builder() + .apiCallBufferSizeInBytes(apiCallBufferSizeInBytes) + .minimumPartSizeInBytes(minimumPartSizeInBytes) + .thresholdInBytes(thresholdInBytes); + } + + /** + * Indicates the value of the configured threshold, in bytes. Any request whose size is less than the configured value will + * not use multipart operation + * @return the value of the configured threshold. + */ + public Long thresholdInBytes() { + return this.thresholdInBytes; + } + + /** + * Indicated the size, in bytes, of each individual part of the part requests. The actual part size used might be bigger to + * conforms to the maximum number of parts allowed per multipart requests. + * @return the value of the configured part size. + */ + public Long minimumPartSizeInBytes() { + return this.minimumPartSizeInBytes; + } + + /** + * The maximum memory, in bytes, that the SDK will use to buffer requests content into memory. + * @return the value of the configured maximum memory usage. + */ + public Long apiCallBufferSizeInBytes() { + return this.apiCallBufferSizeInBytes; + } + + /** + * Builder for a {@link MultipartConfiguration}. + */ + public interface Builder extends CopyableBuilder { + + /** + * Configures the minimum number of bytes of the body of the request required for requests to be converted to their + * multipart equivalent. Only taken into account when converting {@code putObject} and {@code copyObject} requests. + * Any request whose size is less than the configured value will not use multipart operation, + * even if multipart is enabled via {@link S3AsyncClientBuilder#multipartEnabled(Boolean)}. + *

+ * + * Default value: 8 Mib + * + * @param thresholdInBytes the value of the threshold to set. + * @return an instance of this builder. + */ + Builder thresholdInBytes(Long thresholdInBytes); + + /** + * Indicates the value of the configured threshold. + * @return the value of the threshold. + */ + Long thresholdInBytes(); + + /** + * Configures the part size, in bytes, to be used in each individual part requests. + * Only used for putObject and copyObject operations. + *

+ * When uploading large payload, the size of the payload of each individual part requests might actually be + * bigger than + * the configured value since there is a limit to the maximum number of parts possible per multipart request. If the + * configured part size would lead to a number of parts higher than the maximum allowed, a larger part size will be + * calculated instead to allow fewer part to be uploaded, to avoid the limit imposed on the maximum number of parts. + *

+ * In the case where the {@code minimumPartSizeInBytes} is set to a value higher than the {@code thresholdInBytes}, when + * the client receive a request with a size smaller than a single part multipart operation will NOT be performed + * even if the size of the request is larger than the threshold. + *

+ * Default value: 8 Mib + * + * @param minimumPartSizeInBytes the value of the part size to set + * @return an instance of this builder. + */ + Builder minimumPartSizeInBytes(Long minimumPartSizeInBytes); + + /** + * Indicated the value of the part configured size. + * @return the value of the part size + */ + Long minimumPartSizeInBytes(); + + /** + * Configures the maximum amount of memory, in bytes, the SDK will use to buffer content of requests in memory. + * Increasing this value may lead to better performance at the cost of using more memory. + *

+ * Default value: If not specified, the SDK will use the equivalent of four parts worth of memory, so 32 Mib by default. + * + * @param apiCallBufferSizeInBytes the value of the maximum memory usage. + * @return an instance of this builder. + */ + Builder apiCallBufferSizeInBytes(Long apiCallBufferSizeInBytes); + + /** + * Indicates the value of the maximum memory usage that the SDK will use. + * @return the value of the maximum memory usage. + */ + Long apiCallBufferSizeInBytes(); + } + + private static class DefaultMultipartConfigBuilder implements Builder { + private Long thresholdInBytes; + private Long minimumPartSizeInBytes; + private Long apiCallBufferSizeInBytes; + + public Builder thresholdInBytes(Long thresholdInBytes) { + this.thresholdInBytes = thresholdInBytes; + return this; + } + + public Long thresholdInBytes() { + return this.thresholdInBytes; + } + + public Builder minimumPartSizeInBytes(Long minimumPartSizeInBytes) { + this.minimumPartSizeInBytes = minimumPartSizeInBytes; + return this; + } + + public Long minimumPartSizeInBytes() { + return this.minimumPartSizeInBytes; + } + + @Override + public Builder apiCallBufferSizeInBytes(Long maximumMemoryUsageInBytes) { + this.apiCallBufferSizeInBytes = maximumMemoryUsageInBytes; + return this; + } + + @Override + public Long apiCallBufferSizeInBytes() { + return apiCallBufferSizeInBytes; + } + + @Override + public MultipartConfiguration build() { + return new MultipartConfiguration(this); + } + } +} diff --git a/services/s3/src/main/resources/codegen-resources/customization.config b/services/s3/src/main/resources/codegen-resources/customization.config index 1a1efb76c5f4..ccddba62880c 100644 --- a/services/s3/src/main/resources/codegen-resources/customization.config +++ b/services/s3/src/main/resources/codegen-resources/customization.config @@ -236,6 +236,13 @@ "syncClientDecorator": "software.amazon.awssdk.services.s3.internal.client.S3SyncClientDecorator", "asyncClientDecorator": "software.amazon.awssdk.services.s3.internal.client.S3AsyncClientDecorator", "useGlobalEndpoint": true, + "multipartCustomization": { + "multipartConfigurationClass": "software.amazon.awssdk.services.s3.multipart.MultipartConfiguration", + "multipartConfigMethodDoc": "Configuration for multipart operation of this client.", + "multipartEnableMethodDoc": "Enables automatic conversion of put and copy method to their equivalent multipart operation.", + "contextParamEnabledKey": "S3AsyncClientDecorator.MULTIPART_ENABLED_KEY", + "contextParamConfigKey": "S3AsyncClientDecorator.MULTIPART_CONFIGURATION_KEY" + }, "interceptors": [ "software.amazon.awssdk.services.s3.internal.handlers.PutObjectInterceptor", "software.amazon.awssdk.services.s3.internal.handlers.CreateBucketInterceptor", diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientUserAgentTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientUserAgentTest.java new file mode 100644 index 000000000000..0f41c7c78e74 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientUserAgentTest.java @@ -0,0 +1,82 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ApiName; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.HttpExecuteResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient; + +class MultipartClientUserAgentTest { + private MockAsyncHttpClient mockAsyncHttpClient; + private UserAgentInterceptor userAgentInterceptor; + private S3AsyncClient s3Client; + + @BeforeEach + void init() { + this.mockAsyncHttpClient = new MockAsyncHttpClient(); + this.userAgentInterceptor = new UserAgentInterceptor(); + s3Client = S3AsyncClient.builder() + .httpClient(mockAsyncHttpClient) + .endpointOverride(URI.create("http://localhost")) + .overrideConfiguration(c -> c.addExecutionInterceptor(userAgentInterceptor)) + .multipartConfiguration(c -> c.minimumPartSizeInBytes(512L).thresholdInBytes(512L)) + .multipartEnabled(true) + .region(Region.US_EAST_1) + .build(); + } + + @AfterEach + void reset() { + this.mockAsyncHttpClient.reset(); + } + + @Test + void validateUserAgent_nonMultipartMethod() throws Exception { + HttpExecuteResponse response = HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + mockAsyncHttpClient.stubResponses(response); + + s3Client.headObject(req -> req.key("mock").bucket("mock")).get(); + + assertThat(userAgentInterceptor.apiNames) + .anyMatch(api -> "hll".equals(api.name()) && "s3Multipart".equals(api.version())); + } + + private static final class UserAgentInterceptor implements ExecutionInterceptor { + private final List apiNames = new ArrayList<>(); + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + context.request().overrideConfiguration().ifPresent(c -> apiNames.addAll(c.apiNames())); + } + } + +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientBuilderTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientBuilderTest.java new file mode 100644 index 000000000000..510d441c4caa --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientBuilderTest.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; + +class S3MultipartClientBuilderTest { + + @Test + void multipartEnabledWithConfig_shouldBuildMultipartClient() { + S3AsyncClient client = S3AsyncClient.builder() + .multipartEnabled(true) + .multipartConfiguration(MultipartConfiguration.builder().build()) + .region(Region.US_EAST_1) + .build(); + assertThat(client).isInstanceOf(MultipartS3AsyncClient.class); + } + + @Test + void multipartEnabledWithoutConfig_shouldBuildMultipartClient() { + S3AsyncClient client = S3AsyncClient.builder() + .multipartEnabled(true) + .region(Region.US_EAST_1) + .build(); + assertThat(client).isInstanceOf(MultipartS3AsyncClient.class); + } + + @Test + void multipartDisabledWithConfig_shouldNotBuildMultipartClient() { + S3AsyncClient client = S3AsyncClient.builder() + .multipartEnabled(false) + .multipartConfiguration(b -> b.apiCallBufferSizeInBytes(1024L)) + .region(Region.US_EAST_1) + .build(); + assertThat(client).isNotInstanceOf(MultipartS3AsyncClient.class); + } + + @Test + void noMultipart_shouldNotBeMultipartClient() { + S3AsyncClient client = S3AsyncClient.builder() + .region(Region.US_EAST_1) + .build(); + assertThat(client).isNotInstanceOf(MultipartS3AsyncClient.class); + } +} From 28c126d1aa714de7af6dc1e268185415c41d73b6 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Wed, 2 Aug 2023 21:06:49 -0700 Subject: [PATCH 11/13] Fix test --- .../services/s3/internal/crt/CopyObjectHelperTest.java | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java index f8dd7335494f..bd5c34f91048 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java @@ -24,7 +24,6 @@ import static org.mockito.Mockito.when; import java.util.List; -import java.util.Random; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -49,10 +48,7 @@ import software.amazon.awssdk.services.s3.model.NoSuchBucketException; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; import software.amazon.awssdk.services.s3.model.UploadPartCopyResponse; -import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.CompletableFutureUtils; -import software.amazon.awssdk.utils.Md5Utils; - class CopyObjectHelperTest { private static final String SOURCE_BUCKET = "source"; @@ -66,7 +62,7 @@ class CopyObjectHelperTest { private CopyObjectHelper copyHelper; private static final long PART_SIZE = 1024L; - private static final long UPLOAD_THRESHOLD = 2048L; + private static final long UPLOAD_THRESHOLD = PART_SIZE * 2; @BeforeEach public void setUp() { @@ -284,6 +280,7 @@ void multiPartCopy_contentSizeExceeds10000Parts_shouldAdjustPartSize() { } } + @Test public void multiPartCopy_sseCHeadersSetInOriginalRequest_includedInCompleteMultipart() { String customerAlgorithm = "algorithm"; @@ -294,7 +291,7 @@ public void multiPartCopy_sseCHeadersSetInOriginalRequest_includedInCompleteMult .sseCustomerKey(customerKey) .sseCustomerKeyMD5(customerKeyMd5)); - stubSuccessfulHeadObjectCall(2 * PART_SIZE_BYTES); + stubSuccessfulHeadObjectCall(3 * PART_SIZE_BYTES); stubSuccessfulCreateMulipartCall(); stubSuccessfulUploadPartCopyCalls(); stubSuccessfulCompleteMultipartCall(); From 85a1fd716d4a053c24c46bdb0f0000f67106612a Mon Sep 17 00:00:00 2001 From: Olivier L Applin Date: Thu, 3 Aug 2023 09:04:55 -0400 Subject: [PATCH 12/13] Guard against re-subscription in SplittingPublisher (#4253) * guard against re-subscription in SplittingPublisher * fix checkstyle * Error msg --- .../internal/async/SplittingPublisher.java | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 43f2e10ff192..c56d1b6437d9 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -25,6 +25,8 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.core.internal.util.NoopSubscription; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.async.SimplePublisher; @@ -48,8 +50,8 @@ public class SplittingPublisher implements SdkPublisher { private final long bufferSizeInBytes; private SplittingPublisher(Builder builder) { - this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); - this.chunkSizeInBytes = builder.chunkSizeInBytes == null ? DEFAULT_CHUNK_SIZE : builder.chunkSizeInBytes; + this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); + this.chunkSizeInBytes = builder.chunkSizeInBytes == null ? DEFAULT_CHUNK_SIZE : builder.chunkSizeInBytes; this.bufferSizeInBytes = builder.bufferSizeInBytes == null ? DEFAULT_BUFFER_SIZE : builder.bufferSizeInBytes; this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); @@ -234,13 +236,14 @@ private Long totalDataRemaining() { private final class DownstreamBody implements AsyncRequestBody { /** - * The maximum length of the content this AsyncRequestBody can hold. - * If the upstream content length is known, this is the same as totalLength + * The maximum length of the content this AsyncRequestBody can hold. If the upstream content length is known, this is + * the same as totalLength */ private final long maxLength; private final Long totalLength; private final SimplePublisher delegate = new SimplePublisher<>(); private final int chunkNumber; + private final AtomicBoolean subscribeCalled = new AtomicBoolean(false); private volatile long transferredLength = 0; private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) { @@ -282,7 +285,14 @@ public void error(Throwable error) { @Override public void subscribe(Subscriber s) { - delegate.subscribe(s); + if (subscribeCalled.compareAndSet(false, true)) { + delegate.subscribe(s); + } else { + s.onSubscribe(new NoopSubscription(s)); + s.onError(NonRetryableException.create( + "A retry was attempted, but AsyncRequestBody.split does not " + + "support retries.")); + } } private void addDataBuffered(int length) { @@ -293,7 +303,7 @@ private void addDataBuffered(int length) { } } } - + public static final class Builder { private AsyncRequestBody asyncRequestBody; private Long chunkSizeInBytes; From 466e6876b9fdc1e3eace44ca811cbda6ee95b2f2 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Thu, 3 Aug 2023 11:32:06 -0700 Subject: [PATCH 13/13] Fix a race condition where the third upload part request was sent before the second one (#4260) --- .../multipart/UploadWithUnknownContentLengthHelper.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index fa8be1e0c6f3..0c8c3c70b516 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -174,11 +174,13 @@ public void onNext(AsyncRequestBody asyncRequestBody) { subscription.cancel(); } else { uploadId = createMultipartUploadResponse.uploadId(); - uploadIdFuture.complete(uploadId); log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); sendUploadPartRequest(uploadId, firstRequestBody); sendUploadPartRequest(uploadId, asyncRequestBody); + + // We need to complete the uploadIdFuture *after* the first two requests have been sent + uploadIdFuture.complete(uploadId); } }); CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture);