diff --git a/daphne/testing.py b/daphne/testing.py index 785edf9d..ab5729e2 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -18,11 +18,21 @@ class BaseDaphneTestingInstance: startup_timeout = 2 def __init__( - self, xff=False, http_timeout=None, request_buffer_size=None, *, application + self, + xff=False, + http_timeout=None, + request_buffer_size=None, + *, + application, + host="127.0.0.1", + unix_socket=None, + file_descriptor=None, ): self.xff = xff self.http_timeout = http_timeout - self.host = "127.0.0.1" + self.host = host + self.unix_socket = unix_socket + self.file_descriptor = file_descriptor self.request_buffer_size = request_buffer_size self.application = application @@ -44,6 +54,8 @@ def __enter__(self): # Start up process self.process = DaphneProcess( host=self.host, + unix_socket=self.unix_socket, + file_descriptor=self.file_descriptor, get_application=self.get_application, kwargs=kwargs, setup=self.process_setup, @@ -126,9 +138,20 @@ class DaphneProcess(multiprocessing.Process): port it ends up listening on back to the parent process. """ - def __init__(self, host, get_application, kwargs=None, setup=None, teardown=None): + def __init__( + self, + get_application, + host=None, + file_descriptor=None, + unix_socket=None, + kwargs=None, + setup=None, + teardown=None, + ): super().__init__() self.host = host + self.file_descriptor = file_descriptor + self.unix_socket = unix_socket self.get_application = get_application self.kwargs = kwargs or {} self.setup = setup @@ -153,12 +176,17 @@ def run(self): try: # Create the server class - endpoints = build_endpoint_description_strings(host=self.host, port=0) + endpoints = build_endpoint_description_strings( + host=self.host, + port=0 if self.host else None, + unix_socket=self.unix_socket, + file_descriptor=self.file_descriptor, + ) self.server = Server( application=application, endpoints=endpoints, signal_handlers=False, - **self.kwargs + **self.kwargs, ) # Set up a poller to look for the port reactor.callLater(0.1, self.resolve_port) @@ -177,11 +205,18 @@ def run(self): def resolve_port(self): from twisted.internet import reactor - if self.server.listening_addresses: - self.port.value = self.server.listening_addresses[0][1] - self.ready.set() + if not all(listener.called for listener in self.server.listeners): + pass + elif self.host: + if self.server.listening_addresses: + self.port.value = self.server.listening_addresses[0][1] + self.ready.set() + return else: - reactor.callLater(0.1, self.resolve_port) + self.port.value = -1 + self.ready.set() + return + reactor.callLater(0.1, self.resolve_port) class TestApplication: diff --git a/setup.cfg b/setup.cfg index cdde7036..c6afb6e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,9 +24,7 @@ classifiers = Topic :: Internet :: WWW/HTTP [options] -package_dir = - twisted=daphne/twisted -packages = find: +packages = find_namespace: include_package_data = True install_requires = asgiref>=3.5.2,<4 @@ -48,6 +46,11 @@ tests = pytest pytest-asyncio +[options.packages.find] +include= + daphne* + twisted* + [flake8] exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* extend-ignore = E123, E128, E266, E402, W503, E731, W601 diff --git a/tests/http_base.py b/tests/http_base.py index e5a80c21..8d483b8c 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -17,6 +17,20 @@ class DaphneTestCase(unittest.TestCase): to store/retrieve the request/response messages. """ + _instance_endpoint_args = {} + + @staticmethod + def _get_instance_raw_socket_connection(test_app, *, timeout): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.connect((test_app.host, test_app.port)) + return s + + @staticmethod + def _get_instance_http_connection(test_app, *, timeout): + return HTTPConnection(test_app.host, test_app.port, timeout=timeout) + ### Plain HTTP helpers def run_daphne_http( @@ -36,13 +50,15 @@ def run_daphne_http( and response messages. """ with DaphneTestingInstance( - xff=xff, request_buffer_size=request_buffer_size + xff=xff, + request_buffer_size=request_buffer_size, + **self._instance_endpoint_args, ) as test_app: # Add the response messages test_app.add_send_messages(responses) # Send it the request. We have to do this the long way to allow # duplicate headers. - conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout) + conn = self._get_instance_http_connection(test_app, timeout=timeout) if params: path += "?" + parse.urlencode(params, doseq=True) conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True) @@ -74,13 +90,10 @@ def run_daphne_raw(self, data, *, responses=None, timeout=1): Returns what Daphne sends back. """ assert isinstance(data, bytes) - with DaphneTestingInstance() as test_app: + with DaphneTestingInstance(**self._instance_endpoint_args) as test_app: if responses is not None: test_app.add_send_messages(responses) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.settimeout(timeout) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.connect((test_app.host, test_app.port)) + s = self._get_instance_raw_socket_connection(test_app, timeout=timeout) s.send(data) try: return s.recv(1000000) diff --git a/tests/test_unixsocket.py b/tests/test_unixsocket.py new file mode 100644 index 00000000..46ad4f94 --- /dev/null +++ b/tests/test_unixsocket.py @@ -0,0 +1,56 @@ +import os +import socket +import weakref +from http.client import HTTPConnection +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import skipUnless + +import test_http_response +from http_base import DaphneTestCase + +__all__ = ["UnixSocketFDDaphneTestCase", "TestInheritedUnixSocket"] + + +class UnixSocketFDDaphneTestCase(DaphneTestCase): + @property + def _instance_endpoint_args(self): + tmp_dir = TemporaryDirectory() + weakref.finalize(self, tmp_dir.cleanup) + sock_path = str(Path(tmp_dir.name, "test.sock")) + listen_sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) + listen_sock.bind(sock_path) + listen_sock.listen() + listen_sock_fileno = os.dup(listen_sock.fileno()) + os.set_inheritable(listen_sock_fileno, True) + listen_sock.close() + return {"host": None, "file_descriptor": listen_sock_fileno} + + @staticmethod + def _get_instance_socket_path(test_app): + with socket.socket(fileno=os.dup(test_app.file_descriptor)) as sock: + return sock.getsockname() + + @classmethod + def _get_instance_raw_socket_connection(cls, test_app, *, timeout): + socket_name = cls._get_instance_socket_path(test_app) + s = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) + s.settimeout(timeout) + s.connect(socket_name) + return s + + @classmethod + def _get_instance_http_connection(cls, test_app, *, timeout): + def connect(): + conn.sock = cls._get_instance_raw_socket_connection( + test_app, timeout=timeout + ) + + conn = HTTPConnection("", timeout=timeout) + conn.connect = connect + return conn + + +@skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX support not present.") +class TestInheritedUnixSocket(UnixSocketFDDaphneTestCase): + test_minimal_response = test_http_response.TestHTTPResponse.test_minimal_response diff --git a/daphne/twisted/plugins/fd_endpoint.py b/twisted/plugins/fd_endpoint.py similarity index 74% rename from daphne/twisted/plugins/fd_endpoint.py rename to twisted/plugins/fd_endpoint.py index 313a3154..ddf3a452 100644 --- a/daphne/twisted/plugins/fd_endpoint.py +++ b/twisted/plugins/fd_endpoint.py @@ -1,3 +1,4 @@ +import os import socket from twisted.internet import endpoints @@ -10,8 +11,13 @@ class _FDParser: prefix = "fd" - def _parseServer(self, reactor, fileno, domain=socket.AF_INET): + def _parseServer(self, reactor, fileno, domain=None): fileno = int(fileno) + if domain: + domain = getattr(socket, f"AF_{domain}") + else: + with socket.socket(fileno=os.dup(fileno)) as sock: + domain = sock.family return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain) def parseStreamServer(self, reactor, *args, **kwargs):