Skip to content

Commit 7246d8e

Browse files
jrfastabborkmann
authored andcommitted
bpf: helper to pop data from messages
This adds a BPF SK_MSG program helper so that we can pop data from a msg. We use this to pop metadata from a previous push data call. Signed-off-by: John Fastabend <[email protected]> Signed-off-by: Daniel Borkmann <[email protected]>
1 parent 17d95e4 commit 7246d8e

File tree

4 files changed

+209
-6
lines changed

4 files changed

+209
-6
lines changed

include/uapi/linux/bpf.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,19 @@ union bpf_attr {
22682268
*
22692269
* Return
22702270
* 0 on success, or a negative error in case of failure.
2271+
*
2272+
* int bpf_msg_pop_data(struct sk_msg_buff *msg, u32 start, u32 pop, u64 flags)
2273+
* Description
2274+
* Will remove *pop* bytes from a *msg* starting at byte *start*.
2275+
* This may result in **ENOMEM** errors under certain situations if
2276+
* an allocation and copy are required due to a full ring buffer.
2277+
* However, the helper will try to avoid doing the allocation
2278+
* if possible. Other errors can occur if input parameters are
2279+
* invalid either due to *start* byte not being valid part of msg
2280+
* payload and/or *pop* value being to large.
2281+
*
2282+
* Return
2283+
* 0 on success, or a negative erro in case of failure.
22712284
*/
22722285
#define __BPF_FUNC_MAPPER(FN) \
22732286
FN(unspec), \
@@ -2360,7 +2373,8 @@ union bpf_attr {
23602373
FN(map_push_elem), \
23612374
FN(map_pop_elem), \
23622375
FN(map_peek_elem), \
2363-
FN(msg_push_data),
2376+
FN(msg_push_data), \
2377+
FN(msg_pop_data),
23642378

23652379
/* integer value in 'imm' field of BPF_CALL instruction selects which helper
23662380
* function eBPF program intends to call

net/core/filter.c

+171
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,174 @@ static const struct bpf_func_proto bpf_msg_push_data_proto = {
24252425
.arg4_type = ARG_ANYTHING,
24262426
};
24272427

2428+
static void sk_msg_shift_left(struct sk_msg *msg, int i)
2429+
{
2430+
int prev;
2431+
2432+
do {
2433+
prev = i;
2434+
sk_msg_iter_var_next(i);
2435+
msg->sg.data[prev] = msg->sg.data[i];
2436+
} while (i != msg->sg.end);
2437+
2438+
sk_msg_iter_prev(msg, end);
2439+
}
2440+
2441+
static void sk_msg_shift_right(struct sk_msg *msg, int i)
2442+
{
2443+
struct scatterlist tmp, sge;
2444+
2445+
sk_msg_iter_next(msg, end);
2446+
sge = sk_msg_elem_cpy(msg, i);
2447+
sk_msg_iter_var_next(i);
2448+
tmp = sk_msg_elem_cpy(msg, i);
2449+
2450+
while (i != msg->sg.end) {
2451+
msg->sg.data[i] = sge;
2452+
sk_msg_iter_var_next(i);
2453+
sge = tmp;
2454+
tmp = sk_msg_elem_cpy(msg, i);
2455+
}
2456+
}
2457+
2458+
BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start,
2459+
u32, len, u64, flags)
2460+
{
2461+
u32 i = 0, l, space, offset = 0;
2462+
u64 last = start + len;
2463+
int pop;
2464+
2465+
if (unlikely(flags))
2466+
return -EINVAL;
2467+
2468+
/* First find the starting scatterlist element */
2469+
i = msg->sg.start;
2470+
do {
2471+
l = sk_msg_elem(msg, i)->length;
2472+
2473+
if (start < offset + l)
2474+
break;
2475+
offset += l;
2476+
sk_msg_iter_var_next(i);
2477+
} while (i != msg->sg.end);
2478+
2479+
/* Bounds checks: start and pop must be inside message */
2480+
if (start >= offset + l || last >= msg->sg.size)
2481+
return -EINVAL;
2482+
2483+
space = MAX_MSG_FRAGS - sk_msg_elem_used(msg);
2484+
2485+
pop = len;
2486+
/* --------------| offset
2487+
* -| start |-------- len -------|
2488+
*
2489+
* |----- a ----|-------- pop -------|----- b ----|
2490+
* |______________________________________________| length
2491+
*
2492+
*
2493+
* a: region at front of scatter element to save
2494+
* b: region at back of scatter element to save when length > A + pop
2495+
* pop: region to pop from element, same as input 'pop' here will be
2496+
* decremented below per iteration.
2497+
*
2498+
* Two top-level cases to handle when start != offset, first B is non
2499+
* zero and second B is zero corresponding to when a pop includes more
2500+
* than one element.
2501+
*
2502+
* Then if B is non-zero AND there is no space allocate space and
2503+
* compact A, B regions into page. If there is space shift ring to
2504+
* the rigth free'ing the next element in ring to place B, leaving
2505+
* A untouched except to reduce length.
2506+
*/
2507+
if (start != offset) {
2508+
struct scatterlist *nsge, *sge = sk_msg_elem(msg, i);
2509+
int a = start;
2510+
int b = sge->length - pop - a;
2511+
2512+
sk_msg_iter_var_next(i);
2513+
2514+
if (pop < sge->length - a) {
2515+
if (space) {
2516+
sge->length = a;
2517+
sk_msg_shift_right(msg, i);
2518+
nsge = sk_msg_elem(msg, i);
2519+
get_page(sg_page(sge));
2520+
sg_set_page(nsge,
2521+
sg_page(sge),
2522+
b, sge->offset + pop + a);
2523+
} else {
2524+
struct page *page, *orig;
2525+
u8 *to, *from;
2526+
2527+
page = alloc_pages(__GFP_NOWARN |
2528+
__GFP_COMP | GFP_ATOMIC,
2529+
get_order(a + b));
2530+
if (unlikely(!page))
2531+
return -ENOMEM;
2532+
2533+
sge->length = a;
2534+
orig = sg_page(sge);
2535+
from = sg_virt(sge);
2536+
to = page_address(page);
2537+
memcpy(to, from, a);
2538+
memcpy(to + a, from + a + pop, b);
2539+
sg_set_page(sge, page, a + b, 0);
2540+
put_page(orig);
2541+
}
2542+
pop = 0;
2543+
} else if (pop >= sge->length - a) {
2544+
sge->length = a;
2545+
pop -= (sge->length - a);
2546+
}
2547+
}
2548+
2549+
/* From above the current layout _must_ be as follows,
2550+
*
2551+
* -| offset
2552+
* -| start
2553+
*
2554+
* |---- pop ---|---------------- b ------------|
2555+
* |____________________________________________| length
2556+
*
2557+
* Offset and start of the current msg elem are equal because in the
2558+
* previous case we handled offset != start and either consumed the
2559+
* entire element and advanced to the next element OR pop == 0.
2560+
*
2561+
* Two cases to handle here are first pop is less than the length
2562+
* leaving some remainder b above. Simply adjust the element's layout
2563+
* in this case. Or pop >= length of the element so that b = 0. In this
2564+
* case advance to next element decrementing pop.
2565+
*/
2566+
while (pop) {
2567+
struct scatterlist *sge = sk_msg_elem(msg, i);
2568+
2569+
if (pop < sge->length) {
2570+
sge->length -= pop;
2571+
sge->offset += pop;
2572+
pop = 0;
2573+
} else {
2574+
pop -= sge->length;
2575+
sk_msg_shift_left(msg, i);
2576+
}
2577+
sk_msg_iter_var_next(i);
2578+
}
2579+
2580+
sk_mem_uncharge(msg->sk, len - pop);
2581+
msg->sg.size -= (len - pop);
2582+
sk_msg_compute_data_pointers(msg);
2583+
return 0;
2584+
}
2585+
2586+
static const struct bpf_func_proto bpf_msg_pop_data_proto = {
2587+
.func = bpf_msg_pop_data,
2588+
.gpl_only = false,
2589+
.ret_type = RET_INTEGER,
2590+
.arg1_type = ARG_PTR_TO_CTX,
2591+
.arg2_type = ARG_ANYTHING,
2592+
.arg3_type = ARG_ANYTHING,
2593+
.arg4_type = ARG_ANYTHING,
2594+
};
2595+
24282596
BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
24292597
{
24302598
return task_get_classid(skb);
@@ -5098,6 +5266,7 @@ bool bpf_helper_changes_pkt_data(void *func)
50985266
func == bpf_xdp_adjust_meta ||
50995267
func == bpf_msg_pull_data ||
51005268
func == bpf_msg_push_data ||
5269+
func == bpf_msg_pop_data ||
51015270
func == bpf_xdp_adjust_tail ||
51025271
#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
51035272
func == bpf_lwt_seg6_store_bytes ||
@@ -5394,6 +5563,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
53945563
return &bpf_msg_pull_data_proto;
53955564
case BPF_FUNC_msg_push_data:
53965565
return &bpf_msg_push_data_proto;
5566+
case BPF_FUNC_msg_pop_data:
5567+
return &bpf_msg_pop_data_proto;
53975568
default:
53985569
return bpf_base_func_proto(func_id);
53995570
}

net/ipv4/tcp_bpf.c

+14-3
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,23 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
289289
{
290290
bool cork = false, enospc = msg->sg.start == msg->sg.end;
291291
struct sock *sk_redir;
292-
u32 tosend;
292+
u32 tosend, delta = 0;
293293
int ret;
294294

295295
more_data:
296-
if (psock->eval == __SK_NONE)
296+
if (psock->eval == __SK_NONE) {
297+
/* Track delta in msg size to add/subtract it on SK_DROP from
298+
* returned to user copied size. This ensures user doesn't
299+
* get a positive return code with msg_cut_data and SK_DROP
300+
* verdict.
301+
*/
302+
delta = msg->sg.size;
297303
psock->eval = sk_psock_msg_verdict(sk, psock, msg);
304+
if (msg->sg.size < delta)
305+
delta -= msg->sg.size;
306+
else
307+
delta = 0;
308+
}
298309

299310
if (msg->cork_bytes &&
300311
msg->cork_bytes > msg->sg.size && !enospc) {
@@ -350,7 +361,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
350361
default:
351362
sk_msg_free_partial(sk, msg, tosend);
352363
sk_msg_apply_bytes(psock, tosend);
353-
*copied -= tosend;
364+
*copied -= (tosend + delta);
354365
return -EACCES;
355366
}
356367

net/tls/tls_sw.c

+9-2
Original file line numberDiff line numberDiff line change
@@ -687,15 +687,22 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
687687
struct sock *sk_redir;
688688
struct tls_rec *rec;
689689
int err = 0, send;
690+
u32 delta = 0;
690691
bool enospc;
691692

692693
psock = sk_psock_get(sk);
693694
if (!psock)
694695
return tls_push_record(sk, flags, record_type);
695696
more_data:
696697
enospc = sk_msg_full(msg);
697-
if (psock->eval == __SK_NONE)
698+
if (psock->eval == __SK_NONE) {
699+
delta = msg->sg.size;
698700
psock->eval = sk_psock_msg_verdict(sk, psock, msg);
701+
if (delta < msg->sg.size)
702+
delta -= msg->sg.size;
703+
else
704+
delta = 0;
705+
}
699706
if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
700707
!enospc && !full_record) {
701708
err = -ENOSPC;
@@ -743,7 +750,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
743750
msg->apply_bytes -= send;
744751
if (msg->sg.size == 0)
745752
tls_free_open_rec(sk);
746-
*copied -= send;
753+
*copied -= (send + delta);
747754
err = -EACCES;
748755
}
749756

0 commit comments

Comments
 (0)