Combine 32-bit and 64-bit is_branch_taken logic for SCALAR_VALUE registers. It makes it easier to see parallels between two domains (32-bit and 64-bit), and makes subsequent refactoring more straightforward. No functional changes. Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx> Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx> --- kernel/bpf/verifier.c | 200 +++++++++++++----------------------------- 1 file changed, 59 insertions(+), 141 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index d5213cef5389..b077dd99b159 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -14183,166 +14183,86 @@ static u64 reg_const_value(struct bpf_reg_state *reg, bool subreg32) /* * <reg1> <op> <reg2>, currently assuming reg2 is a constant */ -static int is_branch32_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode) +static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) { - struct tnum subreg = tnum_subreg(reg1->var_off); - u32 val = (u32)tnum_subreg(reg2->var_off).value; - s32 sval = (s32)val; + struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; + u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value; + u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value; + s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value; + s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value; + u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value; + s64 sval = is_jmp32 ? (s32)uval : (s64)uval; switch (opcode) { case BPF_JEQ: - if (tnum_is_const(subreg)) - return !!tnum_equals_const(subreg, val); - else if (val < reg1->u32_min_value || val > reg1->u32_max_value) + if (tnum_is_const(t1)) + return !!tnum_equals_const(t1, uval); + else if (uval < umin1 || uval > umax1) return 0; - else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value) + else if (sval < smin1 || sval > smax1) return 0; break; case BPF_JNE: - if (tnum_is_const(subreg)) - return !tnum_equals_const(subreg, val); - else if (val < reg1->u32_min_value || val > reg1->u32_max_value) + if (tnum_is_const(t1)) + return !tnum_equals_const(t1, uval); + else if (uval < umin1 || uval > umax1) return 1; - else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value) + else if (sval < smin1 || sval > smax1) return 1; break; case BPF_JSET: - if ((~subreg.mask & subreg.value) & val) + if ((~t1.mask & t1.value) & uval) return 1; - if (!((subreg.mask | subreg.value) & val)) + if (!((t1.mask | t1.value) & uval)) return 0; break; case BPF_JGT: - if (reg1->u32_min_value > val) + if (umin1 > uval ) return 1; - else if (reg1->u32_max_value <= val) + else if (umax1 <= uval) return 0; break; case BPF_JSGT: - if (reg1->s32_min_value > sval) + if (smin1 > sval) return 1; - else if (reg1->s32_max_value <= sval) + else if (smax1 <= sval) return 0; break; case BPF_JLT: - if (reg1->u32_max_value < val) + if (umax1 < uval) return 1; - else if (reg1->u32_min_value >= val) + else if (umin1 >= uval) return 0; break; case BPF_JSLT: - if (reg1->s32_max_value < sval) + if (smax1 < sval) return 1; - else if (reg1->s32_min_value >= sval) + else if (smin1 >= sval) return 0; break; case BPF_JGE: - if (reg1->u32_min_value >= val) + if (umin1 >= uval) return 1; - else if (reg1->u32_max_value < val) + else if (umax1 < uval) return 0; break; case BPF_JSGE: - if (reg1->s32_min_value >= sval) + if (smin1 >= sval) return 1; - else if (reg1->s32_max_value < sval) + else if (smax1 < sval) return 0; break; case BPF_JLE: - if (reg1->u32_max_value <= val) + if (umax1 <= uval) return 1; - else if (reg1->u32_min_value > val) + else if (umin1 > uval) return 0; break; case BPF_JSLE: - if (reg1->s32_max_value <= sval) + if (smax1 <= sval) return 1; - else if (reg1->s32_min_value > sval) - return 0; - break; - } - - return -1; -} - - -/* - * <reg1> <op> <reg2>, currently assuming reg2 is a constant - */ -static int is_branch64_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode) -{ - u64 val = reg2->var_off.value; - s64 sval = (s64)val; - - switch (opcode) { - case BPF_JEQ: - if (tnum_is_const(reg1->var_off)) - return !!tnum_equals_const(reg1->var_off, val); - else if (val < reg1->umin_value || val > reg1->umax_value) - return 0; - else if (sval < reg1->smin_value || sval > reg1->smax_value) - return 0; - break; - case BPF_JNE: - if (tnum_is_const(reg1->var_off)) - return !tnum_equals_const(reg1->var_off, val); - else if (val < reg1->umin_value || val > reg1->umax_value) - return 1; - else if (sval < reg1->smin_value || sval > reg1->smax_value) - return 1; - break; - case BPF_JSET: - if ((~reg1->var_off.mask & reg1->var_off.value) & val) - return 1; - if (!((reg1->var_off.mask | reg1->var_off.value) & val)) - return 0; - break; - case BPF_JGT: - if (reg1->umin_value > val) - return 1; - else if (reg1->umax_value <= val) - return 0; - break; - case BPF_JSGT: - if (reg1->smin_value > sval) - return 1; - else if (reg1->smax_value <= sval) - return 0; - break; - case BPF_JLT: - if (reg1->umax_value < val) - return 1; - else if (reg1->umin_value >= val) - return 0; - break; - case BPF_JSLT: - if (reg1->smax_value < sval) - return 1; - else if (reg1->smin_value >= sval) - return 0; - break; - case BPF_JGE: - if (reg1->umin_value >= val) - return 1; - else if (reg1->umax_value < val) - return 0; - break; - case BPF_JSGE: - if (reg1->smin_value >= sval) - return 1; - else if (reg1->smax_value < sval) - return 0; - break; - case BPF_JLE: - if (reg1->umax_value <= val) - return 1; - else if (reg1->umin_value > val) - return 0; - break; - case BPF_JSLE: - if (reg1->smax_value <= sval) - return 1; - else if (reg1->smin_value > sval) + else if (smin1 > sval) return 0; break; } @@ -14456,9 +14376,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg } } - if (is_jmp32) - return is_branch32_taken(reg1, reg2, opcode); - return is_branch64_taken(reg1, reg2, opcode); + return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); } /* Adjusts the register min/max values in the case that the dst_reg is the @@ -14468,15 +14386,15 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg */ static void reg_set_min_max(struct bpf_reg_state *true_reg, struct bpf_reg_state *false_reg, - u64 val, u32 val32, + u64 uval, u32 uval32, u8 opcode, bool is_jmp32) { struct tnum false_32off = tnum_subreg(false_reg->var_off); struct tnum false_64off = false_reg->var_off; struct tnum true_32off = tnum_subreg(true_reg->var_off); struct tnum true_64off = true_reg->var_off; - s64 sval = (s64)val; - s32 sval32 = (s32)val32; + s64 sval = (s64)uval; + s32 sval32 = (s32)uval32; /* If the dst_reg is a pointer, we can't learn anything about its * variable offset from the compare (unless src_reg were a pointer into @@ -14499,49 +14417,49 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg, */ case BPF_JEQ: if (is_jmp32) { - __mark_reg32_known(true_reg, val32); + __mark_reg32_known(true_reg, uval32); true_32off = tnum_subreg(true_reg->var_off); } else { - ___mark_reg_known(true_reg, val); + ___mark_reg_known(true_reg, uval); true_64off = true_reg->var_off; } break; case BPF_JNE: if (is_jmp32) { - __mark_reg32_known(false_reg, val32); + __mark_reg32_known(false_reg, uval32); false_32off = tnum_subreg(false_reg->var_off); } else { - ___mark_reg_known(false_reg, val); + ___mark_reg_known(false_reg, uval); false_64off = false_reg->var_off; } break; case BPF_JSET: if (is_jmp32) { - false_32off = tnum_and(false_32off, tnum_const(~val32)); - if (is_power_of_2(val32)) + false_32off = tnum_and(false_32off, tnum_const(~uval32)); + if (is_power_of_2(uval32)) true_32off = tnum_or(true_32off, - tnum_const(val32)); + tnum_const(uval32)); } else { - false_64off = tnum_and(false_64off, tnum_const(~val)); - if (is_power_of_2(val)) + false_64off = tnum_and(false_64off, tnum_const(~uval)); + if (is_power_of_2(uval)) true_64off = tnum_or(true_64off, - tnum_const(val)); + tnum_const(uval)); } break; case BPF_JGE: case BPF_JGT: { if (is_jmp32) { - u32 false_umax = opcode == BPF_JGT ? val32 : val32 - 1; - u32 true_umin = opcode == BPF_JGT ? val32 + 1 : val32; + u32 false_umax = opcode == BPF_JGT ? uval32 : uval32 - 1; + u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32; false_reg->u32_max_value = min(false_reg->u32_max_value, false_umax); true_reg->u32_min_value = max(true_reg->u32_min_value, true_umin); } else { - u64 false_umax = opcode == BPF_JGT ? val : val - 1; - u64 true_umin = opcode == BPF_JGT ? val + 1 : val; + u64 false_umax = opcode == BPF_JGT ? uval : uval - 1; + u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval; false_reg->umax_value = min(false_reg->umax_value, false_umax); true_reg->umin_value = max(true_reg->umin_value, true_umin); @@ -14570,16 +14488,16 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg, case BPF_JLT: { if (is_jmp32) { - u32 false_umin = opcode == BPF_JLT ? val32 : val32 + 1; - u32 true_umax = opcode == BPF_JLT ? val32 - 1 : val32; + u32 false_umin = opcode == BPF_JLT ? uval32 : uval32 + 1; + u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32; false_reg->u32_min_value = max(false_reg->u32_min_value, false_umin); true_reg->u32_max_value = min(true_reg->u32_max_value, true_umax); } else { - u64 false_umin = opcode == BPF_JLT ? val : val + 1; - u64 true_umax = opcode == BPF_JLT ? val - 1 : val; + u64 false_umin = opcode == BPF_JLT ? uval : uval + 1; + u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval; false_reg->umin_value = max(false_reg->umin_value, false_umin); true_reg->umax_value = min(true_reg->umax_value, true_umax); @@ -14628,7 +14546,7 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg, */ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, struct bpf_reg_state *false_reg, - u64 val, u32 val32, + u64 uval, u32 uval32, u8 opcode, bool is_jmp32) { opcode = flip_opcode(opcode); @@ -14636,7 +14554,7 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, * BPF_JA, can't get here. */ if (opcode) - reg_set_min_max(true_reg, false_reg, val, val32, opcode, is_jmp32); + reg_set_min_max(true_reg, false_reg, uval, uval32, opcode, is_jmp32); } /* Regs are known to be equal, so intersect their min/max/var_off */ -- 2.34.1