From 273bd567f16b6d5b4af6baf1151eb17240e4badf Mon Sep 17 00:00:00 2001 From: jessesightler-redhat <168553759+jessesightler-redhat@users.noreply.github.com> Date: Wed, 10 Jul 2024 12:40:01 -0400 Subject: [PATCH] Added support for SNI (#755) feat(core): add support for SNI Provide in `check_hostname` in the client too. --------- Co-authored-by: Jeff Widman --- kazoo/client.py | 7 +++++ kazoo/handlers/utils.py | 15 +++++++--- kazoo/protocol/connection.py | 2 ++ kazoo/tests/test_utils.py | 53 ++++++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/kazoo/client.py b/kazoo/client.py index 27b7c384..7ee4e836 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -120,6 +120,7 @@ def __init__( ca=None, use_ssl=False, verify_certs=True, + check_hostname=False, **kwargs, ): """Create a :class:`KazooClient` instance. All time arguments @@ -182,6 +183,8 @@ def __init__( :param use_ssl: argument to control whether SSL is used or not :param verify_certs: when using SSL, argument to bypass certs verification + :param check_hostname: when using SSL, check the hostname + against the hostname in the cert Basic Example: @@ -237,6 +240,7 @@ def __init__( self.use_ssl = use_ssl self.verify_certs = verify_certs + self.check_hostname = check_hostname self.certfile = certfile self.keyfile = keyfile self.keyfile_password = keyfile_password @@ -758,8 +762,10 @@ def command(self, cmd=b"ruok"): raise ConnectionLoss("No connection to server") peer = self._connection._socket.getpeername()[:2] + peer_host = self._connection._socket.getpeername()[1] sock = self.handler.create_connection( peer, + hostname=peer_host, timeout=self._session_timeout / 1000.0, use_ssl=self.use_ssl, ca=self.ca, @@ -767,6 +773,7 @@ def command(self, cmd=b"ruok"): keyfile=self.keyfile, keyfile_password=self.keyfile_password, verify_certs=self.verify_certs, + check_hostname=self.check_hostname, ) sock.sendall(cmd) result = sock.recv(8192) diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index f227baa6..206806f6 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -196,6 +196,7 @@ def create_tcp_socket(module): def create_tcp_connection( module, address, + hostname=None, timeout=None, use_ssl=False, ca=None, @@ -203,6 +204,7 @@ def create_tcp_connection( keyfile=None, keyfile_password=None, verify_certs=True, + check_hostname=False, options=None, ciphers=None, ): @@ -237,11 +239,14 @@ def create_tcp_connection( # Load default CA certs context.load_default_certs(ssl.Purpose.SERVER_AUTH) + if check_hostname and not verify_certs: + raise ValueError( + "verify_certs must be True when" + + " check_hostname is True" + ) # We must set check_hostname to False prior to setting # verify_mode to CERT_NONE. - # TODO: Make hostname verification configurable as some users may - # elect to use it. - context.check_hostname = False + context.check_hostname = check_hostname context.verify_mode = ( ssl.CERT_REQUIRED if verify_certs else ssl.CERT_NONE ) @@ -258,7 +263,9 @@ def create_tcp_connection( addrs = socket.getaddrinfo( address[0], address[1], 0, socket.SOCK_STREAM ) - conn = context.wrap_socket(module.socket(addrs[0][0])) + conn = context.wrap_socket( + module.socket(addrs[0][0]), server_hostname=hostname + ) conn.settimeout(timeout_at) conn.connect(address) sock = conn diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index ad4f3b1f..ba30b84e 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -703,6 +703,7 @@ def _connect(self, host, hostip, port): with self._socket_error_handling(): self._socket = self.handler.create_connection( address=(hostip, port), + hostname=host, timeout=client._session_timeout / 1000.0, use_ssl=self.client.use_ssl, keyfile=self.client.keyfile, @@ -710,6 +711,7 @@ def _connect(self, host, hostip, port): ca=self.client.ca, keyfile_password=self.client.keyfile_password, verify_certs=self.client.verify_certs, + check_hostname=self.client.check_hostname, ) self._socket.setblocking(0) diff --git a/kazoo/tests/test_utils.py b/kazoo/tests/test_utils.py index b56744cb..96d484a8 100644 --- a/kazoo/tests/test_utils.py +++ b/kazoo/tests/test_utils.py @@ -30,6 +30,59 @@ def test_timeout_arg(self): timeout = call_args[0][1] assert timeout >= 0, "socket timeout must be nonnegative" + def test_ssl_server_hostname(self): + from kazoo.handlers import utils + from kazoo.handlers.utils import create_tcp_connection, socket, ssl + + with patch.object(utils, "_set_default_tcpsock_options"): + with patch.object(ssl.SSLContext, "wrap_socket") as wrap_socket: + create_tcp_connection( + socket, + ("127.0.0.1", 2181), + timeout=1.5, + hostname="fakehostname", + use_ssl=True, + ) + + for call_args in wrap_socket.call_args_list: + server_hostname = call_args[1]["server_hostname"] + assert server_hostname == "fakehostname" + + def test_ssl_server_check_hostname(self): + from kazoo.handlers import utils + from kazoo.handlers.utils import create_tcp_connection, socket, ssl + + with patch.object(utils, "_set_default_tcpsock_options"): + with patch.object( + ssl.SSLContext, "wrap_socket", autospec=True + ) as wrap_socket: + create_tcp_connection( + socket, + ("127.0.0.1", 2181), + timeout=1.5, + hostname="fakehostname", + use_ssl=True, + check_hostname=True, + ) + + for call_args in wrap_socket.call_args_list: + ssl_context = call_args[0][0] + assert ssl_context.check_hostname + + def test_ssl_server_check_hostname_config_validation(self): + from kazoo.handlers.utils import create_tcp_connection, socket + + with pytest.raises(ValueError): + create_tcp_connection( + socket, + ("127.0.0.1", 2181), + timeout=1.5, + hostname="fakehostname", + use_ssl=True, + verify_certs=False, + check_hostname=True, + ) + def test_timeout_arg_eventlet(self): if not EVENTLET_HANDLER_AVAILABLE: pytest.skip("eventlet handler not available.")