Re: [RFC bpf-next v1 2/8] bpf: no_caller_saved_registers attribute for helper calls

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Sat, Jun 29, 2024 at 2:48 AM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote:
>
> GCC and LLVM define a no_caller_saved_registers function attribute.
> This attribute means that function scratches only some of
> the caller saved registers defined by ABI.
> For BPF the set of such registers could be defined as follows:
> - R0 is scratched only if function is non-void;
> - R1-R5 are scratched only if corresponding parameter type is defined
>   in the function prototype.
>
> This commit introduces flag bpf_func_prot->nocsr.
> If this flag is set for some helper function, verifier assumes that
> it follows no_caller_saved_registers calling convention.
>
> The contract between kernel and clang allows to simultaneously use
> such functions and maintain backwards compatibility with old
> kernels that don't understand no_caller_saved_registers calls
> (nocsr for short):
>
> - clang generates a simple pattern for nocsr calls, e.g.:
>
>     r1 = 1;
>     r2 = 2;
>     *(u64 *)(r10 - 8)  = r1;
>     *(u64 *)(r10 - 16) = r2;
>     call %[to_be_inlined_by_jit]

"inline_by_jit" is misleading, it can be inlined by BPF verifier using
BPF instructions, not just by BPF JIT

>     r2 = *(u64 *)(r10 - 16);
>     r1 = *(u64 *)(r10 - 8);
>     r0 = r1;
>     r0 += r2;
>     exit;
>
> - kernel removes unnecessary spills and fills, if called function is
>   inlined by current JIT (with assumption that patch inserted by JIT
>   honors nocsr contract, e.g. does not scratch r3-r5 for the example
>   above), e.g. the code above would be transformed to:
>
>     r1 = 1;
>     r2 = 2;
>     call %[to_be_inlined_by_jit]
>     r0 = r1;
>     r0 += r2;
>     exit;
>
> Technically, the transformation is split into the following phases:
> - during check_cfg() function update_nocsr_pattern_marks() is used to
>   find potential patterns;
> - upon stack read or write access,
>   function check_nocsr_stack_contract() is used to verify if
>   stack offsets, presumably reserved for nocsr patterns, are used
>   only from those patterns;
> - function remove_nocsr_spills_fills(), called from bpf_check(),
>   applies the rewrite for valid patterns.
>
> See comment in match_and_mark_nocsr_pattern() for more details.
>
> Suggested-by: Alexei Starovoitov <alexei.starovoitov@xxxxxxxxx>
> Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx>
> ---
>  include/linux/bpf.h          |   6 +
>  include/linux/bpf_verifier.h |   9 ++
>  kernel/bpf/verifier.c        | 300 ++++++++++++++++++++++++++++++++++-
>  3 files changed, 307 insertions(+), 8 deletions(-)
>

[...]

> -static int find_subprog(struct bpf_verifier_env *env, int off)
> +/* Find subprogram that contains instruction at 'off' */
> +static int find_containing_subprog(struct bpf_verifier_env *env, int off)
>  {
> -       struct bpf_subprog_info *p;
> +       struct bpf_subprog_info *vals = env->subprog_info;
> +       int high = env->subprog_cnt - 1;
> +       int low = 0, ret = -ENOENT;
>
> -       p = bsearch(&off, env->subprog_info, env->subprog_cnt,
> -                   sizeof(env->subprog_info[0]), cmp_subprogs);
> -       if (!p)
> +       if (off >= env->prog->len || off < 0)
>                 return -ENOENT;
> -       return p - env->subprog_info;
>
> +       while (low <= high) {
> +               int mid = (low + high)/2;

styling nit: (...) / 2

> +               struct bpf_subprog_info *val = &vals[mid];
> +               int diff = off - val->start;
> +
> +               if (diff < 0) {

tbh, this hurts my brain. Why not write human-readable and more meaningful

if (off < val->start)

?


> +                       high = mid - 1;
> +               } else {
> +                       low = mid + 1;
> +                       /* remember last time mid.start <= off */
> +                       ret = mid;
> +               }

feel free to ignore, but I find this unnecessary `ret = mid` part a
bit inelegant. See find_linfo in kernel/bpf/log.c for how
lower_bound-like binary search could be implemented without this (I
mean the pattern where invariant keeps low or high as always
satisfying the condition and the other one being adjusted with +1 or
-1, depending on desired logic).


> +       }
> +       return ret;
> +}
> +
> +/* Find subprogram that starts exactly at 'off' */
> +static int find_subprog(struct bpf_verifier_env *env, int off)
> +{
> +       int idx;
> +
> +       idx = find_containing_subprog(env, off);
> +       if (idx < 0 || env->subprog_info[idx].start != off)
> +               return -ENOENT;
> +       return idx;
>  }
>

[...]

> +static u8 get_helper_reg_mask(const struct bpf_func_proto *fn)
> +{
> +       u8 mask;
> +       int i;
> +
> +       if (!fn->nocsr)
> +               return ALL_CALLER_SAVED_REGS;
> +
> +       mask = 0;
> +       mask |= fn->ret_type == RET_VOID ? 0 : BIT(BPF_REG_0);
> +       for (i = 0; i < ARRAY_SIZE(fn->arg_type); ++i)
> +               mask |= fn->arg_type[i] == ARG_DONTCARE ? 0 : BIT(BPF_REG_1 + i);

again subjective, but

if (fn->ret_type != RET_VOID)
    mask |= BIT(BPF_REG_0);

(and similarly for ARG_DONTCARE)

seems a bit more readable and not that much more verbose

> +       return mask;
> +}
> +
> +/* True if do_misc_fixups() replaces calls to helper number 'imm',
> + * replacement patch is presumed to follow no_caller_saved_registers contract
> + * (see match_and_mark_nocsr_pattern() below).
> + */
> +static bool verifier_inlines_helper_call(struct bpf_verifier_env *env, s32 imm)
> +{

note that there is now also bpf_jit_inlines_helper_call()


> +       return false;
> +}
> +
> +/* If 'insn' is a call that follows no_caller_saved_registers contract
> + * and called function is inlined by current jit, return a mask with
> + * 1s corresponding to registers that are scratched by this call
> + * (depends on return type and number of return parameters).
> + * Otherwise return ALL_CALLER_SAVED_REGS mask.
> + */
> +static u32 call_csr_mask(struct bpf_verifier_env *env, struct bpf_insn *insn)
> +{
> +       const struct bpf_func_proto *fn;
> +
> +       if (bpf_helper_call(insn) &&
> +           verifier_inlines_helper_call(env, insn->imm) &&

strictly speaking, does nocsr have anything to do with inlining,
though? E.g., if we know for sure (however, that's a separate issue)
that helper implementation doesn't touch extra registers, why do we
need inlining to make use of nocsr?

> +           get_helper_proto(env, insn->imm, &fn) == 0 &&
> +           fn->nocsr)
> +               return ~get_helper_reg_mask(fn);
> +
> +       return ALL_CALLER_SAVED_REGS;
> +}
> +

[...]

> + * For example, it is *not* safe to remove spill/fill below:
> + *
> + *   r1 = 1;
> + *   *(u64 *)(r10 - 8)  = r1;            r1 = 1;
> + *   call %[to_be_inlined_by_jit]  -->   call %[to_be_inlined_by_jit]
> + *   r1 = *(u64 *)(r10 - 8);             r0 = *(u64 *)(r10 - 8);  <---- wrong !!!
> + *   r0 = *(u64 *)(r10 - 8);             r0 += r1;
> + *   r0 += r1;                           exit;
> + *   exit;
> + */
> +static int match_and_mark_nocsr_pattern(struct bpf_verifier_env *env, int t, bool mark)
> +{
> +       struct bpf_insn *insns = env->prog->insnsi, *stx, *ldx;
> +       struct bpf_subprog_info *subprog;
> +       u32 csr_mask = call_csr_mask(env, &insns[t]);
> +       u32 reg_mask = ~csr_mask | ~ALL_CALLER_SAVED_REGS;
> +       int s, i;
> +       s16 off;
> +
> +       if (csr_mask == ALL_CALLER_SAVED_REGS)
> +               return false;

false -> 0 ?

> +
> +       for (i = 1, off = 0; i <= ARRAY_SIZE(caller_saved); ++i, off += BPF_REG_SIZE) {
> +               if (t - i < 0 || t + i >= env->prog->len)
> +                       break;
> +               stx = &insns[t - i];
> +               ldx = &insns[t + i];
> +               if (off == 0) {
> +                       off = stx->off;
> +                       if (off % BPF_REG_SIZE != 0)
> +                               break;

just return here?

> +               }
> +               if (/* *(u64 *)(r10 - off) = r[0-5]? */
> +                   stx->code != (BPF_STX | BPF_MEM | BPF_DW) ||
> +                   stx->dst_reg != BPF_REG_10 ||
> +                   /* r[0-5] = *(u64 *)(r10 - off)? */
> +                   ldx->code != (BPF_LDX | BPF_MEM | BPF_DW) ||
> +                   ldx->src_reg != BPF_REG_10 ||
> +                   /* check spill/fill for the same reg and offset */
> +                   stx->src_reg != ldx->dst_reg ||
> +                   stx->off != ldx->off ||
> +                   stx->off != off ||
> +                   /* this should be a previously unseen register */
> +                   BIT(stx->src_reg) & reg_mask)
> +                       break;
> +               reg_mask |= BIT(stx->src_reg);
> +               if (mark) {
> +                       env->insn_aux_data[t - i].nocsr_pattern = true;
> +                       env->insn_aux_data[t + i].nocsr_pattern = true;
> +               }
> +       }
> +       if (i == 1)
> +               return 0;
> +       if (mark) {
> +               s = find_containing_subprog(env, t);
> +               /* can't happen */
> +               if (WARN_ON_ONCE(s < 0))
> +                       return 0;
> +               subprog = &env->subprog_info[s];
> +               subprog->nocsr_stack_off = min(subprog->nocsr_stack_off, off);
> +       }

why not split pattern detection and all this other marking logic? You
can return "the size of csr pattern", meaning how many spills/fills
are there surrounding the call, no? Then all the marking can be done
(if necessary) by the caller.

The question is what to do with zero patter (no spills/fills for nocsr
call, is that valid case?)

> +       return i - 1;
> +}
> +
> +/* If instruction 't' is a nocsr call surrounded by spill/fill pairs,
> + * update env->subprog_info[_]->nocsr_stack_off and
> + * env->insn_aux_data[_].nocsr_pattern fields.
> + */
> +static void update_nocsr_pattern_marks(struct bpf_verifier_env *env, int t)
> +{
> +       match_and_mark_nocsr_pattern(env, t, true);
> +}
> +
> +/* If instruction 't' is a nocsr call surrounded by spill/fill pairs,
> + * return the number of such pairs.
> + */
> +static int match_nocsr_pattern(struct bpf_verifier_env *env, int t)
> +{
> +       return match_and_mark_nocsr_pattern(env, t, false);
> +}
> +
>  /* Visits the instruction at index t and returns one of the following:
>   *  < 0 - an error occurred
>   *  DONE_EXPLORING - the instruction was fully explored
> @@ -16017,6 +16262,8 @@ static int visit_insn(int t, struct bpf_verifier_env *env)
>                                 mark_force_checkpoint(env, t);
>                         }
>                 }
> +               if (insn->src_reg == 0)
> +                       update_nocsr_pattern_marks(env, t);

as you mentioned, we discussed moving this from check_cfg() step, as
it doesn't seem to be coupled with "graph" part of the algorithm

>                 return visit_func_call_insn(t, insns, env, insn->src_reg == BPF_PSEUDO_CALL);
>
>         case BPF_JA:
> @@ -19063,15 +19310,16 @@ static int opt_remove_dead_code(struct bpf_verifier_env *env)
>         return 0;
>  }
>
> +static const struct bpf_insn NOP = BPF_JMP_IMM(BPF_JA, 0, 0, 0);
> +
>  static int opt_remove_nops(struct bpf_verifier_env *env)
>  {
> -       const struct bpf_insn ja = BPF_JMP_IMM(BPF_JA, 0, 0, 0);
>         struct bpf_insn *insn = env->prog->insnsi;
>         int insn_cnt = env->prog->len;
>         int i, err;
>
>         for (i = 0; i < insn_cnt; i++) {
> -               if (memcmp(&insn[i], &ja, sizeof(ja)))
> +               if (memcmp(&insn[i], &NOP, sizeof(NOP)))
>                         continue;
>
>                 err = verifier_remove_insns(env, i, 1);
> @@ -20801,6 +21049,39 @@ static int optimize_bpf_loop(struct bpf_verifier_env *env)
>         return 0;
>  }
>
> +/* Remove unnecessary spill/fill pairs, members of nocsr pattern.
> + * Do this as a separate pass to avoid interfering with helper/kfunc
> + * inlining logic in do_misc_fixups().
> + * See comment for match_and_mark_nocsr_pattern().
> + */
> +static int remove_nocsr_spills_fills(struct bpf_verifier_env *env)
> +{
> +       struct bpf_subprog_info *subprogs = env->subprog_info;
> +       int i, j, spills_num, cur_subprog = 0;
> +       struct bpf_insn *insn = env->prog->insnsi;
> +       int insn_cnt = env->prog->len;
> +
> +       for (i = 0; i < insn_cnt; i++, insn++) {
> +               spills_num = match_nocsr_pattern(env, i);

we can probably afford a single-byte field somewhere in
bpf_insn_aux_data to remember "csr pattern size" instead of just a
true/false fact that it is nocsr call. And so we wouldn't need to do
pattern matching again here, we'll just have all the data.


> +               if (spills_num == 0)
> +                       goto next_insn;
> +               for (j = 1; j <= spills_num; ++j)
> +                       if ((insn - j)->off >= subprogs[cur_subprog].nocsr_stack_off ||
> +                           (insn + j)->off >= subprogs[cur_subprog].nocsr_stack_off)
> +                               goto next_insn;
> +               /* NOPs are removed by opt_remove_nops() later */
> +               for (j = 1; j <= spills_num; ++j) {
> +                       *(insn - j) = NOP;
> +                       *(insn + j) = NOP;
> +               }
> +
> +next_insn:
> +               if (subprogs[cur_subprog + 1].start == i + 1)
> +                       cur_subprog++;
> +       }
> +       return 0;
> +}
> +
>  static void free_states(struct bpf_verifier_env *env)
>  {
>         struct bpf_verifier_state_list *sl, *sln;
> @@ -21719,6 +22000,9 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr, bpfptr_t uattr, __u3
>         if (ret == 0)
>                 ret = optimize_bpf_loop(env);
>
> +       if (ret == 0)
> +               ret = remove_nocsr_spills_fills(env);
> +
>         if (is_priv) {
>                 if (ret == 0)
>                         opt_hard_wire_dead_code_branches(env);
> --
> 2.45.2
>





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux