Skip to content

Commit 06f24f5

Browse files
Boris Saranafacebook-github-bot
authored andcommitted
Use multi_get for store that has extended API support. (#408)
Summary: Pull Request resolved: #408 The TCP store has API v2 support we can reduce the network overhead of Gloo rendezvous significantly by fetching a batch of key instead of doing them one by one. Initial testing shows ~15X improvement for 4k jobs. Gloo process group init: Baseline ( fbcode trunk): 2k job (https://fburl.com/mlhub/x1prxu89) : ~82sec (~1.4 min) 4k job (https://fburl.com/mlhub/v1djk4n5) : ~393 sec (~6.6min) 8k job (https://fburl.com/mlhub/cagqrs7m): (~55mins) With optimizations (D48130088 + D52083376): 2k job (https://fburl.com/mlhub/x0cskdag) : ~18 sec ( ~5x faster) 4k job (https://fburl.com/mlhub/xzmvkm4j) : ~ 25 sec (~15x faster) 8k job (https://fburl.com/mlhub/gdyeizv9) : ~ 85 sec (~35x faster) Reviewed By: xunnanxu Differential Revision: D52083376
1 parent 1ff67a5 commit 06f24f5

File tree

4 files changed

+57
-12
lines changed

4 files changed

+57
-12
lines changed

gloo/common/store.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ class IStore {
2525
virtual void wait(
2626
const std::vector<std::string>& keys,
2727
const std::chrono::milliseconds& timeout) = 0;
28+
29+
// Extended 2.0 API support
30+
virtual bool has_v2_support() = 0;
31+
32+
virtual std::vector<std::vector<char>> multi_get(
33+
const std::vector<std::string>& keys) = 0;
34+
35+
virtual void multi_set(
36+
const std::vector<std::string>& keys,
37+
const std::vector<std::vector<char>>& values) = 0;
38+
39+
virtual void append(
40+
const std::string& key,
41+
const std::vector<char>& value) = 0;
42+
virtual int64_t add(const std::string& key, int64_t value) = 0;
2843
};
2944

3045
} // namespace gloo

gloo/common/utils.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,10 @@ bool useRankAsSeqNumber() {
3636
(std::string(res) == "True" || std::string(res) == "1");
3737
}
3838

39+
bool isStoreExtendedApiEnabled() {
40+
const auto& res = std::getenv("GLOO_ENABLE_STORE_V2_API");
41+
return res != nullptr &&
42+
(std::string(res) == "True" || std::string(res) == "1");
43+
}
44+
3945
} // namespace gloo

gloo/common/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ std::string getHostname();
1616

1717
bool useRankAsSeqNumber();
1818

19+
bool isStoreExtendedApiEnabled();
20+
1921
} // namespace gloo

gloo/transport/tcp/context.cc

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
#include "gloo/transport/tcp/context.h"
1010

11+
#include <algorithm>
12+
#include <cstdint>
1113
#include <cstring>
1214
#include <iostream>
15+
#include <string>
1316

14-
#include "gloo/common/error.h"
1517
#include "gloo/common/logging.h"
1618
#include "gloo/common/utils.h"
1719
#include "gloo/transport/tcp/device.h"
@@ -22,6 +24,8 @@ namespace gloo {
2224
namespace transport {
2325
namespace tcp {
2426

27+
constexpr int kDefaultBatchSize = 128;
28+
2529
Context::Context(std::shared_ptr<Device> device, int rank, int size)
2630
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
2731

@@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) {
7882
// which does not have the rank info hosted at a higher `Pair` level).
7983
// So better safe than sorry for now we try to minimize the changeset needed.
8084
const auto& currentRankPair = getPair(rank);
81-
auto deviceAddress = Address(
85+
const auto& deviceAddress = Address(
8286
static_cast<const Pair*>(currentRankPair.get())->address().getSockaddr());
8387
Rank currentRankInfo(
8488
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
8589
store.set(std::to_string(rank), currentRankInfo.bytes());
8690

91+
std::vector<std::vector<char>> remoteRankInfos;
92+
int key = 0;
93+
if (isStoreExtendedApiEnabled() && store.has_v2_support()) {
94+
auto sizeRemaining = size;
95+
while (sizeRemaining > 0) {
96+
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
97+
std::vector<std::string> keys(batchKeys);
98+
std::generate_n(
99+
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
100+
const auto& batchRemoteInfos = store.multi_get(keys);
101+
remoteRankInfos.insert(
102+
remoteRankInfos.end(),
103+
batchRemoteInfos.begin(),
104+
batchRemoteInfos.end());
105+
sizeRemaining -= batchKeys;
106+
}
107+
} else {
108+
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
109+
const auto& keyStr = std::to_string(key++);
110+
store.wait({keyStr.c_str()}, getTimeout());
111+
return store.get(keyStr);
112+
});
113+
}
114+
87115
// Connect every pair
88116
for (int i = 0; i < size; i++) {
89117
if (i == rank) {
@@ -95,24 +123,18 @@ void Context::createAndConnectAllPairs(IStore& store) {
95123
continue;
96124
}
97125

98-
// Wait for address of other side of this pair to become available
99-
std::ostringstream key;
100-
key << i;
101-
store.wait({key.str()}, getTimeout());
126+
Rank remoteRankInfo(remoteRankInfos[i]);
102127

103-
// Connect to other side of this pair
104-
std::vector<char> rankInfoBytes = store.get(key.str());
105-
Rank remoteRankInfo(rankInfoBytes);
106-
const auto& remoteHostname = remoteRankInfo.hostname;
107-
if (!localRankSet && remoteHostname == localHostName) {
128+
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
108129
++localRank;
109130
}
110131

111132
const auto& pair = getPair(i);
112133
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
113134
auto remoteAddr = Address(
114135
remoteDeviceAddr,
115-
useRankAsSeqNum ? (ssize_t)rank : remoteRankInfo.pairIdentifiers[rank]);
136+
useRankAsSeqNum ? (sequence_number_t)rank
137+
: remoteRankInfo.pairIdentifiers[rank]);
116138
pair->connect(remoteAddr.bytes());
117139
}
118140

0 commit comments

Comments
 (0)