[PATCH] bpf: Support bpf shadow stack

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



To support shadow stack, for each program in jit, allocate the
stack in the entry of bpf program, and free the stack in
the exit of bpf program.

This works for all bpf selftests, but it is expensive.
To avoid runtime kmalloc, we could preallocate some spaces,
e.g., percpu pages to be used for stack. This should work
for non-sleepable programs. For sleepable program, current
kmalloc/free may still work since performance is not critical.

Signed-off-by: Yonghong Song <yonghong.song@xxxxxxxxx>
---
 arch/x86/net/bpf_jit_comp.c | 200 ++++++++++++++++++++++++++++++++----
 include/linux/bpf.h         |   3 +
 kernel/bpf/core.c           |  25 +++++
 3 files changed, 209 insertions(+), 19 deletions(-)

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 673fdbd765d7..653792af3b11 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -267,7 +267,7 @@ struct jit_context {
 };
 
 /* Maximum number of bytes emitted while JITing one eBPF insn */
-#define BPF_MAX_INSN_SIZE	128
+#define BPF_MAX_INSN_SIZE	160
 #define BPF_INSN_SAFETY		64
 
 /* Number of bytes emit_patch() needs to generate instructions */
@@ -275,6 +275,14 @@ struct jit_context {
 /* Number of bytes that will be skipped on tailcall */
 #define X86_TAIL_CALL_OFFSET	(11 + ENDBR_INSN_SIZE)
 
+static void push_r9(u8 **pprog)
+{
+	u8 *prog = *pprog;
+
+	EMIT2(0x41, 0x51);   /* push r9 */
+	*pprog = prog;
+}
+
 static void push_r12(u8 **pprog)
 {
 	u8 *prog = *pprog;
@@ -298,6 +306,14 @@ static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
 	*pprog = prog;
 }
 
+static void pop_r9(u8 **pprog)
+{
+	u8 *prog = *pprog;
+
+	EMIT2(0x41, 0x59);   /* pop r9 */
+	*pprog = prog;
+}
+
 static void pop_r12(u8 **pprog)
 {
 	u8 *prog = *pprog;
@@ -437,6 +453,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
 		 * first restore those callee-saved regs from stack, before
 		 * reusing the stack frame.
 		 */
+		pop_r9(&prog);
 		pop_callee_regs(&prog, all_callee_regs_used);
 		pop_r12(&prog);
 		/* Reset the stack frame. */
@@ -589,6 +606,9 @@ static void emit_return(u8 **pprog, u8 *ip)
 	*pprog = prog;
 }
 
+static int emit_shadow_stack_free(u8 **pprog, struct bpf_prog *bpf_prog,
+				  u8 *ip, u8 *temp);
+
 /*
  * Generate the following code:
  *
@@ -603,14 +623,14 @@ static void emit_return(u8 **pprog, u8 *ip)
  *   goto *(prog->bpf_func + prologue_size);
  * out:
  */
-static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
+static int emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
 					u8 **pprog, bool *callee_regs_used,
-					u32 stack_depth, u8 *ip,
+					u32 stack_depth, u8 *ip, u8 *temp,
 					struct jit_context *ctx)
 {
 	int tcc_off = -4 - round_up(stack_depth, 8);
 	u8 *prog = *pprog, *start = *pprog;
-	int offset;
+	int err, offset;
 
 	/*
 	 * rdi - pointer to ctx
@@ -626,8 +646,8 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
 	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 	      offsetof(struct bpf_array, map.max_entries));
 
-	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
-	EMIT2(X86_JBE, offset);                   /* jbe out */
+	offset = ctx->tail_call_indirect_label - (prog + 6 - start);
+	EMIT2_off32(0x0f, 0x86, offset);                   /* jbe out */
 
 	/*
 	 * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
@@ -654,10 +674,16 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
 	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
 	EMIT2(X86_JE, offset);                    /* je out */
 
+	err = emit_shadow_stack_free(&prog, bpf_prog, ip, temp);
+	if (err)
+		return err;
+
+	pop_r9(&prog);
 	if (bpf_prog->aux->exception_boundary) {
 		pop_callee_regs(&prog, all_callee_regs_used);
 		pop_r12(&prog);
 	} else {
+		pop_r9(&prog);
 		pop_callee_regs(&prog, callee_regs_used);
 		if (bpf_arena_get_kern_vm_start(bpf_prog->aux->arena))
 			pop_r12(&prog);
@@ -683,17 +709,18 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
 	/* out: */
 	ctx->tail_call_indirect_label = prog - start;
 	*pprog = prog;
+	return 0;
 }
 
-static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
+static int emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
 				      struct bpf_jit_poke_descriptor *poke,
-				      u8 **pprog, u8 *ip,
+				      u8 **pprog, u8 *ip, u8 *temp,
 				      bool *callee_regs_used, u32 stack_depth,
 				      struct jit_context *ctx)
 {
 	int tcc_off = -4 - round_up(stack_depth, 8);
 	u8 *prog = *pprog, *start = *pprog;
-	int offset;
+	int err, offset;
 
 	/*
 	 * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
@@ -715,10 +742,16 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
 	emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
 		  poke->tailcall_bypass);
 
+	err = emit_shadow_stack_free(&prog, bpf_prog, ip, temp);
+	if (err)
+		return err;
+
+	pop_r9(&prog);
 	if (bpf_prog->aux->exception_boundary) {
 		pop_callee_regs(&prog, all_callee_regs_used);
 		pop_r12(&prog);
 	} else {
+		pop_r9(&prog);
 		pop_callee_regs(&prog, callee_regs_used);
 		if (bpf_arena_get_kern_vm_start(bpf_prog->aux->arena))
 			pop_r12(&prog);
@@ -734,6 +767,7 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
 	ctx->tail_call_direct_label = prog - start;
 
 	*pprog = prog;
+	return 0;
 }
 
 static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
@@ -1311,6 +1345,103 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
 	*pprog = prog;
 }
 
+/* call bpf_shadow_stack_alloc function. Preserve r1-r5 registers. */
+static int emit_shadow_stack_alloc(u8 **pprog, struct bpf_prog *bpf_prog,
+				   u8 *image, u8 *temp)
+{
+	int offs;
+	u8 *func;
+
+	/* save parameters to preserve original bpf arguments. */
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_1);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_2);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_3);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_4);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_5);
+	push_r9(pprog);
+
+	emit_mov_imm64(pprog, BPF_REG_1, (long) bpf_prog >> 32, (u32) (long) bpf_prog);
+	func = (u8 *)bpf_shadow_stack_alloc;
+	offs = *pprog - temp;
+	offs += x86_call_depth_emit_accounting(pprog, func);
+	if (emit_call(pprog, func, image + offs))
+		return -EINVAL;
+
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_5, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_4, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_3, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_2, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_1, X86_REG_R9);
+
+	/* Save the frame pointer to the stack so it can be
+	 * retrieved later.
+	 */
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_0);
+	push_r9(pprog);
+
+	return 0;
+}
+
+/* call bpf_shadow_stack_free function. Preserve r0-r5 registers. */
+static int emit_shadow_stack_free(u8 **pprog, struct bpf_prog *bpf_prog,
+				  u8 *ip, u8 *temp)
+{
+	int offs;
+	u8 *func;
+
+	pop_r9(pprog);
+	push_r9(pprog);
+	/* X86_REG_R9 holds the shadow frame pointer */
+	emit_mov_reg(pprog, true, AUX_REG, X86_REG_R9);
+
+	/* save reg 0-5 to preserve original values */
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_0);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_1);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_2);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_3);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_4);
+	push_r9(pprog);
+	emit_mov_reg(pprog, true, X86_REG_R9, BPF_REG_5);
+	push_r9(pprog);
+
+	emit_mov_imm64(pprog, BPF_REG_1, (long) bpf_prog >> 32, (u32) (long) bpf_prog);
+	emit_mov_reg(pprog, true, BPF_REG_2, AUX_REG);
+	func = (u8 *)bpf_shadow_stack_free;
+	offs = *pprog - temp;
+	offs += x86_call_depth_emit_accounting(pprog, func);
+	if (emit_call(pprog, func, ip + offs))
+		return -EINVAL;
+
+	/* restore reg 0-5 to preserve original values */
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_5, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_4, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_3, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_2, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_1, X86_REG_R9);
+	pop_r9(pprog);
+	emit_mov_reg(pprog, true, BPF_REG_0, X86_REG_R9);
+
+	return 0;
+}
+
 #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
 
 /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
@@ -1328,11 +1459,14 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 	bool seen_exit = false;
 	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
 	u64 arena_vm_start, user_vm_start;
-	int i, excnt = 0;
+	int i, excnt = 0, stack_depth;
 	int ilen, proglen = 0;
 	u8 *prog = temp;
 	int err;
 
+	/* enable shadow stack */
+	stack_depth = 0;
+
 	arena_vm_start = bpf_arena_get_kern_vm_start(bpf_prog->aux->arena);
 	user_vm_start = bpf_arena_get_user_vm_start(bpf_prog->aux->arena);
 
@@ -1342,7 +1476,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 	/* tail call's presence in current prog implies it is reachable */
 	tail_call_reachable |= tail_call_seen;
 
-	emit_prologue(&prog, bpf_prog->aux->stack_depth,
+	emit_prologue(&prog, stack_depth,
 		      bpf_prog_was_classic(bpf_prog), tail_call_reachable,
 		      bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb);
 	/* Exception callback will clobber callee regs for its own use, and
@@ -1359,11 +1493,17 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 		if (arena_vm_start)
 			push_r12(&prog);
 		push_callee_regs(&prog, callee_regs_used);
+		/* save r9 */
+		push_r9(&prog);
 	}
 	if (arena_vm_start)
 		emit_mov_imm64(&prog, X86_REG_R12,
 			       arena_vm_start >> 32, (u32) arena_vm_start);
 
+	err = emit_shadow_stack_alloc(&prog, bpf_prog, image, temp);
+	if (err)
+		return err;
+
 	ilen = prog - temp;
 	if (rw_image)
 		memcpy(rw_image + proglen, temp, ilen);
@@ -1371,6 +1511,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 	addrs[0] = proglen;
 	prog = temp;
 
+
 	for (i = 1; i <= insn_cnt; i++, insn++) {
 		const s32 imm32 = insn->imm;
 		u32 dst_reg = insn->dst_reg;
@@ -1383,6 +1524,18 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
 		u8 *func;
 		int nops;
 
+		if (src_reg == BPF_REG_FP) {
+			pop_r9(&prog);
+			push_r9(&prog);
+			src_reg = X86_REG_R9;
+		}
+
+		if (dst_reg == BPF_REG_FP) {
+			pop_r9(&prog);
+			push_r9(&prog);
+			dst_reg = X86_REG_R9;
+		}
+
 		switch (insn->code) {
 			/* ALU */
 		case BPF_ALU | BPF_ADD | BPF_X:
@@ -2045,7 +2198,7 @@ st:			if (is_imm8(insn->off))
 
 			func = (u8 *) __bpf_call_base + imm32;
 			if (tail_call_reachable) {
-				RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
+				RESTORE_TAIL_CALL_CNT(stack_depth);
 				if (!imm32)
 					return -EINVAL;
 				offs = 7 + x86_call_depth_emit_accounting(&prog, func);
@@ -2061,19 +2214,21 @@ st:			if (is_imm8(insn->off))
 
 		case BPF_JMP | BPF_TAIL_CALL:
 			if (imm32)
-				emit_bpf_tail_call_direct(bpf_prog,
+				err = emit_bpf_tail_call_direct(bpf_prog,
 							  &bpf_prog->aux->poke_tab[imm32 - 1],
-							  &prog, image + addrs[i - 1],
+							  &prog, image + addrs[i - 1], temp,
 							  callee_regs_used,
-							  bpf_prog->aux->stack_depth,
+							  stack_depth,
 							  ctx);
 			else
-				emit_bpf_tail_call_indirect(bpf_prog,
+				err = emit_bpf_tail_call_indirect(bpf_prog,
 							    &prog,
 							    callee_regs_used,
-							    bpf_prog->aux->stack_depth,
-							    image + addrs[i - 1],
+							    stack_depth,
+							    image + addrs[i - 1], temp,
 							    ctx);
+			if (err)
+				return err;
 			break;
 
 			/* cond jump */
@@ -2322,10 +2477,17 @@ st:			if (is_imm8(insn->off))
 			seen_exit = true;
 			/* Update cleanup_addr */
 			ctx->cleanup_addr = proglen;
+
+			err = emit_shadow_stack_free(&prog, bpf_prog, image + addrs[i - 1], temp);
+			if (err)
+				return err;
+
+			pop_r9(&prog);
 			if (bpf_prog->aux->exception_boundary) {
 				pop_callee_regs(&prog, all_callee_regs_used);
 				pop_r12(&prog);
 			} else {
+				pop_r9(&prog);
 				pop_callee_regs(&prog, callee_regs_used);
 				if (arena_vm_start)
 					pop_r12(&prog);
@@ -2347,7 +2509,7 @@ st:			if (is_imm8(insn->off))
 
 		ilen = prog - temp;
 		if (ilen > BPF_MAX_INSN_SIZE) {
-			pr_err("bpf_jit: fatal insn size error\n");
+			pr_err("bpf_jit: fatal insn size error: %d\n", ilen);
 			return -EFAULT;
 		}
 
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 5034c1b4ded7..b0f9ea882253 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -1133,6 +1133,9 @@ typedef void (*bpf_trampoline_exit_t)(struct bpf_prog *prog, u64 start,
 bpf_trampoline_enter_t bpf_trampoline_enter(const struct bpf_prog *prog);
 bpf_trampoline_exit_t bpf_trampoline_exit(const struct bpf_prog *prog);
 
+void * notrace bpf_shadow_stack_alloc(struct bpf_prog *prog);
+void notrace bpf_shadow_stack_free(struct bpf_prog *prog, void *shadow_frame);
+
 struct bpf_ksym {
 	unsigned long		 start;
 	unsigned long		 end;
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index a41718eaeefe..831841b5af7f 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -2434,6 +2434,31 @@ struct bpf_prog *bpf_prog_select_runtime(struct bpf_prog *fp, int *err)
 }
 EXPORT_SYMBOL_GPL(bpf_prog_select_runtime);
 
+void * notrace bpf_shadow_stack_alloc(struct bpf_prog *prog)
+{
+	int stack_depth = prog->aux->stack_depth;
+	void *shadow_stack;
+
+	if (!stack_depth)
+		return NULL;
+	shadow_stack = kmalloc(round_up(stack_depth, 16), __GFP_NORETRY);
+	if (!shadow_stack)
+		return NULL;
+	return shadow_stack + round_up(stack_depth, 16);
+}
+
+void notrace bpf_shadow_stack_free(struct bpf_prog *prog, void *shadow_frame)
+{
+	int stack_depth = prog->aux->stack_depth;
+	void *shadow_stack;
+
+	if (!shadow_frame)
+		return;
+
+	shadow_stack = shadow_frame - round_up(stack_depth, 16);
+	kfree(shadow_stack);
+}
+
 static unsigned int __bpf_prog_ret1(const void *ctx,
 				    const struct bpf_insn *insn)
 {
-- 
2.43.0





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux