[PATCH 5.10 10/95] bpf: Add helper macro bpf_for_each_reg_in_vstate

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



From: Kumar Kartikeya Dwivedi <memxor@xxxxxxxxx>

[ Upstream commit b239da34203f49c40b5d656220c39647c3ff0b3c ]

For a lot of use cases in future patches, we will want to modify the
state of registers part of some same 'group' (e.g. same ref_obj_id). It
won't just be limited to releasing reference state, but setting a type
flag dynamically based on certain actions, etc.

Hence, we need a way to easily pass a callback to the function that
iterates over all registers in current bpf_verifier_state in all frames
upto (and including) the curframe.

While in C++ we would be able to easily use a lambda to pass state and
the callback together, sadly we aren't using C++ in the kernel. The next
best thing to avoid defining a function for each case seems like
statement expressions in GNU C. The kernel already uses them heavily,
hence they can passed to the macro in the style of a lambda. The
statement expression will then be substituted in the for loop bodies.

Variables __state and __reg are set to current bpf_func_state and reg
for each invocation of the expression inside the passed in verifier
state.

Then, convert mark_ptr_or_null_regs, clear_all_pkt_pointers,
release_reference, find_good_pkt_pointers, find_equal_scalars to
use bpf_for_each_reg_in_vstate.

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@xxxxxxxxx>
Link: https://lore.kernel.org/r/20220904204145.3089-16-memxor@xxxxxxxxx
Signed-off-by: Alexei Starovoitov <ast@xxxxxxxxxx>
Stable-dep-of: f1db20814af5 ("bpf: Fix wrong reg type conversion in release_reference()")
Signed-off-by: Sasha Levin <sashal@xxxxxxxxxx>
---
 include/linux/bpf_verifier.h |  21 ++++++
 kernel/bpf/verifier.c        | 135 ++++++++---------------------------
 2 files changed, 49 insertions(+), 107 deletions(-)

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index f49165f9229c..4d37c69e76b1 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -290,6 +290,27 @@ struct bpf_verifier_state {
 	     iter < frame->allocated_stack / BPF_REG_SIZE;		\
 	     iter++, reg = bpf_get_spilled_reg(iter, frame))
 
+/* Invoke __expr over regsiters in __vst, setting __state and __reg */
+#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr)   \
+	({                                                               \
+		struct bpf_verifier_state *___vstate = __vst;            \
+		int ___i, ___j;                                          \
+		for (___i = 0; ___i <= ___vstate->curframe; ___i++) {    \
+			struct bpf_reg_state *___regs;                   \
+			__state = ___vstate->frame[___i];                \
+			___regs = __state->regs;                         \
+			for (___j = 0; ___j < MAX_BPF_REG; ___j++) {     \
+				__reg = &___regs[___j];                  \
+				(void)(__expr);                          \
+			}                                                \
+			bpf_for_each_spilled_reg(___j, __state, __reg) { \
+				if (!__reg)                              \
+					continue;                        \
+				(void)(__expr);                          \
+			}                                                \
+		}                                                        \
+	})
+
 /* linked list of verifier states used to prune search */
 struct bpf_verifier_state_list {
 	struct bpf_verifier_state state;
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 510a54471f13..3a0f288f538c 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -4993,31 +4993,15 @@ static int check_func_proto(const struct bpf_func_proto *fn, int func_id)
 /* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
  * are now invalid, so turn them into unknown SCALAR_VALUE.
  */
-static void __clear_all_pkt_pointers(struct bpf_verifier_env *env,
-				     struct bpf_func_state *state)
+static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
 {
-	struct bpf_reg_state *regs = state->regs, *reg;
-	int i;
-
-	for (i = 0; i < MAX_BPF_REG; i++)
-		if (reg_is_pkt_pointer_any(&regs[i]))
-			mark_reg_unknown(env, regs, i);
+	struct bpf_func_state *state;
+	struct bpf_reg_state *reg;
 
-	bpf_for_each_spilled_reg(i, state, reg) {
-		if (!reg)
-			continue;
+	bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
 		if (reg_is_pkt_pointer_any(reg))
 			__mark_reg_unknown(env, reg);
-	}
-}
-
-static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
-{
-	struct bpf_verifier_state *vstate = env->cur_state;
-	int i;
-
-	for (i = 0; i <= vstate->curframe; i++)
-		__clear_all_pkt_pointers(env, vstate->frame[i]);
+	}));
 }
 
 enum {
@@ -5046,41 +5030,24 @@ static void mark_pkt_end(struct bpf_verifier_state *vstate, int regn, bool range
 		reg->range = AT_PKT_END;
 }
 
-static void release_reg_references(struct bpf_verifier_env *env,
-				   struct bpf_func_state *state,
-				   int ref_obj_id)
-{
-	struct bpf_reg_state *regs = state->regs, *reg;
-	int i;
-
-	for (i = 0; i < MAX_BPF_REG; i++)
-		if (regs[i].ref_obj_id == ref_obj_id)
-			mark_reg_unknown(env, regs, i);
-
-	bpf_for_each_spilled_reg(i, state, reg) {
-		if (!reg)
-			continue;
-		if (reg->ref_obj_id == ref_obj_id)
-			__mark_reg_unknown(env, reg);
-	}
-}
-
 /* The pointer with the specified id has released its reference to kernel
  * resources. Identify all copies of the same pointer and clear the reference.
  */
 static int release_reference(struct bpf_verifier_env *env,
 			     int ref_obj_id)
 {
-	struct bpf_verifier_state *vstate = env->cur_state;
+	struct bpf_func_state *state;
+	struct bpf_reg_state *reg;
 	int err;
-	int i;
 
 	err = release_reference_state(cur_func(env), ref_obj_id);
 	if (err)
 		return err;
 
-	for (i = 0; i <= vstate->curframe; i++)
-		release_reg_references(env, vstate->frame[i], ref_obj_id);
+	bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
+		if (reg->ref_obj_id == ref_obj_id)
+			__mark_reg_unknown(env, reg);
+	}));
 
 	return 0;
 }
