This lets us test that both BPF_PROG_TYPE_CGROUP_SOCK_ADDR and BPF_PROG_TYPE_SOCK_OPS can access underlying bpf_sock. Cc: Martin Lau <kafai@xxxxxx> Signed-off-by: Stanislav Fomichev <sdf@xxxxxxxxxx> --- .../selftests/bpf/progs/socket_cookie_prog.c | 46 ++++++++++++------- .../selftests/bpf/test_socket_cookie.c | 24 ++++------ 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/tools/testing/selftests/bpf/progs/socket_cookie_prog.c b/tools/testing/selftests/bpf/progs/socket_cookie_prog.c index 9ff8ac4b0bf6..0db15c3210ad 100644 --- a/tools/testing/selftests/bpf/progs/socket_cookie_prog.c +++ b/tools/testing/selftests/bpf/progs/socket_cookie_prog.c @@ -7,25 +7,35 @@ #include "bpf_helpers.h" #include "bpf_endian.h" +struct socket_cookie { + __u64 cookie_key; + __u32 cookie_value; +}; + struct bpf_map_def SEC("maps") socket_cookies = { - .type = BPF_MAP_TYPE_HASH, - .key_size = sizeof(__u64), - .value_size = sizeof(__u32), - .max_entries = 1 << 8, + .type = BPF_MAP_TYPE_SK_STORAGE, + .key_size = sizeof(int), + .value_size = sizeof(struct socket_cookie), + .map_flags = BPF_F_NO_PREALLOC, }; +BPF_ANNOTATE_KV_PAIR(socket_cookies, int, struct socket_cookie); + SEC("cgroup/connect6") int set_cookie(struct bpf_sock_addr *ctx) { - __u32 cookie_value = 0xFF; - __u64 cookie_key; + struct socket_cookie *p; if (ctx->family != AF_INET6 || ctx->user_family != AF_INET6) return 1; - cookie_key = bpf_get_socket_cookie(ctx); - if (bpf_map_update_elem(&socket_cookies, &cookie_key, &cookie_value, 0)) - return 0; + p = bpf_sk_storage_get(&socket_cookies, ctx->sk, 0, + BPF_SK_STORAGE_GET_F_CREATE); + if (!p) + return 1; + + p->cookie_value = 0xFF; + p->cookie_key = bpf_get_socket_cookie(ctx); return 1; } @@ -33,9 +43,8 @@ int set_cookie(struct bpf_sock_addr *ctx) SEC("sockops") int update_cookie(struct bpf_sock_ops *ctx) { - __u32 new_cookie_value; - __u32 *cookie_value; - __u64 cookie_key; + struct bpf_sock *sk; + struct socket_cookie *p; if (ctx->family != AF_INET6) return 1; @@ -43,14 +52,17 @@ int update_cookie(struct bpf_sock_ops *ctx) if (ctx->op != BPF_SOCK_OPS_TCP_CONNECT_CB) return 1; - cookie_key = bpf_get_socket_cookie(ctx); + if (!ctx->sk) + return 1; + + p = bpf_sk_storage_get(&socket_cookies, ctx->sk, 0, 0); + if (!p) + return 1; - cookie_value = bpf_map_lookup_elem(&socket_cookies, &cookie_key); - if (!cookie_value) + if (p->cookie_key != bpf_get_socket_cookie(ctx)) return 1; - new_cookie_value = (ctx->local_port << 8) | *cookie_value; - bpf_map_update_elem(&socket_cookies, &cookie_key, &new_cookie_value, 0); + p->cookie_value = (ctx->local_port << 8) | p->cookie_value; return 1; } diff --git a/tools/testing/selftests/bpf/test_socket_cookie.c b/tools/testing/selftests/bpf/test_socket_cookie.c index cac8ee57a013..15653b0e26eb 100644 --- a/tools/testing/selftests/bpf/test_socket_cookie.c +++ b/tools/testing/selftests/bpf/test_socket_cookie.c @@ -18,6 +18,11 @@ #define CG_PATH "/foo" #define SOCKET_COOKIE_PROG "./socket_cookie_prog.o" +struct socket_cookie { + __u64 cookie_key; + __u32 cookie_value; +}; + static int start_server(void) { struct sockaddr_in6 addr; @@ -89,8 +94,7 @@ static int validate_map(struct bpf_map *map, int client_fd) __u32 cookie_expected_value; struct sockaddr_in6 addr; socklen_t len = sizeof(addr); - __u32 cookie_value; - __u64 cookie_key; + struct socket_cookie val; int err = 0; int map_fd; @@ -101,17 +105,7 @@ static int validate_map(struct bpf_map *map, int client_fd) map_fd = bpf_map__fd(map); - err = bpf_map_get_next_key(map_fd, NULL, &cookie_key); - if (err) { - log_err("Can't get cookie key from map"); - goto out; - } - - err = bpf_map_lookup_elem(map_fd, &cookie_key, &cookie_value); - if (err) { - log_err("Can't get cookie value from map"); - goto out; - } + err = bpf_map_lookup_elem(map_fd, &client_fd, &val); err = getsockname(client_fd, (struct sockaddr *)&addr, &len); if (err) { @@ -120,8 +114,8 @@ static int validate_map(struct bpf_map *map, int client_fd) } cookie_expected_value = (ntohs(addr.sin6_port) << 8) | 0xFF; - if (cookie_value != cookie_expected_value) { - log_err("Unexpected value in map: %x != %x", cookie_value, + if (val.cookie_value != cookie_expected_value) { + log_err("Unexpected value in map: %x != %x", val.cookie_value, cookie_expected_value); goto err; } -- 2.22.0.rc2.383.gf4fbbf30c2-goog