[RFCv3 09/15] selftests: tcp_authopt: Test key address binding

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



By default TCP-AO keys apply to all possible peers but it's possible to
have different keys for different remote hosts.

This patch adds initial tests for the behavior behind the
TCP_AUTHOPT_KEY_BIND_ADDR flag. Server rejection is tested via client
timeout so this can be slightly slow.

Signed-off-by: Leonard Crestez <cdleonard@xxxxxxxxx>
---
 .../tcp_authopt_test/netns_fixture.py         |  63 +++++++
 .../tcp_authopt/tcp_authopt_test/server.py    |  82 ++++++++++
 .../tcp_authopt/tcp_authopt_test/test_bind.py | 143 ++++++++++++++++
 .../tcp_authopt/tcp_authopt_test/utils.py     | 154 ++++++++++++++++++
 4 files changed, 442 insertions(+)
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py
new file mode 100644
index 000000000000..20bb12c2aae2
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/netns_fixture.py
@@ -0,0 +1,63 @@
+# SPDX-License-Identifier: GPL-2.0
+import subprocess
+import socket
+from ipaddress import IPv4Address
+from ipaddress import IPv6Address
+
+
+class NamespaceFixture:
+    """Create a pair of namespaces connected by one veth pair
+
+    Each end of the pair has multiple addresses but everything is in the same subnet
+    """
+
+    ns1_name = "tcp_authopt_test_1"
+    ns2_name = "tcp_authopt_test_2"
+
+    @classmethod
+    def get_ipv4_addr(cls, ns=1, index=1) -> IPv4Address:
+        return IPv4Address("10.10.0.0") + (ns << 8) + index
+
+    @classmethod
+    def get_ipv6_addr(cls, ns=1, index=1) -> IPv6Address:
+        return IPv6Address("fd00::") + (ns << 16) + index
+
+    @classmethod
+    def get_addr(cls, address_family=socket.AF_INET, ns=1, index=1):
+        if address_family == socket.AF_INET:
+            return cls.get_ipv4_addr(ns, index)
+        elif address_family == socket.AF_INET6:
+            return cls.get_ipv6_addr(ns, index)
+        else:
+            raise ValueError(f"Bad address_family={address_family}")
+
+    def __init__(self, **kw):
+        for k, v in kw.items():
+            setattr(self, k, v)
+
+    def __enter__(self):
+        script = f"""
+set -e -x
+ip netns del {self.ns1_name} || true
+ip netns del {self.ns2_name} || true
+ip netns add {self.ns1_name}
+ip netns add {self.ns2_name}
+ip link add veth0 netns {self.ns1_name} type veth peer name veth0 netns {self.ns2_name}
+ip netns exec {self.ns1_name} ip link set veth0 up
+ip netns exec {self.ns2_name} ip link set veth0 up
+"""
+        for index in [1, 2, 3]:
+            script += f"ip -n {self.ns1_name} addr add {self.get_ipv4_addr(1, index)}/16 dev veth0\n"
+            script += f"ip -n {self.ns2_name} addr add {self.get_ipv4_addr(2, index)}/16 dev veth0\n"
+            script += f"ip -n {self.ns1_name} addr add {self.get_ipv6_addr(1, index)}/64 dev veth0 nodad\n"
+            script += f"ip -n {self.ns2_name} addr add {self.get_ipv6_addr(2, index)}/64 dev veth0 nodad\n"
+        subprocess.run(script, shell=True, check=True)
+        return self
+
+    def __exit__(self, *a):
+        script = f"""
+set -e -x
+ip netns del {self.ns1_name} || true
+ip netns del {self.ns2_name} || true
+"""
+        subprocess.run(script, shell=True, check=True)
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py
new file mode 100644
index 000000000000..c4cce8f5862a
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/server.py
@@ -0,0 +1,82 @@
+# SPDX-License-Identifier: GPL-2.0
+import logging
+import os
+import selectors
+from contextlib import ExitStack
+from threading import Thread
+
+logger = logging.getLogger(__name__)
+
+
+class SimpleServerThread(Thread):
+    """Simple server thread for testing TCP sockets
+
+    All data is read in 1000 bytes chunks and either echoed back or discarded.
+    """
+
+    def __init__(self, socket, mode="recv"):
+        self.listen_socket = socket
+        self.server_socket = []
+        self.mode = mode
+        super().__init__()
+
+    def _read(self, conn, events):
+        # logger.debug("events=%r", events)
+        data = conn.recv(1000)
+        # logger.debug("len(data)=%r", len(data))
+        if len(data) == 0:
+            # logger.info("closing %r", conn)
+            conn.close()
+            self.sel.unregister(conn)
+        else:
+            if self.mode == "echo":
+                conn.sendall(data)
+            elif self.mode == "recv":
+                pass
+            else:
+                raise ValueError(f"Unknown mode {self.mode}")
+
+    def _stop_pipe_read(self, conn, events):
+        self.should_loop = False
+
+    def start(self) -> None:
+        self.exit_stack = ExitStack()
+        self._stop_pipe_rfd, self._stop_pipe_wfd = os.pipe()
+        self.exit_stack.callback(lambda: os.close(self._stop_pipe_rfd))
+        self.exit_stack.callback(lambda: os.close(self._stop_pipe_wfd))
+        return super().start()
+
+    def _accept(self, conn, events):
+        assert conn == self.listen_socket
+        conn, _addr = self.listen_socket.accept()
+        conn = self.exit_stack.enter_context(conn)
+        conn.setblocking(False)
+        self.sel.register(conn, selectors.EVENT_READ, self._read)
+        self.server_socket.append(conn)
+
+    def run(self):
+        self.should_loop = True
+        self.sel = self.exit_stack.enter_context(selectors.DefaultSelector())
+        self.sel.register(
+            self._stop_pipe_rfd, selectors.EVENT_READ, self._stop_pipe_read
+        )
+        self.sel.register(self.listen_socket, selectors.EVENT_READ, self._accept)
+        # logger.debug("loop init")
+        while self.should_loop:
+            for key, events in self.sel.select(timeout=1):
+                callback = key.data
+                callback(key.fileobj, events)
+        # logger.debug("loop done")
+
+    def stop(self):
+        """Try to stop nicely"""
+        os.write(self._stop_pipe_wfd, b"Q")
+        self.join()
+        self.exit_stack.close()
+
+    def __enter__(self):
+        self.start()
+        return self
+
+    def __exit__(self, *args):
+        self.stop()
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py
new file mode 100644
index 000000000000..980954098e97
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_bind.py
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Test TCP-AO keys can be bound to specific remote addresses"""
+from contextlib import ExitStack
+import socket
+import pytest
+from .netns_fixture import NamespaceFixture
+from .utils import create_listen_socket
+from .server import SimpleServerThread
+from . import linux_tcp_authopt
+from .linux_tcp_authopt import (
+    tcp_authopt,
+    set_tcp_authopt,
+    set_tcp_authopt_key,
+    tcp_authopt_key,
+)
+from .utils import netns_context, DEFAULT_TCP_SERVER_PORT, check_socket_echo
+from .conftest import skipif_missing_tcp_authopt
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6])
+def test_addr_server_bind(exit_stack: ExitStack, address_family):
+    """ "Server only accept client2, check client1 fails"""
+    nsfixture = exit_stack.enter_context(NamespaceFixture())
+    server_addr = str(nsfixture.get_addr(address_family, 1, 1))
+    client_addr = str(nsfixture.get_addr(address_family, 2, 1))
+    client_addr2 = str(nsfixture.get_addr(address_family, 2, 2))
+
+    # create server:
+    listen_socket = exit_stack.push(
+        create_listen_socket(family=address_family, ns=nsfixture.ns1_name)
+    )
+    exit_stack.enter_context(SimpleServerThread(listen_socket, mode="echo"))
+
+    # set keys:
+    server_key = tcp_authopt_key(
+        alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+        key="hello",
+        flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR,
+        addr=client_addr2,
+    )
+    set_tcp_authopt(
+        listen_socket,
+        tcp_authopt(flags=linux_tcp_authopt.TCP_AUTHOPT_FLAG_REJECT_UNEXPECTED),
+    )
+    set_tcp_authopt_key(listen_socket, server_key)
+
+    # create client socket:
+    def create_client_socket():
+        with netns_context(nsfixture.ns2_name):
+            client_socket = socket.socket(address_family, socket.SOCK_STREAM)
+        client_key = tcp_authopt_key(
+            alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+            key="hello",
+        )
+        set_tcp_authopt_key(client_socket, client_key)
+        return client_socket
+
+    # addr match:
+    # with create_client_socket() as client_socket2:
+    #    client_socket2.bind((client_addr2, 0))
+    #    client_socket2.settimeout(1.0)
+    #    client_socket2.connect((server_addr, TCP_SERVER_PORT))
+
+    # addr mismatch:
+    with create_client_socket() as client_socket1:
+        client_socket1.bind((client_addr, 0))
+        with pytest.raises(socket.timeout):
+            client_socket1.settimeout(1.0)
+            client_socket1.connect((server_addr, DEFAULT_TCP_SERVER_PORT))
+
+
+@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6])
+def test_addr_client_bind(exit_stack: ExitStack, address_family):
+    """ "Client configures different keys with same id but different addresses"""
+    nsfixture = exit_stack.enter_context(NamespaceFixture())
+    server_addr1 = str(nsfixture.get_addr(address_family, 1, 1))
+    server_addr2 = str(nsfixture.get_addr(address_family, 1, 2))
+    client_addr = str(nsfixture.get_addr(address_family, 2, 1))
+
+    # create servers:
+    listen_socket1 = exit_stack.enter_context(
+        create_listen_socket(
+            family=address_family, ns=nsfixture.ns1_name, bind_addr=server_addr1
+        )
+    )
+    listen_socket2 = exit_stack.enter_context(
+        create_listen_socket(
+            family=address_family, ns=nsfixture.ns1_name, bind_addr=server_addr2
+        )
+    )
+    exit_stack.enter_context(SimpleServerThread(listen_socket1, mode="echo"))
+    exit_stack.enter_context(SimpleServerThread(listen_socket2, mode="echo"))
+
+    # set keys:
+    set_tcp_authopt_key(
+        listen_socket1,
+        tcp_authopt_key(
+            alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+            key="11111",
+        ),
+    )
+    set_tcp_authopt_key(
+        listen_socket2,
+        tcp_authopt_key(
+            alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+            key="22222",
+        ),
+    )
+
+    # create client socket:
+    def create_client_socket():
+        with netns_context(nsfixture.ns2_name):
+            client_socket = socket.socket(address_family, socket.SOCK_STREAM)
+        set_tcp_authopt_key(
+            client_socket,
+            tcp_authopt_key(
+                alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+                key="11111",
+                flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR,
+                addr=server_addr1,
+            ),
+        )
+        set_tcp_authopt_key(
+            client_socket,
+            tcp_authopt_key(
+                alg=linux_tcp_authopt.TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+                key="22222",
+                flags=linux_tcp_authopt.TCP_AUTHOPT_KEY_BIND_ADDR,
+                addr=server_addr2,
+            ),
+        )
+        client_socket.settimeout(1.0)
+        client_socket.bind((client_addr, 0))
+        return client_socket
+
+    with create_client_socket() as client_socket1:
+        client_socket1.connect((server_addr1, DEFAULT_TCP_SERVER_PORT))
+        check_socket_echo(client_socket1)
+    with create_client_socket() as client_socket2:
+        client_socket2.connect((server_addr2, DEFAULT_TCP_SERVER_PORT))
+        check_socket_echo(client_socket2)
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py
new file mode 100644
index 000000000000..22bd3f0a142a
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/utils.py
@@ -0,0 +1,154 @@
+# SPDX-License-Identifier: GPL-2.0
+import json
+import random
+import subprocess
+import threading
+import typing
+import socket
+from dataclasses import dataclass
+from contextlib import nullcontext
+
+from nsenter import Namespace
+from scapy.sendrecv import AsyncSniffer
+
+
+# TCP port does not impact Authentication Option so define a single default
+DEFAULT_TCP_SERVER_PORT = 17971
+
+
+class SimpleWaitEvent(threading.Event):
+    @property
+    def value(self) -> bool:
+        return self.is_set()
+
+    @value.setter
+    def value(self, value: bool):
+        if value:
+            self.set()
+        else:
+            self.clear()
+
+    def wait(self, timeout=None):
+        """Like Event.wait except raise on timeout"""
+        super().wait(timeout)
+        if not self.is_set():
+            raise TimeoutError(f"Timed out timeout={timeout!r}")
+
+
+def recvall(sock, todo):
+    """Receive exactly todo bytes unless EOF"""
+    data = bytes()
+    while True:
+        chunk = sock.recv(todo)
+        if not len(chunk):
+            return data
+        data += chunk
+        todo -= len(chunk)
+        if todo == 0:
+            return data
+        assert todo > 0
+
+
+def randbytes(count) -> bytes:
+    """Return a random byte array"""
+    return bytes([random.randint(0, 255) for index in range(count)])
+
+
+def check_socket_echo(sock, size=1024):
+    """Send random bytes and check they are received"""
+    send_buf = randbytes(size)
+    sock.sendall(send_buf)
+    recv_buf = recvall(sock, size)
+    assert send_buf == recv_buf
+
+
+def nstat_json(command_prefix: str = ""):
+    """Parse nstat output into a python dict"""
+    runres = subprocess.run(
+        f"{command_prefix}nstat -a --zeros --json",
+        shell=True,
+        check=True,
+        stdout=subprocess.PIPE,
+        encoding="utf-8",
+    )
+    return json.loads(runres.stdout)
+
+
+def netns_context(ns: str = ""):
+    """Create context manager for a certain optional netns
+
+    If the ns argument is empty then just return a `nullcontext`
+    """
+    if ns:
+        return Namespace("/var/run/netns/" + ns, "net")
+    else:
+        return nullcontext()
+
+
+def create_listen_socket(
+    ns: str = "",
+    family=socket.AF_INET,
+    reuseaddr=True,
+    listen_depth=10,
+    bind_addr="",
+    bind_port=DEFAULT_TCP_SERVER_PORT,
+):
+    with netns_context(ns):
+        listen_socket = socket.socket(family, socket.SOCK_STREAM)
+    if reuseaddr:
+        listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    listen_socket.bind((str(bind_addr), bind_port))
+    listen_socket.listen(listen_depth)
+    return listen_socket
+
+
+@dataclass
+class tcphdr_authopt:
+    """Representation of a TCP auth option as it appears in a TCP packet"""
+
+    keyid: int
+    rnextkeyid: int
+    mac: bytes
+
+    @classmethod
+    def unpack(cls, buf) -> "tcphdr_authopt":
+        return cls(buf[0], buf[1], buf[2:])
+
+    def __repr__(self):
+        return f"tcphdr_authopt({self.keyid}, {self.rnextkeyid}, bytes.fromhex({self.mac.hex(' ')!r})"
+
+
+def scapy_tcp_get_authopt_val(tcp) -> typing.Optional[tcphdr_authopt]:
+    for optnum, optval in tcp.options:
+        if optnum == 29:
+            return tcphdr_authopt.unpack(optval)
+    return None
+
+
+def scapy_sniffer_start_block(sniffer: AsyncSniffer, timeout=1):
+    """Like AsyncSniffer.start except block until sniffing starts
+
+    This ensures no lost packets and no delays
+    """
+    if sniffer.kwargs.get("started_callback"):
+        raise ValueError("sniffer must not already have a started_callback")
+
+    e = SimpleWaitEvent()
+    sniffer.kwargs["started_callback"] = e.set
+    sniffer.start()
+    e.wait(timeout=timeout)
+
+
+def scapy_sniffer_stop(sniffer: AsyncSniffer):
+    """Like AsyncSniffer.stop except no error is raising if not running"""
+    if sniffer is not None and sniffer.running:
+        sniffer.stop()
+
+
+class AsyncSnifferContext(AsyncSniffer):
+    def __enter__(self):
+        scapy_sniffer_start_block(self)
+        return self
+
+    def __exit__(self, *a):
+        scapy_sniffer_stop(self)
-- 
2.25.1




[Index of Archives]     [Kernel]     [Gnu Classpath]     [Gnu Crypto]     [DM Crypt]     [Netfilter]     [Bugtraq]

  Powered by Linux