Skip to content

Commit bbdb41e

Browse files
authored
Clear all migrating locked object when client crashed. (#1982)
Fixes #1977 Signed-off-by: vegetableysm <[email protected]>
1 parent 55c9058 commit bbdb41e

File tree

4 files changed

+102
-19
lines changed

4 files changed

+102
-19
lines changed

src/client/rpc_client.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,6 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe,
776776
std::vector<int> fd_sent;
777777

778778
std::string message_out;
779-
RDMABlobScopeGuard rdmaBlobScopeGuard;
780779
if (rdma_connected_) {
781780
WriteGetRemoteBuffersRequest(std::set<ObjectID>{id}, unsafe, false, true,
782781
message_out);
@@ -788,14 +787,17 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe,
788787
json message_in;
789788
RETURN_ON_ERROR(doRead(message_in));
790789
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));
791-
RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload");
790+
791+
RDMABlobScopeGuard rdmaBlobScopeGuard;
792792
if (rdma_connected_) {
793-
std::unordered_set<ObjectID> ids{payloads[0].object_id};
793+
std::unordered_set<ObjectID> ids{id};
794794
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
795795
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
796796
rdmaBlobScopeGuard.set(func, ids);
797797
}
798798

799+
RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload");
800+
799801
buffer = std::shared_ptr<RemoteBlob>(new RemoteBlob(
800802
payloads[0].object_id, remote_instance_id_, payloads[0].data_size));
801803
// read the actual payload
@@ -892,7 +894,6 @@ Status RPCClient::GetRemoteBlobs(
892894
std::unordered_set<ObjectID> id_set(ids.begin(), ids.end());
893895
std::vector<Payload> payloads;
894896
std::vector<int> fd_sent;
895-
RDMABlobScopeGuard rdmaBlobScopeGuard;
896897

897898
std::string message_out;
898899
if (rdma_connected_) {
@@ -905,16 +906,19 @@ Status RPCClient::GetRemoteBlobs(
905906
json message_in;
906907
RETURN_ON_ERROR(doRead(message_in));
907908
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));
908-
RETURN_ON_ASSERT(payloads.size() == id_set.size(),
909-
"The result size doesn't match with the requested sizes: " +
910-
std::to_string(payloads.size()) + " vs. " +
911-
std::to_string(id_set.size()));
909+
910+
RDMABlobScopeGuard rdmaBlobScopeGuard;
912911
if (rdma_connected_) {
913912
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
914913
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
915914
rdmaBlobScopeGuard.set(func, id_set);
916915
}
917916

917+
RETURN_ON_ASSERT(payloads.size() == id_set.size(),
918+
"The result size doesn't match with the requested sizes: " +
919+
std::to_string(payloads.size()) + " vs. " +
920+
std::to_string(id_set.size()));
921+
918922
std::unordered_map<ObjectID, std::shared_ptr<RemoteBlob>> id_payload_map;
919923
if (rdma_connected_) {
920924
for (auto const& payload : payloads) {
@@ -982,6 +986,14 @@ Status RPCClient::GetRemoteBlobs(
982986
json message_in;
983987
RETURN_ON_ERROR(doRead(message_in));
984988
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));
989+
990+
RDMABlobScopeGuard rdmaBlobScopeGuard;
991+
if (rdma_connected_) {
992+
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
993+
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
994+
rdmaBlobScopeGuard.set(func, id_set);
995+
}
996+
985997
RETURN_ON_ASSERT(payloads.size() == id_set.size(),
986998
"The result size doesn't match with the requested sizes: " +
987999
std::to_string(payloads.size()) + " vs. " +

src/server/async/rpc_server.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ void RPCServer::doVineyardReleaseMemory(VineyardRecvContext* recv_context,
329329

330330
void RPCServer::doVineyardClose(VineyardRecvContext* recv_context) {
331331
VLOG(100) << "Receive close msg!";
332+
if (recv_context == nullptr) {
333+
return;
334+
}
332335
rdma_server_->CloseConnection(recv_context->rdma_conn_id);
333336

334337
std::lock_guard<std::recursive_mutex> scope_lock(this->rdma_mutex_);
@@ -369,6 +372,9 @@ void RPCServer::doRDMARecv() {
369372
VineyardRecvContext* recv_context =
370373
reinterpret_cast<VineyardRecvContext*>(context);
371374
doVineyardClose(recv_context);
375+
if (recv_context) {
376+
delete recv_context;
377+
}
372378
}
373379
VLOG(100) << "Get RX completion failed! Error:" << status.message();
374380
VLOG(100) << "Retry...";

src/server/async/socket_server.cc

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,22 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) {
786786

787787
TRY_READ_REQUEST(ReadGetRemoteBuffersRequest, root, ids, unsafe, compress,
788788
use_rdma);
789-
server_ptr_->LockTransmissionObjects(ids);
790-
RESPONSE_ON_ERROR(bulk_store_->GetUnsafe(ids, unsafe, objects));
791-
RESPONSE_ON_ERROR(bulk_store_->AddDependency(
792-
std::unordered_set<ObjectID>(ids.begin(), ids.end()), this->getConnId()));
789+
this->LockTransmissionObjects(ids);
790+
if (!bulk_store_->GetUnsafe(ids, unsafe, objects).ok()) {
791+
this->UnlockTransmissionObjects(ids);
792+
WriteErrorReply(Status::KeyError("Failed to get objects"), message_out);
793+
this->doWrite(message_out);
794+
return false;
795+
}
796+
if (!bulk_store_
797+
->AddDependency(std::unordered_set<ObjectID>(ids.begin(), ids.end()),
798+
this->getConnId())
799+
.ok()) {
800+
this->UnlockTransmissionObjects(ids);
801+
WriteErrorReply(Status::KeyError("Failed to add dependency"), message_out);
802+
this->doWrite(message_out);
803+
return false;
804+
}
793805
WriteGetBuffersReply(objects, {}, compress, message_out);
794806

795807
if (!use_rdma) {
@@ -802,7 +814,7 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) {
802814
<< "Failed to send buffers to remote client: "
803815
<< status.ToString();
804816
}
805-
self->server_ptr_->UnlockTransmissionObjects(ids);
817+
self->UnlockTransmissionObjects(ids);
806818
return Status::OK();
807819
});
808820
return Status::OK();
@@ -1846,12 +1858,10 @@ bool SocketConnection::doReleaseBlobsWithRDMA(const json& root) {
18461858
std::vector<ObjectID> ids;
18471859
TRY_READ_REQUEST(ReadReleaseBlobsWithRDMARequest, root, ids);
18481860

1849-
boost::asio::post(server_ptr_->GetIOContext(), [self, ids]() {
1850-
self->server_ptr_->UnlockTransmissionObjects(ids);
1851-
std::string message_out;
1852-
WriteReleaseBlobsWithRDMAReply(message_out);
1853-
self->doWrite(message_out);
1854-
});
1861+
this->UnlockTransmissionObjects(ids);
1862+
std::string message_out;
1863+
WriteReleaseBlobsWithRDMAReply(message_out);
1864+
this->doWrite(message_out);
18551865

18561866
return false;
18571867
}
@@ -1884,6 +1894,7 @@ void SocketConnection::doWrite(std::string&& buf) {
18841894
}
18851895

18861896
void SocketConnection::doStop() {
1897+
this->ClearLockedObjects();
18871898
if (this->Stop()) {
18881899
// drop connection
18891900
socket_server_ptr_->RemoveConnection(conn_id_);
@@ -1928,6 +1939,50 @@ void SocketConnection::doAsyncWrite(std::string&& buf, callback_t<> callback,
19281939
});
19291940
}
19301941

1942+
void SocketConnection::LockTransmissionObjects(
1943+
const std::vector<ObjectID>& ids) {
1944+
{
1945+
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
1946+
for (auto const& id : ids) {
1947+
if (locked_objects_.find(id) == locked_objects_.end()) {
1948+
locked_objects_[id] = 1;
1949+
} else {
1950+
++locked_objects_[id];
1951+
}
1952+
}
1953+
}
1954+
server_ptr_->LockTransmissionObjects(ids);
1955+
}
1956+
1957+
void SocketConnection::UnlockTransmissionObjects(
1958+
const std::vector<ObjectID>& ids) {
1959+
{
1960+
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
1961+
for (auto const& id : ids) {
1962+
if (locked_objects_.find(id) != locked_objects_.end()) {
1963+
if (--locked_objects_[id] == 0) {
1964+
locked_objects_.erase(id);
1965+
}
1966+
}
1967+
}
1968+
}
1969+
server_ptr_->UnlockTransmissionObjects(ids);
1970+
}
1971+
1972+
void SocketConnection::ClearLockedObjects() {
1973+
std::vector<ObjectID> ids;
1974+
{
1975+
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
1976+
for (auto const& kv : locked_objects_) {
1977+
for (int i = 0; i < kv.second; ++i) {
1978+
ids.push_back(kv.first);
1979+
}
1980+
}
1981+
locked_objects_.clear();
1982+
}
1983+
server_ptr_->UnlockTransmissionObjects(ids);
1984+
}
1985+
19311986
SocketServer::SocketServer(std::shared_ptr<VineyardServer> vs_ptr)
19321987
: vs_ptr_(vs_ptr), next_conn_id_(0) {}
19331988

src/server/async/socket_server.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <string>
2525
#include <unordered_map>
2626
#include <unordered_set>
27+
#include <vector>
2728

2829
#include "common/memory/payload.h"
2930
#include "common/util/asio.h" // IWYU pragma: keep
@@ -193,6 +194,12 @@ class SocketConnection : public std::enable_shared_from_this<SocketConnection> {
193194
this->server_ptr_ = session;
194195
}
195196

197+
void LockTransmissionObjects(const std::vector<ObjectID>& ids);
198+
199+
void UnlockTransmissionObjects(const std::vector<ObjectID>& ids);
200+
201+
void ClearLockedObjects();
202+
196203
// whether the connection has been correctly "registered"
197204
std::atomic_bool registered_;
198205

@@ -216,6 +223,9 @@ class SocketConnection : public std::enable_shared_from_this<SocketConnection> {
216223
size_t read_msg_header_;
217224
std::string read_msg_body_;
218225

226+
std::unordered_map<ObjectID, int> locked_objects_;
227+
std::mutex locked_objects_mutex_;
228+
219229
friend class IPCServer;
220230
friend class RPCServer;
221231
};

0 commit comments

Comments
 (0)