Re: [PATCH v5 bpf-next 19/23] bpf: generalize is_scalar_branch_taken() logic

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Tue, Oct 31, 2023 at 1:53 PM Andrii Nakryiko
<andrii.nakryiko@xxxxxxxxx> wrote:
>
> On Tue, Oct 31, 2023 at 11:01 AM Andrii Nakryiko
> <andrii.nakryiko@xxxxxxxxx> wrote:
> >
> > On Tue, Oct 31, 2023 at 9:35 AM Alexei Starovoitov
> > <alexei.starovoitov@xxxxxxxxx> wrote:
> > >
> > > On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
> > > <andrii.nakryiko@xxxxxxxxx> wrote:
> > > >
> > > > On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> > > > <alexei.starovoitov@xxxxxxxxx> wrote:
> > > > >
> > > > > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > > > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > > > > cases when both registers are not constants. Previously supported
> > > > > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > > > > vs <range> set of cases.
> > > > > >
> > > > > > Generalized logic relies on straightforward segment intersection checks.
> > > > > >
> > > > > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>
> > > > > > ---
> > > > > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > > > > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > > > > >
> > > > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > > > index 4c974296127b..f18a8247e5e2 100644
> > > > > > --- a/kernel/bpf/verifier.c
> > > > > > +++ b/kernel/bpf/verifier.c
> > > > > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > > > > >                                 u8 opcode, bool is_jmp32)
> > > > > >  {
> > > > > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > > > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > > > > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > > > > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > > > > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > > > > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > > > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > > > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > > > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > > > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > > > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > > > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > > > > >
> > > > > >       switch (opcode) {
> > > > > >       case BPF_JEQ:
> > > > > > -             if (tnum_is_const(t1))
> > > > > > -                     return !!tnum_equals_const(t1, val);
> > > > > > -             else if (val < umin1 || val > umax1)
> > > > > > +             /* const tnums */
> > > > > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > > > > +                     return t1.value == t2.value;
> > > > > > +             /* const ranges */
> > > > > > +             if (umin1 == umax1 && umin2 == umax2)
> > > > > > +                     return umin1 == umin2;
> > > > >
> > > > > I don't follow this logic.
> > > > > umin1 == umax1 means that it's a single constant and
> > > > > it should have been handled by earlier tnum_is_const check.
> > > >
> > > > I think you follow the logic, you just think it's redundant. Yes, it's
> > > > basically the same as
> > > >
> > > >           if (tnum_is_const(t1) && tnum_is_const(t2))
> > > >                 return t1.value == t2.value;
> > > >
> > > > but based on ranges. I didn't feel comfortable to assume that if umin1
> > > > == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> > > > perform one redundant check.
> > > >
> > > > In short, I don't trust tnum to be as precise as umin/umax and other ranges.
> > > >
> > > > >
> > > > > > +             if (smin1 == smax1 && smin2 == smax2)
> > > > > > +                     return umin1 == umin2;
> > > > >
> > > > > here it's even more confusing. smin == smax -> singel const,
> > > > > but then compare umin1 with umin2 ?!
> > > >
> > > > Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
> > > >
> > > > What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> > > > we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> > > > relation for second register). But I fixed these typos in both BPF_JEQ
> > > > and BPF_JNE branches.
> > >
> > > Not just 'saves us'. The tnum <-> bounds sync is mandatory.
> > > I think we have a test where a function returns [-errno, 0]
> > > and then we do if (ret < 0) check. At this point the reg has
> > > to be tnum_is_const and zero.
> > > So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
> > > Otherwise it's a bug in sync logic.
> > > I think instead of doing redundant and confusing check may be
> > > add WARN either here or in sync logic to make sure it's all good ?
> >
> > Ok, let's add it as part of register state sanity checks we discussed
> > on another patch. I'll drop the checks and will re-run all the test to
> > make sure we are not missing anything.
>
> So I have this as one more patch for the next revision (pending local
> testing). If you hate any part of it, I'd appreciate early feedback :)
> I'll wait for Eduard to finish going through the series (probably
> tomorrow), and then will post the next version based on all the
> feedback I got (and whatever might still come).
>
> Note, in the below, I don't output the actual register state on
> violation, which is unfortunate. But to make this happen I need to
> refactor print_verifier_state() to allow me to print register state.
> I've been wanting to move print_verifier_state() into kernel/bpf/log.c
> for a while now, and fix how we print the state of spilled registers
> (and maybe few more small things), so I'll do that separately, and
> then add register state printing to sanity check error.
>
>
> Author: Andrii Nakryiko <andrii@xxxxxxxxxx>
> Date:   Tue Oct 31 13:34:33 2023 -0700
>
>     bpf: add register bounds sanity checks
>
>     Add simple sanity checks that validate well-formed ranges (min <= max)
>     across u64, s64, u32, and s32 ranges. Also for cases when the value is
>     constant (either 64-bit or 32-bit), we validate that ranges and tnums
>     are in agreement.
>
>     These bounds checks are performed at the end of BPF_ALU/BPF_ALU64
>     operations, on conditional jumps, and for LDX instructions (where subreg
>     zero/sign extension is probably the most important to check). This
>     covers most of the interesting cases.
>
>     Also, we validate the sanity of the return register when manually
> adjusting it
>     for some special helpers.
>
>     Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>
>
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index c85d974ba21f..b29c85089bc9 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -2615,6 +2615,46 @@ static void reg_bounds_sync(struct bpf_reg_state *reg)
>         __update_reg_bounds(reg);
>  }
>
> +static int reg_bounds_sanity_check(struct bpf_verifier_env *env,
> struct bpf_reg_state *reg)
> +{
> +       const char *msg;
> +
> +       if (reg->umin_value > reg->umax_value ||
> +           reg->smin_value > reg->smax_value ||
> +           reg->u32_min_value > reg->u32_max_value ||
> +           reg->s32_min_value > reg->s32_max_value) {
> +                   msg = "range bounds violation";
> +                   goto out;
> +       }
> +
> +       if (tnum_is_const(reg->var_off)) {
> +               u64 uval = reg->var_off.value;
> +               s64 sval = (s64)uval;
> +
> +               if (reg->umin_value != uval || reg->umax_value != uval ||
> +                   reg->smin_value != sval || reg->smax_value != sval) {
> +                       msg = "const tnum out of sync with range bounds";
> +                       goto out;
> +               }
> +       }
> +
> +       if (tnum_subreg_is_const(reg->var_off)) {
> +               u32 uval32 = tnum_subreg(reg->var_off).value;
> +               s32 sval32 = (s32)uval32;
> +
> +               if (reg->u32_min_value != uval32 || reg->u32_max_value
> != uval32 ||
> +                   reg->s32_min_value != sval32 || reg->s32_max_value
> != sval32) {
> +                       msg = "const tnum (subreg) out of sync with
> range bounds";
> +                       goto out;
> +               }
> +       }
> +
> +       return 0;
> +out:
> +       verbose(env, "%s\n", msg);
> +       return -EFAULT;
> +}
> +
>  static bool __reg32_bound_s64(s32 a)
>  {
>         return a >= 0 && a <= S32_MAX;
> @@ -9928,14 +9968,15 @@ static int prepare_func_exit(struct
> bpf_verifier_env *env, int *insn_idx)
>         return 0;
>  }
>
> -static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type,
> -                                  int func_id,
> -                                  struct bpf_call_arg_meta *meta)
> +static int do_refine_retval_range(struct bpf_verifier_env *env,
> +                                 struct bpf_reg_state *regs, int ret_type,
> +                                 int func_id,
> +                                 struct bpf_call_arg_meta *meta)
>  {
>         struct bpf_reg_state *ret_reg = &regs[BPF_REG_0];
>
>         if (ret_type != RET_INTEGER)
> -               return;
> +               return 0;
>
>         switch (func_id) {
>         case BPF_FUNC_get_stack:
> @@ -9961,6 +10002,8 @@ static void do_refine_retval_range(struct
> bpf_reg_state *regs, int ret_type,
>                 reg_bounds_sync(ret_reg);
>                 break;
>         }
> +
> +       return reg_bounds_sanity_check(env, ret_reg);
>  }
>
>  static int
> @@ -10612,7 +10655,9 @@ static int check_helper_call(struct
> bpf_verifier_env *env, struct bpf_insn *insn
>                 regs[BPF_REG_0].ref_obj_id = id;
>         }
>
> -       do_refine_retval_range(regs, fn->ret_type, func_id, &meta);
> +       err = do_refine_retval_range(env, regs, fn->ret_type, func_id, &meta);
> +       if (err)
> +               return err;
>
>         err = check_map_func_compatibility(env, meta.map_ptr, func_id);
>         if (err)
> @@ -14079,13 +14124,12 @@ static int check_alu_op(struct
> bpf_verifier_env *env, struct bpf_insn *insn)
>
>                 /* check dest operand */
>                 err = check_reg_arg(env, insn->dst_reg, DST_OP_NO_MARK);
> +               err = err ?: adjust_reg_min_max_vals(env, insn);
>                 if (err)
>                         return err;
> -
> -               return adjust_reg_min_max_vals(env, insn);
>         }
>
> -       return 0;
> +       return reg_bounds_sanity_check(env, &regs[insn->dst_reg]);
>  }
>
>  static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
> @@ -14600,18 +14644,21 @@ static void regs_refine_cond_op(struct
> bpf_reg_state *reg1, struct bpf_reg_state
>   * Technically we can do similar adjustments for pointers to the same object,
>   * but we don't support that right now.
>   */
> -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> -                           struct bpf_reg_state *true_reg2,
> -                           struct bpf_reg_state *false_reg1,
> -                           struct bpf_reg_state *false_reg2,
> -                           u8 opcode, bool is_jmp32)
> +static int reg_set_min_max(struct bpf_verifier_env *env,
> +                          struct bpf_reg_state *true_reg1,
> +                          struct bpf_reg_state *true_reg2,
> +                          struct bpf_reg_state *false_reg1,
> +                          struct bpf_reg_state *false_reg2,
> +                          u8 opcode, bool is_jmp32)
>  {
> +       int err;
> +
>         /* If either register is a pointer, we can't learn anything about its
>          * variable offset from the compare (unless they were a pointer into
>          * the same object, but we don't bother with that).
>          */
>         if (false_reg1->type != SCALAR_VALUE || false_reg2->type !=
> SCALAR_VALUE)
> -               return;
> +               return 0;
>
>         /* fallthrough (FALSE) branch */
>         regs_refine_cond_op(false_reg1, false_reg2,
> rev_opcode(opcode), is_jmp32);
> @@ -14622,6 +14669,12 @@ static void reg_set_min_max(struct
> bpf_reg_state *true_reg1,
>         regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
>         reg_bounds_sync(true_reg1);
>         reg_bounds_sync(true_reg2);
> +
> +       err = reg_bounds_sanity_check(env, true_reg1);
> +       err = err ?: reg_bounds_sanity_check(env, true_reg2);
> +       err = err ?: reg_bounds_sanity_check(env, false_reg1);
> +       err = err ?: reg_bounds_sanity_check(env, false_reg2);
> +       return err;
>  }
>
>  static void mark_ptr_or_null_reg(struct bpf_func_state *state,
> @@ -14915,15 +14968,20 @@ static int check_cond_jmp_op(struct
> bpf_verifier_env *env,
>         other_branch_regs = other_branch->frame[other_branch->curframe]->regs;
>
>         if (BPF_SRC(insn->code) == BPF_X) {
> -               reg_set_min_max(&other_branch_regs[insn->dst_reg],
> -                               &other_branch_regs[insn->src_reg],
> -                               dst_reg, src_reg, opcode, is_jmp32);
> +               err = reg_set_min_max(env,
> +                                     &other_branch_regs[insn->dst_reg],
> +                                     &other_branch_regs[insn->src_reg],
> +                                     dst_reg, src_reg, opcode, is_jmp32);
>         } else /* BPF_SRC(insn->code) == BPF_K */ {
> -               reg_set_min_max(&other_branch_regs[insn->dst_reg],
> -                               src_reg /* fake one */,
> -                               dst_reg, src_reg /* same fake one */,
> -                               opcode, is_jmp32);
> +               err = reg_set_min_max(env,
> +                                     &other_branch_regs[insn->dst_reg],
> +                                     src_reg /* fake one */,
> +                                     dst_reg, src_reg /* same fake one */,
> +                                     opcode, is_jmp32);
>         }
> +       if (err)
> +               return err;
> +
>         if (BPF_SRC(insn->code) == BPF_X &&
>             src_reg->type == SCALAR_VALUE && src_reg->id &&
>             !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
> @@ -17426,10 +17484,8 @@ static int do_check(struct bpf_verifier_env *env)
>                                                insn->off, BPF_SIZE(insn->code),
>                                                BPF_READ, insn->dst_reg, false,
>                                                BPF_MODE(insn->code) ==
> BPF_MEMSX);
> -                       if (err)
> -                               return err;
> -
> -                       err = save_aux_ptr_type(env, src_reg_type, true);
> +                       err = err ?: save_aux_ptr_type(env, src_reg_type, true);
> +                       err = reg_bounds_sanity_check(env,
> &regs[insn->dst_reg]);

this should obviously be `err = err ?: reg_bounds_sanity_check(...)`
(somehow it gets obvious in the email, not locally)

>                         if (err)
>                                 return err;
>                 } else if (class == BPF_STX) {





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux