Skip to content

Commit 38a8de2

Browse files
d4l3kfacebook-github-bot
authored andcommitted
c10d/gloo: add ibverbs backend
Summary: This provides a new "UnboundBuffer" implementation for Gloo ibverbs backend so it can be used with PyTorch. This currently is passing basic tests such as `reduce_test` and `send_recv_test` but there are a number of failures. Putting this up for review so the follow up fixes are less of a mega PR and also so we can start doing some initial testing with this E2E with PyTorch. Known issues: * support recv from any is not supported * AllreduceBcubeBase2 is failing Differential Revision: D73291471
1 parent 4ecd9ce commit 38a8de2

File tree

16 files changed

+768
-115
lines changed

16 files changed

+768
-115
lines changed

gloo/allreduce_ring_chunked.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,19 @@ class AllreduceRingChunked : public Algorithm {
145145
fn_->call(&ptrs_[0][offset], inbox_[chunkOffset & 1], length);
146146
}
147147

148+
GLOO_DEBUG("START sendNotification round=", round);
149+
148150
// Send notification to node on the left that
149151
// this node is ready for an inbox write.
150152
sendNotificationBuf_->send();
151153

154+
GLOO_DEBUG("START recvNotification round=", round);
152155
// Wait for notification from node on the right
153156
// to be sure this node can start an inbox write.
154157
recvNotificationBuf_->waitRecv();
155158

159+
GLOO_DEBUG("DONE recvNotification round=", round);
160+
156161
// Copy accumulated chunk
157162
copyChunkAtOffset(chunkOffset);
158163
}

gloo/common/logging.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,20 @@
1515
#include <limits>
1616
#include <vector>
1717

18+
#include "gloo/common/error.h"
1819
#include "gloo/common/string.h"
1920

2021
namespace gloo {
2122

23+
#define GLOO_LOG_MSG(level, ...) \
24+
std::cerr << ::gloo::MakeString( \
25+
"[", __FILE__, ":", __LINE__, "] ", level, " ", __VA_ARGS__, "\n")
26+
27+
#define GLOO_INFO(...) GLOO_LOG_MSG("INFO", __VA_ARGS__)
28+
#define GLOO_ERROR(...) GLOO_LOG_MSG("ERROR", __VA_ARGS__)
29+
#define GLOO_WARN(...) GLOO_LOG_MSG("WARN", __VA_ARGS__)
30+
#define GLOO_DEBUG(...) // GLOO_LOG_MSG("DEBUG", __VA_ARGS__)
31+
2232
class EnforceNotMet : public std::exception {
2333
public:
2434
EnforceNotMet(
@@ -157,7 +167,4 @@ BINARY_COMP_HELPER(LessEquals, <=)
157167
#define GLOO_ENFORCE_GT(x, y, ...) \
158168
GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__)
159169

160-
#define GLOO_ERROR(...) \
161-
std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl
162-
163170
} // namespace gloo

gloo/test/base_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,35 @@ const char* kDefaultDevice = "localhost";
1616

1717
// Transports that instantiated algorithms can be tested against.
1818
const std::vector<Transport> kTransportsForClassAlgorithms = {
19+
#if GLOO_HAVE_TRANSPORT_TCP
1920
Transport::TCP,
2021
Transport::TCP_LAZY,
22+
#endif
2123
#if GLOO_HAVE_TRANSPORT_TCP_TLS
2224
Transport::TCP_TLS,
2325
#endif
26+
#if GLOO_HAVE_TRANSPORT_IBVERBS
27+
Transport::IBVERBS,
28+
#endif
2429
};
2530

2631
// Transports that function algorithms can be tested against.
2732
// This is the new style of calling collectives and must be
2833
// preferred over the instantiated style.
2934
const std::vector<Transport> kTransportsForFunctionAlgorithms = {
35+
#if GLOO_HAVE_TRANSPORT_TCP
3036
Transport::TCP,
3137
Transport::TCP_LAZY,
38+
#endif
3239
#if GLOO_HAVE_TRANSPORT_TCP_TLS
3340
Transport::TCP_TLS,
3441
#endif
42+
#if GLOO_HAVE_TRANSPORT_UV
3543
Transport::UV,
44+
#endif
45+
#if GLOO_HAVE_TRANSPORT_IBVERBS
46+
Transport::IBVERBS,
47+
#endif
3648
};
3749

3850
std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
@@ -59,6 +71,13 @@ std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
5971
return ::gloo::transport::uv::CreateDevice(kDefaultDevice);
6072
#endif
6173
}
74+
#endif
75+
#if GLOO_HAVE_TRANSPORT_IBVERBS
76+
if (transport == Transport::IBVERBS) {
77+
gloo::transport::ibverbs::attr attr;
78+
attr.port = 1;
79+
return ::gloo::transport::ibverbs::CreateDevice(attr);
80+
}
6281
#endif
6382
return nullptr;
6483
}

gloo/test/base_test.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
#include "gloo/transport/uv/device.h"
3434
#endif
3535

36+
#if GLOO_HAVE_TRANSPORT_IBVERBS
37+
#include "gloo/transport/ibverbs/device.h"
38+
#endif
39+
3640
namespace gloo {
3741
namespace test {
3842

@@ -64,6 +68,7 @@ enum Transport {
6468
TCP_TLS,
6569
#endif
6670
UV,
71+
IBVERBS,
6772
};
6873

6974
extern const std::vector<Transport> kTransportsForClassAlgorithms;

gloo/test/send_recv_test.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include <array>
1313
#include <unordered_set>
1414

15-
#include "gloo/transport/tcp/unbound_buffer.h"
16-
1715
namespace gloo {
1816
namespace test {
1917
namespace {
@@ -515,7 +513,7 @@ INSTANTIATE_TEST_CASE_P(
515513
SendRecvDefault,
516514
SendRecvTest,
517515
::testing::Combine(
518-
::testing::Values(Transport::TCP, Transport::UV),
516+
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
519517
::testing::Values(2, 3, 4, 5, 6, 7, 8),
520518
::testing::Values(1)));
521519

gloo/test/tcp_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <gtest/gtest.h>
22

3+
#if GLOO_HAVE_TRANSPORT_TCP
4+
35
#include <gloo/transport/tcp/helpers.h>
46
#include <gloo/transport/tcp/loop.h>
57

@@ -34,3 +36,5 @@ TEST(TcpTest, ConnectTimeout) {
3436
} // namespace tcp
3537
} // namespace transport
3638
} // namespace gloo
39+
40+
#endif // GLOO_HAVE_TRANSPORT_TCP

gloo/transport/ibverbs/buffer.cc

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ Buffer::Buffer(Pair* pair, int slot, void* ptr, size_t size)
3030
ex_(nullptr) {
3131
mr_ = ibv_reg_mr(
3232
pair_->dev_->pd_,
33-
ptr_,
34-
size_,
33+
size == 0 ? static_cast<void*>(&emptyBuf_) : ptr,
34+
size == 0 ? sizeof(emptyBuf_) : size,
3535
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
3636

3737
// Provide hint if the error is EFAULT and nv_peer_mem is not loaded
@@ -59,8 +59,28 @@ Buffer::Buffer(Pair* pair, int slot, void* ptr, size_t size)
5959
}
6060

6161
Buffer::~Buffer() {
62-
GLOO_ENFORCE_EQ(sendPending_, 0, "Destructing buffer expecting completions");
63-
ibv_dereg_mr(mr_);
62+
int sendPending = 0;
63+
{
64+
std::unique_lock<std::mutex> lock(m_);
65+
sendPending = sendPending_;
66+
}
67+
68+
if (sendPending > 0) {
69+
GLOO_WARN(
70+
"Destructing buffer with pending sends, sendPending_=", sendPending_);
71+
try {
72+
for (int i = 0; i < sendPending; i++) {
73+
waitSend();
74+
}
75+
} catch (const std::exception& ex) {
76+
GLOO_WARN("Exception while waiting for send completion: ", ex.what());
77+
}
78+
}
79+
80+
{
81+
std::unique_lock<std::mutex> lock(m_);
82+
ibv_dereg_mr(mr_);
83+
}
6484
}
6585

6686
// Wait for a receive operation to finish.
@@ -167,36 +187,40 @@ void Buffer::waitSend() {
167187
}
168188

169189
void Buffer::send(size_t offset, size_t length, size_t roffset) {
170-
// Can't assert on roffset, since we don't know the size of
171-
// the remote buffer. Refactor of initialization code needed
172-
// to support this.
173-
GLOO_ENFORCE_LE(offset + length, size_);
174-
175190
{
176191
std::unique_lock<std::mutex> lock(m_);
192+
193+
// Can't assert on roffset, since we don't know the size of
194+
// the remote buffer. Refactor of initialization code needed
195+
// to support this.
196+
GLOO_ENFORCE_LE(offset + length, size_);
197+
177198
checkErrorState();
178-
}
179199

180-
if (debug_) {
181-
std::cout << "[" << getpid() << "] ";
182-
std::cout << "send " << length << " bytes";
183-
std::cout << std::endl;
200+
if (debug_) {
201+
std::cout << "[" << getpid() << "] ";
202+
std::cout << "send " << length << " bytes";
203+
std::cout << std::endl;
204+
}
205+
206+
// Increment number of sends in flight
207+
sendPending_++;
184208
}
185209

186-
// Increment number of sends in flight
187-
sendPending_++;
210+
// Release lock before calling into the pair to avoid deadlock.
188211

189212
pair_->send(this, offset, length, roffset);
190213
}
191214

192-
void Buffer::handleCompletion(struct ibv_wc* wc) {
215+
void Buffer::handleCompletion(int rank, struct ibv_wc* wc) {
216+
std::unique_lock<std::mutex> lock(m_);
217+
193218
if (wc->opcode & IBV_WC_RECV) {
194219
if (debug_) {
195220
std::cout << "[" << getpid() << "] ";
196221
std::cout << "recv " << wc->byte_len << " bytes";
197222
std::cout << std::endl;
198223
}
199-
std::unique_lock<std::mutex> lock(m_);
200224
recvCompletions_++;
201225
recvCv_.notify_one();
202226
} else if (wc->opcode == IBV_WC_RDMA_WRITE) {
@@ -205,7 +229,6 @@ void Buffer::handleCompletion(struct ibv_wc* wc) {
205229
std::cout << "send complete";
206230
std::cout << std::endl;
207231
}
208-
std::unique_lock<std::mutex> lock(m_);
209232
sendCompletions_++;
210233
sendPending_--;
211234
sendCv_.notify_one();

gloo/transport/ibverbs/buffer.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace gloo {
2424
namespace transport {
2525
namespace ibverbs {
2626

27-
class Buffer : public ::gloo::transport::Buffer {
27+
class Buffer : public ::gloo::transport::Buffer, public BufferHandler {
2828
public:
2929
virtual ~Buffer();
3030

@@ -33,18 +33,26 @@ class Buffer : public ::gloo::transport::Buffer {
3333
virtual void waitRecv() override;
3434
virtual void waitSend() override;
3535

36-
void handleCompletion(struct ibv_wc* wc);
36+
void handleCompletion(int rank, struct ibv_wc* wc) override;
3737

38-
void signalError(const std::exception_ptr& ex);
38+
void signalError(const std::exception_ptr& ex) override;
3939
void checkErrorState();
4040

41+
bool isPeristentHandler() override {
42+
return true;
43+
}
44+
4145
protected:
4246
// May only be constructed from helper function in pair.cc
4347
Buffer(Pair* pair, int slot, void* ptr, size_t size);
4448

4549
Pair* pair_;
4650

51+
// Empty buffer to use when a nullptr buffer is created.
52+
char emptyBuf_[1];
53+
4754
struct ibv_mr* mr_;
55+
std::unique_ptr<struct ibv_mr> peerMr_;
4856

4957
std::mutex m_;
5058
std::condition_variable recvCv_;

gloo/transport/ibverbs/context.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "gloo/common/error.h"
1212
#include "gloo/transport/ibverbs/device.h"
1313
#include "gloo/transport/ibverbs/pair.h"
14+
#include "gloo/transport/ibverbs/unbound_buffer.h"
1415

1516
namespace gloo {
1617
namespace transport {
@@ -23,16 +24,26 @@ Context::~Context() {}
2324

2425
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
2526
pairs_[rank] = std::unique_ptr<transport::Pair>(
26-
new ibverbs::Pair(device_, getTimeout()));
27+
new ibverbs::Pair(rank, device_, getTimeout()));
2728
return pairs_[rank];
2829
}
2930

3031
std::unique_ptr<transport::UnboundBuffer> Context::createUnboundBuffer(
3132
void* ptr,
3233
size_t size) {
33-
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
34-
"Unbound buffers not supported yet for ibverbs transport");
35-
return std::unique_ptr<transport::UnboundBuffer>();
34+
return std::make_unique<UnboundBuffer>(this->shared_from_this(), ptr, size);
35+
}
36+
37+
void Context::signalException(const std::string& msg) {
38+
// The `pairs_` vector is logically constant. After the context and
39+
// all of its pairs have been created it is not mutated until the
40+
// context is destructed. Therefore, we don't need to acquire this
41+
// context's instance lock before looping over `pairs_`.
42+
for (auto& pair : pairs_) {
43+
if (pair) {
44+
reinterpret_cast<ibverbs::Pair*>(pair.get())->signalIoFailure(msg);
45+
}
46+
}
3647
}
3748

3849
} // namespace ibverbs

gloo/transport/ibverbs/context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ class Context : public ::gloo::transport::Context,
3333
void* ptr,
3434
size_t size) override;
3535

36+
// Set exception on every pair in this context. This is called when
37+
// waiting for a send or recv operation on an unbound buffer times
38+
// out. All pairs should be signaled and closed in that event.
39+
void signalException(const std::string& msg);
40+
3641
protected:
3742
std::shared_ptr<Device> device_;
3843

3944
friend class Pair;
45+
friend class UnboundBuffer;
4046
};
4147

4248
} // namespace ibverbs

0 commit comments

Comments
 (0)