Skip to content

Commit ae9b62a

Browse files
authored
Revert "gloo: async" (#425)
This reverts commit fbdac74. fbshipit-source-id: 4cfad1d7b70082f5ebbe90aa76e2ca88bb4565f5
1 parent fbdac74 commit ae9b62a

27 files changed

+77
-259
lines changed

gloo/test/allgather_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,6 @@ TEST_F(AllgatherNewTest, TestTimeout) {
167167
AllgatherOptions opts(context);
168168
opts.setInput(input.getPointer(), 1);
169169
opts.setOutput(output.getPointer(), context->size);
170-
171-
// Run one operation first so we're measuring the operation timeout not
172-
// connection timeout.
173-
allgather(opts);
174-
175170
opts.setTimeout(std::chrono::milliseconds(10));
176171
if (context->rank == 0) {
177172
try {

gloo/test/allgatherv_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ TEST_F(AllgathervTest, TestTimeout) {
9595
std::vector<size_t> counts({1, 1});
9696
AllgathervOptions opts(context);
9797
opts.setOutput(output.getPointer(), counts);
98-
99-
// Run one operation first so we're measuring the operation timeout not
100-
// connection timeout.
101-
allgatherv(opts);
102-
10398
opts.setTimeout(std::chrono::milliseconds(10));
10499
if (context->rank == 0) {
105100
try {

gloo/test/allreduce_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,6 @@ TEST_F(AllreduceNewTest, TestTimeout) {
389389
AllreduceOptions opts(context);
390390
opts.setOutputs(outputs.getPointers(), 1);
391391
opts.setReduceFunction(getFunction<uint64_t>());
392-
393-
// Run one operation first so we're measuring the operation timeout not
394-
// connection timeout.
395-
allreduce(opts);
396-
397392
opts.setTimeout(std::chrono::milliseconds(10));
398393
if (context->rank == 0) {
399394
try {

gloo/test/barrier_test.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,12 @@ INSTANTIATE_TEST_CASE_P(
127127
TEST_F(BarrierNewTest, TestTimeout) {
128128
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
129129
BarrierOptions opts(context);
130-
131-
// Run barrier first so we're measuring the barrier timeout not connection
132-
// timeout.
133-
barrier(opts);
134-
135130
opts.setTimeout(std::chrono::milliseconds(10));
136131
if (context->rank == 0) {
137132
try {
138133
barrier(opts);
139134
FAIL() << "Expected exception to be thrown";
140135
} catch (::gloo::IoException& e) {
141-
std::cerr << e.what() << std::endl;
142136
ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos);
143137
}
144138
}

gloo/test/base_test.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ const char* kDefaultDevice = "localhost";
1717
// Transports that instantiated algorithms can be tested against.
1818
const std::vector<Transport> kTransportsForClassAlgorithms = {
1919
Transport::TCP,
20-
Transport::TCP_LAZY,
2120
#if GLOO_HAVE_TRANSPORT_TCP_TLS
2221
Transport::TCP_TLS,
2322
#endif
@@ -28,7 +27,6 @@ const std::vector<Transport> kTransportsForClassAlgorithms = {
2827
// preferred over the instantiated style.
2928
const std::vector<Transport> kTransportsForFunctionAlgorithms = {
3029
Transport::TCP,
31-
Transport::TCP_LAZY,
3230
#if GLOO_HAVE_TRANSPORT_TCP_TLS
3331
Transport::TCP_TLS,
3432
#endif
@@ -39,8 +37,6 @@ std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
3937
#if GLOO_HAVE_TRANSPORT_TCP
4038
if (transport == Transport::TCP) {
4139
return ::gloo::transport::tcp::CreateDevice(kDefaultDevice);
42-
} else if (transport == Transport::TCP_LAZY) {
43-
return ::gloo::transport::tcp::CreateLazyDevice(kDefaultDevice);
4440
}
4541
#endif
4642
#if GLOO_HAVE_TRANSPORT_TCP_TLS

gloo/test/base_test.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class Barrier {
5959

6060
enum Transport {
6161
TCP,
62-
TCP_LAZY,
6362
#if GLOO_HAVE_TRANSPORT_TCP_TLS
6463
TCP_TLS,
6564
#endif

gloo/test/broadcast_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,6 @@ TEST_F(BroadcastTest, TestTimeout) {
182182
BroadcastOptions opts(context);
183183
opts.setOutput(output.getPointer(), 1);
184184
opts.setRoot(0);
185-
186-
// Run one operation first so we're measuring the operation timeout not
187-
// connection timeout.
188-
broadcast(opts);
189-
190185
opts.setTimeout(std::chrono::milliseconds(10));
191186
if (context->rank == 0) {
192187
try {

gloo/test/gather_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ TEST_F(GatherTest, TestTimeout) {
7676
opts.setInput(input.getPointer(), 1);
7777
opts.setOutput(output.getPointer(), context->size);
7878
opts.setRoot(0);
79-
80-
// Run one operation first so we're measuring the operation timeout not
81-
// connection timeout.
82-
gather(opts);
83-
8479
opts.setTimeout(std::chrono::milliseconds(10));
8580
if (context->rank == 0) {
8681
try {

gloo/test/gatherv_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,6 @@ TEST_F(GathervTest, TestTimeout) {
106106
opts.setRoot(0);
107107
opts.setInput(input.getPointer(), 1);
108108
opts.setOutput(output.getPointer(), counts);
109-
110-
// Run one operation first so we're measuring the operation timeout not
111-
// connection timeout.
112-
gatherv(opts);
113-
114109
opts.setTimeout(std::chrono::milliseconds(10));
115110
if (context->rank == 0) {
116111
try {

gloo/test/multiproc_test.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,11 @@ class MultiProcWorker {
102102
auto device = createDevice(transport);
103103
context->setTimeout(std::chrono::milliseconds(kMultiProcTimeout));
104104
context->connectFullMesh(store_, device);
105-
106-
// Wait for all workers to be ready
107-
ringBarrier(context);
108-
109105
device.reset();
110106
sem_post(semaphore_);
111107
fn(std::move(context));
112108
}
113109

114-
void ringBarrier(std::shared_ptr<::gloo::rendezvous::Context>& context) {
115-
int sendScratch = 0;
116-
int recvScratch = 0;
117-
auto sendBuf =
118-
context->createUnboundBuffer(&sendScratch, sizeof(sendScratch));
119-
auto recvBuf =
120-
context->createUnboundBuffer(&recvScratch, sizeof(recvScratch));
121-
const auto leftRank = (context->size + context->rank - 1) % context->size;
122-
const auto rightRank = (context->rank + 1) % context->size;
123-
124-
sendBuf->send(leftRank, 0);
125-
recvBuf->recv(rightRank, 0);
126-
sendBuf->waitSend();
127-
recvBuf->waitRecv();
128-
}
129-
130110
protected:
131111
std::shared_ptr<::gloo::rendezvous::Store> store_;
132112
sem_t* semaphore_;

gloo/test/reduce_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,6 @@ TEST_F(ReduceTest, TestTimeout) {
101101
opts.setOutput(outputs.getPointer(), 1);
102102
opts.setRoot(0);
103103
opts.setReduceFunction(getFunction<uint64_t>());
104-
105-
// Run one operation first so we're measuring the operation timeout not
106-
// connection timeout.
107-
reduce(opts);
108-
109104
opts.setTimeout(std::chrono::milliseconds(10));
110105
if (context->rank == 0) {
111106
try {

gloo/test/scatter_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,6 @@ TEST_F(ScatterTest, TestTimeout) {
7272
opts.setInputs(input.getPointers(), 1);
7373
opts.setOutput(output.getPointer(), 1);
7474
opts.setRoot(0);
75-
76-
// Run one operation first so we're measuring the operation timeout not
77-
// connection timeout.
78-
scatter(opts);
79-
8075
opts.setTimeout(std::chrono::milliseconds(10));
8176
if (context->rank == 0) {
8277
try {

gloo/test/tcp_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace transport {
88
namespace tcp {
99

1010
TEST(TcpTest, ConnectTimeout) {
11-
Loop loop;
11+
auto loop = std::make_shared<Loop>();
1212

1313
std::mutex m;
1414
std::condition_variable cv;
@@ -25,7 +25,7 @@ TEST(TcpTest, ConnectTimeout) {
2525
EXPECT_TRUE(e);
2626
EXPECT_TRUE(dynamic_cast<const TimeoutError*>(&e));
2727
};
28-
connectLoop(loop, remote, 0, 5, timeout, std::move(fn));
28+
connectLoop(*loop, remote, 0, 5, timeout, std::move(fn));
2929

3030
std::unique_lock<std::mutex> lock(m);
3131
cv.wait(lock, [&] { return done; });

gloo/transport/tcp/context.cc

Lines changed: 53 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,9 @@ namespace tcp {
2727
constexpr int kDefaultBatchSize = 128;
2828

2929
Context::Context(std::shared_ptr<Device> device, int rank, int size)
30-
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {
31-
connecting_.resize(size);
32-
}
30+
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
3331

3432
Context::~Context() {
35-
// We need to shutdown the loop thread prior to freeing the pairs.
36-
device_->shutdown();
37-
3833
// Pairs refer to device by raw pointer.
3934
// Ensure they are destructed before the device.
4035
pairs_.clear();
@@ -55,6 +50,8 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
5550
// it's not super straightforward so left for folks having more bandwidth
5651
// later on.
5752

53+
int localRank = 0;
54+
bool localRankSet = false;
5855
auto localHostName = getHostname();
5956
bool useRankAsSeqNum = useRankAsSeqNumber();
6057

@@ -64,7 +61,7 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
6461
std::vector<ssize_t> pairIdentifiers;
6562
for (int i = 0; i < size; i++) {
6663
const auto& pair = createPair(i, useRankAsSeqNum);
67-
if (!useRankAsSeqNum && !device_->isLazyInit()) {
64+
if (!useRankAsSeqNum) {
6865
// Need to preserve the order of the pair identifiers if we are not using
6966
// the rank as seq number
7067
pairIdentifiers.emplace_back(
@@ -91,101 +88,65 @@ void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
9188
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
9289
store->set(std::to_string(rank), currentRankInfo.bytes());
9390

94-
store_ = store;
95-
96-
if (!device_->isLazyInit()) {
97-
int localRank = 0;
98-
bool localRankSet = false;
99-
std::vector<std::vector<char>> remoteRankInfos;
100-
int key = 0;
101-
if (isStoreExtendedApiEnabled() && store->has_v2_support()) {
102-
auto sizeRemaining = size;
103-
while (sizeRemaining > 0) {
104-
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
105-
std::vector<std::string> keys(batchKeys);
106-
std::generate_n(
107-
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
108-
const auto& batchRemoteInfos = store->multi_get(keys);
109-
remoteRankInfos.insert(
110-
remoteRankInfos.end(),
111-
batchRemoteInfos.begin(),
112-
batchRemoteInfos.end());
113-
sizeRemaining -= batchKeys;
114-
}
115-
} else {
116-
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
117-
const auto& keyStr = std::to_string(key++);
118-
return store->wait_get(keyStr, getTimeout());
119-
});
120-
}
121-
122-
// Connect every pair
123-
for (int i = 0; i < size; i++) {
124-
if (i == rank) {
125-
// at this point we have enumerated all the ranks located on this host
126-
// up to the current rank, so the current `localRank` number is
127-
// what we'll set to the pairs.
128-
localRankSet = true;
129-
// We are not going to connect self.
130-
continue;
131-
}
132-
133-
Rank remoteRankInfo(remoteRankInfos[i]);
134-
135-
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
136-
++localRank;
137-
}
138-
139-
const auto& pair = pairs_[i];
140-
auto remoteDeviceAddr =
141-
Address(remoteRankInfo.addressBytes).getSockaddr();
142-
auto remoteAddr = Address(
143-
remoteDeviceAddr,
144-
useRankAsSeqNum ? (sequence_number_t)rank
145-
: remoteRankInfo.pairIdentifiers[rank]);
146-
pair->connect(remoteAddr.bytes());
147-
connecting_[i] = true;
91+
std::vector<std::vector<char>> remoteRankInfos;
92+
int key = 0;
93+
if (isStoreExtendedApiEnabled() && store->has_v2_support()) {
94+
auto sizeRemaining = size;
95+
while (sizeRemaining > 0) {
96+
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
97+
std::vector<std::string> keys(batchKeys);
98+
std::generate_n(
99+
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
100+
const auto& batchRemoteInfos = store->multi_get(keys);
101+
remoteRankInfos.insert(
102+
remoteRankInfos.end(),
103+
batchRemoteInfos.begin(),
104+
batchRemoteInfos.end());
105+
sizeRemaining -= batchKeys;
148106
}
107+
} else {
108+
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
109+
const auto& keyStr = std::to_string(key++);
110+
return store->wait_get(keyStr, getTimeout());
111+
});
112+
}
149113

150-
// Set the local rank info for all mesh pairs involving current rank
151-
for (int i = 0; i < size; i++) {
152-
if (i == rank) {
153-
continue;
154-
}
155-
const auto& pair = getPair(i);
156-
pair->setLocalRank(localRank);
114+
// Connect every pair
115+
for (int i = 0; i < size; i++) {
116+
if (i == rank) {
117+
// at this point we have enumerated all the ranks located on this host
118+
// up to the current rank, so the current `localRank` number is
119+
// what we'll set to the pairs.
120+
localRankSet = true;
121+
// We are not going to connect self.
122+
continue;
157123
}
158-
}
159124

160-
printConnectivityInfo();
161-
}
125+
Rank remoteRankInfo(remoteRankInfos[i]);
162126

163-
std::unique_ptr<transport::Pair>& Context::getPair(int rank) {
164-
auto& pair = pairs_[rank];
127+
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
128+
++localRank;
129+
}
165130

166-
if (!store_) {
167-
// Manual context creation without store to bootstrap.
168-
return pair;
131+
const auto& pair = getPair(i);
132+
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
133+
auto remoteAddr = Address(
134+
remoteDeviceAddr,
135+
useRankAsSeqNum ? (sequence_number_t)rank
136+
: remoteRankInfo.pairIdentifiers[rank]);
137+
pair->connect(remoteAddr.bytes());
169138
}
170139

171-
// don't connect to self
172-
if (rank == this->rank) {
173-
return pair;
140+
// Set the local rank info for all mesh pairs involving current rank
141+
for (int i = 0; i < size; i++) {
142+
if (i == rank) {
143+
continue;
144+
}
145+
const auto& pair = getPair(i);
146+
pair->setLocalRank(localRank);
174147
}
175148

176-
if (!connecting_[rank]) {
177-
connecting_[rank] = true;
178-
179-
const auto& keyStr = std::to_string(rank);
180-
auto remoteRankInfoBytes = store_->wait_get(keyStr, getTimeout());
181-
182-
Rank remoteRankInfo(remoteRankInfoBytes);
183-
184-
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
185-
auto remoteAddr = Address(remoteDeviceAddr, this->rank);
186-
pair->connect(remoteAddr.bytes());
187-
}
188-
return pair;
149+
printConnectivityInfo();
189150
}
190151

191152
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
@@ -258,11 +219,6 @@ void Context::recvFromAny(
258219
size_t offset,
259220
size_t nbytes,
260221
std::vector<int> srcRanks) {
261-
// Ensure all connections are established.
262-
for (auto rank : srcRanks) {
263-
getPair(rank);
264-
}
265-
266222
for (;;) {
267223
// Find rank of pair we can attempt a recv from
268224
auto rank = recvFromAnyFindRank(buf, slot, offset, nbytes, srcRanks);

0 commit comments

Comments
 (0)