|
| 1 | +// Copyright (c) 2023-present The Bitcoin Core developers |
| 2 | +// Distributed under the MIT software license, see the accompanying |
| 3 | +// file COPYING or http://www.opensource.org/licenses/mit-license.php. |
| 4 | + |
| 5 | +#include <common/sv2_connman.h> |
| 6 | +#include <common/sv2_messages.h> |
| 7 | +#include <logging.h> |
| 8 | +#include <sync.h> |
| 9 | +#include <util/thread.h> |
| 10 | + |
| 11 | +using node::Sv2MsgType; |
| 12 | + |
| 13 | +Sv2Connman::~Sv2Connman() |
| 14 | +{ |
| 15 | + AssertLockNotHeld(m_clients_mutex); |
| 16 | + |
| 17 | + { |
| 18 | + LOCK(m_clients_mutex); |
| 19 | + for (const auto& client : m_sv2_clients) { |
| 20 | + LogTrace(BCLog::SV2, "Disconnecting client id=%zu\n", |
| 21 | + client->m_id); |
| 22 | + client->m_disconnect_flag = true; |
| 23 | + } |
| 24 | + DisconnectFlagged(); |
| 25 | + } |
| 26 | + |
| 27 | + Interrupt(); |
| 28 | + StopThreads(); |
| 29 | +} |
| 30 | + |
| 31 | +bool Sv2Connman::Start(Sv2EventsInterface* msgproc, std::string host, uint16_t port) |
| 32 | +{ |
| 33 | + m_msgproc = msgproc; |
| 34 | + |
| 35 | + try { |
| 36 | + auto sock = BindListenPort(host, port); |
| 37 | + m_listening_socket = std::move(sock); |
| 38 | + } catch (const std::runtime_error& e) { |
| 39 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Template Provider failed to bind to port %d: %s\n", port, e.what()); |
| 40 | + return false; |
| 41 | + } |
| 42 | + |
| 43 | + m_thread_sv2_handler = std::thread(&util::TraceThread, "sv2connman", [this] { ThreadSv2Handler(); }); |
| 44 | + return true; |
| 45 | +} |
| 46 | + |
| 47 | +std::shared_ptr<Sock> Sv2Connman::BindListenPort(std::string host, uint16_t port) const |
| 48 | +{ |
| 49 | + const CService addr_bind = LookupNumeric(host, port); |
| 50 | + |
| 51 | + auto sock = CreateSock(addr_bind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); |
| 52 | + if (!sock) { |
| 53 | + throw std::runtime_error("Sv2 Template Provider cannot create socket"); |
| 54 | + } |
| 55 | + |
| 56 | + struct sockaddr_storage sockaddr; |
| 57 | + socklen_t len = sizeof(sockaddr); |
| 58 | + |
| 59 | + if (!addr_bind.GetSockAddr(reinterpret_cast<struct sockaddr*>(&sockaddr), &len)) { |
| 60 | + throw std::runtime_error("Sv2 Template Provider failed to get socket address"); |
| 61 | + } |
| 62 | + |
| 63 | + if (sock->Bind(reinterpret_cast<struct sockaddr*>(&sockaddr), len) == SOCKET_ERROR) { |
| 64 | + const int nErr = WSAGetLastError(); |
| 65 | + if (nErr == WSAEADDRINUSE) { |
| 66 | + throw std::runtime_error(strprintf("Unable to bind to %d on this computer. Another Stratum v2 process is probably already running.\n", port)); |
| 67 | + } |
| 68 | + |
| 69 | + throw std::runtime_error(strprintf("Unable to bind to %d on this computer (bind returned error %s )\n", port, NetworkErrorString(nErr))); |
| 70 | + } |
| 71 | + |
| 72 | + constexpr int max_pending_conns{4096}; |
| 73 | + if (sock->Listen(max_pending_conns) == SOCKET_ERROR) { |
| 74 | + throw std::runtime_error("Sv2 listening socket has an error listening"); |
| 75 | + } |
| 76 | + |
| 77 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Info, "%s listening on %s:%d\n", SV2_PROTOCOL_NAMES.at(m_subprotocol), host, port); |
| 78 | + |
| 79 | + return sock; |
| 80 | +} |
| 81 | + |
| 82 | + |
| 83 | +void Sv2Connman::DisconnectFlagged() |
| 84 | +{ |
| 85 | + AssertLockHeld(m_clients_mutex); |
| 86 | + |
| 87 | + // Remove clients that are flagged for disconnection. |
| 88 | + m_sv2_clients.erase( |
| 89 | + std::remove_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto &client) { |
| 90 | + return client->m_disconnect_flag; |
| 91 | + }), m_sv2_clients.end()); |
| 92 | +} |
| 93 | + |
| 94 | +void Sv2Connman::ThreadSv2Handler() EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) |
| 95 | +{ |
| 96 | + AssertLockNotHeld(m_clients_mutex); |
| 97 | + |
| 98 | + while (!m_flag_interrupt_sv2) { |
| 99 | + { |
| 100 | + LOCK(m_clients_mutex); |
| 101 | + DisconnectFlagged(); |
| 102 | + } |
| 103 | + |
| 104 | + // Poll/Select the sockets that need handling. |
| 105 | + Sock::EventsPerSock events_per_sock = WITH_LOCK(m_clients_mutex, return GenerateWaitSockets(m_listening_socket, m_sv2_clients)); |
| 106 | + |
| 107 | + constexpr auto timeout = std::chrono::milliseconds(50); |
| 108 | + if (!events_per_sock.begin()->first->WaitMany(timeout, events_per_sock)) { |
| 109 | + continue; |
| 110 | + } |
| 111 | + |
| 112 | + // Accept any new connections for sv2 clients. |
| 113 | + const auto listening_sock = events_per_sock.find(m_listening_socket); |
| 114 | + if (listening_sock != events_per_sock.end() && listening_sock->second.occurred & Sock::RECV) { |
| 115 | + struct sockaddr_storage sockaddr; |
| 116 | + socklen_t sockaddr_len = sizeof(sockaddr); |
| 117 | + |
| 118 | + auto sock = m_listening_socket->Accept(reinterpret_cast<struct sockaddr*>(&sockaddr), &sockaddr_len); |
| 119 | + if (sock) { |
| 120 | + Assume(m_certificate); |
| 121 | + LOCK(m_clients_mutex); |
| 122 | + std::unique_ptr transport = std::make_unique<Sv2Transport>(m_static_key, m_certificate.value()); |
| 123 | + size_t id{m_sv2_clients.size() + 1}; |
| 124 | + auto client = std::make_unique<Sv2Client>(id, std::move(sock), std::move(transport)); |
| 125 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "New client id=%zu connected\n", client->m_id); |
| 126 | + m_sv2_clients.emplace_back(std::move(client)); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + LOCK(m_clients_mutex); |
| 131 | + // Process messages from and for connected sv2_clients. |
| 132 | + for (auto& client : m_sv2_clients) { |
| 133 | + bool has_received_data = false; |
| 134 | + bool has_error_occurred = false; |
| 135 | + |
| 136 | + const auto socket_it = events_per_sock.find(client->m_sock); |
| 137 | + if (socket_it != events_per_sock.end()) { |
| 138 | + has_received_data = socket_it->second.occurred & Sock::RECV; |
| 139 | + has_error_occurred = socket_it->second.occurred & Sock::ERR; |
| 140 | + } |
| 141 | + |
| 142 | + if (has_error_occurred) { |
| 143 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Socket receive error, disconnecting client id=%zu\n", |
| 144 | + client->m_id); |
| 145 | + client->m_disconnect_flag = true; |
| 146 | + continue; |
| 147 | + } |
| 148 | + |
| 149 | + // Process message queue and any outbound bytes still held by the transport |
| 150 | + auto it = client->m_send_messages.begin(); |
| 151 | + std::optional<bool> expected_more; |
| 152 | + while(true) { |
| 153 | + if (it != client->m_send_messages.end()) { |
| 154 | + // If possible, move one message from the send queue to the transport. |
| 155 | + // This fails when there is an existing message still being sent, |
| 156 | + // or when the handshake has not yet completed. |
| 157 | + // |
| 158 | + // Wrap Sv2NetMsg inside CSerializedNetMsg for transport |
| 159 | + CSerializedNetMsg net_msg{*it}; |
| 160 | + if (client->m_transport->SetMessageToSend(net_msg)) { |
| 161 | + ++it; |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + const auto& [data, more, _m_message_type] = client->m_transport->GetBytesToSend(/*have_next_message=*/it != client->m_send_messages.end()); |
| 166 | + size_t total_sent = 0; |
| 167 | + |
| 168 | + // We rely on the 'more' value returned by GetBytesToSend to correctly predict whether more |
| 169 | + // bytes are still to be sent, to correctly set the MSG_MORE flag. As a sanity check, |
| 170 | + // verify that the previously returned 'more' was correct. |
| 171 | + if (expected_more.has_value()) Assume(!data.empty() == *expected_more); |
| 172 | + expected_more = more; |
| 173 | + ssize_t sent = 0; |
| 174 | + |
| 175 | + if (!data.empty()) { |
| 176 | + int flags = MSG_NOSIGNAL | MSG_DONTWAIT; |
| 177 | +#ifdef MSG_MORE |
| 178 | + if (more) { |
| 179 | + flags |= MSG_MORE; |
| 180 | + } |
| 181 | +#endif |
| 182 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send %d bytes to client id=%zu\n", |
| 183 | + data.size() - total_sent, client->m_id); |
| 184 | + sent = client->m_sock->Send(data.data() + total_sent, data.size() - total_sent, flags); |
| 185 | + } |
| 186 | + if (sent > 0) { |
| 187 | + // Notify transport that bytes have been processed. |
| 188 | + client->m_transport->MarkBytesSent(sent); |
| 189 | + if ((size_t)sent != data.size()) { |
| 190 | + // could not send full message; stop sending more |
| 191 | + break; |
| 192 | + } |
| 193 | + } else { |
| 194 | + if (sent < 0) { |
| 195 | + // error |
| 196 | + int nErr = WSAGetLastError(); |
| 197 | + if (nErr != WSAEWOULDBLOCK && nErr != WSAEMSGSIZE && nErr != WSAEINTR && nErr != WSAEINPROGRESS) { |
| 198 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Socket send error for client id=%zu: %s\n", |
| 199 | + client->m_id, NetworkErrorString(nErr)); |
| 200 | + client->m_disconnect_flag = true; |
| 201 | + } |
| 202 | + } |
| 203 | + break; |
| 204 | + } |
| 205 | + } |
| 206 | + // Clear messages that have been handed to transport from the queue |
| 207 | + client->m_send_messages.erase(client->m_send_messages.begin(), it); |
| 208 | + |
| 209 | + // Stop processing this client if something went wrong during sending |
| 210 | + if (client->m_disconnect_flag) break; |
| 211 | + |
| 212 | + if (has_received_data) { |
| 213 | + uint8_t bytes_received_buf[0x10000]; |
| 214 | + |
| 215 | + const auto num_bytes_received = client->m_sock->Recv(bytes_received_buf, sizeof(bytes_received_buf), MSG_DONTWAIT); |
| 216 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Num bytes received from client id=%zu: %d\n", |
| 217 | + client->m_id, num_bytes_received); |
| 218 | + |
| 219 | + if (num_bytes_received <= 0) { |
| 220 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Disconnecting client id=%zu\n", |
| 221 | + client->m_id); |
| 222 | + client->m_disconnect_flag = true; |
| 223 | + break; |
| 224 | + } |
| 225 | + |
| 226 | + try |
| 227 | + { |
| 228 | + auto msg_ = Span(bytes_received_buf, num_bytes_received); |
| 229 | + Span<const uint8_t> msg(reinterpret_cast<const uint8_t*>(msg_.data()), msg_.size()); |
| 230 | + while (msg.size() > 0) { |
| 231 | + // absorb network data |
| 232 | + if (!client->m_transport->ReceivedBytes(msg)) { |
| 233 | + // Serious transport problem |
| 234 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Transport problem, disconnecting client id=%zu\n", |
| 235 | + client->m_id); |
| 236 | + client->m_disconnect_flag = true; |
| 237 | + break; |
| 238 | + } |
| 239 | + |
| 240 | + if (client->m_transport->ReceivedMessageComplete()) { |
| 241 | + bool dummy_reject_message = false; |
| 242 | + Sv2NetMsg msg = client->m_transport->GetReceivedMessage(std::chrono::milliseconds(0), dummy_reject_message); |
| 243 | + ProcessSv2Message(msg, *client.get()); |
| 244 | + } |
| 245 | + } |
| 246 | + } catch (const std::exception& e) { |
| 247 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received error when processing client id=%zu message: %s\n", client->m_id, e.what()); |
| 248 | + client->m_disconnect_flag = true; |
| 249 | + } |
| 250 | + } |
| 251 | + } |
| 252 | + } |
| 253 | +} |
| 254 | + |
| 255 | +Sock::EventsPerSock Sv2Connman::GenerateWaitSockets(const std::shared_ptr<Sock>& listen_socket, const Clients& sv2_clients) const |
| 256 | +{ |
| 257 | + Sock::EventsPerSock events_per_sock; |
| 258 | + events_per_sock.emplace(listen_socket, Sock::Events(Sock::RECV)); |
| 259 | + |
| 260 | + for (const auto& client : sv2_clients) { |
| 261 | + if (!client->m_disconnect_flag && client->m_sock) { |
| 262 | + events_per_sock.emplace(client->m_sock, Sock::Events{Sock::RECV | Sock::ERR}); |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + return events_per_sock; |
| 267 | +} |
| 268 | + |
| 269 | +void Sv2Connman::Interrupt() |
| 270 | +{ |
| 271 | + m_flag_interrupt_sv2 = true; |
| 272 | +} |
| 273 | + |
| 274 | +void Sv2Connman::StopThreads() |
| 275 | +{ |
| 276 | + if (m_thread_sv2_handler.joinable()) { |
| 277 | + m_thread_sv2_handler.join(); |
| 278 | + } |
| 279 | +} |
| 280 | + |
| 281 | +void Sv2Connman::ProcessSv2Message(const Sv2NetMsg& sv2_net_msg, Sv2Client& client) |
| 282 | +{ |
| 283 | + uint8_t msg_type[1] = {uint8_t(sv2_net_msg.m_msg_type)}; |
| 284 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Received 0x%s %s from client id=%zu\n", |
| 285 | + // After clang-17: |
| 286 | + // std::format("{:x}", uint8_t(sv2_net_msg.m_msg_type)), |
| 287 | + HexStr(msg_type), |
| 288 | + node::SV2_MSG_NAMES.at(sv2_net_msg.m_msg_type), client.m_id); |
| 289 | + |
| 290 | + DataStream ss (sv2_net_msg.m_msg); |
| 291 | + |
| 292 | + switch (sv2_net_msg.m_msg_type) |
| 293 | + { |
| 294 | + case Sv2MsgType::SETUP_CONNECTION: |
| 295 | + { |
| 296 | + if (client.m_setup_connection_confirmed) { |
| 297 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Client client id=%zu connection has already been confirmed\n", |
| 298 | + client.m_id); |
| 299 | + return; |
| 300 | + } |
| 301 | + |
| 302 | + node::Sv2SetupConnectionMsg setup_conn; |
| 303 | + try { |
| 304 | + ss >> setup_conn; |
| 305 | + } catch (const std::exception& e) { |
| 306 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid SetupConnection message from client id=%zu: %s\n", |
| 307 | + client.m_id, e.what()); |
| 308 | + client.m_disconnect_flag = true; |
| 309 | + return; |
| 310 | + } |
| 311 | + |
| 312 | + // Disconnect a client that connects on the wrong subprotocol. |
| 313 | + if (setup_conn.m_protocol != m_subprotocol) { |
| 314 | + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"unsupported-protocol"}}; |
| 315 | + |
| 316 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnectionError to client id=%zu\n", |
| 317 | + client.m_id); |
| 318 | + client.m_send_messages.emplace_back(setup_conn_err); |
| 319 | + |
| 320 | + client.m_disconnect_flag = true; |
| 321 | + return; |
| 322 | + } |
| 323 | + |
| 324 | + // Disconnect a client if they are not running a compatible protocol version. |
| 325 | + if ((m_protocol_version < setup_conn.m_min_version) || (m_protocol_version > setup_conn.m_max_version)) { |
| 326 | + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"protocol-version-mismatch"}}; |
| 327 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnection.Error to client id=%zu\n", |
| 328 | + client.m_id); |
| 329 | + client.m_send_messages.emplace_back(setup_conn_err); |
| 330 | + |
| 331 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received a connection from client id=%zu with incompatible protocol_versions: min_version: %d, max_version: %d\n", |
| 332 | + client.m_id, setup_conn.m_min_version, setup_conn.m_max_version); |
| 333 | + client.m_disconnect_flag = true; |
| 334 | + return; |
| 335 | + } |
| 336 | + |
| 337 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x01 SetupConnection.Success to client id=%zu\n", |
| 338 | + client.m_id); |
| 339 | + node::Sv2SetupConnectionSuccessMsg setup_success{m_protocol_version, m_optional_features}; |
| 340 | + client.m_send_messages.emplace_back(setup_success); |
| 341 | + |
| 342 | + client.m_setup_connection_confirmed = true; |
| 343 | + |
| 344 | + break; |
| 345 | + } |
| 346 | + case Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE: |
| 347 | + { |
| 348 | + if (!client.m_setup_connection_confirmed) { |
| 349 | + client.m_disconnect_flag = true; |
| 350 | + return; |
| 351 | + } |
| 352 | + |
| 353 | + node::Sv2CoinbaseOutputDataSizeMsg coinbase_output_data_size; |
| 354 | + try { |
| 355 | + ss >> coinbase_output_data_size; |
| 356 | + client.m_coinbase_output_data_size_recv = true; |
| 357 | + } catch (const std::exception& e) { |
| 358 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid CoinbaseOutputDataSize message from client id=%zu: %s\n", |
| 359 | + client.m_id, e.what()); |
| 360 | + client.m_disconnect_flag = true; |
| 361 | + return; |
| 362 | + } |
| 363 | + |
| 364 | + uint32_t max_additional_size = coinbase_output_data_size.m_coinbase_output_max_additional_size; |
| 365 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "coinbase_output_max_additional_size=%d bytes\n", max_additional_size); |
| 366 | + |
| 367 | + if (max_additional_size > MAX_BLOCK_WEIGHT) { |
| 368 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received impossible CoinbaseOutputDataSize from client id=%zu: %d\n", |
| 369 | + client.m_id, max_additional_size); |
| 370 | + client.m_disconnect_flag = true; |
| 371 | + return; |
| 372 | + } |
| 373 | + |
| 374 | + client.m_coinbase_tx_outputs_size = coinbase_output_data_size.m_coinbase_output_max_additional_size; |
| 375 | + |
| 376 | + break; |
| 377 | + } |
| 378 | + default: { |
| 379 | + uint8_t msg_type[1]{uint8_t(sv2_net_msg.m_msg_type)}; |
| 380 | + LogPrintLevel(BCLog::SV2, BCLog::Level::Warning, "Received unknown message type 0x%s from client id=%zu\n", |
| 381 | + HexStr(msg_type), client.m_id); |
| 382 | + break; |
| 383 | + } |
| 384 | + } |
| 385 | + |
| 386 | + m_msgproc->ReceivedMessage(client, sv2_net_msg.m_msg_type); |
| 387 | +} |
0 commit comments