On Wed, 2023-11-01 at 09:35 -0700, Andrii Nakryiko wrote: > On Tue, Oct 31, 2023 at 4:25 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote: > > > > On Fri, 2023-10-27 at 11:13 -0700, Andrii Nakryiko wrote: > > > Generalize bounds adjustment logic of reg_set_min_max() to handle not > > > just register vs constant case, but in general any register vs any > > > register cases. For most of the operations it's trivial extension based > > > on range vs range comparison logic, we just need to properly pick > > > min/max of a range to compare against min/max of the other range. > > > > > > For BPF_JSET we keep the original capabilities, just make sure JSET is > > > integrated in the common framework. This is manifested in the > > > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more > > > uniform rev_opcode() handling. See the code for details. This allows to > > > reuse the same code exactly both for TRUE and FALSE branches without > > > explicitly handling both conditions with custom code. > > > > > > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE > > > case none of the registers are constants. This is now just a normal > > > generic case handled by reg_set_min_max(). > > > > > > To make tnum handling cleaner, tnum_with_subreg() helper is added, as > > > that's a common operator when dealing with 32-bit subregister bounds. > > > This keeps the overall logic much less noisy when it comes to tnums. > > > > > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> > > > --- > > > include/linux/tnum.h | 4 + > > > kernel/bpf/tnum.c | 7 +- > > > kernel/bpf/verifier.c | 321 +++++++++++++++++++----------------------- > > > 3 files changed, 157 insertions(+), 175 deletions(-) > > > > > > diff --git a/include/linux/tnum.h b/include/linux/tnum.h > > > index 1c3948a1d6ad..3c13240077b8 100644 > > > --- a/include/linux/tnum.h > > > +++ b/include/linux/tnum.h > > > @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a); > > > struct tnum tnum_subreg(struct tnum a); > > > /* Returns the tnum with the lower 32-bit subreg cleared */ > > > struct tnum tnum_clear_subreg(struct tnum a); > > > +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower > > > + * 32-bit subreg in *subreg* > > > + */ > > > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg); > > > /* Returns the tnum with the lower 32-bit subreg set to value */ > > > struct tnum tnum_const_subreg(struct tnum a, u32 value); > > > /* Returns true if 32-bit subreg @a is a known constant*/ > > > diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c > > > index 3d7127f439a1..f4c91c9b27d7 100644 > > > --- a/kernel/bpf/tnum.c > > > +++ b/kernel/bpf/tnum.c > > > @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a) > > > return tnum_lshift(tnum_rshift(a, 32), 32); > > > } > > > > > > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg) > > > +{ > > > + return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg)); > > > +} > > > + > > > struct tnum tnum_const_subreg(struct tnum a, u32 value) > > > { > > > - return tnum_or(tnum_clear_subreg(a), tnum_const(value)); > > > + return tnum_with_subreg(a, tnum_const(value)); > > > } > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > > > index 522566699fbe..4c974296127b 100644 > > > --- a/kernel/bpf/verifier.c > > > +++ b/kernel/bpf/verifier.c > > > @@ -14381,217 +14381,201 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg > > > return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); > > > } > > > > > > -/* Adjusts the register min/max values in the case that the dst_reg is the > > > - * variable register that we are working on, and src_reg is a constant or we're > > > - * simply doing a BPF_K check. > > > - * In JEQ/JNE cases we also adjust the var_off values. > > > +/* Opcode that corresponds to a *false* branch condition. > > > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2 > > > */ > > > -static void reg_set_min_max(struct bpf_reg_state *true_reg1, > > > - struct bpf_reg_state *true_reg2, > > > - struct bpf_reg_state *false_reg1, > > > - struct bpf_reg_state *false_reg2, > > > - u8 opcode, bool is_jmp32) > > > +static u8 rev_opcode(u8 opcode) > > > > Note: this duplicates flip_opcode() (modulo BPF_JSET). > > Not at all! flip_opcode() is for swapping argument order, so JEQ stays > JEQ, but <= becomes >=. While rev_opcode() is for the true/false > branch. So JEQ in the true branch becomes JNE in the false branch, < > is true is complemented by >= in the false branch. Right, my bad, sorry. > > > > > > { > > > - struct tnum false_32off, false_64off; > > > - struct tnum true_32off, true_64off; > > > - u64 val; > > > - u32 val32; > > > - s64 sval; > > > - s32 sval32; > > > - > > [...] > > > > + /* we don't derive any new information for inequality yet */ > > > + break; > > > + case BPF_JSET: > > > + case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */ > > > + u64 val; > > > + > > > + if (!is_reg_const(reg2, is_jmp32)) > > > + swap(reg1, reg2); > > > + if (!is_reg_const(reg2, is_jmp32)) > > > + break; > > > + > > > + val = reg_const_value(reg2, is_jmp32); > > > + /* BPF_JSET requires single bit to learn something useful */ > > > + if (!(opcode & BPF_X) && !is_power_of_2(val)) > > > > Could you please extend comment a bit, e.g. as follows: > > > > /* For BPF_JSET true branch (!(opcode & BPF_X)) a single bit > > * is needed to learn something useful. > > */ > > > > For some reason it took me a while to understand this condition :( > > ok, sure > > > > > > + break; > > > + > > [...] > > > > - case BPF_JGE: > > > case BPF_JGT: > > > - { > > > if (is_jmp32) { > > > - u32 false_umax = opcode == BPF_JGT ? val32 : val32 - 1; > > > - u32 true_umin = opcode == BPF_JGT ? val32 + 1 : val32; > > > - > > > - false_reg1->u32_max_value = min(false_reg1->u32_max_value, > > > - false_umax); > > > - true_reg1->u32_min_value = max(true_reg1->u32_min_value, > > > - true_umin); > > > + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value + 1); > > > > Question: This branch means that reg1 > reg2, right? > > If so, why not use reg2->u32_MAX_value, e.g.: > > > > reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_max_value + 1); > > > > Do I miss something? > > Let's say reg1 can be anything in [10, 20], while reg2 is in [15, 30]. > if reg1 > reg2, then we can only guarantee that reg1 can be [16, 20], > because worst case reg2 = 15, not 30, right? Right, thank you. I should probably refrain from sending comments after midnight. > > > > > > + reg2->u32_max_value = min(reg1->u32_max_value - 1, reg2->u32_max_value); > > > } else { > > > - u64 false_umax = opcode == BPF_JGT ? val : val - 1; > > > - u64 true_umin = opcode == BPF_JGT ? val + 1 : val; > > > - > > > - false_reg1->umax_value = min(false_reg1->umax_value, false_umax); > > > - true_reg1->umin_value = max(true_reg1->umin_value, true_umin); > > > + reg1->umin_value = max(reg1->umin_value, reg2->umin_value + 1); > > > + reg2->umax_value = min(reg1->umax_value - 1, reg2->umax_value); > > > } > > > break; > > [...]