diff --git a/docs/arguments.rst b/docs/arguments.rst index 0b6ca458..b8a856aa 100644 --- a/docs/arguments.rst +++ b/docs/arguments.rst @@ -314,3 +314,17 @@ url_prefix be stripped of the prefix. Default: ``''`` + +channel_request_lookahead + Sets the amount of requests we can continue to read from the socket, while + we are processing current requests. The default value won't allow any + lookahead, increase it above ``0`` to enable. + + When enabled this inserts a callable ``waitress.client_disconnected`` into + the environment that allows the task to check if the client disconnected + while waiting for the response at strategic points in the execution and to + cancel the operation. + + Default: ``0`` + + .. versionadded:: 2.0.0 diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 3860ed51..f4d96776 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -140,7 +140,7 @@ def readable(self): # 1. We're not already about to close the connection. # 2. We're not waiting to flush remaining data before closing the # connection - # 3. There are not too many tasks already queued + # 3. There are not too many tasks already queued (if lookahead is enabled) # 4. There's no data in the output buffer that needs to be sent # before we potentially create a new task. @@ -196,6 +196,15 @@ def received(self, data): return False with self.requests_lock: + # Don't bother processing anymore data if this connection is about + # to close. This may happen if readable() returned True, on the + # main thread before the service thread set the close_when_flushed + # flag, and we read data but our service thread is attempting to + # shut down the connection due to an error. We want to make sure we + # do this while holding the request_lock so that we can't race + if self.will_close or self.close_when_flushed: + return False + while data: if self.request is None: self.request = self.parser_class(self.adj) diff --git a/tests/test_channel.py b/tests/test_channel.py index 8467ae7a..d798091d 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -18,7 +18,7 @@ def _makeOneWithMap(self, adj=None): map = {} inst = self._makeOne(sock, "127.0.0.1", adj, map=map) inst.outbuf_lock = DummyLock() - return inst, sock, map + return inst, sock.local(), map def test_ctor(self): inst, _, map = self._makeOneWithMap() @@ -218,7 +218,7 @@ def test_write_soon_nonempty_byte(self): def send(_): return 0 - sock.send = send + sock.remote.send = send wrote = inst.write_soon(b"a") self.assertEqual(wrote, 1) @@ -236,7 +236,7 @@ def test_write_soon_filewrapper(self): def send(_): return 0 - sock.send = send + sock.remote.send = send outbufs = inst.outbufs wrote = inst.write_soon(wrapper) @@ -270,7 +270,7 @@ def test_write_soon_rotates_outbuf_on_overflow(self): def send(_): return 0 - sock.send = send + sock.remote.send = send inst.adj.outbuf_high_watermark = 3 inst.current_outbuf_count = 4 @@ -286,7 +286,7 @@ def test_write_soon_waits_on_backpressure(self): def send(_): return 0 - sock.send = send + sock.remote.send = send inst.adj.outbuf_high_watermark = 3 inst.total_outbufs_len = 4 @@ -315,7 +315,7 @@ def send(_): inst.connected = False raise Exception() - sock.send = send + sock.remote.send = send inst.adj.outbuf_high_watermark = 3 inst.total_outbufs_len = 4 @@ -345,7 +345,7 @@ def send(_): inst.connected = False raise Exception() - sock.send = send + sock.remote.send = send wrote = inst.write_soon(b"xyz") self.assertEqual(wrote, 3) @@ -376,7 +376,7 @@ def test_handle_write_no_notify_after_flush(self): inst.total_outbufs_len = len(inst.outbufs[0]) inst.adj.send_bytes = 1 inst.adj.outbuf_high_watermark = 2 - sock.send = lambda x, do_close=True: False + sock.remote.send = lambda x, do_close=True: False inst.will_close = False inst.last_activity = 0 result = inst.handle_write() @@ -400,7 +400,7 @@ def test__flush_some_full_outbuf_socket_returns_nonzero(self): def test__flush_some_full_outbuf_socket_returns_zero(self): inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: False + sock.remote.send = lambda x: False inst.outbufs[0].append(b"abc") inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) result = inst._flush_some() @@ -805,11 +805,12 @@ def app_check_disconnect(self, environ, start_response): ) return [body] - def _make_app_with_lookahead(self): + def _make_app_with_lookahead(self, recv_bytes=8192): """ Setup a channel with lookahead and store it and the socket in self """ adj = DummyAdjustments() + adj.recv_bytes = recv_bytes adj.channel_request_lookahead = 5 channel, sock, map = self._makeOneWithMap(adj=adj) channel.server.application = self.app_check_disconnect @@ -901,13 +902,66 @@ def test_lookahead_continue(self): self.assertEqual(data.split("\r\n")[-1], "finished") self.assertEqual(self.request_body, b"x") + def test_lookahead_bad_request_drop_extra_data(self): + """ + Send two requests, the first one being bad, split on the recv_bytes + limit, then emulate a race that could happen whereby we read data from + the socket while the service thread is cleaning up due to an error + processing the request. + """ + + invalid_request = [ + "GET / HTTP/1.1", + "Host: localhost:8080", + "Content-length: -1", + "", + ] + + invalid_request_len = len("".join([x + "\r\n" for x in invalid_request])) + + second_request = [ + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "", + "x", + ] + + full_request = invalid_request + second_request + + self._make_app_with_lookahead(recv_bytes=invalid_request_len) + self._send(*full_request) + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + self.assertTrue(self.channel.close_when_flushed) + # Read all of the next request + self.channel.handle_read() + self.channel.handle_read() + # Validate that there is no more data to be read + self.assertEqual(self.sock.remote.local_sent, b"") + # Validate that we dropped the data from the second read, and did not + # create a new request + self.assertEqual(len(self.channel.requests), 0) + data = self.sock.recv(256).decode("ascii") + self.assertFalse(self.channel.readable()) + self.assertTrue(self.channel.writable()) + + # Handle the write, which will close the socket + self.channel.handle_write() + self.assertTrue(self.sock.closed) + + data = self.sock.recv(256) + self.assertEqual(len(data), 0) + class DummySock: blocking = False closed = False def __init__(self): - self.sent = b"" + self.local_sent = b"" + self.remote_sent = b"" def setblocking(self, *arg): self.blocking = True @@ -925,14 +979,44 @@ def close(self): self.closed = True def send(self, data): - self.sent += data + self.remote_sent += data return len(data) def recv(self, buffer_size): - result = self.sent[:buffer_size] - self.sent = self.sent[buffer_size:] + result = self.local_sent[:buffer_size] + self.local_sent = self.local_sent[buffer_size:] return result + def local(self): + outer = self + + class LocalDummySock: + def send(self, data): + outer.local_sent += data + return len(data) + + def recv(self, buffer_size): + result = outer.remote_sent[:buffer_size] + outer.remote_sent = outer.remote_sent[buffer_size:] + return result + + def close(self): + outer.closed = True + + @property + def sent(self): + return outer.remote_sent + + @property + def closed(self): + return outer.closed + + @property + def remote(self): + return outer + + return LocalDummySock() + class DummyLock: notified = False