Re: [PATCH v5 bpf-next 15/23] bpf: unify 32-bit and 64-bit is_branch_taken logic

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Fri, 2023-10-27 at 11:13 -0700, Andrii Nakryiko wrote:
> Combine 32-bit and 64-bit is_branch_taken logic for SCALAR_VALUE
> registers. It makes it easier to see parallels between two domains
> (32-bit and 64-bit), and makes subsequent refactoring more
> straightforward.
> 
> No functional changes.
> 
> Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>

Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx>

> ---
>  kernel/bpf/verifier.c | 154 ++++++++++--------------------------------
>  1 file changed, 36 insertions(+), 118 deletions(-)
> 
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index fedd6d0e76e5..b911d1111fad 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14185,166 +14185,86 @@ static u64 reg_const_value(struct bpf_reg_state *reg, bool subreg32)
>  /*
>   * <reg1> <op> <reg2>, currently assuming reg2 is a constant
>   */
> -static int is_branch32_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode)
> +static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
> +				  u8 opcode, bool is_jmp32)
>  {
> -	struct tnum subreg = tnum_subreg(reg1->var_off);
> -	u32 val = (u32)tnum_subreg(reg2->var_off).value;
> -	s32 sval = (s32)val;
> +	struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> +	u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> +	u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> +	s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> +	s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> +	u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> +	s64 sval = is_jmp32 ? (s32)val : (s64)val;
>  
>  	switch (opcode) {
>  	case BPF_JEQ:
> -		if (tnum_is_const(subreg))
> -			return !!tnum_equals_const(subreg, val);
> -		else if (val < reg1->u32_min_value || val > reg1->u32_max_value)
> +		if (tnum_is_const(t1))
> +			return !!tnum_equals_const(t1, val);
> +		else if (val < umin1 || val > umax1)
>  			return 0;
> -		else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value)
> +		else if (sval < smin1 || sval > smax1)
>  			return 0;
>  		break;
>  	case BPF_JNE:
> -		if (tnum_is_const(subreg))
> -			return !tnum_equals_const(subreg, val);
> -		else if (val < reg1->u32_min_value || val > reg1->u32_max_value)
> +		if (tnum_is_const(t1))
> +			return !tnum_equals_const(t1, val);
> +		else if (val < umin1 || val > umax1)
>  			return 1;
> -		else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value)
> +		else if (sval < smin1 || sval > smax1)
>  			return 1;
>  		break;
>  	case BPF_JSET:
> -		if ((~subreg.mask & subreg.value) & val)
> +		if ((~t1.mask & t1.value) & val)
>  			return 1;
> -		if (!((subreg.mask | subreg.value) & val))
> +		if (!((t1.mask | t1.value) & val))
>  			return 0;
>  		break;
>  	case BPF_JGT:
> -		if (reg1->u32_min_value > val)
> +		if (umin1 > val )
>  			return 1;
> -		else if (reg1->u32_max_value <= val)
> +		else if (umax1 <= val)
>  			return 0;
>  		break;
>  	case BPF_JSGT:
> -		if (reg1->s32_min_value > sval)
> +		if (smin1 > sval)
>  			return 1;
> -		else if (reg1->s32_max_value <= sval)
> +		else if (smax1 <= sval)
>  			return 0;
>  		break;
>  	case BPF_JLT:
> -		if (reg1->u32_max_value < val)
> +		if (umax1 < val)
>  			return 1;
> -		else if (reg1->u32_min_value >= val)
> +		else if (umin1 >= val)
>  			return 0;
>  		break;
>  	case BPF_JSLT:
> -		if (reg1->s32_max_value < sval)
> +		if (smax1 < sval)
>  			return 1;
> -		else if (reg1->s32_min_value >= sval)
> +		else if (smin1 >= sval)
>  			return 0;
>  		break;
>  	case BPF_JGE:
> -		if (reg1->u32_min_value >= val)
> +		if (umin1 >= val)
>  			return 1;
> -		else if (reg1->u32_max_value < val)
> +		else if (umax1 < val)
>  			return 0;
>  		break;
>  	case BPF_JSGE:
> -		if (reg1->s32_min_value >= sval)
> +		if (smin1 >= sval)
>  			return 1;
> -		else if (reg1->s32_max_value < sval)
> +		else if (smax1 < sval)
>  			return 0;
>  		break;
>  	case BPF_JLE:
> -		if (reg1->u32_max_value <= val)
> +		if (umax1 <= val)
>  			return 1;
> -		else if (reg1->u32_min_value > val)
> +		else if (umin1 > val)
>  			return 0;
>  		break;
>  	case BPF_JSLE:
> -		if (reg1->s32_max_value <= sval)
> +		if (smax1 <= sval)
>  			return 1;
> -		else if (reg1->s32_min_value > sval)
> -			return 0;
> -		break;
> -	}
> -
> -	return -1;
> -}
> -
> -
> -/*
> - * <reg1> <op> <reg2>, currently assuming reg2 is a constant
> - */
> -static int is_branch64_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode)
> -{
> -	u64 val = reg2->var_off.value;
> -	s64 sval = (s64)val;
> -
> -	switch (opcode) {
> -	case BPF_JEQ:
> -		if (tnum_is_const(reg1->var_off))
> -			return !!tnum_equals_const(reg1->var_off, val);
> -		else if (val < reg1->umin_value || val > reg1->umax_value)
> -			return 0;
> -		else if (sval < reg1->smin_value || sval > reg1->smax_value)
> -			return 0;
> -		break;
> -	case BPF_JNE:
> -		if (tnum_is_const(reg1->var_off))
> -			return !tnum_equals_const(reg1->var_off, val);
> -		else if (val < reg1->umin_value || val > reg1->umax_value)
> -			return 1;
> -		else if (sval < reg1->smin_value || sval > reg1->smax_value)
> -			return 1;
> -		break;
> -	case BPF_JSET:
> -		if ((~reg1->var_off.mask & reg1->var_off.value) & val)
> -			return 1;
> -		if (!((reg1->var_off.mask | reg1->var_off.value) & val))
> -			return 0;
> -		break;
> -	case BPF_JGT:
> -		if (reg1->umin_value > val)
> -			return 1;
> -		else if (reg1->umax_value <= val)
> -			return 0;
> -		break;
> -	case BPF_JSGT:
> -		if (reg1->smin_value > sval)
> -			return 1;
> -		else if (reg1->smax_value <= sval)
> -			return 0;
> -		break;
> -	case BPF_JLT:
> -		if (reg1->umax_value < val)
> -			return 1;
> -		else if (reg1->umin_value >= val)
> -			return 0;
> -		break;
> -	case BPF_JSLT:
> -		if (reg1->smax_value < sval)
> -			return 1;
> -		else if (reg1->smin_value >= sval)
> -			return 0;
> -		break;
> -	case BPF_JGE:
> -		if (reg1->umin_value >= val)
> -			return 1;
> -		else if (reg1->umax_value < val)
> -			return 0;
> -		break;
> -	case BPF_JSGE:
> -		if (reg1->smin_value >= sval)
> -			return 1;
> -		else if (reg1->smax_value < sval)
> -			return 0;
> -		break;
> -	case BPF_JLE:
> -		if (reg1->umax_value <= val)
> -			return 1;
> -		else if (reg1->umin_value > val)
> -			return 0;
> -		break;
> -	case BPF_JSLE:
> -		if (reg1->smax_value <= sval)
> -			return 1;
> -		else if (reg1->smin_value > sval)
> +		else if (smin1 > sval)
>  			return 0;
>  		break;
>  	}
> @@ -14458,9 +14378,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
>  		}
>  	}
>  
> -	if (is_jmp32)
> -		return is_branch32_taken(reg1, reg2, opcode);
> -	return is_branch64_taken(reg1, reg2, opcode);
> +	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
>  }
>  
>  /* Adjusts the register min/max values in the case that the dst_reg is the






[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux