[RFCv3 15/15] selftests: tcp_authopt: Add tests for rollover

From: Leonard Crestez
Date: Tue Aug 24 2021 - 17:36:56 EST


RFC5925 requires that the use can examine or control the keys being
used. This is implemented in linux via fields on the TCP_AUTHOPT
sockopt.

Add socket-level tests for the adjusting keyids on live connections and
checking the they are reflected on the peer.

Also check smooth transitions via rnextkeyid.

Signed-off-by: Leonard Crestez <cdleonard@xxxxxxxxx>
---
.../tcp_authopt_test/linux_tcp_authopt.py | 16 +-
.../tcp_authopt_test/test_rollover.py | 181 ++++++++++++++++++
2 files changed, 194 insertions(+), 3 deletions(-)
create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
index 41374f9851aa..23de148a4078 100644
--- a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
@@ -20,10 +20,12 @@ def BIT(x):
TCP_AUTHOPT = 38
TCP_AUTHOPT_KEY = 39

TCP_AUTHOPT_MAXKEYLEN = 80

+TCP_AUTHOPT_FLAG_LOCK_KEYID = BIT(0)
+TCP_AUTHOPT_FLAG_LOCK_RNEXTKEYID = BIT(1)
TCP_AUTHOPT_FLAG_REJECT_UNEXPECTED = BIT(2)

TCP_AUTHOPT_KEY_DEL = BIT(0)
TCP_AUTHOPT_KEY_EXCLUDE_OPTS = BIT(1)
TCP_AUTHOPT_KEY_BIND_ADDR = BIT(2)
@@ -35,24 +37,32 @@ TCP_AUTHOPT_ALG_AES_128_CMAC_96 = 2
@dataclass
class tcp_authopt:
"""Like linux struct tcp_authopt"""

flags: int = 0
- sizeof = 4
+ send_keyid: int = 0
+ send_rnextkeyid: int = 0
+ recv_keyid: int = 0
+ recv_rnextkeyid: int = 0
+ sizeof = 8

def pack(self) -> bytes:
return struct.pack(
- "I",
+ "IBBBB",
self.flags,
+ self.send_keyid,
+ self.send_rnextkeyid,
+ self.recv_keyid,
+ self.recv_rnextkeyid,
)

def __bytes__(self):
return self.pack()

@classmethod
def unpack(cls, b: bytes):
- tup = struct.unpack("I", b)
+ tup = struct.unpack("IBBBB", b)
return cls(*tup)


def set_tcp_authopt(sock, opt: tcp_authopt):
return sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, bytes(opt))
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py
new file mode 100644
index 000000000000..68c59c6d1e33
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py
@@ -0,0 +1,181 @@
+# SPDX-License-Identifier: GPL-2.0
+import typing
+import socket
+from .server import SimpleServerThread
+from .linux_tcp_authopt import (
+ TCP_AUTHOPT_FLAG_LOCK_KEYID,
+ TCP_AUTHOPT_FLAG_LOCK_RNEXTKEYID,
+ set_tcp_authopt_key,
+ tcp_authopt,
+ tcp_authopt_key,
+ set_tcp_authopt,
+ get_tcp_authopt,
+)
+from .utils import DEFAULT_TCP_SERVER_PORT, create_listen_socket, check_socket_echo
+from contextlib import ExitStack, contextmanager
+from .conftest import skipif_missing_tcp_authopt
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+@contextmanager
+def make_tcp_authopt_socket_pair(
+ server_addr="127.0.0.1",
+ server_authopt: tcp_authopt = None,
+ server_key_list: typing.Iterable[tcp_authopt_key] = [],
+ client_authopt: tcp_authopt = None,
+ client_key_list: typing.Iterable[tcp_authopt_key] = [],
+) -> typing.Tuple[socket.socket, socket.socket]:
+ """Make a pair for connected sockets for key switching tests
+
+ Server runs in a background thread implementing echo protocol"""
+ with ExitStack() as exit_stack:
+ listen_socket = exit_stack.enter_context(
+ create_listen_socket(bind_addr=server_addr)
+ )
+ server_thread = exit_stack.enter_context(
+ SimpleServerThread(listen_socket, mode="echo")
+ )
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_socket.settimeout(1.0)
+
+ if server_authopt:
+ set_tcp_authopt(listen_socket, server_authopt)
+ for k in server_key_list:
+ set_tcp_authopt_key(listen_socket, k)
+ if client_authopt:
+ set_tcp_authopt(client_socket, client_authopt)
+ for k in client_key_list:
+ set_tcp_authopt_key(client_socket, k)
+
+ client_socket.connect((server_addr, DEFAULT_TCP_SERVER_PORT))
+ check_socket_echo(client_socket)
+ server_socket = server_thread.server_socket[0]
+
+ yield client_socket, server_socket
+
+
+def test_get_keyids(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1],
+ )
+ )
+
+ check_socket_echo(client_socket)
+ client_tcp_authopt = get_tcp_authopt(client_socket)
+ server_tcp_authopt = get_tcp_authopt(server_socket)
+ assert server_tcp_authopt.send_keyid == 11
+ assert server_tcp_authopt.send_rnextkeyid == 12
+ assert server_tcp_authopt.recv_keyid == 12
+ assert server_tcp_authopt.recv_rnextkeyid == 11
+ assert client_tcp_authopt.send_keyid == 12
+ assert client_tcp_authopt.send_rnextkeyid == 11
+ assert client_tcp_authopt.recv_keyid == 11
+ assert client_tcp_authopt.recv_rnextkeyid == 12
+
+
+def test_rollover_send_keyid(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG_LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).recv_keyid == 11
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+ # Explicit request for key2
+ set_tcp_authopt(
+ client_socket, tcp_authopt(send_keyid=22, flags=TCP_AUTHOPT_FLAG_LOCK_KEYID)
+ )
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).recv_keyid == 21
+ assert get_tcp_authopt(server_socket).recv_keyid == 22
+
+
+def test_rollover_rnextkeyid(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG_LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_rnextkeyid == 11
+
+ # request rnextkeyd=22 but server does not have it
+ set_tcp_authopt(
+ client_socket,
+ tcp_authopt(send_rnextkeyid=21, flags=TCP_AUTHOPT_FLAG_LOCK_RNEXTKEYID),
+ )
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_rnextkeyid == 21
+ assert get_tcp_authopt(server_socket).send_keyid == 11
+
+ # after adding k2 on server the key is switched
+ set_tcp_authopt_key(server_socket, sk2)
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).send_keyid == 21
+
+
+def test_rollover_delkey(exit_stack: ExitStack):
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG_LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+ # invalid send_keyid is just ignored
+ set_tcp_authopt(client_socket, tcp_authopt(send_keyid=7))
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).send_keyid == 12
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+ assert get_tcp_authopt(client_socket).recv_keyid == 11
+
+ # If a key is removed it is replaced by anything that matches
+ ck1.delete_flag = True
+ set_tcp_authopt_key(client_socket, ck1)
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).send_keyid == 22
+ assert get_tcp_authopt(server_socket).send_keyid == 21
+ assert get_tcp_authopt(server_socket).recv_keyid == 22
+ assert get_tcp_authopt(client_socket).recv_keyid == 21
--
2.25.1