Re: [PATCH bpf-next v3 2/3] bpf, x64: Fix tailcall hierarchy

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 




On 2024/4/5 09:03, Alexei Starovoitov wrote:
>>   * Solution changes from percpu tail_call_cnt to tail_call_cnt at task_struct.
> 
> Please remind us what was wrong with per-cpu approach?

There are some cases that the per-cpu approach cannot handle properly.
Especialy, on non-SMP machine, if there are many bpf progs to run with
tail call simultaneously, MAX_TAIL_CALL_CNT limit is unable to limit the
tail call expectedly.

> 
> Also notice we have pseudo per-cpu bpf insns now,
> so things might be easier today.

Great to hear that. With pseudo per-cpu bpf insns, it is able to get
tcc_ptr from task_struct without a function call.

> 
> On Tue, Apr 2, 2024 at 8:27 AM Leon Hwang <hffilwlqm@xxxxxxxxx> wrote:
>>
>> From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
>> handling in JIT"), the tailcall on x64 works better than before.
> ...
>>
>> As a result, the previous tailcall way can be removed totally, including
>>
>> 1. "push rax" at prologue.
>> 2. load tail_call_cnt to rax before calling function.
>> 3. "pop rax" before jumping to tailcallee when tailcall.
>> 4. "push rax" and load tail_call_cnt to rax at trampoline.
> 
> Please trim it.
> It looks like you've been copy pasting it and it's no longer
> accurate.
> Short description of the problem will do.

Got it.

> 
>> Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT")
>> Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT")
>> Signed-off-by: Leon Hwang <hffilwlqm@xxxxxxxxx>
>> ---
>>  arch/x86/net/bpf_jit_comp.c | 137 +++++++++++++++++++++---------------
>>  1 file changed, 81 insertions(+), 56 deletions(-)
>>
>> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
>> index 3b639d6f2f54d..cd06e02e83b64 100644
>> --- a/arch/x86/net/bpf_jit_comp.c
>> +++ b/arch/x86/net/bpf_jit_comp.c
>> @@ -11,6 +11,7 @@
>>  #include <linux/bpf.h>
>>  #include <linux/memory.h>
>>  #include <linux/sort.h>
>> +#include <linux/sched.h>
>>  #include <asm/extable.h>
>>  #include <asm/ftrace.h>
>>  #include <asm/set_memory.h>
>> @@ -18,6 +19,8 @@
>>  #include <asm/text-patching.h>
>>  #include <asm/unwind.h>
>>  #include <asm/cfi.h>
>> +#include <asm/current.h>
>> +#include <asm/percpu.h>
>>
>>  static bool all_callee_regs_used[4] = {true, true, true, true};
>>
>> @@ -273,7 +276,7 @@ struct jit_context {
>>  /* Number of bytes emit_patch() needs to generate instructions */
>>  #define X86_PATCH_SIZE         5
>>  /* Number of bytes that will be skipped on tailcall */
>> -#define X86_TAIL_CALL_OFFSET   (11 + ENDBR_INSN_SIZE)
>> +#define X86_TAIL_CALL_OFFSET   (14 + ENDBR_INSN_SIZE)
>>
>>  static void push_r12(u8 **pprog)
>>  {
>> @@ -403,6 +406,9 @@ static void emit_cfi(u8 **pprog, u32 hash)
>>         *pprog = prog;
>>  }
>>
>> +static int emit_call(u8 **pprog, void *func, void *ip);
>> +static __used void bpf_tail_call_cnt_init(void);
>> +
>>  /*
>>   * Emit x86-64 prologue code for BPF program.
>>   * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
>> @@ -410,9 +416,9 @@ static void emit_cfi(u8 **pprog, u32 hash)
>>   */
>>  static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>>                           bool tail_call_reachable, bool is_subprog,
>> -                         bool is_exception_cb)
>> +                         bool is_exception_cb, u8 *ip)
>>  {
>> -       u8 *prog = *pprog;
>> +       u8 *prog = *pprog, *start = *pprog;
>>
>>         emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash);
>>         /* BPF trampoline can be made to work without these nops,
>> @@ -421,13 +427,14 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>>         emit_nops(&prog, X86_PATCH_SIZE);
>>         if (!ebpf_from_cbpf) {
>>                 if (tail_call_reachable && !is_subprog)
>> -                       /* When it's the entry of the whole tailcall context,
>> -                        * zeroing rax means initialising tail_call_cnt.
>> +                       /* Call bpf_tail_call_cnt_init to initilise
>> +                        * tail_call_cnt.
>>                          */
>> -                       EMIT2(0x31, 0xC0); /* xor eax, eax */
>> +                       emit_call(&prog, bpf_tail_call_cnt_init,
>> +                                 ip + (prog - start));
> 
> You're repeating the same bug we discussed before.
> There is nothing in bpf_tail_call_cnt_init() that
> prevents the compiler from scratching rdi,rsi,...
> bpf_tail_call_cnt_init() is a normal function from compiler pov
> and it's allowed to use those regs.
> Must have been lucky that CI is not showing crashes.

Oh, get it. In order to prevent the compiler from scratching
rdi,rsi,..., the asm clobbered register list in bpf_tail_call_cnt_init()
must be "rdi", "rsi", "rdx", "rcx", "r8". I learn it from the GCC doc[0].

static __used void bpf_tail_call_cnt_init(void)
{
	/* In short:
	 * current->bpf_tail_call_cnt = 0;
	 */

	asm volatile (
	    "addq " __percpu_arg(0) ", %1\n\t"
	    "movq (%1), %1\n\t"
	    "movl $0x0, %c2(%1)\n\t"
	    :
	    : "m" (__my_cpu_var(this_cpu_off)), "r" (&pcpu_hot.current_task),
	      "i" (offsetof(struct task_struct, bpf_tail_call_cnt))
	    : "rdi", "rsi", "rdx", "rcx", "r8" /* to avoid scratching these regs */
	);
}

[0]
https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html#Clobbers-and-Scratch-Registers-1

> 
>>                 else
>>                         /* Keep the same instruction layout. */
>> -                       EMIT2(0x66, 0x90); /* nop2 */
>> +                       emit_nops(&prog, X86_PATCH_SIZE);
>>         }
>>         /* Exception callback receives FP as third parameter */
>>         if (is_exception_cb) {
>> @@ -452,8 +459,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>>         /* sub rsp, rounded_stack_depth */
>>         if (stack_depth)
>>                 EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
>> -       if (tail_call_reachable)
>> -               EMIT1(0x50);         /* push rax */
>>         *pprog = prog;
>>  }
>>
>> @@ -589,13 +594,61 @@ static void emit_return(u8 **pprog, u8 *ip)
>>         *pprog = prog;
>>  }
>>
>> +static __used void bpf_tail_call_cnt_init(void)
>> +{
>> +       /* The following asm equals to
>> +        *
>> +        * u32 *tcc_ptr = &current->bpf_tail_call_cnt;
>> +        *
>> +        * *tcc_ptr = 0;
>> +        */
>> +
>> +       asm volatile (
>> +           "addq " __percpu_arg(0) ", %1\n\t"
>> +           "addq %2, %1\n\t"
>> +           "movq (%1), %1\n\t"
>> +           "addq %3, %1\n\t"
>> +           "movl $0, (%1)\n\t"
>> +           :
>> +           : "m" (this_cpu_off), "r" (&pcpu_hot),
>> +             "i" (offsetof(struct pcpu_hot, current_task)),
>> +             "i" (offsetof(struct task_struct, bpf_tail_call_cnt))
>> +       );
>> +}
>> +
>> +static __used u32 *bpf_tail_call_cnt_ptr(void)
>> +{
>> +       u32 *tcc_ptr;
>> +
>> +       /* The following asm equals to
>> +        *
>> +        * u32 *tcc_ptr = &current->bpf_tail_call_cnt;
>> +        *
>> +        * return tcc_ptr;
>> +        */
>> +
>> +       asm volatile (
>> +           "addq " __percpu_arg(1) ", %2\n\t"
>> +           "addq %3, %2\n\t"
>> +           "movq (%2), %2\n\t"
>> +           "addq %4, %2\n\t"
>> +           "movq %2, %0\n\t"
>> +           : "=r" (tcc_ptr)
>> +           : "m" (this_cpu_off), "r" (&pcpu_hot),
>> +             "i" (offsetof(struct pcpu_hot, current_task)),
>> +             "i" (offsetof(struct task_struct, bpf_tail_call_cnt))
>> +       );
>> +
>> +       return tcc_ptr;
>> +}
>> +
>>  /*
>>   * Generate the following code:
>>   *
>>   * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
>>   *   if (index >= array->map.max_entries)
>>   *     goto out;
>> - *   if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
>> + *   if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
>>   *     goto out;
>>   *   prog = array->ptrs[index];
>>   *   if (prog == NULL)
>> @@ -608,7 +661,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
>>                                         u32 stack_depth, u8 *ip,
>>                                         struct jit_context *ctx)
>>  {
>> -       int tcc_off = -4 - round_up(stack_depth, 8);
>>         u8 *prog = *pprog, *start = *pprog;
>>         int offset;
>>
>> @@ -630,16 +682,16 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
>>         EMIT2(X86_JBE, offset);                   /* jbe out */
>>
>>         /*
>> -        * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
>> +        * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
>>          *      goto out;
>>          */
>> -       EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
>> -       EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
>> +       /* call bpf_tail_call_cnt_ptr */
>> +       emit_call(&prog, bpf_tail_call_cnt_ptr, ip + (prog - start));
> 
> same issue.

I will rewrite it to emit_bpf_tail_call_cnt_ptr(), which will use pseudo
per-cpu bpf insns to get tcc_ptr from task_struct.

static void emit_bpf_tail_call_cnt_ptr(u8 **pprog)
{
	u8 *prog = *pprog;

	/* In short:
	 * return &current->bpf_tail_call_cnt;
	 */

	/* mov rax, &pcpu_hot.current_task */
	EMIT3_off32(0x48, 0xC7, 0xC0, ((u32)(unsigned
long)&pcpu_hot.current_task));

#ifdef CONFIG_SMP
	/* add rax, gs:[&this_cpu_off] */
	EMIT1(0x65);
	EMIT4_off32(0x48, 0x03, 0x04, 0x25, ((u32)(unsigned long)&this_cpu_off));
#endif

	/* mov rax, qword ptr [rax] */
	EMIT3(0x48, 0x8B, 0x00);
	/* add rax, offsetof(struct task_struct, bpf_tail_call_cnt) */
	EMIT2_off32(0x48, 0x05, ((u32)offsetof(struct task_struct,
bpf_tail_call_cnt)));

	*pprog = prog;
}

> 
>> +       EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT);     /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */
>>
>>         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
>>         EMIT2(X86_JAE, offset);                   /* jae out */
>> -       EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
>> -       EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
>> +       EMIT2(0xFF, 0x00);                        /* inc dword ptr [rax] */
>>
>>         /* prog = array->ptrs[index]; */
>>         EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
>> @@ -663,7 +715,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
>>                         pop_r12(&prog);
>>         }
>>
>> -       EMIT1(0x58);                              /* pop rax */
>>         if (stack_depth)
>>                 EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
>>                             round_up(stack_depth, 8));
>> @@ -691,21 +742,20 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
>>                                       bool *callee_regs_used, u32 stack_depth,
>>                                       struct jit_context *ctx)
>>  {
>> -       int tcc_off = -4 - round_up(stack_depth, 8);
>>         u8 *prog = *pprog, *start = *pprog;
>>         int offset;
>>
>>         /*
>> -        * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
>> +        * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
>>          *      goto out;
>>          */
>> -       EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
>> -       EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
>> +       /* call bpf_tail_call_cnt_ptr */
>> +       emit_call(&prog, bpf_tail_call_cnt_ptr, ip);
> 
> and here as well.

Replace with emit_bpf_tail_call_cnt_ptr(), too.

> 
> pw-bot: cr

I would like to send next version with these update.

Thanks,
Leon




[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux