Skip to content

Commit

Permalink
VsockWSGIServer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthb committed Apr 7, 2023
1 parent 66b8600 commit 19f58f0
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 63 deletions.
71 changes: 47 additions & 24 deletions src/waitress/adjustments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import socket
import warnings

from .compat import CPYTHON, HAS_IPV6, LINUX, WIN
from .compat import HAS_IPV6, VSOCK, WIN
from .proxy_headers import PROXY_HEADERS

truthy = frozenset(("t", "true", "y", "yes", "on", "1"))
Expand Down Expand Up @@ -81,6 +81,10 @@ def str_iftruthy(s):
return str(s) if s else None


def int_iftruthy(s):
return int(s) if s else None


def as_socket_list(sockets):
"""Checks if the elements in the list are of type socket and
removes them if not."""
Expand Down Expand Up @@ -130,7 +134,8 @@ class Adjustments:
("asyncore_use_poll", asbool),
("unix_socket", str),
("unix_socket_perms", asoctal),
("vsock_socket", str),
("vsock_socket_cid", int_iftruthy),
("vsock_socket_port", int_iftruthy),
("sockets", as_socket_list),
("channel_request_lookahead", int),
("server_name", str),
Expand Down Expand Up @@ -255,8 +260,9 @@ class Adjustments:
# Path to a Unix domain socket to use.
unix_socket_perms = 0o600

# Path to a vsock socket to use.
vsock_socket = None
# The CID and port to use for a vsock socket.
vsock_socket_cid = None
vsock_socket_port = None

# The socket options to set on receiving a connection. It is a list of
# (level, optname, value) tuples. TCP_NODELAY disables the Nagle
Expand Down Expand Up @@ -306,26 +312,40 @@ def __init__(self, **kw):
if "sockets" in kw and "unix_socket" in kw:
raise ValueError("unix_socket may not be set if sockets is set")

if "sockets" in kw and "vsock_socket" in kw:
raise ValueError("vsock_socket may not be set if sockets is set")
if "sockets" in kw and ("vsock_socket_cid" in kw or "vsock_socket_port" in kw):
raise ValueError(
"vsock_socket_cid or vsock_socket_port may not be set if sockets is set"
)

if "unix_socket" in kw and ("host" in kw or "port" in kw):
raise ValueError("unix_socket may not be set if host or port is set")

if "unix_socket" in kw and "listen" in kw:
raise ValueError("unix_socket may not be set if listen is set")

if "vsock_socket" in kw and not (LINUX and CPYTHON):
raise ValueError("vsock_socket is not supported on this platform")
if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and not VSOCK:
raise ValueError(
"vsock_socket_cid and vsock_socket_port are not supported on this platform"
)

if "vsock_socket" in kw and ("host" in kw or "port" in kw):
raise ValueError("vsock_socket may not be set if host or port is set")
if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and (
"host" in kw or "port" in kw
):
raise ValueError(
"vsock_socket_cid or vsock_socket_port may not be set if host or port is set"
)

if "vsock_socket" in kw and "listen" in kw:
raise ValueError("vsock_socket may not be set if listen is set")
if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and "listen" in kw:
raise ValueError(
"vsock_socket_cid or vsock_socket_port may not be set if listen is set"
)

if "vsock_socket" in kw and "unix_socket" in kw:
raise ValueError("vsock_socket may not be set if unix_socket is set")
if (
"vsock_socket_cid" in kw or "vsock_socket_port" in kw
) and "unix_socket" in kw:
raise ValueError(
"vsock_socket_cid or vsock_socket_port may not be set if unix_socket is set"
)

if "send_bytes" in kw:
warnings.warn(
Expand Down Expand Up @@ -372,10 +392,10 @@ def __init__(self, **kw):
# Try turning the port into an integer
port = int(port)

except Exception:
except Exception as exc:
raise ValueError(
"Windows does not support service names instead of port numbers"
)
) from exc

try:
if "[" in host and "]" in host: # pragma: nocover
Expand Down Expand Up @@ -410,20 +430,20 @@ def __init__(self, **kw):
wanted_sockets.append((family, socktype, proto, sockaddr))
hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1]))

except Exception:
raise ValueError("Invalid host/port specified.")
except Exception as exc:
raise ValueError("Invalid host/port specified.") from exc

if self.trusted_proxy_count is not None and self.trusted_proxy is None:
raise ValueError(
"trusted_proxy_count has no meaning without setting " "trusted_proxy"
"trusted_proxy_count has no meaning without setting trusted_proxy"
)

elif self.trusted_proxy_count is None:
if self.trusted_proxy_count is None:
self.trusted_proxy_count = 1

if self.trusted_proxy_headers and self.trusted_proxy is None:
raise ValueError(
"trusted_proxy_headers has no meaning without setting " "trusted_proxy"
"trusted_proxy_headers has no meaning without setting trusted_proxy"
)

if self.trusted_proxy_headers:
Expand All @@ -434,9 +454,9 @@ def __init__(self, **kw):
unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS
if unknown_values:
raise ValueError(
"Received unknown trusted_proxy_headers value (%s) expected one "
"of %s"
% (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS))
"Received unknown trusted_proxy_headers value "
f"({', '.join(unknown_values)}) expected one "
f"of {', '.join(KNOWN_PROXY_HEADERS)}"
)

if (
Expand Down Expand Up @@ -511,6 +531,7 @@ def check_sockets(cls, sockets):
if hasattr(socket, "AF_VSOCK"):
supported_families.append(socket.AF_VSOCK)

inet_families = (socket.AF_INET, socket.AF_INET6)
family = None
for sock in sockets:
if sock.type != socket.SOCK_STREAM or sock.family not in supported_families:
Expand All @@ -519,5 +540,7 @@ def check_sockets(cls, sockets):
)
if family is None:
family = sock.family
elif family in inet_families and sock.family in inet_families:
pass
elif family != sock.family:
raise ValueError("All sockets must belong to the same family.")
3 changes: 1 addition & 2 deletions src/waitress/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

# Platform detection.
WIN = platform.system() == "Windows"
LINUX = platform.system() == "Linux"
CPYTHON = platform.python_implementation() == "CPython"
VSOCK = hasattr(socket, "AF_VSOCK")

MAXINT = sys.maxsize
HAS_IPV6 = socket.has_ipv6
Expand Down
18 changes: 13 additions & 5 deletions src/waitress/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from waitress.utilities import cleanup_unix_socket

from . import wasyncore
from .compat import VSOCK
from .proxy_headers import proxy_headers_middleware


Expand Down Expand Up @@ -68,7 +69,11 @@ def create_server(
sockinfo=sockinfo,
)

