On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote: > Generalize bounds adjustment logic of reg_set_min_max() to handle not > just register vs constant case, but in general any register vs any > register cases. For most of the operations it's trivial extension based > on range vs range comparison logic, we just need to properly pick > min/max of a range to compare against min/max of the other range. > > For BPF_JSET we keep the original capabilities, just make sure JSET is > integrated in the common framework. This is manifested in the > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more > uniform rev_opcode() handling. See the code for details. This allows to > reuse the same code exactly both for TRUE and FALSE branches without > explicitly handling both conditions with custom code. > > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE > case none of the registers are constants. This is now just a normal > generic case handled by reg_set_min_max(). > > To make tnum handling cleaner, tnum_with_subreg() helper is added, as > that's a common operator when dealing with 32-bit subregister bounds. > This keeps the overall logic much less noisy when it comes to tnums. > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx> (With one bit of a bikeshedding below). > --- > include/linux/tnum.h | 4 + > kernel/bpf/tnum.c | 7 +- > kernel/bpf/verifier.c | 327 ++++++++++++++++++++---------------------- > 3 files changed, 165 insertions(+), 173 deletions(-) > > diff --git a/include/linux/tnum.h b/include/linux/tnum.h > index 1c3948a1d6ad..3c13240077b8 100644 > --- a/include/linux/tnum.h > +++ b/include/linux/tnum.h > @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a); > struct tnum tnum_subreg(struct tnum a); > /* Returns the tnum with the lower 32-bit subreg cleared */ > struct tnum tnum_clear_subreg(struct tnum a); > +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower > + * 32-bit subreg in *subreg* > + */ > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg); > /* Returns the tnum with the lower 32-bit subreg set to value */ > struct tnum tnum_const_subreg(struct tnum a, u32 value); > /* Returns true if 32-bit subreg @a is a known constant*/ > diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c > index 3d7127f439a1..f4c91c9b27d7 100644 > --- a/kernel/bpf/tnum.c > +++ b/kernel/bpf/tnum.c > @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a) > return tnum_lshift(tnum_rshift(a, 32), 32); > } > > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg) > +{ > + return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg)); > +} > + > struct tnum tnum_const_subreg(struct tnum a, u32 value) > { > - return tnum_or(tnum_clear_subreg(a), tnum_const(value)); > + return tnum_with_subreg(a, tnum_const(value)); > } > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index 2197385d91dc..52934080042c 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg > return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); > } > > -/* Adjusts the register min/max values in the case that the dst_reg and > - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K > - * check, in which case we havea fake SCALAR_VALUE representing insn->imm). > - * Technically we can do similar adjustments for pointers to the same object, > - * but we don't support that right now. > +/* Opcode that corresponds to a *false* branch condition. > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2 > */ > -static void reg_set_min_max(struct bpf_reg_state *true_reg1, > - struct bpf_reg_state *true_reg2, > - struct bpf_reg_state *false_reg1, > - struct bpf_reg_state *false_reg2, > - u8 opcode, bool is_jmp32) > +static u8 rev_opcode(u8 opcode) > { > - struct tnum false_32off, false_64off; > - struct tnum true_32off, true_64off; > - u64 uval; > - u32 uval32; > - s64 sval; > - s32 sval32; > - > - /* If either register is a pointer, we can't learn anything about its > - * variable offset from the compare (unless they were a pointer into > - * the same object, but we don't bother with that). > + switch (opcode) { > + case BPF_JEQ: return BPF_JNE; > + case BPF_JNE: return BPF_JEQ; > + /* JSET doesn't have it's reverse opcode in BPF, so add > + * BPF_X flag to denote the reverse of that operation > */ > - if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) > - return; > - > - /* we expect right-hand registers (src ones) to be constants, for now */ > - if (!is_reg_const(false_reg2, is_jmp32)) { > - opcode = flip_opcode(opcode); > - swap(true_reg1, true_reg2); > - swap(false_reg1, false_reg2); > + case BPF_JSET: return BPF_JSET | BPF_X; > + case BPF_JSET | BPF_X: return BPF_JSET; > + case BPF_JGE: return BPF_JLT; > + case BPF_JGT: return BPF_JLE; > + case BPF_JLE: return BPF_JGT; > + case BPF_JLT: return BPF_JGE; > + case BPF_JSGE: return BPF_JSLT; > + case BPF_JSGT: return BPF_JSLE; > + case BPF_JSLE: return BPF_JSGT; > + case BPF_JSLT: return BPF_JSGE; > + default: return 0; > } > - if (!is_reg_const(false_reg2, is_jmp32)) > - return; > +} > > - false_32off = tnum_subreg(false_reg1->var_off); > - false_64off = false_reg1->var_off; > - true_32off = tnum_subreg(true_reg1->var_off); > - true_64off = true_reg1->var_off; > - uval = false_reg2->var_off.value; > - uval32 = (u32)tnum_subreg(false_reg2->var_off).value; > - sval = (s64)uval; > - sval32 = (s32)uval32; > +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */ > +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, > + u8 opcode, bool is_jmp32) > +{ > + struct tnum t; > > switch (opcode) { > - /* JEQ/JNE comparison doesn't change the register equivalence. > - * > - * r1 = r2; > - * if (r1 == 42) goto label; > - * ... > - * label: // here both r1 and r2 are known to be 42. > - * > - * Hence when marking register as known preserve it's ID. > - */ > case BPF_JEQ: > if (is_jmp32) { > - __mark_reg32_known(true_reg1, uval32); > - true_32off = tnum_subreg(true_reg1->var_off); > + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); > + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); > + reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); > + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); > + reg2->u32_min_value = reg1->u32_min_value; > + reg2->u32_max_value = reg1->u32_max_value; > + reg2->s32_min_value = reg1->s32_min_value; > + reg2->s32_max_value = reg1->s32_max_value; > + > + t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off)); > + reg1->var_off = tnum_with_subreg(reg1->var_off, t); > + reg2->var_off = tnum_with_subreg(reg2->var_off, t); > } else { > - ___mark_reg_known(true_reg1, uval); > - true_64off = true_reg1->var_off; > + reg1->umin_value = max(reg1->umin_value, reg2->umin_value); > + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); > + reg1->smin_value = max(reg1->smin_value, reg2->smin_value); > + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); > + reg2->umin_value = reg1->umin_value; > + reg2->umax_value = reg1->umax_value; > + reg2->smin_value = reg1->smin_value; > + reg2->smax_value = reg1->smax_value; > + > + reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off); > + reg2->var_off = reg1->var_off; > } > break; > case BPF_JNE: > + /* we don't derive any new information for inequality yet */ > + break; > + case BPF_JSET: > + case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */ > + u64 val; > + > + if (!is_reg_const(reg2, is_jmp32)) > + swap(reg1, reg2); > + if (!is_reg_const(reg2, is_jmp32)) > + break; > + > + val = reg_const_value(reg2, is_jmp32); > + /* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X) > + * requires single bit to learn something useful. E.g., if we > + * know that `r1 & 0x3` is true, then which bits (0, 1, or both) > + * are actually set? We can learn something definite only if > + * it's a single-bit value to begin with. > + * > + * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have > + * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor > + * bit 1 is set, which we can readily use in adjustments. > + */ > + if (!(opcode & BPF_X) && !is_power_of_2(val)) > + break; > + > if (is_jmp32) { > - __mark_reg32_known(false_reg1, uval32); > - false_32off = tnum_subreg(false_reg1->var_off); > + if (opcode & BPF_X) > + t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val)); > + else > + t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val)); > + reg1->var_off = tnum_with_subreg(reg1->var_off, t); > } else { > - ___mark_reg_known(false_reg1, uval); > - false_64off = false_reg1->var_off; > + if (opcode & BPF_X) > + reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val)); > + else > + reg1->var_off = tnum_or(reg1->var_off, tnum_const(val)); > } > break; > - case BPF_JSET: > + } > + case BPF_JGE: > if (is_jmp32) { > - false_32off = tnum_and(false_32off, tnum_const(~uval32)); > - if (is_power_of_2(uval32)) > - true_32off = tnum_or(true_32off, > - tnum_const(uval32)); > + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); > + reg2->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); > } else { > - false_64off = tnum_and(false_64off, tnum_const(~uval)); > - if (is_power_of_2(uval)) > - true_64off = tnum_or(true_64off, > - tnum_const(uval)); > + reg1->umin_value = max(reg1->umin_value, reg2->umin_value); > + reg2->umax_value = min(reg1->umax_value, reg2->umax_value); > } > break; > - case BPF_JGE: > case BPF_JGT: > - { > if (is_jmp32) { > - u32 false_umax = opcode == BPF_JGT ? uval32 : uval32 - 1; > - u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32; > - > - false_reg1->u32_max_value = min(false_reg1->u32_max_value, > - false_umax); > - true_reg1->u32_min_value = max(true_reg1->u32_min_value, > - true_umin); > + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value + 1); > + reg2->u32_max_value = min(reg1->u32_max_value - 1, reg2->u32_max_value); > } else { > - u64 false_umax = opcode == BPF_JGT ? uval : uval - 1; > - u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval; > - > - false_reg1->umax_value = min(false_reg1->umax_value, false_umax); > - true_reg1->umin_value = max(true_reg1->umin_value, true_umin); > + reg1->umin_value = max(reg1->umin_value, reg2->umin_value + 1); > + reg2->umax_value = min(reg1->umax_value - 1, reg2->umax_value); > } > break; > - } > case BPF_JSGE: > + if (is_jmp32) { > + reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); > + reg2->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); > + } else { > + reg1->smin_value = max(reg1->smin_value, reg2->smin_value); > + reg2->smax_value = min(reg1->smax_value, reg2->smax_value); > + } > + break; > case BPF_JSGT: It is possible to spare some code by swapping arguments here: case BPF_JLE: case BPF_JLT: case BPF_JSLE: case BPF_JSLT: return regs_refine_cond_op(reg2, reg1, flip_opcode(opcode), is_jmp32); > - { > if (is_jmp32) { > - s32 false_smax = opcode == BPF_JSGT ? sval32 : sval32 - 1; > - s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32; > - > - false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax); > - true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin); > + reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value + 1); > + reg2->s32_max_value = min(reg1->s32_max_value - 1, reg2->s32_max_value); > } else { > - s64 false_smax = opcode == BPF_JSGT ? sval : sval - 1; > - s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval; > - > - false_reg1->smax_value = min(false_reg1->smax_value, false_smax); > - true_reg1->smin_value = max(true_reg1->smin_value, true_smin); > + reg1->smin_value = max(reg1->smin_value, reg2->smin_value + 1); > + reg2->smax_value = min(reg1->smax_value - 1, reg2->smax_value); > } > break; > - } > case BPF_JLE: > + if (is_jmp32) { > + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); > + reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); > + } else { > + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); > + reg2->umin_value = max(reg1->umin_value, reg2->umin_value); > + } > + break; > case BPF_JLT: > - { > if (is_jmp32) { > - u32 false_umin = opcode == BPF_JLT ? uval32 : uval32 + 1; > - u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32; > - > - false_reg1->u32_min_value = max(false_reg1->u32_min_value, > - false_umin); > - true_reg1->u32_max_value = min(true_reg1->u32_max_value, > - true_umax); > + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1); > + reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value); > } else { > - u64 false_umin = opcode == BPF_JLT ? uval : uval + 1; > - u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval; > - > - false_reg1->umin_value = max(false_reg1->umin_value, false_umin); > - true_reg1->umax_value = min(true_reg1->umax_value, true_umax); > + reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1); > + reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value); > } > break; > - } > case BPF_JSLE: > + if (is_jmp32) { > + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); > + reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); > + } else { > + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); > + reg2->smin_value = max(reg1->smin_value, reg2->smin_value); > + } > + break; > case BPF_JSLT: > - { > if (is_jmp32) { > - s32 false_smin = opcode == BPF_JSLT ? sval32 : sval32 + 1; > - s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32; > - > - false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin); > - true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax); > + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1); > + reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value); > } else { > - s64 false_smin = opcode == BPF_JSLT ? sval : sval + 1; > - s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval; > - > - false_reg1->smin_value = max(false_reg1->smin_value, false_smin); > - true_reg1->smax_value = min(true_reg1->smax_value, true_smax); > + reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1); > + reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value); > } > break; > - } > default: > return; > } > - > - if (is_jmp32) { > - false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off), > - tnum_subreg(false_32off)); > - true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off), > - tnum_subreg(true_32off)); > - reg_bounds_sync(false_reg1); > - reg_bounds_sync(true_reg1); > - } else { > - false_reg1->var_off = false_64off; > - true_reg1->var_off = true_64off; > - reg_bounds_sync(false_reg1); > - reg_bounds_sync(true_reg1); > - } > -} > - > -/* Regs are known to be equal, so intersect their min/max/var_off */ > -static void __reg_combine_min_max(struct bpf_reg_state *src_reg, > - struct bpf_reg_state *dst_reg) > -{ > - src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value, > - dst_reg->umin_value); > - src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value, > - dst_reg->umax_value); > - src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value, > - dst_reg->smin_value); > - src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value, > - dst_reg->smax_value); > - src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off, > - dst_reg->var_off); > - reg_bounds_sync(src_reg); > - reg_bounds_sync(dst_reg); > } > > -static void reg_combine_min_max(struct bpf_reg_state *true_src, > - struct bpf_reg_state *true_dst, > - struct bpf_reg_state *false_src, > - struct bpf_reg_state *false_dst, > - u8 opcode) > +/* Adjusts the register min/max values in the case that the dst_reg and > + * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K > + * check, in which case we havea fake SCALAR_VALUE representing insn->imm). > + * Technically we can do similar adjustments for pointers to the same object, > + * but we don't support that right now. > + */ > +static void reg_set_min_max(struct bpf_reg_state *true_reg1, > + struct bpf_reg_state *true_reg2, > + struct bpf_reg_state *false_reg1, > + struct bpf_reg_state *false_reg2, > + u8 opcode, bool is_jmp32) > { > - switch (opcode) { > - case BPF_JEQ: > - __reg_combine_min_max(true_src, true_dst); > - break; > - case BPF_JNE: > - __reg_combine_min_max(false_src, false_dst); > - break; > - } > + /* If either register is a pointer, we can't learn anything about its > + * variable offset from the compare (unless they were a pointer into > + * the same object, but we don't bother with that). > + */ > + if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) > + return; > + > + /* fallthrough (FALSE) branch */ > + regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32); > + reg_bounds_sync(false_reg1); > + reg_bounds_sync(false_reg2); > + > + /* jump (TRUE) branch */ > + regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32); > + reg_bounds_sync(true_reg1); > + reg_bounds_sync(true_reg2); > } > > static void mark_ptr_or_null_reg(struct bpf_func_state *state, > @@ -14887,22 +14880,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, > reg_set_min_max(&other_branch_regs[insn->dst_reg], > &other_branch_regs[insn->src_reg], > dst_reg, src_reg, opcode, is_jmp32); > - > - if (dst_reg->type == SCALAR_VALUE && > - src_reg->type == SCALAR_VALUE && > - !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) { > - /* Comparing for equality, we can combine knowledge */ > - reg_combine_min_max(&other_branch_regs[insn->src_reg], > - &other_branch_regs[insn->dst_reg], > - src_reg, dst_reg, opcode); > - } > } else /* BPF_SRC(insn->code) == BPF_K */ { > reg_set_min_max(&other_branch_regs[insn->dst_reg], > src_reg /* fake one */, > dst_reg, src_reg /* same fake one */, > opcode, is_jmp32); > } > - > if (BPF_SRC(insn->code) == BPF_X && > src_reg->type == SCALAR_VALUE && src_reg->id && > !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {