Skip to content

Commit 522dd5c

Browse files
committed
Perform platform-specific initialization in socket service
Mainly, Windows needs WSAStartup before creating sockets.
1 parent 951d318 commit 522dd5c

10 files changed

+71
-47
lines changed

fly/net/socket/detail/nix/socket_operations.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ namespace {
8989

9090
} // namespace
9191

92+
//==================================================================================================
93+
void initialize()
94+
{
95+
}
96+
97+
//==================================================================================================
98+
void deinitialize()
99+
{
100+
}
101+
92102
//==================================================================================================
93103
fly::net::socket_type invalid_socket()
94104
{

fly/net/socket/detail/socket_operations.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99

1010
namespace fly::net::detail {
1111

12+
/**
13+
* Perform any platform-specific actions needed to initialize network services.
14+
*/
15+
void initialize();
16+
17+
/**
18+
* Perform any platform-specific actions needed to deinitialize network services.
19+
*/
20+
void deinitialize();
21+
1222
/**
1323
* @return Invalid socket handle for the target system.
1424
*/

fly/net/socket/detail/win/socket_operations.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ws2ipdef.h>
1313
// clang-format on
1414

15+
#include <atomic>
1516
#include <limits>
1617
#include <type_traits>
1718

@@ -87,8 +88,34 @@ namespace {
8788
return reinterpret_cast<const sockaddr *>(&address);
8889
}
8990

91+
std::atomic_uint64_t s_initialized_services_count {0};
92+
9093
} // namespace
9194

95+
//==================================================================================================
96+
void initialize()
97+
{
98+
if (s_initialized_services_count.fetch_add(1) == 0)
99+
{
100+
WORD version = MAKEWORD(2, 2);
101+
WSADATA wsadata;
102+
103+
if (WSAStartup(version, &wsadata) != 0)
104+
{
105+
deinitialize();
106+
}
107+
}
108+
}
109+
110+
//==================================================================================================
111+
void deinitialize()
112+
{
113+
if (s_initialized_services_count.fetch_sub(1) == 1)
114+
{
115+
WSACleanup();
116+
}
117+
}
118+
92119
//==================================================================================================
93120
fly::net::socket_type invalid_socket()
94121
{

fly/net/socket/socket_service.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ SocketService::SocketService(const std::shared_ptr<fly::SequencedTaskRunner> &ta
3030
:
3131
m_task_runner(task_runner)
3232
{
33+
fly::net::detail::initialize();
34+
}
35+
36+
//==================================================================================================
37+
SocketService::~SocketService() noexcept
38+
{
39+
fly::net::detail::deinitialize();
3340
}
3441

3542
//==================================================================================================

fly/net/socket/socket_service.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ namespace fly::net {
2222
class SocketService : public std::enable_shared_from_this<SocketService>
2323
{
2424
public:
25+
/**
26+
* Destructor. Deinitialize the socket service.
27+
*/
28+
~SocketService() noexcept;
29+
2530
/**
2631
* Create a socket service.
2732
*

test/net/listen_socket.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ CATCH_TEMPLATE_TEST_CASE("ListenSocket", "[net]", fly::net::IPv4Address, fly::ne
3333
using EndpointType = fly::net::Endpoint<IPAddressType>;
3434
using ListenSocket = fly::net::ListenSocket<EndpointType>;
3535

36-
#if defined(FLY_WINDOWS)
37-
fly::test::ScopedWindowsSocketAPI::create();
38-
#endif
36+
fly::test::ScopedSocketServiceSetup::create();
3937

4038
constexpr const auto in_addr_any = IPAddressType::in_addr_any();
4139
constexpr const auto in_addr_loopback = IPAddressType::in_addr_loopback();
@@ -299,10 +297,6 @@ CATCH_TEMPLATE_TEST_CASE("AsyncListenSocket", "[net]", fly::net::IPv4Address, fl
299297
using ListenSocket = fly::net::ListenSocket<EndpointType>;
300298
using TcpSocket = fly::net::TcpSocket<EndpointType>;
301299

302-
#if defined(FLY_WINDOWS)
303-
fly::test::ScopedWindowsSocketAPI::create();
304-
#endif
305-
306300
auto task_runner = fly::test::task_manager()->create_task_runner<fly::SequencedTaskRunner>();
307301
auto socket_service = fly::net::SocketService::create(task_runner);
308302
fly::test::Signal signal;

test/net/socket_service.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ CATCH_TEMPLATE_TEST_CASE("SocketService", "[net]", fly::net::IPv4Address, fly::n
3535
using EndpointType = fly::net::Endpoint<IPAddressType>;
3636
using UdpSocket = fly::net::UdpSocket<EndpointType>;
3737

38-
#if defined(FLY_WINDOWS)
39-
fly::test::ScopedWindowsSocketAPI::create();
40-
#endif
41-
4238
auto task_runner = fly::test::task_manager()->create_task_runner<fly::SequencedTaskRunner>();
4339
auto socket_service = fly::net::SocketService::create(task_runner);
4440
fly::test::Signal signal;

test/net/socket_util.hpp

+9-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "fly/fly.hpp"
4+
#include "fly/net/socket/detail/socket_operations.hpp"
45
#include "fly/net/socket/socket_types.hpp"
56
#include "fly/types/concurrency/concurrent_queue.hpp"
67

@@ -9,10 +10,6 @@
910
#include <chrono>
1011
#include <future>
1112

12-
#if defined(FLY_WINDOWS)
13-
# include <Windows.h>
14-
#endif
15-
1613
namespace fly::test {
1714

1815
/**
@@ -68,39 +65,29 @@ class Signal
6865
fly::ConcurrentQueue<int> m_signal;
6966
};
7067

71-
#if defined(FLY_WINDOWS)
72-
7368
/**
74-
* On Windows, WSAStartup must be invoked before any sockets may be created. Until a new socket
75-
* service is created in fly::net to ensure that, this class may be used by unit tests.
69+
* Perform platform-specific socket service initialization for tests that do not need to use the
70+
* socket service itself.
7671
*/
77-
class ScopedWindowsSocketAPI
72+
class ScopedSocketServiceSetup
7873
{
7974
public:
8075
static inline void create()
8176
{
82-
static ScopedWindowsSocketAPI s_instance;
77+
static ScopedSocketServiceSetup s_instance;
8378
FLY_UNUSED(s_instance);
8479
}
8580

8681
private:
87-
inline ScopedWindowsSocketAPI()
82+
inline ScopedSocketServiceSetup()
8883
{
89-
WORD version = MAKEWORD(2, 2);
90-
WSADATA wsadata;
91-
92-
if (WSAStartup(version, &wsadata) != 0)
93-
{
94-
WSACleanup();
95-
}
84+
fly::net::detail::initialize();
9685
}
9786

98-
inline ~ScopedWindowsSocketAPI()
87+
inline ~ScopedSocketServiceSetup()
9988
{
100-
WSACleanup();
89+
fly::net::detail::deinitialize();
10190
}
10291
};
10392

104-
#endif
105-
10693
} // namespace fly::test

test/net/tcp_socket.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ CATCH_TEMPLATE_TEST_CASE("TcpSocket", "[net]", fly::net::IPv4Address, fly::net::
3737
using ListenSocket = fly::net::ListenSocket<EndpointType>;
3838
using TcpSocket = fly::net::TcpSocket<EndpointType>;
3939

40-
#if defined(FLY_WINDOWS)
41-
fly::test::ScopedWindowsSocketAPI::create();
42-
#endif
40+
fly::test::ScopedSocketServiceSetup::create();
4341

4442
const std::string message(fly::String::generate_random_string(1 << 10));
4543
constexpr const auto in_addr_loopback = IPAddressType::in_addr_loopback();
@@ -323,10 +321,6 @@ CATCH_TEMPLATE_TEST_CASE("AsyncTcpSocket", "[net]", fly::net::IPv4Address, fly::
323321
using ListenSocket = fly::net::ListenSocket<EndpointType>;
324322
using TcpSocket = fly::net::TcpSocket<EndpointType>;
325323

326-
#if defined(FLY_WINDOWS)
327-
fly::test::ScopedWindowsSocketAPI::create();
328-
#endif
329-
330324
auto task_runner = fly::test::task_manager()->create_task_runner<fly::SequencedTaskRunner>();
331325
auto socket_service = fly::net::SocketService::create(task_runner);
332326

test/net/udp_socket.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ CATCH_TEMPLATE_TEST_CASE("UdpSocket", "[net]", fly::net::IPv4Address, fly::net::
3535
using EndpointType = fly::net::Endpoint<IPAddressType>;
3636
using UdpSocket = fly::net::UdpSocket<EndpointType>;
3737

38-
#if defined(FLY_WINDOWS)
39-
fly::test::ScopedWindowsSocketAPI::create();
40-
#endif
38+
fly::test::ScopedSocketServiceSetup::create();
4139

4240
const std::string message(fly::String::generate_random_string(1 << 10));
4341
constexpr const auto in_addr_any = IPAddressType::in_addr_any();
@@ -259,10 +257,6 @@ CATCH_TEMPLATE_TEST_CASE("AsyncUdpSocket", "[net]", fly::net::IPv4Address, fly::
259257
using EndpointType = fly::net::Endpoint<IPAddressType>;
260258
using UdpSocket = fly::net::UdpSocket<EndpointType>;
261259

262-
#if defined(FLY_WINDOWS)
263-
fly::test::ScopedWindowsSocketAPI::create();
264-
#endif
265-
266260
auto task_runner = fly::test::task_manager()->create_task_runner<fly::SequencedTaskRunner>();
267261
auto socket_service = fly::net::SocketService::create(task_runner);
268262

0 commit comments

Comments
 (0)