Pu Lehui <pulehui@xxxxxxxxxxxxxxx> writes: > From: Pu Lehui <pulehui@xxxxxxxxxx> > > Commit 6724a76cff85 ("riscv: ftrace: Reduce the detour code size to > half") optimizes the detour code size of kernel functions to half with > T0 register and the upcoming DYNAMIC_FTRACE_WITH_DIRECT_CALLS of riscv > is based on this optimization, we need to adapt riscv bpf trampoline > based on this. One thing to do is to reduce detour code size of bpf > programs, and the second is to deal with the return address after the > execution of bpf trampoline. Meanwhile, add more comments and rename > some variables to make more sense. The related tests have passed. > > This adaptation needs to be merged before the upcoming > DYNAMIC_FTRACE_WITH_DIRECT_CALLS of riscv, otherwise it will crash due > to a mismatch in the return address. So we target this modification to > bpf tree and add fixes tag for locating. Thank you for working on this! > Fixes: 6724a76cff85 ("riscv: ftrace: Reduce the detour code size to half") This is not a fix. Nothing is broken. Only that this patch much come before or as part of the ftrace series. > Signed-off-by: Pu Lehui <pulehui@xxxxxxxxxx> > --- > arch/riscv/net/bpf_jit_comp64.c | 110 ++++++++++++++------------------ > 1 file changed, 47 insertions(+), 63 deletions(-) > > diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c > index c648864c8cd1..ffc9aa42f918 100644 > --- a/arch/riscv/net/bpf_jit_comp64.c > +++ b/arch/riscv/net/bpf_jit_comp64.c > @@ -241,7 +241,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx) > if (!is_tail_call) > emit_mv(RV_REG_A0, RV_REG_A5, ctx); > emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA, > - is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */ > + is_tail_call ? 12 : 0, /* skip reserved nops and TCC init */ Maybe be explicit, and use the "DETOUR_INSNS" from below (and convert to bytes)? > ctx); > } > > @@ -618,32 +618,7 @@ static int add_exception_handler(const struct bpf_insn *insn, > return 0; > } > > -static int gen_call_or_nops(void *target, void *ip, u32 *insns) > -{ > - s64 rvoff; > - int i, ret; > - struct rv_jit_context ctx; > - > - ctx.ninsns = 0; > - ctx.insns = (u16 *)insns; > - > - if (!target) { > - for (i = 0; i < 4; i++) > - emit(rv_nop(), &ctx); > - return 0; > - } > - > - rvoff = (s64)(target - (ip + 4)); > - emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx); > - ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx); > - if (ret) > - return ret; > - emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx); > - > - return 0; > -} > - > -static int gen_jump_or_nops(void *target, void *ip, u32 *insns) > +static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call) > { > s64 rvoff; > struct rv_jit_context ctx; > @@ -658,38 +633,38 @@ static int gen_jump_or_nops(void *target, void *ip, u32 *insns) > } > > rvoff = (s64)(target - ip); > - return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx); > + return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, > + rvoff, false, &ctx); Nit: Please use the full 100 char width. > } > > +#define DETOUR_NINSNS 2 Better name? Maybe call this patchable function entry something? Also, to catch future breaks like this -- would it make sense to have a static_assert() combined with something tied to -fpatchable-function-entry= from arch/riscv/Makefile? > + > int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type, > void *old_addr, void *new_addr) > { > - u32 old_insns[4], new_insns[4]; > + u32 old_insns[DETOUR_NINSNS], new_insns[DETOUR_NINSNS]; > bool is_call = poke_type == BPF_MOD_CALL; > - int (*gen_insns)(void *target, void *ip, u32 *insns); > - int ninsns = is_call ? 4 : 2; > int ret; > > - if (!is_bpf_text_address((unsigned long)ip)) > + if (!is_kernel_text((unsigned long)ip) && > + !is_bpf_text_address((unsigned long)ip)) > return -ENOTSUPP; > > - gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops; > - > - ret = gen_insns(old_addr, ip, old_insns); > + ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call); > if (ret) > return ret; > > - if (memcmp(ip, old_insns, ninsns * 4)) > + if (memcmp(ip, old_insns, DETOUR_NINSNS * 4)) > return -EFAULT; > > - ret = gen_insns(new_addr, ip, new_insns); > + ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call); > if (ret) > return ret; > > cpus_read_lock(); > mutex_lock(&text_mutex); > - if (memcmp(ip, new_insns, ninsns * 4)) > - ret = patch_text(ip, new_insns, ninsns); > + if (memcmp(ip, new_insns, DETOUR_NINSNS * 4)) > + ret = patch_text(ip, new_insns, DETOUR_NINSNS); > mutex_unlock(&text_mutex); > cpus_read_unlock(); > > @@ -717,7 +692,7 @@ static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx) > } > > static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off, > - int run_ctx_off, bool save_ret, struct rv_jit_context *ctx) > + int run_ctx_off, bool save_retval, struct rv_jit_context *ctx) Why the save_retval name change? This churn is not needed IMO (especially since you keep using the _ret name below). Please keep the old name. > { > int ret, branch_off; > struct bpf_prog *p = l->link.prog; > @@ -757,7 +732,7 @@ static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_of > if (ret) > return ret; > > - if (save_ret) > + if (save_retval) > emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx); > > /* update branch with beqz */ > @@ -787,20 +762,19 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, > int i, ret, offset; > int *branches_off = NULL; > int stack_size = 0, nregs = m->nr_args; > - int retaddr_off, fp_off, retval_off, args_off; > - int nregs_off, ip_off, run_ctx_off, sreg_off; > + int fp_off, retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off; > struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; > struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT]; > struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN]; > void *orig_call = func_addr; > - bool save_ret; > + bool save_retval, traced_ret; > u32 insn; > > /* Generated trampoline stack layout: > * > * FP - 8 [ RA of parent func ] return address of parent > * function > - * FP - retaddr_off [ RA of traced func ] return address of traced > + * FP - 16 [ RA of traced func ] return address of > traced BPF code uses frame pointers. Shouldn't the trampoline frame look like a regular frame [1], i.e. start with return address followed by previous frame pointer? > * function > * FP - fp_off [ FP of parent func ] > * > @@ -833,17 +807,20 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, > if (nregs > 8) > return -ENOTSUPP; > > - /* room for parent function return address */ > + /* room for return address of parent function */ > stack_size += 8; > > - stack_size += 8; > - retaddr_off = stack_size; > + /* whether return to return address of traced function after bpf trampoline */ > + traced_ret = func_addr && !(flags & BPF_TRAMP_F_SKIP_FRAME); > + /* room for return address of traced function */ > + if (traced_ret) > + stack_size += 8; > > stack_size += 8; > fp_off = stack_size; > > - save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET); > - if (save_ret) { > + save_retval = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET); > + if (save_retval) { > stack_size += 8; > retval_off = stack_size; > } > @@ -869,7 +846,11 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, > > emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx); > > - emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx); > + /* store return address of parent function */ > + emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx); > + /* store return address of traced function */ > + if (traced_ret) > + emit_sd(RV_REG_SP, stack_size - 16, RV_REG_T0, ctx); > emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx); > > emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx); > @@ -890,7 +871,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, > > /* skip to actual body of traced function */ > if (flags & BPF_TRAMP_F_SKIP_FRAME) > - orig_call += 16; > + orig_call += 8; Use the define above so it's obvious what you're skipping. > > if (flags & BPF_TRAMP_F_CALL_ORIG) { > emit_imm(RV_REG_A0, (const s64)im, ctx); > @@ -962,22 +943,25 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, > if (flags & BPF_TRAMP_F_RESTORE_REGS) > restore_args(nregs, args_off, ctx); > > - if (save_ret) > + if (save_retval) > emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx); > > emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx); > > - if (flags & BPF_TRAMP_F_SKIP_FRAME) > - /* return address of parent function */ > + if (traced_ret) { > + /* restore return address of parent function */ > emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx); > - else > - /* return address of traced function */ > - emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx); > + /* restore return address of traced function */ > + emit_ld(RV_REG_T0, stack_size - 16, RV_REG_SP, ctx); > + } else { > + /* restore return address of parent function */ > + emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx); > + } > > emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx); > emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx); > > - emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx); > + emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx); > > ret = ctx->ninsns; > out: > @@ -1664,7 +1648,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, > > void bpf_jit_build_prologue(struct rv_jit_context *ctx) > { > - int i, stack_adjust = 0, store_offset, bpf_stack_adjust; > + int stack_adjust = 0, store_offset, bpf_stack_adjust; > > bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16); > if (bpf_stack_adjust) > @@ -1691,9 +1675,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx) > > store_offset = stack_adjust - 8; > > - /* reserve 4 nop insns */ > - for (i = 0; i < 4; i++) > - emit(rv_nop(), ctx); > + /* 2 nops reserved for auipc+jalr pair */ > + emit(rv_nop(), ctx); > + emit(rv_nop(), ctx); Use the define above, instead of hardcoding two nops. Thanks, Björn [1] https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#frame-pointer-convention