> * Solution changes from percpu tail_call_cnt to tail_call_cnt at task_struct. Please remind us what was wrong with per-cpu approach? Also notice we have pseudo per-cpu bpf insns now, so things might be easier today. On Tue, Apr 2, 2024 at 8:27 AM Leon Hwang <hffilwlqm@xxxxxxxxx> wrote: > > From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall > handling in JIT"), the tailcall on x64 works better than before. ... > > As a result, the previous tailcall way can be removed totally, including > > 1. "push rax" at prologue. > 2. load tail_call_cnt to rax before calling function. > 3. "pop rax" before jumping to tailcallee when tailcall. > 4. "push rax" and load tail_call_cnt to rax at trampoline. Please trim it. It looks like you've been copy pasting it and it's no longer accurate. Short description of the problem will do. > Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT") > Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT") > Signed-off-by: Leon Hwang <hffilwlqm@xxxxxxxxx> > --- > arch/x86/net/bpf_jit_comp.c | 137 +++++++++++++++++++++--------------- > 1 file changed, 81 insertions(+), 56 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index 3b639d6f2f54d..cd06e02e83b64 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -11,6 +11,7 @@ > #include <linux/bpf.h> > #include <linux/memory.h> > #include <linux/sort.h> > +#include <linux/sched.h> > #include <asm/extable.h> > #include <asm/ftrace.h> > #include <asm/set_memory.h> > @@ -18,6 +19,8 @@ > #include <asm/text-patching.h> > #include <asm/unwind.h> > #include <asm/cfi.h> > +#include <asm/current.h> > +#include <asm/percpu.h> > > static bool all_callee_regs_used[4] = {true, true, true, true}; > > @@ -273,7 +276,7 @@ struct jit_context { > /* Number of bytes emit_patch() needs to generate instructions */ > #define X86_PATCH_SIZE 5 > /* Number of bytes that will be skipped on tailcall */ > -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) > +#define X86_TAIL_CALL_OFFSET (14 + ENDBR_INSN_SIZE) > > static void push_r12(u8 **pprog) > { > @@ -403,6 +406,9 @@ static void emit_cfi(u8 **pprog, u32 hash) > *pprog = prog; > } > > +static int emit_call(u8 **pprog, void *func, void *ip); > +static __used void bpf_tail_call_cnt_init(void); > + > /* > * Emit x86-64 prologue code for BPF program. > * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes > @@ -410,9 +416,9 @@ static void emit_cfi(u8 **pprog, u32 hash) > */ > static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > bool tail_call_reachable, bool is_subprog, > - bool is_exception_cb) > + bool is_exception_cb, u8 *ip) > { > - u8 *prog = *pprog; > + u8 *prog = *pprog, *start = *pprog; > > emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); > /* BPF trampoline can be made to work without these nops, > @@ -421,13 +427,14 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > emit_nops(&prog, X86_PATCH_SIZE); > if (!ebpf_from_cbpf) { > if (tail_call_reachable && !is_subprog) > - /* When it's the entry of the whole tailcall context, > - * zeroing rax means initialising tail_call_cnt. > + /* Call bpf_tail_call_cnt_init to initilise > + * tail_call_cnt. > */ > - EMIT2(0x31, 0xC0); /* xor eax, eax */ > + emit_call(&prog, bpf_tail_call_cnt_init, > + ip + (prog - start)); You're repeating the same bug we discussed before. There is nothing in bpf_tail_call_cnt_init() that prevents the compiler from scratching rdi,rsi,... bpf_tail_call_cnt_init() is a normal function from compiler pov and it's allowed to use those regs. Must have been lucky that CI is not showing crashes. > else > /* Keep the same instruction layout. */ > - EMIT2(0x66, 0x90); /* nop2 */ > + emit_nops(&prog, X86_PATCH_SIZE); > } > /* Exception callback receives FP as third parameter */ > if (is_exception_cb) { > @@ -452,8 +459,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > /* sub rsp, rounded_stack_depth */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); > - if (tail_call_reachable) > - EMIT1(0x50); /* push rax */ > *pprog = prog; > } > > @@ -589,13 +594,61 @@ static void emit_return(u8 **pprog, u8 *ip) > *pprog = prog; > } > > +static __used void bpf_tail_call_cnt_init(void) > +{ > + /* The following asm equals to > + * > + * u32 *tcc_ptr = ¤t->bpf_tail_call_cnt; > + * > + * *tcc_ptr = 0; > + */ > + > + asm volatile ( > + "addq " __percpu_arg(0) ", %1\n\t" > + "addq %2, %1\n\t" > + "movq (%1), %1\n\t" > + "addq %3, %1\n\t" > + "movl $0, (%1)\n\t" > + : > + : "m" (this_cpu_off), "r" (&pcpu_hot), > + "i" (offsetof(struct pcpu_hot, current_task)), > + "i" (offsetof(struct task_struct, bpf_tail_call_cnt)) > + ); > +} > + > +static __used u32 *bpf_tail_call_cnt_ptr(void) > +{ > + u32 *tcc_ptr; > + > + /* The following asm equals to > + * > + * u32 *tcc_ptr = ¤t->bpf_tail_call_cnt; > + * > + * return tcc_ptr; > + */ > + > + asm volatile ( > + "addq " __percpu_arg(1) ", %2\n\t" > + "addq %3, %2\n\t" > + "movq (%2), %2\n\t" > + "addq %4, %2\n\t" > + "movq %2, %0\n\t" > + : "=r" (tcc_ptr) > + : "m" (this_cpu_off), "r" (&pcpu_hot), > + "i" (offsetof(struct pcpu_hot, current_task)), > + "i" (offsetof(struct task_struct, bpf_tail_call_cnt)) > + ); > + > + return tcc_ptr; > +} > + > /* > * Generate the following code: > * > * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... > * if (index >= array->map.max_entries) > * goto out; > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) > * goto out; > * prog = array->ptrs[index]; > * if (prog == NULL) > @@ -608,7 +661,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > u32 stack_depth, u8 *ip, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > @@ -630,16 +682,16 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > EMIT2(X86_JBE, offset); /* jbe out */ > > /* > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > + /* call bpf_tail_call_cnt_ptr */ > + emit_call(&prog, bpf_tail_call_cnt_ptr, ip + (prog - start)); same issue. > + EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */ > > offset = ctx->tail_call_indirect_label - (prog + 2 - start); > EMIT2(X86_JAE, offset); /* jae out */ > - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ > - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ > + EMIT2(0xFF, 0x00); /* inc dword ptr [rax] */ > > /* prog = array->ptrs[index]; */ > EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ > @@ -663,7 +715,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > pop_r12(&prog); > } > > - EMIT1(0x58); /* pop rax */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ > round_up(stack_depth, 8)); > @@ -691,21 +742,20 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > bool *callee_regs_used, u32 stack_depth, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > /* > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > + /* call bpf_tail_call_cnt_ptr */ > + emit_call(&prog, bpf_tail_call_cnt_ptr, ip); and here as well. pw-bot: cr