Re: [PATCH bpf-next 4/4] selftests/bpf: validate jit behaviour for tail calls

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Thu, Aug 8, 2024 at 6:05 PM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote:
>
> A program calling sub-program which does a tail call.
> The idea is to verify instructions generated by jit for tail calls:
> - in program and sub-program prologues;
> - for subprogram call instruction;
> - for tail call itself.
>
> Signed-off-by: Eduard Zingerman <eddyz87@xxxxxxxxx>
> ---
>  .../selftests/bpf/prog_tests/verifier.c       |   2 +
>  .../bpf/progs/verifier_tailcall_jit.c         | 103 ++++++++++++++++++
>  2 files changed, 105 insertions(+)
>  create mode 100644 tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c
>
> diff --git a/tools/testing/selftests/bpf/prog_tests/verifier.c b/tools/testing/selftests/bpf/prog_tests/verifier.c
> index f8f546eba488..cf3662dbd24f 100644
> --- a/tools/testing/selftests/bpf/prog_tests/verifier.c
> +++ b/tools/testing/selftests/bpf/prog_tests/verifier.c
> @@ -75,6 +75,7 @@
>  #include "verifier_stack_ptr.skel.h"
>  #include "verifier_subprog_precision.skel.h"
>  #include "verifier_subreg.skel.h"
> +#include "verifier_tailcall_jit.skel.h"
>  #include "verifier_typedef.skel.h"
>  #include "verifier_uninit.skel.h"
>  #include "verifier_unpriv.skel.h"
> @@ -198,6 +199,7 @@ void test_verifier_spin_lock(void)            { RUN(verifier_spin_lock); }
>  void test_verifier_stack_ptr(void)            { RUN(verifier_stack_ptr); }
>  void test_verifier_subprog_precision(void)    { RUN(verifier_subprog_precision); }
>  void test_verifier_subreg(void)               { RUN(verifier_subreg); }
> +void test_verifier_tailcall_jit(void)         { RUN(verifier_tailcall_jit); }
>  void test_verifier_typedef(void)              { RUN(verifier_typedef); }
>  void test_verifier_uninit(void)               { RUN(verifier_uninit); }
>  void test_verifier_unpriv(void)               { RUN(verifier_unpriv); }
> diff --git a/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c b/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c
> new file mode 100644
> index 000000000000..1a09c76d7be0
> --- /dev/null
> +++ b/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c
> @@ -0,0 +1,103 @@
> +// SPDX-License-Identifier: GPL-2.0
> +#include <linux/bpf.h>
> +#include <bpf/bpf_helpers.h>
> +#include "bpf_misc.h"
> +
> +int main(void);
> +
> +struct {
> +       __uint(type, BPF_MAP_TYPE_PROG_ARRAY);
> +       __uint(max_entries, 1);
> +       __uint(key_size, sizeof(__u32));
> +       __array(values, void (void));
> +} jmp_table SEC(".maps") = {
> +       .values = {
> +               [0] = (void *) &main,
> +       },
> +};
> +
> +__noinline __auxiliary
> +static __naked int sub(void)
> +{
> +       asm volatile (
> +       "r2 = %[jmp_table] ll;"
> +       "r3 = 0;"
> +       "call 12;"
> +       "exit;"
> +       :
> +       : __imm_addr(jmp_table)
> +       : __clobber_all);
> +}
> +
> +__success
> +/* program entry for main(), regular function prologue */
> +__jit_x86("    endbr64")
> +__jit_x86("    nopl    (%rax,%rax)")
> +__jit_x86("    xorq    %rax, %rax")
> +__jit_x86("    pushq   %rbp")
> +__jit_x86("    movq    %rsp, %rbp")

I'm a bit too lazy to fish it out of the code, so I'll just ask.
Does matching of __jit_x86() string behave in the same way as __msg().
I.e., there could be unexpected lines that would be skipped, as long
as we find a match for each __jit_x86() one?


Isn't that a bit counter-intuitive and potentially dangerous behavior
for checking disassembly? If my assumption is correct, maybe we should
add some sort of `__jit_x86("...")` placeholder to explicitly mark
that we allow some amount of lines to be skipped, but otherwise be
strict and require matching line-by-line?

> +/* tail call prologue for program:
> + * - establish memory location for tail call counter at &rbp[-8];
> + * - spill tail_call_cnt_ptr at &rbp[-16];
> + * - expect tail call counter to be passed in rax;
> + * - for entry program rax is a raw counter, value < 33;
> + * - for tail called program rax is tail_call_cnt_ptr (value > 33).
> + */
> +__jit_x86("    endbr64")
> +__jit_x86("    cmpq    $0x21, %rax")
> +__jit_x86("    ja      L0")
> +__jit_x86("    pushq   %rax")
> +__jit_x86("    movq    %rsp, %rax")
> +__jit_x86("    jmp     L1")
> +__jit_x86("L0: pushq   %rax")                  /* rbp[-8]  = rax         */
> +__jit_x86("L1: pushq   %rax")                  /* rbp[-16] = rax         */
> +/* on subprogram call restore rax to be tail_call_cnt_ptr from rbp[-16]
> + * (cause original rax might be clobbered by this point)
> + */
> +__jit_x86("    movq    -0x10(%rbp), %rax")
> +__jit_x86("    callq   0x[0-9a-f]\\+")         /* call to sub()          */
> +__jit_x86("    xorl    %eax, %eax")
> +__jit_x86("    leave")
> +__jit_x86("    retq")
> +/* subprogram entry for sub(), regular function prologue */
> +__jit_x86("    endbr64")
> +__jit_x86("    nopl    (%rax,%rax)")
> +__jit_x86("    nopl    (%rax)")
> +__jit_x86("    pushq   %rbp")
> +__jit_x86("    movq    %rsp, %rbp")
> +/* tail call prologue for subprogram address of tail call counter
> + * stored at rbp[-16].
> + */
> +__jit_x86("    endbr64")
> +__jit_x86("    pushq   %rax")                  /* rbp[-8]  = rax          */
> +__jit_x86("    pushq   %rax")                  /* rbp[-16] = rax          */
> +__jit_x86("    movabsq $-0x[0-9a-f]\\+, %rsi") /* r2 = &jmp_table         */
> +__jit_x86("    xorl    %edx, %edx")            /* r3 = 0                  */
> +/* bpf_tail_call implementation:
> + * - load tail_call_cnt_ptr from rbp[-16];
> + * - if *tail_call_cnt_ptr < 33, increment it and jump to target;
> + * - otherwise do nothing.
> + */
> +__jit_x86("    movq    -0x10(%rbp), %rax")
> +__jit_x86("    cmpq    $0x21, (%rax)")
> +__jit_x86("    jae     L0")
> +__jit_x86("    nopl    (%rax,%rax)")
> +__jit_x86("    addq    $0x1, (%rax)")          /* *tail_call_cnt_ptr += 1 */
> +__jit_x86("    popq    %rax")
> +__jit_x86("    popq    %rax")
> +__jit_x86("    jmp     0x[0-9a-f]\\+")         /* jump to tail call tgt   */
> +__jit_x86("L0: leave")
> +__jit_x86("    retq")
> +SEC("tc")
> +__naked int main(void)
> +{
> +       asm volatile (
> +       "call %[sub];"
> +       "r0 = 0;"
> +       "exit;"
> +       :
> +       : __imm(sub)
> +       : __clobber_all);
> +}
> +
> +char __license[] SEC("license") = "GPL";
> --
> 2.45.2
>





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux