[PATCH 01/12] selftests/net: Add TCP-AO library

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Provide functions to create selftests dedicated to TCP-AO.
They can run in parallel, as they use temporary net namespaces.
They can be very specific to the feature being tested.
This will allow to create a lot of TCP-AO tests, without complicating
one binary with many --options and to create scenarios, that are
hard to put in bash script that uses one binary.

Signed-off-by: Dmitry Safonov <dima@xxxxxxxxxx>
---
 tools/testing/selftests/Makefile                 |   1 +
 tools/testing/selftests/net/tcp_ao/.gitignore    |   2 +
 tools/testing/selftests/net/tcp_ao/Makefile      |  45 ++
 tools/testing/selftests/net/tcp_ao/connect.c     |  90 ++++
 tools/testing/selftests/net/tcp_ao/lib/aolib.h   | 605 +++++++++++++++++++++++
 tools/testing/selftests/net/tcp_ao/lib/kconfig.c | 148 ++++++
 tools/testing/selftests/net/tcp_ao/lib/netlink.c | 415 ++++++++++++++++
 tools/testing/selftests/net/tcp_ao/lib/proc.c    | 273 ++++++++++
 tools/testing/selftests/net/tcp_ao/lib/repair.c  | 254 ++++++++++
 tools/testing/selftests/net/tcp_ao/lib/setup.c   | 342 +++++++++++++
 tools/testing/selftests/net/tcp_ao/lib/sock.c    | 592 ++++++++++++++++++++++
 tools/testing/selftests/net/tcp_ao/lib/utils.c   |  30 ++
 12 files changed, 2797 insertions(+)

diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile
index 3b2061d1c1a5..f0c854d6511c 100644
--- a/tools/testing/selftests/Makefile
+++ b/tools/testing/selftests/Makefile
@@ -58,6 +58,7 @@ TARGETS += net/forwarding
 TARGETS += net/hsr
 TARGETS += net/mptcp
 TARGETS += net/openvswitch
+TARGETS += net/tcp_ao
 TARGETS += netfilter
 TARGETS += nsfs
 TARGETS += perf_events
