Re: [PATCH bpf-next 2/4] bpf, arm64: Fix tailcall infinite loop caused by freplace

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 




On 30/8/24 15:37, Xu Kuohai wrote:
> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>

[...]

> 
> I think the complexity arises from having to decide whether
> to initialize or keep the tail counter value in the prologue.
> 
> To get rid of this complexity, a straightforward idea is to
> move the tail call counter initialization to the entry of
> bpf world, and in the bpf world, we only increase and check
> the tail call counter, never save/restore or set it. The
> "entry of the bpf world" here refers to mechanisms like
> bpf_prog_run, bpf dispatcher, or bpf trampoline that
> allows bpf prog to be invoked from C function.
> 
> Below is a rough POC diff for arm64 that could pass all
> of your tests. The tail call counter is held in callee-saved
> register x26, and is set to 0 by arch_run_bpf.
> 
> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
> index 8aa32cb140b9..2c0f7daf1655 100644
> --- a/arch/arm64/net/bpf_jit_comp.c
> +++ b/arch/arm64/net/bpf_jit_comp.c
> @@ -26,7 +26,7 @@
> 
>  #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
>  #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
> -#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
> +#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
>  #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
>  #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
> 
> @@ -63,7 +63,7 @@ static const int bpf2a64[] = {
>      [TMP_REG_2] = A64_R(11),
>      [TMP_REG_3] = A64_R(12),
>      /* tail_call_cnt_ptr */
> -    [TCCNT_PTR] = A64_R(26),
> +    [TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter
>      /* temporary register for blinding constants */
>      [BPF_REG_AX] = A64_R(9),
>      /* callee saved register for kern_vm_start address */
> @@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale)
>   *      // PROLOGUE_OFFSET
>   *    // save callee-saved registers
>   */
> -static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
> -{
> -    const bool is_main_prog = !bpf_is_subprog(ctx->prog);
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> -
> -    if (is_main_prog) {
> -        /* Initialize tail_call_cnt. */
> -        emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
> -        emit(A64_MOV(1, ptr, A64_SP), ctx);
> -    } else
> -        emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
> -}
> -
>  static void find_used_callee_regs(struct jit_ctx *ctx)
>  {
>      int i;
> @@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx)
>  #define POKE_OFFSET (BTI_INSNS + 1)
> 
>  /* Tail call offset to jump into */
> -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
> +#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2)
> 
>  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>  {
> @@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool
> ebpf_from_cbpf)
>          emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
>          emit(A64_MOV(1, A64_FP, A64_SP), ctx);
> 
> -        prepare_bpf_tail_call_cnt(ctx);
> -
>          if (!ebpf_from_cbpf && is_main_prog) {
>              cur_offset = ctx->idx - idx0;
>              if (cur_offset != PROLOGUE_OFFSET) {
> @@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool
> ebpf_from_cbpf)
>           *
>           * 12 registers are on the stack
>           */
> -        emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
> +        emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
>      }
> 
>      if (ctx->fp_used)
> @@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
> 
>      const u8 tmp = bpf2a64[TMP_REG_1];
>      const u8 prg = bpf2a64[TMP_REG_2];
> -    const u8 tcc = bpf2a64[TMP_REG_3];
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> +    const u8 tcc = bpf2a64[TCALL_CNT];
>      size_t off;
>      __le32 *branch1 = NULL;
>      __le32 *branch2 = NULL;
> @@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>      emit(A64_NOP, ctx);
> 
>      /*
> -     * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
> +     * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
>       *     goto out;
>       */
>      emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
> -    emit(A64_LDR64I(tcc, ptr, 0), ctx);
>      emit(A64_CMP(1, tcc, tmp), ctx);
>      branch2 = ctx->image + ctx->idx;
>      emit(A64_NOP, ctx);
> 
> -    /* (*tail_call_cnt_ptr)++; */
> +    /* tail_call_cnt++; */
>      emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
> 
>      /* prog = array->ptrs[index];
> @@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>      branch3 = ctx->image + ctx->idx;
>      emit(A64_NOP, ctx);
> 
> -    /* Update tail_call_cnt if the slot is populated. */
> -    emit(A64_STR64I(tcc, ptr, 0), ctx);
> -
>      /* restore SP */
>      if (ctx->stack_size)
>          emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
> @@ -793,6 +773,27 @@ asm (
>  "    .popsection\n"
>  );
> 
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
> *insnsi, bpf_func_t bpf_func);
> +asm (
> +"    .pushsection .text, \"ax\", @progbits\n"
> +"    .global arch_run_bpf\n"
> +"    .type arch_run_bpf, %function\n"
> +"arch_run_bpf:\n"
> +#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
> +"    bti j\n"
> +#endif
> +"    stp x29, x30, [sp, #-16]!\n"
> +"    stp xzr, x26, [sp, #-16]!\n"
> +"    mov x26, #0\n"
> +"    blr x2\n"
> +"    ldp xzr, x26, [sp], #16\n"
> +"    ldp x29, x30, [sp], #16\n"
> +"    ret x30\n"
> +"    .size arch_run_bpf, . - arch_run_bpf\n"
> +"    .popsection\n"
> +);
> +EXPORT_SYMBOL_GPL(arch_run_bpf);
> +
>  /* build a plt initialized like this:
>   *
>   * plt:
> @@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx)
>  static void build_epilogue(struct jit_ctx *ctx)
>  {
>      const u8 r0 = bpf2a64[BPF_REG_0];
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> 
>      /* We're done with BPF stack */
>      if (ctx->stack_size)
> @@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx)
> 
>      pop_callee_regs(ctx);
> 
> -    emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
> -
>      /* Restore FP/LR registers */
>      emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
> 
> @@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
> struct bpf_tramp_image *im,
>      bool save_ret;
>      __le32 **branches = NULL;
> 
> +    bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr);
> +
>      /* trampoline stack layout:
>       *                  [ parent ip         ]
>       *                  [ FP                ]
> @@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx
> *ctx, struct bpf_tramp_image *im,
>       */
>      emit_bti(A64_BTI_JC, ctx);
> 
> +    if (!target_is_bpf) {
> +        emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx);
> +        emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx);
> +    }
> +
>      /* frame for parent function */
>      emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
>      emit(A64_MOV(1, A64_FP, A64_SP), ctx);
> @@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
> struct bpf_tramp_image *im,
>      /* pop frames  */
>      emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>      emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
> +    if (!target_is_bpf)
> +        emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx);
> 
>      if (flags & BPF_TRAMP_F_SKIP_FRAME) {
>          /* skip patched function, return to parent */
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index dc63083f76b7..8660d15dd50c 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -1244,12 +1244,14 @@ struct bpf_dispatcher {
>  #define __bpfcall __nocfi
>  #endif
> 
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
> *insnsi, bpf_func_t bpf_func);
> +
>  static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func(
>      const void *ctx,
>      const struct bpf_insn *insnsi,
>      bpf_func_t bpf_func)
>  {
> -    return bpf_func(ctx, insnsi);
> +    return arch_run_bpf(ctx, insnsi, bpf_func);
>  }
> 
>  /* the implementation of the opaque uapi struct bpf_dynptr */
> @@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void
> *buf, s64 *funcs, int num_func
>  #else
>  #define __BPF_DISPATCHER_SC_INIT(name)
>  #define __BPF_DISPATCHER_SC(name)
> -#define __BPF_DISPATCHER_CALL(name)        bpf_func(ctx, insnsi)
> +#define __BPF_DISPATCHER_CALL(name)        arch_run_bpf(ctx, insnsi,
> bpf_func);
>  #define __BPF_DISPATCHER_UPDATE(_d, _new)
>  #endif
> 

This approach is really cool!

I want an alike approach on x86. But I failed. Because, on x86, it's an
indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)".

Let us imagine the arch_run_bpf() on x86:

unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn
*insnsi, bpf_func_t bpf_func)
{
	asm (
		"pushq %rbp\n\t"
		"movq %rsp, %rbp\n\t"
		"xor %rax, %rax\n\t"
		"pushq %rax\n\t"
		"movq %rsp, %rax\n\t"
		"callq *%rdx\n\t"
		"leave\n\t"
		"ret\n\t"
	);
}

If we can change "callq *%rdx" to a direct call, it'll be really
wonderful to resolve this tailcall issue on x86.

How to introduce arch_bpf_run() for all JIT backends?

Thanks,
Leon





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux