Skip to content

Commit fbdac74

Browse files
authored
gloo: async
Differential Revision: D69698406 Pull Request resolved: #418
1 parent 08c094b commit fbdac74

27 files changed

+259
-77
lines changed

gloo/test/allgather_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ TEST_F(AllgatherNewTest, TestTimeout) {
167167
AllgatherOptions opts(context);
168168
opts.setInput(input.getPointer(), 1);
169169
opts.setOutput(output.getPointer(), context->size);
170+
171+
// Run one operation first so we're measuring the operation timeout not
172+
// connection timeout.
173+
allgather(opts);
174+
170175
opts.setTimeout(std::chrono::milliseconds(10));
171176
if (context->rank == 0) {
172177
try {

gloo/test/allgatherv_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ TEST_F(AllgathervTest, TestTimeout) {
9595
std::vector<size_t> counts({1, 1});
9696
AllgathervOptions opts(context);
9797
opts.setOutput(output.getPointer(), counts);
98+
99+
// Run one operation first so we're measuring the operation timeout not
100+
// connection timeout.
101+
allgatherv(opts);
102+
98103
opts.setTimeout(std::chrono::milliseconds(10));
99104
if (context->rank == 0) {
100105
try {

gloo/test/allreduce_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,11 @@ TEST_F(AllreduceNewTest, TestTimeout) {
389389
AllreduceOptions opts(context);
390390
opts.setOutputs(outputs.getPointers(), 1);
391391
opts.setReduceFunction(getFunction<uint64_t>());
392+
393+
// Run one operation first so we're measuring the operation timeout not
394+
// connection timeout.
395+
allreduce(opts);
396+
392397
opts.setTimeout(std::chrono::milliseconds(10));
393398
if (context->rank == 0) {
394399
try {

gloo/test/barrier_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,18 @@ INSTANTIATE_TEST_CASE_P(
127127
TEST_F(BarrierNewTest, TestTimeout) {
128128
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
129129
BarrierOptions opts(context);
130+
131+
// Run barrier first so we're measuring the barrier timeout not connection
132+
// timeout.
133+
barrier(opts);
134+
130135
opts.setTimeout(std::chrono::milliseconds(10));
131136
if (context->rank == 0) {
132137
try {
133138
barrier(opts);
134139
FAIL() << "Expected exception to be thrown";
135140
} catch (::gloo::IoException& e) {
141+
std::cerr << e.what() << std::endl;
136142
ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos);
137143
}
138144
}

gloo/test/base_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const char* kDefaultDevice = "localhost";
1717
// Transports that instantiated algorithms can be tested against.
1818
const std::vector<Transport> kTransportsForClassAlgorithms = {
1919
Transport::TCP,
20+
Transport::TCP_LAZY,
2021
#if GLOO_HAVE_TRANSPORT_TCP_TLS
2122
Transport::TCP_TLS,
2223
#endif
@@ -27,6 +28,7 @@ const std::vector<Transport> kTransportsForClassAlgorithms = {
2728
// preferred over the instantiated style.
2829
const std::vector<Transport> kTransportsForFunctionAlgorithms = {
2930
Transport::TCP,
31+
Transport::TCP_LAZY,
3032
#if GLOO_HAVE_TRANSPORT_TCP_TLS
3133
Transport::TCP_TLS,
3234
#endif
@@ -37,6 +39,8 @@ std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
3739
#if GLOO_HAVE_TRANSPORT_TCP
3840
if (transport == Transport::TCP) {
3941
return ::gloo::transport::tcp::CreateDevice(kDefaultDevice);
42+
} else if (transport == Transport::TCP_LAZY) {
43+
return ::gloo::transport::tcp::CreateLazyDevice(kDefaultDevice);
4044
}
4145
#endif
4246
#if GLOO_HAVE_TRANSPORT_TCP_TLS

gloo/test/base_test.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class Barrier {
5959

6060
enum Transport {
6161
TCP,
62+
TCP_LAZY,
6263
#if GLOO_HAVE_TRANSPORT_TCP_TLS
6364
TCP_TLS,
6465
#endif

gloo/test/broadcast_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ TEST_F(BroadcastTest, TestTimeout) {
182182
BroadcastOptions opts(context);
183183
opts.setOutput(output.getPointer(), 1);
184184
opts.setRoot(0);
185+
186+
// Run one operation first so we're measuring the operation timeout not
187+
// connection timeout.
188+
broadcast(opts);
189+
185190
opts.setTimeout(std::chrono::milliseconds(10));
186191
if (context->rank == 0) {
187192
try {

gloo/test/gather_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ TEST_F(GatherTest, TestTimeout) {
7676
opts.setInput(input.getPointer(), 1);
7777
opts.setOutput(output.getPointer(), context->size);
7878
opts.setRoot(0);
79+
80+
// Run one operation first so we're measuring the operation timeout not
81+
// connection timeout.
82+
gather(opts);
83+
7984
opts.setTimeout(std::chrono::milliseconds(10));
8085
if (context->rank == 0) {
8186
try {

gloo/test/gatherv_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ TEST_F(GathervTest, TestTimeout) {
106106
opts.setRoot(0);
107107
opts.setInput(input.getPointer(), 1);
108108
opts.setOutput(output.getPointer(), counts);
109+
110+
// Run one operation first so we're measuring the operation timeout not
111+
// connection timeout.
112+
gatherv(opts);
113+
109114
opts.setTimeout(std::chrono::milliseconds(10));
110115
if (context->rank == 0) {
111116
try {

gloo/test/multiproc_test.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,31 @@ class MultiProcWorker {
102102
auto device = createDevice(transport);
103103
context->setTimeout(std::chrono::milliseconds(kMultiProcTimeout));
104104
context->connectFullMesh(store_, device);
105+
106+
// Wait for all workers to be ready
107+
ringBarrier(context);
108+
105109
device.reset();
106110
sem_post(semaphore_);
107111
fn(std::move(context));
108112
}
109113

114+
void ringBarrier(std::shared_ptr<::gloo::rendezvous::Context>& context) {
115+
int sendScratch = 0;
116+
int recvScratch = 0;
117+
auto sendBuf =
118+
context->createUnboundBuffer(&sendScratch, sizeof(sendScratch));
119+
auto recvBuf =
120+
context->createUnboundBuffer(&recvScratch, sizeof(recvScratch));
121+
const auto leftRank = (context->size + context->rank - 1) % context->size;
122+
const auto rightRank = (context->rank + 1) % context->size;
123+
124+
sendBuf->send(leftRank, 0);
125+
recvBuf->recv(rightRank, 0);
126+
sendBuf->waitSend();
127+
recvBuf->waitRecv();
128+
}
129+
110130
protected:
111131
std::shared_ptr<::gloo::rendezvous::Store> store_;
112132
sem_t* semaphore_;

gloo/test/reduce_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ TEST_F(ReduceTest, TestTimeout) {
101101
opts.setOutput(outputs.getPointer(), 1);
102102
opts.setRoot(0);
103103
opts.setReduceFunction(getFunction<uint64_t>());
104+
105+
// Run one operation first so we're measuring the operation timeout not
106+
// connection timeout.
107+
reduce(opts);
108+
104109
opts.setTimeout(std::chrono::milliseconds(10));
105110
if (context->rank == 0) {
106111
try {

gloo/test/scatter_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ TEST_F(ScatterTest, TestTimeout) {
7272
opts.setInputs(input.getPointers(), 1);
7373
opts.setOutput(output.getPointer(), 1);
7474
opts.setRoot(0);
75+
76+
// Run one operation first so we're measuring the operation timeout not
77+
// connection timeout.
78+
scatter(opts);
79+
7580
opts.setTimeout(std::chrono::milliseconds(10));
7681
if (context->rank == 0) {
7782
try {

gloo/test/tcp_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace transport {
88
namespace tcp {
99

1010
TEST(TcpTest, ConnectTimeout) {
11-
auto loop = std::make_shared<Loop>();
11+
Loop loop;
1212

1313
std::mutex m;
1414
std::condition_variable cv;
@@ -25,7 +25,7 @@ TEST(TcpTest, ConnectTimeout) {
2525
EXPECT_TRUE(e);
2626
EXPECT_TRUE(dynamic_cast<const TimeoutError*>(&e));
2727
};
28-
connectLoop(*loop, remote, 0, 5, timeout, std::move(fn));
28+
connectLoop(loop, remote, 0, 5, timeout, std::move(fn));
2929

3030
std::unique_lock<std::mutex> lock(m);
3131
cv.wait(lock, [&] { return done; });

gloo/transport/tcp/context.cc

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ namespace tcp {
2727
constexpr int kDefaultBatchSize = 128;
2828

2929
Context::Context(std::shared_ptr<Device> device, int rank, int size)
30-
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
30+
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {
31+
connecting_.resize(size);
32+
}
3133

3234
Context::~Context() {
35+
// We need to shutdown the loop thread prior to freeing the pairs.
36+
device_->shutdown();
37+
3338
// Pairs refer to device by raw pointer.
3439
// Ensure they are destructed before the device.
3540
pairs_.clear();
@@ -50,8 +55,6 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
5055
// it's not super straightforward so left for folks having more bandwidth
5156
// later on.
5257

53-
int localRank = 0;
54-
bool localRankSet = false;
5558
auto localHostName = getHostname();
5659
bool useRankAsSeqNum = useRankAsSeqNumber();
5760

@@ -61,7 +64,7 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
6164
std::vector<ssize_t> pairIdentifiers;
6265
for (int i = 0; i < size; i++) {
6366
const auto& pair = createPair(i, useRankAsSeqNum);
64-
if (!useRankAsSeqNum) {
67+
if (!useRankAsSeqNum && !device_->isLazyInit()) {
6568
// Need to preserve the order of the pair identifiers if we are not using
6669
// the rank as seq number
6770
pairIdentifiers.emplace_back(
@@ -88,67 +91,103 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
8891
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
8992
store->set(std::to_string(rank), currentRankInfo.bytes());
9093

91-
std::vector<std::vector<char>> remoteRankInfos;
92-
int key = 0;
93-
if (isStoreExtendedApiEnabled() && store->has_v2_support()) {
94-
auto sizeRemaining = size;
95-
while (sizeRemaining > 0) {
96-
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
97-
std::vector<std::string> keys(batchKeys);
98-
std::generate_n(
99-
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
100-
const auto& batchRemoteInfos = store->multi_get(keys);
101-
remoteRankInfos.insert(
102-
remoteRankInfos.end(),
103-
batchRemoteInfos.begin(),
104-
batchRemoteInfos.end());
105-
sizeRemaining -= batchKeys;
94+
store_ = store;
95+
96+
if (!device_->isLazyInit()) {
97+
int localRank = 0;
98+
bool localRankSet = false;
99+
std::vector<std::vector<char>> remoteRankInfos;
100+
int key = 0;
101+
if (isStoreExtendedApiEnabled() && store->has_v2_support()) {
102+
auto sizeRemaining = size;
103+
while (sizeRemaining > 0) {
104+
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
105+
std::vector<std::string> keys(batchKeys);
106+
std::generate_n(
107+
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
108+
const auto& batchRemoteInfos = store->multi_get(keys);
109+
remoteRankInfos.insert(
110+
remoteRankInfos.end(),
111+
batchRemoteInfos.begin(),
112+
batchRemoteInfos.end());
113+
sizeRemaining -= batchKeys;
114+
}
115+
} else {
116+
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
117+
const auto& keyStr = std::to_string(key++);
118+
return store->wait_get(keyStr, getTimeout());
119+
});
106120
}
107-
} else {
108-
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
109-
const auto& keyStr = std::to_string(key++);
110-
return store->wait_get(keyStr, getTimeout());
111-
});
112-
}
113121

114-
// Connect every pair
115-
for (int i = 0; i < size; i++) {
116-
if (i == rank) {
117-
// at this point we have enumerated all the ranks located on this host
118-
// up to the current rank, so the current `localRank` number is
119-
// what we'll set to the pairs.
120-
localRankSet = true;
121-
// We are not going to connect self.
122-
continue;
123-
}
122+
// Connect every pair
123+
for (int i = 0; i < size; i++) {
124+
if (i == rank) {
125+
// at this point we have enumerated all the ranks located on this host
126+
// up to the current rank, so the current `localRank` number is
127+
// what we'll set to the pairs.
128+
localRankSet = true;
129+
// We are not going to connect self.
130+
continue;
131+
}
124132

125-
Rank remoteRankInfo(remoteRankInfos[i]);
133+
Rank remoteRankInfo(remoteRankInfos[i]);
126134

127-
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
128-
++localRank;
129-
}
135+
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
136+
++localRank;
137+
}
130138

131-
const auto& pair = getPair(i);
132-
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
133-
auto remoteAddr = Address(
134-
remoteDeviceAddr,
135-
useRankAsSeqNum ? (sequence_number_t)rank
136-
: remoteRankInfo.pairIdentifiers[rank]);
137-
pair->connect(remoteAddr.bytes());
138-
}
139+
const auto& pair = pairs_[i];
140+
auto remoteDeviceAddr =
141+
Address(remoteRankInfo.addressBytes).getSockaddr();
142+
auto remoteAddr = Address(
143+
remoteDeviceAddr,
144+
useRankAsSeqNum ? (sequence_number_t)rank
145+
: remoteRankInfo.pairIdentifiers[rank]);
146+
pair->connect(remoteAddr.bytes());
147+
connecting_[i] = true;
148+
}
139149

140-
// Set the local rank info for all mesh pairs involving current rank
141-
for (int i = 0; i < size; i++) {
142-
if (i == rank) {
143-
continue;
150+
// Set the local rank info for all mesh pairs involving current rank
151+
for (int i = 0; i < size; i++) {
152+
if (i == rank) {
153+
continue;
154+
}
155+
const auto& pair = getPair(i);
156+
pair->setLocalRank(localRank);
144157
}
145-
const auto& pair = getPair(i);
146-
pair->setLocalRank(localRank);
147158
}
148159

149160
printConnectivityInfo();
150161
}
151162

163+
std::unique_ptr<transport::Pair>& Context::getPair(int rank) {
164+
auto& pair = pairs_[rank];
165+
166+
if (!store_) {
167+
// Manual context creation without store to bootstrap.
168+
return pair;
169+
}
170+
171+
// don't connect to self
172+
if (rank == this->rank) {
173+
return pair;
174+
}
175+
176+
if (!connecting_[rank]) {
177+
connecting_[rank] = true;
178+
179+
const auto& keyStr = std::to_string(rank);
180+
auto remoteRankInfoBytes = store_->wait_get(keyStr, getTimeout());
181+
182+
Rank remoteRankInfo(remoteRankInfoBytes);
183+
184+
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
185+
auto remoteAddr = Address(remoteDeviceAddr, this->rank);
186+
pair->connect(remoteAddr.bytes());
187+
}
188+
return pair;
189+
}
190+
152191
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
153192
pairs_[rank] = std::unique_ptr<transport::Pair>(
154193
new tcp::Pair(this, device_.get(), rank, getTimeout(), false));
@@ -219,6 +258,11 @@ void Context::recvFromAny(
219258
size_t offset,
220259
size_t nbytes,
221260
std::vector<int> srcRanks) {
261+
// Ensure all connections are established.
262+
for (auto rank : srcRanks) {
263+
getPair(rank);
264+
}
265+
222266
for (;;) {
223267
// Find rank of pair we can attempt a recv from
224268
auto rank = recvFromAnyFindRank(buf, slot, offset, nbytes, srcRanks);

0 commit comments

Comments
 (0)