Skip to content

THRIFT-5862: Validate the message size at the endpoint transport only #3127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public TDeserializer() throws TTransportException {
* @throws TTransportException if there an error initializing the underlying transport.
*/
public TDeserializer(TProtocolFactory protocolFactory) throws TTransportException {
trans_ = new TMemoryInputTransport(new TConfiguration());
trans_ = new TMemoryInputTransport();
protocol_ = protocolFactory.getProtocol(trans_);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ public void writeBinary(ByteBuffer bin) throws TException {
/** Reading methods. */
@Override
public TMessage readMessageBegin() throws TException {
trans_.readMessageBegin();
int size = readI32();
if (size < 0) {
int version = size & VERSION_MASK;
Expand All @@ -286,7 +287,9 @@ public TMessage readMessageBegin() throws TException {
}

@Override
public void readMessageEnd() throws TException {}
public void readMessageEnd() throws TException {
trans_.readMessageEnd();
}

@Override
public TStruct readStructBegin() throws TException {
Expand All @@ -310,7 +313,6 @@ public void readFieldEnd() throws TException {}
public TMap readMapBegin() throws TException {
TMap map = new TMap(readByte(), readByte(), readI32());

checkReadBytesAvailable(map);
checkContainerReadLength(map.size);
return map;
}
Expand All @@ -322,7 +324,6 @@ public void readMapEnd() throws TException {}
public TList readListBegin() throws TException {
TList list = new TList(readByte(), readI32());

checkReadBytesAvailable(list);
checkContainerReadLength(list.size);
return list;
}
Expand All @@ -334,7 +335,6 @@ public void readListEnd() throws TException {}
public TSet readSetBegin() throws TException {
TSet set = new TSet(readByte(), readI32());

checkReadBytesAvailable(set);
checkContainerReadLength(set.size);
return set;
}
Expand Down Expand Up @@ -497,8 +497,6 @@ private void checkStringReadLength(int length) throws TException {
throw new TProtocolException(TProtocolException.NEGATIVE_SIZE, "Negative length: " + length);
}

getTransport().checkReadBytesAvailable(length);

if (stringLengthLimit_ != NO_LENGTH_LIMIT && length > stringLengthLimit_) {
throw new TProtocolException(
TProtocolException.SIZE_LIMIT, "Length exceeded max allowed: " + length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ public TMessage readMessageBegin() throws TException {
byte type = (byte) ((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS);
int seqid = readVarint32();
String messageName = readString();
trans_.readMessageBegin();
return new TMessage(messageName, type, seqid);
}

Expand Down Expand Up @@ -575,7 +576,6 @@ public TMap readMapBegin() throws TException {
getTType((byte) (keyAndValueType >> 4)),
getTType((byte) (keyAndValueType & 0xf)),
size);
checkReadBytesAvailable(map);
return map;
}

Expand All @@ -593,7 +593,6 @@ public TList readListBegin() throws TException {
}
checkContainerReadLength(size);
TList list = new TList(getTType(size_and_type), size);
checkReadBytesAvailable(list);
return list;
}

Expand Down Expand Up @@ -697,7 +696,6 @@ public ByteBuffer readBinary() throws TException {
if (length == 0) {
return EMPTY_BUFFER;
}
getTransport().checkReadBytesAvailable(length);
if (trans_.getBytesRemainingInBuffer() >= length) {
ByteBuffer bb = ByteBuffer.wrap(trans_.getBuffer(), trans_.getBufferPosition(), length);
trans_.consumeBuffer(length);
Expand All @@ -723,8 +721,6 @@ private void checkStringReadLength(int length) throws TException {
throw new TProtocolException(TProtocolException.NEGATIVE_SIZE, "Negative length: " + length);
}

getTransport().checkReadBytesAvailable(length);

if (stringLengthLimit_ != NO_LENGTH_LIMIT && length > stringLengthLimit_) {
throw new TProtocolException(
TProtocolException.SIZE_LIMIT, "Length exceeded max allowed: " + length);
Expand All @@ -746,7 +742,9 @@ private void checkContainerReadLength(int length) throws TProtocolException {
// encoding.
//
@Override
public void readMessageEnd() throws TException {}
public void readMessageEnd() throws TException {
trans_.readMessageEnd();
}

@Override
public void readFieldEnd() throws TException {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,12 +832,14 @@ public TMessage readMessageBegin() throws TException {
String name = readJSONString(false).toString(StandardCharsets.UTF_8);
byte type = (byte) readJSONInteger();
int seqid = (int) readJSONInteger();
trans_.readMessageBegin();
return new TMessage(name, type, seqid);
}

@Override
public void readMessageEnd() throws TException {
readJSONArrayEnd();
trans_.readMessageEnd();
}

@Override
Expand Down Expand Up @@ -880,7 +882,6 @@ public TMap readMapBegin() throws TException {
readJSONObjectStart();
TMap map = new TMap(keyType, valueType, size);

checkReadBytesAvailable(map);
return map;
}

Expand All @@ -897,7 +898,6 @@ public TList readListBegin() throws TException {
int size = (int) readJSONInteger();
TList list = new TList(elemType, size);

checkReadBytesAvailable(list);
return list;
}

Expand All @@ -913,7 +913,6 @@ public TSet readSetBegin() throws TException {
int size = (int) readJSONInteger();
TSet set = new TSet(elemType, size);

checkReadBytesAvailable(set);
return set;
}

Expand Down
15 changes: 0 additions & 15 deletions lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,6 @@ public TTransport getTransport() {
return trans_;
}

protected void checkReadBytesAvailable(TMap map) throws TException {
long elemSize = getMinSerializedSize(map.keyType) + getMinSerializedSize(map.valueType);
trans_.checkReadBytesAvailable(map.size * elemSize);
}

protected void checkReadBytesAvailable(TList list) throws TException {
long size = list.getSize();
trans_.checkReadBytesAvailable(size * getMinSerializedSize(list.elemType));
}

protected void checkReadBytesAvailable(TSet set) throws TException {
long size = set.getSize();
trans_.checkReadBytesAvailable(size * getMinSerializedSize(set.elemType));
}

/**
* Return min serialized size in bytes
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,13 @@ public TMap readMapBegin(byte keyType, byte valTyep) throws TException {
int size = super.readI32();
TMap map = new TMap(keyType, valTyep, size);

checkReadBytesAvailable(map);
return map;
}

public TList readListBegin(byte type) throws TException {
int size = super.readI32();
TList list = new TList(type, size);

checkReadBytesAvailable(list);
return list;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
*/
package org.apache.thrift.transport;

import org.apache.thrift.TConfiguration;

/** TTransport for reading from an AutoExpandingBuffer. */
public class AutoExpandingBufferReadTransport extends TEndpointTransport {
public class AutoExpandingBufferReadTransport extends TTransport {

private final AutoExpandingBuffer buf;

private int pos = 0;
private int limit = 0;

public AutoExpandingBufferReadTransport(TConfiguration config, int initialCapacity)
throws TTransportException {
super(config);
public AutoExpandingBufferReadTransport(int initialCapacity) {
this.buf = new AutoExpandingBuffer(initialCapacity);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
*/
package org.apache.thrift.transport;

import org.apache.thrift.TConfiguration;

/** TTransport for writing to an AutoExpandingBuffer. */
public final class AutoExpandingBufferWriteTransport extends TEndpointTransport {
public final class AutoExpandingBufferWriteTransport extends TTransport {

private final AutoExpandingBuffer buf;
private int pos;
Expand All @@ -30,7 +28,6 @@ public final class AutoExpandingBufferWriteTransport extends TEndpointTransport
/**
* Constructor.
*
* @param config the configuration to use. Currently used for defining the maximum message size.
* @param initialCapacity the initial capacity of the buffer
* @param frontReserve space, if any, to reserve at the beginning such that the first write is
* after this reserve. This allows framed transport to reserve space for the frame buffer
Expand All @@ -39,9 +36,8 @@ public final class AutoExpandingBufferWriteTransport extends TEndpointTransport
* @throws IllegalArgumentException if frontReserve is less than zero
* @throws IllegalArgumentException if frontReserve is greater than initialCapacity
*/
public AutoExpandingBufferWriteTransport(
TConfiguration config, int initialCapacity, int frontReserve) throws TTransportException {
super(config);
public AutoExpandingBufferWriteTransport(int initialCapacity, int frontReserve)
throws TTransportException {
if (initialCapacity < 1) {
throw new IllegalArgumentException("initialCapacity");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,18 @@
import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import org.apache.thrift.TConfiguration;

/** ByteBuffer-backed implementation of TTransport. */
public final class TByteBuffer extends TEndpointTransport {
public final class TByteBuffer extends TTransport {
private final ByteBuffer byteBuffer;

/**
* Creates a new TByteBuffer wrapping a given NIO ByteBuffer and custom TConfiguration.
*
* @param configuration the custom TConfiguration.
* @param byteBuffer the NIO ByteBuffer to wrap.
* @throws TTransportException on error.
*/
public TByteBuffer(TConfiguration configuration, ByteBuffer byteBuffer)
throws TTransportException {
super(configuration);
this.byteBuffer = byteBuffer;
updateKnownMessageSize(byteBuffer.capacity());
}

/**
* Creates a new TByteBuffer wrapping a given NIO ByteBuffer.
*
* @param byteBuffer the NIO ByteBuffer to wrap.
* @throws TTransportException on error.
*/
public TByteBuffer(ByteBuffer byteBuffer) throws TTransportException {
this(new TConfiguration(), byteBuffer);
public TByteBuffer(ByteBuffer byteBuffer) {
this.byteBuffer = byteBuffer;
}

@Override
Expand All @@ -47,8 +31,6 @@ public void close() {}
@Override
public int read(byte[] buf, int off, int len) throws TTransportException {
//
checkReadBytesAvailable(len);

final int n = Math.min(byteBuffer.remaining(), len);
if (n > 0) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,82 +35,32 @@ public void setMaxFrameSize(int maxFrameSize) {
getConfiguration().setMaxFrameSize(maxFrameSize);
}

protected long knownMessageSize;
protected long remainingMessageSize;
private long consumedMessage;

private TConfiguration _configuration;
private final TConfiguration _configuration;

public TConfiguration getConfiguration() {
return _configuration;
}

public TEndpointTransport(TConfiguration config) throws TTransportException {
public TEndpointTransport(TConfiguration config) {
_configuration = Objects.isNull(config) ? new TConfiguration() : config;

resetConsumedMessageSize(-1);
}

/**
* Resets RemainingMessageSize to the configured maximum
*
* @param newSize
*/
protected void resetConsumedMessageSize(long newSize) throws TTransportException {
// full reset
if (newSize < 0) {
knownMessageSize = getMaxMessageSize();
remainingMessageSize = getMaxMessageSize();
return;
}

// update only: message size can shrink, but not grow
if (newSize > knownMessageSize)
throw new TTransportException(
TTransportException.MESSAGE_SIZE_LIMIT,
"Message size exceeds limit: " + getMaxMessageSize());

knownMessageSize = newSize;
remainingMessageSize = newSize;
@Override
public void readMessageBegin() {
consumedMessage = 0;
}

/**
* Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
* Will throw if we already consumed too many bytes or if the new size is larger than allowed.
*
* @param size
*/
public void updateKnownMessageSize(long size) throws TTransportException {
long consumed = knownMessageSize - remainingMessageSize;
resetConsumedMessageSize(size == 0 ? -1 : size);
countConsumedMessageBytes(consumed);
}

/**
* Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of
* data
*
* @param numBytes
*/
public void checkReadBytesAvailable(long numBytes) throws TTransportException {
if (remainingMessageSize < numBytes)
public final void consumeReadMessageBytes(int size) throws TTransportException {
consumedMessage += size;
if (consumedMessage > getMaxMessageSize())
throw new TTransportException(
TTransportException.MESSAGE_SIZE_LIMIT,
"Message size exceeds limit: " + getMaxMessageSize());
}

/**
* Consumes numBytes from the RemainingMessageSize.
*
* @param numBytes
*/
protected void countConsumedMessageBytes(long numBytes) throws TTransportException {
if (remainingMessageSize >= numBytes) {
remainingMessageSize -= numBytes;
} else {
remainingMessageSize = 0;
throw new TTransportException(
TTransportException.MESSAGE_SIZE_LIMIT,
"Message size exceeds limit: " + getMaxMessageSize());
}
public long getConsumedMessage() {
return consumedMessage;
}
}
Loading
Loading