Skip to content

Commit d2d624d

Browse files
authored
fix: Make sure outstanding RPCs count in ChannelPool can not go negative (#2185)
Add two flags wasClosed and wasReleased to ReleasingClientCall to check various scenarios. The combination of these two flags can make sure the count of outstanding RPCs can never go negative, and help us identify what exactly goes wrong next time it happens.
1 parent 860ae76 commit d2d624d

File tree

2 files changed

+125
-16
lines changed

2 files changed

+125
-16
lines changed

gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java

+42-6
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
* <p>Package-private for internal use.
6969
*/
7070
class ChannelPool extends ManagedChannel {
71-
private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
71+
@VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
7272
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);
7373

7474
private final ChannelPoolSettings settings;
@@ -421,9 +421,25 @@ private Entry getEntry(int affinity) {
421421
}
422422

423423
/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
424-
private static class Entry {
424+
static class Entry {
425425
private final ManagedChannel channel;
426-
private final AtomicInteger outstandingRpcs = new AtomicInteger(0);
426+
427+
/**
428+
* The primary purpose of keeping a count for outstanding RPCs is to track when a channel is
429+
* safe to close. In grpc, initialization & starting of rpcs is split between 2 methods:
430+
* Channel#newCall() and ClientCall#start. gRPC already has a mechanism to safely close channels
431+
* that have rpcs that have been started. However, it does not protect calls that have been
432+
* created but not started. In the sequence: Channel#newCall() Channel#shutdown()
433+
* ClientCall#Start(), gRpc will error out the call telling the caller that the channel is
434+
* shutdown.
435+
*
436+
* <p>Hence, the increment of outstanding RPCs has to happen when the ClientCall is initialized,
437+
* as part of Channel#newCall(), not after the ClientCall is started. The decrement of
438+
* outstanding RPCs has to happen when the ClientCall is closed or the ClientCall failed to
439+
* start.
440+
*/
441+
@VisibleForTesting final AtomicInteger outstandingRpcs = new AtomicInteger(0);
442+
427443
private final AtomicInteger maxOutstanding = new AtomicInteger();
428444

429445
// Flag that the channel should be closed once all of the outstanding RPC complete.
@@ -470,7 +486,7 @@ private boolean retain() {
470486
private void release() {
471487
int newCount = outstandingRpcs.decrementAndGet();
472488
if (newCount < 0) {
473-
throw new IllegalStateException("Bug: reference count is negative!: " + newCount);
489+
LOG.log(Level.WARNING, "Bug! Reference count is negative (" + newCount + ")!");
474490
}
475491

476492
// Must check outstandingRpcs after shutdownRequested (in reverse order of retain()) to ensure
@@ -526,6 +542,8 @@ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
526542
static class ReleasingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
527543
@Nullable private CancellationException cancellationException;
528544
final Entry entry;
545+
private final AtomicBoolean wasClosed = new AtomicBoolean();
546+
private final AtomicBoolean wasReleased = new AtomicBoolean();
529547

530548
public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {
531549
super(delegate);
@@ -542,17 +560,35 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
542560
new SimpleForwardingClientCallListener<RespT>(responseListener) {
543561
@Override
544562
public void onClose(Status status, Metadata trailers) {
563+
if (!wasClosed.compareAndSet(false, true)) {
564+
LOG.log(
565+
Level.WARNING,
566+
"Call is being closed more than once. Please make sure that onClose() is not being manually called.");
567+
return;
568+
}
545569
try {
546570
super.onClose(status, trailers);
547571
} finally {
548-
entry.release();
572+
if (wasReleased.compareAndSet(false, true)) {
573+
entry.release();
574+
} else {
575+
LOG.log(
576+
Level.WARNING,
577+
"Entry was released before the call is closed. This may be due to an exception on start of the call.");
578+
}
549579
}
550580
}
551581
},
552582
headers);
553583
} catch (Exception e) {
554584
// In case start failed, make sure to release
555-
entry.release();
585+
if (wasReleased.compareAndSet(false, true)) {
586+
entry.release();
587+
} else {
588+
LOG.log(
589+
Level.WARNING,
590+
"The entry is already released. This indicates that onClose() has already been called previously");
591+
}
556592
throw e;
557593
}
558594
}

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java

+83-10
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@
2929
*/
3030
package com.google.api.gax.grpc;
3131

32+
import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE;
3233
import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
3334
import static com.google.common.truth.Truth.assertThat;
3435

36+
import com.google.api.core.ApiFuture;
3537
import com.google.api.gax.grpc.testing.FakeChannelFactory;
3638
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
37-
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
3839
import com.google.api.gax.rpc.ClientContext;
3940
import com.google.api.gax.rpc.ResponseObserver;
4041
import com.google.api.gax.rpc.ServerStreamingCallSettings;
4142
import com.google.api.gax.rpc.ServerStreamingCallable;
4243
import com.google.api.gax.rpc.StreamController;
44+
import com.google.api.gax.rpc.UnaryCallSettings;
45+
import com.google.api.gax.rpc.UnaryCallable;
4346
import com.google.common.base.Preconditions;
4447
import com.google.common.collect.ImmutableList;
4548
import com.google.common.collect.Lists;
@@ -63,6 +66,9 @@
6366
import java.util.concurrent.ScheduledFuture;
6467
import java.util.concurrent.TimeUnit;
6568
import java.util.concurrent.atomic.AtomicInteger;
69+
import java.util.logging.Handler;
70+
import java.util.logging.LogRecord;
71+
import java.util.stream.Collectors;
6672
import org.junit.After;
6773
import org.junit.Assert;
6874
import org.junit.Test;
@@ -117,7 +123,7 @@ public void testRoundRobin() throws IOException {
117123

118124
private void verifyTargetChannel(
119125
ChannelPool pool, List<ManagedChannel> channels, ManagedChannel targetChannel) {
120-
MethodDescriptor<Color, Money> methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
126+
MethodDescriptor<Color, Money> methodDescriptor = METHOD_RECOGNIZE;
121127
CallOptions callOptions = CallOptions.DEFAULT;
122128
@SuppressWarnings("unchecked")
123129
ClientCall<Color, Money> expectedClientCall = Mockito.mock(ClientCall.class);
@@ -143,7 +149,7 @@ public void ensureEvenDistribution() throws InterruptedException, IOException {
143149
final ManagedChannel[] channels = new ManagedChannel[numChannels];
144150
final AtomicInteger[] counts = new AtomicInteger[numChannels];
145151

146-
final MethodDescriptor<Color, Money> methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
152+
final MethodDescriptor<Color, Money> methodDescriptor = METHOD_RECOGNIZE;
147153
final CallOptions callOptions = CallOptions.DEFAULT;
148154
@SuppressWarnings("unchecked")
149155
final ClientCall<Color, Money> clientCall = Mockito.mock(ClientCall.class);
@@ -472,23 +478,21 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro
472478
// Start the minimum number of
473479
for (int i = 0; i < 2; i++) {
474480
ClientCalls.futureUnaryCall(
475-
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
476-
Color.getDefaultInstance());
481+
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
477482
}
478483
pool.resize();
479484
assertThat(pool.entries.get()).hasSize(2);
480485

481486
// Add enough RPCs to be just at the brink of expansion
482487
for (int i = startedCalls.size(); i < 4; i++) {
483488
ClientCalls.futureUnaryCall(
484-
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
485-
Color.getDefaultInstance());
489+
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
486490
}
487491
pool.resize();
488492
assertThat(pool.entries.get()).hasSize(2);
489493

490494
// Add another RPC to push expansion
491-
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT);
495+
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT);
492496
pool.resize();
493497
assertThat(pool.entries.get()).hasSize(4); // += ChannelPool::MAX_RESIZE_DELTA
494498
assertThat(startedCalls).hasSize(5);
@@ -593,8 +597,7 @@ public void removedActiveChannelsAreShutdown() throws Exception {
593597
// Start 2 RPCs
594598
for (int i = 0; i < 2; i++) {
595599
ClientCalls.futureUnaryCall(
596-
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
597-
Color.getDefaultInstance());
600+
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
598601
}
599602
// Complete the first one
600603
@SuppressWarnings("unchecked")
@@ -663,4 +666,74 @@ public void onComplete() {}
663666
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
664667
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
665668
}
669+
670+
@Test
671+
public void testDoubleRelease() throws Exception {
672+
FakeLogHandler logHandler = new FakeLogHandler();
673+
ChannelPool.LOG.addHandler(logHandler);
674+
675+
try {
676+
// Create a fake channel pool thats backed by mock channels that simply record invocations
677+
ClientCall mockClientCall = Mockito.mock(ClientCall.class);
678+
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
679+
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
680+
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
681+
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
682+
683+
pool = ChannelPool.create(channelPoolSettings, factory);
684+
685+
// Construct a fake callable to use the channel pool
686+
ClientContext context =
687+
ClientContext.newBuilder()
688+
.setTransportChannel(GrpcTransportChannel.create(pool))
689+
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
690+
.build();
691+
692+
UnaryCallSettings<Color, Money> settings =
693+
UnaryCallSettings.<Color, Money>newUnaryCallSettingsBuilder().build();
694+
UnaryCallable<Color, Money> callable =
695+
GrpcCallableFactory.createUnaryCallable(
696+
GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context);
697+
698+
// Start the RPC
699+
ApiFuture<Money> rpcFuture =
700+
callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext());
701+
702+
// Get the server side listener and intentionally close it twice
703+
ArgumentCaptor<ClientCall.Listener<?>> clientCallListenerCaptor =
704+
ArgumentCaptor.forClass(ClientCall.Listener.class);
705+
Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any());
706+
clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata());
707+
clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata());
708+
709+
// Ensure that the channel pool properly logged the double call and kept the refCount correct
710+
assertThat(logHandler.getAllMessages())
711+
.contains(
712+
"Call is being closed more than once. Please make sure that onClose() is not being manually called.");
713+
assertThat(pool.entries.get()).hasSize(1);
714+
ChannelPool.Entry entry = pool.entries.get().get(0);
715+
assertThat(entry.outstandingRpcs.get()).isEqualTo(0);
716+
} finally {
717+
ChannelPool.LOG.removeHandler(logHandler);
718+
}
719+
}
720+
721+
private static class FakeLogHandler extends Handler {
722+
List<LogRecord> records = new ArrayList<>();
723+
724+
@Override
725+
public void publish(LogRecord record) {
726+
records.add(record);
727+
}
728+
729+
@Override
730+
public void flush() {}
731+
732+
@Override
733+
public void close() throws SecurityException {}
734+
735+
List<String> getAllMessages() {
736+
return records.stream().map(LogRecord::getMessage).collect(Collectors.toList());
737+
}
738+
}
666739
}

0 commit comments

Comments
 (0)