On Thu, Jul 18, 2024 at 1:24 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote: > > Use bpf_verifier_state->jmp_history to track which registers were > updated by find_equal_scalars() (renamed to collect_linked_regs()) > when conditional jump was verified. Use recorded information in > backtrack_insn() to propagate precision. > > E.g. for the following program: > > while verifying instructions > 1: r1 = r0 | > 2: if r1 < 8 goto ... | push r0,r1 as linked registers in jmp_history > 3: if r0 > 16 goto ... | push r0,r1 as linked registers in jmp_history > 4: r2 = r10 | > 5: r2 += r0 v mark_chain_precision(r0) > > while doing mark_chain_precision(r0) > 5: r2 += r0 | mark r0 precise > 4: r2 = r10 | > 3: if r0 > 16 goto ... | mark r0,r1 as precise > 2: if r1 < 8 goto ... | mark r0,r1 as precise > 1: r1 = r0 v > > Technically, do this as follows: > - Use 10 bits to identify each register that gains range because of > sync_linked_regs(): > - 3 bits for frame number; > - 6 bits for register or stack slot number; > - 1 bit to indicate if register is spilled. > - Use u64 as a vector of 6 such records + 4 bits for vector length. > - Augment struct bpf_jmp_history_entry with a field 'linked_regs' > representing such vector. > - When doing check_cond_jmp_op() remember up to 6 registers that > gain range because of sync_linked_regs() in such a vector. > - Don't propagate range information and reset IDs for registers that > don't fit in 6-value vector. > - Push a pair {instruction index, linked registers vector} > to bpf_verifier_state->jmp_history. > - When doing backtrack_insn() check if any of recorded linked > registers is currently marked precise, if so mark all linked > registers as precise. > > This also requires fixes for two test_verifier tests: > - precise: test 1 > - precise: test 2 > > Both tests contain the following instruction sequence: > > 19: (bf) r2 = r9 ; R2=scalar(id=3) R9=scalar(id=3) > 20: (a5) if r2 < 0x8 goto pc+1 ; R2=scalar(id=3,umin=8) > 21: (95) exit > 22: (07) r2 += 1 ; R2_w=scalar(id=3+1,...) > 23: (bf) r1 = r10 ; R1_w=fp0 R10=fp0 > 24: (07) r1 += -8 ; R1_w=fp-8 > 25: (b7) r3 = 0 ; R3_w=0 > 26: (85) call bpf_probe_read_kernel#113 > > The call to bpf_probe_read_kernel() at (26) forces r2 to be precise. > Previously, this forced all registers with same id to become precise > immediately when mark_chain_precision() is called. > After this change, the precision is propagated to registers sharing > same id only when 'if' instruction is backtracked. > Hence verification log for both tests is changed: > regs=r2,r9 -> regs=r2 for instructions 25..20. > > Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()") > Reported-by: Hao Sun <sunhao.th@xxxxxxxxx> > Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@xxxxxxxxxxxxxx/ > Suggested-by: Andrii Nakryiko <andrii@xxxxxxxxxx> > Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx> > --- > include/linux/bpf_verifier.h | 4 + > kernel/bpf/verifier.c | 248 ++++++++++++++++-- > .../bpf/progs/verifier_subprog_precision.c | 2 +- > .../testing/selftests/bpf/verifier/precise.c | 20 +- > 4 files changed, 242 insertions(+), 32 deletions(-) > [...] > +/* For all R being scalar registers or spilled scalar registers > + * in verifier state, save R in linked_regs if R->id == id. > + * If there are too many Rs sharing same id, reset id for leftover Rs. > + */ > +static void collect_linked_regs(struct bpf_verifier_state *vstate, u32 id, > + struct linked_regs *linked_regs) > +{ > + struct bpf_func_state *func; > + struct bpf_reg_state *reg; > + int i, j; > + > + id = id & ~BPF_ADD_CONST; > + for (i = vstate->curframe; i >= 0; i--) { > + func = vstate->frame[i]; > + for (j = 0; j < BPF_REG_FP; j++) { > + reg = &func->regs[j]; > + __collect_linked_regs(linked_regs, reg, id, i, j, true); > + } > + for (j = 0; j < func->allocated_stack / BPF_REG_SIZE; j++) { > + if (!is_spilled_reg(&func->stack[j])) > + continue; > + reg = &func->stack[j].spilled_ptr; > + __collect_linked_regs(linked_regs, reg, id, i, j, false); > + } > + } > + > + if (linked_regs->cnt == 1) > + linked_regs->cnt = 0; We discussed this rather ugly condition w/ Eduard offline. I agreed to drop it and change the condition `linked_regs.cnt > 0` below to `linked_regs.cnt > 1`. It's unfortunate we can have one "self-linked" register, but it seems like unlinking the last remaining register would be prohibitively expensive (as we don't track how many linked registers for a given ID is there). Anyways, if we ever come to solve this, we can update `> 1` condition to a proper `> 0` one. For now they are equivalent, so it doesn't really matter much. > +} > + > +/* For all R in linked_regs, copy known_reg range into R > + * if R->id == known_reg->id. > + */ > +static void sync_linked_regs(struct bpf_verifier_state *vstate, struct bpf_reg_state *known_reg, > + struct linked_regs *linked_regs) > { > struct bpf_reg_state fake_reg; > - struct bpf_func_state *state; > struct bpf_reg_state *reg; > + struct linked_reg *e; > + int i; > > - bpf_for_each_reg_in_vstate(vstate, state, reg, ({ > + for (i = 0; i < linked_regs->cnt; ++i) { > + e = &linked_regs->entries[i]; > + reg = e->is_reg ? &vstate->frame[e->frameno]->regs[e->regno] > + : &vstate->frame[e->frameno]->stack[e->spi].spilled_ptr; > if (reg->type != SCALAR_VALUE || reg == known_reg) > continue; > if ((reg->id & ~BPF_ADD_CONST) != (known_reg->id & ~BPF_ADD_CONST)) [...]