Re: [PATCH v5 bpf-next 19/23] bpf: generalize is_scalar_branch_taken() logic

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Tue, Oct 31, 2023 at 11:01 AM Andrii Nakryiko
<andrii.nakryiko@xxxxxxxxx> wrote:
>
> On Tue, Oct 31, 2023 at 9:35 AM Alexei Starovoitov
> <alexei.starovoitov@xxxxxxxxx> wrote:
> >
> > On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
> > <andrii.nakryiko@xxxxxxxxx> wrote:
> > >
> > > On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> > > <alexei.starovoitov@xxxxxxxxx> wrote:
> > > >
> > > > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > > > cases when both registers are not constants. Previously supported
> > > > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > > > vs <range> set of cases.
> > > > >
> > > > > Generalized logic relies on straightforward segment intersection checks.
> > > > >
> > > > > Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>
> > > > > ---
> > > > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > > > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > > > >
> > > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > > index 4c974296127b..f18a8247e5e2 100644
> > > > > --- a/kernel/bpf/verifier.c
> > > > > +++ b/kernel/bpf/verifier.c
> > > > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > > > >                                 u8 opcode, bool is_jmp32)
> > > > >  {
> > > > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > > > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > > > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > > > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > > > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > > > >
> > > > >       switch (opcode) {
> > > > >       case BPF_JEQ:
> > > > > -             if (tnum_is_const(t1))
> > > > > -                     return !!tnum_equals_const(t1, val);
> > > > > -             else if (val < umin1 || val > umax1)
> > > > > +             /* const tnums */
> > > > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > > > +                     return t1.value == t2.value;
> > > > > +             /* const ranges */
> > > > > +             if (umin1 == umax1 && umin2 == umax2)
> > > > > +                     return umin1 == umin2;
> > > >
> > > > I don't follow this logic.
> > > > umin1 == umax1 means that it's a single constant and
> > > > it should have been handled by earlier tnum_is_const check.
> > >
> > > I think you follow the logic, you just think it's redundant. Yes, it's
> > > basically the same as
> > >
> > >           if (tnum_is_const(t1) && tnum_is_const(t2))
> > >                 return t1.value == t2.value;
> > >
> > > but based on ranges. I didn't feel comfortable to assume that if umin1
> > > == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> > > perform one redundant check.
> > >
> > > In short, I don't trust tnum to be as precise as umin/umax and other ranges.
> > >
> > > >
> > > > > +             if (smin1 == smax1 && smin2 == smax2)
> > > > > +                     return umin1 == umin2;
> > > >
> > > > here it's even more confusing. smin == smax -> singel const,
> > > > but then compare umin1 with umin2 ?!
> > >
> > > Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
> > >
> > > What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> > > we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> > > relation for second register). But I fixed these typos in both BPF_JEQ
> > > and BPF_JNE branches.
> >
> > Not just 'saves us'. The tnum <-> bounds sync is mandatory.
> > I think we have a test where a function returns [-errno, 0]
> > and then we do if (ret < 0) check. At this point the reg has
> > to be tnum_is_const and zero.
> > So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
> > Otherwise it's a bug in sync logic.
> > I think instead of doing redundant and confusing check may be
> > add WARN either here or in sync logic to make sure it's all good ?
>
> Ok, let's add it as part of register state sanity checks we discussed
> on another patch. I'll drop the checks and will re-run all the test to
> make sure we are not missing anything.

So I have this as one more patch for the next revision (pending local
testing). If you hate any part of it, I'd appreciate early feedback :)
I'll wait for Eduard to finish going through the series (probably
tomorrow), and then will post the next version based on all the
feedback I got (and whatever might still come).

Note, in the below, I don't output the actual register state on
violation, which is unfortunate. But to make this happen I need to
refactor print_verifier_state() to allow me to print register state.
I've been wanting to move print_verifier_state() into kernel/bpf/log.c
for a while now, and fix how we print the state of spilled registers
(and maybe few more small things), so I'll do that separately, and
then add register state printing to sanity check error.


Author: Andrii Nakryiko <andrii@xxxxxxxxxx>
Date:   Tue Oct 31 13:34:33 2023 -0700

    bpf: add register bounds sanity checks

    Add simple sanity checks that validate well-formed ranges (min <= max)
    across u64, s64, u32, and s32 ranges. Also for cases when the value is
    constant (either 64-bit or 32-bit), we validate that ranges and tnums
    are in agreement.

    These bounds checks are performed at the end of BPF_ALU/BPF_ALU64
    operations, on conditional jumps, and for LDX instructions (where subreg
    zero/sign extension is probably the most important to check). This
    covers most of the interesting cases.

    Also, we validate the sanity of the return register when manually
adjusting it
    for some special helpers.

    Signed-off-by: Andrii Nakryiko <andrii@xxxxxxxxxx>

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index c85d974ba21f..b29c85089bc9 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -2615,6 +2615,46 @@ static void reg_bounds_sync(struct bpf_reg_state *reg)
        __update_reg_bounds(reg);
 }

+static int reg_bounds_sanity_check(struct bpf_verifier_env *env,
struct bpf_reg_state *reg)
+{
+       const char *msg;
+
+       if (reg->umin_value > reg->umax_value ||
+           reg->smin_value > reg->smax_value ||
+           reg->u32_min_value > reg->u32_max_value ||
+           reg->s32_min_value > reg->s32_max_value) {
+                   msg = "range bounds violation";
+                   goto out;
+       }
+
+       if (tnum_is_const(reg->var_off)) {
+               u64 uval = reg->var_off.value;
+               s64 sval = (s64)uval;
+
+               if (reg->umin_value != uval || reg->umax_value != uval ||
+                   reg->smin_value != sval || reg->smax_value != sval) {
+                       msg = "const tnum out of sync with range bounds";
+                       goto out;
+               }
+       }
+
+       if (tnum_subreg_is_const(reg->var_off)) {
+               u32 uval32 = tnum_subreg(reg->var_off).value;
+               s32 sval32 = (s32)uval32;
+
+               if (reg->u32_min_value != uval32 || reg->u32_max_value
!= uval32 ||
+                   reg->s32_min_value != sval32 || reg->s32_max_value
!= sval32) {
+                       msg = "const tnum (subreg) out of sync with
range bounds";
+                       goto out;
+               }
+       }
+
+       return 0;
+out:
+       verbose(env, "%s\n", msg);
+       return -EFAULT;
+}
+
 static bool __reg32_bound_s64(s32 a)
 {
        return a >= 0 && a <= S32_MAX;
@@ -9928,14 +9968,15 @@ static int prepare_func_exit(struct
bpf_verifier_env *env, int *insn_idx)
        return 0;
 }

-static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type,
-                                  int func_id,
-                                  struct bpf_call_arg_meta *meta)
+static int do_refine_retval_range(struct bpf_verifier_env *env,
+                                 struct bpf_reg_state *regs, int ret_type,
+                                 int func_id,
+                                 struct bpf_call_arg_meta *meta)
 {
        struct bpf_reg_state *ret_reg = &regs[BPF_REG_0];

        if (ret_type != RET_INTEGER)
-               return;
+               return 0;

        switch (func_id) {
        case BPF_FUNC_get_stack:
@@ -9961,6 +10002,8 @@ static void do_refine_retval_range(struct
bpf_reg_state *regs, int ret_type,
                reg_bounds_sync(ret_reg);
                break;
        }
+
+       return reg_bounds_sanity_check(env, ret_reg);
 }

 static int
@@ -10612,7 +10655,9 @@ static int check_helper_call(struct
bpf_verifier_env *env, struct bpf_insn *insn
                regs[BPF_REG_0].ref_obj_id = id;
        }

-       do_refine_retval_range(regs, fn->ret_type, func_id, &meta);
+       err = do_refine_retval_range(env, regs, fn->ret_type, func_id, &meta);
+       if (err)
+               return err;

        err = check_map_func_compatibility(env, meta.map_ptr, func_id);
        if (err)
@@ -14079,13 +14124,12 @@ static int check_alu_op(struct
bpf_verifier_env *env, struct bpf_insn *insn)

                /* check dest operand */
                err = check_reg_arg(env, insn->dst_reg, DST_OP_NO_MARK);
+               err = err ?: adjust_reg_min_max_vals(env, insn);
                if (err)
                        return err;
-
-               return adjust_reg_min_max_vals(env, insn);
        }

-       return 0;
+       return reg_bounds_sanity_check(env, &regs[insn->dst_reg]);
 }

 static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
@@ -14600,18 +14644,21 @@ static void regs_refine_cond_op(struct
bpf_reg_state *reg1, struct bpf_reg_state
  * Technically we can do similar adjustments for pointers to the same object,
  * but we don't support that right now.
  */
-static void reg_set_min_max(struct bpf_reg_state *true_reg1,
-                           struct bpf_reg_state *true_reg2,
-                           struct bpf_reg_state *false_reg1,
-                           struct bpf_reg_state *false_reg2,
-                           u8 opcode, bool is_jmp32)
+static int reg_set_min_max(struct bpf_verifier_env *env,
+                          struct bpf_reg_state *true_reg1,
+                          struct bpf_reg_state *true_reg2,
+                          struct bpf_reg_state *false_reg1,
+                          struct bpf_reg_state *false_reg2,
+                          u8 opcode, bool is_jmp32)
 {
+       int err;
+
        /* If either register is a pointer, we can't learn anything about its
         * variable offset from the compare (unless they were a pointer into
         * the same object, but we don't bother with that).
         */
        if (false_reg1->type != SCALAR_VALUE || false_reg2->type !=
SCALAR_VALUE)
-               return;
+               return 0;

        /* fallthrough (FALSE) branch */
        regs_refine_cond_op(false_reg1, false_reg2,
rev_opcode(opcode), is_jmp32);
@@ -14622,6 +14669,12 @@ static void reg_set_min_max(struct
bpf_reg_state *true_reg1,
        regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
        reg_bounds_sync(true_reg1);
        reg_bounds_sync(true_reg2);
+
+       err = reg_bounds_sanity_check(env, true_reg1);
+       err = err ?: reg_bounds_sanity_check(env, true_reg2);
+       err = err ?: reg_bounds_sanity_check(env, false_reg1);
+       err = err ?: reg_bounds_sanity_check(env, false_reg2);
+       return err;
 }

 static void mark_ptr_or_null_reg(struct bpf_func_state *state,
@@ -14915,15 +14968,20 @@ static int check_cond_jmp_op(struct
bpf_verifier_env *env,
        other_branch_regs = other_branch->frame[other_branch->curframe]->regs;

        if (BPF_SRC(insn->code) == BPF_X) {
-               reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                               &other_branch_regs[insn->src_reg],
-                               dst_reg, src_reg, opcode, is_jmp32);
+               err = reg_set_min_max(env,
+                                     &other_branch_regs[insn->dst_reg],
+                                     &other_branch_regs[insn->src_reg],
+                                     dst_reg, src_reg, opcode, is_jmp32);
        } else /* BPF_SRC(insn->code) == BPF_K */ {
-               reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                               src_reg /* fake one */,
-                               dst_reg, src_reg /* same fake one */,
-                               opcode, is_jmp32);
+               err = reg_set_min_max(env,
+                                     &other_branch_regs[insn->dst_reg],
+                                     src_reg /* fake one */,
+                                     dst_reg, src_reg /* same fake one */,
+                                     opcode, is_jmp32);
        }
+       if (err)
+               return err;
+
        if (BPF_SRC(insn->code) == BPF_X &&
            src_reg->type == SCALAR_VALUE && src_reg->id &&
            !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
@@ -17426,10 +17484,8 @@ static int do_check(struct bpf_verifier_env *env)
                                               insn->off, BPF_SIZE(insn->code),
                                               BPF_READ, insn->dst_reg, false,
                                               BPF_MODE(insn->code) ==
BPF_MEMSX);
-                       if (err)
-                               return err;
-
-                       err = save_aux_ptr_type(env, src_reg_type, true);
+                       err = err ?: save_aux_ptr_type(env, src_reg_type, true);
+                       err = reg_bounds_sanity_check(env,
&regs[insn->dst_reg]);
                        if (err)
                                return err;
                } else if (class == BPF_STX) {





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux