Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
Fix a race condition on recv_bytes boundary when request is invalid
  • Loading branch information
digitalresistor authored Oct 29, 2024
2 parents fdd2ecf + 810a435 commit e435901
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 15 deletions.
14 changes: 14 additions & 0 deletions docs/arguments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,17 @@ url_prefix
be stripped of the prefix.

Default: ``''``

channel_request_lookahead
Sets the amount of requests we can continue to read from the socket, while
we are processing current requests. The default value won't allow any
lookahead, increase it above ``0`` to enable.

When enabled this inserts a callable ``waitress.client_disconnected`` into
the environment that allows the task to check if the client disconnected
while waiting for the response at strategic points in the execution and to
cancel the operation.

Default: ``0``

.. versionadded:: 2.0.0
11 changes: 10 additions & 1 deletion src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def readable(self):
# 1. We're not already about to close the connection.
# 2. We're not waiting to flush remaining data before closing the
# connection
# 3. There are not too many tasks already queued
# 3. There are not too many tasks already queued (if lookahead is enabled)
# 4. There's no data in the output buffer that needs to be sent
# before we potentially create a new task.

Expand Down Expand Up @@ -196,6 +196,15 @@ def received(self, data):
return False

with self.requests_lock:
# Don't bother processing anymore data if this connection is about
# to close. This may happen if readable() returned True, on the
# main thread before the service thread set the close_when_flushed
# flag, and we read data but our service thread is attempting to
# shut down the connection due to an error. We want to make sure we
# do this while holding the request_lock so that we can't race
if self.will_close or self.close_when_flushed:
return False

while data:
if self.request is None:
self.request = self.parser_class(self.adj)
Expand Down
112 changes: 98 additions & 14 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _makeOneWithMap(self, adj=None):
map = {}
inst = self._makeOne(sock, "127.0.0.1", adj, map=map)
inst.outbuf_lock = DummyLock()
return inst, sock, map
return inst, sock.local(), map

def test_ctor(self):
inst, _, map = self._makeOneWithMap()
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_write_soon_nonempty_byte(self):
def send(_):
return 0

sock.send = send
sock.remote.send = send

wrote = inst.write_soon(b"a")
self.assertEqual(wrote, 1)
Expand All @@ -236,7 +236,7 @@ def test_write_soon_filewrapper(self):
def send(_):
return 0

sock.send = send
sock.remote.send = send

outbufs = inst.outbufs
wrote = inst.write_soon(wrapper)
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_write_soon_rotates_outbuf_on_overflow(self):
def send(_):
return 0

sock.send = send
sock.remote.send = send

inst.adj.outbuf_high_watermark = 3
inst.current_outbuf_count = 4
Expand All @@ -286,7 +286,7 @@ def test_write_soon_waits_on_backpressure(self):
def send(_):
return 0

sock.send = send
sock.remote.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
Expand Down Expand Up @@ -315,7 +315,7 @@ def send(_):
inst.connected = False
raise Exception()

sock.send = send
sock.remote.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
Expand Down Expand Up @@ -345,7 +345,7 @@ def send(_):
inst.connected = False
raise Exception()

sock.send = send
sock.remote.send = send

wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_handle_write_no_notify_after_flush(self):
inst.total_outbufs_len = len(inst.outbufs[0])
inst.adj.send_bytes = 1
inst.adj.outbuf_high_watermark = 2
sock.send = lambda x, do_close=True: False
sock.remote.send = lambda x, do_close=True: False
inst.will_close = False
inst.last_activity = 0
result = inst.handle_write()
Expand All @@ -400,7 +400,7 @@ def test__flush_some_full_outbuf_socket_returns_nonzero(self):

def test__flush_some_full_outbuf_socket_returns_zero(self):
inst, sock, map = self._makeOneWithMap()
sock.send = lambda x: False
sock.remote.send = lambda x: False
inst.outbufs[0].append(b"abc")
inst.total_outbufs_len = sum(len(x) for x in inst.outbufs)
result = inst._flush_some()
Expand Down Expand Up @@ -805,11 +805,12 @@ def app_check_disconnect(self, environ, start_response):
)
return [body]

def _make_app_with_lookahead(self):
def _make_app_with_lookahead(self, recv_bytes=8192):
"""
Setup a channel with lookahead and store it and the socket in self
"""
adj = DummyAdjustments()
adj.recv_bytes = recv_bytes
adj.channel_request_lookahead = 5
channel, sock, map = self._makeOneWithMap(adj=adj)
channel.server.application = self.app_check_disconnect
Expand Down Expand Up @@ -901,13 +902,66 @@ def test_lookahead_continue(self):
self.assertEqual(data.split("\r\n")[-1], "finished")
self.assertEqual(self.request_body, b"x")

def test_lookahead_bad_request_drop_extra_data(self):
"""
Send two requests, the first one being bad, split on the recv_bytes
limit, then emulate a race that could happen whereby we read data from
the socket while the service thread is cleaning up due to an error
processing the request.
"""

invalid_request = [
"GET / HTTP/1.1",
"Host: localhost:8080",
"Content-length: -1",
"",
]

invalid_request_len = len("".join([x + "\r\n" for x in invalid_request]))

second_request = [
"POST / HTTP/1.1",
"Host: localhost:8080",
"Content-Length: 1",
"",
"x",
]

full_request = invalid_request + second_request

self._make_app_with_lookahead(recv_bytes=invalid_request_len)
self._send(*full_request)
self.channel.handle_read()
self.assertEqual(len(self.channel.requests), 1)
self.channel.server.tasks[0].service()
self.assertTrue(self.channel.close_when_flushed)
# Read all of the next request
self.channel.handle_read()
self.channel.handle_read()
# Validate that there is no more data to be read
self.assertEqual(self.sock.remote.local_sent, b"")
# Validate that we dropped the data from the second read, and did not
# create a new request
self.assertEqual(len(self.channel.requests), 0)
data = self.sock.recv(256).decode("ascii")
self.assertFalse(self.channel.readable())
self.assertTrue(self.channel.writable())

# Handle the write, which will close the socket
self.channel.handle_write()
self.assertTrue(self.sock.closed)

data = self.sock.recv(256)
self.assertEqual(len(data), 0)


class DummySock:
blocking = False
closed = False

def __init__(self):
self.sent = b""
self.local_sent = b""
self.remote_sent = b""

def setblocking(self, *arg):
self.blocking = True
Expand All @@ -925,14 +979,44 @@ def close(self):
self.closed = True

def send(self, data):
self.sent += data
self.remote_sent += data
return len(data)

def recv(self, buffer_size):
result = self.sent[:buffer_size]
self.sent = self.sent[buffer_size:]
result = self.local_sent[:buffer_size]
self.local_sent = self.local_sent[buffer_size:]
return result

def local(self):
outer = self

class LocalDummySock:
def send(self, data):
outer.local_sent += data
return len(data)

def recv(self, buffer_size):
result = outer.remote_sent[:buffer_size]
outer.remote_sent = outer.remote_sent[buffer_size:]
return result

def close(self):
outer.closed = True

@property
def sent(self):
return outer.remote_sent

@property
def closed(self):
return outer.closed

@property
def remote(self):
return outer

return LocalDummySock()


class DummyLock:
notified = False
Expand Down

0 comments on commit e435901

Please sign in to comment.