On Fri, Sep 11, 2020 at 08:59:27PM +0200, Maciej Fijalkowski wrote: > On Thu, Sep 03, 2020 at 12:51:14PM -0700, Alexei Starovoitov wrote: > > On Wed, Sep 02, 2020 at 10:08:15PM +0200, Maciej Fijalkowski wrote: > > > diff --git a/tools/testing/selftests/bpf/progs/tailcall6.c b/tools/testing/selftests/bpf/progs/tailcall6.c > > > new file mode 100644 > > > index 000000000000..e72ca5869b58 > > > --- /dev/null > > > +++ b/tools/testing/selftests/bpf/progs/tailcall6.c > > > @@ -0,0 +1,38 @@ > > > +// SPDX-License-Identifier: GPL-2.0 > > > +#include <linux/bpf.h> > > > +#include <bpf/bpf_helpers.h> > > > + > > > +struct { > > > + __uint(type, BPF_MAP_TYPE_PROG_ARRAY); > > > + __uint(max_entries, 2); > > > + __uint(key_size, sizeof(__u32)); > > > + __uint(value_size, sizeof(__u32)); > > > +} jmp_table SEC(".maps"); > > > + > > > +#define TAIL_FUNC(x) \ > > > + SEC("classifier/" #x) \ > > > + int bpf_func_##x(struct __sk_buff *skb) \ > > > + { \ > > > + return x; \ > > > + } > > > +TAIL_FUNC(0) > > > +TAIL_FUNC(1) > > > + > > > +static __attribute__ ((noinline)) > > > +int subprog_tail(struct __sk_buff *skb) > > > +{ > > > + bpf_tail_call(skb, &jmp_table, 0); > > > + > > > + return skb->len * 2; > > > +} > > > + > > > +SEC("classifier") > > > +int entry(struct __sk_buff *skb) > > > +{ > > > + bpf_tail_call(skb, &jmp_table, 1); > > > + > > > + return subprog_tail(skb); > > > +} > > > > Could you add few more tests to exercise the new feature more thoroughly? > > Something like tailcall3.c that checks 32 limit, but doing tail_call from subprog. > > And another test that consume non-trival amount of stack in each function. > > Adding 'volatile char arr[128] = {};' would do the trick. > > Yet another prolonged silence from my side, but not without a reason - > this request opened up a Pandora's box. Great catch and thanks to our development practices! As a community we should remember this lesson and request selftests more often than not. > First thing that came out when I added the global variable to act as a > counter in the tailcall3-like subprog-based test was the fact that when > the patching happen, we need to update the index of tailcall insn that we > store within the poke descriptor. Due to patching and insn not being > adjusted, the poke descriptor was not propagated to subprogram and JIT > started to fail. > > It's rather obvious change so I won't post it here to decrease the chaos > in this response, but I simply teached bpf_patch_insn_data() to go over > poke descriptors and update the insn_idx by given len. Will include in > next revision. +1 > Now onto serious stuff that I would like to discuss. Turns out that for > tailcall3-like selftest: > > // SPDX-License-Identifier: GPL-2.0 > #include <linux/bpf.h> > #include <bpf/bpf_helpers.h> > > struct { > __uint(type, BPF_MAP_TYPE_PROG_ARRAY); > __uint(max_entries, 1); > __uint(key_size, sizeof(__u32)); > __uint(value_size, sizeof(__u32)); > } jmp_table SEC(".maps"); > > static __attribute__ ((noinline)) > int subprog_tail(struct __sk_buff *skb) > { > bpf_tail_call(skb, &jmp_table, 0); > return 1; > } > > SEC("classifier/0") > int bpf_func_0(struct __sk_buff *skb) > { > return subprog_tail(skb); > } > > SEC("classifier") > int entry(struct __sk_buff *skb) > { > bpf_tail_call(skb, &jmp_table, 0); > > return 0; > } > > char __license[] SEC("license") = "GPL"; > int _version SEC("version") = 1; > > following asm was generated: > > entry: > ffffffffa0ca0c40 <load4+0xca0c40>: > ffffffffa0ca0c40: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) > ffffffffa0ca0c45: 31 c0 xor %eax,%eax > ffffffffa0ca0c47: 55 push %rbp > ffffffffa0ca0c48: 48 89 e5 mov %rsp,%rbp > ffffffffa0ca0c4b: 48 81 ec 00 00 00 00 sub $0x0,%rsp > ffffffffa0ca0c52: 50 push %rax > ffffffffa0ca0c53: 48 be 00 6c b1 c1 81 movabs $0xffff8881c1b16c00,%rsi > ffffffffa0ca0c5a: 88 ff ff > ffffffffa0ca0c5d: 31 d2 xor %edx,%edx > ffffffffa0ca0c5f: 8b 85 fc ff ff ff mov -0x4(%rbp),%eax > ffffffffa0ca0c65: 83 f8 20 cmp $0x20,%eax > ffffffffa0ca0c68: 77 1b ja 0xffffffffa0ca0c85 > ffffffffa0ca0c6a: 83 c0 01 add $0x1,%eax > ffffffffa0ca0c6d: 89 85 fc ff ff ff mov %eax,-0x4(%rbp) > ffffffffa0ca0c73: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) > ffffffffa0ca0c78: 58 pop %rax > ffffffffa0ca0c79: 48 81 c4 00 00 00 00 add $0x0,%rsp > ffffffffa0ca0c80: e9 d2 b6 ff ff jmpq 0xffffffffa0c9c357 > ffffffffa0ca0c85: 31 c0 xor %eax,%eax > ffffffffa0ca0c87: 59 pop %rcx > ffffffffa0ca0c88: c9 leaveq > ffffffffa0ca0c89: c3 retq > > func0: > ffffffffa0c9c34c <load4+0xc9c34c>: > ffffffffa0c9c34c: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) > ffffffffa0c9c351: 66 90 xchg %ax,%ax > ffffffffa0c9c353: 55 push %rbp > ffffffffa0c9c354: 48 89 e5 mov %rsp,%rbp > ffffffffa0c9c357: 48 81 ec 00 00 00 00 sub $0x0,%rsp > ffffffffa0c9c35e: e8 b1 20 00 00 callq 0xffffffffa0c9e414 > ffffffffa0c9c363: b8 01 00 00 00 mov $0x1,%eax > ffffffffa0c9c368: c9 leaveq > ffffffffa0c9c369: c3 retq > > subprog_tail: > ffffffffa0c9e414: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) > ffffffffa0c9e419: 31 c0 xor %eax,%eax > ffffffffa0c9e41b: 55 push %rbp > ffffffffa0c9e41c: 48 89 e5 mov %rsp,%rbp > ffffffffa0c9e41f: 48 81 ec 00 00 00 00 sub $0x0,%rsp > ffffffffa0c9e426: 50 push %rax > ffffffffa0c9e427: 48 be 00 6c b1 c1 81 movabs $0xffff8881c1b16c00,%rsi > ffffffffa0c9e42e: 88 ff ff > ffffffffa0c9e431: 31 d2 xor %edx,%edx > ffffffffa0c9e433: 8b 85 fc ff ff ff mov -0x4(%rbp),%eax > ffffffffa0c9e439: 83 f8 20 cmp $0x20,%eax > ffffffffa0c9e43c: 77 1b ja 0xffffffffa0c9e459 > ffffffffa0c9e43e: 83 c0 01 add $0x1,%eax > ffffffffa0c9e441: 89 85 fc ff ff ff mov %eax,-0x4(%rbp) > ffffffffa0c9e447: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) > ffffffffa0c9e44c: 58 pop %rax > ffffffffa0c9e44d: 48 81 c4 00 00 00 00 add $0x0,%rsp > ffffffffa0c9e454: e9 fe de ff ff jmpq 0xffffffffa0c9c357 > ffffffffa0c9e459: 59 pop %rcx > ffffffffa0c9e45a: c9 leaveq > ffffffffa0c9e45b: c3 retq > > So this flow was doing: > entry -> set tailcall counter to 0, bump it by 1, tailcall to func0 > func0 -> call subprog_tail > (we are NOT skipping the first 11 bytes of prologue and this subprogram > has a tailcall, therefore we clear the counter...) > subprog -> do the same thing as entry > > and then loop forever. This shows that in our current design there's a > missing gap of preserving the tailcall counter when bpf2bpf gets mixed > with tailcalls. > > To address this, the idea is to go through the call chain of bpf2bpf progs > and look for a tailcall presence throughout whole chain. If we saw a > single tail call then each node in this call chain needs to be marked as > as a subprog that can reach the tailcall. We would later feed the JIT with > this info and: > - set eax to 0 only when tailcall is reachable and this is the > entry prog > - if tailcall is reachable but there's no tailcall in insns of currently > JITed prog then push rax anyway, so that it will be possible to > propagate further down the call chain > - finally if tailcall is reachable, then we need to precede the 'call' > insn with mov rax, [rsp] may be 'mov rax, [rbp - 4]' for consistency with other places with it's read/written ? > This way jumping to subprog results in tailcall counter sitting in rax and > we will not clear it since it is a subprog. > > I think we can easily mark such progs after we reach the end of insn of > current subprog in check_max_stack_depth(). > > I'd like to share a dirty diff that I currently have so that it's easier > to review this approach rather than finding the diff between revisions. > It also includes the concern Alexei had on 5/7 (hopefully, if i understood > it right): > > ------------------------------------------------------ > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index 58b848029e2f..ed03de3ba27b 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -262,7 +262,7 @@ static void pop_callee_regs(u8 **pprog, bool *callee_regs_used) > * while jumping to another program > */ > static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > - bool tail_call) > + bool tail_call, bool is_subprog, bool tcr) > { > u8 *prog = *pprog; > int cnt = X86_PATCH_SIZE; > @@ -273,7 +273,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt); > prog += cnt; > if (!ebpf_from_cbpf) { > - if (tail_call) > + if ((tcr || tail_call) && !is_subprog) please spell it out as 'tail_call_reachable'. Also probably 'tail_call' argument is no longer needed. > EMIT2(0x31, 0xC0); /* xor eax, eax */ > else > EMIT2(0x66, 0x90); /* nop2 */ > @@ -282,7 +282,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ > /* sub rsp, rounded_stack_depth */ > EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); > - if (!ebpf_from_cbpf && tail_call) > + if ((!ebpf_from_cbpf && tail_call) || tcr) May be do 'if (tail_call_reachable)' ? cbpf doesn't have tail_calls, so ebpf_from_cbpf is unnecessary. > EMIT1(0x50); /* push rax */ > *pprog = prog; > } > @@ -793,7 +793,9 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, > &tail_call_seen); > > emit_prologue(&prog, bpf_prog->aux->stack_depth, > - bpf_prog_was_classic(bpf_prog), tail_call_seen); > + bpf_prog_was_classic(bpf_prog), tail_call_seen, > + bpf_prog->aux->is_subprog, > + bpf_prog->aux->tail_call_reachable); > push_callee_regs(&prog, callee_regs_used); > addrs[0] = prog - temp; > > @@ -1232,8 +1234,14 @@ xadd: if (is_imm8(insn->off)) > /* call */ > case BPF_JMP | BPF_CALL: > func = (u8 *) __bpf_call_base + imm32; > - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) > - return -EINVAL; > + if (bpf_prog->aux->tail_call_reachable) { > + EMIT4(0x48, 0x8B, 0x04, 0x24); // mov rax, [rsp] [rbp-4] ? > + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 4)) > + return -EINVAL; > + } else { > + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) > + return -EINVAL; > + } > break; > > case BPF_JMP | BPF_TAIL_CALL: > @@ -1429,7 +1437,9 @@ xadd: if (is_imm8(insn->off)) > /* Update cleanup_addr */ > ctx->cleanup_addr = proglen; > pop_callee_regs(&prog, callee_regs_used); > - if (!bpf_prog_was_classic(bpf_prog) && tail_call_seen) > + if ((!bpf_prog_was_classic(bpf_prog) && tail_call_seen) || > + bpf_prog->aux->tail_call_reachable) bpf_prog_was_classic() check is redundant? Just 'if (bpf_prog->aux->tail_call_reachable)' should be enough? > + > EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */ > EMIT1(0xC9); /* leave */ > EMIT1(0xC3); /* ret */ > diff --git a/include/linux/bpf.h b/include/linux/bpf.h > index 7910b87e4ea2..d41e08fbb85f 100644 > --- a/include/linux/bpf.h > +++ b/include/linux/bpf.h > @@ -740,6 +740,8 @@ struct bpf_prog_aux { > bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */ > bool func_proto_unreliable; > bool sleepable; > + bool is_subprog; why? aux->func_idx != 0 would do the same. > + bool tail_call_reachable; > enum bpf_tramp_prog_type trampoline_prog_type; > struct bpf_trampoline *trampoline; > struct hlist_node tramp_hlist; > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h > index 5026b75db972..fbc964526ba3 100644 > --- a/include/linux/bpf_verifier.h > +++ b/include/linux/bpf_verifier.h > @@ -359,6 +359,7 @@ struct bpf_subprog_info { > u32 linfo_idx; /* The idx to the main_prog->aux->linfo */ > u16 stack_depth; /* max. stack depth used by this function */ > bool has_tail_call; > + bool tail_call_reachable; > }; > > /* single container for all structs > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index deb6bf3d9f5d..3a7ebcdf076e 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -1490,12 +1490,13 @@ static int check_subprogs(struct bpf_verifier_env *env) > for (i = 0; i < insn_cnt; i++) { > u8 code = insn[i].code; > > - if (insn[i].imm == BPF_FUNC_tail_call) > - subprog[cur_subprog].has_tail_call = true; > if (BPF_CLASS(code) != BPF_JMP && BPF_CLASS(code) != BPF_JMP32) > goto next; > if (BPF_OP(code) == BPF_EXIT || BPF_OP(code) == BPF_CALL) > goto next; > + if ((code == (BPF_JMP | BPF_CALL)) && > + insn[i].imm == BPF_FUNC_tail_call) && insn->src_reg != BPF_PSEUDO_CALL is still missing. > + subprog[cur_subprog].has_tail_call = true; > off = i + insn[i].off + 1; > if (off < subprog_start || off >= subprog_end) { > verbose(env, "jump out of range from insn %d to %d\n", i, off); > @@ -2983,6 +2984,8 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) > struct bpf_insn *insn = env->prog->insnsi; > int ret_insn[MAX_CALL_FRAMES]; > int ret_prog[MAX_CALL_FRAMES]; > + bool tcr; how does it work? Shouldn't it be inited to = false; ? > + int j; > > process_func: > #if defined(CONFIG_X86_64) && defined(CONFIG_BPF_JIT_ALWAYS_ON) > @@ -3039,6 +3042,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) > i); > return -EFAULT; > } > + > + if (!tcr && subprog[idx].has_tail_call) > + tcr = true; > + > frame++; > if (frame >= MAX_CALL_FRAMES) { > verbose(env, "the call stack of %d frames is too deep !\n", > @@ -3047,11 +3054,24 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) > } > goto process_func; > } > + /* this means we are at the end of the call chain; if throughout this In my mind 'end of the call chain' means 'leaf function', so the comment reads a bit misleading to me. Here we're at the end of subprog. It's not necessarily the leaf function. > + * whole call chain tailcall has been detected, then each of the > + * subprogs (or their frames) that are currently present on stack need > + * to be marked as tail call reachable subprogs; > + * this info will be utilized by JIT so that we will be preserving the > + * tail call counter throughout bpf2bpf calls combined with tailcalls > + */ > + if (!tcr) > + goto skip; > + for (j = 0; j < frame; j++) > + subprog[ret_prog[j]].tail_call_reachable = true; > +skip: please avoid goto. Just extra indent isn't that bad: if (tail_call_reachable) for (j = 0; j < frame; j++) subprog[ret_prog[j]].tail_call_reachable = true; > /* end of for() loop means the last insn of the 'subprog' > * was reached. Doesn't matter whether it was JA or EXIT > */ > if (frame == 0) > return 0; > + no need. > depth -= round_up(max_t(u32, subprog[idx].stack_depth, 1), 32); > frame--; > i = ret_insn[frame]; > > ------------------------------------------------------ > > Having this in place preserves the tailcall counter when we mix bpf2bpf > and tailcalls. I will attach the selftest that has a following call chain: > > entry -> entry_subprog -> tailcall0 -> bpf_func0 -> subprog0 -> > -> tailcall1 -> bpf_func1 -> subprog1 -> tailcall2 -> bpf_func2 -> > subprog2 [here bump global counter] --------^ > > We go through first two tailcalls and start counting from the subprog2 > where the loop begins. At the end of the test i see that global counter > gets the value of 31 which is correct. sounds great. > For the test that uses lot of stack across subprogs - i suppose we should > use up to 256 in total, right? otherwise we wouldn't even load the prog so > test won't even run. yep. makes sense to me. > Kudos to Bjorn for brainstorming on this! Indeed. It's pretty cool problem and I think you've came up with a good solution. Since this tail_call_cnt will now be passed from subprog to subrpog via "interesting" rax calling convention we can eventually retrofit it to count used-stack-so-far. That would be for the case we discussed earlier (counting stack vs counting calls). For now the approach you're proposing is good.