Re: [RFC bpf-next v2 2/9] bpf: no_caller_saved_registers attribute for helper calls

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Tue, 2024-07-09 at 16:42 -0700, Andrii Nakryiko wrote:

[...]

> > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > index 2b54e25d2364..735ae0901b3d 100644
> > --- a/include/linux/bpf_verifier.h
> > +++ b/include/linux/bpf_verifier.h
> > @@ -585,6 +585,15 @@ struct bpf_insn_aux_data {
> >          * accepts callback function as a parameter.
> >          */
> >         bool calls_callback;
> > +       /* true if STX or LDX instruction is a part of a spill/fill
> > +        * pattern for a no_caller_saved_registers call.
> > +        */
> > +       u8 nocsr_pattern:1;
> > +       /* for CALL instructions, a number of spill/fill pairs in the
> > +        * no_caller_saved_registers pattern.
> > +        */
> > +       u8 nocsr_spills_num:3;
> 
> despite bitfields this will extend bpf_insn_aux_data by 8 bytes. there
> are 2 bytes of padding after alu_state, let's put this there.
> 
> And let's not add bitfields unless absolutely necessary (this can be
> always done later).

Will remove the bitfields and move the fields.

> 
> > +
> >  };
> > 
> >  #define MAX_USED_MAPS 64 /* max number of maps accessed by one eBPF program */
> > @@ -641,6 +650,11 @@ struct bpf_subprog_info {
> >         u32 linfo_idx; /* The idx to the main_prog->aux->linfo */
> >         u16 stack_depth; /* max. stack depth used by this function */
> >         u16 stack_extra;
> > +       /* stack depth after which slots reserved for
> > +        * no_caller_saved_registers spills/fills start,
> > +        * value <= nocsr_stack_off belongs to the spill/fill area.
> 
> are you sure about <= (not <), it seems like nocsr_stack_off is
> exclusive right bound for nocsr stack region (it would be good to call
> this out explicitly here)

Right, it should be '<', my bad, will update the comment.

> 
> > +        */
> > +       s16 nocsr_stack_off;
> >         bool has_tail_call: 1;
> >         bool tail_call_reachable: 1;
> >         bool has_ld_abs: 1;
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 4869f1fb0a42..d16a249b59ad 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -2471,16 +2471,37 @@ static int cmp_subprogs(const void *a, const void *b)
> >                ((struct bpf_subprog_info *)b)->start;
> >  }
> > 
> > -static int find_subprog(struct bpf_verifier_env *env, int off)
> > +/* Find subprogram that contains instruction at 'off' */
> > +static int find_containing_subprog(struct bpf_verifier_env *env, int off)
> >  {
> > -       struct bpf_subprog_info *p;
> > +       struct bpf_subprog_info *vals = env->subprog_info;
> > +       int l, r, m;
> > 
> > -       p = bsearch(&off, env->subprog_info, env->subprog_cnt,
> > -                   sizeof(env->subprog_info[0]), cmp_subprogs);
> > -       if (!p)
> > +       if (off >= env->prog->len || off < 0 || env->subprog_cnt == 0)
> >                 return -ENOENT;
> > -       return p - env->subprog_info;
> > 
> > +       l = 0;
> > +       m = 0;
> 
> no need to initialize m

Ok

> 
> > +       r = env->subprog_cnt - 1;
> > +       while (l < r) {
> > +               m = l + (r - l + 1) / 2;
> > +               if (vals[m].start <= off)
> > +                       l = m;
> > +               else
> > +                       r = m - 1;
> > +       }
> > +       return l;
> > +}
> 
> I love it, looks great :)
>

Agree

[...]

> > @@ -4501,6 +4522,23 @@ static int get_reg_width(struct bpf_reg_state *reg)
> >         return fls64(reg->umax_value);
> >  }
> > 
> > +/* See comment for mark_nocsr_pattern_for_call() */
> > +static void check_nocsr_stack_contract(struct bpf_verifier_env *env, struct bpf_func_state *state,
> > +                                      int insn_idx, int off)
> > +{
> > +       struct bpf_subprog_info *subprog = &env->subprog_info[state->subprogno];
> > +       struct bpf_insn_aux_data *aux = &env->insn_aux_data[insn_idx];
> > +
> > +       if (subprog->nocsr_stack_off <= off || aux->nocsr_pattern)
> > +               return;
> 
> can helper call instruction go through this check? E.g., if we do
> bpf_probe_read_kernel() into stack slot, where do we check that that
> slot is not overlapping with nocsr spill/fill region?

In check_helper_call() we do check_mem_access() that eventually calls
one of the check_stack_{read,write}_{fixed,varying}_off().
The .access_size should be set for bpf_probe_read_kernel()
because it's argument base type is ARG_PTR_TO_MEM.
I will add a test case to double-check this.

[...]

> > @@ -15951,6 +15993,206 @@ static int visit_func_call_insn(int t, struct bpf_insn *insns,
> >         return ret;
> >  }
> > 
> > +/* Bitmask with 1s for all caller saved registers */
> > +#define ALL_CALLER_SAVED_REGS ((1u << CALLER_SAVED_REGS) - 1)
> > +
> > +/* Return a bitmask specifying which caller saved registers are
> > + * modified by a call to a helper.
> > + * (Either as a return value or as scratch registers).
> > + *
> > + * For normal helpers registers R0-R5 are scratched.
> > + * For helpers marked as no_csr:
> > + * - scratch R0 if function is non-void;
> > + * - scratch R1-R5 if corresponding parameter type is set
> > + *   in the function prototype.
> > + */
> > +static u8 get_helper_reg_mask(const struct bpf_func_proto *fn)
> 
> suggestion: to make this less confusing, here we are returning a mask
> of registers that are clobbered by the helper, is that right? so
> get_helper_clobber_mask() maybe?

get_helper_clobber_mask() is a good name, will change.

[...]

> > +/* If 'insn' is a call that follows no_caller_saved_registers contract
> > + * and called function is inlined by current jit or verifier,
> > + * return a mask with 1s corresponding to registers that are scratched
> > + * by this call (depends on return type and number of return parameters).
> 
> return parameters? was it supposed to be "function parameters/arguments"?

My bad.

> 
> > + * Otherwise return ALL_CALLER_SAVED_REGS mask.
> > + */
> > +static u32 call_csr_mask(struct bpf_verifier_env *env, struct bpf_insn *insn)
> 
> you use u8 for get_helper_reg_mask() and u32 here, why not keep them consistent?

Ok

> similar to the naming nit above, I think we should be a bit more
> explicit with what "mask" actually means. Is this also clobber mask?

I mean, there is a comment right above the function.
This function returns a mask of caller saved registers (csr).
I'll make the name more explicit.

> 
> > +{
> > +       const struct bpf_func_proto *fn;
> > +
> > +       if (bpf_helper_call(insn) &&
> > +           (verifier_inlines_helper_call(env, insn->imm) || bpf_jit_inlines_helper_call(insn->imm)) &&
> > +           get_helper_proto(env, insn->imm, &fn) == 0 &&
> > +           fn->allow_nocsr)
> > +               return ~get_helper_reg_mask(fn);
> 
> hm... I'm a bit confused why we do a negation here? aren't we working
> with clobbering mask... I'll keep reading for now.

Please read the comment before the function.

> 
> > +
> > +       return ALL_CALLER_SAVED_REGS;
> > +}

[...]

> > +static void mark_nocsr_pattern_for_call(struct bpf_verifier_env *env, int t)
> 
> t is insn_idx, let's not carry over old crufty check_cfg naming

Ok

> 
> > +{
> > +       struct bpf_insn *insns = env->prog->insnsi, *stx, *ldx;
> > +       struct bpf_subprog_info *subprog;
> > +       u32 csr_mask = call_csr_mask(env, &insns[t]);
> > +       u32 reg_mask = ~csr_mask | ~ALL_CALLER_SAVED_REGS;
> 
> tbh, I'm lost with all these bitmask and their inversions...
> call_csr_mask()'s result is basically always used inverted, so why not
> return inverted mask in the first place?

The mask is initialized as a set of all registers preserved by this call.
Those that are not in mask need a spill/fill pair.
I'll toss things around to make this more clear.
(naming, comments, maybe move the '| ~ALL_CALLER_SAVED_REGS' to the call_csr_mask()).

> 
> > +       int s, i;
> > +       s16 off;
> > +
> > +       if (csr_mask == ALL_CALLER_SAVED_REGS)
> > +               return;
> > +
> > +       for (i = 1, off = 0; i <= ARRAY_SIZE(caller_saved); ++i, off += BPF_REG_SIZE) {
> > +               if (t - i < 0 || t + i >= env->prog->len)
> > +                       break;
> > +               stx = &insns[t - i];
> > +               ldx = &insns[t + i];
> > +               if (off == 0) {
> > +                       off = stx->off;
> > +                       if (off % BPF_REG_SIZE != 0)
> > +                               break;
> 
> kind of ugly that we assume stx before we actually checked that it's
> STX?... maybe split humongous if below into instruction checking
> (with code and src_reg) and then off checking separately?

Don't see anything ugly about this, tbh.
Can split the 'if' statement, if you think it's hard to read.

> 
> > +               }
> > +               if (/* *(u64 *)(r10 - off) = r[0-5]? */
> > +                   stx->code != (BPF_STX | BPF_MEM | BPF_DW) ||
> > +                   stx->dst_reg != BPF_REG_10 ||
> > +                   /* r[0-5] = *(u64 *)(r10 - off)? */
> > +                   ldx->code != (BPF_LDX | BPF_MEM | BPF_DW) ||
> > +                   ldx->src_reg != BPF_REG_10 ||
> > +                   /* check spill/fill for the same reg and offset */
> > +                   stx->src_reg != ldx->dst_reg ||
> > +                   stx->off != ldx->off ||
> > +                   stx->off != off ||
> > +                   /* this should be a previously unseen register */
> > +                   BIT(stx->src_reg) & reg_mask)
> 
> () around & operation?

No need, & has higher priority over ||.
You can check the AST using
https://tree-sitter.github.io/tree-sitter/playground .

> 
> > +                       break;
> > +               reg_mask |= BIT(stx->src_reg);
> > +               env->insn_aux_data[t - i].nocsr_pattern = 1;
> > +               env->insn_aux_data[t + i].nocsr_pattern = 1;
> > +       }
> > +       if (i == 1)
> > +               return;
> > +       env->insn_aux_data[t].nocsr_spills_num = i - 1;
> > +       s = find_containing_subprog(env, t);
> > +       /* can't happen */
> 
> then don't check ;) we leave the state partially set for CSR but not
> quite. We either should error out completely or just assume
> correctness of find_containing_subprog, IMO

Ok

> 
> > +       if (WARN_ON_ONCE(s < 0))
> > +               return;
> > +       subprog = &env->subprog_info[s];
> > +       subprog->nocsr_stack_off = min(subprog->nocsr_stack_off, off);
> 
> should this be max()? offsets are negative, right? so if nocsr uses -8
> and -16 as in the example, entire [-16, 0) region is nocsr region

This should be min exactly because stack offsets are negative.
For the example above the 'off' is initialized as -16 and then
is incremented by +8 giving final value of -8.
And I need to select the minimal value used between several patterns.

> 
> > +}

[...]

> > @@ -20119,6 +20361,48 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
> >                         goto next_insn;
> >                 if (insn->src_reg == BPF_PSEUDO_CALL)
> >                         goto next_insn;
> > +               /* Remove unnecessary spill/fill pairs, members of nocsr pattern */
> > +               if (env->insn_aux_data[i + delta].nocsr_spills_num > 0) {
> > +                       u32 j, spills_num = env->insn_aux_data[i + delta].nocsr_spills_num;
> > +                       int err;
> > +
> > +                       /* don't apply this on a second visit */
> > +                       env->insn_aux_data[i + delta].nocsr_spills_num = 0;
> > +
> > +                       /* check if spill/fill stack access is in expected offset range */
> > +                       for (j = 1; j <= spills_num; ++j) {
> > +                               if ((insn - j)->off >= subprogs[cur_subprog].nocsr_stack_off ||
> > +                                   (insn + j)->off >= subprogs[cur_subprog].nocsr_stack_off) {
> > +                                       /* do a second visit of this instruction,
> > +                                        * so that verifier can inline it
> > +                                        */
> > +                                       i -= 1;
> > +                                       insn -= 1;
> > +                                       goto next_insn;
> > +                               }
> > +                       }
> 
> I don't get this loop, can you elaborate? Why are we double-checking
> anything here, didn't we do this already?

We established probable patterns and probable minimal offset.
Over the course of program verification we might have invalidated the
.nocsr_stack_off for a particular subprogram => hence a need for this check.

> 
> > +
> > +                       /* apply the rewrite:
> > +                        *   *(u64 *)(r10 - X) = rY ; num-times
> > +                        *   call()                               -> call()
> > +                        *   rY = *(u64 *)(r10 - X) ; num-times
> > +                        */
> > +                       err = verifier_remove_insns(env, i + delta - spills_num, spills_num);
> > +                       if (err)
> > +                               return err;
> > +                       err = verifier_remove_insns(env, i + delta - spills_num + 1, spills_num);
> > +                       if (err)
> > +                               return err;
> 
> why not a single bpf_patch_insn_data()?

bpf_patch_insn_data() assumes that one instruction has to be replaced with many.
Here I need to replace many instructions with a single instruction.
I'd prefer not to tweak bpf_patch_insn_data() for this patch-set.

On the other hand, the do_jit() for x86 removes NOPs (BPF_JA +0),
so I can probably replace spills/fills with NOPs here instead of
calling bpf_patch_insn_data() or bpf_remove_insns().

> > +
> > +                       i += spills_num - 1;
> > +                       /*   ^            ^   do a second visit of this instruction,
> > +                        *   |            '-- so that verifier can inline it
> > +                        *   '--------------- jump over deleted fills
> > +                        */
> > +                       delta -= 2 * spills_num;
> > +                       insn = env->prog->insnsi + i + delta;
> > +                       goto next_insn;
> 
> why not adjust the state and just fall through, what goto next_insn
> does that we can't (and next instruction is misleading, so I'd rather
> fix up and move forward)

I don't like this. The fall-through makes control flow more convoluted.
To understand what would happen next:
- with goto next_insn we just start over;
- with fall-through we need to think about position of this particular
  'if' statement within the loop.

> 
> > +               }
> >                 if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
> >                         ret = fixup_kfunc_call(env, insn, insn_buf, i + delta, &cnt);
> >                         if (ret)

[...]





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux