On Fri, 2023-10-27 at 11:13 -0700, Andrii Nakryiko wrote: > Combine 32-bit and 64-bit is_branch_taken logic for SCALAR_VALUE > registers. It makes it easier to see parallels between two domains > (32-bit and 64-bit), and makes subsequent refactoring more > straightforward. > > No functional changes. > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > kernel/bpf/verifier.c | 154 ++++++++++-------------------------------- > 1 file changed, 36 insertions(+), 118 deletions(-) > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index fedd6d0e76e5..b911d1111fad 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -14185,166 +14185,86 @@ static u64 reg_const_value(struct bpf_reg_state *reg, bool subreg32) > /* > * <reg1> <op> <reg2>, currently assuming reg2 is a constant > */ > -static int is_branch32_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode) > +static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, > + u8 opcode, bool is_jmp32) > { > - struct tnum subreg = tnum_subreg(reg1->var_off); > - u32 val = (u32)tnum_subreg(reg2->var_off).value; > - s32 sval = (s32)val; > + struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; > + u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value; > + u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value; > + s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value; > + s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value; > + u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value; > + s64 sval = is_jmp32 ? (s32)val : (s64)val; > > switch (opcode) { > case BPF_JEQ: > - if (tnum_is_const(subreg)) > - return !!tnum_equals_const(subreg, val); > - else if (val < reg1->u32_min_value || val > reg1->u32_max_value) > + if (tnum_is_const(t1)) > + return !!tnum_equals_const(t1, val); > + else if (val < umin1 || val > umax1) > return 0; > - else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value) > + else if (sval < smin1 || sval > smax1) > return 0; > break; > case BPF_JNE: > - if (tnum_is_const(subreg)) > - return !tnum_equals_const(subreg, val); > - else if (val < reg1->u32_min_value || val > reg1->u32_max_value) > + if (tnum_is_const(t1)) > + return !tnum_equals_const(t1, val); > + else if (val < umin1 || val > umax1) > return 1; > - else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value) > + else if (sval < smin1 || sval > smax1) > return 1; > break; > case BPF_JSET: > - if ((~subreg.mask & subreg.value) & val) > + if ((~t1.mask & t1.value) & val) > return 1; > - if (!((subreg.mask | subreg.value) & val)) > + if (!((t1.mask | t1.value) & val)) > return 0; > break; > case BPF_JGT: > - if (reg1->u32_min_value > val) > + if (umin1 > val ) > return 1; > - else if (reg1->u32_max_value <= val) > + else if (umax1 <= val) > return 0; > break; > case BPF_JSGT: > - if (reg1->s32_min_value > sval) > + if (smin1 > sval) > return 1; > - else if (reg1->s32_max_value <= sval) > + else if (smax1 <= sval) > return 0; > break; > case BPF_JLT: > - if (reg1->u32_max_value < val) > + if (umax1 < val) > return 1; > - else if (reg1->u32_min_value >= val) > + else if (umin1 >= val) > return 0; > break; > case BPF_JSLT: > - if (reg1->s32_max_value < sval) > + if (smax1 < sval) > return 1; > - else if (reg1->s32_min_value >= sval) > + else if (smin1 >= sval) > return 0; > break; > case BPF_JGE: > - if (reg1->u32_min_value >= val) > + if (umin1 >= val) > return 1; > - else if (reg1->u32_max_value < val) > + else if (umax1 < val) > return 0; > break; > case BPF_JSGE: > - if (reg1->s32_min_value >= sval) > + if (smin1 >= sval) > return 1; > - else if (reg1->s32_max_value < sval) > + else if (smax1 < sval) > return 0; > break; > case BPF_JLE: > - if (reg1->u32_max_value <= val) > + if (umax1 <= val) > return 1; > - else if (reg1->u32_min_value > val) > + else if (umin1 > val) > return 0; > break; > case BPF_JSLE: > - if (reg1->s32_max_value <= sval) > + if (smax1 <= sval) > return 1; > - else if (reg1->s32_min_value > sval) > - return 0; > - break; > - } > - > - return -1; > -} > - > - > -/* > - * <reg1> <op> <reg2>, currently assuming reg2 is a constant > - */ > -static int is_branch64_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode) > -{ > - u64 val = reg2->var_off.value; > - s64 sval = (s64)val; > - > - switch (opcode) { > - case BPF_JEQ: > - if (tnum_is_const(reg1->var_off)) > - return !!tnum_equals_const(reg1->var_off, val); > - else if (val < reg1->umin_value || val > reg1->umax_value) > - return 0; > - else if (sval < reg1->smin_value || sval > reg1->smax_value) > - return 0; > - break; > - case BPF_JNE: > - if (tnum_is_const(reg1->var_off)) > - return !tnum_equals_const(reg1->var_off, val); > - else if (val < reg1->umin_value || val > reg1->umax_value) > - return 1; > - else if (sval < reg1->smin_value || sval > reg1->smax_value) > - return 1; > - break; > - case BPF_JSET: > - if ((~reg1->var_off.mask & reg1->var_off.value) & val) > - return 1; > - if (!((reg1->var_off.mask | reg1->var_off.value) & val)) > - return 0; > - break; > - case BPF_JGT: > - if (reg1->umin_value > val) > - return 1; > - else if (reg1->umax_value <= val) > - return 0; > - break; > - case BPF_JSGT: > - if (reg1->smin_value > sval) > - return 1; > - else if (reg1->smax_value <= sval) > - return 0; > - break; > - case BPF_JLT: > - if (reg1->umax_value < val) > - return 1; > - else if (reg1->umin_value >= val) > - return 0; > - break; > - case BPF_JSLT: > - if (reg1->smax_value < sval) > - return 1; > - else if (reg1->smin_value >= sval) > - return 0; > - break; > - case BPF_JGE: > - if (reg1->umin_value >= val) > - return 1; > - else if (reg1->umax_value < val) > - return 0; > - break; > - case BPF_JSGE: > - if (reg1->smin_value >= sval) > - return 1; > - else if (reg1->smax_value < sval) > - return 0; > - break; > - case BPF_JLE: > - if (reg1->umax_value <= val) > - return 1; > - else if (reg1->umin_value > val) > - return 0; > - break; > - case BPF_JSLE: > - if (reg1->smax_value <= sval) > - return 1; > - else if (reg1->smin_value > sval) > + else if (smin1 > sval) > return 0; > break; > } > @@ -14458,9 +14378,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg > } > } > > - if (is_jmp32) > - return is_branch32_taken(reg1, reg2, opcode); > - return is_branch64_taken(reg1, reg2, opcode); > + return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); > } > > /* Adjusts the register min/max values in the case that the dst_reg is the