Skip to content

Commit 6c58943

Browse files
d4l3kfacebook-github-bot
authored andcommitted
gloo: add connection retries (#413)
Summary: This adds connection retries to Gloo in order to try and mitigate issues related to connection timeouts. Reviewed By: c00w, fduwjj, XilunWu Differential Revision: D70345714
1 parent 5ca057d commit 6c58943

File tree

13 files changed

+300
-28
lines changed

13 files changed

+300
-28
lines changed

gloo/common/logging.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <climits>
1212
#include <exception>
1313
#include <functional>
14+
#include <iostream>
1415
#include <limits>
1516
#include <vector>
1617

@@ -156,4 +157,7 @@ BINARY_COMP_HELPER(LessEquals, <=)
156157
#define GLOO_ENFORCE_GT(x, y, ...) \
157158
GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__)
158159

160+
#define GLOO_ERROR(...) \
161+
std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl
162+
159163
} // namespace gloo

gloo/common/utils.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,14 @@ bool isStoreExtendedApiEnabled() {
4242
(std::string(res) == "True" || std::string(res) == "1");
4343
}
4444

45+
bool disableConnectionRetries() {
46+
// use meyer singleton to only compute this exactly once.
47+
static bool disable = []() {
48+
const auto& res = std::getenv("GLOO_DISABLE_CONNECTION_RETRIES");
49+
return res != nullptr &&
50+
(std::string(res) == "True" || std::string(res) == "1");
51+
}();
52+
return disable;
53+
}
54+
4555
} // namespace gloo

gloo/common/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ bool useRankAsSeqNumber();
1818

1919
bool isStoreExtendedApiEnabled();
2020

21+
bool disableConnectionRetries();
22+
2123
} // namespace gloo

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
2424
"${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc"
2525
"${CMAKE_CURRENT_SOURCE_DIR}/multiproc_test.cc"
2626
"${CMAKE_CURRENT_SOURCE_DIR}/transport_test.cc"
27+
"${CMAKE_CURRENT_SOURCE_DIR}/tcp_test.cc"
2728
)
2829
list(APPEND GLOO_TEST_LIBRARIES rt)
2930
endif()

gloo/test/tcp_test.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <gloo/transport/tcp/helpers.h>
4+
#include <gloo/transport/tcp/loop.h>
5+
6+
namespace gloo {
7+
namespace transport {
8+
namespace tcp {
9+
10+
TEST(TcpTest, ConnectTimeout) {
11+
auto loop = std::make_shared<Loop>();
12+
13+
std::mutex m;
14+
std::condition_variable cv;
15+
bool done = false;
16+
17+
// Use bad address
18+
auto remote = Address("::1", 10);
19+
auto timeout = std::chrono::milliseconds(100);
20+
auto fn = [&](std::shared_ptr<Socket>, const Error& e) {
21+
std::lock_guard<std::mutex> lock(m);
22+
done = true;
23+
cv.notify_all();
24+
25+
EXPECT_TRUE(e);
26+
EXPECT_TRUE(dynamic_cast<const TimeoutError*>(&e));
27+
};
28+
connectLoop(loop, remote, timeout, std::move(fn));
29+
30+
std::unique_lock<std::mutex> lock(m);
31+
cv.wait(lock, [&] { return done; });
32+
}
33+
34+
} // namespace tcp
35+
} // namespace transport
36+
} // namespace gloo

gloo/transport/tcp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ else()
77
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
88
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
99
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
10+
"${CMAKE_CURRENT_SOURCE_DIR}/helpers.cc"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/listener.cc"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/loop.cc"
1213
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"

gloo/transport/tcp/address.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ Address::Address(const struct sockaddr* addr, size_t addrlen) {
2828
memcpy(&impl_.ss, addr, addrlen);
2929
}
3030

