On Sat, Jun 29, 2024 at 2:48 AM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote: > > GCC and LLVM define a no_caller_saved_registers function attribute. > This attribute means that function scratches only some of > the caller saved registers defined by ABI. > For BPF the set of such registers could be defined as follows: > - R0 is scratched only if function is non-void; > - R1-R5 are scratched only if corresponding parameter type is defined > in the function prototype. > > This commit introduces flag bpf_func_prot->nocsr. > If this flag is set for some helper function, verifier assumes that > it follows no_caller_saved_registers calling convention. > > The contract between kernel and clang allows to simultaneously use > such functions and maintain backwards compatibility with old > kernels that don't understand no_caller_saved_registers calls > (nocsr for short): > > - clang generates a simple pattern for nocsr calls, e.g.: > > r1 = 1; > r2 = 2; > *(u64 *)(r10 - 8) = r1; > *(u64 *)(r10 - 16) = r2; > call %[to_be_inlined_by_jit] "inline_by_jit" is misleading, it can be inlined by BPF verifier using BPF instructions, not just by BPF JIT > r2 = *(u64 *)(r10 - 16); > r1 = *(u64 *)(r10 - 8); > r0 = r1; > r0 += r2; > exit; > > - kernel removes unnecessary spills and fills, if called function is > inlined by current JIT (with assumption that patch inserted by JIT > honors nocsr contract, e.g. does not scratch r3-r5 for the example > above), e.g. the code above would be transformed to: > > r1 = 1; > r2 = 2; > call %[to_be_inlined_by_jit] > r0 = r1; > r0 += r2; > exit; > > Technically, the transformation is split into the following phases: > - during check_cfg() function update_nocsr_pattern_marks() is used to > find potential patterns; > - upon stack read or write access, > function check_nocsr_stack_contract() is used to verify if > stack offsets, presumably reserved for nocsr patterns, are used > only from those patterns; > - function remove_nocsr_spills_fills(), called from bpf_check(), > applies the rewrite for valid patterns. > > See comment in match_and_mark_nocsr_pattern() for more details. > > Suggested-by: Alexei Starovoitov <alexei.starovoitov@xxxxxxxxx> > Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > include/linux/bpf.h | 6 + > include/linux/bpf_verifier.h | 9 ++ > kernel/bpf/verifier.c | 300 ++++++++++++++++++++++++++++++++++- > 3 files changed, 307 insertions(+), 8 deletions(-) > [...] > -static int find_subprog(struct bpf_verifier_env *env, int off) > +/* Find subprogram that contains instruction at 'off' */ > +static int find_containing_subprog(struct bpf_verifier_env *env, int off) > { > - struct bpf_subprog_info *p; > + struct bpf_subprog_info *vals = env->subprog_info; > + int high = env->subprog_cnt - 1; > + int low = 0, ret = -ENOENT; > > - p = bsearch(&off, env->subprog_info, env->subprog_cnt, > - sizeof(env->subprog_info[0]), cmp_subprogs); > - if (!p) > + if (off >= env->prog->len || off < 0) > return -ENOENT; > - return p - env->subprog_info; > > + while (low <= high) { > + int mid = (low + high)/2; styling nit: (...) / 2 > + struct bpf_subprog_info *val = &vals[mid]; > + int diff = off - val->start; > + > + if (diff < 0) { tbh, this hurts my brain. Why not write human-readable and more meaningful if (off < val->start) ? > + high = mid - 1; > + } else { > + low = mid + 1; > + /* remember last time mid.start <= off */ > + ret = mid; > + } feel free to ignore, but I find this unnecessary `ret = mid` part a bit inelegant. See find_linfo in kernel/bpf/log.c for how lower_bound-like binary search could be implemented without this (I mean the pattern where invariant keeps low or high as always satisfying the condition and the other one being adjusted with +1 or -1, depending on desired logic). > + } > + return ret; > +} > + > +/* Find subprogram that starts exactly at 'off' */ > +static int find_subprog(struct bpf_verifier_env *env, int off) > +{ > + int idx; > + > + idx = find_containing_subprog(env, off); > + if (idx < 0 || env->subprog_info[idx].start != off) > + return -ENOENT; > + return idx; > } > [...] > +static u8 get_helper_reg_mask(const struct bpf_func_proto *fn) > +{ > + u8 mask; > + int i; > + > + if (!fn->nocsr) > + return ALL_CALLER_SAVED_REGS; > + > + mask = 0; > + mask |= fn->ret_type == RET_VOID ? 0 : BIT(BPF_REG_0); > + for (i = 0; i < ARRAY_SIZE(fn->arg_type); ++i) > + mask |= fn->arg_type[i] == ARG_DONTCARE ? 0 : BIT(BPF_REG_1 + i); again subjective, but if (fn->ret_type != RET_VOID) mask |= BIT(BPF_REG_0); (and similarly for ARG_DONTCARE) seems a bit more readable and not that much more verbose > + return mask; > +} > + > +/* True if do_misc_fixups() replaces calls to helper number 'imm', > + * replacement patch is presumed to follow no_caller_saved_registers contract > + * (see match_and_mark_nocsr_pattern() below). > + */ > +static bool verifier_inlines_helper_call(struct bpf_verifier_env *env, s32 imm) > +{ note that there is now also bpf_jit_inlines_helper_call() > + return false; > +} > + > +/* If 'insn' is a call that follows no_caller_saved_registers contract > + * and called function is inlined by current jit, return a mask with > + * 1s corresponding to registers that are scratched by this call > + * (depends on return type and number of return parameters). > + * Otherwise return ALL_CALLER_SAVED_REGS mask. > + */ > +static u32 call_csr_mask(struct bpf_verifier_env *env, struct bpf_insn *insn) > +{ > + const struct bpf_func_proto *fn; > + > + if (bpf_helper_call(insn) && > + verifier_inlines_helper_call(env, insn->imm) && strictly speaking, does nocsr have anything to do with inlining, though? E.g., if we know for sure (however, that's a separate issue) that helper implementation doesn't touch extra registers, why do we need inlining to make use of nocsr? > + get_helper_proto(env, insn->imm, &fn) == 0 && > + fn->nocsr) > + return ~get_helper_reg_mask(fn); > + > + return ALL_CALLER_SAVED_REGS; > +} > + [...] > + * For example, it is *not* safe to remove spill/fill below: > + * > + * r1 = 1; > + * *(u64 *)(r10 - 8) = r1; r1 = 1; > + * call %[to_be_inlined_by_jit] --> call %[to_be_inlined_by_jit] > + * r1 = *(u64 *)(r10 - 8); r0 = *(u64 *)(r10 - 8); <---- wrong !!! > + * r0 = *(u64 *)(r10 - 8); r0 += r1; > + * r0 += r1; exit; > + * exit; > + */ > +static int match_and_mark_nocsr_pattern(struct bpf_verifier_env *env, int t, bool mark) > +{ > + struct bpf_insn *insns = env->prog->insnsi, *stx, *ldx; > + struct bpf_subprog_info *subprog; > + u32 csr_mask = call_csr_mask(env, &insns[t]); > + u32 reg_mask = ~csr_mask | ~ALL_CALLER_SAVED_REGS; > + int s, i; > + s16 off; > + > + if (csr_mask == ALL_CALLER_SAVED_REGS) > + return false; false -> 0 ? > + > + for (i = 1, off = 0; i <= ARRAY_SIZE(caller_saved); ++i, off += BPF_REG_SIZE) { > + if (t - i < 0 || t + i >= env->prog->len) > + break; > + stx = &insns[t - i]; > + ldx = &insns[t + i]; > + if (off == 0) { > + off = stx->off; > + if (off % BPF_REG_SIZE != 0) > + break; just return here? > + } > + if (/* *(u64 *)(r10 - off) = r[0-5]? */ > + stx->code != (BPF_STX | BPF_MEM | BPF_DW) || > + stx->dst_reg != BPF_REG_10 || > + /* r[0-5] = *(u64 *)(r10 - off)? */ > + ldx->code != (BPF_LDX | BPF_MEM | BPF_DW) || > + ldx->src_reg != BPF_REG_10 || > + /* check spill/fill for the same reg and offset */ > + stx->src_reg != ldx->dst_reg || > + stx->off != ldx->off || > + stx->off != off || > + /* this should be a previously unseen register */ > + BIT(stx->src_reg) & reg_mask) > + break; > + reg_mask |= BIT(stx->src_reg); > + if (mark) { > + env->insn_aux_data[t - i].nocsr_pattern = true; > + env->insn_aux_data[t + i].nocsr_pattern = true; > + } > + } > + if (i == 1) > + return 0; > + if (mark) { > + s = find_containing_subprog(env, t); > + /* can't happen */ > + if (WARN_ON_ONCE(s < 0)) > + return 0; > + subprog = &env->subprog_info[s]; > + subprog->nocsr_stack_off = min(subprog->nocsr_stack_off, off); > + } why not split pattern detection and all this other marking logic? You can return "the size of csr pattern", meaning how many spills/fills are there surrounding the call, no? Then all the marking can be done (if necessary) by the caller. The question is what to do with zero patter (no spills/fills for nocsr call, is that valid case?) > + return i - 1; > +} > + > +/* If instruction 't' is a nocsr call surrounded by spill/fill pairs, > + * update env->subprog_info[_]->nocsr_stack_off and > + * env->insn_aux_data[_].nocsr_pattern fields. > + */ > +static void update_nocsr_pattern_marks(struct bpf_verifier_env *env, int t) > +{ > + match_and_mark_nocsr_pattern(env, t, true); > +} > + > +/* If instruction 't' is a nocsr call surrounded by spill/fill pairs, > + * return the number of such pairs. > + */ > +static int match_nocsr_pattern(struct bpf_verifier_env *env, int t) > +{ > + return match_and_mark_nocsr_pattern(env, t, false); > +} > + > /* Visits the instruction at index t and returns one of the following: > * < 0 - an error occurred > * DONE_EXPLORING - the instruction was fully explored > @@ -16017,6 +16262,8 @@ static int visit_insn(int t, struct bpf_verifier_env *env) > mark_force_checkpoint(env, t); > } > } > + if (insn->src_reg == 0) > + update_nocsr_pattern_marks(env, t); as you mentioned, we discussed moving this from check_cfg() step, as it doesn't seem to be coupled with "graph" part of the algorithm > return visit_func_call_insn(t, insns, env, insn->src_reg == BPF_PSEUDO_CALL); > > case BPF_JA: > @@ -19063,15 +19310,16 @@ static int opt_remove_dead_code(struct bpf_verifier_env *env) > return 0; > } > > +static const struct bpf_insn NOP = BPF_JMP_IMM(BPF_JA, 0, 0, 0); > + > static int opt_remove_nops(struct bpf_verifier_env *env) > { > - const struct bpf_insn ja = BPF_JMP_IMM(BPF_JA, 0, 0, 0); > struct bpf_insn *insn = env->prog->insnsi; > int insn_cnt = env->prog->len; > int i, err; > > for (i = 0; i < insn_cnt; i++) { > - if (memcmp(&insn[i], &ja, sizeof(ja))) > + if (memcmp(&insn[i], &NOP, sizeof(NOP))) > continue; > > err = verifier_remove_insns(env, i, 1); > @@ -20801,6 +21049,39 @@ static int optimize_bpf_loop(struct bpf_verifier_env *env) > return 0; > } > > +/* Remove unnecessary spill/fill pairs, members of nocsr pattern. > + * Do this as a separate pass to avoid interfering with helper/kfunc > + * inlining logic in do_misc_fixups(). > + * See comment for match_and_mark_nocsr_pattern(). > + */ > +static int remove_nocsr_spills_fills(struct bpf_verifier_env *env) > +{ > + struct bpf_subprog_info *subprogs = env->subprog_info; > + int i, j, spills_num, cur_subprog = 0; > + struct bpf_insn *insn = env->prog->insnsi; > + int insn_cnt = env->prog->len; > + > + for (i = 0; i < insn_cnt; i++, insn++) { > + spills_num = match_nocsr_pattern(env, i); we can probably afford a single-byte field somewhere in bpf_insn_aux_data to remember "csr pattern size" instead of just a true/false fact that it is nocsr call. And so we wouldn't need to do pattern matching again here, we'll just have all the data. > + if (spills_num == 0) > + goto next_insn; > + for (j = 1; j <= spills_num; ++j) > + if ((insn - j)->off >= subprogs[cur_subprog].nocsr_stack_off || > + (insn + j)->off >= subprogs[cur_subprog].nocsr_stack_off) > + goto next_insn; > + /* NOPs are removed by opt_remove_nops() later */ > + for (j = 1; j <= spills_num; ++j) { > + *(insn - j) = NOP; > + *(insn + j) = NOP; > + } > + > +next_insn: > + if (subprogs[cur_subprog + 1].start == i + 1) > + cur_subprog++; > + } > + return 0; > +} > + > static void free_states(struct bpf_verifier_env *env) > { > struct bpf_verifier_state_list *sl, *sln; > @@ -21719,6 +22000,9 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr, bpfptr_t uattr, __u3 > if (ret == 0) > ret = optimize_bpf_loop(env); > > + if (ret == 0) > + ret = remove_nocsr_spills_fills(env); > + > if (is_priv) { > if (ret == 0) > opt_hard_wire_dead_code_branches(env); > -- > 2.45.2 >