Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush data from the application thread #364

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def check_client_disconnected(self):
may occasionally check if the client has disconnected and interrupt
execution.
"""

return not self.connected

def writable(self):
Expand Down Expand Up @@ -116,23 +117,30 @@ def handle_write(self):
# right now.
flush = None

self._flush_exception(flush)

if self.close_when_flushed and not self.total_outbufs_len:
self.close_when_flushed = False
self.will_close = True

if self.will_close:
self.handle_close()

def _flush_exception(self, flush):
if flush:
try:
flush()
return (flush(), False)
except OSError:
if self.adj.log_socket_errors:
self.logger.exception("Socket error")
self.will_close = True

return (False, True)
except Exception: # pragma: nocover
self.logger.exception("Unexpected exception when flushing")
self.will_close = True

if self.close_when_flushed and not self.total_outbufs_len:
self.close_when_flushed = False
self.will_close = True

if self.will_close:
self.handle_close()
return (False, True)

def readable(self):
# We might want to read more requests. We can only do this if:
Expand Down Expand Up @@ -190,6 +198,7 @@ def received(self, data):
Receives input asynchronously and assigns one or more requests to the
channel.
"""

if not data:
return False

Expand All @@ -201,6 +210,7 @@ def received(self, data):

# if there are requests queued, we can not send the continue
# header yet since the responses need to be kept in order

if (
self.request.expect_continue
and self.request.headers_finished
Expand All @@ -215,6 +225,7 @@ def received(self, data):

if not self.request.empty:
self.requests.append(self.request)

if len(self.requests) == 1:
# self.requests was empty before so the main thread
# is in charge of starting the task. Otherwise,
Expand Down Expand Up @@ -363,7 +374,14 @@ def write_soon(self, data):
self.total_outbufs_len += num_bytes

if self.total_outbufs_len >= self.adj.send_bytes:
self.server.pull_trigger()
(flushed, exception) = self._flush_exception(self._flush_some)

if (
exception
or not flushed
or self.total_outbufs_len >= self.adj.send_bytes
):
self.server.pull_trigger()

return num_bytes

Expand All @@ -374,6 +392,17 @@ def _flush_outbufs_below_high_watermark(self):

if self.total_outbufs_len > self.adj.outbuf_high_watermark:
with self.outbuf_lock:
(_, exception) = self._flush_exception(self._flush_some)

if exception:
# An exception happened while flushing, wake up the main
# thread, then wait for it to decide what to do next
# (probably close the socket, and then just return)
self.server.pull_trigger()
self.outbuf_lock.wait()

return

while (
self.connected
and self.total_outbufs_len > self.adj.outbuf_high_watermark
Expand Down Expand Up @@ -460,6 +489,7 @@ def service(self):
# Add new task to process the next request
with self.requests_lock:
self.requests.pop(0)

if self.connected and self.requests:
self.server.add_task(self)
elif (
Expand Down
95 changes: 84 additions & 11 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def test_write_soon_empty_byte(self):

def test_write_soon_nonempty_byte(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

wrote = inst.write_soon(b"a")
self.assertEqual(wrote, 1)
self.assertEqual(len(inst.outbufs[0]), 1)
Expand All @@ -224,14 +231,19 @@ def test_write_soon_filewrapper(self):
wrapper = ReadOnlyFileBasedBuffer(f, 8192)
wrapper.prepare()
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

outbufs = inst.outbufs
orig_outbuf = outbufs[0]
wrote = inst.write_soon(wrapper)
self.assertEqual(wrote, 3)
self.assertEqual(len(outbufs), 3)
self.assertEqual(outbufs[0], orig_outbuf)
self.assertEqual(outbufs[1], wrapper)
self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer")
self.assertEqual(len(outbufs), 2)
self.assertEqual(outbufs[0], wrapper)
self.assertEqual(outbufs[1].__class__.__name__, "OverflowableBuffer")

def test_write_soon_disconnected(self):
from waitress.channel import ClientDisconnected
Expand All @@ -253,16 +265,29 @@ def dummy_flush():

def test_write_soon_rotates_outbuf_on_overflow(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.current_outbuf_count = 4
wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
self.assertEqual(len(inst.outbufs), 2)
self.assertEqual(inst.outbufs[0].get(), b"")
self.assertEqual(inst.outbufs[1].get(), b"xyz")
self.assertEqual(len(inst.outbufs), 1)
self.assertEqual(inst.outbufs[0].get(), b"xyz")

def test_write_soon_waits_on_backpressure(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
inst.current_outbuf_count = 4
Expand All @@ -275,11 +300,59 @@ def wait(self):
inst.outbuf_lock = Lock()
wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
self.assertEqual(len(inst.outbufs), 2)
self.assertEqual(inst.outbufs[0].get(), b"")
self.assertEqual(inst.outbufs[1].get(), b"xyz")
self.assertEqual(len(inst.outbufs), 1)
self.assertEqual(inst.outbufs[0].get(), b"xyz")
self.assertTrue(inst.outbuf_lock.waited)

def test_write_soon_attempts_flush_high_water_and_exception(self):
from waitress.channel import ClientDisconnected

inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush, it will raise Exception, which
# disconnects the remote end
def send(_):
inst.connected = False
raise Exception()

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
inst.current_outbuf_count = 4

inst.outbufs[0].append(b"test")

class Lock(DummyLock):
def wait(self):
inst.total_outbufs_len = 0
super().wait()

inst.outbuf_lock = Lock()
self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"xyz"))

# Validate we woke up the main thread to deal with the exception of
# trying to send
self.assertTrue(inst.outbuf_lock.waited)
self.assertTrue(inst.server.trigger_pulled)

def test_write_soon_flush_and_exception(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush, it will raise Exception, which
# disconnects the remote end
def send(_):
inst.connected = False
raise Exception()

sock.send = send

wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
# Validate we woke up the main thread to deal with the exception of
# trying to send
self.assertTrue(inst.server.trigger_pulled)

def test_handle_write_notify_after_flush(self):
inst, sock, map = self._makeOneWithMap()
inst.requests = [True]
Expand Down