Change reg_set_min_max() to take FALSE/TRUE sets of two registers each, instead of assuming that we are always comparing to a constant. For now we still assume that right-hand side registers are constants (and make sure that's the case by swapping src/dst regs, if necessary), but subsequent patches will remove this limitation. Taking two by two registers allows to further unify and simplify check_cond_jmp_op() logic. We utilize fake register for BPF_K conditional jump case, just like with is_branch_taken() part. Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> --- kernel/bpf/verifier.c | 112 ++++++++++++++++++------------------------ 1 file changed, 49 insertions(+), 63 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index dde04b17c3a3..522566699fbe 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -14387,26 +14387,43 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg * In JEQ/JNE cases we also adjust the var_off values. */ static void reg_set_min_max(struct bpf_reg_state *true_reg1, + struct bpf_reg_state *true_reg2, struct bpf_reg_state *false_reg1, - u64 val, u32 val32, + struct bpf_reg_state *false_reg2, u8 opcode, bool is_jmp32) { - struct tnum false_32off = tnum_subreg(false_reg1->var_off); - struct tnum false_64off = false_reg1->var_off; - struct tnum true_32off = tnum_subreg(true_reg1->var_off); - struct tnum true_64off = true_reg1->var_off; - s64 sval = (s64)val; - s32 sval32 = (s32)val32; - - /* If the dst_reg is a pointer, we can't learn anything about its - * variable offset from the compare (unless src_reg were a pointer into - * the same object, but we don't bother with that. - * Since false_reg1 and true_reg1 have the same type by construction, we - * only need to check one of them for pointerness. + struct tnum false_32off, false_64off; + struct tnum true_32off, true_64off; + u64 val; + u32 val32; + s64 sval; + s32 sval32; + + /* If either register is a pointer, we can't learn anything about its + * variable offset from the compare (unless they were a pointer into + * the same object, but we don't bother with that). */ - if (__is_pointer_value(false, false_reg1)) + if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) + return; + + /* we expect right-hand registers (src ones) to be constants, for now */ + if (!is_reg_const(false_reg2, is_jmp32)) { + opcode = flip_opcode(opcode); + swap(true_reg1, true_reg2); + swap(false_reg1, false_reg2); + } + if (!is_reg_const(false_reg2, is_jmp32)) return; + false_32off = tnum_subreg(false_reg1->var_off); + false_64off = false_reg1->var_off; + true_32off = tnum_subreg(true_reg1->var_off); + true_64off = true_reg1->var_off; + val = false_reg2->var_off.value; + val32 = (u32)tnum_subreg(false_reg2->var_off).value; + sval = (s64)val; + sval32 = (s32)val32; + switch (opcode) { /* JEQ/JNE comparison doesn't change the register equivalence. * @@ -14543,22 +14560,6 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg1, } } -/* Same as above, but for the case that dst_reg holds a constant and src_reg is - * the variable reg. - */ -static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, - struct bpf_reg_state *false_reg, - u64 val, u32 val32, - u8 opcode, bool is_jmp32) -{ - opcode = flip_opcode(opcode); - /* This uses zero as "not present in table"; luckily the zero opcode, - * BPF_JA, can't get here. - */ - if (opcode) - reg_set_min_max(true_reg, false_reg, val, val32, opcode, is_jmp32); -} - /* Regs are known to be equal, so intersect their min/max/var_off */ static void __reg_combine_min_max(struct bpf_reg_state *src_reg, struct bpf_reg_state *dst_reg) @@ -14891,45 +14892,30 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, * comparable. */ if (BPF_SRC(insn->code) == BPF_X) { - struct bpf_reg_state *src_reg = ®s[insn->src_reg]; + reg_set_min_max(&other_branch_regs[insn->dst_reg], + &other_branch_regs[insn->src_reg], + dst_reg, src_reg, opcode, is_jmp32); if (dst_reg->type == SCALAR_VALUE && - src_reg->type == SCALAR_VALUE) { - if (tnum_is_const(src_reg->var_off) || - (is_jmp32 && - tnum_is_const(tnum_subreg(src_reg->var_off)))) - reg_set_min_max(&other_branch_regs[insn->dst_reg], - dst_reg, - src_reg->var_off.value, - tnum_subreg(src_reg->var_off).value, - opcode, is_jmp32); - else if (tnum_is_const(dst_reg->var_off) || - (is_jmp32 && - tnum_is_const(tnum_subreg(dst_reg->var_off)))) - reg_set_min_max_inv(&other_branch_regs[insn->src_reg], - src_reg, - dst_reg->var_off.value, - tnum_subreg(dst_reg->var_off).value, - opcode, is_jmp32); - else if (!is_jmp32 && - (opcode == BPF_JEQ || opcode == BPF_JNE)) - /* Comparing for equality, we can combine knowledge */ - reg_combine_min_max(&other_branch_regs[insn->src_reg], - &other_branch_regs[insn->dst_reg], - src_reg, dst_reg, opcode); - if (src_reg->id && - !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { - find_equal_scalars(this_branch, src_reg); - find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); - } - + src_reg->type == SCALAR_VALUE && + !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) { + /* Comparing for equality, we can combine knowledge */ + reg_combine_min_max(&other_branch_regs[insn->src_reg], + &other_branch_regs[insn->dst_reg], + src_reg, dst_reg, opcode); } } else if (dst_reg->type == SCALAR_VALUE) { - reg_set_min_max(&other_branch_regs[insn->dst_reg], - dst_reg, insn->imm, (u32)insn->imm, - opcode, is_jmp32); + reg_set_min_max(&other_branch_regs[insn->dst_reg], src_reg, /* fake one */ + dst_reg, src_reg /* same fake one */, + opcode, is_jmp32); } + if (BPF_SRC(insn->code) == BPF_X && + src_reg->type == SCALAR_VALUE && src_reg->id && + !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { + find_equal_scalars(this_branch, src_reg); + find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); + } if (dst_reg->type == SCALAR_VALUE && dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) { find_equal_scalars(this_branch, dst_reg); -- 2.34.1