diff --git a/tools/testing/selftests/net/tcp_ao/.gitignore b/tools/testing/selftests/net/tcp_ao/.gitignore
new file mode 100644
index 000000000000..e8bb81b715b7
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/.gitignore
@@ -0,0 +1,2 @@
+*_ipv4
+*_ipv6
diff --git a/tools/testing/selftests/net/tcp_ao/Makefile b/tools/testing/selftests/net/tcp_ao/Makefile
new file mode 100644
index 000000000000..62425b9fb73c
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/Makefile
@@ -0,0 +1,45 @@
+# SPDX-License-Identifier: GPL-2.0
+TEST_BOTH_AF := connect
+
+TEST_IPV4_PROGS := $(TEST_BOTH_AF:%=%_ipv4)
+TEST_IPV6_PROGS := $(TEST_BOTH_AF:%=%_ipv6)
+
+TEST_GEN_PROGS := $(TEST_IPV4_PROGS) $(TEST_IPV6_PROGS)
+
+top_srcdir	  := ../../../../..
+KSFT_KHDR_INSTALL := 1
+include ../../lib.mk
+
+HOSTAR ?= ar
+
+# Drop it on port to linux/master with commit 8ce72dc32578
+.DEFAULT_GOAL := all
+
+LIBDIR	:= $(OUTPUT)/lib
+LIB	:= $(LIBDIR)/libaotst.a
+LDLIBS	+= $(LIB) -pthread
+LIBDEPS	:= lib/aolib.h Makefile
+
+CFLAGS	:= -Wall -O2 -g -D_GNU_SOURCE -fno-strict-aliasing
+CFLAGS	+= -I ../../../../../usr/include/ -iquote $(LIBDIR)
+CFLAGS	+= -I ../../../../include/
+
+# Library
+LIBSRC	:= kconfig.c netlink.c proc.c repair.c setup.c sock.c utils.c
+LIBOBJ	:= $(LIBSRC:%.c=$(LIBDIR)/%.o)
+EXTRA_CLEAN += $(LIBOBJ) $(LIB)
+
+$(LIB): $(LIBOBJ)
+	$(HOSTAR) rcs $@ $^
+
+$(LIBDIR)/%.o: ./lib/%.c $(LIBDEPS)
+	$(CC) $< $(CFLAGS) $(CPPFLAGS) -o $@ -c
+
+$(TEST_GEN_PROGS): $(LIB)
+
+$(OUTPUT)/%_ipv4: %.c
+	$(LINK.c) $^ $(LDLIBS) -o $@
+
+$(OUTPUT)/%_ipv6: %.c
+	$(LINK.c) -DIPV6_TEST $^ $(LDLIBS) -o $@
+
diff --git a/tools/testing/selftests/net/tcp_ao/connect.c b/tools/testing/selftests/net/tcp_ao/connect.c
new file mode 100644
index 000000000000..81653b47f303
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/connect.c
@@ -0,0 +1,90 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@xxxxxxxxxx> */
+#include <inttypes.h>
+#include "aolib.h"
+
+static void *server_fn(void *arg)
+{
+	int sk, lsk;
+	ssize_t bytes;
+
+	lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
+
+	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+		test_error("setsockopt(TCP_AO_ADD_KEY)");
+	synchronize_threads();
+
+	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+		test_error("test_wait_fd()");
+
+	sk = accept(lsk, NULL, NULL);
+	if (sk < 0)
+		test_error("accept()");
+
+	synchronize_threads();
+
+	bytes = test_server_run(sk, 0, 0);
+
+	test_fail("server served: %zd", bytes);
+	return NULL;
+}
+
+static void *client_fn(void *arg)
+{
+	int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+	uint64_t before_aogood, after_aogood;
+	const size_t nr_packets = 20;
+	struct netstat *ns_before, *ns_after;
+	struct tcp_ao_counters ao1, ao2;
+
+	if (sk < 0)
+		test_error("socket()");
+
+	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+		test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+	synchronize_threads();
+	if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
+		test_error("failed to connect()");
+	synchronize_threads();
+
+	ns_before = netstat_read();
+	before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
+	if (test_get_tcp_ao_counters(sk, &ao1))
+		test_error("test_get_tcp_ao_counters()");
+
+	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
+		test_fail("verify failed");
+		return NULL;
+	}
+
+	ns_after = netstat_read();
+	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
+	if (test_get_tcp_ao_counters(sk, &ao2))
+		test_error("test_get_tcp_ao_counters()");
+	netstat_print_diff(ns_before, ns_after);
+	netstat_free(ns_before);
+	netstat_free(ns_after);
+
+	if (nr_packets > (after_aogood - before_aogood)) {
+		test_fail("TCPAOGood counter mismatch: %zu > (%zu - %zu)",
+				nr_packets, after_aogood, before_aogood);
+		return NULL;
+	}
+	if (test_tcp_ao_counters_cmp("connect", &ao1, &ao2, TEST_CNT_GOOD))
+		return NULL;
+
+	test_ok("connect TCPAOGood %" PRIu64 "/%" PRIu64 "/%" PRIu64 " => %" PRIu64 "/%" PRIu64 "/%" PRIu64 ", sent %" PRIu64,
+			before_aogood, ao1.ao_info_pkt_good,
+			ao1.key_cnts[0].pkt_good,
+			after_aogood, ao2.ao_info_pkt_good,
+			ao2.key_cnts[0].pkt_good,
+			nr_packets);
+	return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+	test_init(1, server_fn, client_fn);
+	return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/aolib.h b/tools/testing/selftests/net/tcp_ao/lib/aolib.h
new file mode 100644
index 000000000000..fbc7f6111815
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/aolib.h
@@ -0,0 +1,605 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * TCP-AO selftest library. Provides helpers to unshare network
+ * namespaces, create veth, assign ip addresses, set routes,
+ * manipulate socket options, read network counter and etc.
+ * Author: Dmitry Safonov <dima@xxxxxxxxxx>
+ */
+#ifndef _AOLIB_H_
+#define _AOLIB_H_
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <linux/snmp.h>
+#include <linux/tcp.h>
+#include <netinet/in.h>
+#include <stdarg.h>
+#include <stdbool.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "../../../../../include/linux/stringify.h"
+#include "../../../../../include/linux/bits.h"
+
+#ifndef SOL_TCP
+/* can't include <netinet/tcp.h> as including <linux/tcp.h> */
+# define SOL_TCP		6	/* TCP level */
+#endif
+
+/* Working around ksft, see the comment in lib/setup.c */
+extern void __test_msg(const char *buf);
+extern void __test_ok(const char *buf);
+extern void __test_fail(const char *buf);
+extern void __test_xfail(const char *buf);
+extern void __test_error(const char *buf);
+extern void __test_skip(const char *buf);
+
+__attribute__((__format__(__printf__, 2, 3)))
+static inline void __test_print(void (*fn)(const char *), const char *fmt, ...)
+{
+#define TEST_MSG_BUFFER_SIZE 4096
+	char buf[TEST_MSG_BUFFER_SIZE];
+	va_list arg;
+
+	va_start(arg, fmt);
+	vsnprintf(buf, sizeof(buf), fmt, arg);
+	va_end(arg);
+	fn(buf);
+}
+
+#define test_print(fmt, ...)						\
+	__test_print(__test_msg, "%ld[%s:%u] " fmt "\n",		\
+		     syscall(SYS_gettid),				\
+		     __FILE__, __LINE__, ##__VA_ARGS__)
+
+#define test_ok(fmt, ...)						\
+	__test_print(__test_ok, fmt "\n", ##__VA_ARGS__)
+#define test_skip(fmt, ...)						\
+	__test_print(__test_skip, fmt "\n", ##__VA_ARGS__)
+#define test_xfail(fmt, ...)						\
+	__test_print(__test_xfail, fmt "\n", ##__VA_ARGS__)
+
+#define test_fail(fmt, ...)						\
+do {									\
+	if (errno)							\
+		__test_print(__test_fail, fmt ": %m\n", ##__VA_ARGS__);	\
+	else								\
+		__test_print(__test_fail, fmt "\n", ##__VA_ARGS__);	\
+	test_failed();							\
+} while (0)
+
+#define KSFT_FAIL  1
+#define test_error(fmt, ...)						\
+do {									\
+	if (errno)							\
+		__test_print(__test_error, "%ld[%s:%u] " fmt ": %m\n",	\
+			     syscall(SYS_gettid), __FILE__, __LINE__,	\
+			     ##__VA_ARGS__);				\
+	else								\
+		__test_print(__test_error, "%ld[%s:%u] " fmt "\n",	\
+			     syscall(SYS_gettid), __FILE__, __LINE__,	\
+			     ##__VA_ARGS__);				\
+	exit(KSFT_FAIL);						\
+} while (0)
+
+enum test_fault {
+	FAULT_TIMEOUT = 1,
+	FAULT_KEYREJECT,
+	FAULT_PREINSTALL_AO,
+	FAULT_PREINSTALL_MD5,
+	FAULT_POSTINSTALL,
+	FAULT_BUSY,
+	FAULT_CURRNEXT,
+	FAULT_FIXME,
+};
+typedef enum test_fault fault_t;
+
+enum test_needs_kconfig {
+	KCONFIG_NET_NS = 0,		/* required */
+	KCONFIG_VETH,			/* required */
+	KCONFIG_TCP_AO,			/* required */
+	KCONFIG_TCP_MD5,		/* optional, for TCP-MD5 features */
+	KCONFIG_NET_VRF,		/* optional, for L3/VRF testing */
+	__KCONFIG_LAST__
+};
+extern bool kernel_config_has(enum test_needs_kconfig k);
+extern const char *tests_skip_reason[__KCONFIG_LAST__];
+static inline bool should_skip_test(const char *tst_name,
+				    enum test_needs_kconfig k)
+{
+	if (kernel_config_has(k))
+		return false;
+	test_skip("%s: %s", tst_name, tests_skip_reason[k]);
+	return true;
+}
+
+union tcp_addr {
+	struct in_addr a4;
+	struct in6_addr a6;
+};
+
+typedef void *(*thread_fn)(void *);
+extern void test_failed(void);
+extern void __test_init(unsigned int ntests, int family, unsigned int prefix,
+			union tcp_addr addr1, union tcp_addr addr2,
+			thread_fn peer1, thread_fn peer2);
+
+static inline void test_init2(unsigned int ntests,
+			      thread_fn peer1, thread_fn peer2,
+			      int family, unsigned int prefix,
+			      const char *addr1, const char *addr2)
+{
+	union tcp_addr taddr1, taddr2;
+
+	if (inet_pton(family, addr1, &taddr1) != 1)
+		test_error("Can't convert ip address %s", addr1);
+	if (inet_pton(family, addr2, &taddr2) != 1)
+		test_error("Can't convert ip address %s", addr2);
+
+	__test_init(ntests, family, prefix, taddr1, taddr2, peer1, peer2);
+}
+extern void test_add_destructor(void (*d)(void));
+
+/* To adjust optmem socket limit, approximately estimate a number,
+ * that is bigger than sizeof(struct tcp_ao_key).
+ */
+#define KERNEL_TCP_AO_KEY_SZ_ROUND_UP	300
+
+extern void test_set_optmem(size_t value);
+extern size_t test_get_optmem(void);
+
+extern const struct sockaddr_in6 addr_any6;
+extern const struct sockaddr_in addr_any4;
+
+#ifdef IPV6_TEST
+# define __TEST_CLIENT_IP(n)	("2001:db8:" __stringify(n) "::1")
+# define TEST_CLIENT_IP	__TEST_CLIENT_IP(1)
+# define TEST_WRONG_IP	"2001:db8:253::1"
+# define TEST_SERVER_IP	"2001:db8:254::1"
+# define TEST_NETWORK	"2001::"
+# define TEST_PREFIX	128
+# define TEST_FAMILY	AF_INET6
+# define SOCKADDR_ANY	addr_any6
+# define sockaddr_af	struct sockaddr_in6
+#else
+# define __TEST_CLIENT_IP(n)	("10.0." __stringify(n) ".1")
+# define TEST_CLIENT_IP	__TEST_CLIENT_IP(1)
+# define TEST_WRONG_IP	"10.0.253.1"
+# define TEST_SERVER_IP	"10.0.254.1"
+# define TEST_NETWORK	"10.0.0.0"
+# define TEST_PREFIX	32
+# define TEST_FAMILY	AF_INET
+# define SOCKADDR_ANY	addr_any4
+# define sockaddr_af	struct sockaddr_in
+#endif
+
+static inline union tcp_addr gen_tcp_addr(union tcp_addr net, size_t n)
+{
+	union tcp_addr ret = net;
+
+#ifdef IPV6_TEST
+	ret.a6.s6_addr32[3] = htonl(n & (BIT(32) - 1));
+	ret.a6.s6_addr32[2] = htonl((n >> 32) & (BIT(32) - 1));
+#else
+	ret.a4.s_addr = htonl(ntohl(net.a4.s_addr) + n);
+#endif
+
+	return ret;
+}
+
+static inline void tcp_addr_to_sockaddr_in(void *dest,
+					   const union tcp_addr *src,
+					   unsigned int port)
+{
+	sockaddr_af *out = dest;
+
+	memset(out, 0, sizeof(*out));
+#ifdef IPV6_TEST
+	out->sin6_family = AF_INET6;
+	out->sin6_port   = port;
+	out->sin6_addr   = src->a6;
+#else
+	out->sin_family  = AF_INET;
+	out->sin_port    = port;
+	out->sin_addr    = src->a4;
+#endif
+}
+
+static inline void test_init(unsigned int ntests,
+			     thread_fn peer1, thread_fn peer2)
+{
+	test_init2(ntests, peer1, peer2, TEST_FAMILY, TEST_PREFIX,
+			TEST_SERVER_IP, TEST_CLIENT_IP);
+}
+extern void synchronize_threads(void);
+extern void switch_ns(int fd);
+
+extern __thread union tcp_addr this_ip_addr;
+extern __thread union tcp_addr this_ip_dest;
+extern int test_family;
+
+extern void randomize_buffer(void *buf, size_t buflen);
+extern int open_netns(void);
+extern int unshare_open_netns(void);
+extern const char veth_name[];
+extern int add_veth(const char *name, int nsfda, int nsfdb);
+extern int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd);
+extern int ip_addr_add(const char *intf, int family,
+		       union tcp_addr addr, uint8_t prefix);
+extern int ip_route_add(const char *intf, int family,
+			union tcp_addr src, union tcp_addr dst);
+extern int ip_route_add_vrf(const char *intf, int family,
+			    union tcp_addr src, union tcp_addr dst,
+			    uint8_t vrf);
+extern int link_set_up(const char *intf);
+
+extern const unsigned int test_server_port;
+extern int test_wait_fd(int sk, time_t sec, bool write);
+extern int __test_connect_socket(int sk, const char *device,
+				 void *addr, size_t addr_sz, time_t timeout);
+extern int __test_listen_socket(int backlog, void *addr, size_t addr_sz);
+
+static inline int test_listen_socket(const union tcp_addr taddr,
+				     unsigned int port, int backlog)
+{
+	sockaddr_af addr;
+
+	tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
+	return __test_listen_socket(backlog, (void *)&addr, sizeof(addr));
+}
+
+/*
+ * In order for selftests to work under CONFIG_CRYPTO_FIPS=y,
+ * the password should be loger than 14 bytes, see hmac_setkey()
+ */
+#define TEST_TCP_AO_MINKEYLEN	14
+#define DEFAULT_TEST_PASSWORD	"In this hour, I do not believe that any darkness will endure."
+
+#ifndef DEFAULT_TEST_ALGO
+#define DEFAULT_TEST_ALGO	"cmac(aes128)"
+#endif
+
+#ifdef IPV6_TEST
+#define DEFAULT_TEST_PREFIX	128
+#else
+#define DEFAULT_TEST_PREFIX	32
+#endif
+
+/*
+ * Timeout on syscalls where failure is not expected.
+ * You may want to rise it if the test machine is very busy.
+ */
+#ifndef TEST_TIMEOUT_SEC
+#define TEST_TIMEOUT_SEC	5
+#endif
+
+/*
+ * Timeout on connect() where a failure is expected.
+ * If set to 0 - kernel will try to retransmit SYN number of times, set in
+ * /proc/sys/net/ipv4/tcp_syn_retries
+ * By default set to 1 to make tests pass faster on non-busy machine.
+ */
+#ifndef TEST_RETRANSMIT_SEC
+#define TEST_RETRANSMIT_SEC	1
+#endif
+
+static inline int _test_connect_socket(int sk, const union tcp_addr taddr,
+				       unsigned int port, time_t timeout)
+{
+	sockaddr_af addr;
+
+	tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
+	return __test_connect_socket(sk, veth_name,
+				     (void *)&addr, sizeof(addr), timeout);
+}
+
+static inline int test_connect_socket(int sk, const union tcp_addr taddr,
+				      unsigned int port)
+{
+	return _test_connect_socket(sk, taddr, port, TEST_TIMEOUT_SEC);
+}
+
+extern int __test_set_md5(int sk, void *addr, size_t addr_sz,
+			  uint8_t prefix, int vrf, const char *password);
+static inline int test_set_md5(int sk, const union tcp_addr in_addr,
+			       uint8_t prefix, int vrf, const char *password)
+{
+	sockaddr_af addr;
+
+	if (prefix > DEFAULT_TEST_PREFIX)
+		prefix = DEFAULT_TEST_PREFIX;
+
+	tcp_addr_to_sockaddr_in(&addr, &in_addr, 0);
+	return __test_set_md5(sk, (void *)&addr, sizeof(addr),
+			prefix, vrf, password);
+}
+
+extern int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
+		void *addr, size_t addr_sz, bool set_current, bool set_rnext,
+		uint8_t prefix, uint8_t vrf,
+		uint8_t sndid, uint8_t rcvid, uint8_t maclen,
+		uint8_t keyflags, uint8_t keylen, const char *key);
+
+static inline int test_prepare_key(struct tcp_ao_add *ao,
+		const char *alg, union tcp_addr taddr,
+		bool set_current, bool set_rnext,
+		uint8_t prefix, uint8_t vrf,
+		uint8_t sndid, uint8_t rcvid, uint8_t maclen,
+		uint8_t keyflags, uint8_t keylen, const char *key)
+{
+	sockaddr_af addr;
+
+	tcp_addr_to_sockaddr_in(&addr, &taddr, 0);
+	return test_prepare_key_sockaddr(ao, alg, (void *)&addr, sizeof(addr),
+			set_current, set_rnext, prefix, vrf, sndid, rcvid,
+			maclen, keyflags, keylen, key);
+}
+
+static inline int test_prepare_def_key(struct tcp_ao_add *ao,
+		const char *key, uint8_t keyflags,
+		union tcp_addr in_addr, uint8_t prefix, uint8_t vrf,
+		uint8_t sndid, uint8_t rcvid)
+{
+	if (prefix > DEFAULT_TEST_PREFIX)
+		prefix = DEFAULT_TEST_PREFIX;
+
+	return test_prepare_key(ao, DEFAULT_TEST_ALGO, in_addr, false, false,
+				prefix, vrf, sndid, rcvid, 0, keyflags,
+				strlen(key), key);
+}
+
+extern int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
+			   void *addr, size_t addr_sz,
+			   uint8_t prefix, uint8_t sndid, uint8_t rcvid);
+extern int test_get_ao_info(int sk, struct tcp_ao_info_opt *out);
+extern int test_set_ao_info(int sk, struct tcp_ao_info_opt *in);
+extern int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
+					  const struct tcp_ao_getsockopt *b);
+extern int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
+					     const struct tcp_ao_info_opt *b);
+
+static inline int test_verify_socket_key(int sk, struct tcp_ao_add *key)
+{
+	struct tcp_ao_getsockopt key2 = {};
+	int err;
+
+	err = test_get_one_ao(sk, &key2, &key->addr, sizeof(key->addr),
+			      key->prefix, key->sndid, key->rcvid);
+	if (err)
+		return err;
+
+	return test_cmp_getsockopt_setsockopt(key, &key2);
+}
+
+static inline int test_add_key_vrf(int sk,
+				   const char *key, uint8_t keyflags,
+				   union tcp_addr in_addr, uint8_t prefix,
+				   uint8_t vrf, uint8_t sndid, uint8_t rcvid)
+{
+	struct tcp_ao_add tmp = {};
+	int err;
+
+	err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix,
+				   vrf, sndid, rcvid);
+	if (err)
+		return err;
+
+	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
+	if (err < 0)
+		return -errno;
+
+	return test_verify_socket_key(sk, &tmp);
+}
+
+static inline int test_add_key(int sk, const char *key,
+			      union tcp_addr in_addr, uint8_t prefix,
+			      uint8_t sndid, uint8_t rcvid)
+{
+	return test_add_key_vrf(sk, key, 0, in_addr, prefix, 0, sndid, rcvid);
+}
+
+static inline int test_verify_socket_ao(int sk, struct tcp_ao_info_opt *ao)
+{
+	struct tcp_ao_info_opt ao2 = {};
+	int err;
+
+	err = test_get_ao_info(sk, &ao2);
+	if (err)
+		return err;
+
+	return test_cmp_getsockopt_setsockopt_ao(ao, &ao2);
+}
+
+static inline int test_set_ao_flags(int sk, bool ao_required, bool accept_icmps)
+{
+	struct tcp_ao_info_opt ao = {};
+	int err;
+
+	err = test_get_ao_info(sk, &ao);
+	/* Maybe ao_info wasn't allocated yet */
+	if (err && err != -ENOENT)
+		return err;
+
+	ao.ao_required = !!ao_required;
+	ao.accept_icmps = !!accept_icmps;
+	err = test_set_ao_info(sk, &ao);
+	if (err)
+		return err;
+
+	return test_verify_socket_ao(sk, &ao);
+}
+
+extern ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec);
+extern ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
+				const size_t msg_len, time_t timeout_sec);
+extern int test_client_verify(int sk, const size_t msg_len, const size_t nr,
+			      time_t timeout_sec);
+
+struct tcp_ao_key_counters {
+	uint8_t sndid;
+	uint8_t rcvid;
+	uint64_t pkt_good;
+	uint64_t pkt_bad;
+};
+
+struct tcp_ao_counters {
+	/* per-netns */
+	uint64_t netns_ao_good;
+	uint64_t netns_ao_bad;
+	uint64_t netns_ao_key_not_found;
+	uint64_t netns_ao_required;
+	uint64_t netns_ao_dropped_icmp;
+	/* per-socket */
+	uint64_t ao_info_pkt_good;
+	uint64_t ao_info_pkt_bad;
+	uint64_t ao_info_pkt_key_not_found;
+	uint64_t ao_info_pkt_ao_required;
+	uint64_t ao_info_pkt_dropped_icmp;
+	/* per-key */
+	size_t nr_keys;
+	struct tcp_ao_key_counters *key_cnts;
+};
+extern int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out);
+
+#define TEST_CNT_KEY_GOOD		BIT(0)
+#define TEST_CNT_KEY_BAD		BIT(1)
+#define TEST_CNT_SOCK_GOOD		BIT(2)
+#define TEST_CNT_SOCK_BAD		BIT(3)
+#define TEST_CNT_SOCK_KEY_NOT_FOUND	BIT(4)
+#define TEST_CNT_SOCK_AO_REQUIRED	BIT(5)
+#define TEST_CNT_SOCK_DROPPED_ICMP	BIT(6)
+#define TEST_CNT_NS_GOOD		BIT(7)
+#define TEST_CNT_NS_BAD			BIT(8)
+#define TEST_CNT_NS_KEY_NOT_FOUND	BIT(9)
+#define TEST_CNT_NS_AO_REQUIRED		BIT(10)
+#define TEST_CNT_NS_DROPPED_ICMP	BIT(11)
+typedef uint16_t test_cnt;
+
+#define TEST_CNT_AO_GOOD		(TEST_CNT_SOCK_GOOD | TEST_CNT_NS_GOOD)
+#define TEST_CNT_AO_BAD			(TEST_CNT_SOCK_BAD | TEST_CNT_NS_BAD)
+#define TEST_CNT_AO_KEY_NOT_FOUND	(TEST_CNT_SOCK_KEY_NOT_FOUND | \
+					 TEST_CNT_NS_KEY_NOT_FOUND)
+#define TEST_CNT_AO_REQUIRED		(TEST_CNT_SOCK_AO_REQUIRED | \
+					 TEST_CNT_NS_AO_REQUIRED)
+#define TEST_CNT_AO_DROPPED_ICMP	(TEST_CNT_SOCK_DROPPED_ICMP | \
+					 TEST_CNT_NS_DROPPED_ICMP)
+#define TEST_CNT_GOOD			(TEST_CNT_KEY_GOOD | TEST_CNT_AO_GOOD)
+#define TEST_CNT_BAD			(TEST_CNT_KEY_BAD | TEST_CNT_AO_BAD)
+
+extern int __test_tcp_ao_counters_cmp(const char *tst_name,
+		struct tcp_ao_counters *before, struct tcp_ao_counters *after,
+		test_cnt expected);
+extern int test_tcp_ao_key_counters_cmp(const char *tst_name,
+		struct tcp_ao_counters *before, struct tcp_ao_counters *after,
+		test_cnt expected, int sndid, int rcvid);
+extern void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts);
+/*
+ * Frees buffers allocated in test_get_tcp_ao_counters().
+ * The function doesn't expect new keys or keys removed between calls
+ * to test_get_tcp_ao_counters(). Check key counters manually if they
+ * may change.
+ */
+static inline int test_tcp_ao_counters_cmp(const char *tst_name,
+					   struct tcp_ao_counters *before,
+					   struct tcp_ao_counters *after,
+					   test_cnt expected)
+{
+	int ret;
+
+	ret = __test_tcp_ao_counters_cmp(tst_name, before, after, expected);
+	if (ret)
+		goto out;
+	ret = test_tcp_ao_key_counters_cmp(tst_name, before, after,
+					   expected, -1, -1);
+out:
+	test_tcp_ao_counters_free(before);
+	test_tcp_ao_counters_free(after);
+	return ret;
+}
+
+struct netstat;
+extern struct netstat *netstat_read(void);
+extern void netstat_free(struct netstat *ns);
+extern void netstat_print_diff(struct netstat *nsa, struct netstat *nsb);
+extern uint64_t netstat_get(struct netstat *ns,
+			    const char *name, bool *not_found);
+
+static inline uint64_t netstat_get_one(const char *name, bool *not_found)
+{
+	struct netstat *ns = netstat_read();
+	uint64_t ret;
+
+	ret = netstat_get(ns, name, not_found);
+
+	netstat_free(ns);
+	return ret;
+}
+
+struct tcp_sock_queue {
+	uint32_t seq;
+	void *buf;
+};
+
+struct tcp_sock_state {
+	struct tcp_info info;
+	struct tcp_repair_window trw;
+	struct tcp_sock_queue out;
+	int outq_len;		/* output queue size (not sent + not acked) */
+	int outq_nsd_len;	/* output queue size (not sent only) */
+	struct tcp_sock_queue in;
+	int inq_len;
+	int mss;
+	int timestamp;
+};
+
+extern void __test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+				   void *addr, size_t addr_size);
+static inline void test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+					sockaddr_af *saddr)
+{
+	__test_sock_checkpoint(sk, state, saddr, sizeof(*saddr));
+}
+extern void test_ao_checkpoint(int sk, struct tcp_ao_repair *state);
+extern void __test_sock_restore(int sk, const char *device,
+				struct tcp_sock_state *state,
+				void *saddr, void *daddr, size_t addr_size);
+static inline void test_sock_restore(int sk, struct tcp_sock_state *state,
+				     sockaddr_af *saddr,
+				     const union tcp_addr daddr,
+				     unsigned int dport)
+{
+	sockaddr_af addr;
+
+	tcp_addr_to_sockaddr_in(&addr, &daddr, htons(dport));
+	__test_sock_restore(sk, veth_name, state, saddr, &addr, sizeof(addr));
+}
+extern void test_ao_restore(int sk, struct tcp_ao_repair *state);
+extern void test_sock_state_free(struct tcp_sock_state *state);
+extern void test_enable_repair(int sk);
+extern void test_disable_repair(int sk);
+extern void test_kill_sk(int sk);
+static inline int test_add_repaired_key(int sk,
+					const char *key, uint8_t keyflags,
+					union tcp_addr in_addr, uint8_t prefix,
+					uint8_t sndid, uint8_t rcvid)
+{
+	struct tcp_ao_add tmp = {};
+	int err;
+
+	err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix,
+				   0, sndid, rcvid);
+	if (err)
+		return err;
+
+	tmp.set_current = 1;
+	tmp.set_rnext = 1;
+	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0)
+		return -errno;
+
+	return test_verify_socket_key(sk, &tmp);
+}
+
+#endif /* _AOLIB_H_ */
diff --git a/tools/testing/selftests/net/tcp_ao/lib/kconfig.c b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c
new file mode 100644
index 000000000000..f279ffc3843b
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c
@@ -0,0 +1,148 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Check what features does the kernel support (where the selftest is running).
+ * Somewhat inspired by CRIU kerndat/kdat kernel features detector.
+ */
+#include <pthread.h>
+#include "aolib.h"
+
+struct kconfig_t {
+	int _errno;		/* the returned error if not supported */
+	int (*check_kconfig)(int *error);
+};
+
+static int has_net_ns(int *err)
+{
+	if (access("/proc/self/ns/net", F_OK) < 0) {
+		*err = errno;
+		if (errno == ENOENT)
+			return 0;
+		test_print("Unable to access /proc/self/ns/net: %m");
+		return -errno;
+	}
+	return *err = errno = 0;
+}
+
+static int has_veth(int *err)
+{
+	int orig_netns, ns_a, ns_b;
+
+	orig_netns = open_netns();
+	ns_a = unshare_open_netns();
+	ns_b = unshare_open_netns();
+
+	*err = add_veth("check_veth", ns_a, ns_b);
+
+	switch_ns(orig_netns);
+	close(orig_netns);
+	close(ns_a);
+	close(ns_b);
+	return 0;
+}
+
+static int has_tcp_ao(int *err)
+{
+	struct sockaddr_in addr = {
+		.sin_family = test_family,
+	};
+	struct tcp_ao_add tmp = {};
+	const char *password = DEFAULT_TEST_PASSWORD;
+	int sk, ret = 0;
+
+	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+	if (sk < 0) {
+		test_print("socket(): %m");
+		return -errno;
+	}
+
+	tmp.sndid = 100;
+	tmp.rcvid = 100;
+	tmp.keylen = strlen(password);
+	memcpy(tmp.key, password, strlen(password));
+	strcpy(tmp.alg_name, "hmac(sha1)");
+	memcpy(&tmp.addr, &addr, sizeof(addr));
+	*err = 0;
+	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) {
+		*err = errno;
+		if (errno != ENOPROTOOPT)
+			ret = -errno;
+	}
+	close(sk);
+	return ret;
+}
+
+static int has_tcp_md5(int *err)
+{
+	union tcp_addr addr_any = {};
+	int sk, ret = 0;
+
+	sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+	if (sk < 0) {
+		test_print("socket(): %m");
+		return -errno;
+	}
+
+	/*
+	 * Under CONFIG_CRYPTO_FIPS=y it fails with ENOMEM, rather with
+	 * anything more descriptive. Oh well.
+	 */
+	*err = 0;
+	if (test_set_md5(sk, addr_any, 0, -1, DEFAULT_TEST_PASSWORD)) {
+		*err = errno;
+		if (errno != ENOPROTOOPT && errno == ENOMEM) {
+			test_print("setsockopt(TCP_MD5SIG_EXT): %m");
+			ret = -errno;
+		}
+	}
+	close(sk);
+	return ret;
+}
+
+static int has_vrfs(int *err)
+{
+	int orig_netns, ns_test, ret = 0;
+
+	orig_netns = open_netns();
+	ns_test = unshare_open_netns();
+
+	*err = add_vrf("ksft-check", 55, 101, ns_test);
+	if (*err && *err != -EOPNOTSUPP) {
+		test_print("Failed to add a VRF: %d", *err);
+		ret = *err;
+	}
+
+	switch_ns(orig_netns);
+	close(orig_netns);
+	close(ns_test);
+	return ret;
+}
+
+static pthread_mutex_t kconfig_lock = PTHREAD_MUTEX_INITIALIZER;
+static struct kconfig_t kconfig[__KCONFIG_LAST__] = {
+	{ -1, has_net_ns },
+	{ -1, has_veth },
+	{ -1, has_tcp_ao },
+	{ -1, has_tcp_md5 },
+	{ -1, has_vrfs },
+};
+
+const char *tests_skip_reason[__KCONFIG_LAST__] = {
+	"Tests require network namespaces support (CONFIG_NET_NS)",
+	"Tests require veth support (CONFIG_VETH)",
+	"Tests require TCP-AO support (CONFIG_TCP_AO)",
+	"setsockopt(TCP_MD5SIG_EXT) is not supported (CONFIG_TCP_MD5)",
+	"VRFs are not supported (CONFIG_NET_VRF)",
+};
+
+bool kernel_config_has(enum test_needs_kconfig k)
+{
+	bool ret;
+
+	pthread_mutex_lock(&kconfig_lock);
+	if (kconfig[k]._errno == -1) {
+		if (kconfig[k].check_kconfig(&kconfig[k]._errno))
+			test_error("Failed to initialize kconfig %u", k);
+	}
+	ret = kconfig[k]._errno == 0;
+	pthread_mutex_unlock(&kconfig_lock);
+	return ret;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/netlink.c b/tools/testing/selftests/net/tcp_ao/lib/netlink.c
new file mode 100644
index 000000000000..b731f2c84083
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/netlink.c
@@ -0,0 +1,415 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Original from tools/testing/selftests/net/ipsec.c */
+#include <linux/netlink.h>
+#include <linux/random.h>
+#include <linux/rtnetlink.h>
+#include <linux/veth.h>
+#include <net/if.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/socket.h>
+
+#include "aolib.h"
+
+#define MAX_PAYLOAD		2048
+
+static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
+{
+	if (*sock > 0) {
+		seq_nr++;
+		return 0;
+	}
+
+	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
+	if (*sock < 0) {
+		test_print("socket(AF_NETLINK)");
+		return -1;
+	}
+
+	randomize_buffer(seq_nr, sizeof(*seq_nr));
+
+	return 0;
+}
+
+static int netlink_check_answer(int sock, bool quite)
+{
+	struct nlmsgerror {
+		struct nlmsghdr hdr;
+		int error;
+		struct nlmsghdr orig_msg;
+	} answer;
+
+	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
+		test_print("recv()");
+		return -1;
+	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
+		test_print("expected NLMSG_ERROR, got %d",
+			   (int)answer.hdr.nlmsg_type);
+		return -1;
+	} else if (answer.error) {
+		if (!quite) {
+			test_print("NLMSG_ERROR: %d: %s",
+				answer.error, strerror(-answer.error));
+		}
+		return answer.error;
+	}
+
+	return 0;
+}
+
+static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
+{
+	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
+}
+
+static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
+		unsigned short rta_type, const void *payload, size_t size)
+{
+	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
+	struct rtattr *attr = rtattr_hdr(nh);
+	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
+
+	if (req_sz < nl_size) {
+		test_print("req buf is too small: %zu < %zu", req_sz, nl_size);
+		return -1;
+	}
+	nh->nlmsg_len = nl_size;
+
+	attr->rta_len = RTA_LENGTH(size);
+	attr->rta_type = rta_type;
+	memcpy(RTA_DATA(attr), payload, size);
+
+	return 0;
+}
+
+static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
+		unsigned short rta_type, const void *payload, size_t size)
+{
+	struct rtattr *ret = rtattr_hdr(nh);
+
+	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
+		return 0;
+
+	return ret;
+}
+
+static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
+		unsigned short rta_type)
+{
+	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
+}
+
+static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
+{
+	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
+
+	attr->rta_len = nlmsg_end - (char *)attr;
+}
+
+static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
+		const char *peer, int ns)
+{
+	struct ifinfomsg pi;
+	struct rtattr *peer_attr;
+
+	memset(&pi, 0, sizeof(pi));
+	pi.ifi_family	= AF_UNSPEC;
+	pi.ifi_change	= 0xFFFFFFFF;
+
+	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
+	if (!peer_attr)
+		return -1;
+
+	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
+		return -1;
+
+	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
+		return -1;
+
+	rtattr_end(nh, peer_attr);
+
+	return 0;
+}
+
+static int __add_veth(int sock, uint32_t seq, const char *name,
+		      int ns_a, int ns_b)
+{
+	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+	struct {
+		struct nlmsghdr		nh;
+		struct ifinfomsg	info;
+		char			attrbuf[MAX_PAYLOAD];
+	} req;
+	static const char veth_type[] = "veth";
+	struct rtattr *link_info, *info_data;
+
+	memset(&req, 0, sizeof(req));
+	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
+	req.nh.nlmsg_type	= RTM_NEWLINK;
+	req.nh.nlmsg_flags	= flags;
+	req.nh.nlmsg_seq	= seq;
+	req.info.ifi_family	= AF_UNSPEC;
+	req.info.ifi_change	= 0xFFFFFFFF;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name)))
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
+		return -1;
+
+	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
+	if (!link_info)
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
+		return -1;
+
+	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
+	if (!info_data)
+		return -1;
+
+	if (veth_pack_peerb(&req.nh, sizeof(req), name, ns_b))
+		return -1;
+
+	rtattr_end(&req.nh, info_data);
+	rtattr_end(&req.nh, link_info);
+
+	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+		test_print("send()");
+		return -1;
+	}
+	return netlink_check_answer(sock, false);
+}
+
+int add_veth(const char *name, int nsfda, int nsfdb)
+{
+	int route_sock = -1, ret;
+	uint32_t route_seq;
+
+	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+		test_error("Failed to open netlink route socket\n");
+
+	ret = __add_veth(route_sock, route_seq++, name, nsfda, nsfdb);
+	close(route_sock);
+	return ret;
+}
+
+static int __ip_addr_add(int sock, uint32_t seq, const char *intf,
+			 int family, union tcp_addr addr, uint8_t prefix)
+{
+	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+	struct {
+		struct nlmsghdr		nh;
+		struct ifaddrmsg	info;
+		char			attrbuf[MAX_PAYLOAD];
+	} req;
+	size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) :
+						sizeof(struct in6_addr);
+
+	memset(&req, 0, sizeof(req));
+	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
+	req.nh.nlmsg_type	= RTM_NEWADDR;
+	req.nh.nlmsg_flags	= flags;
+	req.nh.nlmsg_seq	= seq;
+	req.info.ifa_family	= family;
+	req.info.ifa_prefixlen	= prefix;
+	req.info.ifa_index	= if_nametoindex(intf);
+	req.info.ifa_flags	= IFA_F_NODAD;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, addr_len))
+		return -1;
+
+	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+		test_print("send()");
+		return -1;
+	}
+	return netlink_check_answer(sock, true);
+}
+
+int ip_addr_add(const char *intf, int family,
+		union tcp_addr addr, uint8_t prefix)
+{
+	int route_sock = -1, ret;
+	uint32_t route_seq;
+
+	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+		test_error("Failed to open netlink route socket\n");
+
+	ret = __ip_addr_add(route_sock, route_seq++, intf,
+			    family, addr, prefix);
+
+	close(route_sock);
+	return ret;
+}
+
+static int __ip_route_add(int sock, uint32_t seq, const char *intf, int family,
+			  union tcp_addr src, union tcp_addr dst, uint8_t vrf)
+{
+	struct {
+		struct nlmsghdr	nh;
+		struct rtmsg	rt;
+		char		attrbuf[MAX_PAYLOAD];
+	} req;
+	unsigned int index = if_nametoindex(intf);
+	size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) :
+						sizeof(struct in6_addr);
+
+	memset(&req, 0, sizeof(req));
+	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
+	req.nh.nlmsg_type	= RTM_NEWROUTE;
+	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
+	req.nh.nlmsg_seq	= seq;
+	req.rt.rtm_family	= family;
+	req.rt.rtm_dst_len	= (family == AF_INET) ? 32 : 128;
+	req.rt.rtm_table	= RT_TABLE_MAIN;
+	req.rt.rtm_protocol	= RTPROT_BOOT;
+	req.rt.rtm_scope	= RT_SCOPE_UNIVERSE;
+	req.rt.rtm_type		= RTN_UNICAST;
+
+	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, addr_len))
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, addr_len))
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
+		return -1;
+
+	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+		test_print("send()");
+		return -1;
+	}
+
+	return netlink_check_answer(sock, true);
+}
+
+int ip_route_add_vrf(const char *intf, int family,
+		 union tcp_addr src, union tcp_addr dst, uint8_t vrf)
+{
+	int route_sock = -1, ret;
+	uint32_t route_seq;
+
+	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+		test_error("Failed to open netlink route socket\n");
+
+	ret = __ip_route_add(route_sock, route_seq++, intf,
+			     family, src, dst, vrf);
+	if (ret == -EEXIST) /* ignoring */
+		ret = 0;
+
+	close(route_sock);
+	return ret;
+}
+
+int ip_route_add(const char *intf, int family,
+		 union tcp_addr src, union tcp_addr dst)
+{
+	return ip_route_add_vrf(intf, family, src, dst, RT_TABLE_MAIN);
+}
+
+static int __link_set_up(int sock, uint32_t seq, const char *intf)
+{
+	struct {
+		struct nlmsghdr		nh;
+		struct ifinfomsg	info;
+		char			attrbuf[MAX_PAYLOAD];
+	} req;
+
+	memset(&req, 0, sizeof(req));
+	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
+	req.nh.nlmsg_type	= RTM_NEWLINK;
+	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
+	req.nh.nlmsg_seq	= seq;
+	req.info.ifi_family	= AF_UNSPEC;
+	req.info.ifi_change	= 0xFFFFFFFF;
+	req.info.ifi_index	= if_nametoindex(intf);
+	req.info.ifi_flags	= IFF_UP;
+	req.info.ifi_change	= IFF_UP;
+
+	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+		test_print("send()");
+		return -1;
+	}
+	return netlink_check_answer(sock, false);
+}
+
+int link_set_up(const char *intf)
+{
+	int route_sock = -1, ret;
+	uint32_t route_seq;
+
+	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+		test_error("Failed to open netlink route socket\n");
+
+	ret = __link_set_up(route_sock, route_seq++, intf);
+
+	close(route_sock);
+	return ret;
+}
+
+static int __add_vrf(int sock, uint32_t seq, const char *name,
+		     uint32_t tabid, int ifindex, int nsfd)
+{
+	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+	struct {
+		struct nlmsghdr		nh;
+		struct ifinfomsg	info;
+		char			attrbuf[MAX_PAYLOAD];
+	} req;
+	static const char vrf_type[] = "vrf";
+	struct rtattr *link_info, *info_data;
+
+	memset(&req, 0, sizeof(req));
+	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
+	req.nh.nlmsg_type	= RTM_NEWLINK;
+	req.nh.nlmsg_flags	= flags;
+	req.nh.nlmsg_seq	= seq;
+	req.info.ifi_family	= AF_UNSPEC;
+	req.info.ifi_change	= 0xFFFFFFFF;
+	req.info.ifi_index	= ifindex;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name)))
+		return -1;
+
+	if (nsfd >= 0)
+		if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD,
+				&nsfd, sizeof(nsfd)))
+			return -1;
+
+	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
+	if (!link_info)
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, vrf_type, sizeof(vrf_type)))
+		return -1;
+
+	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
+	if (!info_data)
+		return -1;
+
+	if (rtattr_pack(&req.nh, sizeof(req), IFLA_VRF_TABLE,
+			&tabid, sizeof(tabid)))
+		return -1;
+
+	rtattr_end(&req.nh, info_data);
+	rtattr_end(&req.nh, link_info);
+
+	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+		test_print("send()");
+		return -1;
+	}
+	return netlink_check_answer(sock, true);
+}
+
+int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd)
+{
+	int route_sock = -1, ret;
+	uint32_t route_seq;
+
+	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+		test_error("Failed to open netlink route socket\n");
+
+	ret = __add_vrf(route_sock, route_seq++, name, tabid, ifindex, nsfd);
+	close(route_sock);
+	return ret;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/proc.c b/tools/testing/selftests/net/tcp_ao/lib/proc.c
new file mode 100644
index 000000000000..2322f4d4676d
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/proc.c
@@ -0,0 +1,273 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <inttypes.h>
+#include <pthread.h>
+#include <stdio.h>
+#include "../../../../../include/linux/compiler.h"
+#include "../../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+struct netstat_counter {
+	uint64_t val;
+	char *name;
+};
+
+struct netstat {
+	char *header_name;
+	struct netstat *next;
+	size_t counters_nr;
+	struct netstat_counter *counters;
+};
+
+static struct netstat *lookup_type(struct netstat *ns,
+		const char *type, size_t len)
+{
+	while (ns != NULL) {
+		size_t cmp = max(len, strlen(ns->header_name));
+
+		if (!strncmp(ns->header_name, type, cmp))
+			return ns;
+		ns = ns->next;
+	}
+	return NULL;
+}
+
+static struct netstat *lookup_get(struct netstat *ns,
+				  const char *type, const size_t len)
+{
+	struct netstat *ret;
+
+	ret = lookup_type(ns, type, len);
+	if (ret != NULL)
+		return ret;
+
+	ret = malloc(sizeof(struct netstat));
+	if (!ret)
+		test_error("malloc()");
+
+	ret->header_name = strndup(type, len);
+	if (ret->header_name == NULL)
+		test_error("strndup()");
+	ret->next = ns;
+	ret->counters_nr = 0;
+	ret->counters = NULL;
+
+	return ret;
+}
+
+static struct netstat *lookup_get_column(struct netstat *ns, const char *line)
+{
+	char *column;
+
+	column = strchr(line, ':');
+	if (!column)
+		test_error("can't parse netstat file");
+
+	return lookup_get(ns, line, column - line);
+}
+
+static void netstat_read_type(FILE *fnetstat, struct netstat **dest, char *line)
+{
+	struct netstat *type = lookup_get_column(*dest, line);
+	const char *pos = line;
+	size_t i, nr_elems = 0;
+	char tmp;
+
+	while ((pos = strchr(pos, ' '))) {
+		nr_elems++;
+		pos++;
+	}
+
+	*dest = type;
+	type->counters = reallocarray(type->counters,
+				type->counters_nr + nr_elems,
+				sizeof(struct netstat_counter));
+	if (!type->counters)
+		test_error("reallocarray()");
+
+	pos = strchr(line, ' ') + 1;
+
+	if (fscanf(fnetstat, type->header_name) == EOF)
+		test_error("fscanf(%s)", type->header_name);
+	if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != ':')
+		test_error("Unexpected netstat format (%c)", tmp);
+
+	for (i = type->counters_nr; i < type->counters_nr + nr_elems; i++) {
+		struct netstat_counter *nc = &type->counters[i];
+		const char *new_pos = strchr(pos, ' ');
+		const char *fmt = " %" PRIu64;
+
+		if (new_pos == NULL)
+			new_pos = strchr(pos, '\n');
+
+		nc->name = strndup(pos, new_pos - pos);
+		if (nc->name == NULL)
+			test_error("strndup()");
+
+		if (unlikely(!strcmp(nc->name, "MaxConn")))
+			fmt = " %" PRId64; /* MaxConn is signed, RFC 2012 */
+		if (fscanf(fnetstat, fmt, &nc->val) != 1)
+			test_error("fscanf(%s)", nc->name);
+		pos = new_pos + 1;
+	}
+	type->counters_nr += nr_elems;
+
+	if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != '\n')
+		test_error("Unexpected netstat format");
+}
+
+static const char *snmp6_name = "Snmp6";
+static void snmp6_read(FILE *fnetstat, struct netstat **dest)
+{
+	struct netstat *type = lookup_get(*dest, snmp6_name, strlen(snmp6_name));
+	char *counter_name;
+	size_t i;
+
+	for (i = type->counters_nr;; i++) {
+		struct netstat_counter *nc;
+		uint64_t counter;
+
+		if (fscanf(fnetstat, "%ms", &counter_name) == EOF)
+			break;
+		if (fscanf(fnetstat, "%" PRIu64, &counter) == EOF)
+			test_error("Unexpected snmp6 format");
+		type->counters = reallocarray(type->counters, i + 1,
+					sizeof(struct netstat_counter));
+		if (!type->counters)
+			test_error("reallocarray()");
+		nc = &type->counters[i];
+		nc->name = counter_name;
+		nc->val = counter;
+	}
+	type->counters_nr = i;
+	*dest = type;
+}
+
+struct netstat *netstat_read(void)
+{
+	struct netstat *ret = 0;
+	size_t line_sz = 0;
+	char *line = NULL;
+	FILE *fnetstat;
+
+	/*
+	 * Opening thread-self instead of /proc/net/... as the latter
+	 * points to /proc/self/net/ which instantiates thread-leader's
+	 * net-ns, see:
+	 * commit 155134fef2b6 ("Revert "proc: Point /proc/{mounts,net} at..")
+	 */
+	errno = 0;
+	fnetstat = fopen("/proc/thread-self/net/netstat", "r");
+	if (fnetstat == NULL)
+		test_error("failed to open /proc/net/netstat");
+
+	while (getline(&line, &line_sz, fnetstat) != -1)
+		netstat_read_type(fnetstat, &ret, line);
+	fclose(fnetstat);
+
+	errno = 0;
+	fnetstat = fopen("/proc/thread-self/net/snmp", "r");
+	if (fnetstat == NULL)
+		test_error("failed to open /proc/net/snmp");
+
+	while (getline(&line, &line_sz, fnetstat) != -1)
+		netstat_read_type(fnetstat, &ret, line);
+	fclose(fnetstat);
+
+	errno = 0;
+	fnetstat = fopen("/proc/thread-self/net/snmp6", "r");
+	if (fnetstat == NULL)
+		test_error("failed to open /proc/net/snmp6");
+
+	snmp6_read(fnetstat, &ret);
+	fclose(fnetstat);
+
+	free(line);
+	return ret;
+}
+
+void netstat_free(struct netstat *ns)
+{
+	while (ns != NULL) {
+		struct netstat *prev = ns;
+		size_t i;
+
+		free(ns->header_name);
+		for (i = 0; i < ns->counters_nr; i++)
+			free(ns->counters[i].name);
+		free(ns->counters);
+		ns = ns->next;
+		free(prev);
+	}
+}
+
+static inline void
+__netstat_print_diff(uint64_t a, struct netstat *nsb, size_t i)
+{
+	if (unlikely(!strcmp(nsb->header_name, "MaxConn"))) {
+		test_print("%8s %25s: %" PRId64 " => %" PRId64,
+				nsb->header_name, nsb->counters[i].name,
+				a, nsb->counters[i].val);
+		return;
+	}
+
+	test_print("%8s %25s: %" PRIu64 " => %" PRIu64, nsb->header_name,
+			nsb->counters[i].name, a, nsb->counters[i].val);
+}
+
+void netstat_print_diff(struct netstat *nsa, struct netstat *nsb)
+{
+	size_t i, j;
+
+	while (nsb != NULL) {
+		if (unlikely(strcmp(nsb->header_name, nsa->header_name))) {
+			for (i = 0; i < nsb->counters_nr; i++)
+				__netstat_print_diff(0, nsb, i);
+			nsb = nsb->next;
+			continue;
+		}
+
+		if (nsb->counters_nr < nsa->counters_nr)
+			test_error("Unexpected: some counters dissapeared!");
+
+		for (j = 0, i = 0; i < nsb->counters_nr; i++) {
+			if (strcmp(nsb->counters[i].name, nsa->counters[j].name)) {
+				__netstat_print_diff(0, nsb, i);
+				continue;
+			}
+
+			if (nsa->counters[j].val == nsb->counters[i].val) {
+				j++;
+				continue;
+			}
+
+			__netstat_print_diff(nsa->counters[j].val, nsb, i);
+			j++;
+		}
+		if (j != nsa->counters_nr)
+			test_error("Unexpected: some counters dissapeared!");
+
+		nsb = nsb->next;
+		nsa = nsa->next;
+	}
+}
+
+uint64_t netstat_get(struct netstat *ns, const char *name, bool *not_found)
+{
+	if (not_found)
+		*not_found = false;
+
+	while (ns != NULL) {
+		size_t i;
+
+		for (i = 0; i < ns->counters_nr; i++) {
+			if (!strcmp(name, ns->counters[i].name))
+				return ns->counters[i].val;
+		}
+
+		ns = ns->next;
+	}
+
+	if (not_found)
+		*not_found = true;
+	return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/repair.c b/tools/testing/selftests/net/tcp_ao/lib/repair.c
new file mode 100644
index 000000000000..9893b3ba69f5
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/repair.c
@@ -0,0 +1,254 @@
+// SPDX-License-Identifier: GPL-2.0
+/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets
+ * It tests that TCP-AO enabled connection can be restored.
+ * For the proper socket repair see:
+ * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h
+ */
+#include <fcntl.h>
+#include <linux/sockios.h>
+#include <sys/ioctl.h>
+#include "aolib.h"
+
+#ifndef TCPOPT_MAXSEG
+# define TCPOPT_MAXSEG		2
+#endif
+#ifndef TCPOPT_WINDOW
+# define TCPOPT_WINDOW		3
+#endif
+#ifndef TCPOPT_SACK_PERMITTED
+# define TCPOPT_SACK_PERMITTED	4
+#endif
+#ifndef TCPOPT_TIMESTAMP
+# define TCPOPT_TIMESTAMP	8
+#endif
+
+enum {
+	TCP_ESTABLISHED = 1,
+	TCP_SYN_SENT,
+	TCP_SYN_RECV,
+	TCP_FIN_WAIT1,
+	TCP_FIN_WAIT2,
+	TCP_TIME_WAIT,
+	TCP_CLOSE,
+	TCP_CLOSE_WAIT,
+	TCP_LAST_ACK,
+	TCP_LISTEN,
+	TCP_CLOSING,	/* Now a valid state */
+	TCP_NEW_SYN_RECV,
+
+	TCP_MAX_STATES	/* Leave at the end! */
+};
+
+static void test_sock_checkpoint_queue(int sk, int queue, int qlen,
+				       struct tcp_sock_queue *q)
+{
+	socklen_t len;
+	int ret;
+
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+		test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+	len = sizeof(q->seq);
+	ret = getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &q->seq, &len);
+	if (ret || len != sizeof(q->seq))
+		test_error("getsockopt(TCP_QUEUE_SEQ): %d", (int)len);
+
+	if (!qlen) {
+		q->buf = NULL;
+		return;
+	}
+
+	q->buf = malloc(qlen);
+	if (q->buf == NULL)
+		test_error("malloc()");
+	ret = recv(sk, q->buf, qlen, MSG_PEEK | MSG_DONTWAIT);
+	if (ret != qlen)
+		test_error("recv(%d): %d", qlen, ret);
+}
+
+void __test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+			    void *addr, size_t addr_size)
+{
+	socklen_t len = sizeof(state->info);
+	int ret;
+
+	memset(state, 0, sizeof(*state));
+
+	ret = getsockopt(sk, SOL_TCP, TCP_INFO, &state->info, &len);
+	if (ret || len != sizeof(state->info))
+		test_error("getsockopt(TCP_INFO): %d", (int)len);
+
+	len = addr_size;
+	if (getsockname(sk, addr, &len) || len != addr_size)
+		test_error("getsockname(): %d", (int)len);
+
+	len = sizeof(state->trw);
+	ret = getsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, &len);
+	if (ret || len != sizeof(state->trw))
+		test_error("getsockopt(TCP_REPAIR_WINDOW): %d", (int)len);
+
+	if (ioctl(sk, SIOCOUTQ, &state->outq_len))
+		test_error("ioctl(SIOCOUTQ)");
+
+	if (ioctl(sk, SIOCOUTQNSD, &state->outq_nsd_len))
+		test_error("ioctl(SIOCOUTQNSD)");
+	test_sock_checkpoint_queue(sk, TCP_SEND_QUEUE, state->outq_len, &state->out);
+
+	if (ioctl(sk, SIOCINQ, &state->inq_len))
+		test_error("ioctl(SIOCINQ)");
+	test_sock_checkpoint_queue(sk, TCP_RECV_QUEUE, state->inq_len, &state->in);
+
+	if (state->info.tcpi_state == TCP_CLOSE)
+		state->outq_len = state->outq_nsd_len = 0;
+
+	len = sizeof(state->mss);
+	ret = getsockopt(sk, SOL_TCP, TCP_MAXSEG, &state->mss, &len);
+	if (ret || len != sizeof(state->mss))
+		test_error("getsockopt(TCP_MAXSEG): %d", (int)len);
+
+	len = sizeof(state->timestamp);
+	ret = getsockopt(sk, SOL_TCP, TCP_TIMESTAMP, &state->timestamp, &len);
+	if (ret || len != sizeof(state->timestamp))
+		test_error("getsockopt(TCP_TIMESTAMP): %d", (int)len);
+}
+
+void test_ao_checkpoint(int sk, struct tcp_ao_repair *state)
+{
+	socklen_t len = sizeof(*state);
+	int ret;
+
+	memset(state, 0, sizeof(*state));
+
+	ret = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, &len);
+	if (ret || len != sizeof(*state))
+		test_error("getsockopt(TCP_AO_REPAIR): %d", (int)len);
+}
+
+static void test_sock_restore_seq(int sk, int queue, uint32_t seq)
+{
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+		test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+	if (setsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &seq, sizeof(seq)))
+		test_error("setsockopt(TCP_QUEUE_SEQ)");
+}
+
+static void test_sock_restore_queue(int sk, int queue, void *buf, int len)
+{
+	int chunk = len;
+	size_t off = 0;
+
+	if (len == 0)
+		return;
+
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+		test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+	do {
+		int ret;
+
+		ret = send(sk, buf + off, chunk, 0);
+		if (ret <= 0) {
+			if (chunk > 1024) {
+				chunk >>= 1;
+				continue;
+			}
+			test_error("send()");
+		}
+		off += ret;
+		len -= ret;
+	} while (len > 0);
+}
+
+void __test_sock_restore(int sk, const char *device,
+			 struct tcp_sock_state *state,
+			 void *saddr, void *daddr, size_t addr_size)
+{
+	struct tcp_repair_opt opts[4];
+	unsigned int opt_nr = 0;
+	long flags;
+
+	if (bind(sk, saddr, addr_size))
+		test_error("bind()");
+
+	flags = fcntl(sk, F_GETFL);
+	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+		test_error("fcntl()");
+
+	test_sock_restore_seq(sk, TCP_RECV_QUEUE, state->in.seq - state->inq_len);
+	test_sock_restore_seq(sk, TCP_SEND_QUEUE, state->out.seq - state->outq_len);
+
+	if (device != NULL && setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE,
+					 device, strlen(device) + 1))
+		test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
+
+	if (connect(sk, daddr, addr_size))
+		test_error("connect()");
+
+	if (state->info.tcpi_options & TCPI_OPT_SACK) {
+		opts[opt_nr].opt_code = TCPOPT_SACK_PERMITTED;
+		opts[opt_nr].opt_val = 0;
+		opt_nr++;
+	}
+	if (state->info.tcpi_options & TCPI_OPT_WSCALE) {
+		opts[opt_nr].opt_code = TCPOPT_WINDOW;
+		opts[opt_nr].opt_val = state->info.tcpi_snd_wscale +
+				(state->info.tcpi_rcv_wscale << 16);
+		opt_nr++;
+	}
+	if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
+		opts[opt_nr].opt_code = TCPOPT_TIMESTAMP;
+		opts[opt_nr].opt_val = 0;
+		opt_nr++;
+	}
+	opts[opt_nr].opt_code = TCPOPT_MAXSEG;
+	opts[opt_nr].opt_val = state->mss;
+	opt_nr++;
+
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_OPTIONS, opts, opt_nr * sizeof(opts[0])))
+		test_error("setsockopt(TCP_REPAIR_OPTIONS)");
+
+	if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
+		if (setsockopt(sk, SOL_TCP, TCP_TIMESTAMP,
+			       &state->timestamp, opt_nr * sizeof(opts[0])))
+			test_error("setsockopt(TCP_TIMESTAMP)");
+	}
+	test_sock_restore_queue(sk, TCP_RECV_QUEUE, state->in.buf, state->inq_len);
+	test_sock_restore_queue(sk, TCP_SEND_QUEUE, state->out.buf, state->outq_len);
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, sizeof(state->trw)))
+		test_error("setsockopt(TCP_REPAIR_WINDOW)");
+}
+
+void test_ao_restore(int sk, struct tcp_ao_repair *state)
+{
+	if (setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, sizeof(*state)))
+		test_error("setsockopt(TCP_AO_REPAIR)");
+}
+
+void test_sock_state_free(struct tcp_sock_state *state)
+{
+	free(state->out.buf);
+	free(state->in.buf);
+}
+
+void test_enable_repair(int sk)
+{
+	int val = TCP_REPAIR_ON;
+
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+		test_error("setsockopt(TCP_REPAIR)");
+}
+
+void test_disable_repair(int sk)
+{
+	int val = TCP_REPAIR_OFF_NO_WP;
+
+	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+		test_error("setsockopt(TCP_REPAIR)");
+}
+
+void test_kill_sk(int sk)
+{
+	test_enable_repair(sk);
+	close(sk);
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/setup.c b/tools/testing/selftests/net/tcp_ao/lib/setup.c
new file mode 100644
index 000000000000..374b27c26ebd
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/setup.c
@@ -0,0 +1,342 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <fcntl.h>
+#include <pthread.h>
+#include <sched.h>
+#include <signal.h>
+#include "aolib.h"
+
+/*
+ * Can't be included in the header: it defines static variables which
+ * will be unique to every object. Let's include it only once here.
+ */
+#include "../../../kselftest.h"
+
+/* Prevent overriding of one thread's output by another */
+static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER;
+
+void __test_msg(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_print_msg(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_ok(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_test_result_pass(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_fail(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_test_result_fail(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_xfail(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_test_result_xfail(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_error(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_test_result_error(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_skip(const char *buf)
+{
+	pthread_mutex_lock(&ksft_print_lock);
+	ksft_test_result_skip(buf);
+	pthread_mutex_unlock(&ksft_print_lock);
+}
+
+static volatile int failed;
+static volatile int skipped;
+
+void test_failed(void)
+{
+	failed = 1;
+}
+
+static void test_exit(void)
+{
+	if (failed) {
+		ksft_exit_fail();
+	} else if (skipped) {
+		/* ksft_exit_skip() is different from ksft_exit_*() */
+		ksft_print_cnts();
+		exit(KSFT_SKIP);
+	} else {
+		ksft_exit_pass();
+	}
+}
+
+struct dlist_t {
+	void (*destruct)(void);
+	struct dlist_t *next;
+};
+static struct dlist_t *destructors_list;
+
+void test_add_destructor(void (*d)(void))
+{
+	struct dlist_t *p;
+
+	p = malloc(sizeof(struct dlist_t));
+	if (p == NULL)
+		test_error("malloc() failed");
+
+	p->next = destructors_list;
+	p->destruct = d;
+	destructors_list = p;
+}
+
+static void test_destructor(void) __attribute__((destructor));
+static void test_destructor(void)
+{
+	while (destructors_list) {
+		struct dlist_t *p = destructors_list->next;
+
+		destructors_list->destruct();
+		free(destructors_list);
+		destructors_list = p;
+	}
+	test_exit();
+}
+
+static void sig_int(int signo)
+{
+	test_error("Caught SIGINT - exiting");
+}
+
+int open_netns(void)
+{
+	const char *netns_path = "/proc/self/ns/net";
+	int fd;
+
+	fd = open(netns_path, O_RDONLY);
+	if (fd < 0)
+		test_error("open(%s)", netns_path);
+	return fd;
+}
+
+int unshare_open_netns(void)
+{
+	if (unshare(CLONE_NEWNET) != 0)
+		test_error("unshare()");
+
+	return open_netns();
+}
+
+void switch_ns(int fd)
+{
+	if (setns(fd, CLONE_NEWNET))
+		test_error("setns()");
+}
+
+int switch_save_ns(int new_ns)
+{
+	int ret = open_netns();
+
+	switch_ns(new_ns);
+	return ret;
+}
+
+static int nsfd_outside	= -1;
+static int nsfd_parent	= -1;
+static int nsfd_child	= -1;
+const char veth_name[]	= "ktst-veth";
+
+static void init_namespaces(void)
+{
+	nsfd_outside = open_netns();
+	nsfd_parent = unshare_open_netns();
+	nsfd_child = unshare_open_netns();
+}
+
+static void link_init(const char *veth, int family, uint8_t prefix,
+		      union tcp_addr addr, union tcp_addr dest)
+{
+	if (link_set_up(veth))
+		test_error("Failed to set link up");
+	if (ip_addr_add(veth, family, addr, prefix))
+		test_error("Failed to add ip address");
+	if (ip_route_add(veth, family, addr, dest))
+		test_error("Failed to add route");
+}
+
+static unsigned int nr_threads = 1;
+
+static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER;
+static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER;
+static volatile unsigned int stage_threads[2];
+static volatile unsigned int stage_nr;
+
+/* synchronize all threads in the same stage */
+void synchronize_threads(void)
+{
+	unsigned int q = stage_nr;
+
+	pthread_mutex_lock(&sync_lock);
+	stage_threads[q]++;
+	if (stage_threads[q] == nr_threads) {
+		stage_nr ^= 1;
+		stage_threads[stage_nr] = 0;
+		pthread_cond_signal(&sync_cond);
+	}
+	while (stage_threads[q] < nr_threads)
+		pthread_cond_wait(&sync_cond, &sync_lock);
+	pthread_mutex_unlock(&sync_lock);
+}
+
+__thread union tcp_addr this_ip_addr;
+__thread union tcp_addr this_ip_dest;
+int test_family;
+
+struct new_pthread_arg {
+	thread_fn	func;
+	union tcp_addr	my_ip;
+	union tcp_addr	dest_ip;
+};
+static void *new_pthread_entry(void *arg)
+{
+	struct new_pthread_arg *p = arg;
+
+	this_ip_addr = p->my_ip;
+	this_ip_dest = p->dest_ip;
+	p->func(NULL); /* shouldn't return */
+	exit(KSFT_FAIL);
+}
+
+static void __test_skip_all(const char *msg)
+{
+	ksft_set_plan(1);
+	ksft_print_header();
+	skipped = 1;
+	test_skip("%s", msg);
+	exit(KSFT_SKIP);
+}
+
+void __test_init(unsigned int ntests, int family, unsigned int prefix,
+		 union tcp_addr addr1, union tcp_addr addr2,
+		 thread_fn peer1, thread_fn peer2)
+{
+	struct sigaction sa = {
+		.sa_handler = sig_int,
+		.sa_flags = SA_RESTART,
+	};
+	time_t seed = time(NULL);
+
+	sigemptyset(&sa.sa_mask);
+	if (sigaction(SIGINT, &sa, NULL))
+		test_error("Can't set SIGINT handler");
+
+	test_family = family;
+	if (!kernel_config_has(KCONFIG_NET_NS))
+		__test_skip_all(tests_skip_reason[KCONFIG_NET_NS]);
+	if (!kernel_config_has(KCONFIG_VETH))
+		__test_skip_all(tests_skip_reason[KCONFIG_VETH]);
+	if (!kernel_config_has(KCONFIG_TCP_AO))
+		__test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]);
+
+	ksft_set_plan(ntests);
+	test_print("rand seed %u", (unsigned int)seed);
+	srand(seed);
+
+
+	ksft_print_header();
+	init_namespaces();
+
+	if (add_veth(veth_name, nsfd_parent, nsfd_child))
+		test_error("Failed to add veth");
+
+	switch_ns(nsfd_child);
+	link_init(veth_name, family, prefix, addr2, addr1);
+	if (peer2) {
+		struct new_pthread_arg targ;
+		pthread_t t;
+
+		targ.my_ip = addr2;
+		targ.dest_ip = addr1;
+		targ.func = peer2;
+		nr_threads++;
+		if (pthread_create(&t, NULL, new_pthread_entry, &targ))
+			test_error("Failed to create pthread");
+	}
+	switch_ns(nsfd_parent);
+	link_init(veth_name, family, prefix, addr1, addr2);
+
+	this_ip_addr = addr1;
+	this_ip_dest = addr2;
+	peer1(NULL);
+	if (failed)
+		exit(KSFT_FAIL);
+	else
+		exit(KSFT_PASS);
+}
+
+/* /proc/sys/net/core/optmem_max artifically limits the amount of memory
+ * that can be allocated with sock_kmalloc() on each socket in the system.
+ * It is not virtualized, so it has to written outside test namespaces.
+ * To be nice a test will revert optmem back to the old value.
+ * Keeping it simple without any file lock, which means the tests that
+ * need to set/increase optmem value shouldn't run in parallel.
+ * Also, not re-entrant.
+ */
+static const char *optmem_file = "/proc/sys/net/core/optmem_max";
+static size_t saved_optmem;
+
+size_t test_get_optmem(void)
+{
+	FILE *foptmem;
+	int old_ns;
+	size_t ret;
+
+	old_ns = switch_save_ns(nsfd_outside);
+	foptmem = fopen(optmem_file, "r");
+	if (!foptmem)
+		test_error("failed to open %s", optmem_file);
+
+	if (fscanf(foptmem, "%zu", &ret) != 1)
+		test_error("can't read from %s", optmem_file);
+	fclose(foptmem);
+	switch_ns(old_ns);
+	return ret;
+}
+
+static void __test_set_optmem(size_t new, size_t *old)
+{
+	FILE *foptmem;
+	int old_ns;
+
+	if (old != NULL)
+		*old = test_get_optmem();
+
+	old_ns = switch_save_ns(nsfd_outside);
+	foptmem = fopen(optmem_file, "w");
+	if (!foptmem)
+		test_error("failed to open %s", optmem_file);
+
+	if (fprintf(foptmem, "%zu", new) <= 0)
+		test_error("can't write %zu to %s", new, optmem_file);
+	fclose(foptmem);
+	switch_ns(old_ns);
+}
+
+static void test_revert_optmem(void)
+{
+	if (saved_optmem == 0)
+		return;
+
+	__test_set_optmem(saved_optmem, NULL);
+}
+
+void test_set_optmem(size_t value)
+{
+	if (saved_optmem == 0) {
+		__test_set_optmem(value, &saved_optmem);
+		test_add_destructor(test_revert_optmem);
+	} else {
+		__test_set_optmem(value, NULL);
+	}
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/sock.c b/tools/testing/selftests/net/tcp_ao/lib/sock.c
new file mode 100644
index 000000000000..7f3c31b7d997
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/sock.c
@@ -0,0 +1,592 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <alloca.h>
+#include <fcntl.h>
+#include <inttypes.h>
+#include <string.h>
+#include "../../../../../include/linux/kernel.h"
+#include "../../../../../include/linux/stringify.h"
+#include "aolib.h"
+
+const unsigned int test_server_port = 7010;
+int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
+{
+	int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+	long flags;
+
+	if (sk < 0)
+		test_error("socket()");
+
+	err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
+			 strlen(veth_name) + 1);
+	if (err < 0)
+		test_error("setsockopt(SO_BINDTODEVICE)");
+
+	if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
+		test_error("bind()");
+
+	flags = fcntl(sk, F_GETFL);
+	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+		test_error("fcntl()");
+
+	if (listen(sk, backlog))
+		test_error("listen()");
+
+	return sk;
+}
+
+int test_wait_fd(int sk, time_t sec, bool write)
+{
+	struct timeval tv = { .tv_sec = sec };
+	struct timeval *ptv = NULL;
+	fd_set fds, efds;
+	int ret;
+	socklen_t slen = sizeof(ret);
+
+	FD_ZERO(&fds);
+	FD_SET(sk, &fds);
+	FD_ZERO(&efds);
+	FD_SET(sk, &efds);
+
+	if (sec)
+		ptv = &tv;
+
+	errno = 0;
+	if (write)
+		ret = select(sk + 1, NULL, &fds, &efds, ptv);
+	else
+		ret = select(sk + 1, &fds, NULL, &efds, ptv);
+	if (ret < 0)
+		return -errno;
+	if (ret == 0) {
+		errno = ETIMEDOUT;
+		return -ETIMEDOUT;
+	}
+
+	if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen) || ret)
+		return -ret;
+	return 0;
+}
+
+int __test_connect_socket(int sk, const char *device,
+			  void *addr, size_t addr_sz, time_t timeout)
+{
+	long flags;
+	int err;
+
+	if (device != NULL) {
+		err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
+				 strlen(device) + 1);
+		if (err < 0)
+			test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
+	}
+
+	if (!timeout) {
+		err = connect(sk, addr, addr_sz);
+		if (err) {
+			err = -errno;
+			goto out;
+		}
+		return 0;
+	}
+
+	flags = fcntl(sk, F_GETFL);
+	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+		test_error("fcntl()");
+
+	if (connect(sk, addr, addr_sz) < 0) {
+		if (errno != EINPROGRESS) {
+			err = -errno;
+			goto out;
+		}
+		if (timeout < 0)
+			return sk;
+		err = test_wait_fd(sk, timeout, 1);
+		if (err)
+			goto out;
+	}
+	return sk;
+
+out:
+	close(sk);
+	return err;
+}
+
+int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
+		   int vrf, const char *password)
+{
+	size_t pwd_len = strlen(password);
+	struct tcp_md5sig md5sig = {};
+
+	md5sig.tcpm_keylen = pwd_len;
+	memcpy(md5sig.tcpm_key, password, pwd_len);
+	md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
+	md5sig.tcpm_prefixlen = prefix;
+	if (vrf >= 0) {
+		md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
+		md5sig.tcpm_ifindex = (uint8_t)vrf;
+	}
+	memcpy(&md5sig.tcpm_addr, addr, addr_sz);
+
+	errno = 0;
+	return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
+			&md5sig, sizeof(md5sig));
+}
+
+
+int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
+		void *addr, size_t addr_sz, bool set_current, bool set_rnext,
+		uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
+		uint8_t maclen, uint8_t keyflags,
+		uint8_t keylen, const char *key)
+{
+	memset(ao, 0, sizeof(struct tcp_ao_add));
+
+	ao->set_current	= !!set_current;
+	ao->set_rnext	= !!set_rnext;
+	ao->prefix	= prefix;
+	ao->sndid	= sndid;
+	ao->rcvid	= rcvid;
+	ao->maclen	= maclen;
+	ao->keyflags	= keyflags;
+	ao->keylen	= keylen;
+	ao->ifindex	= vrf;
+
+	memcpy(&ao->addr, addr, addr_sz);
+
+	if (strlen(alg) > 64)
+		return -ENOBUFS;
+	strncpy(ao->alg_name, alg, 64);
+
+	memcpy(ao->key, key,
+	       (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
+	return 0;
+}
+
+static int test_get_ao_keys_nr(int sk)
+{
+	struct tcp_ao_getsockopt tmp = {};
+	socklen_t tmp_sz = sizeof(tmp);
+	int ret;
+
+	tmp.nkeys  = 1;
+	tmp.get_all = 1;
+
+	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
+	if (ret)
+		return -errno;
+	return (int)tmp.nkeys;
+}
+
+int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
+		void *addr, size_t addr_sz, uint8_t prefix,
+		uint8_t sndid, uint8_t rcvid)
+{
+	struct tcp_ao_getsockopt tmp = {};
+	socklen_t tmp_sz = sizeof(tmp);
+	int ret;
+
+	memcpy(&tmp.addr, addr, addr_sz);
+	tmp.prefix = prefix;
+	tmp.sndid  = sndid;
+	tmp.rcvid  = rcvid;
+	tmp.nkeys  = 1;
+
+	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
+	if (ret)
+		return ret;
+	if (tmp.nkeys != 1)
+		return -E2BIG;
+	*out = tmp;
+	return 0;
+}
+
+int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
+{
+	socklen_t sz = sizeof(*out);
+
+	out->reserved = 0;
+	out->reserved2 = 0;
+	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
+		return -errno;
+	if (sz != sizeof(*out))
+		return -EMSGSIZE;
+	return 0;
+}
+
+int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
+{
+	socklen_t sz = sizeof(*in);
+
+	in->reserved = 0;
+	in->reserved2 = 0;
+	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
+		return -errno;
+	return 0;
+}
+
+int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
+				   const struct tcp_ao_getsockopt *b)
+{
+	bool is_kdf_aes_128_cmac = false;
+	bool is_cmac_aes = false;
+
+	if (!strcmp("cmac(aes128)", a->alg_name)) {
+		is_kdf_aes_128_cmac = (a->keylen != 16);
+		is_cmac_aes = true;
+	}
+
+#define __cmp_ao(member)						\
+do {									\
+	if (b->member != a->member) {					\
+		test_fail("getsockopt(): " __stringify(member) " %u != %u",	\
+				b->member, a->member);			\
+		return -1;						\
+	}								\
+} while(0)
+	__cmp_ao(sndid);
+	__cmp_ao(rcvid);
+	__cmp_ao(prefix);
+	__cmp_ao(keyflags);
+	__cmp_ao(ifindex);
+	if (a->maclen) {
+		__cmp_ao(maclen);
+	} else if (b->maclen != 12) {
+		test_fail("getsockopt(): expected default maclen 12, but it's %u",
+				b->maclen);
+		return -1;
+	}
+	if (!is_kdf_aes_128_cmac) {
+		__cmp_ao(keylen);
+	} else if (b->keylen != 16) {
+		test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
+				b->keylen);
+		return -1;
+	}
+#undef __cmp_ao
+	if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
+		test_fail("getsockopt(): returned key is different `%s' != `%s'",
+				b->key, a->key);
+		return -1;
+	}
+	if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
+		test_fail("getsockopt(): returned address is different");
+		return -1;
+	}
+	if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
+		test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
+		return -1;
+	}
+	if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
+		test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
+		return -1;
+	}
+	/* For a established key rotation test don't add a key with
+	 * set_current = 1, as it's likely to change by peer's request;
+	 * rather use setsockopt(TCP_AO_INFO)
+	 */
+	if (a->set_current != b->is_current) {
+		test_fail("getsockopt(): returned key is not Current_key");
+		return -1;
+	}
+	if (a->set_rnext != b->is_rnext) {
+		test_fail("getsockopt(): returned key is not RNext_key");
+		return -1;
+	}
+
+	return 0;
+}
+
+int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
+				      const struct tcp_ao_info_opt *b)
+{
+	/* No check for ::current_key, as it may change by the peer */
+	if (a->ao_required != b->ao_required) {
+		test_fail("getsockopt(): returned ao doesn't have ao_required");
+		return -1;
+	}
+	if (a->accept_icmps != b->accept_icmps) {
+		test_fail("getsockopt(): returned ao doesn't accept ICMPs");
+		return -1;
+	}
+	if (a->set_rnext && a->rnext != b->rnext) {
+		test_fail("getsockopt(): RNext KeyID has changed");
+		return -1;
+	}
+#define __cmp_cnt(member)						\
+do {									\
+	if (b->member != a->member) {					\
+		test_fail("getsockopt(): " __stringify(member) " %llu != %llu",	\
+				b->member, a->member);			\
+		return -1;						\
+	}								\
+} while(0)
+	if (a->set_counters) {
+		__cmp_cnt(pkt_good);
+		__cmp_cnt(pkt_bad);
+		__cmp_cnt(pkt_key_not_found);
+		__cmp_cnt(pkt_ao_required);
+		__cmp_cnt(pkt_dropped_icmp);
+	}
+#undef __cmp_cnt
+	return 0;
+}
+
+int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
+{
+	struct tcp_ao_getsockopt *key_dump;
+	socklen_t key_dump_sz = sizeof(*key_dump);
+	struct tcp_ao_info_opt info = {};
+	bool c1, c2, c3, c4, c5;
+	struct netstat *ns;
+	int err, nr_keys;
+
+	memset(out, 0, sizeof(*out));
+
+	/* per-netns */
+	ns = netstat_read();
+	out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
+	out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
+	out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
+	out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
+	out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
+	netstat_free(ns);
+	if (c1 || c2 || c3 || c4 || c5)
+		return -EOPNOTSUPP;
+
+	err = test_get_ao_info(sk, &info);
+	if (err)
+		return err;
+
+	/* per-socket */
+	out->ao_info_pkt_good		= info.pkt_good;
+	out->ao_info_pkt_bad		= info.pkt_bad;
+	out->ao_info_pkt_key_not_found	= info.pkt_key_not_found;
+	out->ao_info_pkt_ao_required	= info.pkt_ao_required;
+	out->ao_info_pkt_dropped_icmp	= info.pkt_dropped_icmp;
+
+	/* per-key */
+	nr_keys = test_get_ao_keys_nr(sk);
+	if (nr_keys < 0)
+		return nr_keys;
+	if (nr_keys == 0)
+		test_error("test_get_ao_keys_nr() == 0");
+	out->nr_keys = (size_t)nr_keys;
+	key_dump = calloc(nr_keys, key_dump_sz);
+	if (!key_dump)
+		return -errno;
+
+	key_dump[0].nkeys = nr_keys;
+	key_dump[0].get_all = 1;
+	key_dump[0].get_all = 1;
+	err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
+			 key_dump, &key_dump_sz);
+	if (err) {
+		free(key_dump);
+		return -errno;
+	}
+
+	out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
+	if (!out->key_cnts) {
+		free(key_dump);
+		return -errno;
+	}
+
+	while (nr_keys--) {
+		out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
+		out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
+		out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
+		out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
+	}
+	free(key_dump);
+
+	return 0;
+}
+
+int __test_tcp_ao_counters_cmp(const char *tst_name,
+			       struct tcp_ao_counters *before,
+			       struct tcp_ao_counters *after,
+			       test_cnt expected)
+{
+#define __cmp_ao(cnt, expecting_inc)					\
+do {									\
+	if (before->cnt > after->cnt) {					\
+		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
+			  tst_name ?: "", before->cnt, after->cnt);		\
+		return -1;						\
+	}								\
+	if ((before->cnt != after->cnt) != (expecting_inc)) {		\
+		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
+			  tst_name ?: "", (expecting_inc) ? "" : "not ",	\
+			  before->cnt, after->cnt);			\
+		return -1;						\
+	}								\
+} while(0)
+
+	errno = 0;
+	/* per-netns */
+	__cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
+	__cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
+	__cmp_ao(netns_ao_key_not_found,
+		 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
+	__cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
+	__cmp_ao(netns_ao_dropped_icmp,
+		 !!(expected & TEST_CNT_NS_DROPPED_ICMP));
+	/* per-socket */
+	__cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
+	__cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
+	__cmp_ao(ao_info_pkt_key_not_found,
+		 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
+	__cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
+	__cmp_ao(ao_info_pkt_dropped_icmp,
+		 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
+	return 0;
+#undef __cmp_ao
+}
+
+int test_tcp_ao_key_counters_cmp(const char *tst_name,
+				 struct tcp_ao_counters *before,
+				 struct tcp_ao_counters *after,
+				 test_cnt expected,
+				 int sndid, int rcvid)
+{
+	size_t i;
+#define __cmp_ao(i, cnt, expecting_inc)					\
+do {									\
+	if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) {		\
+		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
+			  tst_name ?: "", before->key_cnts[i].cnt,	\
+			  after->key_cnts[i].cnt,			\
+			  before->key_cnts[i].sndid,			\
+			  before->key_cnts[i].rcvid);			\
+		return -1;						\
+	}								\
+	if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) {		\
+		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
+			  tst_name ?: "", (expecting_inc) ? "" : "not ",\
+			  before->key_cnts[i].cnt,			\
+			  after->key_cnts[i].cnt,			\
+			  before->key_cnts[i].sndid,			\
+			  before->key_cnts[i].rcvid);			\
+		return -1;						\
+	}								\
+} while(0)
+
+	if (before->nr_keys != after->nr_keys) {
+		test_fail("%s: Keys changed on the socket %zu != %zu",
+			  tst_name, before->nr_keys, after->nr_keys);
+		return -1;
+	}
+
+	/* per-key */
+	i = before->nr_keys;
+	while (i--) {
+		if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
+			continue;
+		if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
+			continue;
+		__cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
+		__cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
+	}
+	return 0;
+#undef __cmp_ao
+}
+
+void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
+{
+	free(cnts->key_cnts);
+}
+
+#define TEST_BUF_SIZE 4096
+ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
+{
+	ssize_t total = 0;
+
+	do {
+		char buf[TEST_BUF_SIZE];
+		ssize_t bytes, sent;
+		int ret;
+
+		ret = test_wait_fd(sk, timeout_sec, 0);
+		if (ret)
+			return ret;
+
+		bytes = recv(sk, buf, sizeof(buf), 0);
+
+		if (bytes < 0)
+			test_error("recv(): %zd", bytes);
+		if (bytes == 0)
+			break;
+
+		ret = test_wait_fd(sk, timeout_sec, 1);
+		if (ret)
+			return ret;
+
+		sent = send(sk, buf, bytes, 0);
+		if (sent == 0)
+			break;
+		if (sent != bytes)
+			test_error("send()");
+		total += bytes;
+	} while (!quota || total < quota);
+
+	return total;
+}
+
+ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
+			 const size_t msg_len, time_t timeout_sec)
+{
+	char msg[msg_len];
+	int nodelay = 1;
+	size_t i;
+
+	if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
+		test_error("setsockopt(TCP_NODELAY)");
+
+	for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
+		size_t sent, bytes = min(msg_len, buf_sz - i);
+		int ret;
+
+		ret = test_wait_fd(sk, timeout_sec, 1);
+		if (ret)
+			return ret;
+
+		sent = send(sk, buf + i, bytes, 0);
+		if (sent == 0)
+			break;
+		if (sent != bytes)
+			test_error("send()");
+
+		bytes = 0;
+		do {
+			ssize_t got;
+
+			ret = test_wait_fd(sk, timeout_sec, 0);
+			if (ret)
+				return ret;
+
+			got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
+			if (got <= 0)
+				test_error("recv(): %zd", got);
+			bytes += got;
+		} while (bytes < sent);
+		if (bytes > sent)
+			test_error("recv(): %zd > %zd", bytes, sent);
+		if (memcmp(buf + i, msg, bytes) != 0) {
+			test_fail("received message differs");
+			return -1;
+		}
+	}
+	return i;
+}
+
+int test_client_verify(int sk, const size_t msg_len, const size_t nr,
+		       time_t timeout_sec)
+{
+	size_t buf_sz = msg_len * nr;
+	char *buf = alloca(buf_sz);
+
+	randomize_buffer(buf, buf_sz);
+	if (test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec) != buf_sz)
+		return -1;
+	return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/utils.c b/tools/testing/selftests/net/tcp_ao/lib/utils.c
new file mode 100644
index 000000000000..372daca525f5
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/utils.c
@@ -0,0 +1,30 @@
+// SPDX-License-Identifier: GPL-2.0
+#include "aolib.h"
+#include <string.h>
+
+void randomize_buffer(void *buf, size_t buflen)
+{
+	int *p = (int *)buf;
+	size_t words = buflen / sizeof(int);
+	size_t leftover = buflen % sizeof(int);
+
+	if (!buflen)
+		return;
+
+	while (words--)
+		*p++ = rand();
+
+	if (leftover) {
+		int tmp = rand();
+
+		memcpy(buf + buflen - leftover, &tmp, leftover);
+	}
+}
+
+const struct sockaddr_in6 addr_any6 = {
+	.sin6_family	= AF_INET6,
+};
+
+const struct sockaddr_in addr_any4 = {
+	.sin_family	= AF_INET,
+};

-- 
2.43.0





[Index of Archives]     [Linux Wireless]     [Linux Kernel]     [ATH6KL]     [Linux Bluetooth]     [Linux Netdev]     [Kernel Newbies]     [Share Photos]     [IDE]     [Security]     [Git]     [Netfilter]     [Bugtraq]     [Yosemite News]     [MIPS Linux]     [ARM Linux]     [Linux Security]     [Linux RAID]     [Linux ATA RAID]     [Samba]     [Device Mapper]

  Powered by Linux