On Fri, 9 Feb 2024 at 05:06, Alexei Starovoitov <alexei.starovoitov@xxxxxxxxx> wrote: > > From: Alexei Starovoitov <ast@xxxxxxxxxx> > > Add support for [LDX | STX | ST], PROBE_MEM32, [B | H | W | DW] instructions. > They are similar to PROBE_MEM instructions with the following differences: > - PROBE_MEM has to check that the address is in the kernel range with > src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE check > - PROBE_MEM doesn't support store > - PROBE_MEM32 relies on the verifier to clear upper 32-bit in the register > - PROBE_MEM32 adds 64-bit kern_vm_start address (which is stored in %r12 in the prologue) > Due to bpf_arena constructions such %r12 + %reg + off16 access is guaranteed > to be within arena virtual range, so no address check at run-time. > - PROBE_MEM32 allows STX and ST. If they fault the store is a nop. > When LDX faults the destination register is zeroed. > > Signed-off-by: Alexei Starovoitov <ast@xxxxxxxxxx> > --- Just a potential issue with tail calls, but otherwise lgtm so: Acked-by: Kumar Kartikeya Dwivedi <memxor@xxxxxxxxx> > arch/x86/net/bpf_jit_comp.c | 183 +++++++++++++++++++++++++++++++++++- > include/linux/bpf.h | 1 + > include/linux/filter.h | 3 + > 3 files changed, 186 insertions(+), 1 deletion(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index e1390d1e331b..883b7f604b9a 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -113,6 +113,7 @@ static int bpf_size_to_x86_bytes(int bpf_size) > /* Pick a register outside of BPF range for JIT internal work */ > #define AUX_REG (MAX_BPF_JIT_REG + 1) > #define X86_REG_R9 (MAX_BPF_JIT_REG + 2) > +#define X86_REG_R12 (MAX_BPF_JIT_REG + 3) > > [...] > + arena_vm_start = bpf_arena_get_kern_vm_start(bpf_prog->aux->arena); > + > detect_reg_usage(insn, insn_cnt, callee_regs_used, > &tail_call_seen); > > @@ -1172,8 +1300,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image > push_r12(&prog); > push_callee_regs(&prog, all_callee_regs_used); > } else { > + if (arena_vm_start) > + push_r12(&prog); I believe since this is done on entry for arena_vm_start, we need to do matching pop_r12 in emit_bpf_tail_call_indirect and emit_bpf_tail_call_direct before tail call, unless I'm missing something. Otherwise r12 may be bad after prog (push + set to arena_vm_start) -> tail call -> exit (no pop of r12 back from stack). > push_callee_regs(&prog, callee_regs_used); > } > + if (arena_vm_start) > + emit_mov_imm64(&prog, X86_REG_R12, > + arena_vm_start >> 32, (u32) arena_vm_start); > > ilen = prog - temp; > if (rw_image) > @@ -1564,6 +1697,52 @@ st: if (is_imm8(insn->off)) > emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); > break; > > + case BPF_ST | BPF_PROBE_MEM32 | BPF_B: > + case BPF_ST | BPF_PROBE_MEM32 | BPF_H: > + case BPF_ST | BPF_PROBE_MEM32 | BPF_W: > + case BPF_ST | BPF_PROBE_MEM32 | BPF_DW: > + start_of_ldx = prog; > + emit_st_r12(&prog, BPF_SIZE(insn->code), dst_reg, insn->off, insn->imm); > + goto populate_extable; > + > + /* LDX: dst_reg = *(u8*)(src_reg + r12 + off) */ > + case BPF_LDX | BPF_PROBE_MEM32 | BPF_B: > + case BPF_LDX | BPF_PROBE_MEM32 | BPF_H: > + case BPF_LDX | BPF_PROBE_MEM32 | BPF_W: > + case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW: > + case BPF_STX | BPF_PROBE_MEM32 | BPF_B: > + case BPF_STX | BPF_PROBE_MEM32 | BPF_H: > + case BPF_STX | BPF_PROBE_MEM32 | BPF_W: > + case BPF_STX | BPF_PROBE_MEM32 | BPF_DW: > + start_of_ldx = prog; > + if (BPF_CLASS(insn->code) == BPF_LDX) > + emit_ldx_r12(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); > + else > + emit_stx_r12(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); > +populate_extable: > + { > + struct exception_table_entry *ex; > + u8 *_insn = image + proglen + (start_of_ldx - temp); > + s64 delta; > + > + if (!bpf_prog->aux->extable) > + break; > + > + ex = &bpf_prog->aux->extable[excnt++]; > + > + delta = _insn - (u8 *)&ex->insn; > + /* switch ex to rw buffer for writes */ > + ex = (void *)rw_image + ((void *)ex - (void *)image); > + > + ex->insn = delta; > + > + ex->data = EX_TYPE_BPF; > + > + ex->fixup = (prog - start_of_ldx) | > + ((BPF_CLASS(insn->code) == BPF_LDX ? reg2pt_regs[dst_reg] : DONT_CLEAR) << 8); > + } > + break; > + > /* LDX: dst_reg = *(u8*)(src_reg + off) */ > case BPF_LDX | BPF_MEM | BPF_B: > case BPF_LDX | BPF_PROBE_MEM | BPF_B: > @@ -2036,6 +2215,8 @@ st: if (is_imm8(insn->off)) > pop_r12(&prog); > } else { > pop_callee_regs(&prog, callee_regs_used); > + if (arena_vm_start) > + pop_r12(&prog); > } ... Basically this if condition copied to those two other places. > [...]