8
8
9
9
#include " gloo/transport/tcp/context.h"
10
10
11
+ #include < algorithm>
12
+ #include < cstdint>
11
13
#include < cstring>
12
14
#include < iostream>
15
+ #include < string>
13
16
14
- #include " gloo/common/error.h"
15
17
#include " gloo/common/logging.h"
16
18
#include " gloo/common/utils.h"
17
19
#include " gloo/transport/tcp/device.h"
@@ -22,6 +24,8 @@ namespace gloo {
22
24
namespace transport {
23
25
namespace tcp {
24
26
27
+ constexpr int kDefaultBatchSize = 128 ;
28
+
25
29
Context::Context (std::shared_ptr<Device> device, int rank, int size)
26
30
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
27
31
@@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) {
78
82
// which does not have the rank info hosted at a higher `Pair` level).
79
83
// So better safe than sorry for now we try to minimize the changeset needed.
80
84
const auto & currentRankPair = getPair (rank);
81
- auto deviceAddress = Address (
85
+ const auto & deviceAddress = Address (
82
86
static_cast <const Pair*>(currentRankPair.get ())->address ().getSockaddr ());
83
87
Rank currentRankInfo (
84
88
localHostName, deviceAddress.bytes (), std::move (pairIdentifiers));
85
89
store.set (std::to_string (rank), currentRankInfo.bytes ());
86
90
91
+ std::vector<std::vector<char >> remoteRankInfos;
92
+ int key = 0 ;
93
+ if (isStoreExtendedApiEnabled () && store.has_v2_support ()) {
94
+ auto sizeRemaining = size;
95
+ while (sizeRemaining > 0 ) {
96
+ const auto batchKeys = std::min (kDefaultBatchSize , sizeRemaining);
97
+ std::vector<std::string> keys (batchKeys);
98
+ std::generate_n (
99
+ keys.begin (), batchKeys, [&] { return std::to_string (key++); });
100
+ const auto & batchRemoteInfos = store.multi_get (keys);
101
+ remoteRankInfos.insert (
102
+ remoteRankInfos.end (),
103
+ batchRemoteInfos.begin (),
104
+ batchRemoteInfos.end ());
105
+ sizeRemaining -= batchKeys;
106
+ }
107
+ } else {
108
+ std::generate_n (std::back_inserter (remoteRankInfos), size, [&] {
109
+ const auto & keyStr = std::to_string (key++);
110
+ store.wait ({keyStr.c_str ()}, getTimeout ());
111
+ return store.get (keyStr);
112
+ });
113
+ }
114
+
87
115
// Connect every pair
88
116
for (int i = 0 ; i < size; i++) {
89
117
if (i == rank) {
@@ -95,24 +123,18 @@ void Context::createAndConnectAllPairs(IStore& store) {
95
123
continue ;
96
124
}
97
125
98
- // Wait for address of other side of this pair to become available
99
- std::ostringstream key;
100
- key << i;
101
- store.wait ({key.str ()}, getTimeout ());
126
+ Rank remoteRankInfo (remoteRankInfos[i]);
102
127
103
- // Connect to other side of this pair
104
- std::vector<char > rankInfoBytes = store.get (key.str ());
105
- Rank remoteRankInfo (rankInfoBytes);
106
- const auto & remoteHostname = remoteRankInfo.hostname ;
107
- if (!localRankSet && remoteHostname == localHostName) {
128
+ if (!localRankSet && remoteRankInfo.hostname == localHostName) {
108
129
++localRank;
109
130
}
110
131
111
132
const auto & pair = getPair (i);
112
133
auto remoteDeviceAddr = Address (remoteRankInfo.addressBytes ).getSockaddr ();
113
134
auto remoteAddr = Address (
114
135
remoteDeviceAddr,
115
- useRankAsSeqNum ? (ssize_t )rank : remoteRankInfo.pairIdentifiers [rank]);
136
+ useRankAsSeqNum ? (sequence_number_t )rank
137
+ : remoteRankInfo.pairIdentifiers [rank]);
116
138
pair->connect (remoteAddr.bytes ());
117
139
}
118
140
0 commit comments