Add interpreter/jit support for new sign-extension load insns
which adds a new mode (BPF_MEMSX).
Also add verifier support to recognize these insns and to
do proper verification with new insns. In verifier, besides
to deduce proper bounds for the dst_reg, probed memory access
is handled by remembering insn mode in insn->imm field so later
on proper jit insns can be emitted.
Signed-off-by: Yonghong Song <yhs@xxxxxx>
---
arch/x86/net/bpf_jit_comp.c | 32 ++++++++-
include/uapi/linux/bpf.h | 1 +
kernel/bpf/core.c | 13 ++++
kernel/bpf/verifier.c | 125 +++++++++++++++++++++++++++------
tools/include/uapi/linux/bpf.h | 1 +
5 files changed, 151 insertions(+), 21 deletions(-)
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 438adb695daa..addeea95f397 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
*pprog = prog;
}
+/* LDX: dst_reg = *(s8*)(src_reg + off) */
+static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
+{
+ u8 *prog = *pprog;
+
+ switch (size) {
+ case BPF_B:
+ /* Emit 'movsx rax, byte ptr [rax + off]' */
+ EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
+ break;
+ case BPF_H:
+ /* Emit 'movsx rax, word ptr [rax + off]' */
+ EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
+ break;
+ case BPF_W:
+ /* Emit 'movsx rax, dword ptr [rax+0x14]' */
+ EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
+ break;
+ }
+ emit_insn_suffix(&prog, src_reg, dst_reg, off);
+ *pprog = prog;
+}
+
/* STX: *(u8*)(dst_reg + off) = src_reg */
static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
{
@@ -1370,6 +1393,9 @@ st: if (is_imm8(insn->off))
case BPF_LDX | BPF_PROBE_MEM | BPF_W:
case BPF_LDX | BPF_MEM | BPF_DW:
case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+ case BPF_LDX | BPF_MEMSX | BPF_B:
+ case BPF_LDX | BPF_MEMSX | BPF_H:
+ case BPF_LDX | BPF_MEMSX | BPF_W:
insn_off = insn->off;
if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
@@ -1415,7 +1441,11 @@ st: if (is_imm8(insn->off))
start_of_ldx = prog;
end_of_jmp[-1] = start_of_ldx - end_of_jmp;
}
- emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+ if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) ||
+ BPF_MODE(insn->code) == BPF_MEMSX)
+ emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+ else
+ emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
struct exception_table_entry *ex;
u8 *_insn = image + proglen + (start_of_ldx - temp);
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 600d0caebbd8..c7196302d1eb 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -19,6 +19,7 @@
/* ld/ldx fields */
#define BPF_DW 0x18 /* double word (64-bit) */
+#define BPF_MEMSX 0x80 /* load with sign extension */
#define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */
#define BPF_XADD 0xc0 /* exclusive add - legacy name */
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index dc85240a0134..8a1cc658789e 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base);
INSN_3(LDX, MEM, H), \
INSN_3(LDX, MEM, W), \
INSN_3(LDX, MEM, DW), \
+ INSN_3(LDX, MEMSX, B), \
+ INSN_3(LDX, MEMSX, H), \
+ INSN_3(LDX, MEMSX, W), \
/* Immediate based. */ \
INSN_3(LD, IMM, DW)
@@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
LDST(DW, u64)
#undef LDST
+#define LDS(SIZEOP, SIZE) \
+ LDX_MEMSX_##SIZEOP: \
+ DST = *(SIZE *)(unsigned long) (SRC + insn->off); \
+ CONT;
+
+ LDS(B, s8)
+ LDS(H, s16)
+ LDS(W, s32)
+#undef LDS
+
#define ATOMIC_ALU_OP(BOP, KOP) \
case BOP: \
if (BPF_SIZE(insn->code) == BPF_W) \
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 81a93eeac7a0..fbe4ca72d4c1 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
__reg_combine_64_into_32(reg);
}
+static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
+{
+ if (size == 1) {
+ reg->smin_value = reg->s32_min_value = S8_MIN;
+ reg->smax_value = reg->s32_max_value = S8_MAX;
+ } else if (size == 2) {
+ reg->smin_value = reg->s32_min_value = S16_MIN;
+ reg->smax_value = reg->s32_max_value = S16_MAX;
+ } else {
+ /* size == 4 */
+ reg->smin_value = reg->s32_min_value = S32_MIN;
+ reg->smax_value = reg->s32_max_value = S32_MAX;
+ }
+ reg->umin_value = reg->u32_min_value = 0;
+ reg->umax_value = U64_MAX;
+ reg->u32_max_value = U32_MAX;
+ reg->var_off = tnum_unknown;
+}
+
+static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+ u64 top_smax_value, top_smin_value;
+ s64 init_s64_max, init_s64_min, s64_max, s64_min;
+ u64 num_bits = size * 8;
+
+ top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
+ top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
+
+ if (top_smax_value != top_smin_value)
+ goto out;
+
+ /* find the s64_min and s64_min after sign extension */
+ if (size == 1) {
+ init_s64_max = (s8)reg->smax_value;
+ init_s64_min = (s8)reg->smin_value;
+ } else if (size == 2) {
+ init_s64_max = (s16)reg->smax_value;
+ init_s64_min = (s16)reg->smin_value;
+ } else {
+ /* size == 4 */
+ init_s64_max = (s32)reg->smax_value;
+ init_s64_min = (s32)reg->smin_value;
+ }
+
+ s64_max = max(init_s64_max, init_s64_min);
+ s64_min = min(init_s64_max, init_s64_min);
+
+ if (s64_max >= 0 && s64_min >= 0) {
+ reg->smin_value = reg->s32_min_value = s64_min;
+ reg->smax_value = reg->s32_max_value = s64_max;
+ reg->umin_value = reg->u32_min_value = s64_min;
+ reg->umax_value = reg->u32_max_value = s64_max;
+ reg->var_off = tnum_range(s64_min, s64_max);
+ return;
+ }
+
+ if (s64_min < 0 && s64_max < 0) {
+ reg->smin_value = reg->s32_min_value = s64_min;
+ reg->smax_value = reg->s32_max_value = s64_max;
+ reg->umin_value = (u64)s64_max;
+ reg->umax_value = (u64)s64_min;