Skip to content

Commit 43ddab6

Browse files
authored
fix(rpc): Improve input validation and error handling (#13069)
* fix(rpc): Improve input validation and error handling The `rpc-server` was vulnerable to Denial of Service attacks via several RPC commands (`SET_TENSOR`, `GRAPH_COMPUTE`, etc.). Malformed messages could trigger failed assertions (e.g., invalid `ggml_type`) or out-of-bounds reads/writes leading to `GGML_ABORT` calls, crashing the server process. This PR introduces robust input validation and replaces `abort()` calls with graceful error handling: - **Type Validation:** `deserialize_tensor` now checks if the `tensor->type` is within the valid `GGML_TYPE_COUNT` range *before* calling `ggml_new_tensor_4d`. Returns `nullptr` on invalid type. - **Bounds Checks:** Replaced `GGML_ABORT` in `set_tensor`, `set_tensor_hash`, and `get_tensor` handlers with error logging and returning `false` when data/offset parameters are out of buffer bounds. - **Size Checks:** Added safe arithmetic checks (for overflow) in `graph_compute` when calculating required message sizes based on client-provided `n_nodes` and `n_tensors`. Returns early if the reported sizes conflict with the actual message size or would lead to overflow. - **Error Propagation:** - `create_node` now checks for `nullptr` return values from `deserialize_tensor` and its recursive calls, propagating `nullptr` upwards on failure. Uses `find` instead of `at` for safer map access. - `copy_tensor` now checks for `nullptr` from `deserialize_tensor` and sets the response status to failure if deserialization or bounds checks fail. - `graph_compute` now checks for `nullptr` return from `create_node` and returns failure status correctly. The final return value now reflects the actual computation status. These changes improve the RPC server's resilience against malformed client requests, preventing crashes and ensuring errors are handled more gracefully. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): address pr comments removed comments and unnecessary returns Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): ambiguous nullptr from create_node rpc_server::create_node could previously return nullptr if the input ID was 0 (valid) or if an internal error (deserialization, recursion failure) occurred (invalid). This ambiguity made error handling difficult for the caller (`graph_compute`). This commit clarifies the meaning of nullptr: - `graph_compute` now checks if the input 'id' was non-zero when `create_node` returns nullptr, correctly identifying failures versus intentional null links. - `create_node` avoids recursive calls for zero IDs and propagates nullptr unambiguously on failure during recursion. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): initial zero check in create_node The caller (`graph_compute`) already checks `id != 0` when handling a `nullptr` return from `create_node`, correctly distinguishing intentional null links from actual errors. This makes the initial `if (id == 0)` check redundant. Also removes the log message when a tensor ID is not found in the provided map which was added in this branch. Signed-off-by: Ville Vesilehto <[email protected]> * fix(rpc): Handle get_alloc_size failure in server Check the return value of `server.get_alloc_size` in the RPC server loop. If the call fails, return early to close the connection. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): input size validation in graph_compute Removes detailed, step-by-step size calculations and overflow checks in favor of simpler direct comparisons, assuming 64-bit overflow is unlikely. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): remove extra status code setting Removes the explicit setting of `response.result = GGML_STATUS_FAILED` when `create_node` returns `nullptr` within `graph_compute`. Primary signal is the `false` return value in case of failure. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): remove redundant check for tensor->type Breaks CI on ubuntu-cpu-make. Tensor type is uint32_t, thus the check is not needed. Signed-off-by: Ville Vesilehto <[email protected]> --------- Signed-off-by: Ville Vesilehto <[email protected]>
1 parent 1831f53 commit 43ddab6

File tree

1 file changed

+68
-10
lines changed

1 file changed

+68
-10
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

+68-10
Original file line numberDiff line numberDiff line change
@@ -982,8 +982,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
982982
}
983983

984984
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
985+
// Validate tensor type before using it
986+
if (tensor->type >= GGML_TYPE_COUNT) {
987+
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
988+
return nullptr;
989+
}
990+
985991
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
986992
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
993+
994+
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
995+
if (result == nullptr) {
996+
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
997+
return nullptr;
998+
}
999+
9871000
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
9881001
result->nb[i] = tensor->nb[i];
9891002
}
@@ -1043,7 +1056,9 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
10431056
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
10441057

