In preparation for adding unix socket support to the bpf cgroup socket address hooks, start tracking the sockaddr length in the sockaddr tests which will be required when adding tests for unix sockets. --- tools/testing/selftests/bpf/test_sock_addr.c | 130 ++++++++++++------- 1 file changed, 85 insertions(+), 45 deletions(-) diff --git a/tools/testing/selftests/bpf/test_sock_addr.c b/tools/testing/selftests/bpf/test_sock_addr.c index 2c89674fc62c..6a618c8f477c 100644 --- a/tools/testing/selftests/bpf/test_sock_addr.c +++ b/tools/testing/selftests/bpf/test_sock_addr.c @@ -604,7 +604,7 @@ static struct sock_addr_test tests[] = { }; static int mk_sockaddr(int domain, const char *ip, unsigned short port, - struct sockaddr *addr, socklen_t addr_len) + struct sockaddr *addr, socklen_t *addr_len) { struct sockaddr_in6 *addr6; struct sockaddr_in *addr4; @@ -614,10 +614,10 @@ static int mk_sockaddr(int domain, const char *ip, unsigned short port, return -1; } - memset(addr, 0, addr_len); + memset(addr, 0, *addr_len); if (domain == AF_INET) { - if (addr_len < sizeof(struct sockaddr_in)) + if (*addr_len < sizeof(struct sockaddr_in)) return -1; addr4 = (struct sockaddr_in *)addr; addr4->sin_family = domain; @@ -626,8 +626,9 @@ static int mk_sockaddr(int domain, const char *ip, unsigned short port, log_err("Invalid IPv4: %s", ip); return -1; } + *addr_len = sizeof(struct sockaddr_in); } else if (domain == AF_INET6) { - if (addr_len < sizeof(struct sockaddr_in6)) + if (*addr_len < sizeof(struct sockaddr_in6)) return -1; addr6 = (struct sockaddr_in6 *)addr; addr6->sin6_family = domain; @@ -636,6 +637,7 @@ static int mk_sockaddr(int domain, const char *ip, unsigned short port, log_err("Invalid IPv6: %s", ip); return -1; } + *addr_len = sizeof(struct sockaddr_in6); } return 0; @@ -749,6 +751,7 @@ static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test) { struct sockaddr_in dst4_rw_addr; struct in_addr src4_rw_ip; + socklen_t dst4_rw_addr_len = sizeof(dst4_rw_addr); if (inet_pton(AF_INET, SRC4_REWRITE_IP, (void *)&src4_rw_ip) != 1) { log_err("Invalid IPv4: %s", SRC4_REWRITE_IP); @@ -757,7 +760,7 @@ static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test) if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT, (struct sockaddr *)&dst4_rw_addr, - sizeof(dst4_rw_addr)) == -1) + &dst4_rw_addr_len) == -1) return -1; struct bpf_insn insns[] = { @@ -812,6 +815,7 @@ static int sendmsg6_rw_dst_asm_prog_load(const struct sock_addr_test *test, { struct sockaddr_in6 dst6_rw_addr; struct in6_addr src6_rw_ip; + socklen_t dst6_rw_addr_len = sizeof(dst6_rw_addr); if (inet_pton(AF_INET6, SRC6_REWRITE_IP, (void *)&src6_rw_ip) != 1) { log_err("Invalid IPv6: %s", SRC6_REWRITE_IP); @@ -820,7 +824,7 @@ static int sendmsg6_rw_dst_asm_prog_load(const struct sock_addr_test *test, if (mk_sockaddr(AF_INET6, rw_dst_ip, SERV6_REWRITE_PORT, (struct sockaddr *)&dst6_rw_addr, - sizeof(dst6_rw_addr)) == -1) + &dst6_rw_addr_len) == -1) return -1; struct bpf_insn insns[] = { @@ -885,8 +889,9 @@ static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test) return load_path(test, SENDMSG6_PROG_PATH); } -static int cmp_addr(const struct sockaddr_storage *addr1, - const struct sockaddr_storage *addr2, int cmp_port) +static int cmp_addr(const struct sockaddr_storage *addr1, socklen_t addr1_len, + const struct sockaddr_storage *addr2, socklen_t addr2_len, + int cmp_port) { const struct sockaddr_in *four1, *four2; const struct sockaddr_in6 *six1, *six2; @@ -894,6 +899,9 @@ static int cmp_addr(const struct sockaddr_storage *addr1, if (addr1->ss_family != addr2->ss_family) return -1; + if (addr1_len != addr2_len) + return -1; + if (addr1->ss_family == AF_INET) { four1 = (const struct sockaddr_in *)addr1; four2 = (const struct sockaddr_in *)addr2; @@ -911,7 +919,8 @@ static int cmp_addr(const struct sockaddr_storage *addr1, } static int cmp_sock_addr(info_fn fn, int sock1, - const struct sockaddr_storage *addr2, int cmp_port) + const struct sockaddr_storage *addr2, + socklen_t addr2_len, int cmp_port) { struct sockaddr_storage addr1; socklen_t len1 = sizeof(addr1); @@ -920,22 +929,28 @@ static int cmp_sock_addr(info_fn fn, int sock1, if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0) return -1; - return cmp_addr(&addr1, addr2, cmp_port); + return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port); } -static int cmp_local_ip(int sock1, const struct sockaddr_storage *addr2) +static int cmp_local_ip(int sock1, const struct sockaddr_storage *addr2, + socklen_t addr2_len) { - return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 0); + return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, + /*cmp_port*/ 0); } -static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2) +static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2, + socklen_t addr2_len) { - return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 1); + return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, + /*cmp_port*/ 1); } -static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2) +static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2, + socklen_t addr2_len) { - return cmp_sock_addr(getpeername, sock1, addr2, /*cmp_port*/ 1); + return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, + /*cmp_port*/ 1); } static int start_server(int type, const struct sockaddr_storage *addr, @@ -1109,7 +1124,8 @@ static int fastconnect_to_server(const struct sockaddr_storage *addr, MSG_FASTOPEN, &sendmsg_err); } -static int recvmsg_from_client(int sockfd, struct sockaddr_storage *src_addr) +static int recvmsg_from_client(int sockfd, struct sockaddr_storage *src_addr, + socklen_t *src_addr_len) { struct timeval tv; struct msghdr hdr; @@ -1133,31 +1149,39 @@ static int recvmsg_from_client(int sockfd, struct sockaddr_storage *src_addr) memset(&hdr, 0, sizeof(hdr)); hdr.msg_name = src_addr; - hdr.msg_namelen = sizeof(struct sockaddr_storage); + hdr.msg_namelen = *src_addr_len; hdr.msg_iov = &iov; hdr.msg_iovlen = 1; - return recvmsg(sockfd, &hdr, 0); + if (recvmsg(sockfd, &hdr, 0) < 0) + return -1; + + *src_addr_len = hdr.msg_namelen; + return 0; } static int init_addrs(const struct sock_addr_test *test, struct sockaddr_storage *requested_addr, + socklen_t *requested_addr_len, struct sockaddr_storage *expected_addr, - struct sockaddr_storage *expected_src_addr) + socklen_t *expected_addr_len, + struct sockaddr_storage *expected_src_addr, + socklen_t *expected_src_addr_len) { - socklen_t addr_len = sizeof(struct sockaddr_storage); - if (mk_sockaddr(test->domain, test->expected_ip, test->expected_port, - (struct sockaddr *)expected_addr, addr_len) == -1) + (struct sockaddr *)expected_addr, + expected_addr_len) == -1) goto err; if (mk_sockaddr(test->domain, test->requested_ip, test->requested_port, - (struct sockaddr *)requested_addr, addr_len) == -1) + (struct sockaddr *)requested_addr, + requested_addr_len) == -1) goto err; if (test->expected_src_ip && mk_sockaddr(test->domain, test->expected_src_ip, 0, - (struct sockaddr *)expected_src_addr, addr_len) == -1) + (struct sockaddr *)expected_src_addr, + expected_src_addr_len) == -1) goto err; return 0; @@ -1167,25 +1191,28 @@ static int init_addrs(const struct sock_addr_test *test, static int run_bind_test_case(const struct sock_addr_test *test) { - socklen_t addr_len = sizeof(struct sockaddr_storage); struct sockaddr_storage requested_addr; struct sockaddr_storage expected_addr; + socklen_t requested_addr_len = sizeof(struct sockaddr_storage); + socklen_t expected_addr_len = sizeof(struct sockaddr_storage); int clientfd = -1; int servfd = -1; int err = 0; - if (init_addrs(test, &requested_addr, &expected_addr, NULL)) + if (init_addrs(test, &requested_addr, &requested_addr_len, + &expected_addr, &expected_addr_len, NULL, NULL)) goto err; - servfd = start_server(test->type, &requested_addr, addr_len); + servfd = start_server(test->type, &requested_addr, requested_addr_len); if (servfd == -1) goto err; - if (cmp_local_addr(servfd, &expected_addr)) + if (cmp_local_addr(servfd, &expected_addr, expected_addr_len)) goto err; /* Try to connect to server just in case */ - clientfd = connect_to_server(test->type, &expected_addr, addr_len); + clientfd = connect_to_server(test->type, &expected_addr, + expected_addr_len); if (clientfd == -1) goto err; @@ -1204,28 +1231,33 @@ static int run_connect_test_case(const struct sock_addr_test *test) struct sockaddr_storage expected_src_addr; struct sockaddr_storage requested_addr; struct sockaddr_storage expected_addr; + socklen_t expected_src_addr_len = sizeof(struct sockaddr_storage); + socklen_t requested_addr_len = sizeof(struct sockaddr_storage); + socklen_t expected_addr_len = sizeof(struct sockaddr_storage); int clientfd = -1; int servfd = -1; int err = 0; - if (init_addrs(test, &requested_addr, &expected_addr, - &expected_src_addr)) + if (init_addrs(test, &requested_addr, &requested_addr_len, + &expected_addr, &expected_addr_len, &expected_src_addr, + &expected_src_addr_len)) goto err; /* Prepare server to connect to */ - servfd = start_server(test->type, &expected_addr, addr_len); + servfd = start_server(test->type, &expected_addr, expected_addr_len); if (servfd == -1) goto err; - clientfd = connect_to_server(test->type, &requested_addr, addr_len); + clientfd = connect_to_server(test->type, &requested_addr, + requested_addr_len); if (clientfd == -1) goto err; /* Make sure src and dst addrs were overridden properly */ - if (cmp_peer_addr(clientfd, &expected_addr)) + if (cmp_peer_addr(clientfd, &expected_addr, expected_addr_len)) goto err; - if (cmp_local_ip(clientfd, &expected_src_addr)) + if (cmp_local_ip(clientfd, &expected_src_addr, expected_src_addr_len)) goto err; if (test->type == SOCK_STREAM) { @@ -1235,10 +1267,11 @@ static int run_connect_test_case(const struct sock_addr_test *test) goto err; /* Make sure src and dst addrs were overridden properly */ - if (cmp_peer_addr(clientfd, &expected_addr)) + if (cmp_peer_addr(clientfd, &expected_addr, expected_addr_len)) goto err; - if (cmp_local_ip(clientfd, &expected_src_addr)) + if (cmp_local_ip(clientfd, &expected_src_addr, + expected_src_addr_len)) goto err; } @@ -1253,11 +1286,14 @@ static int run_connect_test_case(const struct sock_addr_test *test) static int run_xmsg_test_case(const struct sock_addr_test *test, int max_cmsg) { - socklen_t addr_len = sizeof(struct sockaddr_storage); struct sockaddr_storage expected_addr; struct sockaddr_storage server_addr; struct sockaddr_storage sendmsg_addr; struct sockaddr_storage recvmsg_addr; + socklen_t expected_addr_len = sizeof(struct sockaddr_storage); + socklen_t server_addr_len = sizeof(struct sockaddr_storage); + socklen_t sendmsg_addr_len = sizeof(struct sockaddr_storage); + socklen_t recvmsg_addr_len = sizeof(struct sockaddr_storage); int clientfd = -1; int servfd = -1; int set_cmsg; @@ -1266,11 +1302,12 @@ static int run_xmsg_test_case(const struct sock_addr_test *test, int max_cmsg) if (test->type != SOCK_DGRAM) goto err; - if (init_addrs(test, &sendmsg_addr, &server_addr, &expected_addr)) + if (init_addrs(test, &sendmsg_addr, &sendmsg_addr_len, &server_addr, + &server_addr_len, &expected_addr, &expected_addr_len)) goto err; /* Prepare server to sendmsg to */ - servfd = start_server(test->type, &server_addr, addr_len); + servfd = start_server(test->type, &server_addr, server_addr_len); if (servfd == -1) goto err; @@ -1279,8 +1316,8 @@ static int run_xmsg_test_case(const struct sock_addr_test *test, int max_cmsg) close(clientfd); clientfd = sendmsg_to_server(test->type, &sendmsg_addr, - addr_len, set_cmsg, /*flags*/0, - &err); + sendmsg_addr_len, set_cmsg, + /*flags*/ 0, &err); if (err) goto out; else if (clientfd == -1) @@ -1298,10 +1335,13 @@ static int run_xmsg_test_case(const struct sock_addr_test *test, int max_cmsg) * specific packet may differ from the one used by default and * returned by getsockname(2). */ - if (recvmsg_from_client(servfd, &recvmsg_addr) == -1) + if (recvmsg_from_client(servfd, &recvmsg_addr, + &recvmsg_addr_len) == -1) goto err; - if (cmp_addr(&recvmsg_addr, &expected_addr, /*cmp_port*/0)) + if (cmp_addr(&recvmsg_addr, recvmsg_addr_len, &expected_addr, + expected_addr_len, + /*cmp_port*/ 0)) goto err; } -- 2.38.1