Skip to content

Commit ac94b6e

Browse files
committed
netty: Add option to limit RST_STREAM rate
The behavior purposefully mirrors that of Netty's AbstractHttp2ConnectionHandlerBuilder.decoderEnforceMaxRstFramesPerWindow(). That API is not available to our code as we extend the Http2ConnectionHandler, but we want our API to be able to delegate to Netty's in the future if that ever becomes possible.
1 parent 4caf106 commit ac94b6e

File tree

7 files changed

+140
-2
lines changed

7 files changed

+140
-2
lines changed

netty/src/main/java/io/grpc/netty/NettyServer.java

+7
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
9999
private final long maxConnectionAgeGraceInNanos;
100100
private final boolean permitKeepAliveWithoutCalls;
101101
private final long permitKeepAliveTimeInNanos;
102+
private final int maxRstCount;
103+
private final long maxRstPeriodNanos;
102104
private final Attributes eagAttributes;
103105
private final ReferenceCounted sharedResourceReferenceCounter =
104106
new SharedResourceReferenceCounter();
@@ -127,6 +129,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
127129
long maxConnectionIdleInNanos,
128130
long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos,
129131
boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos,
132+
int maxRstCount, long maxRstPeriodNanos,
130133
Attributes eagAttributes, InternalChannelz channelz) {
131134
this.addresses = checkNotNull(addresses, "addresses");
132135
this.channelFactory = checkNotNull(channelFactory, "channelFactory");
@@ -156,6 +159,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
156159
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
157160
this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls;
158161
this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos;
162+
this.maxRstCount = maxRstCount;
163+
this.maxRstPeriodNanos = maxRstPeriodNanos;
159164
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
160165
this.channelz = Preconditions.checkNotNull(channelz);
161166
this.logId = InternalLogId.allocate(getClass(), addresses.isEmpty() ? "No address" :
@@ -257,6 +262,8 @@ public void initChannel(Channel ch) {
257262
maxConnectionAgeGraceInNanos,
258263
permitKeepAliveWithoutCalls,
259264
permitKeepAliveTimeInNanos,
265+
maxRstCount,
266+
maxRstPeriodNanos,
260267
eagAttributes);
261268
ServerTransportListener transportListener;
262269
// This is to order callbacks on the listener, not to guard access to channel.

netty/src/main/java/io/grpc/netty/NettyServerBuilder.java

+31-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ public final class NettyServerBuilder extends ForwardingServerBuilder<NettyServe
7575
static final long MAX_CONNECTION_IDLE_NANOS_DISABLED = Long.MAX_VALUE;
7676
static final long MAX_CONNECTION_AGE_NANOS_DISABLED = Long.MAX_VALUE;
7777
static final long MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE = Long.MAX_VALUE;
78+
static final int MAX_RST_COUNT_DISABLED = 0;
7879

7980
private static final long MIN_KEEPALIVE_TIME_NANO = TimeUnit.MILLISECONDS.toNanos(1L);
8081
private static final long MIN_KEEPALIVE_TIMEOUT_NANO = TimeUnit.MICROSECONDS.toNanos(499L);
@@ -113,6 +114,8 @@ public final class NettyServerBuilder extends ForwardingServerBuilder<NettyServe
113114
private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
114115
private boolean permitKeepAliveWithoutCalls;
115116
private long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);
117+
private int maxRstCount;
118+
private long maxRstPeriodNanos;
116119
private Attributes eagAttributes = Attributes.EMPTY;
117120

118121
/**
@@ -644,6 +647,33 @@ public NettyServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
644647
return this;
645648
}
646649

650+
/**
651+
* Limits the rate of incoming RST_STREAM frames per connection to maxRstStream per
652+
* secondsPerWindow. When exceeded on a connection, the connection is closed. This can reduce the
653+
* impact of an attacker continually resetting RPCs before they complete, when combined with TLS
654+
* and {@link #maxConcurrentCallsPerConnection(int)}.
655+
*
656+
* <p>gRPC clients send RST_STREAM when they cancel RPCs, so some RST_STREAMs are normal and
657+
* setting this too low can cause errors for legimitate clients.
658+
*
659+
* <p>By default there is no limit.
660+
*
661+
* @param maxRstStream the positive limit of RST_STREAM frames per connection per period, or
662+
* {@code Integer.MAX_VALUE} for unlimited
663+
* @param secondsPerWindow the positive number of seconds per period
664+
*/
665+
@CanIgnoreReturnValue
666+
public NettyServerBuilder maxRstFramesPerWindow(int maxRstStream, int secondsPerWindow) {
667+
checkArgument(maxRstStream > 0, "maxRstStream must be positive");
668+
checkArgument(secondsPerWindow > 0, "secondsPerWindow must be positive");
669+
if (maxRstStream == Integer.MAX_VALUE) {
670+
maxRstStream = MAX_RST_COUNT_DISABLED;
671+
}
672+
this.maxRstCount = maxRstStream;
673+
this.maxRstPeriodNanos = TimeUnit.SECONDS.toNanos(secondsPerWindow);
674+
return this;
675+
}
676+
647677
/** Sets the EAG attributes available to protocol negotiators. Not for general use. */
648678
void eagAttributes(Attributes eagAttributes) {
649679
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
@@ -664,7 +694,7 @@ NettyServer buildTransportServers(
664694
keepAliveTimeInNanos, keepAliveTimeoutInNanos,
665695
maxConnectionIdleInNanos, maxConnectionAgeInNanos,
666696
maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos,
667-
eagAttributes, this.serverImplBuilder.getChannelz());
697+
maxRstCount, maxRstPeriodNanos, eagAttributes, this.serverImplBuilder.getChannelz());
668698
}
669699

670700
@VisibleForTesting

netty/src/main/java/io/grpc/netty/NettyServerHandler.java

+40
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,13 @@ class NettyServerHandler extends AbstractNettyHandler {
125125
private final long keepAliveTimeoutInNanos;
126126
private final long maxConnectionAgeInNanos;
127127
private final long maxConnectionAgeGraceInNanos;
128+
private final int maxRstCount;
129+
private final long maxRstPeriodNanos;
128130
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
129131
private final TransportTracer transportTracer;
130132
private final KeepAliveEnforcer keepAliveEnforcer;
131133
private final Attributes eagAttributes;
134+
private final Ticker ticker;
132135
/** Incomplete attributes produced by negotiator. */
133136
private Attributes negotiationAttributes;
134137
private InternalChannelz.Security securityInfo;
@@ -146,6 +149,9 @@ class NettyServerHandler extends AbstractNettyHandler {
146149
private ScheduledFuture<?> maxConnectionAgeMonitor;
147150
@CheckForNull
148151
private GracefulShutdown gracefulShutdown;
152+
private int rstCount;
153+
private long lastRstNanoTime;
154+
149155

150156
static NettyServerHandler newHandler(
151157
ServerTransportListener transportListener,
@@ -164,6 +170,8 @@ static NettyServerHandler newHandler(
164170
long maxConnectionAgeGraceInNanos,
165171
boolean permitKeepAliveWithoutCalls,
166172
long permitKeepAliveTimeInNanos,
173+
int maxRstCount,
174+
long maxRstPeriodNanos,
167175
Attributes eagAttributes) {
168176
Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s",
169177
maxHeaderListSize);
@@ -192,6 +200,8 @@ static NettyServerHandler newHandler(
192200
maxConnectionAgeGraceInNanos,
193201
permitKeepAliveWithoutCalls,
194202
permitKeepAliveTimeInNanos,
203+
maxRstCount,
204+
maxRstPeriodNanos,
195205
eagAttributes,
196206
Ticker.systemTicker());
197207
}
@@ -215,6 +225,8 @@ static NettyServerHandler newHandler(
215225
long maxConnectionAgeGraceInNanos,
216226
boolean permitKeepAliveWithoutCalls,
217227
long permitKeepAliveTimeInNanos,
228+
int maxRstCount,
229+
long maxRstPeriodNanos,
218230
Attributes eagAttributes,
219231
Ticker ticker) {
220232
Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams);
@@ -266,6 +278,8 @@ static NettyServerHandler newHandler(
266278
maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos,
267279
keepAliveEnforcer,
268280
autoFlowControl,
281+
maxRstCount,
282+
maxRstPeriodNanos,
269283
eagAttributes, ticker);
270284
}
271285

@@ -286,6 +300,8 @@ private NettyServerHandler(
286300
long maxConnectionAgeGraceInNanos,
287301
final KeepAliveEnforcer keepAliveEnforcer,
288302
boolean autoFlowControl,
303+
int maxRstCount,
304+
long maxRstPeriodNanos,
289305
Attributes eagAttributes,
290306
Ticker ticker) {
291307
super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(),
@@ -328,8 +344,12 @@ public void onStreamClosed(Http2Stream stream) {
328344
this.maxConnectionAgeInNanos = maxConnectionAgeInNanos;
329345
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
330346
this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer");
347+
this.maxRstCount = maxRstCount;
348+
this.maxRstPeriodNanos = maxRstPeriodNanos;
331349
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
350+
this.ticker = checkNotNull(ticker, "ticker");
332351

352+
this.lastRstNanoTime = ticker.read();
333353
streamKey = encoder.connection().newKey();
334354
this.transportListener = checkNotNull(transportListener, "transportListener");
335355
this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories");
@@ -527,6 +547,26 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
527547
}
528548

529549
private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception {
550+
if (maxRstCount > 0) {
551+
long now = ticker.read();
552+
if (now - lastRstNanoTime > maxRstPeriodNanos) {
553+
lastRstNanoTime = now;
554+
rstCount = 1;
555+
} else {
556+
rstCount++;
557+
if (rstCount > maxRstCount) {
558+
throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
559+
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
560+
@Override
561+
public Throwable fillInStackTrace() {
562+
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
563+
return this;
564+
}
565+
};
566+
}
567+
}
568+
}
569+
530570
try {
531571
NettyServerStream.TransportState stream = serverStream(connection().stream(streamId));
532572
if (stream != null) {

netty/src/main/java/io/grpc/netty/NettyServerTransport.java

+8
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class NettyServerTransport implements ServerTransport {
7777
private final long maxConnectionAgeGraceInNanos;
7878
private final boolean permitKeepAliveWithoutCalls;
7979
private final long permitKeepAliveTimeInNanos;
80+
private final int maxRstCount;
81+
private final long maxRstPeriodNanos;
8082
private final Attributes eagAttributes;
8183
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
8284
private final TransportTracer transportTracer;
@@ -99,6 +101,8 @@ class NettyServerTransport implements ServerTransport {
99101
long maxConnectionAgeGraceInNanos,
100102
boolean permitKeepAliveWithoutCalls,
101103
long permitKeepAliveTimeInNanos,
104+
int maxRstCount,
105+
long maxRstPeriodNanos,
102106
Attributes eagAttributes) {
103107
this.channel = Preconditions.checkNotNull(channel, "channel");
104108
this.channelUnused = channelUnused;
@@ -118,6 +122,8 @@ class NettyServerTransport implements ServerTransport {
118122
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
119123
this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls;
120124
this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos;
125+
this.maxRstCount = maxRstCount;
126+
this.maxRstPeriodNanos = maxRstPeriodNanos;
121127
this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes");
122128
SocketAddress remote = channel.remoteAddress();
123129
this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null);
@@ -277,6 +283,8 @@ private NettyServerHandler createHandler(
277283
maxConnectionAgeGraceInNanos,
278284
permitKeepAliveWithoutCalls,
279285
permitKeepAliveTimeInNanos,
286+
maxRstCount,
287+
maxRstPeriodNanos,
280288
eagAttributes);
281289
}
282290
}

netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
2828
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
2929
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
30+
import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED;
3031
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
3132
import static org.junit.Assert.assertEquals;
3233
import static org.junit.Assert.assertFalse;
@@ -781,7 +782,7 @@ private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) thr
781782
DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS,
782783
MAX_CONNECTION_IDLE_NANOS_DISABLED,
783784
MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0,
784-
Attributes.EMPTY,
785+
MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY,
785786
channelz);
786787
server.start(serverListener);
787788
address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress());

netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java

+45
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
2424
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
2525
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
26+
import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED;
2627
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
2728
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
2829
import static io.grpc.netty.Utils.HTTP_METHOD;
@@ -33,6 +34,7 @@
3334
import static org.junit.Assert.assertFalse;
3435
import static org.junit.Assert.assertNull;
3536
import static org.junit.Assert.assertSame;
37+
import static org.junit.Assert.assertThrows;
3638
import static org.junit.Assert.assertTrue;
3739
import static org.mockito.AdditionalAnswers.delegatesTo;
3840
import static org.mockito.ArgumentMatchers.any;
@@ -85,6 +87,7 @@
8587
import io.netty.handler.codec.http2.Http2Stream;
8688
import io.netty.util.AsciiString;
8789
import java.io.InputStream;
90+
import java.nio.channels.ClosedChannelException;
8891
import java.util.Arrays;
8992
import java.util.LinkedList;
9093
import java.util.List;
@@ -143,6 +146,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
143146
private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
144147
private long keepAliveTimeInNanos = DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
145148
private long keepAliveTimeoutInNanos = DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
149+
private int maxRstCount = MAX_RST_COUNT_DISABLED;
150+
private long maxRstPeriodNanos;
146151

147152
private class ServerTransportListenerImpl implements ServerTransportListener {
148153

@@ -1249,6 +1254,44 @@ public void maxConnectionAgeGrace_channelClosedAfterGracePeriod_withPingAck()
12491254
assertFalse(channel().isOpen());
12501255
}
12511256

1257+
@Test
1258+
public void maxRstCount_withinLimit_succeeds() throws Exception {
1259+
maxRstCount = 10;
1260+
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
1261+
manualSetUp();
1262+
rapidReset(maxRstCount);
1263+
assertTrue(channel().isOpen());
1264+
}
1265+
1266+
@Test
1267+
public void maxRstCount_exceedsLimit_fails() throws Exception {
1268+
maxRstCount = 10;
1269+
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
1270+
manualSetUp();
1271+
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));
1272+
assertFalse(channel().isOpen());
1273+
}
1274+
1275+
private void rapidReset(int burstSize) throws Exception {
1276+
Http2Headers headers = new DefaultHttp2Headers()
1277+
.method(HTTP_METHOD)
1278+
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
1279+
.set(TE_HEADER, TE_TRAILERS)
1280+
.path(new AsciiString("/foo/bar"));
1281+
int streamId = 1;
1282+
long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize;
1283+
for (int period = 0; period < 3; period++) {
1284+
for (int i = 0; i < burstSize; i++) {
1285+
channelRead(headersFrame(streamId, headers));
1286+
channelRead(rstStreamFrame(streamId, (int) Http2Error.CANCEL.code()));
1287+
streamId += 2;
1288+
fakeClock().forwardNanos(rpcTimeNanos);
1289+
}
1290+
while (channel().readOutbound() != null) {}
1291+
fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1);
1292+
}
1293+
}
1294+
12521295
private void createStream() throws Exception {
12531296
Http2Headers headers = new DefaultHttp2Headers()
12541297
.method(HTTP_METHOD)
@@ -1296,6 +1339,8 @@ protected NettyServerHandler newHandler() {
12961339
maxConnectionAgeGraceInNanos,
12971340
permitKeepAliveWithoutCalls,
12981341
permitKeepAliveTimeInNanos,
1342+
maxRstCount,
1343+
maxRstPeriodNanos,
12991344
Attributes.EMPTY,
13001345
fakeClock().getTicker());
13011346
}

netty/src/test/java/io/grpc/netty/NettyServerTest.java

+7
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class NoHandlerProtocolNegotiator implements ProtocolNegotiator {
153153
1, 1, // ignore
154154
1, 1, // ignore
155155
true, 0, // ignore
156+
0, 0, // ignore
156157
Attributes.EMPTY,
157158
channelz);
158159
final SettableFuture<Void> serverShutdownCalled = SettableFuture.create();
@@ -203,6 +204,7 @@ public void multiPortStartStopGet() throws Exception {
203204
1, 1, // ignore
204205
1, 1, // ignore
205206
true, 0, // ignore
207+
0, 0, // ignore
206208
Attributes.EMPTY,
207209
channelz);
208210
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
@@ -276,6 +278,7 @@ public void multiPortConnections() throws Exception {
276278
1, 1, // ignore
277279
1, 1, // ignore
278280
true, 0, // ignore
281+
0, 0, // ignore
279282
Attributes.EMPTY,
280283
channelz);
281284
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
@@ -337,6 +340,7 @@ public void getPort_notStarted() {
337340
1, 1, // ignore
338341
1, 1, // ignore
339342
true, 0, // ignore
343+
0, 0, // ignore
340344
Attributes.EMPTY,
341345
channelz);
342346

@@ -411,6 +415,7 @@ class TestProtocolNegotiator implements ProtocolNegotiator {
411415
1, 1, // ignore
412416
1, 1, // ignore
413417
true, 0, // ignore
418+
0, 0, // ignore
414419
eagAttributes,
415420
channelz);
416421
ns.start(new ServerListener() {
@@ -458,6 +463,7 @@ public void channelzListenSocket() throws Exception {
458463
1, 1, // ignore
459464
1, 1, // ignore
460465
true, 0, // ignore
466+
0, 0, // ignore
461467
Attributes.EMPTY,
462468
channelz);
463469
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
@@ -600,6 +606,7 @@ private NettyServer getServer(List<SocketAddress> addr, EventLoopGroup ev) {
600606
1, 1, // ignore
601607
1, 1, // ignore
602608
true, 0, // ignore
609+
0, 0, // ignore
603610
Attributes.EMPTY,
604611
channelz);
605612
}

0 commit comments

Comments
 (0)