Skip to content

Commit 7c081eb

Browse files
Googlercopybara-github
Googler
authored andcommitted
Remote: gRPC load balancing. (Part 4)
Implement DynamicConnectionPool which is built on top of SharedConnectionFactory. It creates connections on demands, applies rate limiting on the underying connection and uses Round-Robin algorithm to load balancing across multiple connections. PiperOrigin-RevId: 358116905
1 parent 6667ad7 commit 7c081eb

File tree

2 files changed

+327
-0
lines changed

2 files changed

+327
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright 2021 The Bazel Authors. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package com.google.devtools.build.lib.remote.grpc;
15+
16+
import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
17+
import io.reactivex.rxjava3.core.Single;
18+
import java.io.IOException;
19+
import java.util.ArrayList;
20+
import java.util.concurrent.atomic.AtomicBoolean;
21+
import javax.annotation.concurrent.GuardedBy;
22+
23+
/**
24+
* A {@link ConnectionPool} that creates new connection with given {@link ConnectionFactory} on
25+
* demand and applies rate limiting w.r.t {@code maxConcurrencyPerConnection} for one underlying
26+
* connection. It also uses Round-Robin algorithm to load balancing between underlying connections.
27+
*
28+
* <p>Connections must be closed with {@link Connection#close()} in order to be reused later.
29+
*/
30+
public class DynamicConnectionPool implements ConnectionPool {
31+
private final ConnectionFactory connectionFactory;
32+
private final int maxConcurrencyPerConnection;
33+
private final AtomicBoolean closed = new AtomicBoolean(false);
34+
35+
@GuardedBy("this")
36+
private final ArrayList<SharedConnectionFactory> factories;
37+
38+
@GuardedBy("this")
39+
private int indexTicker = 0;
40+
41+
public DynamicConnectionPool(
42+
ConnectionFactory connectionFactory, int maxConcurrencyPerConnection) {
43+
this.connectionFactory = connectionFactory;
44+
this.maxConcurrencyPerConnection = maxConcurrencyPerConnection;
45+
this.factories = new ArrayList<>();
46+
}
47+
48+
@Override
49+
public void close() throws IOException {
50+
if (closed.compareAndSet(false, true)) {
51+
synchronized (this) {
52+
for (SharedConnectionFactory factory : factories) {
53+
factory.close();
54+
}
55+
factories.clear();
56+
}
57+
}
58+
}
59+
60+
/**
61+
* Performs a simple round robin on the list of {@link SharedConnectionFactory} and return one
62+
* having available connections at this moment.
63+
*
64+
* <p>If no factory has available connections, it will create a new {@link
65+
* SharedConnectionFactory}.
66+
*/
67+
private SharedConnectionFactory nextAvailableFactory() {
68+
if (closed.get()) {
69+
throw new IllegalStateException("closed");
70+
}
71+
72+
synchronized (this) {
73+
for (int times = 0; times < factories.size(); ++times) {
74+
int index = Math.abs(indexTicker % factories.size());
75+
indexTicker += 1;
76+
77+
SharedConnectionFactory factory = factories.get(index);
78+
if (factory.numAvailableConnections() > 0) {
79+
return factory;
80+
}
81+
}
82+
83+
SharedConnectionFactory factory =
84+
new SharedConnectionFactory(connectionFactory, maxConcurrencyPerConnection);
85+
factories.add(factory);
86+
return factory;
87+
}
88+
}
89+
90+
@Override
91+
public Single<SharedConnection> create() {
92+
return Single.defer(
93+
() -> {
94+
SharedConnectionFactory factory = nextAvailableFactory();
95+
return factory.create();
96+
});
97+
}
98+
}
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// Copyright 2021 The Bazel Authors. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package com.google.devtools.build.lib.remote.grpc;
15+
16+
import static com.google.common.truth.Truth.assertThat;
17+
import static org.mockito.Mockito.mock;
18+
import static org.mockito.Mockito.times;
19+
import static org.mockito.Mockito.verify;
20+
import static org.mockito.Mockito.when;
21+
22+
import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
23+
import io.reactivex.rxjava3.core.Single;
24+
import io.reactivex.rxjava3.observers.TestObserver;
25+
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
26+
import java.io.IOException;
27+
import java.util.concurrent.Semaphore;
28+
import java.util.concurrent.atomic.AtomicBoolean;
29+
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.concurrent.atomic.AtomicReference;
31+
import org.junit.After;
32+
import org.junit.Before;
33+
import org.junit.Rule;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.JUnit4;
37+
import org.mockito.Mock;
38+
import org.mockito.junit.MockitoJUnit;
39+
import org.mockito.junit.MockitoRule;
40+
41+
/** Tests for {@link DynamicConnectionPool}. */
42+
@RunWith(JUnit4.class)
43+
public class DynamicConnectionPoolTest {
44+
@Rule public final MockitoRule mockito = MockitoJUnit.rule();
45+
private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null);
46+
47+
@Mock private Connection connection0;
48+
@Mock private Connection connection1;
49+
@Mock private ConnectionFactory connectionFactory;
50+
private final AtomicInteger connectionFactoryCreateTimes = new AtomicInteger(0);
51+
52+
@Before
53+
public void setUp() {
54+
RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);
55+
56+
when(connectionFactory.create())
57+
.thenAnswer(
58+
invocation -> {
59+
int times = connectionFactoryCreateTimes.getAndIncrement();
60+
if (times == 0) {
61+
return Single.just(connection0);
62+
} else {
63+
return Single.just(connection1);
64+
}
65+
});
66+
}
67+
68+
@After
69+
public void tearDown() throws Throwable {
70+
// Make sure rxjava didn't receive global errors
71+
Throwable t = rxGlobalThrowable.getAndSet(null);
72+
if (t != null) {
73+
throw t;
74+
}
75+
}
76+
77+
@Test
78+
public void create_smoke() {
79+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
80+
81+
TestObserver<SharedConnection> observer = pool.create().test();
82+
83+
observer.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
84+
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
85+
}
86+
87+
@Test
88+
public void create_exceedingMaxConcurrent_createNewConnection() {
89+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
90+
91+
TestObserver<SharedConnection> observer0 = pool.create().test();
92+
TestObserver<SharedConnection> observer1 = pool.create().test();
93+
94+
observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
95+
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete();
96+
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2);
97+
}
98+
99+
@Test
100+
public void create_pendingConnectionCreationAndExceedingMaxConcurrent_createNewConnection() {
101+
AtomicBoolean terminated = new AtomicBoolean(false);
102+
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
103+
when(connectionFactory.create())
104+
.thenAnswer(
105+
invocation -> {
106+
if (connectionFactoryCreateTimes.getAndIncrement() == 0) {
107+
return Single.create(
108+
emitter -> {
109+
Thread t =
110+
new Thread(
111+
() -> {
112+
try {
113+
Thread.sleep(Integer.MAX_VALUE);
114+
emitter.onSuccess(connection0);
115+
} catch (InterruptedException e) {
116+
emitter.onError(e);
117+
}
118+
terminated.set(true);
119+
});
120+
t.start();
121+
});
122+
} else {
123+
return Single.just(connection1);
124+
}
125+
});
126+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
127+
128+
TestObserver<SharedConnection> observer0 = pool.create().test();
129+
TestObserver<SharedConnection> observer1 = pool.create().test();
130+
131+
assertThat(terminated.get()).isFalse();
132+
observer0.assertEmpty();
133+
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete();
134+
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2);
135+
}
136+
137+
@Test
138+
public void create_belowMaxConcurrency_shareConnections() {
139+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 2);
140+
141+
TestObserver<SharedConnection> observer0 = pool.create().test();
142+
TestObserver<SharedConnection> observer1 = pool.create().test();
143+
144+
observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
145+
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
146+
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
147+
}
148+
149+
@Test
150+
public void create_afterConnectionClosed_shareConnections() throws IOException {
151+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
152+
TestObserver<SharedConnection> observer0 = pool.create().test();
153+
observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
154+
observer0.values().get(0).close();
155+
156+
TestObserver<SharedConnection> observer1 = pool.create().test();
157+
158+
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
159+
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
160+
}
161+
162+
@Test
163+
public void closePool_noNewConnectionAllowed() throws IOException {
164+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
165+
pool.close();
166+
167+
TestObserver<SharedConnection> observer = pool.create().test();
168+
169+
observer
170+
.assertError(IllegalStateException.class)
171+
.assertError(e -> e.getMessage().contains("closed"));
172+
}
173+
174+
@Test
175+
public void closePool_closeUnderlyingConnection() throws IOException {
176+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
177+
TestObserver<SharedConnection> observer = pool.create().test();
178+
observer.assertComplete();
179+
180+
pool.close();
181+
182+
verify(connection0, times(1)).close();
183+
}
184+
185+
@Test
186+
public void closePool_pendingConnectionCreation_closedError()
187+
throws IOException, InterruptedException {
188+
AtomicBoolean canceled = new AtomicBoolean(false);
189+
AtomicBoolean finished = new AtomicBoolean(false);
190+
Semaphore terminated = new Semaphore(0);
191+
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
192+
when(connectionFactory.create())
193+
.thenAnswer(
194+
invocation ->
195+
Single.create(
196+
emitter -> {
197+
Thread t =
198+
new Thread(
199+
() -> {
200+
try {
201+
Thread.sleep(Integer.MAX_VALUE);
202+
finished.set(true);
203+
emitter.onSuccess(connection0);
204+
} catch (InterruptedException ignored) {
205+
/* no-op */
206+
}
207+
208+
terminated.release();
209+
});
210+
t.start();
211+
212+
emitter.setCancellable(t::interrupt);
213+
})
214+
.doOnDispose(() -> canceled.set(true)));
215+
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
216+
TestObserver<SharedConnection> observer = pool.create().test();
217+
observer.assertEmpty();
218+
219+
assertThat(canceled.get()).isFalse();
220+
pool.close();
221+
222+
terminated.acquire();
223+
observer
224+
.assertError(IllegalStateException.class)
225+
.assertError(e -> e.getMessage().contains("closed"));
226+
assertThat(canceled.get()).isTrue();
227+
assertThat(finished.get()).isFalse();
228+
}
229+
}

0 commit comments

Comments
 (0)