Re: [PATCH bpf-next 01/13] bpf: generalize reg_set_min_max() to handle non-const register comparisons

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote:
> Generalize bounds adjustment logic of reg_set_min_max() to handle not
> just register vs constant case, but in general any register vs any
> register cases. For most of the operations it's trivial extension based
> on range vs range comparison logic, we just need to properly pick
> min/max of a range to compare against min/max of the other range.
> 
> For BPF_JSET we keep the original capabilities, just make sure JSET is
> integrated in the common framework. This is manifested in the
> internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
> uniform rev_opcode() handling. See the code for details. This allows to
> reuse the same code exactly both for TRUE and FALSE branches without
> explicitly handling both conditions with custom code.
> 
> Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> case none of the registers are constants. This is now just a normal
> generic case handled by reg_set_min_max().
> 
> To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> that's a common operator when dealing with 32-bit subregister bounds.
> This keeps the overall logic much less noisy when it comes to tnums.
> 
> Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>

Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx>

(With one bit of a bikeshedding below).

> ---
>  include/linux/tnum.h  |   4 +
>  kernel/bpf/tnum.c     |   7 +-
>  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
>  3 files changed, 165 insertions(+), 173 deletions(-)
> 
> diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> index 1c3948a1d6ad..3c13240077b8 100644
> --- a/include/linux/tnum.h
> +++ b/include/linux/tnum.h
> @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
>  struct tnum tnum_subreg(struct tnum a);
>  /* Returns the tnum with the lower 32-bit subreg cleared */
>  struct tnum tnum_clear_subreg(struct tnum a);
> +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> + * 32-bit subreg in *subreg*
> + */
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
>  /* Returns the tnum with the lower 32-bit subreg set to value */
>  struct tnum tnum_const_subreg(struct tnum a, u32 value);
>  /* Returns true if 32-bit subreg @a is a known constant*/
> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 3d7127f439a1..f4c91c9b27d7 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
>  	return tnum_lshift(tnum_rshift(a, 32), 32);
>  }
>  
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> +{
> +	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> +}
> +
>  struct tnum tnum_const_subreg(struct tnum a, u32 value)
>  {
> -	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> +	return tnum_with_subreg(a, tnum_const(value));
>  }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 2197385d91dc..52934080042c 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
>  	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
>  }
>  
> -/* Adjusts the register min/max values in the case that the dst_reg and
> - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> - * Technically we can do similar adjustments for pointers to the same object,
> - * but we don't support that right now.
> +/* Opcode that corresponds to a *false* branch condition.
> + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
>   */
> -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> -			    struct bpf_reg_state *true_reg2,
> -			    struct bpf_reg_state *false_reg1,
> -			    struct bpf_reg_state *false_reg2,
> -			    u8 opcode, bool is_jmp32)
> +static u8 rev_opcode(u8 opcode)
>  {
> -	struct tnum false_32off, false_64off;
> -	struct tnum true_32off, true_64off;
> -	u64 uval;
> -	u32 uval32;
> -	s64 sval;
> -	s32 sval32;
> -
> -	/* If either register is a pointer, we can't learn anything about its
> -	 * variable offset from the compare (unless they were a pointer into
> -	 * the same object, but we don't bother with that).
> +	switch (opcode) {
> +	case BPF_JEQ:		return BPF_JNE;
> +	case BPF_JNE:		return BPF_JEQ;
> +	/* JSET doesn't have it's reverse opcode in BPF, so add
> +	 * BPF_X flag to denote the reverse of that operation
>  	 */
> -	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> -		return;
> -
> -	/* we expect right-hand registers (src ones) to be constants, for now */
> -	if (!is_reg_const(false_reg2, is_jmp32)) {
> -		opcode = flip_opcode(opcode);
> -		swap(true_reg1, true_reg2);
> -		swap(false_reg1, false_reg2);
> +	case BPF_JSET:		return BPF_JSET | BPF_X;
> +	case BPF_JSET | BPF_X:	return BPF_JSET;
> +	case BPF_JGE:		return BPF_JLT;
> +	case BPF_JGT:		return BPF_JLE;
> +	case BPF_JLE:		return BPF_JGT;
> +	case BPF_JLT:		return BPF_JGE;
> +	case BPF_JSGE:		return BPF_JSLT;
> +	case BPF_JSGT:		return BPF_JSLE;
> +	case BPF_JSLE:		return BPF_JSGT;
> +	case BPF_JSLT:		return BPF_JSGE;
> +	default:		return 0;
>  	}
> -	if (!is_reg_const(false_reg2, is_jmp32))
> -		return;
> +}
>  
> -	false_32off = tnum_subreg(false_reg1->var_off);
> -	false_64off = false_reg1->var_off;
> -	true_32off = tnum_subreg(true_reg1->var_off);
> -	true_64off = true_reg1->var_off;
> -	uval = false_reg2->var_off.value;
> -	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
> -	sval = (s64)uval;
> -	sval32 = (s32)uval32;
> +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
> +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
> +				u8 opcode, bool is_jmp32)
> +{
> +	struct tnum t;
>  
>  	switch (opcode) {
> -	/* JEQ/JNE comparison doesn't change the register equivalence.
> -	 *
> -	 * r1 = r2;
> -	 * if (r1 == 42) goto label;
> -	 * ...
> -	 * label: // here both r1 and r2 are known to be 42.
> -	 *
> -	 * Hence when marking register as known preserve it's ID.
> -	 */
>  	case BPF_JEQ:
>  		if (is_jmp32) {
> -			__mark_reg32_known(true_reg1, uval32);
> -			true_32off = tnum_subreg(true_reg1->var_off);
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +			reg2->u32_min_value = reg1->u32_min_value;
> +			reg2->u32_max_value = reg1->u32_max_value;
> +			reg2->s32_min_value = reg1->s32_min_value;
> +			reg2->s32_max_value = reg1->s32_max_value;
> +
> +			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> +			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
>  		} else {
> -			___mark_reg_known(true_reg1, uval);
> -			true_64off = true_reg1->var_off;
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> +			reg2->umin_value = reg1->umin_value;
> +			reg2->umax_value = reg1->umax_value;
> +			reg2->smin_value = reg1->smin_value;
> +			reg2->smax_value = reg1->smax_value;
> +
> +			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
> +			reg2->var_off = reg1->var_off;
>  		}
>  		break;
>  	case BPF_JNE:
> +		/* we don't derive any new information for inequality yet */
> +		break;
> +	case BPF_JSET:
> +	case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */
> +		u64 val;
> +
> +		if (!is_reg_const(reg2, is_jmp32))
> +			swap(reg1, reg2);
> +		if (!is_reg_const(reg2, is_jmp32))
> +			break;
> +
> +		val = reg_const_value(reg2, is_jmp32);
> +		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
> +		 * requires single bit to learn something useful. E.g., if we
> +		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
> +		 * are actually set? We can learn something definite only if
> +		 * it's a single-bit value to begin with.
> +		 *
> +		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
> +		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
> +		 * bit 1 is set, which we can readily use in adjustments.
> +		 */
> +		if (!(opcode & BPF_X) && !is_power_of_2(val))
> +			break;
> +
>  		if (is_jmp32) {
> -			__mark_reg32_known(false_reg1, uval32);
> -			false_32off = tnum_subreg(false_reg1->var_off);
> +			if (opcode & BPF_X)
> +				t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> +			else
> +				t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
>  		} else {
> -			___mark_reg_known(false_reg1, uval);
> -			false_64off = false_reg1->var_off;
> +			if (opcode & BPF_X)
> +				reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> +			else
> +				reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
>  		}
>  		break;
> -	case BPF_JSET:
> +	}
> +	case BPF_JGE:
>  		if (is_jmp32) {
> -			false_32off = tnum_and(false_32off, tnum_const(~uval32));
> -			if (is_power_of_2(uval32))
> -				true_32off = tnum_or(true_32off,
> -						     tnum_const(uval32));
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +			reg2->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
>  		} else {
> -			false_64off = tnum_and(false_64off, tnum_const(~uval));
> -			if (is_power_of_2(uval))
> -				true_64off = tnum_or(true_64off,
> -						     tnum_const(uval));
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> +			reg2->umax_value = min(reg1->umax_value, reg2->umax_value);
>  		}
>  		break;
> -	case BPF_JGE:
>  	case BPF_JGT:
> -	{
>  		if (is_jmp32) {
> -			u32 false_umax = opcode == BPF_JGT ? uval32  : uval32 - 1;
> -			u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32;
> -
> -			false_reg1->u32_max_value = min(false_reg1->u32_max_value,
> -						       false_umax);
> -			true_reg1->u32_min_value = max(true_reg1->u32_min_value,
> -						      true_umin);
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value + 1);
> +			reg2->u32_max_value = min(reg1->u32_max_value - 1, reg2->u32_max_value);
>  		} else {
> -			u64 false_umax = opcode == BPF_JGT ? uval    : uval - 1;
> -			u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval;
> -
> -			false_reg1->umax_value = min(false_reg1->umax_value, false_umax);
> -			true_reg1->umin_value = max(true_reg1->umin_value, true_umin);
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value + 1);
> +			reg2->umax_value = min(reg1->umax_value - 1, reg2->umax_value);
>  		}
>  		break;
> -	}
>  	case BPF_JSGE:
> +		if (is_jmp32) {
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +			reg2->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +		} else {
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> +			reg2->smax_value = min(reg1->smax_value, reg2->smax_value);
> +		}
> +		break;
>  	case BPF_JSGT:

It is possible to spare some code by swapping arguments here:

	case BPF_JLE:
	case BPF_JLT:
	case BPF_JSLE:
	case BPF_JSLT:
		return regs_refine_cond_op(reg2, reg1, flip_opcode(opcode), is_jmp32);


> -	{
>  		if (is_jmp32) {
> -			s32 false_smax = opcode == BPF_JSGT ? sval32    : sval32 - 1;
> -			s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32;
> -
> -			false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax);
> -			true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin);
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value + 1);
> +			reg2->s32_max_value = min(reg1->s32_max_value - 1, reg2->s32_max_value);
>  		} else {
> -			s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
> -			s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
> -
> -			false_reg1->smax_value = min(false_reg1->smax_value, false_smax);
> -			true_reg1->smin_value = max(true_reg1->smin_value, true_smin);
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value + 1);
> +			reg2->smax_value = min(reg1->smax_value - 1, reg2->smax_value);
>  		}
>  		break;
> -	}
>  	case BPF_JLE:
> +		if (is_jmp32) {
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> +			reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +		} else {
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> +			reg2->umin_value = max(reg1->umin_value, reg2->umin_value);
> +		}
> +		break;
>  	case BPF_JLT:
> -	{
>  		if (is_jmp32) {
> -			u32 false_umin = opcode == BPF_JLT ? uval32  : uval32 + 1;
> -			u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32;
> -
> -			false_reg1->u32_min_value = max(false_reg1->u32_min_value,
> -						       false_umin);
> -			true_reg1->u32_max_value = min(true_reg1->u32_max_value,
> -						      true_umax);
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1);
> +			reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value);
>  		} else {
> -			u64 false_umin = opcode == BPF_JLT ? uval    : uval + 1;
> -			u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval;
> -
> -			false_reg1->umin_value = max(false_reg1->umin_value, false_umin);
> -			true_reg1->umax_value = min(true_reg1->umax_value, true_umax);
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1);
> +			reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value);
>  		}
>  		break;
> -	}
>  	case BPF_JSLE:
> +		if (is_jmp32) {
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +			reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +		} else {
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> +			reg2->smin_value = max(reg1->smin_value, reg2->smin_value);
> +		}
> +		break;
>  	case BPF_JSLT:
> -	{
>  		if (is_jmp32) {
> -			s32 false_smin = opcode == BPF_JSLT ? sval32    : sval32 + 1;
> -			s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32;
> -
> -			false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin);
> -			true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax);
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1);
> +			reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value);
>  		} else {
> -			s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
> -			s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
> -
> -			false_reg1->smin_value = max(false_reg1->smin_value, false_smin);
> -			true_reg1->smax_value = min(true_reg1->smax_value, true_smax);
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1);
> +			reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value);
>  		}
>  		break;
> -	}
>  	default:
>  		return;
>  	}
> -
> -	if (is_jmp32) {
> -		false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off),
> -					     tnum_subreg(false_32off));
> -		true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off),
> -					    tnum_subreg(true_32off));
> -		reg_bounds_sync(false_reg1);
> -		reg_bounds_sync(true_reg1);
> -	} else {
> -		false_reg1->var_off = false_64off;
> -		true_reg1->var_off = true_64off;
> -		reg_bounds_sync(false_reg1);
> -		reg_bounds_sync(true_reg1);
> -	}
> -}
> -
> -/* Regs are known to be equal, so intersect their min/max/var_off */
> -static void __reg_combine_min_max(struct bpf_reg_state *src_reg,
> -				  struct bpf_reg_state *dst_reg)
> -{
> -	src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value,
> -							dst_reg->umin_value);
> -	src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value,
> -							dst_reg->umax_value);
> -	src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value,
> -							dst_reg->smin_value);
> -	src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value,
> -							dst_reg->smax_value);
> -	src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off,
> -							     dst_reg->var_off);
> -	reg_bounds_sync(src_reg);
> -	reg_bounds_sync(dst_reg);
>  }
>  
> -static void reg_combine_min_max(struct bpf_reg_state *true_src,
> -				struct bpf_reg_state *true_dst,
> -				struct bpf_reg_state *false_src,
> -				struct bpf_reg_state *false_dst,
> -				u8 opcode)
> +/* Adjusts the register min/max values in the case that the dst_reg and
> + * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> + * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> + * Technically we can do similar adjustments for pointers to the same object,
> + * but we don't support that right now.
> + */
> +static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> +			    struct bpf_reg_state *true_reg2,
> +			    struct bpf_reg_state *false_reg1,
> +			    struct bpf_reg_state *false_reg2,
> +			    u8 opcode, bool is_jmp32)
>  {
> -	switch (opcode) {
> -	case BPF_JEQ:
> -		__reg_combine_min_max(true_src, true_dst);
> -		break;
> -	case BPF_JNE:
> -		__reg_combine_min_max(false_src, false_dst);
> -		break;
> -	}
> +	/* If either register is a pointer, we can't learn anything about its
> +	 * variable offset from the compare (unless they were a pointer into
> +	 * the same object, but we don't bother with that).
> +	 */
> +	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> +		return;
> +
> +	/* fallthrough (FALSE) branch */
> +	regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32);
> +	reg_bounds_sync(false_reg1);
> +	reg_bounds_sync(false_reg2);
> +
> +	/* jump (TRUE) branch */
> +	regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
> +	reg_bounds_sync(true_reg1);
> +	reg_bounds_sync(true_reg2);
>  }
>  
>  static void mark_ptr_or_null_reg(struct bpf_func_state *state,
> @@ -14887,22 +14880,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
>  		reg_set_min_max(&other_branch_regs[insn->dst_reg],
>  				&other_branch_regs[insn->src_reg],
>  				dst_reg, src_reg, opcode, is_jmp32);
> -
> -		if (dst_reg->type == SCALAR_VALUE &&
> -		    src_reg->type == SCALAR_VALUE &&
> -		    !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) {
> -			/* Comparing for equality, we can combine knowledge */
> -			reg_combine_min_max(&other_branch_regs[insn->src_reg],
> -					    &other_branch_regs[insn->dst_reg],
> -					    src_reg, dst_reg, opcode);
> -		}
>  	} else /* BPF_SRC(insn->code) == BPF_K */ {
>  		reg_set_min_max(&other_branch_regs[insn->dst_reg],
>  				src_reg /* fake one */,
>  				dst_reg, src_reg /* same fake one */,
>  				opcode, is_jmp32);
>  	}
> -
>  	if (BPF_SRC(insn->code) == BPF_X &&
>  	    src_reg->type == SCALAR_VALUE && src_reg->id &&
>  	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {






[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux