On Wed, Feb 21, 2024 at 4:50 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote: > > Use bpf_verifier_state->jmp_history to track which registers were > updated by find_equal_scalars() when conditional jump was verified. > Use recorded information in backtrack_insn() to propagate precision. > > E.g. for the following program: > > while verifying instructions > r1 = r0 | > if r1 < 8 goto ... | push r0,r1 as equal_scalars in jmp_history > if r0 > 16 goto ... | push r0,r1 as equal_scalars in jmp_history > r2 = r10 | > r2 += r0 v mark_chain_precision(r0) > > while doing mark_chain_precision(r0) > r1 = r0 ^ > if r1 < 8 goto ... | mark r0,r1 as precise > if r0 > 16 goto ... | mark r0,r1 as precise > r2 = r10 | > r2 += r0 | mark r0 precise > > Technically achieve this in following steps: > - Use 10 bits to identify each register that gains range because of > find_equal_scalars(): > - 3 bits for frame number; > - 6 bits for register or stack slot number; > - 1 bit to indicate if register is spilled. > - Use u64 as a vector of 6 such records + 4 bits for vector length. > - Augment struct bpf_jmp_history_entry with field 'equal_scalars' > representing such vector. > - When doing check_cond_jmp_op() for remember up to 6 registers that > gain range because of find_equal_scalars() in such a vector. > - Don't propagate range information and reset IDs for registers that > don't fit in 6-value vector. > - Push collected vector to bpf_verifier_state->jmp_history for > instruction index of conditional jump. > - When doing backtrack_insn() for conditional jumps > check if any of recorded equal scalars is currently marked precise, > if so mark all equal recorded scalars as precise. > > Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()") > Reported-by: Hao Sun <sunhao.th@xxxxxxxxx> > Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@xxxxxxxxxxxxxx/ > Suggested-by: Andrii Nakryiko <andrii@xxxxxxxxxx> > Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > include/linux/bpf_verifier.h | 1 + > kernel/bpf/verifier.c | 207 ++++++++++++++++-- > .../bpf/progs/verifier_subprog_precision.c | 2 +- > .../testing/selftests/bpf/verifier/precise.c | 2 +- > 4 files changed, 195 insertions(+), 17 deletions(-) > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h > index cbfb235984c8..26e32555711c 100644 > --- a/include/linux/bpf_verifier.h > +++ b/include/linux/bpf_verifier.h > @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry { > u32 prev_idx : 22; > /* special flags, e.g., whether insn is doing register stack spill/load */ > u32 flags : 10; > + u64 equal_scalars; > }; > [...] > @@ -3314,7 +3384,7 @@ static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_stat > > /* for any branch, call, exit record the history of jmps in the given state */ > static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_state *cur, > - int insn_flags) > + int insn_flags, u64 equal_scalars) > { > struct bpf_jmp_history_entry *p, *cur_hist_ent; > u32 cnt = cur->jmp_history_cnt; > @@ -3332,6 +3402,12 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st > "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n", > env->insn_idx, cur_hist_ent->flags, insn_flags); > cur_hist_ent->flags |= insn_flags; > + if (cur_hist_ent->equal_scalars != 0) { > + verbose(env, "verifier bug: insn_idx %d equal_scalars != 0: %#llx\n", > + env->insn_idx, cur_hist_ent->equal_scalars); > + return -EFAULT; > + } let's do WARN_ONCE() just like we do for flags? why deviating? > + cur_hist_ent->equal_scalars = equal_scalars; > return 0; > } > > @@ -3346,6 +3422,7 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st > p->idx = env->insn_idx; > p->prev_idx = env->prev_insn_idx; > p->flags = insn_flags; > + p->equal_scalars = equal_scalars; > cur->jmp_history_cnt = cnt; > > return 0; [...] > static bool calls_callback(struct bpf_verifier_env *env, int insn_idx); > > /* For given verifier state backtrack_insn() is called from the last insn to > @@ -3802,6 +3917,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx, > */ > return 0; > } else if (BPF_SRC(insn->code) == BPF_X) { > + bt_set_equal_scalars(bt, hist); > if (!bt_is_reg_set(bt, dreg) && !bt_is_reg_set(bt, sreg)) > return 0; > /* dreg <cond> sreg > @@ -3812,6 +3928,9 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx, > */ > bt_set_reg(bt, dreg); > bt_set_reg(bt, sreg); > + bt_set_equal_scalars(bt, hist); > + } else if (BPF_SRC(insn->code) == BPF_K) { > + bt_set_equal_scalars(bt, hist); Can you please elaborate why we are doing bt_set_equal_scalars() in these three places and not everywhere else? I'm trying to understand whether we should do it more generically for any instruction either before or after all the bt_set_xxx() calls... > /* else dreg <cond> K > * Only dreg still needs precision before > * this insn, so for the K-based conditional > @@ -4579,7 +4698,7 @@ static int check_stack_write_fixed_off(struct bpf_verifier_env *env, > } > > if (insn_flags) > - return push_jmp_history(env, env->cur_state, insn_flags); > + return push_jmp_history(env, env->cur_state, insn_flags, 0); > return 0; > } > [...]