if adj.vsock_socket and hasattr(socket, "AF_VSOCK"):
if (adj.vsock_socket_cid or adj.vsock_socket_port) and VSOCK:
if not adj.vsock_socket_cid:
adj.vsock_socket_cid = socket.VMADDR_CID_ANY
if not adj.vsock_socket_port:
adj.vsock_socket_port = socket.VMADDR_PORT_ANY
sockinfo = (socket.AF_VSOCK, socket.SOCK_STREAM, None, None)
return VsockWSGIServer(
application,
Expand Down Expand Up @@ -102,7 +107,7 @@ def create_server(

for sock in adj.sockets:
sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname())
if sock.family == socket.AF_INET or sock.family == socket.AF_INET6:
if sock.family in (socket.AF_INET, socket.AF_INET6):
last_serv = TcpWSGIServer(
application,
map,
Expand Down Expand Up @@ -130,7 +135,7 @@ def create_server(
effective_listen.append(
(last_serv.effective_host, last_serv.effective_port)
)
elif hasattr(socket, "AF_VSOCK") and sock.family == socket.AF_VSOCK:
elif VSOCK and sock.family == socket.AF_VSOCK:
last_serv = VsockWSGIServer(
application,
map,
Expand Down Expand Up @@ -402,6 +407,7 @@ def set_socket_options(self, conn):


if hasattr(socket, "AF_UNIX"):

class UnixWSGIServer(BaseWSGIServer):
def __init__(
self,
Expand Down Expand Up @@ -440,7 +446,9 @@ def getsockname(self):
def fix_addr(self, addr):
return ("localhost", None)

if hasattr(socket, "AF_VSOCK"):

if VSOCK:

class VsockWSGIServer(BaseWSGIServer):
def __init__(
self,
Expand Down Expand Up @@ -468,7 +476,7 @@ def __init__(
)

def bind_server_socket(self):
self.bind(self.adj.vsock_socket)
self.bind((self.adj.vsock_socket_cid, self.adj.vsock_socket_port))

def getsockname(self):
return ("vsock", self.socket.getsockname())
Expand Down
92 changes: 60 additions & 32 deletions tests/test_adjustments.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from re import L
import socket
import unittest
import warnings

from waitress.compat import WIN
from waitress.compat import VSOCK, WIN


class Test_asbool(unittest.TestCase):
Expand Down Expand Up @@ -106,35 +107,40 @@ def _makeOne(self, **kw):
return Adjustments(**kw)

def test_goodvars(self):
inst = self._makeOne(
host="localhost",
port="8080",
threads="5",
trusted_proxy="192.168.1.1",
trusted_proxy_headers={"forwarded"},
trusted_proxy_count=2,
log_untrusted_proxy_headers=True,
url_scheme="https",
backlog="20",
recv_bytes="200",
send_bytes="300",
outbuf_overflow="400",
inbuf_overflow="500",
connection_limit="1000",
cleanup_interval="1100",
channel_timeout="1200",
log_socket_errors="true",
max_request_header_size="1300",
max_request_body_size="1400",
expose_tracebacks="true",
ident="abc",
asyncore_loop_timeout="5",
asyncore_use_poll=True,
unix_socket_perms="777",
url_prefix="///foo/",
ipv4=True,
ipv6=False,
)
kw = {
"host": "localhost",
"port": "8080",
"threads": "5",
"trusted_proxy": "192.168.1.1",
"trusted_proxy_headers": {"forwarded"},
"trusted_proxy_count": 2,
"log_untrusted_proxy_headers": True,
"url_scheme": "https",
"backlog": "20",
"recv_bytes": "200",
"send_bytes": "300",
"outbuf_overflow": "400",
"inbuf_overflow": "500",
"connection_limit": "1000",
"cleanup_interval": "1100",
"channel_timeout": "1200",
"log_socket_errors": "true",
"max_request_header_size": "1300",
"max_request_body_size": 1400,
"expose_tracebacks": "true",
"ident": "abc",
"asyncore_loop_timeout": "5",
"asyncore_use_poll": True,
"unix_socket_perms": "777",
"url_prefix": "///foo/",
"ipv4": True,
"ipv6": False,
}
if VSOCK:
kw["vsock_socket_cid"] = -1
kw["vsock_socket_port"] = -1

inst = self._makeOne(**kw)

self.assertEqual(inst.host, "localhost")
self.assertEqual(inst.port, 8080)
Expand Down Expand Up @@ -164,6 +170,10 @@ def test_goodvars(self):
self.assertEqual(inst.ipv4, True)
self.assertEqual(inst.ipv6, False)

if VSOCK:
self.assertEqual(inst.vsock_socket_cid, -1)
self.assertEqual(inst.vsock_socket_port, -1)

bind_pairs = [
sockaddr[:2]
for (family, _, _, sockaddr) in inst.listen
Expand Down Expand Up @@ -278,7 +288,7 @@ def test_dont_mix_sockets_and_unix_socket(self):
sockets[0].close()

def test_dont_mix_unix_and_vsock_socket(self):
if not hasattr(socket, "AF_VSOCK"):
if not VSOCK:
return
sockets = [
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
Expand All @@ -289,7 +299,7 @@ def test_dont_mix_unix_and_vsock_socket(self):
sock.close()

def test_dont_mix_tcp_and_vsock_socket(self):
if not hasattr(socket, "AF_VSOCK"):
if not VSOCK:
return
sockets = [
socket.socket(socket.AF_INET, socket.SOCK_STREAM),
Expand Down Expand Up @@ -508,3 +518,21 @@ def test_dont_mix_internet_and_unix_sockets(self):
self.assertRaises(ValueError, self._makeOne, sockets=sockets)
sockets[0].close()
sockets[1].close()


if VSOCK:

class TestVsockSocket(unittest.TestCase):
def _makeOne(self, **kw):
from waitress.adjustments import Adjustments

return Adjustments(**kw)

def test_dont_mix_internet_and_unix_sockets(self):
sockets = [
socket.socket(socket.AF_INET, socket.SOCK_STREAM),
socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM),
]
self.assertRaises(ValueError, self._makeOne, sockets=sockets)
sockets[0].close()
sockets[1].close()
Loading

0 comments on commit 19f58f0

Please sign in to comment.