@@ -7219,34 +7186,14 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
 	return 0;
 }
 
-static void __find_good_pkt_pointers(struct bpf_func_state *state,
-				     struct bpf_reg_state *dst_reg,
-				     enum bpf_reg_type type, int new_range)
-{
-	struct bpf_reg_state *reg;
-	int i;
-
-	for (i = 0; i < MAX_BPF_REG; i++) {
-		reg = &state->regs[i];
-		if (reg->type == type && reg->id == dst_reg->id)
-			/* keep the maximum range already checked */
-			reg->range = max(reg->range, new_range);
-	}
-
-	bpf_for_each_spilled_reg(i, state, reg) {
-		if (!reg)
-			continue;
-		if (reg->type == type && reg->id == dst_reg->id)
-			reg->range = max(reg->range, new_range);
-	}
-}
-
 static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
 				   struct bpf_reg_state *dst_reg,
 				   enum bpf_reg_type type,
 				   bool range_right_open)
 {
-	int new_range, i;
+	struct bpf_func_state *state;
+	struct bpf_reg_state *reg;
+	int new_range;
 
 	if (dst_reg->off < 0 ||
 	    (dst_reg->off == 0 && range_right_open))
@@ -7311,9 +7258,11 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
 	 * the range won't allow anything.
 	 * dst_reg->off is known < MAX_PACKET_OFF, therefore it fits in a u16.
 	 */
-	for (i = 0; i <= vstate->curframe; i++)
-		__find_good_pkt_pointers(vstate->frame[i], dst_reg, type,
-					 new_range);
+	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+		if (reg->type == type && reg->id == dst_reg->id)
+			/* keep the maximum range already checked */
+			reg->range = max(reg->range, new_range);
+	}));
 }
 
 static int is_branch32_taken(struct bpf_reg_state *reg, u32 val, u8 opcode)
@@ -7826,7 +7775,7 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
 			reg->ref_obj_id = 0;
 		} else if (!reg_may_point_to_spin_lock(reg)) {
 			/* For not-NULL ptr, reg->ref_obj_id will be reset
-			 * in release_reg_references().
+			 * in release_reference().
 			 *
 			 * reg->id is still used by spin_lock ptr. Other
 			 * than spin_lock ptr type, reg->id can be reset.
@@ -7836,22 +7785,6 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
 	}
 }
 
-static void __mark_ptr_or_null_regs(struct bpf_func_state *state, u32 id,
-				    bool is_null)
-{
-	struct bpf_reg_state *reg;
-	int i;
-
-	for (i = 0; i < MAX_BPF_REG; i++)
-		mark_ptr_or_null_reg(state, &state->regs[i], id, is_null);
-
-	bpf_for_each_spilled_reg(i, state, reg) {
-		if (!reg)
-			continue;
-		mark_ptr_or_null_reg(state, reg, id, is_null);
-	}
-}
-
 /* The logic is similar to find_good_pkt_pointers(), both could eventually
  * be folded together at some point.
  */
@@ -7859,10 +7792,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
 				  bool is_null)
 {
 	struct bpf_func_state *state = vstate->frame[vstate->curframe];
-	struct bpf_reg_state *regs = state->regs;
+	struct bpf_reg_state *regs = state->regs, *reg;
 	u32 ref_obj_id = regs[regno].ref_obj_id;
 	u32 id = regs[regno].id;
-	int i;
 
 	if (ref_obj_id && ref_obj_id == id && is_null)
 		/* regs[regno] is in the " == NULL" branch.
@@ -7871,8 +7803,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
 		 */
 		WARN_ON_ONCE(release_reference_state(state, id));
 
-	for (i = 0; i <= vstate->curframe; i++)
-		__mark_ptr_or_null_regs(vstate->frame[i], id, is_null);
+	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+		mark_ptr_or_null_reg(state, reg, id, is_null);
+	}));
 }
 
 static bool try_match_pkt_pointers(const struct bpf_insn *insn,
@@ -7985,23 +7918,11 @@ static void find_equal_scalars(struct bpf_verifier_state *vstate,
 {
 	struct bpf_func_state *state;
 	struct bpf_reg_state *reg;
-	int i, j;
 
-	for (i = 0; i <= vstate->curframe; i++) {
-		state = vstate->frame[i];
-		for (j = 0; j < MAX_BPF_REG; j++) {
-			reg = &state->regs[j];
-			if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
-				*reg = *known_reg;
-		}
-
-		bpf_for_each_spilled_reg(j, state, reg) {
-			if (!reg)
-				continue;
-			if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
-				*reg = *known_reg;
-		}
-	}
+	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+		if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+			*reg = *known_reg;
+	}));
 }
 
 static int check_cond_jmp_op(struct bpf_verifier_env *env,
-- 
2.35.1






[Index of Archives]     [Linux Kernel]     [Kernel Development Newbies]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite Hiking]     [Linux Kernel]     [Linux SCSI]

  Powered by Linux