Re: [PATCH bpf-next 02/13] bpf: generalize is_scalar_branch_taken() logic

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote:
> Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> cases when both registers are not constants. Previously supported
> <range> vs <scalar> cases are a natural subset of more generic <range>
> vs <range> set of cases.
> 
> Generalized logic relies on straightforward segment intersection checks.
> 
> Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>

Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx>

(With the same nitpick that '<' cases could be converted to '>' cases).

> ---
>  kernel/bpf/verifier.c | 103 ++++++++++++++++++++++++++----------------
>  1 file changed, 63 insertions(+), 40 deletions(-)
> 
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 52934080042c..2627461164ed 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14187,82 +14187,104 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
>  				  u8 opcode, bool is_jmp32)
>  {
>  	struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> +	struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
>  	u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
>  	u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
>  	s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
>  	s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> -	u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> -	s64 sval = is_jmp32 ? (s32)uval : (s64)uval;
> +	u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> +	u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> +	s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> +	s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
>  
>  	switch (opcode) {
>  	case BPF_JEQ:
> -		if (tnum_is_const(t1))
> -			return !!tnum_equals_const(t1, uval);
> -		else if (uval < umin1 || uval > umax1)
> +		/* constants, umin/umax and smin/smax checks would be
> +		 * redundant in this case because they all should match
> +		 */
> +		if (tnum_is_const(t1) && tnum_is_const(t2))
> +			return t1.value == t2.value;
> +		/* const ranges */
> +		if (umin1 == umax1 && umin2 == umax2)
> +			return umin1 == umin2;
> +		if (smin1 == smax1 && smin2 == smax2)
> +			return smin1 == smin2;
> +		/* non-overlapping ranges */
> +		if (umin1 > umax2 || umax1 < umin2)
>  			return 0;
> -		else if (sval < smin1 || sval > smax1)
> +		if (smin1 > smax2 || smax1 < smin2)
>  			return 0;
>  		break;
>  	case BPF_JNE:
> -		if (tnum_is_const(t1))
> -			return !tnum_equals_const(t1, uval);
> -		else if (uval < umin1 || uval > umax1)
> +		/* constants, umin/umax and smin/smax checks would be
> +		 * redundant in this case because they all should match
> +		 */
> +		if (tnum_is_const(t1) && tnum_is_const(t2))
> +			return t1.value != t2.value;
> +		/* non-overlapping ranges */
> +		if (umin1 > umax2 || umax1 < umin2)
>  			return 1;
> -		else if (sval < smin1 || sval > smax1)
> +		if (smin1 > smax2 || smax1 < smin2)
>  			return 1;
>  		break;
>  	case BPF_JSET:
> -		if ((~t1.mask & t1.value) & uval)
> +		if (!is_reg_const(reg2, is_jmp32)) {
> +			swap(reg1, reg2);
> +			swap(t1, t2);
> +		}
> +		if (!is_reg_const(reg2, is_jmp32))
> +			return -1;
> +		if ((~t1.mask & t1.value) & t2.value)
>  			return 1;
> -		if (!((t1.mask | t1.value) & uval))
> +		if (!((t1.mask | t1.value) & t2.value))
>  			return 0;
>  		break;
>  	case BPF_JGT:
> -		if (umin1 > uval )
> +		if (umin1 > umax2)
>  			return 1;
> -		else if (umax1 <= uval)
> +		else if (umax1 <= umin2)
>  			return 0;
>  		break;
>  	case BPF_JSGT:
> -		if (smin1 > sval)
> +		if (smin1 > smax2)
>  			return 1;
> -		else if (smax1 <= sval)
> +		else if (smax1 <= smin2)
>  			return 0;
>  		break;
>  	case BPF_JLT:
> -		if (umax1 < uval)
> +		if (umax1 < umin2)
>  			return 1;
> -		else if (umin1 >= uval)
> +		else if (umin1 >= umax2)
>  			return 0;
>  		break;
>  	case BPF_JSLT:
> -		if (smax1 < sval)
> +		if (smax1 < smin2)
>  			return 1;
> -		else if (smin1 >= sval)
> +		else if (smin1 >= smax2)
>  			return 0;
>  		break;
>  	case BPF_JGE:
> -		if (umin1 >= uval)
> +		if (umin1 >= umax2)
>  			return 1;
> -		else if (umax1 < uval)
> +		else if (umax1 < umin2)
>  			return 0;
>  		break;
>  	case BPF_JSGE:
> -		if (smin1 >= sval)
> +		if (smin1 >= smax2)
>  			return 1;
> -		else if (smax1 < sval)
> +		else if (smax1 < smin2)
>  			return 0;
>  		break;
>  	case BPF_JLE:
> -		if (umax1 <= uval)
> +		if (umax1 <= umin2)
>  			return 1;
> -		else if (umin1 > uval)
> +		else if (umin1 > umax2)
>  			return 0;
>  		break;
>  	case BPF_JSLE:
> -		if (smax1 <= sval)
> +		if (smax1 <= smin2)
>  			return 1;
> -		else if (smin1 > sval)
> +		else if (smin1 > smax2)
>  			return 0;
>  		break;
>  	}
> @@ -14341,28 +14363,28 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
>  static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
>  			   u8 opcode, bool is_jmp32)
>  {
> -	u64 val;
> -
>  	if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32)
>  		return is_pkt_ptr_branch_taken(reg1, reg2, opcode);
>  
> -	/* try to make sure reg2 is a constant SCALAR_VALUE */
> -	if (!is_reg_const(reg2, is_jmp32)) {
> -		opcode = flip_opcode(opcode);
> -		swap(reg1, reg2);
> -	}
> -	/* for now we expect reg2 to be a constant to make any useful decisions */
> -	if (!is_reg_const(reg2, is_jmp32))
> -		return -1;
> -	val = reg_const_value(reg2, is_jmp32);
> +	if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) {
> +		u64 val;
> +
> +		/* arrange that reg2 is a scalar, and reg1 is a pointer */
> +		if (!is_reg_const(reg2, is_jmp32)) {
> +			opcode = flip_opcode(opcode);
> +			swap(reg1, reg2);
> +		}
> +		/* and ensure that reg2 is a constant */
> +		if (!is_reg_const(reg2, is_jmp32))
> +			return -1;
>  
> -	if (__is_pointer_value(false, reg1)) {
>  		if (!reg_not_null(reg1))
>  			return -1;
>  
>  		/* If pointer is valid tests against zero will fail so we can
>  		 * use this to direct branch taken.
>  		 */
> +		val = reg_const_value(reg2, is_jmp32);
>  		if (val != 0)
>  			return -1;
>  
> @@ -14376,6 +14398,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
>  		}
>  	}
>  
> +	/* now deal with two scalars, but not necessarily constants */
>  	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
>  }
>  






[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux