On Wed, Feb 21, 2024 at 4:50 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote: > > Use bpf_verifier_state->jmp_history to track which registers were > updated by find_equal_scalars() when conditional jump was verified. > Use recorded information in backtrack_insn() to propagate precision. > > E.g. for the following program: > > while verifying instructions > r1 = r0 | > if r1 < 8 goto ... | push r0,r1 as equal_scalars in jmp_history > if r0 > 16 goto ... | push r0,r1 as equal_scalars in jmp_history > r2 = r10 | > r2 += r0 v mark_chain_precision(r0) > > while doing mark_chain_precision(r0) > r1 = r0 ^ > if r1 < 8 goto ... | mark r0,r1 as precise > if r0 > 16 goto ... | mark r0,r1 as precise > r2 = r10 | > r2 += r0 | mark r0 precise > > Technically achieve this in following steps: > - Use 10 bits to identify each register that gains range because of > find_equal_scalars(): > - 3 bits for frame number; > - 6 bits for register or stack slot number; > - 1 bit to indicate if register is spilled. > - Use u64 as a vector of 6 such records + 4 bits for vector length. > - Augment struct bpf_jmp_history_entry with field 'equal_scalars' > representing such vector. > - When doing check_cond_jmp_op() for remember up to 6 registers that > gain range because of find_equal_scalars() in such a vector. > - Don't propagate range information and reset IDs for registers that > don't fit in 6-value vector. > - Push collected vector to bpf_verifier_state->jmp_history for > instruction index of conditional jump. > - When doing backtrack_insn() for conditional jumps > check if any of recorded equal scalars is currently marked precise, > if so mark all equal recorded scalars as precise. > > Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()") > Reported-by: Hao Sun <sunhao.th@xxxxxxxxx> > Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@xxxxxxxxxxxxxx/ > Suggested-by: Andrii Nakryiko <andrii@xxxxxxxxxx> > Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > include/linux/bpf_verifier.h | 1 + > kernel/bpf/verifier.c | 207 ++++++++++++++++-- > .../bpf/progs/verifier_subprog_precision.c | 2 +- > .../testing/selftests/bpf/verifier/precise.c | 2 +- > 4 files changed, 195 insertions(+), 17 deletions(-) > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h > index cbfb235984c8..26e32555711c 100644 > --- a/include/linux/bpf_verifier.h > +++ b/include/linux/bpf_verifier.h > @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry { > u32 prev_idx : 22; > /* special flags, e.g., whether insn is doing register stack spill/load */ > u32 flags : 10; > + u64 equal_scalars; nit: should we call this concept as a bit more generic "linked registers" instead of "equal scalars"? > }; > > /* Maximum number of register states that can exist at once */ > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index 759ef089b33c..b95b6842703c 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -3304,6 +3304,76 @@ static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx) > return env->insn_aux_data[insn_idx].jmp_point; > } > > +#define ES_FRAMENO_BITS 3 > +#define ES_SPI_BITS 6 > +#define ES_ENTRY_BITS (ES_SPI_BITS + ES_FRAMENO_BITS + 1) > +#define ES_SIZE_BITS 4 > +#define ES_FRAMENO_MASK ((1ul << ES_FRAMENO_BITS) - 1) > +#define ES_SPI_MASK ((1ul << ES_SPI_BITS) - 1) > +#define ES_SIZE_MASK ((1ul << ES_SIZE_BITS) - 1) > +#define ES_SPI_OFF ES_FRAMENO_BITS > +#define ES_IS_REG_OFF (ES_SPI_BITS + ES_FRAMENO_BITS) > + > +/* Pack one history entry for equal scalars as 10 bits in the following format: > + * - 3-bits frameno > + * - 6-bits spi_or_reg > + * - 1-bit is_reg > + */ > +static u64 equal_scalars_pack(u32 frameno, u32 spi_or_reg, bool is_reg) > +{ > + u64 val = 0; > + > + val |= frameno & ES_FRAMENO_MASK; > + val |= (spi_or_reg & ES_SPI_MASK) << ES_SPI_OFF; > + val |= (is_reg ? 1 : 0) << ES_IS_REG_OFF; > + return val; > +} > + > +static void equal_scalars_unpack(u64 val, u32 *frameno, u32 *spi_or_reg, bool *is_reg) > +{ > + *frameno = val & ES_FRAMENO_MASK; > + *spi_or_reg = (val >> ES_SPI_OFF) & ES_SPI_MASK; > + *is_reg = (val >> ES_IS_REG_OFF) & 0x1; > +} > + > +static u32 equal_scalars_size(u64 equal_scalars) > +{ > + return equal_scalars & ES_SIZE_MASK; > +} > + > +/* Use u64 as a stack of 6 10-bit values, use first 4-bits to track > + * number of elements currently in stack. > + */ > +static bool equal_scalars_push(u64 *equal_scalars, u32 frameno, u32 spi_or_reg, bool is_reg) > +{ > + u32 num; > + > + num = equal_scalars_size(*equal_scalars); > + if (num == 6) > + return false; > + *equal_scalars >>= ES_SIZE_BITS; > + *equal_scalars <<= ES_ENTRY_BITS; > + *equal_scalars |= equal_scalars_pack(frameno, spi_or_reg, is_reg); > + *equal_scalars <<= ES_SIZE_BITS; > + *equal_scalars |= num + 1; > + return true; > +} > + > +static bool equal_scalars_pop(u64 *equal_scalars, u32 *frameno, u32 *spi_or_reg, bool *is_reg) > +{ > + u32 num; > + > + num = equal_scalars_size(*equal_scalars); > + if (num == 0) > + return false; > + *equal_scalars >>= ES_SIZE_BITS; > + equal_scalars_unpack(*equal_scalars, frameno, spi_or_reg, is_reg); > + *equal_scalars >>= ES_ENTRY_BITS; > + *equal_scalars <<= ES_SIZE_BITS; > + *equal_scalars |= num - 1; > + return true; > +} > + I'm wondering if this pop/push set of primitives is the best approach? What if we had pack/unpack operations, where for various checking logic we'd be working with "unpacked" representation, e.g., something like this: struct linked_reg_set { int cnt; struct { int frameno; union { int spi; int regno; }; bool is_set; bool is_reg; } reg_set[6]; }; bt_set_equal_scalars() could accept `struct linked_reg_set*` instead of bitmask itself. Same for find_equal_scalars(). I think even implementation of packing/unpacking would be more straightforward and we won't even need all those ES_xxx consts (or at least fewer of them). WDYT? > static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_state *st, > u32 hist_end, int insn_idx) > { [...]