Skip to content

Commit df5358b

Browse files
d4l3kfacebook-github-bot
authored andcommitted
gloo: async
Differential Revision: D69698406
1 parent cbe963b commit df5358b

File tree

16 files changed

+200
-88
lines changed

16 files changed

+200
-88
lines changed

gloo/common/store.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
#pragma once
1010

1111
#include <chrono>
12+
#include <memory>
1213
#include <string>
1314
#include <vector>
1415

1516
namespace gloo {
1617

17-
class IStore {
18+
class IStore : public std::enable_shared_from_this<IStore> {
1819
public:
1920
virtual ~IStore() = default;
2021

gloo/common/utils.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ std::string getHostname() {
3131
}
3232

3333
bool useRankAsSeqNumber() {
34-
const auto& res = getenv("GLOO_ENABLE_RANK_AS_SEQUENCE_NUMBER");
35-
return res != nullptr &&
36-
(std::string(res) == "True" || std::string(res) == "1");
34+
// const auto& res = getenv("GLOO_ENABLE_RANK_AS_SEQUENCE_NUMBER");
35+
// return res != nullptr && (std::string(res) == "True" || std::string(res) ==
36+
// "1");
37+
return true;
3738
}
3839

3940
bool isStoreExtendedApiEnabled() {

gloo/test/base_test.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class BaseTest : public ::testing::Test {
106106
std::function<void(std::shared_ptr<Context>)> fn,
107107
int base = 2) {
108108
Barrier barrier(size);
109-
::gloo::rendezvous::HashStore store;
109+
auto store = std::make_shared<::gloo::rendezvous::HashStore>();
110110

111111
spawnThreads(size, [&](int rank) {
112112
auto context =
@@ -118,7 +118,7 @@ class BaseTest : public ::testing::Test {
118118
if (!device) {
119119
return;
120120
}
121-
context->connectFullMesh(store, device);
121+
context->connectFullMesh(*store, device);
122122

123123
try {
124124
fn(context);

gloo/transport/tcp/context.cc

Lines changed: 91 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ namespace gloo {
2424
namespace transport {
2525
namespace tcp {
2626

27-
constexpr int kDefaultBatchSize = 128;
28-
2927
Context::Context(std::shared_ptr<Device> device, int rank, int size)
30-
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
28+
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {
29+
connecting_.resize(size);
30+
}
3131

3232
Context::~Context() {
3333
// Pairs refer to device by raw pointer.
@@ -50,8 +50,6 @@ void Context::createAndConnectAllPairs(IStore& store) {
5050
// it's not super straightforward so left for folks having more bandwidth
5151
// later on.
5252

53-
int localRank = 0;
54-
bool localRankSet = false;
5553
auto localHostName = getHostname();
5654
bool useRankAsSeqNum = useRankAsSeqNumber();
5755

@@ -60,13 +58,15 @@ void Context::createAndConnectAllPairs(IStore& store) {
6058
// it's just to keep the later seq num matching logic simple
6159
std::vector<ssize_t> pairIdentifiers;
6260
for (int i = 0; i < size; i++) {
63-
const auto& pair = createPair(i, useRankAsSeqNum);
64-
if (!useRankAsSeqNum) {
65-
// Need to preserve the order of the pair identifiers if we are not using
66-
// the rank as seq number
67-
pairIdentifiers.emplace_back(
68-
static_cast<Pair*>(pair.get())->address().getSeq());
69-
}
61+
// const auto& pair =
62+
createPair(i, useRankAsSeqNum);
63+
// if (!useRankAsSeqNum) {
64+
// // Need to preserve the order of the pair identifiers if we are not
65+
// using
66+
// // the rank as seq number
67+
// pairIdentifiers.emplace_back(
68+
// static_cast<Pair*>(pair.get())->address().getSeq());
69+
// }
7070
}
7171

7272
// Obtain the pair object for this rank
@@ -88,68 +88,96 @@ void Context::createAndConnectAllPairs(IStore& store) {
8888
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
8989
store.set(std::to_string(rank), currentRankInfo.bytes());
9090

91-
std::vector<std::vector<char>> remoteRankInfos;
92-
int key = 0;
93-
if (isStoreExtendedApiEnabled() && store.has_v2_support()) {
94-
auto sizeRemaining = size;
95-
while (sizeRemaining > 0) {
96-
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
97-
std::vector<std::string> keys(batchKeys);
98-
std::generate_n(
99-
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
100-
const auto& batchRemoteInfos = store.multi_get(keys);
101-
remoteRankInfos.insert(
102-
remoteRankInfos.end(),
103-
batchRemoteInfos.begin(),
104-
batchRemoteInfos.end());
105-
sizeRemaining -= batchKeys;
91+
store_ = store.shared_from_this();
92+
93+
// we don't use local rank so why even set it??
94+
95+
/*
96+
std::vector<std::vector<char>> remoteRankInfos;
97+
int key = 0;
98+
if (isStoreExtendedApiEnabled() && store.has_v2_support()) {
99+
auto sizeRemaining = size;
100+
while (sizeRemaining > 0) {
101+
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
102+
std::vector<std::string> keys(batchKeys);
103+
std::generate_n(
104+
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
105+
const auto& batchRemoteInfos = store.multi_get(keys);
106+
remoteRankInfos.insert(
107+
remoteRankInfos.end(),
108+
batchRemoteInfos.begin(),
109+
batchRemoteInfos.end());
110+
sizeRemaining -= batchKeys;
111+
}
112+
} else {
113+
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
114+
const auto& keyStr = std::to_string(key++);
115+
store.wait({keyStr.c_str()}, getTimeout());
116+
return store.get(keyStr);
117+
});
106118
}
107-
} else {
108-
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
109-
const auto& keyStr = std::to_string(key++);
110-
store.wait({keyStr.c_str()}, getTimeout());
111-
return store.get(keyStr);
112-
});
113-
}
114119
115-
// Connect every pair
116-
for (int i = 0; i < size; i++) {
117-
if (i == rank) {
118-
// at this point we have enumerated all the ranks located on this host
119-
// up to the current rank, so the current `localRank` number is
120-
// what we'll set to the pairs.
121-
localRankSet = true;
122-
// We are not going to connect self.
123-
continue;
124-
}
120+
// Connect every pair
121+
for (int i = 0; i < size; i++) {
122+
if (i == rank) {
123+
// at this point we have enumerated all the ranks located on this host
124+
// up to the current rank, so the current `localRank` number is
125+
// what we'll set to the pairs.
126+
localRankSet = true;
127+
// We are not going to connect self.
128+
continue;
129+
}
125130
126-
Rank remoteRankInfo(remoteRankInfos[i]);
131+
Rank remoteRankInfo(remoteRankInfos[i]);
127132
128-
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
129-
++localRank;
130-
}
133+
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
134+
++localRank;
135+
}
131136
132-
const auto& pair = getPair(i);
133-
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
134-
auto remoteAddr = Address(
135-
remoteDeviceAddr,
136-
useRankAsSeqNum ? (sequence_number_t)rank
137-
: remoteRankInfo.pairIdentifiers[rank]);
138-
pair->connect(remoteAddr.bytes());
139-
}
137+
const auto& pair = getPair(i);
138+
auto remoteDeviceAddr =
139+
Address(remoteRankInfo.addressBytes).getSockaddr(); auto remoteAddr =
140+
Address( remoteDeviceAddr, useRankAsSeqNum ? (sequence_number_t)rank :
141+
remoteRankInfo.pairIdentifiers[rank]); pair->connect(remoteAddr.bytes());
142+
}
140143
141-
// Set the local rank info for all mesh pairs involving current rank
142-
for (int i = 0; i < size; i++) {
143-
if (i == rank) {
144-
continue;
144+
// Set the local rank info for all mesh pairs involving current rank
145+
for (int i = 0; i < size; i++) {
146+
if (i == rank) {
147+
continue;
148+
}
149+
const auto& pair = getPair(i);
150+
pair->setLocalRank(localRank);
145151
}
146-
const auto& pair = getPair(i);
147-
pair->setLocalRank(localRank);
148-
}
152+
*/
149153

150154
printConnectivityInfo();
151155
}
152156

157+
std::unique_ptr<transport::Pair>& Context::getPair(int rank) {
158+
auto& pair = pairs_[rank];
159+
160+
// don't connect to self
161+
if (rank == this->rank) {
162+
return pair;
163+
}
164+
165+
if (!connecting_[rank]) {
166+
connecting_[rank] = true;
167+
168+
const auto& keyStr = std::to_string(rank);
169+
store_->wait({keyStr.c_str()}, getTimeout());
170+
auto remoteRankInfoBytes = store_->get(keyStr);
171+
172+
Rank remoteRankInfo(remoteRankInfoBytes);
173+
174+
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
175+
auto remoteAddr = Address(remoteDeviceAddr, this->rank);
176+
pair->connect(remoteAddr.bytes());
177+
}
178+
return pair;
179+
}
180+
153181
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
154182
pairs_[rank] = std::unique_ptr<transport::Pair>(
155183
new tcp::Pair(this, device_.get(), rank, getTimeout(), false));

gloo/transport/tcp/context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,16 @@ class Context : public ::gloo::transport::Context,
4343
int rank,
4444
bool useRankAsSeqNumber);
4545

46+
virtual std::unique_ptr<transport::Pair>& getPair(int rank) override;
47+
4648
std::unique_ptr<transport::UnboundBuffer> createUnboundBuffer(
4749
void* ptr,
4850
size_t size) override;
4951

5052
protected:
5153
std::shared_ptr<Device> device_;
54+
std::shared_ptr<IStore> store_{nullptr};
55+
std::vector<bool> connecting_;
5256

5357
using pendingRecvTuple = std::tuple<
5458
WeakNonOwningPtr<UnboundBuffer>,

gloo/transport/tcp/device.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <netinet/in.h>
1414
#include <string.h>
1515
#include <array>
16+
#include <iostream>
1617

1718
#include "gloo/common/error.h"
1819
#include "gloo/common/linux.h"
@@ -217,7 +218,9 @@ Device::Device(const struct attr& attr)
217218
interfaceSpeedMbps_(getInterfaceSpeedByName(interfaceName_)),
218219
pciBusID_(interfaceToBusID(interfaceName_)) {}
219220

220-
Device::~Device() {}
221+
Device::~Device() {
222+
loop_->shutdown();
223+
}
221224

222225
std::string Device::str() const {
223226
std::stringstream ss;
@@ -328,6 +331,10 @@ void Device::connectAsListener(
328331
listener_->waitForConnection(local.getSeq(), std::move(fn));
329332
}
330333

334+
void Device::cancelConnect(const Address& local) {
335+
listener_->cancelConnect(local.getSeq());
336+
}
337+
331338
// Connecting as initiator is active.
332339
//
333340
// The connect callback is fired when the connection to the other side

gloo/transport/tcp/device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class Device : public ::gloo::transport::Device,
101101
std::chrono::milliseconds timeout,
102102
connect_callback_t fn);
103103

104+
void cancelConnect(const Address& local);
105+
104106
void connectAsListener(
105107
const Address& local,
106108
std::chrono::milliseconds timeout,

gloo/transport/tcp/helpers.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,19 @@ class ReadValueOperation final
4545
fn_(std::move(fn)) {}
4646

4747
void run() {
48+
auto loop = loop_.lock();
49+
if (!loop) {
50+
return;
51+
}
52+
4853
// Cannot initialize leak until after the object has been
4954
// constructed, because the std::make_shared initialization
5055
// doesn't run after construction of the underlying object.
5156
leak_ = this->shared_from_this();
57+
5258
// Register with loop only after we've leaked the shared_ptr,
5359
// because we unleak it when the event loop thread calls.
54-
loop_->registerDescriptor(socket_->fd(), EPOLLIN | EPOLLONESHOT, this);
60+
loop->registerDescriptor(socket_->fd(), EPOLLIN | EPOLLONESHOT, this);
5561
}
5662

5763
void handleEvents(int events) override {
@@ -80,7 +86,7 @@ class ReadValueOperation final
8086
}
8187

8288
private:
83-
std::shared_ptr<Loop> loop_;
89+
std::weak_ptr<Loop> loop_;
8490
std::shared_ptr<Socket> socket_;
8591
callback_t fn_;
8692
std::shared_ptr<ReadValueOperation<T>> leak_;

gloo/transport/tcp/listener.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Listener::Listener(std::shared_ptr<Loop> loop, const attr& attr)
3535
}
3636

3737
Listener::~Listener() {
38+
*closed_ = true;
3839
if (listener_) {
3940
loop_->unregisterDescriptor(listener_->fd(), this);
4041
}
@@ -61,7 +62,7 @@ void Listener::handleEvents(int /* unused */) {
6162
read<sequence_number_t>(
6263
loop_,
6364
sock,
64-
[this](
65+
[this, closed = closed_](
6566
std::shared_ptr<Socket> socket,
6667
const Error& error,
6768
sequence_number_t&& seq) {
@@ -72,6 +73,10 @@ void Listener::handleEvents(int /* unused */) {
7273
return;
7374
}
7475

76+
if (*closed) {
77+
return;
78+
}
79+
7580
haveConnection(std::move(socket), seq);
7681
});
7782
}
@@ -108,6 +113,19 @@ void Listener::waitForConnection(sequence_number_t seq, connect_callback_t fn) {
108113
loop_->defer([fn, socket]() { fn(socket, Error::kSuccess); });
109114
}
110115

116+
void Listener::cancelConnect(sequence_number_t seq) {
117+
std::unique_lock<std::mutex> lock(mutex_);
118+
119+
// If we don't yet have a callback for this sequence number, do nothing.
120+
auto it = seqToCallback_.find(seq);
121+
if (it == seqToCallback_.end()) {
122+
return;
123+
}
124+
125+
// If we already have a callback for this sequence number, cancel it.
126+
seqToCallback_.erase(it);
127+
}
128+
111129
void Listener::haveConnection(
112130
std::shared_ptr<Socket> socket,
113131
sequence_number_t seq) {

gloo/transport/tcp/listener.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,14 @@ class Listener final : public Handler {
4949
// even if the connection is already available.
5050
void waitForConnection(sequence_number_t seq, connect_callback_t fn);
5151

52+
void cancelConnect(sequence_number_t seq);
53+
5254
private:
5355
std::mutex mutex_;
5456
std::shared_ptr<Loop> loop_;
5557
std::shared_ptr<Socket> listener_;
58+
std::shared_ptr<std::atomic<bool>> closed_{
59+
std::make_shared<std::atomic<bool>>(false)};
5660

5761
// Address of this listener and the sequence number for the next
5862
// connection. Sequence numbers are written by a peer right after

0 commit comments

Comments
 (0)