On Fri, 2023-10-27 at 11:13 -0700, Andrii Nakryiko wrote: > Change reg_set_min_max() to take FALSE/TRUE sets of two registers each, > instead of assuming that we are always comparing to a constant. For now > we still assume that right-hand side registers are constants (and make > sure that's the case by swapping src/dst regs, if necessary), but > subsequent patches will remove this limitation. > > Taking two by two registers allows to further unify and simplify > check_cond_jmp_op() logic. We utilize fake register for BPF_K > conditional jump case, just like with is_branch_taken() part. > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > kernel/bpf/verifier.c | 112 ++++++++++++++++++------------------------ > 1 file changed, 49 insertions(+), 63 deletions(-) > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index dde04b17c3a3..522566699fbe 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -14387,26 +14387,43 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg > * In JEQ/JNE cases we also adjust the var_off values. > */ > static void reg_set_min_max(struct bpf_reg_state *true_reg1, > + struct bpf_reg_state *true_reg2, > struct bpf_reg_state *false_reg1, > - u64 val, u32 val32, > + struct bpf_reg_state *false_reg2, > u8 opcode, bool is_jmp32) > { > - struct tnum false_32off = tnum_subreg(false_reg1->var_off); > - struct tnum false_64off = false_reg1->var_off; > - struct tnum true_32off = tnum_subreg(true_reg1->var_off); > - struct tnum true_64off = true_reg1->var_off; > - s64 sval = (s64)val; > - s32 sval32 = (s32)val32; > - > - /* If the dst_reg is a pointer, we can't learn anything about its > - * variable offset from the compare (unless src_reg were a pointer into > - * the same object, but we don't bother with that. > - * Since false_reg1 and true_reg1 have the same type by construction, we > - * only need to check one of them for pointerness. > + struct tnum false_32off, false_64off; > + struct tnum true_32off, true_64off; > + u64 val; > + u32 val32; > + s64 sval; > + s32 sval32; > + > + /* If either register is a pointer, we can't learn anything about its > + * variable offset from the compare (unless they were a pointer into > + * the same object, but we don't bother with that). > */ > - if (__is_pointer_value(false, false_reg1)) > + if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) > + return; > + > + /* we expect right-hand registers (src ones) to be constants, for now */ > + if (!is_reg_const(false_reg2, is_jmp32)) { > + opcode = flip_opcode(opcode); > + swap(true_reg1, true_reg2); > + swap(false_reg1, false_reg2); > + } > + if (!is_reg_const(false_reg2, is_jmp32)) > return; > > + false_32off = tnum_subreg(false_reg1->var_off); > + false_64off = false_reg1->var_off; > + true_32off = tnum_subreg(true_reg1->var_off); > + true_64off = true_reg1->var_off; > + val = false_reg2->var_off.value; > + val32 = (u32)tnum_subreg(false_reg2->var_off).value; > + sval = (s64)val; > + sval32 = (s32)val32; > + > switch (opcode) { > /* JEQ/JNE comparison doesn't change the register equivalence. > * > @@ -14543,22 +14560,6 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg1, > } > } > > -/* Same as above, but for the case that dst_reg holds a constant and src_reg is > - * the variable reg. > - */ > -static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, > - struct bpf_reg_state *false_reg, > - u64 val, u32 val32, > - u8 opcode, bool is_jmp32) > -{ > - opcode = flip_opcode(opcode); > - /* This uses zero as "not present in table"; luckily the zero opcode, > - * BPF_JA, can't get here. > - */ > - if (opcode) > - reg_set_min_max(true_reg, false_reg, val, val32, opcode, is_jmp32); > -} > - > /* Regs are known to be equal, so intersect their min/max/var_off */ > static void __reg_combine_min_max(struct bpf_reg_state *src_reg, > struct bpf_reg_state *dst_reg) > @@ -14891,45 +14892,30 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, > * comparable. > */ > if (BPF_SRC(insn->code) == BPF_X) { > - struct bpf_reg_state *src_reg = ®s[insn->src_reg]; > + reg_set_min_max(&other_branch_regs[insn->dst_reg], > + &other_branch_regs[insn->src_reg], > + dst_reg, src_reg, opcode, is_jmp32); > > if (dst_reg->type == SCALAR_VALUE && > - src_reg->type == SCALAR_VALUE) { > - if (tnum_is_const(src_reg->var_off) || > - (is_jmp32 && > - tnum_is_const(tnum_subreg(src_reg->var_off)))) > - reg_set_min_max(&other_branch_regs[insn->dst_reg], > - dst_reg, > - src_reg->var_off.value, > - tnum_subreg(src_reg->var_off).value, > - opcode, is_jmp32); > - else if (tnum_is_const(dst_reg->var_off) || > - (is_jmp32 && > - tnum_is_const(tnum_subreg(dst_reg->var_off)))) > - reg_set_min_max_inv(&other_branch_regs[insn->src_reg], > - src_reg, > - dst_reg->var_off.value, > - tnum_subreg(dst_reg->var_off).value, > - opcode, is_jmp32); > - else if (!is_jmp32 && > - (opcode == BPF_JEQ || opcode == BPF_JNE)) > - /* Comparing for equality, we can combine knowledge */ > - reg_combine_min_max(&other_branch_regs[insn->src_reg], > - &other_branch_regs[insn->dst_reg], > - src_reg, dst_reg, opcode); > - if (src_reg->id && > - !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { > - find_equal_scalars(this_branch, src_reg); > - find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); > - } > - > + src_reg->type == SCALAR_VALUE && > + !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) { > + /* Comparing for equality, we can combine knowledge */ > + reg_combine_min_max(&other_branch_regs[insn->src_reg], > + &other_branch_regs[insn->dst_reg], > + src_reg, dst_reg, opcode); > } > } else if (dst_reg->type == SCALAR_VALUE) { > - reg_set_min_max(&other_branch_regs[insn->dst_reg], > - dst_reg, insn->imm, (u32)insn->imm, > - opcode, is_jmp32); > + reg_set_min_max(&other_branch_regs[insn->dst_reg], src_reg, /* fake one */ > + dst_reg, src_reg /* same fake one */, > + opcode, is_jmp32); > } > > + if (BPF_SRC(insn->code) == BPF_X && > + src_reg->type == SCALAR_VALUE && src_reg->id && > + !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { > + find_equal_scalars(this_branch, src_reg); > + find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); > + } > if (dst_reg->type == SCALAR_VALUE && dst_reg->id && > !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) { > find_equal_scalars(this_branch, dst_reg);