Skip to content

Commit 7accb26

Browse files
d4l3kfacebook-github-bot
authored andcommitted
gloo: use unbound buffers for barrier, broadcast
Summary: This makes barrier and broadcast operations in Gloo use UnboundBuffer instead of Buffer. UnboundBuffer is the newer preferred buffer implementation and also makes the algorithms much easier to express/understand and efficient since we don't need to allocate more memory. Tracking issue for Buffer removal: #432 Differential Revision: D74096980
1 parent 575bf30 commit 7accb26

File tree

3 files changed

+51
-119
lines changed

3 files changed

+51
-119
lines changed

gloo/barrier_all_to_all.h

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,33 @@ namespace gloo {
1515
class BarrierAllToAll : public Barrier {
1616
public:
1717
explicit BarrierAllToAll(const std::shared_ptr<Context>& context)
18-
: Barrier(context) {
18+
: Barrier(context) {}
19+
20+
void run() {
1921
// Create send/recv buffers for every peer
2022
auto slot = this->context_->nextSlot();
23+
24+
auto buffer = this->context_->createUnboundBuffer(nullptr, 0);
25+
auto timeout = this->context_->getTimeout();
26+
2127
for (auto i = 0; i < this->contextSize_; i++) {
2228
// Skip self
2329
if (i == this->contextRank_) {
2430
continue;
2531
}
26-
27-
auto& pair = this->getPair(i);
28-
auto sdata = std::unique_ptr<int>(new int);
29-
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
30-
sendBuffersData_.push_back(std::move(sdata));
31-
sendBuffers_.push_back(std::move(sbuf));
32-
auto rdata = std::unique_ptr<int>(new int);
33-
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
34-
recvBuffersData_.push_back(std::move(rdata));
35-
recvBuffers_.push_back(std::move(rbuf));
32+
buffer->send(i, slot);
33+
buffer->recv(i, slot);
3634
}
37-
}
3835

39-
void run() {
40-
// Notify peers
41-
for (auto& buffer : sendBuffers_) {
42-
buffer->send();
43-
}
44-
// Wait for notification from peers
45-
for (auto& buffer : recvBuffers_) {
46-
buffer->waitRecv();
36+
for (auto i = 0; i < this->contextSize_; i++) {
37+
// Skip self
38+
if (i == this->contextRank_) {
39+
continue;
40+
}
41+
buffer->waitSend(timeout);
42+
buffer->waitRecv(timeout);
4743
}
4844
}
49-
50-
protected:
51-
std::vector<std::unique_ptr<int>> sendBuffersData_;
52-
std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
53-
std::vector<std::unique_ptr<int>> recvBuffersData_;
54-
std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
5545
};
5646

5747
} // namespace gloo

gloo/barrier_all_to_one.h

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,65 +17,42 @@ class BarrierAllToOne : public Barrier {
1717
explicit BarrierAllToOne(
1818
const std::shared_ptr<Context>& context,
1919
int rootRank = 0)
20-
: Barrier(context), rootRank_(rootRank) {
20+
: Barrier(context), rootRank_(rootRank) {}
21+
22+
void run() {
2123
auto slot = this->context_->nextSlot();
24+
auto timeout = this->context_->getTimeout();
25+
26+
auto buffer = this->context_->createUnboundBuffer(nullptr, 0);
27+
2228
if (this->contextRank_ == rootRank_) {
23-
// Create send/recv buffers for every peer
2429
for (int i = 0; i < this->contextSize_; i++) {
2530
// Skip self
2631
if (i == this->contextRank_) {
2732
continue;
2833
}
29-
30-
auto& pair = this->getPair(i);
31-
auto sdata = std::unique_ptr<int>(new int);
32-
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
33-
sendBuffersData_.push_back(std::move(sdata));
34-
sendBuffers_.push_back(std::move(sbuf));
35-
auto rdata = std::unique_ptr<int>(new int);
36-
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
37-
recvBuffersData_.push_back(std::move(rdata));
38-
recvBuffers_.push_back(std::move(rbuf));
34+
buffer->recv(i, slot);
35+
buffer->waitRecv(timeout);
3936
}
40-
} else {
41-
// Create send/recv buffers to/from the root
42-
auto& pair = this->getPair(rootRank_);
43-
auto sdata = std::unique_ptr<int>(new int);
44-
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
45-
sendBuffersData_.push_back(std::move(sdata));
46-
sendBuffers_.push_back(std::move(sbuf));
47-
auto rdata = std::unique_ptr<int>(new int);
48-
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
49-
recvBuffersData_.push_back(std::move(rdata));
50-
recvBuffers_.push_back(std::move(rbuf));
51-
}
52-
}
53-
54-
void run() {
55-
if (this->contextRank_ == rootRank_) {
56-
// Wait for message from all peers
57-
for (auto& b : recvBuffers_) {
58-
b->waitRecv();
59-
}
60-
// Notify all peers
61-
for (auto& b : sendBuffers_) {
62-
b->send();
37+
for (int i = 0; i < this->contextSize_; i++) {
38+
// Skip self
39+
if (i == this->contextRank_) {
40+
continue;
41+
}
42+
buffer->send(i, slot);
43+
buffer->waitSend(timeout);
6344
}
45+
6446
} else {
65-
// Send message to root
66-
sendBuffers_[0]->send();
67-
// Wait for acknowledgement from root
68-
recvBuffers_[0]->waitRecv();
47+
buffer->send(rootRank_, slot);
48+
buffer->waitSend(timeout);
49+
buffer->recv(rootRank_, slot);
50+
buffer->waitRecv(timeout);
6951
}
7052
}
7153

7254
protected:
7355
const int rootRank_;
74-
75-
std::vector<std::unique_ptr<int>> sendBuffersData_;
76-
std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
77-
std::vector<std::unique_ptr<int>> recvBuffersData_;
78-
std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
7956
};
8057

8158
} // namespace gloo

gloo/broadcast_one_to_all.h

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,6 @@ class BroadcastOneToAll : public Algorithm {
3636
GLOO_ENFORCE_LT(rootRank_, contextSize_);
3737
GLOO_ENFORCE_GE(rootPointerRank_, 0);
3838
GLOO_ENFORCE_LT(rootPointerRank_, ptrs_.size());
39-
40-
// Setup pairs/buffers for sender/receivers
41-
if (contextSize_ > 1) {
42-
auto ptr = ptrs_[rootPointerRank_];
43-
auto slot = context_->nextSlot();
44-
if (contextRank_ == rootRank_) {
45-
sender_.resize(contextSize_);
46-
for (auto i = 0; i < contextSize_; i++) {
47-
if (i == contextRank_) {
48-
continue;
49-
}
50-
51-
sender_[i] = make_unique<forSender>();
52-
auto& pair = context_->getPair(i);
53-
sender_[i]->clearToSendBuffer = pair->createRecvBuffer(
54-
slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy));
55-
sender_[i]->sendBuffer = pair->createSendBuffer(slot, ptr, bytes_);
56-
}
57-
} else {
58-
receiver_ = make_unique<forReceiver>();
59-
auto& rootPair = context_->getPair(rootRank_);
60-
receiver_->clearToSendBuffer = rootPair->createSendBuffer(
61-
slot, &receiver_->dummy, sizeof(receiver_->dummy));
62-
receiver_->recvBuffer = rootPair->createRecvBuffer(slot, ptr, bytes_);
63-
}
64-
}
6539
}
6640

6741
void run() {
@@ -70,14 +44,21 @@ class BroadcastOneToAll : public Algorithm {
7044
return;
7145
}
7246

47+
auto clearToSendBuffer = context_->createUnboundBuffer(nullptr, 0);
48+
auto buffer =
49+
context_->createUnboundBuffer(ptrs_[rootPointerRank_], bytes_);
50+
auto slot = context_->nextSlot();
51+
auto timeout = context_->getTimeout();
52+
7353
if (contextRank_ == rootRank_) {
7454
// Fire off send operations after receiving clear to send
7555
for (auto i = 0; i < contextSize_; i++) {
7656
if (i == contextRank_) {
7757
continue;
7858
}
79-
sender_[i]->clearToSendBuffer->waitRecv();
80-
sender_[i]->sendBuffer->send();
59+
clearToSendBuffer->recv(i, slot);
60+
clearToSendBuffer->waitRecv(timeout);
61+
buffer->send(i, slot);
8162
}
8263

8364
// Broadcast locally while sends are happening
@@ -88,11 +69,13 @@ class BroadcastOneToAll : public Algorithm {
8869
if (i == contextRank_) {
8970
continue;
9071
}
91-
sender_[i]->sendBuffer->waitSend();
72+
buffer->waitSend(timeout);
9273
}
9374
} else {
94-
receiver_->clearToSendBuffer->send();
95-
receiver_->recvBuffer->waitRecv();
75+
clearToSendBuffer->send(rootRank_, slot);
76+
clearToSendBuffer->waitSend(timeout);
77+
buffer->recv(rootRank_, slot);
78+
buffer->waitRecv(timeout);
9679

9780
// Broadcast locally after receiving from root
9881
broadcastLocally();
@@ -116,24 +99,6 @@ class BroadcastOneToAll : public Algorithm {
11699
const size_t bytes_;
117100
const int rootRank_;
118101
const int rootPointerRank_;
119-
120-
// For the sender (root)
121-
struct forSender {
122-
int dummy;
123-
std::unique_ptr<transport::Buffer> clearToSendBuffer;
124-
std::unique_ptr<transport::Buffer> sendBuffer;
125-
};
126-
127-
std::vector<std::unique_ptr<forSender>> sender_;
128-
129-
// For all receivers
130-
struct forReceiver {
131-
int dummy;
132-
std::unique_ptr<transport::Buffer> clearToSendBuffer;
133-
std::unique_ptr<transport::Buffer> recvBuffer;
134-
};
135-
136-
std::unique_ptr<forReceiver> receiver_;
137102
};
138103

139104
} // namespace gloo

0 commit comments

Comments
 (0)