Re: [PATCH bpf-next v3 1/4] bpf: use scalar ids in mark_chain_precision()

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Wed, 2023-06-07 at 14:40 -0700, Andrii Nakryiko wrote:
> On Tue, Jun 6, 2023 at 3:24 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote:
> > 
> > Change mark_chain_precision() to track precision in situations
> > like below:
> > 
> >     r2 = unknown value
> >     ...
> >   --- state #0 ---
> >     ...
> >     r1 = r2                 // r1 and r2 now share the same ID
> >     ...
> >   --- state #1 {r1.id = A, r2.id = A} ---
> >     ...
> >     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> >     ...
> >   --- state #2 {r1.id = A, r2.id = A} ---
> >     r3 = r10
> >     r3 += r1                // need to mark both r1 and r2
> > 
> > At the beginning of the processing of each state, ensure that if a
> > register with a scalar ID is marked as precise, all registers sharing
> > this ID are also marked as precise.
> > 
> > This property would be used by a follow-up change in regsafe().
> > 
> > Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx>
> > ---
> >  include/linux/bpf_verifier.h                  |  10 +-
> >  kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
> >  .../testing/selftests/bpf/verifier/precise.c  |   8 +-
> >  3 files changed, 127 insertions(+), 5 deletions(-)
> > 
> > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > index ee4cc7471ed9..3f9856baa542 100644
> > --- a/include/linux/bpf_verifier.h
> > +++ b/include/linux/bpf_verifier.h
> > @@ -559,6 +559,11 @@ struct backtrack_state {
> >         u64 stack_masks[MAX_CALL_FRAMES];
> >  };
> > 
> > +struct reg_id_scratch {
> > +       u32 count;
> > +       u32 ids[BPF_ID_MAP_SIZE];
> > +};
> > +
> >  /* single container for all structs
> >   * one verifier_env per bpf_check() call
> >   */
> > @@ -590,7 +595,10 @@ struct bpf_verifier_env {
> >         const struct bpf_line_info *prev_linfo;
> >         struct bpf_verifier_log log;
> >         struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
> > -       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > +       union {
> > +               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > +               struct reg_id_scratch precise_ids_scratch;
> 
> naming nit: "ids_scratch" or "idset_scratch" to stay in line with
> "idmap_scratch"?

Makes sense, will change to "idset_scratch".

> 
> > +       };
> >         struct {
> >                 int *insn_state;
> >                 int *insn_stack;
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index d117deb03806..2aa60b73f1b5 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
> >         }
> >  }
> > 
> > +static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
> > +{
> > +       u32 i;
> > +
> > +       for (i = 0; i < s->count; ++i)
> > +               if (s->ids[i] == id)
> > +                       return true;
> > +
> > +       return false;
> > +}
> > +
> > +static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
> > +{
> > +       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
> > +               return -1;
> > +       s->ids[s->count++] = id;
> 
> this will allow duplicated IDs to be added? Was it done in the name of speed?

tbh, it's an artifact from bsearch/sort migration of a series.
While doing test veristat runs I found that maximal value of s->count is 5,
so looks like it would be fine the way it is now and it would be fine
if linear scan is added to avoid duplicate ids. Don't think I have a preference.

> 
> > +       WARN_ONCE(s->count > 64,
> > +                 "reg_id_scratch.count is unreasonably large (%d)", s->count);
> 
> do we need this one? Especially that it's not _ONCE variant? Maybe the
> first WARN_ON_ONCE is enough?

We make an assumption that linear scans of this array are ok, and it
would be scanned often. I'd like to have some indication if this
assumption is broken. The s->ids array is large (10 regs + 64 spills) * 8 frames.
If you think that this logging is not necessary I'll remove it.

> 
> > +       return 0;
> > +}
> > +
> > +static inline void reg_id_scratch_reset(struct reg_id_scratch *s)
> 
> we probably don't need "inline" for all these helpers?

Ok, will remove "inline".

> 
> > +{
> > +       s->count = 0;
> > +}
> > +
> > +/* Collect a set of IDs for all registers currently marked as precise in env->bt.
> > + * Mark all registers with these IDs as precise.
> > + */
> > +static void mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
> > +{
> > +       struct reg_id_scratch *precise_ids = &env->precise_ids_scratch;
> > +       struct backtrack_state *bt = &env->bt;
> > +       struct bpf_func_state *func;
> > +       struct bpf_reg_state *reg;
> > +       DECLARE_BITMAP(mask, 64);
> > +       int i, fr;
> > +
> > +       reg_id_scratch_reset(precise_ids);
> > +
> > +       for (fr = bt->frame; fr >= 0; fr--) {
> > +               func = st->frame[fr];
> > +
> > +               bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
> > +               for_each_set_bit(i, mask, 32) {
> > +                       reg = &func->regs[i];
> > +                       if (!reg->id || reg->type != SCALAR_VALUE)
> > +                               continue;
> > +                       if (reg_id_scratch_push(precise_ids, reg->id))
> > +                               return;
> > +               }
> > +
> > +               bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
> > +               for_each_set_bit(i, mask, 64) {
> > +                       if (i >= func->allocated_stack / BPF_REG_SIZE)
> > +                               break;
> > +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> > +                               continue;
> > +                       reg = &func->stack[i].spilled_ptr;
> > +                       if (!reg->id || reg->type != SCALAR_VALUE)
> 
> is_spilled_scalar_reg() already ensures reg->type is SCALAR_VALUE

Yes, my bad.

> 
> > +                               continue;
> > +                       if (reg_id_scratch_push(precise_ids, reg->id))
> > +                               return;
> 
> if push fails (due to overflow of id set), shouldn't we propagate
> error back and fallback to mark_all_precise?

In theory this push should never fail, as we pre-allocate enough slots
in the scratch. I'll propagate error to __mark_chain_precision() and
exit from that one with -EFAULT.

> 
> 
> > +               }
> > +       }
> > +
> > +       for (fr = 0; fr <= st->curframe; ++fr) {
> > +               func = st->frame[fr];
> > +
> > +               for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
> > +                       reg = &func->regs[i];
> > +                       if (!reg->id)
> > +                               continue;
> > +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> > +                               continue;
> > +                       bt_set_frame_reg(bt, fr, i);
> > +               }
> > +               for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
> > +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> > +                               continue;
> > +                       reg = &func->stack[i].spilled_ptr;
> > +                       if (!reg->id)
> > +                               continue;
> > +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> > +                               continue;
> > +                       bt_set_frame_slot(bt, fr, i);
> > +               }
> > +       }
> > +}
> > +
> >  /*
> >   * __mark_chain_precision() backtracks BPF program instruction sequence and
> >   * chain of verifier states making sure that register *regno* (if regno >= 0)
> > @@ -3910,6 +4000,30 @@ static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
> >                                 bt->frame, last_idx, first_idx, subseq_idx);
> >                 }
> > 
> > +               /* If some register with scalar ID is marked as precise,
> > +                * make sure that all registers sharing this ID are also precise.
> > +                * This is needed to estimate effect of find_equal_scalars().
> > +                * Do this at the last instruction of each state,
> > +                * bpf_reg_state::id fields are valid for these instructions.
> > +                *
> > +                * Allows to track precision in situation like below:
> > +                *
> > +                *     r2 = unknown value
> > +                *     ...
> > +                *   --- state #0 ---
> > +                *     ...
> > +                *     r1 = r2                 // r1 and r2 now share the same ID
> > +                *     ...
> > +                *   --- state #1 {r1.id = A, r2.id = A} ---
> > +                *     ...
> > +                *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> > +                *     ...
> > +                *   --- state #2 {r1.id = A, r2.id = A} ---
> > +                *     r3 = r10
> > +                *     r3 += r1                // need to mark both r1 and r2
> > +                */
> > +               mark_precise_scalar_ids(env, st);
> > +
> >                 if (last_idx < 0) {
> >                         /* we are at the entry into subprog, which
> >                          * is expected for global funcs, but only if
> > diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
> > index b8c0aae8e7ec..99272bb890da 100644
> > --- a/tools/testing/selftests/bpf/verifier/precise.c
> > +++ b/tools/testing/selftests/bpf/verifier/precise.c
> > @@ -46,7 +46,7 @@
> >         mark_precise: frame0: regs=r2 stack= before 20\
> >         mark_precise: frame0: parent state regs=r2 stack=:\
> >         mark_precise: frame0: last_idx 19 first_idx 10\
> > -       mark_precise: frame0: regs=r2 stack= before 19\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 19\
> >         mark_precise: frame0: regs=r9 stack= before 18\
> >         mark_precise: frame0: regs=r8,r9 stack= before 17\
> >         mark_precise: frame0: regs=r0,r9 stack= before 15\
> > @@ -106,10 +106,10 @@
> >         mark_precise: frame0: regs=r2 stack= before 22\
> >         mark_precise: frame0: parent state regs=r2 stack=:\
> >         mark_precise: frame0: last_idx 20 first_idx 20\
> > -       mark_precise: frame0: regs=r2 stack= before 20\
> > -       mark_precise: frame0: parent state regs=r2 stack=:\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 20\
> > +       mark_precise: frame0: parent state regs=r2,r9 stack=:\
> >         mark_precise: frame0: last_idx 19 first_idx 17\
> > -       mark_precise: frame0: regs=r2 stack= before 19\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 19\
> >         mark_precise: frame0: regs=r9 stack= before 18\
> >         mark_precise: frame0: regs=r8,r9 stack= before 17\
> >         mark_precise: frame0: parent state regs= stack=:",
> > --
> > 2.40.1
> > 






[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux