On Sun, Sep 17, 2023 at 5:09 PM Xu Kuohai <xukuohai@xxxxxxxxxxxxxxx> wrote: > > From: Xu Kuohai <xukuohai@xxxxxxxxxx> > > Currently arm64 bpf trampoline supports up to 8 function arguments. > According to the statistics from commit > 473e3150e30a ("bpf, x86: allow function arguments up to 12 for TRACING"), > there are about 200 functions accept 9 to 12 arguments, so adding support > for up to 12 function arguments. Thank you Xu, this will be a nice addition! :) > Due to bpf only supports function arguments up to 16 bytes, according to > AAPCS64, starting from the first argument, each argument is first > attempted to be loaded to 1 or 2 smallest registers from x0-x7, if there > are no enough registers to hold the entire argument, then all remaining > arguments starting from this one are pushed to the stack for passing. If I read the section 6.8.2 of the AAPCS64 correctly, there is a corner case which I believe isn't covered by this logic. void f(u128 a, u128 b, u128, c, u64 d, u128 e, u64 f) {} - a goes on x0 and x1 - b goes on x2 and x3 - c goes on x4 and x5 - d goes on x6 - e spills on the stack because it doesn't fit in the remaining regs - f goes on x7 Maybe it would be good to add something pathological like this to the selftests ? Otherwise I only have minor nitpicks. > Signed-off-by: Xu Kuohai <xukuohai@xxxxxxxxxx> > --- > arch/arm64/net/bpf_jit_comp.c | 171 ++++++++++++++----- > tools/testing/selftests/bpf/DENYLIST.aarch64 | 2 - > 2 files changed, 131 insertions(+), 42 deletions(-) > > diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c > index 7d4af64e3982..a0cf526b07ea 100644 > --- a/arch/arm64/net/bpf_jit_comp.c > +++ b/arch/arm64/net/bpf_jit_comp.c > @@ -1705,7 +1705,7 @@ bool bpf_jit_supports_subprog_tailcalls(void) > } > > static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l, > - int args_off, int retval_off, int run_ctx_off, > + int bargs_off, int retval_off, int run_ctx_off, > bool save_ret) > { > __le32 *branch; > @@ -1747,7 +1747,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l, > /* save return value to callee saved register x20 */ > emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx); > > - emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx); > + emit(A64_ADD_I(1, A64_R(0), A64_SP, bargs_off), ctx); > if (!p->jited) > emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx); > > @@ -1772,7 +1772,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l, > } > > static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl, > - int args_off, int retval_off, int run_ctx_off, > + int bargs_off, int retval_off, int run_ctx_off, > __le32 **branches) > { > int i; > @@ -1782,7 +1782,7 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl, > */ > emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx); > for (i = 0; i < tl->nr_links; i++) { > - invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off, > + invoke_bpf_prog(ctx, tl->links[i], bargs_off, retval_off, > run_ctx_off, true); > /* if (*(u64 *)(sp + retval_off) != 0) > * goto do_fexit; > @@ -1796,23 +1796,111 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl, > } > } > > -static void save_args(struct jit_ctx *ctx, int args_off, int nregs) > +struct arg_aux { > + /* how many args are passed through registers, the rest args are the rest of the* args > + * passed through stack > + */ > + int args_in_reg; Maybe args_in_regs ? since args can go in multiple regs > + /* how many registers used for passing arguments */ are* used > + int regs_for_arg; And here regs_for_args ? Since It's the number of registers used for all args > + /* how many stack slots used for arguments, each slot is 8 bytes */ are* used > + int stack_slots_for_arg; And here stack_slots_for_args, for the same reason as above? > +}; > + > +static void calc_arg_aux(const struct btf_func_model *m, > + struct arg_aux *a) > { > int i; > + int nregs; > + int slots; > + int stack_slots; > + > + /* verifier ensures m->nr_args <= MAX_BPF_FUNC_ARGS */ > + for (i = 0, nregs = 0; i < m->nr_args; i++) { > + slots = (m->arg_size[i] + 7) / 8; > + if (nregs + slots <= 8) /* passed through register ? */ > + nregs += slots; > + else > + break; > + } > + > + a->args_in_reg = i; > + a->regs_for_arg = nregs; > > - for (i = 0; i < nregs; i++) { > - emit(A64_STR64I(i, A64_SP, args_off), ctx); > - args_off += 8; > + /* the rest arguments are passed through stack */ > + for (stack_slots = 0; i < m->nr_args; i++) > + stack_slots += (m->arg_size[i] + 7) / 8; > + > + a->stack_slots_for_arg = stack_slots; > +} > + > +static void clear_garbage(struct jit_ctx *ctx, int reg, int effective_bytes) > +{ > + if (effective_bytes) { > + int garbage_bits = 64 - 8 * effective_bytes; > +#ifdef CONFIG_CPU_BIG_ENDIAN > + /* garbage bits are at the right end */ > + emit(A64_LSR(1, reg, reg, garbage_bits), ctx); > + emit(A64_LSL(1, reg, reg, garbage_bits), ctx); > +#else > + /* garbage bits are at the left end */ > + emit(A64_LSL(1, reg, reg, garbage_bits), ctx); > + emit(A64_LSR(1, reg, reg, garbage_bits), ctx); > +#endif > } > } > > -static void restore_args(struct jit_ctx *ctx, int args_off, int nregs) > +static void save_args(struct jit_ctx *ctx, int bargs_off, int oargs_off, > + const struct btf_func_model *m, > + const struct arg_aux *a, > + bool for_call_origin) > { > int i; > + int reg; > + int doff; > + int soff; > + int slots; > + u8 tmp = bpf2a64[TMP_REG_1]; > + > + /* store argument registers to stack for call bpf, or restore argument to* call bpf or "for the bpf program" > + * registers from stack for the original function > + */ > + for (reg = 0; reg < a->regs_for_arg; reg++) { > + emit(for_call_origin ? > + A64_LDR64I(reg, A64_SP, bargs_off) : > + A64_STR64I(reg, A64_SP, bargs_off), > + ctx); > + bargs_off += 8; > + } > > - for (i = 0; i < nregs; i++) { > - emit(A64_LDR64I(i, A64_SP, args_off), ctx); > - args_off += 8; > + soff = 32; /* on stack arguments start from FP + 32 */ > + doff = (for_call_origin ? oargs_off : bargs_off); > + > + /* save on stack arguments */ > + for (i = a->args_in_reg; i < m->nr_args; i++) { > + slots = (m->arg_size[i] + 7) / 8; > + /* verifier ensures arg_size <= 16, so slots equals 1 or 2 */ > + while (slots-- > 0) { > + emit(A64_LDR64I(tmp, A64_FP, soff), ctx); > + /* if there is unused space in the last slot, clear > + * the garbage contained in the space. > + */ > + if (slots == 0 && !for_call_origin) > + clear_garbage(ctx, tmp, m->arg_size[i] % 8); > + emit(A64_STR64I(tmp, A64_SP, doff), ctx); > + soff += 8; > + doff += 8; > + } > + } > +} > + > +static void restore_args(struct jit_ctx *ctx, int bargs_off, int nregs) > +{ > + int reg; > + > + for (reg = 0; reg < nregs; reg++) { > + emit(A64_LDR64I(reg, A64_SP, bargs_off), ctx); > + bargs_off += 8; > } > } > > @@ -1829,17 +1917,21 @@ static void restore_args(struct jit_ctx *ctx, int args_off, int nregs) > */ > static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im, > struct bpf_tramp_links *tlinks, void *orig_call, > - int nregs, u32 flags) > + const struct btf_func_model *m, > + const struct arg_aux *a, > + u32 flags) > { > int i; > int stack_size; > int retaddr_off; > int regs_off; > int retval_off; > - int args_off; > + int bargs_off; > int nregs_off; > int ip_off; > int run_ctx_off; > + int oargs_off; > + int nregs; > struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; > struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT]; > struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN]; > @@ -1859,19 +1951,26 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im, > * > * SP + retval_off [ return value ] BPF_TRAMP_F_CALL_ORIG or > * BPF_TRAMP_F_RET_FENTRY_RET > - * > * [ arg reg N ] > * [ ... ] > - * SP + args_off [ arg reg 1 ] > + * SP + bargs_off [ arg reg 1 ] for bpf > * > * SP + nregs_off [ arg regs count ] > * > * SP + ip_off [ traced function ] BPF_TRAMP_F_IP_ARG flag > * > * SP + run_ctx_off [ bpf_tramp_run_ctx ] > + * > + * [ stack arg N ] > + * [ ... ] > + * SP + oargs_off [ stack arg 1 ] for original func > */ > > stack_size = 0; > + oargs_off = stack_size; > + if (flags & BPF_TRAMP_F_CALL_ORIG) > + stack_size += 8 * a->stack_slots_for_arg; > + > run_ctx_off = stack_size; > /* room for bpf_tramp_run_ctx */ > stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8); > @@ -1885,9 +1984,10 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im, > /* room for args count */ > stack_size += 8; > > - args_off = stack_size; > + bargs_off = stack_size; > /* room for args */ > - stack_size += nregs * 8; > + nregs = a->regs_for_arg + a->stack_slots_for_arg; Maybe this name no longer makes sense ?