By default TCP-AO keys apply to all possible peers but it's possible to have different keys for different remote hosts. This patch adds initial tests for the behavior behind the TCP_AUTHOPT_KEY_BIND_ADDR flag. Server rejection is tested via client timeout so this can be slightly slow. Signed-off-by: Leonard Crestez <cdleonard@xxxxxxxxx> --- .../tcp_authopt_test/netns_fixture.py | 63 +++++++ .../tcp_authopt/tcp_authopt_test/server.py | 82 ++++++++++ .../tcp_authopt/tcp_authopt_test/test_bind.py | 143 ++++++++++++++++ .../tcp_authopt/tcp_authopt_test/utils.py | 154 ++++++++++++++++++ 4 files changed, 442 insertions(+) create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py new file mode 100644 index 000000000000..20bb12c2aae2 --- /dev/null +++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: GPL-2.0 +import subprocess +import socket +from ipaddress import IPv4Address +from ipaddress import IPv6Address + + +class NamespaceFixture: + """Create a pair of namespaces connected by one veth pair + + Each end of the pair has multiple addresses but everything is in the same subnet + """ + + ns1_name = "tcp_authopt_test_1" + ns2_name = "tcp_authopt_test_2" + + @classmethod + def get_ipv4_addr(cls, ns=1, index=1) -> IPv4Address: + return IPv4Address("10.10.0.0") + (ns << 8) + index + + @classmethod + def get_ipv6_addr(cls, ns=1, index=1) -> IPv6Address: + return IPv6Address("fd00::") + (ns << 16) + index + + @classmethod + def get_addr(cls, address_family=socket.AF_INET, ns=1, index=1): + if address_family == socket.AF_INET: + return cls.get_ipv4_addr(ns, index) + elif address_family == socket.AF_INET6: + return cls.get_ipv6_addr(ns, index) + else: + raise ValueError(f"Bad address_family={address_family}") + + def __init__(self, **kw): + for k, v in kw.items(): + setattr(self, k, v) + + def __enter__(self): + script = f""" +set -e -x +ip netns del {self.ns1_name} || true +ip netns del {self.ns2_name} || true +ip netns add {self.ns1_name} +ip netns add {self.ns2_name} +ip link add veth0 netns {self.ns1_name} type veth peer name veth0 netns {self.ns2_name} +ip netns exec {self.ns1_name} ip link set veth0 up +ip netns exec {self.ns2_name} ip link set veth0 up +""" + for index in [1, 2, 3]: + script += f"ip -n {self.ns1_name} addr add {self.get_ipv4_addr(1, index)}/16 dev veth0\n" + script += f"ip -n {self.ns2_name} addr add {self.get_ipv4_addr(2, index)}/16 dev veth0\n" + script += f"ip -n {self.ns1_name} addr add {self.get_ipv6_addr(1, index)}/64 dev veth0 nodad\n" + script += f"ip -n {self.ns2_name} addr add {self.get_ipv6_addr(2, index)}/64 dev veth0 nodad\n" + subprocess.run(script, shell=True, check=True) + return self + + def __exit__(self, *a): + script = f""" +set -e -x +ip netns del {self.ns1_name} || true +ip netns del {self.ns2_name} || true +""" + subprocess.run(script, shell=True, check=True) diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py new file mode 100644 index 000000000000..c4cce8f5862a --- /dev/null +++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: GPL-2.0 +import logging +import os +import selectors +from contextlib import ExitStack +from threading import Thread + +logger = logging.getLogger(__name__) + + +class SimpleServerThread(Thread): + """Simple server thread for testing TCP sockets + + All data is read in 1000 bytes chunks and either echoed back or discarded. + """ + + def __init__(self, socket, mode="recv"): + self.listen_socket = socket + self.server_socket = [] + self.mode = mode + super().__init__() + + def _read(self, conn, events): + # logger.debug("events=%r", events) + data = conn.recv(1000) + # logger.debug("len(data)=%r", len(data)) + if len(data) == 0: + # logger.info("closing %r", conn) + conn.close() + self.sel.unregister(conn) + else: + if self.mode == "echo": + conn.sendall(data) + elif self.mode == "recv": + pass + else: + raise ValueError(f"Unknown mode {self.mode}") + + def _stop_pipe_read(self, conn, events): + self.should_loop = False + + def start(self) -> None: + self.exit_stack = ExitStack() + self._stop_pipe_rfd, self._stop_pipe_wfd = os.pipe() + self.exit_stack.callback(lambda: os.close(self._stop_pipe_rfd)) + self.exit_stack.callback(lambda: os.close(self._stop_pipe_wfd)) + return super().start() + + def _accept(self, conn, events): + assert conn == self.listen_socket + conn, _addr = self.listen_socket.accept() + conn = self.exit_stack.enter_context(conn) + conn.setblocking(False) + self.sel.register(conn, selectors.EVENT_READ, self._read) + self.server_socket.append(conn) + + def run(self): + self.should_loop = True + self.sel = self.exit_stack.enter_context(selectors.DefaultSelector()) + self.sel.register( + self._stop_pipe_rfd, selectors.EVENT_READ, self._stop_pipe_read + ) + self.sel.register(self.listen_socket, selectors.EVENT_READ, self._accept) + # logger.debug("loop init") + while self.should_loop: + for key, events in self.sel.select(timeout=1): + callback = key.data + callback(key.fileobj, events) + # logger.debug("loop done") + + def stop(self): + """Try to stop nicely""" + os.write(self._stop_pipe_wfd, b"Q") + self.join() + self.exit_stack.close() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py new file mode 100644 index 000000000000..980954098e97 --- /dev/null +++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: GPL-2.0 +"""Test TCP-AO keys can be bound to specific remote addresses""" +from contextlib import ExitStack +import socket +import pytest +from .netns_fixture import NamespaceFixture +from .utils import create_listen_socket +from .server import SimpleServerThread +from . import linux_tcp_authopt +from .linux_tcp_authopt import ( + tcp_authopt, + set_tcp_authopt, + set_tcp_authopt_key, + tcp_authopt_key, +) +from .utils import netns_context, DEFAULT_TCP_SERVER_PORT, check_socket_echo +from .conftest import skipif_missing_tcp_authopt + +pytestmark = skipif_missing_tcp_authopt + + +@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6]) +def test_addr_server_bind(exit_stack: ExitStack, address_family): + """ "Server only accept client2, check client1 fails""" + nsfixture = exit_stack.enter_context(NamespaceFixture()) + server_addr = str(nsfixture.get_addr(address_family, 1, 1)) + client_addr = str(nsfixture.get_addr(address_family, 2, 1)) + client_addr2 = str(nsfixture.get_addr(address_family, 2, 2)) + + # create server: + listen_socket = exit_stack.push( + create_listen_socket(family=address_family, ns=nsfixture.ns1_name) + ) + exit_stack.enter_context(SimpleServerThread(listen_socket, mode="echo")) + + # set keys: + server_key = tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="hello", + flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR, + addr=client_addr2, + ) + set_tcp_authopt( + listen_socket, + tcp_authopt(flags=linux_tcp_authopt.TCP_AUTHOPT_FLAG_REJECT_UNEXPECTED), + ) + set_tcp_authopt_key(listen_socket, server_key) + + # create client socket: + def create_client_socket(): + with netns_context(nsfixture.ns2_name): + client_socket = socket.socket(address_family, socket.SOCK_STREAM) + client_key = tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="hello", + ) + set_tcp_authopt_key(client_socket, client_key) + return client_socket + + # addr match: + # with create_client_socket() as client_socket2: + # client_socket2.bind((client_addr2, 0)) + # client_socket2.settimeout(1.0) + # client_socket2.connect((server_addr, TCP_SERVER_PORT)) + + # addr mismatch: + with create_client_socket() as client_socket1: + client_socket1.bind((client_addr, 0)) + with pytest.raises(socket.timeout): + client_socket1.settimeout(1.0) + client_socket1.connect((server_addr, DEFAULT_TCP_SERVER_PORT)) + + +@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6]) +def test_addr_client_bind(exit_stack: ExitStack, address_family): + """ "Client configures different keys with same id but different addresses""" + nsfixture = exit_stack.enter_context(NamespaceFixture()) + server_addr1 = str(nsfixture.get_addr(address_family, 1, 1)) + server_addr2 = str(nsfixture.get_addr(address_family, 1, 2)) + client_addr = str(nsfixture.get_addr(address_family, 2, 1)) + + # create servers: + listen_socket1 = exit_stack.enter_context( + create_listen_socket( + family=address_family, ns=nsfixture.ns1_name, bind_addr=server_addr1 + ) + ) + listen_socket2 = exit_stack.enter_context( + create_listen_socket( + family=address_family, ns=nsfixture.ns1_name, bind_addr=server_addr2 + ) + ) + exit_stack.enter_context(SimpleServerThread(listen_socket1, mode="echo")) + exit_stack.enter_context(SimpleServerThread(listen_socket2, mode="echo")) + + # set keys: + set_tcp_authopt_key( + listen_socket1, + tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="11111", + ), + ) + set_tcp_authopt_key( + listen_socket2, + tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="22222", + ), + ) + + # create client socket: + def create_client_socket(): + with netns_context(nsfixture.ns2_name): + client_socket = socket.socket(address_family, socket.SOCK_STREAM) + set_tcp_authopt_key( + client_socket, + tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="11111", + flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR, + addr=server_addr1, + ), + ) + set_tcp_authopt_key( + client_socket, + tcp_authopt_key( + alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96, + key="22222", + flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR, + addr=server_addr2, + ), + ) + client_socket.settimeout(1.0) + client_socket.bind((client_addr, 0)) + return client_socket + + with create_client_socket() as client_socket1: + client_socket1.connect((server_addr1, DEFAULT_TCP_SERVER_PORT)) + check_socket_echo(client_socket1) + with create_client_socket() as client_socket2: + client_socket2.connect((server_addr2, DEFAULT_TCP_SERVER_PORT)) + check_socket_echo(client_socket2) diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py new file mode 100644 index 000000000000..22bd3f0a142a --- /dev/null +++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: GPL-2.0 +import json +import random +import subprocess +import threading +import typing +import socket +from dataclasses import dataclass +from contextlib import nullcontext + +from nsenter import Namespace +from scapy.sendrecv import AsyncSniffer + + +# TCP port does not impact Authentication Option so define a single default +DEFAULT_TCP_SERVER_PORT = 17971 + + +class SimpleWaitEvent(threading.Event): + @property + def value(self) -> bool: + return self.is_set() + + @value.setter + def value(self, value: bool): + if value: + self.set() + else: + self.clear() + + def wait(self, timeout=None): + """Like Event.wait except raise on timeout""" + super().wait(timeout) + if not self.is_set(): + raise TimeoutError(f"Timed out timeout={timeout!r}") + + +def recvall(sock, todo): + """Receive exactly todo bytes unless EOF""" + data = bytes() + while True: + chunk = sock.recv(todo) + if not len(chunk): + return data + data += chunk + todo -= len(chunk) + if todo == 0: + return data + assert todo > 0 + + +def randbytes(count) -> bytes: + """Return a random byte array""" + return bytes([random.randint(0, 255) for index in range(count)]) + + +def check_socket_echo(sock, size=1024): + """Send random bytes and check they are received""" + send_buf = randbytes(size) + sock.sendall(send_buf) + recv_buf = recvall(sock, size) + assert send_buf == recv_buf + + +def nstat_json(command_prefix: str = ""): + """Parse nstat output into a python dict""" + runres = subprocess.run( + f"{command_prefix}nstat -a --zeros --json", + shell=True, + check=True, + stdout=subprocess.PIPE, + encoding="utf-8", + ) + return json.loads(runres.stdout) + + +def netns_context(ns: str = ""): + """Create context manager for a certain optional netns + + If the ns argument is empty then just return a `nullcontext` + """ + if ns: + return Namespace("/var/run/netns/" + ns, "net") + else: + return nullcontext() + + +def create_listen_socket( + ns: str = "", + family=socket.AF_INET, + reuseaddr=True, + listen_depth=10, + bind_addr="", + bind_port=DEFAULT_TCP_SERVER_PORT, +): + with netns_context(ns): + listen_socket = socket.socket(family, socket.SOCK_STREAM) + if reuseaddr: + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((str(bind_addr), bind_port)) + listen_socket.listen(listen_depth) + return listen_socket + + +@dataclass +class tcphdr_authopt: + """Representation of a TCP auth option as it appears in a TCP packet""" + + keyid: int + rnextkeyid: int + mac: bytes + + @classmethod + def unpack(cls, buf) -> "tcphdr_authopt": + return cls(buf[0], buf[1], buf[2:]) + + def __repr__(self): + return f"tcphdr_authopt({self.keyid}, {self.rnextkeyid}, bytes.fromhex({self.mac.hex(' ')!r})" + + +def scapy_tcp_get_authopt_val(tcp) -> typing.Optional[tcphdr_authopt]: + for optnum, optval in tcp.options: + if optnum == 29: + return tcphdr_authopt.unpack(optval) + return None + + +def scapy_sniffer_start_block(sniffer: AsyncSniffer, timeout=1): + """Like AsyncSniffer.start except block until sniffing starts + + This ensures no lost packets and no delays + """ + if sniffer.kwargs.get("started_callback"): + raise ValueError("sniffer must not already have a started_callback") + + e = SimpleWaitEvent() + sniffer.kwargs["started_callback"] = e.set + sniffer.start() + e.wait(timeout=timeout) + + +def scapy_sniffer_stop(sniffer: AsyncSniffer): + """Like AsyncSniffer.stop except no error is raising if not running""" + if sniffer is not None and sniffer.running: + sniffer.stop() + + +class AsyncSnifferContext(AsyncSniffer): + def __enter__(self): + scapy_sniffer_start_block(self) + return self + + def __exit__(self, *a): + scapy_sniffer_stop(self) -- 2.25.1