Propagate nullness information for branches of register to register equality compare instructions. The following rules are used: - suppose register A maybe null - suppose register B is not null - for JNE A, B, ... - A is not null in the false branch - for JEQ A, B, ... - A is not null in the true branch E.g. for program like below: r6 = skb->sk; r7 = sk_fullsock(r6); r0 = sk_fullsock(r6); if (r0 == 0) return 0; (a) if (r0 != r7) return 0; (b) *r7->type; (c) return 0; It is safe to dereference r7 at point (c), because of (a) and (b). Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx> --- kernel/bpf/verifier.c | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 2c1f8069f7b7..c48d34625bfd 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -472,6 +472,11 @@ static bool type_may_be_null(u32 type) return type & PTR_MAYBE_NULL; } +static bool type_is_pointer(enum bpf_reg_type type) +{ + return type != NOT_INIT && type != SCALAR_VALUE; +} + static bool is_acquire_function(enum bpf_func_id func_id, const struct bpf_map *map) { @@ -10046,6 +10051,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, struct bpf_verifier_state *other_branch; struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs; struct bpf_reg_state *dst_reg, *other_branch_regs, *src_reg = NULL; + struct bpf_reg_state *eq_branch_regs; u8 opcode = BPF_OP(insn->code); bool is_jmp32; int pred = -1; @@ -10155,7 +10161,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, /* detect if we are comparing against a constant value so we can adjust * our min/max values for our dst register. * this is only legit if both are scalars (or pointers to the same - * object, I suppose, but we don't support that right now), because + * object, I suppose, see the next if block), because * otherwise the different base pointers mean the offsets aren't * comparable. */ @@ -10199,6 +10205,37 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, opcode, is_jmp32); } + /* if one pointer register is compared to another pointer + * register check if PTR_MAYBE_NULL could be lifted. + * E.g. register A - maybe null + * register B - not null + * for JNE A, B, ... - A is not null in the false branch; + * for JEQ A, B, ... - A is not null in the true branch. + */ + if (!is_jmp32 && + BPF_SRC(insn->code) == BPF_X && + type_is_pointer(src_reg->type) && type_is_pointer(dst_reg->type) && + type_may_be_null(src_reg->type) != type_may_be_null(dst_reg->type)) { + eq_branch_regs = NULL; + switch (opcode) { + case BPF_JEQ: + eq_branch_regs = other_branch_regs; + break; + case BPF_JNE: + eq_branch_regs = regs; + break; + default: + /* do nothing */ + break; + } + if (eq_branch_regs) { + if (type_may_be_null(src_reg->type)) + mark_ptr_not_null_reg(&eq_branch_regs[insn->src_reg]); + else + mark_ptr_not_null_reg(&eq_branch_regs[insn->dst_reg]); + } + } + if (dst_reg->type == SCALAR_VALUE && dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) { find_equal_scalars(this_branch, dst_reg); -- 2.37.1