Skip to content

Commit 9ba706d

Browse files
d4l3kfacebook-github-bot
authored andcommitted
gloo: use shared Stores (#423)
Summary: X-link: pytorch/pytorch#150230 This modifies `connectFullMesh` to take in a shared_ptr<IStore> instead of a reference. This is an API breaking change but fairly easy to work around. To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo. This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic. This change will land simultaneously in PyTorch and Gloo repos. Reviewed By: fduwjj Differential Revision: D72084111
1 parent 9d6f6bd commit 9ba706d

File tree

13 files changed

+61
-43
lines changed

13 files changed

+61
-43
lines changed

gloo/benchmark/runner.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ void Runner::rendezvousRedis() {
187187
return;
188188
}
189189

190-
rendezvous::RedisStore redisStore(options_.redisHost, options_.redisPort);
191-
rendezvous::PrefixStore prefixStore(options_.prefix, redisStore);
190+
auto redisStore = std::make_shared<rendezvous::RedisStore>(
191+
options_.redisHost, options_.redisPort);
192+
auto prefixStore =
193+
std::make_shared<rendezvous::PrefixStore>(options_.prefix, redisStore);
192194
auto backingContext = std::make_shared<rendezvous::Context>(
193195
options_.contextRank, options_.contextSize);
194196
backingContext->connectFullMesh(prefixStore, transportDevices_.front());
@@ -221,14 +223,15 @@ void Runner::rendezvousFileSystem() {
221223
return;
222224
}
223225

224-
rendezvous::FileStore fileStore(options_.sharedPath);
225-
rendezvous::PrefixStore prefixStore(options_.prefix, fileStore);
226+
auto fileStore = std::make_shared<rendezvous::FileStore>(options_.sharedPath);
227+
auto prefixStore =
228+
std::make_shared<rendezvous::PrefixStore>(options_.prefix, fileStore);
226229
auto backingContext = std::make_shared<rendezvous::Context>(
227230
options_.contextRank, options_.contextSize);
228231
backingContext->connectFullMesh(prefixStore, transportDevices_.front());
229232
// After connectFullMesh is called, the rendezvous files will have been
230233
// generated so we need to fetch them from the FileStore
231-
keyFilePaths_ = fileStore.getAllKeyFilePaths();
234+
keyFilePaths_ = fileStore->getAllKeyFilePaths();
232235
contextFactory_ =
233236
std::make_shared<rendezvous::ContextFactory>(backingContext);
234237
}

gloo/common/store.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
#pragma once
1010

1111
#include <chrono>
12+
#include <memory>
1213
#include <string>
1314
#include <vector>
1415

16+
#define GLOO_SHARED_STORE
17+
1518
namespace gloo {
1619

1720
class IStore {
@@ -22,6 +25,13 @@ class IStore {
2225

2326
virtual std::vector<char> get(const std::string& key) = 0;
2427

28+
virtual std::vector<char> wait_get(
29+
const std::string& key,
30+
const std::chrono::milliseconds& timeout) {
31+
wait({key}, timeout);
32+
return get(key);
33+
}
34+
2535
virtual void wait(
2636
const std::vector<std::string>& keys,
2737
const std::chrono::milliseconds& timeout) = 0;

gloo/examples/example1.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,15 @@ int main(void) {
8282
// Below, we instantiate rendezvous using the filesystem, given that
8383
// this example uses multiple processes on a single machine.
8484
//
85-
auto fileStore = gloo::rendezvous::FileStore("/tmp");
85+
auto fileStore = std::make_shared<gloo::rendezvous::FileStore>("/tmp");
8686

8787
// To be able to reuse the same store over and over again and not have
8888
// interference between runs, we scope it to a unique prefix with the
8989
// PrefixStore. This wraps another store and prefixes every key before
9090
// forwarding the call to the underlying store.
9191
std::string prefix = getenv("PREFIX");
92-
auto prefixStore = gloo::rendezvous::PrefixStore(prefix, fileStore);
92+
auto prefixStore =
93+
std::make_shared<gloo::rendezvous::PrefixStore>(prefix, fileStore);
9394

9495
// Using this store, we can now create a Gloo context. The context
9596
// holds a reference to every communication pair involving this

gloo/rendezvous/context.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <memory>
1212

1313
#include "gloo/common/logging.h"
14+
#include "gloo/rendezvous/store.h"
1415
#include "gloo/transport/address.h"
1516

1617
namespace gloo {
@@ -22,12 +23,12 @@ Context::Context(int rank, int size, int base)
2223
Context::~Context() {}
2324

2425
void Context::connectFullMesh(
25-
rendezvous::Store& store,
26+
std::shared_ptr<rendezvous::Store> store,
2627
std::shared_ptr<transport::Device>& dev) {
2728
auto transportContext = dev->createContext(rank, size);
2829
transportContext->setTimeout(getTimeout());
2930

30-
transportContext->createAndConnectAllPairs(store);
31+
transportContext->createAndConnectAllPairs(std::move(store));
3132

3233
device_ = dev;
3334
transportContext_ = std::move(transportContext);

gloo/rendezvous/context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Context : public ::gloo::Context {
2929
Context(int rank, int size, int base = 2);
3030
virtual ~Context();
3131

32-
void connectFullMesh(Store& store, std::shared_ptr<transport::Device>& dev);
32+
void connectFullMesh(
33+
std::shared_ptr<Store> store,
34+
std::shared_ptr<transport::Device>& dev);
3335

3436
protected:
3537
friend class ContextFactory;

gloo/rendezvous/prefix_store.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
namespace gloo {
1414
namespace rendezvous {
1515

16-
PrefixStore::PrefixStore(const std::string& prefix, Store& store)
17-
: prefix_(prefix), store_(store) {}
16+
PrefixStore::PrefixStore(
17+
const std::string& prefix,
18+
std::shared_ptr<Store> store)
19+
: prefix_(prefix), store_(std::move(store)) {}
1820

1921
std::string PrefixStore::joinKey(const std::string& key) {
2022
std::stringstream ss;
@@ -23,11 +25,11 @@ std::string PrefixStore::joinKey(const std::string& key) {
2325
}
2426

2527
void PrefixStore::set(const std::string& key, const std::vector<char>& data) {
26-
store_.set(joinKey(key), data);
28+
store_->set(joinKey(key), data);
2729
}
2830

2931
std::vector<char> PrefixStore::get(const std::string& key) {
30-
return store_.get(joinKey(key));
32+
return store_->get(joinKey(key));
3133
}
3234

3335
void PrefixStore::wait(
@@ -38,56 +40,56 @@ void PrefixStore::wait(
3840
for (const auto& key : keys) {
3941
joinedKeys.push_back(joinKey(key));
4042
}
41-
store_.wait(joinedKeys, timeout);
43+
store_->wait(joinedKeys, timeout);
4244
}
4345

4446
bool PrefixStore::has_v2_support() {
45-
return store_.has_v2_support();
47+
return store_->has_v2_support();
4648
}
4749

4850
std::vector<std::vector<char>> PrefixStore::multi_get(
4951
const std::vector<std::string>& keys) {
50-
if (!store_.has_v2_support()) {
52+
if (!store_->has_v2_support()) {
5153
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
5254
"underlying store doesn't support multi_get");
5355
}
5456
std::vector<std::string> prefixed_keys;
5557
for (auto& key : keys) {
5658
prefixed_keys.push_back(joinKey(key));
5759
}
58-
return store_.multi_get(prefixed_keys);
60+
return store_->multi_get(prefixed_keys);
5961
}
6062

6163
void PrefixStore::multi_set(
6264
const std::vector<std::string>& keys,
6365
const std::vector<std::vector<char>>& values) {
64-
if (!store_.has_v2_support()) {
66+
if (!store_->has_v2_support()) {
6567
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
6668
"underlying store doesn't support multi_set");
6769
}
6870
std::vector<std::string> prefixed_keys;
6971
for (auto& key : keys) {
7072
prefixed_keys.push_back(joinKey(key));
7173
}
72-
return store_.multi_set(prefixed_keys, values);
74+
return store_->multi_set(prefixed_keys, values);
7375
}
7476

7577
void PrefixStore::append(
7678
const std::string& key,
7779
const std::vector<char>& data) {
78-
if (!store_.has_v2_support()) {
80+
if (!store_->has_v2_support()) {
7981
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
8082
"underlying store doesn't support append");
8183
}
82-
store_.append(joinKey(key), data);
84+
store_->append(joinKey(key), data);
8385
}
8486

8587
int64_t PrefixStore::add(const std::string& key, int64_t value) {
86-
if (!store_.has_v2_support()) {
88+
if (!store_->has_v2_support()) {
8789
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
8890
"underlying store doesn't support append");
8991
}
90-
return store_.add(joinKey(key), value);
92+
return store_->add(joinKey(key), value);
9193
}
9294

9395
} // namespace rendezvous

gloo/rendezvous/prefix_store.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace rendezvous {
1717

1818
class PrefixStore : public Store {
1919
public:
20-
PrefixStore(const std::string& prefix, Store& store);
20+
PrefixStore(const std::string& prefix, std::shared_ptr<Store> store);
2121

2222
virtual ~PrefixStore() {}
2323

@@ -46,7 +46,7 @@ class PrefixStore : public Store {
4646

4747
protected:
4848
const std::string prefix_;
49-
Store& store_;
49+
std::shared_ptr<Store> store_;
5050

5151
std::string joinKey(const std::string& key);
5252
};

gloo/test/base_test.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class BaseTest : public ::testing::Test {
106106
std::function<void(std::shared_ptr<Context>)> fn,
107107
int base = 2) {
108108
Barrier barrier(size);
109-
::gloo::rendezvous::HashStore store;
109+
auto store = std::make_shared<::gloo::rendezvous::HashStore>();
110110

111111
spawnThreads(size, [&](int rank) {
112112
auto context =

gloo/test/multiproc_test.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ class MultiProcWorker {
101101
auto context = std::make_shared<::gloo::rendezvous::Context>(rank, size);
102102
auto device = createDevice(transport);
103103
context->setTimeout(std::chrono::milliseconds(kMultiProcTimeout));
104-
context->connectFullMesh(*store_, device);
104+
context->connectFullMesh(store_, device);
105105
device.reset();
106106
sem_post(semaphore_);
107107
fn(std::move(context));
108108
}
109109

110110
protected:
111-
std::unique_ptr<::gloo::rendezvous::Store> store_;
111+
std::shared_ptr<::gloo::rendezvous::Store> store_;
112112
sem_t* semaphore_;
113113
};
114114

gloo/transport/context.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ std::unique_ptr<transport::Pair>& Context::getPair(int rank_2) {
2323
return pairs_.at(rank_2);
2424
}
2525

26-
void Context::createAndConnectAllPairs(IStore& store) {
26+
void Context::createAndConnectAllPairs(std::shared_ptr<IStore> store) {
2727
// this is the default un-optimized version of the rendezvous protocol
2828
// where each rank would write N pairs to the store
2929
// and then for each remote peer load the N addresses
@@ -40,15 +40,15 @@ void Context::createAndConnectAllPairs(IStore& store) {
4040
// hostname mapping to compute local ranks.
4141
std::string localKey("rank_" + std::to_string(rank));
4242
const std::vector<char> value(localHostName.begin(), localHostName.end());
43-
store.set(localKey, value);
43+
store->set(localKey, value);
4444

4545
for (int i = 0; i < size; i++) {
4646
if (i == rank) {
4747
break;
4848
}
4949

5050
std::string key("rank_" + std::to_string(i));
51-
auto val = store.get(key);
51+
auto val = store->get(key);
5252
auto hostName = std::string((const char*)val.data(), val.size());
5353

5454
if (hostName == localHostName) {
@@ -68,7 +68,7 @@ void Context::createAndConnectAllPairs(IStore& store) {
6868
allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end());
6969
}
7070

71-
store.set(std::to_string(rank), allBytes);
71+
store->set(std::to_string(rank), allBytes);
7272

7373
// Connect every pair
7474
for (int i = 0; i < size; i++) {
@@ -79,10 +79,10 @@ void Context::createAndConnectAllPairs(IStore& store) {
7979
// Wait for address of other side of this pair to become available
8080
std::ostringstream key;
8181
key << i;
82-
store.wait({key.str()}, getTimeout());
82+
store->wait({key.str()}, getTimeout());
8383

8484
// Connect to other side of this pair
85-
auto allAddrs = store.get(key.str());
85+
auto allAddrs = store->get(key.str());
8686
auto addr = extractAddress(allAddrs, i);
8787
getPair(i)->connect(addr);
8888
}

0 commit comments

Comments
 (0)