On 30/8/24 15:37, Xu Kuohai wrote: > On 8/27/2024 10:23 AM, Leon Hwang wrote: >> [...] > > I think the complexity arises from having to decide whether > to initialize or keep the tail counter value in the prologue. > > To get rid of this complexity, a straightforward idea is to > move the tail call counter initialization to the entry of > bpf world, and in the bpf world, we only increase and check > the tail call counter, never save/restore or set it. The > "entry of the bpf world" here refers to mechanisms like > bpf_prog_run, bpf dispatcher, or bpf trampoline that > allows bpf prog to be invoked from C function. > > Below is a rough POC diff for arm64 that could pass all > of your tests. The tail call counter is held in callee-saved > register x26, and is set to 0 by arch_run_bpf. > > diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c > index 8aa32cb140b9..2c0f7daf1655 100644 > --- a/arch/arm64/net/bpf_jit_comp.c > +++ b/arch/arm64/net/bpf_jit_comp.c > @@ -26,7 +26,7 @@ > > #define TMP_REG_1 (MAX_BPF_JIT_REG + 0) > #define TMP_REG_2 (MAX_BPF_JIT_REG + 1) > -#define TCCNT_PTR (MAX_BPF_JIT_REG + 2) > +#define TCALL_CNT (MAX_BPF_JIT_REG + 2) > #define TMP_REG_3 (MAX_BPF_JIT_REG + 3) > #define ARENA_VM_START (MAX_BPF_JIT_REG + 5) > > @@ -63,7 +63,7 @@ static const int bpf2a64[] = { > [TMP_REG_2] = A64_R(11), > [TMP_REG_3] = A64_R(12), > /* tail_call_cnt_ptr */ > - [TCCNT_PTR] = A64_R(26), > + [TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter > /* temporary register for blinding constants */ > [BPF_REG_AX] = A64_R(9), > /* callee saved register for kern_vm_start address */ > @@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale) > * // PROLOGUE_OFFSET > * // save callee-saved registers > */ > -static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx) > -{ > - const bool is_main_prog = !bpf_is_subprog(ctx->prog); > - const u8 ptr = bpf2a64[TCCNT_PTR]; > - > - if (is_main_prog) { > - /* Initialize tail_call_cnt. */ > - emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx); > - emit(A64_MOV(1, ptr, A64_SP), ctx); > - } else > - emit(A64_PUSH(ptr, ptr, A64_SP), ctx); > -} > - > static void find_used_callee_regs(struct jit_ctx *ctx) > { > int i; > @@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx) > #define POKE_OFFSET (BTI_INSNS + 1) > > /* Tail call offset to jump into */ > -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4) > +#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2) > > static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf) > { > @@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool > ebpf_from_cbpf) > emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx); > emit(A64_MOV(1, A64_FP, A64_SP), ctx); > > - prepare_bpf_tail_call_cnt(ctx); > - > if (!ebpf_from_cbpf && is_main_prog) { > cur_offset = ctx->idx - idx0; > if (cur_offset != PROLOGUE_OFFSET) { > @@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool > ebpf_from_cbpf) > * > * 12 registers are on the stack > */ > - emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx); > + emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx); > } > > if (ctx->fp_used) > @@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) > > const u8 tmp = bpf2a64[TMP_REG_1]; > const u8 prg = bpf2a64[TMP_REG_2]; > - const u8 tcc = bpf2a64[TMP_REG_3]; > - const u8 ptr = bpf2a64[TCCNT_PTR]; > + const u8 tcc = bpf2a64[TCALL_CNT]; > size_t off; > __le32 *branch1 = NULL; > __le32 *branch2 = NULL; > @@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) > emit(A64_NOP, ctx); > > /* > - * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT) > + * if (tail_call_cnt >= MAX_TAIL_CALL_CNT) > * goto out; > */ > emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx); > - emit(A64_LDR64I(tcc, ptr, 0), ctx); > emit(A64_CMP(1, tcc, tmp), ctx); > branch2 = ctx->image + ctx->idx; > emit(A64_NOP, ctx); > > - /* (*tail_call_cnt_ptr)++; */ > + /* tail_call_cnt++; */ > emit(A64_ADD_I(1, tcc, tcc, 1), ctx); > > /* prog = array->ptrs[index]; > @@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) > branch3 = ctx->image + ctx->idx; > emit(A64_NOP, ctx); > > - /* Update tail_call_cnt if the slot is populated. */ > - emit(A64_STR64I(tcc, ptr, 0), ctx); > - > /* restore SP */ > if (ctx->stack_size) > emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx); > @@ -793,6 +773,27 @@ asm ( > " .popsection\n" > ); > > +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn > *insnsi, bpf_func_t bpf_func); > +asm ( > +" .pushsection .text, \"ax\", @progbits\n" > +" .global arch_run_bpf\n" > +" .type arch_run_bpf, %function\n" > +"arch_run_bpf:\n" > +#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) > +" bti j\n" > +#endif > +" stp x29, x30, [sp, #-16]!\n" > +" stp xzr, x26, [sp, #-16]!\n" > +" mov x26, #0\n" > +" blr x2\n" > +" ldp xzr, x26, [sp], #16\n" > +" ldp x29, x30, [sp], #16\n" > +" ret x30\n" > +" .size arch_run_bpf, . - arch_run_bpf\n" > +" .popsection\n" > +); > +EXPORT_SYMBOL_GPL(arch_run_bpf); > + > /* build a plt initialized like this: > * > * plt: > @@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx) > static void build_epilogue(struct jit_ctx *ctx) > { > const u8 r0 = bpf2a64[BPF_REG_0]; > - const u8 ptr = bpf2a64[TCCNT_PTR]; > > /* We're done with BPF stack */ > if (ctx->stack_size) > @@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx) > > pop_callee_regs(ctx); > > - emit(A64_POP(A64_ZR, ptr, A64_SP), ctx); > - > /* Restore FP/LR registers */ > emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx); > > @@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, > struct bpf_tramp_image *im, > bool save_ret; > __le32 **branches = NULL; > > + bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr); > + > /* trampoline stack layout: > * [ parent ip ] > * [ FP ] > @@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx > *ctx, struct bpf_tramp_image *im, > */ > emit_bti(A64_BTI_JC, ctx); > > + if (!target_is_bpf) { > + emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx); > + emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx); > + } > + > /* frame for parent function */ > emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx); > emit(A64_MOV(1, A64_FP, A64_SP), ctx); > @@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, > struct bpf_tramp_image *im, > /* pop frames */ > emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx); > emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx); > + if (!target_is_bpf) > + emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx); > > if (flags & BPF_TRAMP_F_SKIP_FRAME) { > /* skip patched function, return to parent */ > diff --git a/include/linux/bpf.h b/include/linux/bpf.h > index dc63083f76b7..8660d15dd50c 100644 > --- a/include/linux/bpf.h > +++ b/include/linux/bpf.h > @@ -1244,12 +1244,14 @@ struct bpf_dispatcher { > #define __bpfcall __nocfi > #endif > > +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn > *insnsi, bpf_func_t bpf_func); > + > static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func( > const void *ctx, > const struct bpf_insn *insnsi, > bpf_func_t bpf_func) > { > - return bpf_func(ctx, insnsi); > + return arch_run_bpf(ctx, insnsi, bpf_func); > } > > /* the implementation of the opaque uapi struct bpf_dynptr */ > @@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void > *buf, s64 *funcs, int num_func > #else > #define __BPF_DISPATCHER_SC_INIT(name) > #define __BPF_DISPATCHER_SC(name) > -#define __BPF_DISPATCHER_CALL(name) bpf_func(ctx, insnsi) > +#define __BPF_DISPATCHER_CALL(name) arch_run_bpf(ctx, insnsi, > bpf_func); > #define __BPF_DISPATCHER_UPDATE(_d, _new) > #endif > This approach is really cool! I want an alike approach on x86. But I failed. Because, on x86, it's an indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)". Let us imagine the arch_run_bpf() on x86: unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn *insnsi, bpf_func_t bpf_func) { asm ( "pushq %rbp\n\t" "movq %rsp, %rbp\n\t" "xor %rax, %rax\n\t" "pushq %rax\n\t" "movq %rsp, %rax\n\t" "callq *%rdx\n\t" "leave\n\t" "ret\n\t" ); } If we can change "callq *%rdx" to a direct call, it'll be really wonderful to resolve this tailcall issue on x86. How to introduce arch_bpf_run() for all JIT backends? Thanks, Leon