On Wed, 2023-07-19 at 17:01 -0700, Yonghong Song wrote: > Add interpreter/jit support for new sign-extension load insns > which adds a new mode (BPF_MEMSX). > Also add verifier support to recognize these insns and to > do proper verification with new insns. In verifier, besides > to deduce proper bounds for the dst_reg, probed memory access > is also properly handled. > > Signed-off-by: Yonghong Song <yhs@xxxxxx> > --- > arch/x86/net/bpf_jit_comp.c | 42 ++++++++- > include/linux/filter.h | 3 + > include/uapi/linux/bpf.h | 1 + > kernel/bpf/core.c | 21 +++++ > kernel/bpf/verifier.c | 150 +++++++++++++++++++++++++++------ > tools/include/uapi/linux/bpf.h | 1 + > 6 files changed, 191 insertions(+), 27 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index 83c4b45dc65f..54478a9c93e1 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > *pprog = prog; > } > > +/* LDSX: dst_reg = *(s8*)(src_reg + off) */ > +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > +{ > + u8 *prog = *pprog; > + > + switch (size) { > + case BPF_B: > + /* Emit 'movsx rax, byte ptr [rax + off]' */ > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE); > + break; > + case BPF_H: > + /* Emit 'movsx rax, word ptr [rax + off]' */ > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF); > + break; > + case BPF_W: > + /* Emit 'movsx rax, dword ptr [rax+0x14]' */ > + EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63); > + break; > + } > + emit_insn_suffix(&prog, src_reg, dst_reg, off); > + *pprog = prog; > +} > + > /* STX: *(u8*)(dst_reg + off) = src_reg */ > static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > { > @@ -1370,9 +1393,17 @@ st: if (is_imm8(insn->off)) > case BPF_LDX | BPF_PROBE_MEM | BPF_W: > case BPF_LDX | BPF_MEM | BPF_DW: > case BPF_LDX | BPF_PROBE_MEM | BPF_DW: > + /* LDXS: dst_reg = *(s8*)(src_reg + off) */ > + case BPF_LDX | BPF_MEMSX | BPF_B: > + case BPF_LDX | BPF_MEMSX | BPF_H: > + case BPF_LDX | BPF_MEMSX | BPF_W: > + case BPF_LDX | BPF_PROBE_MEMSX | BPF_B: > + case BPF_LDX | BPF_PROBE_MEMSX | BPF_H: > + case BPF_LDX | BPF_PROBE_MEMSX | BPF_W: > insn_off = insn->off; > > - if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > + if (BPF_MODE(insn->code) == BPF_PROBE_MEM || > + BPF_MODE(insn->code) == BPF_PROBE_MEMSX) { > /* Conservatively check that src_reg + insn->off is a kernel address: > * src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE > * src_reg is used as scratch for src_reg += insn->off and restored > @@ -1415,8 +1446,13 @@ st: if (is_imm8(insn->off)) > start_of_ldx = prog; > end_of_jmp[-1] = start_of_ldx - end_of_jmp; > } > - emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > - if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > + if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX || > + BPF_MODE(insn->code) == BPF_MEMSX) > + emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > + else > + emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > + if (BPF_MODE(insn->code) == BPF_PROBE_MEM || > + BPF_MODE(insn->code) == BPF_PROBE_MEMSX) { > struct exception_table_entry *ex; > u8 *_insn = image + proglen + (start_of_ldx - temp); > s64 delta; > diff --git a/include/linux/filter.h b/include/linux/filter.h > index f69114083ec7..a93242b5516b 100644 > --- a/include/linux/filter.h > +++ b/include/linux/filter.h > @@ -69,6 +69,9 @@ struct ctl_table_header; > /* unused opcode to mark special load instruction. Same as BPF_ABS */ > #define BPF_PROBE_MEM 0x20 > > +/* unused opcode to mark special ldsx instruction. Same as BPF_IND */ > +#define BPF_PROBE_MEMSX 0x40 > + > /* unused opcode to mark call to interpreter with arguments */ > #define BPF_CALL_ARGS 0xe0 > > diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h > index 739c15906a65..651a34511780 100644 > --- a/include/uapi/linux/bpf.h > +++ b/include/uapi/linux/bpf.h > @@ -19,6 +19,7 @@ > > /* ld/ldx fields */ > #define BPF_DW 0x18 /* double word (64-bit) */ > +#define BPF_MEMSX 0x80 /* load with sign extension */ > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ > > diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c > index dc85240a0134..01b72fc001f6 100644 > --- a/kernel/bpf/core.c > +++ b/kernel/bpf/core.c > @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base); > INSN_3(LDX, MEM, H), \ > INSN_3(LDX, MEM, W), \ > INSN_3(LDX, MEM, DW), \ > + INSN_3(LDX, MEMSX, B), \ > + INSN_3(LDX, MEMSX, H), \ > + INSN_3(LDX, MEMSX, W), \ > /* Immediate based. */ \ > INSN_3(LD, IMM, DW) > > @@ -1666,6 +1669,9 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) > [BPF_LDX | BPF_PROBE_MEM | BPF_H] = &&LDX_PROBE_MEM_H, > [BPF_LDX | BPF_PROBE_MEM | BPF_W] = &&LDX_PROBE_MEM_W, > [BPF_LDX | BPF_PROBE_MEM | BPF_DW] = &&LDX_PROBE_MEM_DW, > + [BPF_LDX | BPF_PROBE_MEMSX | BPF_B] = &&LDX_PROBE_MEMSX_B, > + [BPF_LDX | BPF_PROBE_MEMSX | BPF_H] = &&LDX_PROBE_MEMSX_H, > + [BPF_LDX | BPF_PROBE_MEMSX | BPF_W] = &&LDX_PROBE_MEMSX_W, > }; > #undef BPF_INSN_3_LBL > #undef BPF_INSN_2_LBL > @@ -1942,6 +1948,21 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) > LDST(DW, u64) > #undef LDST > > +#define LDSX(SIZEOP, SIZE) \ > + LDX_MEMSX_##SIZEOP: \ > + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ > + CONT; \ > + LDX_PROBE_MEMSX_##SIZEOP: \ > + bpf_probe_read_kernel(&DST, sizeof(SIZE), \ > + (const void *)(long) (SRC + insn->off)); \ > + DST = *((SIZE *)&DST); \ > + CONT; > + > + LDSX(B, s8) > + LDSX(H, s16) > + LDSX(W, s32) > +#undef LDSX > + > #define ATOMIC_ALU_OP(BOP, KOP) \ > case BOP: \ > if (BPF_SIZE(insn->code) == BPF_W) \ > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index 803b91135ca0..79c0cd50ec59 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -5809,6 +5809,94 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) > __reg_combine_64_into_32(reg); > } > > +static void set_sext64_default_val(struct bpf_reg_state *reg, int size) > +{ > + if (size == 1) { > + reg->smin_value = reg->s32_min_value = S8_MIN; > + reg->smax_value = reg->s32_max_value = S8_MAX; > + } else if (size == 2) { > + reg->smin_value = reg->s32_min_value = S16_MIN; > + reg->smax_value = reg->s32_max_value = S16_MAX; > + } else { > + /* size == 4 */ > + reg->smin_value = reg->s32_min_value = S32_MIN; > + reg->smax_value = reg->s32_max_value = S32_MAX; > + } > + reg->umin_value = reg->u32_min_value = 0; > + reg->umax_value = U64_MAX; > + reg->u32_max_value = U32_MAX; > + reg->var_off = tnum_unknown; > +} > + > +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) > +{ > + s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval; > + u64 top_smax_value, top_smin_value; > + u64 num_bits = size * 8; > + > + if (tnum_is_const(reg->var_off)) { > + u64_cval = reg->var_off.value; > + if (size == 1) > + reg->var_off = tnum_const((s8)u64_cval); > + else if (size == 2) > + reg->var_off = tnum_const((s16)u64_cval); > + else > + /* size == 4 */ > + reg->var_off = tnum_const((s32)u64_cval); > + > + u64_cval = reg->var_off.value; > + reg->smax_value = reg->smin_value = u64_cval; > + reg->umax_value = reg->umin_value = u64_cval; > + reg->s32_max_value = reg->s32_min_value = u64_cval; > + reg->u32_max_value = reg->u32_min_value = u64_cval; > + return; > + } > + > + top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; > + top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; > + > + if (top_smax_value != top_smin_value) > + goto out; > + > + /* find the s64_min and s64_min after sign extension */ > + if (size == 1) { > + init_s64_max = (s8)reg->smax_value; > + init_s64_min = (s8)reg->smin_value; > + } else if (size == 2) { > + init_s64_max = (s16)reg->smax_value; > + init_s64_min = (s16)reg->smin_value; > + } else { > + init_s64_max = (s32)reg->smax_value; > + init_s64_min = (s32)reg->smin_value; > + } > + > + s64_max = max(init_s64_max, init_s64_min); > + s64_min = min(init_s64_max, init_s64_min); > + > + if (s64_max >= 0 && s64_min >= 0) { > + reg->smin_value = reg->s32_min_value = s64_min; > + reg->smax_value = reg->s32_max_value = s64_max; > + reg->umin_value = reg->u32_min_value = s64_min; > + reg->umax_value = reg->u32_max_value = s64_max; > + reg->var_off = tnum_range(s64_min, s64_max); > + return; > + } > + > + if (s64_min < 0 && s64_max < 0) { > + reg->smin_value = reg->s32_min_value = s64_min; > + reg->smax_value = reg->s32_max_value = s64_max; > + reg->umin_value = (u64)s64_min; > + reg->umax_value = (u64)s64_max; > + reg->u32_min_value = (u32)s64_min; > + reg->u32_max_value = (u32)s64_max; > + reg->var_off = tnum_range((u64)s64_min, (u64)s64_max); > + return; > + } I think that the bodies for (s64_max >= 0 && s64_min >= 0) and (s64_min < 0 && s64_max < 0) are now identical. > + > +out: > + set_sext64_default_val(reg, size); > +} > + > static bool bpf_map_is_rdonly(const struct bpf_map *map) > { > /* A map is considered read-only if the following condition are true: > @@ -5829,7 +5917,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map) > !bpf_map_write_active(map); > } > > -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val, > + bool is_ldsx) > { > void *ptr; > u64 addr; > @@ -5842,13 +5931,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > > switch (size) { > case sizeof(u8): > - *val = (u64)*(u8 *)ptr; > + *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr; > break; > case sizeof(u16): > - *val = (u64)*(u16 *)ptr; > + *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr; > break; > case sizeof(u32): > - *val = (u64)*(u32 *)ptr; > + *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr; > break; > case sizeof(u64): > *val = *(u64 *)ptr; > @@ -6267,7 +6356,7 @@ static int check_stack_access_within_bounds( > */ > static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno, > int off, int bpf_size, enum bpf_access_type t, > - int value_regno, bool strict_alignment_once) > + int value_regno, bool strict_alignment_once, bool is_ldsx) > { > struct bpf_reg_state *regs = cur_regs(env); > struct bpf_reg_state *reg = regs + regno; > @@ -6328,7 +6417,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > u64 val = 0; > > err = bpf_map_direct_read(map, map_off, size, > - &val); > + &val, is_ldsx); > if (err) > return err; > > @@ -6498,8 +6587,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > > if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ && > regs[value_regno].type == SCALAR_VALUE) { > - /* b/h/w load zero-extends, mark upper bits as known 0 */ > - coerce_reg_to_size(®s[value_regno], size); > + if (!is_ldsx) > + /* b/h/w load zero-extends, mark upper bits as known 0 */ > + coerce_reg_to_size(®s[value_regno], size); > + else > + coerce_reg_to_size_sx(®s[value_regno], size); > } > return err; > } > @@ -6591,17 +6683,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i > * case to simulate the register fill. > */ > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > - BPF_SIZE(insn->code), BPF_READ, -1, true); > + BPF_SIZE(insn->code), BPF_READ, -1, true, false); > if (!err && load_reg >= 0) > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > BPF_SIZE(insn->code), BPF_READ, load_reg, > - true); > + true, false); > if (err) > return err; > > /* Check whether we can write into the same memory. */ > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > - BPF_SIZE(insn->code), BPF_WRITE, -1, true); > + BPF_SIZE(insn->code), BPF_WRITE, -1, true, false); > if (err) > return err; > > @@ -6847,7 +6939,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno, > return zero_size_allowed ? 0 : -EACCES; > > return check_mem_access(env, env->insn_idx, regno, offset, BPF_B, > - atype, -1, false); > + atype, -1, false, false); > } > > fallthrough; > @@ -7219,7 +7311,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn > /* we write BPF_DW bits (8 bytes) at a time */ > for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) { > err = check_mem_access(env, insn_idx, regno, > - i, BPF_DW, BPF_WRITE, -1, false); > + i, BPF_DW, BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -7312,7 +7404,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id > > for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) { > err = check_mem_access(env, insn_idx, regno, > - i, BPF_DW, BPF_WRITE, -1, false); > + i, BPF_DW, BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -9456,7 +9548,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn > */ > for (i = 0; i < meta.access_size; i++) { > err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B, > - BPF_WRITE, -1, false); > + BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -16184,7 +16276,7 @@ static int save_aux_ptr_type(struct bpf_verifier_env *env, enum bpf_reg_type typ > * Have to support a use case when one path through > * the program yields TRUSTED pointer while another > * is UNTRUSTED. Fallback to UNTRUSTED to generate > - * BPF_PROBE_MEM. > + * BPF_PROBE_MEM/BPF_PROBE_MEMSX. > */ > *prev_type = PTR_TO_BTF_ID | PTR_UNTRUSTED; > } else { > @@ -16325,7 +16417,8 @@ static int do_check(struct bpf_verifier_env *env) > */ > err = check_mem_access(env, env->insn_idx, insn->src_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_READ, insn->dst_reg, false); > + BPF_READ, insn->dst_reg, false, > + BPF_MODE(insn->code) == BPF_MEMSX); > if (err) > return err; > > @@ -16362,7 +16455,7 @@ static int do_check(struct bpf_verifier_env *env) > /* check that memory (dst_reg + off) is writeable */ > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_WRITE, insn->src_reg, false); > + BPF_WRITE, insn->src_reg, false, false); > if (err) > return err; > > @@ -16387,7 +16480,7 @@ static int do_check(struct bpf_verifier_env *env) > /* check that memory (dst_reg + off) is writeable */ > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_WRITE, -1, false); > + BPF_WRITE, -1, false, false); > if (err) > return err; > > @@ -16815,7 +16908,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env) > > for (i = 0; i < insn_cnt; i++, insn++) { > if (BPF_CLASS(insn->code) == BPF_LDX && > - (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) { > + ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) || > + insn->imm != 0)) { > verbose(env, "BPF_LDX uses reserved fields\n"); > return -EINVAL; > } > @@ -17513,7 +17607,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || > insn->code == (BPF_LDX | BPF_MEM | BPF_H) || > insn->code == (BPF_LDX | BPF_MEM | BPF_W) || > - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { > + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { > type = BPF_READ; > } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || > insn->code == (BPF_STX | BPF_MEM | BPF_H) || > @@ -17572,8 +17669,12 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > */ > case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: > if (type == BPF_READ) { > - insn->code = BPF_LDX | BPF_PROBE_MEM | > - BPF_SIZE((insn)->code); > + if (BPF_MODE(insn->code) == BPF_MEM) > + insn->code = BPF_LDX | BPF_PROBE_MEM | > + BPF_SIZE((insn)->code); > + else > + insn->code = BPF_LDX | BPF_PROBE_MEMSX | > + BPF_SIZE((insn)->code); > env->prog->aux->num_exentries++; > } > continue; > @@ -17761,7 +17862,8 @@ static int jit_subprogs(struct bpf_verifier_env *env) > insn = func[i]->insnsi; > for (j = 0; j < func[i]->len; j++, insn++) { > if (BPF_CLASS(insn->code) == BPF_LDX && > - BPF_MODE(insn->code) == BPF_PROBE_MEM) > + (BPF_MODE(insn->code) == BPF_PROBE_MEM || > + BPF_MODE(insn->code) == BPF_PROBE_MEMSX)) > num_exentries++; > } > func[i]->aux->num_exentries = num_exentries; > diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h > index 739c15906a65..651a34511780 100644 > --- a/tools/include/uapi/linux/bpf.h > +++ b/tools/include/uapi/linux/bpf.h > @@ -19,6 +19,7 @@ > > /* ld/ldx fields */ > #define BPF_DW 0x18 /* double word (64-bit) */ > +#define BPF_MEMSX 0x80 /* load with sign extension */ > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ >