Skip to content

Ensure onNext will be called even if publishing empty content and onC… #4290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-e70484b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Sends final checksum chunk and trailer when only onComplete() is called by upstream (empty content)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
package software.amazon.awssdk.core.internal.async;

import static software.amazon.awssdk.core.HttpChecksumConstant.DEFAULT_ASYNC_CHUNK_SIZE;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumContentLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.LAST_CHUNK_LEN;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumTrailerLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChunkLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.createChecksumTrailer;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.createChunk;
Expand All @@ -28,11 +29,13 @@
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.async.DelegatingSubscriber;
import software.amazon.awssdk.utils.builder.SdkBuilder;

/**
Expand Down Expand Up @@ -129,13 +132,12 @@ public ChecksumCalculatingAsyncRequestBody.Builder trailerHeader(String trailerH

@Override
public Optional<Long> contentLength() {

if (wrapped.contentLength().isPresent() && algorithm != null) {
return Optional.of(calculateChunkLength(wrapped.contentLength().get())
+ calculateChecksumContentLength(algorithm, trailerHeader));
} else {
return wrapped.contentLength();
+ LAST_CHUNK_LEN
+ calculateChecksumTrailerLength(algorithm, trailerHeader));
}
return wrapped.contentLength();
}

@Override
Expand All @@ -149,12 +151,15 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
if (sdkChecksum != null) {
sdkChecksum.reset();
}

SynchronousChunkBuffer synchronousChunkBuffer = new SynchronousChunkBuffer(totalBytes);
wrapped.flatMapIterable(synchronousChunkBuffer::buffer)
alwaysInvokeOnNext(wrapped).flatMapIterable(synchronousChunkBuffer::buffer)
.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, trailerHeader, totalBytes));
}

private SdkPublisher<ByteBuffer> alwaysInvokeOnNext(SdkPublisher<ByteBuffer> source) {
return subscriber -> source.subscribe(new OnNextGuaranteedSubscriber(subscriber));
}

private static final class ChecksumCalculatingSubscriber implements Subscriber<ByteBuffer> {

private final Subscriber<? super ByteBuffer> wrapped;
Expand Down Expand Up @@ -243,4 +248,30 @@ private Iterable<ByteBuffer> buffer(ByteBuffer bytes) {
}
}

public static class OnNextGuaranteedSubscriber extends DelegatingSubscriber<ByteBuffer, ByteBuffer> {

private volatile boolean onNextInvoked;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an AtomicBoolean instead of volatile and use compareAndSet like in the DelegatingSubscriber?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is most correct in this case. If more than one thread can modify it, it seems like then we can miss a case where we're thinking onnext was never called but it was (reading before updating). Then Atomic would be better. @zoewangg does this count as an operation where atomicity is important?
Also have not been able to see a lot of benefits of volatile vs atomic...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onNext, onComplete are guaranteed to be signaled in serial per Reactive Streams spec. https://github.com/reactive-streams/reactive-streams-jvm, so it's not possible to be invoked concurrently.


public OnNextGuaranteedSubscriber(Subscriber<? super ByteBuffer> subscriber) {
super(subscriber);
}

@Override
public void onNext(ByteBuffer t) {
if (!onNextInvoked) {
onNextInvoked = true;
}

subscriber.onNext(t);
}

@Override
public void onComplete() {
if (!onNextInvoked) {
subscriber.onNext(ByteBuffer.wrap(new byte[0]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: request demand should be handled in FlatteningSubscriber as part of flatMapIterable

}
super.onComplete();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import static software.amazon.awssdk.core.HttpChecksumConstant.DEFAULT_ASYNC_CHUNK_SIZE;
import static software.amazon.awssdk.core.HttpChecksumConstant.SIGNING_METHOD;
import static software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE;
import static software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumContentLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumTrailerLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateStreamContentLength;
import static software.amazon.awssdk.core.internal.util.HttpChecksumResolver.getResolvedChecksumSpecs;
import static software.amazon.awssdk.http.Header.CONTENT_LENGTH;

Expand Down Expand Up @@ -179,7 +179,7 @@ private void addFlexibleChecksumInTrailer(SdkHttpFullRequest.Builder request, Re
}
}

long checksumContentLength = calculateChecksumContentLength(checksumSpecs.algorithm(), checksumSpecs.headerName());
long checksumContentLength = calculateChecksumTrailerLength(checksumSpecs.algorithm(), checksumSpecs.headerName());
long contentLen = checksumContentLength + calculateStreamContentLength(originalContentLength, chunkSize);

request.putHeader(HttpChecksumConstant.HEADER_FOR_TRAILER_REFERENCE, checksumSpecs.headerName())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig;
Expand All @@ -40,48 +39,6 @@ public static Builder builder() {
return new Builder();
}

/**
* Calculates the content length for a given Algorithm and header name.
*
* @param algorithm Algorithm used.
* @param headerName Header name.
* @return Content length of the trailer that will be appended at the end.
*/
public static long calculateChecksumContentLength(Algorithm algorithm, String headerName) {
return headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ algorithm.base64EncodedLength().longValue()
+ CRLF.length() + CRLF.length();
}

/**
*
* @param originalContentLength Original Content length.
* @return Calculatec Chunk Length with the chunk encoding format.
*/
private static long calculateChunkLength(long originalContentLength) {
return Long.toHexString(originalContentLength).length()
+ CRLF.length()
+ originalContentLength
+ CRLF.length();
}

public static long calculateStreamContentLength(long originalLength, long defaultChunkSize) {
if (originalLength < 0 || defaultChunkSize == 0) {
throw new IllegalArgumentException(originalLength + ", " + defaultChunkSize + "Args <= 0 not expected");
}

long maxSizeChunks = originalLength / defaultChunkSize;
long remainingBytes = originalLength % defaultChunkSize;

long allChunks = maxSizeChunks * calculateChunkLength(defaultChunkSize);
long remainingInChunk = remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0;
// last byte is composed of a "0" and "\r\n"
long lastByteSize = 1 + (long) CRLF.length();

return allChunks + remainingInChunk + lastByteSize;
}

@Override
protected byte[] createFinalChunk(byte[] finalChunk) {
StringBuilder chunkHeader = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,64 @@ public final class ChunkContentUtils {
public static final String ZERO_BYTE = "0";
public static final String CRLF = "\r\n";

public static final String LAST_CHUNK = ZERO_BYTE + CRLF;
public static final long LAST_CHUNK_LEN = LAST_CHUNK.length();

private ChunkContentUtils() {
}

/**
* The chunk format is: chunk-size CRLF chunk-data CRLF.
*
* @param originalContentLength Original Content length.
* @return Calculates Chunk Length.
* @return the length of this chunk
*/
public static long calculateChunkLength(long originalContentLength) {
if (originalContentLength == 0) {
return 0;
}
return Long.toHexString(originalContentLength).length()
+ CRLF.length()
+ originalContentLength
+ CRLF.length()
+ ZERO_BYTE.length() + CRLF.length();
+ CRLF.length()
+ originalContentLength
+ CRLF.length();
}

/**
* Calculates the content length for data that is divided into chunks.
*
* @param originalLength original content length.
* @param chunkSize chunk size
* @return Content length of the trailer that will be appended at the end.
*/
public static long calculateStreamContentLength(long originalLength, long chunkSize) {
if (originalLength < 0 || chunkSize == 0) {
throw new IllegalArgumentException(originalLength + ", " + chunkSize + "Args <= 0 not expected");
}

long maxSizeChunks = originalLength / chunkSize;
long remainingBytes = originalLength % chunkSize;

long allChunks = maxSizeChunks * calculateChunkLength(chunkSize);
long remainingInChunk = remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0;
// last byte is composed of a "0" and "\r\n"
long lastByteSize = 1 + (long) CRLF.length();

return allChunks + remainingInChunk + lastByteSize;
}

/**
* Calculates the content length for a given Algorithm and header name.
* Calculates the content length for a given algorithm and header name.
*
* @param algorithm Algorithm used.
* @param headerName Header name.
* @return Content length of the trailer that will be appended at the end.
*/
public static long calculateChecksumContentLength(Algorithm algorithm, String headerName) {
int checksumLength = algorithm.base64EncodedLength();

return (headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ checksumLength
+ CRLF.length() + CRLF.length());
public static long calculateChecksumTrailerLength(Algorithm algorithm, String headerName) {
return headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ algorithm.base64EncodedLength().longValue()
+ CRLF.length()
+ CRLF.length();
}

/**
Expand Down Expand Up @@ -86,17 +115,13 @@ public static ByteBuffer createChunk(ByteBuffer chunkData, boolean isLastByte) {
chunkHeader.append(CRLF);
try {
byte[] header = chunkHeader.toString().getBytes(StandardCharsets.UTF_8);
// Last byte does not need additional \r\n trailer
byte[] trailer = !isLastByte ? CRLF.getBytes(StandardCharsets.UTF_8)
: "".getBytes(StandardCharsets.UTF_8);
ByteBuffer chunkFormattedBuffer = ByteBuffer.allocate(header.length + chunkLength + trailer.length);
chunkFormattedBuffer.put(header)
.put(chunkData)
.put(trailer);
chunkFormattedBuffer.put(header).put(chunkData).put(trailer);
chunkFormattedBuffer.flip();
return chunkFormattedBuffer;
} catch (Exception e) {
// This is to warp BufferOverflowException,ReadOnlyBufferException to SdkClientException.
throw SdkClientException.builder()
.message("Unable to create chunked data. " + e.getMessage())
.cause(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.awssdk.core.checksum;

import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumTrailerLength;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -25,6 +26,7 @@
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.util.ChunkContentUtils;

public class AwsChunkedEncodingInputStreamTest {

Expand Down Expand Up @@ -55,10 +57,9 @@ public void readAwsUnsignedChunkedEncodingInputStream() throws IOException {
public void lengthsOfCalculateByChecksumCalculatingInputStream(){

String initialString = "Hello world";
long calculateChunkLength = AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(initialString.length(),
AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE);
long checksumContentLength = AwsUnsignedChunkedEncodingInputStream.calculateChecksumContentLength(
SHA256_ALGORITHM, SHA256_HEADER_NAME);
long calculateChunkLength = ChunkContentUtils.calculateStreamContentLength(initialString.length(),
AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE);
long checksumContentLength = calculateChecksumTrailerLength(SHA256_ALGORITHM, SHA256_HEADER_NAME);
assertThat(calculateChunkLength).isEqualTo(19);
assertThat(checksumContentLength).isEqualTo(71);
}
Expand Down
Loading