31+
Address::Address(const std::string& ip, uint16_t port, sequence_number_t seq) {
32+
if (ip.empty()) {
33+
throw std::invalid_argument("Invalid IP address");
34+
}
35+
sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&impl_.ss);
36+
sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&impl_.ss);
37+
// Check if the IP address is an IPv4 or IPv6 address
38+
if (inet_pton(AF_INET, ip.c_str(), &addr4->sin_addr) == 1) {
39+
// IPv4 address
40+
addr4->sin_family = AF_INET;
41+
addr4->sin_port = htons(port);
42+
} else if (inet_pton(AF_INET6, ip.c_str(), &addr6->sin6_addr) == 1) {
43+
// IPv6 address
44+
addr6->sin6_family = AF_INET6;
45+
addr6->sin6_port = htons(port);
46+
} else {
47+
throw std::invalid_argument("Invalid IP address");
48+
}
49+
50+
// Store sequence number
51+
impl_.seq = seq;
52+
}
53+
3154
Address& Address::operator=(Address&& other) {
3255
std::lock_guard<std::mutex> lock(m_);
3356
impl_.ss = std::move(other.impl_.ss);

gloo/transport/tcp/address.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
#pragma once
1010

11-
#include <sys/socket.h>
12-
#include <unistd.h>
1311
#include <mutex>
1412

13+
#ifdef _WIN32
14+
#include "gloo/common/win.h" // @manual
15+
#else
16+
#include <sys/socket.h>
17+
#endif
18+
1519
#include "gloo/transport/address.h"
1620

1721
namespace gloo {
@@ -32,6 +36,11 @@ class Address : public ::gloo::transport::Address {
3236

3337
explicit Address(const std::vector<char>&);
3438

39+
explicit Address(
40+
const std::string& ip,
41+
uint16_t port,
42+
sequence_number_t seq = -1);
43+
3544
Address(const Address& other);
3645

3746
Address& operator=(Address&& other);

gloo/transport/tcp/device.cc

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "gloo/common/error.h"
1818
#include "gloo/common/linux.h"
1919
#include "gloo/common/logging.h"
20+
#include "gloo/common/utils.h"
2021
#include "gloo/transport/tcp/context.h"
2122
#include "gloo/transport/tcp/helpers.h"
2223
#include "gloo/transport/tcp/pair.h"
@@ -334,20 +335,39 @@ void Device::connectAsListener(
334335
//
335336
void Device::connectAsInitiator(
336337
const Address& remote,
337-
std::chrono::milliseconds /* unused */,
338+
std::chrono::milliseconds timeout,
338339
connect_callback_t fn) {
339-
const auto& sockaddr = remote.getSockaddr();
340-
341-
// Create new socket to connect to peer.
342-
auto socket = Socket::createForFamily(sockaddr.ss_family);
343-
socket->reuseAddr(true);
344-
socket->noDelay(true);
345-
socket->connect(sockaddr);
346-
347-
// Write sequence number for peer to new socket.
348-
// TODO(pietern): Use timeout.
349-
write<sequence_number_t>(
350-
loop_, std::move(socket), remote.getSeq(), std::move(fn));
340+
auto writeSeq = [loop = loop_, seq = remote.getSeq()](
341+
std::shared_ptr<Socket> socket, connect_callback_t fn) {
342+
// Write sequence number for peer to new socket.
343+
write<sequence_number_t>(loop, std::move(socket), seq, std::move(fn));
344+
};
345+
346+
if (disableConnectionRetries()) {
347+
const auto& sockaddr = remote.getSockaddr();
348+
349+
// Create new socket to connect to peer.
350+
auto socket = Socket::createForFamily(sockaddr.ss_family);
351+
socket->reuseAddr(true);
352+
socket->noDelay(true);
353+
socket->connect(sockaddr);
354+
355+
writeSeq(std::move(socket), std::move(fn));
356+
} else {
357+
connectLoop(
358+
loop_,
359+
remote,
360+
timeout,
361+
[loop = loop_, fn = std::move(fn), writeSeq = std::move(writeSeq)](
362+
std::shared_ptr<Socket> socket, const Error& error) {
363+
if (error) {
364+
fn(socket, error);
365+
return;
366+
}
367+
368+
writeSeq(std::move(socket), std::move(fn));
369+
});
370+
}
351371
}
352372

353373
} // namespace tcp

gloo/transport/tcp/error.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,32 @@ std::string Error::what() const {
2323

2424
std::string SystemError::what() const {
2525
std::ostringstream ss;
26-
ss << syscall_ << ": " << strerror(error_);
26+
ss << syscall_ << ": " << strerror(error_) << ", remote=" << remote_.str();
2727
return ss.str();
2828
}
2929

3030
std::string ShortReadError::what() const {
3131
std::ostringstream ss;
3232
ss << "short read: got " << actual_ << " bytes while expecting to read "
33-
<< expected_ << " bytes";
33+
<< expected_ << " bytes, remote=" << remote_.str();
3434
return ss.str();
3535
}
3636

3737
std::string ShortWriteError::what() const {
3838
std::ostringstream ss;
3939
ss << "short write: wrote " << actual_ << " bytes while expecting to write "
40-
<< expected_ << " bytes";
40+
<< expected_ << " bytes, remote=" << remote_.str();
4141
return ss.str();
4242
}
4343

44+
std::string TimeoutError::what() const {
45+
return msg_;
46+
}
47+
48+
std::string LoopError::what() const {
49+
return msg_;
50+
}
51+
4452
} // namespace tcp
4553
} // namespace transport
4654
} // namespace gloo

gloo/transport/tcp/error.h

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <gloo/transport/tcp/address.h>
1112
#include <string>
1213

1314
namespace gloo {
@@ -52,38 +53,70 @@ class Error {
5253

5354
class SystemError : public Error {
5455
public:
55-
explicit SystemError(const char* syscall, int error)
56-
: Error(true), syscall_(syscall), error_(error) {}
56+
explicit SystemError(const char* syscall, int error, Address remote)
57+
: Error(true),
58+
syscall_(syscall),
59+
error_(error),
60+
remote_(std::move(remote)) {}
5761

5862
std::string what() const override;
5963

6064
private:
6165
const char* syscall_;
6266
const int error_;
67+
const Address remote_;
6368
};
6469

6570
class ShortReadError : public Error {
6671
public:
67-
ShortReadError(ssize_t expected, ssize_t actual)
68-
: Error(true), expected_(expected), actual_(actual) {}
72+
ShortReadError(ssize_t expected, ssize_t actual, Address remote)
73+
: Error(true),
74+
expected_(expected),
75+
actual_(actual),
76+
remote_(std::move(remote)) {}
6977

7078
std::string what() const override;
7179

7280
private:
7381
const ssize_t expected_;
7482
const ssize_t actual_;
83+
const Address remote_;
7584
};
7685

7786
class ShortWriteError : public Error {
7887
public:
79-
ShortWriteError(ssize_t expected, ssize_t actual)
80-
: Error(true), expected_(expected), actual_(actual) {}
88+
ShortWriteError(ssize_t expected, ssize_t actual, Address remote)
89+
: Error(true),
90+
expected_(expected),
91+
actual_(actual),
92+
remote_(std::move(remote)) {}
8193

8294
std::string what() const override;
8395

8496
private:
8597
const ssize_t expected_;
8698
const ssize_t actual_;
99+
const Address remote_;
100+
};
101+
102+
class TimeoutError : public Error {
103+
public:
104+
explicit TimeoutError(std::string msg) : Error(true), msg_(std::move(msg)) {}
105+
106+
std::string what() const override;
107+
108+
private:
109+
const std::string msg_;
110+
};
111+
112+
class LoopError : public Error {
113+
public:
114+
explicit LoopError(std::string msg) : Error(true), msg_(std::move(msg)) {}
115+
116+
std::string what() const override;
117+
118+
private:
119+
const std::string msg_;
87120
};
88121

89122
} // namespace tcp

gloo/transport/tcp/helpers.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <gloo/transport/tcp/helpers.h>
2+
3+
namespace gloo {
4+
namespace transport {
5+
namespace tcp {
6+
7+
void connectLoop(
8+
std::shared_ptr<Loop> loop,
9+
const Address& remote,
10+
std::chrono::milliseconds timeout,
11+
typename ConnectOperation::callback_t fn) {
12+
auto x = std::make_shared<ConnectOperation>(
13+
std::move(loop), remote, timeout, std::move(fn));
14+
x->run();
15+
}
16+
17+
} // namespace tcp
18+
} // namespace transport
19+
} // namespace gloo

0 commit comments

Comments
 (0)