On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote: > Generalize is_branch_taken logic for SCALAR_VALUE register to handle > cases when both registers are not constants. Previously supported > <range> vs <scalar> cases are a natural subset of more generic <range> > vs <range> set of cases. > > Generalized logic relies on straightforward segment intersection checks. > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx> (With the same nitpick that '<' cases could be converted to '>' cases). > --- > kernel/bpf/verifier.c | 103 ++++++++++++++++++++++++++---------------- > 1 file changed, 63 insertions(+), 40 deletions(-) > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index 52934080042c..2627461164ed 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -14187,82 +14187,104 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta > u8 opcode, bool is_jmp32) > { > struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; > + struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off; > u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value; > u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value; > s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value; > s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value; > - u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value; > - s64 sval = is_jmp32 ? (s32)uval : (s64)uval; > + u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value; > + u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value; > + s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value; > + s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value; > > switch (opcode) { > case BPF_JEQ: > - if (tnum_is_const(t1)) > - return !!tnum_equals_const(t1, uval); > - else if (uval < umin1 || uval > umax1) > + /* constants, umin/umax and smin/smax checks would be > + * redundant in this case because they all should match > + */ > + if (tnum_is_const(t1) && tnum_is_const(t2)) > + return t1.value == t2.value; > + /* const ranges */ > + if (umin1 == umax1 && umin2 == umax2) > + return umin1 == umin2; > + if (smin1 == smax1 && smin2 == smax2) > + return smin1 == smin2; > + /* non-overlapping ranges */ > + if (umin1 > umax2 || umax1 < umin2) > return 0; > - else if (sval < smin1 || sval > smax1) > + if (smin1 > smax2 || smax1 < smin2) > return 0; > break; > case BPF_JNE: > - if (tnum_is_const(t1)) > - return !tnum_equals_const(t1, uval); > - else if (uval < umin1 || uval > umax1) > + /* constants, umin/umax and smin/smax checks would be > + * redundant in this case because they all should match > + */ > + if (tnum_is_const(t1) && tnum_is_const(t2)) > + return t1.value != t2.value; > + /* non-overlapping ranges */ > + if (umin1 > umax2 || umax1 < umin2) > return 1; > - else if (sval < smin1 || sval > smax1) > + if (smin1 > smax2 || smax1 < smin2) > return 1; > break; > case BPF_JSET: > - if ((~t1.mask & t1.value) & uval) > + if (!is_reg_const(reg2, is_jmp32)) { > + swap(reg1, reg2); > + swap(t1, t2); > + } > + if (!is_reg_const(reg2, is_jmp32)) > + return -1; > + if ((~t1.mask & t1.value) & t2.value) > return 1; > - if (!((t1.mask | t1.value) & uval)) > + if (!((t1.mask | t1.value) & t2.value)) > return 0; > break; > case BPF_JGT: > - if (umin1 > uval ) > + if (umin1 > umax2) > return 1; > - else if (umax1 <= uval) > + else if (umax1 <= umin2) > return 0; > break; > case BPF_JSGT: > - if (smin1 > sval) > + if (smin1 > smax2) > return 1; > - else if (smax1 <= sval) > + else if (smax1 <= smin2) > return 0; > break; > case BPF_JLT: > - if (umax1 < uval) > + if (umax1 < umin2) > return 1; > - else if (umin1 >= uval) > + else if (umin1 >= umax2) > return 0; > break; > case BPF_JSLT: > - if (smax1 < sval) > + if (smax1 < smin2) > return 1; > - else if (smin1 >= sval) > + else if (smin1 >= smax2) > return 0; > break; > case BPF_JGE: > - if (umin1 >= uval) > + if (umin1 >= umax2) > return 1; > - else if (umax1 < uval) > + else if (umax1 < umin2) > return 0; > break; > case BPF_JSGE: > - if (smin1 >= sval) > + if (smin1 >= smax2) > return 1; > - else if (smax1 < sval) > + else if (smax1 < smin2) > return 0; > break; > case BPF_JLE: > - if (umax1 <= uval) > + if (umax1 <= umin2) > return 1; > - else if (umin1 > uval) > + else if (umin1 > umax2) > return 0; > break; > case BPF_JSLE: > - if (smax1 <= sval) > + if (smax1 <= smin2) > return 1; > - else if (smin1 > sval) > + else if (smin1 > smax2) > return 0; > break; > } > @@ -14341,28 +14363,28 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg, > static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, > u8 opcode, bool is_jmp32) > { > - u64 val; > - > if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32) > return is_pkt_ptr_branch_taken(reg1, reg2, opcode); > > - /* try to make sure reg2 is a constant SCALAR_VALUE */ > - if (!is_reg_const(reg2, is_jmp32)) { > - opcode = flip_opcode(opcode); > - swap(reg1, reg2); > - } > - /* for now we expect reg2 to be a constant to make any useful decisions */ > - if (!is_reg_const(reg2, is_jmp32)) > - return -1; > - val = reg_const_value(reg2, is_jmp32); > + if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) { > + u64 val; > + > + /* arrange that reg2 is a scalar, and reg1 is a pointer */ > + if (!is_reg_const(reg2, is_jmp32)) { > + opcode = flip_opcode(opcode); > + swap(reg1, reg2); > + } > + /* and ensure that reg2 is a constant */ > + if (!is_reg_const(reg2, is_jmp32)) > + return -1; > > - if (__is_pointer_value(false, reg1)) { > if (!reg_not_null(reg1)) > return -1; > > /* If pointer is valid tests against zero will fail so we can > * use this to direct branch taken. > */ > + val = reg_const_value(reg2, is_jmp32); > if (val != 0) > return -1; > > @@ -14376,6 +14398,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg > } > } > > + /* now deal with two scalars, but not necessarily constants */ > return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); > } >