On Wed, Oct 11, 2023 at 11:27:23PM +0800, Leon Hwang wrote: > From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall > handling in JIT"), the tailcall on x64 works better than before. > > From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms > for x64 JIT"), tailcall is able to run in BPF subprograms on x64. > > How about: > > 1. More than 1 subprograms are called in a bpf program. > 2. The tailcalls in the subprograms call the bpf program. > > Because of missing tail_call_cnt back-propagation, a tailcall hierarchy > comes up. And MAX_TAIL_CALL_CNT limit does not work for this case. > > As we know, in tail call context, the tail_call_cnt propagates by stack > and rax register between BPF subprograms. So, propagating tail_call_cnt > pointer by stack and rax register makes tail_call_cnt as like a global > variable, in order to make MAX_TAIL_CALL_CNT limit works for tailcall > hierarchy cases. > > Before jumping to other bpf prog, load tail_call_cnt from the pointer > and then compare with MAX_TAIL_CALL_CNT. Finally, increment > tail_call_cnt by its pointer. > > But, where does tail_call_cnt store? > > It stores on the stack of bpf prog's caller, like > > | STACK | > | | > | rip | > +->| tcc | > | | rip | > | | rbp | > | +---------+ RBP > | | | > | | | > | | | > +--| tcc_ptr | > | rbx | > +---------+ RSP > > And tcc_ptr is unnecessary to be popped from stack at the epilogue of bpf > prog, like the way of commit d207929d97ea028f ("bpf, x64: Drop "pop %rcx" > instruction on BPF JIT epilogue"). > > Why not back-propagate tail_call_cnt? > > It's because it's vulnerable to back-propagate it. It's unable to work > well with the following case. > > int prog1(); > int prog2(); > > prog1 is tail caller, and prog2 is tail callee. If we do back-propagate > tail_call_cnt at the epilogue of prog2, can prog2 run standalone at the > same time? The answer is NO. Otherwise, there will be a register to be > polluted, which will make kernel crash. > > Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT") > Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT") > Signed-off-by: Leon Hwang <hffilwlqm@xxxxxxxxx> > --- > arch/x86/net/bpf_jit_comp.c | 40 ++++++++++++++++++++++--------------- > 1 file changed, 24 insertions(+), 16 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index c2a0465d37da4..36631129cc800 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -256,7 +256,7 @@ struct jit_context { > /* Number of bytes emit_patch() needs to generate instructions */ > #define X86_PATCH_SIZE 5 > /* Number of bytes that will be skipped on tailcall */ > -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) > +#define X86_TAIL_CALL_OFFSET (22 + ENDBR_INSN_SIZE) > > static void push_r12(u8 **pprog) > { > @@ -340,14 +340,21 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > EMIT_ENDBR(); > emit_nops(&prog, X86_PATCH_SIZE); > if (!ebpf_from_cbpf) { > - if (tail_call_reachable && !is_subprog) > + if (tail_call_reachable && !is_subprog) { > /* When it's the entry of the whole tailcall context, > * zeroing rax means initialising tail_call_cnt. > */ > - EMIT2(0x31, 0xC0); /* xor eax, eax */ > - else > - /* Keep the same instruction layout. */ > - EMIT2(0x66, 0x90); /* nop2 */ > + EMIT2(0x31, 0xC0); /* xor eax, eax */ > + EMIT1(0x50); /* push rax */ > + /* Make rax as ptr that points to tail_call_cnt. */ > + EMIT3(0x48, 0x89, 0xE0); /* mov rax, rsp */ > + EMIT1_off32(0xE8, 2); /* call main prog */ > + EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */ > + EMIT1(0xC3); /* ret */ > + } else { > + /* Keep the same instruction size. */ > + emit_nops(&prog, 13); > + } At first sight it seemed to me too invasive but after trying out few other approaches in the end it is elegant. I wanted to avoid a bit puzzling call insn in the prologue with a following prologue layout (this will be based on entry prog from tailcall_bpf2bpf3.c that was the first one to break): ffffffffc0012cb4: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) ffffffffc0012cb9: 55 push %rbp ffffffffc0012cba: 48 89 e5 mov %rsp,%rbp ffffffffc0012cbd: 48 83 ec 10 sub $0x10,%rsp ffffffffc0012cc1: 48 89 65 f8 mov %rsp,-0x8(%rbp) ffffffffc0012cc5: 48 c7 04 24 00 00 00 movq $0x0,(%rsp) ffffffffc0012ccc: 00 ffffffffc0012ccd: 48 8b 45 f8 mov -0x8(%rbp),%rax ffffffffc0012cd1: 50 push %rax ffffffffc0012cd2: 48 81 ec 80 00 00 00 sub $0x80,%rsp So we would have hidden 16 bytes on stack at the *beginning* of entry stack frame. First thing right after rbp would be tcc pointer so referring to it wouldn't require us to take into account stack depth. But then if we follow with rest of insns: ffffffffc0012cd9: 31 f6 xor %esi,%esi ffffffffc0012cdb: 48 89 75 f8 mov %rsi,-0x8(%rbp) // BUG, overwrite of tcc ptr ffffffffc0012cdf: 48 89 75 f0 mov %rsi,-0x10(%rbp) ffffffffc0012ce3: 48 89 75 e8 mov %rsi,-0x18(%rbp) ffffffffc0012ce7: 48 89 75 e0 mov %rsi,-0x20(%rbp) ffffffffc0012ceb: 48 89 75 d8 mov %rsi,-0x28(%rbp) ffffffffc0012cef: 48 89 75 d0 mov %rsi,-0x30(%rbp) ffffffffc0012cf3: 48 89 75 c8 mov %rsi,-0x38(%rbp) ffffffffc0012cf7: 48 89 75 c0 mov %rsi,-0x40(%rbp) ffffffffc0012cfb: 48 89 75 b8 mov %rsi,-0x48(%rbp) ffffffffc0012cff: 48 89 75 b0 mov %rsi,-0x50(%rbp) ffffffffc0012d03: 48 89 75 a8 mov %rsi,-0x58(%rbp) ffffffffc0012d07: 48 89 75 a0 mov %rsi,-0x60(%rbp) ffffffffc0012d0b: 48 89 75 98 mov %rsi,-0x68(%rbp) ffffffffc0012d0f: 48 89 75 90 mov %rsi,-0x70(%rbp) ffffffffc0012d13: 48 89 75 88 mov %rsi,-0x78(%rbp) ffffffffc0012d17: 48 89 75 80 mov %rsi,-0x80(%rbp) ffffffffc0012d1b: 48 0f b6 75 ff movzbq -0x1(%rbp),%rsi ffffffffc0012d20: 40 88 75 ff mov %sil,-0x1(%rbp) ffffffffc0012d24: 48 8b 85 f8 ff ff ff mov -0x8(%rbp),%rax ffffffffc0012d2b: e8 30 00 00 00 call 0xffffffffc0012d60 ffffffffc0012d30: c9 leave ffffffffc0012d31: c3 ret So even though it would seem more obvious while looking at prologue what is the intent behind it, this would require us to patch the instructions that make us of R10/stack, which in the end would be way more invasive. After all, for x86 JIT code: Reviewed-by: Maciej Fijalkowski <maciej.fijalkowski@xxxxxxxxx> but it is a must to have a better commit message here. Thanks! > } > /* Exception callback receives FP as third parameter */ > if (is_exception_cb) { > @@ -373,6 +380,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); > if (tail_call_reachable) > + /* Here, rax is tail_call_cnt_ptr. */ > EMIT1(0x50); /* push rax */ > *pprog = prog; > } > @@ -528,7 +536,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > u32 stack_depth, u8 *ip, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > + int tcc_ptr_off = -8 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > @@ -553,13 +561,12 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ > + EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */ > > offset = ctx->tail_call_indirect_label - (prog + 2 - start); > EMIT2(X86_JAE, offset); /* jae out */ > - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ > - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ > + EMIT3(0x83, 0x00, 0x01); /* add dword ptr [rax], 1 */ > > /* prog = array->ptrs[index]; */ > EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ > @@ -581,6 +588,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > pop_callee_regs(&prog, callee_regs_used); > } > > + /* pop tail_call_cnt_ptr */ > EMIT1(0x58); /* pop rax */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ > @@ -609,7 +617,7 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > bool *callee_regs_used, u32 stack_depth, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > + int tcc_ptr_off = -8 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > @@ -617,13 +625,12 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ > + EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */ > > offset = ctx->tail_call_direct_label - (prog + 2 - start); > EMIT2(X86_JAE, offset); /* jae out */ > - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ > - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ > + EMIT3(0x83, 0x00, 0x01); /* add dword ptr [rax], 1 */ > > poke->tailcall_bypass = ip + (prog - start); > poke->adj_off = X86_TAIL_CALL_OFFSET; > @@ -640,6 +647,7 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > pop_callee_regs(&prog, callee_regs_used); > } > > + /* pop tail_call_cnt_ptr */ > EMIT1(0x58); /* pop rax */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); > -- > 2.41.0 > >