diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py index 28e6df83..91b3548c 100644 --- a/src/waitress/adjustments.py +++ b/src/waitress/adjustments.py @@ -17,7 +17,7 @@ import socket import warnings -from .compat import CPYTHON, HAS_IPV6, LINUX, WIN +from .compat import HAS_IPV6, VSOCK, WIN from .proxy_headers import PROXY_HEADERS truthy = frozenset(("t", "true", "y", "yes", "on", "1")) @@ -81,6 +81,10 @@ def str_iftruthy(s): return str(s) if s else None +def int_iftruthy(s): + return int(s) if s else None + + def as_socket_list(sockets): """Checks if the elements in the list are of type socket and removes them if not.""" @@ -130,7 +134,8 @@ class Adjustments: ("asyncore_use_poll", asbool), ("unix_socket", str), ("unix_socket_perms", asoctal), - ("vsock_socket", str), + ("vsock_socket_cid", int_iftruthy), + ("vsock_socket_port", int_iftruthy), ("sockets", as_socket_list), ("channel_request_lookahead", int), ("server_name", str), @@ -255,8 +260,9 @@ class Adjustments: # Path to a Unix domain socket to use. unix_socket_perms = 0o600 - # Path to a vsock socket to use. - vsock_socket = None + # The CID and port to use for a vsock socket. + vsock_socket_cid = None + vsock_socket_port = None # The socket options to set on receiving a connection. It is a list of # (level, optname, value) tuples. TCP_NODELAY disables the Nagle @@ -306,8 +312,10 @@ def __init__(self, **kw): if "sockets" in kw and "unix_socket" in kw: raise ValueError("unix_socket may not be set if sockets is set") - if "sockets" in kw and "vsock_socket" in kw: - raise ValueError("vsock_socket may not be set if sockets is set") + if "sockets" in kw and ("vsock_socket_cid" in kw or "vsock_socket_port" in kw): + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if sockets is set" + ) if "unix_socket" in kw and ("host" in kw or "port" in kw): raise ValueError("unix_socket may not be set if host or port is set") @@ -315,17 +323,29 @@ def __init__(self, **kw): if "unix_socket" in kw and "listen" in kw: raise ValueError("unix_socket may not be set if listen is set") - if "vsock_socket" in kw and not (LINUX and CPYTHON): - raise ValueError("vsock_socket is not supported on this platform") + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and not VSOCK: + raise ValueError( + "vsock_socket_cid and vsock_socket_port are not supported on this platform" + ) - if "vsock_socket" in kw and ("host" in kw or "port" in kw): - raise ValueError("vsock_socket may not be set if host or port is set") + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and ( + "host" in kw or "port" in kw + ): + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if host or port is set" + ) - if "vsock_socket" in kw and "listen" in kw: - raise ValueError("vsock_socket may not be set if listen is set") + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and "listen" in kw: + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if listen is set" + ) - if "vsock_socket" in kw and "unix_socket" in kw: - raise ValueError("vsock_socket may not be set if unix_socket is set") + if ( + "vsock_socket_cid" in kw or "vsock_socket_port" in kw + ) and "unix_socket" in kw: + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if unix_socket is set" + ) if "send_bytes" in kw: warnings.warn( @@ -372,10 +392,10 @@ def __init__(self, **kw): # Try turning the port into an integer port = int(port) - except Exception: + except Exception as exc: raise ValueError( "Windows does not support service names instead of port numbers" - ) + ) from exc try: if "[" in host and "]" in host: # pragma: nocover @@ -410,20 +430,20 @@ def __init__(self, **kw): wanted_sockets.append((family, socktype, proto, sockaddr)) hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) - except Exception: - raise ValueError("Invalid host/port specified.") + except Exception as exc: + raise ValueError("Invalid host/port specified.") from exc if self.trusted_proxy_count is not None and self.trusted_proxy is None: raise ValueError( - "trusted_proxy_count has no meaning without setting " "trusted_proxy" + "trusted_proxy_count has no meaning without setting trusted_proxy" ) - elif self.trusted_proxy_count is None: + if self.trusted_proxy_count is None: self.trusted_proxy_count = 1 if self.trusted_proxy_headers and self.trusted_proxy is None: raise ValueError( - "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + "trusted_proxy_headers has no meaning without setting trusted_proxy" ) if self.trusted_proxy_headers: @@ -434,9 +454,9 @@ def __init__(self, **kw): unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS if unknown_values: raise ValueError( - "Received unknown trusted_proxy_headers value (%s) expected one " - "of %s" - % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + "Received unknown trusted_proxy_headers value " + f"({', '.join(unknown_values)}) expected one " + f"of {', '.join(KNOWN_PROXY_HEADERS)}" ) if ( @@ -511,6 +531,7 @@ def check_sockets(cls, sockets): if hasattr(socket, "AF_VSOCK"): supported_families.append(socket.AF_VSOCK) + inet_families = (socket.AF_INET, socket.AF_INET6) family = None for sock in sockets: if sock.type != socket.SOCK_STREAM or sock.family not in supported_families: @@ -519,5 +540,7 @@ def check_sockets(cls, sockets): ) if family is None: family = sock.family + elif family in inet_families and sock.family in inet_families: + pass elif family != sock.family: raise ValueError("All sockets must belong to the same family.") diff --git a/src/waitress/compat.py b/src/waitress/compat.py index fd177fb6..8cb82dfa 100644 --- a/src/waitress/compat.py +++ b/src/waitress/compat.py @@ -8,8 +8,7 @@ # Platform detection. WIN = platform.system() == "Windows" -LINUX = platform.system() == "Linux" -CPYTHON = platform.python_implementation() == "CPython" +VSOCK = hasattr(socket, "AF_VSOCK") MAXINT = sys.maxsize HAS_IPV6 = socket.has_ipv6 diff --git a/src/waitress/server.py b/src/waitress/server.py index 4bdae205..d249ec39 100644 --- a/src/waitress/server.py +++ b/src/waitress/server.py @@ -25,6 +25,7 @@ from waitress.utilities import cleanup_unix_socket from . import wasyncore +from .compat import VSOCK from .proxy_headers import proxy_headers_middleware @@ -68,7 +69,11 @@ def create_server( sockinfo=sockinfo, ) - if adj.vsock_socket and hasattr(socket, "AF_VSOCK"): + if (adj.vsock_socket_cid or adj.vsock_socket_port) and VSOCK: + if not adj.vsock_socket_cid: + adj.vsock_socket_cid = socket.VMADDR_CID_ANY + if not adj.vsock_socket_port: + adj.vsock_socket_port = socket.VMADDR_PORT_ANY sockinfo = (socket.AF_VSOCK, socket.SOCK_STREAM, None, None) return VsockWSGIServer( application, @@ -102,7 +107,7 @@ def create_server( for sock in adj.sockets: sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) - if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + if sock.family in (socket.AF_INET, socket.AF_INET6): last_serv = TcpWSGIServer( application, map, @@ -130,7 +135,7 @@ def create_server( effective_listen.append( (last_serv.effective_host, last_serv.effective_port) ) - elif hasattr(socket, "AF_VSOCK") and sock.family == socket.AF_VSOCK: + elif VSOCK and sock.family == socket.AF_VSOCK: last_serv = VsockWSGIServer( application, map, @@ -402,6 +407,7 @@ def set_socket_options(self, conn): if hasattr(socket, "AF_UNIX"): + class UnixWSGIServer(BaseWSGIServer): def __init__( self, @@ -440,7 +446,9 @@ def getsockname(self): def fix_addr(self, addr): return ("localhost", None) -if hasattr(socket, "AF_VSOCK"): + +if VSOCK: + class VsockWSGIServer(BaseWSGIServer): def __init__( self, @@ -468,7 +476,7 @@ def __init__( ) def bind_server_socket(self): - self.bind(self.adj.vsock_socket) + self.bind((self.adj.vsock_socket_cid, self.adj.vsock_socket_port)) def getsockname(self): return ("vsock", self.socket.getsockname()) diff --git a/tests/test_adjustments.py b/tests/test_adjustments.py index 16d061c2..435109ff 100644 --- a/tests/test_adjustments.py +++ b/tests/test_adjustments.py @@ -1,8 +1,9 @@ +from re import L import socket import unittest import warnings -from waitress.compat import WIN +from waitress.compat import VSOCK, WIN class Test_asbool(unittest.TestCase): @@ -106,35 +107,40 @@ def _makeOne(self, **kw): return Adjustments(**kw) def test_goodvars(self): - inst = self._makeOne( - host="localhost", - port="8080", - threads="5", - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"forwarded"}, - trusted_proxy_count=2, - log_untrusted_proxy_headers=True, - url_scheme="https", - backlog="20", - recv_bytes="200", - send_bytes="300", - outbuf_overflow="400", - inbuf_overflow="500", - connection_limit="1000", - cleanup_interval="1100", - channel_timeout="1200", - log_socket_errors="true", - max_request_header_size="1300", - max_request_body_size="1400", - expose_tracebacks="true", - ident="abc", - asyncore_loop_timeout="5", - asyncore_use_poll=True, - unix_socket_perms="777", - url_prefix="///foo/", - ipv4=True, - ipv6=False, - ) + kw = { + "host": "localhost", + "port": "8080", + "threads": "5", + "trusted_proxy": "192.168.1.1", + "trusted_proxy_headers": {"forwarded"}, + "trusted_proxy_count": 2, + "log_untrusted_proxy_headers": True, + "url_scheme": "https", + "backlog": "20", + "recv_bytes": "200", + "send_bytes": "300", + "outbuf_overflow": "400", + "inbuf_overflow": "500", + "connection_limit": "1000", + "cleanup_interval": "1100", + "channel_timeout": "1200", + "log_socket_errors": "true", + "max_request_header_size": "1300", + "max_request_body_size": 1400, + "expose_tracebacks": "true", + "ident": "abc", + "asyncore_loop_timeout": "5", + "asyncore_use_poll": True, + "unix_socket_perms": "777", + "url_prefix": "///foo/", + "ipv4": True, + "ipv6": False, + } + if VSOCK: + kw["vsock_socket_cid"] = -1 + kw["vsock_socket_port"] = -1 + + inst = self._makeOne(**kw) self.assertEqual(inst.host, "localhost") self.assertEqual(inst.port, 8080) @@ -164,6 +170,10 @@ def test_goodvars(self): self.assertEqual(inst.ipv4, True) self.assertEqual(inst.ipv6, False) + if VSOCK: + self.assertEqual(inst.vsock_socket_cid, -1) + self.assertEqual(inst.vsock_socket_port, -1) + bind_pairs = [ sockaddr[:2] for (family, _, _, sockaddr) in inst.listen @@ -278,7 +288,7 @@ def test_dont_mix_sockets_and_unix_socket(self): sockets[0].close() def test_dont_mix_unix_and_vsock_socket(self): - if not hasattr(socket, "AF_VSOCK"): + if not VSOCK: return sockets = [ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), @@ -289,7 +299,7 @@ def test_dont_mix_unix_and_vsock_socket(self): sock.close() def test_dont_mix_tcp_and_vsock_socket(self): - if not hasattr(socket, "AF_VSOCK"): + if not VSOCK: return sockets = [ socket.socket(socket.AF_INET, socket.SOCK_STREAM), @@ -508,3 +518,21 @@ def test_dont_mix_internet_and_unix_sockets(self): self.assertRaises(ValueError, self._makeOne, sockets=sockets) sockets[0].close() sockets[1].close() + + +if VSOCK: + + class TestVsockSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/tests/test_server.py b/tests/test_server.py index 6edc3b24..ca8bff8f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,6 +2,8 @@ import socket import unittest +from waitress.compat import VSOCK + dummy_app = object() @@ -415,6 +417,111 @@ def test_create_with_unix_socket(self): self.assertTrue(isinstance(server[1], UnixWSGIServer)) +if VSOCK: + class TestVsockWSGIServer(unittest.TestCase): + vsock_socket_cid = 2 + vsock_socket_port = -1 + + def _makeOne(self, _start=True, _sock=None): + from waitress.server import create_server + + self.inst = create_server( + dummy_app, + map={}, + _start=_start, + _sock=_sock, + _dispatcher=DummyTaskDispatcher(), + vsock_socket_cid=self.vsock_socket_cid, + vsock_socket_port=self.vsock_socket_port, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + self.inst.close() + + def _makeDummy(self, *args, **kwargs): + sock = DummySock(*args, **kwargs) + sock.family = socket.AF_VSOCK + return sock + + def test_unix(self): + inst = self._makeOne(_start=False) + self.assertEqual(inst.socket.family, socket.AF_VSOCK) + self.assertEqual(inst.socket.getsockname(), self.vsock_socket_cid) + + def test_handle_accept(self): + # Working on the assumption that we only have to test the happy path + # for Unix domain sockets as the other paths should've been covered + # by inet sockets. + client = self._makeDummy() + listen = self._makeDummy(acceptresult=(client, None)) + inst = self._makeOne(_sock=listen) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(client.opts, []) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) + + def test_creates_new_sockinfo(self): + from waitress.server import VsockWSGIServer + + self.inst = VsockWSGIServer( + dummy_app, + vsock_socket_cid=self.vsock_socket_cid, + vsock_socket_port=self.vsock_socket_port, + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + BaseWSGIServer, + MultiSocketServer, + TcpWSGIServer, + VsockWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) + ) + self.assertTrue(isinstance(server[0], VsockWSGIServer)) + self.assertTrue(isinstance(server[1], VsockWSGIServer)) + + class DummySock(socket.socket): accepted = False blocking = False