10451058
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1046-
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1059+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1060+
__func__, in_tensor->data, offset, size, p0, p1);
1061+
return false;
10471062
}
10481063
}
10491064

@@ -1118,7 +1133,9 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
11181133
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
11191134

11201135
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1121-
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1136+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1137+
__func__, in_tensor->data, offset, size, *hash, p0, p1);
1138+
return false;
11221139
}
11231140
}
11241141
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
@@ -1183,7 +1200,9 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
11831200
if (request.tensor.data + request.offset < p0 ||
11841201
request.tensor.data + request.offset >= p1 ||
11851202
request.size > (p1 - request.tensor.data - request.offset)) {
1186-
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1203+
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1204+
__func__, request.tensor.data, request.offset, request.size, p0, p1);
1205+
return false;
11871206
}
11881207
}
11891208

@@ -1237,22 +1256,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
12371256
struct ggml_context * ctx,
12381257
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
12391258
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
1240-
if (id == 0) {
1241-
return nullptr;
1242-
}
12431259
if (tensor_map.find(id) != tensor_map.end()) {
12441260
return tensor_map[id];
12451261
}
1246-
const rpc_tensor * tensor = tensor_ptrs.at(id);
1262+
// Safely find the tensor pointer
1263+
auto it_ptr = tensor_ptrs.find(id);
1264+
if (it_ptr == tensor_ptrs.end()) {
1265+
return nullptr;
1266+
}
1267+
const rpc_tensor * tensor = it_ptr->second;
1268+
12471269
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
12481270
if (result == nullptr) {
12491271
return nullptr;
12501272
}
12511273
tensor_map[id] = result;
12521274
for (int i = 0; i < GGML_MAX_SRC; i++) {
1253-
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1275+
// Check if the source ID is 0 before calling create_node recursively
1276+
if (tensor->src[i] == 0) {
1277+
result->src[i] = nullptr;
1278+
} else {
1279+
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1280+
// If the recursive call failed for a non-zero ID, propagate the error
1281+
if (result->src[i] == nullptr) {
1282+
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1283+
__func__, i, tensor->src[i], id);
1284+
// Must return nullptr to signal failure up the call stack
1285+
return nullptr;
1286+
}
1287+
}
1288+
}
1289+
1290+
// Handle view_src similarly
1291+
if (tensor->view_src == 0) {
1292+
result->view_src = nullptr;
1293+
} else {
1294+
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1295+
// If the recursive call failed for a non-zero ID, propagate the error
1296+
if (result->view_src == nullptr) {
1297+
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1298+
__func__, tensor->view_src, id);
1299+
// Must return nullptr to signal failure up the call stack
1300+
return nullptr;
1301+
}
12541302
}
1255-
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
12561303
result->view_offs = tensor->view_offs;
12571304
return result;
12581305
}
@@ -1278,6 +1325,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
12781325
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
12791326

12801327
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1328+
12811329
struct ggml_init_params params = {
12821330
/*.mem_size =*/ buf_size,
12831331
/*.mem_buffer =*/ NULL,
@@ -1297,6 +1345,14 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
12971345
int64_t id;
12981346
memcpy(&id, &nodes[i], sizeof(id));
12991347
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1348+
1349+
// Check if create_node failed for a *non-zero* ID.
1350+
// If id was 0, create_node returning nullptr is expected.
1351+
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1352+
if (graph->nodes[i] == nullptr && id != 0) {
1353+
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1354+
return false;
1355+
}
13001356
}
13011357
ggml_status status = ggml_backend_graph_compute(backend, graph);
13021358
response.result = status;
@@ -1361,7 +1417,9 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
13611417
return;
13621418
}
13631419
rpc_msg_get_alloc_size_rsp response;
1364-
server.get_alloc_size(request, response);
1420+
if (!server.get_alloc_size(request, response)) {
1421+
return;
1422+
}
13651423
if (!send_msg(sockfd, &response, sizeof(response))) {
13661424
return;
13671425
}

0 commit comments

Comments
 (0)