Re: [RFC bpf-next] bpf, verifier: improve signed ranges inference for BPF_AND

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Tue, Jul 16, 2024 at 10:52:26PM GMT, Shung-Hsi Yu wrote:
> This commit teach the BPF verifier how to infer signed ranges directly
> from signed ranges of the operands to prevent verifier rejection
...
> ---
>  kernel/bpf/verifier.c | 62 +++++++++++++++++++++++++++++--------------
>  1 file changed, 42 insertions(+), 20 deletions(-)
> 
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 8da132a1ef28..6d4cdf30cd76 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -13466,6 +13466,39 @@ static void scalar_min_max_mul(struct bpf_reg_state *dst_reg,
>  	}
>  }
>  
> +/* Clears all trailing bits after the most significant unset bit.
> + *
> + * Used for estimating the minimum possible value after BPF_AND. This
> + * effectively rounds a negative value down to a negative power-of-2 value
> + * (except for -1, which just return -1) and returning 0 for non-negative
> + * values. E.g. masked32_negative(0xff0ff0ff) == 0xff000000.

s/masked32_negative/negative32_bit_floor/

> + */
> +static inline s32 negative32_bit_floor(s32 v)
> +{
> +	/* XXX: per C standard section 6.5.7 right shift of signed negative
> +	 * value is implementation-defined. Should unsigned type be used here
> +	 * instead?
> +	 */
> +	v &= v >> 1;
> +	v &= v >> 2;
> +	v &= v >> 4;
> +	v &= v >> 8;
> +	v &= v >> 16;
> +	return v;
> +}
> +
> +/* Same as negative32_bit_floor() above, but for 64-bit signed value */
> +static inline s64 negative_bit_floor(s64 v)
> +{
> +	v &= v >> 1;
> +	v &= v >> 2;
> +	v &= v >> 4;
> +	v &= v >> 8;
> +	v &= v >> 16;
> +	v &= v >> 32;
> +	return v;
> +}
> +
>  static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>  				 struct bpf_reg_state *src_reg)
>  {
> @@ -13485,16 +13518,10 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>  	dst_reg->u32_min_value = var32_off.value;
>  	dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>  
> -	/* Safe to set s32 bounds by casting u32 result into s32 when u32
> -	 * doesn't cross sign boundary. Otherwise set s32 bounds to unbounded.
> -	 */
> -	if ((s32)dst_reg->u32_min_value <= (s32)dst_reg->u32_max_value) {
> -		dst_reg->s32_min_value = dst_reg->u32_min_value;
> -		dst_reg->s32_max_value = dst_reg->u32_max_value;
> -	} else {
> -		dst_reg->s32_min_value = S32_MIN;
> -		dst_reg->s32_max_value = S32_MAX;
> -	}
> +	/* Rough estimate tuned for [-1, 0] & -CONSTANT cases. */
> +	dst_reg->s32_min_value = negative32_bit_floor(min(dst_reg->s32_min_value,
> +							  src_reg->s32_min_value));
> +	dst_reg->s32_max_value = max(dst_reg->s32_max_value, src_reg->s32_max_value);
>  }
>  
>  static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
> @@ -13515,16 +13542,11 @@ static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
>  	dst_reg->umin_value = dst_reg->var_off.value;
>  	dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
>  
> -	/* Safe to set s64 bounds by casting u64 result into s64 when u64
> -	 * doesn't cross sign boundary. Otherwise set s64 bounds to unbounded.
> -	 */
> -	if ((s64)dst_reg->umin_value <= (s64)dst_reg->umax_value) {
> -		dst_reg->smin_value = dst_reg->umin_value;
> -		dst_reg->smax_value = dst_reg->umax_value;
> -	} else {
> -		dst_reg->smin_value = S64_MIN;
> -		dst_reg->smax_value = S64_MAX;
> -	}
> +	/* Rough estimate tuned for [-1, 0] & -CONSTANT cases. */
> +	dst_reg->smin_value = negative_bit_floor(min(dst_reg->smin_value,
> +						     src_reg->smin_value));
> +	dst_reg->smax_value = max(dst_reg->smax_value, src_reg->smax_value);
> +
>  	/* We may learn something more from the var_off */
>  	__update_reg_bounds(dst_reg);
>  }

Checked that this passes BPF CI[0] (except s390x-gcc/test_verifier,
which seems stucked), and verified the logic with z3 (see attached
Python script, adapted from [1]); so it seems to work.

Will try running tools/testing/selftests/bpf/prog_tests/reg_bounds.c
against it next.

0: https://github.com/kernel-patches/bpf/actions/runs/9958322024
1: https://github.com/bpfverif/tnums-cgo22/blob/main/verification/tnum.py
#!/usr/bin/env python3
# Need python3-z3/Z3Py to run
from math import floor, log2
from z3 import *

SIZE = 32
SIZE_LOG_2 = floor(log2(SIZE))


class SignedRange:
    name: str
    min: BitVecRef
    max: BitVecRef

    def __init__(self, name, min=None, max=None):
        self.name = name
        if min is None:
            self.min = BitVec(f'SignedRange({self.name}).min', bv=SIZE)
        elif isinstance(min, int):
            self.min = BitVecVal(min, bv=SIZE)
        else:
            self.min = min
        if max is None:
            self.max = BitVec(f'SignedRange({self.name}).max', bv=SIZE)
        elif isinstance(max, int):
            self.max = BitVecVal(max, bv=SIZE)
        else:
            self.max = max

    def wellformed(self):
        return self.min <= self.max

    def contains(self, val):
        if isinstance(val, int):
            val = BitVecVal(val, bv=SIZE)
        return And(self.min <= val, val <= self.max)


def negative_bit_floor(x: BitVecRef):
    for i in range(0, SIZE_LOG_2):
        shift_count = 2**i
        # Use arithmetic right shift to preserve leading signed bit
        x &= x >> shift_count
    return x


s = Solver()
premises = []

# Given x that is within a well-formed srange1, and y that is within a
# well-formed srange2
x = BitVec('x', bv=SIZE)
srange1 = SignedRange('srange1')
premises += [
    srange1.wellformed(),
    srange1.contains(x),
]

y = BitVec('y', bv=SIZE)
srange2 = SignedRange('srange2')
premises += [
    srange2.wellformed(),
    srange2.contains(y),
]

# Calculate x & y
actual = x & y

# x & y will always be LESS than or equal to max(srange1.max, srange2.max)
guessed_max = BitVec('guessed_max', bv=SIZE)
premises += [
    guessed_max == If(srange1.max > srange2.max, srange1.max, srange2.max)
]

# x & y will always be GREATER than or equal to negative_bit_floor(min(srange1.min, srange2.max)
guessed_min = BitVec('guessed_min', bv=SIZE)
premises += [
    guessed_min == negative_bit_floor(If(srange1.min > srange2.min, srange2.min, srange1.min)),
]

# Check result
s.add(Not(
    Implies(
        And(premises),
        And(guessed_min <= actual, actual <= guessed_max))
))
result = s.check()

if result != sat:
    print('Proved!')
else:
    print('Found counter example')
    print(s.model())

[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux