Skip to content

Commit 685dbed

Browse files
d4l3kfacebook-github-bot
authored andcommitted
c10d/gloo: add ibverbs backend (#437)
Summary: X-link: pytorch/pytorch#153015 This provides a new "UnboundBuffer" implementation for Gloo ibverbs backend so it can be used with PyTorch. This currently is passing basic tests such as `reduce_test` and `send_recv_test` but there are a number of failures. Putting this up for review so the follow up fixes are less of a mega PR and also so we can start doing some initial testing with this E2E with PyTorch. Known issues: * support recv from any is not supported * AllreduceBcubeBase2 is failing Reviewed By: fduwjj Differential Revision: D73291471
1 parent 4ecd9ce commit 685dbed

17 files changed

+744
-114
lines changed

gloo/allreduce_ring_chunked.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,19 @@ class AllreduceRingChunked : public Algorithm {
145145
fn_->call(&ptrs_[0][offset], inbox_[chunkOffset & 1], length);
146146
}
147147

148+
GLOO_DEBUG("START sendNotification round=", round);
149+
148150
// Send notification to node on the left that
149151
// this node is ready for an inbox write.
150152
sendNotificationBuf_->send();
151153

154+
GLOO_DEBUG("START recvNotification round=", round);
152155
// Wait for notification from node on the right
153156
// to be sure this node can start an inbox write.
154157
recvNotificationBuf_->waitRecv();
155158

159+
GLOO_DEBUG("DONE recvNotification round=", round);
160+
156161
// Copy accumulated chunk
157162
copyChunkAtOffset(chunkOffset);
158163
}

gloo/common/logging.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,20 @@
1515
#include <limits>
1616
#include <vector>
1717

18+
#include "gloo/common/error.h"
1819
#include "gloo/common/string.h"
1920

2021
namespace gloo {
2122

23+
#define GLOO_LOG_MSG(level, ...) \
24+
std::cerr << ::gloo::MakeString( \
25+
"[", __FILE__, ":", __LINE__, "] ", level, " ", __VA_ARGS__, "\n")
26+
27+
#define GLOO_INFO(...) GLOO_LOG_MSG("INFO", __VA_ARGS__)
28+
#define GLOO_ERROR(...) GLOO_LOG_MSG("ERROR", __VA_ARGS__)
29+
#define GLOO_WARN(...) GLOO_LOG_MSG("WARN", __VA_ARGS__)
30+
#define GLOO_DEBUG(...) // GLOO_LOG_MSG("DEBUG", __VA_ARGS__)
31+
2232
class EnforceNotMet : public std::exception {
2333
public:
2434
EnforceNotMet(
@@ -157,7 +167,4 @@ BINARY_COMP_HELPER(LessEquals, <=)
157167
#define GLOO_ENFORCE_GT(x, y, ...) \
158168
GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__)
159169

160-
#define GLOO_ERROR(...) \
161-
std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl
162-
163170
} // namespace gloo

gloo/test/base_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,35 @@ const char* kDefaultDevice = "localhost";
1616

1717
// Transports that instantiated algorithms can be tested against.
1818
const std::vector<Transport> kTransportsForClassAlgorithms = {
19+
#if GLOO_HAVE_TRANSPORT_TCP
1920
Transport::TCP,
2021
Transport::TCP_LAZY,
22+
#endif
2123
#if GLOO_HAVE_TRANSPORT_TCP_TLS
2224
Transport::TCP_TLS,
2325
#endif
26+
#if GLOO_HAVE_TRANSPORT_IBVERBS
27+
Transport::IBVERBS,
28+
#endif
2429
};
2530

2631
// Transports that function algorithms can be tested against.
2732
// This is the new style of calling collectives and must be
2833
// preferred over the instantiated style.
2934
const std::vector<Transport> kTransportsForFunctionAlgorithms = {
35+
#if GLOO_HAVE_TRANSPORT_TCP
3036
Transport::TCP,
3137
Transport::TCP_LAZY,
38+
#endif
3239
#if GLOO_HAVE_TRANSPORT_TCP_TLS
3340
Transport::TCP_TLS,
3441
#endif
42+
#if GLOO_HAVE_TRANSPORT_UV
3543
Transport::UV,
44+
#endif
45+
#if GLOO_HAVE_TRANSPORT_IBVERBS
46+
Transport::IBVERBS,
47+
#endif
3648
};
3749

3850
std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
@@ -59,6 +71,17 @@ std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport) {
5971
return ::gloo::transport::uv::CreateDevice(kDefaultDevice);
6072
#endif
6173
}
74+
#endif
75+
#if GLOO_HAVE_TRANSPORT_IBVERBS
76+
if (transport == Transport::IBVERBS) {
77+
gloo::transport::ibverbs::attr attr;
78+
attr.port = 1;
79+
try {
80+
return ::gloo::transport::ibverbs::CreateDevice(attr);
81+
} catch (const InvalidOperationException& e) {
82+
GLOO_INFO("IBVERBS not available: ", e.what());
83+
}
84+
}
6285
#endif
6386
return nullptr;
6487
}

gloo/test/base_test.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
#include "gloo/transport/uv/device.h"
3434
#endif
3535

36+
#if GLOO_HAVE_TRANSPORT_IBVERBS
37+
#include "gloo/transport/ibverbs/device.h"
38+
#endif
39+
3640
namespace gloo {
3741
namespace test {
3842

@@ -64,6 +68,7 @@ enum Transport {
6468
TCP_TLS,
6569
#endif
6670
UV,
71+
IBVERBS,
6772
};
6873

6974
extern const std::vector<Transport> kTransportsForClassAlgorithms;
@@ -117,6 +122,7 @@ class BaseTest : public ::testing::Test {
117122
// socket address.
118123
auto device = device_creator(transport);
119124
if (!device) {
125+
GTEST_SKIP() << "Skipping test: transport not available";
120126
return;
121127
}
122128
context->connectFullMesh(store, device);

gloo/test/send_recv_test.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include <array>
1313
#include <unordered_set>
1414

15-
#include "gloo/transport/tcp/unbound_buffer.h"
16-
1715
namespace gloo {
1816
namespace test {
1917
namespace {
@@ -515,7 +513,7 @@ INSTANTIATE_TEST_CASE_P(
515513
SendRecvDefault,
516514
SendRecvTest,
517515
::testing::Combine(
518-
::testing::Values(Transport::TCP, Transport::UV),
516+
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
519517
::testing::Values(2, 3, 4, 5, 6, 7, 8),
520518
::testing::Values(1)));
521519

gloo/test/tcp_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <gtest/gtest.h>
22

3+
#if GLOO_HAVE_TRANSPORT_TCP
4+
35
#include <gloo/transport/tcp/helpers.h>
46
#include <gloo/transport/tcp/loop.h>
57

@@ -34,3 +36,5 @@ TEST(TcpTest, ConnectTimeout) {
3436
} // namespace tcp
3537
} // namespace transport
3638
} // namespace gloo
39+
40+
#endif // GLOO_HAVE_TRANSPORT_TCP

gloo/transport/ibverbs/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ list(APPEND GLOO_TRANSPORT_SRCS
55
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
66
"${CMAKE_CURRENT_SOURCE_DIR}/memory_region.cc"
77
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"
8+
"${CMAKE_CURRENT_SOURCE_DIR}/unbound_buffer.cc"
89
)
910

1011
list(APPEND GLOO_TRANSPORT_HDRS
@@ -14,6 +15,7 @@ list(APPEND GLOO_TRANSPORT_HDRS
1415
"${CMAKE_CURRENT_SOURCE_DIR}/device.h"
1516
"${CMAKE_CURRENT_SOURCE_DIR}/memory_region.h"
1617
"${CMAKE_CURRENT_SOURCE_DIR}/pair.h"
18+
"${CMAKE_CURRENT_SOURCE_DIR}/unbound_buffer.h"
1719
)
1820

1921
set(GLOO_TRANSPORT_SRCS ${GLOO_TRANSPORT_SRCS} PARENT_SCOPE)

gloo/transport/ibverbs/buffer.cc

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ Buffer::Buffer(Pair* pair, int slot, void* ptr, size_t size)
3030
ex_(nullptr) {
3131
mr_ = ibv_reg_mr(
3232
pair_->dev_->pd_,
33-
ptr_,
34-
size_,
33+
size == 0 ? static_cast<void*>(&emptyBuf_) : ptr,
34+
size == 0 ? sizeof(emptyBuf_) : size,
3535
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
3636

3737
// Provide hint if the error is EFAULT and nv_peer_mem is not loaded
@@ -59,7 +59,12 @@ Buffer::Buffer(Pair* pair, int slot, void* ptr, size_t size)
5959
}
6060

6161
Buffer::~Buffer() {
62-
GLOO_ENFORCE_EQ(sendPending_, 0, "Destructing buffer expecting completions");
62+
std::lock_guard<std::mutex> lock(m_);
63+
if (sendPending_ > 0) {
64+
GLOO_WARN(
65+
"Destructing buffer with pending sends, sendPending_=", sendPending_);
66+
}
67+
6368
ibv_dereg_mr(mr_);
6469
}
6570

@@ -167,36 +172,40 @@ void Buffer::waitSend() {
167172
}
168173

169174
void Buffer::send(size_t offset, size_t length, size_t roffset) {
170-
// Can't assert on roffset, since we don't know the size of
171-
// the remote buffer. Refactor of initialization code needed
172-
// to support this.
173-
GLOO_ENFORCE_LE(offset + length, size_);
174-
175175
{
176176
std::unique_lock<std::mutex> lock(m_);
177+
178+
// Can't assert on roffset, since we don't know the size of
179+
// the remote buffer. Refactor of initialization code needed
180+
// to support this.
181+
GLOO_ENFORCE_LE(offset + length, size_);
182+
177183
checkErrorState();
178-
}
179184

180-
if (debug_) {
181-
std::cout << "[" << getpid() << "] ";
182-
std::cout << "send " << length << " bytes";
183-
std::cout << std::endl;
185+
if (debug_) {
186+
std::cout << "[" << getpid() << "] ";
187+
std::cout << "send " << length << " bytes";
188+
std::cout << std::endl;
189+
}
190+
191+
// Increment number of sends in flight
192+
sendPending_++;
184193
}
185194

186-
// Increment number of sends in flight
187-
sendPending_++;
195+
// Release lock before calling into the pair to avoid deadlock.
188196

189197
pair_->send(this, offset, length, roffset);
190198
}
191199

192-
void Buffer::handleCompletion(struct ibv_wc* wc) {
200+
void Buffer::handleCompletion(int rank, struct ibv_wc* wc) {
201+
std::unique_lock<std::mutex> lock(m_);
202+
193203
if (wc->opcode & IBV_WC_RECV) {
194204
if (debug_) {
195205
std::cout << "[" << getpid() << "] ";
196206
std::cout << "recv " << wc->byte_len << " bytes";
197207
std::cout << std::endl;
198208
}
199-
std::unique_lock<std::mutex> lock(m_);
200209
recvCompletions_++;
201210
recvCv_.notify_one();
202211
} else if (wc->opcode == IBV_WC_RDMA_WRITE) {
@@ -205,7 +214,6 @@ void Buffer::handleCompletion(struct ibv_wc* wc) {
205214
std::cout << "send complete";
206215
std::cout << std::endl;
207216
}
208-
std::unique_lock<std::mutex> lock(m_);
209217
sendCompletions_++;
210218
sendPending_--;
211219
sendCv_.notify_one();

gloo/transport/ibverbs/buffer.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace gloo {
2424
namespace transport {
2525
namespace ibverbs {
2626

27-
class Buffer : public ::gloo::transport::Buffer {
27+
class Buffer : public ::gloo::transport::Buffer, public BufferHandler {
2828
public:
2929
virtual ~Buffer();
3030

@@ -33,18 +33,26 @@ class Buffer : public ::gloo::transport::Buffer {
3333
virtual void waitRecv() override;
3434
virtual void waitSend() override;
3535

36-
void handleCompletion(struct ibv_wc* wc);
36+
void handleCompletion(int rank, struct ibv_wc* wc) override;
3737

38-
void signalError(const std::exception_ptr& ex);
38+
void signalError(const std::exception_ptr& ex) override;
3939
void checkErrorState();
4040

41+
bool isPeristentHandler() override {
42+
return true;
43+
}
44+
4145
protected:
4246
// May only be constructed from helper function in pair.cc
4347
Buffer(Pair* pair, int slot, void* ptr, size_t size);
4448

4549
Pair* pair_;
4650

51+
// Empty buffer to use when a nullptr buffer is created.
52+
char emptyBuf_[1];
53+
4754
struct ibv_mr* mr_;
55+
std::unique_ptr<struct ibv_mr> peerMr_;
4856

4957
std::mutex m_;
5058
std::condition_variable recvCv_;

gloo/transport/ibverbs/context.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "gloo/common/error.h"
1212
#include "gloo/transport/ibverbs/device.h"
1313
#include "gloo/transport/ibverbs/pair.h"
14+
#include "gloo/transport/ibverbs/unbound_buffer.h"
1415

1516
namespace gloo {
1617
namespace transport {
@@ -23,16 +24,26 @@ Context::~Context() {}
2324

2425
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
2526
pairs_[rank] = std::unique_ptr<transport::Pair>(
26-
new ibverbs::Pair(device_, getTimeout()));
27+
new ibverbs::Pair(rank, device_, getTimeout()));
2728
return pairs_[rank];
2829
}
2930

3031
std::unique_ptr<transport::UnboundBuffer> Context::createUnboundBuffer(
3132
void* ptr,
3233
size_t size) {
33-
GLOO_THROW_INVALID_OPERATION_EXCEPTION(
34-
"Unbound buffers not supported yet for ibverbs transport");
35-
return std::unique_ptr<transport::UnboundBuffer>();
34+
return std::make_unique<UnboundBuffer>(this->shared_from_this(), ptr, size);
35+
}
36+
37+
void Context::signalException(const std::string& msg) {
38+
// The `pairs_` vector is logically constant. After the context and
39+
// all of its pairs have been created it is not mutated until the
40+
// context is destructed. Therefore, we don't need to acquire this
41+
// context's instance lock before looping over `pairs_`.
42+
for (auto& pair : pairs_) {
43+
if (pair) {
44+
reinterpret_cast<ibverbs::Pair*>(pair.get())->signalIoFailure(msg);
45+
}
46+
}
3647
}
3748

3849
} // namespace ibverbs

gloo/transport/ibverbs/context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ class Context : public ::gloo::transport::Context,
3333
void* ptr,
3434
size_t size) override;
3535

36+
// Set exception on every pair in this context. This is called when
37+
// waiting for a send or recv operation on an unbound buffer times
38+
// out. All pairs should be signaled and closed in that event.
39+
void signalException(const std::string& msg);
40+
3641
protected:
3742
std::shared_ptr<Device> device_;
3843

3944
friend class Pair;
45+
friend class UnboundBuffer;
4046
};
4147

4248
} // namespace ibverbs

0 commit comments

Comments
 (0)